web-dev-qa-db-fra.com

Comment éviter le surapprentissage sur un simple réseau à action directe

En utilisant ensemble de données sur le diabète des indiens pima J'essaie de construire un modèle précis en utilisant Keras. J'ai écrit le code suivant:

# Visualize training history
from keras import callbacks
from keras.layers import Dropout

tb = callbacks.TensorBoard(log_dir='/.logs', histogram_freq=10, batch_size=32,
                           write_graph=True, write_grads=True, write_images=False,
                           embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None)
# Visualize training history
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt
import numpy

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:, 0:8]
Y = dataset[:, 8]
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu', name='first_input'))
model.add(Dense(500, activation='tanh', name='first_hidden'))
model.add(Dropout(0.5, name='dropout_1'))
model.add(Dense(8, activation='relu', name='second_hidden'))
model.add(Dense(1, activation='sigmoid', name='output_layer'))

# Compile model
model.compile(loss='binary_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

# Fit the model
history = model.fit(X, Y, validation_split=0.33, epochs=1000, batch_size=10, verbose=0, callbacks=[tb])
# list all data in history
print(history.history.keys())
# summarize history for accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('Epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('Epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

Après plusieurs essais, j'ai ajouté des couches d'abandon afin d'éviter le sur-ajustement, mais sans chance. Le graphique suivant montre que la perte de validation et la perte de formation se séparent à un moment donné.

enter image description here

Que puis-je faire d'autre pour optimiser ce réseau?

MISE À JOUR: en fonction des commentaires que j'ai, j'ai peaufiné le code comme ceci:

model = Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', kernel_regularizer=regularizers.l2(0.01),
                activity_regularizer=regularizers.l1(0.01), activation='relu',
                name='first_input'))  # added regularizers
model.add(Dense(8, activation='relu', name='first_hidden'))  # reduced to 8 neurons
model.add(Dropout(0.5, name='dropout_1'))
model.add(Dense(5, activation='relu', name='second_hidden'))
model.add(Dense(1, activation='sigmoid', name='output_layer'))

Voici les graphiques pour 500 époques

enter image description hereenter image description here

11
Shlomi Schwartz

enter image description here

Le premier exemple a donné une précision de validation> 75% et le second a donné une précision <65% et si vous comparez les pertes pour les époques inférieures à 100, c'est moins de <0,5 pour la première et la seconde était> 0,6. Mais comment le deuxième cas est-il meilleur?.

Le second pour moi est un cas de under-fitting: le modèle n'a pas la capacité d'apprendre. Alors que le premier cas a un problème de over-fitting parce que son entraînement n'a pas été interrompu lors du surapprentissage (early stopping). Si la formation était arrêtée à environ 100 époques, ce serait un bien meilleur modèle que les deux.

L'objectif devrait être d'obtenir une petite erreur de prédiction dans les données invisibles et pour cela, vous augmentez la capacité du réseau jusqu'à un point au-delà duquel le sur-ajustement commence à se produire.

Alors, comment éviter over-fitting dans ce cas particulier? Adoptez early stopping.

MODIFICATIONS DU CODE: Pour inclure early stopping et input scaling.

 # input scaling
 scaler = StandardScaler()
 X = scaler.fit_transform(X)

 # Early stopping  
 early_stop = EarlyStopping(monitor='val_loss', min_delta=0, patience=3, verbose=1, mode='auto')

 # create model - almost the same code
 model = Sequential()
 model.add(Dense(12, input_dim=8, activation='relu', name='first_input'))
 model.add(Dense(500, activation='relu', name='first_hidden'))
 model.add(Dropout(0.5, name='dropout_1'))
 model.add(Dense(8, activation='relu', name='second_hidden'))
 model.add(Dense(1, activation='sigmoid', name='output_layer')))

 history = model.fit(X, Y, validation_split=0.33, epochs=1000, batch_size=10, verbose=0, callbacks=[tb, early_stop])

Les graphiques Accuracy et loss:

enter image description here

13
vijay m

Tout d'abord, essayez d'ajouter une régularisation ( https://keras.io/regularizers/ ) comme avec ce code:

model.add(Dense(12, input_dim=12,
            kernel_regularizer=regularizers.l2(0.01),
            activity_regularizer=regularizers.l1(0.01)))

Assurez-vous également de réduire la taille de votre réseau, c'est-à-dire que vous n'avez pas besoin d'une couche cachée de 500 neurones - essayez simplement de la retirer pour diminuer la puissance de représentation et peut-être même une autre couche si elle est toujours surajustée. En outre, utilisez uniquement l'activation relu. Essayez également d'augmenter votre taux d'abandon à quelque chose comme 0,75 (bien qu'il soit déjà élevé). Vous n'avez probablement pas non plus besoin de l'exécuter pendant tant d'époques - il commencera juste à s'ajuster après assez longtemps.

3
CoolPenguin

Pour un ensemble de données comme celui sur le diabète, vous pouvez utiliser un réseau beaucoup plus simple. Essayez de réduire les neurones de votre deuxième couche. (Y a-t-il une raison spécifique pour laquelle vous avez choisi tanh comme activation là-bas?).

De plus, vous pouvez simplement ajouter un rappel EarlyStopping à votre formation: https://keras.io/callbacks/

2
petezurich