web-dev-qa-db-fra.com

Traitement des jeux de données non équilibrés dans Spark MLlib

Je travaille sur un problème particulier de classification binaire avec un jeu de données très déséquilibré, et je me demandais si quelqu'un avait essayé d'implémenter des techniques spécifiques pour traiter les jeux de données non équilibrés (tels que SMOTE ) dans les problèmes de classification à l'aide de MLlib de Spark .

J'utilise l'implémentation Random Forest de MLLib et j'ai déjà essayé l'approche la plus simple consistant à sous-échantillonner de manière aléatoire la classe la plus large, mais cela n'a pas fonctionné aussi bien que prévu.

Je vous serais reconnaissant de tout commentaire concernant votre expérience avec des problèmes similaires.

Merci,

20
dbakr

J'ai utilisé la solution de @Serendipity, mais nous pouvons optimiser la fonction balanceDataset pour éviter d'utiliser un fichier udf. J'ai également ajouté la possibilité de changer la colonne d'étiquette utilisée. Voici la version de la fonction avec laquelle j'ai fini:

def balanceDataset(dataset: DataFrame, label: String = "label"): DataFrame = {
  // Re-balancing (weighting) of records to be used in the logistic loss objective function
  val (datasetSize, positives) = dataset.select(count("*"), sum(dataset(label))).as[(Long, Double)].collect.head
  val balancingRatio = positives / datasetSize

  val weightedDataset = {
    dataset.withColumn("classWeightCol", when(dataset(label) === 0.0, balancingRatio).otherwise(1.0 - balancingRatio))
  }
  weightedDataset
}

Nous créons le classificateur comme il l'a déclaré:

new LogisticRegression().setWeightCol("classWeightCol").setLabelCol("label").setFeaturesCol("features")
1
kanielc

@dbakr Avez-vous obtenu une réponse pour votre prédiction biaisée sur votre jeu de données déséquilibré?

Bien que je ne sois pas sûr que ce soit votre plan initial, notez que si vous sous-échantillonnez d'abord la classe majoritaire de votre jeu de données selon un ratio r, vous pouvez alors obtenir les prédictions non biaisées de la régression logistique de Spark: - utilisez la commande rawPrediction fournie par la fonction transform() et ajustez l'interception avec log(r) - ou vous pouvez entraîner votre régression avec des poids à l'aide de .setWeightCol("classWeightCol") (voir l'article cité ici pour déterminer la valeur qui doit être mis dans les poids).

0
PSAfrance