Skip to content

Commit 8295961

Browse files
authored
tidying up asset pricing lecture (#96)
* misc * misc
1 parent 50e5bfb commit 8295961

File tree

1 file changed

+24
-55
lines changed

1 file changed

+24
-55
lines changed

lectures/markov_asset.md

Lines changed: 24 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -925,13 +925,13 @@ And now let's see without compile time.
925925
```{code-cell} ipython3
926926
qe.tic()
927927
v_jax = sv_pd_ratio_jax(sv_model_jax, shapes).block_until_ready()
928-
jnp_time_1 = qe.toc()
928+
jnp_time = qe.toc()
929929
```
930930
931931
Here's the ratio of times:
932932
933933
```{code-cell} ipython3
934-
jnp_time_1 / np_time
934+
jnp_time / np_time
935935
```
936936
937937
Let's check that the NumPy and JAX versions realize the same solution.
@@ -959,7 +959,7 @@ $$
959959
$$
960960
961961
```{code-cell} ipython3
962-
def H_operator(g, sv_model, shapes):
962+
def H(g, sv_model, shapes):
963963
# Set up
964964
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
965965
I, J, K = shapes
@@ -978,93 +978,62 @@ def H_operator(g, sv_model, shapes):
978978
return Hg
979979
```
980980
981-
The next function modifies our earlier `power_iteration_sr` function so that it
982-
can act on linear operators rather than matrices,
983-
984-
also the spectral radius of the transition matrix less than one ensures the convergence of our calculations in the model.
985-
986-
```{code-cell} ipython3
987-
def update_g(H_operator, sv_model, shapes, num_iterations=20):
988-
g_k = jnp.ones(shapes)
989-
for _ in range(num_iterations):
990-
g_k1 = H_operator(g_k, sv_model, shapes)
991-
sr = jnp.sum(g_k1 * g_k) / jnp.sum(g_k * g_k)
992-
g_k1_norm = jnp.linalg.norm(g_k1)
993-
g_k = g_k1 / g_k1_norm
994-
995-
return sr
996-
```
997-
998-
Let's check the output
999-
1000-
```{code-cell} ipython3
1001-
qe.tic()
1002-
sr = update_g(H_operator, sv_model, shapes)
1003-
qe.toc()
1004-
print(sr)
1005-
```
1006-
1007981
Now we write a version of the solution function for the price-dividend ratio
1008-
that acts directly on the linear operator `H_operator`.
982+
that acts directly on the linear operator `H`.
1009983
1010984
```{code-cell} ipython3
1011-
def sv_pd_ratio_jax_multi(sv_model, shapes):
985+
def sv_pd_ratio_linop(sv_model, shapes):
1012986
1013987
# Setp up
1014988
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model
1015989
I, J, K = shapes
1016990
1017-
# Compute v
1018991
ones_array = np.ones((I, J, K))
1019-
# Set up the operator g -> g
1020-
H = lambda g: H_operator(g, sv_model, shapes)
1021992
# Set up the operator g -> (I - H) g
1022-
J = lambda g: g - H(g)
993+
J = lambda g: g - H(g, sv_model, shapes)
1023994
# Solve v = (I - H)^{-1} H 1
1024-
H1 = H_operator(ones_array, sv_model, shapes)
995+
H1 = H(ones_array, sv_model, shapes)
996+
# Apply an iterative solver that works for linear operators
1025997
v = jax.scipy.sparse.linalg.bicgstab(J, H1)[0]
1026998
return v
1027999
```
10281000
10291001
Let's target these functions for JIT compilation.
10301002
10311003
```{code-cell} ipython3
1032-
H_operator = jax.jit(H_operator, static_argnums=(2,))
1033-
sv_pd_ratio_jax_multi = jax.jit(sv_pd_ratio_jax_multi, static_argnums=(1,))
1004+
H = jax.jit(H, static_argnums=(2,))
1005+
sv_pd_ratio_linop = jax.jit(sv_pd_ratio_linop, static_argnums=(1,))
10341006
```
10351007
10361008
Let's time the solution with compile time included.
10371009
10381010
```{code-cell} ipython3
10391011
qe.tic()
1040-
v_jax_multi = sv_pd_ratio_jax_multi(sv_model, shapes).block_until_ready()
1041-
jnp_time_multi_0 = qe.toc()
1012+
v_jax_linop = sv_pd_ratio_linop(sv_model, shapes).block_until_ready()
1013+
jnp_time_linop_0 = qe.toc()
10421014
```
10431015
10441016
And now let’s see without compile time.
10451017
10461018
```{code-cell} ipython3
10471019
qe.tic()
1048-
v_jax_multi = sv_pd_ratio_jax_multi(sv_model, shapes).block_until_ready()
1049-
jnp_time_multi_1 = qe.toc()
1020+
v_jax_linop = sv_pd_ratio_linop(sv_model, shapes).block_until_ready()
1021+
jnp_linop_time = qe.toc()
10501022
```
10511023
10521024
Let's verify the solution again:
10531025
10541026
```{code-cell} ipython3
1055-
print(jnp.allclose(v, v_jax_multi))
1027+
print(jnp.allclose(v, v_jax_linop))
10561028
```
10571029
10581030
Here’s the ratio of times between memory-efficient and direct version:
10591031
10601032
```{code-cell} ipython3
1061-
jnp_time_multi_1 / jnp_time_1
1033+
jnp_linop_time / jnp_time
10621034
```
10631035
1064-
The speed is somewhat faster. In addition,
1065-
1066-
1. now we can work with much larger grids, and
1067-
2. the memory efficient version will be significantly faster with larger grids.
1036+
The speed is somewhat faster and, moreover, we can now work with much larger grids.
10681037
10691038
Here's a moderately large example, where the state space has 15,625 elements.
10701039
@@ -1073,14 +1042,14 @@ sv_model = create_sv_model(I=25, J=25, K=25)
10731042
sv_model_jax = create_sv_model_jax(sv_model)
10741043
P, hc_grid, Q, hd_grid, R, z_grid, β, γ, bar_σ, μ_c, μ_d = sv_model_jax
10751044
shapes = len(hc_grid), len(hd_grid), len(z_grid)
1076-
```
10771045
1078-
```{code-cell} ipython3
10791046
qe.tic()
1080-
v_jax_multi = sv_pd_ratio_jax_multi(sv_model, shapes).block_until_ready()
1081-
jnp_time_multi_2 = qe.toc()
1047+
_ = sv_pd_ratio_linop(sv_model, shapes).block_until_ready()
1048+
qe.toc()
10821049
```
10831050
1084-
```{code-cell} ipython3
1085-
jnp_time_multi_1 / jnp_time_1
1086-
```
1051+
The solution is computed relatively quickly and without memory issues.
1052+
1053+
Readers will find that they can push these numbers further, although we refrain
1054+
from doing so here.
1055+

0 commit comments

Comments
 (0)