Skip to content

Commit 04f0f41

Browse files
Andrew Zhao LuoAndrew Zhao Luo
authored andcommitted
Merge branch 'main' into andrewzhaoluo-add-cumprod
* main: [AutoScheduler] Add function name in message (apache#7703) [Vulkan] Workaround for zero size allocation (apache#7691) Change behavior of onnx importer to throw when user provides an input no in the graph. (apache#7699) Free TensorRT engine and context (apache#7702) [TFLite] Cast operator adapted for MLIR-based convertor (apache#7639) [CPP_RPC] allow user supplied work dir (apache#7670) Default value for graph_runtime Init lookup_linked_param_func (apache#7676)
2 parents 78ee787 + aa494cf commit 04f0f41

16 files changed

Lines changed: 162 additions & 71 deletions

File tree

apps/cpp_rpc/main.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ static const string kUsage =
5555
"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n"
5656
"--key - The key used to identify the device type in tracker. Default=\"\"\n"
5757
"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n"
58+
"--work-dir - Custom work directory. Default=\"\"\n"
5859
"--silent - Whether to run in silent mode. Default=False\n"
5960
"\n"
6061
" Example\n"
@@ -70,6 +71,7 @@ static const string kUsage =
7071
* \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
7172
* \arg key The key used to identify the device type in tracker. Default=""
7273
* \arg custom_addr Custom IP Address to Report to RPC Tracker. Default=""
74+
* \arg work_dir Custom work directory. Default=""
7375
* \arg silent Whether run in silent mode. Default=False
7476
*/
7577
struct RpcServerArgs {
@@ -79,6 +81,7 @@ struct RpcServerArgs {
7981
string tracker;
8082
string key;
8183
string custom_addr;
84+
string work_dir;
8285
bool silent = false;
8386
#if defined(WIN32)
8487
std::string mmap_path;
@@ -96,6 +99,7 @@ void PrintArgs(const RpcServerArgs& args) {
9699
LOG(INFO) << "tracker = " << args.tracker;
97100
LOG(INFO) << "key = " << args.key;
98101
LOG(INFO) << "custom_addr = " << args.custom_addr;
102+
LOG(INFO) << "work_dir = " << args.work_dir;
99103
LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False"));
100104
}
101105

@@ -238,6 +242,10 @@ void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) {
238242
dmlc::InitLogging("--minloglevel=0");
239243
}
240244
#endif
245+
const string work_dir = GetCmdOption(argc, argv, "--work-dir=");
246+
if (!work_dir.empty()) {
247+
args.work_dir = work_dir;
248+
}
241249
}
242250

243251
/*!
@@ -274,7 +282,7 @@ int RpcServer(int argc, char* argv[]) {
274282
#endif
275283

276284
RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr,
277-
args.silent);
285+
args.work_dir, args.silent);
278286
return 0;
279287
}
280288

apps/cpp_rpc/rpc_env.cc

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ int mkdir(const char* path, int /* ignored */) { return _mkdir(path); }
3939
#include <iostream>
4040
#include <string>
4141
#include <vector>
42-
4342
#include "../../src/support/utils.h"
4443
#include "rpc_env.h"
4544

@@ -85,25 +84,31 @@ void CleanDir(const std::string& dirname);
8584
*/
8685
std::string BuildSharedLibrary(std::string file_in);
8786

