web-dev-qa-db-fra.com

Comment implémenter correctement un réseau LSTM à entrée par lots dans PyTorch?

Cette release de PyTorch semble fournir la variable PackedSequence pour des longueurs variables d’entrée pour un réseau neuronal récurrent. Cependant, j'ai trouvé qu'il était un peu difficile de l'utiliser correctement. 

En utilisant pad_packed_sequence pour récupérer une sortie d'une couche RNN alimentée par pack_padded_sequence, nous avons obtenu un T x B x N tenseur outputsT est le nombre maximal de pas de temps, B est la taille du lot et N est la taille cachée. J'ai trouvé que pour les séquences courtes dans le lot, la sortie suivante sera composée de zéros.

Voici mes questions.

  1. Pour une tâche de sortie unique dans laquelle l'un aurait besoin de la dernière sortie de toutes les séquences, le simple outputs[-1] donnera un résultat erroné, car ce tenseur contient beaucoup de zéros pour les séquences courtes. Il faudra construire des index par longueur de séquence pour extraire la dernière sortie individuelle de toutes les séquences. Y a-t-il un moyen plus simple de le faire?
  2. Pour une tâche de sortie multiple (par exemple, seq2seq), on ajoute généralement un calque linéaire N x O, remodèle le lot T x B x O en TB x O et calcule la perte d'entropie croisée avec les vraies cibles TB (généralement des entiers dans le modèle de langage). Dans cette situation, ces zéros dans la sortie par lots sont-ils importants?
13
Edityouprofile

Question 1 - Dernière heure

C'est le code que j'utilise pour obtenir la sortie du dernier timestep. Je ne sais pas s'il existe une solution plus simple. Si c'est le cas, j'aimerais le savoir. J'ai suivi cette discussion et ai saisi l'extrait de code relatif pour ma méthode last_timestep. Ceci est mon avant.

class BaselineRNN(nn.Module):
    def __init__(self, **kwargs):
        ...

    def last_timestep(self, unpacked, lengths):
        # Index of the last output for each sequence.
        idx = (lengths - 1).view(-1, 1).expand(unpacked.size(0),
                                               unpacked.size(2)).unsqueeze(1)
        return unpacked.gather(1, idx).squeeze()

    def forward(self, x, lengths):
        embs = self.embedding(x)

        # pack the batch
        packed = pack_padded_sequence(embs, list(lengths.data),
                                      batch_first=True)

        out_packed, (h, c) = self.rnn(packed)

        out_unpacked, _ = pad_packed_sequence(out_packed, batch_first=True)

        # get the outputs from the last *non-masked* timestep for each sentence
        last_outputs = self.last_timestep(out_unpacked, lengths)

        # project to the classes using a linear layer
        logits = self.linear(last_outputs)

        return logits

Question 2 - Perte d'entropie croisée masquée

Oui, par défaut, les timesteps à zéro remplissage (cibles) sont importants. Cependant, il est très facile de les masquer. Vous avez deux options, selon la version de PyTorch que vous utilisez.

  1. PyTorch 0.2.0 : pytorch prend désormais en charge le masquage directement dans CrossEntropyLoss , avec l'argument ignore_index. Par exemple, dans la modélisation du langage ou seq2seq, où j'ajoute un remplissage à zéro, je masque les mots remplis à zéro (cible) simplement comme ceci: 

    loss_function = nn.CrossEntropyLoss (ignore_index = 0)

  2. PyTorch 0.1.12 et versions antérieures: dans les versions antérieures de PyTorch, le masquage n'était pas pris en charge. Vous deviez donc implémenter votre propre solution de contournement. La solution que j’ai utilisée était masked_cross_entropy.py , par jihunchoi . Vous pouvez également être intéressé par cette discussion

8
Christos Baziotis

Il y a quelques jours, j'ai trouvé cette méthode qui utilise l'indexation pour accomplir la même tâche avec une ligne.

J'ai d'abord mon lot de données ([batch size, sequence length, features]), donc pour moi:

unpacked_out = unpacked_out[np.arange(unpacked_out.shape[0]), lengths - 1, :]

unpacked_out est la sortie de torch.nn.utils.rnn.pad_packed_sequence.

Je l'ai comparée à la méthode décrite ici , qui ressemble à la méthode last_timestep() que Christos Baziotis utilise ci-dessus (également recommandé ici ) et les résultats sont les mêmes dans mon cas.

0
n8henrie