web-dev-qa-db-fra.com

tensorflow: peut enregistrer le meilleur modèle uniquement avec val_acc disponible, saut

J'ai un problème avec tf.callbacks.ModelChekpoint. Comme vous pouvez le voir dans mon fichier journal, l'avertissement vient toujours avant la dernière itération où le val_acc est calculé. Par conséquent, Modelcheckpoint ne trouve jamais le val_acc

Epoch 1/30
1/8 [==>...........................] - ETA: 19s - loss: 1.4174 - accuracy: 0.3000
2/8 [======>.......................] - ETA: 8s - loss: 1.3363 - accuracy: 0.3500 
3/8 [==========>...................] - ETA: 4s - loss: 1.3994 - accuracy: 0.2667
4/8 [==============>...............] - ETA: 3s - loss: 1.3527 - accuracy: 0.3250
6/8 [=====================>........] - ETA: 1s - loss: 1.3042 - accuracy: 0.3333
WARNING:tensorflow:Can save best model only with val_acc available, skipping.
8/8 [==============================] - 4s 482ms/step - loss: 1.2846 - accuracy: 0.3375 - val_loss: 1.3512 - val_accuracy: 0.5000

Epoch 2/30
1/8 [==>...........................] - ETA: 0s - loss: 1.0098 - accuracy: 0.5000
3/8 [==========>...................] - ETA: 0s - loss: 0.8916 - accuracy: 0.5333
5/8 [=================>............] - ETA: 0s - loss: 0.9533 - accuracy: 0.5600
6/8 [=====================>........] - ETA: 0s - loss: 0.9523 - accuracy: 0.5667
7/8 [=========================>....] - ETA: 0s - loss: 0.9377 - accuracy: 0.5714
WARNING:tensorflow:Can save best model only with val_acc available, skipping.
8/8 [==============================] - 1s 98ms/step - loss: 0.9229 - accuracy: 0.5750 - val_loss: 1.2507 - val_accuracy: 0.5000

Ceci est mon code pour la formation du CNN.

    callbacks = [
        TensorBoard(log_dir=r'C:\Users\reda.elhail\Desktop\logs\{}'.format(Name),
                    histogram_freq=1),
        ModelCheckpoint(filepath=r"C:\Users\reda.elhail\Desktop\checkpoints\{}".format(Name), monitor='val_acc',
                        verbose=2, save_best_only=True, mode='max')]
    history = model.fit_generator(
        train_data_gen,
        steps_per_Epoch=total_train // batch_size,
        epochs=epochs,
        validation_data=val_data_gen,
        validation_steps=total_val // batch_size,
        callbacks=callbacks)```
5
Reda ElH

Je sais à quel point ces choses peuvent être parfois frustrantes ... mais tensorflow nécessite que vous écriviez explicitement le nom de la métrique que vous souhaitez calculer

Vous aurez besoin de dire "val_accuracy"

metric = 'val_accuracy'
ModelCheckpoint(filepath=r"C:\Users\reda.elhail\Desktop\checkpoints\{}".format(Name), monitor=metric,
                    verbose=2, save_best_only=True, mode='max')]

J'espère que cela aide =)

1