diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala index 53866d8d8..c99874d82 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.blaze import java.io.File import java.util.UUID import org.apache.commons.lang3.reflect.FieldUtils +import org.apache.hadoop.fs.Path import org.apache.spark.OneToOneDependency import org.apache.spark.ShuffleDependency import org.apache.spark.SparkEnv @@ -33,6 +34,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.SparkSession import org.apache.spark.sql.blaze.BlazeConverters.ForceNativeExecutionWrapperBase import org.apache.spark.sql.blaze.NativeConverters.NativeExprWrapperBase +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Expression @@ -97,6 +99,7 @@ import org.apache.spark.sql.execution.blaze.plan._ import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase import org.apache.spark.sql.execution.blaze.shuffle.celeborn.BlazeCelebornShuffleManager import org.apache.spark.sql.execution.blaze.shuffle.BlazeBlockStoreShuffleReaderBase +import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec} import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastJoinExec import org.apache.spark.sql.execution.joins.blaze.plan.NativeShuffledHashJoinExecProvider @@ -818,6 +821,37 @@ class ShimsImpl extends Shims with Logging { NativeExprWrapper(nativeExpr, dataType, nullable) } + @enableIf( + Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains( + System.getProperty("blaze.shim"))) + override def getPartitionedFile( + partitionValues: InternalRow, + filePath: String, + offset: Long, + size: Long): PartitionedFile = + PartitionedFile(partitionValues, filePath, offset, size) + + @enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim"))) + override def getPartitionedFile( + partitionValues: InternalRow, + filePath: String, + offset: Long, + size: Long): PartitionedFile = { + import org.apache.spark.paths.SparkPath + PartitionedFile(partitionValues, SparkPath.fromPath(new Path(filePath)), offset, size) + } + + @enableIf( + Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains( + System.getProperty("blaze.shim"))) + override def getMinPartitionNum(sparkSession: SparkSession): Int = + sparkSession.sessionState.conf.filesMinPartitionNum + .getOrElse(sparkSession.sparkContext.defaultParallelism) + + @enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim"))) + override def getMinPartitionNum(sparkSession: SparkSession): Int = + sparkSession.sparkContext.defaultParallelism + @enableIf( Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains( System.getProperty("blaze.shim"))) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala index e620e1981..53195c324 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala @@ -46,7 +46,9 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.hive.execution.InsertIntoHiveTable import org.apache.spark.sql.types.DataType import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.storage.BlockManagerId import org.apache.spark.storage.FileSegment @@ -238,6 +240,14 @@ abstract class Shims { dataType: DataType, nullable: Boolean): Expression + def getPartitionedFile( + partitionValues: InternalRow, + filePath: String, + offset: Long, + size: Long): PartitionedFile + + def getMinPartitionNum(sparkSession: SparkSession): Int + def postTransform(plan: SparkPlan, sc: SparkContext): Unit = {} } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/hive/execution/blaze/plan/NativePaimonTableScanExec.scala b/spark-extension/src/main/scala/org/apache/spark/sql/hive/execution/blaze/plan/NativePaimonTableScanExec.scala index ec26379d0..22842a975 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/hive/execution/blaze/plan/NativePaimonTableScanExec.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/hive/execution/blaze/plan/NativePaimonTableScanExec.scala @@ -216,10 +216,10 @@ case class NativePaimonTableScanExec(basedHiveScan: HiveTableScanExec) (0L until dataFileMeta.fileSize() by maxSplitBytes).map { offset => val remaining = dataFileMeta.fileSize() - offset val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining - PartitionedFile(partitionValues, filePath, offset, size) + Shims.get.getPartitionedFile(partitionValues, filePath, offset, size) } } else { - Seq(PartitionedFile(partitionValues, filePath, 0, dataFileMeta.fileSize())) + Seq(Shims.get.getPartitionedFile(partitionValues, filePath, 0, dataFileMeta.fileSize())) } } @@ -229,8 +229,7 @@ case class NativePaimonTableScanExec(basedHiveScan: HiveTableScanExec) selectedSplits: Seq[DataSplit]): Long = { val defaultMaxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes - val minPartitionNum = sparkSession.sessionState.conf.filesMinPartitionNum - .getOrElse(sparkSession.sparkContext.defaultParallelism) + val minPartitionNum = Shims.get.getMinPartitionNum(sparkSession) val totalBytes = selectedSplits .flatMap(_.dataFiles().asScala.map(_.fileSize() + openCostInBytes)) .sum