web-dev-qa-db-fra.com

Tensorflow Différentes façons d'exporter et d'exécuter un graphe en C ++

Pour importer votre réseau formé vers le C++, vous devez exporter votre réseau pour pouvoir le faire. Après avoir beaucoup cherché et trouvé presque aucune information à ce sujet, il a été précisé que nous devrions utiliser freeze_graph () pour pouvoir le faire.

Grâce à la nouvelle version 0.7 de Tensorflow, ils en ont ajouté documentation .

Après avoir examiné les documentations, j'ai constaté qu'il existe peu de méthodes similaires, pouvez-vous dire quelle est la différence entre freeze_graph() et: tf.train.export_meta_graph Car il a des paramètres similaires, mais il semble qu'il puisse également être utilisé pour importer des modèles en C++ (je suppose que la différence est que pour utiliser la sortie du fichier par cette méthode, vous ne pouvez utiliser que import_graph_def() ou c'est autre chose?)

Une autre question sur l'utilisation de write_graph(): dans les documentations, graph_def Est donnée par sess.graph_def Mais dans les exemples de freeze_graph() c'est sess.graph.as_graph_def(). Quelle est la différence entre ces deux?

Cette question est liée à ce problème.

Merci!

26
Hamed MP

Voici ma solution utilisant les points de contrôle V2 introduits dans TF 0.12.

Il n'est pas nécessaire de convertir toutes les variables en constantes ou geler le graphique .

Pour plus de clarté, un point de contrôle V2 ressemble à ceci dans mon répertoire models:

checkpoint  # some information on the name of the files in the checkpoint
my-model.data-00000-of-00001  # the saved weights
my-model.index  # probably definition of data layout in the previous file
my-model.meta  # protobuf of the graph (nodes and topology info)

Partie Python (sauvegarde)

with tf.Session() as sess:
    tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model')

Si vous créez le Saver avec tf.trainable_variables(), vous pouvez vous épargner des maux de tête et de l'espace de stockage. Mais peut-être que certains modèles plus compliqués nécessitent la sauvegarde de toutes les données, puis supprimez cet argument dans Saver, assurez-vous simplement de créer le Saver après votre graphique est créé. Il est également très judicieux de donner à toutes les variables/couches des noms uniques, sinon vous pouvez exécuter différents problèmes .

Partie Python (inférence)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('models/my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('models/'))
    outputTensors = sess.run(outputOps, feed_dict=feedDict)

Partie C++ (inférence)

Notez que checkpointPath n'est un chemin vers aucun des fichiers existants, juste leur préfixe commun. Si vous avez mis par erreur le chemin vers le fichier .index, TF ne vous dira pas que c'était faux, mais il mourra pendant l'inférence en raison de variables non initialisées.

#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>

using namespace std;
using namespace tensorflow;

...
// set up your input paths
const string pathToGraph = "models/my-model.meta"
const string checkpointPath = "models/my-model";
...

auto session = NewSession(SessionOptions());
if (session == nullptr) {
    throw runtime_error("Could not create Tensorflow session.");
}

Status status;

// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok()) {
    throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}

// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
    throw runtime_error("Error creating graph: " + status.ToString());
}

// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run(
        {{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
        {},
        {graph_def.saver_def().restore_op_name()},
        nullptr);
if (!status.ok()) {
    throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}

// and run the inference to your liking
auto feedDict = ...
auto outputOps = ...
std::vector<tensorflow::Tensor> outputTensors;
status = session->Run(feedDict, outputOps, {}, &outputTensors);
33
Martin Pecka

Pour prédire (et toutes les autres opérations), vous pouvez faire quelque chose comme ceci:

Tout d'abord dans python vous devez nom vos variables ou opération pour une utilisation future

self.init = tf.initialize_variables(tf.all_variables(), name="nInit")

Après l'entraînement, les calculs de so .. lorsque vos variables sont affectées, parcourez-les toutes et enregistrez-les sous forme de constantes dans votre graphique. (presque la même chose peut être faite avec cet outil de gel, mais je le fais habituellement par moi-même, vérifiez "nom = nPoids" dans py et cpp ci-dessous)

def save(self, filename):
    for variable in tf.trainable_variables():
        tensor = tf.constant(variable.eval())
        tf.assign(variable, tensor, name="nWeights")

    tf.train.write_graph(self.sess.graph_def, 'graph/', 'my_graph.pb', as_text=False)

Maintenant, allez en c ++ et chargez notre graphique et chargez les variables à partir des constantes enregistrées:

void load(std::string my_model) {
        auto load_graph_status =
                ReadBinaryProto(tensorflow::Env::Default(), my_model, &graph_def);

        auto session_status = session->Create(graph_def);

        std::vector<tensorflow::Tensor> out;
        std::vector<string> vNames;

        int node_count = graph_def.node_size();
        for (int i = 0; i < node_count; i++) {
            auto n = graph_def.node(i);

            if (n.name().find("nWeights") != std::string::npos) {
                vNames.Push_back(n.name());
            }
        }

        session->Run({}, vNames, {}, &out);

Vous avez maintenant tous vos poids nets neuronaux ou d'autres variables chargés.

De même, vous pouvez effectuer d'autres opérations (vous souvenez-vous des noms?); créer des tenseurs d'entrée et de sortie de taille appropriée, remplir le tenseur d'entrée avec des données et exécuter la session comme suit:

auto operationStatus = session->Run(input, {"put_your_operation_here"}, {}, &out);
18
Alex Joz