web-dev-qa-db-fra.com

TensorFlow: comment est défini dataset.train.next_batch?

J'essaie d'apprendre TensorFlow et d'étudier l'exemple sur: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb

J'ai ensuite quelques questions dans le code ci-dessous:

for Epoch in range(training_epochs):
    # Loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        # Run optimization op (backprop) and cost op (to get loss value)
        _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
    # Display logs per Epoch step
    if Epoch % display_step == 0:
        print("Epoch:", '%04d' % (Epoch+1),
              "cost=", "{:.9f}".format(c))

Puisque mnist n'est qu'un ensemble de données, que fait exactement mnist.train.next_batch signifier? Comment était le dataset.train.next_batch défini?

Merci!

13
Edamame

L'objet mnist est renvoyé par la fonction read_data_sets() définie dans le module tf.contrib.learn. La méthode mnist.train.next_batch(batch_size) est implémentée ici , et elle retourne un Tuple de deux tableaux, où le premier représente un lot d'images MNIST batch_size, Et le second représente un lot d'étiquettes batch-size correspondant à ces images.

Les images sont renvoyées sous forme de tableau NumPy 2D de taille [batch_size, 784] (Car il y a 784 pixels dans une image MNIST) et les étiquettes sont renvoyées sous forme de tableau NumPy 1D de taille [batch_size] (si read_data_sets() a été appelée avec one_hot=False) ou un tableau NumPy 2D de taille [batch_size, 10] (si read_data_sets() a été appelé avec one_hot=True).

25
mrry