web-dev-qa-db-fra.com

Qu'est-ce qui doit être à l'intérieur tf.distribute.strategy.scope ()?

Je joue actuellement avec des stratégies de distribution à Tensorflow 2.0 comme décrit ici https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/distribute/strategy

Je me demande ce qui doit aller à l'intérieur d'un bloc with ...scope() bloquer et ce qui est "facultatif".

Spécifiquement les opérations suivantes. Dois-je mettre ... à l'intérieur d'une with ...scope() pour la distribution au travail?:

  • Création d'optimiseur
  • Création du jeu de données
  • DataSet expérimental_distribute_dataset
  • apply_gradients appelle
  • Itération de jeu de données pour boucle
  • expérimental_run_v2

J'ai joué un peu et mon code semble fonctionner même lorsque j'utilise n ° with ...scope Du tout. Je suis confus si cela a des effets secondaires, je ne vois tout simplement pas en ce moment.

Code sans scope:

strat = tf.distribute.MirroredStrategy()

BATCH_SIZE_PER_REPLICA = 5

print('Replicas: ', strat.num_replicas_in_sync)

global_batch_size = (BATCH_SIZE_PER_REPLICA * strat.num_replicas_in_sync)

dataset = tf.data.Dataset.from_tensors(tf.random.normal([100])).repeat(1000).batch(
    global_batch_size)

g = Model('m', 10, 10, 1, 3)

dist_dataset = strat.experimental_distribute_dataset(dataset)

@tf.function
def train_step(dist_inputs):
  def step_fn(inputs):
    print([(v.name, v.device) for v in g.trainable_variables])
    return g(inputs)

  out = strat.experimental_run_v2(step_fn, args=(dist_inputs,))

for inputs in dist_dataset:
    train_step(inputs)
    break

Code avec portée:

strat = tf.distribute.MirroredStrategy()

BATCH_SIZE_PER_REPLICA = 5

print('Replicas: ', strat.num_replicas_in_sync)

global_batch_size = (BATCH_SIZE_PER_REPLICA * strat.num_replicas_in_sync)

with strat.scope():
    dataset = tf.data.Dataset.from_tensors(tf.random.normal([100])).repeat(1000).batch(
        global_batch_size)

    g = Model('m', 10, 10, 1, 3)

    dist_dataset = strat.experimental_distribute_dataset(dataset)

    @tf.function
    def train_step(dist_inputs):
        def step_fn(inputs):
            print([(v.name, v.device) for v in g.trainable_variables])
            return g(inputs)

        out = strat.experimental_run_v2(step_fn, args=(dist_inputs,))

    for inputs in dist_dataset:
        train_step(inputs)
        break

EDIT: Il semble que strat.experimental_run_v2 Pénètre automatiquement dans la portée de strat. Alors, pourquoi with strat.scope() existe?

11
dparted

Vous n'avez pas besoin d'appeler strat.scope().

experimental_run_v2 Est un moyen simple de mettre votre calcul dans strat.scope().

Voir le code source ci-dessous pour experimental_run_v2, Il enveloppe votre fn dans la portée de vous.

https://github.com/tensorflow/tensorflow/blob/919DFC3D066E72EE02BAA11FBF7B035D9944DAA9/Tensorflow/python/Distribute/distribute_lib.py#l729

0
Ruolin