web-dev-qa-db-fra.com

comment itérer plusieurs fois un jeu de données à l'aide de l'API du jeu de données tensorflow

Comment sortir la valeur dans un jeu de données plusieurs fois? (l'ensemble de données est créé par l'API de l'ensemble de données de tensorflow) 

import tensorflow as tf

dataset = tf.contrib.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
Epoch = 10

for i in range(Epoch):
   for j in range(100):
      value = sess.run(next_element)
      assert j == value
      print(j)

Message d'erreur:

tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]

Comment faire ce travail?

5
void

Tout d’abord, je vous conseille de lire Data Set Guide . Il est décrit tous les détails de l’API DataSet.

Votre question concerne la répétition des données plusieurs fois. Voici deux solutions pour cela:

  1. Itération de toutes les époques à la fois, aucune information sur la fin des époques
import tensorflow as tf

Epoch   = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.repeat(Epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0
j = 0
while True:
    try:
        value = sess.run(next_element)
        assert j == value
        j += 1
        num_batch += 1
        if j > 99: # new Epoch
            j = 0
    except tf.errors.OutOfRangeError:
        break

print ("Num Batch: ", num_batch)
  1. La deuxième option vous informe de la fin de chaque époque, vous pouvez donc ex. vérifier la perte de validation:
import tensorflow as tf

Epoch = 10
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0

for e in range(Epoch):
    print ("Epoch: ", e)
    j = 0
    sess.run(iterator.initializer)
    while True:
        try:
            value = sess.run(next_element)
            assert j == value
            j += 1
            num_batch += 1
        except tf.errors.OutOfRangeError:
            break

print ("Num Batch: ", num_batch)
16
melgor89

Si votre version de tensorflow est 1.3+, je recommande l'API de haut niveau tf.train.MonitoredTrainingSession. La sess créée par cette API peut détecter automatiquement tf.errors.OutOfRangeError avec sess.should_stop(). Pour la plupart des situations de formation, vous devez mélanger les données et obtenir un lot à chaque étape. Je les ai ajoutées dans le code suivant.

import tensorflow as tf

Epoch = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
dataset = dataset.batch(batch_size=32)     # batch_size=1 if you want to get only one element per step
dataset = dataset.repeat(Epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

num_batch = 0
with tf.train.MonitoredTrainingSession() as sess:
    while not sess.should_stop():
        value = sess.run(next_element)
        num_batch += 1
        print("Num Batch: ", num_batch)
3
Tom

Essaye ça

while True:
  try:
    print(sess.run(value))
  except tf.errors.OutOfRangeError:
    break

Chaque fois que l'itérateur de jeu de données atteint la fin des données, il lève tf.errors.OutOfRangeError, vous pouvez le récupérer avec except et le démarrer à partir du début.

2
Grigor Carran