Skip to content

Commit a69338b

Browse files
richoxzhangli20
andauthored
init RSS framework (#551)
Co-authored-by: zhangli20 <zhangli20@kuaishou.com>
1 parent 565c025 commit a69338b

7 files changed

Lines changed: 301 additions & 7 deletions

File tree

native-engine/blaze-jni-bridge/src/jni_bridge.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1251,7 +1251,7 @@ impl<'a> BlazeRssPartitionWriterBase<'_> {
12511251
Ok(BlazeRssPartitionWriterBase {
12521252
class,
12531253
method_write: env
1254-
.get_method_id(class, "write", "(ILjava/nio/ByteBuffer;I)V")
1254+
.get_method_id(class, "write", "(ILjava/nio/ByteBuffer;)V")
12551255
.unwrap(),
12561256
method_write_ret: ReturnType::Primitive(Primitive::Void),
12571257
method_flush: env.get_method_id(class, "flush", "()V").unwrap(),

native-engine/datafusion-ext-plans/src/shuffle/rss.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ impl Write for RssWriter {
3737
let buf = jni_new_direct_byte_buffer!(&buf)?;
3838
jni_call!(
3939
BlazeRssPartitionWriterBase(self.rss_partition_writer.as_obj())
40-
.write(self.partition_id as i32, buf.as_obj(), buf_len as i32) -> ()
40+
.write(self.partition_id as i32, buf.as_obj()) -> ()
4141
)?;
4242
Ok(buf_len)
4343
}
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
* Copyright 2022 The Blaze Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.apache.spark.sql.execution.blaze.shuffle
17+
18+
import org.apache.spark.ShuffleDependency
19+
import org.apache.spark.SparkConf
20+
import org.apache.spark.TaskContext
21+
import org.apache.spark.internal.Logging
22+
import org.apache.spark.shuffle._
23+
import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleDependency.isArrowShuffle
24+
25+
import com.thoughtworks.enableIf
26+
27+
abstract class BlazeRssShuffleManagerBase(conf: SparkConf) extends ShuffleManager with Logging {
28+
override def registerShuffle[K, V, C](
29+
shuffleId: Int,
30+
dependency: ShuffleDependency[K, V, C]): ShuffleHandle
31+
32+
override def unregisterShuffle(shuffleId: Int): Boolean
33+
34+
def getBlazeRssShuffleReader[K, C](
35+
handle: ShuffleHandle,
36+
startPartition: Int,
37+
endPartition: Int,
38+
context: TaskContext,
39+
metrics: ShuffleReadMetricsReporter): BlazeRssShuffleReaderBase[K, C]
40+
41+
def getBlazeRssShuffleReader[K, C](
42+
handle: ShuffleHandle,
43+
startMapIndex: Int,
44+
endMapIndex: Int,
45+
startPartition: Int,
46+
endPartition: Int,
47+
context: TaskContext,
48+
metrics: ShuffleReadMetricsReporter): BlazeRssShuffleReaderBase[K, C]
49+
50+
def getRssShuffleReader[K, C](
51+
handle: ShuffleHandle,
52+
startPartition: Int,
53+
endPartition: Int,
54+
context: TaskContext,
55+
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]
56+
57+
def getRssShuffleReader[K, C](
58+
handle: ShuffleHandle,
59+
startMapIndex: Int,
60+
endMapIndex: Int,
61+
startPartition: Int,
62+
endPartition: Int,
63+
context: TaskContext,
64+
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]
65+
66+
def getBlazeRssShuffleWriter[K, V](
67+
handle: ShuffleHandle,
68+
mapId: Long,
69+
context: TaskContext,
70+
metrics: ShuffleWriteMetricsReporter): BlazeRssShuffleWriterBase[K, V]
71+
72+
def getRssShuffleWriter[K, V](
73+
handle: ShuffleHandle,
74+
mapId: Long,
75+
context: TaskContext,
76+
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V]
77+
78+
@enableIf(
79+
Seq("spark320", "spark324", "spark333", "spark351").contains(
80+
System.getProperty("blaze.shim")))
81+
override def getReader[K, C](
82+
handle: ShuffleHandle,
83+
startMapIndex: Int,
84+
endMapIndex: Int,
85+
startPartition: Int,
86+
endPartition: Int,
87+
context: TaskContext,
88+
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
89+
90+
if (isArrowShuffle(handle)) {
91+
getBlazeRssShuffleReader(
92+
handle,
93+
startMapIndex,
94+
endMapIndex,
95+
startPartition,
96+
endPartition,
97+
context,
98+
metrics)
99+
} else {
100+
getRssShuffleReader(
101+
handle,
102+
startMapIndex,
103+
endMapIndex,
104+
startPartition,
105+
endPartition,
106+
context,
107+
metrics)
108+
}
109+
}
110+
111+
@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
112+
override def getReader[K, C](
113+
handle: ShuffleHandle,
114+
startPartition: Int,
115+
endPartition: Int,
116+
context: TaskContext,
117+
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
118+
119+
if (isArrowShuffle(handle)) {
120+
getBlazeRssShuffleReader(handle, startPartition, endPartition, context, metrics)
121+
} else {
122+
getRssShuffleReader(handle, startPartition, endPartition, context, metrics)
123+
}
124+
}
125+
126+
@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
127+
override def getReaderForRange[K, C](
128+
handle: ShuffleHandle,
129+
startMapIndex: Int,
130+
endMapIndex: Int,
131+
startPartition: Int,
132+
endPartition: Int,
133+
context: TaskContext,
134+
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
135+
136+
if (isArrowShuffle(handle)) {
137+
getBlazeRssShuffleReader(
138+
handle,
139+
startMapIndex,
140+
endMapIndex,
141+
startPartition,
142+
endPartition,
143+
context,
144+
metrics)
145+
} else {
146+
getRssShuffleReader(
147+
handle,
148+
startMapIndex,
149+
endMapIndex,
150+
startPartition,
151+
endPartition,
152+
context,
153+
metrics)
154+
}
155+
}
156+
157+
override def getWriter[K, V](
158+
handle: ShuffleHandle,
159+
mapId: Long,
160+
context: TaskContext,
161+
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
162+
163+
if (isArrowShuffle(handle)) {
164+
getBlazeRssShuffleWriter(handle, mapId, context, metrics)
165+
} else {
166+
getRssShuffleWriter(handle, mapId, context, metrics)
167+
}
168+
}
169+
}

spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,11 @@ object NativeConverters extends Logging {
569569
}
570570
val resultType = (lhs.dataType, rhs.dataType) match {
571571
case (lhsType: DecimalType, rhsType: DecimalType) =>
572-
resultDecimalType(lhsType.precision, lhsType.scale, rhsType.precision, rhsType.scale)
572+
resultDecimalType(
573+
lhsType.precision,
574+
lhsType.scale,
575+
rhsType.precision,
576+
rhsType.scale)
573577
}
574578

575579
buildExprNode {
@@ -606,7 +610,11 @@ object NativeConverters extends Logging {
606610
}
607611
val resultType = (lhs.dataType, rhs.dataType) match {
608612
case (lhsType: DecimalType, rhsType: DecimalType) =>
609-
resultDecimalType(lhsType.precision, lhsType.scale, rhsType.precision, rhsType.scale)
613+
resultDecimalType(
614+
lhsType.precision,
615+
lhsType.scale,
616+
rhsType.precision,
617+
rhsType.scale)
610618
}
611619

612620
buildExprNode {
@@ -642,7 +650,11 @@ object NativeConverters extends Logging {
642650
}
643651
val resultType = (lhs.dataType, rhs.dataType) match {
644652
case (lhsType: DecimalType, rhsType: DecimalType) =>
645-
resultDecimalType(lhsType.precision, lhsType.scale, rhsType.precision, rhsType.scale)
653+
resultDecimalType(
654+
lhsType.precision,
655+
lhsType.scale,
656+
rhsType.precision,
657+
rhsType.scale)
646658
}
647659

648660
buildExprNode {
@@ -686,7 +698,11 @@ object NativeConverters extends Logging {
686698
}
687699
val resultType = (lhs.dataType, rhs.dataType) match {
688700
case (lhsType: DecimalType, rhsType: DecimalType) =>
689-
resultDecimalType(lhsType.precision, lhsType.scale, rhsType.precision, rhsType.scale)
701+
resultDecimalType(
702+
lhsType.precision,
703+
lhsType.scale,
704+
rhsType.precision,
705+
rhsType.scale)
690706
}
691707

692708
buildExprNode {
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright 2022 The Blaze Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.apache.spark.sql.execution.blaze.shuffle
17+
18+
import org.apache.spark.TaskContext
19+
import org.apache.spark.shuffle.BaseShuffleHandle
20+
21+
abstract class BlazeRssShuffleReaderBase[K, C](
22+
handle: BaseShuffleHandle[K, _, C],
23+
context: TaskContext)
24+
extends BlazeBlockStoreShuffleReaderBase[K, C](handle, context) {}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright 2022 The Blaze Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.apache.spark.sql.execution.blaze.shuffle
17+
18+
import java.util.UUID
19+
20+
import org.apache.spark.SparkEnv
21+
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
22+
import org.apache.spark.sql.blaze.JniBridge
23+
import org.apache.spark.sql.blaze.NativeHelper
24+
import org.apache.spark.sql.blaze.NativeRDD
25+
import org.apache.spark.sql.blaze.Shims
26+
import org.apache.spark.Partition
27+
import org.apache.spark.ShuffleDependency
28+
import org.apache.spark.TaskContext
29+
import org.apache.spark.scheduler.MapStatus
30+
import org.apache.spark.shuffle.ShuffleHandle
31+
import org.blaze.protobuf.PhysicalPlanNode
32+
import org.blaze.protobuf.RssShuffleWriterExecNode
33+
34+
abstract class BlazeRssShuffleWriterBase[K, V](metrics: ShuffleWriteMetricsReporter)
35+
extends BlazeShuffleWriterBase[K, V](metrics) {
36+
37+
def getRssPartitionWriter(
38+
handle: ShuffleHandle,
39+
mapId: Int,
40+
metrics: ShuffleWriteMetricsReporter,
41+
numPartitions: Int): RssPartitionWriterBase
42+
43+
def nativeRssShuffleWrite(
44+
nativeShuffleRDD: NativeRDD,
45+
dep: ShuffleDependency[_, _, _],
46+
mapId: Int,
47+
context: TaskContext,
48+
partition: Partition,
49+
numPartitions: Int): MapStatus = {
50+
51+
val rssShuffleWriterObject =
52+
getRssPartitionWriter(dep.shuffleHandle, mapId, metrics, numPartitions)
53+
if (rssShuffleWriterObject == null) {
54+
throw new RuntimeException("cannot get RssPartitionWriter")
55+
}
56+
57+
try {
58+
val jniResourceId = s"RssPartitionWriter:${UUID.randomUUID().toString}"
59+
JniBridge.resourcesMap.put(jniResourceId, rssShuffleWriterObject)
60+
val nativeRssShuffleWriterExec = PhysicalPlanNode
61+
.newBuilder()
62+
.setRssShuffleWriter(
63+
RssShuffleWriterExecNode
64+
.newBuilder(nativeShuffleRDD.nativePlan(partition, context).getRssShuffleWriter)
65+
.setRssPartitionWriterResourceId(jniResourceId)
66+
.build())
67+
.build()
68+
69+
val iterator = NativeHelper.executeNativePlan(
70+
nativeRssShuffleWriterExec,
71+
nativeShuffleRDD.metrics,
72+
partition,
73+
Some(context))
74+
assert(iterator.toArray.isEmpty)
75+
} finally {
76+
rssShuffleWriterObject.close()
77+
}
78+
79+
val mapStatus = Shims.get.getMapStatus(
80+
SparkEnv.get.blockManager.shuffleServerId,
81+
rssShuffleWriterObject.getPartitionLengthMap,
82+
mapId)
83+
mapStatus
84+
}
85+
}

spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/RssPartitionWriterBase.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.blaze.shuffle
1818
import java.nio.ByteBuffer
1919

2020
trait RssPartitionWriterBase {
21-
def write(partitionId: Int, buffer: ByteBuffer, length: Int): Unit
21+
def write(partitionId: Int, buffer: ByteBuffer): Unit
2222
def flush(): Unit
2323
def close(): Unit
2424
def getPartitionLengthMap: Array[Long]

0 commit comments

Comments
 (0)