web-dev-qa-db-fra.com

Pourquoi avons-nous besoin d'appeler zero_grad () dans PyTorch?

La méthode zero_grad() doit être appelée pendant la formation. Mais le documentation n'est pas très utile

|  zero_grad(self)
|      Sets gradients of all model parameters to zero.

Pourquoi avons-nous besoin d'appeler cette méthode?

27
user1424739

Dans PyTorch , nous devons définir les gradients sur zéro avant de commencer à effectuer une backpropragation car PyTorch accumule les gradients lors des passages en arrière ultérieurs. Ceci est pratique lors de la formation RNN. Ainsi, l’action par défaut consiste à accumuler (c'est-à-dire résumer) les gradients à chaque appel loss.backward().

Pour cette raison, lorsque vous commencez votre boucle d’entraînement, vous devriez idéalement zero out the gradients afin que le paramètre soit mis à jour correctement. Sinon, le gradient indiquerait une direction autre que la direction prévue vers minimum (ou maximum, dans le cas d'objectifs de maximisation).

Voici un exemple simple:

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in Zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

Alternativement, si vous faites un descente de gradient vanille, alors:

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in Zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

Note : Le accumulation (c'est-à-dire somme) des gradients se produisent lorsque - .backward() est appelée sur le tenseur loss .

43
kmario23