Skip to content

Commit f5a85c2

Browse files
Andrew Reuschareusch
authored andcommitted
Parameterize test_link_params.
1 parent f583a70 commit f5a85c2

1 file changed

Lines changed: 175 additions & 179 deletions

File tree

tests/python/unittest/test_link_params.py

Lines changed: 175 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import os
2121
import re
2222
import sys
23-
import tempfile
2423

2524
import numpy as np
2625
import pytest
@@ -39,15 +38,17 @@
3938

4039

4140
# The data types that are linkable.
42-
LINKABLE_DTYPES = (
43-
[f"uint{b}" for b in (8, 16, 32, 64)]
44-
+ [f"int{b}" for b in (8, 16, 32, 64)]
45-
+ ["float32", "float64"]
41+
linkable_dtype = tvm.testing.parameter(
42+
*(
43+
[f"uint{b}" for b in (8, 16, 32, 64)]
44+
+ [f"int{b}" for b in (8, 16, 32, 64)]
45+
+ ["float32", "float64"]
46+
)
4647
)
4748

4849

4950
def dtype_info(dtype):
50-
"""Lookup numpy type info for the given string dtype (of LINKABLE_DTYPES above)."""
51+
"""Lookup numpy type info for the given string dtype (of linkable_dtype params above)."""
5152
if "int" in dtype:
5253
return np.iinfo(getattr(np, dtype))
5354
else:
@@ -181,58 +182,56 @@ def _add_decl(name, dtype):
181182

182183

183184
@tvm.testing.requires_llvm
184-
def test_llvm_link_params():
185-
for dtype in LINKABLE_DTYPES:
186-
ir_mod, param_init = _make_mod_and_params(dtype)
187-
rand_input = _make_random_tensor(dtype, INPUT_SHAPE)
188-
target = "llvm"
189-
runtime = Runtime("crt", {"system-lib": True})
190-
executor = Executor("graph", {"link-params": True})
191-
with tvm.transform.PassContext(opt_level=3):
192-
lib = tvm.relay.build(
193-
ir_mod, target, runtime=runtime, executor=executor, params=param_init
194-
)
195-
196-
# NOTE: Need to export_library() and load_library() to link all the Module(llvm, ...)
197-
# against one another.
198-
temp_dir = tempfile.mkdtemp()
199-
export_file = os.path.join(temp_dir, "lib.so")
200-
lib.lib.export_library(export_file)
201-
mod = tvm.runtime.load_module(export_file)
202-
assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded
203-
assert mod.get_function("TVMSystemLibEntryPoint") != None
204-
205-
graph = json.loads(lib.graph_json)
206-
for p in lib.params:
207-
_verify_linked_param(dtype, lib, mod, graph, p) or found_one
208-
209-
# Wrap in function to explicitly deallocate the runtime.
210-
def _run_linked(lib, mod):
211-
graph_json, _, _ = lib
212-
graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0))
213-
graph_rt.set_input("rand_input", rand_input) # NOTE: params not required.
214-
graph_rt.run()
215-
return graph_rt.get_output(0)
216-
217-
linked_output = _run_linked(lib, mod)
218-
219-
runtime = Runtime("cpp", {"system-lib": True})
220-
with tvm.transform.PassContext(opt_level=3):
221-
lib = tvm.relay.build(ir_mod, "llvm", runtime=runtime, params=param_init)
222-
223-
def _run_unlinked(lib):
224-
graph_json, mod, lowered_params = lib
225-
graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0))
226-
graph_rt.set_input("rand_input", rand_input, **lowered_params)
227-
graph_rt.run()
228-
return graph_rt.get_output(0)
229-
230-
unlinked_output = _run_unlinked(lib)
231-
232-
if "int" in dtype:
233-
np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy())
234-
else:
235-
np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy())
185+
def test_llvm_link_params(linkable_dtype):
186+
ir_mod, param_init = _make_mod_and_params(linkable_dtype)
187+
rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE)
188+
main_func = ir_mod["main"]
189+
target = "llvm"
190+
runtime = Runtime("crt", {"system-lib": True})
191+
executor = Executor("graph", {"link-params": True})
192+
with tvm.transform.PassContext(opt_level=3):
193+
lib = tvm.relay.build(ir_mod, target, runtime=runtime, executor=executor, params=param_init)
194+
195+
# NOTE: Need to export_library() and load_library() to link all the Module(llvm, ...)
196+
# against one another.
197+
temp_dir = utils.TempDirectory()
198+
export_file = temp_dir / "lib.so"
199+
lib.lib.export_library(export_file)
200+
mod = tvm.runtime.load_module(export_file)
201+
assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded
202+
assert mod.get_function("TVMSystemLibEntryPoint") != None
203+
204+
graph = json.loads(lib.graph_json)
205+
for p in lib.params:
206+
_verify_linked_param(linkable_dtype, lib, mod, graph, p) or found_one
207+
208+
# Wrap in function to explicitly deallocate the runtime.
209+
def _run_linked(lib, mod):
210+
graph_json, _, _ = lib
211+
graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0))
212+
graph_rt.set_input("rand_input", rand_input) # NOTE: params not required.
213+
graph_rt.run()
214+
return graph_rt.get_output(0)
215+
216+
linked_output = _run_linked(lib, mod)
217+
218+
runtime = Runtime("cpp", {"system-lib": True})
219+
with tvm.transform.PassContext(opt_level=3):
220+
lib = tvm.relay.build(ir_mod, "llvm", runtime=runtime, params=param_init)
221+
222+
def _run_unlinked(lib):
223+
graph_json, mod, lowered_params = lib
224+
graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0))
225+
graph_rt.set_input("rand_input", rand_input, **lowered_params)
226+
graph_rt.run()
227+
return graph_rt.get_output(0)
228+
229+
unlinked_output = _run_unlinked(lib)
230+
231+
if "int" in linkable_dtype:
232+
np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy())
233+
else:
234+
np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy())
236235

