From a6c70711a99a5253f6184a070ab6179a41630d0f Mon Sep 17 00:00:00 2001 From: canesche Date: Mon, 17 Jun 2024 18:28:09 -0300 Subject: [PATCH] [MetaSchedule] Adding post optimization in MetaSchedule to Improve Scheduling --- python/tvm/meta_schedule/__init__.py | 2 + .../post_optimization/__init__.py | 24 ++ .../post_optimization/droplet.py | 134 +++++++++ .../post_optimization/post_opt.py | 76 +++++ .../meta_schedule/post_optimization/space.py | 259 ++++++++++++++++++ .../meta_schedule/post_optimization/utils.py | 112 ++++++++ python/tvm/meta_schedule/relay_integration.py | 4 + python/tvm/meta_schedule/tir_integration.py | 2 + python/tvm/meta_schedule/tune.py | 7 + .../test_meta_schedule_space_post_opt.py | 118 ++++++++ 10 files changed, 738 insertions(+) create mode 100644 python/tvm/meta_schedule/post_optimization/__init__.py create mode 100644 python/tvm/meta_schedule/post_optimization/droplet.py create mode 100644 python/tvm/meta_schedule/post_optimization/post_opt.py create mode 100644 python/tvm/meta_schedule/post_optimization/space.py create mode 100644 python/tvm/meta_schedule/post_optimization/utils.py create mode 100644 tests/python/meta_schedule/test_meta_schedule_space_post_opt.py diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 21a11ff9e84d..b44dbe45e0b7 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -33,6 +33,7 @@ space_generator, tir_integration, trace_apply, + post_optimization, ) from .builder import Builder from .cost_model import CostModel @@ -53,3 +54,4 @@ from .tune import tune_tasks from .tune_context import TuneContext from .utils import derived_object +from .post_optimization import post_opt diff --git a/python/tvm/meta_schedule/post_optimization/__init__.py b/python/tvm/meta_schedule/post_optimization/__init__.py new file mode 100644 index 000000000000..9acb525477e4 --- /dev/null +++ b/python/tvm/meta_schedule/post_optimization/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.database package. +The database that stores serialized tuning records and workloads +""" +from .post_opt import PostOpt +from .droplet import Droplet +from .space import Space +from .utils import write_file, get_time diff --git a/python/tvm/meta_schedule/post_optimization/droplet.py b/python/tvm/meta_schedule/post_optimization/droplet.py new file mode 100644 index 000000000000..b534c74adb77 --- /dev/null +++ b/python/tvm/meta_schedule/post_optimization/droplet.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Droplet algorithm """ + +import os +import numpy as np # type: ignore + +from .utils import write_file, get_time +from .space import Space + + +class Droplet: + """Tuner with droplet algorithm in Meta Schedule. + + Parameters + ---------- + json_file: str + json format file + target: + hardware target + log: str + path to save json file + trials: int + number of samples, the default is 100 + pvalue: float + statistical value to confidence level, the default is 0.05 + """ + + def __init__(self, json_file, workload_file, target, log, pvalue=0.05) -> None: + self.space = Space(json_file, workload_file, target) + self.final_log = write_file([json_file], log) + self.pvalue = pvalue + self.next = [(0, [0] * len(self.space.dims))] + best_avg, _ = get_time(log) + self.best_choice = [0, [0] * len(self.space.dims), best_avg] + self.count, self.execution, self.found_best_pos = 1, 1, True + self.total_execution = 1 + if len(self.space.dims) > 0: + self.total_execution = max(self.space.dims) + self.dims, self.step = self.space.dims, 1 + self.visited, self.batch = set([0]), max(os.cpu_count(), len(self.dims)) + + def next_batch(self, batch_size): + i, json_file_list = 0, [] + while i < len(self.next): + if batch_size > 0 and self.count >= self.trials: + break + json_file_list.append(self.space.template(values=self.next[i][1], create=False)) + i, self.count = i + 1, self.count + 1 + return self.space.run(json_file_list, self.final_log) + + def has_next(self): + return len(self.next) > 0 and self.found_best_pos + + def tune(self, n_trial=100): + self.trials = n_trial + self.speculation() + while self.has_next(): + res = self.next_batch(self.batch) + self.update(res) + + def num_to_bin(self, value, factor=1): + bin_format = str(0) * (len(self.dims) - len(bin(value)[2:])) + bin(value)[2:] + return [int(i) * factor for i in bin_format] + + def search_space(self, factor=1): + "create a search space" + search_space: list = [] + for i in range(0, len(self.space.dims)): + if len(search_space) > self.batch - len(self.next): + break + space = self.num_to_bin(2**i, factor) + idx = self.space.knob2point(space) + if idx not in self.visited: + search_space.append(space) + return search_space + + def next_pos(self, new_positions): + "returns the neighbors of the best solution" + next_set = [] + for p in new_positions: + new_p = [ + (x + y) % self.dims[i] if (x + y > 0) else 0 + for i, (x, y) in enumerate(zip(p, self.best_choice[1])) + ] + idx_p = self.space.knob2point(new_p) + if idx_p not in self.visited: + self.visited.add(idx_p) + next_set.append((idx_p, new_p)) + return next_set + + def speculation(self): + # Gradient descending direction prediction and search space filling + while len(self.next) < self.batch and self.execution < self.total_execution: + self.next += self.next_pos(self.search_space(self.execution)) + self.execution += self.step + + def update(self, results): + """Update the values""" + self.found_best_pos, count_valids = False, 0 + for i, res in enumerate(results): + if np.mean(self.best_choice[2]) > np.mean(res): + self.best_choice = [self.next[i][0], self.next[i][1], res] + self.found_best_pos = True + if np.mean(res) != 10000: + count_valids += 1 + + self.next = [] + + # stop, because all neighborhoods are invalid. + if count_valids == 0: + self.speculation() + self.found_best_pos = True + return + + if self.found_best_pos: + self.next += self.next_pos(self.search_space()) + self.execution = 1 + self.speculation() diff --git a/python/tvm/meta_schedule/post_optimization/post_opt.py b/python/tvm/meta_schedule/post_optimization/post_opt.py new file mode 100644 index 000000000000..f1c3021a1f28 --- /dev/null +++ b/python/tvm/meta_schedule/post_optimization/post_opt.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Post optimization method""" + +import numpy as np # type: ignore +from tvm.target import Target + +from .droplet import Droplet +from .utils import read_cfg_file, get_time, write_file, clean_file + + +class PostOpt: + """PostOpt class + + Parameters + ---------- + work_dir : str + The working directory. + target: Target data + Target device information + trials: integer value + Max number of trials to execute the optimization + """ + + def __init__(self, work_dir: str, target: Target, trials: int = 100) -> None: + self.work_dir = work_dir + self.target = target + self.trials = trials + + def run(self) -> None: + """Execute the post optimization""" + + tuning_file = self.work_dir + "/database_tuning_record.json" + workload_file = self.work_dir + "/database_workload.json" + + cfg = read_cfg_file(tuning_file, workload_file) + + print("id | time MS (s) | time DPMS (s) | speedup") + for idx, layer in enumerate(cfg): + + time, data, workload = cfg[layer] + ms_time = np.mean(time) + + temp_log = f"{self.work_dir}/opt_{idx}.log" + + # Run the exploitation by Droplet + droplet = Droplet(data, workload, self.target, temp_log) + droplet.tune(self.trials) + + dpms_time, dpm_sol = get_time(temp_log) + dpms_time = np.mean(dpms_time) + + speedup = ms_time / dpms_time + + # save the best solution + write_file([dpm_sol], tuning_file, mode="a") + + # show the perfomance + print(f"{idx:2d} | {ms_time:.10f} | {dpms_time:.10f} | {speedup:.2f}") + + # clean the temporary files + clean_file(temp_log) diff --git a/python/tvm/meta_schedule/post_optimization/space.py b/python/tvm/meta_schedule/post_optimization/space.py new file mode 100644 index 000000000000..3b1f5efa1e83 --- /dev/null +++ b/python/tvm/meta_schedule/post_optimization/space.py @@ -0,0 +1,259 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" The class of Space used to optimize the Meta parameters """ + +import json +import random +from copy import deepcopy +from typing import Dict, List, Any +import numpy as np # type: ignore + +from tvm import meta_schedule as ms +from tvm.target import Target +from tvm.tir import Schedule +from tvm.meta_schedule.database import Workload, TuningRecord +from tvm.meta_schedule.utils import remove_build_dir + +from .utils import write_file + + +class Space: + """Space class + + Parameters + ---------- + data: json data + A json file template + workload: json data + A json file workload + target: Target data + Target device information + """ + + def __init__(self, data: Any, workload: Any, target: Target): + self.cfg = deepcopy(data) + self._id = data[0] + self.workload = Workload.from_json(workload) + self.target = target + self.dev = self.get_device_type(target) + self.total_dims = 0 + self.dims: List[int] = [] + self.start: List[int] = [] + self.config_space: Dict[str, List[int]] = dict() + self.create_space() + + def __repr__(self) -> str: + """Print the config space""" + out = "" + for key in self.config_space: + out += f"{key}: dims={self.config_space[key]}\n" + out += f"Total dimensions: {self.total_dims}\n" + return out + + def __str__(self) -> str: + """Print the config space""" + out = "" + for key in self.config_space: + out += f"{key}: dims={self.config_space[key]}\n" + out += f"Total dimensions: {self.total_dims}\n" + return out + + def get_value(self, key, pos): + """Return the space""" + return self.config_space[key][pos] + + def add_space(self, space_list: list, element_list: list, limit=10000) -> List[int]: + """Return a list without repeat and with limited value""" + new_list = element_list + for elem in space_list: + if elem not in new_list and elem <= limit: + new_list.append(elem) + return new_list + + def knob2point(self, knob): + """Convert a array to point""" + point = 0 + for j, k in enumerate(knob): + point += int(np.prod(self.dims[:j])) * k + return point + + def point2knob(self, point): + """Convert point form (single integer) to knob (vector)""" + knob = [] + for dim in self.dims: + knob.append(point % dim) + point //= dim + return knob + + def power_of_two(self, min_value: int, max_value: int) -> list: + """Return power of two array in interval""" + return [1 << i for i in range(min_value, max_value + 1)] + + def get_index(self, array: list, value: int): + """returns an index if it finds the value""" + for i in range(len(array)): + if array[i][0] == value: + return i + return -1 + + def template(self, values=None, create=True): + """Generate the template from the values""" + idx = -1 + config = deepcopy(self.cfg[1]) + for counter, cfg in enumerate(config[0][0]): + opt = cfg[0] + if opt == "Annotate": + ann_key = cfg[2] + if ann_key == ["meta_schedule.parallel"]: + interval = self.power_of_two(5, 9) + elif ann_key == ["meta_schedule.vectorize"]: + interval = self.power_of_two(4, 8) + elif ann_key == ["pragma_auto_unroll_max_step"]: + interval = self.power_of_two(7, 11) + elif ann_key == ["meta_schedule.thread_extent_low_inclusive"]: + interval = self.power_of_two(5, 6) + elif ann_key == ["meta_schedule.thread_extent_high_inclusive"]: + interval = self.power_of_two(8, 12) + else: + continue + idx += 1 + key = f"ann_{idx}" + ann_value = cfg[1][1] + if create: + self.config_space[key] = self.add_space(interval, [ann_value]) + else: + cfg[1][1] = self.get_value(key, values[idx]) + elif opt == "SamplePerfectTile": + tile = config[0][1] + tile_idx = self.get_index(tile, counter) + tile_val = tile[tile_idx][1] + interval = self.power_of_two(1, 6) + for i in range(len(tile_val)): + idx += 1 + key = f"sp_{counter}_{idx}" + split = tile_val[i] + if create: + self.config_space[key] = self.add_space(interval, [split]) + else: + config[0][1][tile_idx][1][i] = self.get_value(key, values[idx]) + elif opt == "TransformLayout": + del config[0][0][counter] + if create: + return None + return config + + def create_space(self): + """Create the space using Meta's space""" + self.template(create=True) + # print(self.config_space) + self.dims = [] + for key in self.config_space: + self.dims.append(len(self.config_space[key])) + self.total_dims = 1 + if len(self.dims) > 0: + for dim in self.dims: + self.total_dims *= dim + + def get_device_type(self, target: Target) -> str: + """Get the device type string from a target. + + Parameters + ---------- + target : Target + The target to get the device type from. + + Returns + ------- + device_type : str + The device type string. + """ + if target.kind.name == "llvm": + return "cpu" + elif target.kind.name == "cuda": + return "cuda" + else: + raise RuntimeError(f"Unsupported target kind for device type: {target.kind.name}") + + def save_log( + self, + path: str, + record: ms.database.TuningRecord, + results: ms.runner.RunnerResult, + ) -> None: + """Save the log file""" + new_json = [self._id, record.as_json()] + new_json[1][1] = results + write_file([new_json], path, "a") + + def run( + self, + json_file_list, + final_log, + timeout=10, + number=2, + repeat=3, + min_repeat_ms=0, + cpu_cache=False, + ): + """Execute a log file and save""" + + builder = ms.builder.LocalBuilder(timeout_sec=timeout) + runner = ms.runner.LocalRunner( + evaluator_config=ms.runner.EvaluatorConfig( + number=number, + repeat=repeat, + min_repeat_ms=min_repeat_ms, + enable_cpu_cache_flush=cpu_cache, + ), + ) + + results = np.full(len(json_file_list), [10000], dtype=list) + records, mods = [], [] + for i, cfg in enumerate(json_file_list): + try: + record = TuningRecord.from_json(json.loads(json.dumps(cfg)), self.workload) + sch = Schedule(self.workload.mod) + # In some layers this is a heavy impact in time cost, so + # I applied this only 25% of the samples. + remove_postproc = random.random() > 0.75 + record.trace.apply_to_schedule(sch, remove_postproc=remove_postproc) + mods.append(sch.mod) + records.append(record) + except Exception: # pylint: disable=broad-except, invalid-name + continue + + builder_res = builder.build([ms.builder.BuilderInput(mod, self.target) for mod in mods]) + + for i, record in enumerate(records): + try: + inp = ms.runner.RunnerInput( + builder_res[i].artifact_path, + device_type=self.dev, + args_info=ms.arg_info.TensorInfo.from_prim_func(mods[i]["main"]), + ) + runner_res = runner.run([inp])[0].result() + results[i] = [v.value for v in runner_res.run_secs] # type: ignore + except Exception: # pylint: disable=broad-except, invalid-name + results[i] = [1e10] + continue + + # save the solution in json file + self.save_log(final_log, record, results[i]) + + # clean up + remove_build_dir(builder_res[i].artifact_path) + return results diff --git a/python/tvm/meta_schedule/post_optimization/utils.py b/python/tvm/meta_schedule/post_optimization/utils.py new file mode 100644 index 000000000000..6fd68c7b355c --- /dev/null +++ b/python/tvm/meta_schedule/post_optimization/utils.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utils file for exploitation schedule""" + +import os +import json +from typing import Dict +import numpy as np # type: ignore + + +def write_file(json_list: list, log: str = "/tmp/file.json", mode: str = "w") -> str: + """Write the log file + + Parameters + ---------- + json_list: list + The list input json + log: Optional[str] + Path destiny to save the log file + mode: Optional[str] + Mode save, "a" means append and "w" means write + + Returns + ------- + ret: str + log path file + """ + with open(log, mode, encoding="utf-8") as outfile: + for j in json_list: + outfile.write(json.dumps(j) + "\n") + return log + + +def clean_file(filename: str) -> None: + """Clean temporary files + + Parameters + ---------- + filename: str + The filepath with remove from the system + """ + if os.path.isfile(filename): + os.remove(filename) + + +def get_time(log: str) -> list: + """Get the time from the log file + + Parameters + ---------- + log: str + log file + + Returns + ------- + ret: list + A list with the best time and the json data + """ + best_time = [1e10, None] + with open(log, "r", encoding="utf-8") as log_file: + for line in log_file.readlines(): + data = json.loads(line) + params = data[1] + time = params[1] + if np.mean(best_time[0]) > np.mean(time): + best_time = [time, data] + return best_time + + +def read_cfg_file(path_tuning_file: str, path_workload_file: str) -> Dict[int, list]: + """Colect the info from meta logfile + + Parameters + ---------- + log: str + The input log path with the meta parameter + + Returns + ------- + ret: dict[layer, Union[time, dict]] + Returns the best time, total time, and data + """ + workload_list = [] + with open(path_workload_file, "r", encoding="utf-8") as log_file: + for line in log_file.readlines(): + workload_list.append(json.loads(line)) + + cfg: Dict[int, list] = dict() + with open(path_tuning_file, "r", encoding="utf-8") as log_file: + for line in log_file.readlines(): + data = json.loads(line) + layer = data[0] + params = data[1] + time = params[1] + + if layer not in cfg.keys() or np.mean(cfg[layer][0]) > np.mean(time): + cfg[layer] = [time, data, workload_list[layer]] + return cfg diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index d22696d9d4f0..1fd7b5d73e82 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -276,6 +276,7 @@ def tune_relay( opt_level: int = 3, disabled_pass: Optional[Union[List[str], Set[str], Tuple[str]]] = None, instruments: Optional[Sequence[PassInstrument]] = None, + post_optimization: Optional[bool] = False, ) -> Database: """Tune a Relay program. @@ -331,6 +332,8 @@ def tune_relay( The list of disabled passes during tasks extraction instruments : Optional[Sequence[PassInstrument]] The list of pass instrument implementations. + post_optimization : Optional[Bool] + Generate post-optimization using Droplet Search as exploitation space. Returns ------- @@ -367,6 +370,7 @@ def tune_relay( measure_callbacks=measure_callbacks, task_scheduler=task_scheduler, module_equality=module_equality, + post_optimization=post_optimization, ) diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index 201cc804d6c8..bffade49a072 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -60,6 +60,7 @@ def tune_tir( # pylint: disable=too-many-locals seed: Optional[int] = None, module_equality: str = "structural", special_space: Optional[Mapping[str, SpaceGenerator.SpaceGeneratorType]] = None, + post_optimization: Optional[bool] = False, ) -> Database: """Tune a TIR function or an IRModule of TIR functions. @@ -156,6 +157,7 @@ def tune_tir( # pylint: disable=too-many-locals measure_callbacks=measure_callbacks, task_scheduler=task_scheduler, module_equality=module_equality, + post_optimization=post_optimization, ) diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 887941ada0d2..78c05fed533e 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -24,6 +24,7 @@ from .runner import Runner from .task_scheduler import TaskScheduler from .tune_context import TuneContext +from .post_optimization import PostOpt def tune_tasks( @@ -41,6 +42,7 @@ def tune_tasks( measure_callbacks: MeasureCallback.CallbackListType = "default", task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", module_equality: str = "structural", + post_optimization: Optional[bool] = False, ) -> Database: """Tune a list of tasks. Using a task scheduler. @@ -81,6 +83,8 @@ def tune_tasks( a given module. The "ignore-ndarray" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see tir/analysis/analysis.py. + post_optimization : Optional[Bool] + Generate post-optimization using Droplet Search as exploitation space. Returns ------- @@ -127,4 +131,7 @@ def tune_tasks( database=database, cost_model=cost_model, ) + if post_optimization: + post_opt = PostOpt(work_dir, tasks[0].target) + post_opt.run() return database diff --git a/tests/python/meta_schedule/test_meta_schedule_space_post_opt.py b/tests/python/meta_schedule/test_meta_schedule_space_post_opt.py new file mode 100644 index 000000000000..4cb5ec59630c --- /dev/null +++ b/tests/python/meta_schedule/test_meta_schedule_space_post_opt.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring,no-member,invalid-name,unused-variable +import logging +import tempfile + +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm import meta_schedule as ms +from tvm.meta_schedule.runner.config import EvaluatorConfig +from tvm.script import tir as T +from tvm.target import Target + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@pytest.mark.skip("Integration test") +@tvm.testing.requires_llvm +def test_tune_matmul_cpu(): + with tempfile.TemporaryDirectory() as work_dir: + target = Target("llvm --num-cores=16") + database = ms.tir_integration.tune_tir( + mod=matmul, + target=target, + work_dir=work_dir, + max_trials_global=32, + num_trials_per_iter=16, + post_optimization=True, + ) + trials = 32 + database = ms.tune_tir( + mod=matmul, + target=target, + max_trials_global=trials, + num_trials_per_iter=64, + work_dir=work_dir, + runner=ms.runner.LocalRunner( + evaluator_config=EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=100, + ) + ), + cost_model=ms.cost_model.XGBModel( + extractor=ms.feature_extractor.PerStoreFeature(), + adaptive_training=False, + ), + strategy=ms.search_strategy.EvolutionarySearch(), + post_optimization=True, # testing post optmization + ) + # +1 because of post optmization + assert len(database) == trials + 1 + + +@pytest.mark.skip("Integration test") +@tvm.testing.requires_cuda +def test_tune_matmul_cuda(): + with tempfile.TemporaryDirectory() as work_dir: + target = Target("nvidia/geforce-rtx-3070") + trials = 32 + database = ms.tune_tir( + mod=matmul, + target=target, + max_trials_global=trials, + num_trials_per_iter=64, + work_dir=work_dir, + runner=ms.runner.LocalRunner( + evaluator_config=EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=100, + ) + ), + cost_model=ms.cost_model.XGBModel( + extractor=ms.feature_extractor.PerStoreFeature(), + adaptive_training=False, + ), + strategy=ms.search_strategy.EvolutionarySearch(), + post_optimization=True, # testing post optmization + ) + # +1 because of post optmization + assert len(database) == trials + 1 + + +if __name__ == """__main__""": + test_tune_matmul_cpu() + test_tune_matmul_cuda()