web-dev-qa-db-fra.com

Remodelage / Pivotement des données dans Spark RDD et / ou Spark DataFrames

J'ai des données au format suivant (RDD ou Spark DataFrame):

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

 rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

# convert to a Spark DataFrame                    
schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlContext.createDataFrame(rdd, schema)

Ce que je voudrais faire, c'est "remodeler" les données, convertir certaines lignes du pays (en particulier les États-Unis, le Royaume-Uni et la Californie) en colonnes:

ID    Age  US  UK  CA  
'X01'  41  3   1   2  
'X02'  72  4   6   7   

Essentiellement, j'ai besoin de quelque chose dans le sens du workflow pivot de Python:

categories = ['US', 'UK', 'CA']
new_df = df[df['Country'].isin(categories)].pivot(index = 'ID', 
                                                  columns = 'Country',
                                                  values = 'Score')

Mon ensemble de données est assez volumineux, donc je ne peux pas vraiment collect() et ingérer les données en mémoire pour effectuer le remodelage dans Python lui-même. Existe-t-il un moyen de convertir .pivot() dans une fonction invocable lors du mappage d'un RDD ou d'un Spark DataFrame? Toute aide serait appréciée!

23
Jason

Depuis Spark 1.6 vous pouvez utiliser la fonction pivot sur GroupedData et fournir une expression agrégée.

pivoted = (df
    .groupBy("ID", "Age")
    .pivot(
        "Country",
        ['US', 'UK', 'CA'])  # Optional list of levels
    .sum("Score"))  # alternatively you can use .agg(expr))
pivoted.show()

## +---+---+---+---+---+
## | ID|Age| US| UK| CA|
## +---+---+---+---+---+
## |X01| 41|  3|  1|  2|
## |X02| 72|  4|  6|  7|
## +---+---+---+---+---+

Les niveaux peuvent être omis, mais s'ils sont fournis, ils peuvent à la fois améliorer les performances et servir de filtre interne.

Cette méthode est encore relativement lente mais bat certainement les données de passage manuel manuellement entre JVM et Python.

18
zero323

Tout d'abord, ce n'est probablement pas une bonne idée, car vous n'obtenez aucune information supplémentaire, mais vous vous liez avec un schéma fixe (c'est-à-dire que vous devez savoir combien de pays vous attendez, et bien sûr, un pays supplémentaire signifie changement de code)

Cela dit, il s'agit d'un problème SQL, qui est illustré ci-dessous. Mais au cas où vous supposeriez que ce n'est pas trop "logiciel comme" (sérieusement, j'ai entendu cela !!), alors vous pouvez vous référer à la première solution.

Solution 1:

def reshape(t):
    out = []
    out.append(t[0])
    out.append(t[1])
    for v in brc.value:
        if t[2] == v:
            out.append(t[3])
        else:
            out.append(0)
    return (out[0],out[1]),(out[2],out[3],out[4],out[5])
def cntryFilter(t):
    if t[2] in brc.value:
        return t
    else:
        pass

def addtup(t1,t2):
    j=()
    for k,v in enumerate(t1):
        j=j+(t1[k]+t2[k],)
    return j

def seq(tIntrm,tNext):
    return addtup(tIntrm,tNext)

def comb(tP,tF):
    return addtup(tP,tF)


countries = ['CA', 'UK', 'US', 'XX']
brc = sc.broadcast(countries)
reshaped = calls.filter(cntryFilter).map(reshape)
pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1)
for i in pivot.collect():
    print i

Maintenant, Solution 2: Bien sûr, mieux car SQL est le bon outil pour cela

callRow = calls.map(lambda t:   

Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3]))
callsDF = ssc.createDataFrame(callRow)
callsDF.printSchema()
callsDF.registerTempTable("calls")
res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\
                    from (select userid,age,\
                                  case when country='CA' then nbrCalls else 0 end ca,\
                                  case when country='UK' then nbrCalls else 0 end uk,\
                                  case when country='US' then nbrCalls else 0 end us,\
                                  case when country='XX' then nbrCalls else 0 end xx \
                             from calls) x \
                     group by userid,age")
res.show()

configuration des données:

data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)]
 calls = sc.parallelize(data,1)
countries = ['CA', 'UK', 'US', 'XX']

Résultat:

De la 1ère solution

(('X02', 72), (7, 6, 4, 8)) 
(('X01', 41), (2, 1, 3, 0))

De la 2ème solution:

root  |-- age: long (nullable = true)  
      |-- country: string (nullable = true)  
      |-- nbrCalls: long (nullable = true)  
      |-- userid: string (nullable = true)

userid age ca uk us xx 
 X02    72  7  6  4  8  
 X01    41  2  1  3  0

Veuillez me faire savoir si cela fonctionne ou non :)

Meilleur Ayan

7
ayan guha

