diff --git a/pom.xml b/pom.xml
index f7e90886b..4af2d7955 100644
--- a/pom.xml
+++ b/pom.xml
@@ -7,6 +7,7 @@
pom
+ spark-version-annotation-macros
spark-extension
${shimPkg}
hadoop-shim
diff --git a/spark-extension-shims-spark3/pom.xml b/spark-extension-shims-spark3/pom.xml
index f60cb9d41..4dc42fee3 100644
--- a/spark-extension-shims-spark3/pom.xml
+++ b/spark-extension-shims-spark3/pom.xml
@@ -90,12 +90,6 @@
1.12.10
-
- com.thoughtworks.enableIf
- enableif_${scalaVersion}
- 1.2.0
-
-
org.apache.spark
spark-core_${scalaVersion}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/InterceptedValidateSparkPlan.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/InterceptedValidateSparkPlan.scala
index 4d1c8ef69..2b7df301a 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/InterceptedValidateSparkPlan.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/InterceptedValidateSparkPlan.scala
@@ -17,14 +17,11 @@ package org.apache.spark.sql.blaze
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
object InterceptedValidateSparkPlan extends Logging {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
def validate(plan: SparkPlan): Unit = {
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.blaze.plan.NativeRenameColumnsBase
@@ -72,14 +69,12 @@ object InterceptedValidateSparkPlan extends Logging {
}
}
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
def validate(plan: SparkPlan): Unit = {
throw new UnsupportedOperationException("validate is not supported in spark 3.0.3 or 3.1.3")
}
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
private def errorOnInvalidBroadcastQueryStage(plan: SparkPlan): Unit = {
import org.apache.spark.sql.execution.adaptive.InvalidAQEPlanException
throw InvalidAQEPlanException("Invalid broadcast query stage", plan)
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala
index e23310f86..fd96a0bcf 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala
@@ -17,6 +17,7 @@ package org.apache.spark.sql.blaze
import java.io.File
import java.util.UUID
+
import org.apache.commons.lang3.reflect.FieldUtils
import org.apache.spark.OneToOneDependency
import org.apache.spark.ShuffleDependency
@@ -111,27 +112,25 @@ import org.apache.spark.sql.types.StringType
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.storage.FileSegment
import org.blaze.{protobuf => pb}
-import com.thoughtworks.enableIf
import org.apache.spark.sql.execution.blaze.shuffle.uniffle.BlazeUniffleShuffleManager
+import org.blaze.sparkver
class ShimsImpl extends Shims with Logging {
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
override def shimVersion: String = "spark-3.0"
- @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.1")
override def shimVersion: String = "spark-3.1"
- @enableIf(Seq("spark-3.2").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.2")
override def shimVersion: String = "spark-3.2"
- @enableIf(Seq("spark-3.3").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.3")
override def shimVersion: String = "spark-3.3"
- @enableIf(Seq("spark-3.4").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.4")
override def shimVersion: String = "spark-3.4"
- @enableIf(Seq("spark-3.5").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.5")
override def shimVersion: String = "spark-3.5"
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def initExtension(): Unit = {
ValidateSparkPlanInjector.inject()
@@ -146,7 +145,7 @@ class ShimsImpl extends Shims with Logging {
}
}
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def initExtension(): Unit = {
if (BlazeConf.FORCE_SHUFFLED_HASH_JOIN.booleanConf()) {
logWarning(s"${BlazeConf.FORCE_SHUFFLED_HASH_JOIN.key} is not supported in $shimVersion")
@@ -374,9 +373,7 @@ class ShimsImpl extends Shims with Logging {
length: Long,
numRecords: Long): FileSegment = new FileSegment(file, offset, length)
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def commit(
dep: ShuffleDependency[_, _, _],
shuffleBlockResolver: IndexShuffleBlockResolver,
@@ -396,7 +393,7 @@ class ShimsImpl extends Shims with Logging {
MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId, partitionLengths, mapId)
}
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def commit(
dep: ShuffleDependency[_, _, _],
shuffleBlockResolver: IndexShuffleBlockResolver,
@@ -533,23 +530,19 @@ class ShimsImpl extends Shims with Logging {
expr.asInstanceOf[AggregateExpression].filter
}
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
private def isAQEShuffleRead(exec: SparkPlan): Boolean = {
import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
exec.isInstanceOf[AQEShuffleReadExec]
}
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
private def isAQEShuffleRead(exec: SparkPlan): Boolean = {
import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec
exec.isInstanceOf[CustomShuffleReaderExec]
}
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = {
import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
import org.apache.spark.sql.execution.CoalescedMapperPartitionSpec
@@ -652,7 +645,7 @@ class ShimsImpl extends Shims with Logging {
}
}
- @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.1")
private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = {
import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec
@@ -744,7 +737,7 @@ class ShimsImpl extends Shims with Logging {
}
}
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = {
import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec
@@ -846,13 +839,11 @@ class ShimsImpl extends Shims with Logging {
}
}
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def getSqlContext(sparkPlan: SparkPlan): SQLContext =
sparkPlan.session.sqlContext
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def getSqlContext(sparkPlan: SparkPlan): SQLContext = sparkPlan.sqlContext
override def createNativeExprWrapper(
@@ -862,9 +853,7 @@ class ShimsImpl extends Shims with Logging {
NativeExprWrapper(nativeExpr, dataType, nullable)
}
- @enableIf(
- Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1 / 3.2 / 3.3")
override def getPartitionedFile(
partitionValues: InternalRow,
filePath: String,
@@ -872,7 +861,7 @@ class ShimsImpl extends Shims with Logging {
size: Long): PartitionedFile =
PartitionedFile(partitionValues, filePath, offset, size)
- @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.4 / 3.5")
override def getPartitionedFile(
partitionValues: InternalRow,
filePath: String,
@@ -883,20 +872,16 @@ class ShimsImpl extends Shims with Logging {
PartitionedFile(partitionValues, SparkPath.fromPath(new Path(filePath)), offset, size)
}
- @enableIf(
- Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override def getMinPartitionNum(sparkSession: SparkSession): Int =
sparkSession.sessionState.conf.filesMinPartitionNum
.getOrElse(sparkSession.sparkContext.defaultParallelism)
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
override def getMinPartitionNum(sparkSession: SparkSession): Int =
sparkSession.sparkContext.defaultParallelism
- @enableIf(
- Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1 / 3.2 / 3.3")
private def convertPromotePrecision(
e: Expression,
isPruningExpr: Boolean,
@@ -909,13 +894,13 @@ class ShimsImpl extends Shims with Logging {
}
}
- @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.4 / 3.5")
private def convertPromotePrecision(
e: Expression,
isPruningExpr: Boolean,
fallback: Expression => pb.PhysicalExprNode): Option[pb.PhysicalExprNode] = None
- @enableIf(Seq("spark-3.3", "spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.3 / 3.4 / 3.5")
private def convertBloomFilterAgg(agg: AggregateFunction): Option[pb.PhysicalAggExprNode] = {
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
agg match {
@@ -940,10 +925,10 @@ class ShimsImpl extends Shims with Logging {
}
}
- @enableIf(Seq("spark-3.0", "spark-3.1", "spark-3.2").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1 / 3.2")
private def convertBloomFilterAgg(agg: AggregateFunction): Option[pb.PhysicalAggExprNode] = None
- @enableIf(Seq("spark-3.3", "spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.3 / 3.4 / 3.5")
private def convertBloomFilterMightContain(
e: Expression,
isPruningExpr: Boolean,
@@ -966,7 +951,7 @@ class ShimsImpl extends Shims with Logging {
}
}
- @enableIf(Seq("spark-3.0", "spark-3.1", "spark-3.2").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1 / 3.2")
private def convertBloomFilterMightContain(
e: Expression,
isPruningExpr: Boolean,
@@ -977,13 +962,11 @@ class ShimsImpl extends Shims with Logging {
case class ForceNativeExecutionWrapper(override val child: SparkPlan)
extends ForceNativeExecutionWrapperBase(child) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
@@ -994,8 +977,6 @@ case class NativeExprWrapper(
override val nullable: Boolean)
extends NativeExprWrapperBase(nativeExpr, dataType, nullable) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy()
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeExec.scala
index 5de82883f..4edf5ba00 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeExec.scala
@@ -16,18 +16,15 @@
package org.apache.spark.sql.execution.blaze.plan
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class ConvertToNativeExec(override val child: SparkPlan) extends ConvertToNativeBase(child) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeAggExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeAggExec.scala
index 8375a5fb0..51625490c 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeAggExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeAggExec.scala
@@ -26,8 +26,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.blaze.plan.NativeAggBase.AggExecMode
import org.apache.spark.sql.types.BinaryType
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeAggExec(
execMode: AggExecMode,
@@ -47,13 +46,11 @@ case class NativeAggExec(
child)
with BaseAggregateExec {
- @enableIf(
- Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override val requiredChildDistributionExpressions: Option[Seq[Expression]] =
theRequiredChildDistributionExpressions
- @enableIf(Seq("spark-3.3", "spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.3 / 3.4 / 3.5")
override val initialInputBufferOffset: Int = theInitialInputBufferOffset
override def output: Seq[Attribute] =
@@ -65,25 +62,19 @@ case class NativeAggExec(
ExprId.apply(NativeAggBase.AGG_BUF_COLUMN_EXPR_ID))
}
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def isStreaming: Boolean = false
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def numShufflePartitions: Option[Int] = None
override def resultExpressions: Seq[NamedExpression] = output
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeExec.scala
index 0a0213b80..4023a4999 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeExec.scala
@@ -22,8 +22,7 @@ import org.apache.spark.sql.blaze.NativeSupports
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeBroadcastExchangeExec(mode: BroadcastMode, override val child: SparkPlan)
extends NativeBroadcastExchangeBase(mode, child)
@@ -42,13 +41,11 @@ case class NativeBroadcastExchangeExec(mode: BroadcastMode, override val child:
relationFuturePromise.future
}
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeExpandExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeExpandExec.scala
index c22fb6ebc..e2516b4c7 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeExpandExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeExpandExec.scala
@@ -18,8 +18,7 @@ package org.apache.spark.sql.execution.blaze.plan
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeExpandExec(
projections: Seq[Seq[Expression]],
@@ -27,13 +26,11 @@ case class NativeExpandExec(
override val child: SparkPlan)
extends NativeExpandBase(projections, output, child) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeFilterExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeFilterExec.scala
index ea0a17aa3..197985646 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeFilterExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeFilterExec.scala
@@ -17,19 +17,16 @@ package org.apache.spark.sql.execution.blaze.plan
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeFilterExec(condition: Expression, override val child: SparkPlan)
extends NativeFilterBase(condition, child) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateExec.scala
index c372fe6db..8a1a2a9e6 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateExec.scala
@@ -18,8 +18,7 @@ package org.apache.spark.sql.execution.blaze.plan
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.Generator
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeGenerateExec(
generator: Generator,
@@ -29,13 +28,11 @@ case class NativeGenerateExec(
override val child: SparkPlan)
extends NativeGenerateBase(generator, requiredChildOutput, outer, generatorOutput, child) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGlobalLimitExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGlobalLimitExec.scala
index d8f18dd89..19f38d203 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGlobalLimitExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGlobalLimitExec.scala
@@ -16,19 +16,16 @@
package org.apache.spark.sql.execution.blaze.plan
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeGlobalLimitExec(limit: Long, override val child: SparkPlan)
extends NativeGlobalLimitBase(limit, child) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeLocalLimitExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeLocalLimitExec.scala
index bbd8b48ee..96124db5a 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeLocalLimitExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeLocalLimitExec.scala
@@ -16,19 +16,16 @@
package org.apache.spark.sql.execution.blaze.plan
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeLocalLimitExec(limit: Long, override val child: SparkPlan)
extends NativeLocalLimitBase(limit, child) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetInsertIntoHiveTableExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetInsertIntoHiveTableExec.scala
index 8ac2ab161..ef31c2700 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetInsertIntoHiveTableExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetInsertIntoHiveTableExec.scala
@@ -23,17 +23,14 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeParquetInsertIntoHiveTableExec(
cmd: InsertIntoHiveTable,
override val child: SparkPlan)
extends NativeParquetInsertIntoHiveTableBase(cmd, child) {
- @enableIf(
- Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1 / 3.2 / 3.3")
override protected def getInsertIntoHiveTableCommand(
table: CatalogTable,
partition: Map[String, Option[String]],
@@ -52,7 +49,7 @@ case class NativeParquetInsertIntoHiveTableExec(
metrics)
}
- @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.4 / 3.5")
override protected def getInsertIntoHiveTableCommand(
table: CatalogTable,
partition: Map[String, Option[String]],
@@ -71,19 +68,15 @@ case class NativeParquetInsertIntoHiveTableExec(
metrics)
}
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
- @enableIf(
- Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1 / 3.2 / 3.3")
class BlazeInsertIntoHiveTable30(
table: CatalogTable,
partition: Map[String, Option[String]],
@@ -108,7 +101,7 @@ case class NativeParquetInsertIntoHiveTableExec(
super.run(sparkSession, nativeParquetSink)
}
- @enableIf(Seq("spark-3.2", "spark-3.3").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3")
override def basicWriteJobStatsTracker(hadoopConf: org.apache.hadoop.conf.Configuration) = {
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
@@ -134,7 +127,7 @@ case class NativeParquetInsertIntoHiveTableExec(
}
}
- @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.1")
override def basicWriteJobStatsTracker(hadoopConf: org.apache.hadoop.conf.Configuration) = {
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
@@ -172,7 +165,7 @@ case class NativeParquetInsertIntoHiveTableExec(
}
}
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
override def basicWriteJobStatsTracker(hadoopConf: org.apache.hadoop.conf.Configuration) = {
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
@@ -202,7 +195,7 @@ case class NativeParquetInsertIntoHiveTableExec(
}
}
- @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.4 / 3.5")
class BlazeInsertIntoHiveTable34(
table: CatalogTable,
partition: Map[String, Option[String]],
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkExec.scala
index e701921d5..54f404775 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkExec.scala
@@ -19,8 +19,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.execution.metric.SQLMetric
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeParquetSinkExec(
sparkSession: SparkSession,
@@ -30,13 +29,11 @@ case class NativeParquetSinkExec(
override val metrics: Map[String, SQLMetric])
extends NativeParquetSinkBase(sparkSession, table, partition, child, metrics) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativePartialTakeOrderedExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativePartialTakeOrderedExec.scala
index 91b8a8576..0a51a8307 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativePartialTakeOrderedExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativePartialTakeOrderedExec.scala
@@ -18,8 +18,7 @@ package org.apache.spark.sql.execution.blaze.plan
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativePartialTakeOrderedExec(
limit: Long,
@@ -28,13 +27,11 @@ case class NativePartialTakeOrderedExec(
override val metrics: Map[String, SQLMetric])
extends NativePartialTakeOrderedBase(limit, sortOrder, child, metrics) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeProjectExecProvider.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeProjectExecProvider.scala
index f3b8dd49d..7fe895260 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeProjectExecProvider.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeProjectExecProvider.scala
@@ -17,11 +17,10 @@ package org.apache.spark.sql.execution.blaze.plan
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case object NativeProjectExecProvider {
- @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.4 / 3.5")
def provide(
projectList: Seq[NamedExpression],
child: SparkPlan,
@@ -49,7 +48,7 @@ case object NativeProjectExecProvider {
NativeProjectExec(projectList, child, addTypeCast)
}
- @enableIf(Seq("spark-3.1", "spark-3.2", "spark-3.3").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.1 / 3.2 / 3.3")
def provide(
projectList: Seq[NamedExpression],
child: SparkPlan,
@@ -65,11 +64,11 @@ case object NativeProjectExecProvider {
with AliasAwareOutputPartitioning
with AliasAwareOutputOrdering {
- @enableIf(Seq("spark-3.2", "spark-3.3").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
@@ -82,7 +81,7 @@ case object NativeProjectExecProvider {
NativeProjectExec(projectList, child, addTypeCast)
}
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
def provide(
projectList: Seq[NamedExpression],
child: SparkPlan,
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeRenameColumnsExecProvider.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeRenameColumnsExecProvider.scala
index b57f00726..e559ab07b 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeRenameColumnsExecProvider.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeRenameColumnsExecProvider.scala
@@ -16,11 +16,10 @@
package org.apache.spark.sql.execution.blaze.plan
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case object NativeRenameColumnsExecProvider {
- @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.4 / 3.5")
def provide(child: SparkPlan, renamedColumnNames: Seq[String]): NativeRenameColumnsBase = {
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.expressions.SortOrder
@@ -44,7 +43,7 @@ case object NativeRenameColumnsExecProvider {
NativeRenameColumnsExec(child, renamedColumnNames)
}
- @enableIf(Seq("spark-3.1", "spark-3.2", "spark-3.3").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.1 / 3.2 / 3.3")
def provide(child: SparkPlan, renamedColumnNames: Seq[String]): NativeRenameColumnsBase = {
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.expressions.SortOrder
@@ -58,11 +57,11 @@ case object NativeRenameColumnsExecProvider {
with AliasAwareOutputPartitioning
with AliasAwareOutputOrdering {
- @enableIf(Seq("spark-3.2", "spark-3.3").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
@@ -73,7 +72,7 @@ case object NativeRenameColumnsExecProvider {
NativeRenameColumnsExec(child, renamedColumnNames)
}
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
def provide(child: SparkPlan, renamedColumnNames: Seq[String]): NativeRenameColumnsBase = {
case class NativeRenameColumnsExec(
override val child: SparkPlan,
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala
index a0d57ab10..d9750bd36 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala
@@ -25,7 +25,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.shuffle.ShuffleWriteProcessor
-import org.apache.spark.shuffle.ShuffleWriter
import org.apache.spark.sql.blaze.NativeHelper
import org.apache.spark.sql.blaze.NativeRDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -33,14 +32,12 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.blaze.shuffle.BlazeRssShuffleWriterBase
-import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleWriter
import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleWriterBase
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeShuffleExchangeExec(
override val outputPartitioning: Partitioning,
@@ -165,30 +162,26 @@ case class NativeShuffleExchangeExec(
// for databricks testing
val causedBroadcastJoinBuildOOM = false
- @enableIf(Seq("spark-3.5").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.5")
override def advisoryPartitionSize: Option[Long] = None
// If users specify the num partitions via APIs like `repartition`, we shouldn't change it.
// For `SinglePartition`, it requires exactly one partition and we can't change it either.
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
override def canChangeNumPartitions: Boolean =
outputPartitioning != SinglePartition
- @enableIf(
- Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override def shuffleOrigin = {
import org.apache.spark.sql.execution.exchange.ShuffleOrigin;
_shuffleOrigin.get.asInstanceOf[ShuffleOrigin]
}
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortExec.scala
index 5eff2206d..d79c0304a 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortExec.scala
@@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.blaze.plan
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeSortExec(
sortOrder: Seq[SortOrder],
@@ -26,13 +25,11 @@ case class NativeSortExec(
override val child: SparkPlan)
extends NativeSortBase(sortOrder, global, child) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeTakeOrderedExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeTakeOrderedExec.scala
index d4cb0c9d5..2c98b82c0 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeTakeOrderedExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeTakeOrderedExec.scala
@@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.blaze.plan
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeTakeOrderedExec(
limit: Long,
@@ -26,13 +25,11 @@ case class NativeTakeOrderedExec(
override val child: SparkPlan)
extends NativeTakeOrderedBase(limit, sortOrder, child) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeUnionExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeUnionExec.scala
index 6a9ce256f..29fcc0600 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeUnionExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeUnionExec.scala
@@ -16,19 +16,16 @@
package org.apache.spark.sql.execution.blaze.plan
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeUnionExec(override val children: Seq[SparkPlan])
extends NativeUnionBase(children) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan =
copy(children = newChildren)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(children = newChildren)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeWindowExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeWindowExec.scala
index c16223e9b..b3cbf8175 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeWindowExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeWindowExec.scala
@@ -19,8 +19,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.execution.SparkPlan
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeWindowExec(
windowExpression: Seq[NamedExpression],
@@ -29,13 +28,11 @@ case class NativeWindowExec(
override val child: SparkPlan)
extends NativeWindowBase(windowExpression, partitionSpec, orderSpec, child) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala
index e8dbf4c94..374217608 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala
@@ -18,12 +18,11 @@ package org.apache.spark.sql.execution.blaze.shuffle
import java.io.InputStream
import org.apache.spark.{MapOutputTracker, SparkEnv, TaskContext}
-import org.apache.spark.internal.{Logging, config}
+import org.apache.spark.internal.{config, Logging}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter}
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator}
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
class BlazeBlockStoreShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
@@ -37,9 +36,7 @@ class BlazeBlockStoreShuffleReader[K, C](
with Logging {
override def readBlocks(): Iterator[(BlockId, InputStream)] = {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
def fetchIterator = new ShuffleBlockFetcherIterator(
context,
blockManager.blockStoreClient,
@@ -60,7 +57,7 @@ class BlazeBlockStoreShuffleReader[K, C](
readMetrics,
fetchContinuousBlocksInBatch).toCompletionIterator
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
def fetchIterator = new ShuffleBlockFetcherIterator(
context,
blockManager.blockStoreClient,
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleManagerBase.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleManagerBase.scala
index 2dab56776..f4574aa34 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleManagerBase.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleManagerBase.scala
@@ -19,8 +19,7 @@ import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle._
import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleDependency.isArrowShuffle
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
abstract class BlazeRssShuffleManagerBase(_conf: SparkConf) extends ShuffleManager with Logging {
override def registerShuffle[K, V, C](
@@ -73,9 +72,7 @@ abstract class BlazeRssShuffleManagerBase(_conf: SparkConf) extends ShuffleManag
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V]
- @enableIf(
- Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override def getReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
@@ -106,7 +103,7 @@ abstract class BlazeRssShuffleManagerBase(_conf: SparkConf) extends ShuffleManag
}
}
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
@@ -121,7 +118,7 @@ abstract class BlazeRssShuffleManagerBase(_conf: SparkConf) extends ShuffleManag
}
}
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
override def getReaderForRange[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala
index bec87beb8..5772f3ea4 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala
@@ -21,8 +21,7 @@ import org.apache.spark.shuffle._
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleDependency.isArrowShuffle
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
val sortShuffleManager = new SortShuffleManager(conf)
@@ -45,9 +44,7 @@ class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
sortShuffleManager.registerShuffle(shuffleId, dependency)
}
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def getReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
@@ -60,10 +57,9 @@ class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
if (isArrowShuffle(handle)) {
val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]]
- @enableIf(Seq("spark-3.2").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.2")
def shuffleMergeFinalized = baseShuffleHandle.dependency.shuffleMergeFinalized
- @enableIf(
- Seq("spark-3.3", "spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.3 / 3.4 / 3.5")
def shuffleMergeFinalized = baseShuffleHandle.dependency.isShuffleMergeFinalizedMarked
val (blocksByAddress, canEnableBatchFetch) =
@@ -106,7 +102,7 @@ class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
}
}
- @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.1")
override def getReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
@@ -143,7 +139,7 @@ class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
}
}
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
@@ -170,7 +166,7 @@ class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
}
}
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
override def getReaderForRange[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleWriter.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleWriter.scala
index ba796786e..f68867406 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleWriter.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleWriter.scala
@@ -16,14 +16,11 @@
package org.apache.spark.sql.execution.blaze.shuffle
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
class BlazeShuffleWriter[K, V](metrics: ShuffleWriteMetricsReporter)
extends BlazeShuffleWriterBase[K, V](metrics) {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def getPartitionLengths(): Array[Long] = partitionLengths
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleWriter.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleWriter.scala
index 4f311d0f5..16a34051c 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleWriter.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleWriter.scala
@@ -25,8 +25,7 @@ import org.apache.spark.shuffle.celeborn.ExecutorShuffleIdTracker
import org.apache.spark.shuffle.celeborn.SparkUtils
import org.apache.spark.sql.execution.blaze.shuffle.BlazeRssShuffleWriterBase
import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
class BlazeCelebornShuffleWriter[K, V](
celebornShuffleWriter: ShuffleWriter[K, V],
@@ -57,9 +56,7 @@ class BlazeCelebornShuffleWriter[K, V](
metrics)
}
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def getPartitionLengths(): Array[Long] = celebornShuffleWriter.getPartitionLengths()
override def rssStop(success: Boolean): Unit = {
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala
index 5e51ae068..8cb7cd64e 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala
@@ -24,8 +24,7 @@ import org.apache.spark.sql.execution.blaze.plan.BroadcastRight
import org.apache.spark.sql.execution.blaze.plan.BroadcastSide
import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastJoinBase
import org.apache.spark.sql.execution.joins.HashJoin
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case class NativeBroadcastJoinExec(
override val left: SparkPlan,
@@ -47,24 +46,20 @@ case class NativeBroadcastJoinExec(
override val condition: Option[Expression] = None
- @enableIf(
- Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override def buildSide: org.apache.spark.sql.catalyst.optimizer.BuildSide =
broadcastSide match {
case BroadcastLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft
case BroadcastRight => org.apache.spark.sql.catalyst.optimizer.BuildRight
}
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
override val buildSide: org.apache.spark.sql.execution.joins.BuildSide = broadcastSide match {
case BroadcastLeft => org.apache.spark.sql.execution.joins.BuildLeft
case BroadcastRight => org.apache.spark.sql.execution.joins.BuildRight
}
- @enableIf(
- Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override def requiredChildDistribution = {
import org.apache.spark.sql.catalyst.plans.physical.BroadcastDistribution
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
@@ -79,36 +74,28 @@ case class NativeBroadcastJoinExec(
}
}
- @enableIf(
- Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override def supportCodegen: Boolean = false
- @enableIf(
- Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override def inputRDDs() = {
throw new NotImplementedError("NativeBroadcastJoin dose not support codegen")
}
- @enableIf(
- Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override protected def prepareRelation(
ctx: org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext)
: org.apache.spark.sql.execution.joins.HashedRelationInfo = {
throw new NotImplementedError("NativeBroadcastJoin dose not support codegen")
}
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildrenInternal(
newLeft: SparkPlan,
newRight: SparkPlan): SparkPlan =
copy(left = newLeft, right = newRight)
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(left = newChildren(0), right = newChildren(1))
}
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeShuffledHashJoinExecProvider.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeShuffledHashJoinExecProvider.scala
index 47c89a8f4..d359622ef 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeShuffledHashJoinExecProvider.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeShuffledHashJoinExecProvider.scala
@@ -20,14 +20,11 @@ import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.blaze.plan.BuildSide
import org.apache.spark.sql.execution.blaze.plan.NativeShuffledHashJoinBase
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case object NativeShuffledHashJoinExecProvider {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
def provide(
left: SparkPlan,
right: SparkPlan,
@@ -74,7 +71,7 @@ case object NativeShuffledHashJoinExecProvider {
NativeShuffledHashJoinExec(left, right, leftKeys, rightKeys, joinType, buildSide)
}
- @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.1")
def provide(
left: SparkPlan,
right: SparkPlan,
@@ -118,7 +115,7 @@ case object NativeShuffledHashJoinExecProvider {
NativeShuffledHashJoinExec(left, right, leftKeys, rightKeys, joinType, buildSide)
}
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
def provide(
left: SparkPlan,
right: SparkPlan,
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeSortMergeJoinExecProvider.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeSortMergeJoinExecProvider.scala
index d589fcacb..df3796530 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeSortMergeJoinExecProvider.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeSortMergeJoinExecProvider.scala
@@ -19,14 +19,11 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.blaze.plan.NativeSortMergeJoinBase
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
case object NativeSortMergeJoinExecProvider {
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
def provide(
left: SparkPlan,
right: SparkPlan,
@@ -71,7 +68,7 @@ case object NativeSortMergeJoinExecProvider {
NativeSortMergeJoinExec(left, right, leftKeys, rightKeys, joinType)
}
- @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0 / 3.1")
def provide(
left: SparkPlan,
right: SparkPlan,
diff --git a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/blaze/BlazeAdaptiveQueryExecSuite.scala b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/blaze/BlazeAdaptiveQueryExecSuite.scala
index d2ac0bf1b..607fffb56 100644
--- a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/blaze/BlazeAdaptiveQueryExecSuite.scala
+++ b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/blaze/BlazeAdaptiveQueryExecSuite.scala
@@ -15,9 +15,9 @@
*/
package org.apache.spark.sql.blaze
-import com.thoughtworks.enableMembersIf
+import org.blaze.sparkverEnableMembers
-@enableMembersIf(Seq("spark-3.5").contains(System.getProperty("blaze.shim")))
+@sparkverEnableMembers("3.5")
class BlazeAdaptiveQueryExecSuite
extends org.apache.spark.sql.QueryTest
with BaseBlazeSQLSuite
diff --git a/spark-extension/pom.xml b/spark-extension/pom.xml
index 26e46c903..2cee03c2e 100644
--- a/spark-extension/pom.xml
+++ b/spark-extension/pom.xml
@@ -13,6 +13,12 @@
jar
+
+ org.blaze
+ spark-version-annotation-macros
+ ${revision}
+ compile
+
org.blaze
proto
@@ -69,11 +75,6 @@
scalatest_${scalaVersion}
test
-
- com.thoughtworks.enableIf
- enableif_${scalaVersion}
- 1.2.0
-
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala
index 4df634542..7d62694fb 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala
@@ -15,8 +15,6 @@
*/
package org.apache.spark.sql.blaze
-import com.thoughtworks.enableIf
-
import scala.annotation.tailrec
import scala.collection.mutable
import org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat
@@ -81,6 +79,7 @@ import org.apache.spark.sql.hive.blaze.BlazeHiveConverters
import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
import org.apache.spark.sql.hive.execution.blaze.plan.NativeHiveTableScanBase
import org.apache.spark.sql.types.LongType
+import org.blaze.sparkver
object BlazeConverters extends Logging {
val enableScan: Boolean =
@@ -296,12 +295,10 @@ object BlazeConverters extends Logging {
getShuffleOrigin(exec))
}
- @enableIf(
- Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = Some(exec.shuffleOrigin)
- @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
+ @sparkver("3.0")
def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = None
def convertFileSourceScanExec(exec: FileSourceScanExec): SparkPlan = {
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/SpillBuf.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/SpillBuf.scala
index e016df0b3..e3aa3262e 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/SpillBuf.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/SpillBuf.scala
@@ -145,7 +145,7 @@ class ReleasedSpillBuf(releasing: SpillBuf) extends SpillBuf {
override val size: Long = releasing.size
releasing.release()
-
+
override def write(buf: ByteBuffer): Unit =
throw new UnsupportedOperationException()
@@ -153,4 +153,4 @@ class ReleasedSpillBuf(releasing: SpillBuf) extends SpillBuf {
throw new UnsupportedOperationException()
override def release(): Unit = {}
-}
\ No newline at end of file
+}
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleWriterBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleWriterBase.scala
index 374d3e749..236e5399a 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleWriterBase.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleWriterBase.scala
@@ -23,13 +23,11 @@ import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteMetricsReporter}
import org.apache.spark.shuffle.ShuffleWriter
import org.apache.spark.sql.blaze.{JniBridge, NativeHelper, NativeRDD, Shims}
import org.blaze.protobuf.{PhysicalPlanNode, RssShuffleWriterExecNode}
-
-import com.thoughtworks.enableIf
+import org.blaze.sparkver
abstract class BlazeRssShuffleWriterBase[K, V](metrics: ShuffleWriteMetricsReporter)
extends ShuffleWriter[K, V] {
- private var mapStatus: Option[MapStatus] = None
private var rpw: RssPartitionWriterBase = _
private var mapId: Int = 0
@@ -70,19 +68,11 @@ abstract class BlazeRssShuffleWriterBase[K, V](metrics: ShuffleWriteMetricsRepor
} finally {
rpw.close()
}
-
- mapStatus = Some(
- Shims.get.getMapStatus(
- SparkEnv.get.blockManager.shuffleServerId,
- rpw.getPartitionLengthMap,
- mapId))
}
def rssStop(success: Boolean): Unit = {}
- @enableIf(
- Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
- System.getProperty("blaze.shim")))
+ @sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def getPartitionLengths(): Array[Long] = rpw.getPartitionLengthMap
override def write(records: Iterator[Product2[K, V]]): Unit = {
diff --git a/spark-version-annotation-macros/pom.xml b/spark-version-annotation-macros/pom.xml
new file mode 100644
index 000000000..f45d25ecf
--- /dev/null
+++ b/spark-version-annotation-macros/pom.xml
@@ -0,0 +1,27 @@
+
+
+ 4.0.0
+
+
+ org.blaze
+ blaze-engine
+ ${revision}
+ ../
+
+ org.blaze
+ spark-version-annotation-macros
+ jar
+
+
+
+ org.scala-lang
+ scala-library
+ provided
+
+
+ org.scalamacros
+ paradise_${scalaLongVersion}
+ 2.1.1
+
+
+
diff --git a/spark-version-annotation-macros/src/main/scala/org/blaze/sparkver.scala b/spark-version-annotation-macros/src/main/scala/org/blaze/sparkver.scala
new file mode 100644
index 000000000..537d1fb2b
--- /dev/null
+++ b/spark-version-annotation-macros/src/main/scala/org/blaze/sparkver.scala
@@ -0,0 +1,110 @@
+/*
+ * Copyright 2022 The Blaze Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.blaze
+
+import scala.annotation.StaticAnnotation
+import scala.annotation.compileTimeOnly
+import scala.language.experimental._
+import scala.reflect.macros.whitebox
+
+object sparkver {
+ def matchVersion(vers: String): Boolean = {
+ val configuredVer = System.getProperty("blaze.shim")
+ for (ver <- vers.split("/")) {
+ val verStripped = ver.trim
+ if (s"spark-$verStripped" == configuredVer) {
+ return true
+ }
+ }
+ false
+ }
+
+ object Macros {
+ def impl(c: whitebox.Context)(annottees: c.Expr[Any]*)(
+ disabled: => c.Expr[Any]): c.Expr[Any] = {
+ import c.universe._
+
+ val versions = c.macroApplication match {
+ case Apply(Select(Apply(_, List(vs)), _), _) => c.eval(c.Expr[String](q"$vs"))
+ }
+
+ if (matchVersion(versions)) {
+ return c.Expr[Any](q"..$annottees")
+ }
+ disabled
+ }
+
+ def verEnable(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
+ import c.universe._
+
+ impl(c)(annottees: _*) {
+ c.Expr(EmptyTree)
+ }
+ }
+
+ def verEnableMembers(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
+ import c.universe._
+
+ impl(c)(annottees: _*) {
+ val head = annottees.head.tree match {
+ case ClassDef(mods, name, tparams, Template(parents, self, _body)) =>
+ ClassDef(mods, name, tparams, Template(parents, self, List(EmptyTree)))
+ case ModuleDef(mods, name, Template(parents, self, _body)) =>
+ ModuleDef(mods, name, Template(parents, self, List(EmptyTree)))
+ }
+ c.Expr(q"$head; ..${annottees.tail}")
+ }
+ }
+
+ def verEnableOverride(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
+ import c.universe._
+ import scala.reflect.internal.Flags
+ impl(c)(annottees: _*) {
+ val head = annottees.head.tree match {
+ case DefDef(mods, name, tparams, vparams, tpt, rhs) =>
+ val newMods = Modifiers(
+ (mods.flags.asInstanceOf[Long] & ~Flags.OVERRIDE).asInstanceOf[FlagSet],
+ mods.privateWithin,
+ mods.annotations)
+ DefDef(newMods, name, tparams, vparams, tpt, rhs)
+
+ case ValDef(mods, name, tpt, rhs) =>
+ val newMods = Modifiers(
+ (mods.flags.asInstanceOf[Long] & ~Flags.OVERRIDE).asInstanceOf[FlagSet],
+ mods.privateWithin,
+ mods.annotations)
+ ValDef(newMods, name, tpt, rhs)
+ }
+ c.Expr(q"$head; ..${annottees.tail}")
+ }
+ }
+ }
+}
+
+@compileTimeOnly("enable macro paradise to expand macro annotations")
+final class sparkver(vers: String) extends StaticAnnotation {
+ def macroTransform(annottees: Any*): Any = macro sparkver.Macros.verEnable
+}
+
+@compileTimeOnly("enable macro paradise to expand macro annotations")
+final class sparkverEnableMembers(vers: String) extends StaticAnnotation {
+ def macroTransform(annottees: Any*): Any = macro sparkver.Macros.verEnableMembers
+}
+
+@compileTimeOnly("enable macro paradise to expand macro annotations")
+final class sparkverEnableOverride(vers: String) extends StaticAnnotation {
+ def macroTransform(annottees: Any*): Any = macro sparkver.Macros.verEnableOverride
+}