|
20 | 20 | import os |
21 | 21 | import re |
22 | 22 | import sys |
23 | | -import tempfile |
24 | 23 |
|
25 | 24 | import numpy as np |
26 | 25 | import pytest |
|
39 | 38 |
|
40 | 39 |
|
41 | 40 | # The data types that are linkable. |
42 | | -LINKABLE_DTYPES = ( |
43 | | - [f"uint{b}" for b in (8, 16, 32, 64)] |
44 | | - + [f"int{b}" for b in (8, 16, 32, 64)] |
45 | | - + ["float32", "float64"] |
| 41 | +linkable_dtype = tvm.testing.parameter( |
| 42 | + *( |
| 43 | + [f"uint{b}" for b in (8, 16, 32, 64)] |
| 44 | + + [f"int{b}" for b in (8, 16, 32, 64)] |
| 45 | + + ["float32", "float64"] |
| 46 | + ) |
46 | 47 | ) |
47 | 48 |
|
48 | 49 |
|
49 | 50 | def dtype_info(dtype): |
50 | | - """Lookup numpy type info for the given string dtype (of LINKABLE_DTYPES above).""" |
| 51 | + """Lookup numpy type info for the given string dtype (of linkable_dtype params above).""" |
51 | 52 | if "int" in dtype: |
52 | 53 | return np.iinfo(getattr(np, dtype)) |
53 | 54 | else: |
@@ -181,58 +182,56 @@ def _add_decl(name, dtype): |
181 | 182 |
|
182 | 183 |
|
183 | 184 | @tvm.testing.requires_llvm |
184 | | -def test_llvm_link_params(): |
185 | | - for dtype in LINKABLE_DTYPES: |
186 | | - ir_mod, param_init = _make_mod_and_params(dtype) |
187 | | - rand_input = _make_random_tensor(dtype, INPUT_SHAPE) |
188 | | - target = "llvm" |
189 | | - runtime = Runtime("crt", {"system-lib": True}) |
190 | | - executor = Executor("graph", {"link-params": True}) |
191 | | - with tvm.transform.PassContext(opt_level=3): |
192 | | - lib = tvm.relay.build( |
193 | | - ir_mod, target, runtime=runtime, executor=executor, params=param_init |
194 | | - ) |
195 | | - |
196 | | - # NOTE: Need to export_library() and load_library() to link all the Module(llvm, ...) |
197 | | - # against one another. |
198 | | - temp_dir = tempfile.mkdtemp() |
199 | | - export_file = os.path.join(temp_dir, "lib.so") |
200 | | - lib.lib.export_library(export_file) |
201 | | - mod = tvm.runtime.load_module(export_file) |
202 | | - assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded |
203 | | - assert mod.get_function("TVMSystemLibEntryPoint") != None |
204 | | - |
205 | | - graph = json.loads(lib.graph_json) |
206 | | - for p in lib.params: |
207 | | - _verify_linked_param(dtype, lib, mod, graph, p) or found_one |
208 | | - |
209 | | - # Wrap in function to explicitly deallocate the runtime. |
210 | | - def _run_linked(lib, mod): |
211 | | - graph_json, _, _ = lib |
212 | | - graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0)) |
213 | | - graph_rt.set_input("rand_input", rand_input) # NOTE: params not required. |
214 | | - graph_rt.run() |
215 | | - return graph_rt.get_output(0) |
216 | | - |
217 | | - linked_output = _run_linked(lib, mod) |
218 | | - |
219 | | - runtime = Runtime("cpp", {"system-lib": True}) |
220 | | - with tvm.transform.PassContext(opt_level=3): |
221 | | - lib = tvm.relay.build(ir_mod, "llvm", runtime=runtime, params=param_init) |
222 | | - |
223 | | - def _run_unlinked(lib): |
224 | | - graph_json, mod, lowered_params = lib |
225 | | - graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0)) |
226 | | - graph_rt.set_input("rand_input", rand_input, **lowered_params) |
227 | | - graph_rt.run() |
228 | | - return graph_rt.get_output(0) |
229 | | - |
230 | | - unlinked_output = _run_unlinked(lib) |
231 | | - |
232 | | - if "int" in dtype: |
233 | | - np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy()) |
234 | | - else: |
235 | | - np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy()) |
| 185 | +def test_llvm_link_params(linkable_dtype): |
| 186 | + ir_mod, param_init = _make_mod_and_params(linkable_dtype) |
| 187 | + rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE) |
| 188 | + main_func = ir_mod["main"] |
| 189 | + target = "llvm" |
| 190 | + runtime = Runtime("crt", {"system-lib": True}) |
| 191 | + executor = Executor("graph", {"link-params": True}) |
| 192 | + with tvm.transform.PassContext(opt_level=3): |
| 193 | + lib = tvm.relay.build(ir_mod, target, runtime=runtime, executor=executor, params=param_init) |
| 194 | + |
| 195 | + # NOTE: Need to export_library() and load_library() to link all the Module(llvm, ...) |
| 196 | + # against one another. |
| 197 | + temp_dir = utils.TempDirectory() |
| 198 | + export_file = temp_dir / "lib.so" |
| 199 | + lib.lib.export_library(export_file) |
| 200 | + mod = tvm.runtime.load_module(export_file) |
| 201 | + assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded |
| 202 | + assert mod.get_function("TVMSystemLibEntryPoint") != None |
| 203 | + |
| 204 | + graph = json.loads(lib.graph_json) |
| 205 | + for p in lib.params: |
| 206 | + _verify_linked_param(linkable_dtype, lib, mod, graph, p) or found_one |
| 207 | + |
| 208 | + # Wrap in function to explicitly deallocate the runtime. |
| 209 | + def _run_linked(lib, mod): |
| 210 | + graph_json, _, _ = lib |
| 211 | + graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0)) |
| 212 | + graph_rt.set_input("rand_input", rand_input) # NOTE: params not required. |
| 213 | + graph_rt.run() |
| 214 | + return graph_rt.get_output(0) |
| 215 | + |
| 216 | + linked_output = _run_linked(lib, mod) |
| 217 | + |
| 218 | + runtime = Runtime("cpp", {"system-lib": True}) |
| 219 | + with tvm.transform.PassContext(opt_level=3): |
| 220 | + lib = tvm.relay.build(ir_mod, "llvm", runtime=runtime, params=param_init) |
| 221 | + |
| 222 | + def _run_unlinked(lib): |
| 223 | + graph_json, mod, lowered_params = lib |
| 224 | + graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0)) |
| 225 | + graph_rt.set_input("rand_input", rand_input, **lowered_params) |
| 226 | + graph_rt.run() |
| 227 | + return graph_rt.get_output(0) |
| 228 | + |
| 229 | + unlinked_output = _run_unlinked(lib) |
| 230 | + |
| 231 | + if "int" in linkable_dtype: |
| 232 | + np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy()) |
| 233 | + else: |
| 234 | + np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy()) |
236 | 235 |
|
237 | 236 |
|
238 | 237 | def _get_c_datatype(dtype): |
@@ -267,141 +266,138 @@ def _format_c_value(dtype, width, x): |
267 | 266 | HEX_NUM_RE = re.compile(r"[+\-]?(?:(?:0x[0-9A-Fa-f.p+-]+)|(?:INFINITY)|(?:NAN))") |
268 | 267 |
|
269 | 268 |
|
270 | | -def test_c_link_params(): |
| 269 | +def test_c_link_params(linkable_dtype): |
271 | 270 | temp_dir = utils.tempdir() |
272 | | - for dtype in LINKABLE_DTYPES: |
273 | | - mod, param_init = _make_mod_and_params(dtype) |
274 | | - rand_input = _make_random_tensor(dtype, INPUT_SHAPE) |
275 | | - target = "c" |
276 | | - executor = Executor("graph", {"link-params": True}) |
277 | | - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): |
278 | | - lib = tvm.relay.build(mod, target, executor=executor, params=param_init) |
279 | | - assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded |
280 | | - |
281 | | - src = lib.lib.get_source() |
282 | | - lib.lib.save(temp_dir.relpath("test.c"), "c") |
283 | | - c_dtype = _get_c_datatype(dtype) |
284 | | - src_lines = src.split("\n") |
285 | | - param = lib.params["p0"].numpy().reshape(np.prod(KERNEL_SHAPE)) |
286 | | - param_def = f'static const {c_dtype} __attribute__((section(".rodata.tvm"), aligned(16))) __tvm_param__p0[{np.prod(param.shape)}] = {{' |
287 | | - for i, line in enumerate(src_lines): |
288 | | - if line == param_def: |
289 | | - i += 1 |
290 | | - break |
291 | | - else: |
292 | | - assert False, f'did not find parameter definition "{param_def}":\n{src}' |
293 | | - |
294 | | - cursor = 0 |
295 | | - width = dtype_info(dtype).bits // 4 + 2 |
296 | | - if dtype.startswith("int"): |
297 | | - width += 1 # Account for sign |
298 | | - |
299 | | - while "};" not in src_lines[i]: |
300 | | - for match in HEX_NUM_RE.finditer(src_lines[i]): |
301 | | - assert match.group() == _format_c_value(dtype, width, param[cursor]), ( |
302 | | - f'p0 byte {cursor}: want "{_format_c_value(dtype, width, param[cursor])}" got ' |
303 | | - f'"{match.group(0)}"; full p0 follows:\n{src}' |
304 | | - ) |
305 | | - cursor += 1 |
| 271 | + mod, param_init = _make_mod_and_params(linkable_dtype) |
| 272 | + rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE) |
| 273 | + main_func = mod["main"] |
| 274 | + target = "c" |
| 275 | + executor = Executor("graph", {"link-params": True}) |
| 276 | + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): |
| 277 | + lib = tvm.relay.build(mod, target, executor=executor, params=param_init) |
| 278 | + assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded |
| 279 | + |
| 280 | + src = lib.lib.get_source() |
| 281 | + lib.lib.save(temp_dir.relpath("test.c"), "c") |
| 282 | + c_dtype = _get_c_datatype(linkable_dtype) |
| 283 | + src_lines = src.split("\n") |
| 284 | + param = lib.params["p0"].numpy().reshape(np.prod(KERNEL_SHAPE)) |
| 285 | + param_def = f'static const {c_dtype} __attribute__((section(".rodata.tvm"), aligned(16))) __tvm_param__p0[{np.prod(param.shape)}] = {{' |
| 286 | + |
| 287 | + for i, line in enumerate(src_lines): |
| 288 | + if line == param_def: |
306 | 289 | i += 1 |
307 | | - |
308 | | - assert cursor == np.prod(param.shape) |
309 | | - |
310 | | - # Need a unique name per library to avoid dlopen caching the lib load. |
311 | | - lib_path = temp_dir.relpath(f"test-{dtype}-linked.so") |
312 | | - lib["remove_params"]().export_library(lib_path) |
313 | | - lib_mod = tvm.runtime.load_module(lib_path) |
314 | | - |
315 | | - # lib_mod = lib_factory['default']() |
316 | | - graph = json.loads(lib.graph_json) |
317 | | - for p in lib.params: |
318 | | - _verify_linked_param(dtype, lib, lib_mod, graph, p) |
319 | | - |
320 | | - # Wrap in function to explicitly deallocate the runtime. |
321 | | - def _run_linked(lib_mod): |
322 | | - graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0))) |
323 | | - graph_rt.set_input("rand_input", rand_input) # NOTE: params not required. |
324 | | - graph_rt.run() |
325 | | - |
326 | | - return graph_rt.get_output(0) |
327 | | - |
328 | | - linked_output = _run_linked(lib_mod) |
329 | | - |
330 | | - linked_params = lib.params |
331 | | - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): |
332 | | - lib = tvm.relay.build(mod, "c", params=param_init) |
333 | | - _, _, params = lib |
334 | | - # Need a unique name per library to avoid dlopen caching the lib load. |
335 | | - lib_path = temp_dir.relpath(f"test-{dtype}-unlinked.so") |
336 | | - lib.export_library(lib_path) |
337 | | - lib_mod = tvm.runtime.load_module(lib_path) |
338 | | - |
339 | | - def _run_unlinked(lib_mod): |
340 | | - graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0))) |
341 | | - graph_rt.set_input("rand_input", rand_input, **params) |
342 | | - graph_rt.run() |
343 | | - return graph_rt.get_output(0) |
344 | | - |
345 | | - unlinked_output = _run_unlinked(lib_mod) |
346 | | - |
347 | | - if "int" in dtype: |
348 | | - np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy()) |
| 290 | + break |
349 | 291 | else: |
350 | | - np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy()) |
| 292 | + assert False, f'did not find parameter definition "{param_def}":\n{src}' |
| 293 | + |
| 294 | + cursor = 0 |
| 295 | + width = dtype_info(linkable_dtype).bits // 4 + 2 |
| 296 | + if linkable_dtype.startswith("int"): |
| 297 | + width += 1 # Account for sign |
| 298 | + |
| 299 | + while "};" not in src_lines[i]: |
| 300 | + for match in HEX_NUM_RE.finditer(src_lines[i]): |
| 301 | + assert match.group() == _format_c_value(linkable_dtype, width, param[cursor]), ( |
| 302 | + f'p0 byte {cursor}: want "{_format_c_value(linkable_dtype, width, param[cursor])}" got ' |
| 303 | + f'"{match.group(0)}"; full p0 follows:\n{src}' |
| 304 | + ) |
| 305 | + cursor += 1 |
| 306 | + i += 1 |
| 307 | + |
| 308 | + assert cursor == np.prod(param.shape) |
| 309 | + |
| 310 | + # Need a unique name per library to avoid dlopen caching the lib load. |
| 311 | + lib_path = temp_dir.relpath(f"test-{linkable_dtype}-linked.so") |
| 312 | + lib["remove_params"]().export_library(lib_path) |
| 313 | + lib_mod = tvm.runtime.load_module(lib_path) |
| 314 | + |
| 315 | + # lib_mod = lib_factory['default']() |
| 316 | + graph = json.loads(lib.graph_json) |
| 317 | + for p in lib.params: |
| 318 | + _verify_linked_param(linkable_dtype, lib, lib_mod, graph, p) |
| 319 | + |
| 320 | + # Wrap in function to explicitly deallocate the runtime. |
| 321 | + def _run_linked(lib_mod): |
| 322 | + graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0))) |
| 323 | + graph_rt.set_input("rand_input", rand_input) # NOTE: params not required. |
| 324 | + graph_rt.run() |
| 325 | + |
| 326 | + return graph_rt.get_output(0) |
| 327 | + |
| 328 | + linked_output = _run_linked(lib_mod) |
| 329 | + |
| 330 | + linked_params = lib.params |
| 331 | + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): |
| 332 | + lib = tvm.relay.build(mod, "c", params=param_init) |
| 333 | + _, _, params = lib |
| 334 | + # Need a unique name per library to avoid dlopen caching the lib load. |
| 335 | + lib_path = temp_dir.relpath(f"test-{linkable_dtype}-unlinked.so") |
| 336 | + lib.export_library(lib_path) |
| 337 | + lib_mod = tvm.runtime.load_module(lib_path) |
| 338 | + |
| 339 | + def _run_unlinked(lib_mod): |
| 340 | + graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0))) |
| 341 | + graph_rt.set_input("rand_input", rand_input, **params) |
| 342 | + graph_rt.run() |
| 343 | + return graph_rt.get_output(0) |
| 344 | + |
| 345 | + unlinked_output = _run_unlinked(lib_mod) |
| 346 | + |
| 347 | + if "int" in linkable_dtype: |
| 348 | + np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy()) |
| 349 | + else: |
| 350 | + np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy()) |
351 | 351 |
|
352 | 352 |
|
353 | 353 | @tvm.testing.requires_micro |
354 | | -def test_crt_link_params(): |
| 354 | +def test_crt_link_params(linkable_dtype): |
355 | 355 | from tvm import micro |
356 | 356 |
|
357 | | - for dtype in LINKABLE_DTYPES: |
358 | | - mod, param_init = _make_mod_and_params(dtype) |
359 | | - rand_input = _make_random_tensor(dtype, INPUT_SHAPE) |
360 | | - target = "c" |
361 | | - runtime = Runtime("crt", {"system-lib": True}) |
362 | | - executor = Executor("graph", {"link-params": True}) |
363 | | - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): |
364 | | - factory = tvm.relay.build( |
365 | | - mod, target, runtime=runtime, executor=executor, params=param_init |
| 357 | + mod, param_init = _make_mod_and_params(linkable_dtype) |
| 358 | + rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE) |
| 359 | + main_func = mod["main"] |
| 360 | + target = "c" |
| 361 | + runtime = Runtime("crt", {"system-lib": True}) |
| 362 | + executor = Executor("graph", {"link-params": True}) |
| 363 | + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): |
| 364 | + factory = tvm.relay.build(mod, target, runtime=runtime, executor=executor, params=param_init) |
| 365 | + assert set(factory.get_params().keys()) == {"p0", "p1"} # NOTE: op folded |
| 366 | + |
| 367 | + temp_dir = tvm.contrib.utils.tempdir() |
| 368 | + template_project_dir = os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host") |
| 369 | + project = tvm.micro.generate_project( |
| 370 | + template_project_dir, factory, temp_dir / "project", {"verbose": 1} |
| 371 | + ) |
| 372 | + project.build() |
| 373 | + project.flash() |
| 374 | + with tvm.micro.Session(project.transport()) as sess: |
| 375 | + graph_rt = tvm.micro.session.create_local_graph_executor( |
| 376 | + factory.get_graph_json(), sess.get_system_lib(), sess.device |
366 | 377 | ) |
367 | | - assert set(factory.get_params().keys()) == {"p0", "p1"} # NOTE: op folded |
368 | 378 |
|
369 | | - temp_dir = tvm.contrib.utils.tempdir() |
370 | | - template_project_dir = os.path.join( |
371 | | - tvm.micro.get_standalone_crt_dir(), "template", "host" |
372 | | - ) |
373 | | - project = tvm.micro.generate_project( |
374 | | - template_project_dir, factory, temp_dir / "project", {"verbose": 1} |
375 | | - ) |
376 | | - project.build() |
377 | | - project.flash() |
378 | | - with tvm.micro.Session(project.transport()) as sess: |
379 | | - graph_rt = tvm.micro.session.create_local_graph_executor( |
380 | | - factory.get_graph_json(), sess.get_system_lib(), sess.device |
381 | | - ) |
382 | | - |
383 | | - # NOTE: not setting params here. |
384 | | - graph_rt.set_input("rand_input", rand_input) |
385 | | - graph_rt.run() |
386 | | - linked_output = graph_rt.get_output(0).numpy() |
| 379 | + # NOTE: not setting params here. |
| 380 | + graph_rt.set_input("rand_input", rand_input) |
| 381 | + graph_rt.run() |
| 382 | + linked_output = graph_rt.get_output(0).numpy() |
387 | 383 |
|
388 | | - runtime = Runtime("cpp", {"system-lib": True}) |
389 | | - with tvm.transform.PassContext(opt_level=3): |
390 | | - lib = tvm.relay.build(mod, "llvm", runtime=runtime, params=param_init) |
| 384 | + runtime = Runtime("cpp", {"system-lib": True}) |
| 385 | + with tvm.transform.PassContext(opt_level=3): |
| 386 | + lib = tvm.relay.build(mod, "llvm", runtime=runtime, params=param_init) |
391 | 387 |
|
392 | | - def _run_unlinked(lib): |
393 | | - graph_json, mod, lowered_params = lib |
394 | | - graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0)) |
395 | | - graph_rt.set_input("rand_input", rand_input, **lowered_params) |
396 | | - graph_rt.run() |
397 | | - return graph_rt.get_output(0).numpy() |
| 388 | + def _run_unlinked(lib): |
| 389 | + graph_json, mod, lowered_params = lib |
| 390 | + graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0)) |
| 391 | + graph_rt.set_input("rand_input", rand_input, **lowered_params) |
| 392 | + graph_rt.run() |
| 393 | + return graph_rt.get_output(0).numpy() |
398 | 394 |
|
399 | | - unlinked_output = _run_unlinked(lib) |
| 395 | + unlinked_output = _run_unlinked(lib) |
400 | 396 |
|
401 | | - if "int" in dtype: |
402 | | - np.testing.assert_equal(unlinked_output, linked_output) |
403 | | - else: |
404 | | - np.testing.assert_allclose(unlinked_output, linked_output) |
| 397 | + if "int" in linkable_dtype: |
| 398 | + np.testing.assert_equal(unlinked_output, linked_output) |
| 399 | + else: |
| 400 | + np.testing.assert_allclose(unlinked_output, linked_output) |
405 | 401 |
|
406 | 402 |
|
407 | 403 | if __name__ == "__main__": |
|
0 commit comments