web-dev-qa-db-fra.com

SPARK SQL - met à jour la table MySql à l'aide de DataFrames et JDBC

J'essaie d'insérer et de mettre à jour des données sur MySql à l'aide de Spark SQL DataFrames et de la connexion JDBC.

J'ai réussi à insérer de nouvelles données à l'aide de SaveMode.Append. Existe-t-il un moyen de mettre à jour les données déjà existantes dans MySql Table à partir de Spark SQL?

Mon code à insérer est:

myDataFrame.write.mode(SaveMode.Append).jdbc(JDBCurl,mySqlTable,connectionProperties)

Si je change en SaveMode.Overwrite supprime la table entière et en crée une nouvelle, je cherche quelque chose comme "ON DUPLICATE KEY UPDATE" disponible dans MySql

16
nicola

Ce n'est pas possible. Pour l'instant (Spark 1.6.0/2.2.0 INSTANTANÉ) Spark DataFrameWriter prend en charge que quatre modes d'écriture:

  • SaveMode.Overwrite: écraser les données existantes.
  • SaveMode.Append: ajoute les données.
  • SaveMode.Ignore: ignore l'opération (c.-à-d. no-op).
  • SaveMode.ErrorIfExists: option par défaut, émet une exception lors de l'exécution.

Vous pouvez insérer manuellement, par exemple, à l'aide de mapPartitions (puisque vous souhaitez qu'une opération UPSERT soit idempotente et facile à implémenter), écrivez dans une table temporaire et exécutez upsert manuellement ou utilisez des déclencheurs.

En règle générale, il est loin d’être trivial d’obtenir un comportement d’extraction pour les opérations par lots et de conserver des performances décentes. N'oubliez pas qu'en général, il y aura plusieurs transactions simultanées en place (une par partition), vous devez donc vous assurer qu'il n'y aura pas de conflit d'écriture (généralement en utilisant un partitionnement spécifique à l'application) ou fournir les procédures de récupération appropriées. En pratique, il peut être préférable d’effectuer et d’écrire par lots sur une table temporaire et de résoudre le composant upsert directement dans la base de données.

18
zero323

la réponse de zero323 est correcte, je voulais juste ajouter que vous pouvez utiliser le package JayDeBeApi pour résoudre ce problème: https://pypi.python.org/pypi/JayDeBeApi/

mettre à jour les données dans votre table mysql. C'est peut-être un fruit facile à vivre puisque vous avez déjà le pilote mysql jdbc installé.

Le module JayDeBeApi vous permet de vous connecter du code Python à bases de données utilisant Java JDBC. Il fournit un DB-API v2.0 Python à cela base de données.

Nous utilisons la distribution Anaconda de Python et le package JayDeBeApi python est livré en standard.

Voir les exemples dans ce lien ci-dessus.

0
Tagar

Dommage qu'il n'y ait pas de mode SaveMode.Upsert dans Spark pour des cas aussi courants que celui d'uperting.

zero322 est correct en général, mais je pense qu’il devrait être possible (avec des compromis de performances) d’offrir une telle fonctionnalité de remplacement.

Je souhaitais également fournir du code Java pour ce cas ..__ Bien sûr, il n’est pas aussi performant que celui intégré à spark, mais il devrait constituer une bonne base pour vos besoins. Il suffit de le modifier en fonction de vos besoins: 

myDF.repartition(20); //one connection per partition, see below

myDF.foreachPartition((Iterator<Row> t) -> {
            Connection conn = DriverManager.getConnection(
                    Constants.DB_JDBC_CONN,
                    Constants.DB_JDBC_USER,
                    Constants.DB_JDBC_PASS);

            conn.setAutoCommit(true);
            Statement statement = conn.createStatement();

            final int batchSize = 100000;
            int i = 0;
            while (t.hasNext()) {
                Row row = t.next();
                try {
                    // better than REPLACE INTO, less cycles
                    statement.addBatch(("INSERT INTO mytable " + "VALUES ("
                            + "'" + row.getAs("_id") + "', 
                            + "'" + row.getStruct(1).get(0) + "'
                            + "')  ON DUPLICATE KEY UPDATE _id='" + row.getAs("_id") + "';"));
                    //conn.commit();

                    if (++i % batchSize == 0) {
                        statement.executeBatch();
                    }
                } catch (SQLIntegrityConstraintViolationException e) {
                    //should not occur, nevertheless
                    //conn.commit();
                } catch (SQLException e) {
                    e.printStackTrace();
                } finally {
                    //conn.commit();
                    statement.executeBatch();
                }
            }
            int[] ret = statement.executeBatch();

            System.out.println("Ret val: " + Arrays.toString(ret));
            System.out.println("Update count: " + statement.getUpdateCount());
            conn.commit();

            statement.close();
            conn.close();
0
Aydin K.

écraser org.Apache.spark.sql.execution.datasources.jdbc JdbcUtils.scala "insérer dans" pour "remplacer dans"

import Java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, SQLException}

import scala.collection.JavaConverters._
import scala.util.control.NonFatal
import com.typesafe.scalalogging.Logger
import org.Apache.spark.sql.catalyst.InternalRow
import org.Apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, DriverWrapper, JDBCOptions}
import org.Apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.Apache.spark.sql.types._
import org.Apache.spark.sql.{DataFrame, Row}