88-
RPCEnv::RPCEnv() {
87+
RPCEnv::RPCEnv(const std::string& wd) {
88+
if (wd != "") {
89+
base_ = wd + "/.cache";
90+
mkdir(wd.c_str(), 0777);
91+
mkdir(base_.c_str(), 0777);
92+
} else {
8993
#if defined(ANDROID) || defined(__ANDROID__)
90-
char cwd[PATH_MAX];
91-
auto cmdline = fopen("/proc/self/cmdline", "r");
92-
fread(cwd, 1, sizeof(cwd), cmdline);
93-
fclose(cmdline);
94-
base_ = "/data/data/" + std::string(cwd) + "/cache/rpc";
94+
char cwd[PATH_MAX];
95+
auto cmdline = fopen("/proc/self/cmdline", "r");
96+
fread(cwd, 1, sizeof(cwd), cmdline);
97+
fclose(cmdline);
98+
base_ = "/data/data/" + std::string(cwd) + "/cache/rpc";
9599
#elif !defined(_WIN32)
96-
char cwd[PATH_MAX];
97-
if (getcwd(cwd, sizeof(cwd))) {
98-
base_ = std::string(cwd) + "/rpc";
99-
} else {
100-
base_ = "./rpc";
101-
}
100+
char cwd[PATH_MAX];
101+
if (getcwd(cwd, sizeof(cwd))) {
102+
base_ = std::string(cwd) + "/rpc";
103+
} else {
104+
base_ = "./rpc";
105+
}
102106
#else
103-
base_ = "./rpc";
107+
base_ = "./rpc";
104108
#endif
109+
mkdir(base_.c_str(), 0777);
110+
}
105111

106-
mkdir(base_.c_str(), 0777);
107112
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([this](TVMArgs args, TVMRetValue* rv) {
108113
*rv = this->GetPath(args[0]);
109114
});

apps/cpp_rpc/rpc_env.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct RPCEnv {
3939
/*!
4040
* \brief Constructor Init The RPC Environment initialize function
4141
*/
42-
RPCEnv();
42+
RPCEnv(const std::string& word_dir = "");
4343
/*!
4444
* \brief GetPath To get the workpath from packed function
4545
* \param name The file name

apps/cpp_rpc/rpc_server.cc

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,15 @@ class RPCServer {
9898
* \brief Constructor.
9999
*/
100100
RPCServer(std::string host, int port, int port_end, std::string tracker_addr, std::string key,
101-
std::string custom_addr)
101+
std::string custom_addr, std::string work_dir)
102102
: host_(std::move(host)),
103103
port_(port),
104104
my_port_(0),
105105
port_end_(port_end),
106106
tracker_addr_(std::move(tracker_addr)),
107107
key_(std::move(key)),
108-
custom_addr_(std::move(custom_addr)) {}
108+
custom_addr_(std::move(custom_addr)),
109+
work_dir_(std::move(work_dir)) {}
109110

110111
/*!
111112
* \brief Destructor.
@@ -174,7 +175,7 @@ class RPCServer {
174175
const pid_t worker_pid = fork();
175176
if (worker_pid == 0) {
176177
// Worker process
177-
ServerLoopProc(conn, addr);
178+
ServerLoopProc(conn, addr, work_dir_);
178179
_exit(0);
179180
}
180181

@@ -201,7 +202,7 @@ class RPCServer {
201202
} else {
202203
auto pid = fork();
203204
if (pid == 0) {
204-
ServerLoopProc(conn, addr);
205+
ServerLoopProc(conn, addr, work_dir_);
205206
exit(0);
206207
}
207208
// Wait for the result
@@ -308,9 +309,10 @@ class RPCServer {
308309
* \param sock The socket information
309310
* \param addr The socket address information
310311
*/
311-
static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) {
312+
static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr,
313+
std::string work_dir) {
312314
// Server loop
313-
const auto env = RPCEnv();
315+
const auto env = RPCEnv(work_dir);
314316
RPCServerLoop(int(sock.sockfd));
315317
LOG(INFO) << "Finish serving " << addr.AsString();
316318
env.CleanUp();
@@ -339,6 +341,7 @@ class RPCServer {
339341
std::string tracker_addr_;
340342
std::string key_;
341343
std::string custom_addr_;
344+
std::string work_dir_;
342345
support::TCPSocket listen_sock_;
343346
support::TCPSocket tracker_sock_;
344347
};
@@ -370,19 +373,19 @@ void ServerLoopFromChild(SOCKET socket) {
370373
* silent mode. Default=True
371374
*/
372375
void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr,
373-
std::string key, std::string custom_addr, bool silent) {
376+
std::string key, std::string custom_addr, std::string work_dir, bool silent) {
374377
if (silent) {
375378
// Only errors and fatal is logged
376379
dmlc::InitLogging("--minloglevel=2");
377380
}
378381
// Start the rpc server
379382
RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key),
380-
std::move(custom_addr));
383+
std::move(custom_addr), std::move(work_dir));
381384
rpc.Start();
382385
}
383386

