web-dev-qa-db-fra.com

Keras attention layer on LSTM

J'utilise keras 1.0.1 J'essaie d'ajouter une couche d'attention au-dessus d'un LSTM. C'est ce que j'ai jusqu'à présent, mais ça ne marche pas.

input_ = Input(shape=(input_length, input_dim))
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
att = TimeDistributed(Dense(1)(lstm))
att = Reshape((-1, input_length))(att)
att = Activation(activation="softmax")(att)
att = RepeatVector(self.HID_DIM)(att)
merge = Merge([att, lstm], "mul")
hid = Merge("sum")(merge)

last = Dense(self.HID_DIM, activation="relu")(hid)

Le réseau doit appliquer un LSTM sur la séquence d'entrée. Ensuite, chaque état masqué du LSTM doit être entré dans une couche entièrement connectée, sur laquelle un Softmax est appliqué. Le softmax est répliqué pour chaque dimension cachée et multiplié par les états cachés LSTM élément par élément. Ensuite, le vecteur résultant doit être moyenné.

EDIT: Ceci compile, mais je ne suis pas sûr si ça fait ce que je pense qu'il devrait faire.

input_ = Input(shape=(input_length, input_dim))
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
att = TimeDistributed(Dense(1))(lstm)
att = Flatten()(att)
att = Activation(activation="softmax")(att)
att = RepeatVector(self.HID_DIM)(att)
att = Permute((2,1))(att)
mer = merge([att, lstm], "mul")
hid = AveragePooling1D(pool_length=input_length)(mer)
hid = Flatten()(hid)
7
siamii

Ici est une implémentation de Attention LSTM avec Keras, et un exemple de son instanciation . Je n'ai pas essayé moi-même, cependant.

2
mossaab