web-dev-qa-db-fra.com

TensorFlow freeze_graph.py: Le nom 'save/Const: 0' fait référence à un tenseur qui n'existe pas

J'essaie actuellement d'exporter un modèle TensorFlow formé en tant que fichier ProtoBuf pour l'utiliser avec l'API TensorFlow C++ sur Android. Par conséquent, j'utilise le script freeze_graph.py .

J'ai exporté mon modèle en utilisant tf.train.write_graph:

tf.train.write_graph(graph_def, FLAGS.save_path, out_name, as_text=True)

et j'utilise un point de contrôle enregistré avec tf.train.Saver.

J'appelle freeze_graph.py comme décrit en haut du script. Après avoir compilé, je cours

bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=<path_to_protobuf_file> \
--input_checkpoint=<model_name>.ckpt-10000 \
--output_graph=<output_protobuf_file_path> \
--output_node_names=dropout/mul_1

Cela me donne le message d'erreur suivant:

TypeError: Cannot interpret feed_dict key as Tensor: The name 'save/Const:0' refers to a Tensor which does not exist. The operation, 'save/Const', does not exist in the graph.

Comme l’erreur indique, je n’ai pas de tenseur save/Const:0 dans mon modèle exporté. Cependant, le code freeze_graph.py indique que l’on peut spécifier ce nom de tenseur par le drapeau filename_tensor_name. Malheureusement, je ne trouve aucune information sur ce que ce tenseur devrait être et comment le régler correctement pour mon modèle.

Quelqu'un peut-il me dire comment produire un tenseur save/Const:0 dans mon modèle ProtoBuf exporté ou comment définir le drapeau filename_tensor_name correctement?

12
mackcmillion

L’indicateur --filename_tensor_name permet de spécifier le nom d’un tenseur d’espace réservé créé lors de la construction de tf.train.Saver pour votre modèle. *

Dans votre programme d'origine, vous pouvez imprimer la valeur de saver.saver_def.filename_tensor_name pour obtenir la valeur que vous devez transmettre pour cet indicateur. Vous pouvez également vouloir imprimer la valeur de saver.saver_def.restore_op_name pour obtenir une valeur pour le drapeau --restore_op_name (car je suppose que la valeur par défaut ne sera pas correcte pour votre graphique).

Sinon, le tampon de protocole tf.train.SaverDef inclut toutes les informations nécessaires à la reconstruction des informations pertinentes pour ces indicateurs. Si vous préférez, vous pouvez écrire saver.saver_def dans un fichier et transmettre le nom de ce fichier sous la forme de l'indicateur --input_saver à freeze_graph.py.


* L'étendue du nom par défaut pour tf.train.Saver est "save/" et l'espace réservé est en réalité un tf.constant() dont le nom par défaut est "Const:0", ce qui explique pourquoi l'indicateur par défaut est "save/Const:0".

6
mrry

J'ai remarqué qu'une erreur m'était arrivée quand j'avais arrangé le code comme ceci:

sess = tf.Session()
tf.train.write_graph(sess.graph_def, '', '/tmp/train.pbtxt')
init = tf.initialize_all_variables()
saver = tf.train.Saver()
sess.run(init)

Cela a fonctionné après avoir changé la disposition du code comme ceci:

# Add ops to save and restore all the variables.
saver = tf.train.Saver()    
init = tf.initialize_all_variables()
sess = tf.Session()
tf.train.write_graph(sess.graph_def, '', '/tmp/train.pbtxt')
sess.run(init)

Je ne sais pas trop pourquoi. @mrry pourriez-vous expliquer un peu plus?

2
Drag0