web-dev-qa-db-fra.com

Charge de modèle extrêmement lente avec keras

J'ai un ensemble de modèles Keras (30) que j'ai formés et enregistrés en utilisant:

 model.save('model{0}.h5'.format(n_model))

Lorsque j'essaie de les charger, en utilisant load_model, le temps requis pour chaque modèle est assez grand et incrémental. Le chargement se fait comme:

models = {}
for i in range(30):
    start = time.time()
    models[i] = load_model('model{0}.h5'.format(ix)) 
    end = time.time()
    print "Model {0}: seconds {1}".format(ix, end - start)

Et la sortie est:

...
Model 9: seconds 7.38966012001
Model 10: seconds 9.99283003807
Model 11: seconds 9.7262301445
Model 12: seconds 9.17000102997
Model 13: seconds 10.1657290459
Model 14: seconds 12.5914049149
Model 15: seconds 11.652477026
Model 16: seconds 12.0126030445
Model 17: seconds 14.3402299881
Model 18: seconds 14.3761711121
...

Chaque modèle est vraiment simple: 2 couches cachées avec 10 neurones chacune (taille ~ 50Kb). Pourquoi le chargement prend-il autant et pourquoi le temps augmente-t-il? Suis-je en train de manquer quelque chose (par exemple une fonction de fermeture pour le modèle?)

SOLUTION

J'ai découvert que pour accélérer le chargement du modèle il vaut mieux stocker la structure des réseaux et les poids dans deux fichiers distincts: La partie sauvegarde:

model.save_weights('model.h5')
model_json = model.to_json()
with open('model.json', "w") as json_file:
    json_file.write(model_json)
json_file.close()

La partie chargement:

from keras.models import model_from_json
json_file = open("model.json", 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)
model.load_weights("model.h5")
17
Titus Pullo

J'ai résolu le problème en effaçant la session de keras avant chaque chargement

from keras import backend as K
for i in range(...):
  K.clear_session()
  model = load_model(...)
10
GearLux

J'ai essayé avec K.clear_session(), et cela augmente le temps de chargement à chaque fois.
Cependant, mes modèles chargés de cette manière ne peuvent pas utiliser la fonction model.predict En raison de l'erreur suivante:
ValueError: Tensor Tensor("Sigmoid_2:0", shape=(?, 17), dtype=float32) is not an element of this graph.
Github # 2397 fournit une discussion détaillée à ce sujet. La meilleure solution pour l'instant est de prédire les données juste après le chargement du modèle, au lieu de charger des dizaines de modèles en même temps. Après avoir prédit chaque fois, vous pouvez utiliser K.clear_session() pour libérer le GPU, afin que le prochain chargement ne prenne pas plus de temps.

3
Wentai Chen