web-dev-qa-db-fra.com

Pourquoi s'embêter avec les réseaux de neurones récurrents pour les données structurées?

J'ai développé des réseaux de neurones à action directe (FNN) et des réseaux de neurones récurrents (RNN) à Keras avec des données structurées de la forme [instances, time, features], et les performances des FNN et des RNN sont les mêmes (sauf que les RNN nécessitent plus de temps de calcul).

J'ai également simulé des données tabulaires (code ci-dessous) où je m'attendais à ce qu'un RNN surpasse un FNN car la valeur suivante de la série dépend de la valeur précédente de la série; cependant, les deux architectures prédisent correctement.

Avec les données PNL, j'ai vu des RNN surpasser les FNN, mais pas avec des données tabulaires. En général, quand peut-on s'attendre à ce qu'un RNN surpasse un FNN avec des données tabulaires? Plus précisément, quelqu'un pourrait-il publier du code de simulation avec des données tabulaires démontrant un RNN surperformant un FNN?

Je vous remercie! Si mon code de simulation n'est pas idéal pour ma question, veuillez l'adapter ou en partager un plus idéal!

from keras import models
from keras import layers

from keras.layers import Dense, LSTM

import numpy as np
import matplotlib.pyplot as plt

Deux entités ont été simulées sur 10 pas de temps, où la valeur de la deuxième entité dépend de la valeur des deux entités à l'étape de temps précédente.

## Simulate data.

np.random.seed(20180825)

X = np.random.randint(50, 70, size = (11000, 1)) / 100

X = np.concatenate((X, X), axis = 1)

for i in range(10):

    X_next = np.random.randint(50, 70, size = (11000, 1)) / 100

    X = np.concatenate((X, X_next, (0.50 * X[:, -1].reshape(len(X), 1)) 
        + (0.50 * X[:, -2].reshape(len(X), 1))), axis = 1)

print(X.shape)

## Training and validation data.

split = 10000

Y_train = X[:split, -1:].reshape(split, 1)
Y_valid = X[split:, -1:].reshape(len(X) - split, 1)
X_train = X[:split, :-2]
X_valid = X[split:, :-2]

print(X_train.shape)
print(Y_train.shape)
print(X_valid.shape)
print(Y_valid.shape)

FNN:

## FNN model.

# Define model.

network_fnn = models.Sequential()
network_fnn.add(layers.Dense(64, activation = 'relu', input_shape = (X_train.shape[1],)))
network_fnn.add(Dense(1, activation = None))

# Compile model.

network_fnn.compile(optimizer = 'adam', loss = 'mean_squared_error')

# Fit model.

history_fnn = network_fnn.fit(X_train, Y_train, epochs = 10, batch_size = 32, verbose = False,
    validation_data = (X_valid, Y_valid))

plt.scatter(Y_train, network_fnn.predict(X_train), alpha = 0.1)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.show()

plt.scatter(Y_valid, network_fnn.predict(X_valid), alpha = 0.1)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.show()

LSTM:

## LSTM model.

