Skip to content
Merged
11 changes: 7 additions & 4 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,10 +955,13 @@ def _convert_concat(
if input_shape is None:
input_shape = keras_layer.input_shape

if data_layout == "NHWC" or len(input_shape[0]) < 4:
axis = -1
else:
axis = 1
axis = keras_layer.axis
dims = len(input_shape[0])
if data_layout == "NCHW": # need_transpose
if axis == -1:
axis = 1
else:
axis = axis + 1 if axis < dims else 1
return _op.concatenate(_as_list(inexpr), axis=axis)


Expand Down
19 changes: 19 additions & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,24 @@ def test_forward_merge(self, keras_mod):
keras_model = keras_mod.models.Model(data, out)
verify_keras_frontend(keras_model)

def test_forward_concatenate(self, keras_mod):
"""test_forward_concatenate"""
data1 = keras_mod.layers.Input(shape=(1, 2, 2))
data2 = keras_mod.layers.Input(shape=(1, 1, 2))
merge_func = keras_mod.layers.Concatenate(axis=2)
out = merge_func([data1, data2])
keras_model = keras_mod.models.Model([data1, data2], out)
verify_keras_frontend(keras_model, layout="NHWC")
verify_keras_frontend(keras_model, layout="NCHW")
# test default axis (e.g., -1)
data1 = keras_mod.layers.Input(shape=(1, 2, 2))
data2 = keras_mod.layers.Input(shape=(1, 2, 3))
merge_func = keras_mod.layers.Concatenate()
out = merge_func([data1, data2])
keras_model = keras_mod.models.Model([data1, data2], out)
verify_keras_frontend(keras_model, layout="NHWC")
verify_keras_frontend(keras_model, layout="NCHW")

def test_forward_merge_dot(self, keras_mod):
"""test_forward_merge_dot"""
data1 = keras_mod.layers.Input(shape=(2, 2))
Expand Down Expand Up @@ -793,6 +811,7 @@ def test_forward_time_distributed(self, keras_mod):
if __name__ == "__main__":
for k in [keras, tf_keras]:
sut = TestKeras()
sut.test_forward_concatenate(keras_mod=k)
sut.test_forward_merge_dot(keras_mod=k)
sut.test_forward_merge(keras_mod=k)
sut.test_forward_activations(keras_mod=k)
Expand Down