@@ -1485,6 +1485,8 @@ def main(
14851485def test_binary ():
14861486 input_info1 = [([1 , 3 , 10 , 10 ], "float32" ), ([1 , 3 , 10 , 10 ], "float32" )]
14871487 input_info2 = [([1 , 3 , 10 , 10 ], "float32" )]
1488+ input_info3 = [([1 , 3 , 10 , 10 ], "int32" ), ([1 , 3 , 10 , 10 ], "int32" )]
1489+ input_info4 = [([1 , 3 , 10 , 10 ], "int32" )]
14881490
14891491 # Add
14901492 class Add1 (Module ):
@@ -1962,6 +1964,211 @@ def main(
19621964 verify_model (Ne1 (), input_info1 , {}, expected23 )
19631965 verify_model (Ne2 (), input_info2 , {}, expected24 )
19641966
1967+ # Lshift
1968+ class LShift1 (Module ):
1969+ def forward (self , lhs , rhs ):
1970+ return lhs << rhs
1971+
1972+ @tvm .script .ir_module
1973+ class expected25 :
1974+ @R .function
1975+ def main (
1976+ lhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
1977+ rhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
1978+ ) -> R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ):
1979+ # block 0
1980+ with R .dataflow ():
1981+ lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = R .left_shift (lhs_1 , rhs_1 )
1982+ gv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = lv
1983+ R .output (gv )
1984+
1985+ return gv
1986+
1987+ class LShift2 (Module ):
1988+ def forward (self , lhs ):
1989+ return lhs << 1
1990+
1991+ @tvm .script .ir_module
1992+ class expected26 :
1993+ @R .function
1994+ def main (
1995+ lhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
1996+ ) -> R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ):
1997+ # block 0
1998+ with R .dataflow ():
1999+ lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = R .left_shift (lhs_1 , R .const (1 ))
2000+ gv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = lv
2001+ R .output (gv )
2002+
2003+ return gv
2004+
2005+ verify_model (LShift1 (), input_info3 , {}, expected25 )
2006+ verify_model (LShift2 (), input_info4 , {}, expected26 )
2007+
2008+ # Rshift
2009+ class RShift1 (Module ):
2010+ def forward (self , lhs , rhs ):
2011+ return lhs >> rhs
2012+
2013+ @tvm .script .ir_module
2014+ class expected27 :
2015+ @R .function
2016+ def main (
2017+ lhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
2018+ rhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
2019+ ) -> R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ):
2020+ # block 0
2021+ with R .dataflow ():
2022+ lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = R .right_shift (lhs_1 , rhs_1 )
2023+ gv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = lv
2024+ R .output (gv )
2025+
2026+ return gv
2027+
2028+ class RShift2 (Module ):
2029+ def forward (self , lhs ):
2030+ return lhs >> 1
2031+
2032+ @tvm .script .ir_module
2033+ class expected28 :
2034+ @R .function
2035+ def main (
2036+ lhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
2037+ ) -> R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ):
2038+ # block 0
2039+ with R .dataflow ():
2040+ lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = R .right_shift (lhs_1 , R .const (1 ))
2041+ gv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = lv
2042+ R .output (gv )
2043+
2044+ return gv
2045+
2046+ verify_model (RShift1 (), input_info3 , {}, expected27 )
2047+ verify_model (RShift2 (), input_info4 , {}, expected28 )
2048+
2049+ # Bitwise and
2050+ class BitwiseAnd1 (Module ):
2051+ def forward (self , lhs , rhs ):
2052+ return lhs & rhs
2053+
2054+ @tvm .script .ir_module
2055+ class expected29 :
2056+ @R .function
2057+ def main (
2058+ lhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
2059+ rhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
2060+ ) -> R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ):
2061+ # block 0
2062+ with R .dataflow ():
2063+ lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = R .bitwise_and (lhs_1 , rhs_1 )
2064+ gv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = lv
2065+ R .output (gv )
2066+
2067+ return gv
2068+
2069+ class BitwiseAnd2 (Module ):
2070+ def forward (self , lhs ):
2071+ return lhs & 1
2072+
2073+ @tvm .script .ir_module
2074+ class expected30 :
2075+ @R .function
2076+ def main (
2077+ lhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
2078+ ) -> R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ):
2079+ # block 0
2080+ with R .dataflow ():
2081+ lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = R .bitwise_and (lhs_1 , R .const (1 ))
2082+ gv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = lv
2083+ R .output (gv )
2084+
2085+ return gv
2086+
2087+ verify_model (BitwiseAnd1 (), input_info3 , {}, expected29 )
2088+ verify_model (BitwiseAnd2 (), input_info4 , {}, expected30 )
2089+
2090+ # Bitwise or
2091+ class BitwiseOr1 (Module ):
2092+ def forward (self , lhs , rhs ):
2093+ return lhs | rhs
2094+
2095+ @tvm .script .ir_module
2096+ class expected31 :
2097+ @R .function
2098+ def main (
2099+ lhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
2100+ rhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
2101+ ) -> R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ):
2102+ # block 0
2103+ with R .dataflow ():
2104+ lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = R .bitwise_or (lhs_1 , rhs_1 )
2105+ gv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = lv
2106+ R .output (gv )
2107+
2108+ return gv
2109+
2110+ class BitwiseOr2 (Module ):
2111+ def forward (self , lhs ):
2112+ return lhs | 1
2113+
2114+ @tvm .script .ir_module
2115+ class expected32 :
2116+ @R .function
2117+ def main (
2118+ lhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
2119+ ) -> R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ):
2120+ # block 0
2121+ with R .dataflow ():
2122+ lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = R .bitwise_or (lhs_1 , R .const (1 ))
2123+ gv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = lv
2124+ R .output (gv )
2125+
2126+ return gv
2127+
2128+ verify_model (BitwiseOr1 (), input_info3 , {}, expected31 )
2129+ verify_model (BitwiseOr2 (), input_info4 , {}, expected32 )
2130+
2131+ # Bitwise xor
2132+ class BitwiseXor1 (Module ):
2133+ def forward (self , lhs , rhs ):
2134+ return lhs ^ rhs
2135+
2136+ @tvm .script .ir_module
2137+ class expected33 :
2138+ @R .function
2139+ def main (
2140+ lhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
2141+ rhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
2142+ ) -> R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ):
2143+ # block 0
2144+ with R .dataflow ():
2145+ lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = R .bitwise_xor (lhs_1 , rhs_1 )
2146+ gv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = lv
2147+ R .output (gv )
2148+
2149+ return gv
2150+
2151+ class BitwiseXor2 (Module ):
2152+ def forward (self , lhs ):
2153+ return lhs ^ 1
2154+
2155+ @tvm .script .ir_module
2156+ class expected34 :
2157+ @R .function
2158+ def main (
2159+ lhs_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ),
2160+ ) -> R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ):
2161+ # block 0
2162+ with R .dataflow ():
2163+ lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = R .bitwise_xor (lhs_1 , R .const (1 ))
2164+ gv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "int32" ) = lv
2165+ R .output (gv )
2166+
2167+ return gv
2168+
2169+ verify_model (BitwiseXor1 (), input_info3 , {}, expected33 )
2170+ verify_model (BitwiseXor2 (), input_info4 , {}, expected34 )
2171+
19652172
19662173def test_size ():
19672174 input_info = [([1 , 3 , 10 , 10 ], "float32" )]
@@ -3745,6 +3952,27 @@ def main(
37453952 verify_model (Max (), [([256 , 256 ], "float32" ), ([256 , 256 ], "float32" )], {}, Expected1 )
37463953
37473954
3955+ def test_min ():
3956+ class Min (Module ):
3957+ def forward (self , x , y ):
3958+ return torch .min (x , y )
3959+
3960+ @I .ir_module
3961+ class Expected1 :
3962+ @R .function
3963+ def main (
3964+ inp_0 : R .Tensor ((256 , 256 ), dtype = "float32" ),
3965+ inp_1 : R .Tensor ((256 , 256 ), dtype = "float32" ),
3966+ ) -> R .Tensor ((256 , 256 ), dtype = "float32" ):
3967+ with R .dataflow ():
3968+ lv : R .Tensor ((256 , 256 ), dtype = "float32" ) = R .minimum (inp_0 , inp_1 )
3969+ gv : R .Tensor ((256 , 256 ), dtype = "float32" ) = lv
3970+ R .output (gv )
3971+ return gv
3972+
3973+ verify_model (Min (), [([256 , 256 ], "float32" ), ([256 , 256 ], "float32" )], {}, Expected1 )
3974+
3975+
37483976def test_attention ():
37493977 @I .ir_module
37503978 class Expected1 :
0 commit comments