|
39 | 39 |
|
40 | 40 |
|
41 | 41 | # The data types that are linkable. |
42 | | -LINKABLE_DTYPES = ( |
43 | | - [f"uint{b}" for b in (8, 16, 32, 64)] |
44 | | - + [f"int{b}" for b in (8, 16, 32, 64)] |
45 | | - + ["float32", "float64"] |
| 42 | +linkable_dtype = tvm.testing.parameter( |
| 43 | + *([f"uint{b}" for b in (8, 16, 32, 64)] |
| 44 | + + [f"int{b}" for b in (8, 16, 32, 64)] |
| 45 | + + ["float32", "float64"]) |
46 | 46 | ) |
47 | 47 |
|
48 | 48 |
|
49 | 49 | def dtype_info(dtype): |
50 | | - """Lookup numpy type info for the given string dtype (of LINKABLE_DTYPES above).""" |
| 50 | + """Lookup numpy type info for the given string dtype (of linkable_dtype params above).""" |
51 | 51 | if "int" in dtype: |
52 | 52 | return np.iinfo(getattr(np, dtype)) |
53 | 53 | else: |
@@ -181,54 +181,53 @@ def _add_decl(name, dtype): |
181 | 181 |
|
182 | 182 |
|
183 | 183 | @tvm.testing.requires_llvm |
184 | | -def test_llvm_link_params(): |
185 | | - for dtype in LINKABLE_DTYPES: |
186 | | - ir_mod, param_init = _make_mod_and_params(dtype) |
187 | | - rand_input = _make_random_tensor(dtype, INPUT_SHAPE) |
188 | | - main_func = ir_mod["main"] |
189 | | - target = "llvm --runtime=c --system-lib --link-params" |
190 | | - with tvm.transform.PassContext(opt_level=3): |
191 | | - lib = tvm.relay.build(ir_mod, target, params=param_init) |
192 | | - |
193 | | - # NOTE: Need to export_library() and load_library() to link all the Module(llvm, ...) |
194 | | - # against one another. |
195 | | - temp_dir = tempfile.mkdtemp() |
196 | | - export_file = os.path.join(temp_dir, "lib.so") |
197 | | - lib.lib.export_library(export_file) |
198 | | - mod = tvm.runtime.load_module(export_file) |
199 | | - assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded |
200 | | - assert mod.get_function("TVMSystemLibEntryPoint") != None |
201 | | - |
202 | | - graph = json.loads(lib.graph_json) |
203 | | - for p in lib.params: |
204 | | - _verify_linked_param(dtype, lib, mod, graph, p) or found_one |
205 | | - |
206 | | - # Wrap in function to explicitly deallocate the runtime. |
207 | | - def _run_linked(lib, mod): |
208 | | - graph_json, _, _ = lib |
209 | | - graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0)) |
210 | | - graph_rt.set_input("rand_input", rand_input) # NOTE: params not required. |
211 | | - graph_rt.run() |
212 | | - return graph_rt.get_output(0) |
213 | | - |
214 | | - linked_output = _run_linked(lib, mod) |
215 | | - |
216 | | - with tvm.transform.PassContext(opt_level=3): |
217 | | - lib = tvm.relay.build(ir_mod, "llvm --system-lib", params=param_init) |
218 | | - |
219 | | - def _run_unlinked(lib): |
220 | | - graph_json, mod, lowered_params = lib |
221 | | - graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0)) |
222 | | - graph_rt.set_input("rand_input", rand_input, **lowered_params) |
223 | | - graph_rt.run() |
224 | | - return graph_rt.get_output(0) |
225 | | - |
226 | | - unlinked_output = _run_unlinked(lib) |
227 | | - |
228 | | - if "int" in dtype: |
229 | | - np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy()) |
230 | | - else: |
231 | | - np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy()) |
| 184 | +def test_llvm_link_params(linkable_dtype): |
| 185 | + ir_mod, param_init = _make_mod_and_params(linkable_dtype) |
| 186 | + rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE) |
| 187 | + main_func = ir_mod["main"] |
| 188 | + target = "llvm --runtime=c --system-lib --link-params" |
| 189 | + with tvm.transform.PassContext(opt_level=3): |
| 190 | + lib = tvm.relay.build(ir_mod, target, params=param_init) |
| 191 | + |
| 192 | + # NOTE: Need to export_library() and load_library() to link all the Module(llvm, ...) |
| 193 | + # against one another. |
| 194 | + temp_dir = tempfile.mkdtemp() |
| 195 | + export_file = os.path.join(temp_dir, "lib.so") |
| 196 | + lib.lib.export_library(export_file) |
| 197 | + mod = tvm.runtime.load_module(export_file) |
| 198 | + assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded |
| 199 | + assert mod.get_function("TVMSystemLibEntryPoint") != None |
| 200 | + |
| 201 | + graph = json.loads(lib.graph_json) |
| 202 | + for p in lib.params: |
| 203 | + _verify_linked_param(linkable_dtype, lib, mod, graph, p) or found_one |
| 204 | + |
| 205 | + # Wrap in function to explicitly deallocate the runtime. |
| 206 | + def _run_linked(lib, mod): |
| 207 | + graph_json, _, _ = lib |
| 208 | + graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0)) |
| 209 | + graph_rt.set_input("rand_input", rand_input) # NOTE: params not required. |
| 210 | + graph_rt.run() |
| 211 | + return graph_rt.get_output(0) |
| 212 | + |
| 213 | + linked_output = _run_linked(lib, mod) |
| 214 | + |
| 215 | + with tvm.transform.PassContext(opt_level=3): |
| 216 | + lib = tvm.relay.build(ir_mod, "llvm --system-lib", params=param_init) |
| 217 | + |
| 218 | + def _run_unlinked(lib): |
| 219 | + graph_json, mod, lowered_params = lib |
| 220 | + graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0)) |
| 221 | + graph_rt.set_input("rand_input", rand_input, **lowered_params) |
| 222 | + graph_rt.run() |
| 223 | + return graph_rt.get_output(0) |
| 224 | + |
| 225 | + unlinked_output = _run_unlinked(lib) |
| 226 | + |
| 227 | + if "int" in linkable_dtype: |
| 228 | + np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy()) |
| 229 | + else: |
| 230 | + np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy()) |
232 | 231 |
|
233 | 232 |
|
234 | 233 | def _get_c_datatype(dtype): |
@@ -263,137 +262,134 @@ def _format_c_value(dtype, width, x): |
263 | 262 | HEX_NUM_RE = re.compile(r"[+\-]?(?:(?:0x[0-9A-Fa-f.p+-]+)|(?:INFINITY)|(?:NAN))") |
264 | 263 |
|
265 | 264 |
|
266 | | -def test_c_link_params(): |
| 265 | +def test_c_link_params(linkable_dtype): |
267 | 266 | temp_dir = utils.tempdir() |
268 | | - for dtype in LINKABLE_DTYPES: |
269 | | - mod, param_init = _make_mod_and_params(dtype) |
270 | | - rand_input = _make_random_tensor(dtype, INPUT_SHAPE) |
271 | | - main_func = mod["main"] |
272 | | - target = "c --link-params" |
273 | | - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): |
274 | | - lib = tvm.relay.build(mod, target, params=param_init) |
275 | | - assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded |
276 | | - |
277 | | - src = lib.lib.get_source() |
278 | | - lib.lib.save(temp_dir.relpath("test.c"), "c") |
279 | | - c_dtype = _get_c_datatype(dtype) |
280 | | - src_lines = src.split("\n") |
281 | | - param = lib.params["p0"].numpy().reshape(np.prod(KERNEL_SHAPE)) |
282 | | - param_def = f"static const {c_dtype} __tvm_param__p0[{np.prod(param.shape)}] = {{" |
283 | | - for i, line in enumerate(src_lines): |
284 | | - if line == param_def: |
285 | | - i += 1 |
286 | | - break |
287 | | - else: |
288 | | - assert False, f'did not find parameter definition "{param_def}":\n{src}' |
289 | | - |
290 | | - cursor = 0 |
291 | | - width = dtype_info(dtype).bits // 4 + 2 |
292 | | - if dtype.startswith("int"): |
293 | | - width += 1 # Account for sign |
294 | | - |
295 | | - while "};" not in src_lines[i]: |
296 | | - for match in HEX_NUM_RE.finditer(src_lines[i]): |
297 | | - assert match.group() == _format_c_value(dtype, width, param[cursor]), ( |
298 | | - f'p0 byte {cursor}: want "{_format_c_value(dtype, width, param[cursor])}" got ' |
299 | | - f'"{match.group(0)}"; full p0 follows:\n{src}' |
300 | | - ) |
301 | | - cursor += 1 |
| 267 | + mod, param_init = _make_mod_and_params(linkable_dtype) |
| 268 | + rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE) |
| 269 | + main_func = mod["main"] |
| 270 | + target = "c --link-params" |
| 271 | + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): |
| 272 | + lib = tvm.relay.build(mod, target, params=param_init) |
| 273 | + assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded |
| 274 | + |
| 275 | + src = lib.lib.get_source() |
| 276 | + lib.lib.save(temp_dir.relpath("test.c"), "c") |
| 277 | + c_dtype = _get_c_datatype(linkable_dtype) |
| 278 | + src_lines = src.split("\n") |
| 279 | + param = lib.params["p0"].numpy().reshape(np.prod(KERNEL_SHAPE)) |
| 280 | + param_def = f"static const {c_dtype} __tvm_param__p0[{np.prod(param.shape)}] = {{" |
| 281 | + for i, line in enumerate(src_lines): |
| 282 | + if line == param_def: |
302 | 283 | i += 1 |
303 | | - |
304 | | - assert cursor == np.prod(param.shape) |
305 | | - |
306 | | - # Need a unique name per library to avoid dlopen caching the lib load. |
307 | | - lib_path = temp_dir.relpath(f"test-{dtype}-linked.so") |
308 | | - lib["remove_params"]().export_library(lib_path) |
309 | | - lib_mod = tvm.runtime.load_module(lib_path) |
310 | | - |
311 | | - # lib_mod = lib_factory['default']() |
312 | | - graph = json.loads(lib.graph_json) |
313 | | - for p in lib.params: |
314 | | - _verify_linked_param(dtype, lib, lib_mod, graph, p) |
315 | | - |
316 | | - # Wrap in function to explicitly deallocate the runtime. |
317 | | - def _run_linked(lib_mod): |
318 | | - graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0))) |
319 | | - graph_rt.set_input("rand_input", rand_input) # NOTE: params not required. |
320 | | - graph_rt.run() |
321 | | - |
322 | | - return graph_rt.get_output(0) |
323 | | - |
324 | | - linked_output = _run_linked(lib_mod) |
325 | | - |
326 | | - linked_params = lib.params |
327 | | - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): |
328 | | - lib = tvm.relay.build(mod, "c", params=param_init) |
329 | | - _, _, params = lib |
330 | | - # Need a unique name per library to avoid dlopen caching the lib load. |
331 | | - lib_path = temp_dir.relpath(f"test-{dtype}-unlinked.so") |
332 | | - lib.export_library(lib_path) |
333 | | - lib_mod = tvm.runtime.load_module(lib_path) |
334 | | - |
335 | | - def _run_unlinked(lib_mod): |
336 | | - graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0))) |
337 | | - graph_rt.set_input("rand_input", rand_input, **params) |
338 | | - graph_rt.run() |
339 | | - return graph_rt.get_output(0) |
340 | | - |
341 | | - unlinked_output = _run_unlinked(lib_mod) |
342 | | - |
343 | | - if "int" in dtype: |
344 | | - np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy()) |
| 284 | + break |
345 | 285 | else: |
346 | | - np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy()) |
| 286 | + assert False, f'did not find parameter definition "{param_def}":\n{src}' |
| 287 | + |
| 288 | + cursor = 0 |
| 289 | + width = dtype_info(linkable_dtype).bits // 4 + 2 |
| 290 | + if linkable_dtype.startswith("int"): |
| 291 | + width += 1 # Account for sign |
| 292 | + |
| 293 | + while "};" not in src_lines[i]: |
| 294 | + for match in HEX_NUM_RE.finditer(src_lines[i]): |
| 295 | + assert match.group() == _format_c_value(linkable_dtype, width, param[cursor]), ( |
| 296 | + f'p0 byte {cursor}: want "{_format_c_value(linkable_dtype, width, param[cursor])}" got ' |
| 297 | + f'"{match.group(0)}"; full p0 follows:\n{src}' |
| 298 | + ) |
| 299 | + cursor += 1 |
| 300 | + i += 1 |
| 301 | + |
| 302 | + assert cursor == np.prod(param.shape) |
| 303 | + |
| 304 | + # Need a unique name per library to avoid dlopen caching the lib load. |
| 305 | + lib_path = temp_dir.relpath(f"test-{linkable_dtype}-linked.so") |
| 306 | + lib["remove_params"]().export_library(lib_path) |
| 307 | + lib_mod = tvm.runtime.load_module(lib_path) |
| 308 | + |
| 309 | + # lib_mod = lib_factory['default']() |
| 310 | + graph = json.loads(lib.graph_json) |
| 311 | + for p in lib.params: |
| 312 | + _verify_linked_param(linkable_dtype, lib, lib_mod, graph, p) |
| 313 | + |
| 314 | + # Wrap in function to explicitly deallocate the runtime. |
| 315 | + def _run_linked(lib_mod): |
| 316 | + graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0))) |
| 317 | + graph_rt.set_input("rand_input", rand_input) # NOTE: params not required. |
| 318 | + graph_rt.run() |
| 319 | + |
| 320 | + return graph_rt.get_output(0) |
| 321 | + |
| 322 | + linked_output = _run_linked(lib_mod) |
| 323 | + |
| 324 | + linked_params = lib.params |
| 325 | + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): |
| 326 | + lib = tvm.relay.build(mod, "c", params=param_init) |
| 327 | + _, _, params = lib |
| 328 | + # Need a unique name per library to avoid dlopen caching the lib load. |
| 329 | + lib_path = temp_dir.relpath(f"test-{linkable_dtype}-unlinked.so") |
| 330 | + lib.export_library(lib_path) |
| 331 | + lib_mod = tvm.runtime.load_module(lib_path) |
| 332 | + |
| 333 | + def _run_unlinked(lib_mod): |
| 334 | + graph_rt = tvm.contrib.graph_executor.GraphModule(lib_mod["default"](tvm.cpu(0))) |
| 335 | + graph_rt.set_input("rand_input", rand_input, **params) |
| 336 | + graph_rt.run() |
| 337 | + return graph_rt.get_output(0) |
| 338 | + |
| 339 | + unlinked_output = _run_unlinked(lib_mod) |
| 340 | + |
| 341 | + if "int" in linkable_dtype: |
| 342 | + np.testing.assert_equal(unlinked_output.numpy(), linked_output.numpy()) |
| 343 | + else: |
| 344 | + np.testing.assert_allclose(unlinked_output.numpy(), linked_output.numpy()) |
347 | 345 |
|
348 | 346 |
|
349 | 347 | @tvm.testing.requires_micro |
350 | | -def test_crt_link_params(): |
| 348 | +def test_crt_link_params(linkable_dtype): |
351 | 349 | from tvm import micro |
352 | | - |
353 | | - for dtype in LINKABLE_DTYPES: |
354 | | - mod, param_init = _make_mod_and_params(dtype) |
355 | | - rand_input = _make_random_tensor(dtype, INPUT_SHAPE) |
356 | | - main_func = mod["main"] |
357 | | - target = "c --system-lib --runtime=c --link-params" |
358 | | - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): |
359 | | - factory = tvm.relay.build(mod, target, params=param_init) |
360 | | - assert set(factory.get_params().keys()) == {"p0", "p1"} # NOTE: op folded |
361 | | - |
362 | | - temp_dir = tvm.contrib.utils.tempdir() |
363 | | - template_project_dir = os.path.join( |
364 | | - tvm.micro.get_standalone_crt_dir(), "template", "host" |
365 | | - ) |
366 | | - project = tvm.micro.generate_project( |
367 | | - template_project_dir, factory, temp_dir / "project", {"verbose": 1} |
| 350 | + mod, param_init = _make_mod_and_params(linkable_dtype) |
| 351 | + rand_input = _make_random_tensor(linkable_dtype, INPUT_SHAPE) |
| 352 | + main_func = mod["main"] |
| 353 | + target = "c --system-lib --runtime=c --link-params" |
| 354 | + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): |
| 355 | + factory = tvm.relay.build(mod, target, params=param_init) |
| 356 | + assert set(factory.get_params().keys()) == {"p0", "p1"} # NOTE: op folded |
| 357 | + |
| 358 | + temp_dir = tvm.contrib.utils.tempdir() |
| 359 | + template_project_dir = os.path.join( |
| 360 | + tvm.micro.get_standalone_crt_dir(), "template", "host" |
| 361 | + ) |
| 362 | + project = tvm.micro.generate_project( |
| 363 | + template_project_dir, factory, temp_dir / "project", {"verbose": 1} |
| 364 | + ) |
| 365 | + project.build() |
| 366 | + project.flash() |
| 367 | + with tvm.micro.Session(project.transport()) as sess: |
| 368 | + graph_rt = tvm.micro.session.create_local_graph_executor( |
| 369 | + factory.get_graph_json(), sess.get_system_lib(), sess.device |
368 | 370 | ) |
369 | | - project.build() |
370 | | - project.flash() |
371 | | - with tvm.micro.Session(project.transport()) as sess: |
372 | | - graph_rt = tvm.micro.session.create_local_graph_executor( |
373 | | - factory.get_graph_json(), sess.get_system_lib(), sess.device |
374 | | - ) |
375 | 371 |
|
376 | | - # NOTE: not setting params here. |
377 | | - graph_rt.set_input("rand_input", rand_input) |
378 | | - graph_rt.run() |
379 | | - linked_output = graph_rt.get_output(0).numpy() |
| 372 | + # NOTE: not setting params here. |
| 373 | + graph_rt.set_input("rand_input", rand_input) |
| 374 | + graph_rt.run() |
| 375 | + linked_output = graph_rt.get_output(0).numpy() |
380 | 376 |
|
381 | | - with tvm.transform.PassContext(opt_level=3): |
382 | | - lib = tvm.relay.build(mod, "llvm --system-lib", params=param_init) |
| 377 | + with tvm.transform.PassContext(opt_level=3): |
| 378 | + lib = tvm.relay.build(mod, "llvm --system-lib", params=param_init) |
383 | 379 |
|
384 | | - def _run_unlinked(lib): |
385 | | - graph_json, mod, lowered_params = lib |
386 | | - graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0)) |
387 | | - graph_rt.set_input("rand_input", rand_input, **lowered_params) |
388 | | - graph_rt.run() |
389 | | - return graph_rt.get_output(0).numpy() |
| 380 | + def _run_unlinked(lib): |
| 381 | + graph_json, mod, lowered_params = lib |
| 382 | + graph_rt = tvm.contrib.graph_executor.create(graph_json, mod, tvm.cpu(0)) |
| 383 | + graph_rt.set_input("rand_input", rand_input, **lowered_params) |
| 384 | + graph_rt.run() |
| 385 | + return graph_rt.get_output(0).numpy() |
390 | 386 |
|
391 | | - unlinked_output = _run_unlinked(lib) |
| 387 | + unlinked_output = _run_unlinked(lib) |
392 | 388 |
|
393 | | - if "int" in dtype: |
394 | | - np.testing.assert_equal(unlinked_output, linked_output) |
395 | | - else: |
396 | | - np.testing.assert_allclose(unlinked_output, linked_output) |
| 389 | + if "int" in linkable_dtype: |
| 390 | + np.testing.assert_equal(unlinked_output, linked_output) |
| 391 | + else: |
| 392 | + np.testing.assert_allclose(unlinked_output, linked_output) |
397 | 393 |
|
398 | 394 |
|
399 | 395 | if __name__ == "__main__": |
|
0 commit comments