web-dev-qa-db-fra.com

Comment extraire des règles d'arbre de décision sklearn pour pandas conditions booléennes?

Il y a tellement de messages comme celui-ci sur la façon d'extraire les règles d'arbre de décision sklearn mais je n'ai trouvé aucun sur l'utilisation des pandas.

Prenez ces données et ce modèle par exemple, comme ci-dessous

# Create Decision Tree classifer object
clf = DecisionTreeClassifier(criterion="entropy", max_depth=3)

# Train Decision Tree Classifer
clf = clf.fit(X_train,y_train)

Le résultat:

enter image description here

Attendu:

Il y a 8 règles concernant cet exemple.

De gauche à droite, notez que la trame de données est df

r1 = (df['glucose']<=127.5) & (df['bmi']<=26.45) & (df['bmi']<=9.1)
……
r8 =  (df['glucose']>127.5) & (df['bmi']>28.15) & (df['glucose']>158.5)

Je ne suis pas un maître de l'extraction de règles d'arbre de décision sklearn. Obtenir les conditions booléennes pandas m'aidera à calculer des échantillons et d'autres mesures pour chaque règle. Je souhaite donc extraire chaque règle à une condition booléenne pandas).

17
Jack

Tout d'abord, utilisons le scikit documentation sur la structure de l'arbre de décision pour obtenir des informations sur l'arbre qui a été construit:

n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold

Nous définissons ensuite deux fonctions récursives. Le premier trouvera le chemin depuis la racine de l'arbre pour créer un nœud spécifique (toutes les feuilles dans notre cas). Le second écrira les règles spécifiques utilisées pour créer un nœud en utilisant son chemin de création:

def find_path(node_numb, path, x):
        path.append(node_numb)
        if node_numb == x:
            return True
        left = False
        right = False
        if (children_left[node_numb] !=-1):
            left = find_path(children_left[node_numb], path, x)
        if (children_right[node_numb] !=-1):
            right = find_path(children_right[node_numb], path, x)
        if left or right :
            return True
        path.remove(node_numb)
        return False


def get_rule(path, column_names):
    mask = ''
    for index, node in enumerate(path):
        #We check if we are not in the leaf
        if index!=len(path)-1:
            # Do we go under or over the threshold ?
            if (children_left[node] == path[index+1]):
                mask += "(df['{}']<= {}) \t ".format(column_names[feature[node]], threshold[node])
            else:
                mask += "(df['{}']> {}) \t ".format(column_names[feature[node]], threshold[node])
    # We insert the & at the right places
    mask = mask.replace("\t", "&", mask.count("\t") - 1)
    mask = mask.replace("\t", "")
    return mask

Enfin, nous utilisons ces deux fonctions pour stocker d'abord le chemin de création de chaque feuille. Et puis pour stocker les règles utilisées pour créer chaque feuille:

# Leaves
leave_id = clf.apply(X_test)

paths ={}
for leaf in np.unique(leave_id):
    path_leaf = []
    find_path(0, path_leaf, leaf)
    paths[leaf] = np.unique(np.sort(path_leaf))

rules = {}
for key in paths:
    rules[key] = get_rule(paths[key], pima.columns)

Avec les données que vous avez fournies, la sortie est:

rules =
{3: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']<= 9.100000381469727)  ",
 4: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']> 9.100000381469727)  ",
 6: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']<= 27.5)  ",
 7: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']> 27.5)  ",
 10: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']<= 145.5)  ",
 11: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']> 145.5)  ",
 13: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']<= 158.5)  ",
 14: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']> 158.5)  "}

Comme les règles sont des chaînes, vous ne pouvez pas les appeler directement en utilisant df[rules[3]], Vous devez utiliser la fonction eval comme ceci df[eval(rules[3])]

17
vlemaistre

Vous pouvez maintenant utiliser export_text.

from sklearn.tree import export_text

r = export_text(loan_tree, feature_names=(list(X_train.columns)))
print(r)

Un exemple complet de sklearn

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
iris = load_iris()
X = iris['data']
y = iris['target']
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(X, y)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)
0
kevin