@@ -44,6 +44,126 @@ constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm";
4444
4545namespace tir {
4646namespace usmp {
47+ /*
48+ * \brief The ConstantInfoNode contains numeric literal in RO pool
49+ */
50+ struct ConstantInfoNode : public Object {
51+ String name_hint;
52+ Integer byte_alignment;
53+ Integer byte_offset;
54+ runtime::NDArray data;
55+
56+ void VisitAttrs (tvm::AttrVisitor* v) {
57+ v->Visit (" constant_names" , &name_hint);
58+ v->Visit (" constant_alignment" , &byte_alignment);
59+ v->Visit (" constant_offsets" , &byte_offset);
60+ v->Visit (" constant_data" , &data);
61+ }
62+
63+ bool SEqualReduce (const ConstantInfoNode* other, SEqualReducer equal) const {
64+ return equal (name_hint, other->name_hint ) && equal (byte_alignment, other->byte_alignment ) &&
65+ equal (byte_offset, other->byte_offset ) && equal (data, other->data );
66+ }
67+
68+ void SHashReduce (SHashReducer hash_reduce) const {
69+ hash_reduce (name_hint);
70+ hash_reduce (byte_alignment);
71+ hash_reduce (byte_offset);
72+ hash_reduce (data);
73+ }
74+
75+ static constexpr const char * _type_key = " tir.usmp.ConstantInfo" ;
76+ static constexpr bool _type_has_method_sequal_reduce = true ;
77+ static constexpr bool _type_has_method_shash_reduce = true ;
78+ TVM_DECLARE_FINAL_OBJECT_INFO (ConstantInfoNode, Object);
79+ };
80+
81+ class ConstantInfo : public ObjectRef {
82+ public:
83+ TVM_DLL ConstantInfo (String name, Integer byte_alignment, Integer byte_offset,
84+ runtime::NDArray data);
85+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS (ConstantInfo, ObjectRef, ConstantInfoNode);
86+ };
87+
88+ #if 0
89+ struct PoolInfoNode : public Object {
90+ /*! \brief The name of the memory pool */
91+ String pool_name;
92+ /*! \brief The expected size hint to be used by the allocator.
93+ * The size_hint_bytes is set to kUnrestrictedPoolSizeHint
94+ * to indicate the pool is not size restricted.
95+ */
96+ Integer size_hint_bytes;
97+ /*! \brief The accessibility from each Target */
98+ Map<Target, String> target_access; // 'rw' or 'ro'
99+ /*! \brief The clock frequency of the memory in Hz */
100+ Integer clock_frequency_hz;
101+ /*! \brief The read bandwidth in bytes/cycle */
102+ Integer read_bandwidth_bytes_per_cycle;
103+ /*! \brief The write bandwidth in bytes/cycle */
104+ Integer write_bandwidth_bytes_per_cycle;
105+ /*! \brief The read latency in cycles */
106+ Integer read_latency_cycles;
107+ /*! \brief The write latency in cycles */
108+ Integer write_latency_cycles;
109+ /*! \brief The burst length in bytes for each Target */
110+ Map<Target, Integer> target_burst_bytes;
111+ /*! \brief Whether pool is internally generated.
112+ * The internal pools will be generated as part of
113+ * the entry point code generation of the executor
114+ */
115+ bool is_internal = false;
116+
117+ Array<ConstantInfo> constant_info_arr;
118+
119+ void VisitAttrs(tvm::AttrVisitor* v) {
120+ v->Visit("pool_name", &pool_name);
121+ v->Visit("size_hint_bytes", &size_hint_bytes);
122+ v->Visit("target_access", &target_access);
123+ v->Visit("clock_frequency_hz", &clock_frequency_hz);
124+ v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle);
125+ v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle);
126+ v->Visit("read_latency_cycles", &read_latency_cycles);
127+ v->Visit("write_latency_cycles", &write_latency_cycles);
128+ v->Visit("target_burst_bytes", &target_burst_bytes);
129+ v->Visit("is_internal", &is_internal);
130+ v->Visit("constant_info_arr", &constant_info_arr);
131+ }
132+
133+ bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const {
134+ return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) &&
135+ equal(target_access, other->target_access) &&
136+ equal(target_access, other->target_access) &&
137+ equal(clock_frequency_hz, other->clock_frequency_hz) &&
138+ equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) &&
139+ equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) &&
140+ equal(read_latency_cycles, other->read_latency_cycles) &&
141+ equal(write_latency_cycles, other->write_latency_cycles) &&
142+ equal(target_burst_bytes, other->target_burst_bytes) &&
143+ equal(is_internal, other->is_internal) &&
144+ equal(constant_info_arr, other->constant_info_arr);
145+ }
146+
147+ void SHashReduce(SHashReducer hash_reduce) const {
148+ hash_reduce(pool_name);
149+ hash_reduce(size_hint_bytes);
150+ hash_reduce(target_access);
151+ hash_reduce(clock_frequency_hz);
152+ hash_reduce(read_bandwidth_bytes_per_cycle);
153+ hash_reduce(write_bandwidth_bytes_per_cycle);
154+ hash_reduce(read_latency_cycles);
155+ hash_reduce(write_latency_cycles);
156+ hash_reduce(target_burst_bytes);
157+ hash_reduce(is_internal);
158+ hash_reduce(constant_info_arr);
159+ }
160+
161+ static constexpr const char* _type_key = "tir.usmp.PoolInfo";
162+ static constexpr bool _type_has_method_sequal_reduce = true;
163+ static constexpr bool _type_has_method_shash_reduce = true;
164+ TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object);
165+ };
166+ #endif
47167
48168/* !
49169 * \brief Describes an abstract memory buffer that will get allocated inside a pool.
@@ -150,20 +270,25 @@ class BufferInfoAnalysis : public ObjectRef {
150270struct PoolAllocationNode : public Object {
151271 /* ! \brief The assigned PoolInfo object */
152272 PoolInfo pool_info;
273+ /* ! \brief The byte alignment where the tensor is supposed to be placed within the pool*/
274+ Integer byte_alignment;
153275 /* ! \brief The byte offset where the tensor is supposed to be placed within the pool*/
154276 Integer byte_offset;
155277
156278 void VisitAttrs (tvm::AttrVisitor* v) {
157279 v->Visit (" pool_info" , &pool_info);
280+ v->Visit (" byte_alignment" , &byte_alignment);
158281 v->Visit (" byte_offset" , &byte_offset);
159282 }
160283
161284 bool SEqualReduce (const PoolAllocationNode* other, SEqualReducer equal) const {
162- return equal (pool_info, other->pool_info ) && equal (byte_offset, other->byte_offset );
285+ return equal (pool_info, other->pool_info ) && equal (byte_alignment, other->byte_alignment ) &&
286+ equal (byte_offset, other->byte_offset );
163287 }
164288
165289 void SHashReduce (SHashReducer hash_reduce) const {
166290 hash_reduce (pool_info);
291+ hash_reduce (byte_alignment);
167292 hash_reduce (byte_offset);
168293 }
169294
@@ -173,7 +298,7 @@ struct PoolAllocationNode : public Object {
173298
174299class PoolAllocation : public ObjectRef {
175300 public:
176- TVM_DLL PoolAllocation (PoolInfo pool_info, Integer byte_offset);
301+ TVM_DLL PoolAllocation (PoolInfo pool_info, Integer byte_alignment, Integer byte_offset);
177302 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS (PoolAllocation, ObjectRef, PoolAllocationNode);
178303};
179304
@@ -187,22 +312,26 @@ struct AllocatedPoolInfoNode : public Object {
187312 Integer allocated_size;
188313 /* ! \brief An optional associated pool Var*/
189314 Optional<Var> pool_var;
315+ /* ! \brief pool initialization data */
316+ Array<ConstantInfo> constant_info_arr;
190317
191318 void VisitAttrs (tvm::AttrVisitor* v) {
192319 v->Visit (" pool_info" , &pool_info);
193320 v->Visit (" allocated_size" , &allocated_size);
194321 v->Visit (" pool_var" , &pool_var);
322+ v->Visit (" constant_info_arr" , &constant_info_arr);
195323 }
196324
197325 bool SEqualReduce (const AllocatedPoolInfoNode* other, SEqualReducer equal) const {
198326 return equal (pool_info, other->pool_info ) && equal (allocated_size, other->allocated_size ) &&
199- equal (pool_var, other->pool_var );
327+ equal (pool_var, other->pool_var ) && equal (constant_info_arr, other-> constant_info_arr ) ;
200328 }
201329
202330 void SHashReduce (SHashReducer hash_reduce) const {
203331 hash_reduce (pool_info);
204332 hash_reduce (allocated_size);
205333 hash_reduce (pool_var);
334+ hash_reduce (constant_info_arr);
206335 }
207336
208337 static constexpr const char * _type_key = " tir.usmp.AllocatedPoolInfo" ;
@@ -211,7 +340,8 @@ struct AllocatedPoolInfoNode : public Object {
211340
212341class AllocatedPoolInfo : public ObjectRef {
213342 public:
214- TVM_DLL AllocatedPoolInfo (PoolInfo pool_info, Integer allocated_size, Var pool_var = Var());
343+ TVM_DLL AllocatedPoolInfo (PoolInfo pool_info, Integer allocated_size, Var pool_var = Var(),
344+ Array<ConstantInfo> = {});
215345 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS (AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode);
216346};
217347
@@ -243,6 +373,13 @@ static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_poo
243373 */
244374Integer CalculateExtentsSize (const AllocateNode* op);
245375
376+ /* !
377+ * \brief Calculate the size of the extents in bytes
378+ *
379+ * \param op the allocate const node
380+ */
381+ Integer CalculateExtentsSize (const AllocateConstNode* op);
382+
246383/* !
247384 * \brief Joins the Stmt nodes with PoolAllocation objects
248385 *
@@ -268,7 +405,6 @@ static constexpr const char* kPoolArgs = "pool_args";
268405 * as an Array.
269406 */
270407static constexpr const char * kPoolInfoIRModuleAttr = " pool_infos" ;
271-
272408} // namespace attr
273409
274410} // namespace tvm
0 commit comments