web-dev-qa-db-fra.com

Comment utiliser correctement tf.metrics.accuracy?

J'ai quelques difficultés à utiliser la fonction accuracy de tf.metrics pour un problème de classification multiple avec des logits en entrée.

Ma sortie de modèle ressemble à:

logits = [[0.1, 0.5, 0.4],
          [0.8, 0.1, 0.1],
          [0.6, 0.3, 0.2]]

Et mes étiquettes sont des vecteurs encodés à chaud:

labels = [[0, 1, 0],
          [1, 0, 0],
          [0, 0, 1]]

Lorsque j'essaie de faire quelque chose comme tf.metrics.accuracy(labels, logits), cela ne donne jamais le résultat correct. Je fais évidemment quelque chose de mal, mais je ne peux pas comprendre ce que c'est.

22
Thomas Reynaud

TL; DR

La fonction de précision tf.metrics.accuracy calcule la fréquence à laquelle les prédictions correspondent aux étiquettes en fonction de deux variables locales créées: total et count, utilisées pour calculer la fréquence avec laquelle logits correspond à labels

acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1), 
                                  predictions=tf.argmax(logits,1))

print(sess.run([acc, acc_op]))
print(sess.run([acc]))
# Output
#[0.0, 0.66666669]
#[0.66666669]
  • acc (exactitude): renvoie simplement les métriques en utilisant total et count, ne met pas à jour les métriques.
  • acc_op (update up): met à jour les métriques.

Pour comprendre pourquoi acc renvoie 0.0, consultez les détails ci-dessous.


Détails à l'aide d'un exemple simple:

logits = tf.placeholder(tf.int64, [2,3])
labels = tf.Variable([[0, 1, 0], [1, 0, 1]])

acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1),   
                                  predictions=tf.argmax(logits,1))

Initialise les variables:

Puisque metrics.accuracy crée deux variables locales total et count, nous devons appeler local_variables_initializer() pour les initialiser.

sess = tf.Session()

sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())

stream_vars = [i for i in tf.local_variables()]
print(stream_vars)

#[<tf.Variable 'accuracy/total:0' shape=() dtype=float32_ref>,
# <tf.Variable 'accuracy/count:0' shape=() dtype=float32_ref>]

Comprendre les opérations de mise à jour et le calcul de la précision:

print('acc:',sess.run(acc, {logits:[[0,1,0],[1,0,1]]}))
#acc: 0.0

print('[total, count]:',sess.run(stream_vars)) 
#[total, count]: [0.0, 0.0]

Les valeurs ci-dessus renvoient 0,0 pour l'exactitude, car total et count sont des zéros, malgré le fait de donner des entrées correspondantes.

print('ops:', sess.run(acc_op, {logits:[[0,1,0],[1,0,1]]})) 
#ops: 1.0

print('[total, count]:',sess.run(stream_vars)) 
#[total, count]: [2.0, 2.0]

Avec les nouvelles entrées, la précision est calculée lorsque l’opération de mise à jour est appelée. Remarque: étant donné que tous les logits et étiquettes correspondent, nous obtenons une précision de 1,0 et les variables locales total et count donnent en fait total correctly predicted et le total comparisons made.

Nous appelons maintenant accuracy avec les nouvelles entrées (pas les opérations de mise à jour):

print('acc:', sess.run(acc,{logits:[[1,0,0],[0,1,0]]}))
#acc: 1.0

L'appel d'exactitude ne met pas à jour les métriques avec les nouvelles entrées, il renvoie simplement la valeur en utilisant les deux variables locales. Remarque: les logits et les étiquettes ne correspondent pas dans ce cas. Maintenant, appelez à nouveau les opérations de mise à jour:

print('op:',sess.run(acc_op,{logits:[[0,1,0],[0,1,0]]}))
#op: 0.75 
print('[total, count]:',sess.run(stream_vars)) 
#[total, count]: [3.0, 4.0]

Les métriques sont mises à jour pour les nouvelles entrées


Pour plus d'informations sur l'utilisation des métriques lors de la formation et sur leur réinitialisation lors de la validation, reportez-vous à ici .

47
vijay m

Appliqué sur un CNN, vous pouvez écrire:

x_len=24*24
y_len=2

x = tf.placeholder(tf.float32, shape=[None, x_len], name='input')

fc1 = ... # cnn's fully connected layer
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
layer_fc_dropout = tf.nn.dropout(fc1, keep_prob, name='dropout')

y_pred = tf.nn.softmax(fc1, name='output')
logits = tf.argmax(y_pred, axis=1)

y_true = tf.placeholder(tf.float32, shape=[None, y_len], name='y_true')
acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(y_true, axis=1), predictions=tf.argmax(y_pred, 1))


sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

def print_accuracy(x_data, y_data, dropout=1.0):
    accuracy = sess.run(acc_op, feed_dict = {y_true: y_data, x: x_data, keep_prob: dropout})
    print('Accuracy: ', accuracy)
0
Tobias Ernst