web-dev-qa-db-fra.com

Tf 2.0: RuntimeError: GradientTape.gradient ne peut être appelé qu'une seule fois sur des bandes non persistantes

Dans tf 2.0 DC Gan dans guide tensorflow 2. , il y a deux bandes de dégradé. Voir ci-dessous.

@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(Zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(Zip(gradients_of_discriminator, discriminator.trainable_variables))

Comme vous pouvez le voir clairement, il existe deux bandes de dégradé. Je me demandais quelle différence faisait l'utilisation d'une seule bande et je l'ai changé comme suit

@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(Zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(Zip(gradients_of_discriminator, discriminator.trainable_variables))

Cela me donne l'erreur suivante:

RuntimeError: GradientTape.gradient can only be called once on non-persistent tapes.

Je voudrais savoir pourquoi deux bandes sont nécessaires. Pour l'instant, la documentation sur les API tf2.0 est insuffisante. Quelqu'un peut-il m'expliquer ou me diriger vers les bons documents/tutoriels?

10
Himaprasoon

De la documentation de GradientTape:

Par défaut, les ressources détenues par un GradientTape sont libérées dès que la méthode GradientTape.gradient () est appelée. Pour calculer plusieurs dégradés sur le même calcul, créez une bande de dégradé persistante. Cela permet plusieurs appels à la méthode gradient () lorsque des ressources sont libérées lorsque l'objet bande est récupéré.

5
Sparky05

La raison technique est que gradient est appelé deux fois, ce qui n'est pas autorisé sur les bandes (non persistantes).

Dans le cas présent, cependant, la raison sous-jacente est que l'apprentissage du GANS se fait généralement en alternant l'optimisation du générateur et du discriminateur. Chaque optimisation a son propre optimiseur qui fonctionne généralement sur différentes variables, et de nos jours même la perte qui est minimisée est différente (gen_loss et disc_loss dans votre code).

Vous vous retrouvez donc avec deux gradients parce que la formation des GAN consiste essentiellement à optimiser deux problèmes différents (contradictoires) de manière alternée.

0
P-Gn