Skip to content

Commit 3021391

Browse files
jstacmmcky
andauthored
Kesten edits (#219)
* misc * @mmcky edits * enable dropdown * upgrade CUDANN to 9.10.2 --------- Co-authored-by: mmcky <mamckay@gmail.com>
1 parent 8ff32f3 commit 3021391

File tree

2 files changed

+44
-24
lines changed

2 files changed

+44
-24
lines changed

.github/workflows/ci.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ jobs:
77
- uses: actions/checkout@v4
88
with:
99
ref: ${{ github.event.pull_request.head.sha }}
10+
- name: Upgrade CUDANN
11+
shell: bash -l {0}
12+
run: |
13+
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
14+
sudo dpkg -i cuda-keyring_1.1-1_all.deb
15+
sudo apt-get update
16+
sudo apt-get -y install cudnn-cuda-12
1017
- name: Setup Anaconda
1118
uses: conda-incubator/setup-miniconda@v3
1219
with:

lectures/kesten_processes.md

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ In addition to JAX and Anaconda, this lecture will need the following libraries:
4040
This lecture describes Kesten processes, which are an important class of
4141
stochastic processes, and an application of firm dynamics.
4242

43-
The lecture draws on [an earlier QuantEcon
44-
lecture](https://python.quantecon.org/kesten_processes.html), which uses Numba
45-
to accelerate the computations.
43+
The lecture draws on [an earlier QuantEcon lecture](https://python.quantecon.org/kesten_processes.html),
44+
which uses Numba to accelerate the computations.
4645

4746
In that earlier lecture you can find a more detailed discussion of the concepts involved.
4847

@@ -137,10 +136,8 @@ We now study the implications of this specification.
137136

138137
#### Heavy tails
139138

140-
If the conditions of the [Kesten--Goldie
141-
Theorem](https://python.quantecon.org/kesten_processes.html#the-kestengoldie-theorem)
142-
are satisfied, then {eq}`firm_dynam` implies that the firm size distribution
143-
will have Pareto tails.
139+
If the conditions of the [Kesten--Goldie Theorem](https://python.quantecon.org/kesten_processes.html#the-kestengoldie-theorem)
140+
are satisfied, then {eq}`firm_dynam` implies that the firm size distribution will have Pareto tails.
144141

145142
This matches empirical findings across many data sets.
146143

@@ -190,12 +187,11 @@ class Firm(NamedTuple):
190187
μ_e: float = 0.0
191188
σ_e: float = 0.5
192189
s_bar: float = 1.0
193-
194-
#
195-
# Here's code to update a cross-section of firms according to the dynamics in
196-
# [](firm_dynam_ee).
197190
```
198191

192+
Here's code to update a cross-section of firms according to the dynamics in
193+
[](firm_dynam_ee).
194+
199195
```{code-cell} ipython3
200196
@jax.jit
201197
def update_cross_section(s, a, b, e, firm):
@@ -250,7 +246,6 @@ data = generate_cross_section(firm).block_until_ready()
250246
toc()
251247
```
252248

253-
254249
Let's produce the rank-size plot and check the distribution:
255250

256251
```{code-cell} ipython3
@@ -271,7 +266,7 @@ The plot produces a straight line, consistent with a Pareto tail.
271266

272267
We did not JIT-compile the `for` loop above because
273268
acceleration of outer loops makes relatively little difference terms of
274-
compute time.
269+
compute time.
275270

276271
However, to maximize performance, let's try squeezing out a bit more speed
277272
by replacing the `for` loop with
@@ -311,10 +306,10 @@ def generate_cross_section_lax(
311306
0, T, update_cross_section, initial_state
312307
)
313308
return final_s
314-
315-
# Let's see if we got any speed gain
316309
```
317310

311+
Let's see if we get any speed gain
312+
318313
```{code-cell} ipython3
319314
tic()
320315
data = generate_cross_section_lax(firm).block_until_ready()
@@ -339,14 +334,27 @@ ax.set_ylabel("log size")
339334
340335
plt.show()
341336
342-
#
343-
# If the time horizon is not too large, we can also try generating all shocks at
344-
# once.
345-
#
346-
# Note, however, that this approach consumes more memory, as we need to have to
347-
# store large matrices of random draws
348-
#
349-
# Hence the code below will fail due to out-of-memory errors when `T` and `M` are large.
337+
```
338+
339+
## Exercises
340+
341+
```{exercise-start}
342+
:label: kp_ex1
343+
```
344+
345+
Try writing an alternative version of `generate_cross_section_lax()` where the entire sequence of random draws is generated at once, so that all of `a`, `b`, and `e` are of shape `(T, M)`.
346+
347+
(The `update_cross_section()` function should not generate any random numbers.)
348+
349+
Does it improve the runtime?
350+
351+
What are the pros and cons of this approach.
352+
353+
```{exercise-end}
354+
```
355+
356+
```{solution-start} kp_ex1
357+
:class: dropdown
350358
```
351359

352360
```{code-cell} ipython3
@@ -393,6 +401,11 @@ data = generate_cross_section_lax(firm).block_until_ready()
393401
toc()
394402
```
395403

396-
This second method might be slightly faster in some cases but in general the
404+
This method might be faster in some cases but in general the
397405
relative speed will depend on the size of the cross-section and the length of
398406
the simulation paths.
407+
408+
Also, this method is far more memory intensive.
409+
410+
```{solution-end}
411+
```

0 commit comments

Comments
 (0)