Skip to content

Commit 596333b

Browse files
MasterJH5574junrushaozxybazhspectrometerHBHSiyuan Feng
authored
[MetaSchedule] Schedule Rule: Auto Inline (apache#9943)
Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Xiyou Zhou <xiyou@octoml.ai> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Xiyou Zhou <xiyou@octoml.ai> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org>
1 parent 3c8de42 commit 596333b

9 files changed

Lines changed: 795 additions & 2 deletions

File tree

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ class ScheduleRule : public runtime::ObjectRef {
115115
* \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
116116
* \param into_producer If allows to inline a block into its producer
117117
* \param into_consumer If allows to inline a block into its consumer
118-
* \param into_cache_only If it only allows to inline into a block generated by cache_read/write
119118
* \param inline_const_tensor Always inline constant tensors
120119
* \param disallow_if_then_else Always disallow if-then-else-like constructs
121120
* \param require_ordered Always require the read-to-write mapping to be ordered
@@ -125,7 +124,6 @@ class ScheduleRule : public runtime::ObjectRef {
125124
*/
126125
TVM_DLL static ScheduleRule AutoInline(bool into_producer, //
127126
bool into_consumer, //
128-
bool into_cache_only, //
129127
bool inline_const_tensor, //
130128
bool disallow_if_then_else, //
131129
bool require_injective, //

python/tvm/meta_schedule/schedule_rule/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616
Meta Schedule schedule rules are used for modification of
1717
blocks in a schedule. See also PostOrderApply.
1818
"""
19+
from .auto_inline import AutoInline
1920
from .schedule_rule import PyScheduleRule, ScheduleRule
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions"""
18+
from typing import List, Optional
19+
20+
from tvm._ffi import register_object
21+
22+
from .. import _ffi_api
23+
from .schedule_rule import ScheduleRule
24+
25+
26+
@register_object("meta_schedule.AutoInline")
27+
class AutoInline(ScheduleRule):
28+
"""Rule that inlines spatial blocks if it satisfies some conditions
29+
30+
Parameters
31+
----------
32+
into_producer : bool
33+
If allows to inline a block into its producer
34+
into_consumer : bool
35+
If allows to inline a block into its consumer
36+
inline_const_tensor : bool
37+
Always inline constant tensors
38+
disallow_if_then_else : bool
39+
Always disallow if-then-else-like constructs
40+
require_injective : bool
41+
Always require the read-to-write mapping to be ordered
42+
require_ordered : bool
43+
Always require the read-to-write mapping to be injective
44+
disallow_op : Optional[List[str]]
45+
The operators that are disallowed in auto inline
46+
"""
47+
48+
def __init__(
49+
self,
50+
into_producer: bool,
51+
into_consumer: bool,
52+
inline_const_tensor: bool,
53+
disallow_if_then_else: bool,
54+
require_injective: bool,
55+
require_ordered: bool,
56+
disallow_op: Optional[List[str]] = None,
57+
) -> None:
58+
self.__init_handle_by_constructor__(
59+
_ffi_api.ScheduleRuleAutoInline, # type: ignore # pylint: disable=no-member
60+
into_producer,
61+
into_consumer,
62+
inline_const_tensor,
63+
disallow_if_then_else,
64+
require_injective,
65+
require_ordered,
66+
disallow_op,
67+
)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Default schedule rules"""
18+
from tvm.meta_schedule.schedule_rule import (
19+
AutoInline,
20+
ScheduleRule,
21+
)
22+
from tvm.target import Target
23+
24+
25+
def auto_inline(target: Target) -> ScheduleRule:
26+
"""Default schedule rules for auto inline"""
27+
if target.kind.name == "llvm":
28+
return AutoInline(
29+
into_producer=False,
30+
into_consumer=True,
31+
inline_const_tensor=True,
32+
disallow_if_then_else=True,
33+
require_injective=True,
34+
require_ordered=True,
35+
disallow_op=["tir.exp"],
36+
)
37+
if target.kind.name == "cuda":
38+
return AutoInline(
39+
into_producer=True,
40+
into_consumer=True,
41+
inline_const_tensor=True,
42+
disallow_if_then_else=False,
43+
require_injective=False,
44+
require_ordered=False,
45+
disallow_op=None,
46+
)
47+
raise NotImplementedError(f"{target.kind.name} is not supported")
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace meta_schedule {
23+
24+
/*! \brief The type of inline to be performed on a specific block */
25+
enum class InlineType : int32_t {
26+
/*! \brief No inline opportunity */
27+
kNoInline = 0,
28+
/*! \brief Inline the block into its consumer */
29+
kInlineIntoConsumer = 1,
30+
/*! \brief Inline the block into its producer */
31+
kInlineIntoProducer = 2,
32+
};
33+
34+
/*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */
35+
class AutoInlineNode : public ScheduleRuleNode {
36+
public:
37+
/*! \brief Checks if the specific block should be inlined */
38+
inline InlineType CheckInline(const tir::Schedule& sch, const tir::BlockRV& block_rv);
39+
40+
// Inherited from ScheduleRuleNode
41+
void InitializeWithTuneContext(const TuneContext& context) final {}
42+
43+
// Inherited from ScheduleRuleNode
44+
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
45+
InlineType inline_type = CheckInline(sch, block_rv);
46+
if (inline_type == InlineType::kInlineIntoConsumer) {
47+
sch->ComputeInline(block_rv);
48+
} else if (inline_type == InlineType::kInlineIntoProducer) {
49+
sch->ReverseComputeInline(block_rv);
50+
}
51+
return {sch};
52+
}
53+
54+
public:
55+
/*! \brief If allows to inline a block into its producer */
56+
bool into_producer;
57+
/*! \brief If allows to inline a block into its consumer */
58+
bool into_consumer;
59+
/*! \brief Always inline constant tensors */
60+
bool inline_const_tensor;
61+
/*! \brief Always disallow if-then-else-like constructs */
62+
bool disallow_if_then_else;
63+
/*! \brief Always require the read-to-write mapping to be injective to do auto inline */
64+
bool require_injective;
65+
/*! \brief Always require the read-to-write mapping to be ordered to do auto inline */
66+
bool require_ordered;
67+
/*! \brief The operators that are disallowed in auto inline */
68+
Array<Op> disallow_op;
69+
70+
void VisitAttrs(tvm::AttrVisitor* v) {
71+
v->Visit("into_producer", &into_producer);
72+
v->Visit("into_consumer", &into_consumer);
73+
v->Visit("inline_const_tensor", &inline_const_tensor);
74+
v->Visit("disallow_if_then_else", &disallow_if_then_else);
75+
v->Visit("require_injective", &require_injective);
76+
v->Visit("require_ordered", &require_ordered);
77+
v->Visit("disallow_op", &disallow_op);
78+
}
79+
80+
static constexpr const char* _type_key = "meta_schedule.AutoInline";
81+
TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode);
82+
};
83+
84+
inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
85+
const tir::BlockRV& block_rv) {
86+
using namespace tvm::tir;
87+
StmtSRef block_sref = sch->GetSRef(block_rv);
88+
ScheduleState state = sch->state();
89+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
90+
BlockRealize realize = GetBlockRealize(state, block_sref);
91+
// Cond 1. The block has only one write buffer
92+
if (block->writes.size() != 1) {
93+
return InlineType::kNoInline;
94+
}
95+
// Cond 2. For a block that generates a constant tensor, ignore all other conditions
96+
if (inline_const_tensor && block->reads.empty()) {
97+
return InlineType::kInlineIntoConsumer;
98+
}
99+
// Cond 3. The block doesn't contain any disallowed operators
100+
if (!disallow_op.empty() && HasOp(realize, disallow_op)) {
101+
return InlineType::kNoInline;
102+
}
103+
// Cond 4. The block doesn't have any if-then-else-like constructs
104+
if (disallow_if_then_else && HasIfThenElse(realize)) {
105+
return InlineType::kNoInline;
106+
}
107+
// Cond 5. The mapping from read indices to write indices are injective and ordered
108+
if (require_injective || require_ordered) {
109+
const BufferRegion& write_region = block->writes[0];
110+
for (const BufferRegion& read_region : block->reads) {
111+
bool injective, ordered;
112+
auto _ = std::ignore;
113+
std::tie(/*exists=*/_, /*surjective=*/_, injective, ordered, /*no_const_read=*/_,
114+
/*no_shift_read=*/_) = AnalyzeReadWritePattern(read_region, write_region);
115+
if (require_injective && injective == false) {
116+
return InlineType::kNoInline;
117+
}
118+
if (require_ordered && ordered == false) {
119+
return InlineType::kNoInline;
120+
}
121+
}
122+
}
123+
// Last cond: Check inline into the consumers or the spatial producer
124+
tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, //
125+
/*require_stage_pipeline=*/false, //
126+
/*require_subtree_compact_dataflow=*/false);
127+
if (into_consumer) {
128+
Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
129+
if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) {
130+
return InlineType::kInlineIntoConsumer;
131+
}
132+
}
133+
if (into_producer) {
134+
Array<tir::StmtSRef> producer_srefs = GetProducers(state, block_sref);
135+
if (producer_srefs.size() == 1 &&
136+
tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) &&
137+
CanReverseComputeInline(state, block_sref)) {
138+
return InlineType::kInlineIntoProducer;
139+
}
140+
}
141+
return InlineType::kNoInline;
142+
}
143+
144+
ScheduleRule ScheduleRule::AutoInline(bool into_producer, //
145+
bool into_consumer, //
146+
bool inline_const_tensor, //
147+
bool disallow_if_then_else, //
148+
bool require_injective, //
149+
bool require_ordered, //
150+
Optional<Array<String>> disallow_op) {
151+
ObjectPtr<AutoInlineNode> n = make_object<AutoInlineNode>();
152+
n->into_producer = into_producer;
153+
n->into_consumer = into_consumer;
154+
n->inline_const_tensor = inline_const_tensor;
155+
n->disallow_if_then_else = disallow_if_then_else;
156+
n->require_injective = require_injective;
157+
n->require_ordered = require_ordered;
158+
n->disallow_op.clear();
159+
if (disallow_op.defined()) {
160+
Array<String> op_names = disallow_op.value();
161+
n->disallow_op.reserve(op_names.size());
162+
for (const String& op_name : op_names) {
163+
n->disallow_op.push_back(Op::Get(op_name));
164+
}
165+
}
166+
return ScheduleRule(n);
167+
}
168+
169+
TVM_REGISTER_NODE_TYPE(AutoInlineNode);
170+
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline")
171+
.set_body_typed(ScheduleRule::AutoInline);
172+
173+
} // namespace meta_schedule
174+
} // namespace tvm

src/tir/schedule/analysis.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#define TVM_TIR_SCHEDULE_ANALYSIS_H_
2121

2222
#include <tvm/arith/analyzer.h>
23+
#include <tvm/ir/op.h>
2324
#include <tvm/tir/schedule/state.h>
2425

2526
#include <tuple>
@@ -442,6 +443,50 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S
442443
bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
443444
const StmtSRef& loop_sref, bool preserve_unit_loops);
444445

446+
/*!
447+
* \brief Checks if the given AST contains the specific operators
448+
* \param stmt The AST statement to be checked
449+
* \param ops The list of operators to be checked
450+
* \return A boolean indicating whether the AST contains the specific operators
451+
*/
452+
bool HasOp(const Stmt& stmt, const Array<Op>& ops);
453+
454+
/*!
455+
* \brief Checks if the given AST statement contains if-then-else, including
456+
* 1) IfThenElse statement
457+
* 2) Select expression
458+
* 3) The operator `tir.if_then_else`
459+
* 4) non-constant-true Block predicates
460+
* \param stmt The AST statement to be checked
461+
* \return A boolean indicating whether the statement contains the if-then-else pattern
462+
*/
463+
bool HasIfThenElse(const Stmt& stmt);
464+
465+
/*!
466+
* \brief Given the read/write region, extract the pattern of their index correspondence
467+
* namely, the mapping from read index to the write index.
468+
* \param read_region The read region
469+
* \param write_region The write region
470+
* \return A tuple of booleans, the extracted pattern
471+
* 0) exists: if the pattern is found
472+
* 1) surjective: if the pattern is surjective, i.e. each write index is mapped at least once
473+
* e.g. A[i, j] = B[i, i, j]
474+
* 2) injective: if the pattern is injective, i.e. each write index is mapped at most once.
475+
* e.g. A[i, j] = B[i]
476+
* 3) ordered: if the mapping is ordered
477+
* 4) no_const_read: if there is no constant indexing in the read indices,
478+
* e.g. A[i, j] = B[0, i, j]
479+
* 5) no_shift_read: if there is no constant shift in the read indices,
480+
* e.g. A[i, j] = B[i + 1, j]
481+
*/
482+
std::tuple</*exists=*/bool,
483+
/*surjective=*/bool,
484+
/*injective=*/bool,
485+
/*ordered=*/bool,
486+
/*no_const_read=*/bool,
487+
/*no_shift_read=*/bool>
488+
AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region);
489+
445490
} // namespace tir
446491
} // namespace tvm
447492

0 commit comments

Comments
 (0)