web-dev-qa-db-fra.com

Comment extraire les règles de décision de l'arbre de décision scikit-learn?

Puis-je extraire les règles de décision sous-jacentes (ou "chemins de décision") d'un arbre formé dans un arbre de décision sous forme de liste textuelle?

Quelque chose comme: 

if A>0.4 then if B<0.2 then if C>0.8 then class='X' 

Merci de votre aide.

108
Dror Hilman

Je crois que cette réponse est plus correcte que les autres réponses ici:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

Cela affiche une fonction Python valide. Voici un exemple de sortie pour une arborescence qui tente de renvoyer son entrée, un nombre compris entre 0 et 10.

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

Voici quelques points d'achoppement que je vois dans d'autres réponses:

  1. Utiliser tree_.threshold == -2 pour décider si un nœud est une feuille n'est pas une bonne idée. Et si c'est un vrai noeud de décision avec un seuil de -2? Au lieu de cela, vous devriez regarder tree.feature ou tree.children_*.
  2. La ligne features = [feature_names[i] for i in tree_.feature] plante avec ma version de sklearn, car certaines valeurs de tree.tree_.feature sont -2 (spécifiquement pour les nœuds terminaux).
  3. Il n'est pas nécessaire d'avoir plusieurs instructions if si dans la fonction récursive, une seule suffit.
91
paulkernfeld

J'ai créé ma propre fonction pour extraire les règles des arbres de décision créés par sklearn:

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

Cette fonction commence par les nœuds (identifiés par -1 dans les tableaux enfants), puis recherche les parents de manière récursive. J'appelle cela une "lignée" de nœud. En cours de route, je récupère les valeurs nécessaires pour créer la logique if/then/else SAS:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

Les ensembles de nuplets ci-dessous contiennent tout ce dont j'ai besoin pour créer SAS instructions if/then/else. Je n'aime pas utiliser les blocs do dans SAS, c'est pourquoi je crée une logique décrivant le chemin complet d'un nœud. L'entier unique après les nuplets est l'ID du nœud terminal dans un chemin. Tous les tuples précédents se combinent pour créer ce nœud.

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

GraphViz output of example tree

45
Zelazny7

J'ai modifié le code soumis par Zelazny7 pour imprimer un pseudocode:

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

si vous appelez get_code(dt, df.columns) sur le même exemple, vous obtiendrez:

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}
33
Daniele

Il existe une nouvelle méthode DecisionTreeClassifier , decision_path, dans la version 0.18.0 . Les développeurs fournissent une vaste (bien documentée) procédure pas à pas .

La première section de code de la procédure pas à pas qui imprime la structure arborescente semble être OK. Cependant, j'ai modifié le code dans la deuxième section pour interroger un échantillon. Mes changements notés avec # <--

Edit Les modifications marquées par # <-- dans le code ci-dessous ont depuis été mises à jour dans le lien pas à pas après que les erreurs ont été signalées dans les demandes d'extraction # 8653 et # 10951 . C'est beaucoup plus facile à suivre maintenant. 

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0: 
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

Changez le sample_id pour voir les chemins de décision pour les autres exemples. Je n'ai pas interrogé les développeurs à propos de ces modifications, je me suis simplement montré plus intuitif lorsque j'ai suivi l'exemple.

12
Kevin
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()

Vous pouvez voir un arbre de digraphe. Alors, clf.tree_.feature et clf.tree_.value sont des tableaux de noeuds séparant les entités et des valeurs de noeuds Vous pouvez vous référer à plus de détails à partir de cette source github

10
lennon310

Les codes ci-dessous sont mon approche sous anaconda python 2.7 plus un nom de paquetage "pydot-ng" pour créer un fichier PDF avec des règles de décision. J'espère que c'est utile.

from sklearn import tree

clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)

feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)

def output_pdf(clf_, name):
    from sklearn import tree
    from sklearn.externals.six import StringIO
    import pydot_ng as pydot
    dot_data = StringIO()
    tree.export_graphviz(clf_, out_file=dot_data,
                         feature_names=feature_names,
                         class_names=class_name,
                         filled=True, rounded=True,
                         special_characters=True,
                          node_ids=1,)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("%s.pdf"%name)

output_pdf(clf_, name='filename%s'%n)

un graphique d'arborescence montre ici

2
TED Zhao

