web-dev-qa-db-fra.com

Scikit Learn GridSearchCV sans validation croisée (apprentissage non supervisé)

Est-il possible d'utiliser GridSearchCV sans validation croisée? J'essaie d'optimiser le nombre de clusters dans le clustering KMeans via la recherche de grille, et donc je n'ai pas besoin ou je ne veux pas de validation croisée.

Le documentation me déroute également car sous la méthode fit (), il a une option pour l'apprentissage non supervisé (dit d'utiliser None pour l'apprentissage non supervisé). Mais si vous voulez faire un apprentissage non supervisé, vous devez le faire sans validation croisée et il ne semble pas y avoir d'option pour se débarrasser de la validation croisée.

17
DataMan

Après beaucoup de recherches, j'ai pu trouver ce fil . Il semble que vous pouvez vous débarrasser de la validation croisée dans GridSearchCV si vous utilisez:

cv=[(slice(None), slice(None))]

J'ai testé cela par rapport à ma propre version codée de la recherche de grille sans validation croisée et j'obtiens les mêmes résultats des deux méthodes. Je poste cette réponse à ma propre question au cas où d'autres auraient le même problème.

Edit: pour répondre à la question de jjrr dans les commentaires, voici un exemple d'utilisation:

from sklearn.metrics import silhouette_score as sc

def cv_silhouette_scorer(estimator, X):
    estimator.fit(X)
    cluster_labels = estimator.labels_
    num_labels = len(set(cluster_labels))
    num_samples = len(X.index)
    if num_labels == 1 or num_labels == num_samples:
        return -1
    else:
        return sc(X, cluster_labels)

cv = [(slice(None), slice(None))]
gs = GridSearchCV(estimator=sklearn.cluster.MeanShift(), param_grid=param_dict, 
                  scoring=cv_silhouette_scorer, cv=cv, n_jobs=-1)
gs.fit(df[cols_of_interest])
16
DataMan

Je vais répondre à votre question car il semble qu'elle soit toujours sans réponse. En utilisant la méthode de parallélisme avec la boucle for, vous pouvez utiliser le module multiprocessing.

from multiprocessing.dummy import Pool
from sklearn.cluster import KMeans
import functools

kmeans = KMeans()

# define your custom function for passing into each thread
def find_cluster(n_clusters, kmeans, X):
    from sklearn.metrics import silhouette_score  # you want to import in the scorer in your function

    kmeans.set_params(n_clusters=n_clusters)  # set n_cluster
    labels = kmeans.fit_predict(X)  # fit & predict
    score = silhouette_score(X, labels)  # get the score

    return score

# Now's the parallel implementation
clusters = [3, 4, 5]
pool = Pool()
results = pool.map(functools.partial(find_cluster, kmeans=kmeans, X=X), clusters)
pool.close()
pool.join()

# print the results
print(results)  # will print a list of scores that corresponds to the clusters list
6
Scratch'N'Purr

Je suis récemment sorti avec le validateur croisé personnalisé suivant, basé sur cette réponse . Je l'ai transmis à GridSearchCV et il a correctement désactivé la validation croisée pour moi:

import numpy as np

class DisabledCV:
    def __init__(self):
        self.n_splits = 1

    def split(self, X, y, groups=None):
        yield (np.arange(len(X)), np.arange(len(y)))

    def get_n_splits(self, X, y, groups=None):
        return self.n_splits

J'espère que ça peut aider.

5
MrD

Je pense que l'utilisation de cv = ShuffleSplit (test_size = 0.20, n_splits = 1) avec n_splits = 1 est une meilleure solution comme celle-ci post suggéré

4
ihebiheb