web-dev-qa-db-fra.com

collect_list en préservant l'ordre basé sur une autre variable

J'essaie de créer une nouvelle colonne de listes dans Pyspark en utilisant une agrégation groupby sur un ensemble existant de colonnes. Un exemple de trame de données d'entrée est fourni ci-dessous:

------------------------
id | date        | value
------------------------
1  |2014-01-03   | 10 
1  |2014-01-04   | 5
1  |2014-01-05   | 15
1  |2014-01-06   | 20
2  |2014-02-10   | 100   
2  |2014-03-11   | 500
2  |2014-04-15   | 1500

Le résultat attendu est:

id | value_list
------------------------
1  | [10, 5, 15, 20]
2  | [100, 500, 1500]

Les valeurs d'une liste sont triées par date.

J'ai essayé d'utiliser collect_list comme suit:

from pyspark.sql import functions as F
ordered_df = input_df.orderBy(['id','date'],ascending = True)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))

Mais collect_list ne garantit pas l'ordre même si je trie la trame de données d'entrée par date avant agrégation.

Quelqu'un pourrait-il m'aider à faire l'agrégation en préservant l'ordre basé sur une deuxième variable (date)?

21
Ravi

Si vous collectez à la fois des dates et des valeurs sous forme de liste, vous pouvez trier la colonne résultante en fonction de la date à l'aide de et udf, puis conserver uniquement les valeurs dans le résultat.

import operator
import pyspark.sql.functions as F

# create list column
grouped_df = input_df.groupby("id") \
               .agg(F.collect_list(F.struct("date", "value")) \
               .alias("list_col"))

# define udf
def sorter(l):
  res = sorted(l, key=operator.itemgetter(0))
  return [item[1] for item in res]

sort_udf = F.udf(sorter)

# test
grouped_df.select("id", sort_udf("list_col") \
  .alias("sorted_list")) \
  .show(truncate = False)
+---+----------------+
|id |sorted_list     |
+---+----------------+
|1  |[10, 5, 15, 20] |
|2  |[100, 500, 1500]|
+---+----------------+
30
mtoto
from pyspark.sql import functions as F
from pyspark.sql import Window

w = Window.partitionBy('id').orderBy('date')

sorted_list_df = input_df.withColumn(
            'sorted_list', F.collect_list('value').over(w)
        )\
        .groupBy('id')\
        .agg(F.max('sorted_list').alias('sorted_list'))

Window les exemples fournis par les utilisateurs n'expliquent souvent pas vraiment ce qui se passe alors laissez-moi le disséquer pour vous.

Comme vous le savez, en utilisant collect_list avec groupBy donnera une liste de valeurs non ordonnée. En effet, selon la façon dont vos données sont partitionnées, Spark ajoutera des valeurs à votre liste dès qu'il trouvera une ligne dans le groupe. L'ordre dépend ensuite de la façon dont Spark planifie votre agrégation sur les exécuteurs.

Une fonction Window vous permet de contrôler cette situation, en regroupant les lignes par une certaine valeur afin que vous puissiez effectuer une opération over chacun des groupes résultants:

w = Window.partitionBy('id').orderBy('date')
  • partitionBy - vous voulez des groupes/partitions de lignes avec le même id
  • orderBy - vous voulez que chaque ligne du groupe soit triée par date

Une fois que vous avez défini l'étendue de votre fenêtre - "lignes avec le même id, triées par date" -, vous pouvez l'utiliser pour effectuer une opération sur elle, dans ce cas, un collect_list:

F.collect_list('value').over(w)

À ce stade, vous avez créé une nouvelle colonne sorted_list avec une liste ordonnée de valeurs, triées par date, mais vous avez toujours des lignes dupliquées par id. Pour supprimer les lignes dupliquées que vous souhaitez groupByid et conserver la valeur max pour chaque groupe:

.groupBy('id')\
.agg(F.max('sorted_list').alias('sorted_list'))
43
TMichel

La question était pour PySpark mais pourrait être utile de l'avoir aussi pour Scala Spark.

