@@ -1224,16 +1224,15 @@ layout while a return value of `1` will denote that accumulator matrices are `B`
12241224layout.
12251225
12261226``` llvm
1227- declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixMulOp .[MatTyC].[MatTyA].[MatTyB](
1227+ declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixMultiply .[MatTyC].[MatTyA].[MatTyB](
12281228 immarg i32, ; opcode
12291229 %dx.types.LinAlgMatrix<mangling>, ; matrix A
12301230 %dx.types.LinAlgMatrix<mangling> ; matrix B
12311231 )
12321232```
12331233
1234- Two opcodes are available for this operation class:
1235- * Matrix Matrix Multiply: ` C = A * B `
1236- * Matrix Matrix Multiply with Accumulation: ` C += A * B `
1234+ This operation multiplies an A matrix and B matrix into new accumulator matrix
1235+ following the form ` C = A * B ` .
12371236
12381237Validation rules will enforce that:
12391238* argument A is an ` A ` matrix
@@ -1269,6 +1268,32 @@ Validation rules will enforce that:
12691268
12701269Must be called from wave-uniform control flow.
12711270
1271+ ``` llvm
1272+ declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixMultiplyAccumulate.[MatTyR].[MatTyA].[MatTyB].[MatTyC](
1273+ immarg i32, ; opcode
1274+ %dx.types.LinAlgMatrix<mangling>, ; matrix A
1275+ %dx.types.LinAlgMatrix<mangling> ; matrix B
1276+ %dx.types.LinAlgMatrix<mangling> ; matrix C
1277+ )
1278+ ```
1279+
1280+ This operation multiplies an A matrix and B matrix and accumlates it into an
1281+ accumulator matrix following the form ` R = C + (A * B) ` .
1282+
1283+ Validation rules will enforce that:
1284+ * argument A is an ` A ` matrix
1285+ * argument B is a ` B ` matrix
1286+ * argument C is an ` Accumulator ` matrix
1287+ * return value (R) is an ` Accumulator ` matrix
1288+ * All four matrices have the same scope (Wave or ThreadGroup)
1289+ * Matrix A's dimensions shall be M x K
1290+ * Matrix B's dimensions shall be K x N
1291+ * Matrix C's dimensions shall be M x N
1292+ * Matrix R's dimensions shall be M x N
1293+ * The element types are compatible
1294+
1295+ Must be called from wave-uniform control flow.
1296+
12721297``` llvm
12731298declare <[NUMo] x [TYo]> @dx.op.linAlgMatVecMul.v[NUMo][TYo].[MatTy].v[NUMi][TYi](
12741299 immarg i32, ; opcode
0 commit comments