web-dev-qa-db-fra.com

TensorFlow enregistre / charge un graphique à partir d'un fichier

D'après ce que j'ai recueilli jusqu'à présent, il existe différentes manières de transférer un graphique TensorFlow dans un fichier, puis de le charger dans un autre programme, mais je n'ai pas été en mesure de trouver des exemples/informations clairs sur leur fonctionnement. Ce que je sais déjà, c'est ceci:

  1. Enregistrez les variables du modèle dans un fichier de point de contrôle (.ckpt) à l'aide d'une tf.train.Saver() et restaurez-les plus tard ( source )
  2. Enregistrez un modèle dans un fichier .pb et rechargez-le à l’aide de tf.train.write_graph() et tf.import_graph_def() ( source ).
  3. Chargez un modèle à partir d'un fichier .pb, renouvelez-le et transférez-le dans un nouveau fichier .pb à l'aide de Bazel ( source )
  4. Geler le graphique pour enregistrer le graphique et les poids ensemble ( source )
  5. Utilisez as_graph_def() pour enregistrer le modèle et pour les pondérations/variables, mappez-les en constantes ( source )

Cependant, je n'ai pas pu éclaircir plusieurs questions concernant ces différentes méthodes:

  1. S'agissant des fichiers de point de contrôle, enregistrent-ils uniquement les poids formés d'un modèle? Les fichiers de point de contrôle peuvent-ils être chargés dans un nouveau programme et utilisés pour exécuter le modèle ou servent-ils simplement à enregistrer les poids dans un modèle à une heure/étape donnée?
  2. En ce qui concerne tf.train.write_graph(), les poids/variables sont-ils également enregistrés?
  3. En ce qui concerne Bazel, peut-il uniquement enregistrer dans/charger à partir des fichiers .pb pour le recyclage? Y at-il une simple commande Bazel juste pour vider un graphique dans un fichier .pb?
  4. En ce qui concerne le gel, un graphique gelé peut-il être chargé avec tf.import_graph_def()?
  5. La démo Android de TensorFlow est chargée dans le modèle Inception de Google à partir d'un fichier .pb. Si je voulais substituer mon propre fichier .pb, comment pourrais-je m'y prendre? Aurais-je besoin de changer n'importe quel code/méthodes natifs?
  6. En général, quelle est exactement la différence entre toutes ces méthodes? Ou plus généralement, quelle est la différence entre as_graph_def() /. Ckpt/.pb?

En bref, ce que je recherche, c’est une méthode pour enregistrer à la fois un graphique (comme dans, les différentes opérations et autres) et ses pondérations/variables dans un fichier, qui peut ensuite être utilisé pour charger le graphique et les pondérations dans un autre programme. , pour utilisation (pas nécessairement formation continue/recyclage).

La documentation sur ce sujet n'étant pas très simple, toute réponse ou information serait grandement appréciée.

87
Technicolor

Il existe de nombreuses façons d’aborder le problème de la sauvegarde d’un modèle dans TensorFlow, ce qui peut le rendre un peu déroutant. Prenant chacune de vos sous-questions à tour de rôle:

  1. Les fichiers de point de contrôle (produits par exemple en appelant saver.save() sur un tf.train.Saver objet) ne contient que les poids et toutes les autres variables définies dans le même programme. Pour les utiliser dans un autre programme, vous devez recréer la structure graphique associée (par exemple, en exécutant du code pour le reconstruire, ou en appelant tf.import_graph_def() ), qui indique à TensorFlow quoi faire avec ces poids. Notez que l'appel de saver.save() génère également un fichier contenant un MetaGraphDef , qui contient un graphique et des détails sur la manière d'associer les poids d'un point de contrôle à ce graphique. Voir le tutoriel pour plus de détails.

  2. tf.train.write_graph() écrit uniquement la structure du graphe; pas les poids.

  3. Bazel n'a aucun lien avec la lecture ou l'écriture de graphiques TensorFlow. (Je comprends peut-être mal votre question: n'hésitez pas à la clarifier dans un commentaire.)

  4. Un graphique figé peut être chargé avec tf.import_graph_def() . Dans ce cas, les poids sont (généralement) incorporés dans le graphique, vous n'avez donc pas besoin de charger un point de contrôle séparé.

  5. Le principal changement consisterait à mettre à jour les noms du ou des tenseurs introduits dans le modèle, ainsi que les noms du ou des tenseurs extraits du modèle. Dans la démo TensorFlow Android, cela correspondrait aux chaînes inputName et outputName qui sont transmises à TensorFlowClassifier.initializeTensorFlow() .

  6. GraphDef est la structure du programme, qui ne change généralement pas au cours du processus de formation. Le point de contrôle est un instantané de l'état d'un processus de formation, qui change généralement à chaque étape du processus de formation. En conséquence, TensorFlow utilise différents formats de stockage pour ces types de données, et l'API de bas niveau fournit différentes manières de les enregistrer et de les charger. Bibliothèques de niveau supérieur, telles que les bibliothèques MetaGraphDef , Keras et skflow s'appuie sur ces mécanismes pour fournir des moyens plus pratiques de sauvegarder et de restaurer un modèle entier.

74
mrry

Vous pouvez essayer le code suivant:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
1