@@ -294,30 +294,128 @@ get_value = jax.jit(get_value, static_argnums=(2,))
294294We use successive approximation for VFI.
295295
296296``` {code-cell} ipython3
297- :load: _static/lecture_specific/successive_approx.py
297+ def successive_approx_jax(T, # Operator (callable)
298+ x_0, # Initial condition
299+ tol=1e-6, # Error tolerance
300+ max_iter=10_000): # Max iteration bound
301+ def body_fun(k_x_err):
302+ k, x, error = k_x_err
303+ x_new = T(x)
304+ error = jnp.max(jnp.abs(x_new - x))
305+ return k + 1, x_new, error
306+
307+ def cond_fun(k_x_err):
308+ k, x, error = k_x_err
309+ return jnp.logical_and(error > tol, k < max_iter)
310+
311+ k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tol + 1))
312+ return x
313+
314+ successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(0,))
315+ ```
316+
317+ For OPI we'll add a compiled routine that computes $T_σ^m v$.
318+
319+ ``` {code-cell} ipython3
320+ def iterate_policy_operator(σ, v, m, params, sizes, arrays):
321+
322+ def update(i, v):
323+ v = T_σ(v, σ, params, sizes, arrays)
324+ return v
325+
326+ v = jax.lax.fori_loop(0, m, update, v)
327+ return v
328+
329+ iterate_policy_operator = jax.jit(iterate_policy_operator,
330+ static_argnums=(4,))
298331```
299332
300333Finally, we introduce the solvers that implement VFI, HPI and OPI.
301334
302335``` {code-cell} ipython3
303- :load: _static/lecture_specific/vfi.py
336+ def value_function_iteration(model, tol=1e-5):
337+ """
338+ Implements value function iteration.
339+ """
340+ params, sizes, arrays = model
341+ vz = jnp.zeros(sizes)
342+ _T = lambda v: T(v, params, sizes, arrays)
343+ v_star = successive_approx_jax(_T, vz, tol=tol)
344+ return get_greedy(v_star, params, sizes, arrays)
304345```
305346
347+ For OPI we will use a compiled JAX ` lax.while_loop ` operation to speed execution.
348+
349+
306350``` {code-cell} ipython3
307- :load: _static/lecture_specific/hpi.py
351+ def opi_loop(params, sizes, arrays, m, tol, max_iter):
352+ """
353+ Implements optimistic policy iteration (see dp.quantecon.org) with
354+ step size m.
355+
356+ """
357+ v_init = jnp.zeros(sizes)
358+
359+ def condition_function(inputs):
360+ i, v, error = inputs
361+ return jnp.logical_and(error > tol, i < max_iter)
362+
363+ def update(inputs):
364+ i, v, error = inputs
365+ last_v = v
366+ σ = get_greedy(v, params, sizes, arrays)
367+ v = iterate_policy_operator(σ, v, m, params, sizes, arrays)
368+ error = jnp.max(jnp.abs(v - last_v))
369+ i += 1
370+ return i, v, error
371+
372+ num_iter, v, error = jax.lax.while_loop(condition_function,
373+ update,
374+ (0, v_init, tol + 1))
375+
376+ return get_greedy(v, params, sizes, arrays)
377+
378+ opi_loop = jax.jit(opi_loop, static_argnums=(1,))
308379```
309380
381+ Here's a friendly interface to OPI
382+
310383``` {code-cell} ipython3
311- :load: _static/lecture_specific/opi.py
384+ def optimistic_policy_iteration(model, m=10, tol=1e-5, max_iter=10_000):
385+ params, sizes, arrays = model
386+ σ_star = opi_loop(params, sizes, arrays, m, tol, max_iter)
387+ return σ_star
312388```
313389
390+ Here's HPI
391+
392+
393+ ``` {code-cell} ipython3
394+ def howard_policy_iteration(model, maxiter=250):
395+ """
396+ Implements Howard policy iteration (see dp.quantecon.org)
397+ """
398+ params, sizes, arrays = model
399+ σ = jnp.zeros(sizes, dtype=int)
400+ i, error = 0, 1.0
401+ while error > 0 and i < maxiter:
402+ v_σ = get_value(σ, params, sizes, arrays)
403+ σ_new = get_greedy(v_σ, params, sizes, arrays)
404+ error = jnp.max(jnp.abs(σ_new - σ))
405+ σ = σ_new
406+ i = i + 1
407+ print(f"Concluded loop {i} with error {error}.")
408+ return σ
409+ ```
410+
411+
314412``` {code-cell} ipython3
315413:tags: [hide-output]
316414
317415model = create_investment_model()
318416print("Starting HPI.")
319417qe.tic()
320- out = policy_iteration (model)
418+ out = howard_policy_iteration (model)
321419elapsed = qe.toc()
322420print(out)
323421print(f"HPI completed in {elapsed} seconds.")
@@ -328,7 +426,7 @@ print(f"HPI completed in {elapsed} seconds.")
328426
329427print("Starting VFI.")
330428qe.tic()
331- out = value_iteration (model)
429+ out = value_function_iteration (model)
332430elapsed = qe.toc()
333431print(out)
334432print(f"VFI completed in {elapsed} seconds.")
@@ -356,7 +454,7 @@ y_grid, z_grid, Q = arrays
356454```
357455
358456``` {code-cell} ipython3
359- σ_star = policy_iteration (model)
457+ σ_star = howard_policy_iteration (model)
360458
361459fig, ax = plt.subplots(figsize=(9, 5))
362460ax.plot(y_grid, y_grid, "k--", label="45")
@@ -376,15 +474,15 @@ m_vals = range(5, 600, 40)
376474model = create_investment_model()
377475print("Running Howard policy iteration.")
378476qe.tic()
379- σ_pi = policy_iteration (model)
477+ σ_pi = howard_policy_iteration (model)
380478pi_time = qe.toc()
381479```
382480
383481``` {code-cell} ipython3
384482print(f"PI completed in {pi_time} seconds.")
385483print("Running value function iteration.")
386484qe.tic()
387- σ_vfi = value_iteration (model, tol=1e-5)
485+ σ_vfi = value_function_iteration (model, tol=1e-5)
388486vfi_time = qe.toc()
389487print(f"VFI completed in {vfi_time} seconds.")
390488```
0 commit comments