Skip to content

Commit 1939bec

Browse files
committed
add converter for MXNet slice in Relay
1 parent 411c973 commit 1939bec

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

python/tvm/relay/frontend/mxnet.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,22 @@ def _mx_batch_norm(inputs, attrs):
172172
return _op.nn.batch_norm(*inputs, **new_attrs)
173173

174174

175+
def _mx_slice(inputs, attrs):
176+
new_attrs = {}
177+
begin = attrs.get_int_tuple('begin', None)
178+
end = attrs.get_int_tuple('end', None)
179+
stride = attrs.get_int_tuple('step', None)
180+
print(begin, end, stride)
181+
if begin is None or end is None:
182+
raise RuntimeError("begin and end are required parameters.")
183+
if None in begin or None in end:
184+
raise RuntimeError("None in begin or end is not supported yet.")
185+
new_attrs = {'begin': begin, 'end': end}
186+
if stride is not None:
187+
new_attrs['stride'] = stride
188+
return _op.strided_slice(inputs[0], **new_attrs)
189+
190+
175191
def _mx_split(inputs, attrs):
176192
axis = attrs.get_int("axis", 1)
177193
new_attrs = {}
@@ -368,6 +384,7 @@ def _mx_roi_align(inputs, attrs):
368384
"BatchNorm" : _mx_batch_norm,
369385
"BatchNorm_v1" : _mx_batch_norm,
370386
"LRN" : _mx_lrn,
387+
"slice" : _mx_slice,
371388
"SliceChannel" : _mx_split,
372389
"split" : _mx_split,
373390
"expand_dims" : _mx_expand_dims,

tests/python/frontend/mxnet/test_forward.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ def test_forward_argmin():
190190
mx_sym = mx.sym.argmin(data, axis=0)
191191
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))
192192

193+
def test_forward_slice():
194+
data = mx.sym.var('data')
195+
mx_sym = mx.sym.slice(data, begin=(0, 1), end=(2, 4))
196+
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 3))
197+
mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2))
198+
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2))
199+
193200
def test_forward_where():
194201
cond = mx.sym.var('cond')
195202
x = mx.sym.var('x')

0 commit comments

Comments
 (0)