web-dev-qa-db-fra.com

Egalité de DataFrame dans Apache Spark

Supposons que df1 et df2 sont deux DataFrames dans Apache Spark, calculés à l'aide de deux mécanismes différents, par exemple Spark SQL et l'API Scala/Java/Python. 

Existe-t-il un moyen idiomatique de déterminer si les deux trames de données sont équivalentes (égales, isomorphes), l'équivalence étant déterminée par le fait que les données (noms de colonne et valeurs de colonne pour chaque ligne) sont identiques, à l'exception du classement des lignes et des colonnes?

La motivation de la question est qu’il existe souvent de nombreuses façons de calculer des résultats de données volumineuses, chacune avec ses propres compromis. Lorsque l'on explore ces compromis, il est important de maintenir l'exactitude et, partant, la nécessité de vérifier l'équivalence/l'égalité d'un ensemble de données de test significatif.

17
Sim

Il existe des méthodes standard dans les suites de tests Apache Spark, mais la plupart d'entre elles impliquent de collecter les données localement. Si vous souhaitez effectuer des tests d'égalité sur de grandes trames de données, cette solution n'est probablement pas adaptée.

En vérifiant d'abord le schéma, vous pouvez ensuite effectuer une intersection vers df3 et vérifier que le nombre de df1, df2 et df3 est égal (toutefois, cela ne fonctionne que s'il n'y a pas de lignes dupliquées, s'il existe différentes lignes dupliquées, cette méthode pourrait toujours retourne vrai).

Une autre option consisterait à extraire les RDD sous-jacents des deux DataFrames, à mapper sur (Ligne, 1), à effectuer une reduction pour Compter le nombre de chaque ligne, puis à co-regrouper les deux RDD résultants, puis à effectuer un agrégat régulier et à renvoyer false si aucun des itérateurs ne sont égaux.

9
Holden

Je ne connais pas l'idiome, mais je pense que vous pouvez obtenir un moyen robuste de comparer les DataFrames comme vous le décrivez comme suit. (J'utilise PySpark à titre d'illustration, mais l'approche utilise plusieurs langues.)

a = spark.range(5)
b = spark.range(5)

a_prime = a.groupBy(sorted(a.columns)).count()
b_prime = b.groupBy(sorted(b.columns)).count()

assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0

Cette approche gère correctement les cas où les DataFrames peuvent avoir des lignes en double, des lignes dans différents ordres et/ou des colonnes dans des ordres différents.

Par exemple:

a = spark.createDataFrame([('nick', 30), ('bob', 40)], ['name', 'age'])
b = spark.createDataFrame([(40, 'bob'), (30, 'nick')], ['age', 'name'])
c = spark.createDataFrame([('nick', 30), ('bob', 40), ('nick', 30)], ['name', 'age'])

a_prime = a.groupBy(sorted(a.columns)).count()
b_prime = b.groupBy(sorted(b.columns)).count()
c_prime = c.groupBy(sorted(c.columns)).count()

assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0
assert a_prime.subtract(c_prime).count() != 0

Cette approche est assez coûteuse, mais la plupart des dépenses sont inévitables étant donné la nécessité de réaliser un diff complet. Et cela devrait aller très bien car il n’est pas nécessaire de collecter quoi que ce soit localement. Si vous assouplissez la contrainte voulant que la comparaison prenne en compte les lignes en double, vous pouvez alors supprimer la fonction groupBy() et effectuer uniquement la fonction subtract(), ce qui accélérerait probablement considérablement les choses.

8
Nick Chammas

La bibliothèque spark-fast-tests dispose de deux méthodes pour effectuer des comparaisons DataFrame (je suis le créateur de la bibliothèque):

La méthode assertSmallDataFrameEquality collecte les DataFrames sur le nœud du pilote et effectue la comparaison.

def assertSmallDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
  if (!actualDF.schema.equals(expectedDF.schema)) {
    throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
  }
  if (!actualDF.collect().sameElements(expectedDF.collect())) {
    throw new DataFrameContentMismatch(contentMismatchMessage(actualDF, expectedDF))
  }
}

La méthode assertLargeDataFrameEquality compare les DataFrames répartis sur plusieurs machines (le code est essentiellement copié à partir de spark-testing-base )

def assertLargeDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
  if (!actualDF.schema.equals(expectedDF.schema)) {
    throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
  }
  try {
    actualDF.rdd.cache
    expectedDF.rdd.cache

    val actualCount = actualDF.rdd.count
    val expectedCount = expectedDF.rdd.count
    if (actualCount != expectedCount) {
      throw new DataFrameContentMismatch(countMismatchMessage(actualCount, expectedCount))
    }

    val expectedIndexValue = zipWithIndex(actualDF.rdd)
    val resultIndexValue = zipWithIndex(expectedDF.rdd)

    val unequalRDD = expectedIndexValue
      .join(resultIndexValue)
      .filter {
        case (idx, (r1, r2)) =>
          !(r1.equals(r2) || RowComparer.areRowsEqual(r1, r2, 0.0))
      }

    val maxUnequalRowsToShow = 10
    assertEmpty(unequalRDD.take(maxUnequalRowsToShow))

  } finally {
    actualDF.rdd.unpersist()
    expectedDF.rdd.unpersist()
  }
}

assertSmallDataFrameEquality est plus rapide pour les petites comparaisons DataFrame et je l'ai trouvé suffisant pour mes suites de tests.

4
Powers

Java:

assert resultDs.union(answerDs).distinct().count() == resultDs.intersect(answerDs).count();
2
user1442346

Vous pouvez le faire en utilisant un peu de déduplication en combinaison avec une jointure externe complète. L’avantage de cette approche est qu’elle ne nécessite pas la collecte de résultats auprès du pilote et évite d’exécuter plusieurs tâches.

import org.Apache.spark.sql._
import org.Apache.spark.sql.functions._

// Generate some random data.
def random(n: Int, s: Long) = {
  spark.range(n).select(
    (Rand(s) * 10000).cast("int").as("a"),
    (Rand(s + 5) * 1000).cast("int").as("b"))
}
val df1 = random(10000000, 34)
val df2 = random(10000000, 17)

// Move all the keys into a struct (to make handling nulls easy), deduplicate the given dataset
// and count the rows per key.
def dedup(df: Dataset[Row]): Dataset[Row] = {
  df.select(struct(df.columns.map(col): _*).as("key"))
    .groupBy($"key")
    .agg(count(lit(1)).as("row_count"))
}

// Deduplicate the inputs and join them using a full outer join. The result can contain
// the following things:
// 1. Both keys are not null (and thus equal), and the row counts are the same. The dataset
//    is the same for the given key.
// 2. Both keys are not null (and thus equal), and the row counts are not the same. The dataset
//    contains the same keys.
// 3. Only the right key is not null.
// 4. Only the left key is not null.
val joined = dedup(df1).as("l").join(dedup(df2).as("r"), $"l.key" === $"r.key", "full")

// Summarize the differences.
val summary = joined.select(
  count(when($"l.key".isNotNull && $"r.key".isNotNull && $"r.row_count" === $"l.row_count", 1)).as("left_right_same_rc"),
  count(when($"l.key".isNotNull && $"r.key".isNotNull && $"r.row_count" =!= $"l.row_count", 1)).as("left_right_different_rc"),
  count(when($"l.key".isNotNull && $"r.key".isNull, 1)).as("left_only"),
  count(when($"l.key".isNull && $"r.key".isNotNull, 1)).as("right_only"))
summary.show()
0