Skip to content

Commit cc3779b

Browse files
committed
Add tests on boradcast blocking over already blocked layout
1 parent c423ca4 commit cc3779b

File tree

1 file changed

+200
-0
lines changed

1 file changed

+200
-0
lines changed

tests/python/relay/test_pass_alter_op_layout.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,6 +1602,206 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
16021602
np.testing.assert_allclose(res.numpy(), res1.numpy())
16031603

16041604

1605+
def test_alter_layout_blocked_no_broadcast():
1606+
"""Test boradcast operators working on already blocked layout"""
1607+
1608+
def before():
1609+
dtype = "float32"
1610+
input_shape = (1, 8, 16, 16, 4)
1611+
filter_shape = (1, 8, 4, 4, 4, 4)
1612+
bias_shape = (1, 1, 1, 1, 4)
1613+
A = relay.var("data", shape=input_shape, dtype=dtype)
1614+
B = relay.var("weight", shape=filter_shape, dtype=dtype)
1615+
C = relay.var("bias", shape=bias_shape, dtype=dtype)
1616+
1617+
conv = relay.nn.conv2d(
1618+
A,
1619+
B,
1620+
data_layout="NCHW4c",
1621+
kernel_layout="OIHW4i4o",
1622+
padding=[3, 3, 0, 0],
1623+
strides=[2, 2],
1624+
out_dtype=dtype,
1625+
channels=4,
1626+
kernel_size=(4, 4),
1627+
)
1628+
bias = relay.op.add(conv, C)
1629+
bias = relay.Function(analysis.free_vars(bias), bias)
1630+
return bias
1631+
1632+
def expected():
1633+
return before()
1634+
1635+
def alter_conv2d(attrs, inputs, tinfos, out_type):
1636+
data, weight = inputs
1637+
new_attrs = dict(attrs)
1638+
new_attrs["data_layout"] = "NCHW4c"
1639+
new_attrs["kernel_layout"] = "OIHW4i4o"
1640+
return relay.nn.conv2d(data, weight, **new_attrs)
1641+
1642+
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
1643+
a = run_opt_pass(before(), transform.AlterOpLayout())
1644+
b = run_opt_pass(expected(), transform.InferType())
1645+
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b)
1646+
1647+
inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32)
1648+
weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32)
1649+
z = np.random.uniform(size=(1, 1, 1, 1, 4)).astype(np.float32)
1650+
mod = tvm.IRModule.from_expr(before())
1651+
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
1652+
with tvm.transform.PassContext(opt_level=4):
1653+
res = relay.build_module.create_executor(
1654+
"graph", mod, target="llvm", device=tvm.cpu()
1655+
).evaluate()(inp, weight, z)
1656+
with tvm.transform.PassContext(opt_level=0):
1657+
res1 = relay.build_module.create_executor(
1658+
"debug", mod, target="llvm", device=tvm.cpu()
1659+
).evaluate()(inp, weight, z)
1660+
np.testing.assert_allclose(res.numpy(), res1.numpy())
1661+
1662+
1663+
def test_alter_layout_blocked_broadcast():
1664+
"""Test boradcast operators working on already blocked layout"""
1665+
1666+
def before():
1667+
dtype = "float32"
1668+
input_shape = (1, 8, 16, 16, 4)
1669+
filter_shape = (1, 8, 4, 4, 4, 4)
1670+
bias_shape = (1, 1, 1, 1, 1)
1671+
A = relay.var("data", shape=input_shape, dtype=dtype)
1672+
B = relay.var("weight", shape=filter_shape, dtype=dtype)
1673+
C = relay.var("bias", shape=bias_shape, dtype=dtype)
1674+
1675+
conv = relay.nn.conv2d(
1676+
A,
1677+
B,
1678+
data_layout="NCHW4c",
1679+
kernel_layout="OIHW4i4o",
1680+
padding=[3, 3, 0, 0],
1681+
strides=[2, 2],
1682+
out_dtype=dtype,
1683+
channels=4,
1684+
kernel_size=(4, 4),
1685+
)
1686+
bias = relay.op.add(conv, C)
1687+
bias = relay.Function(analysis.free_vars(bias), bias)
1688+
return bias
1689+
1690+
def expected():
1691+
return before()
1692+
1693+
def alter_conv2d(attrs, inputs, tinfos, out_type):
1694+
data, weight = inputs
1695+
new_attrs = dict(attrs)
1696+
new_attrs["data_layout"] = "NCHW4c"
1697+
new_attrs["kernel_layout"] = "OIHW4i4o"
1698+
return relay.nn.conv2d(data, weight, **new_attrs)
1699+
1700+
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
1701+
a = run_opt_pass(before(), transform.AlterOpLayout())
1702+
b = run_opt_pass(expected(), transform.InferType())
1703+
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b)
1704+
1705+
inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32)
1706+
weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32)
1707+
z = np.random.uniform(size=(1, 1, 1, 1, 1)).astype(np.float32)
1708+
mod = tvm.IRModule.from_expr(before())
1709+
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
1710+
with tvm.transform.PassContext(opt_level=4):
1711+
res = relay.build_module.create_executor(
1712+
"graph", mod, target="llvm", device=tvm.cpu()
1713+
).evaluate()(inp, weight, z)
1714+
with tvm.transform.PassContext(opt_level=0):
1715+
res1 = relay.build_module.create_executor(
1716+
"debug", mod, target="llvm", device=tvm.cpu()
1717+
).evaluate()(inp, weight, z)
1718+
np.testing.assert_allclose(res.numpy(), res1.numpy())
1719+
1720+
1721+
def test_alter_layout_re_blocking_broadcast():
1722+
"""Test of re-blocking shapes with boradcast operators"""
1723+
1724+
def before():
1725+
dtype = "float32"
1726+
input_shape = (1, 8, 16, 16, 4)
1727+
filter_shape = (1, 8, 4, 4, 4, 4)
1728+
bias_shape = (1, 1, 1, 1, 4)
1729+
A = relay.var("data", shape=input_shape, dtype=dtype)
1730+
B = relay.var("weight", shape=filter_shape, dtype=dtype)
1731+
C = relay.var("bias", shape=bias_shape, dtype=dtype)
1732+
1733+
conv = relay.nn.conv2d(
1734+
A,
1735+
B,
1736+
data_layout="NCHW4c",
1737+
kernel_layout="OIHW4i4o",
1738+
padding=[3, 3, 0, 0],
1739+
strides=[2, 2],
1740+
out_dtype=dtype,
1741+
channels=4,
1742+
kernel_size=(4, 4),
1743+
)
1744+
bias = relay.op.add(conv, C)
1745+
bias = relay.Function(analysis.free_vars(bias), bias)
1746+
return bias
1747+
1748+
def expected():
1749+
dtype = "float32"
1750+
input_shape = (1, 8, 16, 16, 4)
1751+
filter_shape = (1, 8, 4, 4, 4, 4)
1752+
bias_shape = (1, 1, 1, 1, 4)
1753+
A = relay.var("data", shape=input_shape, dtype=dtype)
1754+
B = relay.var("weight", shape=filter_shape, dtype=dtype)
1755+
C = relay.var("bias", shape=bias_shape, dtype=dtype)
1756+
1757+
A = relay.layout_transform(A, src_layout="NCHW4c", dst_layout="NCHW2c")
1758+
B = relay.layout_transform(B, src_layout="OIHW4i4o", dst_layout="OIHW2i2o")
1759+
1760+
conv = relay.nn.conv2d(
1761+
A,
1762+
B,
1763+
data_layout="NCHW2c",
1764+
kernel_layout="OIHW2i2o",
1765+
padding=[3, 3, 0, 0],
1766+
strides=[2, 2],
1767+
out_dtype=dtype,
1768+
channels=4,
1769+
kernel_size=(4, 4),
1770+
)
1771+
C = relay.layout_transform(C, src_layout="NCHW4c", dst_layout="NCHW2c")
1772+
bias = relay.op.add(conv, C)
1773+
bias = relay.layout_transform(bias, src_layout="NCHW2c", dst_layout="NCHW4c")
1774+
bias = relay.Function(analysis.free_vars(bias), bias)
1775+
return bias
1776+
1777+
def alter_conv2d(attrs, inputs, tinfos, out_type):
1778+
data, weight = inputs
1779+
new_attrs = dict(attrs)
1780+
new_attrs["data_layout"] = "NCHW2c"
1781+
new_attrs["kernel_layout"] = "OIHW2i2o"
1782+
return relay.nn.conv2d(data, weight, **new_attrs)
1783+
1784+
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
1785+
a = run_opt_pass(before(), transform.AlterOpLayout())
1786+
b = run_opt_pass(expected(), transform.InferType())
1787+
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b)
1788+
1789+
inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32)
1790+
weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32)
1791+
z = np.random.uniform(size=(1, 1, 1, 1, 4)).astype(np.float32)
1792+
mod = tvm.IRModule.from_expr(before())
1793+
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
1794+
with tvm.transform.PassContext(opt_level=4):
1795+
res = relay.build_module.create_executor(
1796+
"graph", mod, target="llvm", device=tvm.cpu()
1797+
).evaluate()(inp, weight, z)
1798+
with tvm.transform.PassContext(opt_level=0):
1799+
res1 = relay.build_module.create_executor(
1800+
"debug", mod, target="llvm", device=tvm.cpu()
1801+
).evaluate()(inp, weight, z)
1802+
np.testing.assert_allclose(res.numpy(), res1.numpy(), rtol=1e-5, atol=1e-5)
1803+
1804+
16051805
def test_broadcast_non_adaptable():
16061806
"""NCHW4c + [x, x, 4] and NCHW4c is being altered to NCHW"""
16071807

0 commit comments

Comments
 (0)