X_lstm_train = X_train.reshape(X_train.shape[0], X_train.shape[1] // 2, 2)
X_lstm_valid = X_valid.reshape(X_valid.shape[0], X_valid.shape[1] // 2, 2)

# Define model.

network_lstm = models.Sequential()
network_lstm.add(layers.LSTM(64, activation = 'relu', input_shape = (X_lstm_train.shape[1], 2)))
network_lstm.add(layers.Dense(1, activation = None))

# Compile model.

network_lstm.compile(optimizer = 'adam', loss = 'mean_squared_error')

# Fit model.

history_lstm = network_lstm.fit(X_lstm_train, Y_train, epochs = 10, batch_size = 32, verbose = False,
    validation_data = (X_lstm_valid, Y_valid))

plt.scatter(Y_train, network_lstm.predict(X_lstm_train), alpha = 0.1)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.show()

plt.scatter(Y_valid, network_lstm.predict(X_lstm_valid), alpha = 0.1)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.show()
29

En pratique, même en PNL, vous voyez que les RNN et les CNN sont souvent compétitifs. Voici un article de revue de 2017 qui le montre plus en détail. En théorie, il se peut que les RNN puissent mieux gérer toute la complexité et la nature séquentielle du langage, mais dans la pratique, le plus gros obstacle est généralement de bien former le réseau et les RNN sont capricieux.

Un autre problème qui pourrait avoir une chance de fonctionner serait de regarder un problème comme le problème des parenthèses équilibrées (avec juste des parenthèses dans les chaînes ou des parenthèses avec d'autres caractères de distraction). Cela nécessite le traitement séquentiel des entrées et le suivi de certains états et pourrait être plus facile à apprendre avec un LSTM puis un FFN.

Mise à jour: certaines données qui semblent séquentielles peuvent ne pas avoir à être traitées séquentiellement. Par exemple, même si vous fournissez une séquence de nombres à ajouter puisque l'addition est commutative, un FFN fera aussi bien qu'un RNN. Cela pourrait également être vrai pour de nombreux problèmes de santé où les informations dominantes ne sont pas de nature séquentielle. Supposons que chaque année les habitudes de tabagisme d'un patient soient mesurées. D'un point de vue comportemental, la trajectoire est importante mais si vous prédisez si le patient développera un cancer du poumon, la prédiction sera dominée uniquement par le nombre d'années pendant lesquelles le patient a fumé (peut-être limité aux 10 dernières années pour le FFN).

Vous voulez donc rendre le problème du jouet plus complexe et nécessiter de prendre en compte l'ordre des données. Peut-être une sorte de série temporelle simulée, où vous voulez prédire s'il y avait un pic dans les données, mais vous ne vous souciez pas des valeurs absolues, juste de la nature relative du pic.

Update2

J'ai modifié votre code pour montrer un cas où les RNN fonctionnent mieux. L'astuce consistait à utiliser une logique conditionnelle plus complexe qui est plus naturellement modélisée dans les LSTM que dans les FFN. Le code est ci-dessous. Pour 8 colonnes, nous voyons que le FFN s'entraîne en 1 minute et atteint une perte de validation de 6,3. Le LSTM prend 3 fois plus de temps pour s'entraîner, mais sa perte de validation finale est 6 fois inférieure à 1,06.

Au fur et à mesure que nous augmentons le nombre de colonnes, le LSTM a un avantage de plus en plus important, surtout si nous ajoutons des conditions plus compliquées. Pour 16 colonnes, la perte de validation FFN est de 19 (et vous pouvez voir plus clairement la courbe d'apprentissage car le modèle n'est pas capable d'ajuster instantanément les données). En comparaison, le LSTM prend 11 fois plus de temps à s'entraîner mais a une perte de validation de 0,31, 30 fois plus petite que le FFN! Vous pouvez jouer avec des matrices encore plus grandes pour voir jusqu'où cette tendance va s'étendre.

from keras import models
from keras import layers

from keras.layers import Dense, LSTM

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import time

matplotlib.use('Agg')

np.random.seed(20180908)

rows = 20500
cols = 10

# Randomly generate Z
Z = 100*np.random.uniform(0.05, 1.0, size = (rows, cols))

larger = np.max(Z[:, :cols/2], axis=1).reshape((rows, 1))
larger2 = np.max(Z[:, cols/2:], axis=1).reshape((rows, 1))
smaller = np.min((larger, larger2), axis=0)
# Z is now the max of the first half of the array.
Z = np.append(Z, larger, axis=1)
# Z is now the min of the max of each half of the array.
# Z = np.append(Z, smaller, axis=1)

# Combine and shuffle.

#Z = np.concatenate((Z_sum, Z_avg), axis = 0)

np.random.shuffle(Z)

## Training and validation data.

split = 10000

X_train = Z[:split, :-1]
X_valid = Z[split:, :-1]
Y_train = Z[:split, -1:].reshape(split, 1)
Y_valid = Z[split:, -1:].reshape(rows - split, 1)

print(X_train.shape)
print(Y_train.shape)
print(X_valid.shape)
print(Y_valid.shape)

print("Now setting up the FNN")

## FNN model.

tick = time.time()

# Define model.

network_fnn = models.Sequential()
network_fnn.add(layers.Dense(32, activation = 'relu', input_shape = (X_train.shape[1],)))
network_fnn.add(Dense(1, activation = None))

# Compile model.

network_fnn.compile(optimizer = 'adam', loss = 'mean_squared_error')

# Fit model.

history_fnn = network_fnn.fit(X_train, Y_train, epochs = 500, batch_size = 128, verbose = False,
    validation_data = (X_valid, Y_valid))

tock = time.time()

print()
print(str('%.2f' % ((tock - tick) / 60)) + ' minutes.')

print("Now evaluating the FNN")

loss_fnn = history_fnn.history['loss']
val_loss_fnn = history_fnn.history['val_loss']
epochs_fnn = range(1, len(loss_fnn) + 1)
print("train loss: ", loss_fnn[-1])
print("validation loss: ", val_loss_fnn[-1])

plt.plot(epochs_fnn, loss_fnn, 'black', label = 'Training Loss')
plt.plot(epochs_fnn, val_loss_fnn, 'red', label = 'Validation Loss')
plt.title('FNN: Training and Validation Loss')
plt.legend()
plt.show()

plt.scatter(Y_train, network_fnn.predict(X_train), alpha = 0.1)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('training points')
plt.show()

plt.scatter(Y_valid, network_fnn.predict(X_valid), alpha = 0.1)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('valid points')
plt.show()

print("LSTM")

## LSTM model.

X_lstm_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_lstm_valid = X_valid.reshape(X_valid.shape[0], X_valid.shape[1], 1)

tick = time.time()

# Define model.

network_lstm = models.Sequential()
network_lstm.add(layers.LSTM(32, activation = 'relu', input_shape = (X_lstm_train.shape[1], 1)))
network_lstm.add(layers.Dense(1, activation = None))

# Compile model.

network_lstm.compile(optimizer = 'adam', loss = 'mean_squared_error')

# Fit model.

history_lstm = network_lstm.fit(X_lstm_train, Y_train, epochs = 500, batch_size = 128, verbose = False,
    validation_data = (X_lstm_valid, Y_valid))

tock = time.time()

print()
print(str('%.2f' % ((tock - tick) / 60)) + ' minutes.')

print("now eval")

loss_lstm = history_lstm.history['loss']
val_loss_lstm = history_lstm.history['val_loss']
epochs_lstm = range(1, len(loss_lstm) + 1)
print("train loss: ", loss_lstm[-1])
print("validation loss: ", val_loss_lstm[-1])

plt.plot(epochs_lstm, loss_lstm, 'black', label = 'Training Loss')
plt.plot(epochs_lstm, val_loss_lstm, 'red', label = 'Validation Loss')
plt.title('LSTM: Training and Validation Loss')
plt.legend()
plt.show()

plt.scatter(Y_train, network_lstm.predict(X_lstm_train), alpha = 0.1)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title('training')
plt.show()

plt.scatter(Y_valid, network_lstm.predict(X_lstm_valid), alpha = 0.1)
plt.xlabel('Actual')
plt.ylabel('Predicted')
plt.title("validation")
plt.show()
7
emschorsch