Skip to content

Commit e7bcf17

Browse files
authored
[Relax][PyTorch] Support MatrixMultiply op for ExportedProgram importer (#18343)
This pr supports `mm.default` for ExportedProgram importer. Resolves the issue #18339.
1 parent a21e0df commit e7bcf17

2 files changed

Lines changed: 29 additions & 0 deletions

File tree

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,9 @@ def create_convert_map(
434434
"matmul.default": self._binary_op(
435435
partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul
436436
),
437+
"mm.default": self._binary_op(
438+
partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul
439+
),
437440
"max.other": self._binary_op(relax.op.maximum, max),
438441
"min.other": self._binary_op(relax.op.minimum, min),
439442
"max.default": self._unary_op(relax.op.max),

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5914,6 +5914,32 @@ def main(
59145914
verify_model(Model(), example_args, {}, Expected)
59155915

59165916

5917+
def test_mm():
5918+
class MatrixMultiply(Module):
5919+
def forward(self, a, b):
5920+
return torch.mm(a, b)
5921+
5922+
example_args = (
5923+
torch.randn(2, 3, dtype=torch.float32),
5924+
torch.randn(3, 4, dtype=torch.float32),
5925+
)
5926+
5927+
@tvm.script.ir_module
5928+
class Expected:
5929+
@R.function
5930+
def main(
5931+
a: R.Tensor((2, 3), dtype="float32"),
5932+
b: R.Tensor((3, 4), dtype="float32"),
5933+
) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
5934+
with R.dataflow():
5935+
lv: R.Tensor((2, 4), dtype="float32") = R.matmul(a, b, out_dtype="float32")
5936+
gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (lv,)
5937+
R.output(gv)
5938+
return gv
5939+
5940+
verify_model(MatrixMultiply(), example_args, {}, Expected)
5941+
5942+
59175943
if __name__ == "__main__":
59185944
tvm.testing.main()
59195945
1

0 commit comments

Comments
 (0)