Préparons le cadre de données de test:

import org.Apache.spark.sql.functions._
import org.Apache.spark.sql.{DataFrame, Row, SparkSession}
import org.Apache.spark.sql.expressions.{ Window, UserDefinedFunction}

import Java.sql.Date
import Java.time.LocalDate

val spark: SparkSession = ...

// Out test data set
val data: Seq[(Int, Date, Int)] = Seq(
  (1, Date.valueOf(LocalDate.parse("2014-01-03")), 10),
  (1, Date.valueOf(LocalDate.parse("2014-01-04")), 5),
  (1, Date.valueOf(LocalDate.parse("2014-01-05")), 15),
  (1, Date.valueOf(LocalDate.parse("2014-01-06")), 20),
  (2, Date.valueOf(LocalDate.parse("2014-02-10")), 100),
  (2, Date.valueOf(LocalDate.parse("2014-02-11")), 500),
  (2, Date.valueOf(LocalDate.parse("2014-02-15")), 1500)
)

// Create dataframe
val df: DataFrame = spark.createDataFrame(data)
  .toDF("id", "date", "value")
df.show()
//+---+----------+-----+
//| id|      date|value|
//+---+----------+-----+
//|  1|2014-01-03|   10|
//|  1|2014-01-04|    5|
//|  1|2014-01-05|   15|
//|  1|2014-01-06|   20|
//|  2|2014-02-10|  100|
//|  2|2014-02-11|  500|
//|  2|2014-02-15| 1500|
//+---+----------+-----+

Utiliser UDF

// Group by id and aggregate date and value to new column date_value
val grouped = df.groupBy(col("id"))
  .agg(collect_list(struct("date", "value")) as "date_value")
grouped.show()
grouped.printSchema()
// +---+--------------------+
// | id|          date_value|
// +---+--------------------+
// |  1|[[2014-01-03,10],...|
// |  2|[[2014-02-10,100]...|
// +---+--------------------+

// udf to extract data from Row, sort by needed column (date) and return value
val sortUdf: UserDefinedFunction = udf((rows: Seq[Row]) => {
  rows.map { case Row(date: Date, value: Int) => (date, value) }
    .sortBy { case (date, value) => date }
    .map { case (date, value) => value }
})

// Select id and value_list
val r1 = grouped.select(col("id"), sortUdf(col("date_value")).alias("value_list"))
r1.show()
// +---+----------------+
// | id|      value_list|
// +---+----------------+
// |  1| [10, 5, 15, 20]|
// |  2|[100, 500, 1500]|
// +---+----------------+

Utiliser la fenêtre

val window = Window.partitionBy(col("id")).orderBy(col("date"))
val sortedDf = df.withColumn("values_sorted_by_date", collect_list("value").over(window))
sortedDf.show()
//+---+----------+-----+---------------------+
//| id|      date|value|values_sorted_by_date|
//+---+----------+-----+---------------------+
//|  1|2014-01-03|   10|                 [10]|
//|  1|2014-01-04|    5|              [10, 5]|
//|  1|2014-01-05|   15|          [10, 5, 15]|
//|  1|2014-01-06|   20|      [10, 5, 15, 20]|
//|  2|2014-02-10|  100|                [100]|
//|  2|2014-02-11|  500|           [100, 500]|
//|  2|2014-02-15| 1500|     [100, 500, 1500]|
//+---+----------+-----+---------------------+

val r2 = sortedDf.groupBy(col("id"))
  .agg(max("values_sorted_by_date").as("value_list")) 
r2.show()
//+---+----------------+
//| id|      value_list|
//+---+----------------+
//|  1| [10, 5, 15, 20]|
//|  2|[100, 500, 1500]|
//+---+----------------+
8
Artavazd Balayan

Pour nous assurer que le tri est effectué pour chaque identifiant, nous pouvons utiliser sortWithinPartitions:

from pyspark.sql import functions as F
ordered_df = (
    input_df
        .repartition(input_df.id)
        .sortWithinPartitions(['date'])


)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))
4
ShadyStego