Skip to content

Commit 81c77fe

Browse files
committed
optimize trotter time evolution
1 parent f7ca9b9 commit 81c77fe

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

tencirchem/dynamic/time_derivative.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,17 @@ def get_circuit(ham_terms, spin_basis, n_layers, init_state, params, param_ids=N
7373
return c
7474

7575

76-
def one_trotter_step(ham_terms, spin_basis, init_state, dt):
76+
def one_trotter_step(ham_terms, spin_basis, init_state, dt, inplace=False):
7777
"""
7878
one step first order trotter decompostion
7979
"""
8080
ansatz_op_list = construct_ansatz_op(ham_terms, spin_basis)
8181

8282
if isinstance(init_state, tc.Circuit):
83-
c = tc.Circuit.from_qir(init_state.to_qir(), circuit_params=init_state.circuit_param)
83+
if inplace:
84+
c = init_state
85+
else:
86+
c = tc.Circuit.from_qir(init_state.to_qir(), circuit_params=init_state.circuit_param)
8487
else:
8588
c = tc.Circuit(len(spin_basis), inputs=init_state)
8689

tencirchem/dynamic/time_evolution.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def __init__(
8585
self.ansatz = get_ansatz(self.model.ham_terms, self.model.basis, self.n_layers, self.init_circuit)
8686
self.jacobian_func = get_jacobian_func(self.ansatz)
8787

88-
self.current_circuit = self.init_circuit
8988
# setup runtime components
89+
self.current_circuit = self.init_circuit
9090
self.eps = eps
9191
self.include_phase = False
9292
if ivp_config is None:
@@ -152,8 +152,9 @@ def kernel(self, tau, algo="vqd"):
152152
self.scipy_sol_list.append(scipy_sol)
153153
else:
154154
assert algo == "trotter"
155-
self.current_circuit = one_trotter_step(self.model.ham_terms, self.model.basis, self.current_circuit, tau)
156-
state = self.current_circuit.state()
155+
self.current_circuit = one_trotter_step(self.model.ham_terms, self.model.basis, self.current_circuit, tau, inplace=True)
156+
shortcut = one_trotter_step(self.model.ham_terms, self.model.basis, self.state, tau)
157+
state = shortcut.state()
157158

158159
time1 = time.time()
159160
self.t_list.append(self.t + tau)

tests/static/test_hea.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ def test_hea(engine, backend_str, grad, reset_backend):
2929
hea = HEA.ry(uccsd.int1e, uccsd.int2e, uccsd.n_elec, uccsd.e_core, 3, engine=engine)
3030
hea.grad = grad
3131
e = hea.kernel()
32-
np.testing.assert_allclose(e, uccsd.e_fci, atol=0.1)
32+
atol = 0.1
33+
if engine == "tensornetwork-noise&shot" and grad == "free":
34+
atol *= 2
35+
np.testing.assert_allclose(e, uccsd.e_fci, atol=atol)
3336

3437

3538
def test_qiskit_circuit():

0 commit comments

Comments
 (0)