Skip to content

Commit 9192543

Browse files
committed
WIP
1 parent 1ea9768 commit 9192543

4 files changed

Lines changed: 21 additions & 13 deletions

File tree

src/relay/backend/aot_executor_codegen.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ class AOTExecutorCodegen : public ExprVisitor {
330330
fi_node->tir_primfuncs.Set(primfunc_target, primfunc);
331331
fi_node->relay_primfuncs.Set(primfunc_target, relay_func);
332332
}
333-
function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node));
333+
function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node));
334334
}
335335

336336
void VisitExpr_(const CallNode* op) override {
@@ -368,7 +368,7 @@ class AOTExecutorCodegen : public ExprVisitor {
368368
UpdateConstants(func, &params_);
369369

370370
// Generate the TIR function call
371-
CreateFuncCall(GetRef<Call>(op), ext_func->func_name);
371+
CreateFuncCall(GetRef<Call>(op), ext_func->prim_fn_var->name_hint);
372372
return;
373373
}
374374

@@ -403,7 +403,7 @@ class AOTExecutorCodegen : public ExprVisitor {
403403
UpdateFunctionMetadata(lowered_func, func, target);
404404

405405
// Generate the TIR function call
406-
CreateFuncCall(GetRef<Call>(op), lowered_func->func_name);
406+
CreateFuncCall(GetRef<Call>(op), lowered_func->prim_fn_var->name_hint);
407407
}
408408

409409
void VisitExpr_(const VarNode* op) override {

src/relay/backend/graph_executor_codegen.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
189189
targets_ = targets;
190190
}
191191

192-
/*!
192+
/*!
193193
* \brief Update the "main" control function's metadata
194194
*
195195
* \param func The main function that contains calls to relay primitive functions
@@ -273,6 +273,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
273273
}
274274

275275
function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node));
276+
}
277+
278+
276279
StorageInfo GetStorageInfo(const Expr& e) {
277280
size_t count = memory_plan_->expr_to_storage_info.count(e);
278281
ICHECK_GT(count, 0) << "Expr is not existing in storage plan";
@@ -466,16 +469,18 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
466469
return fields;
467470
}
468471

469-
std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const std::string& op_name,
470-
const std::string& func_name, GraphAttrs attrs) {
472+
std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const std::string& func_name, GraphAttrs op_attrs) {
471473
std::vector<GraphNodeRef> inputs;
472474
for (auto arg : op->args) {
473475
auto res = VisitExpr(arg);
474476
for (auto nr : res) {
475477
inputs.push_back(nr);
476478
}
477479
}
478-
auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs);
480+
481+
// Compute the operator name, because we used the get unique name when generating the kernel.
482+
auto op_name =_GetUniqueName(func_name);
483+
auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, op_attrs);
479484
return AddNode(node, GetRef<Expr>(op));
480485
}
481486

@@ -484,8 +489,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
484489
if (auto global_node = call->op.as<GlobalVarNode>()) {
485490
auto prim_fn_name = global_node->name_hint;
486491

487-
488-
return GraphAddCallNode(call_node, _GetUniqueName(prim_fn_name), prim_fn_name);
492+
// TODO(@jroesch): attach attributes somehow
493+
return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs());
489494
} else {
490495
ICHECK(false) << "Non-primitive-call nodes should have been transformed away.\n"
491496
<< "The graph executor code generator expects all calls to have their callee "

src/relay/backend/graph_plan_memory.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ namespace tvm {
3434
namespace relay {
3535

3636
namespace backend {
37-
StorageInfo::StorageInfo(std::vector<int64_t> storage_ids, std::vector<DLDeviceType> device_types) {
37+
StorageInfo::StorageInfo(std::vector<int64_t> storage_ids, std::vector<DLDeviceType> device_types, std::vector<int64_t> storage_sizes_in_bytes) {
3838
auto n = make_object<StorageInfoNode>();
3939
n->storage_ids = std::move(storage_ids);
4040
n->device_types = std::move(device_types);
41+
n->storage_sizes_in_bytes = std::move(storage_sizes_in_bytes);
4142
data_ = std::move(n);
4243
}
4344

@@ -233,7 +234,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
233234
for (const auto& kv : token_map_) {
234235
std::vector<int64_t> storage_ids;
235236
std::vector<DLDeviceType> device_types;
236-
std::vector<Integer> sid_sizes_byte
237+
std::vector<int64_t> sid_sizes_byte;
237238

238239
for (StorageToken* tok : kv.second) {
239240
if (tok->device_type) {
@@ -243,7 +244,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
243244
storage_ids.push_back(tok->storage_id);
244245
device_types.push_back(static_cast<DLDeviceType>(tok->device_type));
245246
}
246-
auto storage_info = backend::StorageInfo(storage_ids, device_types);
247+
auto storage_info = backend::StorageInfo(storage_ids, device_types, sid_sizes_byte);
247248
smap.Set(GetRef<Expr>(kv.first), storage_info);
248249
}
249250
// Either all or none of the nodes should be annotated.

src/relay/backend/utils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ class StorageInfoNode : public Object {
9999
std::vector<int64_t> storage_ids;
100100
/* \brief The type of "virtual devices" these expressions are stored on. */
101101
std::vector<DLDeviceType> device_types;
102+
/* \brief The sizes of each storage element. */
103+
std::vector<int64_t> storage_sizes_in_bytes;
102104

103105
// TODO(@jroesch): expose the fields
104106
void VisitAttrs(AttrVisitor* v) {}
@@ -110,7 +112,7 @@ class StorageInfoNode : public Object {
110112
/*! \brief The storage information for a single expression. */
111113
class StorageInfo : public ObjectRef {
112114
public:
113-
StorageInfo(std::vector<int64_t> storage_ids, std::vector<DLDeviceType> device_types);
115+
StorageInfo(std::vector<int64_t> storage_ids, std::vector<DLDeviceType> device_types, std::vector<int64_t> storage_sizes_in_bytes);
114116
TVM_DEFINE_OBJECT_REF_METHODS(StorageInfo, ObjectRef, StorageInfoNode);
115117
};
116118

0 commit comments

Comments
 (0)