Skip to content

Commit abe8567

Browse files
committed
[TIR.Constant] U1 usecase
Constants are now aggregated into one struct and initialized in default_lib0.c file Change-Id: I34d61f8139c8a92c06944fe990ba892a660476fd
1 parent 45e39be commit abe8567

File tree

22 files changed

+651
-106
lines changed

22 files changed

+651
-106
lines changed

include/tvm/tir/stmt.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,9 @@ class AllocateConst : public Stmt {
673673
* create AllocateConstNode with irmod_storage_idx or data
674674
*/
675675
TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
676-
ObjectRef data_or_idx, Stmt body, Span span = Span());
676+
ObjectRef data_or_idx, Stmt body,
677+
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
678+
Span span = Span());
677679
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
678680
};
679681

include/tvm/tir/usmp/algorithms.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_inf
5353
*/
5454
Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr,
5555
const Integer& memory_pressure);
56+
/*!
57+
*\brief The Hill-Climb algoritm to plan memory
58+
*
59+
* This will perform an attempt to utilize probabalistic approach to memory
60+
* allocation. Typically better than greedy family, but quite slow due to large
61+
* number of iterations.
62+
*
63+
* \return A Map of BufferInfo objects and their associated PoolAllocation
64+
*/
65+
Map<BufferInfo, PoolAllocation> HillClimb(const Array<BufferInfo>& buffer_info_arr,
66+
const Integer& memory_pressure);
5667

