Skip to content

Commit 1244089

Browse files
authored
[MetaSchedule] Add Testing Script with ONNX Support (#11587)
This PR introduces 2 tuning script for meta schedule and auto scheduler tuning support with onnx files. Now we can easily introduce onnx models benchmarking with command line scripts. Sample tuning call looks similar to the following script For Meta Schedule ONNX tuning: ``` python3 -m tvm.meta_schedule.testing.tune_onnx_meta_schedule \ --model-name "$MODEL_NAME" \ --onnx-path "$ONNX_PATH" \ --input-shape "$INPUT_SHAPE" \ --target "$TARGET" \ --num-trials $NUM_TRIALS \ --rpc-host $RPC_HOST \ --rpc-port $RPC_PORT \ --rpc-key $RPC_KEY \ --rpc-workers $RPC_WORKERS \ --work-dir $WORK_DIR \ |& tee "$WORK_DIR/$MODEL_NAME.log" ``` For AutoScheduler ONNX tuning: ``` python3 -m tvm.meta_schedule.testing.tune_onnx_auto_scheduler \ --model-name "$MODEL_NAME" \ --onnx-path "$ONNX_PATH" \ --input-shape "$INPUT_SHAPE" \ --target "$TARGET" \ --num-trials $NUM_TRIALS \ --rpc-host $RPC_HOST \ --rpc-port $RPC_PORT \ --rpc-key $RPC_KEY \ --rpc-workers $RPC_WORKERS \ --log-dir $WORK_DIR \ |& tee "$WORK_DIR/$MODEL_NAME.log" ```
1 parent 32a86f8 commit 1244089

File tree

3 files changed

+439
-2
lines changed

3 files changed

+439
-2
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-docstring
18+
import argparse
19+
import json
20+
import os
21+
22+
import numpy as np # type: ignore
23+
import onnx # type: ignore
24+
import tvm
25+
from tvm.relay.frontend import from_onnx
26+
from tvm import auto_scheduler
27+
from tvm import meta_schedule as ms
28+
from tvm import relay
29+
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
30+
31+
32+
def _parse_args():
33+
args = argparse.ArgumentParser()
34+
args.add_argument(
35+
"--model-name",
36+
type=str,
37+
required=True,
38+
)
39+
args.add_argument(
40+
"--onnx-path",
41+
type=str,
42+
required=True,
43+
)
44+
args.add_argument(
45+
"--input-shape",
46+
type=str,
47+
required=True,
48+
help='example: `[{"name": "input1", "dtype": "int64", "shape": [1, 1, 8]}]',
49+
)
50+
args.add_argument(
51+
"--target",
52+
type=str,
53+
required=True,
54+
)
55+
args.add_argument(
56+
"--num-trials",
57+
type=int,
58+
required=True,
59+
)
60+
args.add_argument(
61+
"--rpc-host",
62+
type=str,
63+
required=True,
64+
)
65+
args.add_argument(
66+
"--rpc-port",
67+
type=int,
68+
required=True,
69+
)
70+
args.add_argument(
71+
"--rpc-key",
72+
type=str,
73+
required=True,
74+
)
75+
args.add_argument(
76+
"--rpc-workers",
77+
type=int,
78+
required=True,
79+
)
80+
args.add_argument(
81+
"--work-dir",
82+
type=str,
83+
required=True,
84+
)
85+
parsed = args.parse_args()
86+
parsed.target = tvm.target.Target(parsed.target)
87+
parsed.input_shape = json.loads(parsed.input_shape)
88+
parsed.rpc_config = ms.runner.RPCConfig(
89+
tracker_host=parsed.rpc_host,
90+
tracker_port=parsed.rpc_port,
91+
tracker_key=parsed.rpc_key,
92+
session_timeout_sec=3600,
93+
)
94+
return parsed
95+
96+
97+
ARGS = _parse_args()
98+
99+
100+
def main():
101+
log_file = os.path.join(ARGS.work_dir, f"{ARGS.model_name}.json")
102+
103+
runner = auto_scheduler.RPCRunner(
104+
key=ARGS.rpc_key,
105+
host=ARGS.rpc_host,
106+
port=ARGS.rpc_port,
107+
n_parallel=ARGS.rpc_workers,
108+
number=3,
109+
repeat=1,
110+
min_repeat_ms=100, # TODO
111+
enable_cpu_cache_flush=False, # TODO
112+
)
113+
114+
if ARGS.target.kind.name == "llvm":
115+
hardware_params = auto_scheduler.HardwareParams(
116+
num_cores=int(ARGS.target.attrs["num-cores"]),
117+
target=ARGS.target,
118+
)
119+
elif ARGS.target.kind.name == "cuda":
120+
hardware_params = auto_scheduler.HardwareParams(
121+
num_cores=-1,
122+
vector_unit_bytes=16,
123+
cache_line_bytes=64,
124+
max_shared_memory_per_block=int(ARGS.target.attrs["max_shared_memory_per_block"]),
125+
max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]),
126+
# The value `max_local_memory_per_block` is not used in AutoScheduler,
127+
# but is required by the API.
128+
max_local_memory_per_block=12345678,
129+
max_vthread_extent=8,
130+
warp_size=32,
131+
)
132+
else:
133+
raise NotImplementedError(f"Unsupported target {ARGS.target}")
134+
135+
print(f"Workload: {ARGS.model_name}")
136+
onnx_model = onnx.load(ARGS.onnx_path)
137+
shape_dict = {}
138+
for item in ARGS.input_shape:
139+
print(f" input_name: {item['name']}")
140+
print(f" input_shape: {item['shape']}")
141+
print(f" input_dtype: {item['dtype']}")
142+
shape_dict[item["name"]] = item["shape"]
143+
mod, params = from_onnx(onnx_model, shape_dict, freeze_params=True)
144+
tasks, task_weights = auto_scheduler.extract_tasks(
145+
mod["main"],
146+
params,
147+
target=ARGS.target,
148+
hardware_params=hardware_params,
149+
)
150+
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
151+
print(f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) =====")
152+
print(task.compute_dag)
153+
154+
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
155+
tuner.tune(
156+
auto_scheduler.TuningOptions(
157+
num_measure_trials=ARGS.num_trials,
158+
runner=runner,
159+
measure_callbacks=[
160+
auto_scheduler.RecordToFile(log_file),
161+
],
162+
)
163+
)
164+
165+
with auto_scheduler.ApplyHistoryBest(log_file):
166+
with tvm.transform.PassContext(
167+
opt_level=3,
168+
config={"relay.backend.use_auto_scheduler": True},
169+
):
170+
lib = relay.build(
171+
mod,
172+
target=ARGS.target,
173+
params=params,
174+
)
175+
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
176+
input_data = {}
177+
for item in ARGS.input_shape:
178+
input_name, input_shape, input_dtype = item["name"], item["shape"], item["dtype"]
179+
if input_dtype.startswith("float"):
180+
input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype)
181+
else:
182+
input_data[input_name] = np.random.randint(
183+
low=0, high=10000, size=input_shape, dtype=input_dtype
184+
)
185+
186+
def f_timer(rt_mod, dev, input_data):
187+
# pylint: disable=import-outside-toplevel
188+
from tvm.contrib.graph_executor import GraphModule
189+
190+
# pylint: enable=import-outside-toplevel
191+
192+
mod = GraphModule(rt_mod["default"](dev))
193+
for input_name, input_value in input_data.items():
194+
mod.set_input(input_name, input_value)
195+
ftimer = mod.module.time_evaluator(
196+
"run",
197+
dev,
198+
min_repeat_ms=500,
199+
repeat=3,
200+
)
201+
results = list(np.array(ftimer().results) * 1000.0) # type: ignore
202+
print("Running time in time_evaluator: ", results)
203+
204+
run_module_via_rpc(
205+
rpc_config=ARGS.rpc_config,
206+
lib=lib,
207+
dev_type=ARGS.target.kind.name,
208+
args=input_data,
209+
continuation=f_timer,
210+
)
211+
212+
def f_per_layer(rt_mod, dev, input_data):
213+
# pylint: disable=import-outside-toplevel
214+
from tvm.contrib.debugger.debug_executor import create
215+
216+
# pylint: enable=import-outside-toplevel
217+
mod = create(graph, rt_mod, dev)
218+
for input_name, input_value in input_data.items():
219+
mod.set_input(input_name, input_value)
220+
graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
221+
graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000)
222+
print("|graph_nodes| = ", len(graph_nodes))
223+
print("|graph_time| = ", len(graph_time))
224+
graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, graph_time)}
225+
for k, v in graph_nodes_time.items():
226+
print(f"{k} : {v:.3f}")
227+
228+
run_module_via_rpc(
229+
rpc_config=ARGS.rpc_config,
230+
lib=rt_mod,
231+
dev_type=ARGS.target.kind.name,
232+
args=input_data,
233+
continuation=f_per_layer,
234+
)
235+
236+
237+
if __name__ == "__main__":
238+
main()

0 commit comments

Comments
 (0)