web-dev-qa-db-fra.com

Enregistrer le modèle toutes les 10 époques tensorflow.keras v2

J'utilise des keras définis comme sous-module dans tensorflow v2. J'entraîne mon modèle en utilisant la méthode fit_generator(). Je veux enregistrer mon modèle toutes les 10 époques. Comment puis-je atteindre cet objectif?

En Keras (pas en tant que sous-module de tf), je peux donner ModelCheckpoint(model_savepath,period=10). Mais dans tf v2, ils ont changé cela en ModelCheckpoint(model_savepath, save_freq)save_freq peut être 'Epoch' auquel cas le modèle est sauvegardé à chaque époque. Si save_freq est un entier, le modèle est enregistré après que tant d'échantillons ont été traités. Mais je veux que ce soit après 10 époques. Comment puis-je atteindre cet objectif?

10
Nagabhushan S N

En utilisant tf.keras.callbacks.ModelCheckpoint utilisation save_freq='Epoch' et passez un argument supplémentaire period=10.

Bien que cela ne soit pas documenté dans le documents officiels , c'est la façon de le faire (notez qu'il est documenté que vous pouvez passer period, n'explique tout simplement pas ce qu'il fait).

3
bluesummers