web-dev-qa-db-fra.com

Comment définir et utiliser une fonction d'agrégation définie par l'utilisateur dans Spark SQL?

Je sais comment écrire un UDF en Spark SQL:

def belowThreshold(power: Int): Boolean = {
        return power < -40
      }

sqlContext.udf.register("belowThreshold", belowThreshold _)

Puis-je faire quelque chose de similaire pour définir une fonction d'agrégation? Comment cela se fait-il?

Pour le contexte, je veux exécuter la requête SQL suivante:

val aggDF = sqlContext.sql("""SELECT span, belowThreshold(opticalReceivePower), timestamp
                                    FROM ifDF
                                    WHERE opticalReceivePower IS NOT null
                                    GROUP BY span, timestamp
                                    ORDER BY span""")

Il devrait renvoyer quelque chose comme

Row(span1, false, T0)

Je veux que la fonction d'agrégation me dise s'il y a des valeurs pour opticalReceivePower dans les groupes définis par span et timestamp qui sont en dessous du seuil. Dois-je écrire mon UDAF différemment de l'UDF que j'ai collé ci-dessus?

36
Rory Byrne

Méthodes prises en charge

Spark> = 2,3

Udf vectorisé (Python uniquement):

from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType

from pyspark.sql.types import *
import pandas as pd

df = sc.parallelize([
    ("a", 0), ("a", 1), ("b", 30), ("b", -50)
]).toDF(["group", "power"])

def below_threshold(threshold, group="group", power="power"):
    @pandas_udf("struct<group: string, below_threshold: boolean>", PandasUDFType.GROUPED_MAP)
    def below_threshold_(df):
        df = pd.DataFrame(
           df.groupby(group).apply(lambda x: (x[power] < threshold).any()))
        df.reset_index(inplace=True, drop=False)
        return df

    return below_threshold_

Exemple d'utilisation:

df.groupBy("group").apply(below_threshold(-40)).show()

## +-----+---------------+
## |group|below_threshold|
## +-----+---------------+
## |    b|           true|
## |    a|          false|
## +-----+---------------+

Voir aussi Application d'UDF sur GroupedData dans PySpark (avec un fonctionnement python)

Spark> = 2.0 (éventuellement 1.6 mais avec une API légèrement différente):

Il est possible d'utiliser Aggregators sur le type Datasets:

import org.Apache.spark.sql.expressions.Aggregator
import org.Apache.spark.sql.{Encoder, Encoders}

class BelowThreshold[I](f: I => Boolean)  extends Aggregator[I, Boolean, Boolean]
    with Serializable {
  def zero = false
  def reduce(acc: Boolean, x: I) = acc | f(x)
  def merge(acc1: Boolean, acc2: Boolean) = acc1 | acc2
  def finish(acc: Boolean) = acc

  def bufferEncoder: Encoder[Boolean] = Encoders.scalaBoolean
  def outputEncoder: Encoder[Boolean] = Encoders.scalaBoolean
}

val belowThreshold = new BelowThreshold[(String, Int)](_._2 < - 40).toColumn
df.as[(String, Int)].groupByKey(_._1).agg(belowThreshold)

Spark> = 1,5 :

Dans Spark 1.5, vous pouvez créer UDAF comme ceci, bien qu'il s'agisse probablement d'une surpuissance:

import org.Apache.spark.sql.expressions._
import org.Apache.spark.sql.types._
import org.Apache.spark.sql.Row

object belowThreshold extends UserDefinedAggregateFunction {
    // Schema you get as an input
    def inputSchema = new StructType().add("power", IntegerType)
    // Schema of the row which is used for aggregation
    def bufferSchema = new StructType().add("ind", BooleanType)
    // Returned type
    def dataType = BooleanType
    // Self-explaining 
    def deterministic = true
    // zero value
    def initialize(buffer: MutableAggregationBuffer) = buffer.update(0, false)
    // Similar to seqOp in aggregate
    def update(buffer: MutableAggregationBuffer, input: Row) = {
        if (!input.isNullAt(0))
          buffer.update(0, buffer.getBoolean(0) | input.getInt(0) < -40)
    }
    // Similar to combOp in aggregate
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      buffer1.update(0, buffer1.getBoolean(0) | buffer2.getBoolean(0))    
    }
    // Called on exit to get return value
    def evaluate(buffer: Row) = buffer.getBoolean(0)
}

Exemple d'utilisation:

df
  .groupBy($"group")
  .agg(belowThreshold($"power").alias("belowThreshold"))
  .show

// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// |    a|         false|
// |    b|          true|
// +-----+--------------+

Solution de contournement Spark 1.4 :

Je ne suis pas sûr de bien comprendre vos besoins, mais pour autant que je sache, une agrégation simple devrait suffire ici:

val df = sc.parallelize(Seq(
    ("a", 0), ("a", 1), ("b", 30), ("b", -50))).toDF("group", "power")

df
  .withColumn("belowThreshold", ($"power".lt(-40)).cast(IntegerType))
  .groupBy($"group")
  .agg(sum($"belowThreshold").notEqual(0).alias("belowThreshold"))
  .show

// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// |    a|         false|
// |    b|          true|
// +-----+--------------+

Spark <= 1.4 :

Pour autant que je sache, en ce moment (Spark 1.4.1), il n'y a pas de support pour UDAF, autre que ceux de Hive. Cela devrait être possible avec Spark 1.5 (voir SPARK-3947 ).

Méthodes non prises en charge/internes

En interne Spark utilise un certain nombre de classes, y compris ImperativeAggregates et DeclarativeAggregates .

Ils sont destinés à un usage interne et peuvent changer sans préavis, donc ce n'est probablement pas quelque chose que vous souhaitez utiliser dans votre code de production, mais juste pour être complet, BelowThreshold avec DeclarativeAggregate pourrait être implémenté comme ceci (testé avec Spark 2.2-SNAPSHOT):

import org.Apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.Apache.spark.sql.catalyst.expressions._
import org.Apache.spark.sql.types._

case class BelowThreshold(child: Expression, threshold: Expression) 
    extends  DeclarativeAggregate  {
  override def children: Seq[Expression] = Seq(child, threshold)

  override def nullable: Boolean = false
  override def dataType: DataType = BooleanType

  private lazy val belowThreshold = AttributeReference(
    "belowThreshold", BooleanType, nullable = false
  )()

  // Used to derive schema
  override lazy val aggBufferAttributes = belowThreshold :: Nil

  override lazy val initialValues = Seq(
    Literal(false)
  )

  override lazy val updateExpressions = Seq(Or(
    belowThreshold,
    If(IsNull(child), Literal(false), LessThan(child, threshold))
  ))

  override lazy val mergeExpressions = Seq(
    Or(belowThreshold.left, belowThreshold.right)
  )

  override lazy val evaluateExpression = belowThreshold
  override def defaultResult: Option[Literal] = Option(Literal(false))
} 

Il devrait être en outre encapsulé avec un équivalent de withAggregateFunction .

74
zero323