web-dev-qa-db-fra.com

Pytorch - RuntimeError: Essayer de reculer dans le graphique une deuxième fois, mais les tampons ont déjà été libérés

Je continue à rencontrer cette erreur:

RuntimeError: Essayer de reculer dans le graphique une deuxième fois, mais les tampons ont déjà été libérés. Spécifiez retention_graph = True lors de votre premier appel en arrière.

J'avais effectué une recherche dans le forum Pytorch, mais je ne parviens toujours pas à découvrir ce que j'ai fait de mal dans ma fonction de perte personnalisée. Mon modèle est nn.GRU, et voici ma fonction de perte personnalisée:

def _loss(outputs, session, items):  # `items` is a dict() contains embedding of all items
    def f(output, target):
        pos = torch.from_numpy(np.array([items[target["click"]]])).float()
        neg = torch.from_numpy(np.array([items[idx] for idx in target["suggest_list"] if idx != target["click"]])).float()
        if USE_CUDA:
            pos, neg = pos.cuda(), neg.cuda()
        pos, neg = Variable(pos), Variable(neg)

        pos = F.cosine_similarity(output, pos)
        if neg.size()[0] == 0:
            return torch.mean(F.logsigmoid(pos))
        neg = F.cosine_similarity(output.expand_as(neg), neg)

        return torch.mean(F.logsigmoid(pos - neg))

    loss = map(f, outputs, session)
return -torch.mean(torch.cat(loss))

Code de formation:

    # zero the parameter gradients
    model.zero_grad()

    # forward + backward + optimize
    outputs, hidden = model(inputs, hidden)
    loss = _loss(outputs, session, items)
    acc_loss += loss.data[0]

    loss.backward()
    # Add parameters' gradients to their values, multiplied by learning rate
    for p in model.parameters():
        p.data.add_(-learning_rate, p.grad.data)
14
Viet Phan

Le problème vient de ma boucle d'entraînement: il ne détache ni ne reconditionne l'état caché entre les lots? Si c'est le cas, alors loss.backward() essaie de se propager en arrière jusqu'au début du temps, ce qui fonctionne pour le premier lot mais pas pour le second car le graphique du premier lot a été ignoré.

il y a deux solutions possibles.

1) détacher/reconditionner l'état caché entre les lots. Il y a (au moins) trois façons de procéder (et j'ai choisi cette solution):

 hidden.detach_()
 hidden = hidden.detach()

2) remplacez loss.backward () par loss.backward(retain_graph=True) mais sachez que chaque lot successif prendra plus de temps que le précédent car il devra se propager en retour jusqu'au début du premier lot.

Exemple

15
Viet Phan