web-dev-qa-db-fra.com

Comment faire le produit de matrices dans PyTorch

Numpy je peux faire une simple multiplication matricielle comme ceci:

a = numpy.arange(2*3).reshape(3,2)
b = numpy.arange(2).reshape(2,1)
print(a)
print(b)
print(a.dot(b))

Cependant, lorsque j'essaie avec PyTorch Tensors, cela ne fonctionne pas:

a = torch.Tensor([[1, 2, 3], [1, 2, 3]]).view(-1, 2)
b = torch.Tensor([[2, 1]]).view(2, -1)
print(a)
print(a.size())

print(b)
print(b.size())

print(torch.dot(a, b))

Ce code lève l'erreur suivante:

RuntimeError: taille de tenseur incohérente dans /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503

Des idées sur la façon dont la multiplication matricielle peut être réalisée dans PyTorch?

35
blckbird

Vous cherchez

torch.mm(a,b)

Notez que torch.dot() se comporte différemment de np.dot(). Il y a eu des discussions sur ce qui serait souhaitable ici . Plus précisément, torch.dot() traite les deux a et b comme des vecteurs 1D (quelle que soit leur forme d'origine) et calcule leur produit intérieur. L'erreur est renvoyée, car ce comportement rend votre a un vecteur de longueur 6 et votre b un vecteur de longueur 2; par conséquent, leur produit intérieur ne peut pas être calculé. Pour la multiplication de matrice dans PyTorch, utilisez torch.mm(). np.dot() de Numpy's, en revanche, est plus flexible; il calcule le produit interne pour les tableaux 1D et effectue la multiplication de matrice pour les tableaux 2D.

52
mexmex

Si vous voulez faire une multiplication matricielle (tenseur de rang 2), vous pouvez le faire de quatre manières équivalentes:

AB = A.mm(B) # computes A.B (matrix multiplication)
# or
AB = torch.mm(A, B)
# or
AB = torch.matmul(A, B)
# or, even simpler
AB = A @ B # Python 3.5+

Il y a quelques subtilités. De la documentation PyTorch :

torch.mm ne diffuse pas. Pour les produits de diffusion matriciels, voir torch.matmul ().

Par exemple, vous ne pouvez pas multiplier deux vecteurs à une dimension avec torch.mm, ni multiplier des matrices groupées (rang 3). À cette fin, vous devriez utiliser le plus polyvalent torch.matmul. Pour une liste complète des comportements de diffusion de torch.matmul, voir le documentation .

Pour multiplier les éléments, vous pouvez simplement faire (si A et B ont la même forme)

A * B # element-wise matrix multiplication (Hadamard product)
26
BiBi

Utilisez torch.mm(a, b) ou torch.matmul(a, b)
Les deux sont identiques.

>>> torch.mm
<built-in method mm of type object at 0x11712a870>
>>> torch.matmul
<built-in method matmul of type object at 0x11712a870>

Il y a une autre option qu'il serait bon de connaître. C'est l'opérateur @. @Simon H.

>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> a@b
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.mm(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.matmul(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])    

Les trois donnent les mêmes résultats.

Liens connexes:
Opérateur de multiplication matricielle
PEP 465 - Un opérateur infixe dédié à la multiplication matricielle

4
David Jung