Skip to content

Commit 64639b9

Browse files
author
Andrew Reusch
committed
Parameterize test_link_params.
1 parent 2dc58be commit 64639b9

1 file changed

Lines changed: 167 additions & 171 deletions

File tree

tests/python/unittest/test_link_params.py

Lines changed: 167 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@
3939

4040

4141
# The data types that are linkable.
42-
LINKABLE_DTYPES = (
43-
[f"uint{b}" for b in (8, 16, 32, 64)]
44-
+ [f"int{b}" for b in (8, 16, 32, 64)]
45-
+ ["float32", "float64"]
42+
linkable_dtype = tvm.testing.parameter(
43+
*([f"uint{b}" for b in (8, 16, 32, 64)]
44+
+ [f"int{b}" for b in (8, 16, 32, 64)]
45+
+ ["float32", "float64"])
4646
)
4747

4848

4949
def dtype_info(dtype):
50-
"""Lookup numpy type info for the given string dtype (of LINKABLE_DTYPES above)."""
50+
"""Lookup numpy type info for the given string dtype (of linkable_dtype params above)."""
5151
if "int" in dtype:
5252
return np.iinfo(getattr(np, dtype))
5353
else:
@@ -181,54 +181,53 @@ def _add_decl(name, dtype):
181181

182182

183183
@tvm.testing.requires_llvm
184-
def test_llvm_link_params():
185-
for dtype in LINKABLE_DTYPES:
186-
ir_mod, param_init = _make_mod_and_params(dtype)
187-
rand_input = _make_random_tensor(dtype, INPUT_SHAPE)
188-
main_func = ir_mod["main"]
189-
target = "llvm --runtime=c --system-lib --link-params"
190-
with tvm.transform.PassContext(opt_level=3):
191-
lib = tvm.relay.build(ir_mod, target, params=param_init)
192-
193-
# NOTE: Need to export_library() and load_library() to link all the Module(llvm, ...)
194-
# against one another.
195-
temp_dir = tempfile.mkdtemp()
196-
export_file = os.path.join(temp_dir, "lib.so")
197-
lib.lib.export_library(export_file)
198-
mod = tvm.runtime.load_module(export_file)
199-
assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded
200-
assert mod.get_function("TVMSystemLibEntryPoint") != None
201-
202-
graph = json.loads(lib.graph_json)
203-
for p in lib.params:
204-
_verify_linked_param(dtype, lib, mod, graph, p) or found_one
205-
206-
# Wrap in function to explicitly deallocate the runtime.
207-
def _run_linked(lib, mod):
208-
graph_json, _, _ = lib
209-
graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0))
210-
graph_rt.set_input("rand_input", rand_input) # NOTE: params not required.
211-
graph_rt.run()
212-
return graph_rt.get_output(0)
213-
214-
linked_output = _run_linked(lib, mod)
215-
216-
with tvm.transform.PassContext(opt_level=3):
217-
lib = tvm.relay.build(ir_mod, "llvm --system-lib", params=param_init)
218-
219-
def _run_unlinked(lib):
220-
graph_json, mod, lowered_params = lib
221-
graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0))
222-
graph_rt.set_input("rand_input", rand_input, **lowered_params)
223-
graph_rt.run()
224-
return graph_rt.get_output(0)
225-
226-
unlinked_output = _run_unlinked(lib)
227-
228-
if "int" in dtype:
229-
np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy())
230-
else:
231-
np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy())
184+
def test_llvm_link_params(linkable_dtype):
185+
ir_mod, param_init = _make_mod_and_params(linkable_dtype)
186+
rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE)
187+
main_func = ir_mod["main"]
188+
target = "llvm --runtime=c --system-lib --link-params"
189+
with tvm.transform.PassContext(opt_level=3):
190+
lib = tvm.relay.build(ir_mod, target, params=param_init)
191+
192+
# NOTE: Need to export_library() and load_library() to link all the Module(llvm, ...)
193+
# against one another.
194+
temp_dir = tempfile.mkdtemp()
195+
export_file = os.path.join(temp_dir, "lib.so")
196+
lib.lib.export_library(export_file)
197+
mod = tvm.runtime.load_module(export_file)
198+
assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded
199+
assert mod.get_function("TVMSystemLibEntryPoint") != None
200+
201+
graph = json.loads(lib.graph_json)
202+
for p in lib.params:
203+
_verify_linked_param(linkable_dtype, lib, mod, graph, p) or found_one
204+
205+
# Wrap in function to explicitly deallocate the runtime.
206+
def _run_linked(lib, mod):
207+
graph_json, _, _ = lib
208+
graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0))
209+
graph_rt.set_input("rand_input", rand_input) # NOTE: params not required.
210+
graph_rt.run()
211+
return graph_rt.get_output(0)
212+
213+
linked_output = _run_linked(lib, mod)
214+
215+
with tvm.transform.PassContext(opt_level=3):
216+
lib = tvm.relay.build(ir_mod, "llvm --system-lib", params=param_init)
217+
218+
def _run_unlinked(lib):
219+
graph_json, mod, lowered_params = lib
220+
graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0))
221+
graph_rt.set_input("rand_input", rand_input, **lowered_params)
222+
graph_rt.run()
223+
return graph_rt.get_output(0)
224+
225+
unlinked_output = _run_unlinked(lib)
226+
227+
if "int" in linkable_dtype:
228+
np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy())
229+
else:
230+
np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy())
232231

