Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lectures/_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ sphinx:
macros:
"argmax" : "arg\\,max"
"argmin" : "arg\\,min"
intersphinx_mapping:
intermediate:
- "https://python.quantecon.org/"
- null
mathjax_path: https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
rediraffe_redirects:
index_toc.md: intro.md
Expand Down
128 changes: 53 additions & 75 deletions lectures/kesten_processes.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.17.2
jupytext_version: 1.16.7
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -32,18 +32,17 @@ In addition to JAX and Anaconda, this lecture will need the following libraries:
```{code-cell} ipython3
:tags: [hide-output]

!pip install quantecon
!pip install --upgrade quantecon
```

## Overview

This lecture describes Kesten processes, which are an important class of
stochastic processes, and an application of firm dynamics.
stochastic processes, and their application to firm dynamics.

The lecture draws on [an earlier QuantEcon lecture](https://python.quantecon.org/kesten_processes.html),
which uses Numba to accelerate the computations.
The lecture draws on {doc}`intermediate:kesten_processes`.

In that earlier lecture you can find a more detailed discussion of the concepts involved.
In that earlier lecture, you can find a more detailed discussion of the concepts involved.

This lecture focuses on implementing the same computations in JAX.

Expand All @@ -55,13 +54,11 @@ import quantecon as qe
import jax
import jax.numpy as jnp
from jax import random
from jax import lax
from quantecon import tic, toc
from typing import NamedTuple
from functools import partial
```

Let's check the GPU we are running
Let's check the GPU we are running on

```{code-cell} ipython3
!nvidia-smi
Expand All @@ -85,19 +82,17 @@ sequences.

We are interested in the dynamics of $\{X_t\}_{t \geq 0}$ when $X_0$ is given.

We will focus on the nonnegative scalar case, where $X_t$ takes values in $\mathbb R_+$.
We will focus on the nonnegative scalar case, where $X_t$ takes values in $\mathbb{R}_+$.

In particular, we will assume that

* the initial condition $X_0$ is nonnegative,
* $\{a_t\}_{t \geq 1}$ is a nonnegative IID stochastic process and
* $\{a_t\}_{t \geq 1}$ is a nonnegative IID stochastic process, and
* $\{\eta_t\}_{t \geq 1}$ is another nonnegative IID stochastic process, independent of the first.


### Application: firm dynamics

In this section we apply Kesten process theory to the study of firm dynamics.

In this section, we apply Kesten process theory to the study of firm dynamics.

#### Gibrat's law

Expand All @@ -121,7 +116,7 @@ for some positive IID sequence $\{a_t\}$.
Subsequent empirical research has shown that this specification is not accurate,
particularly for small firms.

However, we can get close to the data by modifying {eq}`firm_dynam_gb` to
However, we can get closer to the data by modifying {eq}`firm_dynam_gb` to

```{math}
:label: firm_dynam
Expand All @@ -136,7 +131,7 @@ We now study the implications of this specification.

#### Heavy tails

If the conditions of the [Kesten--Goldie Theorem](https://python.quantecon.org/kesten_processes.html#the-kestengoldie-theorem)
If the conditions of the {doc}`intermediate:kesten_processes#the-kestengoldie-theorem`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mmcky Is there any suggested approach to handle this?

are satisfied, then {eq}`firm_dynam` implies that the firm size distribution will have Pareto tails.

This matches empirical findings across many data sets.
Expand All @@ -154,10 +149,9 @@ In this setting, firm dynamics can be expressed as
(a_{t+1} s_t + b_{t+1}) \mathbb{1}\{s_t \geq \bar s\}
```

The motivation behind and interpretation of [](firm_dynam_ee) can be found in
[our earlier Kesten process lecture](https://python.quantecon.org/kesten_processes.html).
The motivation behind and interpretation of [](firm_dynam_ee) can be found in {doc}`intermediate:kesten_processes`.

What can we say about dynamics?
What can we say about the dynamics?

Although {eq}`firm_dynam_ee` is not a Kesten process, it does update in the
same way as a Kesten process when $s_t$ is large.
Expand All @@ -168,8 +162,8 @@ We can investigate this question via simulation and rank-size plots.

The approach will be to

1. generate $M$ draws of $s_T$ when $M$ and $T$ are large and
1. plot the largest 1,000 of the resulting draws in a rank-size plot.
1. generate $M$ draws of $s_T$ when $M$ and $T$ are large, and
2. plot the largest 1,000 of the resulting draws in a rank-size plot.

(The distribution of $s_T$ will be close to the stationary distribution
when $T$ is large.)
Expand All @@ -180,12 +174,12 @@ Here's a class to store parameters:

```{code-cell} ipython3
class Firm(NamedTuple):
μ_a: float = -0.5
σ_a: float = 0.1
μ_b: float = 0.0
σ_b: float = 0.5
μ_e: float = 0.0
σ_e: float = 0.5
μ_a: float = -0.5
σ_a: float = 0.1
μ_b: float = 0.0
σ_b: float = 0.5
μ_e: float = 0.0
σ_e: float = 0.5
s_bar: float = 1.0
```

Expand All @@ -204,18 +198,16 @@ Now we write a for loop that repeatedly calls this function, to push a
cross-section of firms forward in time.

For sufficiently large `T`, the cross-section it returns (the cross-section at
time `T`) corresponds to firm size distribution in (approximate) equilibrium.
time `T`) corresponds to the firm size distribution in (approximate) equilibrium.

```{code-cell} ipython3
def generate_cross_section(
firm, M=500_000, T=500, s_init=1.0, seed=123
):
def generate_cross_section(firm, M=500_000, T=500, s_init=1.0, seed=123):

μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm
key = random.PRNGKey(seed)

# Initialize the cross-section to a common value
s = jnp.full((M, ), s_init)
s = jnp.full((M,), s_init)

# Perform updates on s for time t
for t in range(T):
Expand All @@ -235,17 +227,15 @@ Let's try running the code and generating a cross-section.

```{code-cell} ipython3
firm = Firm()
tic()
data = generate_cross_section(firm).block_until_ready()
toc()
with qe.Timer():
data = generate_cross_section(firm).block_until_ready()
```

We run the function again so we can see the speed without compile time.

```{code-cell} ipython3
tic()
data = generate_cross_section(firm).block_until_ready()
toc()
with qe.Timer():
data = generate_cross_section(firm).block_until_ready()
```

Let's produce the rank-size plot and check the distribution:
Expand All @@ -254,7 +244,7 @@ Let's produce the rank-size plot and check the distribution:
fig, ax = plt.subplots()

rank_data, size_data = qe.rank_size(data, c=0.01)
ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5)
ax.loglog(rank_data, size_data, "o", markersize=3.0, alpha=0.5)
ax.set_xlabel("log rank")
ax.set_ylabel("log size")

