web-dev-qa-db-fra.com

Différence entre tensor.permute et tensor.view dans PyTorch?

Quelle est la différence entre tensor.permute() et tensor.view()?

Ils semblent faire la même chose.

9
samol

Contribution

In [12]: aten = torch.tensor([[1, 2, 3], [4, 5, 6]])

In [13]: aten
Out[13]: 
tensor([[ 1,  2,  3],
        [ 4,  5,  6]])

In [14]: aten.shape
Out[14]: torch.Size([2, 3])

torch.view() remodèle le tenseur en une forme différente mais compatible. Par exemple, notre tenseur d'entrée aten a la forme (2, 3). Cela peut être affiché comme tenseurs de formes (6, 1), (1, 6) Etc.,

# reshaping (or viewing) 2x3 matrix as a column vector of shape 6x1
In [15]: aten.view(6, -1)
Out[15]: 
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6]])

In [16]: aten.view(6, -1).shape
Out[16]: torch.Size([6, 1])

Alternativement, il peut également être remodelé ou voir ed comme un vecteur ligne de forme (1, 6) Comme dans:

In [19]: aten.view(-1, 6)
Out[19]: tensor([[ 1,  2,  3,  4,  5,  6]])

In [20]: aten.view(-1, 6).shape
Out[20]: torch.Size([1, 6])

Tandis que tensor.permute() est uniquement utilisé pour permuter les axes. L'exemple ci-dessous clarifiera les choses:

In [39]: aten
Out[39]: 
tensor([[ 1,  2,  3],
        [ 4,  5,  6]])

In [40]: aten.shape
Out[40]: torch.Size([2, 3])

# swapping the axes/dimensions 0 and 1
In [41]: aten.permute(1, 0)
Out[41]: 
tensor([[ 1,  4],
        [ 2,  5],
        [ 3,  6]])

# since we permute the axes/dims, the shape changed from (2, 3) => (3, 2)
In [42]: aten.permute(1, 0).shape
Out[42]: torch.Size([3, 2])

Vous pouvez également utiliser une indexation négative pour faire la même chose que dans:

In [45]: aten.permute(-1, 0)
Out[45]: 
tensor([[ 1,  4],
        [ 2,  5],
        [ 3,  6]])

In [46]: aten.permute(-1, 0).shape
Out[46]: torch.Size([3, 2])
4
kmario23

La vue change la façon dont le tenseur est représenté. Par exemple: un tenseur à 4 éléments peut être représenté comme 4X1 ou 2X2 ou 1X4 mais permute change les axes. Pendant la permutation, les données sont déplacées, mais avec la vue, les données ne sont pas déplacées mais simplement réinterprétées.

Les exemples de code ci-dessous peuvent vous aider. a est un tenseur/matrice 2x2. Avec l'utilisation de la vue, vous pouvez lire a comme un vecteur de colonne ou de ligne (tenseur). Mais vous ne pouvez pas le transposer. Pour transposer, vous avez besoin de permuter. La transposition est obtenue en permutant/permutant les axes.

In [7]: import torch

In [8]: a = torch.tensor([[1,2],[3,4]])

In [9]: a
Out[9]: 
tensor([[ 1,  2],
        [ 3,  4]])

In [11]: a.permute(1,0)
Out[11]: 
tensor([[ 1,  3],
        [ 2,  4]])

In [12]: a.view(4,1)
Out[12]: 
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])

In [13]: 

Bonus: Voir https://Twitter.com/karpathy/status/1013322763790999552

3
Umang Gupta