Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
}
}
// Last cond: Check inline into the consumers or the spatial producer
tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false);
tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref,
/*require_stage_pipeline=*/false);
if (into_consumer) {
Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) {
Expand Down
5 changes: 2 additions & 3 deletions src/meta_schedule/schedule_rule/random_compute_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,8 @@ class RandomComputeLocationNode : public ScheduleRuleNode {
return false;
}
// Cond 2. The block should be the direct child block of the root block.
if (GetScopeRoot(sch->state(), block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false)
if (GetScopeRoot(sch->state(), block_sref,
/*require_stage_pipeline=*/false)
->parent != nullptr) {
return false;
}
Expand Down
23 changes: 14 additions & 9 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,12 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
* \param self The schedule state
* \param sref The sref whose scope is to be checked
* \param require_stage_pipeline A boolean indicating whether to check stage pipeline
* \param require_subtree_compact_dataflow A boolean indicating whether to check
* subtree compact dataflow property. The scope root may have one or more subtrees rooted at
* its direct children, and this property requires all the blocks of the subtree
* that the specified sref is in to be complete block or reduction block.
* \throw ScheduleError if
* 1) the sref has been the root of the AST (so it has no scope root), or
* 2) require_stage_pipeline = true, but its scope root is not a stage pipeline
* 3) require_subtree_compact_dataflow = true, but the subtree that the sref is in doesn't satisfy
* the compact dataflow condition, i.e. a block in the subtree is neither complete block nor
* reduction block
* \return The block sref to the scope root
*/
StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline,
bool require_subtree_compact_dataflow);
StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline);

/*!
* \brief The information of a block scope, including the leaf blocks,
Expand Down Expand Up @@ -173,6 +165,19 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref);

/*!
* \brief Check the subtree compact dataflow property. The scope root may have one or more subtrees
* rooted at its direct children, and this property requires all the blocks of the subtree
* that the specified sref is in to be complete block or reduction block.
* \param self The schedule state
* \param subtree_root The sref of the subtree root to be checked
* \param scope_root_sref The scope root of the block
* \throw ScheduleError If the subtree that the sref is in doesn't satisfy the compact
* dataflow condition, i.e. a block in the subtree is neither complete block nor
* reduction block
*/
void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root,
const StmtSRef& scope_root_sref);
/*!
* \brief Check if the block is an output block, i.e. the block writes to at least a buffer that is
* not allocated under the current scope
Expand Down
85 changes: 42 additions & 43 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl

/******** Scope ********/

StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, //
bool require_stage_pipeline, //
bool require_subtree_compact_dataflow) {
StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref,
bool require_stage_pipeline) {
class RootBlockError : public ScheduleError {
public:
explicit RootBlockError(IRModule mod) : mod_(mod) {}
Expand Down Expand Up @@ -85,31 +84,6 @@ Definition of a scope that is a stage pipeline:
Block block_;
};

class NotCompactDataFlowError : public ScheduleError {
public:
explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block)
: mod_(std::move(mod)),
subtree_root_(std::move(subtree_root)),
violate_block_(std::move(violate_block)) {
ICHECK(subtree_root_->IsInstance<BlockNode>() || subtree_root_->IsInstance<ForNode>());
}
String FastErrorString() const final {
return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, "
"because some of its child block on SRef tree is neither a complete block nor a "
"reduction block";
}
String DetailRenderTemplate() const final {
return "The queried subtree root {0} in SRef tree does not have compact dataflow, because "
"its child block {1} on SRef tree is neither a complete block nor a reduction block";
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {subtree_root_, violate_block_}; }

IRModule mod_;
Stmt subtree_root_;
Block violate_block_;
};

StmtSRef scope_root_sref{nullptr};
StmtSRef scope_root_subtree{nullptr};
// Step 1. Find the scope root and the subtree that the given sref is in
Expand All @@ -135,18 +109,6 @@ Definition of a scope that is a stage pipeline:
throw NotStagePipelineError(self->mod, GetRef<Block>(block));
}
}
// Step 3. Handle `require_subtree_compact_dataflow`
if (require_subtree_compact_dataflow) {
Array<StmtSRef> child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_subtree);
for (const StmtSRef& block_sref : child_block_srefs) {
if (!IsCompleteBlock(self, block_sref, scope_root_sref) &&
!IsReductionBlock(self, block_sref, scope_root_sref)) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(scope_root_subtree->stmt),
GetRef<Block>(block));
}
}
}
return scope_root_sref;
}

Expand Down Expand Up @@ -401,6 +363,44 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl
reduction_block_error_code);
}

void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root,
const StmtSRef& scope_root_sref) {
class NotCompactDataFlowError : public ScheduleError {
public:
explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block)
: mod_(std::move(mod)),
subtree_root_(std::move(subtree_root)),
violate_block_(std::move(violate_block)) {
ICHECK(subtree_root_->IsInstance<BlockNode>() || subtree_root_->IsInstance<ForNode>());
}
String FastErrorString() const final {
return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, "
"because some of its child block on SRef tree is neither a complete block nor a "
"reduction block";
}
String DetailRenderTemplate() const final {
return "The queried subtree root {0} in SRef tree does not have compact dataflow, because "
"its child block {1} on SRef tree is neither a complete block nor a reduction block";
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {subtree_root_, violate_block_}; }

IRModule mod_;
Stmt subtree_root_;
Block violate_block_;
};

