web-dev-qa-db-fra.com

TensorFlow, pourquoi il y a 3 fichiers après avoir enregistré le modèle?

Après avoir lu les docs , j’ai sauvegardé un modèle dans TensorFlow, voici mon code de démonstration:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

mais après cela, j'ai trouvé il y a 3 fichiers

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

Et je ne peux pas restaurer le modèle en restaurant le fichier model.ckpt, car ce type de fichier n'existe pas. Voici mon code

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Alors, pourquoi il y a 3 fichiers?

83
GoingMyWay

Essaye ça:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

La méthode de sauvegarde TensorFlow enregistre trois types de fichiers car elle stocke les structure graphique séparément de la valeurs variables. Le fichier .meta décrit la structure du graphe enregistrée. Vous devez donc l'importer avant de restaurer le point de contrôle (sinon, il ne sait pas à quelles variables correspondent les valeurs de point de contrôle enregistrées).

Alternativement, vous pouvez faire ceci:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

Même s'il n'y a pas de fichier nommé model.ckpt, vous vous référez quand même au point de contrôle enregistré sous ce nom lors de la restauration. À partir du saver.py code source

Les utilisateurs doivent uniquement interagir avec le préfixe spécifié par l'utilisateur ... à la place de tout chemin physique.

85
T.K. Bartel
  • meta file : décrit la structure de graphe enregistrée, comprend GraphDef, SaverDef, etc. puis appliquez tf.train.import_meta_graph('/tmp/model.ckpt.meta'), restaurera Saver et Graph.

  • fichier index : c'est une table chaîne-chaîne immuable (tensorflow :: table :: Table). Chaque clé est le nom d'un tenseur et sa valeur est un BundleEntryProto sérialisé. Chaque BundleEntryProto décrit les métadonnées d'un tenseur: lequel des fichiers "données" contient le contenu d'un tenseur, le décalage dans ce fichier, la somme de contrôle, certaines données auxiliaires, etc. 

  • fichier de données : il s'agit de la collection TensorBundle, enregistrez les valeurs de toutes les variables. 

41
Guangcong Liu

Je suis en train de restaurer des implémentations Word formées à partir de Word2Vec tutoriel.

Si vous avez créé plusieurs points de contrôle:

par exemple. les fichiers créés ressemblent à ceci

model.ckpt-55695.data-00000-of-00001

model.ckpt-55695.index

model.ckpt-55695.meta

essaye ça

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

en appelant restore_session ():

def test_Word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")
2
Steven Wong

Si vous avez formé un CNN avec des abandons, par exemple, vous pourriez faire ceci:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
0
Sashank Aryal