Voici une approche native Spark qui ne câble pas les noms des colonnes. Elle est basée sur aggregateByKey, et utilise un dictionnaire pour collecter les colonnes qui apparaissent pour chaque clé. Ensuite, nous rassemblons tous les noms de colonnes pour créer la trame de données finale. [La version précédente utilisait jsonRDD après avoir émis un dictionnaire pour chaque enregistrement, mais c'est plus efficace.] Restreindre à une liste spécifique de colonnes, ou exclure celles comme XX serait une modification facile.

Les performances semblent bonnes même sur des tables assez grandes. J'utilise une variation qui compte le nombre de fois que chacun d'un nombre variable d'événements se produit pour chaque ID, générant une colonne par type d'événement. Le code est fondamentalement le même sauf qu'il utilise un collections.Counter au lieu d'un dict dans le seqFn pour compter les occurrences.

from pyspark.sql.types import *

rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlCtx.createDataFrame(rdd, schema)

def seqPivot(u, v):
    if not u:
        u = {}
    u[v.Country] = v.Score
    return u

def cmbPivot(u1, u2):
    u1.update(u2)
    return u1

pivot = (
    df
    .rdd
    .keyBy(lambda row: row.ID)
    .aggregateByKey(None, seqPivot, cmbPivot)
)
columns = (
    pivot
    .values()
    .map(lambda u: set(u.keys()))
    .reduce(lambda s,t: s.union(t))
)
result = sqlCtx.createDataFrame(
    pivot
    .map(lambda (k, u): [k] + [u.get(c) for c in columns]),
    schema=StructType(
        [StructField('ID', StringType())] + 
        [StructField(c, IntegerType()) for c in columns]
    )
)
result.show()

Produit:

ID  CA UK US XX  
X02 7  6  4  8   
X01 2  1  3  null
5
patricksurry

Donc, tout d'abord, j'ai dû apporter cette correction à votre RDD (qui correspond à votre sortie réelle):

rdd = sc.parallelize([('X01',41,'US',3),
                      ('X01',41,'UK',1),
                      ('X01',41,'CA',2),
                      ('X02',72,'US',4),
                      ('X02',72,'UK',6),
                      ('X02',72,'CA',7),
                      ('X02',72,'XX',8)])

Une fois que j'ai fait cette correction, cela a fait l'affaire:

df.select($"ID", $"Age").groupBy($"ID").agg($"ID", first($"Age") as "Age")
.join(
    df.select($"ID" as "usID", $"Country" as "C1",$"Score" as "US"),
    $"ID" === $"usID" and $"C1" === "US"
)
.join(
    df.select($"ID" as "ukID", $"Country" as "C2",$"Score" as "UK"),
    $"ID" === $"ukID" and $"C2" === "UK"
)
.join(
    df.select($"ID" as "caID", $"Country" as "C3",$"Score" as "CA"), 
    $"ID" === $"caID" and $"C3" === "CA"
)
.select($"ID",$"Age",$"US",$"UK",$"CA")

Certainement pas aussi élégant que votre pivot.

1
David Griffin

Juste quelques commentaires sur la réponse très utile de patricksurry:

  • la colonne Age est manquante, il suffit donc d'ajouter u ["Age"] = v.Age à la fonction seqPivot
  • il s'est avéré que les deux boucles sur les éléments des colonnes ont donné les éléments dans un ordre différent. Les valeurs des colonnes étaient correctes, mais pas leurs noms. Pour éviter ce comportement, commandez simplement la liste des colonnes.

Voici le code légèrement modifié:

from pyspark.sql.types import *

rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlCtx.createDataFrame(rdd, schema)

# u is a dictionarie
# v is a Row
def seqPivot(u, v):
    if not u:
        u = {}
    u[v.Country] = v.Score
    # In the original posting the Age column was not specified
    u["Age"] = v.Age
    return u

# u1
# u2
def cmbPivot(u1, u2):
    u1.update(u2)
    return u1

pivot = (
    rdd
    .map(lambda row: Row(ID=row[0], Age=row[1], Country=row[2],  Score=row[3]))
    .keyBy(lambda row: row.ID)
    .aggregateByKey(None, seqPivot, cmbPivot)
)

columns = (
    pivot
    .values()
    .map(lambda u: set(u.keys()))
    .reduce(lambda s,t: s.union(t))
)

columns_ord = sorted(columns)

result = sqlCtx.createDataFrame(
    pivot
    .map(lambda (k, u): [k] + [u.get(c, None) for c in columns_ord]),
        schema=StructType(
            [StructField('ID', StringType())] + 
            [StructField(c, IntegerType()) for c in columns_ord]
        )
    )

print result.show()

Enfin, la sortie doit être

+---+---+---+---+---+----+
| ID|Age| CA| UK| US|  XX|
+---+---+---+---+---+----+
|X02| 72|  7|  6|  4|   8|
|X01| 41|  2|  1|  3|null|
+---+---+---+---+---+----+
1
rolpat

Il y a un JIRA dans Hive pour PIVOT pour le faire en natif, sans une énorme instruction CASE pour chaque valeur:

https://issues.Apache.org/jira/browse/Hive-3776

Veuillez voter pour cette JIRA afin qu'elle soit mise en œuvre plus tôt. Une fois qu'il est dans Hive SQL, Spark ne manque généralement pas trop derrière et il sera finalement implémenté dans Spark également).

0
Tagar