web-dev-qa-db-fra.com

Que fait data.norm () <1000 dans PyTorch?

Je suis le tutoriel PyTorch ici . Il dit que

x = torch.randn(3, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
    y = y * 2

print(y)

Out:    
tensor([-590.4467,   97.6760,  921.0221])

Quelqu'un pourrait-il expliquer ce que data.norm () fait ici? Lorsque je remplace .randn par .ones, sa sortie est tensor([ 1024., 1024., 1024.]).

7
voo_doo

C'est simplement la norme L2 (norme euclidienne) du tenseur. Ci-dessous une illustration:

In [15]: x = torch.randn(3, requires_grad=True)

In [16]: y = x * 2

In [17]: y.data
Out[17]: tensor([-1.2510, -0.6302,  1.2898])

In [18]: y.data.norm()
Out[18]: tensor(1.9041)

# computing the norm using elementary operations
In [19]: torch.sqrt(torch.sum(torch.pow(y, 2)))
Out[19]: tensor(1.9041)

Tout d'abord, il place tous les éléments dans tenseur y, puis les additionne et prend finalement une racine carrée. Ces opérations calculent la norme dite L2.

3
kmario23

En s'appuyant sur ce que @ kmario23 dit, il multiplie par 2 les éléments d'un vecteur jusqu'à ce que la distance/magnitude euclidienne du vecteur soit d'au moins 1000.

Avec l'exemple du vecteur avec (1,1,1): il passe à (512, 512, 512), où la norme l2 est d'environ 886. Cela est inférieur à 1000, il est donc multiplié par 2 et devient ( 1024, 1024, 1024). Cela a une magnitude supérieure à 1000, donc ça s'arrête.

1
Jonathan
y.data.norm() 

est équivalent à

torch.sqrt(torch.sum(torch.pow(y, 2)))
0
aimuch