Skip to content

Commit 5361636

Browse files
authored
[mlir][LLVM] refactor FailOnUnsupportedFP (#172054)
Enable `FailOnUnsupportedFP` for `ConvertToLLVMPattern` and set it to `true` for all `math-to-llvm` patterns. This fixes various invalid lowerings of `math` ops on `fp8`/`fp4` types.
1 parent 560fe76 commit 5361636

File tree

6 files changed

+113
-57
lines changed

6 files changed

+113
-57
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic,
5454
const LLVMTypeConverter &typeConverter,
5555
RewriterBase &rewriter);
5656

57+
/// Return "true" if the given type is an unsupported floating point type.
58+
/// In case of a vector type, return "true" if the element type is an
59+
/// unsupported floating point type.
60+
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
61+
Type type);
62+
/// Return "true" if the given op has any unsupported floating point
63+
/// types (either operands or results).
64+
bool opHasUnsupportedFloatingPointTypes(Operation *op,
65+
const TypeConverter &typeConverter);
5766
} // namespace detail
5867

5968
/// Decomposes a `src` value into a set of values of type `dstType` through
@@ -203,7 +212,7 @@ class ConvertToLLVMPattern : public ConversionPattern {
203212

204213
/// Utility class for operation conversions targeting the LLVM dialect that
205214
/// match exactly one source operation.
206-
template <typename SourceOp>
215+
template <typename SourceOp, bool FailOnUnsupportedFP = false>
207216
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
208217
public:
209218
using OpAdaptor = typename SourceOp::Adaptor;
@@ -220,12 +229,24 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
220229
LogicalResult
221230
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
222231
ConversionPatternRewriter &rewriter) const final {
232+
// Bail on unsupported floating point types. (These are type-converted to
233+
// integer types.)
234+
if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
235+
op, *this->typeConverter)) {
236+
return rewriter.notifyMatchFailure(op, "unsupported floating point type");
237+
}
223238
auto sourceOp = cast<SourceOp>(op);
224239
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
225240
}
226241
LogicalResult
227242
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
228243
ConversionPatternRewriter &rewriter) const final {
244+
// Bail on unsupported floating point types. (These are type-converted to
245+
// integer types.)
246+
if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
247+
op, *this->typeConverter)) {
248+
return rewriter.notifyMatchFailure(op, "unsupported floating point type");
249+
}
229250
auto sourceOp = cast<SourceOp>(op);
230251
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
231252
rewriter);

mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,6 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
6060
Attribute propertiesAttr,
6161
const LLVMTypeConverter &typeConverter,
6262
ConversionPatternRewriter &rewriter);
63-
64-
/// Return "true" if the given type is an unsupported floating point type. In
65-
/// case of a vector type, return "true" if the element type is an unsupported
66-
/// floating point type.
67-
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
68-
Type type);
6963
} // namespace detail
7064
} // namespace LLVM
7165

@@ -98,9 +92,11 @@ template <typename SourceOp, typename TargetOp,
9892
template <typename, typename> typename AttrConvert =
9993
AttrConvertPassThrough,
10094
bool FailOnUnsupportedFP = false>
101-
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
95+
class VectorConvertToLLVMPattern
96+
: public ConvertOpToLLVMPattern<SourceOp, FailOnUnsupportedFP> {
10297
public:
103-
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
98+
using ConvertOpToLLVMPattern<SourceOp,
99+
FailOnUnsupportedFP>::ConvertOpToLLVMPattern;
104100
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
105101

106102
LogicalResult
@@ -112,16 +108,9 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
112108

113109
// Bail on unsupported floating point types. (These are type-converted to
114110
// integer types.)
115-
if (FailOnUnsupportedFP) {
116-
for (Value operand : op->getOperands())
117-
if (LLVM::detail::isUnsupportedFloatingPointType(
118-
*this->getTypeConverter(), operand.getType()))
119-
return rewriter.notifyMatchFailure(op,
120-
"unsupported floating point type");
121-
if (LLVM::detail::isUnsupportedFloatingPointType(
122-
*this->getTypeConverter(), op->getResult(0).getType()))
123-
return rewriter.notifyMatchFailure(op,
124-
"unsupported floating point type");
111+
if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
112+
op, *this->typeConverter)) {
113+
return rewriter.notifyMatchFailure(op, "unsupported floating point type");
125114
}
126115

127116
// Determine attributes for the target op

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,3 +516,34 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
516516
base, index, noWrapFlags)
517517
: base;
518518
}
519+
520+
/// Return the given type if it's a floating point type. If the given type is
521+
/// a vector type, return its element type if it's a floating point type.
522+
static FloatType getFloatingPointType(Type type) {
523+
if (auto floatType = dyn_cast<FloatType>(type))
524+
return floatType;
525+
if (auto vecType = dyn_cast<VectorType>(type))
526+
return dyn_cast<FloatType>(vecType.getElementType());
527+
return nullptr;
528+
}
529+
530+
bool LLVM::detail::isUnsupportedFloatingPointType(
531+
const TypeConverter &typeConverter, Type type) {
532+
FloatType floatType = getFloatingPointType(type);
533+
if (!floatType)
534+
return false;
535+
Type convertedType = typeConverter.convertType(floatType);
536+
if (!convertedType)
537+
return true;
538+
return !isa<FloatType>(convertedType);
539+
}
540+
541+
bool LLVM::detail::opHasUnsupportedFloatingPointTypes(
542+
Operation *op, const TypeConverter &typeConverter) {
543+
for (Value operand : op->getOperands())
544+
if (isUnsupportedFloatingPointType(typeConverter, operand.getType()))
545+
return true;
546+
return llvm::any_of(op->getResults(), [&typeConverter](OpResult r) {
547+
return isUnsupportedFloatingPointType(typeConverter, r.getType());
548+
});
549+
}

mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -130,24 +130,3 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
130130
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
131131
rewriter);
132132
}
133-
134-
/// Return the given type if it's a floating point type. If the given type is
135-
/// a vector type, return its element type if it's a floating point type.
136-
static FloatType getFloatingPointType(Type type) {
137-
if (auto floatType = dyn_cast<FloatType>(type))
138-
return floatType;
139-
if (auto vecType = dyn_cast<VectorType>(type))
140-
return dyn_cast<FloatType>(vecType.getElementType());
141-
return nullptr;
142-
}
143-
144-
bool LLVM::detail::isUnsupportedFloatingPointType(
145-
const TypeConverter &typeConverter, Type type) {
146-
FloatType floatType = getFloatingPointType(type);
147-
if (!floatType)
148-
return false;
149-
Type convertedType = typeConverter.convertType(floatType);
150-
if (!convertedType)
151-
return true;
152-
return !isa<FloatType>(convertedType);
153-
}

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ namespace {
3232
template <typename SourceOp, typename TargetOp>
3333
using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
3434

35-
template <typename SourceOp, typename TargetOp>
35+
template <typename SourceOp, typename TargetOp, bool FailOnUnsupportedFP = true>
3636
using ConvertFMFMathToLLVMPattern =
37-
VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
37+
VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
38+
FailOnUnsupportedFP>;
3839

3940
using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
4041
using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
@@ -44,7 +45,9 @@ using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
4445
using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
4546
using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
4647
using CtPopFOpLowering =
47-
VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
48+
VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp,
49+
AttrConvertPassThrough,
50+
/*FailOnUnsupportedFP=*/true>;
4851
using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
4952
using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
5053
using FloorOpLowering =
@@ -76,8 +79,10 @@ using ATan2OpLowering =
7679
// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
7780
// may be better to separate the patterns.
7881
template <typename MathOp, typename LLVMOp>
79-
struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
80-
using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
82+
struct IntOpWithFlagLowering
83+
: public ConvertOpToLLVMPattern<MathOp, /*FailOnUnsupportedFP=*/true> {
84+
using ConvertOpToLLVMPattern<
85+
MathOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
8186
using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
8287

8388
LogicalResult
@@ -122,8 +127,11 @@ using CountTrailingZerosOpLowering =
122127
using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
123128

124129
// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
125-
struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
126-
using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
130+
struct SincosOpLowering
131+
: public ConvertOpToLLVMPattern<math::SincosOp,
132+
/*FailOnUnsupportedFP=*/true> {
133+
using ConvertOpToLLVMPattern<
134+
math::SincosOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
127135

128136
LogicalResult
129137
matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
@@ -154,8 +162,11 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
154162
};
155163

156164
// A `expm1` is converted into `exp - 1`.
157-
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
158-
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
165+
struct ExpM1OpLowering
166+
: public ConvertOpToLLVMPattern<math::ExpM1Op,
167+
/*FailOnUnsupportedFP=*/true> {
168+
using ConvertOpToLLVMPattern<
169+
math::ExpM1Op, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
159170

160171
LogicalResult
161172
matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
@@ -216,8 +227,11 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
216227
};
217228

218229
// A `log1p` is converted into `log(1 + ...)`.
219-
struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
220-
using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
230+
struct Log1pOpLowering
231+
: public ConvertOpToLLVMPattern<math::Log1pOp,
232+
/*FailOnUnsupportedFP=*/true> {
233+
using ConvertOpToLLVMPattern<
234+
math::Log1pOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
221235

222236
LogicalResult
223237
matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
@@ -278,8 +292,11 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
278292
};
279293

280294
// A `rsqrt` is converted into `1 / sqrt`.
281-
struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
282-
using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
295+
struct RsqrtOpLowering
296+
: public ConvertOpToLLVMPattern<math::RsqrtOp,
297+
/*FailOnUnsupportedFP=*/true> {
298+
using ConvertOpToLLVMPattern<
299+
math::RsqrtOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
283300

284301
LogicalResult
285302
matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
@@ -339,8 +356,11 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
339356
}
340357
};
341358

342-
struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
343-
using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
359+
struct IsNaNOpLowering
360+
: public ConvertOpToLLVMPattern<math::IsNaNOp,
361+
/*FailOnUnsupportedFP=*/true> {
362+
using ConvertOpToLLVMPattern<
363+
math::IsNaNOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
344364

345365
LogicalResult
346366
matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
@@ -358,8 +378,11 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
358378
}
359379
};
360380

361-
struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
362-
using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
381+
struct IsFiniteOpLowering
382+
: public ConvertOpToLLVMPattern<math::IsFiniteOp,
383+
/*FailOnUnsupportedFP=*/true> {
384+
using ConvertOpToLLVMPattern<
385+
math::IsFiniteOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
363386

364387
LogicalResult
365388
matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,

mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,3 +628,16 @@ func.func @fastmath(%arg0 : f32, %arg1 : vector<4xf32>) {
628628
%3 = math.fma %arg0, %arg0, %arg0 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f32
629629
func.return
630630
}
631+
632+
// -----
633+
634+
// CHECK-LABEL: func @unsupported_fp_type
635+
// CHECK: math.absf {{.*}} : f4E2M1FN
636+
// CHECK: math.cos {{.*}} : f4E2M1FN
637+
// CHECK: math.fma {{.*}} : f4E2M1FN
638+
func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: f4E2M1FN, %arg2: f4E2M1FN) {
639+
%0 = math.absf %arg0 : f4E2M1FN
640+
%1 = math.cos %arg0 : f4E2M1FN
641+
%2 = math.fma %arg1, %arg1, %arg2 : f4E2M1FN
642+
return
643+
}

0 commit comments

Comments
 (0)