Skip to content

Commit e473342

Browse files
authored
[mlir][Vector] Improve vector.transferx store-to-load-forwarding (#171840)
This patch changes the transfer_write -> transfer_read load store forwarding canonicalization pattern to work based on permutation maps and less on adhoc logic. The old logic couldn't canonicalize a simple unit dim broadcast through transfer_write/transfer_read which is added as a test in this patch. This patch also details what would be needed to support cases which are not yet implemented better.
1 parent 0deee8c commit e473342

File tree

2 files changed

+83
-51
lines changed

2 files changed

+83
-51
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 64 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5118,6 +5118,22 @@ Speculation::Speculatability TransferReadOp::getSpeculatability() {
51185118
return Speculation::NotSpeculatable;
51195119
}
51205120

5121+
/// Given a projected permutation, inverse an affine map, making the unused dims
5122+
/// 0 in the result.
5123+
static AffineMap inverseWithUnusedDims(AffineMap map) {
5124+
assert(map.isProjectedPermutation() &&
5125+
"expected a projected permutation map");
5126+
SmallVector<AffineExpr> results(map.getNumInputs(),
5127+
getAffineConstantExpr(0, map.getContext()));
5128+
for (auto [idx, result] : llvm::enumerate(map.getResults())) {
5129+
// We should only have dim exprs because this is a projected permutation.
5130+
int64_t pos = cast<AffineDimExpr>(result).getPosition();
5131+
results[pos] = getAffineDimExpr(idx, map.getContext());
5132+
}
5133+
return AffineMap::get(/*dimCount=*/map.getNumResults(), /*symbolCount=*/0,
5134+
results, map.getContext());
5135+
}
5136+
51215137
namespace {
51225138
/// Store to load forwarding for transfer operations with permuation maps.
51235139
/// Even if the permutation maps are different we can still propagate the store
@@ -5153,6 +5169,13 @@ struct TransferReadAfterWriteToBroadcast
51535169
// Bail if we need an alias analysis.
51545170
if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
51555171
return failure();
5172+
// Bail in the masked case (too complex atm and needed to properly account
5173+
// for padding).
5174+
if (readOp.getMask() || defWrite.getMask())
5175+
return failure();
5176+
// If indices are not the same a shift may be required, bail.
5177+
if (readOp.getIndices() != defWrite.getIndices())
5178+
return failure();
51565179
// Bail if we need a bounds analysis.
51575180
if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
51585181
return failure();
@@ -5161,60 +5184,50 @@ struct TransferReadAfterWriteToBroadcast
51615184
if (readOp.getTransferChunkAccessed() !=
51625185
defWrite.getTransferChunkAccessed())
51635186
return failure();
5164-
// TODO: Support cases where a dim is explicitly written but implicitly
5165-
// read (i.e., a unit dim that is rank reduced).
5166-
if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
5167-
getUnusedDimsBitVector({defWrite.getPermutationMap()}))
5168-
return failure();
5169-
// This pattern should only catch the broadcast case, the non-broadcast case
5170-
// should be done separately to keep application conditions clean and
5171-
// separate.
5172-
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
5173-
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
5174-
bool bcast = !readMap.getBroadcastDims().empty() ||
5175-
!writeMap.getBroadcastDims().empty();
5176-
if (!bcast)
5177-
return failure();
5178-
// At this point, we know we have a bcast.
5179-
// Bail in the masked case (too complex atm and needed to properly account
5180-
// for padding).
5181-
if (readOp.getMask() || defWrite.getMask())
5182-
return failure();
5183-
// If indices are not the same a shift may be required, bail.
5184-
if (readOp.getIndices() != defWrite.getIndices())
5187+
// WriteMap: tensor -> w_vec
5188+
// ReadMap: tensor -> r_vec
5189+
//
5190+
// inv(WriteMap): w_vec -> tensor
5191+
// inv(WriteMap) o ReadMap: w_vec -> r_vec
5192+
AffineMap readMap = readOp.getPermutationMap();
5193+
AffineMap writeMap = defWrite.getPermutationMap();
5194+
AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5195+
AffineMap composedMap = readMap.compose(invWriteMap);
5196+
// If there are any unused dims in the composedMap, we have to drop some
5197+
// unit dims from the written vector before we can do transpose(broadcast).
5198+
// TODO: Support this case.
5199+
if (getUnusedDimsBitVector(composedMap).any())
51855200
return failure();
5186-
5201+
// readVec = transpose(broadcast(writeVec))
5202+
//
5203+
// Build a transpose permutation for the above transpose operation.
5204+
//
5205+
// Treat the composed map as having extra leading dimensions which are
5206+
// the broadcasted dimensions, and treat the zeros as these new broadcasted
5207+
// dimensions.
5208+
SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
5209+
int64_t numBroadcastedDims = broadcastedDims.size();
5210+
auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5211+
invPerm.resize(composedMap.getNumResults());
5212+
for (auto [idx, expr] : llvm::enumerate(composedMap.getResults())) {
5213+
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
5214+
int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5215+
invPerm[effectiveDim] = idx;
5216+
}
5217+
}
5218+
// Applying the inverse permutation on the readVecTy will give us the
5219+
// broadcast result type.
5220+
VectorType readVecTy = readOp.getVectorType();
5221+
SmallVector<int64_t> permutation = invertPermutationVector(invPerm);
5222+
auto broadcastedVecTy =
5223+
VectorType::get(applyPermutation(readVecTy.getShape(), invPerm),
5224+
readVecTy.getElementType(),
5225+
applyPermutation(readVecTy.getScalableDims(), invPerm));
5226+
// Build the transpose(broadcast) transformation.
51875227
Value vec = defWrite.getVector();
5188-
// TODO: loop through the chain of transfer_write if we can prove that they
5189-
// don't overlap with the transfer_read. This requires improving
5190-
// `isDisjointTransferIndices` helper.
5191-
AffineMap map = readMap.compose(writeMap);
5192-
if (map.getNumResults() == 0)
5193-
return failure();
5194-
// Calculate the permutation to apply to go from the vector stored to the
5195-
// vector read.
5196-
SmallVector<unsigned> permutation;
5197-
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
5198-
return failure();
5199-
52005228
Location loc = readOp.getLoc();
5201-
// Calculate the broadcast shape by applying the reverse permutation to the
5202-
// final shape we want.
5203-
ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
5204-
SmallVector<int64_t> broadcastShape(destShape.size());
5205-
SmallVector<bool> broadcastScalableFlags(destShape.size());
5206-
for (const auto &pos : llvm::enumerate(permutation)) {
5207-
broadcastShape[pos.value()] = destShape[pos.index()];
5208-
broadcastScalableFlags[pos.value()] =
5209-
readOp.getVectorType().getScalableDims()[pos.index()];
5210-
}
5211-
VectorType broadcastedType = VectorType::get(
5212-
broadcastShape, defWrite.getVectorType().getElementType(),
5213-
broadcastScalableFlags);
5214-
vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
5215-
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
5216-
rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
5217-
transposePerm);
5229+
vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5230+
rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, permutation);
52185231
return success();
52195232
}
52205233
};

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,6 +1914,25 @@ func.func @store_to_load_tensor_perm_broadcast(%arg0 : tensor<4x4x4xf32>,
19141914

19151915
// -----
19161916

1917+
// CHECK-LABEL: func @store_to_load_tensor_forwarding_unit_dim_broadcast
1918+
// CHECK-SAME: (%[[V0:.*]]: vector<4x8xf32>, %[[MEM:.*]]: tensor<1x1x4x8xf32>)
1919+
// CHECK-NOT: vector.transfer_write
1920+
// CHECK-NOT: vector.transfer_read
1921+
// CHECK: %[[RET:.+]] = vector.broadcast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
1922+
// CHECK: return %[[RET]]
1923+
func.func @store_to_load_tensor_forwarding_unit_dim_broadcast(
1924+
%vec: vector<4x8xf32>,
1925+
%mem : tensor<1x1x4x8xf32>
1926+
) -> vector<1x1x4x8xf32> {
1927+
%c0 = arith.constant 0 : index
1928+
%cst_0 = arith.constant 0.0 : f32
1929+
%write = vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] : vector<4x8xf32>, tensor<1x1x4x8xf32>
1930+
%read = vector.transfer_read %write[%c0, %c0, %c0, %c0], %cst_0 : tensor<1x1x4x8xf32>, vector<1x1x4x8xf32>
1931+
return %read : vector<1x1x4x8xf32>
1932+
}
1933+
1934+
// -----
1935+
19171936

19181937
// CHECK-LABEL: func @dead_store_tensor
19191938
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index

0 commit comments

Comments
 (0)