Skip to content

Commit a5cb532

Browse files
icemelonWei Chen
authored andcommitted
[Relay][Frontend] Add slice axis op in mxnet converter (apache#2706)
* Add slice axis op in mxnet converter * Fix lint
1 parent 25ba3ea commit a5cb532

2 files changed

Lines changed: 47 additions & 0 deletions

File tree

python/tvm/relay/frontend/mxnet.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,34 @@ def _mx_slice(inputs, attrs):
194194
return _op.strided_slice(inputs[0], **new_attrs)
195195

196196

197+
def _mx_slice_axis(inputs, attrs):
198+
assert len(inputs) == 1
199+
shape = ir_pass.infer_type(inputs[0]).checked_type.shape
200+
axis = attrs.get_int("axis")
201+
ax_beg = attrs.get_int("begin")
202+
ax_end = attrs.get_str("end")
203+
if ax_end == "None":
204+
ax_end = int(shape[axis])
205+
else:
206+
ax_end = int(ax_end)
207+
if ax_beg < 0:
208+
ax_beg += int(shape[axis])
209+
if ax_end < 0:
210+
ax_end += int(shape[axis])
211+
assert ax_beg >= 0 and ax_beg < int(shape[axis])
212+
assert ax_end > ax_beg and ax_end <= int(shape[axis])
213+
begin = []
214+
end = []
215+
for i, dim in enumerate(shape):
216+
if i != axis:
217+
begin.append(0)
218+
end.append(dim)
219+
else:
220+
begin.append(ax_beg)
221+
end.append(ax_end)
222+
return _op.strided_slice(inputs[0], begin, end)
223+
224+
197225
def _mx_split(inputs, attrs):
198226
axis = attrs.get_int("axis", 1)
199227
new_attrs = {}
@@ -423,6 +451,7 @@ def _mx_roi_align(inputs, attrs):
423451
"BatchNorm_v1" : _mx_batch_norm,
424452
"LRN" : _mx_lrn,
425453
"slice" : _mx_slice,
454+
"slice_axis" : _mx_slice_axis,
426455
"SliceChannel" : _mx_split,
427456
"split" : _mx_split,
428457
"expand_dims" : _mx_expand_dims,

tests/python/frontend/mxnet/test_forward.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,23 @@ def test_forward_scalar_ops():
337337
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
338338

339339

340+
def test_forward_slice_axis():
341+
def verify(shape, axis, begin, end):
342+
data_np = np.random.uniform(size=shape).astype("float32")
343+
ref_res = mx.nd.slice_axis(mx.nd.array(data_np), axis, begin, end)
344+
mx_sym = mx.sym.slice_axis(mx.sym.var("data"), axis, begin, end)
345+
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape})
346+
for target, ctx in ctx_list():
347+
for kind in ["graph", "debug"]:
348+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
349+
op_res = intrp.evaluate(new_sym)(data_np)
350+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
351+
verify((3, 4), 0, 1, 2)
352+
verify((3, 4), 0, 1, None)
353+
verify((3, 4), 1, 0, 2)
354+
verify((3, 4), 1, -3, -1)
355+
356+
340357
if __name__ == '__main__':
341358
test_forward_mlp()
342359
test_forward_vgg()
@@ -363,3 +380,4 @@ def test_forward_scalar_ops():
363380
test_forward_broadcast_ops()
364381
test_forward_elemwise_ops()
365382
test_forward_scalar_ops()
383+
test_forward_slice_axis()

0 commit comments

Comments
 (0)