5768
/*!
5869
* \brief The Hill-Climb algorithm to plan memory

include/tvm/tir/usmp/utils.h

Lines changed: 141 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,126 @@ constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm";
4444

4545
namespace tir {
4646
namespace usmp {
47+
/*
48+
* \brief The ConstantInfoNode contains numeric literal in RO pool
49+
*/
50+
struct ConstantInfoNode : public Object {
51+
String name_hint;
52+
Integer byte_alignment;
53+
Integer byte_offset;
54+
runtime::NDArray data;
55+
56+
void VisitAttrs(tvm::AttrVisitor* v) {
57+
v->Visit("constant_names", &name_hint);
58+
v->Visit("constant_alignment", &byte_alignment);
59+
v->Visit("constant_offsets", &byte_offset);
60+
v->Visit("constant_data", &data);
61+
}
62+
63+
bool SEqualReduce(const ConstantInfoNode* other, SEqualReducer equal) const {
64+
return equal(name_hint, other->name_hint) && equal(byte_alignment, other->byte_alignment) &&
65+
equal(byte_offset, other->byte_offset) && equal(data, other->data);
66+
}
67+
68+
void SHashReduce(SHashReducer hash_reduce) const {
69+
hash_reduce(name_hint);
70+
hash_reduce(byte_alignment);
71+
hash_reduce(byte_offset);
72+
hash_reduce(data);
73+
}
74+
75+
static constexpr const char* _type_key = "tir.usmp.ConstantInfo";
76+
static constexpr bool _type_has_method_sequal_reduce = true;
77+
static constexpr bool _type_has_method_shash_reduce = true;
78+
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantInfoNode, Object);
79+
};
80+
81+
class ConstantInfo : public ObjectRef {
82+
public:
83+
TVM_DLL ConstantInfo(String name, Integer byte_alignment, Integer byte_offset,
84+
runtime::NDArray data);
85+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ConstantInfo, ObjectRef, ConstantInfoNode);
86+
};
87+
88+
#if 0
89+
struct PoolInfoNode : public Object {
90+
/*! \brief The name of the memory pool */
91+
String pool_name;
92+
/*! \brief The expected size hint to be used by the allocator.
93+
* The size_hint_bytes is set to kUnrestrictedPoolSizeHint
94+
* to indicate the pool is not size restricted.
95+
*/
96+
Integer size_hint_bytes;
97+
/*! \brief The accessibility from each Target */
98+
Map<Target, String> target_access; // 'rw' or 'ro'
99+
/*! \brief The clock frequency of the memory in Hz */
100+
Integer clock_frequency_hz;
101+
/*! \brief The read bandwidth in bytes/cycle */
102+
Integer read_bandwidth_bytes_per_cycle;
103+
/*! \brief The write bandwidth in bytes/cycle */
104+
Integer write_bandwidth_bytes_per_cycle;
105+
/*! \brief The read latency in cycles */
106+
Integer read_latency_cycles;
107+
/*! \brief The write latency in cycles */
108+
Integer write_latency_cycles;
109+
/*! \brief The burst length in bytes for each Target */
110+
Map<Target, Integer> target_burst_bytes;
111+
/*! \brief Whether pool is internally generated.
112+
* The internal pools will be generated as part of
113+
* the entry point code generation of the executor
114+
*/
115+
bool is_internal = false;
116+
117+
Array<ConstantInfo> constant_info_arr;
118+
119+
void VisitAttrs(tvm::AttrVisitor* v) {
120+
v->Visit("pool_name", &pool_name);
121+
v->Visit("size_hint_bytes", &size_hint_bytes);
122+
v->Visit("target_access", &target_access);
123+
v->Visit("clock_frequency_hz", &clock_frequency_hz);
124+
v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle);
125+
v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle);
126+
v->Visit("read_latency_cycles", &read_latency_cycles);
127+
v->Visit("write_latency_cycles", &write_latency_cycles);
128+
v->Visit("target_burst_bytes", &target_burst_bytes);
129+
v->Visit("is_internal", &is_internal);
130+
v->Visit("constant_info_arr", &constant_info_arr);
131+
}
132+
133+
bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const {
134+
return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) &&
135+
equal(target_access, other->target_access) &&
136+
equal(target_access, other->target_access) &&
137+
equal(clock_frequency_hz, other->clock_frequency_hz) &&
138+
equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) &&
139+
equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) &&
140+
equal(read_latency_cycles, other->read_latency_cycles) &&
141+
equal(write_latency_cycles, other->write_latency_cycles) &&
142+
equal(target_burst_bytes, other->target_burst_bytes) &&
143+
equal(is_internal, other->is_internal) &&
144+
equal(constant_info_arr, other->constant_info_arr);
145+
}
146+
147+
void SHashReduce(SHashReducer hash_reduce) const {
148+
hash_reduce(pool_name);
149+
hash_reduce(size_hint_bytes);
150+
hash_reduce(target_access);
151+
hash_reduce(clock_frequency_hz);
152+
hash_reduce(read_bandwidth_bytes_per_cycle);
153+
hash_reduce(write_bandwidth_bytes_per_cycle);
154+
hash_reduce(read_latency_cycles);
155+
hash_reduce(write_latency_cycles);
156+
hash_reduce(target_burst_bytes);
157+
hash_reduce(is_internal);
158+
hash_reduce(constant_info_arr);
159+
}
160+
161+
static constexpr const char* _type_key = "tir.usmp.PoolInfo";
162+
static constexpr bool _type_has_method_sequal_reduce = true;
163+
static constexpr bool _type_has_method_shash_reduce = true;
164+
TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object);
165+
};
166+
#endif
47167

48168
/*!
49169
* \brief Describes an abstract memory buffer that will get allocated inside a pool.
@@ -150,20 +270,25 @@ class BufferInfoAnalysis : public ObjectRef {
150270
struct PoolAllocationNode : public Object {
151271
/*! \brief The assigned PoolInfo object */
152272
PoolInfo pool_info;
273+
/*! \brief The byte alignment where the tensor is supposed to be placed within the pool*/
274+
Integer byte_alignment;
153275
/*! \brief The byte offset where the tensor is supposed to be placed within the pool*/
154276
Integer byte_offset;
155277

156278
void VisitAttrs(tvm::AttrVisitor* v) {
157279
v->Visit("pool_info", &pool_info);
280+
v->Visit("byte_alignment", &byte_alignment);
158281
v->Visit("byte_offset", &byte_offset);
159282
}
160283

161284
bool SEqualReduce(const PoolAllocationNode* other, SEqualReducer equal) const {
162-
return equal(pool_info, other->pool_info) && equal(byte_offset, other->byte_offset);
285+
return equal(pool_info, other->pool_info) && equal(byte_alignment, other->byte_alignment) &&
286+
equal(byte_offset, other->byte_offset);
163287
}
164288

