web-dev-qa-db-fra.com

Opération Pytorch pour détecter les NaN

Existe-t-il une procédure interne à Pytorch pour détecter les NaN dans les tenseurs? Tensorflow a les opérations tf.is_nan Et tf.check_numerics ... Pytorch a quelque chose de similaire quelque part? Je n'ai pas pu trouver quelque chose comme ça dans les documents ...

Je recherche spécifiquement une routine interne Pytorch, car j'aimerais que cela se produise sur le GPU ainsi que sur le CPU. Cela exclut les solutions basées sur numpy (comme np.isnan(sometensor.numpy()).any()) ...

23
cleros

Vous pouvez toujours tirer parti du fait que nan != nan:

>>> x = torch.tensor([1, 2, np.nan])
tensor([  1.,   2., nan.])
>>> x != x
tensor([ 0,  0,  1], dtype=torch.uint8)

Avec pytorch 0.4, il y a aussi torch.isnan :

>>> torch.isnan(x)
tensor([ 0,  0,  1], dtype=torch.uint8)
37
nemo

À partir de PyTorch 0.4.1, il y a le gestionnaire de contexte detect_anomaly , qui insère automatiquement des assertions équivalentes à assert not torch.isnan(grad).any() entre toutes les étapes de propagation en arrière. Il est très utile lorsque des problèmes surviennent lors d'un passage en arrière.

22
Jatentaki