@@ -925,13 +925,13 @@ And now let's see without compile time.
925925```{code-cell} ipython3
926926qe.tic()
927927v_jax = sv_pd_ratio_jax(sv_model_jax, shapes).block_until_ready()
928- jnp_time_1 = qe.toc()
928+ jnp_time = qe.toc()
929929```
930930
931931Here's the ratio of times:
932932
933933```{code-cell} ipython3
934- jnp_time_1 / np_time
934+ jnp_time / np_time
935935```
936936
937937Let's check that the NumPy and JAX versions realize the same solution.
959959$$
960960
961961```{code-cell} ipython3
962- def H_operator (g, sv_model, shapes):
962+ def H (g, sv_model, shapes):
963963 # Set up
964964 P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
965965 I, J, K = shapes
@@ -978,93 +978,62 @@ def H_operator(g, sv_model, shapes):
978978 return Hg
979979```
980980
981- The next function modifies our earlier `power_iteration_sr` function so that it
982- can act on linear operators rather than matrices,
983-
984- also the spectral radius of the transition matrix less than one ensures the convergence of our calculations in the model.
985-
986- ```{code-cell} ipython3
987- def update_g(H_operator, sv_model, shapes, num_iterations=20):
988- g_k = jnp.ones(shapes)
989- for _ in range(num_iterations):
990- g_k1 = H_operator(g_k, sv_model, shapes)
991- sr = jnp.sum(g_k1 * g_k) / jnp.sum(g_k * g_k)
992- g_k1_norm = jnp.linalg.norm(g_k1)
993- g_k = g_k1 / g_k1_norm
994-
995- return sr
996- ```
997-
998- Let's check the output
999-
1000- ```{code-cell} ipython3
1001- qe.tic()
1002- sr = update_g(H_operator, sv_model, shapes)
1003- qe.toc()
1004- print(sr)
1005- ```
1006-
1007981Now we write a version of the solution function for the price-dividend ratio
1008- that acts directly on the linear operator `H_operator `.
982+ that acts directly on the linear operator `H `.
1009983
1010984```{code-cell} ipython3
1011- def sv_pd_ratio_jax_multi (sv_model, shapes):
985+ def sv_pd_ratio_linop (sv_model, shapes):
1012986
1013987 # Setp up
1014988 P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
1015989 I, J, K = shapes
1016990
1017- # Compute v
1018991 ones_array = np.ones((I, J, K))
1019- # Set up the operator g -> g
1020- H = lambda g: H_operator(g, sv_model, shapes)
1021992 # Set up the operator g -> (I - H) g
1022- J = lambda g: g - H(g)
993+ J = lambda g: g - H(g, sv_model, shapes )
1023994 # Solve v = (I - H)^{-1} H 1
1024- H1 = H_operator(ones_array, sv_model, shapes)
995+ H1 = H(ones_array, sv_model, shapes)
996+ # Apply an iterative solver that works for linear operators
1025997 v = jax.scipy.sparse.linalg.bicgstab(J, H1)[0]
1026998 return v
1027999```
10281000
10291001Let's target these functions for JIT compilation.
10301002
10311003```{code-cell} ipython3
1032- H_operator = jax.jit(H_operator , static_argnums=(2,))
1033- sv_pd_ratio_jax_multi = jax.jit(sv_pd_ratio_jax_multi , static_argnums=(1,))
1004+ H = jax.jit(H , static_argnums=(2,))
1005+ sv_pd_ratio_linop = jax.jit(sv_pd_ratio_linop , static_argnums=(1,))
10341006```
10351007
10361008Let's time the solution with compile time included.
10371009
10381010```{code-cell} ipython3
10391011qe.tic()
1040- v_jax_multi = sv_pd_ratio_jax_multi (sv_model, shapes).block_until_ready()
1041- jnp_time_multi_0 = qe.toc()
1012+ v_jax_linop = sv_pd_ratio_linop (sv_model, shapes).block_until_ready()
1013+ jnp_time_linop_0 = qe.toc()
10421014```
10431015
10441016And now let’s see without compile time.
10451017
10461018```{code-cell} ipython3
10471019qe.tic()
1048- v_jax_multi = sv_pd_ratio_jax_multi (sv_model, shapes).block_until_ready()
1049- jnp_time_multi_1 = qe.toc()
1020+ v_jax_linop = sv_pd_ratio_linop (sv_model, shapes).block_until_ready()
1021+ jnp_linop_time = qe.toc()
10501022```
10511023
10521024Let's verify the solution again:
10531025
10541026```{code-cell} ipython3
1055- print(jnp.allclose(v, v_jax_multi ))
1027+ print(jnp.allclose(v, v_jax_linop ))
10561028```
10571029
10581030Here’s the ratio of times between memory-efficient and direct version:
10591031
10601032```{code-cell} ipython3
1061- jnp_time_multi_1 / jnp_time_1
1033+ jnp_linop_time / jnp_time
10621034```
10631035
1064- The speed is somewhat faster. In addition,
1065-
1066- 1. now we can work with much larger grids, and
1067- 2. the memory efficient version will be significantly faster with larger grids.
1036+ The speed is somewhat faster and, moreover, we can now work with much larger grids.
10681037
10691038Here's a moderately large example, where the state space has 15,625 elements.
10701039
@@ -1073,14 +1042,14 @@ sv_model = create_sv_model(I=25, J=25, K=25)
10731042sv_model_jax = create_sv_model_jax(sv_model)
10741043P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model_jax
10751044shapes = len(hc_grid), len(hd_grid), len(z_grid)
1076- ```
10771045
1078- ```{code-cell} ipython3
10791046qe.tic()
1080- v_jax_multi = sv_pd_ratio_jax_multi (sv_model, shapes).block_until_ready()
1081- jnp_time_multi_2 = qe.toc()
1047+ _ = sv_pd_ratio_linop (sv_model, shapes).block_until_ready()
1048+ qe.toc()
10821049```
10831050
1084- ```{code-cell} ipython3
1085- jnp_time_multi_1 / jnp_time_1
1086- ```
1051+ The solution is computed relatively quickly and without memory issues.
1052+
1053+ Readers will find that they can push these numbers further, although we refrain
1054+ from doing so here.
1055+
0 commit comments