Skip to content

Commit ae20376

Browse files
authored
Merge pull request #1 from jjren/trotter
add trotter algorithm for dynamics
2 parents 568f8bc + 1e803f2 commit ae20376

File tree

4 files changed

+90
-23
lines changed

4 files changed

+90
-23
lines changed

tencirchem/dynamic/time_derivative.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010

1111
logger = logging.getLogger(__name__)
1212

13+
def construct_ansatz_op(ham_terms, spin_basis):
1314

14-
def get_circuit(ham_terms, spin_basis, n_layers, init_state, params, param_ids=None, compile_evolution=False):
15-
if param_ids is None:
16-
param_ids = list(range(len(ham_terms)))
1715
dof_idx_dict = {b.dof: i for i, b in enumerate(spin_basis)}
1816
ansatz_op_list = []
1917

@@ -35,17 +33,25 @@ def get_circuit(ham_terms, spin_basis, n_layers, init_state, params, param_ids=N
3533
op_mat = np.kron(op_mat, tc.gates._z_matrix)
3634
name += "Z"
3735
qubit_idx_list = [dof_idx_dict[dof] for dof in op.dofs]
38-
ansatz_op_list.append((op_mat, name, qubit_idx_list))
36+
ansatz_op_list.append((op_mat, op.factor, name, qubit_idx_list))
37+
38+
return ansatz_op_list
39+
40+
def get_circuit(ham_terms, spin_basis, n_layers, init_state, params, param_ids=None, compile_evolution=False):
41+
if param_ids is None:
42+
param_ids = list(range(len(ham_terms)))
3943

4044
params = tc.backend.reshape(params, [n_layers, max(param_ids) + 1])
4145

46+
ansatz_op_list = construct_ansatz_op(ham_terms, spin_basis)
47+
4248
if isinstance(init_state, tc.Circuit):
4349
c = tc.Circuit.from_qir(init_state.to_qir(), circuit_params=init_state.circuit_param)
4450
else:
4551
c = tc.Circuit(len(spin_basis), inputs=init_state)
4652

4753
for i in range(0, n_layers):
48-
for j, (ansatz_op, name, qubit_idx_list) in enumerate(ansatz_op_list):
54+
for j, (ansatz_op, op_factor, name, qubit_idx_list) in enumerate(ansatz_op_list):
4955
param_id = np.abs(param_ids[j])
5056
# +0.1 is to avoid np.sign(0) problem
5157
sign = np.sign(param_ids[j] + 0.1)
@@ -59,6 +65,20 @@ def get_circuit(ham_terms, spin_basis, n_layers, init_state, params, param_ids=N
5965
c = evolve_pauli(c, pauli_string, theta=theta)
6066
return c
6167

68+
def one_trotter_step(ham_terms, spin_basis, init_state, dt):
69+
"""
70+
one step first order trotter decompostion
71+
"""
72+
ansatz_op_list = construct_ansatz_op(ham_terms, spin_basis)
73+
74+
if isinstance(init_state, tc.Circuit):
75+
c = tc.Circuit.from_qir(init_state.to_qir(), circuit_params=init_state.circuit_param)
76+
else:
77+
c = tc.Circuit(len(spin_basis), inputs=init_state)
78+
79+
for (ansatz_op, op_factor, name, qubit_idx_list) in ansatz_op_list:
80+
c.exp1(*qubit_idx_list, unitary=ansatz_op, theta=dt*op_factor, name=name)
81+
return c
6282

6383
def get_ansatz(ham_terms, spin_basis, n_layers, init_state, param_ids=None):
6484
@jit

tencirchem/dynamic/time_evolution.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import time
2+
import logging
23
from functools import partial
34
from collections import defaultdict
45
from typing import Dict, List, Optional
@@ -12,13 +13,13 @@
1213
from renormalizer.model.basis import BasisSet
1314

1415
from tencirchem.dynamic.transform import qubit_encode_op, qubit_encode_basis, get_init_circuit
15-
from tencirchem.dynamic.time_derivative import get_circuit, get_ansatz, get_jacobian_func, get_deriv, get_pvqd_loss_func
16+
from tencirchem.dynamic.time_derivative import get_circuit, get_ansatz, get_jacobian_func, get_deriv, get_pvqd_loss_func, one_trotter_step
1617

18+
logger = logging.getLogger(__name__)
1719

1820
def evolve_exact(evals: np.ndarray, evecs: np.ndarray, init: np.ndarray, t: float):
1921
return evecs @ (np.diag(np.exp(-1j * t * evals)) @ (evecs.T @ init))
2022

21-
2223
class TimeEvolution:
2324
def __init__(
2425
self,
@@ -30,6 +31,7 @@ def __init__(
3031
eps: float = 1e-5,
3132
property_op_dict: Dict = None,
3233
ref_only: bool = False,
34+
ivp_config = None,
3335
):
3436
# handling defaults
3537
if init_condition is None:
@@ -61,9 +63,14 @@ def __init__(
6163
self.ansatz = get_ansatz(self.model.ham_terms, self.model.basis, self.n_layers, self.init_circuit)
6264
self.jacobian_func = get_jacobian_func(self.ansatz)
6365

66+
self.current_circuit = self.init_circuit
6467
# setup runtime components
6568
self.eps = eps
6669
self.include_phase = False
70+
if ivp_config is None:
71+
self.ivp_config = Ivp_Config()
72+
else:
73+
self.ivp_config = ivp_config
6774

6875
def scipy_deriv(t, _params):
6976
return get_deriv(self.ansatz, self.jacobian_func, _params, self.h, self.eps, self.include_phase)
@@ -102,26 +109,36 @@ def solve_pvqd(_params, delta_t):
102109

103110
self.wall_time_list = []
104111

105-
def kernel(self, tau, pvqd=False):
112+
def kernel(self, tau, algo="vqd"):
106113
# one step of time evolution
107114
if self.ref_only:
108115
return self.kernel_ref_only(tau)
109116
time0 = time.time()
110-
if not pvqd:
111-
scipy_sol = solve_ivp(self.scipy_deriv, [self.t, self.t + tau], self.params)
112-
new_params = scipy_sol.y[:, -1]
117+
if algo == "vqd" or algo == "pvqd":
118+
if algo == "vqd":
119+
method, rtol, atol = self.ivp_config.method, self.ivp_config.rtol, self.ivp_config.atol
120+
scipy_sol = solve_ivp(self.scipy_deriv, [self.t, self.t + tau],
121+
self.params, method=method, rtol=rtol, atol=atol)
122+
new_params = scipy_sol.y[:, -1]
123+
else:
124+
scipy_sol = self.solve_pvqd(self.params, tau)
125+
new_params = self.params + scipy_sol.x
126+
127+
self.params_list.append(new_params)
128+
state = self.ansatz(self.params)
129+
self.scipy_sol_list.append(scipy_sol)
113130
else:
114-
scipy_sol = self.solve_pvqd(self.params, tau)
115-
new_params = self.params + scipy_sol.x
131+
assert algo == "trotter"
132+
self.current_circuit = one_trotter_step(self.model.ham_terms, self.model.basis,
133+
self.current_circuit, tau)
134+
state = self.current_circuit.state()
135+
116136
time1 = time.time()
117137
self.t_list.append(self.t + tau)
118-
self.params_list.append(new_params)
119138
# t and params already updated
120-
state = self.ansatz(self.params)
121139
self.state_list.append(state)
122140
state_ref = evolve_exact(self.evals_ref, self.evecs_ref, self.init_ref, self.t)
123141
self.state_ref_list.append(state_ref)
124-
self.scipy_sol_list.append(scipy_sol)
125142
# calculate properties
126143
self.update_property_dict(state, state_ref)
127144

@@ -203,3 +220,13 @@ def property_dict(self):
203220
@property
204221
def wall_time(self):
205222
return self.wall_time_list[-1]
223+
224+
class Ivp_Config:
225+
def __init__(self,
226+
method="RK45",
227+
rtol=1e-3,
228+
atol=1e-6):
229+
self.method = method
230+
self.rtol = rtol
231+
self.atol = atol
232+

tencirchem/dynamic/transform.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,20 +248,32 @@ def get_encoding(m, boson_encoding):
248248

249249

250250
def get_init_circuit(model_ref, model, boson_encoding, init_condition):
251-
for v in init_condition.values():
251+
for k, v in init_condition.items():
252+
basis = model_ref.dof_to_basis[k]
252253
if not isinstance(v, int):
253-
return get_init_circuit_general(model_ref, model, boson_encoding, init_condition)
254+
if isinstance(basis, BasisHalfSpin) and v.shape == (2,2) and \
255+
np.allclose(np.eye(2), v @ v.T.conj) and np.allclose(np.eye(2), v.T.conj @ v):
256+
continue
257+
else:
258+
return get_init_circuit_general(model_ref, model, boson_encoding, init_condition)
259+
254260
# replace the dof_name key to site_index key
255261
circuit = tc.Circuit(len(model.basis))
256262
for k, v in init_condition.items():
257263
basis = model_ref.dof_to_basis[k]
258264
if isinstance(basis, BasisHalfSpin):
259265
if v == 1:
260266
circuit.X(model.dof_to_siteidx[k])
267+
elif v.shape == (2,2):
268+
circuit.ANY(idx, unitary=v)
269+
else:
270+
assert v == 0
261271
elif isinstance(basis, BasisMultiElectron):
262272
if v == 1:
263273
idx = model.dof_to_siteidx[basis.dofs]
264274
circuit.X(idx)
275+
else:
276+
assert v == 0
265277
else:
266278
assert basis.is_phonon
267279
if boson_encoding is None:

tests/dynamics/test_dynamics.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tencirchem.dynamic import sbm, qubit_encode_op, qubit_encode_basis, TimeEvolution
77

88

9-
@pytest.mark.parametrize("algorithm", ["vanilla", "include_phase", "p-VQD"])
9+
@pytest.mark.parametrize("algorithm", ["vanilla", "include_phase", "p-VQD", "trotter"])
1010
def test_sbm(reset_backend, algorithm):
1111
set_backend("jax")
1212
epsilon = 0
@@ -30,11 +30,19 @@ def test_sbm(reset_backend, algorithm):
3030
eps=1e-5,
3131
)
3232
te.include_phase = algorithm == "include_phase"
33-
34-
pvqd = algorithm == "p-VQD"
35-
tau = 0.1
33+
34+
if algorithm in ["vanilla", "include_phase"]:
35+
algo = "vqd"
36+
tau = 0.1
37+
elif algorithm == "p-VQD":
38+
algo = "pvqd"
39+
tau = 0.1
40+
else:
41+
algo = "trotter"
42+
tau = 0.02
43+
3644
for _ in range(50):
37-
te.kernel(tau, pvqd=pvqd)
45+
te.kernel(tau, algo=algo)
3846
z = te.property_dict["Z"]
3947
np.testing.assert_allclose(z[:, 0], z[:, 1], atol=1e-2)
4048

0 commit comments

Comments
 (0)