web-dev-qa-db-fra.com

Keras model.fit () avec tf.dataset API + validation_data

Donc, mon modèle keras fonctionne avec un fichier tf.Dataset via le code suivant:

# Initialize batch generators(returns tf.Dataset)
batch_train = build_features.get_train_batches(batch_size=batch_size)

# Create TensorFlow Iterator object
iterator = batch_train.make_one_shot_iterator()
dataset_inputs, dataset_labels = iterator.get_next()

# Create Model
logits = .....(some layers)
keras.models.Model(inputs=dataset_inputs, outputs=logits)

# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit(epochs=epochs, steps_per_Epoch=num_batches, callbacks=callbacks, verbose=1)

cependant, lorsque j'essaie de transmettre le paramètre validation_data au modèle. Cela me dit que je ne peux pas l’utiliser avec le générateur. Existe-t-il un moyen d'utiliser la validation lors de l'utilisation de tf.Dataset

par exemple dans tensorflow je pourrais faire ce qui suit :

# initialize batch generators
batch_train = build_features.get_train_batches(batch_size=batch_size)
batch_valid = build_features.get_valid_batches(batch_size=batch_size)

# create TensorFlow Iterator object
iterator = tf.data.Iterator.from_structure(batch_train.output_types,
                                           batch_train.output_shapes)

# create two initialization ops to switch between the datasets
init_op_train = iterator.make_initializer(batch_train)
init_op_valid = iterator.make_initializer(batch_valid)

puis utilisez simplement sess.run(init_op_train) et sess.run(init_op_valid) pour basculer entre les jeux de données

J'ai essayé d'implémenter un rappel qui ne fait que cela (passer à l'ensemble de validation, prévoir et revenir), mais il me dit que je ne peux pas utiliser model.predict dans un rappel.

quelqu'un peut-il m'aider à obtenir une validation en travaillant avec Keras + Tf.Dataset

edit: incorporer la réponse dans le code

Donc, finalement, ce qui a fonctionné pour moi, grâce à la réponse sélectionnée est:

# Initialize batch generators(returns tf.Dataset)
batch_train = # returns tf.Dataset
batch_valid = # returns tf.Dataset

# Create TensorFlow Iterator object and wrap it in a generator
itr_train = make_iterator(batch_train)
itr_valid = make_iterator(batch_train)

# Create Model
logits = # the keras model
keras.models.Model(inputs=dataset_inputs, outputs=logits)

# Train network
model.compile(optimizer=train_opt, loss=model_loss, target_tensors=[dataset_labels])
model.fit_generator(
    generator=itr_train, validation_data=itr_valid, validation_steps=batch_size,
    epochs=epochs, steps_per_Epoch=num_batches, callbacks=cbs, verbose=1, workers=0)

def make_iterator(dataset):
    iterator = dataset.make_one_shot_iterator()
    next_val = iterator.get_next()

    with K.get_session().as_default() as sess:
        while True:
            *inputs, labels = sess.run(next_val)
            yield inputs, labels

Cela n'introduit aucun frais généraux

5
Mark Rofail

J'ai résolu le problème en utilisant fit_genertor. J'ai trouvé la solution ici . J'ai appliqué la solution @ Dat-Nguyen.

Vous devez simplement créer deux itérateurs, un pour la formation et un pour la validation, puis créer votre propre générateur où vous extrairez des lots du jeu de données et fournissez les données sous la forme (batch_data, batch_labels). Enfin, dans model.fit_generator, vous passerez les train_generator et validation_generator. 

2
W. Sam

Pour connecter un itérateur réinitialisable à un modèle de Keras, vous devez connecter un itérateur qui renvoie les valeurs x et y simultanément:

sess = tf.Session()
keras.backend.set_session(sess) 

x = np.random.random((5, 2))
y = np.array([0, 1] * 3 + [1, 0] * 2).reshape(5, 2) # One hot encoded
input_dataset = tf.data.Dataset.from_tensor_slices((c, d))

# Create your reinitializable_iterator and initializer
reinitializable_iterator = tf.data.Iterator.from_structure(input_dataset.output_types, input_dataset.output_shapes)
init_op = reinitializable_iterator.make_initializer(input_dataset)

#run the initializer
sess.run(init_op) # feed_dict if you're using placeholders as input

# build keras model and plug in the iterator
model = keras.Model.model(...)
model.compile(...)
model.fit(reinitializable_iterator,...)

Si vous disposez également d'un ensemble de données de validation, le plus simple consiste à créer un itérateur séparé et à l'insérer dans le paramètre validation_data. Assurez-vous de définir vos étapes step_per_Epoch et validation_steps car elles ne peuvent pas être déduites. 

0
Razorocean