Skip to content

Commit 62f9b1d

Browse files
authored
[Tensorflow] Fix conv2d_transpose for NHWC layout (#14546)
* [Tensorflow] Fix conv2d_transpose for NHWC layout If "data_format" == "NHWC", the kernel_layout should be "HWOI" rather than "HWIO". * remove deed code * add test cases * Update test_forward.py * Update test_forward.py * Update tensorflow_ops.py * Update tensorflow_ops.py
1 parent 670d128 commit 62f9b1d

2 files changed

Lines changed: 22 additions & 3 deletions

File tree

python/tvm/relay/frontend/tensorflow_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,8 @@ def _impl(inputs, attr, params, mod):
464464
if opname == "conv":
465465
attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW"
466466
elif opname == "conv_transpose":
467-
# conv_transpose in TVM has weights be IOHW for NCHW
468-
attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "IOHW"
467+
# conv_transpose has weights be IOHW, because the attr["data_format"] always be NCHW
468+
attr["kernel_layout"] = "IOHW"
469469
else:
470470
attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW"
471471

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,16 @@ def test_forward_convolution():
742742
"NCHW",
743743
[1, 1, 8, 8],
744744
)
745-
745+
_test_convolution(
746+
"conv_transpose",
747+
[4, 19, 8, 8],
748+
[2, 2, 66, 19],
749+
[1, 1],
750+
[2, 2],
751+
"VALID",
752+
"NCHW",
753+
[4, 66, 16, 16],
754+
)
746755
_test_convolution("conv", [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC")
747756
_test_convolution("conv", [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC")
748757
_test_convolution("conv", [4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC")
@@ -917,6 +926,16 @@ def test_forward_convolution():
917926
[4, 8, 8, 176],
918927
add_shapes_to_graph_def=False,
919928
)
929+
_test_convolution(
930+
"conv_transpose",
931+
[4, 8, 8, 19],
932+
[2, 2, 66, 19],
933+
[1, 1],
934+
[2, 2],
935+
"VALID",
936+
"NHWC",
937+
[4, 16, 16, 66],
938+
)
920939
# Explicit padding
921940
if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"):
922941
_test_convolution(

0 commit comments

Comments
 (0)