Skip to content

Commit e08caef

Browse files
authored
[MetaSchedule] Enhance tune_tir to tune IRModule of TIR Collections (#14784)
This PR enhances the user-facing API `tune_tir` that makes it more convenient to feed in an IRModule consists of multiple TIRs and customize the search space for them. This is widely used in our MLC-LLM project where we wanted to customize the search space for some quantization-related operators. An example usecase: ```python from tvm import meta_schedule as ms ms.tir_integration.tune_tir( ... space="cuda", <========== by default, the space is "cuda" rather than "cuda-tensorcore" special_space={ "fused_decode1_fused_matmul2_add1_gelu": sch_fused_decode_gemv, "decode": sch_decode, }, ) ```
1 parent 01324ef commit e08caef

4 files changed

Lines changed: 51 additions & 22 deletions

File tree

python/tvm/contrib/torch/as_torch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import torch
3535
import torch.utils.dlpack
36+
3637
import tvm
3738
from tvm import meta_schedule as ms
3839
from tvm.target.target import Target
@@ -66,7 +67,6 @@ def tune(
6667
task_scheduler: ms.TaskScheduler.TaskSchedulerType = "round-robin",
6768
space: ms.SpaceGenerator.SpaceGeneratorType = "post-order-apply",
6869
strategy: ms.SearchStrategy.SearchStrategyType = "replay-trace",
69-
task_name: str = "main",
7070
num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
7171
seed: Optional[int] = None,
7272
) -> None:
@@ -99,7 +99,6 @@ def tune(
9999
task_scheduler=task_scheduler,
100100
space=space,
101101
strategy=strategy,
102-
task_name=task_name,
103102
num_tuning_cores=num_tuning_cores,
104103
seed=seed,
105104
)

python/tvm/meta_schedule/space_generator/space_generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def create_schedule_fn(
126126
return PostOrderApply(*args, **kwargs)
127127
if kind == "union":
128128
return SpaceGeneratorUnion(*args, **kwargs)
129+
if isinstance(kind, str):
130+
return PostOrderApply(sch_rules=kind, postprocs=kind, mutator_probs=kind)
129131
raise ValueError(f"Unknown SpaceGenerator: {kind}")
130132

131133

python/tvm/meta_schedule/testing/tune_te.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def main():
135135
adaptive_training=ARGS.adaptive_training,
136136
),
137137
strategy=ms.search_strategy.EvolutionarySearch(),
138-
task_name=ARGS.workload,
139138
)
140139

141140
print("Tuning Time:")

