web-dev-qa-db-fra.com

Comment fonctionne le paramètre class_weight dans scikit-learn?

J'ai beaucoup de mal à comprendre le fonctionnement du paramètre class_weight dans la régression logistique de scikit-learn.

La situation

Je souhaite utiliser la régression logistique pour effectuer une classification binaire sur un ensemble de données très déséquilibré. Les classes sont étiquetées 0 (négatif) et 1 (positif) et les données observées sont dans un rapport d'environ 19: 1 avec la majorité des échantillons ayant un résultat négatif.

Première tentative: préparation manuelle des données de formation

J'ai divisé les données que j'avais avec des ensembles disjoints pour la formation et les tests (environ 80/20). Ensuite, j'ai échantillonné au hasard les données d'entraînement à la main pour obtenir des données d'entraînement dans des proportions différentes de 19: 1; à partir de 2: 1 -> 16: 1.

J'ai ensuite appris la régression logistique sur ces différents sous-ensembles de données d'apprentissage et le rappel tracé (= TP/(TP + FN)) en fonction des différentes proportions d'entraînement. Bien entendu, le rappel a été calculé sur les échantillons TEST disjoints présentant les proportions observées de 19: 1. Remarque: bien que j'ai entraîné les différents modèles sur différentes données d'apprentissage, j'ai calculé le rappel de chacun d'entre eux sur les mêmes données de test (disjointes).

Les résultats ont été conformes aux attentes: le rappel était d’environ 60% à des proportions d’entraînement de 2: 1 et a chuté assez rapidement au moment où il est passé à 16: 1. Il y avait plusieurs proportions 2: 1 -> 6: 1 où le rappel était décemment supérieur à 5%.

Deuxième tentative: Recherche dans la grille

Ensuite, je voulais tester différents paramètres de régularisation et j'ai donc utilisé GridSearchCV et créé une grille de plusieurs valeurs du paramètre C ainsi que du paramètre class_weight. Pour traduire mes n: m proportions d’échantillons d’entraînement négatif: positif dans le dictionnaire de class_weight, j’ai pensé que je ne devais spécifier que plusieurs dictionnaires comme suit:

{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 }   #expected 4:1

et j'ai également inclus None et auto.

Cette fois, les résultats ont été totalement déformés. Tous mes rappels sont sortis minuscules (<0,05) pour chaque valeur de class_weight sauf auto. Je ne peux donc que supposer que ma compréhension de la définition du dictionnaire class_weight est erronée. Fait intéressant, la valeur class_weight de 'auto' dans la recherche sur la grille était d'environ 59% pour toutes les valeurs de C, et j'ai supposé qu'elle se situe à 1: 1?

Mes questions

1) Comment utilisez-vous correctement class_weight pour obtenir différents équilibres dans les données d'apprentissage par rapport à ce que vous lui donnez réellement? Plus précisément, quel dictionnaire dois-je transmettre à class_weight pour utiliser n: m proportions d’échantillons d’entraînement négatif: positif?

2) Si vous transmettez divers dictionnaires class_weight à GridSearchCV, lors de la validation croisée, les données du pli d’apprentissage seront-elles rééquilibrées, mais les proportions exactes de l’échantillon seront utilisées pour calculer ma fonction de scoring sur le pli de test? Ceci est essentiel car toute métrique ne m'est utile que si elle provient de données dans les proportions observées.

3) Que fait la auto valeur de class_weight en termes de proportions? J'ai lu la documentation et je suppose que "l'équilibre des données est inversement proportionnel à leur fréquence" signifie simplement que le résultat est de 1: 1. Est-ce correct? Si non, quelqu'un peut-il clarifier?

Merci beaucoup, toute clarification serait grandement appréciée!

90
kilgoretrout

Tout d'abord, il ne serait peut-être pas bon de simplement s'en tenir à un rappel. Vous pouvez simplement obtenir un rappel de 100% en classant tout dans la classe des positifs. Je suggère généralement d'utiliser l'AUC pour sélectionner les paramètres, puis de trouver un seuil pour le point de fonctionnement (disons un niveau de précision donné) qui vous intéresse.

Pour savoir comment class_weight fonctionne: Cela pénalise les erreurs dans les échantillons de class[i] avec class_weight[i] au lieu de 1. Un poids de classe plus élevé signifie donc que vous souhaitez mettre davantage l'accent sur une classe. D'après ce que vous dites, il semble que la classe 0 soit 19 fois plus fréquente que la classe 1. Vous devez donc augmenter le class_weight de la classe 1 par rapport à la classe 0, par exemple {0: .1, 1: .9}. Si la somme de class_weight ne correspond pas à 1, le paramètre de régularisation sera fondamentalement modifié.

Pour savoir comment fonctionne class_weight="auto", vous pouvez consulter cette discussion . Dans la version dev, vous pouvez utiliser class_weight="balanced", ce qui est plus facile à comprendre: cela signifie essentiellement de répliquer la classe la plus petite jusqu'à obtenir autant d'échantillons que dans la plus grande, mais de manière implicite.

101
Andreas Mueller