Skip to content

Commit 2d67deb

Browse files
ConvolutedDogShiboXing
authored andcommitted
[Fix][ONNX] Fix CumSum conversion when loading ONNX model (apache#18137)
* Fix onnx cumsum * Fix onnx cumsum
1 parent 2c015b9 commit 2d67deb

2 files changed

Lines changed: 8 additions & 1 deletion

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,9 +1338,16 @@ def _impl_v14(cls, bb, inputs, attr, params):
13381338
axis = int(axis.data.numpy())
13391339
elif isinstance(axis, relax.Var):
13401340
axis = 0
1341+
1342+
if attr.get("reverse", 0) != 0:
1343+
data = bb.emit_te(topi.flip, data, axis=axis if axis else 0)
1344+
13411345
data = relax.op.cumsum(data, axis)
1346+
data = bb.normalize(data)
1347+
13421348
if attr.get("reverse", 0) != 0:
13431349
data = bb.emit_te(topi.flip, data, axis=axis if axis else 0)
1350+
13441351
return data
13451352

13461353

tests/python/relax/test_frontend_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ def test_pow():
10991099
verify_binary("Pow", [32, 32], [32, 32], [32, 32])
11001100

11011101

1102-
@pytest.mark.parametrize("reverse", [False])
1102+
@pytest.mark.parametrize("reverse", [True, False])
11031103
@pytest.mark.parametrize("exclusive", [False])
11041104
def test_cumsum(reverse, exclusive):
11051105
cumsum_node = helper.make_node(

0 commit comments

Comments
 (0)