diff --git a/tools/clang/include/clang/Basic/DiagnosticGroups.td b/tools/clang/include/clang/Basic/DiagnosticGroups.td index ff21b34652..d1b14a7de9 100644 --- a/tools/clang/include/clang/Basic/DiagnosticGroups.td +++ b/tools/clang/include/clang/Basic/DiagnosticGroups.td @@ -807,4 +807,5 @@ def HLSLAvailability: DiagGroup<"hlsl-availability">; def HLSLAvailabilityConstant: DiagGroup<"hlsl-availability-constant">; def HLSLBarrier : DiagGroup<"hlsl-barrier">; def HLSLLegacyLiterals : DiagGroup<"hlsl-legacy-literal">; +def HLSLGroupshared202x : DiagGroup<"hlsl-groupshared-202x">; // HLSL Change Ends diff --git a/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td b/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td index ce6d6b17cc..d8dcd6c8f9 100644 --- a/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -1597,9 +1597,10 @@ def warn_hlsl_sometimes_uninit_out_param : Warning< "its declaration is reached|" "%3 is called}2">, InGroup; -def warn_hlsl_groupshared_202x: Warning< - "Support for groupshared parameter annotation not added until HLSL 202x">, - InGroup; +def warn_hlsl_groupshared_202x + : Warning<"Support for groupshared parameter annotation not added until " + "HLSL 202x">, + InGroup; def warn_hlsl_groupshared_inout: Warning< "Passing groupshared variable to a parameter annotated with inout. See " "'groupshared' parameter annotation added in 202x.">, diff --git a/tools/clang/lib/Headers/hlsl/dx/linalg.h b/tools/clang/lib/Headers/hlsl/dx/linalg.h index 9ed72fa485..230ac77cbe 100644 --- a/tools/clang/lib/Headers/hlsl/dx/linalg.h +++ b/tools/clang/lib/Headers/hlsl/dx/linalg.h @@ -4,10 +4,57 @@ (__SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR >= 10)) && \ (__HLSL_VERSION >= 2021) +#pragma dxc diagnostic push +#pragma dxc diagnostic ignored "-Whlsl-groupshared-202x" + namespace hlsl { #define SIZE_TYPE int +template struct is_arithmetic { + static const bool value = false; +}; + +#define __ARITHMETIC_TYPE(type) \ + template <> struct is_arithmetic { \ + static const bool value = true; \ + }; + +#if __HLSL_ENABLE_16_BIT +__ARITHMETIC_TYPE(uint16_t) +__ARITHMETIC_TYPE(int16_t) +#endif +__ARITHMETIC_TYPE(uint) +__ARITHMETIC_TYPE(int) +__ARITHMETIC_TYPE(uint64_t) +__ARITHMETIC_TYPE(int64_t) +__ARITHMETIC_TYPE(half) +__ARITHMETIC_TYPE(float) +__ARITHMETIC_TYPE(double) + +template struct is_signed { + static const bool value = true; +}; + +#define __UNSIGNED_TYPE(type) \ + template <> struct is_signed { \ + static const bool value = false; \ + }; + +#if __HLSL_ENABLE_16_BIT +__UNSIGNED_TYPE(uint16_t) +#endif +__UNSIGNED_TYPE(uint) +__UNSIGNED_TYPE(uint64_t) + +#undef __UNSIGNED_TYPE + +template struct enable_if {}; + +template struct enable_if { + using type = T; +}; + } // namespace hlsl namespace dxil { @@ -82,7 +129,11 @@ struct ComponentType { using ComponentEnum = ComponentType::ComponentEnum; struct MatrixUse { - enum MatrixUseEnum { A = 0, B = 1, Accumulator = 2 }; + enum MatrixUseEnum { + A = 0, + B = 1, + Accumulator = 2, + }; }; using MatrixUseEnum = MatrixUse::MatrixUseEnum; @@ -95,16 +146,394 @@ struct MatrixScope { }; using MatrixScopeEnum = MatrixScope::MatrixScopeEnum; +struct MatrixLayout { + enum MatrixLayoutEnum { + RowMajor = 0, + ColMajor = 1, + MulOptimal = 2, + OuterProductOptimal = 3, + }; +}; +using MatrixLayoutEnum = MatrixLayout::MatrixLayoutEnum; + +namespace __detail { +template struct ComponentTypeTraits { + using Type = uint; + static const bool IsNativeScalar = false; + static const uint ElementsPerScalar = 4; +}; + +#define __MATRIX_SCALAR_COMPONENT_MAPPING(enum_val, type) \ + template <> struct ComponentTypeTraits { \ + using Type = type; \ + static const bool IsNativeScalar = true; \ + static const uint ElementsPerScalar = 1; \ + }; + +#if __HLSL_ENABLE_16_BIT +__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I16, int16_t) +__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U16, uint16_t) +__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F16, float16_t) +#endif + +__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I32, int32_t) +__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U32, uint32_t) +__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F32, float) +__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::I64, int64_t) +__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::U64, uint64_t) +__MATRIX_SCALAR_COMPONENT_MAPPING(ComponentType::F64, double) + +} // namespace __detail + +template struct VectorRef { + ByteAddressBuffer Buf; + uint Offset; +}; + +template struct InterpretedVector { + vector Data; + static const ComponentEnum Interpretation = DT; + static const SIZE_TYPE Size = + __detail::ComponentTypeTraits
::ElementsPerScalar * N; +}; + +template +InterpretedVector MakeInterpretedVector(vector Vec) { + InterpretedVector IV = {Vec}; + return IV; +} + template class Matrix { + using ElementType = typename __detail::ComponentTypeTraits::Type; + // If this isn't a native scalar, we have an 8-bit type, so we have 4 elements + // packed in each scalar value. + static const uint ElementsPerScalar = + __detail::ComponentTypeTraits::ElementsPerScalar; + static const bool IsNativeScalar = + __detail::ComponentTypeTraits::IsNativeScalar; + using HandleT = __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentTy, M, N, Use, Scope)]]; HandleT __handle; + + template + Matrix Cast() { + Matrix Result; + __builtin_LinAlg_CopyConvertMatrix(Result.__handle, __handle, Transpose); + return Result; + } + + template + static typename hlsl::enable_if::value, Matrix>::type + Splat(T Val) { + Matrix Result; + __builtin_LinAlg_FillMatrix(Result.__handle, Val); + return Result; + } + + static Matrix Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, + MatrixLayoutEnum Layout, + uint Align = sizeof(ElementType)) { + Matrix Result; + __builtin_LinAlg_MatrixLoadFromDescriptor(Result.__handle, Res, StartOffset, + Stride, Layout, Align); + return Result; + } + + static Matrix Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride, + MatrixLayoutEnum Layout, + uint Align = sizeof(ElementType)) { + Matrix Result; + __builtin_LinAlg_MatrixLoadFromDescriptor(Result.__handle, Res, StartOffset, + Stride, Layout, Align); + return Result; + } + + template + static typename hlsl::enable_if::value && + (M * N / ElementsPerScalar <= Size), + Matrix>::type + Load(groupshared T Arr[Size], uint StartIdx, uint Stride, + MatrixLayoutEnum Layout) { + Matrix Result; + __builtin_LinAlg_MatrixLoadFromMemory(Result.__handle, Arr, StartIdx, + Stride, Layout); + return Result; + } + + template + typename hlsl::enable_if::type + Length() { + return __builtin_LinAlg_MatrixLength(__handle); + } + + template + typename hlsl::enable_if::type + GetCoordinate(uint Index) { + return __builtin_LinAlg_MatrixGetCoordinate(__handle, Index); + } + + template + typename hlsl::enable_if::type + Get(uint Index) { + ElementType Result; + __builtin_LinAlg_MatrixGetElement(Result, __handle, Index); + return Result; + } + + template + typename hlsl::enable_if::type + Set(uint Index, ElementType Value) { + __builtin_LinAlg_MatrixSetElement(__handle, __handle, Index, Value); + } + + void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride, + MatrixLayoutEnum Layout, uint Align = sizeof(ElementType)) { + __builtin_LinAlg_MatrixStoreToDescriptor(__handle, Res, StartOffset, Stride, + Layout, Align); + } + + template + typename hlsl::enable_if::value && + (M * N / ElementsPerScalar <= Size), + void>::type + Store(groupshared T Arr[Size], uint StartIdx, uint Stride, + MatrixLayoutEnum Layout) { + __builtin_LinAlg_MatrixStoreToMemory(__handle, Arr, StartIdx, Stride, + Layout); + } + + // Accumulate methods + template + typename hlsl::enable_if::type + InterlockedAccumulate(RWByteAddressBuffer Res, uint StartOffset, uint Stride, + MatrixLayoutEnum Layout, + uint Align = sizeof(ElementType)) { + __builtin_LinAlg_MatrixAccumulateToDescriptor(__handle, Res, StartOffset, + Stride, Layout, Align); + } + + template + typename hlsl::enable_if< + hlsl::is_arithmetic::value && Use == MatrixUse::Accumulator && + UseLocal == Use && (M * N / ElementsPerScalar <= Size), + void>::type + InterlockedAccumulate(groupshared T Arr[Size], uint StartIdx, uint Stride, + MatrixLayoutEnum Layout) { + __builtin_LinAlg_MatrixAccumulateToMemory(__handle, Arr, StartIdx, Stride, + Layout); + } + + template + typename hlsl::enable_if::type + Accumulate(const Matrix MatrixA) { + __builtin_LinAlg_MatrixAccumulate(__handle, __handle, MatrixA.__handle); + } + + template + typename hlsl::enable_if::type + Accumulate(const Matrix MatrixB) { + __builtin_LinAlg_MatrixAccumulate(__handle, __handle, MatrixB.__handle); + } + + template + typename hlsl::enable_if::type + MultiplyAccumulate(const Matrix MatrixA, + const Matrix MatrixB) { + __builtin_LinAlg_MatrixMatrixMultiplyAccumulate( + __handle, __handle, MatrixA.__handle, MatrixB.__handle); + } +}; + +// Thread-scope Matrices are read-only. Using a template partial +// specialization for this simplifies the SFINAE-foo above. +template +class Matrix { + using ElementType = typename __detail::ComponentTypeTraits::Type; + + using HandleT = __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes( + ComponentTy, M, N, Use, MatrixScope::Thread)]]; + HandleT __handle; + + template + static Matrix Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, + uint Align = sizeof(ElementType)) { + Matrix Result; + __builtin_LinAlg_MatrixLoadFromDescriptor(Result.__handle, Res, StartOffset, + Stride, Layout, Align); + return Result; + } + + void InterlockedAccumulate(RWByteAddressBuffer Res, uint StartOffset, + uint Stride, MatrixLayoutEnum Layout, + uint Align = sizeof(ElementType)) { + __builtin_LinAlg_MatrixAccumulateToDescriptor(__handle, Res, StartOffset, + Stride, Layout, Align); + } }; +MatrixUseEnum AccumulatorLayout() { + return (MatrixUseEnum)(__builtin_LinAlg_MatrixQueryAccumulatorLayout()); +} + +template +Matrix +Multiply(const Matrix MatrixA, + const Matrix MatrixB) { + Matrix Result; + __builtin_LinAlg_MatrixMatrixMultiply(Result.__handle, MatrixA.__handle, + MatrixB.__handle); + return Result; +} + +template +Matrix +Multiply(const Matrix MatrixA, + const Matrix MatrixB) { + Matrix Result; + __builtin_LinAlg_MatrixMatrixMultiply(Result.__handle, MatrixA.__handle, + MatrixB.__handle); + return Result; +} + +template +Matrix Multiply( + const Matrix MatrixA, + const Matrix MatrixB) { + Matrix Result; + __builtin_LinAlg_MatrixMatrixMultiply(Result.__handle, MatrixA.__handle, + MatrixB.__handle); + return Result; +} + +template +Matrix Multiply( + const Matrix MatrixA, + const Matrix + MatrixB) { + Matrix Result; + __builtin_LinAlg_MatrixMatrixMultiply(Result.__handle, MatrixA.__handle, + MatrixB.__handle); + return Result; +} + +// Cooperative Vector Replacement API +// Cooperative Vector operates on per-thread vectors multiplying against B +// matrices with thread scope. + +template +// clang-format off +typename hlsl::enable_if::value, vector >::type +// clang-format on +Multiply(Matrix MatrixA, + vector Vec) { + vector Result; + __builtin_LinAlg_MatrixVectorMultiply(Result, MatrixA.__handle, + hlsl::is_signed::value, Vec, + MatrixDT); + return Result; +} + +template +// clang-format off +typename hlsl::enable_if::value, vector >::type +// clang-format on +MultiplyAdd(Matrix MatrixA, + vector Vec, vector Bias) { + vector Result; + __builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle, + hlsl::is_signed::value, + Vec, MatrixDT, Bias, MatrixDT); + return Result; +} + +template +// clang-format off +typename hlsl::enable_if< + InterpretedVector::Size == M, + vector >::type +// clang-format on +MultiplyAdd(Matrix MatrixA, + InterpretedVector InterpVec, + vector Bias) { + vector Result; + __builtin_LinAlg_MatrixVectorMultiplyAdd( + Result, MatrixA.__handle, hlsl::is_signed::value, + InterpVec.Data, InterpVec.Interpretation, Bias, MatrixDT); + return Result; +} + +template +// clang-format off +typename hlsl::enable_if::value, + vector >::type +// clang-format on +MultiplyAdd(Matrix MatrixA, + vector Vec, VectorRef BiasRef) { + using BiasVecTy = + vector::Type, K>; + BiasVecTy BiasVec = BiasRef.Buf.template Load(BiasRef.Offset); + vector Result; + __builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle, + hlsl::is_signed::value, + Vec, MatrixDT, BiasVec, BiasElTy); + return Result; +} + +template +// clang-format off +typename hlsl::enable_if< + InterpretedVector::Size == M, + vector >::type +// clang-format on +MultiplyAdd(Matrix MatrixA, + InterpretedVector InterpVec, + VectorRef BiasRef) { + using BiasVecTy = + vector::Type, K>; + BiasVecTy BiasVec = BiasRef.Buf.template Load(BiasRef.Offset); + vector Result; + __builtin_LinAlg_MatrixVectorMultiplyAdd( + Result, MatrixA.__handle, hlsl::is_signed::value, + InterpVec.Data, InterpVec.Interpretation, BiasVec, BiasElTy); + return Result; +} + +// Outer product functions +template +Matrix +OuterProduct(vector VecA, vector VecB) { + Matrix Result; + __builtin_LinAlg_MatrixOuterProduct(Result.__handle, VecA, VecB); + return Result; +} + } // namespace linalg } // namespace dx +#pragma dxc diagnostic pop + #endif // SM 6.10 check and HV version check diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 30f27fcb6f..61e4f28318 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -7792,6 +7792,7 @@ unsigned HLSLExternalSource::GetNumElements(QualType anyType) { case AR_TOBJ_BASIC: case AR_TOBJ_OBJECT: case AR_TOBJ_STRING: + case AR_TOBJ_LINALG_MATRIX: return 1; case AR_TOBJ_COMPOUND: { // TODO: consider caching this value for perf @@ -7927,6 +7928,7 @@ QualType HLSLExternalSource::GetNthElementType(QualType type, unsigned index) { case AR_TOBJ_BASIC: case AR_TOBJ_OBJECT: case AR_TOBJ_STRING: + case AR_TOBJ_LINALG_MATRIX: return (index == 0) ? type : QualType(); case AR_TOBJ_COMPOUND: { // TODO: consider caching this value for perf diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/matrix-class.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/matrix-class.hlsl new file mode 100644 index 0000000000..a657d3b96f --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/matrix-class.hlsl @@ -0,0 +1,182 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -I %hlsl_headers -T cs_6_10 %s | FileCheck %s + +#include +using namespace dx::linalg; + +using MatrixATy = Matrix; +using MatrixBTy = Matrix; +using MatrixBTyInt = Matrix; +using MatrixAccumTy = Matrix; +using TSMatrixATy = Matrix; + +ByteAddressBuffer BAB : register(t0); +RWByteAddressBuffer RWBAB : register(u0); +groupshared float SharedArr[256]; + +[numthreads(4, 4, 4)] +void main(uint ID : SV_GroupID) +{ + +// CHECK: %[[GROUP_ID:.*]] = call i32 @dx.op.groupId.i32(i32 94, i32 0) ; GroupId(component) + +// Matrix::Splat +// +// CHECK: %[[MATA1:.*]] = call %dx.types.LinAlgMatrixC9M4N4U0S1 @dx.op.linAlgFillMatrix.mC9M4N4U0S1.f32( +// CHECK-SAME: i32 -2147483636, float 1.000000e+00) + MatrixATy MatA1 = MatrixATy::Splat(1.0f); + +// CHECK: %[[MATB1:.*]] = call %dx.types.LinAlgMatrixC9M4N4U1S1 @dx.op.linAlgFillMatrix.mC9M4N4U1S1.f32( +// CHECK-SAME: i32 -2147483636, float 2.000000e+00) + MatrixBTy MatB1; + MatB1 = MatrixBTy::Splat(2.0f); + +// Matrix::Cast +// +// CHECK: call %dx.types.LinAlgMatrixC4M4N4U1S1 @dx.op.linAlgCopyConvertMatrix.mC4M4N4U1S1.mC9M4N4U0S1( +// CHECK-SAME: i32 -2147483635, %dx.types.LinAlgMatrixC9M4N4U0S1 %[[MATA1]], i1 false) +// CHECK-SAME: ; LinAlgCopyConvertMatrix(srcMatrix,transpose) + MatrixBTyInt MatBInt1 = MatA1.Cast(); + +// CHECK: call %dx.types.LinAlgMatrixC4M4N4U1S1 @dx.op.linAlgCopyConvertMatrix.mC4M4N4U1S1.mC9M4N4U1S1( +// CHECK-SAME: i32 -2147483635, %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB1]], i1 true) +// CHECK-SAME: ; LinAlgCopyConvertMatrix(srcMatrix,transpose) + MatrixBTyInt MatBInt2; + MatBInt2 = MatB1.Cast(); + +// Matrix::Load from ByteAddressBuffer +// +// CHECK: %[[MATA2:.*]] = call %dx.types.LinAlgMatrixC9M4N4U0S1 +// CHECK-SAME: @dx.op.linAlgMatrixLoadFromDescriptor.mC9M4N4U0S1(i32 -2147483634, +// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 4) +// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align) + MatrixATy MatA2 = MatrixATy::Load(BAB, 0, 16, MatrixLayoutEnum::ColMajor); + +// Matrix::Load from RWByteAddressBuffer +// +// CHECK: %[[MATB2:.*]] = call %dx.types.LinAlgMatrixC9M4N4U1S1 +// CHECK-SAME: @dx.op.linAlgMatrixLoadFromDescriptor.mC9M4N4U1S1(i32 -2147483634, +// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 256, i32 16, i32 1, i32 4) +// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align) + MatrixBTy MatB2; + MatB2 = MatrixBTy::Load(RWBAB, 256, 16, MatrixLayoutEnum::ColMajor); + +// Matrix::Load from groupshared memory +// +// CHECK: %[[MATB3:.*]] = call %dx.types.LinAlgMatrixC9M4N4U1S1 +// CHECK-SAME: @dx.op.linAlgMatrixLoadFromMemory.mC9M4N4U1S1.f32(i32 -2147483633, +// CHECK-SAME: float addrspace(3)* getelementptr inbounds ([256 x float], +// CHECK-SAME: [256 x float] addrspace(3)* @"\01?SharedArr@@3PAMA", i32 0, i32 0), +// CHECK-SAME: i32 0, i32 16, i32 1) ; LinAlgMatrixLoadFromMemory(memory,offset,stride,layout) + MatrixBTy MatB3 = MatrixBTy::Load(SharedArr, 0, 16, MatrixLayoutEnum::ColMajor); + +// Matrix::Length +// +// CHECK: call i32 @dx.op.linAlgMatrixLength.mC9M4N4U0S1(i32 -2147483632, +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U0S1 %[[MATA1]]) ; LinAlgMatrixLength + uint len = MatA1.Length(); + +// Matrix::GetCoordinate +// +// CHECK: call <2 x i32> @dx.op.linAlgMatrixGetCoordinate.mC9M4N4U1S1(i32 -2147483631, +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB1]], i32 %[[GROUP_ID]]) +// CHECK-SAME:; LinAlgMatrixGetCoordinate(matrix,threadLocalIndex) + uint2 coord = MatB1.GetCoordinate(ID); + +// Matrix::Get +// +// CHECK: %[[VAL:.*]] = call float @dx.op.linAlgMatrixGetElement.f32.mC9M4N4U0S1(i32 -2147483630, +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U0S1 %[[MATA1]], i32 %[[GROUP_ID]]) +// CHECK-SAME:; LinAlgMatrixGetElement(matrix,threadLocalIndex) + float val = MatA1.Get(ID); + +// Matrix::Set +// +// CHECK: %[[MATB1_2:.*]] = call %dx.types.LinAlgMatrixC9M4N4U1S1 +// CHECK-SAME: @dx.op.linAlgMatrixSetElement.mC9M4N4U1S1.mC9M4N4U1S1.f32( +// CHECK-SAME: i32 -2147483629, %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB1]], +// CHECK-SAME: i32 %[[GROUP_ID]], float %[[VAL]]) ; LinAlgMatrixSetElement(matrix,threadLocalIndex,value) + MatB1.Set(ID, val); + +// Matrix::Store to resource descriptor +// +// CHECK: call void @dx.op.linAlgMatrixStoreToDescriptor.mC9M4N4U1S1(i32 -2147483628, +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB1_2]], %dx.types.Handle %{{[0-9]+}}, +// CHECK-SAME: i32 256, i32 16, i32 1, i32 4) ; +// CHECK-SAME: LinAlgMatrixStoreToDescriptor(matrix,handle,offset,stride,layout,align) + MatB1.Store(RWBAB, 256, 16, MatrixLayoutEnum::ColMajor); + +// Matrix::Store to groupshared memory +// +// CHECK: call void @dx.op.linAlgMatrixStoreToMemory.mC9M4N4U1S1.f32(i32 -2147483627, +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB2]], float addrspace(3)* getelementptr inbounds + // CHECK-SAME: ([256 x float], [256 x float] addrspace(3)* @"\01?SharedArr@@3PAMA", i32 0, i32 0), +// CHECK-SAME: i32 0, i32 16, i32 1) ; LinAlgMatrixStoreToMemory(matrix,memory,offset,stride,layout) + MatB2.Store(SharedArr, 0, 16, MatrixLayoutEnum::ColMajor); + +// CHECK: %[[ACCUM0:.*]] = call %dx.types.LinAlgMatrixC9M4N4U2S1 @dx.op.linAlgFillMatrix.mC9M4N4U2S1.f32( +// CHECK-SAME: i32 -2147483636, float 1.400000e+01) ; LinAlgFillMatrix(value) + MatrixAccumTy AccMat1 = MatrixAccumTy::Splat(14.0f); + +// Matrix::InterlockedAccumulate to resource descriptor +// +// CHECK: call void @dx.op.linAlgMatrixAccumulateToDescriptor.mC9M4N4U2S1(i32 -2147483621, +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S1 %[[ACCUM0]], %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 4) +// CHECK-SAME: ; LinAlgMatrixAccumulateToDescriptor(matrix,handle,offset,stride,layout,align) + AccMat1.InterlockedAccumulate(RWBAB, 0, 16, MatrixLayoutEnum::ColMajor); + +// Matrix::InterlockedAccumulate to groupshared memory +// +// CHECK: call void @dx.op.linAlgMatrixAccumulateToMemory.mC9M4N4U2S1.f32(i32 -2147483620, +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S1 %18, +// CHECK-SAME: float addrspace(3)* getelementptr inbounds ([256 x float], +// CHECK-SAME: [256 x float] addrspace(3)* @"\01?SharedArr@@3PAMA", i32 0, i32 0), i32 0, i32 16, i32 1) +// CHECK-SAME: ; LinAlgMatrixAccumulateToMemory(matrix,memory,offset,stride,layout) + AccMat1.InterlockedAccumulate(SharedArr, 0, 16, MatrixLayoutEnum::ColMajor); + +// Matrix::Accumulate +// +// CHECK: %[[ACCUM1:.*]] = call %dx.types.LinAlgMatrixC9M4N4U2S1 @dx.op.linAlgFillMatrix.mC9M4N4U2S1.f32( +// CHECK-SAME: i32 -2147483636, float 0.000000e+00) ; LinAlgFillMatrix(value) + MatrixAccumTy AccMat2 = MatrixAccumTy::Splat(0.0f); + +// CHECK: %[[ACCUM2:.*]] = call %dx.types.LinAlgMatrixC9M4N4U2S1 +// CHECK-SAME: @dx.op.linAlgMatrixAccumulate.mC9M4N4U2S1.mC9M4N4U2S1.mC9M4N4U0S1(i32 -2147483624, +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S1 %[[ACCUM1]], +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U0S1 %[[MATA2]]) ; LinAlgMatrixAccumulate(matrixLHS,matrixRHS) + AccMat2.Accumulate(MatA2); + +// CHECK: %[[ACCUM3:.*]] = call %dx.types.LinAlgMatrixC9M4N4U2S1 +// CHECK-SAME: @dx.op.linAlgMatrixAccumulate.mC9M4N4U2S1.mC9M4N4U2S1.mC9M4N4U1S1(i32 -2147483624, +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S1 %[[ACCUM2]], +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB2]]) +// CHECK-SAME: ; LinAlgMatrixAccumulate(matrixLHS,matrixRHS) + AccMat2.Accumulate(MatB2); + +// Matrix::MultiplyAccumulate +// +// CHECK: %[[ACCUM4:.*]] = call %dx.types.LinAlgMatrixC9M4N4U2S1 +// CHECK-SAME: @dx.op.linAlgMatrixMultiplyAccumulate.mC9M4N4U2S1.mC9M4N4U2S1.mC9M4N4U0S1.mC9M4N4U1S1(i32 -2147483637, +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S1 %[[ACCUM3]], +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U0S1 %[[MATA1]], +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB1_2]]) +// CHECK-SAME: ; LinAlgMatrixMultiplyAccumulate(matrixA,matrixB,matrixC) + AccMat2.MultiplyAccumulate(MatA1, MatB1); + +// Matrix::Load for thread-scope matrix +// +// CHECK: %[[TSMATA:.*]] = call %dx.types.LinAlgMatrixC9M4N4U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC9M4N4U0S0( +// CHECK-SAME: i32 -2147483634, %dx.types.Handle %24, i32 0, i32 16, i32 1, i32 4) +// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align) + TSMatrixATy TSMatA = TSMatrixATy::Load(BAB, 0, 16); + +// Matrix::InterlockedAccumulate for thread-scope matrix +// +// CHECK: call void @dx.op.linAlgMatrixAccumulateToDescriptor.mC9M4N4U0S0(i32 -2147483621, +// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U0S0 %25, %dx.types.Handle %26, i32 0, i32 16, i32 1, i32 4) +// CHECK-SAME: ; LinAlgMatrixAccumulateToDescriptor(matrix,handle,offset,stride,layout,align) + TSMatA.InterlockedAccumulate(RWBAB, 0, 16, MatrixLayoutEnum::ColMajor); + +// CHECK: call i32 @dx.op.linAlgMatrixQueryAccumulatorLayout(i32 -2147483626) ; LinAlgMatrixQueryAccumulatorLayout() + MatrixUseEnum layout = AccumulatorLayout(); +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/matrix-multiply.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/matrix-multiply.hlsl new file mode 100644 index 0000000000..c7357d3e93 --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/matrix-multiply.hlsl @@ -0,0 +1,73 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -I %hlsl_headers -T cs_6_10 %s | FileCheck %s + +#include +using namespace dx::linalg; + +[numthreads(4, 4, 4)] +void main() +{ +// Multiply in Wave scope +// + using MatrixAF16WTy = Matrix; + using MatrixAI32WTy = Matrix; + using MatrixBI32WTy = Matrix; + using MatrixAccF32WTy = Matrix; + using MatrixAccI32WTy = Matrix; + +// CHECK: %[[MATA1:.*]] = call %dx.types.LinAlgMatrixC8M3N4U0S1 @dx.op.linAlgFillMatrix.mC8M3N4U0S1.f32( +// CHECK-SAME: i32 -2147483636, float 1.500000e+00) ; LinAlgFillMatrix(value) + MatrixAF16WTy MatA1 = MatrixAF16WTy::Splat(1.5f); + +// CHECK: %[[MATA2:.*]] = call %dx.types.LinAlgMatrixC4M3N4U0S1 @dx.op.linAlgFillMatrix.mC4M3N4U0S1.i32( +// CHECK-SAME: i32 -2147483636, i32 45) ; LinAlgFillMatrix(value) + MatrixAI32WTy MatA2 = MatrixAI32WTy::Splat(45); + +// CHECK: %[[MATB1:.*]] = call %dx.types.LinAlgMatrixC4M4N5U1S1 @dx.op.linAlgFillMatrix.mC4M4N5U1S1.i32( +// CHECK-SAME: i32 -2147483636, i32 13) ; LinAlgFillMatrix(value) + MatrixBI32WTy MatB1 = MatrixBI32WTy::Splat(13); + +// CHECK: %[[MATC1:.*]] = call %dx.types.LinAlgMatrixC9M3N5U2S1 +// CHECK-SAME: @dx.op.linAlgMatrixMultiply.mC9M3N5U2S1.mC8M3N4U0S1.mC4M4N5U1S1(i32 -2147483625, +// CHECK-SAME: %dx.types.LinAlgMatrixC8M3N4U0S1 %[[MATA1]], %dx.types.LinAlgMatrixC4M4N5U1S1 %[[MATB1]]) +// CHECK-SAME: ; LinAlgMatrixMultiply(matrixA,matrixB) + MatrixAccF32WTy MatCFlt1 = Multiply(MatA1, MatB1); + +// CHECK: %[[MATC2:.*]] = call %dx.types.LinAlgMatrixC4M3N5U2S1 +// CHECK-SAME: @dx.op.linAlgMatrixMultiply.mC4M3N5U2S1.mC4M3N4U0S1.mC4M4N5U1S1(i32 -2147483625, +// CHECK-SAME: %dx.types.LinAlgMatrixC4M3N4U0S1 %2, %dx.types.LinAlgMatrixC4M4N5U1S1 %3) +// CHECK-SAME: ; LinAlgMatrixMultiply(matrixA,matrixB) + MatrixAccI32WTy MatCInt1 = Multiply(MatA2, MatB1); + +// Multiply in ThreadGroup scope +// + using MatrixAF16TGTy = Matrix; + using MatrixAI32TGTy = Matrix; + using MatrixBI32TGTy = Matrix; + using MatrixAccF32TGTy = Matrix; + using MatrixAccI32TGTy = Matrix; + +// CHECK: %[[MATA3:.*]] = call %dx.types.LinAlgMatrixC8M3N4U0S2 @dx.op.linAlgFillMatrix.mC8M3N4U0S2.f32( +// CHECK-SAME: i32 -2147483636, float 2.500000e+00) ; LinAlgFillMatrix(value) + MatrixAF16TGTy MatA3 = MatrixAF16TGTy::Splat(2.5f); + +// CHECK: %[[MATA4:.*]] = call %dx.types.LinAlgMatrixC4M3N4U0S2 @dx.op.linAlgFillMatrix.mC4M3N4U0S2.i32( +// CHECK-SAME: i32 -2147483636, i32 23) ; LinAlgFillMatrix(value) + MatrixAI32TGTy MatA4 = MatrixAI32TGTy::Splat(23); + +// CHECK: %[[MATB3:.*]] = call %dx.types.LinAlgMatrixC4M4N5U1S2 @dx.op.linAlgFillMatrix.mC4M4N5U1S2.i32( +// CHECK-SAME: i32 -2147483636, i32 7) ; LinAlgFillMatrix(value) + MatrixBI32TGTy MatB3 = MatrixBI32TGTy::Splat(7); + +// CHECK: %[[MATC3:.*]] = call %dx.types.LinAlgMatrixC9M3N5U2S2 +// CHECK-SAME: @dx.op.linAlgMatrixMultiply.mC9M3N5U2S2.mC8M3N4U0S2.mC4M4N5U1S2(i32 -2147483625, +// CHECK-SAME: %dx.types.LinAlgMatrixC8M3N4U0S2 %[[MATA3]], %dx.types.LinAlgMatrixC4M4N5U1S2 %[[MATB3]]) +// CHECK-SAME: ; LinAlgMatrixMultiply(matrixA,matrixB) + MatrixAccF32TGTy MatCFlt2 = Multiply(MatA3, MatB3); + +// CHECK: %[[MATC4:.*]] = call %dx.types.LinAlgMatrixC4M3N5U2S2 +// CHECK-SAME: @dx.op.linAlgMatrixMultiply.mC4M3N5U2S2.mC4M3N4U0S2.mC4M4N5U1S2(i32 -2147483625, +// CHECK-SAME: %dx.types.LinAlgMatrixC4M3N4U0S2 %[[MATA4]], %dx.types.LinAlgMatrixC4M4N5U1S2 %[[MATB3]]) +// CHECK-SAME: ; LinAlgMatrixMultiply(matrixA,matrixB) + MatrixAccI32TGTy MatCInt2 = Multiply(MatA4, MatB3); +} diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl new file mode 100644 index 0000000000..0fc9a27c6c --- /dev/null +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl @@ -0,0 +1,66 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -I %hlsl_headers -enable-16bit-types -T cs_6_10 %s | FileCheck %s + +#include +using namespace dx::linalg; + +using MatrixATy = Matrix; +using MatrixAccumTy = Matrix; + +ByteAddressBuffer BAB : register(t0); + +[numthreads(4, 4, 4)] +void main(uint ID : SV_GroupID) { + +// CHECK: %[[MAT1:.*]] = call %dx.types.LinAlgMatrixC8M8N8U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M8N8U0S0( +// CHECK-SAME: i32 -2147483634, %dx.types.Handle %2, i32 0, i32 8, i32 1, i32 2) +// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align) + MatrixATy Mat1 = MatrixATy::Load(BAB, 0, 8); + + vector vec1 = 10.3f; + +// CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N8U0S0.v8f16(i32 -2147483623, +// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %3, i1 true, <8 x half> , i32 8) +// CHECK-SAME: ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation) + vector vec2 = Multiply(Mat1, vec1); + +// CHECK: %[[VEC3:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8f16(i32 -2147483622, +// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> , i32 8, <8 x half> %[[VEC2]], i32 8) +// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) + vector vec3 = MultiplyAdd(Mat1, vec1, vec2); + +// CHECK: %[[VEC4:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8f16(i32 -2147483622, +// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC2]], i32 8, <8 x half> %[[VEC3]], i32 8) +// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) + InterpretedVector interpVec2 = MakeInterpretedVector(vec2); + vector vec4 = MultiplyAdd(Mat1, interpVec2, vec3); + + // CHECK: %[[RAWLOAD:.*]] = call %dx.types.ResRet.v8i16 @dx.op.rawBufferVectorLoad.v8i16(i32 303, + // CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 4096, i32 undef, i32 2) ; RawBufferVectorLoad(buf,index,elementOffset,alignment) + + // CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0 + + // CHECK: %[[VEC5:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8i16(i32 -2147483622, + // CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC3]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2) + // CHECK-SAME:; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) + VectorRef memBias = {BAB, 4096}; + vector vec5 = MultiplyAdd(Mat1, vec3, memBias); + + // CHECK: %[[RAWLOAD:.*]] = call %dx.types.ResRet.v8i16 @dx.op.rawBufferVectorLoad.v8i16(i32 303, + // CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 4096, i32 undef, i32 2) + // CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment) + + // CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0 + + // CHECK: %[[VEC6:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N8U0S0.v8f16.v8i16(i32 -2147483622, + // CHECK-SAME: %dx.types.LinAlgMatrixC8M8N8U0S0 %[[MAT1]], i1 true, <8 x half> %[[VEC2]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2) + // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) + vector vec6 = MultiplyAdd(Mat1, interpVec2, memBias); + + // CHECK: %[[ACCUM:.*]] = call %dx.types.LinAlgMatrixC8M8N8U2S0 + // CHECK-SAME: @dx.op.linAlgMatrixOuterProduct.mC8M8N8U2S0.v8f16.v8f16(i32 -2147483619, + // CHECK-SAME: <8 x half> %[[VEC5]], <8 x half> %[[VEC6]]) ; LinAlgMatrixOuterProduct(vectorA,vectorB) + MatrixAccumTy AccumMatrix = OuterProduct(vec5, vec6); +} diff --git a/utils/hct/gen_intrin_main.txt b/utils/hct/gen_intrin_main.txt index 8dc2950442..239f381614 100644 --- a/utils/hct/gen_intrin_main.txt +++ b/utils/hct/gen_intrin_main.txt @@ -397,13 +397,13 @@ void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(out LinAlgMatrix ret, uint [[min_sm=6.10]] __builtin_LinAlg_MatrixLength(in LinAlgMatrix matrix); uint<2> [[min_sm=6.10]] __builtin_LinAlg_MatrixGetCoordinate(in LinAlgMatrix matrix, in uint threadLocalIndex); void [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(out numeric ret, in LinAlgMatrix matrix, in uint threadLocalIndex); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(out LinAlgMatrix ret, in LinAlgMatrix matrix, in uint threadLocalIndex, in numeric value); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(ref LinAlgMatrix ret, in LinAlgMatrix matrix, in uint threadLocalIndex, in numeric value); void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align); void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout); uint [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout(); void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(out LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(out LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(ref LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(ref LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS); void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<> input, in uint inputInterp); void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<> input, in uint inputInterp, in numeric<> bias, in uint biasInterp); void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align);