web-dev-qa-db-fra.com

Remplacement de tf.placeholder et feed_dict par l'API tf.data

J'ai un modèle TensorFlow existant qui utilisait un titulaire tf.placeholder pour l'entrée de modèle et le paramètre feed_dict de tf.Session (). Run pour alimenter des données. Auparavant, l'intégralité du jeu de données était lue en mémoire et transmise de cette manière. 

Je souhaite utiliser un ensemble de données beaucoup plus volumineux et tirer parti des améliorations apportées aux performances de l'API tf.data. J'ai défini un tf.data.TextLineDataset et un itérateur one-shot, mais j'ai du mal à comprendre comment intégrer les données dans le modèle pour les former.

Au début, j'ai essayé de définir simplement feed_dict en tant que dictionnaire de l'espace réservé à iterator.get_next (), mais cela m'a donné une erreur en disant que la valeur d'un flux ne peut être un objet tf.Tensor. Plus au fond, j’ai compris que c’est parce que l’objet retourné par iterator.get_next () fait déjà partie du graphique, contrairement à ce que vous donneriez à feed_dict - et que je ne devrais pas essayer d’utiliser Feed_dict du tout pour raisons de performance. 

Alors maintenant, je me suis débarrassé de l'entrée tf.placeholder et l'ai remplacée par un paramètre du constructeur de la classe qui définit mon modèle; lors de la construction du modèle dans mon code d'apprentissage, je passe la sortie de iterator.get_next () à ce paramètre. Cela semble déjà un peu maladroit car cela brise la séparation entre la définition du modèle et la procédure de formation/jeux de données. Et je reçois maintenant une erreur en disant que le Tenseur représentant (je crois) les entrées de mon modèle doit provenir du même graphique que le Tenseur de iterator.get_next ().

Suis-je sur la bonne voie avec cette approche et fais juste quelque chose de mal avec la façon dont j'ai configuré le graphique et la session, ou quelque chose du genre? (Les jeux de données et le modèle sont tous deux initialisés en dehors d'une session et l'erreur se produit avant que j'essaie d'en créer une.) 

Ou suis-je totalement hors de propos avec cela et dois-je faire quelque chose de différent, comme utiliser l'API Estimator et tout définir dans une fonction d'entrée?

Voici un code montrant un exemple minimal:

import tensorflow as tf
import numpy as np

class Network:
    def __init__(self, x_in, input_size):
        self.input_size = input_size
        # self.x_in = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size))  # Original
        self.x_in = x_in
        self.output_size = 3

        tf.reset_default_graph()  # This turned out to be the problem

        self.layer = tf.layers.dense(self.x_in, self.output_size, activation=tf.nn.relu)
        self.loss = tf.reduce_sum(tf.square(self.layer - tf.constant(0, dtype=tf.float32, shape=[self.output_size])))

data_array = np.random.standard_normal([4, 10]).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data_array).batch(2)

model = Network(x_in=dataset.make_one_shot_iterator().get_next(), input_size=dataset.output_shapes[-1])
13
erobertc

La ligne tf.reset_default_graph() dans le constructeur du modèle à partir du code original qui m'a été donné en était la cause. Enlever ça corrige ça.

1
erobertc

Il m'a fallu un peu de temps pour comprendre. Vous êtes sur la bonne voie. La définition complète du jeu de données ne constitue qu'une partie du graphique. Je le crée généralement en tant que classe différente de celle de ma classe Model et passe l'ensemble de données à la classe Model. Je spécifie la classe de jeu de données que je souhaite charger sur la ligne de commande, puis je la charge de manière dynamique, ce qui permet de découpler le jeu de données et le graphique de manière modulaire.

Notez que vous pouvez (et devriez) nommer tous les tenseurs du jeu de données. Cela facilite vraiment la compréhension des choses lorsque vous transmettez des données à travers les différentes transformations dont vous avez besoin.

Vous pouvez écrire des scénarios de test simples qui extraient des échantillons de la iterator.get_next() et les affichent. Vous aurez alors quelque chose comme sess.run(next_element_tensor), pas feed_dict, comme vous l'avez correctement noté.

Une fois que vous aurez compris, vous commencerez probablement à aimer le pipeline d’entrée de jeux de données. Cela vous oblige à bien modulariser votre code et à le transformer en une structure facile à tester. 

Assurez-vous de lire le guide du développeur, il y a des tonnes d'exemples:

https://www.tensorflow.org/programmers_guide/datasets

Une autre chose que je noterai est combien il est facile de travailler avec un train et de tester des données avec ce pipeline. Cela est important car vous effectuez souvent une augmentation des données sur le jeu de données d'apprentissage que vous n'effectuez pas sur le jeu de données test. from_string_handle vous permet de le faire et est clairement décrit dans le guide ci-dessus.

5
David Parks