Skip to content
11 changes: 11 additions & 0 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,6 +1767,7 @@ class Circuit(AbstractCircuit):
* batch_remove
* batch_insert_into
* insert_at_frontier
* reverse

Circuits can also be iterated over,

Expand Down Expand Up @@ -2525,6 +2526,16 @@ def clear_operations_touching(
self._moments[k] = self._moments[k].without_operations_touching(qubits)
self._mutated()

def reverse(self) -> None:
"""Reverses the moments in the circuit, and the operations in the moments."""
# Work on a copy in case validation fails halfway through.
copy = self.copy()
backwards = []
for moment in copy[::-1]:
backwards.append(Moment(reversed(moment.operations)))
self._moments = backwards
self._mutated()

@property
def moments(self) -> Sequence[cirq.Moment]:
return self._moments
Expand Down
173 changes: 173 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,6 +1868,179 @@ def test_clear_operations_touching() -> None:
)


def test_reverse_empty_circuit():
circuit = cirq.Circuit()
circuit.reverse()
assert len(circuit) == 0
assert circuit == cirq.Circuit()


def test_reverse_single_moment_single_operation():
q = cirq.GridQubit(0, 0)
circuit = cirq.Circuit(cirq.X(q))
original_str = str(circuit)

circuit.reverse()

assert str(circuit) == original_str
assert len(circuit) == 1


def test_reverse_single_moment_multiple_operations():
"""Test reversing a circuit with one moment and multiple operations."""
q0, q1, q2 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1), cirq.GridQubit(0, 2)
original_ops = [cirq.X(q0), cirq.Y(q1), cirq.Z(q2)]
circuit = cirq.Circuit(cirq.Moment(original_ops))

circuit.reverse()

# Moment order unchanged (only one moment), but operations reversed
assert len(circuit) == 1
reversed_ops = list(circuit[0])
assert reversed_ops == list(reversed(original_ops))


def test_reverse_multiple_moments_single_operations():
"""Test reversing a circuit with multiple moments, each with single operations."""
q = cirq.GridQubit(0, 0)
circuit = cirq.Circuit([cirq.Moment([cirq.X(q)]), cirq.Moment([cirq.Y(q)]), cirq.Moment([cirq.Z(q)])])

original_moments = [str(moment) for moment in circuit]
circuit.reverse()

# Moments should be reversed
assert len(circuit) == 3
reversed_moments = [str(moment) for moment in circuit]
assert reversed_moments == list(reversed(original_moments))

def test_reverse_multiple_moments_multiple_operations():
"""Test reversing a circuit with multiple moments and multiple operations."""
q0, q1 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)
circuit = cirq.Circuit(
[
cirq.Moment([cirq.X(q0), cirq.Y(q1)]),
cirq.Moment([cirq.Z(q0), cirq.H(q1)]),
cirq.Moment([cirq.S(q0), cirq.T(q1)])
]
)

# Store original structure
original_structure = []
for moment in circuit:
original_structure.append(list(moment.operations))

circuit.reverse()

# Check that moments are reversed and operations within each moment are reversed
assert len(circuit) == 3

# First moment should be the reversed last moment
expected_first = list(reversed(original_structure[2]))
actual_first = list(circuit[0])
assert actual_first == expected_first

# Second moment should be the reversed middle moment
expected_second = list(reversed(original_structure[1]))
actual_second = list(circuit[1])
assert actual_second == expected_second

# Third moment should be the reversed first moment
expected_third = list(reversed(original_structure[0]))
actual_third = list(circuit[2])
assert actual_third == expected_third


def test_reverse_twice_returns_original():
"""Test that reversing twice returns the original circuit."""
q0, q1 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)
original_circuit = cirq.Circuit([
cirq.Moment([cirq.X(q0), cirq.Y(q1)]),
cirq.Moment([cirq.Z(q0)]),
cirq.Moment([cirq.H(q0), cirq.S(q1)])
]
)

# Make a copy to compare against
expected = original_circuit.copy()

# Reverse twice
original_circuit.reverse()
original_circuit.reverse()

# Should be back to original
assert original_circuit == expected


def test_reverse_with_measurements():
"""Test reversing a circuit with measurement operations."""
q0, q1 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)
circuit = cirq.Circuit(
[
cirq.Moment([cirq.X(q0), cirq.Y(q1)]),
cirq.Moment([cirq.measure(q0, key='a'), cirq.measure(q1, key='b')])
]
)

original_structure = []
for moment in circuit:
original_structure.append(list(moment.operations))

circuit.reverse()

# Check structure is properly reversed
assert len(circuit) == 2

# First moment should be reversed measurements
actual_first = list(circuit[0])
assert len(actual_first) == 2
assert all(isinstance(op.gate, cirq.MeasurementGate) for op in actual_first)

# Second moment should be reversed X, Y gates
actual_second = list(circuit[1])
assert len(actual_second) == 2


def test_reverse_with_two_qubit_gates():
"""Test reversing a circuit with two-qubit gates."""
q0, q1, q2 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1), cirq.GridQubit(0, 2)
circuit = cirq.Circuit(
[
cirq.Moment([cirq.CNOT(q0, q1), cirq.X(q2)]),
cirq.Moment([cirq.CZ(q1, q2)]),
cirq.Moment([cirq.SWAP(q0, q2), cirq.Y(q1)])
]
)

original_structure = []
for moment in circuit:
original_structure.append(list(moment.operations))

circuit.reverse()

# Verify the structure is correctly reversed
assert len(circuit) == 3

# Check that two-qubit gates are preserved correctly
for i, moment in enumerate(circuit):
expected_ops = list(reversed(original_structure[2 - i]))
actual_ops = list(moment.operations)
assert actual_ops == expected_ops


def test_reverse_modifies_original_circuit():
"""Test that reverse() modifies the original circuit in-place."""
q = cirq.GridQubit(0, 0)
circuit = cirq.Circuit([cirq.Moment([cirq.X(q)]), cirq.Moment([cirq.Y(q)])])

original_id = id(circuit)
circuit.reverse()

# Should be the same object
assert id(circuit) == original_id

# But content should be different
assert str(circuit[0]) != "X(q(0, 0))" # First moment is now Y

@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
def test_all_qubits(circuit_cls) -> None:
a = cirq.NamedQubit('a')
Expand Down
8 changes: 6 additions & 2 deletions cirq-core/cirq/transformers/stratify.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import itertools
from typing import Callable, Iterable, Sequence, TYPE_CHECKING, Union

import copy

from cirq import _import, circuits, ops, protocols
from cirq.transformers import transformer_api

Expand Down Expand Up @@ -69,7 +71,8 @@ def stratified_circuit(
# Try the algorithm with each permutation of the classifiers.
smallest_depth = protocols.num_qubits(circuit) * len(circuit) + 1
shortest_stratified_circuit = circuits.Circuit()
reversed_circuit = circuit[::-1]
reversed_circuit = copy.deepcopy(circuit)
reversed_circuit.reverse()
for ordered_classifiers in itertools.permutations(classifiers):
solution = _stratify_circuit(
circuit,
Expand All @@ -87,7 +90,8 @@ def stratified_circuit(
reversed_circuit,
classifiers=ordered_classifiers,
context=context or transformer_api.TransformerContext(),
)[::-1]
)
solution.reverse()
if len(solution) < smallest_depth:
shortest_stratified_circuit = solution
smallest_depth = len(solution)
Expand Down
Loading