web-dev-qa-db-fra.com

Quelle est la différence entre "caché" et "sortie" dans PyTorch LSTM?

J'ai de la difficulté à comprendre la documentation du module LSTM de PyTorch (ainsi que RNN et GRU, qui sont similaires). En ce qui concerne les résultats, il est dit:

Sorties: sortie, (h_n, c_n)

  • output (seq_len, batch, hidden_size * num_directions): tenseur contenant les entités en sortie (h_t) de la dernière couche du RNN, pour chaque t. Si une torch.nn.utils.rnn.PackedSequence a été donnée en entrée, la sortie sera également une séquence compactée.
  • h_n (num_layers * num_directions, batch, hidden_size): tenseur contenant l'état masqué pour t = seq_len
  • c_n (num_layers * num_directions, batch, hidden_size): tenseur contenant l'état de la cellule pour t = seq_len

Il semble que les variables output et h_n donnent toutes les deux les valeurs de l'état masqué. h_n fournit-il de manière redondante le dernier pas de temps déjà inclus dans output, ou y a-t-il autre chose que cela?

24
Nathaniel

J'ai fait un diagramme. Les noms suivent les PyTorch docs , bien que j'ai renommé num_layers en w.

output comprend tous les états cachés dans la dernière couche ("dernière" en profondeur, et non en temps). (h_n, c_n) comprend les états cachés après le dernier pas temporel, t = n, afin que vous puissiez éventuellement les insérer dans un autre LSTM.

 LSTM diagram

La dimension du lot n'est pas incluse.

47
nnnmmm

L'état de sortie est le tenseur de tous les états cachés de chaque pas de temps dans le RNN (LSTM), et l'état caché renvoyé par le RNN (LSTM) est le dernier état caché depuis le dernier pas de temps de la séquence d'entrée. Vous pouvez vérifier cela en collectant tous les états cachés de chaque étape et en le comparant à l'état de sortie (à condition que vous n'utilisiez pas pack_padded_sequence).

0
Jibin Mathew

Cela dépend vraiment du modèle que vous utilisez et de la manière dont vous interpréterez le modèle. La sortie peut être:

  • un seul état caché de cellule LSTM
  • plusieurs états cachés de cellules LSTM
  • toutes les sorties d'états cachés

La sortie, n’est presque jamais interprétée directement. Si l'entrée est codée, il devrait y avoir une couche softmax pour décoder les résultats.

Remarque: dans la modélisation du langage, les états cachés permettent de définir la probabilité du prochain mot, p (wt + 1| w1, ..., wt) = softmax (Wht+ b).

0
prosti