web-dev-qa-db-fra.com

Forêt aléatoire avec GridSearchCV - Erreur sur param_grid

Je tente de créer un modèle de forêt aléatoire avec GridSearchCV mais j'obtiens une erreur concernant param_grid: "ValueError: paramètre non valide max_features pour l'estimateur Pipeline. Vérifiez la liste des paramètres disponibles avec` estimator.get_params (). Keys () ". Je classe des documents, donc je pousse également le vectoriseur tf-idf dans le pipeline. Voici le code:

from sklearn import metrics
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, f1_score, accuracy_score, precision_score, confusion_matrix
from sklearn.pipeline import Pipeline

 #Classifier Pipeline
pipeline = Pipeline([
    ('tfidf', TfidfVectorizer()),
    ('classifier', RandomForestClassifier())
])
# Params for classifier
params = {"max_depth": [3, None],
              "max_features": [1, 3, 10],
              "min_samples_split": [1, 3, 10],
              "min_samples_leaf": [1, 3, 10],
              # "bootstrap": [True, False],
              "criterion": ["gini", "entropy"]}

# Grid Search Execute
rf_grid = GridSearchCV(estimator=pipeline , param_grid=params) #cv=10
rf_detector = rf_grid.fit(X_train, Y_train)
print(rf_grid.grid_scores_)

Je ne peux pas comprendre pourquoi l'erreur s'affiche. Le même problème se produit lorsque j'exécute un arbre de décision avec GridSearchCV. (Scikit-learn 0,17)

17
OAK

Vous devez affecter les paramètres à l'étape nommée dans le pipeline. Dans votre cas classifier. Essayez d'ajouter au début classifier__ au nom du paramètre. exemple de pipeline

params = {"classifier__max_depth": [3, None],
              "classifier__max_features": [1, 3, 10],
              "classifier__min_samples_split": [1, 3, 10],
              "classifier__min_samples_leaf": [1, 3, 10],
              # "bootstrap": [True, False],
              "classifier__criterion": ["gini", "entropy"]}
23
Kevin

Essayez d'exécuter get_params() sur votre dernier objet pipeline , pas seulement l'estimateur. De cette façon, il générerait tous les éléments de tuyau disponibles clés uniques pour les paramètres de la grille.

sorted(pipeline.get_params().keys())

['classifier', 'classifier__bootstrap', 'classifier__class_weight', 'classifier__criterion', 'classifier__max_depth', 'classifier__max_features', 'classifier__max_leaf_nodes', 'classifier__min_impurity_split_s', 'classifier__min_impurity_split_s', ' classifier__min_samples_split ',' classifier__min_weight_fraction_leaf ',' classifier__n_estimators ',' classifier__n_jobs ',' classifier__oob_score ',' classifier__random_state ',' classifier__verbose ',' classifier__warm_startf, tf ',' tf ',' tf , 'tfidf__dtype', 'tfidf__encoding', 'tfidf__input', 'tfidf__lowercase', 'tfidf__max_df', 'tfidf__max_features', 'tfidf__min_df', 'tfidf__ngram_rf_f_f_f_f,' tff tfidf__strip_accents ',' tfidf__sublinear_tf ',' tfidf__token_pattern ',' tfidf__tokenizer ',' tfidf__use_idf ',' tfidf__vocabulary ']

Ceci est particulièrement utile lorsque vous utilisez la syntaxe courte make_pipeline() pour Piplines, où vous ne vous souciez pas des étiquettes pour les éléments de tuyau:

pipeline = make_pipeline(TfidfVectorizer(), RandomForestClassifier())
sorted(pipeline.get_params().keys())

[ 'Randomforestclassifier', 'randomforestclassifier__bootstrap', 'randomforestclassifier__class_weight', 'randomforestclassifier__criterion', 'randomforestclassifier__max_depth', 'randomforestclassifier__max_features', 'randomforestclassifier__max_leaf_nodes', 'randomforestclassifier__min_impurity_split', 'randomforestclassifier__min_samples_leaf', ' randomforestclassifier__min_samples_split ", 'randomforestclassifier__min_weight_fraction_leaf', 'randomforestclassifier__n_estimators', 'randomforestclassifier__n_jobs', 'randomforestclassifier__oob_score', 'randomforestclassifier__random_state', 'randomforestclassifier__verbose', 'randomforestclassifier__warm_start', 'étapes', 'tfidfvectorizer', 'tfidfvectorizer__analyzer', 'tfidfvectorizer__binary', 'tfidfvectorizer__decode_error' , 'tfidfvectorizer__dtype', 'tfidfvectorizer__encoding', 'tfidfvectorizer__input', 'tfidfvectorizer__lowercase', 'tfidfvectorizer__max_df', 'tfidfvectorizer__max_features', 'tfidfvectorizer__min fidfvectorizer__ngram_range ", 'tfidfvectorizer__norm', 'tfidfvectorizer__preprocessor', 'tfidfvectorizer__smooth_idf', '', 'tfidfvectorizer__stop_words tfidfvectorizer__strip_accents', 'tfidfvectorizer__sublinear_tf', '', 'tfidfvectorizer__token_pattern tfidfvectorizer__tokenizer', 'tfidfvectorizer__use_idf', 'tfidfvectorizer__vocabulary']

7
mork