From 4bd6097106857114dd1fd95faef24cf900961198 Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Mon, 30 Jun 2025 11:19:58 +0800 Subject: [PATCH] fix NPE when initializing non-deterministic UDF wrapper in driver side --- .../org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala | 5 ++++- .../org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala | 5 ++++- .../org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala | 5 ++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala index 26c4a2a73..f49034b1a 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDAFWrapperContext.scala @@ -79,7 +79,10 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) extends Logging { // initialize all nondeterministic children exprs expr.foreach { case nondeterministic: Nondeterministic => - nondeterministic.initialize(TaskContext.get.partitionId()) + nondeterministic.initialize(TaskContext.get match { + case tc: TaskContext => tc.partitionId() + case null => 0 + }) case _ => } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala index d078528fb..d4fe5186c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala @@ -46,7 +46,10 @@ case class SparkUDFWrapperContext(serialized: ByteBuffer) extends Logging { // initialize all nondeterministic children exprs expr.foreach { case nondeterministic: Nondeterministic => - nondeterministic.initialize(TaskContext.get.partitionId()) + nondeterministic.initialize(TaskContext.get match { + case tc: TaskContext => tc.partitionId() + case null => 0 + }) case _ => } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala index 24ad65fc1..c9003313b 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDTFWrapperContext.scala @@ -49,7 +49,10 @@ case class SparkUDTFWrapperContext(serialized: ByteBuffer) extends Logging { // initialize all nondeterministic children exprs expr.foreach { case nondeterministic: Nondeterministic => - nondeterministic.initialize(TaskContext.get.partitionId()) + nondeterministic.initialize(TaskContext.get match { + case tc: TaskContext => tc.partitionId() + case null => 0 + }) case _ => }