web-dev-qa-db-fra.com

Quelle est la différence entre tf.truncated_normal et tf.random_normal?

tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None) renvoie des valeurs aléatoires à partir d'une distribution normale.

tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None) renvoie des valeurs aléatoires à partir d'une distribution normale tronquée.

J'ai essayé de googler "distribution normale tronquée". Mais je n'ai pas compris grand chose.

44
Tarun Wadhwa

La documentation dit tout: pour une distribution normale tronquée:

Les valeurs générées suivent une distribution normale avec une moyenne et un écart type spécifiés, à l'exception des valeurs dont l'amplitude est supérieure à 2 écarts types par rapport à la moyenne, sont supprimées et recalculées.

Le plus probablement, il est facile de comprendre la différence en traçant le graphique pour vous-même (% magic est parce que j'utilise jupyter notebook):

import tensorflow as tf
import matplotlib.pyplot as plt

%matplotlib inline  

n = 500000
A = tf.truncated_normal((n,))
B = tf.random_normal((n,))
with tf.Session() as sess:
    a, b = sess.run([A, B])

Et maintenant

plt.hist(a, 100, (-4.2, 4.2));
plt.hist(b, 100, (-4.2, 4.2));

enter image description here


Le but de l’utilisation de la normale tronquée est de surmonter la saturation des fonctions du tome comme sigmoïde (où si la valeur est trop grande/petite, le neurone cesse d’apprendre).

65
Salvador Dali

tf.truncated_normal() sélectionne des nombres aléatoires dans une distribution normale dont la moyenne est proche de 0 et les valeurs sont proches de 0. Par exemple, de -0,1 à 0,1. C'est ce qu'on appelle tronqué parce que vous coupez la queue d'une distribution normale.

tf.random_normal() sélectionne des nombres aléatoires dans une distribution normale dont la moyenne est proche de 0, mais les valeurs peuvent être un peu plus éloignées. Par exemple, de -2 à 2.

En apprentissage automatique, en pratique, vous voulez généralement que vos poids soient proches de 0.

23
ksooklall

La documentation de l’API pour tf.truncated_normal () décrit la fonction comme suit:

Renvoie des valeurs aléatoires à partir d'une distribution normale tronquée.

Les valeurs générées suivent une distribution normale avec une moyenne et un écart type spécifiés, à l'exception des valeurs dont l'amplitude est supérieure à 2 écarts types par rapport à la moyenne, sont supprimées et recalculées.

8
Martin Svedin