Skip to content

Commit c1e993c

Browse files
Lily Orth-SmithMatthew Brookhart
authored andcommitted
Dynamic ONNX importer: Upsampling and Pad (#2)
fix lint fix Call reference fix a type issue with expand fix a bad test refactor respond to review comments, fix batch matmul tests
1 parent b28e7e2 commit c1e993c

5 files changed

Lines changed: 117 additions & 74 deletions

File tree

include/tvm/relay/transform.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ TVM_DLL Pass FastMath();
213213
*
214214
* Searches the graph for dynamic ops. If the dynamic inputs to those ops are constants, it replaces
215215
* them with static ops and re-performs type inference and constant folding. The pass repeats
216-
* istself until the graph stops changing or we run too many iterations.
216+
* itself until the graph stops changing or we run too many iterations.
217217
*
218218
* \return The pass.
219219
*/

python/tvm/relay/frontend/onnx.py

Lines changed: 77 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,9 @@
2828
from .. import op as _op
2929
from .. import vision as _vision
3030

31-
from ..function import Function
32-
from ..expr import Call, Let
33-
from ..expr import If, Tuple, TupleGetItem
34-
from ..expr import RefCreate, RefRead, RefWrite
35-
from ..expr_functor import ExprFunctor
36-
from ..adt import Match, Clause
37-
from ..op.tensor import minimum as _minimum, maximum as _maximum
38-
3931
from .common import AttrCvt, Renamer
4032
from .common import get_relay_op, new_var, infer_shape, infer_channels
41-
from .common import infer_type, get_name, infer_value, infer_value_simulated
33+
from .common import infer_type, get_name, infer_value_simulated
4234

4335
__all__ = ['from_onnx']
4436

@@ -642,26 +634,22 @@ def _impl_v2(cls, inputs, attr, params):
642634

643635
@classmethod
644636
def _impl_v11(cls, inputs, attr, params):
645-
pad_width = []
646-
pads = infer_value_simulated(inputs[1], params).asnumpy()
637+
pads = inputs[1]
647638
if len(inputs) == 3:
648-
value = infer_value_simulated(inputs[2], params).asnumpy().item()
639+
value = _op.take(inputs[2], _op.const(0))
649640
else:
650641
value = 0
651-
attr["pad_value"] = value
652-
dims = int(len(pads) / 2)
653-
for i in range(dims):
654-
pad_width.append((pads[i], pads[i + dims]))
655-
attr['pad_width'] = pad_width
642+
643+
pads_shape = infer_shape(pads)
644+
dims = int(pads_shape[0] / 2)
645+
pad_width_expr = _op.transpose(_op.reshape(pads, (2, dims)))
656646
pad_mode = attr.get('mode', b'constant').decode('utf-8')
657-
if pad_mode in ['constant', 'edge', 'reflect']:
658-
attr['pad_mode'] = pad_mode
659-
attr.pop('mode', None)
660-
else:
647+
648+
if not pad_mode in ['constant', 'edge', 'reflect']:
661649
raise tvm.error.OpAttributeInvalid('Value ' + pad_mode +
662650
' in attribute "mode" is invalid for operator Pad.')
663651

664-
return AttrCvt('pad')(inputs[:1], attr, params)
652+
return _op.nn.pad(inputs[0], pad_width_expr, value, pad_mode=pad_mode)
665653

666654

667655
class ParametricSoftPlus(OnnxOpConverter):
@@ -869,17 +857,24 @@ class Upsample(OnnxOpConverter):
869857
@classmethod
870858
def _impl_v9(cls, inputs, attr, params):
871859
scales = attr.get('scales')
860+
861+
input_shape = infer_shape(inputs[0])
862+
dims = len(input_shape)
863+
872864
if not scales:
873865
#Here we are going to higher OPSET version.
874-
assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs))
866+
assert len(inputs) == 2, "Upsample op takes 2 inputs, {} given".format(len(inputs))
867+
875868
if get_name(inputs[1]) in params:
876869
scales = params[inputs[1].name_hint].asnumpy()
877-
else:
870+
elif dims == 5:
878871
scales = infer_value_simulated(inputs[1], params).asnumpy()
879-
inputs = inputs[:1]
880-
assert scales[0] == 1.0 and scales[1] == 1.0
881-
input_shape = infer_shape(inputs[0])
882-
dims = len(input_shape)
872+
else:
873+
scales = inputs[1]
874+
875+
if not isinstance(scales, _expr.Call):
876+
assert scales[0] == 1.0 and scales[1] == 1.0
877+
883878
mode = attr.get('mode')
884879
if mode == b'nearest':
885880
method = "nearest_neighbor"
@@ -888,21 +883,41 @@ def _impl_v9(cls, inputs, attr, params):
888883
else:
889884
raise tvm.error.OpAttributeInvalid(
890885
'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
891-
attr = {'scale_h': scales[-2], 'scale_w': scales[-1], 'method': method}
886+
887+
if method == 'nearest_neighbor':
888+
align_corners = False
889+
else:
890+
align_corners = True
891+
# in 3d case, we use the purely static op
892892
if dims == 5:
893-
assert len(scales) == 5
894-
attr['scale_d'] = scales[-3]
895-
attr['layout'] = 'NCDHW'
896-
op_name = 'upsampling3d'
893+
scale_h = scales[-2]
894+
scale_w = scales[-1]
895+
scale_d = scales[-3]
896+
layout = 'NCDHW'
897+
out = _op.nn.upsampling3d(inputs[0],
898+
scale_d,
899+
scale_h,
900+
scale_w,
901+
layout=layout,
902+
method=method)
903+
# in 2d case, use dynamic op
897904
else:
898-
assert len(scales) == 4
899-
attr['layout'] = 'NCHW'
900-
if method == 'nearest_neighbor':
901-
attr['align_corners'] = False
905+
if isinstance(scales, _expr.Call):
906+
scale_h = _op.take(scales, _op.const(3))
907+
scale_w = _op.take(scales, _op.const(4))
902908
else:
903-
attr['align_corners'] = True
904-
op_name = 'upsampling'
905-
return AttrCvt(op_name)(inputs, attr)
909+
assert len(scales) == 4
910+
scale_h = scales[-2]
911+
scale_w = scales[-1]
912+
layout = 'NCHW'
913+
914+
out = _op.nn.upsampling(inputs[0],
915+
scale_h,
916+
scale_w,
917+
layout=layout,
918+
method=method,
919+
align_corners=align_corners)
920+
return out
906921

907922

908923
class Shape(OnnxOpConverter):
@@ -1422,7 +1437,8 @@ class Expand(OnnxOpConverter):
14221437
"""
14231438
@classmethod
14241439
def _impl_v8(cls, inputs, attr, params):
1425-
in_shape = _op.shape_of(inputs[0])
1440+
dtype = infer_type(inputs[1]).checked_type.dtype
1441+
in_shape = _op.shape_of(inputs[0], dtype=dtype)
14261442
shape = inputs[1]
14271443

14281444
# Currently 'op.broadcast_to' expect the rank of the given 'shape'
@@ -1441,14 +1457,11 @@ def expand_shape(in_shape, shape):
14411457
in_dims = infer_shape(in_shape)[0]
14421458
new_dims = infer_shape(shape)[0]
14431459
if in_dims < new_dims:
1444-
in_shape = _op.concatenate([_expr.const([
1445-
1,
1446-
] * (new_dims - in_dims)), in_shape],
1447-
axis=0)
1460+
in_shape = _op.concatenate([_expr.const([1, ] * (new_dims - in_dims), dtype=dtype),
1461+
in_shape], axis=0)
14481462
elif new_dims > in_dims:
1449-
shape = _op.concatenate([_expr.const([
1450-
1,
1451-
] * (in_dims - new_dims)), shape], axis=0)
1463+
shape = _op.concatenate([_expr.const([1, ] * (in_dims - new_dims), dtype=dtype),
1464+
shape], axis=0)
14521465
new_shape = _op.maximum(in_shape, shape)
14531466
return new_shape
14541467

@@ -2058,6 +2071,13 @@ def from_onnx(self, graph, opset, freeze_params=False):
20582071
20592072
opset : opset version
20602073
2074+
freeze_params: bool
2075+
If this parameter is true, the importer will take any provided
2076+
onnx input values (weights, shapes, etc) and embed them into the relay model
2077+
as Constants instead of variables. This allows more aggressive optimizations
2078+
at compile time and helps in making models static if certain inputs represent
2079+
attributes relay would traditionally consider compile-time constants.
2080+
20612081
Returns
20622082
-------
20632083
mod : tvm.IRModule
@@ -2156,12 +2176,12 @@ def from_onnx(self, graph, opset, freeze_params=False):
21562176
## Maintain the order of inputs and parametersfrom the ONNX graph, but only include
21572177
## those parameters that are needed to execute the relay graph
21582178
free_vars = analysis.free_vars(outputs)
2159-
nodes = {v:k for k,v in self._nodes.items()}
2179+
nodes = {v: k for k, v in self._nodes.items()}
21602180
free_vars = [nodes[var] for var in free_vars]
21612181
for i_name in self._params:
21622182
if i_name in free_vars and i_name not in self._inputs:
21632183
self._inputs[i_name] = self._nodes[i_name]
2164-
func = _function.Function([v for k,v in self._inputs.items()], outputs)
2184+
func = _function.Function([v for k, v in self._inputs.items()], outputs)
21652185
if freeze_params:
21662186
func, params = self.freeze(func, self._params)
21672187
return IRModule.from_expr(func), params
@@ -2282,6 +2302,13 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
22822302
Override to autodetected opset.
22832303
This can be helpful for some testing.
22842304
2305+
freeze_params: bool
2306+
If this parameter is true, the importer will take any provided
2307+
onnx input values (weights, shapes, etc) and embed them into the relay model
2308+
as Constants instead of variables. This allows more aggressive optimizations
2309+
at compile time and helps in making models static if certain inputs represent
2310+
attributes relay would traditionally consider compile-time constants.
2311+
22852312
Returns
22862313
-------
22872314
mod : tvm.IRModule

python/tvm/topi/x86/batch_matmul.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def batch_matmul(cfg, x, y, out_shape=None):
5050
assert XK == YK, "shapes of x and y is inconsistant"
5151
B = XB
5252
K = XK
53+
if out_shape is not None:
54+
assert out_shape[0] == B, "got invalid output shape"
55+
assert out_shape[1] == M, "got invalid output shape"
56+
assert out_shape[2] == N, "got invalid output shape"
5357
if cfg.is_fallback:
5458
_default_batch_matmul_config(cfg, M, N, K)
5559

tests/python/frontend/onnx/test_forward.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None, freez
4848
""" Generic function to execute and get tvm output with vm executor"""
4949
if not isinstance(input_data, list):
5050
input_data = [input_data]
51-
input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data)
51+
_, shape_dict = get_input_data_shape_dict(graph_def, input_data)
5252

5353
mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset, freeze_params=freeze_params)
5454

@@ -167,15 +167,26 @@ def test_reshape():
167167
# @tvm.testing.uses_gpu
168168
def test_expand():
169169

170-
def _test_expand(name, data, shape, ref_data):
170+
def _test_expand(name, data, shape, ref_data, dtype="int32"):
171171
shape_array = np.array(shape)
172-
shape_node = onnx.helper.make_node('Constant',
173-
inputs=[],
174-
outputs=['shape'],
175-
value=onnx.helper.make_tensor(name = 'const_tensor',
176-
data_type = onnx.TensorProto.INT32,
177-
dims = shape_array.shape,
178-
vals = shape_array.flatten().astype('int32')))
172+
if dtype == "int32":
173+
shape_node = onnx.helper.make_node('Constant',
174+
inputs=[],
175+
outputs=['shape'],
176+
value=onnx.helper.make_tensor(name = 'const_tensor',
177+
data_type = onnx.TensorProto.INT32,
178+
dims = shape_array.shape,
179+
vals = shape_array.flatten().astype('int32')))
180+
elif dtype == "int64":
181+
shape_node = onnx.helper.make_node('Constant',
182+
inputs=[],
183+
outputs=['shape'],
184+
value=onnx.helper.make_tensor(name = 'const_tensor',
185+
data_type = onnx.TensorProto.INT64,
186+
dims = shape_array.shape,
187+
vals = shape_array.flatten().astype('int64')))
188+
else:
189+
raise "Invalid dtype"
179190
expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
180191

181192
graph = helper.make_graph([shape_node, expand_node],
@@ -196,13 +207,15 @@ def _test_expand(name, data, shape, ref_data):
196207
shape = (3, 4)
197208
data = np.random.uniform(size=in_shape).astype(np.float32)
198209
ref_data = np.tile(data, 4)
199-
_test_expand('expand_with_dim_unchanged_test', data, shape, ref_data)
210+
_test_expand('expand_with_dim_unchanged_test', data, shape, ref_data, "int32")
211+
_test_expand('expand_with_dim_unchanged_test', data, shape, ref_data, "int64")
200212

201213
in_shape = (3, 1)
202214
shape = (2, 1, 6)
203215
data = np.random.uniform(size=in_shape).astype(np.float32)
204216
ref_data = data * np.ones(shape, dtype=np.float32)
205-
_test_expand('expand_with_dim_changed_test', data, shape, ref_data)
217+
_test_expand('expand_with_dim_changed_test', data, shape, ref_data, "int32")
218+
_test_expand('expand_with_dim_changed_test', data, shape, ref_data, "int64")
206219

207220

208221
def verify_depth_to_space(inshape, outshape, mode, blockSize):
@@ -822,8 +835,8 @@ def verify_batch_matmul(a_shape, b_shape):
822835
model, [a_array, b_array], target, ctx)
823836
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
824837

825-
# TODO(mbrookhart): enable once VM supports heterogenous execution
826-
# @tvm.testing.uses_gpu
838+
# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
839+
@tvm.testing.parametrize_targets("llvm")
827840
def test_batch_matmul():
828841
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4))
829842
verify_batch_matmul((2, 4, 3), (3, 4))
@@ -1024,11 +1037,9 @@ def _test_upsample_bilinear_opset9():
10241037
graph, producer_name='upsample_bilinear_opset9_test')
10251038

10261039
for target, ctx in tvm.testing.enabled_targets():
1027-
tvm_out = get_tvm_output(
1028-
model, in_array, target, ctx, out_shape, 'float32')
1040+
tvm_out = get_tvm_output_with_vm(model, [in_array], target, ctx, opset=9, freeze_params=True)
10291041
tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
10301042

1031-
10321043
def _test_upsample3d_trilinear():
10331044
scale = 2
10341045
in_shape = (1, 1, 3, 3, 3)
@@ -1062,7 +1073,8 @@ def _test_upsample3d_trilinear():
10621073
model, in_array, target, ctx, out_shape, 'float32')
10631074
tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
10641075

1065-
@tvm.testing.uses_gpu
1076+
# TODO(mbrookhart): enable once VM supports heterogenous execution
1077+
# @tvm.testing.uses_gpu
10661078
def test_upsample():
10671079
_test_upsample_nearest()
10681080
_test_upsample_bilinear()
@@ -1455,7 +1467,7 @@ def verify_pad_v11(indata, pads, mode='constant', value=0.0):
14551467
outputs=[helper.make_tensor_value_info("output",
14561468
TensorProto.FLOAT, list(outdata.shape))])
14571469
else:
1458-
inputs = [indata, pads, np.array([value])]
1470+
inputs = [indata, pads, np.array([value]).astype("float32")]
14591471
outdata = np.pad(indata, pad_width=np_pads,
14601472
mode='constant', constant_values=value)
14611473
node = helper.make_node(
@@ -1471,7 +1483,7 @@ def verify_pad_v11(indata, pads, mode='constant', value=0.0):
14711483
helper.make_tensor_value_info("pads",
14721484
TensorProto.INT64,(len(pads),)),
14731485
helper.make_tensor_value_info("constant_value",
1474-
TensorProto.INT64,(1,)),
1486+
TensorProto.FLOAT,(1,)),
14751487
],
14761488
initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads),
14771489
helper.make_tensor("constant_value", TensorProto.FLOAT, (1,), [value])],
@@ -1480,12 +1492,12 @@ def verify_pad_v11(indata, pads, mode='constant', value=0.0):
14801492
model = helper.make_model(graph, producer_name='pad_test')
14811493
# tvm result
14821494
for target, ctx in tvm.testing.enabled_targets():
1483-
tvm_out = get_tvm_output(
1484-
model, inputs, target, ctx, outdata.shape, 'float32', opset=11)
1495+
tvm_out = get_tvm_output_with_vm(model, inputs, target, ctx, opset=11, freeze_params=False)
14851496
tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
14861497

14871498

1488-
@tvm.testing.uses_gpu
1499+
# TODO(mbrookhart): enable once VM supports heterogenous execution
1500+
# @tvm.testing.uses_gpu
14891501
def test_pad():
14901502
verify_pad(np.random.randn(2, 2).astype(
14911503
np.float32), [0, 1, 0, 0], 'constant', 0.0)

tests/python/relay/test_op_level10.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def verify_dynamic_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"):
358358
y_np = np.random.uniform(size=y_shape).astype(dtype)
359359
z_np = tvm.topi.testing.batch_matmul(x_np, y_np)
360360

361-
for target, ctx in ctx_list():
361+
for target, ctx in tvm.testing.enabled_targets():
362362
for kind in ["vm", "debug"]:
363363
mod = tvm.ir.IRModule.from_expr(func)
364364
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)

0 commit comments

Comments
 (0)