Skip to content

Commit f7286ba

Browse files
committed
[Relay][AlterOp] NHWC to NCHWc support for Pool, pad, concatenate, sum.
1 parent d703fb4 commit f7286ba

File tree

6 files changed

+321
-35
lines changed

6 files changed

+321
-35
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -748,10 +748,12 @@ def convert_conv(self, op, conv_type):
748748
elif padding == Padding.SAME:
749749
pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h)
750750
pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
751-
in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
752-
(pad_top, pad_bottom),
753-
(pad_left, pad_right),
754-
(0, 0)))
751+
do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0)
752+
if do_pad:
753+
in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
754+
(pad_top, pad_bottom),
755+
(pad_left, pad_right),
756+
(0, 0)))
755757
else:
756758
raise tvm.error.OpAttributeUnImplemented(
757759
'Padding format {} is not supported for operator Conv.'.format(padding))

src/relay/op/nn/pad.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY

src/relay/op/nn/pooling.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,9 @@ Array<Array<Layout> > Pool2DInferCorrectLayout(
4747
T *params = const_cast<T*>(attrs.as<T>());
4848

4949
if (new_in_layouts.defined()) {
50+
// Set the pool with the new layout.
5051
CHECK_EQ(new_in_layouts.size(), 1);
51-
52-
Layout raw_layout(params->layout);
53-
Layout input = new_in_layouts[0];
54-
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
55-
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
56-
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
57-
params->layout = input.name(); // modify self to follow the input layout
58-
}
52+
params->layout = new_in_layouts[0].name();
5953
}
6054

6155
Layout inferred_layout(params->layout);

src/relay/op/tensor/reduce.cc

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -119,6 +119,59 @@ Array<Integer> GetExcludeAxes(size_t indim,
119119
return r_axes;
120120
}
121121

122+
// Return the modified layout for AlterOpLayout pass.
123+
Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,
124+
const Array<Layout>& new_in_layouts,
125+
const Array<Layout>& old_in_layouts,
126+
const Array<Array<IndexExpr>>& old_in_shapes) {
127+
// NOTE: Discard "const" qualifier here.
128+
ReduceAttrs* params = const_cast<ReduceAttrs*>(attrs.as<ReduceAttrs>());
129+
130+
// Get the reduce axes.
131+
uint32_t indim = old_in_shapes[0].size();
132+
auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);
133+
134+
Layout ret = Layout::Undef();
135+
if (new_in_layouts.defined() && r_axes.size()) {
136+
// Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
137+
// modified layout axes.
138+
CHECK_EQ(new_in_layouts.size(), 1);
139+
CHECK_EQ(old_in_layouts.size(), 1);
140+
141+
// 1) Collect the original axes
142+
std::unordered_set<std::string> old_r_dims;
143+
for (auto r_axis : r_axes) {
144+
old_r_dims.emplace(old_in_layouts[0][r_axis].name());
145+
}
146+
147+
// 2) Collect the new axes by walking new_layout.
148+
tvm::Array<tvm::Integer> new_r_axes;
149+
std::string new_layout_string = "";
150+
int axis_index = 0;
151+
for (auto iter_var : new_in_layouts[0]->axes) {
152+
const auto& layout_axis = LayoutAxis::Get(iter_var);
153+
const std::string& layout_dim = layout_axis.name();
154+
if (old_r_dims.count(layout_dim)) {
155+
new_r_axes.push_back(tvm::Integer(axis_index));
156+
}
157+
// Collect only the primal axis.
158+
if (layout_axis.IsPrimal()) {
159+
new_layout_string += layout_dim;
160+
axis_index++;
161+
}
162+
}
163+
164+
// 3) Set the new axis and layout.
165+
ret = Layout(new_layout_string);
166+
params->axis = new_r_axes;
167+
} else if (old_in_layouts.defined()) {
168+
// If the new layout is undefined, set the old layout as the inferred layout.
169+
CHECK_EQ(old_in_layouts.size(), 1);
170+
ret = old_in_layouts[0];
171+
}
172+
173+
return Array<Array<Layout>>{{ret}, {ret}};
174+
}
122175

123176
template<typename F>
124177
Array<Tensor> ReduceCompute(const Attrs& attrs,
@@ -325,6 +378,7 @@ Example::
325378
.set_attrs_type_key("relay.attrs.ReduceAttrs")
326379
.set_support_level(4)
327380
.add_type_rel("Reduce", ReduceRel)
381+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout)
328382
.set_attr<FTVMCompute>("FTVMCompute", SumCompute)
329383
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
330384

src/relay/op/tensor/transform.cc

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,22 +283,34 @@ Array<Array<Layout>> ConcatenateLayout(
283283
const Array<Layout>& new_in_layouts,
284284
const Array<Layout>& old_in_layouts,
285285
const Array<Array<IndexExpr>> &old_in_shapes) {
286-
const ConcatenateAttrs* param = attrs.as<ConcatenateAttrs>();
286+
ConcatenateAttrs* param = const_cast<ConcatenateAttrs*>(attrs.as<ConcatenateAttrs>());
287287

288288
size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
289289
static_cast<size_t>(param->axis);
290290

291291
Layout ret;
292+
bool is_new_layout_selected = false;
292293
if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
294+
// If all the new input layouts are same, the new in layout gets selected. For axis, the new
295+
// axis in the new layout is identified. The param->axis is then modified on the fly to conform
296+
// to the new input layout.
293297
const auto& concate_dim = old_in_layouts[0][axis];
294-
for (size_t i = 0; i < new_in_layouts.size(); ++i) {
295-
if (new_in_layouts[i].ndim() > axis &&
296-
new_in_layouts[i][axis] == concate_dim) {
297-
ret = new_in_layouts[i];
298-
break;
298+
bool all_input_layouts_same = true;
299+
for (auto new_layout : new_in_layouts) {
300+
if (!new_layout.Equals(new_in_layouts[0])) {
301+
all_input_layouts_same = false;
299302
}
300303
}
301-
} else { // this function is called on the original correct relay ir
304+
if (all_input_layouts_same) {
305+
auto new_index = new_in_layouts[0].IndexOf(concate_dim);
306+
ret = new_in_layouts[0];
307+
param->axis = new_index;
308+
is_new_layout_selected = true;
309+
}
310+
}
311+
312+
if (!is_new_layout_selected) {
313+
// this function is called on the original correct relay ir
302314
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
303315
if (old_in_layouts[i].defined()) {
304316
ret = old_in_layouts[i];

0 commit comments

Comments
 (0)