web-dev-qa-db-fra.com

Comment croiser la validation du modèle RandomForest?

Je veux évaluer une forêt aléatoire en cours de formation sur certaines données. Existe-t-il un utilitaire dans Apache Spark pour faire de même ou dois-je effectuer une validation croisée manuellement?

21
ashishsjsu

ML fournit une classe CrossValidator qui peut être utilisée pour effectuer une validation croisée et une recherche de paramètres. En supposant que vos données sont déjà prétraitées, vous pouvez ajouter la validation croisée comme suit:

import org.Apache.spark.ml.Pipeline
import org.Apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.Apache.spark.ml.classification.RandomForestClassifier
import org.Apache.spark.ml.evaluation.MulticlassClassificationEvaluator

// [label: double, features: vector]
trainingData org.Apache.spark.sql.DataFrame = ??? 
val nFolds: Int = ???
val numTrees: Int = ???
val metric: String = ???

val rf = new RandomForestClassifier()
  .setLabelCol("label")
  .setFeaturesCol("features")
  .setNumTrees(numTrees)

val pipeline = new Pipeline().setStages(Array(rf)) 

val paramGrid = new ParamGridBuilder().build() // No parameter search

val evaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("label")
  .setPredictionCol("prediction")
  // "f1" (default), "weightedPrecision", "weightedRecall", "accuracy"
  .setMetricName(metric) 

val cv = new CrossValidator()
  // ml.Pipeline with ml.classification.RandomForestClassifier
  .setEstimator(pipeline)
  // ml.evaluation.MulticlassClassificationEvaluator
  .setEvaluator(evaluator) 
  .setEstimatorParamMaps(paramGrid)
  .setNumFolds(nFolds)

val model = cv.fit(trainingData) // trainingData: DataFrame

Utilisation de PySpark:

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

trainingData = ... # DataFrame[label: double, features: vector]
numFolds = ... # Integer

rf = RandomForestClassifier(labelCol="label", featuresCol="features")
evaluator = MulticlassClassificationEvaluator() # + other params as in Scala    

pipeline = Pipeline(stages=[rf])
paramGrid = (ParamGridBuilder. 
    .addGrid(rf.numTrees, [3, 10])
    .addGrid(...)  # Add other parameters
    .build())

crossval = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=numFolds)

model = crossval.fit(trainingData)
36
zero323

Pour tirer parti de l'excellente réponse de zero323 à l'aide de Random Forest Classifier, voici un exemple similaire pour Random Forest Regressor:

import org.Apache.spark.ml.Pipeline
import org.Apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.Apache.spark.ml.regression.RandomForestRegressor // CHANGED
import org.Apache.spark.ml.evaluation.RegressionEvaluator // CHANGED
import org.Apache.spark.ml.feature.{VectorAssembler, VectorIndexer}

val numFolds = ??? // Integer
val data = ??? // DataFrame

// Training (80%) and test data (20%)
val Array(train, test) = data.randomSplit(Array(0.8,0.2))
val featuresCols = data.columns
val va = new VectorAssembler()
va.setInputCols(featuresCols)
va.setOutputCol("rawFeatures")
val vi = new VectorIndexer()
vi.setInputCol("rawFeatures")
vi.setOutputCol("features")
vi.setMaxCategories(5)
val regressor = new RandomForestRegressor()
regressor.setLabelCol("events")

val metric = "rmse"
val evaluator = new RegressionEvaluator()
  .setLabelCol("events")
  .setPredictionCol("prediction")
  //     "rmse" (default): root mean squared error
  //     "mse": mean squared error
  //     "r2": R2 metric
  //     "mae": mean absolute error 
  .setMetricName(metric) 

val paramGrid = new ParamGridBuilder().build()
val cv = new CrossValidator()
  .setEstimator(regressor)
  .setEvaluator(evaluator) 
  .setEstimatorParamMaps(paramGrid)
  .setNumFolds(numFolds)

val model = cv.fit(train) // train: DataFrame
val predictions = model.transform(test)
predictions.show
val rmse = evaluator.evaluate(predictions)
println(rmse)

Source de mesure de l'évaluateur: https://spark.Apache.org/docs/latest/api/scala/#org.Apache.spark.ml.evaluation.RegressionEvaluator

2
Garren S