How to call Scala UDFs from PySpark
Motivation
Suppose we decided to speed up our PySpark job. One possible way to this is to write a Scala UDF. If we implement it in Scala instead of Python, the Spark workers will execute the computation themselves rather than ask Python code to do it, and won’t need to serialize/deserialize the data to the Python component. Double win!
So, we can write a simple Scala object with a single function in it. Then we wrap the function in udf
. I sincerely apologize for the imperative style.
package com.example.spark.udfs
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
object Primes {
def isPrime(x: Int): Boolean = {
if (x < 2) return false
var i = 2
while (i * i <= x) {
if (x % i == 0) {
return false
}
i += 1
}
return true
}
def isPrimeUDF: UserDefinedFunction = udf { x: Int => isPrime(x) }
}
Save it as src/com/example/spark/udfs/Primes.scala
or under any other convenient name.
Compiling Scala
We don’t need heavyweight IDEs or sbt
to compile such a simple function. Let’s do it from the command line! We would need to use Scala 2.11 when compiling (as of Spark 2.0.1/2.1.0). Download Scala 2.11 and unpack it somewhere (I unpacked it as ~/Downloads/Distribs/scala-2.11.8
).
Setup the environment variables:
SCALA_HOME=$HOME/Downloads/Distribs/scala-2.11.8
PATH=$PATH:$SCALA_HOME/bin
Make sure we also have $SPARK_HOME
set:
$ echo $SPARK_HOME
/Users/sserebryakov/spark-2.0.1-bin-hadoop2.7
Check that we can invoke scalac
:
$ scalac -version
Scala compiler version 2.11.8 -- Copyright 2002-2016, LAMP/EPFL
Go to our project folder and build a JAR. Note that the wildcard is just a star, not *.jar
:
$ scalac -classpath "$SPARK_HOME/jars/*" src/com/example/spark/udfs/Primes.scala -d primes_udf.jar
Running PySpark
Now we can provide the JAR in the classpath for pyspark
:
$ pyspark --jars primes_udf.jar
Now we can use the Scala function like this. The prerequisite is that the data should already be in the DataFrame
format. Be careful with the syntax, the py4j
exceptions are rarely helpful.
from pyspark.sql import Row
from pyspark.sql.column import Column, _to_java_column, _to_seq
a = range(10)
df = sc.parallelize(a).map(lambda x: Row(number=x)).toDF()
scala_udf_is_prime = sc._jvm.com.example.spark.udfs.Primes.isPrimeUDF()
is_prime_column = lambda col: Column(scala_udf_is_prime.apply(_to_seq(sc, [col], _to_java_column)))
df.withColumn('is_prime', is_prime_column(df['number'])).show()
This will print:
+------+--------+
|number|is_prime|
+------+--------+
| 0| false|
| 1| false|
| 2| true|
| 3| true|
| 4| false|
| 5| true|
| 6| false|
| 7| true|
| 8| false|
| 9| false|
+------+--------+
Tested with Spark 2.0.1, but should work for 2.1.0 (the latest as of this writing) as well.
Currying
If we want to initialize the Scala class with some arguments once, and call the function repeatedly, we can use a case class
. Example:
package com.example.spark.udfs
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
case class DistanceComputer(originX: Double, originY: Double) {
def distanceUDF: UserDefinedFunction = udf { (x: Double, y: Double) =>
math.sqrt(math.pow(x - originX, 2) + math.pow(y - originY, 2))
}
}
Registering with PySpark would look like:
scala_udf_distance = sc._jvm.com.example.spark.udfs.DistanceComputer(0.0, 0.0).distanceUDF()
Make sure to conform exactly to the argument types when calling the constructor. A constructor with Double
argument should be called with 0.0
, not 0
.