Juste parce que tout le monde a été si utile, je vais juste ajouter une modification aux belles solutions de Zelazny7 et de Daniele. Celui-ci est destiné à Python 2.7, avec des onglets pour le rendre plus lisible:

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                    print '\t' * tabdepth,
                    print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "} else {"
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "}"
            else:
                    print '\t' * tabdepth,
                    print "return " + str(value[node])

    recurse(left, right, threshold, features, 0)
2
Ruslan

Ceci s'appuie sur la réponse de @paulkernfeld. Si vous avez un cadre de données X avec vos entités et un cadre de données cible y avec vos réponses et que vous souhaitez avoir une idée de la valeur y indiquée dans quel nœud (et également de le tracer en conséquence), vous pouvez procéder comme suit:

    def tree_to_code(tree, feature_names):
        codelines = []
        codelines.append('def get_cat(X_tmp):\n')
        codelines.append('   catout = []\n')
        codelines.append('   for codelines in range(0,X_tmp.shape[0]):\n')
        codelines.append('      Xin = X_tmp.iloc[codelines]\n')
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        #print "def tree({}):".format(", ".join(feature_names))

        def recurse(node, depth):
            indent = "      " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                codelines.append( '{}else:  # if Xin["{}"] > {}\n'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                codelines.append( '{}mycat = {}\n'.format(indent, node))

        recurse(0, 1)
        codelines.append('      catout.append(mycat)\n')
        codelines.append('   return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
        codelines.append('node_ids = get_cat(X)\n')
        return codelines
    mycode = tree_to_code(clf,X.columns.values)

    # now execute the function and obtain the dataframe with all nodes
    exec(''.join(mycode))
    node_ids = [int(x[0]) for x in node_ids.values]
    node_ids2 = pd.DataFrame(node_ids)

    print('make plot')
    import matplotlib.cm as cm
    colors = cm.Rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
    #plt.figure(figsize=cm2inch(24, 21))
    for i in list(set(node_ids)):
        plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))  
    mytitle = ['y colored by node']
    plt.title(mytitle ,fontsize=14)
    plt.xlabel('my xlabel')
    plt.ylabel(tagname)
    plt.xticks(rotation=70)       
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
    plt.tight_layout()
    plt.show()
    plt.close 

pas la version la plus élégante mais elle fait le travail ... 

2
horseshoe

Scikit learn a introduit une nouvelle méthode délicieuse appelée export_text dans la version 0.21 (mai 2019) pour extraire les règles d'un arbre. Documentation ici . Il n'est plus nécessaire de créer une fonction personnalisée.

Une fois que vous avez adapté votre modèle, il vous suffit de deux lignes de code. Commencez par importer export_text:

from sklearn.tree.export import export_text

Deuxièmement, créez un objet qui contiendra vos règles. Pour rendre les règles plus lisibles, utilisez l'argument feature_names et transmettez une liste de vos noms d'entités. Par exemple, si votre modèle s'appelle model et que vos entités sont nommées dans un cadre de données appelé X_train, vous pouvez créer un objet appelé tree_rules:

tree_rules = export_text(model, feature_names=list(X_train))

Ensuite, imprimez ou enregistrez tree_rules. Votre sortie ressemblera à ceci:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1
1
yzerman

Je suis passé par là, mais j'avais besoin que les règles soient écrites dans ce format 

if A>0.4 then if B<0.2 then if C>0.8 then class='X' 

J'ai donc adapté la réponse de @paulkernfeld (merci) que vous pouvez personnaliser en fonction de vos besoins.

def tree_to_code(tree, feature_names, Y):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    pathto=dict()

    global k
    k = 0
    def recurse(node, depth, parent):
        global k
        indent = "  " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            s= "{} <= {} ".format( name, threshold, node )
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s

            recurse(tree_.children_left[node], depth + 1, node)
            s="{} > {}".format( name, threshold)
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s
            recurse(tree_.children_right[node], depth + 1, node)
        else:
            k=k+1
            print(k,')',pathto[parent], tree_.value[node])
    recurse(0, 1, 0)
1
Ala Ham

Voici une fonction d'impression des règles d'un arbre de décision scikit-learn sous python 3 et avec des décalages pour les blocs conditionnels afin de rendre la structure plus lisible:

