web-dev-qa-db-fra.com

Comment ajouter une condition si dans un graphique TensorFlow?

Disons que j'ai le code suivant:

x = tf.placeholder("float32", shape=[None, ins_size**2*3], name = "x_input")
condition = tf.placeholder("int32", shape=[1, 1], name = "condition")
W = tf.Variable(tf.zeros([ins_size**2*3,label_option]), name = "weights")
b = tf.Variable(tf.zeros([label_option]), name = "bias")

if condition > 0:
    y = tf.nn.softmax(tf.matmul(x, W) + b)
else:
    y = tf.nn.softmax(tf.matmul(x, W) - b)  

L'instruction if fonctionnerait-elle dans le calcul (je ne le pense pas)? Sinon, comment puis-je ajouter une instruction if dans le graphique de calcul TensorFlow?

55
Yee Liu

Vous avez raison de dire que l'instruction if ne fonctionne pas ici, car la condition est évaluée au moment de la construction du graphique, alors que vous voulez probablement que la condition dépende de la valeur fournie à l'espace réservé au moment de l'exécution. (En fait, il prendra toujours la première branche, car condition > 0 Est évalué à Tensor, ce qui est "vérité" en Python .)

Pour prendre en charge le flux de contrôle conditionnel, TensorFlow fournit l'opérateur tf.cond() , qui évalue l'une des deux branches, en fonction d'une condition booléenne. Pour vous montrer comment l'utiliser, je vais réécrire votre programme pour que condition soit une valeur scalaire tf.int32 Par souci de simplicité:

x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input")
condition = tf.placeholder(tf.int32, shape=[], name="condition")
W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights")
b = tf.Variable(tf.zeros([label_option]), name="bias")

y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b)
86
mrry

TensorFlow 2.0

TF 2.0 introduit une fonctionnalité appelée AutoGraph qui vous permet de compiler JIT python code dans les exécutions de graphes. Cela signifie que vous pouvez utiliser python control instructions de flux (oui, cela inclut les instructions if).

AutoGraph prend en charge les instructions Python telles que while, for, if, break, continue et return, avec prise en charge de l'imbrication, ce qui signifie que vous pouvez utiliser les expressions Tensor dans la condition d'instructions while et if ou effectuer une itération sur un tenseur dans un for boucle.

Vous devrez définir une fonction implémentant votre logique et l'annoter avec tf.function . Voici un exemple modifié de la documentation:

import tensorflow as tf

@tf.function
def sum_even(items):
  s = 0
  for c in items:
    if tf.equal(c % 2, 0): 
        s += c
  return s

sum_even(tf.constant([10, 12, 15, 20]))
#  <tf.Tensor: id=1146, shape=(), dtype=int32, numpy=42>
1
cs95