Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.blaze
import java.io.File
import java.util.UUID
import org.apache.commons.lang3.reflect.FieldUtils
import org.apache.hadoop.fs.Path
import org.apache.spark.OneToOneDependency
import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkEnv
Expand All @@ -33,6 +34,7 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.blaze.BlazeConverters.ForceNativeExecutionWrapperBase
import org.apache.spark.sql.blaze.NativeConverters.NativeExprWrapperBase
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.Expression
Expand Down Expand Up @@ -97,6 +99,7 @@ import org.apache.spark.sql.execution.blaze.plan._
import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase
import org.apache.spark.sql.execution.blaze.shuffle.celeborn.BlazeCelebornShuffleManager
import org.apache.spark.sql.execution.blaze.shuffle.BlazeBlockStoreShuffleReaderBase
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec}
import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastJoinExec
import org.apache.spark.sql.execution.joins.blaze.plan.NativeShuffledHashJoinExecProvider
Expand Down Expand Up @@ -818,6 +821,37 @@ class ShimsImpl extends Shims with Logging {
NativeExprWrapper(nativeExpr, dataType, nullable)
}

@enableIf(
Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains(
System.getProperty("blaze.shim")))
override def getPartitionedFile(
partitionValues: InternalRow,
filePath: String,
offset: Long,
size: Long): PartitionedFile =
PartitionedFile(partitionValues, filePath, offset, size)

@enableIf(Seq("spark-3.4", "spark-3.5").contains(System.getProperty("blaze.shim")))
override def getPartitionedFile(
partitionValues: InternalRow,
filePath: String,
offset: Long,
size: Long): PartitionedFile = {
import org.apache.spark.paths.SparkPath
PartitionedFile(partitionValues, SparkPath.fromPath(new Path(filePath)), offset, size)
}

@enableIf(
Seq("spark-3.1", "spark-3.2", "spark-3.3", "spark-3.4", "spark-3.5").contains(
System.getProperty("blaze.shim")))
override def getMinPartitionNum(sparkSession: SparkSession): Int =
sparkSession.sessionState.conf.filesMinPartitionNum
.getOrElse(sparkSession.sparkContext.defaultParallelism)

@enableIf(Seq("spark-3.0").contains(System.getProperty("blaze.shim")))
override def getMinPartitionNum(sparkSession: SparkSession): Int =
sparkSession.sparkContext.defaultParallelism

@enableIf(
Seq("spark-3.0", "spark-3.1", "spark-3.2", "spark-3.3").contains(
System.getProperty("blaze.shim")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.storage.FileSegment

Expand Down Expand Up @@ -238,6 +240,14 @@ abstract class Shims {
dataType: DataType,
nullable: Boolean): Expression

def getPartitionedFile(
partitionValues: InternalRow,
filePath: String,
offset: Long,
size: Long): PartitionedFile

def getMinPartitionNum(sparkSession: SparkSession): Int

def postTransform(plan: SparkPlan, sc: SparkContext): Unit = {}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ case class NativePaimonTableScanExec(basedHiveScan: HiveTableScanExec)
(0L until dataFileMeta.fileSize() by maxSplitBytes).map { offset =>
val remaining = dataFileMeta.fileSize() - offset
val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining
PartitionedFile(partitionValues, filePath, offset, size)
Shims.get.getPartitionedFile(partitionValues, filePath, offset, size)
}
} else {
Seq(PartitionedFile(partitionValues, filePath, 0, dataFileMeta.fileSize()))
Seq(Shims.get.getPartitionedFile(partitionValues, filePath, 0, dataFileMeta.fileSize()))
}
}

Expand All @@ -229,8 +229,7 @@ case class NativePaimonTableScanExec(basedHiveScan: HiveTableScanExec)
selectedSplits: Seq[DataSplit]): Long = {
val defaultMaxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes
val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes
val minPartitionNum = sparkSession.sessionState.conf.filesMinPartitionNum
.getOrElse(sparkSession.sparkContext.defaultParallelism)
val minPartitionNum = Shims.get.getMinPartitionNum(sparkSession)
val totalBytes = selectedSplits
.flatMap(_.dataFiles().asScala.map(_.fileSize() + openCostInBytes))
.sum
Expand Down