web-dev-qa-db-fra.com

Comment fonctionne python numpy.where ()?

Je joue avec numpy et fouille dans la documentation et j'ai découvert un peu de magie. À savoir je parle de numpy.where():

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))

Comment font-ils en interne pour que vous puissiez passer quelque chose comme x > 5 dans une méthode? Je suppose que cela a quelque chose à voir avec __gt__ mais je cherche une explication détaillée.

88
pajton

Comment réalisent-ils en interne que vous êtes capable de passer quelque chose comme x> 5 dans une méthode?

La réponse courte est qu'ils ne le font pas.

Toute sorte d'opération logique sur un tableau numpy renvoie un tableau booléen. (c'est-à-dire __gt__, __lt__, etc., tous renvoient des tableaux booléens où la condition donnée est vraie).

Par exemple. 

x = np.arange(9).reshape(3,3)
print x > 5

rendements:

array([[False, False, False],
       [False, False, False],
       [ True,  True,  True]], dtype=bool)

C'est la même raison pour laquelle quelque chose comme if x > 5: lève une ValueError si x est un tableau numpy. C'est un tableau de valeurs True/False, pas une valeur unique.

De plus, les tableaux numpy peuvent être indexés par des tableaux booléens. Par exemple. x[x>5] donne [6 7 8], dans ce cas.

Honnêtement, il est assez rare que vous ayez réellement besoin de numpy.where, mais il ne renvoie que les indicateurs où un tableau booléen est True. En général, vous pouvez faire ce dont vous avez besoin avec une simple indexation booléenne.

72
Joe Kington

Old Answer C'est un peu déroutant. Cela vous donne les LIEUX (tous) où votre affirmation est vraie.

alors:

>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
       99]),)
>>> np.where(a == 90)
(array([90]),)

a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
       77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
       94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040

Je l'utilise comme alternative à list.index (), mais il a également de nombreuses autres utilisations. Je ne l'ai jamais utilisé avec des tableaux 2D.

http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

Nouvelle réponse Il semble que la personne demandait quelque chose de plus fondamental.

La question était de savoir comment VOUS pourriez implémenter quelque chose qui permet à une fonction (telle que où) de savoir ce qui a été demandé.

Tout d’abord, notez que l’appel de l’un des opérateurs de comparaison est une chose intéressante.

a > 1000
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`

Ceci est fait en surchargeant la méthode "__gt__". Par exemple:

>>> class demo(object):
    def __gt__(self, item):
        print item


>>> a = demo()
>>> a > 4
4

Comme vous pouvez le constater, "a> 4" était un code valide.

Vous pouvez obtenir la liste complète et la documentation de toutes les fonctions surchargées ici: http://docs.python.org/reference/datamodel.html

Quelque chose d’incroyable est la simplicité de cette démarche. TOUTES les opérations en python sont effectuées de cette manière. Dire a> b est équivalent à a .gt (b)!

23
Garrett Berg

np.where renvoie un tuple de longueur égal à la dimension du numpy ndarray sur lequel il est appelé (en d'autres termes ndim) et chaque élément de Tuple est un numpy ndarray d'indices de toutes les valeurs du ndarray initial pour lesquelles la condition est True. . (S'il vous plaît ne confondez pas la dimension avec la forme)

Par exemple:

x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
      [3, 4, 5],
      [6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))


y est un tuple de longueur 2 car x.ndim est égal à 2. Le premier élément de Tuple contient les numéros de ligne de tous les éléments supérieurs à 4 et le deuxième élément contient les numéros de colonne de tous les éléments supérieurs à 4. Comme vous pouvez le constater, [1 , 2,2,2] correspond aux numéros de ligne 5,6,7,8 et [2,0,1,2] correspond aux numéros de colonne 5,6,7,8 Notez que le ndarray est parcouru le long de la première dimension (rangée).

De même,

x=np.arange(27).reshape(3,3,3)
np.where(x>4)


retournera un tuple de longueur 3 car x a 3 dimensions.

Mais attendez, il y a plus à np.where!

lorsque deux arguments supplémentaires sont ajoutés à np.where; il effectuera une opération de remplacement pour toutes les combinaisons paires-rangées-colonnes obtenues par le tuple ci-dessus. 

x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
   [0, 0, 1],
   [1, 1, 1]])
0
Piyush Singh