def print_decision_tree(tree, feature_names=None, offset_unit='    '):
    '''Plots textual representation of rules of a decision tree
    tree: scikit-learn representation of tree
    feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
    offset_unit: a string of offset of the conditional block'''

    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value
    if feature_names is None:
        features  = ['f%d'%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0):
            offset = offset_unit*depth
            if (threshold[node] != -2):
                    print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node],depth+1)
                    print(offset+"} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node],depth+1)
                    print(offset+"}")
            else:
                    print(offset+"return " + str(value[node]))

    recurse(left, right, threshold, features, 0,0)
1
Apogentus

Apparemment, il y a longtemps, quelqu'un a déjà décidé d'essayer d'ajouter la fonction suivante aux fonctions d'exportation de l'arbre officiel de scikit (qui ne prend en charge que export_graphviz)

def export_dict(tree, feature_names=None, max_depth=None) :
    """Export a decision tree in dict format.

Voici son engagement complet:

https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

Pas tout à fait sûr de ce qui est arrivé à ce commentaire. Mais vous pouvez aussi essayer d'utiliser cette fonction.

Je pense que cela justifie une demande sérieuse de documentation auprès des personnes compétentes de scikit-learn afin de documenter correctement l'API sklearn.tree.Tree, qui est la structure arborescente sous-jacente que DecisionTreeClassifier expose sous son attribut tree_.

0
Aris Koning

Vous pouvez également le rendre plus informatif en le distinguant à quelle classe il appartient ou même en mentionnant sa valeur de sortie.

def print_decision_tree(tree, feature_names, offset_unit='    '):    
left      = tree.tree_.children_left
right     = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
    features  = ['f%d'%i for i in tree.tree_.feature]
else:
    features  = [feature_names[i] for i in tree.tree_.feature]        

def recurse(left, right, threshold, features, node, depth=0):
        offset = offset_unit*depth
        if (threshold[node] != -2):
                print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                if left[node] != -1:
                        recurse (left, right, threshold, features,left[node],depth+1)
                print(offset+"} else {")
                if right[node] != -1:
                        recurse (left, right, threshold, features,right[node],depth+1)
                print(offset+"}")
        else:
                #print(offset,value[node]) 

                #To remove values from node
                temp=str(value[node])
                mid=len(temp)//2
                tempx=[]
                tempy=[]
                cnt=0
                for i in temp:
                    if cnt<=mid:
                        tempx.append(i)
                        cnt+=1
                    else:
                        tempy.append(i)
                        cnt+=1
                val_yes=[]
                val_no=[]
                res=[]
                for j in tempx:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_no.append(j)
                for j in tempy:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_yes.append(j)
                val_yes = int("".join(map(str, val_yes)))
                val_no = int("".join(map(str, val_no)))

                if val_yes>val_no:
                    print(offset,'\033[1m',"YES")
                    print('\033[0m')
                Elif val_no>val_yes:
                    print(offset,'\033[1m',"NO")
                    print('\033[0m')
                else:
                    print(offset,'\033[1m',"Tie")
                    print('\033[0m')

recurse(left, right, threshold, features, 0,0)

enter image description here

0
Amit Rautray

Code modifié de Zelazny7 pour extraire le code SQL de l’arbre de décision.

# SQL from decision tree

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]
     le='<='               
     g ='>'
     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'
          lineage.append((parent, split, threshold[parent], features[parent]))
          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)
     print 'case '
     for j,child in enumerate(idx):
        clause=' when '
        for node in recurse(left, right, child):
            if len(str(node))<3:
                continue
            i=node
            if i[1]=='l':  sign=le 
            else: sign=g
            clause=clause+i[3]+sign+str(i[2])+' and '
        clause=clause[:-4]+' then '+str(j)
        print clause
     print 'else 99 end as clusters'
0
Arslán

Voici un moyen de traduire tout l'arbre en une seule expression python (pas nécessairement trop lisible par l'homme) à l'aide de la bibliothèque SKompiler :

from skompiler import skompile
skompile(dtree.predict).to('python/code')
0
KT.

Utilisez juste la fonction de sklearn.tree comme ceci

from sklearn.tree import export_graphviz
    export_graphviz(tree,
                out_file = "tree.dot",
                feature_names = tree.columns) //or just ["petal length", "petal width"]

Ensuite, recherchez dans le dossier de votre projet le fichier tree.dot, copiez-le TOUT le contenu et collez-le ici http://www.webgraphviz.com/ et générez votre graphique :)

0
chainstair