|
1 | 1 | #include <iostream> |
| 2 | +#include <optional> |
2 | 3 | #include <tvm/ffi/function.h> |
3 | 4 | #include <tvm/runtime/vm/vm.h> |
4 | 5 |
|
5 | | -using tvm::ffi::Any; |
6 | | -using tvm::ffi::Function; |
7 | | -using tvm::runtime::Module; |
8 | | -using tvm::runtime::NDArray; |
9 | | -using tvm::runtime::memory::AllocatorType; |
10 | | - |
11 | 6 | int main() |
12 | 7 | { |
13 | 8 | std::string path = "./compiled_artifact.so"; |
14 | 9 |
|
15 | 10 | // Load the shared object |
16 | | - Module m = Module::LoadFromFile(path); |
17 | | - std::cout << m << std::endl; |
| 11 | + tvm::ffi::Module m = tvm::ffi::Module::LoadFromFile(path); |
18 | 12 |
|
19 | | - Function vm_load_executable = m.GetFunction("vm_load_executable"); |
| 13 | + tvm::ffi::Optional<tvm::ffi::Function> vm_load_executable = m->GetFunction("vm_load_executable"); |
20 | 14 | CHECK(vm_load_executable != nullptr) |
21 | 15 | << "Error: `vm_load_executable` does not exist in file `" << path << "`"; |
22 | 16 | std::cout << "Found vm_load_executable()" << std::endl; |
23 | 17 |
|
24 | 18 | // Create a VM from the Executable |
25 | | - Module mod = vm_load_executable().cast<Module>(); |
26 | | - Function vm_initialization = mod.GetFunction("vm_initialization"); |
| 19 | + tvm::ffi::Module mod = (*vm_load_executable)().cast<tvm::ffi::Module>(); |
| 20 | + tvm::ffi::Optional<tvm::ffi::Function> vm_initialization = mod->GetFunction("vm_initialization"); |
27 | 21 | CHECK(vm_initialization != nullptr) |
28 | 22 | << "Error: `vm_initialization` does not exist in file `" << path << "`"; |
29 | 23 | std::cout << "Found vm_initialization()" << std::endl; |
30 | 24 |
|
31 | 25 | // Initialize the VM |
32 | 26 | tvm::Device device{kDLCPU, 0}; |
33 | | - vm_initialization(static_cast<int>(device.device_type), static_cast<int>(device.device_id), |
34 | | - static_cast<int>(AllocatorType::kPooled), static_cast<int>(kDLCPU), 0, |
35 | | - static_cast<int>(AllocatorType::kPooled)); |
| 27 | + (*vm_initialization)(static_cast<int>(device.device_type), static_cast<int>(device.device_id), |
| 28 | + static_cast<int>(tvm::runtime::memory::AllocatorType::kPooled), static_cast<int>(kDLCPU), 0, |
| 29 | + static_cast<int>(tvm::runtime::memory::AllocatorType::kPooled)); |
36 | 30 | std::cout << "vm initialized" << std::endl; |
37 | 31 |
|
38 | | - Function main = mod.GetFunction("main"); |
| 32 | + tvm::ffi::Optional<tvm::ffi::Function> main = mod->GetFunction("main"); |
39 | 33 | CHECK(main != nullptr) |
40 | 34 | << "Error: Entry function does not exist in file `" << path << "`"; |
41 | 35 | std::cout << "Found main()" << std::endl; |
42 | 36 |
|
43 | 37 | // Create and initialize the input array |
44 | 38 | auto i32 = tvm::runtime::DataType::Int(32); |
45 | | - NDArray input = NDArray::Empty({3, 3}, i32, device); |
| 39 | + tvm::runtime::NDArray input = tvm::runtime::NDArray::Empty({3, 3}, i32, device); |
46 | 40 | int numel = input.Shape()->Product(); |
47 | 41 | for (int i = 0; i < numel; ++i) |
48 | 42 | static_cast<int *>(input->data)[i] = i; |
49 | 43 | std::cout << "Input array initialized" << std::endl; |
50 | 44 |
|
51 | 45 | // Run the main function |
52 | | - NDArray output = main(input).cast<NDArray>(); |
| 46 | + tvm::runtime::NDArray output = (*main)(input).cast<tvm::runtime::NDArray>(); |
53 | 47 | std::cout << "output: " << std::endl; |
54 | 48 | for (int i = 0; i < numel; ++i) |
55 | 49 | std::cout << " " << static_cast<int *>(output->data)[i] << std::endl; |
|
0 commit comments