web-dev-qa-db-fra.com

Fonction find () de style MATLAB dans Python

Dans MATLAB, il est facile de trouver les indices de valeurs qui répondent à une condition particulière:

>> a = [1,2,3,1,2,3,1,2,3];
>> find(a > 2)     % find the indecies where this condition is true
[3, 6, 9]          % (MATLAB uses 1-based indexing)
>> a(find(a > 2))  % get the values at those locations
[3, 3, 3]

Quelle serait la meilleure façon de le faire en Python?

Jusqu'à présent, j'ai trouvé ce qui suit. Pour obtenir simplement les valeurs:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> [val for val in a if val > 2]
[3, 3, 3]

Mais si je veux l'index de chacune de ces valeurs, c'est un peu plus compliqué:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> inds = [i for (i, val) in enumerate(a) if val > 2]
>>> inds
[2, 5, 8]
>>> [val for (i, val) in enumerate(a) if i in inds]
[3, 3, 3]

Existe-t-il une meilleure façon de le faire en Python, en particulier pour des conditions arbitraires (pas seulement 'val> 2')?

J'ai trouvé des fonctions équivalentes à MATLAB 'find' dans NumPy mais je n'ai actuellement pas accès à ces bibliothèques.

55
user344226

Vous pouvez créer une fonction qui prend un paramètre appelable qui sera utilisé dans la partie condition de votre compréhension de liste. Ensuite, vous pouvez utiliser un lambda ou un autre objet fonction pour passer votre condition arbitraire:

def indices(a, func):
    return [i for (i, val) in enumerate(a) if func(val)]

a = [1, 2, 3, 1, 2, 3, 1, 2, 3]

inds = indices(a, lambda x: x > 2)

>>> inds
[2, 5, 8]

C'est un peu plus proche de votre exemple Matlab, sans avoir à charger tout numpy.

26
John

dans numpy vous avez where:

>> import numpy as np
>> x = np.random.randint(0, 20, 10)
>> x
array([14, 13,  1, 15,  8,  0, 17, 11, 19, 13])
>> np.where(x > 10)
(array([0, 1, 3, 6, 7, 8, 9], dtype=int64),)
83
joaquin

Ou utilisez la fonction non nulle de numpy:

import numpy as np
a    = np.array([1,2,3,4,5])
inds = np.nonzero(a>2)
a[inds] 
array([3, 4, 5])
8
vincentv

Pourquoi ne pas simplement utiliser ceci:

[i for i in range(len(a)) if a[i] > 2]

ou pour des conditions arbitraires, définissez une fonction f pour votre condition et faites:

[i for i in range(len(a)) if f(a[i])]
5
JasonFruit

La routine numpy la plus couramment utilisée pour cette application est numpy.where() ; cependant, je crois que cela fonctionne de la même manière que numpy.nonzero() .

import numpy
a    = numpy.array([1,2,3,4,5])
inds = numpy.where(a>2)

Pour obtenir les valeurs, vous pouvez soit stocker les indices et les découper avec:

a[inds]

ou vous pouvez passer le tableau comme paramètre facultatif:

numpy.where(a>2, a)

ou plusieurs tableaux:

b = numpy.array([11,22,33,44,55])
numpy.where(a>2, a, b)
4
ryanjdillon

J'ai essayé de trouver un moyen rapide de faire exactement cette chose, et voici ce que je suis tombé sur (utilise numpy pour sa comparaison rapide de vecteur):

a_bool = numpy.array(a) > 2
inds = [i for (i, val) in enumerate(a_bool) if val]

Il s'avère que c'est beaucoup plus rapide que:

inds = [i for (i, val) in enumerate(a) if val > 2]

Il semble que Python est plus rapide à la comparaison lorsqu'il est fait dans un tableau numpy, et/ou plus rapide à faire des listes de compréhension lors de la vérification de la vérité plutôt que de la comparaison.

Modifier:

Je revoyais mon code et je suis tombé sur une façon peut-être moins gourmande en mémoire, un peu plus rapide et ultra-concise de le faire en une seule ligne:

inds = np.arange( len(a) )[ a < 2 ]
3
Nate

Pour obtenir des valeurs avec des conditions arbitraires, vous pouvez utiliser filter() avec une fonction lambda:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> filter(lambda x: x > 2, a)
[3, 3, 3]

Une façon possible d'obtenir les indices serait d'utiliser enumerate() pour construire un Tuple avec des indices et des valeurs, puis de filtrer cela:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> aind = Tuple(enumerate(a))
>>> print aind
((0, 1), (1, 2), (2, 3), (3, 1), (4, 2), (5, 3), (6, 1), (7, 2), (8, 3))
>>> filter(lambda x: x[1] > 2, aind)
((2, 3), (5, 3), (8, 3))
3
Blair

Je pense que j'ai peut-être trouvé un substitut rapide et simple. BTW J'ai senti que la fonction np.where () n'était pas très satisfaisante, en ce sens qu'elle contenait en quelque sorte une ligne ennuyeuse de zéro élément.

import matplotlib.mlab as mlab
a = np.random.randn(1,5)
print a

>> [[ 1.36406736  1.45217257 -0.06896245  0.98429727 -0.59281957]]

idx = mlab.find(a<0)
print idx
type(idx)

>> [2 4]
>> np.ndarray

Best, Da

2
DidasW

Le code de recherche de Matlab a deux arguments. Le code de John représente le premier argument mais pas le second. Par exemple, si vous voulez savoir où dans l'index la condition est satisfaite: la fonction de Mtlab serait:

find(x>2,1)

En utilisant le code de John, tout ce que vous avez à faire est d'ajouter un [x] à la fin de la fonction index, où x est le numéro d'index que vous recherchez.

def indices(a, func):
    return [i for (i, val) in enumerate(a) if func(val)]

a = [1, 2, 3, 1, 2, 3, 1, 2, 3]

inds = indices(a, lambda x: x > 2)[0] #[0] being the 2nd matlab argument

qui renvoie >>> 2, le premier indice à dépasser 2.

0
Clayton Pipkin