diff --git a/include/mqt-core/CircuitOptimizer.hpp b/include/mqt-core/CircuitOptimizer.hpp index 1d748ebbc..aaaac070c 100644 --- a/include/mqt-core/CircuitOptimizer.hpp +++ b/include/mqt-core/CircuitOptimizer.hpp @@ -51,7 +51,8 @@ class CircuitOptimizer { static void reorderOperations(QuantumComputation& qc); - static void flattenOperations(QuantumComputation& qc); + static void flattenOperations(QuantumComputation& qc, + bool customGatesOnly = false); static void cancelCNOTs(QuantumComputation& qc); diff --git a/include/mqt-core/operations/CompoundOperation.hpp b/include/mqt-core/operations/CompoundOperation.hpp index 2a300e036..5cf259b63 100644 --- a/include/mqt-core/operations/CompoundOperation.hpp +++ b/include/mqt-core/operations/CompoundOperation.hpp @@ -17,12 +17,14 @@ namespace qc { class CompoundOperation final : public Operation { private: std::vector> ops; + bool customGate; public: - explicit CompoundOperation(); + explicit CompoundOperation(bool isCustom = false); explicit CompoundOperation( - std::vector>&& operations); + std::vector>&& operations, + bool isCustom = false); CompoundOperation(const CompoundOperation& co); @@ -36,6 +38,8 @@ class CompoundOperation final : public Operation { [[nodiscard]] inline bool isSymbolicOperation() const override; + [[nodiscard]] bool isCustomGate() const; + void addControl(Control c) override; void clearControls() override; diff --git a/src/CircuitOptimizer.cpp b/src/CircuitOptimizer.cpp index fdf4c497e..6b43347b1 100644 --- a/src/CircuitOptimizer.cpp +++ b/src/CircuitOptimizer.cpp @@ -1197,11 +1197,17 @@ Iterator flattenCompoundOperation(std::vector>& ops, return it; } -void CircuitOptimizer::flattenOperations(QuantumComputation& qc) { +void CircuitOptimizer::flattenOperations(QuantumComputation& qc, + bool customGatesOnly) { auto it = qc.begin(); while (it != qc.end()) { if ((*it)->isCompoundOperation()) { - it = flattenCompoundOperation(qc.ops, it); + auto& op = dynamic_cast(**it); + if (!customGatesOnly || op.isCustomGate()) { + it = flattenCompoundOperation(qc.ops, it); + } else { + ++it; + } } else { ++it; } diff --git a/src/operations/CompoundOperation.cpp b/src/operations/CompoundOperation.cpp index 3bc2ab7c3..a5b1e9f1f 100644 --- a/src/operations/CompoundOperation.cpp +++ b/src/operations/CompoundOperation.cpp @@ -19,20 +19,20 @@ #include namespace qc { -CompoundOperation::CompoundOperation() { +CompoundOperation::CompoundOperation(bool isCustom) : customGate(isCustom) { name = "Compound operation:"; type = Compound; } CompoundOperation::CompoundOperation( - std::vector>&& operations) - : CompoundOperation() { + std::vector>&& operations, bool isCustom) + : CompoundOperation(isCustom) { // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) ops = std::move(operations); } CompoundOperation::CompoundOperation(const CompoundOperation& co) - : Operation(co), ops(co.ops.size()) { + : Operation(co), ops(co.ops.size()), customGate(co.customGate) { for (std::size_t i = 0; i < co.ops.size(); ++i) { ops[i] = co.ops[i]->clone(); } @@ -45,6 +45,7 @@ CompoundOperation& CompoundOperation::operator=(const CompoundOperation& co) { for (std::size_t i = 0; i < co.ops.size(); ++i) { ops[i] = co.ops[i]->clone(); } + customGate = co.customGate; } return *this; } @@ -61,6 +62,8 @@ bool CompoundOperation::isNonUnitaryOperation() const { bool CompoundOperation::isCompoundOperation() const { return true; } +bool CompoundOperation::isCustomGate() const { return customGate; } + bool CompoundOperation::isSymbolicOperation() const { return std::any_of(ops.begin(), ops.end(), [](const auto& op) { return op->isSymbolicOperation(); }); diff --git a/src/parsers/QASM3Parser.cpp b/src/parsers/QASM3Parser.cpp index 0b3a3cb62..e1d6502a2 100644 --- a/src/parsers/QASM3Parser.cpp +++ b/src/parsers/QASM3Parser.cpp @@ -645,7 +645,7 @@ class OpenQasm3Parser final : public InstVisitor { index++; } - auto op = std::make_unique(); + auto op = std::make_unique(true); for (const auto& nestedGate : compoundGate->body) { if (auto barrierStatement = std::dynamic_pointer_cast(nestedGate); diff --git a/test/unittests/test_qfr_functionality.cpp b/test/unittests/test_qfr_functionality.cpp index d8e22326e..7e5994693 100644 --- a/test/unittests/test_qfr_functionality.cpp +++ b/test/unittests/test_qfr_functionality.cpp @@ -8,6 +8,7 @@ #include "operations/Expression.hpp" #include "operations/NonUnitaryOperation.hpp" #include "operations/OpType.hpp" +#include "operations/StandardOperation.hpp" #include #include @@ -1192,6 +1193,51 @@ TEST_F(QFRFunctionality, FlattenRecursive) { EXPECT_TRUE(gate2->getControls().empty()); } +TEST_F(QFRFunctionality, FlattenCustomOnly) { + const std::size_t nqubits = 1U; + + // create a nested compound operation + QuantumComputation op(nqubits); + op.x(0); + op.z(0); + QuantumComputation op2(nqubits); + op2.emplace_back(op.asCompoundOperation()); + QuantumComputation qc(nqubits); + qc.emplace_back(op2.asCompoundOperation()); + std::cout << qc << "\n"; + + qc::CircuitOptimizer::flattenOperations(qc, true); + std::cout << qc << "\n"; + + ASSERT_EQ(qc.getNops(), 1U); + auto& gate = qc.at(0); + EXPECT_EQ(gate->getType(), qc::Compound); + + std::vector> opsCompound; + opsCompound.push_back(std::make_unique(0, qc::X)); + opsCompound.push_back(std::make_unique(0, qc::Z)); + QuantumComputation qc2(nqubits); + qc2.emplace_back(std::move(opsCompound), true); + std::cout << qc2 << "\n"; + + qc::CircuitOptimizer::flattenOperations(qc2, true); + std::cout << qc2 << "\n"; + + for (const auto& g : qc2) { + EXPECT_FALSE(g->isCompoundOperation()); + } + + ASSERT_EQ(qc2.getNops(), 2U); + auto& gate3 = qc2.at(0); + EXPECT_EQ(gate3->getType(), qc::X); + EXPECT_EQ(gate3->getTargets().at(0), 0U); + EXPECT_TRUE(gate3->getControls().empty()); + auto& gate4 = qc2.at(1); + EXPECT_EQ(gate4->getType(), qc::Z); + EXPECT_EQ(gate4->getTargets().at(0), 0U); + EXPECT_TRUE(gate4->getControls().empty()); +} + TEST_F(QFRFunctionality, OperationEquality) { const auto x = StandardOperation(0, qc::X); const auto z = StandardOperation(0, qc::Z);