237236

238237
def _get_c_datatype(dtype):
@@ -267,141 +266,138 @@ def _format_c_value(dtype, width, x):
267266
HEX_NUM_RE = re.compile(r"[+\-]?(?:(?:0x[0-9A-Fa-f.p+-]+)|(?:INFINITY)|(?:NAN))")
268267

269268

270-
def test_c_link_params():
269+
def test_c_link_params(linkable_dtype):
271270
temp_dir = utils.tempdir()
272-
for dtype in LINKABLE_DTYPES:
273-
mod, param_init = _make_mod_and_params(dtype)
274-
rand_input = _make_random_tensor(dtype, INPUT_SHAPE)
275-
target = "c"
276-
executor = Executor("graph", {"link-params": True})
277-
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
278-
lib = tvm.relay.build(mod, target, executor=executor, params=param_init)
279-
assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded
280-
281-
src = lib.lib.get_source()
282-
lib.lib.save(temp_dir.relpath("test.c"), "c")
283-
c_dtype = _get_c_datatype(dtype)
284-
src_lines = src.split("\n")
285-
param = lib.params["p0"].numpy().reshape(np.prod(KERNEL_SHAPE))
286-
param_def = f'static const {c_dtype} __attribute__((section(".rodata.tvm"), aligned(16))) __tvm_param__p0[{np.prod(param.shape)}] = {{'
287-
for i, line in enumerate(src_lines):
288-
if line == param_def:
289-
i += 1
290-
break
291-
else:
292-
assert False, f'did not find parameter definition "{param_def}":\n{src}'
293-
294-
cursor = 0
295-
width = dtype_info(dtype).bits // 4 + 2
296-
if dtype.startswith("int"):
297-
width += 1 # Account for sign
298-
299-
while "};" not in src_lines[i]:
300-
for match in HEX_NUM_RE.finditer(src_lines[i]):
301-
assert match.group() == _format_c_value(dtype, width, param[cursor]), (
302-
f'p0 byte {cursor}: want "{_format_c_value(dtype, width, param[cursor])}" got '
303-
f'"{match.group(0)}"; full p0 follows:\n{src}'
304-
)
305-
cursor += 1
271+
mod, param_init = _make_mod_and_params(linkable_dtype)
272+
rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE)
273+
main_func = mod["main"]
274+
target = "c"
275+
executor = Executor("graph", {"link-params": True})
276+
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
277+
lib = tvm.relay.build(mod, target, executor=executor, params=param_init)
278+
assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded
279+
280+
src = lib.lib.get_source()
281+
lib.lib.save(temp_dir.relpath("test.c"), "c")
282+
c_dtype = _get_c_datatype(linkable_dtype)
283+
src_lines = src.split("\n")
284+
param = lib.params["p0"].numpy().reshape(np.prod(KERNEL_SHAPE))
285+
param_def = f'static const {c_dtype} __attribute__((section(".rodata.tvm"), aligned(16))) __tvm_param__p0[{np.prod(param.shape)}] = {{'
286+
287+
for i, line in enumerate(src_lines):
288+
if line == param_def:
306289
i += 1
307-
308-
assert cursor == np.prod(param.shape)
309-
310-
# Need a unique name per library to avoid dlopen caching the lib load.
311-
lib_path = temp_dir.relpath(f"test-{dtype}-linked.so")
312-
lib["remove_params"]().export_library(lib_path)
313-
lib_mod = tvm.runtime.load_module(lib_path)
314-
315-
# lib_mod = lib_factory['default']()
316-
graph = json.loads(lib.graph_json)
317-
for p in lib.params:
318-
_verify_linked_param(dtype, lib, lib_mod, graph, p)
319-
320-
# Wrap in function to explicitly deallocate the runtime.
321-
def _run_linked(lib_mod):
322-
graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0)))
323-
graph_rt.set_input("rand_input", rand_input) # NOTE: params not required.
324-
graph_rt.run()
325-
326-
return graph_rt.get_output(0)
327-
328-
linked_output = _run_linked(lib_mod)
329-
330-
linked_params = lib.params
331-
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
332-
lib = tvm.relay.build(mod, "c", params=param_init)
333-
_, _, params = lib
334-
# Need a unique name per library to avoid dlopen caching the lib load.
335-
lib_path = temp_dir.relpath(f"test-{dtype}-unlinked.so")
336-
lib.export_library(lib_path)
337-
lib_mod = tvm.runtime.load_module(lib_path)
338-
339-
def _run_unlinked(lib_mod):
340-
graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0)))
341-
graph_rt.set_input("rand_input", rand_input, **params)
342-
graph_rt.run()
343-
return graph_rt.get_output(0)
344-
345-
unlinked_output = _run_unlinked(lib_mod)
346-
347-
if "int" in dtype:
348-
np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy())
290+
break
349291
else:
350-
np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy())
292+
assert False, f'did not find parameter definition "{param_def}":\n{src}'
293+
294+
cursor = 0
295+
width = dtype_info(linkable_dtype).bits // 4 + 2
296+
if linkable_dtype.startswith("int"):
297+
width += 1 # Account for sign
298+
299+
while "};" not in src_lines[i]:
300+
for match in HEX_NUM_RE.finditer(src_lines[i]):
301+
assert match.group() == _format_c_value(linkable_dtype, width, param[cursor]), (
302+
f'p0 byte {cursor}: want "{_format_c_value(linkable_dtype, width, param[cursor])}" got '
303+
f'"{match.group(0)}"; full p0 follows:\n{src}'
304+
)
305+
cursor += 1
306+
i += 1
307+
308+
assert cursor == np.prod(param.shape)
309+
310+
# Need a unique name per library to avoid dlopen caching the lib load.
311+
lib_path = temp_dir.relpath(f"test-{linkable_dtype}-linked.so")
312+
lib["remove_params"]().export_library(lib_path)
313+
lib_mod = tvm.runtime.load_module(lib_path)
314+
315+
# lib_mod = lib_factory['default']()
316+
graph = json.loads(lib.graph_json)
317+
for p in lib.params:
318+
_verify_linked_param(linkable_dtype, lib, lib_mod, graph, p)
319+
320+
# Wrap in function to explicitly deallocate the runtime.
321+
def _run_linked(lib_mod):
322+
graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0)))
323+
graph_rt.set_input("rand_input", rand_input) # NOTE: params not required.
324+
graph_rt.run()
325+
326+
return graph_rt.get_output(0)
327+
328+
linked_output = _run_linked(lib_mod)
329+
330+
linked_params = lib.params
331+
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
332+
lib = tvm.relay.build(mod, "c", params=param_init)
333+
_, _, params = lib
334+
# Need a unique name per library to avoid dlopen caching the lib load.
335+
lib_path = temp_dir.relpath(f"test-{linkable_dtype}-unlinked.so")
336+
lib.export_library(lib_path)
337+
lib_mod = tvm.runtime.load_module(lib_path)
338+
339+
def _run_unlinked(lib_mod):
340+
graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0)))
341+
graph_rt.set_input("rand_input", rand_input, **params)
342+
graph_rt.run()
343+
return graph_rt.get_output(0)
344+
345+
unlinked_output = _run_unlinked(lib_mod)
346+
347+
if "int" in linkable_dtype:
348+
np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy())
349+
else:
350+
np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy())
351351

