Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,34 @@ class CelebornShuffleReader[K, C](
partCnt += 1
val hostPort = location.hostAndFetchPort
if (!workerRequestMap.containsKey(hostPort)) {
val client = shuffleClient.getDataClientFactory().createClient(
location.getHost,
location.getFetchPort)
val pbOpenStreamList = PbOpenStreamList.newBuilder()
pbOpenStreamList.setShuffleKey(shuffleKey)
workerRequestMap.put(
hostPort,
(client, new util.ArrayList[PartitionLocation], pbOpenStreamList))
try {
val client = shuffleClient.getDataClientFactory().createClient(
location.getHost,
location.getFetchPort)
val pbOpenStreamList = PbOpenStreamList.newBuilder()
pbOpenStreamList.setShuffleKey(shuffleKey)
workerRequestMap.put(
hostPort,
(client, new util.ArrayList[PartitionLocation], pbOpenStreamList))
} catch {
case ex: Exception =>
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex)
logWarning(
s"Failed to create client for $shuffleKey-$partitionId from host: ${location.hostAndFetchPort}. " +
s"Shuffle reader will try its replica if exists.")
}
}
workerRequestMap.get(hostPort) match {
case (_, locArr, pbOpenStreamListBuilder) =>
locArr.add(location)
pbOpenStreamListBuilder.addFileName(location.getFileName)
.addStartIndex(startMapIndex)
.addEndIndex(endMapIndex)
pbOpenStreamListBuilder.addReadLocalShuffle(
localFetchEnabled && location.getHost.equals(localHostAddress))
case _ =>
logDebug(s"Empty client for host ${hostPort}")
}
val (_, locArr, pbOpenStreamListBuilder) = workerRequestMap.get(hostPort)

