web-dev-qa-db-fra.com

Est-il possible de rendre une variable entraînable non formable?

J'ai créé une variable trainable dans une portée. Plus tard, j'ai entré la même portée, défini la portée sur reuse_variables, et utilisé get_variable pour récupérer la même variable. Cependant, je ne peux pas définir la propriété trainable de la variable sur False. Ma get_variable la ligne est comme:

weight_var = tf.get_variable('weights', trainable = False)

Mais la variable 'weights' est toujours dans la sortie de tf.trainable_variables.

Puis-je définir l'indicateur trainable d'une variable partagée sur False en utilisant get_variable?

La raison pour laquelle je veux le faire est que j'essaie de réutiliser les filtres de bas niveau pré-formés à partir de VGG net dans mon modèle, et je veux construire le graphique comme avant, récupérer la variable de poids et attribuer des valeurs de filtre VGG à la variable de poids, puis maintenez-les fixes pendant l'étape de formation suivante.

33
Wei Liu

Après avoir regardé la documentation et le code, j'étais pas en mesure de trouver un moyen de supprimer une variable du TRAINABLE_VARIABLES.

Voici ce qui se passe:

  • La première fois que tf.get_variable('weights', trainable=True) est appelée, la variable est ajoutée à la liste de TRAINABLE_VARIABLES.
  • La deuxième fois que vous appelez tf.get_variable('weights', trainable=False), vous obtenez la même variable mais l'argument trainable=False N'a aucun effet car la variable est déjà présente dans la liste de TRAINABLE_VARIABLES (Et il y a aucun moyen de le supprimer de là)

Première solution

Lorsque vous appelez la méthode minimize de l'optimiseur (voir doc. ), vous pouvez passer un var_list=[...] Comme argument avec les variables que vous souhaitez optimiser.

Par exemple, si vous souhaitez figer toutes les couches de VGG à l'exception des deux dernières, vous pouvez transmettre les poids des deux dernières couches dans var_list.

Deuxième solution

Vous pouvez utiliser une tf.train.Saver() pour enregistrer les variables et les restaurer plus tard (voir ce tutoriel ).

  • Tout d'abord, vous entraînez l'ensemble de votre modèle VGG avec toutes les variables entraînables . Vous les enregistrez dans un fichier de point de contrôle en appelant saver.save(sess, "/path/to/dir/model.ckpt").
  • Ensuite (dans un autre fichier) vous entraînez la deuxième version avec des variables non entraînables . Vous chargez les variables précédemment stockées avec saver.restore(sess, "/path/to/dir/model.ckpt").

Facultativement, vous pouvez décider de ne sauvegarder que certaines des variables dans votre fichier de point de contrôle. Voir doc pour plus d'informations.

28
Olivier Moindrot

Lorsque vous souhaitez former ou optimiser uniquement certaines couches d'un réseau pré-formé, c'est ce que vous devez savoir.

La méthode minimize de TensorFlow prend un argument facultatif var_list, une liste de variables à ajuster par rétropropagation.

Si vous ne spécifiez pas var_list, toute variable TF du graphique peut être ajustée par l'optimiseur. Lorsque vous spécifiez certaines variables dans var_list, TF maintient toutes les autres variables constantes.

Voici un exemple de script que jonbruner et son collaborateur ont utilisé.

tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)

Cela recherche toutes les variables définies précédemment qui ont "g_" dans le nom de la variable, les place dans une liste et exécute l'optimiseur ADAM sur elles.

Vous pouvez trouver les réponses correspondantes ici sur Quora

10
rocksyne

Afin de supprimer une variable de la liste des variables entraînables, vous pouvez d'abord accéder à la collection via: trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) Là, trainable_collection Contient une référence à la collection de variables entraînables. Si vous pop des éléments de cette liste, en faisant par exemple trainable_collection.pop(0), vous allez supprimer la variable correspondante des variables entraînables, et donc cette variable ne sera pas entraînée.

Bien que cela fonctionne avec pop, j'ai encore du mal à trouver un moyen d'utiliser correctement remove avec l'argument correct, donc nous ne dépendons pas de l'index des variables.

EDIT: Étant donné que vous avez le nom des variables dans le graphique (vous pouvez l'obtenir en inspectant le protobuf du graphique ou, ce qui est plus facile, en utilisant Tensorboard), vous pouvez l'utiliser pour parcourir la liste des variables entraînables, puis supprimez les variables de la collection entraînable. Exemple: disons que je veux que les variables avec les noms "batch_normalization/gamma:0" Et "batch_normalization/beta:0" PAS soient entraînées, mais elles sont déjà ajoutées à la collection TRAINABLE_VARIABLES. Ce que je peux faire, c'est: `

#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
    #uses the attribute 'name' of the variable
    if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
        variables_to_remove.append(vari)
for rem in variables_to_remove:
    trainable_collection.remove(rem)

`Cela supprimera avec succès les deux variables de la collection, et elles ne seront plus entraînées.

6
Elisio Quintino

Vous pouvez utiliser tf.get_collection_ref pour obtenir la référence de la collection plutôt que tf.get_collection

0
Yuki