web-dev-qa-db-fra.com

Différence entre Variable et get_variable dans TensorFlow

Autant que je sache, Variable est l'opération par défaut pour la création d'une variable, et get_variable est principalement utilisé pour le partage de poids.

D'un côté, certaines personnes suggèrent d'utiliser get_variable au lieu de l'opération primitive Variable chaque fois que vous avez besoin d'une variable. D'autre part, je ne vois qu'une utilisation de get_variable dans les documents officiels et les démos de TensorFlow.

Je souhaite donc connaître quelques règles de base sur la manière d’utiliser correctement ces deux mécanismes. Existe-t-il des principes "standard"?

99
Lifu Huang

Je recommanderais de toujours utiliser tf.get_variable(...) - il sera plus facile de refactoriser votre code si vous devez partager des variables à tout moment, par exemple dans un paramètre multi-gpu (voir l'exemple CIFAR multi-gpu). Il n'y a pas d'inconvénient à cela. 

tf.Variable pur est de niveau inférieur; tf.get_variable() n’existait pas à un moment donné, de sorte que certains codes utilisent toujours la méthode de bas niveau.

82
Lukasz Kaiser

tf.Variable est une classe et il existe plusieurs façons de créer tf.Variable, y compris tf.Variable .__ init__ et tf.get_variable. 

tf.Variable .__ init__: crée une nouvelle variable avec initial_value.

W = tf.Variable(<initial-value>, name=<optional-name>)

tf.get_variable: obtient une variable existante avec ces paramètres ou en crée une nouvelle. Vous pouvez également utiliser l'initialiseur.

W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
       regularizer=None, trainable=True, collections=None)

Il est très utile d'utiliser des initialiseurs tels que xavier_initializer:

W = tf.get_variable("W", shape=[784, 256],
       initializer=tf.contrib.layers.xavier_initializer())

Plus d'informations sur https://www.tensorflow.org/versions/r0.8/api_docs/python/state_ops.html#Variable .

61
Sung Kim

Je peux trouver deux différences principales entre l'un et l'autre:

  1. La première est que tf.Variable créera toujours une nouvelle variable, si tf.get_variable obtient du graphe une variable existante avec ces paramètres, et si elle n'existe pas, elle en crée une nouvelle.

  2. tf.Variable nécessite qu'une valeur initiale soit spécifiée.

Il est important de préciser que la fonction tf.get_variable préfixe le nom avec la portée de la variable actuelle pour effectuer des contrôles de réutilisation. Par exemple:

with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
    c = tf.get_variable("v", [1]) #c.name == "one/v:0"

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"

assert(a is c)  #Assertion is true, they refer to the same object.
assert(a is d)  #AssertionError: they are different objects
assert(d is e)  #AssertionError: they are different objects

La dernière erreur d'assertion est intéressante: deux variables du même nom sous la même portée sont supposées être la même variable. Mais si vous testez les noms des variables d et e, vous réaliserez que Tensorflow a modifié le nom de la variable e:

d.name   #d.name == "two/v:0"
e.name   #e.name == "two/v_1:0"
38
Jadiel de Armas

Une autre différence réside dans le fait que l’un est dans la collection ('variable_store',) mais que l’autre ne l’est pas. 

S'il vous plaît voir la source code

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store

Laissez-moi illustrer cela: 

import tensorflow as tf
from tensorflow.python.framework import ops

embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="Word_embeddings_1", dtype=tf.float32) 
embedding_2 = tf.get_variable("Word_embeddings_2", shape=[30522, 1024])

graph = tf.get_default_graph()
collections = graph.collections

for c in collections:
    stores = ops.get_collection(c)
    print('collection %s: ' % str(c))
    for k, store in enumerate(stores):
        try:
            print('\t%d: %s' % (k, str(store._vars)))
        except:
            print('\t%d: %s' % (k, str(store)))
    print('')

Le résultat: 

collection ('__variable_store',): 0: {'Word_embeddings_2': <tf.Variable 'Word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}

1
lerner