web-dev-qa-db-fra.com

Choix aléatoire avec Pytorch?

J'ai un tenseur d'images et j'aimerais en sélectionner au hasard. Je cherche l'équivalent de np.random.choice().

import torch

pictures = torch.randint(0, 256, (1000, 28, 28, 3))

Disons que je veux 10 de ces photos.

6
Nicolas Gervais

torch n'a pas d'implémentation équivalente de np.random.choice(). La meilleure chose à faire est de choisir un index aléatoire parmi les choix.

choices[torch.randint(choices.shape[0], (1,))]

Cela génère un randint entre 0 et le nombre d'éléments dans le tenseur.

for i in range(5):
    print(choices[torch.randint(choices.shape[0], (1,))])
tensor([2])
tensor([6])
tensor([2])
tensor([6])
tensor([7])

Si vous souhaitez définir replacement = False, supprimez la valeur choisie à l'aide d'un masque:

for i in range(10):
    value = choices[torch.randint(choices.shape[0], (1,))]
    choices = choices[choices!=value]
    print(value)
tensor([2])
tensor([4])
tensor([6])
tensor([7])
0
Nicolas Gervais

Dans mon cas: values.shape = (386363948, 2), k = 190973, le code suivant fonctionne assez rapidement, 0,1 ~ 0,2 seconde.

indice = random.sample(range(386363948), 190973)
indice = torch.tensor(indice)
sampled_values = values[indice]

Cependant, utiliser torch.randperm coûterait plus de 20 secondes.

sampled_values = values[torch.randperm(386363948)[190973]]
1
刘致远

comme l'autre mentionné, la torche n'a pas le choix, vous pouvez utiliser randint ou permutation à la place

import torch
n = 4
choices = torch.Rand(4, 3)
choices_flat = choices.view(-1)
index = torch.randint(choices_flat.numel(), (n,))
# or if replace = False
index = torch.randperm(choices_flat.numel())[:n]
select = choices_flat[index]
0
Qianyi Zhang