/**
  * Util functions for JDBC tables.
  */
object UpdateJdbcUtils {

  val logger = Logger(this.getClass)

  /**
    * Returns a factory for creating connections to the given JDBC URL.
    *
    * @param options - JDBC options that contains url, table and other information.
    */
  def createConnectionFactory(options: JDBCOptions): () => Connection = {
    val driverClass: String = options.driverClass
    () => {
      DriverRegistry.register(driverClass)
      val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
        case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
        case d if d.getClass.getCanonicalName == driverClass => d
      }.getOrElse {
        throw new IllegalStateException(
          s"Did not find registered driver with class $driverClass")
      }
      driver.connect(options.url, options.asConnectionProperties)
    }
  }

  /**
    * Returns a PreparedStatement that inserts a row into table via conn.
    */
  def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
  : PreparedStatement = {
    val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
    val sql = s"REPLACE INTO $table ($columns) VALUES ($placeholders)"
    conn.prepareStatement(sql)
  }

  /**
    * Retrieve standard jdbc types.
    *
    * @param dt The datatype (e.g. [[org.Apache.spark.sql.types.StringType]])
    * @return The default JdbcType for this DataType
    */
  def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
    dt match {
      case IntegerType => Option(JdbcType("INTEGER", Java.sql.Types.INTEGER))
      case LongType => Option(JdbcType("BIGINT", Java.sql.Types.BIGINT))
      case DoubleType => Option(JdbcType("DOUBLE PRECISION", Java.sql.Types.DOUBLE))
      case FloatType => Option(JdbcType("REAL", Java.sql.Types.FLOAT))
      case ShortType => Option(JdbcType("INTEGER", Java.sql.Types.SMALLINT))
      case ByteType => Option(JdbcType("BYTE", Java.sql.Types.TINYINT))
      case BooleanType => Option(JdbcType("BIT(1)", Java.sql.Types.BIT))
      case StringType => Option(JdbcType("TEXT", Java.sql.Types.CLOB))
      case BinaryType => Option(JdbcType("BLOB", Java.sql.Types.BLOB))
      case TimestampType => Option(JdbcType("TIMESTAMP", Java.sql.Types.TIMESTAMP))
      case DateType => Option(JdbcType("DATE", Java.sql.Types.DATE))
      case t: DecimalType => Option(
        JdbcType(s"DECIMAL(${t.precision},${t.scale})", Java.sql.Types.DECIMAL))
      case _ => None
    }
  }

  private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
    dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
      throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
  }

  // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
  // for `MutableRow`. The last argument `Int` means the index for the value to be set in
  // the row and also used for the value in `ResultSet`.
  private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit

  // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
  // `PreparedStatement`. The last argument `Int` means the index for the value to be set
  // in the SQL statement and also used for the value in `Row`.
  private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit

  /**
    * Saves a partition of a DataFrame to the JDBC database.  This is done in
    * a single database transaction (unless isolation level is "NONE")
    * in order to avoid repeatedly inserting data as much as possible.
    *
    * It is still theoretically possible for rows in a DataFrame to be
    * inserted into the database more than once if a stage somehow fails after
    * the commit occurs but before the stage can return successfully.
    *
    * This is not a closure inside saveTable() because apparently cosmetic
    * implementation changes elsewhere might easily render such a closure
    * non-Serializable.  Instead, we explicitly close over all variables that
    * are used.
    */
  def savePartition(
                     getConnection: () => Connection,
                     table: String,
                     iterator: Iterator[Row],
                     rddSchema: StructType,
                     nullTypes: Array[Int],
                     batchSize: Int,
                     dialect: JdbcDialect,
                     isolationLevel: Int): Iterator[Byte] = {
    val conn = getConnection()
    var committed = false

    var finalIsolationLevel = Connection.TRANSACTION_NONE
    if (isolationLevel != Connection.TRANSACTION_NONE) {
      try {
        val metadata = conn.getMetaData
        if (metadata.supportsTransactions()) {
          // Update to at least use the default isolation, if any transaction level
          // has been chosen and transactions are supported
          val defaultIsolation = metadata.getDefaultTransactionIsolation
          finalIsolationLevel = defaultIsolation
          if (metadata.supportsTransactionIsolationLevel(isolationLevel)) {
            // Finally update to actually requested level if possible
            finalIsolationLevel = isolationLevel
          } else {
            logger.warn(s"Requested isolation level $isolationLevel is not supported; " +
              s"falling back to default isolation level $defaultIsolation")
          }
        } else {
          logger.warn(s"Requested isolation level $isolationLevel, but transactions are unsupported")
        }
      } catch {
        case NonFatal(e) => logger.warn("Exception while detecting transaction support", e)
      }
    }
    val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE

    try {
      if (supportsTransactions) {
        conn.setAutoCommit(false) // Everything in the same db transaction.
        conn.setTransactionIsolation(finalIsolationLevel)
      }
      val stmt = insertStatement(conn, table, rddSchema, dialect)
      val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
        .map(makeSetter(conn, dialect, _))
      val numFields = rddSchema.fields.length

      try {
        var rowCount = 0
        while (iterator.hasNext) {
          val row = iterator.next()
          var i = 0
          while (i < numFields) {
            if (row.isNullAt(i)) {
              stmt.setNull(i + 1, nullTypes(i))
            } else {
              setters(i).apply(stmt, row, i)
            }
            i = i + 1
          }
          stmt.addBatch()
          rowCount += 1
          if (rowCount % batchSize == 0) {
            stmt.executeBatch()
            rowCount = 0
          }
        }
        if (rowCount > 0) {
          stmt.executeBatch()
        }
      } finally {
        stmt.close()
      }
      if (supportsTransactions) {
        conn.commit()
      }
      committed = true
      Iterator.empty
    } catch {
      case e: SQLException =>
        val cause = e.getNextException
        if (cause != null && e.getCause != cause) {
          if (e.getCause == null) {
            e.initCause(cause)
          } else {
            e.addSuppressed(cause)
          }
        }
        throw e
    } finally {
      if (!committed) {
        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        if (supportsTransactions) {
          conn.rollback()
        }
        conn.close()
      } else {
        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try {
          conn.close()
        } catch {
          case e: Exception => logger.warn("Transaction succeeded, but closing failed", e)
        }
      }
    }
  }

  /**
    * Saves the RDD to the database in a single transaction.
    */
  def saveTable(
                 df: DataFrame,
                 url: String,
                 table: String,
                 options: JDBCOptions) {
    val dialect = JdbcDialects.get(url)
    val nullTypes: Array[Int] = df.schema.fields.map { field =>
      getJdbcType(field.dataType, dialect).jdbcNullType
    }

    val rddSchema = df.schema
    val getConnection: () => Connection = createConnectionFactory(options)
    val batchSize = options.batchSize
    val isolationLevel = options.isolationLevel
    df.foreachPartition(iterator => savePartition(
      getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
    )
  }

  private def makeSetter(
                          conn: Connection,
                          dialect: JdbcDialect,
                          dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))

    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))

    case DoubleType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDouble(pos + 1, row.getDouble(pos))

    case FloatType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setFloat(pos + 1, row.getFloat(pos))

    case ShortType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getShort(pos))

    case ByteType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getByte(pos))

    case BooleanType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBoolean(pos + 1, row.getBoolean(pos))

    case StringType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setString(pos + 1, row.getString(pos))

    case BinaryType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))

    case TimestampType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setTimestamp(pos + 1, row.getAs[Java.sql.Timestamp](pos))

    case DateType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDate(pos + 1, row.getAs[Java.sql.Date](pos))

    case t: DecimalType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

    case ArrayType(et, _) =>
      // remove type length parameters from end of type name
      val typeName = getJdbcType(et, dialect).databaseTypeDefinition
        .toLowerCase.split("\\(")(0)
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        val array = conn.createArrayOf(
          typeName,
          row.getSeq[AnyRef](pos).toArray)
        stmt.setArray(pos + 1, array)

    case _ =>
      (_: PreparedStatement, _: Row, pos: Int) =>
        throw new IllegalArgumentException(
          s"Can't translate non-null value for field $pos")
  }
}

usage:

val url = s"jdbc:mysql://$Host/$database?useUnicode=true&characterEncoding=UTF-8"

val parameters: Map[String, String] = Map(
  "url" -> url,
  "dbtable" -> table,
  "driver" -> "com.mysql.jdbc.Driver",
  "numPartitions" -> numPartitions.toString,
  "user" -> user,
  "password" -> password
)
val options = new JDBCOptions(parameters)

for (d <- data) {
  UpdateJdbcUtils.saveTable(d, url, table, options)
}

ps: faites attention à l'impasse, ne mettez pas à jour les données fréquemment, utilisez-le simplement pour le ré-exécuter en cas d'urgence, je pense que c'est pour cette raison que spark n'appuie pas ce responsable.

0
user1442346