web-dev-qa-db-fra.com

Comment sélectionner la première ligne de chaque groupe?

J'ai un DataFrame généré comme suit:

df.groupBy($"Hour", $"Category")
  .agg(sum($"value") as "TotalValue")
  .sort($"Hour".asc, $"TotalValue".desc))

Les résultats ressemblent à:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   0|   cat13|      22.1|
|   0|   cat95|      19.6|
|   0|  cat105|       1.3|
|   1|   cat67|      28.5|
|   1|    cat4|      26.8|
|   1|   cat13|      12.6|
|   1|   cat23|       5.3|
|   2|   cat56|      39.6|
|   2|   cat40|      29.7|
|   2|  cat187|      27.9|
|   2|   cat68|       9.8|
|   3|    cat8|      35.6|
| ...|    ....|      ....|
+----+--------+----------+

Comme vous pouvez le constater, le DataFrame est classé par Hour dans un ordre croissant, puis par TotalValue par ordre décroissant.

J'aimerais sélectionner la rangée du haut de chaque groupe, c'est-à-dire.

  • dans le groupe des heures == 0 sélectionner (0, cat26,30.9)
  • dans le groupe des heures == 1 sélectionnez (1, cat67,28.5)
  • dans le groupe des heures == 2 sélectionner (2, cat56,39.6)
  • etc

Donc, le résultat souhaité serait:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
| ...|     ...|       ...|
+----+--------+----------+

Il peut être utile de pouvoir également sélectionner les N premières lignes de chaque groupe.

Toute aide est grandement appréciée.

101
Rami

Fonctions de la fenêtre:

Quelque chose comme ça devrait faire l'affaire:

import org.Apache.spark.sql.functions.{row_number, max, broadcast}
import org.Apache.spark.sql.expressions.Window

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Cette méthode sera inefficace en cas de biais important des données.

Agrégation SQL simple suivie de join:

Sinon, vous pouvez rejoindre avec un bloc de données agrégé:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))

val dfTopByJoin = df.join(broadcast(dfMax),
    ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
  .drop("max_hour")
  .drop("max_value")

dfTopByJoin.show

// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Il conservera les valeurs en double (s'il y a plus d'une catégorie par heure avec la même valeur totale). Vous pouvez les supprimer comme suit:

dfTopByJoin
  .groupBy($"hour")
  .agg(
    first("category").alias("category"),
    first("TotalValue").alias("TotalValue"))

Utilisation de la commande sur structs:

Astuce soignée, mais pas très bien testée, ne nécessitant ni jointure ni fonctions de fenêtre:

val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
  .groupBy($"hour")
  .agg(max("vs").alias("vs"))
  .select($"Hour", $"vs.Category", $"vs.TotalValue")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Avec API DataSet (Spark 1.6+, 2.0+):

Spark 1.6:

case class Record(Hour: Integer, Category: String, TotalValue: Double)

df.as[Record]
  .groupBy($"hour")
  .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
  .show

// +---+--------------+
// | _1|            _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+

Spark 2.0 ou version ultérieure:

df.as[Record]
  .groupByKey(_.Hour)
  .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

Les deux dernières méthodes peuvent tirer parti de la combinaison côté carte et ne nécessitent pas de lecture aléatoire. La plupart du temps, elles devraient donc offrir de meilleures performances que les fonctions de fenêtre et les jointures. Celles-ci peuvent également être utilisées avec le streaming structuré en mode de sortie completed.

N'utilisez pas:

df.orderBy(...).groupBy(...).agg(first(...), ...)

Cela peut sembler fonctionner (surtout dans le mode local) mais il n’est pas fiable ( SPARK-16207 ). Crédits à Tzach Zohar pour reliant le problème pertinent de JIRA .

La même note s'applique à 

df.orderBy(...).dropDuplicates(...)

qui utilise en interne un plan d'exécution équivalent.

169
zero323

Pour Spark 2.0.2 avec regroupement de plusieurs colonnes:

import org.Apache.spark.sql.functions.row_number
import org.Apache.spark.sql.expressions.Window

val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)

val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
11
Antonín Hoskovec

C'est exactement la même chose que zero323 's answer mais en mode requête SQL.

En supposant que le cadre de données soit créé et enregistré en tant que 

df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0   |cat26   |30.9      |
//|0   |cat13   |22.1      |
//|0   |cat95   |19.6      |
//|0   |cat105  |1.3       |
//|1   |cat67   |28.5      |
//|1   |cat4    |26.8      |
//|1   |cat13   |12.6      |
//|1   |cat23   |5.3       |
//|2   |cat56   |39.6      |
//|2   |cat40   |29.7      |
//|2   |cat187  |27.9      |
//|2   |cat68   |9.8       |
//|3   |cat8    |35.6      |
//+----+--------+----------+

Fonction de la fenêtre:

sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn  FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Agrégation SQL simple suivie d'une jointure:

sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
  "(select Hour, Category, TotalValue from table tmp1 " +
  "join " +
  "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
  "on " +
  "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
  "group by tmp3.Hour")
  .show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Utilisation de la commande sur les structures:

sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Méthode des ensembles de données et ne pas faire s sont identiques à ceux de la réponse d'origine

7
Ramesh Maharjan

La solution ci-dessous ne fait qu'un groupe et extrait les lignes de votre cadre de données contenant la valeur maxValue en une seule fois. Pas besoin d'autres jointures, ni de Windows.

import org.Apache.spark.sql.Row
import org.Apache.spark.sql.catalyst.encoders.RowEncoder
import org.Apache.spark.sql.DataFrame

//df is the dataframe with Day, Category, TotalValue

implicit val dfEnc = RowEncoder(df.schema)

val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}
1
elghoto

Ici tu peux faire comme ça -

   val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")

data.withColumnRenamed("_1","Hour").show
1
Shubham Agrawal

Si le dataframe doit être groupé en plusieurs colonnes, cela peut aider

val keys = List("Hour", "Category");
 val selectFirstValueOfNoneGroupedColumns = 
 df.columns
   .filterNot(keys.toSet)
   .map(_ -> "first").toMap
 val grouped = 
 df.groupBy(keys.head, keys.tail: _*)
   .agg(selectFirstValueOfNoneGroupedColumns)

J'espère que cela aide quelqu'un avec un problème similaire

0
NehaM

Le modèle est groupe par clés => faire quelque chose à chaque groupe, par exemple. réduire => retourner au dataframe

Je pensais que l'abstraction Dataframe était un peu lourde dans ce cas, alors j'ai utilisé la fonctionnalité RDD

 val rdd: RDD[Row] = originalDf
  .rdd
  .groupBy(row => row.getAs[String]("grouping_row"))
  .map(iterableTuple => {
    iterableTuple._2.reduce(reduceFunction)
  })

val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)
0
Rubber Duck

Une bonne façon de faire cela avec l’API Dataframe utilise la logique Argmax comme si

  val df = Seq(
    (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
    (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
    (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
    (3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")

  df.groupBy($"Hour")
    .agg(max(struct($"TotalValue", $"Category")).as("argmax"))
    .select($"Hour", $"argmax.*").show

 +----+----------+--------+
 |Hour|TotalValue|Category|
 +----+----------+--------+
 |   1|      28.5|   cat67|
 |   3|      35.6|    cat8|
 |   2|      39.6|   cat56|
 |   0|      30.9|   cat26|
 +----+----------+--------+
0
randal25