165289
void SHashReduce(SHashReducer hash_reduce) const {
166290
hash_reduce(pool_info);
291+
hash_reduce(byte_alignment);
167292
hash_reduce(byte_offset);
168293
}
169294

@@ -173,7 +298,7 @@ struct PoolAllocationNode : public Object {
173298

174299
class PoolAllocation : public ObjectRef {
175300
public:
176-
TVM_DLL PoolAllocation(PoolInfo pool_info, Integer byte_offset);
301+
TVM_DLL PoolAllocation(PoolInfo pool_info, Integer byte_alignment, Integer byte_offset);
177302
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolAllocation, ObjectRef, PoolAllocationNode);
178303
};
179304

@@ -187,22 +312,26 @@ struct AllocatedPoolInfoNode : public Object {
187312
Integer allocated_size;
188313
/*! \brief An optional associated pool Var*/
189314
Optional<Var> pool_var;
315+
/*! \brief pool initialization data */
316+
Array<ConstantInfo> constant_info_arr;
190317

191318
void VisitAttrs(tvm::AttrVisitor* v) {
192319
v->Visit("pool_info", &pool_info);
193320
v->Visit("allocated_size", &allocated_size);
194321
v->Visit("pool_var", &pool_var);
322+
v->Visit("constant_info_arr", &constant_info_arr);
195323
}
196324

197325
bool SEqualReduce(const AllocatedPoolInfoNode* other, SEqualReducer equal) const {
198326
return equal(pool_info, other->pool_info) && equal(allocated_size, other->allocated_size) &&
199-
equal(pool_var, other->pool_var);
327+
equal(pool_var, other->pool_var) && equal(constant_info_arr, other->constant_info_arr);
200328
}
201329

202330
void SHashReduce(SHashReducer hash_reduce) const {
203331
hash_reduce(pool_info);
204332
hash_reduce(allocated_size);
205333
hash_reduce(pool_var);
334+
hash_reduce(constant_info_arr);
206335
}
207336

208337
static constexpr const char* _type_key = "tir.usmp.AllocatedPoolInfo";
@@ -211,7 +340,8 @@ struct AllocatedPoolInfoNode : public Object {
211340

212341
class AllocatedPoolInfo : public ObjectRef {
213342
public:
214-
TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var = Var());
343+
TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var = Var(),
344+
Array<ConstantInfo> = {});
215345
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode);
216346
};
217347

@@ -243,6 +373,13 @@ static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_poo
243373
*/
244374
Integer CalculateExtentsSize(const AllocateNode* op);
245375

376+
/*!
377+
* \brief Calculate the size of the extents in bytes
378+
*
379+
* \param op the allocate const node
380+
*/
381+
Integer CalculateExtentsSize(const AllocateConstNode* op);
382+
246383
/*!
247384
* \brief Joins the Stmt nodes with PoolAllocation objects
248385
*
@@ -268,7 +405,6 @@ static constexpr const char* kPoolArgs = "pool_args";
268405
* as an Array.
269406
*/
270407
static constexpr const char* kPoolInfoIRModuleAttr = "pool_infos";
271-
272408
} // namespace attr
273409

274410
} // namespace tvm

python/tvm/script/tir/scope_handler.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,20 @@ class AllocateConst(WithScopeHandler):
166166
"""
167167

168168
def __init__(self):
169-
def allocate_const(raw_data, dtype, shape, span=None):
169+
def allocate_const(raw_data, dtype, shape, annotations=None, span=None):
170170
list_data = []
171171
for i in raw_data:
172172
list_data.append(i.value)
173173
nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))
174-
n = tvm.tir.AllocateConst(self.buffer_var, dtype, shape, nd_data, self.body, span=span)
174+
n = tvm.tir.AllocateConst(
175+
self.buffer_var,
176+
dtype,
177+
shape,
178+
nd_data,
179+
self.body,
180+
annotations=annotations,
181+
span=span,
182+
)
175183
return n
176184

177185
super().__init__(allocate_const, concise_scope=True, def_symbol=True)
@@ -199,7 +207,7 @@ def enter_scope(
199207
else:
200208
raise Exception("Internal Bug")
201209

202-
def setup_buffer_var(data, dtype, shape, span: Span = None):
210+
def setup_buffer_var(data, dtype, shape, annotations: dict = None, span: Span = None):
203211
"""Setup buffer var for a given type."""
204212
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
205213
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)

python/tvm/tir/stmt.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,13 +364,16 @@ class AllocateConst(Stmt):
364364
body : Stmt
365365
The body statement.
366366
367+
annotations : Optional[Map]
368+
Additional annotations about the allocation.
369+
367370
span : Optional[Span]
368371
The location of this itervar in the source code.
369372
"""
370373

