Skip to content

Commit 432ccfa

Browse files
authored
[Relax][PyTorch] Add support for and_, lshift, min, or_, rshift, xor ops (#17668)
* Update fx_translator.py * Update test_frontend_from_fx.py * Update test_frontend_from_fx.py
1 parent 25be3e2 commit 432ccfa

2 files changed

Lines changed: 234 additions & 0 deletions

File tree

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,23 +660,29 @@ def create_convert_map(
660660
"triu": self._tril_triu(relax.op.triu),
661661
# binary
662662
"add": self._binary_op(relax.op.add, operator.add),
663+
"and_": self._binary_op(relax.op.bitwise_and, operator.and_),
663664
"eq": self._binary_op(relax.op.equal, operator.eq),
664665
"floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv),
665666
"ge": self._binary_op(relax.op.greater_equal, operator.ge),
666667
"gt": self._binary_op(relax.op.greater, operator.gt),
667668
"iadd": self._binary_op(relax.op.add, operator.add),
668669
"le": self._binary_op(relax.op.less_equal, operator.le),
670+
"lshift": self._binary_op(relax.op.left_shift, operator.lshift),
669671
"lt": self._binary_op(relax.op.less, operator.lt),
670672
"matmul": self._binary_op(
671673
partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul
672674
),
673675
"max": self._binary_op(relax.op.maximum, max),
676+
"min": self._binary_op(relax.op.minimum, min),
674677
"mod": self._binary_op(relax.op.mod, operator.mod),
675678
"mul": self._binary_op(relax.op.multiply, operator.mul),
676679
"ne": self._binary_op(relax.op.not_equal, operator.ne),
677680
"pow": self._binary_op(relax.op.power, operator.pow),
681+
"or_": self._binary_op(relax.op.bitwise_or, operator.or_),
682+
"rshift": self._binary_op(relax.op.right_shift, operator.rshift),
678683
"sub": self._binary_op(relax.op.subtract, operator.sub),
679684
"truediv": self._binary_op(relax.op.divide, operator.truediv),
685+
"xor": self._binary_op(relax.op.bitwise_xor, operator.xor),
680686
# neural network
681687
"adaptive_avg_pool2d": self._adaptive_avg_pool2d,
682688
"addmm": self._addmm,

tests/python/relax/test_frontend_from_fx.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,6 +1485,8 @@ def main(
14851485
def test_binary():
14861486
input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
14871487
input_info2 = [([1, 3, 10, 10], "float32")]
1488+
input_info3 = [([1, 3, 10, 10], "int32"), ([1, 3, 10, 10], "int32")]
1489+
input_info4 = [([1, 3, 10, 10], "int32")]
14881490

14891491
# Add
14901492
class Add1(Module):
@@ -1962,6 +1964,211 @@ def main(
19621964
verify_model(Ne1(), input_info1, {}, expected23)
19631965
verify_model(Ne2(), input_info2, {}, expected24)
19641966

1967+
# Lshift
1968+
class LShift1(Module):
1969+
def forward(self, lhs, rhs):
1970+
return lhs << rhs
1971+
1972+
@tvm.script.ir_module
1973+
class expected25:
1974+
@R.function
1975+
def main(
1976+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
1977+
rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
1978+
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
1979+
# block 0
1980+
with R.dataflow():
1981+
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.left_shift(lhs_1, rhs_1)
1982+
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
1983+
R.output(gv)
1984+
1985+
return gv
1986+
1987+
class LShift2(Module):
1988+
def forward(self, lhs):
1989+
return lhs << 1
1990+
1991+
@tvm.script.ir_module
1992+
class expected26:
1993+
@R.function
1994+
def main(
1995+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
1996+
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
1997+
# block 0
1998+
with R.dataflow():
1999+
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.left_shift(lhs_1, R.const(1))
2000+
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
2001+
R.output(gv)
2002+
2003+
return gv
2004+
2005+
verify_model(LShift1(), input_info3, {}, expected25)
2006+
verify_model(LShift2(), input_info4, {}, expected26)
2007+
2008+
# Rshift
2009+
class RShift1(Module):
2010+
def forward(self, lhs, rhs):
2011+
return lhs >> rhs
2012+
2013+
@tvm.script.ir_module
2014+
class expected27:
2015+
@R.function
2016+
def main(
2017+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
2018+
rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
2019+
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
2020+
# block 0
2021+
with R.dataflow():
2022+
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.right_shift(lhs_1, rhs_1)
2023+
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
2024+
R.output(gv)
2025+
2026+
return gv
2027+
2028+
class RShift2(Module):
2029+
def forward(self, lhs):
2030+
return lhs >> 1
2031+
2032+
@tvm.script.ir_module
2033+
class expected28:
2034+
@R.function
2035+
def main(
2036+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
2037+
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
2038+
# block 0
2039+
with R.dataflow():
2040+
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.right_shift(lhs_1, R.const(1))
2041+
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
2042+
R.output(gv)
2043+
2044+
return gv
2045+
2046+
verify_model(RShift1(), input_info3, {}, expected27)
2047+
verify_model(RShift2(), input_info4, {}, expected28)
2048+
2049+
# Bitwise and
2050+
class BitwiseAnd1(Module):
2051+
def forward(self, lhs, rhs):
2052+
return lhs & rhs
2053+
2054+
@tvm.script.ir_module
2055+
class expected29:
2056+
@R.function
2057+
def main(
2058+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
2059+
rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
2060+
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
2061+
# block 0
2062+
with R.dataflow():
2063+
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_and(lhs_1, rhs_1)
2064+
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
2065+
R.output(gv)
2066+
2067+
return gv
2068+
2069+
class BitwiseAnd2(Module):
2070+
def forward(self, lhs):
2071+
return lhs & 1
2072+
2073+
@tvm.script.ir_module
2074+
class expected30:
2075+
@R.function
2076+
def main(
2077+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
2078+
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
2079+
# block 0
2080+
with R.dataflow():
2081+
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_and(lhs_1, R.const(1))
2082+
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
2083+
R.output(gv)
2084+
2085+
return gv
2086+
2087+
verify_model(BitwiseAnd1(), input_info3, {}, expected29)
2088+
verify_model(BitwiseAnd2(), input_info4, {}, expected30)
2089+
2090+
# Bitwise or
2091+
class BitwiseOr1(Module):
2092+
def forward(self, lhs, rhs):
2093+
return lhs | rhs
2094+
2095+
@tvm.script.ir_module
2096+
class expected31:
2097+
@R.function
2098+
def main(
2099+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
2100+
rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
2101+
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
2102+
# block 0
2103+
with R.dataflow():
2104+
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_or(lhs_1, rhs_1)
2105+
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
2106+
R.output(gv)
2107+
2108+
return gv
2109+
2110+
class BitwiseOr2(Module):
2111+
def forward(self, lhs):
2112+
return lhs | 1
2113+
2114+
@tvm.script.ir_module
2115+
class expected32:
2116+
@R.function
2117+
def main(
2118+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
2119+
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
2120+
# block 0
2121+
with R.dataflow():
2122+
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_or(lhs_1, R.const(1))
2123+
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
2124+
R.output(gv)
2125+
2126+
return gv
2127+
2128+
verify_model(BitwiseOr1(), input_info3, {}, expected31)
2129+
verify_model(BitwiseOr2(), input_info4, {}, expected32)
2130+
2131+
# Bitwise xor
2132+
class BitwiseXor1(Module):
2133+
def forward(self, lhs, rhs):
2134+
return lhs ^ rhs
2135+
2136+
@tvm.script.ir_module
2137+
class expected33:
2138+
@R.function
2139+
def main(
2140+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
2141+
rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
2142+
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
2143+
# block 0
2144+
with R.dataflow():
2145+
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_xor(lhs_1, rhs_1)
2146+
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
2147+
R.output(gv)
2148+
2149+
return gv
2150+
2151+
class BitwiseXor2(Module):
2152+
def forward(self, lhs):
2153+
return lhs ^ 1
2154+
2155+
@tvm.script.ir_module
2156+
class expected34:
2157+
@R.function
2158+
def main(
2159+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
2160+
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
2161+
# block 0
2162+
with R.dataflow():
2163+
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_xor(lhs_1, R.const(1))
2164+
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
2165+
R.output(gv)
2166+
2167+
return gv
2168+
2169+
verify_model(BitwiseXor1(), input_info3, {}, expected33)
2170+
verify_model(BitwiseXor2(), input_info4, {}, expected34)
2171+
19652172

19662173
def test_size():
19672174
input_info = [([1, 3, 10, 10], "float32")]
@@ -3745,6 +3952,27 @@ def main(
37453952
verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], {}, Expected1)
37463953

37473954

3955+
def test_min():
3956+
class Min(Module):
3957+
def forward(self, x, y):
3958+
return torch.min(x, y)
3959+
3960+
@I.ir_module
3961+
class Expected1:
3962+
@R.function
3963+
def main(
3964+
inp_0: R.Tensor((256, 256), dtype="float32"),
3965+
inp_1: R.Tensor((256, 256), dtype="float32"),
3966+
) -> R.Tensor((256, 256), dtype="float32"):
3967+
with R.dataflow():
3968+
lv: R.Tensor((256, 256), dtype="float32") = R.minimum(inp_0, inp_1)
3969+
gv: R.Tensor((256, 256), dtype="float32") = lv
3970+
R.output(gv)
3971+
return gv
3972+
3973+
verify_model(Min(), [([256, 256], "float32"), ([256, 256], "float32")], {}, Expected1)
3974+
3975+
37483976
def test_attention():
37493977
@I.ir_module
37503978
class Expected1:

0 commit comments

Comments
 (0)