web-dev-qa-db-fra.com

Comment puis-je trouver les attributs sur lesquels mon arbre se divise lorsque j'utilise scikit-learn?

J'ai exploré scikit-learn, pris des arbres de décision avec des critères de fractionnement d'entropie et de gini, et exploré les différences.

Ma question est de savoir comment "ouvrir le capot" et savoir exactement sur quels attributs les arbres se divisent à chaque niveau, ainsi que leurs valeurs d'informations associées, afin que je puisse voir où les deux critères font des choix différents?

Jusqu'à présent, j'ai exploré les 9 méthodes décrites dans la documentation. Ils ne semblent pas autoriser l'accès à ces informations. Mais cette information est sûrement accessible? J'imagine une liste ou un dict qui a des entrées pour le nœud et le gain.

Merci pour votre aide et mes excuses si j'ai raté quelque chose de complètement évident.

39
Matt O'Brien

Directement à partir de la documentation ( http://scikit-learn.org/0.12/modules/tree.html ):

from io import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)

Le module StringIO n'est plus pris en charge dans Python3, au lieu d'importer le module io.

Il y a aussi le tree_ attribut dans votre objet d'arbre de décision, qui permet l'accès direct à toute la structure.

Et vous pouvez simplement le lire

clf.tree_.children_left #array of left children
clf.tree_.children_right #array of right children
clf.tree_.feature #array of nodes splitting feature
clf.tree_.threshold #array of nodes splitting points
clf.tree_.value #array of nodes values

pour plus de détails, regardez le code source de la méthode d'exportation

En général, vous pouvez utiliser le module inspect

from inspect import getmembers
print( getmembers( clf.tree_ ) )

pour obtenir tous les éléments de l'objet

Decision tree visualization from sklearn docs

32
lejlot

Si vous voulez simplement voir rapidement ce qui se passe dans l'arborescence, essayez:

Zip(X.columns[clf.tree_.feature], clf.tree_.threshold, clf.tree_.children_left, clf.tree_.children_right)

où X est le bloc de données de variables indépendantes et clf est l'objet d'arbre de décision. Remarquerez que clf.tree_.children_left et clf.tree_.children_right contiennent ensemble l'ordre dans lequel les divisions ont été effectuées (chacune correspondrait à une flèche dans la visualisation graphique).

11
Daniel Gibson

Scikit learn a introduit une délicieuse nouvelle méthode appelée export_text dans la version 0.21 (mai 2019) pour afficher toutes les règles d'un arbre. Documentation ici .

Une fois que vous avez adapté votre modèle, il vous suffit de deux lignes de code. Tout d'abord, importez export_text:

from sklearn.tree.export import export_text

Deuxièmement, créez un objet qui contiendra vos règles. Pour rendre les règles plus lisibles, utilisez le feature_names argument et passez une liste de vos noms de fonction. Par exemple, si votre modèle est appelé model et vos fonctions sont nommées dans un cadre de données appelé X_train, vous pouvez créer un objet appelé tree_rules:

tree_rules = export_text(model, feature_names=list(X_train))

Imprimez ou enregistrez simplement tree_rules. Votre sortie ressemblera à ceci:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1
0
yzerman