web-dev-qa-db-fra.com

Spark / Scala a répété des appels à withColumn () en utilisant la même fonction sur plusieurs colonnes

J'ai actuellement du code dans lequel j'applique à plusieurs reprises la même procédure à plusieurs colonnes DataFrame via plusieurs chaînes de .withColumn, et je souhaite créer une fonction pour rationaliser la procédure. Dans mon cas, je trouve des sommes cumulées sur des colonnes agrégées par clés:

val newDF = oldDF
  .withColumn("cumA", sum("A").over(Window.partitionBy("ID").orderBy("time")))
  .withColumn("cumB", sum("B").over(Window.partitionBy("ID").orderBy("time")))
  .withColumn("cumC", sum("C").over(Window.partitionBy("ID").orderBy("time")))
  //.withColumn(...)

Ce que je voudrais, c'est quelque chose comme:

def createCumulativeColums(cols: Array[String], df: DataFrame): DataFrame = {
  // Implement the above cumulative sums, partitioning, and ordering
}

ou mieux encore:

def withColumns(cols: Array[String], df: DataFrame, f: function): DataFrame = {
  // Implement a udf/arbitrary function on all the specified columns
}

Vous pouvez utiliser select avec des varargs incluant *:

import spark.implicits._

df.select($"*" +: Seq("A", "B", "C").map(c => 
  sum(c).over(Window.partitionBy("ID").orderBy("time")).alias(s"cum$c")
): _*)

Cette:

  • Mappe les noms des colonnes aux expressions de fenêtre avec Seq("A", ...).map(...)
  • Ajoute toutes les colonnes préexistantes à $"*" +: ....
  • Décompresse la séquence combinée avec ... : _*.

et peut être généralisé comme:

import org.Apache.spark.sql.{Column, DataFrame}

/**
 * @param cols a sequence of columns to transform
 * @param df an input DataFrame
 * @param f a function to be applied on each col in cols
 */
def withColumns(cols: Seq[String], df: DataFrame, f: String => Column) =
  df.select($"*" +: cols.map(c => f(c)): _*)

Si vous trouvez la syntaxe withColumn plus lisible, vous pouvez utiliser foldLeft:

Seq("A", "B", "C").foldLeft(df)((df, c) =>
  df.withColumn(s"cum$c",  sum(c).over(Window.partitionBy("ID").orderBy("time")))
)

qui peut être généralisé par exemple pour:

/**
 * @param cols a sequence of columns to transform
 * @param df an input DataFrame
 * @param f a function to be applied on each col in cols
 * @param name a function mapping from input to output name.
 */
def withColumns(cols: Seq[String], df: DataFrame, 
    f: String =>  Column, name: String => String = identity) =
  cols.foldLeft(df)((df, c) => df.withColumn(name(c), f(c)))
26
user6910411

La question est un peu ancienne, mais j'ai pensé qu'il serait utile (peut-être pour d'autres) de noter que le repliement de la liste des colonnes utilisant le DataFrame comme accumulateur et le mappage sur le DataFrame ont sensiblement différents résultats de performance lorsque le nombre de colonnes n'est pas trivial (voir ici pour l'explication complète). Pour faire court ... pour quelques colonnes, foldLeft est bien, sinon map est mieux.

5
Lorenzo