web-dev-qa-db-fra.com

Comment utiliser numba sur une fonction membre d'une classe?

J'utilise la version stable de Numba 0.30.1.

Je peux le faire:

import numba as nb
@nb.jit("void(f8[:])",nopython=True)                             
def complicated(x):                                  
    for a in x:
        b = a**2.+a**3.

comme test et l'accélération est énorme. Mais je ne sais pas comment procéder si j'ai besoin d'accélérer une fonction à l'intérieur d'une classe.

import numba as nb
def myClass(object):
    def __init__(self):
        self.k = 1
    #@nb.jit(???,nopython=True)                             
    def complicated(self,x):                                  
        for a in x:
            b = a**2.+a**3.+self.k

Quel type numba dois-je utiliser pour l'objet self? J'ai besoin d'avoir cette fonction dans une classe car elle doit accéder à une variable membre.

17
dbrane

J'étais dans une situation très similaire et j'ai trouvé un moyen d'utiliser une fonction Numba-JITed à l'intérieur d'une classe.

L'astuce consiste à utiliser une méthode statique, car ce type de méthodes n'est pas appelé en ajoutant l'ajout de l'instance d'objet à la liste des arguments. L'inconvénient de ne pas avoir accès à self est que vous ne pouvez pas utiliser de variables définies en dehors de la méthode. Vous devez donc les transmettre à la méthode statique à partir d'une méthode d'appel ayant accès à self. Dans mon cas, je n'avais pas besoin de définir une méthode wrapper. Je devais juste diviser la méthode que je voulais compiler en JIT en deux méthodes.

Dans le cas de votre exemple, la solution serait:

from numba import jit

class MyClass:
    def __init__(self):
        self.k = 1

    def calculation(self):
        k = self.k
        return self.complicated([1,2,3],k)

    @staticmethod
    @jit(nopython=True)                             
    def complicated(x,k):                                  
        for a in x:
            b = a**2 .+ a**3 .+ k
23
Marduk

Vous avez plusieurs options:

Utilisez un jitclass ( http://numba.pydata.org/numba-doc/0.30.1/user/jitclass.html ) pour "numba-ize" le tout.

Ou faites de la fonction membre un wrapper et passez les variables membres à travers:

import numba as nb

@nb.jit
def _complicated(x, k):
    for a in x:
        b = a**2.+a**3.+k

class myClass(object):
    def __init__(self):
        self.k = 1

    def complicated(self,x):                                  
        _complicated(x, self.k)
9
JoshAdel