Skip to content

Commit e515b70

Browse files
committed
support ffi.Module (apache/tvm#18213)
1 parent 17ebac1 commit e515b70

1 file changed

Lines changed: 11 additions & 17 deletions

File tree

main.cc

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,49 @@
11
#include <iostream>
2+
#include <optional>
23
#include <tvm/ffi/function.h>
34
#include <tvm/runtime/vm/vm.h>
45

5-
using tvm::ffi::Any;
6-
using tvm::ffi::Function;
7-
using tvm::runtime::Module;
8-
using tvm::runtime::NDArray;
9-
using tvm::runtime::memory::AllocatorType;
10-
116
int main()
127
{
138
std::string path = "./compiled_artifact.so";
149

1510
// Load the shared object
16-
Module m = Module::LoadFromFile(path);
17-
std::cout << m << std::endl;
11+
tvm::ffi::Module m = tvm::ffi::Module::LoadFromFile(path);
1812

19-
Function vm_load_executable = m.GetFunction("vm_load_executable");
13+
tvm::ffi::Optional<tvm::ffi::Function> vm_load_executable = m->GetFunction("vm_load_executable");
2014
CHECK(vm_load_executable != nullptr)
2115
<< "Error: `vm_load_executable` does not exist in file `" << path << "`";
2216
std::cout << "Found vm_load_executable()" << std::endl;
2317

2418
// Create a VM from the Executable
25-
Module mod = vm_load_executable().cast<Module>();
26-
Function vm_initialization = mod.GetFunction("vm_initialization");
19+
tvm::ffi::Module mod = (*vm_load_executable)().cast<tvm::ffi::Module>();
20+
tvm::ffi::Optional<tvm::ffi::Function> vm_initialization = mod->GetFunction("vm_initialization");
2721
CHECK(vm_initialization != nullptr)
2822
<< "Error: `vm_initialization` does not exist in file `" << path << "`";
2923
std::cout << "Found vm_initialization()" << std::endl;
3024

3125
// Initialize the VM
3226
tvm::Device device{kDLCPU, 0};
33-
vm_initialization(static_cast<int>(device.device_type), static_cast<int>(device.device_id),
34-
static_cast<int>(AllocatorType::kPooled), static_cast<int>(kDLCPU), 0,
35-
static_cast<int>(AllocatorType::kPooled));
27+
(*vm_initialization)(static_cast<int>(device.device_type), static_cast<int>(device.device_id),
28+
static_cast<int>(tvm::runtime::memory::AllocatorType::kPooled), static_cast<int>(kDLCPU), 0,
29+
static_cast<int>(tvm::runtime::memory::AllocatorType::kPooled));
3630
std::cout << "vm initialized" << std::endl;
3731

38-
Function main = mod.GetFunction("main");
32+
tvm::ffi::Optional<tvm::ffi::Function> main = mod->GetFunction("main");
3933
CHECK(main != nullptr)
4034
<< "Error: Entry function does not exist in file `" << path << "`";
4135
std::cout << "Found main()" << std::endl;
4236

4337
// Create and initialize the input array
4438
auto i32 = tvm::runtime::DataType::Int(32);
45-
NDArray input = NDArray::Empty({3, 3}, i32, device);
39+
tvm::runtime::NDArray input = tvm::runtime::NDArray::Empty({3, 3}, i32, device);
4640
int numel = input.Shape()->Product();
4741
for (int i = 0; i < numel; ++i)
4842
static_cast<int *>(input->data)[i] = i;
4943
std::cout << "Input array initialized" << std::endl;
5044

5145
// Run the main function
52-
NDArray output = main(input).cast<NDArray>();
46+
tvm::runtime::NDArray output = (*main)(input).cast<tvm::runtime::NDArray>();
5347
std::cout << "output: " << std::endl;
5448
for (int i = 0; i < numel; ++i)
5549
std::cout << " " << static_cast<int *>(output->data)[i] << std::endl;

0 commit comments

Comments
 (0)