diff --git a/python/tvm/contrib/msc/core/transform/pattern.py b/python/tvm/contrib/msc/core/transform/pattern.py index fdc6a628310d..135bac64ae80 100644 --- a/python/tvm/contrib/msc/core/transform/pattern.py +++ b/python/tvm/contrib/msc/core/transform/pattern.py @@ -330,7 +330,8 @@ def make_relax_attention_pattern() -> ( q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q) k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k) v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v) - out = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans) + attention = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans) + out = relax_pattern.is_op("relax.permute_dims")(attention) annotations = { "weight_q": weight_q, "weight_k": weight_k, @@ -338,7 +339,8 @@ def make_relax_attention_pattern() -> ( "q_trans": q_trans, "k_trans": k_trans, "v_trans": v_trans, - "attention": out, + "attention": attention, + "out": out, } return out, annotations @@ -378,7 +380,8 @@ def make_relax_mask_attention_pattern() -> ( q_trans = relax_pattern.is_op("relax.permute_dims")(weight_q) k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k) v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v) - out = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans, v_trans, mask) + attention = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans, v_trans, mask) + out = relax_pattern.is_op("relax.permute_dims")(attention) annotations = { "weight_q": weight_q, "weight_k": weight_k, @@ -387,7 +390,8 @@ def make_relax_mask_attention_pattern() -> ( "q_trans": q_trans, "k_trans": k_trans, "v_trans": v_trans, - "attention": out, + "attention": attention, + "out": out, } return out, annotations diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 983bce0255d9..27da69dbb182 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1015,7 +1015,9 @@ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: assert "float" in attn_mask.struct_info.dtype, msg return self.block_builder.emit( - relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) + transpose_S_H( + relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) + ) ) def _unbind(self, node: fx.Node) -> relax.Var: diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index 1913e8ecda8e..73722f987701 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -107,6 +107,7 @@ class RelaxAttentionCodeGen : public RelaxOpCode { .op_list_arg(axes_key, "axes"); } stack_.op_call().op_inputs_arg(false).op_arg("scale").op_str_arg("causal_mask"); + stack_.op_call("relax.op.permute_dims").op_output_arg().op_list_arg("axes_3", "axes"); } }; diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 60c8a73dcc67..7fa71df20b45 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -2362,12 +2362,7 @@ def forward(self, q_data, k_data, v_data): {"name": "inp_2", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ACBD"}, ], "outputs": [ - { - "name": "attention", - "shape": [1, seq, 8, 64], - "dtype": "float32", - "layout": "ABCD", - } + {"name": "attention", "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ABCD"} ], "nodes": {"total": 4, "input": 3, "msc.attention": 1}, } @@ -2396,7 +2391,7 @@ def forward(self, q_data, k_data, v_data, mask): "outputs": [ { "name": "attention_bias", - "shape": [1, seq, 8, 64], + "shape": [1, 8, seq, 64], "dtype": "float32", "layout": "ABCD", } diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 78fc7abdf748..e7e1e991c534 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3825,7 +3825,7 @@ def main( inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), - ) -> R.Tensor((32, 128, 8, 64), dtype="float32"): + ) -> R.Tensor((32, 8, 128, 64), dtype="float32"): with R.dataflow(): lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( inp_0, axes=[0, 2, 1, 3] @@ -3839,7 +3839,10 @@ def main( lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( lv, lv1, lv2, scale=None ) - gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3 + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] + ) + gv: R.Tensor((32, 8, 128, 64), dtype="float32") = lv4 R.output(gv) return gv @@ -3851,7 +3854,7 @@ def main( inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"), - ) -> R.Tensor((32, 128, 8, 64), dtype="float32"): + ) -> R.Tensor((32, 8, 128, 64), dtype="float32"): with R.dataflow(): lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( inp_0, axes=[0, 2, 1, 3] @@ -3865,7 +3868,10 @@ def main( lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( lv, lv1, lv2, inp_3, scale=None ) - gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3 + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] + ) + gv: R.Tensor((32, 8, 128, 64), dtype="float32") = lv4 R.output(gv) return gv @@ -3876,7 +3882,7 @@ def main( inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), - ) -> R.Tensor((32, 128, 8, 64), dtype="float32"): + ) -> R.Tensor((32, 8, 128, 64), dtype="float32"): with R.dataflow(): lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( inp_0, axes=[0, 2, 1, 3] @@ -3890,7 +3896,10 @@ def main( lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( lv, lv1, lv2, scale=None, causal_mask="TopLeft" ) - gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3 + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] + ) + gv: R.Tensor((32, 8, 128, 64), dtype="float32") = lv4 R.output(gv) return gv