Skip to content

Commit da0bc0e

Browse files
author
Alex Gladkov
committed
Add support for MXNet pad operator.
MXNet pad is described at: https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.pad Add support for parameter 'None' in MXNet slice operator. MXNet 'slice' is described at https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.slice Add support for MXNet cos, sin, arctan MXNet 'cos' is described at https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.cos MXNet 'sin' is described at https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.sin MXNet arctan is descirbed at https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.arctan Add support for MXNet 1D Convolution and 1D Deconvolution MXNet convolution is described at: https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.Convolution MXNet Deconvolution is described at: https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.Deconvolution
1 parent 4b431c6 commit da0bc0e

File tree

23 files changed

+340
-30
lines changed

23 files changed

+340
-30
lines changed

include/tvm/expr_operator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ TVM_DECLARE_INTRIN_UNARY(log);
521521
TVM_DECLARE_INTRIN_UNARY(popcount);
522522
TVM_DECLARE_INTRIN_UNARY(cos);
523523
TVM_DECLARE_INTRIN_UNARY(sin);
524+
TVM_DECLARE_INTRIN_UNARY(atan);
524525

525526
// Implementation details after this
526527
inline bool is_const(const Expr& x) {

include/tvm/relay/attrs/nn.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,18 @@ struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
405405
struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
406406
double pad_value;
407407
Array<Array<IndexExpr> > pad_width;
408+
std::string pad_mode;
408409

409410
TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
410411
TVM_ATTR_FIELD(pad_value).set_default(0.0)
411-
.describe("Specifies the strides of the convolution.");
412+
.describe("The value used for padding when mode is 'constant'.");
412413
TVM_ATTR_FIELD(pad_width)
413414
.describe("Number of values padded to the edges of each axis, "
414415
"in the format of ((before_1, after_1), ..., (before_N, after_N))");
416+
TVM_ATTR_FIELD(pad_mode).set_default("constant")
417+
.describe("Padding type to use. \"constant\" pads with constant_value, "
418+
"\"edge\" pads using the edge values of the input array, "
419+
"\"reflect\" pads by reflecting values with respect to the edges.");
415420
}
416421
};
417422

python/tvm/intrin.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,21 @@ def sin(x):
304304
"""
305305
return call_pure_intrin(x.dtype, "sin", x)
306306

307+
def atan(x):
308+
"""Take atan of input x.
309+
310+
Parameters
311+
----------
312+
x : Expr
313+
Input argument.
314+
315+
Returns
316+
-------
317+
y : Expr
318+
The result.
319+
"""
320+
return call_pure_intrin(x.dtype, "atan", x)
321+
307322
def sqrt(x):
308323
"""Take square root of input x.
309324

