@@ -339,8 +339,8 @@ void OuterProdAccum() {
339339### HLSL API Concepts
340340
341341The new HLSL API introduces a new ` linalg::Matrix ` type which represents an
342- opaque matrix object, and contains an intangible handle that refers to the
343- allocated matrix.
342+ opaque matrix object, and contains an intangible value object that refers to the
343+ matrix.
344344
345345The ` linalg::Matrix ` template type is parameterized based on the matrix
346346component data type, dimensions, use, and scope. These parameters restrict where
@@ -443,7 +443,10 @@ In HLSL, matrix objects are intangible objects so they do not have defined size
443443or memory layout. When in use, implementations are expected to distribute the
444444storage of matrices across the thread-local storage for all threads in a SIMD
445445unit. An implementation may also utilize caches or other memory regions as
446- appropriate. At the DXIL level a matrix is represented as a handle object.
446+ appropriate. At the DXIL level a matrix is represented as a value object.
447+ Because LLVM 3.7 doesn't allow value objects of opaque types, the matrix object
448+ stores a pointer in the IR, but implementations will replace this with an
449+ implementation-defined object.
447450
448451An A matrix is a collection of per-thread vectors representing matrix rows,
449452while a B matrix is a collection of per-thread vectors representing matrix
@@ -1022,45 +1025,50 @@ enum class DXILComponentType {
10221025}
10231026```
10241027
1025- This feature also adds a matrix ref that serves as an opaque type handle to the
1026- implementation's representation of the matrix.
1027-
1028-
1029- ```llvm
1030- %dx.types.MatrixRef = type { i8 * }
1031- ```
1032-
1033- The compiler will also generate a permutation of typed matrix handles with names
1034- of the format ` %dx.types.AttributedMatrixRef<mangling> ` . The mangling scheme for
1028+ The compiler will generate a permutation of typed matrix handles with names of
1029+ the format `%dx.types.LinAlgMatrix<mangling>`. The mangling scheme for
10351030each type name will capture the type parameterization with the tokens `C`,
10361031`M`, `N`, `U` and `S` denoting each encoded property.
10371032
10381033```
10391034 ; Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Wave>
1040- %dx.types.AttributedMatrixRefC10M16N16U0S1 = type { i8 * }
1035+ %dx.types.LinAlgMatrixC10M16N16U0S1 = type { i8 * }
10411036 ; Matrix<ComponentType::F16, 16, 16, MatrixUse::B, MatrixScope::Wave>
1042- %dx.types.AttributedMatrixRefC10M16N16U1S1 = type { i8 * }
1037+ %dx.types.LinAlgMatrixC10M16N16U1S1 = type { i8 * }
10431038 ; Matrix<ComponentType::F32, 16, 16, MatrixUse::Accumulator, MatrixScope::Wave>
1044- %dx.types.AttributedMatrixRefC11M16N16U2S1 = type { i8 * }
1039+ %dx.types.LinAlgMatrixC11M16N16U2S1 = type { i8 * }
10451040```
10461041
1047- DXIL validation will enforce that an ` AttributedMatrixRef ` of any type may be
1048- bitcast to a ` %dx.types.MatrixRef ` , but the inverse cast will be disallowed.
1042+ DXIL validation will enforce that a `LinAlgMatrix` of any type may not
1043+ be bitcast to any other type.
1044+
1045+ #### Type Metadata
1046+
1047+ A new named metadata `dx.targetTypes` will be added to contain mappings of
1048+ attributed matrix types to their type parameters avoiding needing to parse the
1049+ type mangling. For the given examples above metadata of the form below will be
1050+ generated:
10491051
1050- ### DXIL Operations
10511052
1052- ``` llvm
1053- declare %dx.types.AttributedMatrixRef<mangling> @dx.op.createMatrix<mangling>(
1054- immarg i32 ; opcode
1055- )
10561053```
1054+ !dx.targetTypes = !{!1, !2, !3}
1055+ ; Matrix<ComponentType::F16, 16 , 16 , MatrixUse::A, MatrixScope::Wave>
1056+ !1 = !{%dx.types.LinAlgMatrixC10M16N16U0S1 undef, i32 10, i32 16, i32 16, i32 0, i32 1 }
1057+ ; Matrix<ComponentType::F16, 16, 16, MatrixUse::B, MatrixScope::Wave>
1058+ !2 = !{%dx.types.LinAlgMatrixC10M16N16U1S1 undef, i32 10, i32 16, i32 16, i32 1, i32 1 }
1059+ ; Matrix<ComponentType::F32, 16, 16, MatrixUse::Accumulator, MatrixScope::Wave>
1060+ !3 = !{%dx.types.LinAlgMatrixC11M16N16U2S1 undef, i32 11, i32 16, i32 16, i32 2, i32 1 }
1061+ ```
1062+
1063+ > Note: to ease compatability with modern LLVM we want the metadata to avoid
1064+ > encoding pointers since modern LLVM will convert pointers to opaque pointers
1065+ > losing the type information.
10571066
1058- Creates a new uninitialized matrix handle.
1067+ ### DXIL Operations
10591068
10601069```llvm
1061- declare void @dx.op.fillMatrix.[TY](
1070+ declare %dx.types.LinAlgMatrix<mangling> @dx.op.fillMatrix.[MatTy] .[TY](
10621071 immarg i32, ; opcode
1063- %dx.types.MatrixRef, ; matrix
10641072 [Ty] ; fill value
10651073 )
10661074```
@@ -1070,11 +1078,10 @@ matrix component's type, a type conversion is applied following the rules
10701078documented in the [ Conversions] ( #conversions ) section.
10711079
10721080``` llvm
1073- declare void @dx.op.copyConvertMatrix(
1074- immarg i32, ; opcode
1075- %dx.types.MatrixRef, ; matrix destination
1076- %dx.types.MatrixRef, ; matrix source
1077- immarg i1, ; transpose
1081+ declare %dx.types.LinAlgMatrix<mangling> @dx.op.copyConvertMatrix.[MatTy1].[MatTy2](
1082+ immarg i32, ; opcode
1083+ %dx.types.LinAlgMatrix<mangling>, ; matrix source
1084+ immarg i1, ; transpose
10781085 )
10791086```
10801087
@@ -1084,9 +1091,8 @@ unmodified after this operation is applied. Validation shall enforce that both
10841091matrices have the same scope and dimensions.
10851092
10861093``` llvm
1087- declare void @dx.op.matrixLoadFromDescriptor(
1094+ declare %dx.types.LinAlgMatrix<mangling> @dx.op.matrixLoadFromDescriptor.[MatTy] (
10881095 immarg i32, ; opcode
1089- %dx.types.MatrixRef, ; matrix
10901096 %dx.types.Handle, ; ByteAddressBuffer
10911097 i32, ; Offset
10921098 i32, ; Stride
@@ -1106,9 +1112,8 @@ Validation rules will enforce that:
11061112* ` Stride ` is ` 0 ` if the ` Layout ` is not ` RowMajor ` or ` ColMajor `
11071113
11081114``` llvm
1109- declare void @dx.op.matrixLoadFromMemory.p [Ty](
1115+ declare %dx.types.LinAlgMatrix<mangling> @dx.op.matrixLoadFromMemory.[MatTy]. [Ty](
11101116 immarg i32, ; opcode
1111- %dx.types.MatrixRef, ; matrix
11121117 [Ty] * addrspace(4), ; groupshared T[M * N]
11131118 i32, ; Offset
11141119 i32, ; Stride
@@ -1121,31 +1126,31 @@ between opaque matrices and groupshared memory are defined in the
11211126[ Conversions] ( #conversions ) section below.
11221127
11231128``` llvm
1124- declare i32 @dx.op.matrixLength(
1125- immarg i32, ; opcode
1126- %dx.types.MatrixRef ; matrix
1129+ declare i32 @dx.op.matrixLength.[MatTy] (
1130+ immarg i32, ; opcode
1131+ %dx.types.LinAlgMatrix<mangling> ; matrix
11271132 )
11281133```
11291134
11301135Returns the number of elements stored in thread-local storage on the active
11311136thread for the provided matrix.
11321137
11331138``` llvm
1134- declare <2 x i32> @dx.op.matrixGetCoordinate(
1135- immarg i32, ; opcode
1136- %dx.types.MatrixRef , ; matrix
1137- i32 ; thread-local index
1139+ declare <2 x i32> @dx.op.matrixGetCoordinate.[MatTy] (
1140+ immarg i32, ; opcode
1141+ %dx.types.LinAlgMatrix<mangling> , ; matrix
1142+ i32 ; thread-local index
11381143 )
11391144```
11401145
11411146Returns a two element vector containing the column and row of the matrix that
11421147the thread-local index corresponds to.
11431148
11441149``` llvm
1145- declare [Ty] @dx.op.matrixGetElement.[Ty](
1146- immarg i32, ; opcode
1147- %dx.types.MatrixRef , ; matrix
1148- i32 ; thread-local index
1150+ declare [Ty] @dx.op.matrixGetElement.[Ty].[MatTy] (
1151+ immarg i32, ; opcode
1152+ %dx.types.LinAlgMatrix<mangling> , ; matrix
1153+ i32 ; thread-local index
11491154 )
11501155```
11511156
@@ -1154,11 +1159,11 @@ If the index is out of range for the values stored in this thread the result is
115411590 .
11551160
11561161``` llvm
1157- declare void @dx.op.matrixSetElement.[Ty](
1158- immarg i32, ; opcode
1159- %dx.types.MatrixRef , ; matrix
1160- i32, ; thread-local index
1161- [Ty] ; value
1162+ declare %dx.types.LinAlgMatrix<mangling> @dx.op.matrixSetElement.[MatTy].[MatTy] .[Ty](
1163+ immarg i32, ; opcode
1164+ %dx.types.LinAlgMatrix<mangling> , ; input matrix
1165+ i32, ; thread-local index
1166+ [Ty] ; value
11621167 )
11631168```
11641169
@@ -1167,13 +1172,13 @@ to the value provided. If the index is out of range for the values stored in
11671172this thread the result is a no-op.
11681173
11691174``` llvm
1170- declare void @dx.op.matrixStoreToDescriptor(
1171- immarg i32, ; opcode
1172- %dx.types.MatrixRef , ; matrix
1173- %dx.types.Handle, ; ByteAddressBuffer
1174- i32, ; Offset
1175- i32, ; Stride
1176- i32, ; matrix layout
1175+ declare void @dx.op.matrixStoreToDescriptor.[MatTy] (
1176+ immarg i32, ; opcode
1177+ %dx.types.LinAlgMatrix<mangling> , ; matrix
1178+ %dx.types.Handle, ; ByteAddressBuffer
1179+ i32, ; Offset
1180+ i32, ; Stride
1181+ i32, ; matrix layout
11771182 )
11781183```
11791184
@@ -1185,13 +1190,13 @@ Validation rules will enforce that:
11851190* ` Layout ` is ` RowMajor ` or ` ColMajor `
11861191
11871192``` llvm
1188- declare void @dx.op.matrixStoreToMemory.p [Ty](
1189- immarg i32, ; opcode
1190- %dx.types.MatrixRef , ; matrix
1191- [Ty] *, ; groupshared T[M * N]
1192- i32, ; Offset
1193- i32, ; Stride
1194- i32, ; matrix layout
1193+ declare void @dx.op.matrixStoreToMemory.[MatTy]. [Ty](
1194+ immarg i32, ; opcode
1195+ %dx.types.LinAlgMatrix<mangling> , ; matrix
1196+ [Ty] *, ; groupshared T[M * N]
1197+ i32, ; Offset
1198+ i32, ; Stride
1199+ i32, ; matrix layout
11951200 )
11961201```
11971202
@@ -1215,11 +1220,10 @@ layout while a return value of `1` will denote that accumulator matrices are `B`
12151220layout.
12161221
12171222``` llvm
1218- declare void @dx.op.matrixMulOp(
1219- immarg i32, ; opcode
1220- %dx.types.MatrixRef, ; matrix A
1221- %dx.types.MatrixRef, ; matrix B
1222- %dx.types.MatrixRef ; matrix C
1223+ declare %dx.types.LinAlgMatrix<mangling> @dx.op.matrixMulOp.[MatTyC].[MatTyA].[MatTyB](
1224+ immarg i32, ; opcode
1225+ %dx.types.LinAlgMatrix<mangling>, ; matrix A
1226+ %dx.types.LinAlgMatrix<mangling> ; matrix B
12231227 )
12241228```
12251229
@@ -1230,7 +1234,7 @@ Two opcodes are available for this operation class:
12301234Validation rules will enforce that:
12311235* argument A is an ` A ` matrix
12321236* argument B is a ` B ` matrix
1233- * argument C is an ` Accumulator ` matrix
1237+ * return value (C) is an ` Accumulator ` matrix
12341238* All three matrices have the same scope (Wave or ThreadGroup)
12351239* Matrix A's dimensions shall be M x K
12361240* Matrix B's dimensions shall be K x N
@@ -1241,31 +1245,32 @@ Must be called from wave-uniform control flow.
12411245
12421246
12431247``` llvm
1244- declare void @dx.op.matrixAccumulate(
1245- immarg i32, ; opcode
1246- %dx.types.MatrixRef , ; matrix RHS
1247- %dx.types.MatrixRef , ; matrix LHS
1248+ declare %dx.types.LinAlgMatrix<mangling> @dx.op.matrixAccumulate.[MatTyC].[MatTyLHS].[MatTyRHS] (
1249+ immarg i32, ; opcode
1250+ %dx.types.LinAlgMatrix<mangling> , ; matrix LHS
1251+ %dx.types.LinAlgMatrix<mangling> , ; matrix RHS
12481252 )
12491253```
12501254
12511255This operation accumulates an ` A ` or ` B ` matrix into an accumulator following
1252- the form ` LHS += RHS ` .
1256+ the form ` LHS = LHS + RHS ` .
12531257
12541258Validation rules will enforce that:
12551259* Argument RHS is an ` A ` or ` B ` matrix
12561260* Argument LHS is an ` Accumulator ` matrix
1261+ * Type of LHS is the same as the return type
12571262* Both matrices have the same scope (Wave or ThreadGroup)
12581263* Both matrices have the same dimensions
12591264* The element types are compatible
12601265
12611266Must be called from wave-uniform control flow.
12621267
12631268``` llvm
1264- declare <[NUMo] x [TYo]> @dx.op.matvecmul.v[NUMo][TYo].v[NUMi][TYi](
1265- immarg i32, ; opcode
1266- %dx.types.MatrixRef , ; matrix A
1267- <[NUMi] x [TYi]>, ; input vector
1268- immarg i32 ; input interpretation type (DXILComponentType)
1269+ declare <[NUMo] x [TYo]> @dx.op.matvecmul.v[NUMo][TYo].[MatTy]. v[NUMi][TYi](
1270+ immarg i32, ; opcode
1271+ %dx.types.LinAlgMatrix<mangling> , ; matrix A
1272+ <[NUMi] x [TYi]>, ; input vector
1273+ immarg i32 ; input interpretation type (DXILComponentType)
12691274)
12701275```
12711276
@@ -1277,13 +1282,13 @@ Validation will enforce that:
12771282* The matrix A is an ` A ` matrix of ` Thread ` scope
12781283
12791284``` llvm
1280- declare <[NUMo] x [TYo]> @dx.op.matvecmuladd.v[NUMo][TYo].v[NUMi][TYi].v[NUMo][TYb](
1281- immarg i32, ; opcode
1282- %dx.types.MatrixRef , ; matrix A
1283- <[NUMi] x [TYi]>, ; input vector
1284- immarg i32, ; input interpretation type (DXILComponentType)
1285- <[NUMo] x [TYb]>, ; bias vector
1286- immarg i32 ; bias interpretation type (DXILComponentType)
1285+ declare <[NUMo] x [TYo]> @dx.op.matvecmuladd.v[NUMo][TYo].[MatTy]. v[NUMi][TYi].v[NUMo][TYb](
1286+ immarg i32, ; opcode
1287+ %dx.types.LinAlgMatrix<mangling> , ; matrix A
1288+ <[NUMi] x [TYi]>, ; input vector
1289+ immarg i32, ; input interpretation type (DXILComponentType)
1290+ <[NUMo] x [TYb]>, ; bias vector
1291+ immarg i32 ; bias interpretation type (DXILComponentType)
12871292)
12881293```
12891294
@@ -1296,13 +1301,13 @@ Validation will enforce that:
12961301* The matrix A is an ` A ` matrix of ` Thread ` scope
12971302
12981303``` llvm
1299- declare void @dx.op.matrixAccumulateToDescriptor(
1300- immarg i32, ; opcode
1301- %dx.types.MatrixRef , ; matrix
1302- %dx.types.Handle, ; RWByteAddressBuffer
1303- i32, ; Offset
1304- i32, ; Stride
1305- i32 ; matrix layout
1304+ declare void @dx.op.matrixAccumulateToDescriptor.[MatTy] (
1305+ immarg i32, ; opcode
1306+ %dx.types.LinAlgMatrix<mangling> , ; matrix
1307+ %dx.types.Handle, ; RWByteAddressBuffer
1308+ i32, ; Offset
1309+ i32, ; Stride
1310+ i32 ; matrix layout
13061311 )
13071312```
13081313
@@ -1320,13 +1325,13 @@ Validation rules will enforce that:
13201325* ` Stride ` is ` 0 ` if the ` Layout ` is not ` RowMajor ` or ` ColMajor `
13211326
13221327``` llvm
1323- declare void @dx.op.matrixAccumulateToMemory.p[Ty](
1324- immarg i32, ; opcode
1325- %dx.types.MatrixRef , ; matrix
1326- [Ty] *, ; groupshared T[M * N]
1327- i32, ; Offset
1328- i32, ; Stride
1329- i32 ; matrix layout
1328+ declare void @dx.op.matrixAccumulateToMemory.[MatTy]. p[Ty](
1329+ immarg i32, ; opcode
1330+ %dx.types.LinAlgMatrix<mangling> , ; matrix
1331+ [Ty] *, ; groupshared T[M * N]
1332+ i32, ; Offset
1333+ i32, ; Stride
1334+ i32 ; matrix layout
13301335 )
13311336```
13321337
@@ -1339,9 +1344,8 @@ The validator will ensure that the group shared target memory is large enough
13391344for the write.
13401345
13411346``` llvm
1342- declare void @dx.op.matrixOuterProduct.v[M][TY].v[N][TY](
1347+ declare %dx.types.LinAlgMatrix<mangling> @dx.op.matrixOuterProduct.[MatTy] .v[M][TY].v[N][TY](
13431348 immarg i32, ; opcode
1344- %dx.types.MatrixRef, ; matrix
13451349 <[M] x [Ty]>, ; vector A
13461350 <[N] x [Ty]> ; vector B
13471351 )
@@ -1358,14 +1362,7 @@ Validation will ensure that:
13581362* The element type of the matrix argument matches the element type of the input
13591363 vectors, or the input vectors are ` i32 ` if the matrix uses types not directly
13601364 representable in DXIL.
1361-
1362- #### DXIL Validation
1363-
1364- Each use of a ` MatrixRef ` argument to a DXIL operation must be tracable to a
1365- single ` AttributedMatrixRef ` object returned from a unique ` createMatrix `
1366- operation. This validation rule is similar to the rules enforced for local
1367- resource objects, and allows trivial identification of the lifetime of any
1368- matrix (from ` createMatrix ` to its last use).
1365+ * The element type of vector A and vector B must be the same.
13691366
13701367#### Bounds Checking Behavior
13711368
0 commit comments