384387
TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) {
385-
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
388+
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]);
386389
});
387390
} // namespace runtime
388391
} // namespace tvm

apps/cpp_rpc/rpc_server.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,12 @@ void ServerLoopFromChild(SOCKET socket);
4848
* \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
4949
* \param key The key used to identify the device type in tracker. Default=""
5050
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
51+
* \param work_dir Custom work directory. Default=""
5152
* \param silent Whether run in silent mode. Default=True
5253
*/
5354
void RPCServerCreate(std::string host = "", int port = 9090, int port_end = 9099,
5455
std::string tracker_addr = "", std::string key = "",
55-
std::string custom_addr = "", bool silent = true);
56+
std::string custom_addr = "", std::string work_dir = "", bool silent = true);
5657
} // namespace runtime
5758
} // namespace tvm
5859
#endif // TVM_APPS_CPP_RPC_SERVER_H_

python/tvm/auto_scheduler/dispatcher.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class DispatchContext(object):
5050
def __init__(self):
5151
self._old_ctx = DispatchContext.current
5252

53-
def query(self, target, workload_key, has_complex_op, dag):
53+
def query(self, target, workload_key, has_complex_op, dag, func_name):
5454
"""
5555
Query the context to get the specific config for a workload.
5656
If cannot find the result inside this context, this function will query it
@@ -66,15 +66,17 @@ def query(self, target, workload_key, has_complex_op, dag):
6666
Whether this workload has at least one complex op.
6767
dag: ComputeDAG
6868
The ComputeDAG of the workload.
69+
func_name: str
70+
The function name of this workload.
6971
7072
Returns
7173
-------
7274
state : StateObject
7375
The state that stores schedule configuration for the workload
7476
"""
75-
ret = self._query_inside(target, workload_key)
77+
ret = self._query_inside(target, workload_key, func_name)
7678
if ret is None:
77-
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
79+
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag, func_name)
7880
return ret
7981

8082
def update(self, target, workload_key, state):
@@ -92,7 +94,7 @@ def update(self, target, workload_key, state):
9294
"""
9395
raise NotImplementedError()
9496

95-
def _query_inside(self, target, workload_key):
97+
def _query_inside(self, target, workload_key, func_name):
9698
"""
9799
Query the context to get the specific config for a workload.
98100
This function only query config inside this context.
@@ -103,6 +105,8 @@ def _query_inside(self, target, workload_key):
103105
The current target
104106
workload_key : str
105107
The current workload_key.
108+
func_name: str
109+
The function name of this workload.
106110
107111
Returns
108112
-------
@@ -241,7 +245,7 @@ def load(self, records, n_lines=None):
241245

242246
logger.debug("Finish loading %d records", counter)
243247

244-
def _query_inside(self, target, workload_key):
248+
def _query_inside(self, target, workload_key, func_name):
245249
if target is None:
246250
raise RuntimeError(
247251
"Need a target context to find the history best. "
@@ -343,18 +347,20 @@ def __init__(
343347
records, n_lines=None, include_compatible=True
344348
)
345349

346-
def query(self, target, workload_key, has_complex_op, dag):
350+
def query(self, target, workload_key, has_complex_op, dag, func_name):
347351
if has_complex_op or self.sample_simple_workloads:
348-
ret = self._query_inside(target, workload_key)
352+
ret = self._query_inside(target, workload_key, func_name)
349353
else:
350-
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
354+
ret = super(ApplyHistoryBestOrSample, self)._query_inside(
355+
target, workload_key, func_name
356+
)
351357

352358
if ret is None:
353-
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
359+
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag, func_name)
354360
return ret
355361

356-
def _query_inside(self, target, workload_key):
357-
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
362+
def _query_inside(self, target, workload_key, func_name):
363+
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key, func_name)
358364
if ret is not None:
359365
return ret
360366

