web-dev-qa-db-fra.com

Indexation d'un tenseur multidimensionnel avec un tenseur dans PyTorch

J'ai le code suivant:

a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])

J'ai un index multidimensionnel b et je veux l'utiliser pour sélectionner une seule cellule dans a. Si b n'était pas un tenseur, je pourrais faire:

a[1,1,1,1]

Qui renvoie la bonne cellule, mais:

a[b]

Ne fonctionne pas, car il sélectionne simplement a[1] quatre fois.

Comment puis-je faire ceci? Merci

8

Une solution plus élégante (et plus simple) pourrait être de simplement convertir b en Tuple:

a[Tuple(b)]
Out[10]: tensor(5.)

J'étais curieux de voir comment cela fonctionne avec numpy "ordinaire", et j'ai trouvé un article connexe expliquant cela assez bien ici .

5
dennlinger

Vous pouvez diviser b en 4 à l'aide de chunk , puis utiliser le fragmenté b pour indexer l'élément spécifique que vous souhaitez:

>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)]   # here's the trick!
Out[24]: tensor([[40, 80,  0]])

Ce qui est bien, c'est qu'il peut être facilement généralisé à n'importe quelle dimension de a, il vous suffit de faire en sorte que le nombre de mandrins soit égal à la dimension de a.

5
Shai