352352

353353
@tvm.testing.requires_micro
354-
def test_crt_link_params():
354+
def test_crt_link_params(linkable_dtype):
355355
from tvm import micro
356356

357-
for dtype in LINKABLE_DTYPES:
358-
mod, param_init = _make_mod_and_params(dtype)
359-
rand_input = _make_random_tensor(dtype, INPUT_SHAPE)
360-
target = "c"
361-
runtime = Runtime("crt", {"system-lib": True})
362-
executor = Executor("graph", {"link-params": True})
363-
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
364-
factory = tvm.relay.build(
365-
mod, target, runtime=runtime, executor=executor, params=param_init
357+
mod, param_init = _make_mod_and_params(linkable_dtype)
358+
rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE)
359+
main_func = mod["main"]
360+
target = "c"
361+
runtime = Runtime("crt", {"system-lib": True})
362+
executor = Executor("graph", {"link-params": True})
363+
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
364+
factory = tvm.relay.build(mod, target, runtime=runtime, executor=executor, params=param_init)
365+
assert set(factory.get_params().keys()) == {"p0", "p1"} # NOTE: op folded
366+
367+
temp_dir = tvm.contrib.utils.tempdir()
368+
template_project_dir = os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host")
369+
project = tvm.micro.generate_project(
370+
template_project_dir, factory, temp_dir / "project", {"verbose": 1}
371+
)
372+
project.build()
373+
project.flash()
374+
with tvm.micro.Session(project.transport()) as sess:
375+
graph_rt = tvm.micro.session.create_local_graph_executor(
376+
factory.get_graph_json(), sess.get_system_lib(), sess.device
366377
)
367-
assert set(factory.get_params().keys()) == {"p0", "p1"} # NOTE: op folded
368378

