11import time
2+ import logging
23from functools import partial
34from collections import defaultdict
45from typing import Dict , List , Optional
1213from renormalizer .model .basis import BasisSet
1314
1415from tencirchem .dynamic .transform import qubit_encode_op , qubit_encode_basis , get_init_circuit
15- from tencirchem .dynamic .time_derivative import get_circuit , get_ansatz , get_jacobian_func , get_deriv , get_pvqd_loss_func
16+ from tencirchem .dynamic .time_derivative import get_circuit , get_ansatz , get_jacobian_func , get_deriv , get_pvqd_loss_func , one_trotter_step
1617
18+ logger = logging .getLogger (__name__ )
1719
1820def evolve_exact (evals : np .ndarray , evecs : np .ndarray , init : np .ndarray , t : float ):
1921 return evecs @ (np .diag (np .exp (- 1j * t * evals )) @ (evecs .T @ init ))
2022
21-
2223class TimeEvolution :
2324 def __init__ (
2425 self ,
@@ -30,6 +31,7 @@ def __init__(
3031 eps : float = 1e-5 ,
3132 property_op_dict : Dict = None ,
3233 ref_only : bool = False ,
34+ ivp_config = None ,
3335 ):
3436 # handling defaults
3537 if init_condition is None :
@@ -61,9 +63,14 @@ def __init__(
6163 self .ansatz = get_ansatz (self .model .ham_terms , self .model .basis , self .n_layers , self .init_circuit )
6264 self .jacobian_func = get_jacobian_func (self .ansatz )
6365
66+ self .current_circuit = self .init_circuit
6467 # setup runtime components
6568 self .eps = eps
6669 self .include_phase = False
70+ if ivp_config is None :
71+ self .ivp_config = Ivp_Config ()
72+ else :
73+ self .ivp_config = ivp_config
6774
6875 def scipy_deriv (t , _params ):
6976 return get_deriv (self .ansatz , self .jacobian_func , _params , self .h , self .eps , self .include_phase )
@@ -102,26 +109,36 @@ def solve_pvqd(_params, delta_t):
102109
103110 self .wall_time_list = []
104111
105- def kernel (self , tau , pvqd = False ):
112+ def kernel (self , tau , algo = "vqd" ):
106113 # one step of time evolution
107114 if self .ref_only :
108115 return self .kernel_ref_only (tau )
109116 time0 = time .time ()
110- if not pvqd :
111- scipy_sol = solve_ivp (self .scipy_deriv , [self .t , self .t + tau ], self .params )
112- new_params = scipy_sol .y [:, - 1 ]
117+ if algo == "vqd" or algo == "pvqd" :
118+ if algo == "vqd" :
119+ method , rtol , atol = self .ivp_config .method , self .ivp_config .rtol , self .ivp_config .atol
120+ scipy_sol = solve_ivp (self .scipy_deriv , [self .t , self .t + tau ],
121+ self .params , method = method , rtol = rtol , atol = atol )
122+ new_params = scipy_sol .y [:, - 1 ]
123+ else :
124+ scipy_sol = self .solve_pvqd (self .params , tau )
125+ new_params = self .params + scipy_sol .x
126+
127+ self .params_list .append (new_params )
128+ state = self .ansatz (self .params )
129+ self .scipy_sol_list .append (scipy_sol )
113130 else :
114- scipy_sol = self .solve_pvqd (self .params , tau )
115- new_params = self .params + scipy_sol .x
131+ assert algo == "trotter"
132+ self .current_circuit = one_trotter_step (self .model .ham_terms , self .model .basis ,
133+ self .current_circuit , tau )
134+ state = self .current_circuit .state ()
135+
116136 time1 = time .time ()
117137 self .t_list .append (self .t + tau )
118- self .params_list .append (new_params )
119138 # t and params already updated
120- state = self .ansatz (self .params )
121139 self .state_list .append (state )
122140 state_ref = evolve_exact (self .evals_ref , self .evecs_ref , self .init_ref , self .t )
123141 self .state_ref_list .append (state_ref )
124- self .scipy_sol_list .append (scipy_sol )
125142 # calculate properties
126143 self .update_property_dict (state , state_ref )
127144
@@ -203,3 +220,13 @@ def property_dict(self):
203220 @property
204221 def wall_time (self ):
205222 return self .wall_time_list [- 1 ]
223+
224+ class Ivp_Config :
225+ def __init__ (self ,
226+ method = "RK45" ,
227+ rtol = 1e-3 ,
228+ atol = 1e-6 ):
229+ self .method = method
230+ self .rtol = rtol
231+ self .atol = atol
232+
0 commit comments