From afc066e52c16b64ffb4f3dcf920a04028f3b6eed Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Tue, 25 Mar 2025 00:39:23 +0800 Subject: [PATCH] introduce spark version control with spark-version-annotation-macros, replacing enableIf --- pom.xml | 1 + spark-extension-shims-spark3/pom.xml | 6 - .../blaze/InterceptedValidateSparkPlan.scala | 13 +-- .../apache/spark/sql/blaze/ShimsImpl.scala | 83 +++++-------- .../blaze/plan/ConvertToNativeExec.scala | 9 +- .../execution/blaze/plan/NativeAggExec.scala | 23 ++-- .../plan/NativeBroadcastExchangeExec.scala | 9 +- .../blaze/plan/NativeExpandExec.scala | 9 +- .../blaze/plan/NativeFilterExec.scala | 9 +- .../blaze/plan/NativeGenerateExec.scala | 9 +- .../blaze/plan/NativeGlobalLimitExec.scala | 9 +- .../blaze/plan/NativeLocalLimitExec.scala | 9 +- ...NativeParquetInsertIntoHiveTableExec.scala | 27 ++--- .../blaze/plan/NativeParquetSinkExec.scala | 9 +- .../plan/NativePartialTakeOrderedExec.scala | 9 +- .../plan/NativeProjectExecProvider.scala | 13 +-- .../NativeRenameColumnsExecProvider.scala | 13 +-- .../plan/NativeShuffleExchangeExec.scala | 19 +-- .../execution/blaze/plan/NativeSortExec.scala | 9 +- .../blaze/plan/NativeTakeOrderedExec.scala | 9 +- .../blaze/plan/NativeUnionExec.scala | 9 +- .../blaze/plan/NativeWindowExec.scala | 9 +- .../BlazeBlockStoreShuffleReader.scala | 11 +- .../shuffle/BlazeRssShuffleManagerBase.scala | 11 +- .../blaze/shuffle/BlazeShuffleManager.scala | 18 ++- .../blaze/shuffle/BlazeShuffleWriter.scala | 7 +- .../celeborn/BlazeCelebornShuffleWriter.scala | 7 +- .../blaze/plan/NativeBroadcastJoinExec.scala | 31 ++--- .../NativeShuffledHashJoinExecProvider.scala | 11 +- .../NativeSortMergeJoinExecProvider.scala | 9 +- .../blaze/BlazeAdaptiveQueryExecSuite.scala | 4 +- spark-extension/pom.xml | 11 +- .../spark/sql/blaze/BlazeConverters.scala | 9 +- .../spark/sql/blaze/memory/SpillBuf.scala | 4 +- .../shuffle/BlazeRssShuffleWriterBase.scala | 14 +-- spark-version-annotation-macros/pom.xml | 27 +++++ .../src/main/scala/org/blaze/sparkver.scala | 110 ++++++++++++++++++ 37 files changed, 298 insertions(+), 301 deletions(-) create mode 100644 spark-version-annotation-macros/pom.xml create mode 100644 spark-version-annotation-macros/src/main/scala/org/blaze/sparkver.scala diff --git a/pom.xml b/pom.xml index f7e90886b..4af2d7955 100644 --- a/pom.xml +++ b/pom.xml @@ -7,6 +7,7 @@ pom + spark-version-annotation-macros spark-extension ${shimPkg} hadoop-shim diff --git a/spark-extension-shims-spark3/pom.xml b/spark-extension-shims-spark3/pom.xml index f60cb9d41..4dc42fee3 100644 --- a/spark-extension-shims-spark3/pom.xml +++ b/spark-extension-shims-spark3/pom.xml @@ -90,12 +90,6 @@ 1.12.10 - - com.thoughtworks.enableIf - enableif_${scalaVersion} - 1.2.0 - - org.apache.spark spark-core_${scalaVersion} diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/InterceptedValidateSparkPlan.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/InterceptedValidateSparkPlan.scala index 4d1c8ef69..2b7df301a 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/InterceptedValidateSparkPlan.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/InterceptedValidateSparkPlan.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.blaze import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver object InterceptedValidateSparkPlan extends Logging { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") def validate(plan: SparkPlan): Unit = { import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec import org.apache.spark.sql.execution.blaze.plan.NativeRenameColumnsBase @@ -72,14 +69,12 @@ object InterceptedValidateSparkPlan extends Logging { } } - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") def validate(plan: SparkPlan): Unit = { throw new UnsupportedOperationException("validate is not supported in spark 3.0.3 or 3.1.3") } - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") private def errorOnInvalidBroadcastQueryStage(plan: SparkPlan): Unit = { import org.apache.spark.sql.execution.adaptive.InvalidAQEPlanException throw InvalidAQEPlanException("Invalid broadcast query stage", plan) diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala index e23310f86..fd96a0bcf 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.blaze import java.io.File import java.util.UUID + import org.apache.commons.lang3.reflect.FieldUtils import org.apache.spark.OneToOneDependency import org.apache.spark.ShuffleDependency @@ -111,27 +112,25 @@ import org.apache.spark.sql.types.StringType import org.apache.spark.storage.BlockManagerId import org.apache.spark.storage.FileSegment import org.blaze.{protobuf => pb} -import com.thoughtworks.enableIf import org.apache.spark.sql.execution.blaze.shuffle.uniffle.BlazeUniffleShuffleManager +import org.blaze.sparkver class ShimsImpl extends Shims with Logging { - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") override def shimVersion: String = "spark-3.0" - @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.1") override def shimVersion: String = "spark-3.1" - @enableIf(Seq("spark-3.2").contains(System.getProperty("blaze.shim"))) + @sparkver("3.2") override def shimVersion: String = "spark-3.2" - @enableIf(Seq("spark-3.3").contains(System.getProperty("blaze.shim"))) + @sparkver("3.3") override def shimVersion: String = "spark-3.3" - @enableIf(Seq("spark-3.4").contains(System.getProperty("blaze.shim"))) + @sparkver("3.4") override def shimVersion: String = "spark-3.4" - @enableIf(Seq("spark-3.5").contains(System.getProperty("blaze.shim"))) + @sparkver("3.5") override def shimVersion: String = "spark-3.5" - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override def initExtension(): Unit = { ValidateSparkPlanInjector.inject() @@ -146,7 +145,7 @@ class ShimsImpl extends Shims with Logging { } } - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def initExtension(): Unit = { if (BlazeConf.FORCE_SHUFFLED_HASH_JOIN.booleanConf()) { logWarning(s"${BlazeConf.FORCE_SHUFFLED_HASH_JOIN.key} is not supported in $shimVersion") @@ -374,9 +373,7 @@ class ShimsImpl extends Shims with Logging { length: Long, numRecords: Long): FileSegment = new FileSegment(file, offset, length) - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override def commit( dep: ShuffleDependency[_, _, _], shuffleBlockResolver: IndexShuffleBlockResolver, @@ -396,7 +393,7 @@ class ShimsImpl extends Shims with Logging { MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId, partitionLengths, mapId) } - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def commit( dep: ShuffleDependency[_, _, _], shuffleBlockResolver: IndexShuffleBlockResolver, @@ -533,23 +530,19 @@ class ShimsImpl extends Shims with Logging { expr.asInstanceOf[AggregateExpression].filter } - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") private def isAQEShuffleRead(exec: SparkPlan): Boolean = { import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec exec.isInstanceOf[AQEShuffleReadExec] } - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") private def isAQEShuffleRead(exec: SparkPlan): Boolean = { import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec exec.isInstanceOf[CustomShuffleReaderExec] } - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = { import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec import org.apache.spark.sql.execution.CoalescedMapperPartitionSpec @@ -652,7 +645,7 @@ class ShimsImpl extends Shims with Logging { } } - @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.1") private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = { import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec @@ -744,7 +737,7 @@ class ShimsImpl extends Shims with Logging { } } - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = { import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec @@ -846,13 +839,11 @@ class ShimsImpl extends Shims with Logging { } } - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override def getSqlContext(sparkPlan: SparkPlan): SQLContext = sparkPlan.session.sqlContext - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def getSqlContext(sparkPlan: SparkPlan): SQLContext = sparkPlan.sqlContext override def createNativeExprWrapper( @@ -862,9 +853,7 @@ class ShimsImpl extends Shims with Logging { NativeExprWrapper(nativeExpr, dataType, nullable) } - @enableIf( - Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1 / 3.2 / 3.3") override def getPartitionedFile( partitionValues: InternalRow, filePath: String, @@ -872,7 +861,7 @@ class ShimsImpl extends Shims with Logging { size: Long): PartitionedFile = PartitionedFile(partitionValues, filePath, offset, size) - @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim"))) + @sparkver("3.4 / 3.5") override def getPartitionedFile( partitionValues: InternalRow, filePath: String, @@ -883,20 +872,16 @@ class ShimsImpl extends Shims with Logging { PartitionedFile(partitionValues, SparkPath.fromPath(new Path(filePath)), offset, size) } - @enableIf( - Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5") override def getMinPartitionNum(sparkSession: SparkSession): Int = sparkSession.sessionState.conf.filesMinPartitionNum .getOrElse(sparkSession.sparkContext.defaultParallelism) - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") override def getMinPartitionNum(sparkSession: SparkSession): Int = sparkSession.sparkContext.defaultParallelism - @enableIf( - Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1 / 3.2 / 3.3") private def convertPromotePrecision( e: Expression, isPruningExpr: Boolean, @@ -909,13 +894,13 @@ class ShimsImpl extends Shims with Logging { } } - @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim"))) + @sparkver("3.4 / 3.5") private def convertPromotePrecision( e: Expression, isPruningExpr: Boolean, fallback: Expression => pb.PhysicalExprNode): Option[pb.PhysicalExprNode] = None - @enableIf(Seq("spark-3.3", "spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim"))) + @sparkver("3.3 / 3.4 / 3.5") private def convertBloomFilterAgg(agg: AggregateFunction): Option[pb.PhysicalAggExprNode] = { import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate agg match { @@ -940,10 +925,10 @@ class ShimsImpl extends Shims with Logging { } } - @enableIf(Seq("spark-3.0", "spark-3.1", "spark-3.2").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1 / 3.2") private def convertBloomFilterAgg(agg: AggregateFunction): Option[pb.PhysicalAggExprNode] = None - @enableIf(Seq("spark-3.3", "spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim"))) + @sparkver("3.3 / 3.4 / 3.5") private def convertBloomFilterMightContain( e: Expression, isPruningExpr: Boolean, @@ -966,7 +951,7 @@ class ShimsImpl extends Shims with Logging { } } - @enableIf(Seq("spark-3.0", "spark-3.1", "spark-3.2").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1 / 3.2") private def convertBloomFilterMightContain( e: Expression, isPruningExpr: Boolean, @@ -977,13 +962,11 @@ class ShimsImpl extends Shims with Logging { case class ForceNativeExecutionWrapper(override val child: SparkPlan) extends ForceNativeExecutionWrapperBase(child) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } @@ -994,8 +977,6 @@ case class NativeExprWrapper( override val nullable: Boolean) extends NativeExprWrapperBase(nativeExpr, dataType, nullable) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy() } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeExec.scala index 5de82883f..4edf5ba00 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeExec.scala @@ -16,18 +16,15 @@ package org.apache.spark.sql.execution.blaze.plan import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class ConvertToNativeExec(override val child: SparkPlan) extends ConvertToNativeBase(child) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeAggExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeAggExec.scala index 8375a5fb0..51625490c 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeAggExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeAggExec.scala @@ -26,8 +26,7 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.blaze.plan.NativeAggBase.AggExecMode import org.apache.spark.sql.types.BinaryType - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeAggExec( execMode: AggExecMode, @@ -47,13 +46,11 @@ case class NativeAggExec( child) with BaseAggregateExec { - @enableIf( - Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5") override val requiredChildDistributionExpressions: Option[Seq[Expression]] = theRequiredChildDistributionExpressions - @enableIf(Seq("spark-3.3", "spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim"))) + @sparkver("3.3 / 3.4 / 3.5") override val initialInputBufferOffset: Int = theInitialInputBufferOffset override def output: Seq[Attribute] = @@ -65,25 +62,19 @@ case class NativeAggExec( ExprId.apply(NativeAggBase.AGG_BUF_COLUMN_EXPR_ID)) } - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override def isStreaming: Boolean = false - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override def numShufflePartitions: Option[Int] = None override def resultExpressions: Seq[NamedExpression] = output - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeExec.scala index 0a0213b80..4023a4999 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeExec.scala @@ -22,8 +22,7 @@ import org.apache.spark.sql.blaze.NativeSupports import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeBroadcastExchangeExec(mode: BroadcastMode, override val child: SparkPlan) extends NativeBroadcastExchangeBase(mode, child) @@ -42,13 +41,11 @@ case class NativeBroadcastExchangeExec(mode: BroadcastMode, override val child: relationFuturePromise.future } - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeExpandExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeExpandExec.scala index c22fb6ebc..e2516b4c7 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeExpandExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeExpandExec.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution.blaze.plan import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeExpandExec( projections: Seq[Seq[Expression]], @@ -27,13 +26,11 @@ case class NativeExpandExec( override val child: SparkPlan) extends NativeExpandBase(projections, output, child) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeFilterExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeFilterExec.scala index ea0a17aa3..197985646 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeFilterExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeFilterExec.scala @@ -17,19 +17,16 @@ package org.apache.spark.sql.execution.blaze.plan import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeFilterExec(condition: Expression, override val child: SparkPlan) extends NativeFilterBase(condition, child) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateExec.scala index c372fe6db..8a1a2a9e6 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateExec.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution.blaze.plan import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Generator import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeGenerateExec( generator: Generator, @@ -29,13 +28,11 @@ case class NativeGenerateExec( override val child: SparkPlan) extends NativeGenerateBase(generator, requiredChildOutput, outer, generatorOutput, child) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGlobalLimitExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGlobalLimitExec.scala index d8f18dd89..19f38d203 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGlobalLimitExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGlobalLimitExec.scala @@ -16,19 +16,16 @@ package org.apache.spark.sql.execution.blaze.plan import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeGlobalLimitExec(limit: Long, override val child: SparkPlan) extends NativeGlobalLimitBase(limit, child) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeLocalLimitExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeLocalLimitExec.scala index bbd8b48ee..96124db5a 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeLocalLimitExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeLocalLimitExec.scala @@ -16,19 +16,16 @@ package org.apache.spark.sql.execution.blaze.plan import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeLocalLimitExec(limit: Long, override val child: SparkPlan) extends NativeLocalLimitBase(limit, child) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetInsertIntoHiveTableExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetInsertIntoHiveTableExec.scala index 8ac2ab161..ef31c2700 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetInsertIntoHiveTableExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetInsertIntoHiveTableExec.scala @@ -23,17 +23,14 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeParquetInsertIntoHiveTableExec( cmd: InsertIntoHiveTable, override val child: SparkPlan) extends NativeParquetInsertIntoHiveTableBase(cmd, child) { - @enableIf( - Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1 / 3.2 / 3.3") override protected def getInsertIntoHiveTableCommand( table: CatalogTable, partition: Map[String, Option[String]], @@ -52,7 +49,7 @@ case class NativeParquetInsertIntoHiveTableExec( metrics) } - @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim"))) + @sparkver("3.4 / 3.5") override protected def getInsertIntoHiveTableCommand( table: CatalogTable, partition: Map[String, Option[String]], @@ -71,19 +68,15 @@ case class NativeParquetInsertIntoHiveTableExec( metrics) } - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) - @enableIf( - Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1 / 3.2 / 3.3") class BlazeInsertIntoHiveTable30( table: CatalogTable, partition: Map[String, Option[String]], @@ -108,7 +101,7 @@ case class NativeParquetInsertIntoHiveTableExec( super.run(sparkSession, nativeParquetSink) } - @enableIf(Seq("spark-3.2", "spark-3.3").contains(System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3") override def basicWriteJobStatsTracker(hadoopConf: org.apache.hadoop.conf.Configuration) = { import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker @@ -134,7 +127,7 @@ case class NativeParquetInsertIntoHiveTableExec( } } - @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.1") override def basicWriteJobStatsTracker(hadoopConf: org.apache.hadoop.conf.Configuration) = { import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker @@ -172,7 +165,7 @@ case class NativeParquetInsertIntoHiveTableExec( } } - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") override def basicWriteJobStatsTracker(hadoopConf: org.apache.hadoop.conf.Configuration) = { import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker @@ -202,7 +195,7 @@ case class NativeParquetInsertIntoHiveTableExec( } } - @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim"))) + @sparkver("3.4 / 3.5") class BlazeInsertIntoHiveTable34( table: CatalogTable, partition: Map[String, Option[String]], diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkExec.scala index e701921d5..54f404775 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkExec.scala @@ -19,8 +19,7 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.execution.metric.SQLMetric - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeParquetSinkExec( sparkSession: SparkSession, @@ -30,13 +29,11 @@ case class NativeParquetSinkExec( override val metrics: Map[String, SQLMetric]) extends NativeParquetSinkBase(sparkSession, table, partition, child, metrics) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativePartialTakeOrderedExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativePartialTakeOrderedExec.scala index 91b8a8576..0a51a8307 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativePartialTakeOrderedExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativePartialTakeOrderedExec.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution.blaze.plan import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativePartialTakeOrderedExec( limit: Long, @@ -28,13 +27,11 @@ case class NativePartialTakeOrderedExec( override val metrics: Map[String, SQLMetric]) extends NativePartialTakeOrderedBase(limit, sortOrder, child, metrics) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeProjectExecProvider.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeProjectExecProvider.scala index f3b8dd49d..7fe895260 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeProjectExecProvider.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeProjectExecProvider.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.execution.blaze.plan import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case object NativeProjectExecProvider { - @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim"))) + @sparkver("3.4 / 3.5") def provide( projectList: Seq[NamedExpression], child: SparkPlan, @@ -49,7 +48,7 @@ case object NativeProjectExecProvider { NativeProjectExec(projectList, child, addTypeCast) } - @enableIf(Seq("spark-3.1", "spark-3.2", "spark-3.3").contains(System.getProperty("blaze.shim"))) + @sparkver("3.1 / 3.2 / 3.3") def provide( projectList: Seq[NamedExpression], child: SparkPlan, @@ -65,11 +64,11 @@ case object NativeProjectExecProvider { with AliasAwareOutputPartitioning with AliasAwareOutputOrdering { - @enableIf(Seq("spark-3.2", "spark-3.3").contains(System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) @@ -82,7 +81,7 @@ case object NativeProjectExecProvider { NativeProjectExec(projectList, child, addTypeCast) } - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") def provide( projectList: Seq[NamedExpression], child: SparkPlan, diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeRenameColumnsExecProvider.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeRenameColumnsExecProvider.scala index b57f00726..e559ab07b 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeRenameColumnsExecProvider.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeRenameColumnsExecProvider.scala @@ -16,11 +16,10 @@ package org.apache.spark.sql.execution.blaze.plan import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case object NativeRenameColumnsExecProvider { - @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim"))) + @sparkver("3.4 / 3.5") def provide(child: SparkPlan, renamedColumnNames: Seq[String]): NativeRenameColumnsBase = { import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.expressions.SortOrder @@ -44,7 +43,7 @@ case object NativeRenameColumnsExecProvider { NativeRenameColumnsExec(child, renamedColumnNames) } - @enableIf(Seq("spark-3.1", "spark-3.2", "spark-3.3").contains(System.getProperty("blaze.shim"))) + @sparkver("3.1 / 3.2 / 3.3") def provide(child: SparkPlan, renamedColumnNames: Seq[String]): NativeRenameColumnsBase = { import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.expressions.SortOrder @@ -58,11 +57,11 @@ case object NativeRenameColumnsExecProvider { with AliasAwareOutputPartitioning with AliasAwareOutputOrdering { - @enableIf(Seq("spark-3.2", "spark-3.3").contains(System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) @@ -73,7 +72,7 @@ case object NativeRenameColumnsExecProvider { NativeRenameColumnsExec(child, renamedColumnNames) } - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") def provide(child: SparkPlan, renamedColumnNames: Seq[String]): NativeRenameColumnsBase = { case class NativeRenameColumnsExec( override val child: SparkPlan, diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala index a0d57ab10..d9750bd36 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeExec.scala @@ -25,7 +25,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.shuffle.ShuffleWriteProcessor -import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.sql.blaze.NativeHelper import org.apache.spark.sql.blaze.NativeRDD import org.apache.spark.sql.catalyst.InternalRow @@ -33,14 +32,12 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.blaze.shuffle.BlazeRssShuffleWriterBase -import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleWriter import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleWriterBase import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter import org.apache.spark.sql.execution.metric.SQLShuffleWriteMetricsReporter - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeShuffleExchangeExec( override val outputPartitioning: Partitioning, @@ -165,30 +162,26 @@ case class NativeShuffleExchangeExec( // for databricks testing val causedBroadcastJoinBuildOOM = false - @enableIf(Seq("spark-3.5").contains(System.getProperty("blaze.shim"))) + @sparkver("3.5") override def advisoryPartitionSize: Option[Long] = None // If users specify the num partitions via APIs like `repartition`, we shouldn't change it. // For `SinglePartition`, it requires exactly one partition and we can't change it either. - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") override def canChangeNumPartitions: Boolean = outputPartitioning != SinglePartition - @enableIf( - Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5") override def shuffleOrigin = { import org.apache.spark.sql.execution.exchange.ShuffleOrigin; _shuffleOrigin.get.asInstanceOf[ShuffleOrigin] } - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortExec.scala index 5eff2206d..d79c0304a 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortExec.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.blaze.plan import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeSortExec( sortOrder: Seq[SortOrder], @@ -26,13 +25,11 @@ case class NativeSortExec( override val child: SparkPlan) extends NativeSortBase(sortOrder, global, child) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeTakeOrderedExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeTakeOrderedExec.scala index d4cb0c9d5..2c98b82c0 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeTakeOrderedExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeTakeOrderedExec.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.blaze.plan import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeTakeOrderedExec( limit: Long, @@ -26,13 +25,11 @@ case class NativeTakeOrderedExec( override val child: SparkPlan) extends NativeTakeOrderedBase(limit, sortOrder, child) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeUnionExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeUnionExec.scala index 6a9ce256f..29fcc0600 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeUnionExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeUnionExec.scala @@ -16,19 +16,16 @@ package org.apache.spark.sql.execution.blaze.plan import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeUnionExec(override val children: Seq[SparkPlan]) extends NativeUnionBase(children) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = copy(children = newChildren) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(children = newChildren) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeWindowExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeWindowExec.scala index c16223e9b..b3cbf8175 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeWindowExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeWindowExec.scala @@ -19,8 +19,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.execution.SparkPlan - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeWindowExec( windowExpression: Seq[NamedExpression], @@ -29,13 +28,11 @@ case class NativeWindowExec( override val child: SparkPlan) extends NativeWindowBase(windowExpression, partitionSpec, orderSpec, child) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(child = newChildren.head) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala index e8dbf4c94..374217608 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala @@ -18,12 +18,11 @@ package org.apache.spark.sql.execution.blaze.shuffle import java.io.InputStream import org.apache.spark.{MapOutputTracker, SparkEnv, TaskContext} -import org.apache.spark.internal.{Logging, config} +import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter} import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} - -import com.thoughtworks.enableIf +import org.blaze.sparkver class BlazeBlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], @@ -37,9 +36,7 @@ class BlazeBlockStoreShuffleReader[K, C]( with Logging { override def readBlocks(): Iterator[(BlockId, InputStream)] = { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") def fetchIterator = new ShuffleBlockFetcherIterator( context, blockManager.blockStoreClient, @@ -60,7 +57,7 @@ class BlazeBlockStoreShuffleReader[K, C]( readMetrics, fetchContinuousBlocksInBatch).toCompletionIterator - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") def fetchIterator = new ShuffleBlockFetcherIterator( context, blockManager.blockStoreClient, diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleManagerBase.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleManagerBase.scala index 2dab56776..f4574aa34 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleManagerBase.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleManagerBase.scala @@ -19,8 +19,7 @@ import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.shuffle._ import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleDependency.isArrowShuffle - -import com.thoughtworks.enableIf +import org.blaze.sparkver abstract class BlazeRssShuffleManagerBase(_conf: SparkConf) extends ShuffleManager with Logging { override def registerShuffle[K, V, C]( @@ -73,9 +72,7 @@ abstract class BlazeRssShuffleManagerBase(_conf: SparkConf) extends ShuffleManag context: TaskContext, metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] - @enableIf( - Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5") override def getReader[K, C]( handle: ShuffleHandle, startMapIndex: Int, @@ -106,7 +103,7 @@ abstract class BlazeRssShuffleManagerBase(_conf: SparkConf) extends ShuffleManag } } - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") override def getReader[K, C]( handle: ShuffleHandle, startPartition: Int, @@ -121,7 +118,7 @@ abstract class BlazeRssShuffleManagerBase(_conf: SparkConf) extends ShuffleManag } } - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") override def getReaderForRange[K, C]( handle: ShuffleHandle, startMapIndex: Int, diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala index bec87beb8..5772f3ea4 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala @@ -21,8 +21,7 @@ import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleDependency.isArrowShuffle - -import com.thoughtworks.enableIf +import org.blaze.sparkver class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { val sortShuffleManager = new SortShuffleManager(conf) @@ -45,9 +44,7 @@ class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { sortShuffleManager.registerShuffle(shuffleId, dependency) } - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override def getReader[K, C]( handle: ShuffleHandle, startMapIndex: Int, @@ -60,10 +57,9 @@ class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { if (isArrowShuffle(handle)) { val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]] - @enableIf(Seq("spark-3.2").contains(System.getProperty("blaze.shim"))) + @sparkver("3.2") def shuffleMergeFinalized = baseShuffleHandle.dependency.shuffleMergeFinalized - @enableIf( - Seq("spark-3.3", "spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim"))) + @sparkver("3.3 / 3.4 / 3.5") def shuffleMergeFinalized = baseShuffleHandle.dependency.isShuffleMergeFinalizedMarked val (blocksByAddress, canEnableBatchFetch) = @@ -106,7 +102,7 @@ class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { } } - @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.1") override def getReader[K, C]( handle: ShuffleHandle, startMapIndex: Int, @@ -143,7 +139,7 @@ class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { } } - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") override def getReader[K, C]( handle: ShuffleHandle, startPartition: Int, @@ -170,7 +166,7 @@ class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { } } - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") override def getReaderForRange[K, C]( handle: ShuffleHandle, startMapIndex: Int, diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleWriter.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleWriter.scala index ba796786e..f68867406 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleWriter.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleWriter.scala @@ -16,14 +16,11 @@ package org.apache.spark.sql.execution.blaze.shuffle import org.apache.spark.shuffle.ShuffleWriteMetricsReporter - -import com.thoughtworks.enableIf +import org.blaze.sparkver class BlazeShuffleWriter[K, V](metrics: ShuffleWriteMetricsReporter) extends BlazeShuffleWriterBase[K, V](metrics) { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override def getPartitionLengths(): Array[Long] = partitionLengths } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleWriter.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleWriter.scala index 4f311d0f5..16a34051c 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleWriter.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleWriter.scala @@ -25,8 +25,7 @@ import org.apache.spark.shuffle.celeborn.ExecutorShuffleIdTracker import org.apache.spark.shuffle.celeborn.SparkUtils import org.apache.spark.sql.execution.blaze.shuffle.BlazeRssShuffleWriterBase import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase - -import com.thoughtworks.enableIf +import org.blaze.sparkver class BlazeCelebornShuffleWriter[K, V]( celebornShuffleWriter: ShuffleWriter[K, V], @@ -57,9 +56,7 @@ class BlazeCelebornShuffleWriter[K, V]( metrics) } - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override def getPartitionLengths(): Array[Long] = celebornShuffleWriter.getPartitionLengths() override def rssStop(success: Boolean): Unit = { diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala index 5e51ae068..8cb7cd64e 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.execution.blaze.plan.BroadcastRight import org.apache.spark.sql.execution.blaze.plan.BroadcastSide import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastJoinBase import org.apache.spark.sql.execution.joins.HashJoin - -import com.thoughtworks.enableIf +import org.blaze.sparkver case class NativeBroadcastJoinExec( override val left: SparkPlan, @@ -47,24 +46,20 @@ case class NativeBroadcastJoinExec( override val condition: Option[Expression] = None - @enableIf( - Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5") override def buildSide: org.apache.spark.sql.catalyst.optimizer.BuildSide = broadcastSide match { case BroadcastLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft case BroadcastRight => org.apache.spark.sql.catalyst.optimizer.BuildRight } - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") override val buildSide: org.apache.spark.sql.execution.joins.BuildSide = broadcastSide match { case BroadcastLeft => org.apache.spark.sql.execution.joins.BuildLeft case BroadcastRight => org.apache.spark.sql.execution.joins.BuildRight } - @enableIf( - Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5") override def requiredChildDistribution = { import org.apache.spark.sql.catalyst.plans.physical.BroadcastDistribution import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution @@ -79,36 +74,28 @@ case class NativeBroadcastJoinExec( } } - @enableIf( - Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5") override def supportCodegen: Boolean = false - @enableIf( - Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5") override def inputRDDs() = { throw new NotImplementedError("NativeBroadcastJoin dose not support codegen") } - @enableIf( - Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5") override protected def prepareRelation( ctx: org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext) : org.apache.spark.sql.execution.joins.HashedRelationInfo = { throw new NotImplementedError("NativeBroadcastJoin dose not support codegen") } - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override protected def withNewChildrenInternal( newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = copy(left = newLeft, right = newRight) - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(left = newChildren(0), right = newChildren(1)) } diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeShuffledHashJoinExecProvider.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeShuffledHashJoinExecProvider.scala index 47c89a8f4..d359622ef 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeShuffledHashJoinExecProvider.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeShuffledHashJoinExecProvider.scala @@ -20,14 +20,11 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.blaze.plan.BuildSide import org.apache.spark.sql.execution.blaze.plan.NativeShuffledHashJoinBase - -import com.thoughtworks.enableIf +import org.blaze.sparkver case object NativeShuffledHashJoinExecProvider { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") def provide( left: SparkPlan, right: SparkPlan, @@ -74,7 +71,7 @@ case object NativeShuffledHashJoinExecProvider { NativeShuffledHashJoinExec(left, right, leftKeys, rightKeys, joinType, buildSide) } - @enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.1") def provide( left: SparkPlan, right: SparkPlan, @@ -118,7 +115,7 @@ case object NativeShuffledHashJoinExecProvider { NativeShuffledHashJoinExec(left, right, leftKeys, rightKeys, joinType, buildSide) } - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") def provide( left: SparkPlan, right: SparkPlan, diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeSortMergeJoinExecProvider.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeSortMergeJoinExecProvider.scala index d589fcacb..df3796530 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeSortMergeJoinExecProvider.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeSortMergeJoinExecProvider.scala @@ -19,14 +19,11 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.blaze.plan.NativeSortMergeJoinBase - -import com.thoughtworks.enableIf +import org.blaze.sparkver case object NativeSortMergeJoinExecProvider { - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") def provide( left: SparkPlan, right: SparkPlan, @@ -71,7 +68,7 @@ case object NativeSortMergeJoinExecProvider { NativeSortMergeJoinExec(left, right, leftKeys, rightKeys, joinType) } - @enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0 / 3.1") def provide( left: SparkPlan, right: SparkPlan, diff --git a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/blaze/BlazeAdaptiveQueryExecSuite.scala b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/blaze/BlazeAdaptiveQueryExecSuite.scala index d2ac0bf1b..607fffb56 100644 --- a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/blaze/BlazeAdaptiveQueryExecSuite.scala +++ b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/blaze/BlazeAdaptiveQueryExecSuite.scala @@ -15,9 +15,9 @@ */ package org.apache.spark.sql.blaze -import com.thoughtworks.enableMembersIf +import org.blaze.sparkverEnableMembers -@enableMembersIf(Seq("spark-3.5").contains(System.getProperty("blaze.shim"))) +@sparkverEnableMembers("3.5") class BlazeAdaptiveQueryExecSuite extends org.apache.spark.sql.QueryTest with BaseBlazeSQLSuite diff --git a/spark-extension/pom.xml b/spark-extension/pom.xml index 26e46c903..2cee03c2e 100644 --- a/spark-extension/pom.xml +++ b/spark-extension/pom.xml @@ -13,6 +13,12 @@ jar + + org.blaze + spark-version-annotation-macros + ${revision} + compile + org.blaze proto @@ -69,11 +75,6 @@ scalatest_${scalaVersion} test - - com.thoughtworks.enableIf - enableif_${scalaVersion} - 1.2.0 - diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala index 4df634542..7d62694fb 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala @@ -15,8 +15,6 @@ */ package org.apache.spark.sql.blaze -import com.thoughtworks.enableIf - import scala.annotation.tailrec import scala.collection.mutable import org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat @@ -81,6 +79,7 @@ import org.apache.spark.sql.hive.blaze.BlazeHiveConverters import org.apache.spark.sql.hive.execution.InsertIntoHiveTable import org.apache.spark.sql.hive.execution.blaze.plan.NativeHiveTableScanBase import org.apache.spark.sql.types.LongType +import org.blaze.sparkver object BlazeConverters extends Logging { val enableScan: Boolean = @@ -296,12 +295,10 @@ object BlazeConverters extends Logging { getShuffleOrigin(exec)) } - @enableIf( - Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5") def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = Some(exec.shuffleOrigin) - @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + @sparkver("3.0") def getShuffleOrigin(exec: ShuffleExchangeExec): Option[Any] = None def convertFileSourceScanExec(exec: FileSourceScanExec): SparkPlan = { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/SpillBuf.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/SpillBuf.scala index e016df0b3..e3aa3262e 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/SpillBuf.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/SpillBuf.scala @@ -145,7 +145,7 @@ class ReleasedSpillBuf(releasing: SpillBuf) extends SpillBuf { override val size: Long = releasing.size releasing.release() - + override def write(buf: ByteBuffer): Unit = throw new UnsupportedOperationException() @@ -153,4 +153,4 @@ class ReleasedSpillBuf(releasing: SpillBuf) extends SpillBuf { throw new UnsupportedOperationException() override def release(): Unit = {} -} \ No newline at end of file +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleWriterBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleWriterBase.scala index 374d3e749..236e5399a 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleWriterBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeRssShuffleWriterBase.scala @@ -23,13 +23,11 @@ import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteMetricsReporter} import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.sql.blaze.{JniBridge, NativeHelper, NativeRDD, Shims} import org.blaze.protobuf.{PhysicalPlanNode, RssShuffleWriterExecNode} - -import com.thoughtworks.enableIf +import org.blaze.sparkver abstract class BlazeRssShuffleWriterBase[K, V](metrics: ShuffleWriteMetricsReporter) extends ShuffleWriter[K, V] { - private var mapStatus: Option[MapStatus] = None private var rpw: RssPartitionWriterBase = _ private var mapId: Int = 0 @@ -70,19 +68,11 @@ abstract class BlazeRssShuffleWriterBase[K, V](metrics: ShuffleWriteMetricsRepor } finally { rpw.close() } - - mapStatus = Some( - Shims.get.getMapStatus( - SparkEnv.get.blockManager.shuffleServerId, - rpw.getPartitionLengthMap, - mapId)) } def rssStop(success: Boolean): Unit = {} - @enableIf( - Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( - System.getProperty("blaze.shim"))) + @sparkver("3.2 / 3.3 / 3.4 / 3.5") override def getPartitionLengths(): Array[Long] = rpw.getPartitionLengthMap override def write(records: Iterator[Product2[K, V]]): Unit = { diff --git a/spark-version-annotation-macros/pom.xml b/spark-version-annotation-macros/pom.xml new file mode 100644 index 000000000..f45d25ecf --- /dev/null +++ b/spark-version-annotation-macros/pom.xml @@ -0,0 +1,27 @@ + + + 4.0.0 + + + org.blaze + blaze-engine + ${revision} + ../ + + org.blaze + spark-version-annotation-macros + jar + + + + org.scala-lang + scala-library + provided + + + org.scalamacros + paradise_${scalaLongVersion} + 2.1.1 + + + diff --git a/spark-version-annotation-macros/src/main/scala/org/blaze/sparkver.scala b/spark-version-annotation-macros/src/main/scala/org/blaze/sparkver.scala new file mode 100644 index 000000000..537d1fb2b --- /dev/null +++ b/spark-version-annotation-macros/src/main/scala/org/blaze/sparkver.scala @@ -0,0 +1,110 @@ +/* + * Copyright 2022 The Blaze Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.blaze + +import scala.annotation.StaticAnnotation +import scala.annotation.compileTimeOnly +import scala.language.experimental._ +import scala.reflect.macros.whitebox + +object sparkver { + def matchVersion(vers: String): Boolean = { + val configuredVer = System.getProperty("blaze.shim") + for (ver <- vers.split("/")) { + val verStripped = ver.trim + if (s"spark-$verStripped" == configuredVer) { + return true + } + } + false + } + + object Macros { + def impl(c: whitebox.Context)(annottees: c.Expr[Any]*)( + disabled: => c.Expr[Any]): c.Expr[Any] = { + import c.universe._ + + val versions = c.macroApplication match { + case Apply(Select(Apply(_, List(vs)), _), _) => c.eval(c.Expr[String](q"$vs")) + } + + if (matchVersion(versions)) { + return c.Expr[Any](q"..$annottees") + } + disabled + } + + def verEnable(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { + import c.universe._ + + impl(c)(annottees: _*) { + c.Expr(EmptyTree) + } + } + + def verEnableMembers(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { + import c.universe._ + + impl(c)(annottees: _*) { + val head = annottees.head.tree match { + case ClassDef(mods, name, tparams, Template(parents, self, _body)) => + ClassDef(mods, name, tparams, Template(parents, self, List(EmptyTree))) + case ModuleDef(mods, name, Template(parents, self, _body)) => + ModuleDef(mods, name, Template(parents, self, List(EmptyTree))) + } + c.Expr(q"$head; ..${annottees.tail}") + } + } + + def verEnableOverride(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { + import c.universe._ + import scala.reflect.internal.Flags + impl(c)(annottees: _*) { + val head = annottees.head.tree match { + case DefDef(mods, name, tparams, vparams, tpt, rhs) => + val newMods = Modifiers( + (mods.flags.asInstanceOf[Long] & ~Flags.OVERRIDE).asInstanceOf[FlagSet], + mods.privateWithin, + mods.annotations) + DefDef(newMods, name, tparams, vparams, tpt, rhs) + + case ValDef(mods, name, tpt, rhs) => + val newMods = Modifiers( + (mods.flags.asInstanceOf[Long] & ~Flags.OVERRIDE).asInstanceOf[FlagSet], + mods.privateWithin, + mods.annotations) + ValDef(newMods, name, tpt, rhs) + } + c.Expr(q"$head; ..${annottees.tail}") + } + } + } +} + +@compileTimeOnly("enable macro paradise to expand macro annotations") +final class sparkver(vers: String) extends StaticAnnotation { + def macroTransform(annottees: Any*): Any = macro sparkver.Macros.verEnable +} + +@compileTimeOnly("enable macro paradise to expand macro annotations") +final class sparkverEnableMembers(vers: String) extends StaticAnnotation { + def macroTransform(annottees: Any*): Any = macro sparkver.Macros.verEnableMembers +} + +@compileTimeOnly("enable macro paradise to expand macro annotations") +final class sparkverEnableOverride(vers: String) extends StaticAnnotation { + def macroTransform(annottees: Any*): Any = macro sparkver.Macros.verEnableOverride +}