Array<StmtSRef> child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root);
for (const StmtSRef& block_sref : child_block_srefs) {
if (!IsCompleteBlock(self, block_sref, scope_root_sref) &&
!IsReductionBlock(self, block_sref, scope_root_sref)) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(subtree_root->stmt),
GetRef<Block>(block));
}
}
}

bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref) {
const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref);
Expand Down Expand Up @@ -1843,9 +1843,8 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
}

// Cond 2. The block is a reduction block and has trivial binding.
const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false);
const StmtSRef& scope_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false);
if (!IsReductionBlock(self, block_sref, scope_sref) //
|| !IsTrivialBinding(self, block_sref) //
|| HasBeenMultiLevelTiled(block_sref)) {
Expand Down
6 changes: 2 additions & 4 deletions src/tir/schedule/primitive/blockize_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,7 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) {

// Step 8: Update the cached flags
StmtSRef outer_block_sref = self->stmt2ref.at(outer_realize->block.get());
StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false,
/*require_subtree_compact_dataflow=*/false);
StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false);
bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root);
self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root));
self->block_info[scope_root].affine_binding = scope_block_affine_binding;
Expand Down Expand Up @@ -629,8 +628,7 @@ void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref,
self->Replace(block_sref, new_block, {{block_realize->block, new_block}});

// Step 6: Update the cached flags.
StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false,
/*require_subtree_compact_dataflow=*/false);
StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
self->UpdateScopeBlockInfo(static_cast<const BlockNode*>(scope_root->stmt)->body);
}

Expand Down
6 changes: 2 additions & 4 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
Buffer read_buffer =
GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, /*is_write=*/false);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);

// Step 2. Create CacheStageInfo
Expand Down Expand Up @@ -703,8 +702,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
Buffer write_buffer =
GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, /*is_write=*/true);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);

// Step 2. Creating CacheStageInfo
CacheStageInfo info;
Expand Down
7 changes: 4 additions & 3 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,14 +456,15 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
// Step 1. Bunch of checks
// Check condition 1) and 2): stage pipeline and subtree compact dataflow
// Check condition 1) : scope stage pipeline
StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/true);
/*require_stage_pipeline=*/true);
Block scope_root = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>());
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<StmtSRef> producer_srefs = GetProducers(block_sref, scope);
Array<StmtSRef> consumer_srefs = GetConsumers(block_sref, scope);
// Check condition 2) : `block` is a complete or reduction block
CheckCompleteOrReductionBlock(self, block_sref, scope_root_sref);
// Check condition 3): `block` and `loop` are under the same scope,
// and `loop` is not the ancestor of `block`
NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, scope_root_sref,
Expand Down
8 changes: 3 additions & 5 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,8 @@ void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
Block producer_block = GetRef<Block>(_producer_block);
Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block);
// Step 1. Get the scope block
StmtSRef scope_root_sref = GetScopeRoot(self, producer_block_sref, //
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
StmtSRef scope_root_sref = GetScopeRoot(self, producer_block_sref,
/*require_stage_pipeline=*/true);
// Step 2. Check completeness
CheckNotOutputBlock(self, producer_block_sref, scope_root_sref);
CheckCompleteBlock(self, producer_block_sref, scope_root_sref);
Expand Down Expand Up @@ -593,8 +592,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
Block consumer_block = GetRef<Block>(_consumer_block);
// Step 1. Get the scope block
StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, //
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
/*require_stage_pipeline=*/true);
Buffer inlined_buffer =
NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref);
// Step 2. Check completeness
Expand Down
6 changes: 3 additions & 3 deletions src/tir/schedule/primitive/for_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref
* parallelized/vectorized/bound.
*/
// Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow.
GetScopeRoot(self, loop_sref, //
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/true);
StmtSRef scope_root_sref = GetScopeRoot(self, loop_sref,
/*require_stage_pipeline=*/true);
CheckSubtreeCompactDataflow(self, loop_sref, scope_root_sref);

// Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each
// underlying block.
Expand Down
6 changes: 2 additions & 4 deletions src/tir/schedule/primitive/get_block_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent
}

Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sref) {
StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false,
/*require_stage_pipeline=*/false);
StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
Array<Dependency> edges = self->GetBlockScope(scope_root)->GetDepsByDst(block_sref);
Array<StmtSRef> results;
results.reserve(edges.size());
Expand All @@ -92,8 +91,7 @@ Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sr
}

Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) {
StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false,
/*require_stage_pipeline=*/false);
StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
Array<Dependency> edges = self->GetBlockScope(scope_root)->GetDepsBySrc(block_sref);
Array<StmtSRef> results;
results.reserve(edges.size());
Expand Down
6 changes: 2 additions & 4 deletions src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
}
// Cond 1. Check block is reduction
StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false,
/*require_subtree_compact_dataflow=*/false);
/*require_stage_pipeline=*/false);
CheckReductionBlock(self, block_sref, scope_root_sref);
// Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
Expand Down Expand Up @@ -1009,8 +1008,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
const StmtSRef& block_sref = self->stmt2ref.at(block_realize->block.get());
const Block& block = block_realize->block;
StmtSRef scope_root = GetScopeRoot(self, block_sref, //
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
/*require_stage_pipeline=*/true);
CheckReductionBlock(self, block_sref, scope_root);
const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop, rf_loop_sref);
if (rf_loop->kind != ForKind::kSerial) {
Expand Down
Loading