python/tvm/relay/frontend/common.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,8 @@ def get_int_tuple(self, key, default=RequiredAttr()):
124124
"""
125125
if key in self.attrs:
126126
tshape = self.attrs[key]
127-
ret = []
128-
for x in tshape.strip('()[]').split(','):
129-
x = x.strip()
130-
if not x:
131-
continue
132-
if x == "None":
133-
ret.append(None)
134-
else:
135-
ret.append(int(x))
136-
return tuple(ret)
127+
return tuple(int(x) if x.strip("- ").isdigit() else None
128+
for x in tshape.strip('()[]').split(',') if x)
137129
if isinstance(default, RequiredAttr):
138130
raise AttributeError("Required attribute {} not found.".format(key))
139131
return default

python/tvm/relay/frontend/mxnet.py

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,55 @@ def _mx_zeros(inputs, attrs):
112112
return _op.zeros(shape=shape, dtype=dtype)
113113

114114

115+
def _mx_conv(inputs, attrs):
116+
kernel_size = attrs.get_int_tuple("kernel")
117+
if len(kernel_size) == 2:
118+
return _mx_conv2d(inputs, attrs)
119+
elif len(kernel_size) == 1:
120+
return _mx_conv1d(inputs, attrs)
121+
else:
122+
raise tvm.error.OpAttributeInvalid(
123+
'1D or 2D kernels only are supported for operator Convolution')
124+
125+
def _mx_conv1d(inputs, attrs):
126+
kernel_size = attrs.get_int_tuple("kernel")
127+
if len(kernel_size) != 1:
128+
raise tvm.error.OpAttributeInvalid(
129+
'Non 1D or 2D kernels are not supported for operator Convolution')
130+
data_layout = attrs.get_str("layout", "NCW")
131+
# MXNet Conv1D only supports ‘NCW’ layout for now.
132+
if data_layout != "NCW":
133+
raise tvm.error.OpAttributeInvalid(
134+
'Only "NCW" data layout is supported for 1D Convolution')
135+
data_layout = "NCHW"
136+
channel_axis = 1
137+
kernel_layout = "OIHW"
138+
139+
new_attrs = {}
140+
new_attrs["channels"] = attrs.get_int("num_filter")
141+
new_attrs["kernel_size"] = (1,) + kernel_size
142+
new_attrs["strides"] = (1,) + attrs.get_int_tuple("stride", (1,))
143+
new_attrs["padding"] = (0,) + attrs.get_int_tuple("pad", (0,))
144+
new_attrs["dilation"] = (1,) + attrs.get_int_tuple("dilate", (1,))
145+
new_attrs["groups"] = attrs.get_int("num_group", 1)
146+
new_attrs["data_layout"] = data_layout
147+
new_attrs["kernel_layout"] = kernel_layout
148+
use_bias = not attrs.get_bool("no_bias", False)
149+
data = _op.expand_dims(inputs[0], axis=2)
150+
kernel = _op.expand_dims(inputs[1], axis=2)
151+
res = _op.nn.conv2d(data, kernel, **new_attrs)
152+
if use_bias:
153+
assert len(inputs) == 3
154+
res = _op.nn.bias_add(res, inputs[2], axis=channel_axis)
155+
res = _op.squeeze(res, axis=[2])
156+
return res
157+
158+
115159
def _mx_conv2d(inputs, attrs):
116160
kernel_size = attrs.get_int_tuple("kernel")
117161
if len(kernel_size) != 2:
118162
raise tvm.error.OpAttributeInvalid(
119-
'Non-2D kernels are not supported for operator Conv2D.')
163+
'Non 1D or 2D kernels are not supported for operator Convolution')
120164
data_layout = attrs.get_str("layout", "NCHW")
121165
channel_axis = _get_channel_axis(data_layout, "conv2d")
122166

@@ -142,6 +186,51 @@ def _mx_conv2d(inputs, attrs):
142186
return res
143187

144188

189+
def _mx_conv_transpose(inputs, attrs):
190+
kernel_size = attrs.get_int_tuple("kernel")
191+
if len(kernel_size) == 2:
192+
return _mx_conv2d_transpose(inputs, attrs)
193+
elif len(kernel_size) == 1:
194+
return _mx_conv1d_transpose(inputs, attrs)
195+
else:
196+
raise tvm.error.OpAttributeInvalid(
197+
'1D or 2D kernels only are supported for operator Convolution')
198+
199+
200+
def _mx_conv1d_transpose(inputs, attrs):
201+
if "target_shape" in attrs.attrs:
202+
raise tvm.error.OpAttributeUnImplemented(
203+
'Attribute "target_shape" is not supported for operator Conv2D-transpose.')
204+
data_layout = attrs.get_str("layout", "NCW")
205+
if data_layout != "NCW":
206+
raise tvm.error.OpAttributeInvalid(
207+
'Only "NCW" data layout is supported for 1D Convolution')
208+
data_layout = "NCHW"
209+
channel_axis = 1
210+
kernel_layout = "OIHW"
211+
212+
new_attrs = {}
213+
new_attrs["channels"] = attrs.get_int("num_filter")
214+
new_attrs["kernel_size"] = (1,) + attrs.get_int_tuple("kernel")
215+
new_attrs["strides"] = (1,) + attrs.get_int_tuple("stride", (1,))
216+
new_attrs["output_padding"] = (0,) + attrs.get_int_tuple("adj", (0,))
217+
new_attrs["padding"] = (0,) + attrs.get_int_tuple("pad", (0,))
218+
new_attrs["dilation"] = (1,) + attrs.get_int_tuple("dilate", (1,))
219+
new_attrs["groups"] = attrs.get_int("num_group", 1)
220+
new_attrs["data_layout"] = data_layout
221+
new_attrs["kernel_layout"] = kernel_layout
222+
use_bias = not attrs.get_bool("no_bias", True)
223+
data = _op.expand_dims(inputs[0], axis=2)
224+
kernel = _op.expand_dims(inputs[1], axis=2)
225+
res = _op.nn.conv2d_transpose(data, kernel, **new_attrs)
226+
227+
if use_bias:
228+
assert len(inputs) == 3
229+
res = _op.nn.bias_add(res, inputs[2], axis=channel_axis)
230+
res = _op.squeeze(res, axis=[2])
231+
return res
232+
233+
145234
def _mx_conv2d_transpose(inputs, attrs):
146235
if "target_shape" in attrs.attrs:
147236
raise tvm.error.OpAttributeUnImplemented(
@@ -257,13 +346,7 @@ def _mx_slice(inputs, attrs):
257346
if end is None:
258347
raise tvm.error.OpAttributeRequired(
259348
'Attribute "end" not found in operator Slice.')
260-
if None in begin:
261-
data_shape = _infer_type(inputs[0]).checked_type.shape
262-
for i, beg in enumerate(begin):
263-
if beg is None:
264-
assert end[i] is None
265-
begin[i] = 0
266-
end[i] = data_shape[i]
349+
begin = tuple(x if x is not None else 0 for x in begin)
267350
new_attrs = {'begin': begin, 'end': end}
268351
if stride is not None:
269352
new_attrs['strides'] = stride
@@ -373,6 +456,27 @@ def _mx_expand_dims(inputs, attrs):
373456
axis = attrs.get_int("axis")
374457
return _op.expand_dims(inputs[0], axis=axis)
375458

459+
def _mx_pad(inputs, attrs):
460+
pad_mode = attrs.get_str('mode', None)
461+
if pad_mode is None:
462+
raise tvm.error.OpAttributeRequired(
463+
'Attribute "mode" not found in operator pad.')
464+
if pad_mode not in ['constant', 'edge', 'reflect']:
465+
raise tvm.error.OpAttributeInvalid(
466+
'Value ' + mode + ' in attribute "mode" is not valid')
467+
pad_width = attrs.get_int_tuple('pad_width', None)
468+
if pad_width is None:
469+
raise tvm.error.OpAttributeRequired(
470+
'Attribute "pad_width" not found in operator pad.')
471+
if None in pad_width:
472+
raise tvm.error.OpAttributeInvalid(
473+
'Value None in attribute "pad_width" of operator Slice is not valid.')
474+
constant_value = attrs.get_float('constant_value', 0.0)
475+
padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], pad_width[1::2]))
476+
return _op.nn.pad(data=inputs[0],
477+
pad_width=padding,
478+
pad_value=constant_value,
479+
pad_mode=pad_mode)
376480

377481
def _mx_leaky_relu(inputs, attrs):
378482
act_type = attrs.get_str("act_type")
@@ -931,6 +1035,8 @@ def _mx_one_hot(inputs, attrs):
9311035
"ones_like",
9321036
"where",
9331037
"gather_nd",
1038+
"cos",
1039+
"sin"
9341040
]
9351041

9361042
_convert_map = {
@@ -943,6 +1049,7 @@ def _mx_one_hot(inputs, attrs):
9431049
"broadcast_mod" : _rename(_op.mod),
9441050
"broadcast_maximum" : _rename(_op.maximum),
9451051
"broadcast_minimum" : _rename(_op.minimum),
1052+
"arctan" : _rename(_op.atan),
9461053
"broadcast_equal" : _mx_compare(_op.equal, _rename),
9471054
"broadcast_not_equal" : _mx_compare(_op.not_equal, _rename),
9481055
"broadcast_greater" : _mx_compare(_op.greater, _rename),
@@ -1018,9 +1125,9 @@ def _mx_one_hot(inputs, attrs):
10181125
"_zeros" : _mx_zeros,
10191126
"FullyConnected": _mx_fully_connected,
10201127
"Activation" : _mx_activations,
1021-
"Convolution" : _mx_conv2d,
1128+
"Convolution" : _mx_conv,
10221129
"Convolution_v1": _mx_conv2d,
1023-
"Deconvolution" : _mx_conv2d_transpose,
1130+
"Deconvolution" : _mx_conv_transpose,
10241131
"Pooling" : _mx_pooling,
10251132
"Pooling_v1" : _mx_pooling,
10261133
"Dropout" : _mx_dropout,
@@ -1044,6 +1151,8 @@ def _mx_one_hot(inputs, attrs):
10441151
"_full" : _mx_full,
10451152
"repeat" : _mx_repeat,
10461153
"tile" : _mx_tile,
1154+
"pad" : _mx_pad,
1155+
"Pad" : _mx_pad,
10471156
"take" : _mx_take,
10481157
"reverse" : _mx_reverse,
10491158
"squeeze" : _mx_squeeze,

python/tvm/relay/frontend/tensorflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
# Licensed to the Apache Software Foundation (ASF) under one
23
# or more contributor license agreements. See the NOTICE file
34
# distributed with this work for additional information

python/tvm/relay/op/_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
register_schedule("log1p", schedule_broadcast)
3030
register_schedule("cos", schedule_broadcast)
3131
register_schedule("sin", schedule_broadcast)
32+
register_schedule("atan", schedule_broadcast)
3233
register_schedule("exp", schedule_broadcast)
3334
register_schedule("erf", schedule_broadcast)
3435
register_schedule("sqrt", schedule_broadcast)

python/tvm/relay/op/_tensor_grad.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def sin_grad(orig, grad):
6060
x = orig.args[0]
6161
return [grad * cos(x)]
6262

63+
@register_gradient("atan")
64+
def atan_grad(orig, grad):
65+
"""Returns [grad * 1 / (1 + x ^ 2)]"""
66+
x = orig.args[0]
67+
a = const(2.0)
68+
return [grad * ones_like(x) / (ones_like(x) + power(x, a))]
6369

6470
@register_gradient("exp")
6571
def exp_grad(orig, grad):

python/tvm/relay/op/nn/nn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,8 @@ def prelu(data, alpha, axis=1):
673673

674674
def pad(data,
675675
pad_width,
676-
pad_value=0.0):
676+
pad_value=0.0,
677+
pad_mode='constant'):
677678
r"""Padding
678679
679680
This operator takes in a tensor and pads each axis by the specified
@@ -688,13 +689,16 @@ def pad(data,
688689
of ((before_1, after_1), ..., (before_N, after_N))
689690
pad_value: float, optional, default=0.0
690691
The value used for padding
691-
692+
pad_mode: 'constant', 'edge', 'reflect'
693+
'constant' pads with constant_value pad_value
694+
'edge' pads using the edge values of the input array
695+
'reflect' pads by reflecting values with respect to the edge
692696
Returns
693697
-------
694698
result : tvm.relay.Expr
695699
The computed result.
696700
"""
697-
return _make.pad(data, pad_width, pad_value)
701+
return _make.pad(data, pad_width, pad_value, pad_mode)
698702

699703

700704
def mirror_pad(data,

python/tvm/relay/op/tensor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,21 @@ def sin(data):
7676
"""
7777
return _make.sin(data)
7878

79+
def atan(data):
80+
"""Compute elementwise atan of data.
81+
82+
Parameters
83+
----------
84+
data : relay.Expr
85+
The input data
86+
87+
Returns
88+
-------
89+
result : relay.Expr
90+
The computed result.
91+
"""
92+
return _make.atan(data)
93+
7994
def exp(data):
8095
"""Compute elementwise exp of data.
8196

0 commit comments

Comments
 (0)