web-dev-qa-db-fra.com

Comment passer avec élégance les meilleurs paramètres de GridseachCV de Sklearn à un autre modèle?

J'ai trouvé un ensemble de meilleurs hyperparamètres pour mon estimateur KNN avec Grid Search CV:

>>> knn_gridsearch_model.best_params_
{'algorithm': 'auto', 'metric': 'manhattan', 'n_neighbors': 3}

Jusqu'ici tout va bien. Je veux former mon estimateur final avec ces nouveaux paramètres. Existe-t-il un moyen de lui transmettre directement le dict hyperparamètre ci-dessus? J'ai essayé ceci:

>>> new_knn_model = KNeighborsClassifier(knn_gridsearch_model.best_params_)

mais à la place le résultat espéré new_knn_model vient de recevoir le dict entier comme premier paramètre du modèle et a laissé les autres par défaut:

>>> knn_model
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1,
           n_neighbors={'n_neighbors': 3, 'metric': 'manhattan', 'algorithm': 'auto'},
           p=2, weights='uniform')

Décevant en effet.

14
Hendrik

Vous pouvez le faire comme suit:

new_knn_model = KNeighborsClassifier()
new_knn_model.set_params(**knn_gridsearch_model.best_params_)

Ou décompressez directement comme @taras l'a suggéré:

new_knn_model = KNeighborsClassifier(**knn_gridsearch_model.best_params_)

Soit dit en passant, après avoir terminé l'exécution de la recherche dans la grille, l'objet de recherche dans la grille conserve (par défaut) les meilleurs paramètres, de sorte que vous pouvez utiliser l'objet lui-même. Alternativement, vous pouvez également accéder au classificateur avec les meilleurs paramètres via

gs.best_estimator_
27
Miriam Farber

Je veux juste souligner que l'utilisation du grid.best_parameters et passez-les à un nouveau modèle par unpacking comme:

my_model = KNeighborsClassifier(**grid.best_params_)

est bon et tout et personnellement je l'ai beaucoup utilisé.
Cependant, comme vous pouvez le voir dans la documentation ici , si votre objectif est de prédire quelque chose en utilisant ces meilleurs_paramètres, vous pouvez directement utiliser le grid.predict méthode qui utilisera ces meilleurs paramètres pour vous par défaut.

exemple:

y_pred = grid.predict(X_test)

J'espère que cela vous a été utile.

1
Rayhane Mama