@@ -321,7 +321,6 @@ def __init__(
321321 else :
322322 raise ValueError (f"Unknown initialization method: { init_method } " )
323323
324-
325324 # circuit related
326325 self ._init_state = None
327326 self .ex_ops = None
@@ -1334,7 +1333,6 @@ def param_ids(self, v):
13341333
13351334
13361335def compute_fe_t2 (no , nv , int1e , int2e ):
1337-
13381336 n_orb = no + nv
13391337
13401338 def translate_o (n ):
@@ -1349,8 +1347,8 @@ def translate_v(n):
13491347 else :
13501348 return n // 2 + no
13511349
1352- t2 = np .zeros ((2 * no , 2 * no , 2 * nv , 2 * nv ))
1353- for i , j , k , l in product (range (2 * no ), range (2 * no ), range (2 * nv ), range (2 * nv )):
1350+ t2 = np .zeros ((2 * no , 2 * no , 2 * nv , 2 * nv ))
1351+ for i , j , k , l in product (range (2 * no ), range (2 * no ), range (2 * nv ), range (2 * nv )):
13541352 # spin not conserved
13551353 if i % 2 != k % 2 or j % 2 != l % 2 :
13561354 continue
@@ -1398,23 +1396,32 @@ def _compute_e_diff(r, s, b, a, int1e, int2e, n_orb, no):
13981396 new_a .append (i % n_orb )
13991397
14001398 diag1e = np .diag (int1e )
1401- diagj = np .einsum (' iijj->ij' , int2e )
1402- diagk = np .einsum (' ijji->ij' , int2e )
1399+ diagj = np .einsum (" iijj->ij" , int2e )
1400+ diagk = np .einsum (" ijji->ij" , int2e )
14031401
14041402 e_diff_1e = diag1e [new_a ].sum () + diag1e [new_b ].sum () - diag1e [old_a ].sum () - diag1e [old_b ].sum ()
1405- e_diff_j = _compute_j_outer (diagj , inert_a , inert_b , new_a , new_b ) - _compute_j_outer (diagj , inert_a , inert_b , old_a , old_b )
1406- e_diff_k = _compute_k_outer (diagk , inert_a , inert_b , new_a , new_b ) - _compute_k_outer (diagk , inert_a , inert_b , old_a , old_b )
1403+ # fmt: off
1404+ e_diff_j = _compute_j_outer (diagj , inert_a , inert_b , new_a , new_b ) \
1405+ - _compute_j_outer (diagj , inert_a , inert_b , old_a , old_b )
1406+ e_diff_k = _compute_k_outer (diagk , inert_a , inert_b , new_a , new_b ) \
1407+ - _compute_k_outer (diagk , inert_a , inert_b , old_a , old_b )
1408+ # fmt: on
14071409 return e_diff_1e + 1 / 2 * (e_diff_j - e_diff_k )
14081410
14091411
14101412def _compute_j_outer (diagj , inert_a , inert_b , outer_a , outer_b ):
1413+ # fmt: off
14111414 v = diagj [inert_a ][:, outer_a ].sum () + diagj [outer_a ][:, inert_a ].sum () + diagj [outer_a ][:, outer_a ].sum () \
1412- + diagj [inert_a ][:, outer_b ].sum () + diagj [outer_a ][:, inert_b ].sum () + diagj [outer_a ][:, outer_b ].sum () \
1413- + diagj [inert_b ][:, outer_a ].sum () + diagj [outer_b ][:, inert_a ].sum () + diagj [outer_b ][:, outer_a ].sum () \
1414- + diagj [inert_b ][:, outer_b ].sum () + diagj [outer_b ][:, inert_b ].sum () + diagj [outer_b ][:, outer_b ].sum ()
1415+ + diagj [inert_a ][:, outer_b ].sum () + diagj [outer_a ][:, inert_b ].sum () + diagj [outer_a ][:, outer_b ].sum () \
1416+ + diagj [inert_b ][:, outer_a ].sum () + diagj [outer_b ][:, inert_a ].sum () + diagj [outer_b ][:, outer_a ].sum () \
1417+ + diagj [inert_b ][:, outer_b ].sum () + diagj [outer_b ][:, inert_b ].sum () + diagj [outer_b ][:, outer_b ].sum ()
1418+ # fmt: on
14151419 return v
14161420
1421+
14171422def _compute_k_outer (diagk , inert_a , inert_b , outer_a , outer_b ):
1423+ # fmt: off
14181424 v = diagk [inert_a ][:, outer_a ].sum () + diagk [outer_a ][:, inert_a ].sum () + diagk [outer_a ][:, outer_a ].sum () \
1419- + diagk [inert_b ][:, outer_b ].sum () + diagk [outer_b ][:, inert_b ].sum () + diagk [outer_b ][:, outer_b ].sum ()
1425+ + diagk [inert_b ][:, outer_b ].sum () + diagk [outer_b ][:, inert_b ].sum () + diagk [outer_b ][:, outer_b ].sum ()
1426+ # fmt: on
14201427 return v
0 commit comments