web-dev-qa-db-fra.com

Tensorflow python: Accès à des éléments individuels dans un tenseur

Cette question concerne l'accès à des éléments individuels dans un tenseur, disons [[1,2,3]]. J'ai besoin d'accéder à l'élément interne [1,2,3] (ceci peut être effectué avec .eval () ou sess.run ()) mais cela prend plus de temps quand la taille du tenseur est énorme)

Y at-il une méthode pour faire la même chose plus rapidement?

Merci d'avance.

41
cipher42

Il existe deux méthodes principales pour accéder à des sous-ensembles d'éléments d'un tenseur. Chacune de ces méthodes devrait convenir à votre exemple.

  1. Utilisez l'opérateur d'indexation (basé sur tf.slice() ) pour extraire une tranche contiguë du tenseur.

    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    
    output = input[0, :]
    print sess.run(output)  # ==> [1 2 3]
    

    L'opérateur d'indexation prend en charge bon nombre des mêmes spécifications de tranche que NumPy.

  2. Utilisez le tf.gather() op pour sélectionner une tranche non contiguë du tenseur.

    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    
    output = tf.gather(input, 0)
    print sess.run(output)  # ==> [1 2 3]
    
    output = tf.gather(input, [0, 2])
    print sess.run(output)  # ==> [[1 2 3] [7 8 9]]
    

    Notez que tf.gather() ne vous permet de sélectionner que des tranches entières dans la 0ème dimension (lignes entières dans l'exemple d'une matrice). Il peut donc être nécessaire de tf.reshape() ou tf.transpose() votre entrée pour obtenir les éléments appropriés.

51
mrry

Je suppose que c'est le reste du calcul qui prend du temps, plutôt que d'accéder à un élément.

Le résultat peut également nécessiter une copie de la mémoire stockée. Par conséquent, si elle se trouve sur la carte graphique, elle devra être recopiée vers RAM d’abord, puis vous aurez accès à votre élément. Si Dans ce cas, vous pouvez l'ignorer en ajoutant une opération tensorflow pour prendre le premier élément et ne le renvoyer que.

1
Sorin

Vous ne pouvez simplement pas obtenir valeur du 0ème élément de [[1,2,3]] sans exécuter () - ning ou eval () - une opération qui l'obtiendrait. Parce qu'avant de "courir" ou "eval", vous avez seulement une description de la façon d'obtenir cet élément interne (parce que TF utilise des graphiques/calculs symboliques). Donc, même si vous utilisiez tf.gather/tf.slice, il vous faudrait quand même obtenir valeurs pour ces opérations via eval/run. Voir la réponse de @ mrry.

1