Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<packaging>pom</packaging>

<modules>
<module>spark-version-annotation-macros</module>
<module>spark-extension</module>
<module>${shimPkg}</module>
<module>hadoop-shim</module>
Expand Down
6 changes: 0 additions & 6 deletions spark-extension-shims-spark3/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,6 @@
<version>1.12.10</version>
</dependency>

<dependency>
<groupId>com.thoughtworks.enableIf</groupId>
<artifactId>enableif_${scalaVersion}</artifactId>
<version>1.2.0</version>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scalaVersion}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@ package org.apache.spark.sql.blaze

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan

import com.thoughtworks.enableIf
import org.blaze.sparkver

object InterceptedValidateSparkPlan extends Logging {

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
def validate(plan: SparkPlan): Unit = {
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.blaze.plan.NativeRenameColumnsBase
Expand Down Expand Up @@ -72,14 +69,12 @@ object InterceptedValidateSparkPlan extends Logging {
}
}

@enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
@sparkver("3.0 / 3.1")
def validate(plan: SparkPlan): Unit = {
throw new UnsupportedOperationException("validate is not supported in spark 3.0.3 or 3.1.3")
}

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
private def errorOnInvalidBroadcastQueryStage(plan: SparkPlan): Unit = {
import org.apache.spark.sql.execution.adaptive.InvalidAQEPlanException
throw InvalidAQEPlanException("Invalid broadcast query stage", plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package org.apache.spark.sql.blaze

import java.io.File
import java.util.UUID

import org.apache.commons.lang3.reflect.FieldUtils
import org.apache.spark.OneToOneDependency
import org.apache.spark.ShuffleDependency
Expand Down Expand Up @@ -111,27 +112,25 @@ import org.apache.spark.sql.types.StringType
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.storage.FileSegment
import org.blaze.{protobuf => pb}
import com.thoughtworks.enableIf
import org.apache.spark.sql.execution.blaze.shuffle.uniffle.BlazeUniffleShuffleManager
import org.blaze.sparkver

class ShimsImpl extends Shims with Logging {

@enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
@sparkver("3.0")
override def shimVersion: String = "spark-3.0"
@enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim")))
@sparkver("3.1")
override def shimVersion: String = "spark-3.1"
@enableIf(Seq("spark-3.2").contains(System.getProperty("blaze.shim")))
@sparkver("3.2")
override def shimVersion: String = "spark-3.2"
@enableIf(Seq("spark-3.3").contains(System.getProperty("blaze.shim")))
@sparkver("3.3")
override def shimVersion: String = "spark-3.3"
@enableIf(Seq("spark-3.4").contains(System.getProperty("blaze.shim")))
@sparkver("3.4")
override def shimVersion: String = "spark-3.4"
@enableIf(Seq("spark-3.5").contains(System.getProperty("blaze.shim")))
@sparkver("3.5")
override def shimVersion: String = "spark-3.5"

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def initExtension(): Unit = {
ValidateSparkPlanInjector.inject()

Expand All @@ -146,7 +145,7 @@ class ShimsImpl extends Shims with Logging {
}
}

@enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
@sparkver("3.0 / 3.1")
override def initExtension(): Unit = {
if (BlazeConf.FORCE_SHUFFLED_HASH_JOIN.booleanConf()) {
logWarning(s"${BlazeConf.FORCE_SHUFFLED_HASH_JOIN.key} is not supported in $shimVersion")
Expand Down Expand Up @@ -374,9 +373,7 @@ class ShimsImpl extends Shims with Logging {
length: Long,
numRecords: Long): FileSegment = new FileSegment(file, offset, length)

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def commit(
dep: ShuffleDependency[_, _, _],
shuffleBlockResolver: IndexShuffleBlockResolver,
Expand All @@ -396,7 +393,7 @@ class ShimsImpl extends Shims with Logging {
MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId, partitionLengths, mapId)
}

@enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
@sparkver("3.0 / 3.1")
override def commit(
dep: ShuffleDependency[_, _, _],
shuffleBlockResolver: IndexShuffleBlockResolver,
Expand Down Expand Up @@ -533,23 +530,19 @@ class ShimsImpl extends Shims with Logging {
expr.asInstanceOf[AggregateExpression].filter
}

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
private def isAQEShuffleRead(exec: SparkPlan): Boolean = {
import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
exec.isInstanceOf[AQEShuffleReadExec]
}

@enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
@sparkver("3.0 / 3.1")
private def isAQEShuffleRead(exec: SparkPlan): Boolean = {
import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec
exec.isInstanceOf[CustomShuffleReaderExec]
}

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = {
import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
import org.apache.spark.sql.execution.CoalescedMapperPartitionSpec
Expand Down Expand Up @@ -652,7 +645,7 @@ class ShimsImpl extends Shims with Logging {
}
}

@enableIf(Seq("spark-3.1").contains(System.getProperty("blaze.shim")))
@sparkver("3.1")
private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = {
import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec

Expand Down Expand Up @@ -744,7 +737,7 @@ class ShimsImpl extends Shims with Logging {
}
}

@enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
@sparkver("3.0")
private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = {
import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec

Expand Down Expand Up @@ -846,13 +839,11 @@ class ShimsImpl extends Shims with Logging {
}
}

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def getSqlContext(sparkPlan: SparkPlan): SQLContext =
sparkPlan.session.sqlContext

@enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
@sparkver("3.0 / 3.1")
override def getSqlContext(sparkPlan: SparkPlan): SQLContext = sparkPlan.sqlContext

override def createNativeExprWrapper(
Expand All @@ -862,17 +853,15 @@ class ShimsImpl extends Shims with Logging {
NativeExprWrapper(nativeExpr, dataType, nullable)
}

@enableIf(
Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains(
System.getProperty("blaze.shim")))
@sparkver("3.0 / 3.1 / 3.2 / 3.3")
override def getPartitionedFile(
partitionValues: InternalRow,
filePath: String,
offset: Long,
size: Long): PartitionedFile =
PartitionedFile(partitionValues, filePath, offset, size)

@enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
@sparkver("3.4 / 3.5")
override def getPartitionedFile(
partitionValues: InternalRow,
filePath: String,
Expand All @@ -883,20 +872,16 @@ class ShimsImpl extends Shims with Logging {
PartitionedFile(partitionValues, SparkPath.fromPath(new Path(filePath)), offset, size)
}

@enableIf(
Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override def getMinPartitionNum(sparkSession: SparkSession): Int =
sparkSession.sessionState.conf.filesMinPartitionNum
.getOrElse(sparkSession.sparkContext.defaultParallelism)

@enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
@sparkver("3.0")
override def getMinPartitionNum(sparkSession: SparkSession): Int =
sparkSession.sparkContext.defaultParallelism

@enableIf(
Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains(
System.getProperty("blaze.shim")))
@sparkver("3.0 / 3.1 / 3.2 / 3.3")
private def convertPromotePrecision(
e: Expression,
isPruningExpr: Boolean,
Expand All @@ -909,13 +894,13 @@ class ShimsImpl extends Shims with Logging {
}
}

@enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
@sparkver("3.4 / 3.5")
private def convertPromotePrecision(
e: Expression,
isPruningExpr: Boolean,
fallback: Expression => pb.PhysicalExprNode): Option[pb.PhysicalExprNode] = None

@enableIf(Seq("spark-3.3", "spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
@sparkver("3.3 / 3.4 / 3.5")
private def convertBloomFilterAgg(agg: AggregateFunction): Option[pb.PhysicalAggExprNode] = {
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
agg match {
Expand All @@ -940,10 +925,10 @@ class ShimsImpl extends Shims with Logging {
}
}

@enableIf(Seq("spark-3.0", "spark-3.1", "spark-3.2").contains(System.getProperty("blaze.shim")))
@sparkver("3.0 / 3.1 / 3.2")
private def convertBloomFilterAgg(agg: AggregateFunction): Option[pb.PhysicalAggExprNode] = None

@enableIf(Seq("spark-3.3", "spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
@sparkver("3.3 / 3.4 / 3.5")
private def convertBloomFilterMightContain(
e: Expression,
isPruningExpr: Boolean,
Expand All @@ -966,7 +951,7 @@ class ShimsImpl extends Shims with Logging {
}
}

@enableIf(Seq("spark-3.0", "spark-3.1", "spark-3.2").contains(System.getProperty("blaze.shim")))
@sparkver("3.0 / 3.1 / 3.2")
private def convertBloomFilterMightContain(
e: Expression,
isPruningExpr: Boolean,
Expand All @@ -977,13 +962,11 @@ class ShimsImpl extends Shims with Logging {
case class ForceNativeExecutionWrapper(override val child: SparkPlan)
extends ForceNativeExecutionWrapperBase(child) {

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
@sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Expand All @@ -994,8 +977,6 @@ case class NativeExprWrapper(
override val nullable: Boolean)
extends NativeExprWrapperBase(nativeExpr, dataType, nullable) {

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy()
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,15 @@
package org.apache.spark.sql.execution.blaze.plan

import org.apache.spark.sql.execution.SparkPlan

import com.thoughtworks.enableIf
import org.blaze.sparkver

case class ConvertToNativeExec(override val child: SparkPlan) extends ConvertToNativeBase(child) {

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
@sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.blaze.plan.NativeAggBase.AggExecMode
import org.apache.spark.sql.types.BinaryType

import com.thoughtworks.enableIf
import org.blaze.sparkver

case class NativeAggExec(
execMode: AggExecMode,
Expand All @@ -47,13 +46,11 @@ case class NativeAggExec(
child)
with BaseAggregateExec {

@enableIf(
Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override val requiredChildDistributionExpressions: Option[Seq[Expression]] =
theRequiredChildDistributionExpressions

@enableIf(Seq("spark-3.3", "spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
@sparkver("3.3 / 3.4 / 3.5")
override val initialInputBufferOffset: Int = theInitialInputBufferOffset

override def output: Seq[Attribute] =
Expand All @@ -65,25 +62,19 @@ case class NativeAggExec(
ExprId.apply(NativeAggBase.AGG_BUF_COLUMN_EXPR_ID))
}

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def isStreaming: Boolean = false

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
override def numShufflePartitions: Option[Int] = None

override def resultExpressions: Seq[NamedExpression] = output

@enableIf(
Seq("spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
@sparkver("3.2 / 3.3 / 3.4 / 3.5")
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark-3.0", "spark-3.1").contains(System.getProperty("blaze.shim")))
@sparkver("3.0 / 3.1")
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Loading
Loading