web-dev-qa-db-fra.com

Erreur "TypeError: l'objet" Tenseur "n'est pas itérable" avec l'estimateur tensorflow

J'ai une source de données (infinie) générée de manière procédurale et j'essaie de l'utiliser comme entrée dans le Tensorflow de haut niveau Estimator pour former un détecteur d'objet 3D basé sur une image.

J'ai configuré le jeu de données comme dans Tensorflor Estimator Quickstart , et mon dataset_input_fn renvoie un Tuple de fonctionnalités et d'étiquettes Tensor, tout comme les Estimator.train spécifie la fonction, et comment cela le tutoriel montre , mais je reçois une erreur en essayant d'appeler la fonction train:

TypeError: 'Tensor' object is not iterable.

Qu'est-ce que je fais mal?


    def data_generator():
        """
        Generator for image (features) and ground truth object positions (labels)

        Sample an image and object positions from a procedurally generated data source
        """
        while True:
            source.step()  # generate next data point

            object_ground_truth = source.get_ground_truth() # list of 9 floats
            cam_img = source.get_cam_frame()  # image (224, 224, 3) 
            yield (cam_img, object_ground_truth)

    def dataset_input_fn():
        """
        Tensorflow `Dataset` object from generator
        """

        dataset = tf.data.Dataset.from_generator(data_generator, (tf.uint8, tf.float32), \
            (tf.TensorShape([224, 224, 3]), tf.TensorShape([9])))
        dataset = dataset.batch(16)

        iterator = dataset.make_one_shot_iterator()

        features, labels = iterator.get_next()
        return features, labels

    def main():
        """
        Estimator [from Keras model](https://www.tensorflow.org/programmers_guide/estimators#creating_estimators_from_keras_models) 

        Try to call `est_vgg.train()` leads to the error
        """
        ....
        est_vgg16 = tf.keras.estimator.model_to_estimator(keras_model=keras_vgg16)
        est_vgg16.train(input_fn=dataset_input_fn, steps=10)
        ....

Voici le code complet

(note: les choses sont nommées différemment de cette question)

Voici la trace de la pile:

Traceback (most recent call last):
  File "./rock_detector.py", line 155, in <module>
    main()
  File "./rock_detector.py", line 117, in main
    est_vgg16.train(input_fn=dataset_input_fn, steps=10)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 302, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 711, in _train_model
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 694, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 145, in model_fn
    labels)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 92, in _clone_and_build_model
    keras_model, features)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 58, in _create_ordered_io
    for key in estimator_io_dict:
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 505, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.
7
matwilso

Faites en sorte que votre fonction d'entrée renvoie un dictionnaire de fonctionnalités comme celle-ci:

def dataset_input_fn():
  ...
  features, labels = iterator.get_next()
  return {'image': features}, labels
5
Maxim