Skip to content

Commit 689beed

Browse files
comaniactrevor-m
authored andcommitted
[AutoScheduler] Add function name in message (apache#7703)
* [AutoScheduler] Add function name in message * fix
1 parent dd5dbd9 commit 689beed

3 files changed

Lines changed: 35 additions & 23 deletions

File tree

python/tvm/auto_scheduler/dispatcher.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class DispatchContext(object):
5050
def __init__(self):
5151
self._old_ctx = DispatchContext.current
5252

53-
def query(self, target, workload_key, has_complex_op, dag):
53+
def query(self, target, workload_key, has_complex_op, dag, func_name):
5454
"""
5555
Query the context to get the specific config for a workload.
5656
If cannot find the result inside this context, this function will query it
@@ -66,15 +66,17 @@ def query(self, target, workload_key, has_complex_op, dag):
6666
Whether this workload has at least one complex op.
6767
dag: ComputeDAG
6868
The ComputeDAG of the workload.
69+
func_name: str
70+
The function name of this workload.
6971
7072
Returns
7173
-------
7274
state : StateObject
7375
The state that stores schedule configuration for the workload
7476
"""
75-
ret = self._query_inside(target, workload_key)
77+
ret = self._query_inside(target, workload_key, func_name)
7678
if ret is None:
77-
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
79+
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag, func_name)
7880
return ret
7981

8082
def update(self, target, workload_key, state):
@@ -92,7 +94,7 @@ def update(self, target, workload_key, state):
9294
"""
9395
raise NotImplementedError()
9496

95-
def _query_inside(self, target, workload_key):
97+
def _query_inside(self, target, workload_key, func_name):
9698
"""
9799
Query the context to get the specific config for a workload.
98100
This function only query config inside this context.
@@ -103,6 +105,8 @@ def _query_inside(self, target, workload_key):
103105
The current target
104106
workload_key : str
105107
The current workload_key.
108+
func_name: str
109+
The function name of this workload.
106110
107111
Returns
108112
-------
@@ -241,7 +245,7 @@ def load(self, records, n_lines=None):
241245

242246
logger.debug("Finish loading %d records", counter)
243247

244-
def _query_inside(self, target, workload_key):
248+
def _query_inside(self, target, workload_key, func_name):
245249
if target is None:
246250
raise RuntimeError(
247251
"Need a target context to find the history best. "
@@ -343,18 +347,20 @@ def __init__(
343347
records, n_lines=None, include_compatible=True
344348
)
345349

346-
def query(self, target, workload_key, has_complex_op, dag):
350+
def query(self, target, workload_key, has_complex_op, dag, func_name):
347351
if has_complex_op or self.sample_simple_workloads:
348-
ret = self._query_inside(target, workload_key)
352+
ret = self._query_inside(target, workload_key, func_name)
349353
else:
350-
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
354+
ret = super(ApplyHistoryBestOrSample, self)._query_inside(
355+
target, workload_key, func_name
356+
)
351357

352358
if ret is None:
353-
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
359+
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag, func_name)
354360
return ret
355361

356-
def _query_inside(self, target, workload_key):
357-
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
362+
def _query_inside(self, target, workload_key, func_name):
363+
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key, func_name)
358364
if ret is not None:
359365
return ret
360366

@@ -386,7 +392,9 @@ def _query_inside(self, target, workload_key):
386392

387393
# Load the sampled records and query again.
388394
self.load(log_file)
389-
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
395+
ret = super(ApplyHistoryBestOrSample, self)._query_inside(
396+
target, workload_key, func_name
397+
)
390398

391399
del measure_ctx
392400
return ret
@@ -411,18 +419,19 @@ def __init__(self):
411419
# a set to prevent print duplicated message
412420
self.messages = set()
413421

414-
def query(self, target, workload_key, has_complex_op, dag):
422+
def query(self, target, workload_key, has_complex_op, dag, func_name):
415423
key = (str(target), workload_key)
416424
if key in self.memory:
417425
return self.memory[key]
418426

419427
if self.verbose == 2 or (has_complex_op and self.verbose == 1):
420428
msg = (
421-
"-----------------------------------\n"
422-
"Cannot find tuned schedules for target=%s, workload_key=%s. "
423-
"A fallback TOPI schedule is used, "
424-
"which may bring great performance regression or even compilation failure. "
425-
"Compute DAG info:\n%s" % (target, workload_key, dag)
429+
f"-----------------------------------\n"
430+
f"{func_name}\n"
431+
f"Cannot find tuned schedules for target={target}, workload_key={workload_key}. "
432+
f"A fallback TOPI schedule is used, "
433+
f"which may bring great performance regression or even compilation failure. "
434+
f"Compute DAG info:\n{dag}"
426435
)
427436
if msg not in self.messages:
428437
self.messages.add(msg)
@@ -434,8 +443,8 @@ def query(self, target, workload_key, has_complex_op, dag):
434443
self.memory[key] = state
435444
return state
436445

437-
def _query_inside(self, target, workload_key):
438-
_ = target = workload_key
446+
def _query_inside(self, target, workload_key, func_name):
447+
_ = target = workload_key = func_name
439448
raise RuntimeError("This function should never be called")
440449

441450
def update(self, target, workload_key, state):

python/tvm/auto_scheduler/relay_integration.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,17 @@ def traverse(t):
256256

257257

258258
@tvm._ffi.register_func("auto_scheduler.relay_integration.auto_schedule_topi_compute")
259-
def auto_schedule_topi(outs):
259+
def auto_schedule_topi(func_name, outs):
260260
"""Use auto-scheduler to schedule any topi compute function.
261261
262262
Note: This is used internally for relay integration. Do
263263
not use this as a general user-facing API.
264264
265265
Parameters
266266
----------
267+
func_name: str
268+
The name of the function being scheduled.
269+
267270
outs: List[Tensor]
268271
The output tensors of topi compute functions
269272
@@ -289,7 +292,7 @@ def auto_schedule_topi(outs):
289292
target = tvm.target.Target.current()
290293

291294
dispatch_ctx = DispatchContext.current
292-
state = dispatch_ctx.query(target, key, has_complex_op, dag)
295+
state = dispatch_ctx.query(target, key, has_complex_op, dag, func_name)
293296
schedule = None
294297

295298
env = TracingEnvironment.current

src/relay/backend/compile_engine.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
157157
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
158158
ICHECK(fauto_schedule != nullptr)
159159
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
160-
ObjectRef obj = (*fauto_schedule)(tensor_outs);
160+
ObjectRef obj = (*fauto_schedule)(String(cache_node->func_name), tensor_outs);
161161
if (obj.defined()) {
162162
schedule = Downcast<te::Schedule>(obj);
163163
}

0 commit comments

Comments
 (0)