web-dev-qa-db-fra.com

TypeError: l'argument Fetch a un type non valide float32, doit être une chaîne ou un tenseur

Je forme un CNN assez similaire à celui de l'exemple this , pour la segmentation d'image. Les images sont 1500x1500x1 et les étiquettes sont de la même taille.

Après avoir défini la structure CNN et lancé la session comme dans cet exemple de code: (conv_net_test.py)

with tf.Session() as sess:
sess.run(init)
summ = tf.train.SummaryWriter('/tmp/logdir/', sess.graph_def)
step = 1
print ("import data, read from read_data_sets()...")

#Data defined by me, returns a DataSet object with testing and training images and labels for segmentation problem.
data = import_data_test.read_data_sets('Dataset')

# Keep training until reach max iterations
while step * batch_size < training_iters:
    batch_x, batch_y = data.train.next_batch(batch_size)
    print ("running backprop for step %d" % step)
    batch_x = batch_x.reshape(batch_size, n_input, n_input, n_channels)
    batch_y = batch_y.reshape(batch_size, n_input, n_input, n_channels)
    batch_y = np.int64(batch_y)
    sess.run(optimizer, feed_dict={x: batch_x, y: batch_y, keep_prob: dropout})
    if step % display_step == 0:
        # Calculate batch loss and accuracy
        #pdb.set_trace()
        loss, acc = sess.run([loss, accuracy], feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
    step += 1
print "Optimization Finished"

J'ai rencontré la TypeError suivante (stacktrace ci-dessous):

    conv_net_test.py in <module>()
    178             #pdb.set_trace()
--> 179             loss, acc = sess.run([loss, accuracy], feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
    180         step += 1
    181     print "Optimization Finished!"

tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    370     try:
    371       result = self._run(None, fetches, feed_dict, options_ptr,
--> 372                          run_metadata_ptr)
    373       if run_metadata:
    374         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
    582 
    583     # Validate and process fetches.
--> 584     processed_fetches = self._process_fetches(fetches)
    585     unique_fetches = processed_fetches[0]
    586     target_list = processed_fetches[1]

tensorflow/python/client/session.pyc in _process_fetches(self, fetches)
    538           raise TypeError('Fetch argument %r of %r has invalid type %r, '
    539                           'must be a string or Tensor. (%s)'
--> 540                           % (subfetch, fetch, type(subfetch), str(e)))

TypeError: Fetch argument 1.4415792e+2 of 1.4415792e+2 has invalid type <type 'numpy.float32'>, must be a string or Tensor. (Can not convert a float32 into a Tensor or Operation.)

Je suis perplexe à ce stade. C'est peut-être un cas simple de conversion du type, mais je ne sais pas comment/où. Aussi, pourquoi la perte doit-elle être une chaîne? (En supposant que la même erreur apparaîtra également pour la précision, une fois cela corrigé).

Toute aide appréciée!

16
mshiv

Lorsque vous utilisez loss = sess.run(loss), vous redéfinissez dans python la variable loss.

La première fois, il fonctionnera bien. La deuxième fois, vous essaierez de faire:

sess.run(1.4415792e+2)

Parce que loss est maintenant un flottant.


Vous devez utiliser des noms différents comme:

loss_val, acc = sess.run([loss, accuracy], feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
60
Olivier Moindrot