web-dev-qa-db-fra.com

Comprendre PyTorch einsum

Je sais comment einsum fonctionne dans NumPy. PyTorch offre également une fonctionnalité similaire: torch.einsum () . Quelles sont les similitudes et les différences, en termes de fonctionnalités ou de performances? Les informations disponibles dans la documentation de PyTorch sont plutôt rares et ne fournissent aucun aperçu à ce sujet.

8
kmario23

Étant donné que la description d'einsum est maigre dans la documentation de la torche, j'ai décidé d'écrire ce message pour documenter, comparer et contraster la façon dont torch.einsum() se comporte par rapport à numpy.einsum() .

Différences:

  • NumPy autorise les minuscules et les majuscules [a-zA-Z] Pour la " chaîne d'indice" tandis que PyTorch n'autorise que les minuscules [a-z].

  • NumPy accepte nd-tableaux, simples Python listes (ou tuples), liste de listes (ou Tuple de tuples, liste de tuples, Tuple de listes) ou même tenseurs PyTorch comme opérandes (c'est-à-dire entrées). Cela est dû au fait que les opérandes doivent seulement être array_like et pas strictement des nd-tableaux NumPy. Au contraire, PyTorch s'attend à ce que les opérandes (c'est-à-dire les entrées) soient strictement des tenseurs PyTorch. Il lancera un TypeError si vous passez soit un simple Python listes/tuples (ou ses combinaisons) soit des nd-tableaux NumPy.

  • NumPy prend en charge de nombreux arguments de mots clés (par exemple optimize) en plus de nd-arrays Tandis que PyTorch n'offre pas encore une telle flexibilité.

Voici les implémentations de quelques exemples à la fois dans PyTorch et NumPy:

# input tensors to work with

In [16]: vec
Out[16]: tensor([0, 1, 2, 3])

In [17]: aten
Out[17]: 
tensor([[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]])

In [18]: bten
Out[18]: 
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]])

1) Multiplication matricielle
PyTorch: torch.matmul(aten, bten); aten.mm(bten)
NumPy: np.einsum("ij, jk -> ik", arr1, arr2)

In [19]: torch.einsum('ij, jk -> ik', aten, bten)
Out[19]: 
tensor([[130, 130, 130, 130],
        [230, 230, 230, 230],
        [330, 330, 330, 330],
        [430, 430, 430, 430]])

2) Extraire les éléments le long de la diagonale principale
PyTorch: torch.diag(aten)
NumPy: np.einsum("ii -> i", arr)

In [28]: torch.einsum('ii -> i', aten)
Out[28]: tensor([11, 22, 33, 44])

3) Produit Hadamard (c'est-à-dire produit élément par élément de deux tenseurs)
PyTorch: aten * bten
NumPy: np.einsum("ij, ij -> ij", arr1, arr2)

In [34]: torch.einsum('ij, ij -> ij', aten, bten)
Out[34]: 
tensor([[ 11,  12,  13,  14],
        [ 42,  44,  46,  48],
        [ 93,  96,  99, 102],
        [164, 168, 172, 176]])

4) Équerrage au niveau des éléments
PyTorch: aten ** 2
NumPy: np.einsum("ij, ij -> ij", arr, arr)

In [37]: torch.einsum('ij, ij -> ij', aten, aten)
Out[37]: 
tensor([[ 121,  144,  169,  196],
        [ 441,  484,  529,  576],
        [ 961, 1024, 1089, 1156],
        [1681, 1764, 1849, 1936]])

Général: la puissance par élément nth peut être implémentée en répétant la chaîne d'indice et le tenseur n fois. Par exemple, le calcul de la 4e puissance par élément d'un tenseur peut être effectué en utilisant:

# NumPy: np.einsum('ij, ij, ij, ij -> ij', arr, arr, arr, arr)
In [38]: torch.einsum('ij, ij, ij, ij -> ij', aten, aten, aten, aten)
Out[38]: 
tensor([[  14641,   20736,   28561,   38416],
        [ 194481,  234256,  279841,  331776],
        [ 923521, 1048576, 1185921, 1336336],
        [2825761, 3111696, 3418801, 3748096]])

5) Trace (c'est-à-dire somme des éléments diagonaux principaux)
PyTorch: torch.trace(aten)
NumPy einsum: np.einsum("ii -> ", arr)

In [44]: torch.einsum('ii -> ', aten)
Out[44]: tensor(110)

6) Transposition matricielle
PyTorch: torch.transpose(aten, 1, 0)
NumPy einsum: np.einsum("ij -> ji", arr)

