Skip to content

Commit 0d53746

Browse files
authored
[HLSL][Matrix] Add support for ICK_HLSL_Matrix_Splat to add splat cast of scalars (#170885)
fixes #168960 Adds `ICK_HLSL_Matrix_Splat` and hooks it up to `PerformImplicitConversion` and `IsMatrixConversion`. Map these to `CK_HLSLAggregateSplatCast`.
1 parent 1307b77 commit 0d53746

File tree

8 files changed

+126
-3
lines changed

8 files changed

+126
-3
lines changed

clang/include/clang/Sema/Overload.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,9 @@ class Sema;
207207
// HLSL vector splat from scalar or boolean type.
208208
ICK_HLSL_Vector_Splat,
209209

210+
/// HLSL matrix splat from scalar or boolean type.
211+
ICK_HLSL_Matrix_Splat,
212+
210213
/// The number of conversion kinds
211214
ICK_Num_Conversion_Kinds,
212215
};

clang/include/clang/Sema/Sema.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7944,6 +7944,10 @@ class Sema final : public SemaBase {
79447944
/// implicit casts if necessary.
79457945
ExprResult prepareVectorSplat(QualType VectorTy, Expr *SplattedExpr);
79467946

7947+
/// Prepare `SplattedExpr` for a matrix splat operation, adding
7948+
/// implicit casts if necessary.
7949+
ExprResult prepareMatrixSplat(QualType MatrixTy, Expr *SplattedExpr);
7950+
79477951
// CheckExtVectorCast - check type constraints for extended vectors.
79487952
// Since vectors are an extension, there are no C standard reference for this.
79497953
// We allow casting between vectors and integer datatypes of the same size,

clang/lib/Sema/SemaExpr.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7834,6 +7834,26 @@ ExprResult Sema::prepareVectorSplat(QualType VectorTy, Expr *SplattedExpr) {
78347834
return ImpCastExprToType(SplattedExpr, DestElemTy, CK);
78357835
}
78367836

7837+
ExprResult Sema::prepareMatrixSplat(QualType MatrixTy, Expr *SplattedExpr) {
7838+
QualType DestElemTy = MatrixTy->castAs<MatrixType>()->getElementType();
7839+
7840+
if (DestElemTy == SplattedExpr->getType())
7841+
return SplattedExpr;
7842+
7843+
assert(DestElemTy->isFloatingType() ||
7844+
DestElemTy->isIntegralOrEnumerationType());
7845+
7846+
// TODO: Add support for boolean matrix once exposed
7847+
// https://github.com/llvm/llvm-project/issues/170920
7848+
ExprResult CastExprRes = SplattedExpr;
7849+
CastKind CK = PrepareScalarCast(CastExprRes, DestElemTy);
7850+
if (CastExprRes.isInvalid())
7851+
return ExprError();
7852+
SplattedExpr = CastExprRes.get();
7853+
7854+
return ImpCastExprToType(SplattedExpr, DestElemTy, CK);
7855+
}
7856+
78377857
ExprResult Sema::CheckExtVectorCast(SourceRange R, QualType DestTy,
78387858
Expr *CastExpr, CastKind &Kind) {
78397859
assert(DestTy->isExtVectorType() && "Not an extended vector type!");

clang/lib/Sema/SemaExprCXX.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5212,6 +5212,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
52125212
case ICK_HLSL_Vector_Truncation:
52135213
case ICK_HLSL_Matrix_Truncation:
52145214
case ICK_HLSL_Vector_Splat:
5215+
case ICK_HLSL_Matrix_Splat:
52155216
llvm_unreachable("Improper second standard conversion");
52165217
}
52175218

@@ -5231,6 +5232,15 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
52315232
.get();
52325233
break;
52335234
}
5235+
case ICK_HLSL_Matrix_Splat: {
5236+
// Matrix splat from any arithmetic type to a matrix.
5237+
Expr *Elem = prepareMatrixSplat(ToType, From).get();
5238+
From =
5239+
ImpCastExprToType(Elem, ToType, CK_HLSLAggregateSplatCast, VK_PRValue,
5240+
/*BasePath=*/nullptr, CCK)
5241+
.get();
5242+
break;
5243+
}
52345244
case ICK_HLSL_Vector_Truncation: {
52355245
// Note: HLSL built-in vectors are ExtVectors. Since this truncates a
52365246
// vector to a smaller vector or to a scalar, this can only operate on

clang/lib/Sema/SemaOverload.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ ImplicitConversionRank clang::GetConversionRank(ImplicitConversionKind Kind) {
165165
ICR_HLSL_Dimension_Reduction,
166166
ICR_Conversion,
167167
ICR_HLSL_Scalar_Widening,
168+
ICR_HLSL_Scalar_Widening,
168169
};
169170
static_assert(std::size(Rank) == (int)ICK_Num_Conversion_Kinds);
170171
return Rank[(int)Kind];
@@ -228,6 +229,7 @@ static const char *GetImplicitConversionName(ImplicitConversionKind Kind) {
228229
"HLSL matrix truncation",
229230
"Non-decaying array conversion",
230231
"HLSL vector splat",
232+
"HLSL matrix splat",
231233
};
232234
static_assert(std::size(Name) == (int)ICK_Num_Conversion_Kinds);
233235
return Name[Kind];
@@ -2145,6 +2147,15 @@ static bool IsMatrixConversion(Sema &S, QualType FromType, QualType ToType,
21452147
return true;
21462148
return IsVectorOrMatrixElementConversion(S, FromElTy, ToElTy, ICK, From);
21472149
}
2150+
2151+
// Matrix splat from any arithmetic type to a matrix.
2152+
if (ToMatrixType && FromType->isArithmeticType()) {
2153+
ElConv = ICK_HLSL_Matrix_Splat;
2154+
QualType ToElTy = ToMatrixType->getElementType();
2155+
return IsVectorOrMatrixElementConversion(S, FromType, ToElTy, ICK, From);
2156+
ICK = ICK_HLSL_Matrix_Splat;
2157+
return true;
2158+
}
21482159
if (FromMatrixType && !ToMatrixType) {
21492160
ElConv = ICK_HLSL_Matrix_Truncation;
21502161
QualType FromElTy = FromMatrixType->getElementType();
@@ -6301,6 +6312,7 @@ static bool CheckConvertedConstantConversions(Sema &S,
63016312
case ICK_SVE_Vector_Conversion:
63026313
case ICK_RVV_Vector_Conversion:
63036314
case ICK_HLSL_Vector_Splat:
6315+
case ICK_HLSL_Matrix_Splat:
63046316
case ICK_Vector_Splat:
63056317
case ICK_Complex_Real:
63066318
case ICK_Block_Pointer_Conversion:
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 6
2+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.7-library -disable-llvm-passes -emit-llvm -finclude-default-header -o - %s | FileCheck %s
3+
4+
// CHECK-LABEL: define hidden void @_Z13ConstantSplatv(
5+
// CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
6+
// CHECK-NEXT: [[ENTRY:.*:]]
7+
// CHECK-NEXT: [[M:%.*]] = alloca [16 x i32], align 4
8+
// CHECK-NEXT: store <16 x i32> splat (i32 1), ptr [[M]], align 4
9+
// CHECK-NEXT: ret void
10+
//
11+
void ConstantSplat() {
12+
int4x4 M = 1;
13+
}
14+
15+
// CHECK-LABEL: define hidden void @_Z18ConstantFloatSplatv(
16+
// CHECK-SAME: ) #[[ATTR0]] {
17+
// CHECK-NEXT: [[ENTRY:.*:]]
18+
// CHECK-NEXT: [[M:%.*]] = alloca [4 x float], align 4
19+
// CHECK-NEXT: store <4 x float> splat (float 3.250000e+00), ptr [[M]], align 4
20+
// CHECK-NEXT: ret void
21+
//
22+
void ConstantFloatSplat() {
23+
float2x2 M = 3.25;
24+
}
25+
26+
// CHECK-LABEL: define hidden void @_Z12DynamicSplatf(
27+
// CHECK-SAME: float noundef nofpclass(nan inf) [[VALUE:%.*]]) #[[ATTR0]] {
28+
// CHECK-NEXT: [[ENTRY:.*:]]
29+
// CHECK-NEXT: [[VALUE_ADDR:%.*]] = alloca float, align 4
30+
// CHECK-NEXT: [[M:%.*]] = alloca [9 x float], align 4
31+
// CHECK-NEXT: store float [[VALUE]], ptr [[VALUE_ADDR]], align 4
32+
// CHECK-NEXT: [[TMP0:%.*]] = load float, ptr [[VALUE_ADDR]], align 4
33+
// CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <9 x float> poison, float [[TMP0]], i64 0
34+
// CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <9 x float> [[SPLAT_SPLATINSERT]], <9 x float> poison, <9 x i32> zeroinitializer
35+
// CHECK-NEXT: store <9 x float> [[SPLAT_SPLAT]], ptr [[M]], align 4
36+
// CHECK-NEXT: ret void
37+
//
38+
void DynamicSplat(float Value) {
39+
float3x3 M = Value;
40+
}
41+
42+
// CHECK-LABEL: define hidden void @_Z13CastThenSplatDv4_f(
43+
// CHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[VALUE:%.*]]) #[[ATTR0]] {
44+
// CHECK-NEXT: [[ENTRY:.*:]]
45+
// CHECK-NEXT: [[VALUE_ADDR:%.*]] = alloca <4 x float>, align 16
46+
// CHECK-NEXT: [[M:%.*]] = alloca [9 x float], align 4
47+
// CHECK-NEXT: store <4 x float> [[VALUE]], ptr [[VALUE_ADDR]], align 16
48+
// CHECK-NEXT: [[TMP0:%.*]] = load <4 x float>, ptr [[VALUE_ADDR]], align 16
49+
// CHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <4 x float> [[TMP0]], i32 0
50+
// CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <9 x float> poison, float [[CAST_VTRUNC]], i64 0
51+
// CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <9 x float> [[SPLAT_SPLATINSERT]], <9 x float> poison, <9 x i32> zeroinitializer
52+
// CHECK-NEXT: store <9 x float> [[SPLAT_SPLAT]], ptr [[M]], align 4
53+
// CHECK-NEXT: ret void
54+
//
55+
void CastThenSplat(float4 Value) {
56+
float3x3 M = (float) Value;
57+
}

clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,14 @@ void fn2x2(float2x2) {}
228228
void fn2x2IO(inout float2x2) {}
229229
void fnI2x2IO(inout int2x2) {}
230230

231-
void matOrVec(float4 F) {}
232-
void matOrVec(float2x2 F) {}
231+
void matOrVec(float4 F) {} // expected-note {{candidate function}}
232+
void matOrVec(float2x2 F) {} // expected-note {{candidate function}}
233233

234234
void matOrVec2(float3 F) {} // expected-note{{candidate function}}
235235
void matOrVec2(float2x3 F) {} // expected-note{{candidate function}}
236236

237+
void matOrVec3(float4x4 F) {}
238+
237239
export void Case8(float2x3 f23, float4x4 f44, float3x3 f33, float3x2 f32) {
238240
int2x2 i22 = f23;
239241
// expected-warning@-1{{implicit conversion truncates matrix: 'float2x3' (aka 'matrix<float, 2, 3>') to 'int2x2' (aka 'matrix<int, 2, 2>')}}
@@ -269,8 +271,12 @@ export void Case8(float2x3 f23, float4x4 f44, float3x3 f33, float3x2 f32) {
269271
//CHECK-NEXT: ImplicitCastExpr {{.*}} 'float4x4':'matrix<float, 4, 4>' <LValueToRValue>
270272

271273
#ifdef ERROR
272-
matOrVec(2.0); // TODO: See #168960 this should be ambiguous once we implement ICK_HLSL_Matrix_Splat.
274+
matOrVec(2.0); // expected-error {{call to 'matOrVec' is ambiguous}}
273275
#endif
276+
matOrVec3(3.14);
277+
//CHECK: ImplicitCastExpr {{.*}} 'float4x4':'matrix<float, 4, 4>' <HLSLAggregateSplatCast>
278+
//CHECK-NEXT: FloatingLiteral {{.*}} <col:13> 'float' 3.140000e+00
279+
274280
matOrVec2(f23);
275281
//CHECK: DeclRefExpr {{.*}} 'void (float2x3)' lvalue Function {{.*}} 'matOrVec2' 'void (float2x3)'
276282
//CHECK-NEXT: ImplicitCastExpr {{.*}} 'float2x3':'matrix<float, 2, 3>' <LValueToRValue>
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -verify %s
2+
3+
void SplatOfVectortoMat(int4 V){
4+
int2x2 M = V;
5+
// expected-error@-1 {{cannot initialize a variable of type 'int2x2' (aka 'matrix<int, 2, 2>') with an lvalue of type 'int4' (aka 'vector<int, 4>')}}
6+
}
7+
8+
void SplatOfMattoMat(int4x3 N){
9+
int4x4 M = N;
10+
// expected-error@-1 {{cannot initialize a variable of type 'matrix<[2 * ...], 4>' with an lvalue of type 'matrix<[2 * ...], 3>'}}
11+
}

0 commit comments

Comments
 (0)