@@ -98,14 +98,15 @@ class RPCServer {
9898 * \brief Constructor.
9999 */
100100 RPCServer (std::string host, int port, int port_end, std::string tracker_addr, std::string key,
101- std::string custom_addr)
101+ std::string custom_addr, std::string work_dir )
102102 : host_(std::move(host)),
103103 port_ (port),
104104 my_port_(0 ),
105105 port_end_(port_end),
106106 tracker_addr_(std::move(tracker_addr)),
107107 key_(std::move(key)),
108- custom_addr_(std::move(custom_addr)) {}
108+ custom_addr_(std::move(custom_addr)),
109+ work_dir_(std::move(work_dir)) {}
109110
110111 /* !
111112 * \brief Destructor.
@@ -174,7 +175,7 @@ class RPCServer {
174175 const pid_t worker_pid = fork ();
175176 if (worker_pid == 0 ) {
176177 // Worker process
177- ServerLoopProc (conn, addr);
178+ ServerLoopProc (conn, addr, work_dir_ );
178179 _exit (0 );
179180 }
180181
@@ -201,7 +202,7 @@ class RPCServer {
201202 } else {
202203 auto pid = fork ();
203204 if (pid == 0 ) {
204- ServerLoopProc (conn, addr);
205+ ServerLoopProc (conn, addr, work_dir_ );
205206 exit (0 );
206207 }
207208 // Wait for the result
@@ -308,9 +309,10 @@ class RPCServer {
308309 * \param sock The socket information
309310 * \param addr The socket address information
310311 */
311- static void ServerLoopProc (support::TCPSocket sock, support::SockAddr addr) {
312+ static void ServerLoopProc (support::TCPSocket sock, support::SockAddr addr,
313+ std::string work_dir) {
312314 // Server loop
313- const auto env = RPCEnv ();
315+ const auto env = RPCEnv (work_dir );
314316 RPCServerLoop (int (sock.sockfd ));
315317 LOG (INFO) << " Finish serving " << addr.AsString ();
316318 env.CleanUp ();
@@ -339,6 +341,7 @@ class RPCServer {
339341 std::string tracker_addr_;
340342 std::string key_;
341343 std::string custom_addr_;
344+ std::string work_dir_;
342345 support::TCPSocket listen_sock_;
343346 support::TCPSocket tracker_sock_;
344347};
@@ -370,19 +373,19 @@ void ServerLoopFromChild(SOCKET socket) {
370373 * silent mode. Default=True
371374 */
372375void RPCServerCreate (std::string host, int port, int port_end, std::string tracker_addr,
373- std::string key, std::string custom_addr, bool silent) {
376+ std::string key, std::string custom_addr, std::string work_dir, bool silent) {
374377 if (silent) {
375378 // Only errors and fatal is logged
376379 dmlc::InitLogging (" --minloglevel=2" );
377380 }
378381 // Start the rpc server
379382 RPCServer rpc (std::move (host), port, port_end, std::move (tracker_addr), std::move (key),
380- std::move (custom_addr));
383+ std::move (custom_addr), std::move (work_dir) );
381384 rpc.Start ();
382385}
383386
384387TVM_REGISTER_GLOBAL (" rpc.ServerCreate" ).set_body([](TVMArgs args, TVMRetValue* rv) {
385- RPCServerCreate (args[0 ], args[1 ], args[2 ], args[3 ], args[4 ], args[5 ], args[6 ]);
388+ RPCServerCreate (args[0 ], args[1 ], args[2 ], args[3 ], args[4 ], args[5 ], args[6 ], args[ 7 ] );
386389});
387390} // namespace runtime
388391} // namespace tvm
0 commit comments