In [58]: torch.einsum('ij -> ji', aten)
Out[58]: 
tensor([[11, 21, 31, 41],
        [12, 22, 32, 42],
        [13, 23, 33, 43],
        [14, 24, 34, 44]])

7) Produit extérieur (de vecteurs)
PyTorch: torch.ger(vec, vec)
NumPy einsum: np.einsum("i, j -> ij", vec, vec)

In [73]: torch.einsum('i, j -> ij', vec, vec)
Out[73]: 
tensor([[0, 0, 0, 0],
        [0, 1, 2, 3],
        [0, 2, 4, 6],
        [0, 3, 6, 9]])

8) Produit intérieur (des vecteurs) PyTorch: torch.dot(vec1, vec2)
NumPy einsum: np.einsum("i, i -> ", vec1, vec2)

In [76]: torch.einsum('i, i -> ', vec, vec)
Out[76]: tensor(14)

9) Somme le long de l'axe 0
PyTorch: torch.sum(aten, 0)
NumPy einsum: np.einsum("ij -> j", arr)

In [85]: torch.einsum('ij -> j', aten)
Out[85]: tensor([104, 108, 112, 116])

10) Somme le long de l'axe 1
PyTorch: torch.sum(aten, 1)
NumPy einsum: np.einsum("ij -> i", arr)

In [86]: torch.einsum('ij -> i', aten)
Out[86]: tensor([ 50,  90, 130, 170])

11) Multiplication de matrice de lots
PyTorch: torch.bmm(batch_tensor_1, batch_tensor_2)
NumPy: np.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)

# input batch tensors to work with
In [13]: batch_tensor_1 = torch.arange(2 * 4 * 3).reshape(2, 4, 3)
In [14]: batch_tensor_2 = torch.arange(2 * 3 * 4).reshape(2, 3, 4) 

In [15]: torch.bmm(batch_tensor_1, batch_tensor_2)  
Out[15]: 
tensor([[[  20,   23,   26,   29],
         [  56,   68,   80,   92],
         [  92,  113,  134,  155],
         [ 128,  158,  188,  218]],

        [[ 632,  671,  710,  749],
         [ 776,  824,  872,  920],
         [ 920,  977, 1034, 1091],
         [1064, 1130, 1196, 1262]]])

# sanity check with the shapes
In [16]: torch.bmm(batch_tensor_1, batch_tensor_2).shape 
Out[16]: torch.Size([2, 4, 4])

# batch matrix multiply using einsum
In [17]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)
Out[17]: 
tensor([[[  20,   23,   26,   29],
         [  56,   68,   80,   92],
         [  92,  113,  134,  155],
         [ 128,  158,  188,  218]],

        [[ 632,  671,  710,  749],
         [ 776,  824,  872,  920],
         [ 920,  977, 1034, 1091],
         [1064, 1130, 1196, 1262]]])

# sanity check with the shapes
In [18]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2).shape

12) Somme le long de l'axe 2
PyTorch: torch.sum(batch_ten, 2)
NumPy einsum: np.einsum("ijk -> ij", arr3D)

In [99]: torch.einsum("ijk -> ij", batch_ten)
Out[99]: 
tensor([[ 50,  90, 130, 170],
        [  4,   8,  12,  16]])

13) Additionner tous les éléments d'un tenseur nD
PyTorch: torch.sum(batch_ten)
NumPy einsum: np.einsum("ijk -> ", arr3D)

In [101]: torch.einsum("ijk -> ", batch_ten)
Out[101]: tensor(480)

14) Somme sur plusieurs axes (c'est-à-dire marginalisation)
PyTorch: torch.sum(arr, dim=(dim0, dim1, dim2, dim3, dim4, dim6, dim7))
NumPy: np.einsum("ijklmnop -> n", nDarr)

# 8D tensor
In [103]: nDten = torch.randn((3,5,4,6,8,2,7,9))
In [104]: nDten.shape
Out[104]: torch.Size([3, 5, 4, 6, 8, 2, 7, 9])

# marginalize out dimension 5 (i.e. "n" here)
In [111]: esum = torch.einsum("ijklmnop -> n", nDten)
In [112]: esum
Out[112]: tensor([  98.6921, -206.0575])

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [113]: tsum = torch.sum(nDten, dim=(0, 1, 2, 3, 4, 6, 7))

In [115]: torch.allclose(tsum, esum)
Out[115]: True

15) Produits à double point/ produit intérieur Frobenius (identique à: torch.sum (produit hadamard) cf. 3)
PyTorch: torch.sum(aten * bten)
NumPy: np.einsum("ij, ij -> ", arr1, arr2)

In [120]: torch.einsum("ij, ij -> ", aten, bten)
Out[120]: tensor(1300)
25
kmario23