Expand All @@ -263,36 +253,33 @@ plt.show()

The plot produces a straight line, consistent with a Pareto tail.


#### Alternative implementation with `lax.fori_loop`
#### Alternative implementation with `jax.lax.fori_loop`

Although we JIT-compiled some of the code above,
we did not JIT-compile the `for` loop.

Let's try squeezing out a bit more speed
by

* replacing the `for` loop with [`lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html) and
* replacing the `for` loop with [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), and
* JIT-compiling the whole function.

Here a the `lax.fori_loop` version:
Here is the `jax.lax.fori_loop` version:

```{code-cell} ipython3
@jax.jit
def generate_cross_section_lax(
firm, T=500, M=500_000, s_init=1.0, seed=123
):
def generate_cross_section_lax(firm, T=500, M=500_000, s_init=1.0, seed=123):

μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm
key = random.PRNGKey(seed)

# Initial cross section
s = jnp.full((M, ), s_init)
s = jnp.full((M,), s_init)

def update_cross_section(t, state):
s, key = state
key, *subkeys = jax.random.split(key, 4)
# Generate current random draws
# Generate current random draws
a = μ_a + σ_a * random.normal(subkeys[0], (M,))
b = μ_b + σ_b * random.normal(subkeys[1], (M,))
e = μ_e + σ_e * random.normal(subkeys[2], (M,))
Expand All @@ -303,26 +290,22 @@ def generate_cross_section_lax(
new_state = s, key
return new_state

# Use fori_loop
# Use fori_loop
initial_state = s, key
final_s, final_key = lax.fori_loop(
0, T, update_cross_section, initial_state
)
final_s, final_key = jax.lax.fori_loop(0, T, update_cross_section, initial_state)
return final_s
```

Let's see if we get any speed gain

```{code-cell} ipython3
tic()
data = generate_cross_section_lax(firm).block_until_ready()
toc()
with qe.Timer():
data = generate_cross_section_lax(firm).block_until_ready()
```

```{code-cell} ipython3
tic()
data = generate_cross_section_lax(firm).block_until_ready()
toc()
with qe.Timer():
data = generate_cross_section_lax(firm).block_until_ready()
```

Here we produce the same rank-size plot:
Expand All @@ -331,12 +314,11 @@ Here we produce the same rank-size plot:
fig, ax = plt.subplots()

rank_data, size_data = qe.rank_size(data, c=0.01)
ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5)
ax.loglog(rank_data, size_data, "o", markersize=3.0, alpha=0.5)
ax.set_xlabel("log rank")
ax.set_ylabel("log size")

plt.show()

```

## Exercises
Expand All @@ -362,46 +344,42 @@ What are the pros and cons of this approach?

```{code-cell} ipython3
@jax.jit
def generate_cross_section_lax(
firm, T=500, M=500_000, s_init=1.0, seed=123
):
def generate_cross_section_lax(firm, T=500, M=500_000, s_init=1.0, seed=123):

μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm
key = random.PRNGKey(seed)
subkey_1, subkey_2, subkey_3 = random.split(key, 3)
# Generate entire sequence of random draws

# Generate entire sequence of random draws
a = μ_a + σ_a * random.normal(subkey_1, (T, M))
b = μ_b + σ_b * random.normal(subkey_2, (T, M))
e = μ_e + σ_e * random.normal(subkey_3, (T, M))
# Exponentiate them
a, b, e = jax.tree.map(jnp.exp, (a, b, e))
# Initial cross section
s = jnp.full((M, ), s_init)
s = jnp.full((M,), s_init)

def update_cross_section(t, s):
# Pull out the t-th cross-section of shocks
a_t, b_t, e_t = a[t], b[t], e[t]
s = jnp.where(s < s_bar, e_t, a_t * s + b_t)
return s

# Use lax.scan to perform the calculations on all states
s_final = lax.fori_loop(0, T, update_cross_section, s)
# Use lax.fori_loop to perform the calculations on all states
s_final = jax.lax.fori_loop(0, T, update_cross_section, s)
return s_final
```

Here are the run times.

```{code-cell} ipython3
tic()
data = generate_cross_section_lax(firm).block_until_ready()
toc()
with qe.Timer():
data = generate_cross_section_lax(firm).block_until_ready()
```

```{code-cell} ipython3
tic()
data = generate_cross_section_lax(firm).block_until_ready()
toc()
with qe.Timer():
data = generate_cross_section_lax(firm).block_until_ready()
```

This method might or might not be faster.
Expand Down
Loading