Skip to content

Commit 0d270f5

Browse files
committed
Fix up the macro related reg
1 parent a16643b commit 0d270f5

4 files changed

Lines changed: 23 additions & 11 deletions

File tree

src/runtime/contrib/sort/sort.cc

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct float16 {
8080
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
8181
// and sort axis is dk. sort_num should have dimension of
8282
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
83-
TVM_FFI_STATIC_INIT_BLOCK({
83+
void RegisterArgsortNMS() {
8484
namespace refl = tvm::ffi::reflection;
8585
refl::GlobalDef().def_packed(
8686
"tvm.contrib.sort.argsort_nms", [](ffi::PackedArgs args, ffi::Any* ret) {
@@ -157,7 +157,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
157157
}
158158
}
159159
});
160-
});
160+
}
161161

162162
template <typename DataType, typename OutType>
163163
void sort_impl(
@@ -222,7 +222,7 @@ void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) {
222222
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
223223
// and sort axis is dk. sort_num should have dimension of
224224
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
225-
TVM_FFI_STATIC_INIT_BLOCK({
225+
void RegisterArgsort() {
226226
namespace refl = tvm::ffi::reflection;
227227
refl::GlobalDef().def_packed("tvm.contrib.sort.argsort", [](ffi::PackedArgs args, ffi::Any* ret) {
228228
auto input = args[0].cast<DLTensor*>();
@@ -311,7 +311,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
311311
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
312312
}
313313
});
314-
});
314+
}
315315

316316
// Sort implemented C library sort.
317317
// Return sorted tensor.
@@ -320,7 +320,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
320320
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
321321
// and sort axis is dk. sort_num should have dimension of
322322
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
323-
TVM_FFI_STATIC_INIT_BLOCK({
323+
void RegisterSort() {
324324
namespace refl = tvm::ffi::reflection;
325325
refl::GlobalDef().def_packed("tvm.contrib.sort.sort", [](ffi::PackedArgs args, ffi::Any* ret) {
326326
auto input = args[0].cast<DLTensor*>();
@@ -357,7 +357,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
357357
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
358358
}
359359
});
360-
});
360+
}
361361

362362
template <typename DataType, typename IndicesType>
363363
void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis,
@@ -452,7 +452,7 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i
452452
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
453453
// and sort axis is dk. sort_num should have dimension of
454454
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
455-
TVM_FFI_STATIC_INIT_BLOCK({
455+
void RegisterTopk() {
456456
namespace refl = tvm::ffi::reflection;
457457
refl::GlobalDef().def_packed("tvm.contrib.sort.topk", [](ffi::PackedArgs args, ffi::Any* ret) {
458458
auto input = args[0].cast<DLTensor*>();
@@ -574,6 +574,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
574574
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
575575
}
576576
});
577+
}
578+
579+
TVM_FFI_STATIC_INIT_BLOCK({
580+
RegisterArgsortNMS();
581+
RegisterArgsort();
582+
RegisterSort();
583+
RegisterTopk();
577584
});
578585

579586
} // namespace contrib

src/target/llvm/codegen_llvm.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2351,7 +2351,7 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm)
23512351
return nullptr;
23522352
}
23532353

2354-
TVM_FFI_STATIC_INIT_BLOCK({
2354+
static void CodegenLLVMRegisterReflection() {
23552355
namespace refl = tvm::ffi::reflection;
23562356
refl::GlobalDef()
23572357
.def("tvm.codegen.llvm.GetDefaultTargetTriple",
@@ -2385,7 +2385,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
23852385
LOG(WARNING) << "Current version of LLVM does not support feature detection on your CPU";
23862386
return {};
23872387
});
2388-
});
2388+
}
2389+
2390+
TVM_FFI_STATIC_INIT_BLOCK({ CodegenLLVMRegisterReflection(); });
23892391

23902392
} // namespace codegen
23912393
} // namespace tvm

src/target/llvm/llvm_module.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ void* LLVMModuleNode::GetFunctionAddr(const std::string& name,
621621
return nullptr;
622622
}
623623

624-
TVM_FFI_STATIC_INIT_BLOCK({
624+
static void LLVMReflectionRegister() {
625625
namespace refl = tvm::ffi::reflection;
626626
refl::GlobalDef()
627627
.def("target.build.llvm",
@@ -787,7 +787,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
787787
n->SetJITEngine(llvm_target->GetJITEngine());
788788
return runtime::Module(n);
789789
});
790-
});
790+
}
791+
792+
TVM_FFI_STATIC_INIT_BLOCK({ LLVMReflectionRegister(); });
791793

792794
} // namespace codegen
793795
} // namespace tvm

tests/python/relax/test_vm_callback_function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def relax_func(
100100
)
101101
vm = tvm.relax.VirtualMachine(ex, dev)
102102

103+
# custom callback that raises an error in python
103104
def custom_callback():
104105
local_var = 42
105106
raise RuntimeError("Error thrown from callback")

0 commit comments

Comments
 (0)