@@ -386,7 +392,9 @@ def _query_inside(self, target, workload_key):
386392

387393
# Load the sampled records and query again.
388394
self.load(log_file)
389-
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
395+
ret = super(ApplyHistoryBestOrSample, self)._query_inside(
396+
target, workload_key, func_name
397+
)
390398

391399
del measure_ctx
392400
return ret
@@ -411,18 +419,19 @@ def __init__(self):
411419
# a set to prevent print duplicated message
412420
self.messages = set()
413421

414-
def query(self, target, workload_key, has_complex_op, dag):
422+
def query(self, target, workload_key, has_complex_op, dag, func_name):
415423
key = (str(target), workload_key)
416424
if key in self.memory:
417425
return self.memory[key]
418426

419427
if self.verbose == 2 or (has_complex_op and self.verbose == 1):
420428
msg = (
421-
"-----------------------------------\n"
422-
"Cannot find tuned schedules for target=%s, workload_key=%s. "
423-
"A fallback TOPI schedule is used, "
424-
"which may bring great performance regression or even compilation failure. "
425-
"Compute DAG info:\n%s" % (target, workload_key, dag)
429+
f"-----------------------------------\n"
430+
f"{func_name}\n"
431+
f"Cannot find tuned schedules for target={target}, workload_key={workload_key}. "
432+
f"A fallback TOPI schedule is used, "
433+
f"which may bring great performance regression or even compilation failure. "
434+
f"Compute DAG info:\n{dag}"
426435
)
427436
if msg not in self.messages:
428437
self.messages.add(msg)
@@ -434,8 +443,8 @@ def query(self, target, workload_key, has_complex_op, dag):
434443
self.memory[key] = state
435444
return state
436445

437-
def _query_inside(self, target, workload_key):
438-
_ = target = workload_key
446+
def _query_inside(self, target, workload_key, func_name):
447+
_ = target = workload_key = func_name
439448
raise RuntimeError("This function should never be called")
440449

441450
def update(self, target, workload_key, state):

python/tvm/auto_scheduler/relay_integration.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,17 @@ def traverse(t):
256256

257257

258258
@tvm._ffi.register_func("auto_scheduler.relay_integration.auto_schedule_topi_compute")
259-
def auto_schedule_topi(outs):
259+
def auto_schedule_topi(func_name, outs):
260260
"""Use auto-scheduler to schedule any topi compute function.
261261
262262
Note: This is used internally for relay integration. Do
263263
not use this as a general user-facing API.
264264
265265
Parameters
266266
----------
267+
func_name: str
268+
The name of the function being scheduled.
269+
267270
outs: List[Tensor]
268271
The output tensors of topi compute functions
269272
@@ -289,7 +292,7 @@ def auto_schedule_topi(outs):
289292
target = tvm.target.Target.current()
290293

291294
dispatch_ctx = DispatchContext.current
292-
state = dispatch_ctx.query(target, key, has_complex_op, dag)
295+
state = dispatch_ctx.query(target, key, has_complex_op, dag, func_name)
293296
schedule = None
294297

295298
env = TracingEnvironment.current

python/tvm/relay/frontend/onnx.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2914,7 +2914,7 @@ def from_onnx(self, graph, opset, get_output_expr=False):
29142914
else:
29152915
self._num_input += 1
29162916
if i_name in self._shape:
2917-
i_shape = self._shape[i_name]
2917+
i_shape = self._shape.pop(i_name)
29182918
else:
29192919
if "?" in str(i_shape):
29202920
warning_msg = (
@@ -2929,6 +2929,11 @@ def from_onnx(self, graph, opset, get_output_expr=False):
29292929
dtype = d_type
29302930
self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype)
29312931
self._inputs[i_name] = self._nodes[i_name]
2932+
assert (
2933+
len(self._shape) == 0
2934+
), "User specified the shape for inputs that weren't found in the graph: " + str(
2935+
self._shape
2936+
)
29322937
# get list of unsupported ops
29332938
convert_map = _get_convert_map(opset)
29342939
unsupported_ops = set()

0 commit comments

Comments
 (0)