diff --git a/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl index 876edcca46..268e814e7e 100644 --- a/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl +++ b/tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl @@ -5,7 +5,8 @@ using namespace dx::linalg; using MatrixATy = Matrix; -using MatrixAccumTy = Matrix; +using MatrixAccum_8_8_Ty = Matrix; +using MatrixAccum_8_4_Ty = Matrix; ByteAddressBuffer BAB : register(t0); @@ -13,14 +14,14 @@ ByteAddressBuffer BAB : register(t0); void main(uint ID : SV_GroupID) { // CHECK: %[[MAT1:.*]] = call %dx.types.LinAlgMatrixC8M8N4U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M8N4U0S0( -// CHECK-SAME: i32 -2147483634, %dx.types.Handle %2, i32 0, i32 8, i32 1, i32 2) +// CHECK-SAME: i32 -2147483634, %dx.types.Handle %{{[0-9]+}}, i32 0, i32 8, i32 1, i32 2) // CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align) MatrixATy Mat1 = MatrixATy::Load(BAB, 0, 8); vector vec1 = 10.3f; // CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N4U0S0.v4f16(i32 -2147483623, -// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %3, i1 true, <4 x half> , i32 8) ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation) vector vec2 = Multiply(Mat1, vec1); @@ -61,13 +62,24 @@ void main(uint ID : SV_GroupID) { // CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation) vector vec6 = MultiplyAdd(Mat1, interpVec2, memBias); - // CHECK: %[[ACCUM:.*]] = call %dx.types.LinAlgMatrixC8M8N8U2S0 + // CHECK: %[[ACCUM1:.*]] = call %dx.types.LinAlgMatrixC8M8N8U2S0 // CHECK-SAME: @dx.op.linAlgMatrixOuterProduct.mC8M8N8U2S0.v8f16.v8f16(i32 -2147483619, // CHECK-SAME: <8 x half> %[[VEC5]], <8 x half> %[[VEC6]]) ; LinAlgMatrixOuterProduct(vectorA,vectorB) - MatrixAccumTy AccumMatrix = OuterProduct(vec5, vec6); + MatrixAccum_8_8_Ty AccumMatrix1 = OuterProduct(vec5, vec6); + + // CHECK: %[[ACCUM2:.*]] = call %dx.types.LinAlgMatrixC8M8N4U2S0 @dx.op.linAlgMatrixOuterProduct.mC8M8N4U2S0.v8f16.v4f16( + // CHECK-SAME: i32 -2147483619, <8 x half> %[[VEC5]], <4 x half> %[[VEC20]]) ; LinAlgMatrixOuterProduct(vectorA,vectorB) + MatrixAccum_8_4_Ty AccumMatrix2 = OuterProduct(vec5, vec20); // CHECK: %[[CONV_VEC:.*]] = call <8 x float> @dx.op.linAlgConvert.v8f32.v8f16(i32 -2147483618, // CHECK-SAME: <8 x half> %[[VEC6]], i32 8, i32 9) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation) InterpretedVector convertedVec; convertedVec = Convert(vec6); + + // CHECK: call <4 x i32> @dx.op.linAlgConvert.v4i32.v16f16(i32 -2147483618, <16 x half> %21, i32 8, i32 21) + // CHECK: ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation) + typedef vector half16; + half16 srcF16 = BAB.Load(128); + InterpretedVector convertedPacked = Convert(srcF16); + } diff --git a/utils/hct/gen_intrin_main.txt b/utils/hct/gen_intrin_main.txt index 3075e8a4cf..657b3e942f 100644 --- a/utils/hct/gen_intrin_main.txt +++ b/utils/hct/gen_intrin_main.txt @@ -408,8 +408,8 @@ void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric ret, i void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric input, in uint inputInterp, in numeric bias, in uint biasInterp); void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align); void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<> vecA, in numeric<> vecB); -void [[min_sm=6.10]] __builtin_LinAlg_Convert(out numeric<> ret, in numeric<> vec, in uint input_interp, in uint output_interp); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric vecA, in numeric vecB); +void [[min_sm=6.10]] __builtin_LinAlg_Convert(out numeric ret, in numeric vec, in uint input_interp, in uint output_interp); } namespace