web-dev-qa-db-fra.com

Quelle est la différence entre xgb.train et xgb.XGBRegressor (ou xgb.XGBClassifier)?

Je sais déjà "xgboost.XGBRegressor est une interface Scikit-Learn Wrapper pour XGBoost. "

Mais ont-ils une autre différence?

15
Statham

xgboost.train est l'API de bas niveau pour former le modèle via la méthode de renforcement du gradient.

xgboost.XGBRegressor et xgboost.XGBClassifier sont les wrappers ( Scikit-Learn-like wrappers, comme ils l'appellent) qui préparent le DMatrix et transmettent la fonction objective et les paramètres correspondants. Au final, l'appel fit se résume simplement à:

self._Booster = train(params, dmatrix,
                      self.n_estimators, evals=evals,
                      early_stopping_rounds=early_stopping_rounds,
                      evals_result=evals_result, obj=obj, feval=feval,
                      verbose_eval=verbose)

Cela signifie que tout qui peut être fait avec XGBRegressor et XGBClassifier est réalisable via le sous-jacent xgboost.train fonction. Inversement, ce n'est évidemment pas vrai, par exemple, certains paramètres utiles de xgboost.train ne sont pas pris en charge dans l'API XGBModel. La liste des différences notables comprend:

  • xgboost.train permet de définir le callbacks appliqué à la fin de chaque itération.
  • xgboost.train permet la poursuite de la formation via xgb_model paramètre.
  • xgboost.train permet non seulement la minisation de la fonction eval, mais aussi la maximisation.
24
Maxim

@Maxim, à partir de xgboost 0.90 (ou bien avant), ces différences n'existent plus en ce que xgboost.XGBClassifier.fit :

  • a callbacks
  • permet la contiunation avec le paramètre xgb_model
  • et prend en charge les mêmes mesures d'évaluation intégrées ou fonctions d'évaluation personnalisées

Ce que je trouve différent, c'est evals_result, En ce sens qu'il doit être récupéré séparément après ajustement (clf.evals_result()) et le résultat dict est différent car il ne peut pas en tirer parti du nom des valeurs dans la liste de surveillance (watchlist = [(d_train, 'train'), (d_valid, 'valid')]).

1
paulperry