@@ -1602,6 +1602,206 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
16021602 np .testing .assert_allclose (res .numpy (), res1 .numpy ())
16031603
16041604
1605+ def test_alter_layout_blocked_no_broadcast ():
1606+ """Test boradcast operators working on already blocked layout"""
1607+
1608+ def before ():
1609+ dtype = "float32"
1610+ input_shape = (1 , 8 , 16 , 16 , 4 )
1611+ filter_shape = (1 , 8 , 4 , 4 , 4 , 4 )
1612+ bias_shape = (1 , 1 , 1 , 1 , 4 )
1613+ A = relay .var ("data" , shape = input_shape , dtype = dtype )
1614+ B = relay .var ("weight" , shape = filter_shape , dtype = dtype )
1615+ C = relay .var ("bias" , shape = bias_shape , dtype = dtype )
1616+
1617+ conv = relay .nn .conv2d (
1618+ A ,
1619+ B ,
1620+ data_layout = "NCHW4c" ,
1621+ kernel_layout = "OIHW4i4o" ,
1622+ padding = [3 , 3 , 0 , 0 ],
1623+ strides = [2 , 2 ],
1624+ out_dtype = dtype ,
1625+ channels = 4 ,
1626+ kernel_size = (4 , 4 ),
1627+ )
1628+ bias = relay .op .add (conv , C )
1629+ bias = relay .Function (analysis .free_vars (bias ), bias )
1630+ return bias
1631+
1632+ def expected ():
1633+ return before ()
1634+
1635+ def alter_conv2d (attrs , inputs , tinfos , out_type ):
1636+ data , weight = inputs
1637+ new_attrs = dict (attrs )
1638+ new_attrs ["data_layout" ] = "NCHW4c"
1639+ new_attrs ["kernel_layout" ] = "OIHW4i4o"
1640+ return relay .nn .conv2d (data , weight , ** new_attrs )
1641+
1642+ with TempOpAttr ("nn.conv2d" , "FTVMAlterOpLayout" , alter_conv2d ):
1643+ a = run_opt_pass (before (), transform .AlterOpLayout ())
1644+ b = run_opt_pass (expected (), transform .InferType ())
1645+ assert tvm .ir .structural_equal (a , b ), "Actual = \n " + str (a ) + "\n Expected = \n " + str (b )
1646+
1647+ inp = np .random .uniform (size = (1 , 8 , 16 , 16 , 4 )).astype (np .float32 )
1648+ weight = np .random .uniform (size = (1 , 8 , 4 , 4 , 4 , 4 )).astype (np .float32 )
1649+ z = np .random .uniform (size = (1 , 1 , 1 , 1 , 4 )).astype (np .float32 )
1650+ mod = tvm .IRModule .from_expr (before ())
1651+ with TempOpAttr ("nn.conv2d" , "FTVMAlterOpLayout" , alter_conv2d ):
1652+ with tvm .transform .PassContext (opt_level = 4 ):
1653+ res = relay .build_module .create_executor (
1654+ "graph" , mod , target = "llvm" , device = tvm .cpu ()
1655+ ).evaluate ()(inp , weight , z )
1656+ with tvm .transform .PassContext (opt_level = 0 ):
1657+ res1 = relay .build_module .create_executor (
1658+ "debug" , mod , target = "llvm" , device = tvm .cpu ()
1659+ ).evaluate ()(inp , weight , z )
1660+ np .testing .assert_allclose (res .numpy (), res1 .numpy ())
1661+
1662+
1663+ def test_alter_layout_blocked_broadcast ():
1664+ """Test boradcast operators working on already blocked layout"""
1665+
1666+ def before ():
1667+ dtype = "float32"
1668+ input_shape = (1 , 8 , 16 , 16 , 4 )
1669+ filter_shape = (1 , 8 , 4 , 4 , 4 , 4 )
1670+ bias_shape = (1 , 1 , 1 , 1 , 1 )
1671+ A = relay .var ("data" , shape = input_shape , dtype = dtype )
1672+ B = relay .var ("weight" , shape = filter_shape , dtype = dtype )
1673+ C = relay .var ("bias" , shape = bias_shape , dtype = dtype )
1674+
1675+ conv = relay .nn .conv2d (
1676+ A ,
1677+ B ,
1678+ data_layout = "NCHW4c" ,
1679+ kernel_layout = "OIHW4i4o" ,
1680+ padding = [3 , 3 , 0 , 0 ],
1681+ strides = [2 , 2 ],
1682+ out_dtype = dtype ,
1683+ channels = 4 ,
1684+ kernel_size = (4 , 4 ),
1685+ )
1686+ bias = relay .op .add (conv , C )
1687+ bias = relay .Function (analysis .free_vars (bias ), bias )
1688+ return bias
1689+
1690+ def expected ():
1691+ return before ()
1692+
1693+ def alter_conv2d (attrs , inputs , tinfos , out_type ):
1694+ data , weight = inputs
1695+ new_attrs = dict (attrs )
1696+ new_attrs ["data_layout" ] = "NCHW4c"
1697+ new_attrs ["kernel_layout" ] = "OIHW4i4o"
1698+ return relay .nn .conv2d (data , weight , ** new_attrs )
1699+
1700+ with TempOpAttr ("nn.conv2d" , "FTVMAlterOpLayout" , alter_conv2d ):
1701+ a = run_opt_pass (before (), transform .AlterOpLayout ())
1702+ b = run_opt_pass (expected (), transform .InferType ())
1703+ assert tvm .ir .structural_equal (a , b ), "Actual = \n " + str (a ) + "\n Expected = \n " + str (b )
1704+
1705+ inp = np .random .uniform (size = (1 , 8 , 16 , 16 , 4 )).astype (np .float32 )
1706+ weight = np .random .uniform (size = (1 , 8 , 4 , 4 , 4 , 4 )).astype (np .float32 )
1707+ z = np .random .uniform (size = (1 , 1 , 1 , 1 , 1 )).astype (np .float32 )
1708+ mod = tvm .IRModule .from_expr (before ())
1709+ with TempOpAttr ("nn.conv2d" , "FTVMAlterOpLayout" , alter_conv2d ):
1710+ with tvm .transform .PassContext (opt_level = 4 ):
1711+ res = relay .build_module .create_executor (
1712+ "graph" , mod , target = "llvm" , device = tvm .cpu ()
1713+ ).evaluate ()(inp , weight , z )
1714+ with tvm .transform .PassContext (opt_level = 0 ):
1715+ res1 = relay .build_module .create_executor (
1716+ "debug" , mod , target = "llvm" , device = tvm .cpu ()
1717+ ).evaluate ()(inp , weight , z )
1718+ np .testing .assert_allclose (res .numpy (), res1 .numpy ())
1719+
1720+
1721+ def test_alter_layout_re_blocking_broadcast ():
1722+ """Test of re-blocking shapes with boradcast operators"""
1723+
1724+ def before ():
1725+ dtype = "float32"
1726+ input_shape = (1 , 8 , 16 , 16 , 4 )
1727+ filter_shape = (1 , 8 , 4 , 4 , 4 , 4 )
1728+ bias_shape = (1 , 1 , 1 , 1 , 4 )
1729+ A = relay .var ("data" , shape = input_shape , dtype = dtype )
1730+ B = relay .var ("weight" , shape = filter_shape , dtype = dtype )
1731+ C = relay .var ("bias" , shape = bias_shape , dtype = dtype )
1732+
1733+ conv = relay .nn .conv2d (
1734+ A ,
1735+ B ,
1736+ data_layout = "NCHW4c" ,
1737+ kernel_layout = "OIHW4i4o" ,
1738+ padding = [3 , 3 , 0 , 0 ],
1739+ strides = [2 , 2 ],
1740+ out_dtype = dtype ,
1741+ channels = 4 ,
1742+ kernel_size = (4 , 4 ),
1743+ )
1744+ bias = relay .op .add (conv , C )
1745+ bias = relay .Function (analysis .free_vars (bias ), bias )
1746+ return bias
1747+
1748+ def expected ():
1749+ dtype = "float32"
1750+ input_shape = (1 , 8 , 16 , 16 , 4 )
1751+ filter_shape = (1 , 8 , 4 , 4 , 4 , 4 )
1752+ bias_shape = (1 , 1 , 1 , 1 , 4 )
1753+ A = relay .var ("data" , shape = input_shape , dtype = dtype )
1754+ B = relay .var ("weight" , shape = filter_shape , dtype = dtype )
1755+ C = relay .var ("bias" , shape = bias_shape , dtype = dtype )
1756+
1757+ A = relay .layout_transform (A , src_layout = "NCHW4c" , dst_layout = "NCHW2c" )
1758+ B = relay .layout_transform (B , src_layout = "OIHW4i4o" , dst_layout = "OIHW2i2o" )
1759+
1760+ conv = relay .nn .conv2d (
1761+ A ,
1762+ B ,
1763+ data_layout = "NCHW2c" ,
1764+ kernel_layout = "OIHW2i2o" ,
1765+ padding = [3 , 3 , 0 , 0 ],
1766+ strides = [2 , 2 ],
1767+ out_dtype = dtype ,
1768+ channels = 4 ,
1769+ kernel_size = (4 , 4 ),
1770+ )
1771+ C = relay .layout_transform (C , src_layout = "NCHW4c" , dst_layout = "NCHW2c" )
1772+ bias = relay .op .add (conv , C )
1773+ bias = relay .layout_transform (bias , src_layout = "NCHW2c" , dst_layout = "NCHW4c" )
1774+ bias = relay .Function (analysis .free_vars (bias ), bias )
1775+ return bias
1776+
1777+ def alter_conv2d (attrs , inputs , tinfos , out_type ):
1778+ data , weight = inputs
1779+ new_attrs = dict (attrs )
1780+ new_attrs ["data_layout" ] = "NCHW2c"
1781+ new_attrs ["kernel_layout" ] = "OIHW2i2o"
1782+ return relay .nn .conv2d (data , weight , ** new_attrs )
1783+
1784+ with TempOpAttr ("nn.conv2d" , "FTVMAlterOpLayout" , alter_conv2d ):
1785+ a = run_opt_pass (before (), transform .AlterOpLayout ())
1786+ b = run_opt_pass (expected (), transform .InferType ())
1787+ assert tvm .ir .structural_equal (a , b ), "Actual = \n " + str (a ) + "\n Expected = \n " + str (b )
1788+
1789+ inp = np .random .uniform (size = (1 , 8 , 16 , 16 , 4 )).astype (np .float32 )
1790+ weight = np .random .uniform (size = (1 , 8 , 4 , 4 , 4 , 4 )).astype (np .float32 )
1791+ z = np .random .uniform (size = (1 , 1 , 1 , 1 , 4 )).astype (np .float32 )
1792+ mod = tvm .IRModule .from_expr (before ())
1793+ with TempOpAttr ("nn.conv2d" , "FTVMAlterOpLayout" , alter_conv2d ):
1794+ with tvm .transform .PassContext (opt_level = 4 ):
1795+ res = relay .build_module .create_executor (
1796+ "graph" , mod , target = "llvm" , device = tvm .cpu ()
1797+ ).evaluate ()(inp , weight , z )
1798+ with tvm .transform .PassContext (opt_level = 0 ):
1799+ res1 = relay .build_module .create_executor (
1800+ "debug" , mod , target = "llvm" , device = tvm .cpu ()
1801+ ).evaluate ()(inp , weight , z )
1802+ np .testing .assert_allclose (res .numpy (), res1 .numpy (), rtol = 1e-5 , atol = 1e-5 )
1803+
1804+
16051805def test_broadcast_non_adaptable ():
16061806 """NCHW4c + [x, x, 4] and NCHW4c is being altered to NCHW"""
16071807
0 commit comments