web-dev-qa-db-fra.com

Pyspark DataFrame UDF sur la colonne de texte

J'essaie de faire du nettoyage de texte NLP de certaines colonnes Unicode dans un PySpark DataFrame. J'ai essayé dans Spark 1.3, 1.5 et 1.6 et je n'arrive pas à faire fonctionner les choses pour la vie de moi. J'ai également essayé d'utiliser Python 2.7 et Python 3.4.

J'ai créé un udf extrêmement simple comme vu ci-dessous qui devrait simplement renvoyer une chaîne pour chaque enregistrement dans une nouvelle colonne. D'autres fonctions manipuleront le texte, puis renverront le texte modifié dans une nouvelle colonne.

import pyspark
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.sql import SQLContext
from pyspark.sql.functions import udf

def dummy_function(data_str):
    cleaned_str = 'dummyData' 
    return cleaned_str

dummy_function_udf = udf(dummy_function, StringType())

Certains exemples de données peuvent être décompressés de ici.

Voici le code que j'utilise pour importer les données et ensuite appliquer le udf sur.

# Load a text file and convert each line to a Row.
lines = sc.textFile("classified_tweets.txt")
parts = lines.map(lambda l: l.split("\t"))
training = parts.map(lambda p: (p[0], p[1]))

# Create dataframe
training_df = sqlContext.createDataFrame(training, ["Tweet", "classification"])

training_df.show(5)
+--------------------+--------------+
|               Tweet|classification|
+--------------------+--------------+
|rt @jiffyclub: wi...|        python|
|rt @arnicas: ipyt...|        python|
|rt @treycausey: i...|        python|
|what's my best op...|        python|
|rt @raymondh: #py...|        python|
+--------------------+--------------+

# Apply UDF function
df = training_df.withColumn("dummy", dummy_function_udf(training_df['Tweet']))
df.show(5)

Lorsque j'exécute le df.show (5), j'obtiens l'erreur suivante. Je comprends que le problème ne provient probablement pas de show () mais la trace ne m'aide pas beaucoup.

 ---------------------------------------------------------------------------Py4JJavaError                             Traceback (most recent call last)<ipython-input-19-0b21c233c724> in <module>()
      1 df = training_df.withColumn("dummy", dummy_function_udf(training_df['Tweet']))
