Skip to content

Commit 9c93f8c

Browse files
committed
[BLAZE-707][FOLLOWUP] NativePaimonTableScanExec should use shimed PartitionedFile and min partition number
1 parent d5bf5a0 commit 9c93f8c

3 files changed

Lines changed: 46 additions & 4 deletions

File tree

spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ package org.apache.spark.sql.blaze
1818
import java.io.File
1919
import java.util.UUID
2020
import org.apache.commons.lang3.reflect.FieldUtils
21+
import org.apache.hadoop.fs.Path
2122
import org.apache.spark.OneToOneDependency
2223
import org.apache.spark.ShuffleDependency
2324
import org.apache.spark.SparkEnv
2425
import org.apache.spark.SparkException
2526
import org.apache.spark.TaskContext
2627
import org.apache.spark.internal.Logging
28+
import org.apache.spark.paths.SparkPath
2729
import org.apache.spark.rdd.RDD
2830
import org.apache.spark.scheduler.MapStatus
2931
import org.apache.spark.shuffle.IndexShuffleBlockResolver
@@ -33,6 +35,7 @@ import org.apache.spark.sql.SQLContext
3335
import org.apache.spark.sql.SparkSession
3436
import org.apache.spark.sql.blaze.BlazeConverters.ForceNativeExecutionWrapperBase
3537
import org.apache.spark.sql.blaze.NativeConverters.NativeExprWrapperBase
38+
import org.apache.spark.sql.catalyst.InternalRow
3639
import org.apache.spark.sql.catalyst.catalog.CatalogTable
3740
import org.apache.spark.sql.catalyst.expressions.Attribute
3841
import org.apache.spark.sql.catalyst.expressions.Expression
@@ -97,6 +100,7 @@ import org.apache.spark.sql.execution.blaze.plan._
97100
import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase
98101
import org.apache.spark.sql.execution.blaze.shuffle.celeborn.BlazeCelebornShuffleManager
99102
import org.apache.spark.sql.execution.blaze.shuffle.BlazeBlockStoreShuffleReaderBase
103+
import org.apache.spark.sql.execution.datasources.PartitionedFile
100104
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec}
101105
import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastJoinExec
102106
import org.apache.spark.sql.execution.joins.blaze.plan.NativeShuffledHashJoinExecProvider
@@ -818,6 +822,35 @@ class ShimsImpl extends Shims with Logging {
818822
NativeExprWrapper(nativeExpr, dataType, nullable)
819823
}
820824

825+
@enableIf(
826+
Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains(
827+
System.getProperty("blaze.shim")))
828+
override def getPartitionedFile(
829+
partitionValues: InternalRow,
830+
filePath: String,
831+
offset: Long,
832+
size: Long): PartitionedFile =
833+
PartitionedFile(partitionValues, filePath, offset, size)
834+
835+
@enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
836+
override def getPartitionedFile(
837+
partitionValues: InternalRow,
838+
filePath: String,
839+
offset: Long,
840+
size: Long): PartitionedFile =
841+
PartitionedFile(partitionValues, SparkPath.fromPath(new Path(filePath)), offset, size)
842+
843+
@enableIf(
844+
Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
845+
System.getProperty("blaze.shim")))
846+
override def getMinPartitionNum(sparkSession: SparkSession): Int =
847+
sparkSession.sessionState.conf.filesMinPartitionNum
848+
.getOrElse(sparkSession.sparkContext.defaultParallelism)
849+
850+
@enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
851+
override def getMinPartitionNum(sparkSession: SparkSession): Int =
852+
sparkSession.sparkContext.defaultParallelism
853+
821854
@enableIf(
822855
Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains(
823856
System.getProperty("blaze.shim")))

spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ import org.apache.spark.sql.execution.metric.SQLMetric
4646
import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
4747
import org.apache.spark.sql.types.DataType
4848
import org.apache.spark.sql.SparkSession
49+
import org.apache.spark.sql.catalyst.InternalRow
4950
import org.apache.spark.sql.catalyst.catalog.CatalogTable
51+
import org.apache.spark.sql.execution.datasources.PartitionedFile
5052
import org.apache.spark.storage.BlockManagerId
5153
import org.apache.spark.storage.FileSegment
5254

@@ -238,6 +240,14 @@ abstract class Shims {
238240
dataType: DataType,
239241
nullable: Boolean): Expression
240242

243+
def getPartitionedFile(
244+
partitionValues: InternalRow,
245+
filePath: String,
246+
offset: Long,
247+
size: Long): PartitionedFile
248+
249+
def getMinPartitionNum(sparkSession: SparkSession): Int
250+
241251
def postTransform(plan: SparkPlan, sc: SparkContext): Unit = {}
242252
}
243253

spark-extension/src/main/scala/org/apache/spark/sql/hive/execution/blaze/plan/NativePaimonTableScanExec.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,10 @@ case class NativePaimonTableScanExec(basedHiveScan: HiveTableScanExec)
216216
(0L until dataFileMeta.fileSize() by maxSplitBytes).map { offset =>
217217
val remaining = dataFileMeta.fileSize() - offset
218218
val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining
219-
PartitionedFile(partitionValues, filePath, offset, size)
219+
Shims.get.getPartitionedFile(partitionValues, filePath, offset, size)
220220
}
221221
} else {
222-
Seq(PartitionedFile(partitionValues, filePath, 0, dataFileMeta.fileSize()))
222+
Seq(Shims.get.getPartitionedFile(partitionValues, filePath, 0, dataFileMeta.fileSize()))
223223
}
224224
}
225225

@@ -229,8 +229,7 @@ case class NativePaimonTableScanExec(basedHiveScan: HiveTableScanExec)
229229
selectedSplits: Seq[DataSplit]): Long = {
230230
val defaultMaxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes
231231
val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes
232-
val minPartitionNum = sparkSession.sessionState.conf.filesMinPartitionNum
233-
.getOrElse(sparkSession.sparkContext.defaultParallelism)
232+
val minPartitionNum = Shims.get.getMinPartitionNum(sparkSession)
234233
val totalBytes = selectedSplits
235234
.flatMap(_.dataFiles().asScala.map(_.fileSize() + openCostInBytes))
236235
.sum

0 commit comments

Comments
 (0)