Skip to content

Commit 39eb396

Browse files
committed
Add tests on boradcast blocking over already blocked layout
1 parent c423ca4 commit 39eb396

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed

tests/python/relay/test_pass_alter_op_layout.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,6 +1602,152 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
16021602
np.testing.assert_allclose(res.numpy(), res1.numpy())
16031603

16041604

1605+
def test_alter_layout_blocked_broadcast():
1606+
"""Test boradcast operators working on already blocked layout"""
1607+
1608+
def before():
1609+
dtype = "float32"
1610+
input_shape = (1, 8, 16, 16, 4)
1611+
filter_shape = (1, 8, 4, 4, 4, 4)
1612+
bias_shape = (1, 1, 1, 1, 4)
1613+
A = relay.var("data", shape=input_shape, dtype=dtype)
1614+
B = relay.var("weight", shape=filter_shape, dtype=dtype)
1615+
C = relay.var("bias", shape=bias_shape, dtype=dtype)
1616+
1617+
conv = relay.nn.conv2d(
1618+
A,
1619+
B,
1620+
data_layout="NCHW4c",
1621+
kernel_layout="OIHW4i4o",
1622+
padding=[3, 3, 0, 0],
1623+
strides=[2, 2],
1624+
out_dtype=dtype,
1625+
channels=4,
1626+
kernel_size=(4, 4),
1627+
)
1628+
bias = relay.op.add(conv, C)
1629+
bias = relay.Function(analysis.free_vars(bias), bias)
1630+
return bias
1631+
1632+
def expected():
1633+
return before()
1634+
1635+
def alter_conv2d(attrs, inputs, tinfos, out_type):
1636+
data, weight = inputs
1637+
new_attrs = dict(attrs)
1638+
new_attrs["data_layout"] = "NCHW4c"
1639+
new_attrs["kernel_layout"] = "OIHW4i4o"
1640+
return relay.nn.conv2d(data, weight, **new_attrs)
1641+
1642+
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
1643+
# a = run_opt_pass(before(), transform.AlterOpLayout())
1644+
# b = run_opt_pass(before(), transform.AlterOpLayout())
1645+
a = run_opt_pass(before(), transform.InferType())
1646+
b = run_opt_pass(expected(), transform.InferType())
1647+
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b)
1648+
1649+
inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32)
1650+
weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32)
1651+
z = np.random.uniform(size=(1, 1, 1, 1, 4)).astype(np.float32)
1652+
mod = tvm.IRModule.from_expr(before())
1653+
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
1654+
with tvm.transform.PassContext(opt_level=4):
1655+
res = relay.build_module.create_executor(
1656+
"graph", mod, target="llvm", device=tvm.cpu()
1657+
).evaluate()(inp, weight, z)
1658+
with tvm.transform.PassContext(opt_level=0):
1659+
res1 = relay.build_module.create_executor(
1660+
"debug", mod, target="llvm", device=tvm.cpu()
1661+
).evaluate()(inp, weight, z)
1662+
np.testing.assert_allclose(res.numpy(), res1.numpy())
1663+
1664+
1665+
def test_alter_layout_re_blocking_broadcast():
1666+
"""Test of re-blocking shapes with boradcast operators"""
1667+
1668+
def before():
1669+
dtype = "float32"
1670+
input_shape = (1, 8, 16, 16, 4)
1671+
filter_shape = (1, 8, 4, 4, 4, 4)
1672+
bias_shape = (1, 1, 1, 1, 4)
1673+
A = relay.var("data", shape=input_shape, dtype=dtype)
1674+
B = relay.var("weight", shape=filter_shape, dtype=dtype)
1675+
C = relay.var("bias", shape=bias_shape, dtype=dtype)
1676+
1677+
conv = relay.nn.conv2d(
1678+
A,
1679+
B,
1680+
data_layout="NCHW4c",
1681+
kernel_layout="OIHW4i4o",
1682+
padding=[3, 3, 0, 0],
1683+
strides=[2, 2],
1684+
out_dtype=dtype,
1685+
channels=4,
1686+
kernel_size=(4, 4),
1687+
)
1688+
bias = relay.op.add(conv, C)
1689+
bias = relay.Function(analysis.free_vars(bias), bias)
1690+
return bias
1691+
1692+
def expected():
1693+
dtype = "float32"
1694+
input_shape = (1, 8, 16, 16, 4)
1695+
filter_shape = (1, 8, 4, 4, 4, 4)
1696+
bias_shape = (1, 1, 1, 1, 4)
1697+
A = relay.var("data", shape=input_shape, dtype=dtype)
1698+
B = relay.var("weight", shape=filter_shape, dtype=dtype)
1699+
C = relay.var("bias", shape=bias_shape, dtype=dtype)
1700+
1701+
A = relay.layout_transform(A, src_layout="NCHW4c", dst_layout="NCHW2c")
1702+
B = relay.layout_transform(B, src_layout="OIHW4i4o", dst_layout="OIHW2i2o")
1703+
1704+
conv = relay.nn.conv2d(
1705+
A,
1706+
B,
1707+
data_layout="NCHW2c",
1708+
kernel_layout="OIHW2i2o",
1709+
padding=[3, 3, 0, 0],
1710+
strides=[2, 2],
1711+
out_dtype=dtype,
1712+
channels=4,
1713+
kernel_size=(4, 4),
1714+
)
1715+
C = relay.layout_transform(C, src_layout="NCHW4c", dst_layout="NCHW2c")
1716+
bias = relay.op.add(conv, C)
1717+
bias = relay.layout_transform(bias, src_layout="NCHW2c", dst_layout="NCHW4c")
1718+
bias = relay.Function(analysis.free_vars(bias), bias)
1719+
return bias
1720+
1721+
def alter_conv2d(attrs, inputs, tinfos, out_type):
1722+
data, weight = inputs
1723+
new_attrs = dict(attrs)
1724+
new_attrs["data_layout"] = "NCHW2c"
1725+
new_attrs["kernel_layout"] = "OIHW2i2o"
1726+
return relay.nn.conv2d(data, weight, **new_attrs)
1727+
1728+
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
1729+
a = run_opt_pass(before(), transform.AlterOpLayout())
1730+
print(a)
1731+
b = run_opt_pass(expected(), transform.InferType())
1732+
print(b)
1733+
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b)
1734+
1735+
inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32)
1736+
weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32)
1737+
z = np.random.uniform(size=(1, 1, 1, 1, 4)).astype(np.float32)
1738+
mod = tvm.IRModule.from_expr(before())
1739+
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
1740+
with tvm.transform.PassContext(opt_level=4):
1741+
res = relay.build_module.create_executor(
1742+
"graph", mod, target="llvm", device=tvm.cpu()
1743+
).evaluate()(inp, weight, z)
1744+
with tvm.transform.PassContext(opt_level=0):
1745+
res1 = relay.build_module.create_executor(
1746+
"debug", mod, target="llvm", device=tvm.cpu()
1747+
).evaluate()(inp, weight, z)
1748+
np.testing.assert_allclose(res.numpy(), res1.numpy(), rtol=1e-5, atol=1e-5)
1749+
1750+
16051751
def test_broadcast_non_adaptable():
16061752
"""NCHW4c + [x, x, 4] and NCHW4c is being altered to NCHW"""
16071753

0 commit comments

Comments
 (0)