locArr.add(location)
pbOpenStreamListBuilder.addFileName(location.getFileName)
.addStartIndex(startMapIndex)
.addEndIndex(endMapIndex)
pbOpenStreamListBuilder.addReadLocalShuffle(
localFetchEnabled && location.getHost.equals(localHostAddress))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,4 +285,6 @@ public abstract int getShuffleId(
public abstract boolean reportBarrierTaskFailure(int appShuffleId, String appShuffleIdentifier);

public abstract TransportClientFactory getDataClientFactory();

public abstract void excludeFailedFetchLocation(String hostAndFetchPort, Exception e);
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ public class ShuffleClientImpl extends ShuffleClient {
private final Set<String> pushExcludedWorkers = ConcurrentHashMap.newKeySet();
private final ConcurrentHashMap<String, Long> fetchExcludedWorkers =
JavaUtils.newConcurrentHashMap();
private boolean pushReplicateEnabled;
private boolean fetchExcludeWorkerOnFailureEnabled;

private final ExecutorService pushDataRetryPool;

Expand Down Expand Up @@ -180,6 +182,8 @@ public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier u
pushBufferMaxSize = conf.clientPushBufferMaxSize();
pushExcludeWorkerOnFailureEnabled = conf.clientPushExcludeWorkerOnFailureEnabled();
shuffleCompressionEnabled = !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
pushReplicateEnabled = conf.clientPushReplicateEnabled();
fetchExcludeWorkerOnFailureEnabled = conf.clientFetchExcludeWorkerOnFailureEnabled();
if (conf.clientPushReplicateEnabled()) {
pushDataTimeout = conf.pushDataTimeoutMs() * 2;
} else {
Expand Down Expand Up @@ -1904,4 +1908,12 @@ private StatusCode getPushDataFailCause(String message) {
public TransportClientFactory getDataClientFactory() {
return dataClientFactory;
}

public void excludeFailedFetchLocation(String hostAndFetchPort, Exception e) {
if (pushReplicateEnabled
&& fetchExcludeWorkerOnFailureEnabled
&& Utils.isCriticalCauseForFetch(e)) {
fetchExcludedWorkers.put(hostAndFetchPort, System.currentTimeMillis());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private final boolean enabledReadLocalShuffle;
private final String localHostAddress;

private boolean pushReplicateEnabled;
private boolean fetchExcludeWorkerOnFailureEnabled;
private boolean shuffleCompressionEnabled;
private long fetchExcludedWorkerExpireTimeout;
private ConcurrentHashMap<String, Long> fetchExcludedWorkers;
Expand Down Expand Up @@ -205,8 +203,6 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled();
this.enabledReadLocalShuffle = conf.enableReadLocalShuffleFile();
this.localHostAddress = Utils.localHostName(conf);
this.pushReplicateEnabled = conf.clientPushReplicateEnabled();
this.fetchExcludeWorkerOnFailureEnabled = conf.clientFetchExcludeWorkerOnFailureEnabled();
this.shuffleCompressionEnabled =
!conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout();
Expand Down Expand Up @@ -299,12 +295,6 @@ private void moveToNextReader(boolean fetchChunk) throws IOException {
}
}

private void excludeFailedLocation(PartitionLocation location, Exception e) {
if (pushReplicateEnabled && fetchExcludeWorkerOnFailureEnabled && isCriticalCause(e)) {
fetchExcludedWorkers.put(location.hostAndFetchPort(), System.currentTimeMillis());
}
}

private boolean isExcluded(PartitionLocation location) {
Long timestamp = fetchExcludedWorkers.get(location.hostAndFetchPort());
if (timestamp == null) {
Expand Down Expand Up @@ -354,7 +344,7 @@ private PartitionReader createReaderWithRetry(
return createReader(location, pbStreamHandler, fetchChunkRetryCnt, fetchChunkMaxRetry);
} catch (Exception e) {
lastException = e;
excludeFailedLocation(location, e);
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort(), e);
fetchChunkRetryCnt++;
if (location.hasPeer()) {
// fetchChunkRetryCnt % 2 == 0 means both replicas have been tried,
Expand Down Expand Up @@ -392,7 +382,8 @@ private ByteBuf getNextChunk() throws IOException {
}
return currentReader.next();
} catch (Exception e) {
excludeFailedLocation(currentReader.getLocation(), e);
shuffleClient.excludeFailedFetchLocation(
currentReader.getLocation().hostAndFetchPort(), e);
fetchChunkRetryCnt++;
currentReader.close();
if (fetchChunkRetryCnt == fetchChunkMaxRetry) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ public TransportClientFactory getDataClientFactory() {
return null;
}

@Override
public void excludeFailedFetchLocation(String hostAndFetchPort, Exception e) {}

public void initReducePartitionMap(int shuffleId, int numPartitions, int workerNum) {
ConcurrentHashMap<Integer, PartitionLocation> map = JavaUtils.newConcurrentHashMap();
String host = "host";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.roaringbitmap.RoaringBitmap

import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.CelebornConf.PORT_MAX_RETRY
import org.apache.celeborn.common.exception.CelebornException
import org.apache.celeborn.common.exception.{CelebornException, CelebornIOException}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{DiskStatus, WorkerInfo}
import org.apache.celeborn.common.network.protocol.TransportMessage
Expand Down Expand Up @@ -1343,4 +1343,16 @@ object Utils extends Logging {
throw e
}
}

def isCriticalCauseForFetch(e: Exception) = {
val rpcTimeout =
e.isInstanceOf[IOException] && e.getCause != null && e.getCause.isInstanceOf[TimeoutException]
val connectException =
e.isInstanceOf[CelebornIOException] && e.getMessage != null && (e.getMessage.startsWith(
"Connecting to") || e.getMessage.startsWith("Failed to"))
val fetchChunkTimeout = e.isInstanceOf[
CelebornIOException] && e.getCause != null && e.getCause.isInstanceOf[IOException]
connectException || rpcTimeout || fetchChunkTimeout
}

}