web-dev-qa-db-fra.com

L'état initial de RNN est-il réinitialisé pour les mini-lots suivants?

Quelqu'un pourrait-il préciser si l'état initial du RNN dans TF est réinitialisé pour les mini-lots suivants, ou si le dernier état du mini-lot précédent est utilisé comme mentionné dans Ilya Sutskever et al., ICLR 2015 =?

17
VM_AI

Les opérations tf.nn.dynamic_rnn() ou tf.nn.rnn() permettent de spécifier l'état initial du RNN à l'aide du paramètre initial_state. Si vous ne spécifiez pas ce paramètre, les états masqués seront initialisés à zéro vecteurs au début de chaque lot d'apprentissage.

Dans TensorFlow, vous pouvez encapsuler les tenseurs dans tf.Variable() pour conserver leurs valeurs dans le graphique entre plusieurs exécutions de session. Assurez-vous simplement de les marquer comme non entraînables car les optimiseurs règlent toutes les variables entraînables par défaut.

data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))

cell = tf.nn.rnn_cell.GRUCell(256)
state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
output, new_state = tf.nn.dynamic_rnn(cell, data, initial_state=state)

with tf.control_dependencies([state.assign(new_state)]):
    output = tf.identity(output)

sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(output, {data: ...})

Je n'ai pas testé ce code mais il devrait vous donner un indice dans la bonne direction. Il y a aussi un tf.nn.state_saving_rnn() auquel vous pouvez fournir un objet économiseur d'état, mais je ne l'ai pas encore utilisé.

20
danijar

En plus de la réponse de danijar, voici le code d'un LSTM, dont l'état est un Tuple (state_is_Tuple=True). Il prend également en charge plusieurs couches.

Nous définissons deux fonctions - une pour obtenir les variables d'état avec un état zéro initial et une fonction pour renvoyer une opération, que nous pouvons transmettre à session.run afin de mettre à jour les variables d'état avec le dernier état caché du LSTM.

def get_state_variables(batch_size, cell):
    # For each layer, get the initial state and make a variable out of it
    # to enable updating its value.
    state_variables = []
    for state_c, state_h in cell.zero_state(batch_size, tf.float32):
        state_variables.append(tf.contrib.rnn.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a Tuple, so that it can be fed to dynamic_rnn as an initial state
    return Tuple(state_variables)


def get_state_update_op(state_variables, new_states):
    # Add an operation to update the train states with the last state tensors
    update_ops = []
    for state_variable, new_state in Zip(state_variables, new_states):
        # Assign the new state to the state variables on this layer
        update_ops.extend([state_variable[0].assign(new_state[0]),
                           state_variable[1].assign(new_state[1])])
    # Return a Tuple in order to combine all update_ops into a single operation.
    # The Tuple's actual value should not be used.
    return tf.Tuple(update_ops)

Semblable à la réponse de danijar, nous pouvons l'utiliser pour mettre à jour l'état du LSTM après chaque lot:

data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cells = [tf.contrib.rnn.GRUCell(256) for _ in range(num_layers)]
cell = tf.contrib.rnn.MultiRNNCell(cells)

# For each layer, get the initial state. states will be a Tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)

# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)

# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})

La principale différence est que state_is_Tuple=True fait de l'état du LSTM un LSTMStateTuple contenant deux variables (état de cellule et état caché) au lieu d'une seule variable. L'utilisation de plusieurs couches fait alors de l'état du LSTM un tuple de LSTMStateTuples - un par couche.

9
Kilian Batzner