371-
def __init__(self, buffer_var, dtype, extents, condition, body, span=None):
374+
def __init__(self, buffer_var, dtype, extents, condition, body, annotations=None, span=None):
372375
self.__init_handle_by_constructor__(
373-
_ffi_api.AllocateConst, buffer_var, dtype, extents, condition, body, span
376+
_ffi_api.AllocateConst, buffer_var, dtype, extents, condition, body, annotations, span
374377
)
375378

376379

python/tvm/tir/usmp/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,18 @@ class PoolAllocation(Object):
8383
pool_info : PoolInfo
8484
The PoolInfo to which this allocation corresponds to
8585
86+
byte_alignment : int
87+
The alignment in the pool where the allocate node should be placed
88+
8689
byte_offset : int
8790
The offset in the pool where the allocate node should be placed
8891
8992
"""
9093

91-
def __init__(self, pool_info: PoolInfo, byte_offset: int):
94+
def __init__(self, pool_info: PoolInfo, byte_alignment: int, byte_offset: int):
9295
self.__init_handle_by_constructor__(
9396
_ffi_api.PoolAllocation, # type: ignore # pylint: disable=no-member
9497
pool_info,
98+
byte_alignment,
9599
byte_offset,
96100
)

src/relay/backend/aot_executor_codegen.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -706,9 +706,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
706706
* brief Run USMP to plan memory for lowered IRModule
707707
*/
708708
IRModule PlanMemoryWithUSMP(const IRModule& mod) {
709-
Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
710-
Integer workspace_byte_alignment =
711-
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
709+
Integer workspace_byte_alignment = getModuleAlignment(mod);
712710
IRModule lowered_mod = mod->ShallowCopy();
713711
lowered_mod = tir::transform::UnifiedStaticMemoryPlanner()(lowered_mod);
714712
// Update workspace size based on the pool allocations.
@@ -748,9 +746,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
748746
* brief Run StorageRewrite to plan memory for lowered IRModule
749747
*/
750748
IRModule PlanMemoryWithStorageRewrite(const IRModule& mod) {
751-
Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
752-
Integer workspace_byte_alignment =
753-
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
749+
Integer workspace_byte_alignment = getModuleAlignment(mod);
754750
IRModule lowered_mod = mod->ShallowCopy();
755751
// Running StorageRewrite just on the main function
756752
tir::PrimFunc tir_main_func =
@@ -773,6 +769,11 @@ class AOTExecutorCodegen : public MixedModeVisitor {
773769
return lowered_mod;
774770
}
775771

772+
Integer getModuleAlignment(const IRModule& mod) {
773+
Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
774+
return executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
775+
}
776+
776777
protected:
777778
/*! \brief mod */
778779
runtime::Module* mod_;
@@ -837,10 +838,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {
837838
ICHECK(target_host_.defined()) << "require a target_host to be given for AOT codegen";
838839
VLOG(1) << "target host: " << target_host_->ToDebugString();
839840

841+
Integer workspace_byte_alignment = getModuleAlignment(mod);
842+
840843
Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
841844
String interface_api = executor_config->GetAttr<String>("interface-api").value_or("packed");
842-
Integer workspace_byte_alignment =
843-
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
844845
use_unpacked_api_ = executor_config->GetAttr<Bool>("unpacked-api").value_or(Bool(false));
845846

846847
// TODO(mbs): Plumb from compiler config

src/target/source/codegen_params.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream&
238238
}
239239

240240
default:
241-
CHECK(false) << "Data type not supported";
241+
CHECK(false) << "Data type '" << arr_type << "' not supported";
242242
}
243243

244244
os.flags(old_fmtflags);

0 commit comments

Comments
 (0)