Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-0.14.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,9 @@
* Adding the measurement type into the MLIR assembly format for `qec.ppm` and `qec.select.ppm`
[(#2347)](https://github.com/PennyLaneAI/catalyst/pull/2347)

* Remove duplicate code for canonicalization and verification of Pauli Product Rotation operations.
[(#2313)](https://github.com/PennyLaneAI/catalyst/pull/2313)

<h3>Documentation 📝</h3>

* A new statevector simulator ``lightning.amdgpu`` has been added for optimized performance on AMD GPUs.
Expand Down
64 changes: 25 additions & 39 deletions mlir/lib/QEC/IR/QECOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,37 +35,45 @@ using namespace catalyst::qec;
#include "QEC/IR/QECOps.cpp.inc"

//===----------------------------------------------------------------------===//
// QEC op verifiers.
// QEC op canonicalizers/verifiers helper methods.
//===----------------------------------------------------------------------===//

LogicalResult PPRotationOp::verify()
template <typename OpType> LogicalResult canonicalizePPROp(OpType op, PatternRewriter &rewriter)
{
size_t numPauliProduct = getPauliProduct().size();

if (numPauliProduct == 0) {
return emitOpError("Pauli string must be non-empty");
}
bool allIdentity = llvm::all_of(op.getPauliProduct(), [](mlir::Attribute attr) {
auto pauliStr = llvm::cast<mlir::StringAttr>(attr);
return pauliStr.getValue() == "I";
});

if (numPauliProduct != getInQubits().size()) {
return emitOpError("Number of qubits must match number of pauli operators");
if (allIdentity) {
rewriter.replaceOp(op, op.getInQubits());
return mlir::success();
}
return mlir::success();
return mlir::failure();
}

LogicalResult PPRotationArbitraryOp::verify()
template <typename OpType> LogicalResult verifyPPROp(OpType op)
{
size_t numPauliProduct = getPauliProduct().size();
size_t numPauliProduct = op.getPauliProduct().size();

if (numPauliProduct == 0) {
return emitOpError("Pauli string must be non-empty");
return op.emitOpError("Pauli string must be non-empty");
}

if (numPauliProduct != getInQubits().size()) {
return emitOpError("Number of qubits must match number of pauli operators");
if (numPauliProduct != op.getInQubits().size()) {
return op.emitOpError("Number of qubits must match number of pauli operators");
}
return mlir::success();
}

//===----------------------------------------------------------------------===//
// QEC op verifiers.
//===----------------------------------------------------------------------===//

LogicalResult PPRotationOp::verify() { return verifyPPROp(*this); }

LogicalResult PPRotationArbitraryOp::verify() { return verifyPPROp(*this); }

LogicalResult PPMeasurementOp::verify()
{
if (getInQubits().size() != getPauliProduct().size()) {
Expand Down Expand Up @@ -105,35 +113,13 @@ LogicalResult FabricateOp::verify()

LogicalResult PPRotationOp::canonicalize(PPRotationOp op, PatternRewriter &rewriter)
{
auto pauliProduct = op.getPauliProduct();

bool allIdentity = llvm::all_of(pauliProduct, [](mlir::Attribute attr) {
auto pauliStr = llvm::cast<mlir::StringAttr>(attr);
return pauliStr.getValue() == "I";
});

if (allIdentity) {
rewriter.replaceOp(op, op.getInQubits());
return mlir::success();
}
return mlir::failure();
return canonicalizePPROp(op, rewriter);
}

LogicalResult PPRotationArbitraryOp::canonicalize(PPRotationArbitraryOp op,
PatternRewriter &rewriter)
{
auto pauliProduct = op.getPauliProduct();

bool allIdentity = llvm::all_of(pauliProduct, [](mlir::Attribute attr) {
auto pauliStr = llvm::cast<mlir::StringAttr>(attr);
return pauliStr.getValue() == "I";
});

if (allIdentity) {
rewriter.replaceOp(op, op.getInQubits());
return mlir::success();
}
return mlir::failure();
return canonicalizePPROp(op, rewriter);
}

void LayerOp::build(OpBuilder &builder, OperationState &result, ValueRange inValues,
Expand Down