Skip to content

Commit dc043fe

Browse files
authored
[Relax][PyTorch] Support one_hot, empty_like ops for ExportedProgram importer (#17751)
* Update exported_program_translator.py * Update test_frontend_from_exported_program.py * Update test_frontend_from_exported_program.py * Update exported_program_translator.py
1 parent f6236ce commit dc043fe

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,23 @@ def _slice(self, node: fx.Node) -> relax.Var:
174174
stride = [node.args[4] if len(node.args) > 4 else 1]
175175
return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride))
176176

177+
########## Creation ##########
178+
179+
def _one_hot(self, node: fx.Node) -> relax.Var:
180+
x = self.env[node.args[0]]
181+
num_classes = node.args[1] if len(node.args) > 1 else node.kwargs.get("num_classes")
182+
if num_classes is None:
183+
raise ValueError("num_classes not found in node.args or node.kwargs")
184+
185+
on_value = node.args[2] if len(node.args) > 2 else node.kwargs.get("on_value", 1)
186+
off_value = node.args[3] if len(node.args) > 3 else node.kwargs.get("off_value", 0)
187+
axis = node.args[4] if len(node.args) > 4 else node.kwargs.get("axis", -1)
188+
189+
on_value = relax.PrimValue(on_value)
190+
off_value = relax.PrimValue(off_value)
191+
192+
return self.block_builder.emit(relax.op.one_hot(x, on_value, off_value, num_classes, axis))
193+
177194
########## Others ##########
178195

179196
def create_convert_map(
@@ -331,8 +348,10 @@ def create_convert_map(
331348
"contiguous.default": lambda node: self.env[node.args[0]], # no-op
332349
"clone.default": lambda node: self.env[node.args[0]],
333350
"empty.memory_format": self._empty,
351+
"empty_like.default": self._empty_like,
334352
"fill.Scalar": self._fill,
335353
"new_ones.default": self._new_ones,
354+
"one_hot.default": self._one_hot,
336355
# other
337356
"getitem": self._getitem,
338357
}

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3425,6 +3425,74 @@ def main(
34253425
tvm.ir.assert_structural_equal(mod, Expected)
34263426

34273427

3428+
def test_empty_like():
3429+
class EmptyLike(Module):
3430+
def forward(self, data):
3431+
return torch.empty_like(data)
3432+
3433+
@tvm.script.ir_module
3434+
class Expected:
3435+
@R.function
3436+
def main(
3437+
inp_0: R.Tensor((5,), dtype="float32"),
3438+
) -> R.Tuple(R.Tensor((5,), dtype="float32")):
3439+
with R.dataflow():
3440+
lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0, dtype="void")
3441+
gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,)
3442+
R.output(gv)
3443+
return gv
3444+
3445+
example_args = (torch.randn(5, dtype=torch.float32),)
3446+
3447+
verify_model(EmptyLike(), example_args, {}, Expected)
3448+
3449+
3450+
def test_one_hot():
3451+
class OneHot(Module):
3452+
def forward(self, indices):
3453+
return torch.nn.functional.one_hot(indices, num_classes=10)
3454+
3455+
@tvm.script.ir_module
3456+
class Expected:
3457+
@R.function
3458+
def main(
3459+
inp_0: R.Tensor((5,), dtype="int64"),
3460+
) -> R.Tuple(R.Tensor((5, 10), dtype="int64")):
3461+
with R.dataflow():
3462+
lv: R.Tensor((5, 10), dtype="int64") = R.one_hot(
3463+
inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1
3464+
)
3465+
gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv,)
3466+
R.output(gv)
3467+
return gv
3468+
3469+
example_args = (torch.randint(0, 10, (5,), dtype=torch.int64),)
3470+
3471+
verify_model(OneHot(), example_args, {}, Expected)
3472+
3473+
3474+
def test_select():
3475+
class Select(Module):
3476+
def forward(self, input):
3477+
return torch.select(input, 0, 1)
3478+
3479+
@tvm.script.ir_module
3480+
class Expected:
3481+
@R.function
3482+
def main(
3483+
inp_0: R.Tensor((2, 3), dtype="float32"),
3484+
) -> R.Tuple(R.Tensor((3,), dtype="float32")):
3485+
with R.dataflow():
3486+
lv: R.Tensor((3,), dtype="float32") = R.take(inp_0, R.const(1, "int64"), axis=0)
3487+
gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv,)
3488+
R.output(gv)
3489+
return gv
3490+
3491+
example_args = (torch.randn(2, 3, dtype=torch.float32),)
3492+
3493+
verify_model(Select(), example_args, {}, Expected)
3494+
3495+
34283496
def test_gather():
34293497
class Gather0(Module):
34303498
def forward(self, data, indices):

0 commit comments

Comments
 (0)