web-dev-qa-db-fra.com

Fractionner un jeu de données créé par l'API Tensorflow pour former et tester?

Est-ce que quelqu'un sait comment fractionner un jeu de données créé par l'API de jeu de données (tf.data.Dataset) dans Tensorflow en test et entraînement?

10
Dani

En supposant que vous ayez la variable all_dataset de type tf.data.Dataset:

test_dataset = all_dataset.take(1000) 
train_dataset = all_dataset.skip(1000)

Le jeu de données test contient maintenant les 1000 premiers éléments et le reste est destiné à la formation.

9
apatsekin

Vous pouvez utiliser Dataset.take() et Dataset.skip():

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)

Pour plus de généralité, j'ai donné un exemple en utilisant un fractionnement train/val/test 70/15/15, mais si vous n'avez pas besoin d'un test ou d'un ensemble de valeurs, ignorez simplement les 2 dernières lignes.

Prenez :

Crée un jeu de données avec au plus un nombre d'éléments de ce jeu de données.

Sauter :

Crée un jeu de données qui ignore le nombre d'éléments de ce jeu de données.

Vous voudrez peut-être aussi vous pencher sur Dataset.shard() :

Crée un jeu de données qui comprend uniquement 1/num_shards de ce jeu de données.


Avertissement Je suis tombé par hasard sur cette question après avoir répondu celui-ci alors je pensais que je répandrais l'amour

2
ted

Vous pouvez utiliser shard:

dataset = dataset.shuffle()  # optional
trainset = dataset.shard(2, 0)
testset = dataset.shard(2, 1)

Voir: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard

0
Ben-Uri

Maintenant, Tensorflow ne contient aucun outil pour cela.
Vous pouvez utiliser sklearn.model_selection.train_test_split pour générer un ensemble de données train/eval/test, puis créer tf.data.Dataset 

0
Lunar_one