web-dev-qa-db-fra.com

TensorFlow: tf.train.batch charge-t-il automatiquement le prochain lot lorsque celui-ci est terminé?

Par exemple, après avoir créé mes opérations, alimenté les données du lot et exécuté l'opération, tf.train.batch alimente-t-il automatiquement un autre lot de données dans la session?

Je demande ceci parce que tf.train.batch a un attribut de allow_smaller_final_batch qui permet au lot final d’être chargé avec une taille inférieure à la taille de lot indiquée. Est-ce que cela signifie que même sans boucle, le prochain lot pourrait être automatiquement alimenté? D'après les codes du didacticiel, je suis plutôt confus. Lorsque je charge un seul lot, j'obtiens littéralement une taille de lot unique de forme [batch_size, height, width, num_channels], mais le documentation le dit Creates batches of tensors in tensors. De même, lorsque je lis le code du didacticiel dans le tf- tutoriel pas à pas , où il existe une fonction appelée load_batch, seuls 3 tenseurs sont renvoyés: images, images_raw, labels. Où se trouvent les «lots» de données, comme expliqué dans la documentation?

Merci de votre aide.

13
kwotsin

... tf.train.batch alimente-t-il automatiquement un autre lot de données vers la session?

Non, rien ne se passe automatiquement. Vous devez rappeler sess.run(...) pour charger un nouveau lot.

Est-ce que cela signifie que même sans boucle, le prochain lot pourrait être automatiquement alimenté?

Non. tf.train.batch(..) chargera toujours batch_size tenseurs. Si vous avez par exemple 100 images et un batch_size=30, alors vous aurez 3 * 30 lots car vous pouvez appeler sess.run(batch) trois fois avant que la file d’entrée ne commence à partir du début (ou s’arrête si Epoch=1). Cela signifie que vous manquez des échantillons 100-3*30=10 lors de la formation. Si vous ne voulez pas les manquer, vous pouvez exécuter tf.train.batch(..., allow_smaller_final_batch=True). Vous disposez alors de 3x 30 lots d'échantillons et d'un lot de 10 échantillons avant le redémarrage de la file d'attente.

Laissez-moi également élaborer avec un exemple de code:

queue = tf.train.string_input_producer(filenames,
        num_epochs=1) # only iterate through all samples in dataset once

reader = tf.TFRecordReader() # or any reader you need
_, example = reader.read(queue)

image, label = your_conversion_fn(example)

# batch will now load up to 100 image-label-pairs on sess.run(...)
# most tf ops are tuned to work on batches
# this is faster and also gives better result on e.g. gradient calculation
batch = tf.train.batch([image, label], batch_size=100)

with tf.Session() as sess:
    # "boilerplate" code
    sess.run([
        tf.local_variables_initializer(),
        tf.global_variables_initializer(),
    ])
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        # in most cases coord.should_stop() will return True
        # when there are no more samples to read
        # if num_epochs=0 then it will run for ever
        while not coord.should_stop():
            # will start reading, working data from input queue
            # and "fetch" the results of the computation graph
            # into raw_images and raw_labels
            raw_images, raw_labels = sess.run([images, labels])
    finally:
        coord.request_stop()
        coord.join(threads)
16
bodokaiser

Vous devez appeler sess.run et lui transmettre le lot chaque fois que vous souhaitez charger le prochain lot. Voir le code ci-dessous.

img = [0,1,2,3,4,5,6,7,8]
lbl = [0,1,2,3,4,5,6,7,8]
images = tf.convert_to_tensor(img)
labels = tf.convert_to_tensor(lbl)
input_queue = tf.train.slice_input_producer([images,labels])
sliced_img = input_queue[0]
sliced_lbl = input_queue[1]

img_batch, lbl_batch = tf.train.batch([sliced_img,sliced_lbl], batch_size=3)
with tf.Session() as sess:
    coord   = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(0,3): #batch size
        image_batch,label_batch = sess.run([img_batch,lbl_batch ])
        print(image_batch, label_batch)

    coord.request_stop()
    coord.join(threads)

la réponse serait quelque chose comme ceci: 

[4,1,8] [4,1,8]

[2,3,7] [2,3,7]

[2,6,8] [2,6,8]

0
Abdul Bari