369-
temp_dir = tvm.contrib.utils.tempdir()
370-
template_project_dir = os.path.join(
371-
tvm.micro.get_standalone_crt_dir(), "template", "host"
372-
)
373-
project = tvm.micro.generate_project(
374-
template_project_dir, factory, temp_dir / "project", {"verbose": 1}
375-
)
376-
project.build()
377-
project.flash()
378-
with tvm.micro.Session(project.transport()) as sess:
379-
graph_rt = tvm.micro.session.create_local_graph_executor(
380-
factory.get_graph_json(), sess.get_system_lib(), sess.device
381-
)
382-
383-
# NOTE: not setting params here.
384-
graph_rt.set_input("rand_input", rand_input)
385-
graph_rt.run()
386-
linked_output = graph_rt.get_output(0).numpy()
379+
# NOTE: not setting params here.
380+
graph_rt.set_input("rand_input", rand_input)
381+
graph_rt.run()
382+
linked_output = graph_rt.get_output(0).numpy()
387383

388-
runtime = Runtime("cpp", {"system-lib": True})
389-
with tvm.transform.PassContext(opt_level=3):
390-
lib = tvm.relay.build(mod, "llvm", runtime=runtime, params=param_init)
384+
runtime = Runtime("cpp", {"system-lib": True})
385+
with tvm.transform.PassContext(opt_level=3):
386+
lib = tvm.relay.build(mod, "llvm", runtime=runtime, params=param_init)
391387

392-
def _run_unlinked(lib):
393-
graph_json, mod, lowered_params = lib
394-
graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0))
395-
graph_rt.set_input("rand_input", rand_input, **lowered_params)
396-
graph_rt.run()
397-
return graph_rt.get_output(0).numpy()
388+
def _run_unlinked(lib):
389+
graph_json, mod, lowered_params = lib
390+
graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0))
391+
graph_rt.set_input("rand_input", rand_input, **lowered_params)
392+
graph_rt.run()
393+
return graph_rt.get_output(0).numpy()
398394

399-
unlinked_output = _run_unlinked(lib)
395+
unlinked_output = _run_unlinked(lib)
400396

401-
if "int" in dtype:
402-
np.testing.assert_equal(unlinked_output, linked_output)
403-
else:
404-
np.testing.assert_allclose(unlinked_output, linked_output)
397+
if "int" in linkable_dtype:
398+
np.testing.assert_equal(unlinked_output, linked_output)
399+
else:
400+
np.testing.assert_allclose(unlinked_output, linked_output)
405401

406402

407403
if __name__ == "__main__":

0 commit comments

Comments
 (0)