web-dev-qa-db-fra.com

Enregistrer et charger l'état de l'optimiseur de modèle

J'ai un ensemble de modèles assez compliqués que je m'entraîne et je cherche un moyen de sauvegarder et de charger les états de l'optimiseur de modèle. Les "modèles de formateur" consistent en différentes combinaisons de plusieurs autres "modèles de poids", dont certains ont des poids partagés, certains ont des poids figés selon le formateur, etc. C'est un exemple un peu trop compliqué à partager, mais en bref , Je ne peux pas utiliser model.save('model_file.h5') et keras.models.load_model('model_file.h5') lors de l'arrêt et du démarrage de mon entraînement.

L'utilisation de model.load_weights('weight_file.h5') fonctionne très bien pour tester mon modèle si la formation est terminée, mais si j'essaie de continuer à former le modèle en utilisant cette méthode, la perte ne revient même pas près de revenir à son dernier emplacement. J'ai lu que c'est parce que l'état de l'optimiseur n'est pas enregistré en utilisant cette méthode qui a du sens. Cependant, j'ai besoin d'une méthode pour enregistrer et charger les états des optimiseurs de mes modèles d'entraînement. Il semble que les keras aient déjà eu une model.optimizer.get_sate() et model.optimizer.set_sate() qui accomplirait ce que je recherche, mais cela ne semble plus être le cas (du moins pour l'optimiseur Adam). Existe-t-il d'autres solutions avec le Keras actuel?

9
Starnetter

Vous pouvez extraire les lignes importantes du load_model et save_model les fonctions.

Pour enregistrer les états de l'optimiseur, dans save_model:

# Save optimizer weights.
symbolic_weights = getattr(model.optimizer, 'weights')
if symbolic_weights:
    optimizer_weights_group = f.create_group('optimizer_weights')
    weight_values = K.batch_get_value(symbolic_weights)

Pour charger les états de l'optimiseur, dans load_model:

# Set optimizer weights.
if 'optimizer_weights' in f:
    # Build train function (to get weight updates).
    if isinstance(model, Sequential):
        model.model._make_train_function()
    else:
        model._make_train_function()

    # ...

    try:
        model.optimizer.set_weights(optimizer_weight_values)

En combinant les lignes ci-dessus, voici un exemple:

  1. Ajustez d'abord le modèle pour 5 époques.
X, y = np.random.Rand(100, 50), np.random.randint(2, size=100)
x = Input((50,))
out = Dense(1, activation='sigmoid')(x)
model = Model(x, out)
model.compile(optimizer='adam', loss='binary_crossentropy')
model.fit(X, y, epochs=5)

Epoch 1/5
100/100 [==============================] - 0s 4ms/step - loss: 0.7716
Epoch 2/5
100/100 [==============================] - 0s 64us/step - loss: 0.7678
Epoch 3/5
100/100 [==============================] - 0s 82us/step - loss: 0.7665
Epoch 4/5
100/100 [==============================] - 0s 56us/step - loss: 0.7647
Epoch 5/5
100/100 [==============================] - 0s 76us/step - loss: 0.7638
  1. Enregistrez maintenant les poids et les états de l'optimiseur.
model.save_weights('weights.h5')
symbolic_weights = getattr(model.optimizer, 'weights')
weight_values = K.batch_get_value(symbolic_weights)
with open('optimizer.pkl', 'wb') as f:
    pickle.dump(weight_values, f)
  1. Reconstruisez le modèle dans une autre session python et chargez les poids.
x = Input((50,))
out = Dense(1, activation='sigmoid')(x)
model = Model(x, out)
model.compile(optimizer='adam', loss='binary_crossentropy')

model.load_weights('weights.h5')
model._make_train_function()
with open('optimizer.pkl', 'rb') as f:
    weight_values = pickle.load(f)
model.optimizer.set_weights(weight_values)
  1. Continuez la formation du modèle.
model.fit(X, y, epochs=5)

Epoch 1/5
100/100 [==============================] - 0s 674us/step - loss: 0.7629
Epoch 2/5
100/100 [==============================] - 0s 49us/step - loss: 0.7617
Epoch 3/5
100/100 [==============================] - 0s 49us/step - loss: 0.7611
Epoch 4/5
100/100 [==============================] - 0s 55us/step - loss: 0.7601
Epoch 5/5
100/100 [==============================] - 0s 49us/step - loss: 0.7594
14
Yu-Yang

la mise à niveau de Keras vers 2.2.4 et l'utilisation de cornichons ont résolu ce problème pour moi. avec la version 2.2.3 de keras, les modèles Keras peuvent désormais être décapés en toute sécurité.

2
ismail