web-dev-qa-db-fra.com

Keras load_model avec des objets personnalisés ne fonctionne pas correctement

Réglage

Comme déjà mentionné dans le titre, j'ai eu un problème avec ma fonction de perte personnalisée lors de la tentative de chargement du modèle enregistré. Ma perte se présente comme suit:

def weighted_cross_entropy(weights):

    weights = K.variable(weights)

    def loss(y_true, y_pred):
        y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon())

        loss = y_true * K.log(y_pred) * weights
        loss = -K.sum(loss, -1)
        return loss

    return loss

weighted_loss = weighted_cross_entropy([0.1,0.9])

J'ai donc utilisé le weighted_loss fonctionne comme fonction de perte et tout fonctionnait bien. Une fois la formation terminée, j'enregistre le modèle sous .h5fichier avec la norme model.save fonction de l'API keras.

Problème

Lorsque j'essaie de charger le modèle via

model = load_model(path,custom_objects={"weighted_loss":weighted_loss})

Je reçois un ValueError me disant que la perte est inconnue.

Erreur

Le message d'erreur se présente comme suit:

File "...\predict.py", line 29, in my_script
"weighted_loss": weighted_loss})
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 419, in load_model
model = _deserialize_model(f, custom_objects, compile)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 312, in _deserialize_model
sample_weight_mode=sample_weight_mode)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\training.py", line 139, in compile
loss_function = losses.get(loss)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 133, in get
return deserialize(identifier)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 114, in deserialize
printable_module_name='loss function')
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\utils\generic_utils.py", line 165, in deserialize_keras_object
':' + function_name)
ValueError: Unknown loss function:loss

Questions

Comment puis-je résoudre ce problème? Est-il possible que la raison en soit ma définition de perte enveloppée? Donc, keras ne sait pas comment gérer la variable weights?

4
pafi

Le nom de votre fonction de perte est loss (c'est-à-dire def loss(y_true, y_pred):). Par conséquent, lors du chargement du modèle, vous devez spécifier 'loss' comme son nom:

model = load_model(path, custom_objects={'loss': weighted_loss})
4
today