web-dev-qa-db-fra.com

Ajuster une valeur unique dans le tenseur - TensorFlow

Cela me gêne de demander cela, mais comment ajuster une seule valeur dans un tenseur? Supposons que vous vouliez ajouter '1' à une seule valeur de votre tenseur?

Le faire par indexation ne fonctionne pas:

TypeError: 'Tensor' object does not support item assignment

Une approche serait de construire un tenseur de 0 de forme identique. Et puis en ajustant un 1 à la position que vous voulez. Ensuite, vous ajouteriez les deux tenseurs ensemble. Encore une fois, cela pose le même problème qu'avant.

J'ai lu plusieurs fois la documentation de l'API et n'arrive pas à comprendre comment faire. Merci d'avance!

55
LeavesBreathe

UPDATE: TensorFlow 1.0 inclut un opérateur tf.scatter_nd() , qui peut être utilisé pour créer delta ci-dessous sans créer un tf.SparseTensor.


C'est en fait étonnamment délicat avec les opérations existantes! Quelqu'un peut peut-être suggérer un moyen plus agréable de résumer ce qui suit, mais voici un moyen de le faire.

Disons que vous avez un tenseur tf.constant():

c = tf.constant([[0.0, 0.0, 0.0],
                 [0.0, 0.0, 0.0],
                 [0.0, 0.0, 0.0]])

... et vous voulez ajouter 1.0 à l'emplacement [1, 1]. Une façon de faire est de définir un tf.SparseTensor , delta, représentant le changement:

indices = [[1, 1]]  # A list of coordinates to update.

values = [1.0]  # A list of values corresponding to the respective
                # coordinate in indices.

shape = [3, 3]  # The shape of the corresponding dense tensor, same as `c`.

delta = tf.SparseTensor(indices, values, shape)

Ensuite, vous pouvez utiliser le tf.sparse_tensor_to_dense() op pour créer un tenseur dense à partir de delta et l’ajouter à c:

result = c + tf.sparse_tensor_to_dense(delta)

sess = tf.Session()
sess.run(result)
# ==> array([[ 0.,  0.,  0.],
#            [ 0.,  1.,  0.],
#            [ 0.,  0.,  0.]], dtype=float32)
66
mrry

Que diriez-vous de tf.scatter_update(ref, indices, updates) ou tf.scatter_add(ref, indices, updates)?

ref[indices[...], :] = updates
ref[indices[...], :] += updates

Voir this .

8
Liping Liu

tf.scatter_update n’a pas d’opérateur de descente de gradient attribué et génère une erreur lors de l’apprentissage avec au moins tf.train.GradientDescentOptimizer. Vous devez implémenter la manipulation de bits avec des fonctions de bas niveau.

1
johannes