----> 2 df.show(5)
/Users/dreyco676/spark-1.6.0-bin-hadoop2.6/python/pyspark/sql/dataframe.py in show(self, n, truncate)
    255         +---+-----+
    256         """
--> 257         print(self._jdf.showString(n, truncate))
    258 
    259     def __repr__(self):
/Users/dreyco676/spark-1.6.0-bin-hadoop2.6/python/lib/py4j-0.9-src.Zip/py4j/Java_gateway.py in __call__(self, *args)
    811         answer = self.gateway_client.send_command(command)
    812         return_value = get_return_value(
--> 813             answer, self.gateway_client, self.target_id, self.name)
    814 
    815         for temp_arg in temp_args:
/Users/dreyco676/spark-1.6.0-bin-hadoop2.6/python/pyspark/sql/utils.py in deco(*a, **kw)
     43     def deco(*a, **kw):
     44         try:
---> 45             return f(*a, **kw)
     46         except py4j.protocol.Py4JJavaError as e:
     47             s = e.Java_exception.toString()
/Users/dreyco676/spark-1.6.0-bin-hadoop2.6/python/lib/py4j-0.9-src.Zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    306                 raise Py4JJavaError(
    307                     "An error occurred while calling {0}{1}{2}.\n".
--> 308                     format(target_id, ".", name), value)
    309             else:
    310                 raise Py4JError(
Py4JJavaError: An error occurred while calling o474.showString.
: org.Apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 10.0 failed 1 times, most recent failure: Lost task 0.0 in stage 10.0 (TID 10, localhost): org.Apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/dreyco676/spark-1.6.0-bin-hadoop2.6/python/lib/pyspark.Zip/pyspark/worker.py", line 111, in main
    process()
  File "/Users/dreyco676/spark-1.6.0-bin-hadoop2.6/python/lib/pyspark.Zip/pyspark/worker.py", line 106, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/Users/dreyco676/spark-1.6.0-bin-hadoop2.6/python/lib/pyspark.Zip/pyspark/serializers.py", line 263, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "<ipython-input-12-4bc30395aac5>", line 4, in <lambda>
IndexError: list index out of range

    at org.Apache.spark.api.python.PythonRunner$$anon$1.read(PythonRDD.scala:166)
    at org.Apache.spark.api.python.PythonRunner$$anon$1.next(PythonRDD.scala:129)
    at org.Apache.spark.api.python.PythonRunner$$anon$1.next(PythonRDD.scala:125)
    at org.Apache.spark.InterruptibleIterator.next(InterruptibleIterator.scala:43)
    at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
    at scala.collection.Iterator$GroupedIterator.takeDestructively(Iterator.scala:913)
    at scala.collection.Iterator$GroupedIterator.go(Iterator.scala:929)
    at scala.collection.Iterator$GroupedIterator.fill(Iterator.scala:968)
    at scala.collection.Iterator$GroupedIterator.hasNext(Iterator.scala:972)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
    at scala.collection.Iterator$class.foreach(Iterator.scala:727)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
    at org.Apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:452)
    at org.Apache.spark.api.python.PythonRunner$WriterThread$$anonfun$run$3.apply(PythonRDD.scala:280)
    at org.Apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1741)
    at org.Apache.spark.api.python.PythonRunner$WriterThread.run(PythonRDD.scala:239)

Driver stacktrace:
    at org.Apache.spark.scheduler.DAGScheduler.org$Apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1431)
    at org.Apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1419)
    at org.Apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1418)
    at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
    at org.Apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1418)
    at org.Apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:799)
    at org.Apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:799)
    at scala.Option.foreach(Option.scala:236)
    at org.Apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:799)
    at org.Apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1640)
    at org.Apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1599)
    at org.Apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1588)
    at org.Apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
    at org.Apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:620)
    at org.Apache.spark.SparkContext.runJob(SparkContext.scala:1832)
    at org.Apache.spark.SparkContext.runJob(SparkContext.scala:1845)
    at org.Apache.spark.SparkContext.runJob(SparkContext.scala:1858)
    at org.Apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:212)
    at org.Apache.spark.sql.execution.Limit.executeCollect(basicOperators.scala:165)
    at org.Apache.spark.sql.execution.SparkPlan.executeCollectPublic(SparkPlan.scala:174)
    at org.Apache.spark.sql.DataFrame$$anonfun$org$Apache$spark$sql$DataFrame$$execute$1$1.apply(DataFrame.scala:1538)
    at org.Apache.spark.sql.DataFrame$$anonfun$org$Apache$spark$sql$DataFrame$$execute$1$1.apply(DataFrame.scala:1538)
    at org.Apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:56)
    at org.Apache.spark.sql.DataFrame.withNewExecutionId(DataFrame.scala:2125)
    at org.Apache.spark.sql.DataFrame.org$Apache$spark$sql$DataFrame$$execute$1(DataFrame.scala:1537)
    at org.Apache.spark.sql.DataFrame.org$Apache$spark$sql$DataFrame$$collect(DataFrame.scala:1544)
    at org.Apache.spark.sql.DataFrame$$anonfun$head$1.apply(DataFrame.scala:1414)
    at org.Apache.spark.sql.DataFrame$$anonfun$head$1.apply(DataFrame.scala:1413)
    at org.Apache.spark.sql.DataFrame.withCallback(DataFrame.scala:2138)
    at org.Apache.spark.sql.DataFrame.head(DataFrame.scala:1413)
    at org.Apache.spark.sql.DataFrame.take(DataFrame.scala:1495)
    at org.Apache.spark.sql.DataFrame.showString(DataFrame.scala:171)
    at Sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at Sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.Java:62)
    at Sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.Java:43)
    at Java.lang.reflect.Method.invoke(Method.Java:497)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.Java:231)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.Java:381)
    at py4j.Gateway.invoke(Gateway.Java:259)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.Java:133)
    at py4j.commands.CallCommand.execute(CallCommand.Java:79)
    at py4j.GatewayConnection.run(GatewayConnection.Java:209)
    at Java.lang.Thread.run(Thread.Java:745)
Caused by: org.Apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/dreyco676/spark-1.6.0-bin-hadoop2.6/python/lib/pyspark.Zip/pyspark/worker.py", line 111, in main
    process()
  File "/Users/dreyco676/spark-1.6.0-bin-hadoop2.6/python/lib/pyspark.Zip/pyspark/worker.py", line 106, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/Users/dreyco676/spark-1.6.0-bin-hadoop2.6/python/lib/pyspark.Zip/pyspark/serializers.py", line 263, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "<ipython-input-12-4bc30395aac5>", line 4, in <lambda>
IndexError: list index out of range

    at org.Apache.spark.api.python.PythonRunner$$anon$1.read(PythonRDD.scala:166)
    at org.Apache.spark.api.python.PythonRunner$$anon$1.next(PythonRDD.scala:129)
    at org.Apache.spark.api.python.PythonRunner$$anon$1.next(PythonRDD.scala:125)
    at org.Apache.spark.InterruptibleIterator.next(InterruptibleIterator.scala:43)
    at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
    at scala.collection.Iterator$GroupedIterator.takeDestructively(Iterator.scala:913)
    at scala.collection.Iterator$GroupedIterator.go(Iterator.scala:929)
    at scala.collection.Iterator$GroupedIterator.fill(Iterator.scala:968)
    at scala.collection.Iterator$GroupedIterator.hasNext(Iterator.scala:972)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:327)
    at scala.collection.Iterator$class.foreach(Iterator.scala:727)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
    at org.Apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:452)
    at org.Apache.spark.api.python.PythonRunner$WriterThread$$anonfun$run$3.apply(PythonRDD.scala:280)
    at org.Apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1741)
    at org.Apache.spark.api.python.PythonRunner$WriterThread.run(PythonRDD.scala:239)

Fonction réelle que j'essaie:

def tag_and_remove(data_str):
    cleaned_str = ' '
    # noun tags
    nn_tags = ['NN', 'NNP', 'NNP', 'NNPS', 'NNS']
    # adjectives
    jj_tags = ['JJ', 'JJR', 'JJS']
    # verbs
    vb_tags = ['VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ']
    nltk_tags = nn_tags + jj_tags + vb_tags

    # break string into 'words'
    text = data_str.split()

    # tag the text and keep only those with the right tags
    tagged_text = pos_tag(text)
    for tagged_Word in tagged_text:
        if tagged_Word[1] in nltk_tags:
            cleaned_str += tagged_Word[0] + ' '

    return cleaned_str


tag_and_remove_udf = udf(tag_and_remove, StringType())
13
dreyco676

Votre jeu de données n'est pas propre. 985 lignes split('\t') à une seule valeur:

>>> from operator import add
>>> lines = sc.textFile("classified_tweets.txt")
>>> parts = lines.map(lambda l: l.split("\t"))
>>> parts.map(lambda l: (len(l), 1)).reduceByKey(add).collect()
[(2, 149195), (1, 985)]
>>> parts.filter(lambda l: len(l) == 1).take(5)
[['"show me the money!”  at what point do you start trying to monetize your #startup? Tweet us with #startuplife.'],
 ['a good pitch can mean money in the bank for your #startup. see how body language plays a key role:  (via: ajalumnify)'],
 ['100+ apps in five years? @2359media did it using Microsoft #Azure:  #azureapps'],
 ['does buying better coffee make you a better leader? little things can make a big difference:  (via: @jmbrandonbb)'],
 ['.@msftventures graduates pitched\xa0#homeautomation #startups to #vcs! check out how they celebrated: ']]

Changer ainsi votre code en:

>>> training = parts.filter(lambda l: len(l) == 2).map(lambda p: (p[0], p[1].strip()))
>>> training_df = sqlContext.createDataFrame(training, ["Tweet", "classification"])
>>> df = training_df.withColumn("dummy", dummy_function_udf(training_df['Tweet']))
>>> df.show(5)
+--------------------+--------------+---------+
|               Tweet|classification|    dummy|
+--------------------+--------------+---------+
|rt @jiffyclub: wi...|        python|dummyData|
|rt @arnicas: ipyt...|        python|dummyData|
|rt @treycausey: i...|        python|dummyData|
|what's my best op...|        python|dummyData|
|rt @raymondh: #py...|        python|dummyData|
+--------------------+--------------+---------+
only showing top 5 rows
9
AChampion

Je pense que vous définissez mal le problème et simplifiez peut-être votre lambda aux fins de cette question tout en cachant le vrai problème.

La trace de votre pile se lit

File "<ipython-input-12-4bc30395aac5>", line 4, in <lambda>
IndexError: list index out of range

Lorsque j'exécute ce code, cela fonctionne très bien:

import pyspark
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.sql import SQLContext
from pyspark.sql.functions import udf

training_df = sqlContext.sql("select 'foo' as Tweet, 'bar' as classification")

def dummy_function(data_str):
     cleaned_str = 'dummyData'
     return cleaned_str

dummy_function_udf = udf(dummy_function, StringType())
df = training_df.withColumn("dummy", dummy_function_udf(training_df['Tweet']))
df.show()

+-----+--------------+---------+
|Tweet|classification|    dummy|
+-----+--------------+---------+
|  foo|           bar|dummyData|
+-----+--------------+---------+

Êtes-vous sûr qu'il n'y a pas d'autre bogue dans votre dummy_function_udf? Quel est le "vrai" udf que vous utilisez - à part cet exemple de version?

7
Kirk Broadhurst

Ci-dessous on travaille avec le spark2,

import hashlib
import uuid
import datetime
from pyspark.sql.types import StringType

def customencoding(s):
    m = hashlib.md5()
    m.update(s.encode('utf-8'))
    d = m.hexdigest()
    return d

spark.udf.register("udf_customhashing32adadf", customencoding, StringType())

spark.sql("SELECT udf_customhashing32adadf('test') as rowid").show(10, False)

Vous pouvez l'implémenter de la même manière.

0
Vicky