diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 85268de858d1..e336c7de7071 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -189,6 +189,19 @@ def _reshape(inputs, attrs): new_attrs['shape'] = _required_attr(attrs, 'shape') return _get_nnvm_op(op_name)(*inputs, **new_attrs) +def _slice(inputs, attrs): + begin = attrs.get('begin', None) + end = attrs.get('end', None) + stride = attrs.get('step', None) + if begin is None or end is None: + raise RuntimeError('begin and end are required params') + if 'None' in begin or 'None' in end: + raise RuntimeError('None in begin or end not supported yet...') + new_attrs = {'begin': begin, 'end': end} + if stride is not None: + new_attrs['stride'] = stride + return _get_nnvm_op('strided_slice')(inputs[0], **new_attrs) + def _split(inputs, attrs): op_name, new_attrs = 'split', {} axis = attrs.get('axis', 1) @@ -337,6 +350,7 @@ def _argmin(inputs, attrs): 'Pooling' : _pooling, 'Pooling_v1' : _pooling, 'Reshape' : _reshape, + 'slice' : _slice, 'SliceChannel' : _split, 'split' : _split, 'Softmax' : _rename('softmax'), diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py index e9225a4c7c50..97ffa20b3edc 100644 --- a/nnvm/tests/python/frontend/mxnet/test_forward.py +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -220,6 +220,14 @@ def test_forward_where(): tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy() tvm.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5) +def test_forward_slice(): + data = mx.sym.var('data') + mx_sym = mx.sym.slice(data, begin=(0, 1), end=(2, 4)) + verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 3)) + mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2)) + verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2)) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -242,4 +250,5 @@ def test_forward_where(): test_forward_argmax() test_forward_argmin() test_forward_where() + test_forward_slice() diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index c48a116a9d0e..9ef5f626393a 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -172,6 +172,21 @@ def _mx_batch_norm(inputs, attrs): return _op.nn.batch_norm(*inputs, **new_attrs) +def _mx_slice(inputs, attrs): + new_attrs = {} + begin = attrs.get_int_tuple('begin', None) + end = attrs.get_int_tuple('end', None) + stride = attrs.get_int_tuple('step', None) + if begin is None or end is None: + raise RuntimeError("begin and end are required parameters.") + if None in begin or None in end: + raise RuntimeError("None in begin or end is not supported yet.") + new_attrs = {'begin': begin, 'end': end} + if stride is not None: + new_attrs['strides'] = stride + return _op.strided_slice(inputs[0], **new_attrs) + + def _mx_split(inputs, attrs): axis = attrs.get_int("axis", 1) new_attrs = {} @@ -368,6 +383,7 @@ def _mx_roi_align(inputs, attrs): "BatchNorm" : _mx_batch_norm, "BatchNorm_v1" : _mx_batch_norm, "LRN" : _mx_lrn, + "slice" : _mx_slice, "SliceChannel" : _mx_split, "split" : _mx_split, "expand_dims" : _mx_expand_dims, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index ca1bdbbbefc9..671316079308 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -190,6 +190,13 @@ def test_forward_argmin(): mx_sym = mx.sym.argmin(data, axis=0) verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,)) +def test_forward_slice(): + data = mx.sym.var('data') + mx_sym = mx.sym.slice(data, begin=(0, 1), end=(2, 4)) + verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 3)) + mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2)) + verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2)) + def test_forward_where(): cond = mx.sym.var('cond') x = mx.sym.var('x')