web-dev-qa-db-fra.com

Groupe DataFrame / Dataset Par comportement / optimisation

Supposons que nous ayons DataFrame df composé des colonnes suivantes:

Nom, prénom, taille, largeur, longueur, poids

Maintenant, nous voulons effectuer quelques opérations, par exemple, nous voulons créer quelques DataFrames contenant des données sur la taille et la largeur.

val df1 = df.groupBy("surname").agg( sum("size") )
val df2 = df.groupBy("surname").agg( sum("width") )

comme vous pouvez le constater, d'autres colonnes, comme la longueur, ne sont utilisées nulle part. Spark est-il assez intelligent pour supprimer les colonnes redondantes avant la phase de brassage ou sont-elles transportées? Wil s'exécute:

val dfBasic = df.select("surname", "size", "width")

avant de regrouper en quelque sorte affecter les performances?

30
Niemand

Oui, c'est " assez intelligent ". groupBy effectuée sur un DataFrame n'est pas la même opération que groupBy effectuée sur un RDD ordinaire. Dans un scénario que vous avez décrit, il n'est pas du tout nécessaire de déplacer des données brutes. Créons un petit exemple pour illustrer cela:

val df = sc.parallelize(Seq(
   ("a", "foo", 1), ("a", "foo", 3), ("b", "bar", 5), ("b", "bar", 1)
)).toDF("x", "y", "z")

df.groupBy("x").agg(sum($"z")).explain

// == Physical Plan ==
// *HashAggregate(keys=[x#148], functions=[sum(cast(z#150 as bigint))])
// +- Exchange hashpartitioning(x#148, 200)
//    +- *HashAggregate(keys=[x#148], functions=[partial_sum(cast(z#150 as bigint))])
//       +- *Project [_1#144 AS x#148, _3#146 AS z#150]
//          +- *SerializeFromObject [staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#144, staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#145, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#146]
//             +- Scan ExternalRDDScan[obj#143]

Comme vous le pouvez, la première phase est une projection où seules les colonnes requises sont conservées. Les données suivantes sont agrégées localement et finalement transférées et agrégées globalement. Vous obtiendrez un résultat de réponse légèrement différent si vous utilisez Spark <= 1.4 mais la structure générale doit être exactement la même.

Enfin, une visualisation DAG montrant que la description ci-dessus décrit le travail réel:

group by and agg DAG

De même, Dataset.groupByKey suivi de reduceGroups, contient les deux côtés de la carte (ObjectHashAggregate avec partial_reduceaggregator) et côté réduit (ObjectHashAggregate avec reduceaggregator réduction):

case class Foo(x: String, y: String, z: Int)

val ds = df.as[Foo]
ds.groupByKey(_.x).reduceGroups((x, y) => x.copy(z = x.z + y.z)).explain

// == Physical Plan ==
// ObjectHashAggregate(keys=[value#126], functions=[reduceaggregator(org.Apache.spark.sql.expressions.ReduceAggregator@54d90261, Some(newInstance(class $line40.$read$$iw$$iw$Foo)), Some(class $line40.$read$$iw$$iw$Foo), Some(StructType(StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false))), input[0, scala.Tuple2, true]._1 AS value#128, if ((isnull(input[0, scala.Tuple2, true]._2) || None.equals)) null else named_struct(x, staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).x, true, false) AS x#25, y, staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).y, true, false) AS y#26, z, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).z AS z#27) AS _2#129, newInstance(class scala.Tuple2), staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).x, true, false) AS x#25, staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).y, true, false) AS y#26, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).z AS z#27, StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false), true, 0, 0)])
// +- Exchange hashpartitioning(value#126, 200)
//    +- ObjectHashAggregate(keys=[value#126], functions=[partial_reduceaggregator(org.Apache.spark.sql.expressions.ReduceAggregator@54d90261, Some(newInstance(class $line40.$read$$iw$$iw$Foo)), Some(class $line40.$read$$iw$$iw$Foo), Some(StructType(StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false))), input[0, scala.Tuple2, true]._1 AS value#128, if ((isnull(input[0, scala.Tuple2, true]._2) || None.equals)) null else named_struct(x, staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).x, true, false) AS x#25, y, staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).y, true, false) AS y#26, z, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]._2)).z AS z#27) AS _2#129, newInstance(class scala.Tuple2), staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).x, true, false) AS x#25, staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).y, true, false) AS y#26, assertnotnull(assertnotnull(input[0, $line40.$read$$iw$$iw$Foo, true])).z AS z#27, StructField(x,StringType,true), StructField(y,StringType,true), StructField(z,IntegerType,false), true, 0, 0)])
//       +- AppendColumns <function1>, newInstance(class $line40.$read$$iw$$iw$Foo), [staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, Java.lang.String, true], true, false) AS value#126]
//          +- *Project [_1#4 AS x#8, _2#5 AS y#9, _3#6 AS z#10]
//             +- *SerializeFromObject [staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#4, staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#5, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#6]
//                +- Scan ExternalRDDScan[obj#3]

groupByKey + reduceGroups

Cependant, d'autres méthodes de KeyValueGroupedDataset peuvent fonctionner de manière similaire à RDD.groupByKey. Par exemple, mapGroups (ou flatMapGroups) n'utilise pas d'agrégation partielle.

ds.groupByKey(_.x)
  .mapGroups((_, iter) => iter.reduce((x, y) => x.copy(z = x.z + y.z)))
  .explain

//== Physical Plan ==
//*SerializeFromObject [staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).x, true, false) AS x#37, staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).y, true, false) AS y#38, assertnotnull(input[0, $line15.$read$$iw$$iw$Foo, true]).z AS z#39]
//+- MapGroups <function2>, value#32.toString, newInstance(class $line15.$read$$iw$$iw$Foo), [value#32], [x#8, y#9, z#10], obj#36: $line15.$read$$iw$$iw$Foo
//   +- *Sort [value#32 ASC NULLS FIRST], false, 0
//      +- Exchange hashpartitioning(value#32, 200)
//         +- AppendColumns <function1>, newInstance(class $line15.$read$$iw$$iw$Foo), [staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, Java.lang.String, true], true, false) AS value#32]
//            +- *Project [_1#4 AS x#8, _2#5 AS y#9, _3#6 AS z#10]
//               +- *SerializeFromObject [staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._1, true, false) AS _1#4, staticinvoke(class org.Apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple3, true])._2, true, false) AS _2#5, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#6]
//                  +- Scan ExternalRDDScan[obj#3]

groupByKey + mapGroups

38
zero323