Skip to content

Commit 5896507

Browse files
authored
[0035] Fix LinAlg MatrixMulAccumulate (#791)
Fixes #781
1 parent e3cc6e9 commit 5896507

1 file changed

Lines changed: 29 additions & 4 deletions

File tree

proposals/0035-linalg-matrix.md

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,16 +1224,15 @@ layout while a return value of `1` will denote that accumulator matrices are `B`
12241224
layout.
12251225

12261226
```llvm
1227-
declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixMulOp.[MatTyC].[MatTyA].[MatTyB](
1227+
declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixMultiply.[MatTyC].[MatTyA].[MatTyB](
12281228
immarg i32, ; opcode
12291229
%dx.types.LinAlgMatrix<mangling>, ; matrix A
12301230
%dx.types.LinAlgMatrix<mangling> ; matrix B
12311231
)
12321232
```
12331233

1234-
Two opcodes are available for this operation class:
1235-
* Matrix Matrix Multiply: `C = A * B`
1236-
* Matrix Matrix Multiply with Accumulation: `C += A * B`
1234+
This operation multiplies an A matrix and B matrix into new accumulator matrix
1235+
following the form `C = A * B`.
12371236

12381237
Validation rules will enforce that:
12391238
* argument A is an `A` matrix
@@ -1269,6 +1268,32 @@ Validation rules will enforce that:
12691268

12701269
Must be called from wave-uniform control flow.
12711270

1271+
```llvm
1272+
declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixMultiplyAccumulate.[MatTyR].[MatTyA].[MatTyB].[MatTyC](
1273+
immarg i32, ; opcode
1274+
%dx.types.LinAlgMatrix<mangling>, ; matrix A
1275+
%dx.types.LinAlgMatrix<mangling> ; matrix B
1276+
%dx.types.LinAlgMatrix<mangling> ; matrix C
1277+
)
1278+
```
1279+
1280+
This operation multiplies an A matrix and B matrix and accumlates it into an
1281+
accumulator matrix following the form `R = C + (A * B)`.
1282+
1283+
Validation rules will enforce that:
1284+
* argument A is an `A` matrix
1285+
* argument B is a `B` matrix
1286+
* argument C is an `Accumulator` matrix
1287+
* return value (R) is an `Accumulator` matrix
1288+
* All four matrices have the same scope (Wave or ThreadGroup)
1289+
* Matrix A's dimensions shall be M x K
1290+
* Matrix B's dimensions shall be K x N
1291+
* Matrix C's dimensions shall be M x N
1292+
* Matrix R's dimensions shall be M x N
1293+
* The element types are compatible
1294+
1295+
Must be called from wave-uniform control flow.
1296+
12721297
``` llvm
12731298
declare <[NUMo] x [TYo]> @dx.op.linAlgMatVecMul.v[NUMo][TYo].[MatTy].v[NUMi][TYi](
12741299
immarg i32, ; opcode

0 commit comments

Comments
 (0)