- Notifications
You must be signed in to change notification settings - Fork 1.1k
Use an efficient representation for merged components of operations #7484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…s in the merge methods. Currently merge_operations represents merged components as a CircuitOperation. This means in a merge of n operations, n-1 CircuitOperations are created, with a complexity of O(n^2). We use a disjont-set data structure to reduce the complexity to O(n) for merge_k_qubit_unitaries_to_circuit_op. merge_operations itself can't be improved because it uses a merge_func that requires creation of CircuitOperation at every step.
I didn't change the algorithm in merge operations, but switched from CircuitOperation to Component for intermediate merges. I found this inefficiency while investigating why Still this change improves transpilation. On my laptop, I measured an average improvement of 5,09% from 86.89 seconds to 82,46 seconds for 7 qubits. I also measured the improvement, if This is the script I used to measure the transpiler improvement: """Tests performance of quantum Shannon decomposition. """ import sys import time from scipy.stats import unitary_group import cirq if not sys.warnoptions: import warnings warnings.simplefilter("ignore") def main(): measurements = [] n_qubits = 7 for i in range(10): print(f'Iteration {i}') U = unitary_group.rvs(2**n_qubits) qubits = [cirq.NamedQubit(f'q{i}') for i in range(n_qubits)] start_time = time.time() circuit = cirq.Circuit(cirq.quantum_shannon_decomposition(qubits, U)) elapsed_time = time.time() - start_time print(f'Time taken: {elapsed_time:.2f} seconds') # Run cirq transpiler start_time = time.time() circuit = cirq.optimize_for_target_gateset(circuit, gateset=cirq.CZTargetGateset(preserve_moment_structure=False), max_num_passes=None) elapsed_time = time.time() - start_time measurements.append(elapsed_time) print(f'Cirq depth for {n_qubits} qubits (with transpiler): {len(circuit)} ({elapsed_time:.2f} sec)') print(f'Average time taken for transpiling: {sum(measurements) / len(measurements):.2f} seconds') if __name__ == '__main__': main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, I think the PR greatly improved the performance and the readability!
The unit tests look solid and confirm the merge behavior is working as intended. I've focused my review on the overall design in this iteration.
For performance, IIUC, given m connected ops,
- before: construct CircuitOperation m-1 times.
- after: construct CircuitOperation 1 time + extra on-average-constant disjoint-set union and find time.
The bottleneck is the CircuitOperation construction time in the previous implementation. Correct me if I am wrong.
Since this is a big design change, we should be careful here for the compatibility and extensibility. I've put a couple of comments for discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we do online queries about the connected components?
The components are used throughout |
Thank you a lot for the review @babacry and @NoureldinYosri ! I made changes following your comments, please have a look! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me with a couple of nits.
Over to Nour @NoureldinYosri
containing measurement operations with the same key. | ||
ckey_indexes: Mapping from measurement keys to (sorted) list of component moments | ||
containing classically controlled operations controlled on the same key. | ||
components_by_index: List of circuit moments containing components. We use a dictionary |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not from your pr, but the naming of the var is strange: should it be the opposite as something like index_by_component_list? see go/python-tips/054
and could you update the explanation of the attribute with the map details? The key of the map is component, what does the value of the map represent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is correct, but confusing. The data structure is:
list of dictionary {key: component, value: 0}, where the position in the list is given by the moment id
So the dictionaries are in a list, with the index the moment id. For a given moment, a dictionary is used to preserve the insertion order of the components in the moment. The value has no meaning, it's only there to take advantage of the dict order preserving feature.
I added more details in the docstring, please have a look!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see, thanks for the clarification! It's a bit tricky, I rephrased a bit:
components_by_index: List of components indexed by moment id. For a moment id, we use a dictionary to keep track of the components in the moment. The dictionary instead of a set is used to preserve insertion order and the dictionary's values are intentionally unused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I updated the description.
merged_circuit_op_tag: tag to use for CircuitOperations | ||
Returns: | ||
the circuit with merged components as a CircuitOperation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: keep the docstring format uniform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, can you give me more details about what's wrong here? I don't find the issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: uppercase at the beginning and period at the end.
The circuit with merged components as a CircuitOperation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
circuit: CIRCUIT_TYPE, | ||
factory: ComponentFactory, | ||
*, | ||
merged_circuit_op_tag: str = "Merged connected component", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a good practice to put space in a tag? @NoureldinYosri I am not sure, though existing code before this pr have some mixed style of tags.
maybe "merged_connected_component"? or "_merged_connected_component" (start with _ to separate a default tag named by cirq pkg
with user specified tag
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied the tag value from line 621 in merge_operations_to_circuit_op. I think the two should be in sync.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@codrut3 sorry about the late reply but I was wondering if instead of manually implementing union-find we can do the required the functionality using https://networkx.org/documentation/stable/_modules/networkx/utils/union_find.html
this will split the logic of merging moments into two parts where one of them where we need to maintain only one of them
def merge(self, a, b): self._merge_moments(a, b) self.union_find.union(a, b)
is it possible to do that or will it break the logic?
I think it works to use a separate data structure to do the union-find. This structure would have to be initialized by the user and passed as an argument to the component factory. However, I would prefer to use scipy.DisjointSet because it allows me to add new singletons on the fly, while https://networkx.org/documentation/stable/_modules/networkx/utils/union_find.html requires all elements to be known from the beginning. If you are ok with adding a dependency on the scipy package, I will rewrite the code. Do I need to add the package version as a requirement somewhere? Let me know how to proceed! @NoureldinYosri |
@codrut3 that sounds good, Cirq already depends on scipy Cirq/cirq-core/requirements.txt Line 10 in fe0dc21
sorry for the extra work |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks! @codrut3
Single comments are added in the reviews.
No worries! I changed the implementation to use I took this opportunity to rewrite the connected component classes. I moved the |
If memory consumption ever becomes an issue, we can replace with a no-op trivial operation. Also clarify comment for similar action for ComponentWithOps.
The Component class uses default Python object hash and equality functions. If retrieved from sets or dictionaries, they must be identical.
Make them throw NotImplementedError which is ignored in coverage.
These operate on a general list of sorted integers and do not need any class data. Replaced with `_insort_last` and `_remove_last`.
Iteration over Circuit yields Moments.
Indicate these are internal functions not intended for use outside of cirq.
default_factory=lambda: defaultdict(lambda: [-1]) | ||
) | ||
ops_by_index: list[dict[cirq.Operation, int]] = dataclasses.field(default_factory=list) | ||
components_by_index: list[dict[Component, int]] = dataclasses.field(default_factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit - if dictionary values are unused, it is perhaps better to declare it as dict[Component, NoneType]
and use None
instead of 0
to add to dictionary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after addressing some minor issues - mainly renaming the new module connected_component
to _connected_component
to mark it as for internal use only.
Rather than writing a lot of comments I pushed the suggestions directly to the PR, these should be easy to review commit-by-commit.
Please let me know if this looks OK to you, LGTM on my side (with one tiny inline suggestion).
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@ ## main #7484 +/- ## ======================================== Coverage 99.38% 99.38% ======================================== Files 1089 1091 +2 Lines 97551 97813 +262 ======================================== + Hits 96950 97212 +262 Misses 601 601 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Optimized for the majority case when the value is the last in the list. | ||
""" | ||
indices.reverse() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't reverse O(n)?
I think it would be better to check with an if the last element before trying reverse and remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made some changes in my commit, please have a look!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching that! I was under impression that list.reverse is O(1), but it is indeed O(n). I have debug-printed the removed positions when not last and they are close to the list end. I think we can use bisection to optimize the removal a bit more, here is my take on it in a735c01.
LMK if this looks OK to you.
Thank you a lot for your review and changes! I addressed the nit and made another change, PTAL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to go on my side.
As of quantumlib#7590 typing_extensions is not in cirq-core requirements. mypy can detect mismatched `merge` arguments w/r to the base class without the `@typing_extensions.override` decorator. It does not seem worth reintroducing this dependency, here we drop the decorator instead. NB: Python 3.12 added a similar `typing.override` decorator, but we need to support 3.11.
One last tweak in 7637dde - typing_extensions is not in the cirq-core requirements anymore so I removed its import. |
Nice, looks good to me, thank you a lot for the changes! |
Thanks a lot for contributing this! |
Currently
merge_operations
represents merged components as aCircuitOperation
. This means in a merge ofn
operations,n-1
CircuitOperations
are created, with a complexity ofO(n^2)
.I use a disjont-set data structure to reduce the complexity to
O(n)
formerge_k_qubit_unitaries_to_circuit_op
.merge_operations
itself can't be improved because it uses amerge_func
that requires creating aCircuitOperation
at every step.