Skip to content

Commit d5fb6cf

Browse files
PhilippvKIvy
authored andcommitted
Add runtime.ModuleGetFormat method enabling export of BYOC generated sources which require a .cpp/.cc file extension (apache#9243)
* Allow export of C++ kernels using correct file extension * [WIP] Set module_key=c for CSourceCrtMetadataModuleNode to temporarily fix failing tests I realized that the module format `cc` is currently already used by the `CSourceCrtMetadataModuleNode` declared in `src/target/source/source_module.cc`. This needs to be discussed first to decide if either the module_key should be changed or the test cases expecting the systemlib kernel (e.g. `default_lib0.c`) to have a `.c` extension. * Update Makefiles used by tests/python/relay/aot/ to support C++ file extensions AOT: Add c++ support to aot_test.mk AOT: Add c++ support to corstone300.mk * Add missing definition of GetFormat to cmsisnn and ethosn codegens (WIP) * Resolve PR comments * lint python/tvm/runtime/module.py * fix EthosUModuleNode for CI * Fix: detect empty module.format * Add error message to assertion * Lint python/tvm/runtime/module.py
1 parent 9c51ed8 commit d5fb6cf

8 files changed

Lines changed: 64 additions & 14 deletions

File tree

include/tvm/runtime/module.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ class TVM_DLL ModuleNode : public Object {
156156
* \return Possible source code when available.
157157
*/
158158
virtual std::string GetSource(const std::string& format = "");
159+
/*!
160+
* \brief Get the format of the module, when available.
161+
* \return Possible format when available.
162+
*/
163+
virtual std::string GetFormat();
159164
/*!
160165
* \brief Get packed function from current module by name.
161166
*

python/tvm/micro/model_library_format.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,12 @@ def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None):
8686

8787
for dso_mod in dso_modules:
8888
if dso_mod.type_key == "c":
89+
assert dso_mod.format in ["c", "cc", "cpp"]
90+
ext = dso_mod.format
8991
index = mod_indices["src"]
9092
mod_indices["src"] += 1
9193
parent_dir = os.path.join(host_codegen_dir, "src")
92-
file_name = os.path.join(parent_dir, f"{lib_name}{index}.c")
94+
file_name = os.path.join(parent_dir, f"{lib_name}{index}.{ext}")
9395
elif dso_mod.type_key == "llvm":
9496
index = mod_indices["lib"]
9597
mod_indices["lib"] += 1

python/tvm/runtime/module.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ def type_key(self):
185185
"""Get type key of the module."""
186186
return _ffi_api.ModuleGetTypeKey(self)
187187

188+
@property
189+
def format(self):
190+
"""Get the format of the module."""
191+
return _ffi_api.ModuleGetFormat(self)
192+
188193
def get_source(self, fmt=""):
189194
"""Get source code from module, if available.
190195
@@ -402,7 +407,12 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
402407
for index, module in enumerate(modules):
403408
if fcompile is not None and hasattr(fcompile, "object_format"):
404409
if module.type_key == "c":
405-
object_format = "c"
410+
assert module.format in [
411+
"c",
412+
"cc",
413+
"cpp",
414+
], "The module.format needs to be either c, cc or cpp"
415+
object_format = module.format
406416
has_c_module = True
407417
else:
408418
object_format = fcompile.object_format
@@ -411,7 +421,15 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
411421
object_format = "o"
412422
else:
413423
assert module.type_key == "c"
414-
object_format = "c"
424+
if len(module.format) > 0:
425+
assert module.format in [
426+
"c",
427+
"cc",
428+
"cpp",
429+
], "The module.format needs to be either c, cc or cpp"
430+
object_format = module.format
431+
else:
432+
object_format = "c"
415433
if "cc" in kwargs:
416434
if kwargs["cc"] == "nvcc":
417435
object_format = "cu"

src/relay/backend/contrib/ethosu/source_module.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class EthosUModuleNode : public ModuleNode {
8989

9090
std::string GetSource(const std::string& format) final { return c_source; }
9191

92+
std::string GetFormat() { return "c"; }
93+
9294
Array<CompilationArtifact> GetArtifacts() { return compilation_artifacts_; }
9395

9496
/*!

src/runtime/module.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
127127
}
128128
}
129129

130+
std::string ModuleNode::GetFormat() {
131+
LOG(FATAL) << "Module[" << type_key() << "] does not support GetFormat";
132+
return "";
133+
}
134+
130135
bool RuntimeEnabled(const std::string& target) {
131136
std::string f_name;
132137
if (target == "cpu") {
@@ -179,6 +184,10 @@ TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) {
179184
return std::string(mod->type_key());
180185
});
181186

187+
TVM_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) {
188+
return mod->GetFormat();
189+
});
190+
182191
TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile);
183192

184193
TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")

src/target/source/source_module.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class SourceModuleNode : public runtime::ModuleNode {
6363

6464
std::string GetSource(const std::string& format) final { return code_; }
6565

66+
std::string GetFormat() { return fmt_; }
67+
6668
protected:
6769
std::string code_;
6870
std::string fmt_;
@@ -102,10 +104,12 @@ class CSourceModuleNode : public runtime::ModuleNode {
102104

103105
std::string GetSource(const std::string& format) final { return code_; }
104106

107+
std::string GetFormat() { return fmt_; }
108+
105109
void SaveToFile(const std::string& file_name, const std::string& format) final {
106110
std::string fmt = GetFileFormat(file_name, format);
107111
std::string meta_file = GetMetaFilePath(file_name);
108-
if (fmt == "c" || fmt == "cu") {
112+
if (fmt == "c" || fmt == "cc" || fmt == "cpp" || fmt == "cu") {
109113
ICHECK_NE(code_.length(), 0);
110114
SaveBinaryToFile(file_name, code_);
111115
} else {
@@ -160,14 +164,15 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
160164

161165
std::string GetSource(const std::string& format) final { return code_.str(); }
162166

167+
std::string GetFormat() { return fmt_; }
163168
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
164169
return PackedFunc(nullptr);
165170
}
166171

167172
void SaveToFile(const std::string& file_name, const std::string& format) final {
168173
std::string fmt = GetFileFormat(file_name, format);
169174
std::string meta_file = GetMetaFilePath(file_name);
170-
if (fmt == "c") {
175+
if (fmt == "c" || fmt == "cc" || fmt == "cpp") {
171176
auto code_str = code_.str();
172177
ICHECK_NE(code_str.length(), 0);
173178
SaveBinaryToFile(file_name, code_str);
@@ -509,7 +514,7 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array<runtime::Module>& mod
509514
}
510515
}
511516
}
512-
auto n = make_object<CSourceCrtMetadataModuleNode>(func_names, "cc", target, runtime, metadata);
517+
auto n = make_object<CSourceCrtMetadataModuleNode>(func_names, "c", target, runtime, metadata);
513518
auto csrc_metadata_module = runtime::Module(n);
514519
for (const auto& mod : modules) {
515520
csrc_metadata_module.Import(mod);

tests/python/relay/aot/corstone300.mk

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ QUIET ?= @
6969
$(endif)
7070

7171
CRT_SRCS = $(shell find $(CRT_ROOT))
72-
CODEGEN_SRCS = $(shell find $(abspath $(CODEGEN_ROOT)/host/src/*.c))
73-
CODEGEN_OBJS = $(subst .c,.o,$(CODEGEN_SRCS))
72+
C_CODEGEN_SRCS = $(shell find $(abspath $(CODEGEN_ROOT)/host/src/*.c))
73+
CC_CODEGEN_SRCS = $(shell find $(abspath $(CODEGEN_ROOT)/host/src/*.cc))
74+
C_CODEGEN_OBJS = $(subst .c,.o,$(C_CODEGEN_SRCS))
75+
CC_CODEGEN_OBJS = $(subst .cc,.o,$(CC_CODEGEN_SRCS))
7476
CMSIS_STARTUP_SRCS = $(shell find ${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Source/*.c)
7577
UART_SRCS = $(shell find ${PLATFORM_PATH}/*.c)
7678

@@ -96,9 +98,9 @@ $(build_dir)/tvm_ethosu_runtime.o: $(TVM_ROOT)/src/runtime/contrib/ethosu/bare_m
9698
$(QUIET)mkdir -p $(@D)
9799
$(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^
98100

99-
$(build_dir)/libcodegen.a: $(CODEGEN_SRCS)
100-
$(QUIET)cd $(abspath $(CODEGEN_ROOT)/host/src) && $(CC) -c $(PKG_CFLAGS) $(CODEGEN_SRCS)
101-
$(QUIET)$(AR) -cr $(abspath $(build_dir)/libcodegen.a) $(CODEGEN_OBJS)
101+
$(build_dir)/libcodegen.a: $(C_CODEGEN_SRCS) $(CC_CODEGEN_SRCS)
102+
$(QUIET)cd $(abspath $(CODEGEN_ROOT)/host/src) && $(CC) -c $(PKG_CFLAGS) $(C_CODEGEN_SRCS) $(CC_CODEGEN_SRCS)
103+
$(QUIET)$(AR) -cr $(abspath $(build_dir)/libcodegen.a) $(C_CODEGEN_OBJS) $(CC_CODEGEN_OBJS)
102104
$(QUIET)$(RANLIB) $(abspath $(build_dir)/libcodegen.a)
103105

104106
${build_dir}/libcmsis_startup.a: $(CMSIS_STARTUP_SRCS)

tests/python/relay/aot/default.mk

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ ENABLE_TVM_PLATFORM_ABORT_BACKTRACE = 0
2222
DMLC_CORE=$(TVM_ROOT)/3rdparty/dmlc-core
2323
PKG_COMPILE_OPTS = -g
2424
CC = gcc
25+
#CC = g++
2526
AR = ar
2627
RANLIB = ranlib
2728
CC_OPTS = CC=$(CC) AR=$(AR) RANLIB=$(RANLIB)
@@ -39,17 +40,23 @@ $(endif)
3940

4041
aot_test_runner: $(build_dir)/aot_test_runner
4142

42-
source_libs= $(wildcard $(build_dir)/../codegen/host/src/*.c)
43-
lib_objs =$(source_libs:.c=.o)
43+
c_source_libs= $(wildcard $(build_dir)/../codegen/host/src/*.c)
44+
cc_source_libs= $(wildcard $(build_dir)/../codegen/host/src/*.cc)
45+
c_lib_objs =$(c_source_libs:.c=.o)
46+
cc_lib_objs =$(cc_source_libs:.cc=.o)
4447

45-
$(build_dir)/aot_test_runner: $(build_dir)/test.c $(source_libs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o
48+
$(build_dir)/aot_test_runner: $(build_dir)/test.c $(c_source_libs) $(cc_source_libs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o
4649
$(QUIET)mkdir -p $(@D)
4750
$(QUIET)$(CC) $(CFLAGS) $(PKG_CFLAGS) -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS) -lm
4851

4952
$(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.c
5053
$(QUIET)mkdir -p $(@D)
5154
$(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS)
5255

56+
$(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.cc
57+
$(QUIET)mkdir -p $(@D)
58+
$(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS)
59+
5360
$(build_dir)/stack_allocator.o: $(STANDALONE_CRT_DIR)/src/runtime/crt/memory/stack_allocator.c
5461
$(QUIET)mkdir -p $(@D)
5562
$(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS)

0 commit comments

Comments
 (0)