@@ -50,7 +50,7 @@ class DispatchContext(object):
5050 def __init__ (self ):
5151 self ._old_ctx = DispatchContext .current
5252
53- def query (self , target , workload_key , has_complex_op , dag ):
53+ def query (self , target , workload_key , has_complex_op , dag , func_name ):
5454 """
5555 Query the context to get the specific config for a workload.
5656 If cannot find the result inside this context, this function will query it
@@ -66,15 +66,17 @@ def query(self, target, workload_key, has_complex_op, dag):
6666 Whether this workload has at least one complex op.
6767 dag: ComputeDAG
6868 The ComputeDAG of the workload.
69+ func_name: str
70+ The function name of this workload.
6971
7072 Returns
7173 -------
7274 state : StateObject
7375 The state that stores schedule configuration for the workload
7476 """
75- ret = self ._query_inside (target , workload_key )
77+ ret = self ._query_inside (target , workload_key , func_name )
7678 if ret is None :
77- ret = self ._old_ctx .query (target , workload_key , has_complex_op , dag )
79+ ret = self ._old_ctx .query (target , workload_key , has_complex_op , dag , func_name )
7880 return ret
7981
8082 def update (self , target , workload_key , state ):
@@ -92,7 +94,7 @@ def update(self, target, workload_key, state):
9294 """
9395 raise NotImplementedError ()
9496
95- def _query_inside (self , target , workload_key ):
97+ def _query_inside (self , target , workload_key , func_name ):
9698 """
9799 Query the context to get the specific config for a workload.
98100 This function only query config inside this context.
@@ -103,6 +105,8 @@ def _query_inside(self, target, workload_key):
103105 The current target
104106 workload_key : str
105107 The current workload_key.
108+ func_name: str
109+ The function name of this workload.
106110
107111 Returns
108112 -------
@@ -241,7 +245,7 @@ def load(self, records, n_lines=None):
241245
242246 logger .debug ("Finish loading %d records" , counter )
243247
244- def _query_inside (self , target , workload_key ):
248+ def _query_inside (self , target , workload_key , func_name ):
245249 if target is None :
246250 raise RuntimeError (
247251 "Need a target context to find the history best. "
@@ -343,18 +347,20 @@ def __init__(
343347 records , n_lines = None , include_compatible = True
344348 )
345349
346- def query (self , target , workload_key , has_complex_op , dag ):
350+ def query (self , target , workload_key , has_complex_op , dag , func_name ):
347351 if has_complex_op or self .sample_simple_workloads :
348- ret = self ._query_inside (target , workload_key )
352+ ret = self ._query_inside (target , workload_key , func_name )
349353 else :
350- ret = super (ApplyHistoryBestOrSample , self )._query_inside (target , workload_key )
354+ ret = super (ApplyHistoryBestOrSample , self )._query_inside (
355+ target , workload_key , func_name
356+ )
351357
352358 if ret is None :
353- ret = self ._old_ctx .query (target , workload_key , has_complex_op , dag )
359+ ret = self ._old_ctx .query (target , workload_key , has_complex_op , dag , func_name )
354360 return ret
355361
356- def _query_inside (self , target , workload_key ):
357- ret = super (ApplyHistoryBestOrSample , self )._query_inside (target , workload_key )
362+ def _query_inside (self , target , workload_key , func_name ):
363+ ret = super (ApplyHistoryBestOrSample , self )._query_inside (target , workload_key , func_name )
358364 if ret is not None :
359365 return ret
360366
@@ -386,7 +392,9 @@ def _query_inside(self, target, workload_key):
386392
387393 # Load the sampled records and query again.
388394 self .load (log_file )
389- ret = super (ApplyHistoryBestOrSample , self )._query_inside (target , workload_key )
395+ ret = super (ApplyHistoryBestOrSample , self )._query_inside (
396+ target , workload_key , func_name
397+ )
390398
391399 del measure_ctx
392400 return ret
@@ -411,18 +419,19 @@ def __init__(self):
411419 # a set to prevent print duplicated message
412420 self .messages = set ()
413421
414- def query (self , target , workload_key , has_complex_op , dag ):
422+ def query (self , target , workload_key , has_complex_op , dag , func_name ):
415423 key = (str (target ), workload_key )
416424 if key in self .memory :
417425 return self .memory [key ]
418426
419427 if self .verbose == 2 or (has_complex_op and self .verbose == 1 ):
420428 msg = (
421- "-----------------------------------\n "
422- "Cannot find tuned schedules for target=%s, workload_key=%s. "
423- "A fallback TOPI schedule is used, "
424- "which may bring great performance regression or even compilation failure. "
425- "Compute DAG info:\n %s" % (target , workload_key , dag )
429+ f"-----------------------------------\n "
430+ f"{ func_name } \n "
431+ f"Cannot find tuned schedules for target={ target } , workload_key={ workload_key } . "
432+ f"A fallback TOPI schedule is used, "
433+ f"which may bring great performance regression or even compilation failure. "
434+ f"Compute DAG info:\n { dag } "
426435 )
427436 if msg not in self .messages :
428437 self .messages .add (msg )
@@ -434,8 +443,8 @@ def query(self, target, workload_key, has_complex_op, dag):
434443 self .memory [key ] = state
435444 return state
436445
437- def _query_inside (self , target , workload_key ):
438- _ = target = workload_key
446+ def _query_inside (self , target , workload_key , func_name ):
447+ _ = target = workload_key = func_name
439448 raise RuntimeError ("This function should never be called" )
440449
441450 def update (self , target , workload_key , state ):
0 commit comments