diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index 26c4a2a73..f49034b1a 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -79,7 +79,10 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) extends Logging { // initialize all nondeterministic children exprs expr.foreach { case nondeterministic: Nondeterministic => - nondeterministic.initialize(TaskContext.get.partitionId()) + nondeterministic.initialize(TaskContext.get match { + case tc: TaskContext => tc.partitionId() + case null => 0 + }) case _ => } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala index d078528fb..d4fe5186c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala @@ -46,7 +46,10 @@ case class SparkUDFWrapperContext(serialized: ByteBuffer) extends Logging { // initialize all nondeterministic children exprs expr.foreach { case nondeterministic: Nondeterministic => - nondeterministic.initialize(TaskContext.get.partitionId()) + nondeterministic.initialize(TaskContext.get match { + case tc: TaskContext => tc.partitionId() + case null => 0 + }) case _ => } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala index 24ad65fc1..c9003313b 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala @@ -49,7 +49,10 @@ case class SparkUDTFWrapperContext(serialized: ByteBuffer) extends Logging { // initialize all nondeterministic children exprs expr.foreach { case nondeterministic: Nondeterministic => - nondeterministic.initialize(TaskContext.get.partitionId()) + nondeterministic.initialize(TaskContext.get match { + case tc: TaskContext => tc.partitionId() + case null => 0 + }) case _ => }