233232

234233
def _get_c_datatype(dtype):
@@ -263,137 +262,134 @@ def _format_c_value(dtype, width, x):
263262
HEX_NUM_RE = re.compile(r"[+\-]?(?:(?:0x[0-9A-Fa-f.p+-]+)|(?:INFINITY)|(?:NAN))")
264263

265264

266-
def test_c_link_params():
265+
def test_c_link_params(linkable_dtype):
267266
temp_dir = utils.tempdir()
268-
for dtype in LINKABLE_DTYPES:
269-
mod, param_init = _make_mod_and_params(dtype)
270-
rand_input = _make_random_tensor(dtype, INPUT_SHAPE)
271-
main_func = mod["main"]
272-
target = "c --link-params"
273-
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
274-
lib = tvm.relay.build(mod, target, params=param_init)
275-
assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded
276-
277-
src = lib.lib.get_source()
278-
lib.lib.save(temp_dir.relpath("test.c"), "c")
279-
c_dtype = _get_c_datatype(dtype)
280-
src_lines = src.split("\n")
281-
param = lib.params["p0"].numpy().reshape(np.prod(KERNEL_SHAPE))
282-
param_def = f"static const {c_dtype} __tvm_param__p0[{np.prod(param.shape)}] = {{"
283-
for i, line in enumerate(src_lines):
284-
if line == param_def:
285-
i += 1
286-
break
287-
else:
288-
assert False, f'did not find parameter definition "{param_def}":\n{src}'
289-
290-
cursor = 0
291-
width = dtype_info(dtype).bits // 4 + 2
292-
if dtype.startswith("int"):
293-
width += 1 # Account for sign
294-
295-
while "};" not in src_lines[i]:
296-
for match in HEX_NUM_RE.finditer(src_lines[i]):
297-
assert match.group() == _format_c_value(dtype, width, param[cursor]), (
298-
f'p0 byte {cursor}: want "{_format_c_value(dtype, width, param[cursor])}" got '
299-
f'"{match.group(0)}"; full p0 follows:\n{src}'
300-
)
301-
cursor += 1
267+
mod, param_init = _make_mod_and_params(linkable_dtype)
268+
rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE)
269+
main_func = mod["main"]
270+
target = "c --link-params"
271+
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
272+
lib = tvm.relay.build(mod, target, params=param_init)
273+
assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded
274+
275+
src = lib.lib.get_source()
276+
lib.lib.save(temp_dir.relpath("test.c"), "c")
277+
c_dtype = _get_c_datatype(linkable_dtype)
278+
src_lines = src.split("\n")
279+
param = lib.params["p0"].numpy().reshape(np.prod(KERNEL_SHAPE))
280+
param_def = f"static const {c_dtype} __tvm_param__p0[{np.prod(param.shape)}] = {{"
281+
for i, line in enumerate(src_lines):
282+
if line == param_def:
302283
i += 1
303-
304-
assert cursor == np.prod(param.shape)
305-
306-
# Need a unique name per library to avoid dlopen caching the lib load.
307-
lib_path = temp_dir.relpath(f"test-{dtype}-linked.so")
308-
lib["remove_params"]().export_library(lib_path)
309-
lib_mod = tvm.runtime.load_module(lib_path)
310-
311-
# lib_mod = lib_factory['default']()
312-
graph = json.loads(lib.graph_json)
313-
for p in lib.params:
314-
_verify_linked_param(dtype, lib, lib_mod, graph, p)
315-
316-
# Wrap in function to explicitly deallocate the runtime.
317-
def _run_linked(lib_mod):
318-
graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0)))
319-
graph_rt.set_input("rand_input", rand_input) # NOTE: params not required.
320-
graph_rt.run()
321-
322-
return graph_rt.get_output(0)
323-
324-
linked_output = _run_linked(lib_mod)
325-
326-
linked_params = lib.params
327-
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
328-
lib = tvm.relay.build(mod, "c", params=param_init)
329-
_, _, params = lib
330-
# Need a unique name per library to avoid dlopen caching the lib load.
331-
lib_path = temp_dir.relpath(f"test-{dtype}-unlinked.so")
332-
lib.export_library(lib_path)
333-
lib_mod = tvm.runtime.load_module(lib_path)
334-
335-
def _run_unlinked(lib_mod):
336-
graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0)))
337-
graph_rt.set_input("rand_input", rand_input, **params)
338-
graph_rt.run()
339-
return graph_rt.get_output(0)
340-
341-
unlinked_output = _run_unlinked(lib_mod)
342-
343-
if "int" in dtype:
344-
np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy())
284+
break
345285
else:
346-
np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy())
286+
assert False, f'did not find parameter definition "{param_def}":\n{src}'
287+
288+
cursor = 0
289+
width = dtype_info(linkable_dtype).bits // 4 + 2
290+
if linkable_dtype.startswith("int"):
291+
width += 1 # Account for sign
292+
293+
while "};" not in src_lines[i]:
294+
for match in HEX_NUM_RE.finditer(src_lines[i]):
295+
assert match.group() == _format_c_value(linkable_dtype, width, param[cursor]), (
296+
f'p0 byte {cursor}: want "{_format_c_value(linkable_dtype, width, param[cursor])}" got '
297+
f'"{match.group(0)}"; full p0 follows:\n{src}'
298+
)
299+
cursor += 1
300+
i += 1
301+
302+
assert cursor == np.prod(param.shape)
303+
304+
# Need a unique name per library to avoid dlopen caching the lib load.
305+
lib_path = temp_dir.relpath(f"test-{linkable_dtype}-linked.so")
306+
lib["remove_params"]().export_library(lib_path)
307+
lib_mod = tvm.runtime.load_module(lib_path)
308+
309+
# lib_mod = lib_factory['default']()
310+
graph = json.loads(lib.graph_json)
311+
for p in lib.params:
312+
_verify_linked_param(linkable_dtype, lib, lib_mod, graph, p)
313+
314+
# Wrap in function to explicitly deallocate the runtime.
315+
def _run_linked(lib_mod):
316+
graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0)))
317+
graph_rt.set_input("rand_input", rand_input) # NOTE: params not required.
318+
graph_rt.run()
319+
320+
return graph_rt.get_output(0)
321+
322+
linked_output = _run_linked(lib_mod)
323+
324+
linked_params = lib.params
325+
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
326+
lib = tvm.relay.build(mod, "c", params=param_init)
327+
_, _, params = lib
328+
# Need a unique name per library to avoid dlopen caching the lib load.
329+
lib_path = temp_dir.relpath(f"test-{linkable_dtype}-unlinked.so")
330+
lib.export_library(lib_path)
331+
lib_mod = tvm.runtime.load_module(lib_path)
332+
333+
def _run_unlinked(lib_mod):
334+
graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0)))
335+
graph_rt.set_input("rand_input", rand_input, **params)
336+
graph_rt.run()
337+
return graph_rt.get_output(0)
338+
339+
unlinked_output = _run_unlinked(lib_mod)
340+
341+
if "int" in linkable_dtype:
342+
np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy())
343+
else:
344+
np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy())
347345

