J'ai un tenseur d'images et j'aimerais en sélectionner au hasard. Je cherche l'équivalent de np.random.choice()
.
import torch
pictures = torch.randint(0, 256, (1000, 28, 28, 3))
Disons que je veux 10 de ces photos.
torch
n'a pas d'implémentation équivalente de np.random.choice()
. La meilleure chose à faire est de choisir un index aléatoire parmi les choix.
choices[torch.randint(choices.shape[0], (1,))]
Cela génère un randint
entre 0 et le nombre d'éléments dans le tenseur.
for i in range(5):
print(choices[torch.randint(choices.shape[0], (1,))])
tensor([2])
tensor([6])
tensor([2])
tensor([6])
tensor([7])
Si vous souhaitez définir replacement = False
, supprimez la valeur choisie à l'aide d'un masque:
for i in range(10):
value = choices[torch.randint(choices.shape[0], (1,))]
choices = choices[choices!=value]
print(value)
tensor([2])
tensor([4])
tensor([6])
tensor([7])
Dans mon cas: values.shape = (386363948, 2), k = 190973, le code suivant fonctionne assez rapidement, 0,1 ~ 0,2 seconde.
indice = random.sample(range(386363948), 190973)
indice = torch.tensor(indice)
sampled_values = values[indice]
Cependant, utiliser torch.randperm coûterait plus de 20 secondes.
sampled_values = values[torch.randperm(386363948)[190973]]
comme l'autre mentionné, la torche n'a pas le choix, vous pouvez utiliser randint ou permutation à la place
import torch
n = 4
choices = torch.Rand(4, 3)
choices_flat = choices.view(-1)
index = torch.randint(choices_flat.numel(), (n,))
# or if replace = False
index = torch.randperm(choices_flat.numel())[:n]
select = choices_flat[index]