Skip to content

Commit 3dc399f

Browse files
hugolatendresseShiboXing
authored andcommitted
[Relax] Add support to ingest Tensor.expand_as() (apache#17724)
Support to ingest Tensor.expand_as(), with unit test for correctness
1 parent 50f5b2d commit 3dc399f

4 files changed

Lines changed: 57 additions & 0 deletions

File tree

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,14 @@ def _expand(self, node: fx.Node) -> relax.Var:
883883
broadcast_shape.append(i)
884884
return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape))
885885

886+
def _expand_as(self, node: fx.Node) -> relax.Var:
887+
args = self.retrieve_args(node)
888+
# args[0] is the 'self' tensor
889+
# args[1] is the 'other' tensor
890+
data = args[0]
891+
other_shape = self.shape_of(args[1]) # the shape of 'other'
892+
return self.block_builder.emit(relax.op.broadcast_to(data, other_shape))
893+
886894
def _flip(self, node: fx.Node) -> relax.Var:
887895
x = self.env[node.args[0]]
888896
dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims", None)

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def create_convert_map(
298298
"copy_.default": self._copy_,
299299
"cumsum.default": self._cumsum,
300300
"expand.default": self._expand,
301+
"expand_as.default": self._expand_as,
301302
"permute.default": self._permute,
302303
"repeat.default": self._repeat,
303304
"select.int": self._select,

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ def create_convert_map(
749749
"contiguous": lambda node: self.env[node.args[0]],
750750
"cumsum": self._cumsum,
751751
"expand": self._expand,
752+
"expand_as.default": self._expand_as,
752753
"flatten": self._flatten,
753754
"flip": self._flip,
754755
"gather": self._gather,

tests/python/relax/test_from_exported_to_cuda.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,53 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
5656
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)
5757

5858

59+
@tvm.testing.parametrize_targets("cuda")
60+
def test_tensor_expand_as(target, dev):
61+
class ExpandAs0(torch.nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
self.template = torch.ones((1, 1, 1, 1))
65+
66+
def forward(self, x):
67+
return self.template.expand_as(x)
68+
69+
class ExpandAs1(torch.nn.Module):
70+
def __init__(self):
71+
super().__init__()
72+
self.template = torch.ones((2, 1, 4, 1))
73+
74+
def forward(self, x):
75+
return self.template.expand_as(x)
76+
77+
class ExpandAs2(torch.nn.Module):
78+
def __init__(self):
79+
super().__init__()
80+
self.template = torch.ones((2, 1, 1, 10))
81+
82+
def forward(self, x):
83+
return self.template.expand_as(x)
84+
85+
class ExpandAs3(torch.nn.Module):
86+
def __init__(self):
87+
super().__init__()
88+
self.template = torch.ones((2, 3, 1, 1))
89+
90+
def forward(self, x):
91+
return self.template.expand_as(x)
92+
93+
raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32)
94+
95+
torch_module0 = ExpandAs0().eval()
96+
torch_module1 = ExpandAs1().eval()
97+
torch_module2 = ExpandAs2().eval()
98+
torch_module3 = ExpandAs3().eval()
99+
100+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev)
101+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, target, dev)
102+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev)
103+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev)
104+
105+
59106
@tvm.testing.parametrize_targets("cuda")
60107
def test_copy_(target, dev):
61108
class CopyTester(nn.Module):

0 commit comments

Comments
 (0)