348346

349347
@tvm.testing.requires_micro
350-
def test_crt_link_params():
348+
def test_crt_link_params(linkable_dtype):
351349
from tvm import micro
352-
353-
for dtype in LINKABLE_DTYPES:
354-
mod, param_init = _make_mod_and_params(dtype)
355-
rand_input = _make_random_tensor(dtype, INPUT_SHAPE)
356-
main_func = mod["main"]
357-
target = "c --system-lib --runtime=c --link-params"
358-
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
359-
factory = tvm.relay.build(mod, target, params=param_init)
360-
assert set(factory.get_params().keys()) == {"p0", "p1"} # NOTE: op folded
361-
362-
temp_dir = tvm.contrib.utils.tempdir()
363-
template_project_dir = os.path.join(
364-
tvm.micro.get_standalone_crt_dir(), "template", "host"
365-
)
366-
project = tvm.micro.generate_project(
367-
template_project_dir, factory, temp_dir / "project", {"verbose": 1}
350+
mod, param_init = _make_mod_and_params(linkable_dtype)
351+
rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE)
352+
main_func = mod["main"]
353+
target = "c --system-lib --runtime=c --link-params"
354+
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
355+
factory = tvm.relay.build(mod, target, params=param_init)
356+
assert set(factory.get_params().keys()) == {"p0", "p1"} # NOTE: op folded
357+
358+
temp_dir = tvm.contrib.utils.tempdir()
359+
template_project_dir = os.path.join(
360+
tvm.micro.get_standalone_crt_dir(), "template", "host"
361+
)
362+
project = tvm.micro.generate_project(
363+
template_project_dir, factory, temp_dir / "project", {"verbose": 1}
364+
)
365+
project.build()
366+
project.flash()
367+
with tvm.micro.Session(project.transport()) as sess:
368+
graph_rt = tvm.micro.session.create_local_graph_executor(
369+
factory.get_graph_json(), sess.get_system_lib(), sess.device
368370
)
369-
project.build()
370-
project.flash()
371-
with tvm.micro.Session(project.transport()) as sess:
372-
graph_rt = tvm.micro.session.create_local_graph_executor(
373-
factory.get_graph_json(), sess.get_system_lib(), sess.device
374-
)
375371

