web-dev-qa-db-fra.com

Tensorflow et multitraitement: Sessions en cours

J'ai récemment travaillé sur un projet utilisant un réseau de neurones pour le contrôle de robots virtuels. J'ai utilisé tensorflow pour le coder et tout se passe bien. Jusqu'à présent, j'ai utilisé des simulations séquentielles pour évaluer la qualité du réseau de neurones. Cependant, je souhaite exécuter plusieurs simulations en parallèle afin de réduire le temps nécessaire à l'obtention des données.

Pour ce faire, j'importe le package multiprocessing de python. Au départ, je passais la variable sess (sess=tf.Session()) à une fonction qui exécuterait la simulation. Cependant, une fois que j'arrive à une instruction qui utilise cette variable sess, le processus se ferme sans avertissement. Après une recherche un peu, j’ai trouvé ces deux articles: Tensorflow: passer une session à un multiprocessus python Et Exécuter simultanément plusieurs sessions tensorflow

Bien qu’ils soient très liés, je n’ai pas pu comprendre comment le faire fonctionner. J'ai essayé de créer une session pour chaque processus individuel et d'assigner les poids du réseau neuronal à ses paramètres pouvant être entraînés sans succès. J'ai également essayé de sauvegarder la session dans un fichier, puis de le charger dans un processus, mais sans succès non plus.

Est-ce que quelqu'un a réussi à passer une session (ou des clones de sessions) à plusieurs processus?

Merci.

10
MrRed

J'utilise keras comme enveloppe avec tensorflow comme support, mais le même principe général devrait s'appliquer.

Si vous essayez quelque chose comme ça:

import keras
from functools import partial
from multiprocessing import Pool

def ModelFunc(i,SomeData):
    YourModel = Here
    return(ModelScore)

pool = Pool(processes = 4)
for i,Score in enumerate(pool.imap(partial(ModelFunc,SomeData),range(4))):
    print(Score)

Ça va échouer. Cependant, si vous essayez quelque chose comme ceci: 

from functools import partial
from multiprocessing import Pool

def ModelFunc(i,SomeData):
    import keras
    YourModel = Here
    return(ModelScore)

pool = Pool(processes = 4)
for i,Score in enumerate(pool.imap(partial(ModelFunc,SomeData),range(4))):
    print(Score)

Ça devrait marcher. Essayez d'appeler tensorflow séparément pour chaque processus.

3
June Skeeter

Vous ne pouvez pas utiliser le multitraitement Python pour transmettre une variable TensorFlow Session à un multiprocessing.Pool de manière directe car l'objet Session ne peut pas être décapé (il est fondamentalement non sérialisable car il peut gérer la mémoire GPU et son état de la sorte).

Je suggérerais de paralléliser le code en utilisant acteurs , qui sont essentiellement les analogues analogiques des "objets" et leur utilisation sont utilisés pour gérer les états dans un environnement distribué.

Ray est un bon cadre pour le faire. Vous pouvez définir une classe Python qui gère TensorFlow Session et expose une méthode pour exécuter votre simulation.

import ray
import tensorflow as tf

ray.init()

@ray.remote
class Simulator(object):
    def __init__(self):
        self.sess = tf.Session()
        self.simple_model = tf.constant([1.0])

    def simulate(self):
        return self.sess.run(self.simple_model)

# Create two actors.
simulators = [Simulator.remote() for _ in range(2)]

# Run two simulations in parallel.
results = ray.get([s.simulate.remote() for s in simulators])

Voici quelques exemples supplémentaires de parallélisant TensorFlow avec Ray .

Voir la documentation Ray . Notez que je suis l'un des développeurs de Ray.

2
Robert Nishihara