@@ -1602,6 +1602,152 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
16021602 np .testing .assert_allclose (res .numpy (), res1 .numpy ())
16031603
16041604
1605+ def test_alter_layout_blocked_broadcast ():
1606+ """Test boradcast operators working on already blocked layout"""
1607+
1608+ def before ():
1609+ dtype = "float32"
1610+ input_shape = (1 , 8 , 16 , 16 , 4 )
1611+ filter_shape = (1 , 8 , 4 , 4 , 4 , 4 )
1612+ bias_shape = (1 , 1 , 1 , 1 , 4 )
1613+ A = relay .var ("data" , shape = input_shape , dtype = dtype )
1614+ B = relay .var ("weight" , shape = filter_shape , dtype = dtype )
1615+ C = relay .var ("bias" , shape = bias_shape , dtype = dtype )
1616+
1617+ conv = relay .nn .conv2d (
1618+ A ,
1619+ B ,
1620+ data_layout = "NCHW4c" ,
1621+ kernel_layout = "OIHW4i4o" ,
1622+ padding = [3 , 3 , 0 , 0 ],
1623+ strides = [2 , 2 ],
1624+ out_dtype = dtype ,
1625+ channels = 4 ,
1626+ kernel_size = (4 , 4 ),
1627+ )
1628+ bias = relay .op .add (conv , C )
1629+ bias = relay .Function (analysis .free_vars (bias ), bias )
1630+ return bias
1631+
1632+ def expected ():
1633+ return before ()
1634+
1635+ def alter_conv2d (attrs , inputs , tinfos , out_type ):
1636+ data , weight = inputs
1637+ new_attrs = dict (attrs )
1638+ new_attrs ["data_layout" ] = "NCHW4c"
1639+ new_attrs ["kernel_layout" ] = "OIHW4i4o"
1640+ return relay .nn .conv2d (data , weight , ** new_attrs )
1641+
1642+ with TempOpAttr ("nn.conv2d" , "FTVMAlterOpLayout" , alter_conv2d ):
1643+ # a = run_opt_pass(before(), transform.AlterOpLayout())
1644+ # b = run_opt_pass(before(), transform.AlterOpLayout())
1645+ a = run_opt_pass (before (), transform .InferType ())
1646+ b = run_opt_pass (expected (), transform .InferType ())
1647+ assert tvm .ir .structural_equal (a , b ), "Actual = \n " + str (a ) + "\n Expected = \n " + str (b )
1648+
1649+ inp = np .random .uniform (size = (1 , 8 , 16 , 16 , 4 )).astype (np .float32 )
1650+ weight = np .random .uniform (size = (1 , 8 , 4 , 4 , 4 , 4 )).astype (np .float32 )
1651+ z = np .random .uniform (size = (1 , 1 , 1 , 1 , 4 )).astype (np .float32 )
1652+ mod = tvm .IRModule .from_expr (before ())
1653+ with TempOpAttr ("nn.conv2d" , "FTVMAlterOpLayout" , alter_conv2d ):
1654+ with tvm .transform .PassContext (opt_level = 4 ):
1655+ res = relay .build_module .create_executor (
1656+ "graph" , mod , target = "llvm" , device = tvm .cpu ()
1657+ ).evaluate ()(inp , weight , z )
1658+ with tvm .transform .PassContext (opt_level = 0 ):
1659+ res1 = relay .build_module .create_executor (
1660+ "debug" , mod , target = "llvm" , device = tvm .cpu ()
1661+ ).evaluate ()(inp , weight , z )
1662+ np .testing .assert_allclose (res .numpy (), res1 .numpy ())
1663+
1664+
1665+ def test_alter_layout_re_blocking_broadcast ():
1666+ """Test of re-blocking shapes with boradcast operators"""
1667+
1668+ def before ():
1669+ dtype = "float32"
1670+ input_shape = (1 , 8 , 16 , 16 , 4 )
1671+ filter_shape = (1 , 8 , 4 , 4 , 4 , 4 )
1672+ bias_shape = (1 , 1 , 1 , 1 , 4 )
1673+ A = relay .var ("data" , shape = input_shape , dtype = dtype )
1674+ B = relay .var ("weight" , shape = filter_shape , dtype = dtype )
1675+ C = relay .var ("bias" , shape = bias_shape , dtype = dtype )
1676+
1677+ conv = relay .nn .conv2d (
1678+ A ,
1679+ B ,
1680+ data_layout = "NCHW4c" ,
1681+ kernel_layout = "OIHW4i4o" ,
1682+ padding = [3 , 3 , 0 , 0 ],
1683+ strides = [2 , 2 ],
1684+ out_dtype = dtype ,
1685+ channels = 4 ,
1686+ kernel_size = (4 , 4 ),
1687+ )
1688+ bias = relay .op .add (conv , C )
1689+ bias = relay .Function (analysis .free_vars (bias ), bias )
1690+ return bias
1691+
1692+ def expected ():
1693+ dtype = "float32"
1694+ input_shape = (1 , 8 , 16 , 16 , 4 )
1695+ filter_shape = (1 , 8 , 4 , 4 , 4 , 4 )
1696+ bias_shape = (1 , 1 , 1 , 1 , 4 )
1697+ A = relay .var ("data" , shape = input_shape , dtype = dtype )
1698+ B = relay .var ("weight" , shape = filter_shape , dtype = dtype )
1699+ C = relay .var ("bias" , shape = bias_shape , dtype = dtype )
1700+
1701+ A = relay .layout_transform (A , src_layout = "NCHW4c" , dst_layout = "NCHW2c" )
1702+ B = relay .layout_transform (B , src_layout = "OIHW4i4o" , dst_layout = "OIHW2i2o" )
1703+
1704+ conv = relay .nn .conv2d (
1705+ A ,
1706+ B ,
1707+ data_layout = "NCHW2c" ,
1708+ kernel_layout = "OIHW2i2o" ,
1709+ padding = [3 , 3 , 0 , 0 ],
1710+ strides = [2 , 2 ],
1711+ out_dtype = dtype ,
1712+ channels = 4 ,
1713+ kernel_size = (4 , 4 ),
1714+ )
1715+ C = relay .layout_transform (C , src_layout = "NCHW4c" , dst_layout = "NCHW2c" )
1716+ bias = relay .op .add (conv , C )
1717+ bias = relay .layout_transform (bias , src_layout = "NCHW2c" , dst_layout = "NCHW4c" )
1718+ bias = relay .Function (analysis .free_vars (bias ), bias )
1719+ return bias
1720+
1721+ def alter_conv2d (attrs , inputs , tinfos , out_type ):
1722+ data , weight = inputs
1723+ new_attrs = dict (attrs )
1724+ new_attrs ["data_layout" ] = "NCHW2c"
1725+ new_attrs ["kernel_layout" ] = "OIHW2i2o"
1726+ return relay .nn .conv2d (data , weight , ** new_attrs )
1727+
1728+ with TempOpAttr ("nn.conv2d" , "FTVMAlterOpLayout" , alter_conv2d ):
1729+ a = run_opt_pass (before (), transform .AlterOpLayout ())
1730+ print (a )
1731+ b = run_opt_pass (expected (), transform .InferType ())
1732+ print (b )
1733+ assert tvm .ir .structural_equal (a , b ), "Actual = \n " + str (a ) + "\n Expected = \n " + str (b )
1734+
1735+ inp = np .random .uniform (size = (1 , 8 , 16 , 16 , 4 )).astype (np .float32 )
1736+ weight = np .random .uniform (size = (1 , 8 , 4 , 4 , 4 , 4 )).astype (np .float32 )
1737+ z = np .random .uniform (size = (1 , 1 , 1 , 1 , 4 )).astype (np .float32 )
1738+ mod = tvm .IRModule .from_expr (before ())
1739+ with TempOpAttr ("nn.conv2d" , "FTVMAlterOpLayout" , alter_conv2d ):
1740+ with tvm .transform .PassContext (opt_level = 4 ):
1741+ res = relay .build_module .create_executor (
1742+ "graph" , mod , target = "llvm" , device = tvm .cpu ()
1743+ ).evaluate ()(inp , weight , z )
1744+ with tvm .transform .PassContext (opt_level = 0 ):
1745+ res1 = relay .build_module .create_executor (
1746+ "debug" , mod , target = "llvm" , device = tvm .cpu ()
1747+ ).evaluate ()(inp , weight , z )
1748+ np .testing .assert_allclose (res .numpy (), res1 .numpy (), rtol = 1e-5 , atol = 1e-5 )
1749+
1750+
16051751def test_broadcast_non_adaptable ():
16061752 """NCHW4c + [x, x, 4] and NCHW4c is being altered to NCHW"""
16071753
0 commit comments