web-dev-qa-db-fra.com

RuntimeWarning: valeur non valide rencontrée en plus grand

J'ai essayé d'implémenter soft-max avec le code suivant (out_vec Est un vecteur numpy de floats):

numerator = np.exp(out_vec)
denominator = np.sum(np.exp(out_vec))
out_vec = numerator/denominator

Cependant, j'ai eu une erreur de débordement à cause de np.exp(out_vec). Par conséquent, j'ai vérifié (manuellement) quelle est la limite supérieure de np.exp() et trouve que np.exp(709) est un nombre, mais np.exp(710) est considéré comme étant np.inf. Ainsi, pour éviter l'erreur de débordement, j'ai modifié mon code comme suit:

out_vec[out_vec > 709] = 709 #prevent np.exp overflow
numerator = np.exp(out_vec)
denominator = np.sum(np.exp(out_vec))
out_vec = numerator/denominator

Maintenant, je reçois une erreur différente:

RuntimeWarning: invalid value encountered in greater out_vec[out_vec > 709] = 709

Quel est le problème avec la ligne que j'ai ajoutée? J'ai recherché cette erreur spécifique et tout ce que j'ai trouvé est le conseil des gens sur la façon de l'ignorer. Ignorer simplement l'erreur ne m'aidera pas, car chaque fois que mon code rencontre cette erreur, il ne donne pas les résultats habituels.

19
Cheshie

Votre problème est dû aux éléments NaN ou Inf de votre out_vec tableau. Vous pouvez utiliser le code suivant pour éviter ce problème:

if np.isnan(np.sum(out_vec)):
    out_vec = out_vec[~numpy.isnan(out_vec)] # just remove nan elements from vector
out_vec[out_vec > 709] = 709
...

ou vous pouvez utiliser le code suivant pour laisser les valeurs NaN dans votre tableau:

out_vec[ np.array([e > 709 if ~np.isnan(e) else False for e in out_vec], dtype=bool) ] = 709
27
kvorobiev

Dans mon cas, l'avertissement n'apparaissait pas lorsque vous appelez cela avant la comparaison (les valeurs de NaN étaient comparées)

np.warnings.filterwarnings('ignore')
12
juerg

Le meilleur moyen à l’OMI serait d’utiliser une implémentation plus stable numériquement de la somme des exponentielles.

from scipy.misc import logsumexp
out_vec = np.exp(out_vec - logsumexp(out_vec))
6
Ramin Barati