diff --git a/pom.xml b/pom.xml
index d847e35a1..83e9e8b3f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -272,7 +272,7 @@
3.2.9
3.0.0
3.0.3
- 0.5.1
+ 0.5.2
@@ -287,7 +287,7 @@
3.2.9
3.0.0
3.1.3
- 0.5.1
+ 0.5.2
@@ -302,7 +302,7 @@
3.2.9
3.0.0
3.2.4
- 0.5.1
+ 0.5.2
@@ -317,7 +317,7 @@
3.2.9
3.0.0
3.3.4
- 0.5.1
+ 0.5.2
@@ -332,7 +332,7 @@
3.2.9
3.0.0
3.4.4
- 0.5.1
+ 0.5.2
@@ -347,7 +347,7 @@
3.2.9
3.0.0
3.5.3
- 0.5.1
+ 0.5.2
diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleReader.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleReader.scala
index 020f06d1c..2d4bfb7ec 100644
--- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleReader.scala
+++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/celeborn/BlazeCelebornShuffleReader.scala
@@ -21,10 +21,9 @@ import java.util
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference
-
import scala.collection.JavaConverters._
-
import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
import org.apache.celeborn.client.read.CelebornInputStream
import org.apache.celeborn.client.read.MetricsCallback
import org.apache.celeborn.common.CelebornConf
@@ -115,7 +114,16 @@ class BlazeCelebornShuffleReader[K, C](
val localHostAddress = Utils.localHostName(conf)
val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
// startPartition is irrelevant
- val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
+ var fileGroups: ReduceFileGroups = null
+ try {
+ // startPartition is irrelevant
+ fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
+ } catch {
+ case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
+ handleFetchExceptions(shuffleId, 0, ce)
+ case e: Throwable => throw e
+ }
+
// host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList)
val workerRequestMap = new util.HashMap[
String,
@@ -129,23 +137,34 @@ class BlazeCelebornShuffleReader[K, C](
partCnt += 1
val hostPort = location.hostAndFetchPort
if (!workerRequestMap.containsKey(hostPort)) {
- val client = shuffleClient
- .getDataClientFactory()
- .createClient(location.getHost, location.getFetchPort)
- val pbOpenStreamList = PbOpenStreamList.newBuilder()
- pbOpenStreamList.setShuffleKey(shuffleKey)
- workerRequestMap
- .put(hostPort, (client, new util.ArrayList[PartitionLocation], pbOpenStreamList))
+ try {
+ val client = shuffleClient.getDataClientFactory().createClient(
+ location.getHost,
+ location.getFetchPort)
+ val pbOpenStreamList = PbOpenStreamList.newBuilder()
+ pbOpenStreamList.setShuffleKey(shuffleKey)
+ workerRequestMap.put(
+ hostPort,
+ (client, new util.ArrayList[PartitionLocation], pbOpenStreamList))
+ } catch {
+ case ex: Exception =>
+ shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex)
+ logWarning(
+ s"Failed to create client for $shuffleKey-$partitionId from host: ${location.hostAndFetchPort}. " +
+ s"Shuffle reader will try its replica if exists.")
+ }
+ }
+ workerRequestMap.get(hostPort) match {
+ case (_, locArr, pbOpenStreamListBuilder) =>
+ locArr.add(location)
+ pbOpenStreamListBuilder.addFileName(location.getFileName)
+ .addStartIndex(startMapIndex.getOrElse(0))
+ .addEndIndex(endMapIndex.getOrElse(Int.MaxValue))
+ pbOpenStreamListBuilder.addReadLocalShuffle(
+ localFetchEnabled && location.getHost.equals(localHostAddress))
+ case _ =>
+ logDebug(s"Empty client for host ${hostPort}")
}
- val (_, locArr, pbOpenStreamListBuilder) = workerRequestMap.get(hostPort)
-
- locArr.add(location)
- pbOpenStreamListBuilder
- .addFileName(location.getFileName)
- .addStartIndex(startMapIndex.getOrElse(0))
- .addEndIndex(endMapIndex.getOrElse(Int.MaxValue))
- pbOpenStreamListBuilder.addReadLocalShuffle(
- localFetchEnabled && location.getHost.equals(localHostAddress))
}
}
}
@@ -266,18 +285,7 @@ class BlazeCelebornShuffleReader[K, C](
if (exceptionRef.get() != null) {
exceptionRef.get() match {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
- if (throwsFetchFailure &&
- shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
- throw new FetchFailedException(
- null,
- handle.shuffleId,
- -1,
- -1,
- partitionId,
- SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId,
- ce)
- } else
- throw ce
+ handleFetchExceptions(handle.shuffleId, partitionId, ce)
case e => throw e
}
}
@@ -312,4 +320,20 @@ class BlazeCelebornShuffleReader[K, C](
recordIter.map(block => (null, block._2)), // blockId is not used
() => context.taskMetrics().mergeShuffleReadMetrics())
}
+
+ private def handleFetchExceptions(shuffleId: Int, partitionId: Int, ce: Throwable) = {
+ if (throwsFetchFailure &&
+ shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
+ logWarning(s"Handle fetch exceptions for ${shuffleId}-${partitionId}", ce)
+ throw new FetchFailedException(
+ null,
+ handle.shuffleId,
+ -1,
+ -1,
+ partitionId,
+ SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId,
+ ce)
+ } else
+ throw ce
+ }
}