Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.0.3</sparkVersion>
<celebornVersion>0.5.1</celebornVersion>
<celebornVersion>0.5.2</celebornVersion>
</properties>
</profile>

Expand All @@ -287,7 +287,7 @@
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.1.3</sparkVersion>
<celebornVersion>0.5.1</celebornVersion>
<celebornVersion>0.5.2</celebornVersion>
</properties>
</profile>

Expand All @@ -302,7 +302,7 @@
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.2.4</sparkVersion>
<celebornVersion>0.5.1</celebornVersion>
<celebornVersion>0.5.2</celebornVersion>
</properties>
</profile>

Expand All @@ -317,7 +317,7 @@
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.3.4</sparkVersion>
<celebornVersion>0.5.1</celebornVersion>
<celebornVersion>0.5.2</celebornVersion>
</properties>
</profile>

Expand All @@ -332,7 +332,7 @@
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.4.4</sparkVersion>
<celebornVersion>0.5.1</celebornVersion>
<celebornVersion>0.5.2</celebornVersion>
</properties>
</profile>

Expand All @@ -347,7 +347,7 @@
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.5.3</sparkVersion>
<celebornVersion>0.5.1</celebornVersion>
<celebornVersion>0.5.2</celebornVersion>
</properties>
</profile>
</profiles>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ import java.util
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
import org.apache.celeborn.client.read.CelebornInputStream
import org.apache.celeborn.client.read.MetricsCallback
import org.apache.celeborn.common.CelebornConf
Expand Down Expand Up @@ -115,7 +114,16 @@ class BlazeCelebornShuffleReader[K, C](
val localHostAddress = Utils.localHostName(conf)
val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
// startPartition is irrelevant
val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
var fileGroups: ReduceFileGroups = null
try {
// startPartition is irrelevant
fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
} catch {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
handleFetchExceptions(shuffleId, 0, ce)
case e: Throwable => throw e
}

// host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList)
val workerRequestMap = new util.HashMap[
String,
Expand All @@ -129,23 +137,34 @@ class BlazeCelebornShuffleReader[K, C](
partCnt += 1
val hostPort = location.hostAndFetchPort
if (!workerRequestMap.containsKey(hostPort)) {
val client = shuffleClient
.getDataClientFactory()
.createClient(location.getHost, location.getFetchPort)
val pbOpenStreamList = PbOpenStreamList.newBuilder()
pbOpenStreamList.setShuffleKey(shuffleKey)
workerRequestMap
.put(hostPort, (client, new util.ArrayList[PartitionLocation], pbOpenStreamList))
try {
val client = shuffleClient.getDataClientFactory().createClient(
location.getHost,
location.getFetchPort)
val pbOpenStreamList = PbOpenStreamList.newBuilder()
pbOpenStreamList.setShuffleKey(shuffleKey)
workerRequestMap.put(
hostPort,
(client, new util.ArrayList[PartitionLocation], pbOpenStreamList))
} catch {
case ex: Exception =>
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex)
logWarning(
s"Failed to create client for $shuffleKey-$partitionId from host: ${location.hostAndFetchPort}. " +
s"Shuffle reader will try its replica if exists.")
}
}
workerRequestMap.get(hostPort) match {
case (_, locArr, pbOpenStreamListBuilder) =>
locArr.add(location)
pbOpenStreamListBuilder.addFileName(location.getFileName)
.addStartIndex(startMapIndex.getOrElse(0))
.addEndIndex(endMapIndex.getOrElse(Int.MaxValue))
pbOpenStreamListBuilder.addReadLocalShuffle(
localFetchEnabled && location.getHost.equals(localHostAddress))
case _ =>
logDebug(s"Empty client for host ${hostPort}")
}
val (_, locArr, pbOpenStreamListBuilder) = workerRequestMap.get(hostPort)

locArr.add(location)
pbOpenStreamListBuilder
.addFileName(location.getFileName)
.addStartIndex(startMapIndex.getOrElse(0))
.addEndIndex(endMapIndex.getOrElse(Int.MaxValue))
pbOpenStreamListBuilder.addReadLocalShuffle(
localFetchEnabled && location.getHost.equals(localHostAddress))
}
}
}
Expand Down Expand Up @@ -266,18 +285,7 @@ class BlazeCelebornShuffleReader[K, C](
if (exceptionRef.get() != null) {
exceptionRef.get() match {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
if (throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
throw new FetchFailedException(
null,
handle.shuffleId,
-1,
-1,
partitionId,
SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId,
ce)
} else
throw ce
handleFetchExceptions(handle.shuffleId, partitionId, ce)
case e => throw e
}
}
Expand Down Expand Up @@ -312,4 +320,20 @@ class BlazeCelebornShuffleReader[K, C](
recordIter.map(block => (null, block._2)), // blockId is not used
() => context.taskMetrics().mergeShuffleReadMetrics())
}

private def handleFetchExceptions(shuffleId: Int, partitionId: Int, ce: Throwable) = {
if (throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
logWarning(s"Handle fetch exceptions for ${shuffleId}-${partitionId}", ce)
throw new FetchFailedException(
null,
handle.shuffleId,
-1,
-1,
partitionId,
SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId,
ce)
} else
throw ce
}
}