Skip to content

Commit b1e4906

Browse files
committed
add converter for MXNet slice in nnvm
1 parent 4ba3047 commit b1e4906

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

nnvm/python/nnvm/frontend/mxnet.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,19 @@ def _reshape(inputs, attrs):
189189
new_attrs['shape'] = _required_attr(attrs, 'shape')
190190
return _get_nnvm_op(op_name)(*inputs, **new_attrs)
191191

192+
def _slice(inputs, attrs):
193+
begin = attrs.get('begin', None)
194+
end = attrs.get('end', None)
195+
stride = attrs.get('step', None)
196+
if begin is None or end is None:
197+
raise RuntimeError('begin and end are required params')
198+
if 'None' in begin or 'None' in end:
199+
raise RuntimeError('None in begin or end not supported yet...')
200+
new_attrs = {'begin': begin, 'end': end}
201+
if stride is not None:
202+
new_attrs['stride'] = stride
203+
return _get_nnvm_op('strided_slice')(inputs[0], **new_attrs)
204+
192205
def _split(inputs, attrs):
193206
op_name, new_attrs = 'split', {}
194207
axis = attrs.get('axis', 1)
@@ -337,6 +350,7 @@ def _argmin(inputs, attrs):
337350
'Pooling' : _pooling,
338351
'Pooling_v1' : _pooling,
339352
'Reshape' : _reshape,
353+
'slice' : _slice,
340354
'SliceChannel' : _split,
341355
'split' : _split,
342356
'Softmax' : _rename('softmax'),

nnvm/tests/python/frontend/mxnet/test_forward.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,14 @@ def test_forward_where():
220220
tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
221221
tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5)
222222

223+
def test_forward_slice():
224+
data = mx.sym.var('data')
225+
mx_sym = mx.sym.slice(data, begin=(0, 1), end=(2, 4))
226+
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 3))
227+
mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2))
228+
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2))
229+
230+
223231
if __name__ == '__main__':
224232
test_forward_mlp()
225233
test_forward_vgg()
@@ -242,4 +250,5 @@ def test_forward_where():
242250
test_forward_argmax()
243251
test_forward_argmin()
244252
test_forward_where()
253+
test_forward_slice()
245254

0 commit comments

Comments
 (0)