web-dev-qa-db-fra.com

Récupérer les n premiers de chaque groupe d'un DataFrame dans pyspark

Il y a un DataFrame dans pyspark avec les données ci-dessous:

user_id object_id score
user_1  object_1  3
user_1  object_1  1
user_1  object_2  2
user_2  object_1  5
user_2  object_2  2
user_2  object_2  6

Ce que j'attends, c'est de renvoyer 2 enregistrements dans chaque groupe avec le même identifiant utilisateur, qui doivent avoir le score le plus élevé. Par conséquent, le résultat devrait ressembler à ceci:

user_id object_id score
user_1  object_1  3
user_1  object_2  2
user_2  object_2  6
user_2  object_1  5

Je suis vraiment nouveau sur pyspark. Quelqu'un pourrait-il me donner un extrait de code ou un portail vers la documentation associée à ce problème? Grand merci!

33
KAs

Je crois que vous devez utiliser fonctions de la fenêtre pour atteindre le rang de chaque ligne en fonction de user_id et score, puis filtrez vos résultats pour ne conserver que les deux premières valeurs.

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())

df.select('*', rank().over(window).alias('rank')) 
  .filter(col('rank') <= 2) 
  .show() 
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1|    3|   1|
#| user_1| object_2|    2|   2|
#| user_2| object_2|    6|   1|
#| user_2| object_1|    5|   2|
#+-------+---------+-----+----+

En général, le guide officiel guide de programmation est un bon point de départ pour apprendre Spark.

Les données

rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
51
mtoto

Top-n est plus précis si vous utilisez row_number Au lieu de rank lorsque vous obtenez l'égalité de rang:

val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
  .where(col('row_number') <= n) \
  .limit(20) \
  .toPandas()

Remarque limit(20).toPandas() astuce au lieu de show() pour les blocs-notes Jupyter pour une mise en forme plus agréable.

19
Martin Tapp

Je sais que la question est posée pour pyspark et je cherchais une réponse similaire dans Scala i.e.

Récupérer les n premières valeurs dans chaque groupe d'un DataFrame dans Scala

Voici la version scala de la réponse de @ mtoto.

import org.Apache.spark.sql.expressions.Window
import org.Apache.spark.sql.functions.rank
import org.Apache.spark.sql.functions.col

val window = Window.partitionBy("user_id").orderBy('score desc)
val rankByScore = rank().over(window)
df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show() 
# you can change the value 2 to any number you want. Here 2 represents the top 2 values

Plus d'exemples peuvent être trouvés ici .

2
Abu Shoeb