web-dev-qa-db-fra.com

Arrêter tôt avec tf.estimator, comment?

J'utilise tf.estimator dans TensorFlow 1.4 et tf.estimator.train_and_evaluate c'est super mais j'ai besoin d'arrêter tôt. Quelle est la façon préférée d'ajouter cela?

Je suppose qu'il y a tf.train.SessionRunHook quelque part pour ça. J'ai vu qu'il y avait un ancien paquet contrib avec un ValidationMonitor qui semblait avoir un arrêt précoce, mais il ne semble plus exister en 1.4. Ou est-ce que la manière préférée à l'avenir sera de compter sur tf.keras (avec lequel un arrêt précoce est vraiment facile) au lieu de tf.estimator/tf.layers/tf.data, peut-être?

19
Carl Thomé

Bonnes nouvelles! tf.estimator a maintenant un arrêt précoce de la prise en charge sur master et il semble que ce sera en 1.10.

estimator = tf.estimator.Estimator(model_fn, model_dir)

os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name='loss',
    max_steps_without_decrease=1000,
    min_steps=100)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
    eval_spec=tf.estimator.EvalSpec(eval_input_fn))
25
Carl Thomé

Oui il y a tf.train.StopAtStepHook :

Cette demande de hook s'arrête après qu'un certain nombre d'étapes ont été exécutées ou qu'une dernière étape a été atteinte. Une seule des deux options peut être spécifiée.

Vous pouvez également l'étendre et mettre en œuvre votre propre stratégie d'arrêt en fonction des résultats de l'étape.

class MyHook(session_run_hook.SessionRunHook):
  ...
  def after_run(self, run_context, run_values):
    if condition:
      run_context.request_stop()
2
Maxim

Tout d'abord, vous devez nommer la perte pour la rendre disponible à l'appel d'arrêt anticipé. Si votre variable de perte est nommée "perte" dans l'estimateur, la ligne

copyloss = tf.identity(loss, name="loss")

juste en dessous cela fonctionnera.

Ensuite, créez un crochet avec ce code.

class EarlyStopping(tf.train.SessionRunHook):
    def __init__(self,smoothing=.997,tolerance=.03):
        self.lowestloss=float("inf")
        self.currentsmoothedloss=-1
        self.tolerance=tolerance
        self.smoothing=smoothing
    def before_run(self, run_context):
        graph = ops.get_default_graph()
        #print(graph)
        self.lossop=graph.get_operation_by_name("loss")
        #print(self.lossop)
        #print(self.lossop.outputs)
        self.element = self.lossop.outputs[0]
        #print(self.element)
        return tf.train.SessionRunArgs([self.element])
    def after_run(self, run_context, run_values):
        loss=run_values.results[0]
        #print("loss "+str(loss))
        #print("running average "+str(self.currentsmoothedloss))
        #print("")
        if(self.currentsmoothedloss<0):
            self.currentsmoothedloss=loss*1.5
        self.currentsmoothedloss=self.currentsmoothedloss*self.smoothing+loss*(1-self.smoothing)
        if(self.currentsmoothedloss<self.lowestloss):
            self.lowestloss=self.currentsmoothedloss
        if(self.currentsmoothedloss>self.lowestloss+self.tolerance):
            run_context.request_stop()
            print("REQUESTED_STOP")
            raise ValueError('Model Stopping because loss is increasing from EarlyStopping hook')

cela compare une validation de perte lissée exponentiellement avec sa valeur la plus basse, et si elle est supérieure par tolérance, elle arrête la formation. S'il s'arrête trop tôt, l'augmentation de la tolérance et du lissage l'arrêtera plus tard. Continuez à lisser en dessous de un, sinon il ne s'arrêtera jamais.

Vous pouvez remplacer la logique dans after_run par quelque chose d'autre si vous souhaitez arrêter en fonction d'une condition différente.

Maintenant, ajoutez ce crochet à la spécification d'évaluation. Votre code devrait ressembler à ceci:

eval_spec=tf.estimator.EvalSpec(input_fn=lambda:eval_input_fn(batchsize),steps=100,hooks=[EarlyStopping()])#

Remarque importante: la fonction run_context.request_stop () est interrompue dans l'appel train_and_evaluate et n'arrête pas la formation. J'ai donc soulevé une erreur de valeur pour arrêter la formation. Vous devez donc encapsuler l'appel train_and_evaluate dans un bloc try catch comme celui-ci:

try:
    tf.estimator.train_and_evaluate(classifier,train_spec,eval_spec)
except ValueError as e:
    print("training stopped")

si vous ne le faites pas, le code plantera avec une erreur lorsque la formation s'arrêtera.

2
user3806120

Une autre option qui n'utilise pas de crochets consiste à créer un tf.contrib.learn.Experiment (qui semble, même en contrib, soutenir également le nouveau tf.estimator.Estimator).

Entraînez-vous ensuite par la méthode (apparemment expérimentale) continuous_train_and_eval avec une personnalisation appropriée continuous_eval_predicate_fn.

Selon le docu tensorflow, le continuous_eval_predicate_fn est

Une fonction de prédicat déterminant s'il faut continuer l'évaluation après chaque itération.

et appelé avec le eval_results de la dernière analyse. Pour un arrêt précoce, utilisez une fonction personnalisée qui conserve comme état le meilleur résultat actuel et un compteur et renvoie False lorsque la condition pour un arrêt précoce est atteinte.

Note ajoutée: cette approche consisterait à utiliser des méthodes obsolètes avec tensorflow 1.7 (tout tf.contrib.learn est obsolète à partir de cette version: https://www.tensorflow.org/api_docs/python/tf/contrib/apprendre )

1
skb