python/tvm/meta_schedule/tir_integration.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""MetaSchedule-TIR integration"""
18-
from typing import Optional, Union
18+
from typing import List, Mapping, Optional, Tuple, Union
1919

2020
# isort: off
2121
from typing_extensions import Literal
@@ -38,37 +38,41 @@
3838
from .utils import fork_seed
3939

4040

41-
def tune_tir(
41+
def tune_tir( # pylint: disable=too-many-locals
4242
mod: Union[ir.IRModule, tir.PrimFunc],
4343
target: Union[str, Target],
4444
work_dir: str,
4545
max_trials_global: int,
4646
*,
47+
max_trials_per_task: Optional[int] = None,
4748
num_trials_per_iter: int = 64,
4849
builder: Builder.BuilderType = "local",
4950
runner: Runner.RunnerType = "local",
5051
database: Database.DatabaseType = "json",
5152
cost_model: CostModel.CostModelType = "xgb",
5253
measure_callbacks: MeasureCallback.CallbackListType = "default",
53-
task_scheduler: TaskScheduler.TaskSchedulerType = "round-robin",
54+
task_scheduler: TaskScheduler.TaskSchedulerType = "gradient",
5455
space: SpaceGenerator.SpaceGeneratorType = "post-order-apply",
5556
strategy: SearchStrategy.SearchStrategyType = "evolutionary",
56-
task_name: str = "main",
5757
num_tuning_cores: Union[Literal["physical", "logical"], int] = "physical",
5858
seed: Optional[int] = None,
59+
module_equality: str = "structural",
60+
special_space: Optional[Mapping[str, SpaceGenerator.SpaceGeneratorType]] = None,
5961
) -> Database:
60-
"""Tune a TIR function.
62+
"""Tune a TIR function or an IRModule of TIR functions.
6163
6264
Parameters
6365
----------
6466
mod : Union[ir.IRModule, tir.PrimFunc]
65-
The TIR function to tune.
67+
The TIR IRModule to tune.
6668
target : Union[str, Target]
6769
The target to tune for.
6870
work_dir : str
6971
The working directory.
7072
max_trials_global : int
7173
The maximum number of trials to run globally.
74+
max_trials_per_task : Optional[int]
75+
The maximum number of trials to run per task.
7276
num_trials_per_iter : int
7377
The number of trials to run per iteration
7478
builder : Builder.BuilderType
@@ -87,44 +91,69 @@ def tune_tir(
8791
The space generator.
8892
strategy : SearchStrategy.SearchStrategyType
8993
The search strategy.
90-
task_name : str
91-
The name of the task.
9294
num_tuning_cores : Union[Literal["physical", "logical"], int]
9395
The number of CPU cores to use during tuning.
9496
seed : Optional[int]
9597
The seed for the random number generator.
98+
module_equality : Optional[str]
99+
A string to specify the module equality testing and hashing method.
100+
special_space : Optional[Mapping[str, SpaceGenerator.SpaceGeneratorType]]
101+
A mapping from task name to a special space generator for that task.
96102
97103
Returns
98104
-------
99105
database : Database
100106
The database with all tuning records
101107
"""
102-
(logger,) = get_loggers_from_work_dir(work_dir, [task_name])
103-
(seed,) = fork_seed(seed, n=1)
104-
return tune_tasks(
105-
tasks=[
108+
if isinstance(mod, tir.PrimFunc):
109+
mod = _normalize_mod(mod)
110+
111+
named_tasks: List[Tuple[str, tir.PrimFunc]] = []
112+
for gv, func in mod.functions.items(): # pylint: disable=invalid-name
113+
if isinstance(func, tir.PrimFunc):
114+
named_tasks.append((gv.name_hint, func))
115+
named_tasks.sort(key=lambda x: x[0])
116+
117+
task_names = [x for x, _ in named_tasks]
118+
tasks: List[TuneContext] = []
119+
for task_name, task_func, logger, rand_state in zip(
120+
task_names,
121+
[x for _, x in named_tasks],
122+
get_loggers_from_work_dir(work_dir, task_names),
123+
fork_seed(seed, n=len(named_tasks)),
124+
):
125+
if special_space and task_name in special_space:
126+
task_space = special_space[task_name]
127+
else:
128+
task_space = space
129+
if task_space is None:
130+
continue
131+
tasks.append(
106132
TuneContext(
107-
mod=mod,
133+
mod=task_func,
108134
target=target,
109-
space_generator=space,
135+
space_generator=task_space,
110136
search_strategy=strategy,
111137
task_name=task_name,
112-
logger=logger,
113-
rand_state=seed,
138+
rand_state=rand_state,
114139
num_threads=num_tuning_cores,
140+
logger=logger,
115141
).clone()
116-
],
117-
task_weights=[1.0],
142+
)
143+
return tune_tasks(
144+
tasks=tasks,
145+
task_weights=[1.0] * len(tasks),
118146
work_dir=work_dir,
119147
max_trials_global=max_trials_global,
120-
max_trials_per_task=max_trials_global,
148+
max_trials_per_task=max_trials_per_task,
121149
num_trials_per_iter=num_trials_per_iter,
122150
builder=builder,
123151
runner=runner,
124152
database=database,
125153
cost_model=cost_model,
126154
measure_callbacks=measure_callbacks,
127155
task_scheduler=task_scheduler,
156+
module_equality=module_equality,
128157
)
129158

130159

0 commit comments

Comments
 (0)