web-dev-qa-db-fra.com

L'indicateur à étiquettes multiples n'est pas pris en charge pour la matrice de confusion

multilabel-indicator is not supported Est le message d'erreur que j'obtiens lorsque j'essaie d'exécuter:

confusion_matrix(y_test, predictions)

y_test Est un DataFrame de forme:

Horse | Dog | Cat
1       0     0
0       1     0
0       1     0
...     ...   ...

predictions est un numpy array:

[[1, 0, 0],
 [0, 1, 0],
 [0, 1, 0]]

J'ai cherché un peu le message d'erreur, mais je n'ai pas vraiment trouvé quelque chose que je pourrais appliquer. Des indices?

16
Khaine775

Non, votre entrée pour confusion_matrix doit être une liste de prédictions, pas des OHE (un encodage à chaud). Appelez argmax sur votre y_test et y_pred, et vous devriez obtenir ce que vous attendez.

confusion_matrix(
    y_test.values.argmax(axis=1), predictions.argmax(axis=1))

array([[1, 0],
       [0, 2]])
33
cs95

La matrice de confusion prend un vecteur d'étiquettes (pas l'encodage à chaud). Tu devrais courir

confusion_matrix(y_test.values.argmax(axis=1), predictions.argmax(axis=1))
7
Joshua Howard