web-dev-qa-db-fra.com

Puis-je découper des tenseurs avec une indexation logique ou des listes d'indices?

J'essaie de découper un tenseur PyTorch en utilisant un index logique sur les colonnes. Je veux les colonnes qui correspondent à une valeur 1 dans le vecteur d'index. Le découpage en tranches et l'indexation logique sont possibles, mais sont-ils possibles ensemble? Si c'est le cas, comment? Ma tentative continue de renvoyer l'erreur inutile

TypeError: indexation d'un tenseur avec un objet de type ByteTensor. Les seuls types pris en charge sont les entiers, les tranches, les scalaires numpy et torch.LongTensor ou torch.ByteTensor comme seul argument.

MCVE

Sortie désirée

C = torch.LongTensor([[1, 3], [4, 6]])
# 1 3
# 4 6

Indexation logique sur les colonnes uniquement

import torch
A_log = torch.ByteTensor([1, 0, 1]) # the logical index
B = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
C = B[:, A_log] # Throws error

J'ai aussi essayé d'utiliser une liste d'indices

import torch
A_idx = torch.LongTensor([0, 2]) # the index vector
B = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
C = B[:, A_idx] # Throws error

Si les vecteurs sont de la même taille, l'indexation logique fonctionne

import torch
A_log = torch.ByteTensor([1, 0, 1]) # the logical index
B = torch.LongTensor([1, 2, 3])
C = B[A_log]

Si j'utilise des plages d'indices contigus, le découpage fonctionne

import torch
B = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
C = B[:, 1:2]

Je peux obtenir le résultat souhaité en répétant l'index logique afin qu'il ait la même taille que le tenseur que j'indexe, mais je dois également remodeler la sortie.

import torch
A_log = torch.ByteTensor([1, 0, 1]) # the logical index
B = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
C = B[A_log.repeat(2, 1)] # [torch.LongTensor of size 4]
C = C.resize_(2, 2)
9
Cecilia

Je pense que cela est mis en œuvre comme index_select fonction, vous pouvez essayer

import torch
A_idx = torch.LongTensor([0, 2]) # the index vector
B = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
C = B.index_select(1, A_idx)
# 1 3
# 4 6
6
dontloo