376-
# NOTE: not setting params here.
377-
graph_rt.set_input("rand_input", rand_input)
378-
graph_rt.run()
379-
linked_output = graph_rt.get_output(0).numpy()
372+
# NOTE: not setting params here.
373+
graph_rt.set_input("rand_input", rand_input)
374+
graph_rt.run()
375+
linked_output = graph_rt.get_output(0).numpy()
380376

381-
with tvm.transform.PassContext(opt_level=3):
382-
lib = tvm.relay.build(mod, "llvm --system-lib", params=param_init)
377+
with tvm.transform.PassContext(opt_level=3):
378+
lib = tvm.relay.build(mod, "llvm --system-lib", params=param_init)
383379

384-
def _run_unlinked(lib):
385-
graph_json, mod, lowered_params = lib
386-
graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0))
387-
graph_rt.set_input("rand_input", rand_input, **lowered_params)
388-
graph_rt.run()
389-
return graph_rt.get_output(0).numpy()
380+
def _run_unlinked(lib):
381+
graph_json, mod, lowered_params = lib
382+
graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0))
383+
graph_rt.set_input("rand_input", rand_input, **lowered_params)
384+
graph_rt.run()
385+
return graph_rt.get_output(0).numpy()
390386

391-
unlinked_output = _run_unlinked(lib)
387+
unlinked_output = _run_unlinked(lib)
392388

393-
if "int" in dtype:
394-
np.testing.assert_equal(unlinked_output, linked_output)
395-
else:
396-
np.testing.assert_allclose(unlinked_output, linked_output)
389+
if "int" in linkable_dtype:
390+
np.testing.assert_equal(unlinked_output, linked_output)
391+
else:
392+
np.testing.assert_allclose(unlinked_output, linked_output)
397393

398394

399395
if __name__ == "__main__":

0 commit comments

Comments
 (0)