Skip to content

Commit 690765a

Browse files
llvm-beanzV-FEXrttex3dhekota
authored
[0035] Rework specification to treat matrices as values (#769)
This updates the linalg spec to treat the AttributedMatrixRef objects as value objects in the SSA graph. This should address concerns about object lifetimes. Fixes #756 --------- Co-authored-by: Ashley Coleman <ascoleman@microsoft.com> Co-authored-by: Tex Riddell <texr@microsoft.com> Co-authored-by: Helena Kotas <hekotas@microsoft.com>
1 parent 5fdb63c commit 690765a

1 file changed

Lines changed: 109 additions & 112 deletions

File tree

proposals/0035-linalg-matrix.md

Lines changed: 109 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ void OuterProdAccum() {
339339
### HLSL API Concepts
340340

341341
The new HLSL API introduces a new `linalg::Matrix` type which represents an
342-
opaque matrix object, and contains an intangible handle that refers to the
343-
allocated matrix.
342+
opaque matrix object, and contains an intangible value object that refers to the
343+
matrix.
344344

345345
The `linalg::Matrix` template type is parameterized based on the matrix
346346
component data type, dimensions, use, and scope. These parameters restrict where
@@ -443,7 +443,10 @@ In HLSL, matrix objects are intangible objects so they do not have defined size
443443
or memory layout. When in use, implementations are expected to distribute the
444444
storage of matrices across the thread-local storage for all threads in a SIMD
445445
unit. An implementation may also utilize caches or other memory regions as
446-
appropriate. At the DXIL level a matrix is represented as a handle object.
446+
appropriate. At the DXIL level a matrix is represented as a value object.
447+
Because LLVM 3.7 doesn't allow value objects of opaque types, the matrix object
448+
stores a pointer in the IR, but implementations will replace this with an
449+
implementation-defined object.
447450

448451
An A matrix is a collection of per-thread vectors representing matrix rows,
449452
while a B matrix is a collection of per-thread vectors representing matrix
@@ -1022,45 +1025,50 @@ enum class DXILComponentType {
10221025
}
10231026
```
10241027
1025-
This feature also adds a matrix ref that serves as an opaque type handle to the
1026-
implementation's representation of the matrix.
1027-
1028-
1029-
```llvm
1030-
%dx.types.MatrixRef = type { i8 * }
1031-
```
1032-
1033-
The compiler will also generate a permutation of typed matrix handles with names
1034-
of the format `%dx.types.AttributedMatrixRef<mangling>`. The mangling scheme for
1028+
The compiler will generate a permutation of typed matrix handles with names of
1029+
the format `%dx.types.LinAlgMatrix<mangling>`. The mangling scheme for
10351030
each type name will capture the type parameterization with the tokens `C`,
10361031
`M`, `N`, `U` and `S` denoting each encoded property.
10371032
10381033
```
10391034
; Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Wave>
1040-
%dx.types.AttributedMatrixRefC10M16N16U0S1 = type { i8 * }
1035+
%dx.types.LinAlgMatrixC10M16N16U0S1 = type { i8 * }
10411036
; Matrix<ComponentType::F16, 16, 16, MatrixUse::B, MatrixScope::Wave>
1042-
%dx.types.AttributedMatrixRefC10M16N16U1S1 = type { i8 * }
1037+
%dx.types.LinAlgMatrixC10M16N16U1S1 = type { i8 * }
10431038
; Matrix<ComponentType::F32, 16, 16, MatrixUse::Accumulator, MatrixScope::Wave>
1044-
%dx.types.AttributedMatrixRefC11M16N16U2S1 = type { i8 * }
1039+
%dx.types.LinAlgMatrixC11M16N16U2S1 = type { i8 * }
10451040
```
10461041
1047-
DXIL validation will enforce that an `AttributedMatrixRef` of any type may be
1048-
bitcast to a `%dx.types.MatrixRef`, but the inverse cast will be disallowed.
1042+
DXIL validation will enforce that a `LinAlgMatrix` of any type may not
1043+
be bitcast to any other type.
1044+
1045+
#### Type Metadata
1046+
1047+
A new named metadata `dx.targetTypes` will be added to contain mappings of
1048+
attributed matrix types to their type parameters avoiding needing to parse the
1049+
type mangling. For the given examples above metadata of the form below will be
1050+
generated:
10491051
1050-
### DXIL Operations
10511052
1052-
```llvm
1053-
declare %dx.types.AttributedMatrixRef<mangling> @dx.op.createMatrix<mangling>(
1054-
immarg i32 ; opcode
1055-
)
10561053
```
1054+
!dx.targetTypes = !{!1, !2, !3}
1055+
; Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Wave>
1056+
!1 = !{%dx.types.LinAlgMatrixC10M16N16U0S1 undef, i32 10, i32 16, i32 16, i32 0, i32 1 }
1057+
; Matrix<ComponentType::F16, 16, 16, MatrixUse::B, MatrixScope::Wave>
1058+
!2 = !{%dx.types.LinAlgMatrixC10M16N16U1S1 undef, i32 10, i32 16, i32 16, i32 1, i32 1 }
1059+
; Matrix<ComponentType::F32, 16, 16, MatrixUse::Accumulator, MatrixScope::Wave>
1060+
!3 = !{%dx.types.LinAlgMatrixC11M16N16U2S1 undef, i32 11, i32 16, i32 16, i32 2, i32 1 }
1061+
```
1062+
1063+
> Note: to ease compatability with modern LLVM we want the metadata to avoid
1064+
> encoding pointers since modern LLVM will convert pointers to opaque pointers
1065+
> losing the type information.
10571066
1058-
Creates a new uninitialized matrix handle.
1067+
### DXIL Operations
10591068
10601069
```llvm
1061-
declare void @dx.op.fillMatrix.[TY](
1070+
declare %dx.types.LinAlgMatrix<mangling> @dx.op.fillMatrix.[MatTy].[TY](
10621071
immarg i32, ; opcode
1063-
%dx.types.MatrixRef, ; matrix
10641072
[Ty] ; fill value
10651073
)
10661074
```
@@ -1070,11 +1078,10 @@ matrix component's type, a type conversion is applied following the rules
10701078
documented in the [Conversions](#conversions) section.
10711079

10721080
```llvm
1073-
declare void @dx.op.copyConvertMatrix(
1074-
immarg i32, ; opcode
1075-
%dx.types.MatrixRef, ; matrix destination
1076-
%dx.types.MatrixRef, ; matrix source
1077-
immarg i1, ; transpose
1081+
declare %dx.types.LinAlgMatrix<mangling> @dx.op.copyConvertMatrix.[MatTy1].[MatTy2](
1082+
immarg i32, ; opcode
1083+
%dx.types.LinAlgMatrix<mangling>, ; matrix source
1084+
immarg i1, ; transpose
10781085
)
10791086
```
10801087

@@ -1084,9 +1091,8 @@ unmodified after this operation is applied. Validation shall enforce that both
10841091
matrices have the same scope and dimensions.
10851092

10861093
```llvm
1087-
declare void @dx.op.matrixLoadFromDescriptor(
1094+
declare %dx.types.LinAlgMatrix<mangling> @dx.op.matrixLoadFromDescriptor.[MatTy](
10881095
immarg i32, ; opcode
1089-
%dx.types.MatrixRef, ; matrix
10901096
%dx.types.Handle, ; ByteAddressBuffer
10911097
i32, ; Offset
10921098
i32, ; Stride
@@ -1106,9 +1112,8 @@ Validation rules will enforce that:
11061112
* `Stride` is `0` if the `Layout` is not `RowMajor` or `ColMajor`
11071113

11081114
```llvm
1109-
declare void @dx.op.matrixLoadFromMemory.p[Ty](
1115+
declare %dx.types.LinAlgMatrix<mangling> @dx.op.matrixLoadFromMemory.[MatTy].[Ty](
11101116
immarg i32, ; opcode
1111-
%dx.types.MatrixRef, ; matrix
11121117
[Ty] * addrspace(4), ; groupshared T[M * N]
11131118
i32, ; Offset
11141119
i32, ; Stride
@@ -1121,31 +1126,31 @@ between opaque matrices and groupshared memory are defined in the
11211126
[Conversions](#conversions) section below.
11221127

11231128
```llvm
1124-
declare i32 @dx.op.matrixLength(
1125-
immarg i32, ; opcode
1126-
%dx.types.MatrixRef ; matrix
1129+
declare i32 @dx.op.matrixLength.[MatTy](
1130+
immarg i32, ; opcode
1131+
%dx.types.LinAlgMatrix<mangling> ; matrix
11271132
)
11281133
```
11291134

11301135
Returns the number of elements stored in thread-local storage on the active
11311136
thread for the provided matrix.
11321137

11331138
```llvm
1134-
declare <2 x i32> @dx.op.matrixGetCoordinate(
1135-
immarg i32, ; opcode
1136-
%dx.types.MatrixRef, ; matrix
1137-
i32 ; thread-local index
1139+
declare <2 x i32> @dx.op.matrixGetCoordinate.[MatTy](
1140+
immarg i32, ; opcode
1141+
%dx.types.LinAlgMatrix<mangling>, ; matrix
1142+
i32 ; thread-local index
11381143
)
11391144
```
11401145

11411146
Returns a two element vector containing the column and row of the matrix that
11421147
the thread-local index corresponds to.
11431148

11441149
```llvm
1145-
declare [Ty] @dx.op.matrixGetElement.[Ty](
1146-
immarg i32, ; opcode
1147-
%dx.types.MatrixRef, ; matrix
1148-
i32 ; thread-local index
1150+
declare [Ty] @dx.op.matrixGetElement.[Ty].[MatTy](
1151+
immarg i32, ; opcode
1152+
%dx.types.LinAlgMatrix<mangling>, ; matrix
1153+
i32 ; thread-local index
11491154
)
11501155
```
11511156

@@ -1154,11 +1159,11 @@ If the index is out of range for the values stored in this thread the result is
11541159
0.
11551160

11561161
```llvm
1157-
declare void @dx.op.matrixSetElement.[Ty](
1158-
immarg i32, ; opcode
1159-
%dx.types.MatrixRef, ; matrix
1160-
i32, ; thread-local index
1161-
[Ty] ; value
1162+
declare %dx.types.LinAlgMatrix<mangling> @dx.op.matrixSetElement.[MatTy].[MatTy].[Ty](
1163+
immarg i32, ; opcode
1164+
%dx.types.LinAlgMatrix<mangling>, ; input matrix
1165+
i32, ; thread-local index
1166+
[Ty] ; value
11621167
)
11631168
```
11641169

@@ -1167,13 +1172,13 @@ to the value provided. If the index is out of range for the values stored in
11671172
this thread the result is a no-op.
11681173

11691174
```llvm
1170-
declare void @dx.op.matrixStoreToDescriptor(
1171-
immarg i32, ; opcode
1172-
%dx.types.MatrixRef, ; matrix
1173-
%dx.types.Handle, ; ByteAddressBuffer
1174-
i32, ; Offset
1175-
i32, ; Stride
1176-
i32, ; matrix layout
1175+
declare void @dx.op.matrixStoreToDescriptor.[MatTy](
1176+
immarg i32, ; opcode
1177+
%dx.types.LinAlgMatrix<mangling>, ; matrix
1178+
%dx.types.Handle, ; ByteAddressBuffer
1179+
i32, ; Offset
1180+
i32, ; Stride
1181+
i32, ; matrix layout
11771182
)
11781183
```
11791184

@@ -1185,13 +1190,13 @@ Validation rules will enforce that:
11851190
* `Layout` is `RowMajor` or `ColMajor`
11861191

11871192
```llvm
1188-
declare void @dx.op.matrixStoreToMemory.p[Ty](
1189-
immarg i32, ; opcode
1190-
%dx.types.MatrixRef, ; matrix
1191-
[Ty] *, ; groupshared T[M * N]
1192-
i32, ; Offset
1193-
i32, ; Stride
1194-
i32, ; matrix layout
1193+
declare void @dx.op.matrixStoreToMemory.[MatTy].[Ty](
1194+
immarg i32, ; opcode
1195+
%dx.types.LinAlgMatrix<mangling>, ; matrix
1196+
[Ty] *, ; groupshared T[M * N]
1197+
i32, ; Offset
1198+
i32, ; Stride
1199+
i32, ; matrix layout
11951200
)
11961201
```
11971202

@@ -1215,11 +1220,10 @@ layout while a return value of `1` will denote that accumulator matrices are `B`
12151220
layout.
12161221

12171222
```llvm
1218-
declare void @dx.op.matrixMulOp(
1219-
immarg i32, ; opcode
1220-
%dx.types.MatrixRef, ; matrix A
1221-
%dx.types.MatrixRef, ; matrix B
1222-
%dx.types.MatrixRef ; matrix C
1223+
declare %dx.types.LinAlgMatrix<mangling> @dx.op.matrixMulOp.[MatTyC].[MatTyA].[MatTyB](
1224+
immarg i32, ; opcode
1225+
%dx.types.LinAlgMatrix<mangling>, ; matrix A
1226+
%dx.types.LinAlgMatrix<mangling> ; matrix B
12231227
)
12241228
```
12251229

@@ -1230,7 +1234,7 @@ Two opcodes are available for this operation class:
12301234
Validation rules will enforce that:
12311235
* argument A is an `A` matrix
12321236
* argument B is a `B` matrix
1233-
* argument C is an `Accumulator` matrix
1237+
* return value (C) is an `Accumulator` matrix
12341238
* All three matrices have the same scope (Wave or ThreadGroup)
12351239
* Matrix A's dimensions shall be M x K
12361240
* Matrix B's dimensions shall be K x N
@@ -1241,31 +1245,32 @@ Must be called from wave-uniform control flow.
12411245

12421246

12431247
```llvm
1244-
declare void @dx.op.matrixAccumulate(
1245-
immarg i32, ; opcode
1246-
%dx.types.MatrixRef, ; matrix RHS
1247-
%dx.types.MatrixRef, ; matrix LHS
1248+
declare %dx.types.LinAlgMatrix<mangling> @dx.op.matrixAccumulate.[MatTyC].[MatTyLHS].[MatTyRHS](
1249+
immarg i32, ; opcode
1250+
%dx.types.LinAlgMatrix<mangling>, ; matrix LHS
1251+
%dx.types.LinAlgMatrix<mangling>, ; matrix RHS
12481252
)
12491253
```
12501254

12511255
This operation accumulates an `A` or `B` matrix into an accumulator following
1252-
the form `LHS += RHS`.
1256+
the form `LHS = LHS + RHS`.
12531257

12541258
Validation rules will enforce that:
12551259
* Argument RHS is an `A` or `B` matrix
12561260
* Argument LHS is an `Accumulator` matrix
1261+
* Type of LHS is the same as the return type
12571262
* Both matrices have the same scope (Wave or ThreadGroup)
12581263
* Both matrices have the same dimensions
12591264
* The element types are compatible
12601265

12611266
Must be called from wave-uniform control flow.
12621267

12631268
``` llvm
1264-
declare <[NUMo] x [TYo]> @dx.op.matvecmul.v[NUMo][TYo].v[NUMi][TYi](
1265-
immarg i32, ; opcode
1266-
%dx.types.MatrixRef, ; matrix A
1267-
<[NUMi] x [TYi]>, ; input vector
1268-
immarg i32 ; input interpretation type (DXILComponentType)
1269+
declare <[NUMo] x [TYo]> @dx.op.matvecmul.v[NUMo][TYo].[MatTy].v[NUMi][TYi](
1270+
immarg i32, ; opcode
1271+
%dx.types.LinAlgMatrix<mangling>, ; matrix A
1272+
<[NUMi] x [TYi]>, ; input vector
1273+
immarg i32 ; input interpretation type (DXILComponentType)
12691274
)
12701275
```
12711276

@@ -1277,13 +1282,13 @@ Validation will enforce that:
12771282
* The matrix A is an `A` matrix of `Thread` scope
12781283

12791284
``` llvm
1280-
declare <[NUMo] x [TYo]> @dx.op.matvecmuladd.v[NUMo][TYo].v[NUMi][TYi].v[NUMo][TYb](
1281-
immarg i32, ; opcode
1282-
%dx.types.MatrixRef, ; matrix A
1283-
<[NUMi] x [TYi]>, ; input vector
1284-
immarg i32, ; input interpretation type (DXILComponentType)
1285-
<[NUMo] x [TYb]>, ; bias vector
1286-
immarg i32 ; bias interpretation type (DXILComponentType)
1285+
declare <[NUMo] x [TYo]> @dx.op.matvecmuladd.v[NUMo][TYo].[MatTy].v[NUMi][TYi].v[NUMo][TYb](
1286+
immarg i32, ; opcode
1287+
%dx.types.LinAlgMatrix<mangling>, ; matrix A
1288+
<[NUMi] x [TYi]>, ; input vector
1289+
immarg i32, ; input interpretation type (DXILComponentType)
1290+
<[NUMo] x [TYb]>, ; bias vector
1291+
immarg i32 ; bias interpretation type (DXILComponentType)
12871292
)
12881293
```
12891294

@@ -1296,13 +1301,13 @@ Validation will enforce that:
12961301
* The matrix A is an `A` matrix of `Thread` scope
12971302

12981303
```llvm
1299-
declare void @dx.op.matrixAccumulateToDescriptor(
1300-
immarg i32, ; opcode
1301-
%dx.types.MatrixRef, ; matrix
1302-
%dx.types.Handle, ; RWByteAddressBuffer
1303-
i32, ; Offset
1304-
i32, ; Stride
1305-
i32 ; matrix layout
1304+
declare void @dx.op.matrixAccumulateToDescriptor.[MatTy](
1305+
immarg i32, ; opcode
1306+
%dx.types.LinAlgMatrix<mangling>, ; matrix
1307+
%dx.types.Handle, ; RWByteAddressBuffer
1308+
i32, ; Offset
1309+
i32, ; Stride
1310+
i32 ; matrix layout
13061311
)
13071312
```
13081313

@@ -1320,13 +1325,13 @@ Validation rules will enforce that:
13201325
* `Stride` is `0` if the `Layout` is not `RowMajor` or `ColMajor`
13211326

13221327
```llvm
1323-
declare void @dx.op.matrixAccumulateToMemory.p[Ty](
1324-
immarg i32, ; opcode
1325-
%dx.types.MatrixRef, ; matrix
1326-
[Ty] *, ; groupshared T[M * N]
1327-
i32, ; Offset
1328-
i32, ; Stride
1329-
i32 ; matrix layout
1328+
declare void @dx.op.matrixAccumulateToMemory.[MatTy].p[Ty](
1329+
immarg i32, ; opcode
1330+
%dx.types.LinAlgMatrix<mangling>, ; matrix
1331+
[Ty] *, ; groupshared T[M * N]
1332+
i32, ; Offset
1333+
i32, ; Stride
1334+
i32 ; matrix layout
13301335
)
13311336
```
13321337

@@ -1339,9 +1344,8 @@ The validator will ensure that the group shared target memory is large enough
13391344
for the write.
13401345

13411346
```llvm
1342-
declare void @dx.op.matrixOuterProduct.v[M][TY].v[N][TY](
1347+
declare %dx.types.LinAlgMatrix<mangling> @dx.op.matrixOuterProduct.[MatTy].v[M][TY].v[N][TY](
13431348
immarg i32, ; opcode
1344-
%dx.types.MatrixRef, ; matrix
13451349
<[M] x [Ty]>, ; vector A
13461350
<[N] x [Ty]> ; vector B
13471351
)
@@ -1358,14 +1362,7 @@ Validation will ensure that:
13581362
* The element type of the matrix argument matches the element type of the input
13591363
vectors, or the input vectors are `i32` if the matrix uses types not directly
13601364
representable in DXIL.
1361-
1362-
#### DXIL Validation
1363-
1364-
Each use of a `MatrixRef` argument to a DXIL operation must be tracable to a
1365-
single `AttributedMatrixRef` object returned from a unique `createMatrix`
1366-
operation. This validation rule is similar to the rules enforced for local
1367-
resource objects, and allows trivial identification of the lifetime of any
1368-
matrix (from `createMatrix` to its last use).
1365+
* The element type of vector A and vector B must be the same.
13691366

13701367
#### Bounds Checking Behavior
13711368

0 commit comments

Comments
 (0)