@@ -5118,6 +5118,22 @@ Speculation::Speculatability TransferReadOp::getSpeculatability() {
51185118 return Speculation::NotSpeculatable;
51195119}
51205120
5121+ // / Given a projected permutation, inverse an affine map, making the unused dims
5122+ // / 0 in the result.
5123+ static AffineMap inverseWithUnusedDims (AffineMap map) {
5124+ assert (map.isProjectedPermutation () &&
5125+ " expected a projected permutation map" );
5126+ SmallVector<AffineExpr> results (map.getNumInputs (),
5127+ getAffineConstantExpr (0 , map.getContext ()));
5128+ for (auto [idx, result] : llvm::enumerate (map.getResults ())) {
5129+ // We should only have dim exprs because this is a projected permutation.
5130+ int64_t pos = cast<AffineDimExpr>(result).getPosition ();
5131+ results[pos] = getAffineDimExpr (idx, map.getContext ());
5132+ }
5133+ return AffineMap::get (/* dimCount=*/ map.getNumResults (), /* symbolCount=*/ 0 ,
5134+ results, map.getContext ());
5135+ }
5136+
51215137namespace {
51225138// / Store to load forwarding for transfer operations with permuation maps.
51235139// / Even if the permutation maps are different we can still propagate the store
@@ -5153,6 +5169,13 @@ struct TransferReadAfterWriteToBroadcast
51535169 // Bail if we need an alias analysis.
51545170 if (!readOp.hasPureTensorSemantics () || !defWrite.hasPureTensorSemantics ())
51555171 return failure ();
5172+ // Bail in the masked case (too complex atm and needed to properly account
5173+ // for padding).
5174+ if (readOp.getMask () || defWrite.getMask ())
5175+ return failure ();
5176+ // If indices are not the same a shift may be required, bail.
5177+ if (readOp.getIndices () != defWrite.getIndices ())
5178+ return failure ();
51565179 // Bail if we need a bounds analysis.
51575180 if (readOp.hasOutOfBoundsDim () || defWrite.hasOutOfBoundsDim ())
51585181 return failure ();
@@ -5161,60 +5184,50 @@ struct TransferReadAfterWriteToBroadcast
51615184 if (readOp.getTransferChunkAccessed () !=
51625185 defWrite.getTransferChunkAccessed ())
51635186 return failure ();
5164- // TODO: Support cases where a dim is explicitly written but implicitly
5165- // read (i.e., a unit dim that is rank reduced).
5166- if (getUnusedDimsBitVector ({readOp.getPermutationMap ()}) !=
5167- getUnusedDimsBitVector ({defWrite.getPermutationMap ()}))
5168- return failure ();
5169- // This pattern should only catch the broadcast case, the non-broadcast case
5170- // should be done separately to keep application conditions clean and
5171- // separate.
5172- AffineMap readMap = compressUnusedDims (readOp.getPermutationMap ());
5173- AffineMap writeMap = compressUnusedDims (defWrite.getPermutationMap ());
5174- bool bcast = !readMap.getBroadcastDims ().empty () ||
5175- !writeMap.getBroadcastDims ().empty ();
5176- if (!bcast)
5177- return failure ();
5178- // At this point, we know we have a bcast.
5179- // Bail in the masked case (too complex atm and needed to properly account
5180- // for padding).
5181- if (readOp.getMask () || defWrite.getMask ())
5182- return failure ();
5183- // If indices are not the same a shift may be required, bail.
5184- if (readOp.getIndices () != defWrite.getIndices ())
5187+ // WriteMap: tensor -> w_vec
5188+ // ReadMap: tensor -> r_vec
5189+ //
5190+ // inv(WriteMap): w_vec -> tensor
5191+ // inv(WriteMap) o ReadMap: w_vec -> r_vec
5192+ AffineMap readMap = readOp.getPermutationMap ();
5193+ AffineMap writeMap = defWrite.getPermutationMap ();
5194+ AffineMap invWriteMap = inverseWithUnusedDims (writeMap);
5195+ AffineMap composedMap = readMap.compose (invWriteMap);
5196+ // If there are any unused dims in the composedMap, we have to drop some
5197+ // unit dims from the written vector before we can do transpose(broadcast).
5198+ // TODO: Support this case.
5199+ if (getUnusedDimsBitVector (composedMap).any ())
51855200 return failure ();
5186-
5201+ // readVec = transpose(broadcast(writeVec))
5202+ //
5203+ // Build a transpose permutation for the above transpose operation.
5204+ //
5205+ // Treat the composed map as having extra leading dimensions which are
5206+ // the broadcasted dimensions, and treat the zeros as these new broadcasted
5207+ // dimensions.
5208+ SmallVector<unsigned > broadcastedDims = composedMap.getBroadcastDims ();
5209+ int64_t numBroadcastedDims = broadcastedDims.size ();
5210+ auto invPerm = llvm::to_vector_of<int64_t >(broadcastedDims);
5211+ invPerm.resize (composedMap.getNumResults ());
5212+ for (auto [idx, expr] : llvm::enumerate (composedMap.getResults ())) {
5213+ if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
5214+ int64_t effectiveDim = dim.getPosition () + numBroadcastedDims;
5215+ invPerm[effectiveDim] = idx;
5216+ }
5217+ }
5218+ // Applying the inverse permutation on the readVecTy will give us the
5219+ // broadcast result type.
5220+ VectorType readVecTy = readOp.getVectorType ();
5221+ SmallVector<int64_t > permutation = invertPermutationVector (invPerm);
5222+ auto broadcastedVecTy =
5223+ VectorType::get (applyPermutation (readVecTy.getShape (), invPerm),
5224+ readVecTy.getElementType (),
5225+ applyPermutation (readVecTy.getScalableDims (), invPerm));
5226+ // Build the transpose(broadcast) transformation.
51875227 Value vec = defWrite.getVector ();
5188- // TODO: loop through the chain of transfer_write if we can prove that they
5189- // don't overlap with the transfer_read. This requires improving
5190- // `isDisjointTransferIndices` helper.
5191- AffineMap map = readMap.compose (writeMap);
5192- if (map.getNumResults () == 0 )
5193- return failure ();
5194- // Calculate the permutation to apply to go from the vector stored to the
5195- // vector read.
5196- SmallVector<unsigned > permutation;
5197- if (!map.isPermutationOfMinorIdentityWithBroadcasting (permutation))
5198- return failure ();
5199-
52005228 Location loc = readOp.getLoc ();
5201- // Calculate the broadcast shape by applying the reverse permutation to the
5202- // final shape we want.
5203- ArrayRef<int64_t > destShape = readOp.getVectorType ().getShape ();
5204- SmallVector<int64_t > broadcastShape (destShape.size ());
5205- SmallVector<bool > broadcastScalableFlags (destShape.size ());
5206- for (const auto &pos : llvm::enumerate (permutation)) {
5207- broadcastShape[pos.value ()] = destShape[pos.index ()];
5208- broadcastScalableFlags[pos.value ()] =
5209- readOp.getVectorType ().getScalableDims ()[pos.index ()];
5210- }
5211- VectorType broadcastedType = VectorType::get (
5212- broadcastShape, defWrite.getVectorType ().getElementType (),
5213- broadcastScalableFlags);
5214- vec = vector::BroadcastOp::create (rewriter, loc, broadcastedType, vec);
5215- SmallVector<int64_t > transposePerm (permutation.begin (), permutation.end ());
5216- rewriter.replaceOpWithNewOp <vector::TransposeOp>(readOp, vec,
5217- transposePerm);
5229+ vec = vector::BroadcastOp::create (rewriter, loc, broadcastedVecTy, vec);
5230+ rewriter.replaceOpWithNewOp <vector::TransposeOp>(readOp, vec, permutation);
52185231 return success ();
52195232 }
52205233};
0 commit comments