web-dev-qa-db-fra.com

Calculer la distance par paires dans un lot sans répliquer le tenseur dans Tensorflow?

Je souhaite calculer la distance carrée par paires d'un lot de fonctionnalités dans Tensorflow. J'ai une implémentation simple en utilisant + et * opérations de Mosaïque du tenseur d'origine:

def pairwise_l2_norm2(x, y, scope=None):
    with tf.op_scope([x, y], scope, 'pairwise_l2_norm2'):
        size_x = tf.shape(x)[0]
        size_y = tf.shape(y)[0]
        xx = tf.expand_dims(x, -1)
        xx = tf.tile(xx, tf.pack([1, 1, size_y]))

        yy = tf.expand_dims(y, -1)
        yy = tf.tile(yy, tf.pack([1, 1, size_x]))
        yy = tf.transpose(yy, perm=[2, 1, 0])

        diff = tf.sub(xx, yy)
        square_diff = tf.square(diff)

        square_dist = tf.reduce_sum(square_diff, 1)

        return square_dist

Cette fonction prend en entrée deux matrices de taille (m, d) et (n, d) et calcule la distance au carré entre chaque vecteur de rangée. La sortie est une matrice de taille (m, n) avec l'élément 'd_ij = dist (x_i, y_j)'.

Le problème, c’est que j’ai un lot important et que les fonctions de faible intensité «m, n, d 'répliquer le tenseur consomment beaucoup de mémoire. Je cherche un autre moyen de mettre en œuvre ceci sans augmenter l'utilisation de la mémoire et juste stocker le tenseur de distance final Sorte de double boucle sur le tenseur d'origine.

24
jrabary

Vous pouvez utiliser une algèbre linéaire pour la transformer en opérations matricielles. Notez que vous avez besoin de la matrice Da[i] est la ième ligne de votre matrice d'origine et 

D[i,j] = (a[i]-a[j])(a[i]-a[j])'

Vous pouvez réécrire cela dans 

D[i,j] = r[i] - 2 a[i]a[j]' + r[j]

r[i] est la norme au carré de ième ligne de la matrice d'origine.

Dans un système prenant en charge les règles broadcast standard , vous pouvez traiter r comme un vecteur de colonne et écrire D comme

D = r - 2 A A' + r'

Dans TensorFlow, vous pouvez écrire ceci comme

A = tf.constant([[1, 1], [2, 2], [3, 3]])
r = tf.reduce_sum(A*A, 1)

# turn r into column vector
r = tf.reshape(r, [-1, 1])
D = r - 2*tf.matmul(A, tf.transpose(A)) + tf.transpose(r)
sess = tf.Session()
sess.run(D)

résultat

array([[0, 2, 8],
       [2, 0, 2],
       [8, 2, 0]], dtype=int32)
47

Utiliser squared_difference:

def squared_dist(A): 
    expanded_a = tf.expand_dims(A, 1)
    expanded_b = tf.expand_dims(A, 0)
    distances = tf.reduce_sum(tf.squared_difference(expanded_a, expanded_b), 2)
    return distances

Une chose que j’ai remarquée est que cette solution utilisant tf.squared_difference me donne une mémoire insuffisante (MOO) pour les très gros vecteurs, contrairement à l’approche de @YaroslavBulatov. Donc, je pense que la décomposition de l'opération produit une empreinte mémoire plus petite (que je pensais que squared_difference gèrerait mieux sous le capot).

11
Yamaneko

Voici une solution plus générale pour deux tenseurs de coordonnées A et B:

def squared_dist(A, B):
  assert A.shape.as_list() == B.shape.as_list()

  row_norms_A = tf.reduce_sum(tf.square(A), axis=1)
  row_norms_A = tf.reshape(row_norms_A, [-1, 1])  # Column vector.

  row_norms_B = tf.reduce_sum(tf.square(B), axis=1)
  row_norms_B = tf.reshape(row_norms_B, [1, -1])  # Row vector.

  return row_norms_A - 2 * tf.matmul(A, tf.transpose(B)) + row_norms_B

Notez que ceci est la distance carrée. Si vous voulez changer cela en distance euclidienne, effectuez un tf.sqrt sur le résultat. Si vous voulez faire cela, n'oubliez pas d'ajouter une petite constante pour compenser les instabilités en virgule flottante: dist = tf.sqrt(squared_dist(A, B) + 1e-6).

3
Augustin

Si vous voulez calculer une autre méthode, changez l'ordre des modules tf.

def compute_euclidean_distance(x, y):
    size_x = x.shape.dims[0]
    size_y = y.shape.dims[0]
    for i in range(size_x):
        tile_one = tf.reshape(tf.tile(x[i], [size_y]), [size_y, -1])
        eu_one = tf.expand_dims(tf.sqrt(tf.reduce_sum(tf.pow(tf.subtract(tile_one, y), 2), axis=1)), axis=0)
        if i == 0:
            d = eu_one
        else:
            d = tf.concat([d, eu_one], axis=0)
return d
0
Hyunguk Choi