Skip to content

Commit 759ee12

Browse files
LeshengJinsunggg
andauthored
[Support] Add Interrupt Handling in Pipe (#16255)
Co-authored-by: Sunghyun Park <sunggg@umich.edu>
1 parent 943861c commit 759ee12

3 files changed

Lines changed: 117 additions & 59 deletions

File tree

src/support/errno_handling.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file errno_handling.h
22+
* \brief Common error number handling functions for socket.h and pipe.h
23+
*/
24+
#ifndef TVM_SUPPORT_ERRNO_HANDLING_H_
25+
#define TVM_SUPPORT_ERRNO_HANDLING_H_
26+
#include <errno.h>
27+
28+
#include "ssize.h"
29+
30+
namespace tvm {
31+
namespace support {
32+
/*!
33+
* \brief Call a function and retry if an EINTR error is encountered.
34+
*
35+
* Socket operations can return EINTR when the interrupt handler
36+
* is registered by the execution environment(e.g. python).
37+
* We should retry if there is no KeyboardInterrupt recorded in
38+
* the environment.
39+
*
40+
* \note This function is needed to avoid rare interrupt event
41+
* in long running server code.
42+
*
43+
* \param func The function to retry.
44+
* \return The return code returned by function f or error_value on retry failure.
45+
*/
46+
template <typename FuncType, typename GetErrorCodeFuncType>
47+
inline ssize_t RetryCallOnEINTR(FuncType func, GetErrorCodeFuncType fgeterrorcode) {
48+
ssize_t ret = func();
49+
// common path
50+
if (ret != -1) return ret;
51+
// less common path
52+
do {
53+
if (fgeterrorcode() == EINTR) {
54+
// Call into env check signals to see if there are
55+
// environment specific(e.g. python) signal exceptions.
56+
// This function will throw an exception if there is
57+
// if the process received a signal that requires TVM to return immediately (e.g. SIGINT).
58+
runtime::EnvCheckSignals();
59+
} else {
60+
// other errors
61+
return ret;
62+
}
63+
ret = func();
64+
} while (ret == -1);
65+
return ret;
66+
}
67+
} // namespace support
68+
} // namespace tvm
69+
#endif // TVM_SUPPORT_ERRNO_HANDLING_H_

src/support/pipe.h

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include <cstdlib>
3737
#include <cstring>
3838
#endif
39+
#include "errno_handling.h"
3940

4041
namespace tvm {
4142
namespace support {
@@ -52,8 +53,21 @@ class Pipe : public dmlc::Stream {
5253
#endif
5354
/*! \brief destructor */
5455
~Pipe() { Flush(); }
56+
5557
using Stream::Read;
5658
using Stream::Write;
59+
60+
/*!
61+
* \return last error of pipe operation
62+
*/
63+
static int GetLastErrorCode() {
64+
#ifdef _WIN32
65+
return GetLastError();
66+
#else
67+
return errno;
68+
#endif
69+
}
70+
5771
/*!
5872
* \brief reads data from a file descriptor
5973
* \param ptr pointer to a memory buffer
@@ -63,12 +77,15 @@ class Pipe : public dmlc::Stream {
6377
size_t Read(void* ptr, size_t size) final {
6478
if (size == 0) return 0;
6579
#ifdef _WIN32
66-
DWORD nread;
67-
ICHECK(ReadFile(handle_, static_cast<TCHAR*>(ptr), size, &nread, nullptr))
68-
<< "Read Error: " << GetLastError();
80+
auto fread = [&]() {
81+
DWORD nread;
82+
if (!ReadFile(handle_, static_cast<TCHAR*>(ptr), size, &nread, nullptr)) return -1;
83+
return nread;
84+
};
85+
DWORD nread = static_cast<DWORD>(RetryCallOnEINTR(fread, GetLastErrorCode));
86+
ICHECK_EQ(static_cast<size_t>(nread), size) << "Read Error: " << GetLastError();
6987
#else
70-
ssize_t nread;
71-
nread = read(handle_, ptr, size);
88+
ssize_t nread = RetryCallOnEINTR([&]() { return read(handle_, ptr, size); }, GetLastErrorCode);
7289
ICHECK_GE(nread, 0) << "Write Error: " << strerror(errno);
7390
#endif
7491
return static_cast<size_t>(nread);
@@ -82,13 +99,16 @@ class Pipe : public dmlc::Stream {
8299
void Write(const void* ptr, size_t size) final {
83100
if (size == 0) return;
84101
#ifdef _WIN32
85-
DWORD nwrite;
86-
ICHECK(WriteFile(handle_, static_cast<const TCHAR*>(ptr), size, &nwrite, nullptr) &&
87-
static_cast<size_t>(nwrite) == size)
88-
<< "Write Error: " << GetLastError();
102+
auto fwrite = [&]() {
103+
DWORD nwrite;
104+
if (!WriteFile(handle_, static_cast<const TCHAR*>(ptr), size, &nwrite, nullptr)) return -1;
105+
return nwrite;
106+
};
107+
DWORD nwrite = static_cast<DWORD>(RetryCallOnEINTR(fwrite, GetLastErrorCode));
108+
ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << GetLastError();
89109
#else
90-
ssize_t nwrite;
91-
nwrite = write(handle_, ptr, size);
110+
ssize_t nwrite =
111+
RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode);
92112
ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << strerror(errno);
93113
#endif
94114
}

src/support/socket.h

Lines changed: 17 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
#endif
4040
#else
4141
#include <arpa/inet.h>
42-
#include <errno.h>
4342
#include <fcntl.h>
4443
#include <netdb.h>
4544
#include <netinet/in.h>
@@ -56,8 +55,9 @@
5655
#include <unordered_map>
5756
#include <vector>
5857

59-
#include "../support/ssize.h"
60-
#include "../support/utils.h"
58+
#include "errno_handling.h"
59+
#include "ssize.h"
60+
#include "utils.h"
6161

6262
#if defined(_WIN32)
6363
static inline int poll(struct pollfd* pfd, int nfds, int timeout) {
@@ -310,7 +310,7 @@ class Socket {
310310
/*!
311311
* \return last error of socket operation
312312
*/
313-
static int GetLastError() {
313+
static int GetLastErrorCode() {
314314
#ifdef _WIN32
315315
return WSAGetLastError();
316316
#else
@@ -319,7 +319,7 @@ class Socket {
319319
}
320320
/*! \return whether last error was would block */
321321
static bool LastErrorWouldBlock() {
322-
int errsv = GetLastError();
322+
int errsv = GetLastErrorCode();
323323
#ifdef _WIN32
324324
return errsv == WSAEWOULDBLOCK;
325325
#else
@@ -355,50 +355,14 @@ class Socket {
355355
* \param msg The error message.
356356
*/
357357
static void Error(const char* msg) {
358-
int errsv = GetLastError();
358+
int errsv = GetLastErrorCode();
359359
#ifdef _WIN32
360360
LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv;
361361
#else
362362
LOG(FATAL) << "Socket " << msg << " Error:" << strerror(errsv);
363363
#endif
364364
}
365365

366-
/*!
367-
* \brief Call a function and retry if an EINTR error is encountered.
368-
*
369-
* Socket operations can return EINTR when the interrupt handler
370-
* is registered by the execution environment(e.g. python).
371-
* We should retry if there is no KeyboardInterrupt recorded in
372-
* the environment.
373-
*
374-
* \note This function is needed to avoid rare interrupt event
375-
* in long running server code.
376-
*
377-
* \param func The function to retry.
378-
* \return The return code returned by function f or error_value on retry failure.
379-
*/
380-
template <typename FuncType>
381-
ssize_t RetryCallOnEINTR(FuncType func) {
382-
ssize_t ret = func();
383-
// common path
384-
if (ret != -1) return ret;
385-
// less common path
386-
do {
387-
if (GetLastError() == EINTR) {
388-
// Call into env check signals to see if there are
389-
// environment specific(e.g. python) signal exceptions.
390-
// This function will throw an exception if there is
391-
// if the process received a signal that requires TVM to return immediately (e.g. SIGINT).
392-
runtime::EnvCheckSignals();
393-
} else {
394-
// other errors
395-
return ret;
396-
}
397-
ret = func();
398-
} while (ret == -1);
399-
return ret;
400-
}
401-
402366
protected:
403367
explicit Socket(SockType sockfd) : sockfd(sockfd) {}
404368
};
@@ -445,7 +409,8 @@ class TCPSocket : public Socket {
445409
* \return The accepted socket connection.
446410
*/
447411
TCPSocket Accept() {
448-
SockType newfd = RetryCallOnEINTR([&]() { return accept(sockfd, nullptr, nullptr); });
412+
SockType newfd =
413+
RetryCallOnEINTR([&]() { return accept(sockfd, nullptr, nullptr); }, GetLastErrorCode);
449414
if (newfd == INVALID_SOCKET) {
450415
Socket::Error("Accept");
451416
}
@@ -459,7 +424,8 @@ class TCPSocket : public Socket {
459424
TCPSocket Accept(SockAddr* addr) {
460425
socklen_t addrlen = sizeof(addr->addr);
461426
SockType newfd = RetryCallOnEINTR(
462-
[&]() { return accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr), &addrlen); });
427+
[&]() { return accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr), &addrlen); },
428+
GetLastErrorCode);
463429
if (newfd == INVALID_SOCKET) {
464430
Socket::Error("Accept");
465431
}
@@ -500,7 +466,7 @@ class TCPSocket : public Socket {
500466
ssize_t Send(const void* buf_, size_t len, int flag = 0) {
501467
const char* buf = reinterpret_cast<const char*>(buf_);
502468
return RetryCallOnEINTR(
503-
[&]() { return send(sockfd, buf, static_cast<sock_size_t>(len), flag); });
469+
[&]() { return send(sockfd, buf, static_cast<sock_size_t>(len), flag); }, GetLastErrorCode);
504470
}
505471
/*!
506472
* \brief receive data using the socket
@@ -513,7 +479,8 @@ class TCPSocket : public Socket {
513479
ssize_t Recv(void* buf_, size_t len, int flags = 0) {
514480
char* buf = reinterpret_cast<char*>(buf_);
515481
return RetryCallOnEINTR(
516-
[&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len), flags); });
482+
[&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len), flags); },
483+
GetLastErrorCode);
517484
}
518485
/*!
519486
* \brief perform block write that will attempt to send all data out
@@ -527,7 +494,8 @@ class TCPSocket : public Socket {
527494
size_t ndone = 0;
528495
while (ndone < len) {
529496
ssize_t ret = RetryCallOnEINTR(
530-
[&]() { return send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0); });
497+
[&]() { return send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0); },
498+
GetLastErrorCode);
531499
if (ret == -1) {
532500
if (LastErrorWouldBlock()) return ndone;
533501
Socket::Error("SendAll");
@@ -549,7 +517,8 @@ class TCPSocket : public Socket {
549517
size_t ndone = 0;
550518
while (ndone < len) {
551519
ssize_t ret = RetryCallOnEINTR(
552-
[&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL); });
520+
[&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL); },
521+
GetLastErrorCode);
553522
if (ret == -1) {
554523
if (LastErrorWouldBlock()) {
555524
LOG(FATAL) << "would block";

0 commit comments

Comments
 (0)