Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,23 @@
using namespace dx::linalg;

using MatrixATy = Matrix<ComponentType::F16, 8, 4, MatrixUse::A, MatrixScope::Thread>;
using MatrixAccumTy = Matrix<ComponentType::F16, 8, 8, MatrixUse::Accumulator, MatrixScope::Thread>;
using MatrixAccum_8_8_Ty = Matrix<ComponentType::F16, 8, 8, MatrixUse::Accumulator, MatrixScope::Thread>;
using MatrixAccum_8_4_Ty = Matrix<ComponentType::F16, 8, 4, MatrixUse::Accumulator, MatrixScope::Thread>;

ByteAddressBuffer BAB : register(t0);

[numthreads(4, 4, 4)]
void main(uint ID : SV_GroupID) {

// CHECK: %[[MAT1:.*]] = call %dx.types.LinAlgMatrixC8M8N4U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M8N4U0S0(
// CHECK-SAME: i32 -2147483634, %dx.types.Handle %2, i32 0, i32 8, i32 1, i32 2)
// CHECK-SAME: i32 -2147483634, %dx.types.Handle %{{[0-9]+}}, i32 0, i32 8, i32 1, i32 2)
// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
MatrixATy Mat1 = MatrixATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0, 8);

vector<half, 4> vec1 = 10.3f;

// CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N4U0S0.v4f16(i32 -2147483623,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %3, i1 true, <4 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
// CHECK-SAME: half 0xH4926>, i32 8) ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation)
vector<half, 8> vec2 = Multiply<half>(Mat1, vec1);

Expand Down Expand Up @@ -61,13 +62,24 @@ void main(uint ID : SV_GroupID) {
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 8> vec6 = MultiplyAdd<half>(Mat1, interpVec2, memBias);

// CHECK: %[[ACCUM:.*]] = call %dx.types.LinAlgMatrixC8M8N8U2S0
// CHECK: %[[ACCUM1:.*]] = call %dx.types.LinAlgMatrixC8M8N8U2S0
// CHECK-SAME: @dx.op.linAlgMatrixOuterProduct.mC8M8N8U2S0.v8f16.v8f16(i32 -2147483619,
// CHECK-SAME: <8 x half> %[[VEC5]], <8 x half> %[[VEC6]]) ; LinAlgMatrixOuterProduct(vectorA,vectorB)
MatrixAccumTy AccumMatrix = OuterProduct<ComponentType::F16>(vec5, vec6);
MatrixAccum_8_8_Ty AccumMatrix1 = OuterProduct<ComponentType::F16>(vec5, vec6);

// CHECK: %[[ACCUM2:.*]] = call %dx.types.LinAlgMatrixC8M8N4U2S0 @dx.op.linAlgMatrixOuterProduct.mC8M8N4U2S0.v8f16.v4f16(
// CHECK-SAME: i32 -2147483619, <8 x half> %[[VEC5]], <4 x half> %[[VEC20]]) ; LinAlgMatrixOuterProduct(vectorA,vectorB)
MatrixAccum_8_4_Ty AccumMatrix2 = OuterProduct<ComponentType::F16>(vec5, vec20);

// CHECK: %[[CONV_VEC:.*]] = call <8 x float> @dx.op.linAlgConvert.v8f32.v8f16(i32 -2147483618,
// CHECK-SAME: <8 x half> %[[VEC6]], i32 8, i32 9) ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
InterpretedVector<float, 8, ComponentType::F32> convertedVec;
convertedVec = Convert<ComponentType::F32, ComponentType::F16>(vec6);

// CHECK: call <4 x i32> @dx.op.linAlgConvert.v4i32.v16f16(i32 -2147483618, <16 x half> %21, i32 8, i32 21)
// CHECK: ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
typedef vector<half, 16> half16;
half16 srcF16 = BAB.Load<half16>(128);
InterpretedVector<uint, 4, ComponentEnum::F8_E4M3FN> convertedPacked = Convert<ComponentEnum::F8_E4M3FN, ComponentEnum::F16>(srcF16);

}
4 changes: 2 additions & 2 deletions utils/hct/gen_intrin_main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,8 @@ void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<c> ret, i
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<c2> input, in uint inputInterp, in numeric<c> bias, in uint biasInterp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<> vecA, in numeric<> vecB);
void [[min_sm=6.10]] __builtin_LinAlg_Convert(out numeric<> ret, in numeric<> vec, in uint input_interp, in uint output_interp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<c> vecA, in numeric<c2> vecB);
void [[min_sm=6.10]] __builtin_LinAlg_Convert(out numeric<c> ret, in numeric<c2> vec, in uint input_interp, in uint output_interp);

} namespace

Expand Down
Loading