Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
aff3c1d
TST: runs-on integration
mmcky Apr 25, 2025
c978ea6
fix syntax error
mmcky Apr 25, 2025
93c0dc1
switch to non-container workflow
mmcky Apr 25, 2025
129c32d
add latex
mmcky Apr 25, 2025
0b817d3
add numpyro and jax
mmcky Apr 25, 2025
f11f96e
create more space on Ubuntu
mmcky Apr 25, 2025
c579e6d
try larger runner disk size setting
mmcky Apr 25, 2025
f492bbf
change 80 to large?
mmcky Apr 25, 2025
7cf5b01
specify larger disk on ci.yml
mmcky Apr 25, 2025
e3f132f
enable sudo pip
mmcky Apr 25, 2025
01f0ac8
Install jax via conda
mmcky Apr 26, 2025
af626ca
accept install -y
mmcky Apr 26, 2025
5278cad
revert to pip
mmcky Apr 26, 2025
336fbd1
use local CUDA library
mmcky Apr 27, 2025
a07cb48
test jax install
mmcky Apr 27, 2025
3881c0d
install using anaconda/pip
mmcky Apr 27, 2025
05fcb5d
use anaconda pip
mmcky Apr 27, 2025
540b3a1
ensure anaconda python is being used
mmcky Apr 27, 2025
aae09d5
update anaconda path
mmcky Apr 27, 2025
dd969b2
upgrade install of jax
mmcky Apr 27, 2025
65e76b4
tmp: disable cache
mmcky Apr 27, 2025
831592d
update JAX install of pip
mmcky Apr 28, 2025
2c2b83a
ensure path to software is set
mmcky Apr 28, 2025
9bbcd78
add more debug testing
mmcky Apr 28, 2025
dd15278
fix syntax issue
mmcky Apr 28, 2025
8dc00f9
restore cache for quicker debug
mmcky Apr 28, 2025
007187b
change status.md
mmcky Apr 28, 2025
712ef7f
tst: disable cache and add further jax test
mmcky Apr 28, 2025
d98813a
execution test on g4dn.2xlarge
mmcky Apr 28, 2025
74ba7c8
full build with all elements
mmcky Apr 29, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/runs-on.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
runners:
larger-disk:
disk: large
54 changes: 39 additions & 15 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,56 @@ name: Build Preview [using jupyter-book]
on: [pull_request]
jobs:
preview:
runs-on: quantecon-gpu
container:
image: ghcr.io/quantecon/lecture-python-container:cuda-12.6.0-anaconda-2024-10-py312-b
options: --gpus all
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=ubuntu24-gpu-x64/disk=large"
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Check nvidia drivers
- name: Setup Anaconda
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
auto-activate-base: true
miniconda-version: 'latest'
python-version: "3.12"
environment-file: environment.yml
activate-environment: quantecon
- name: Install jax and jaxlib
shell: bash -l {0}
run: |
nvidia-smi
- name: Check python version
pip install -U "jax[cuda12]"
python --version
python test-jax-install.py
- name: Install latex dependencies
run: |
sudo apt-get -qq update
sudo apt-get install -y \
texlive-latex-recommended \
texlive-latex-extra \
texlive-fonts-recommended \
texlive-fonts-extra \
texlive-xetex \
latexmk \
xindy \
dvipng \
cm-super
- name: Check nvidia drivers
shell: bash -l {0}
run: |
python --version
nvidia-smi
- name: Display Conda Environment Versions
shell: bash -l {0}
run: conda list
- name: Display Pip Versions
shell: bash -l {0}
run: pip list
- name: Download "build" folder (cache)
uses: dawidd6/action-download-artifact@v3
with:
workflow: cache.yml
branch: main
name: build-cache
path: _build
# - name: Download "build" folder (cache)
# uses: dawidd6/action-download-artifact@v3
# with:
# workflow: cache.yml
# branch: main
# name: build-cache
# path: _build
# Build Assets (Download Notebooks and PDF via LaTeX)
- name: Build Download Notebooks (sphinx-tojupyter)
shell: bash -l {0}
Expand All @@ -48,6 +69,9 @@ jobs:
- name: Build HTML
shell: bash -l {0}
run: |
echo $PATH
python --version
python test-jax-install.py
jb build lectures --path-output ./ -n -W --keep-going
- name: Upload build folder
uses: actions/upload-artifact@v4
Expand Down
3 changes: 1 addition & 2 deletions lectures/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ You can check the backend used by JAX using:

```{code-cell} ipython3
import jax
# Check if JAX is using GPU
print(f"JAX backend: {jax.devices()[0].platform}")
print(f"JAX backend: {jax.devices()[0].platform}") # Check if JAX is using GPU
```

and the hardware we are running on:
Expand Down
21 changes: 21 additions & 0 deletions test-jax-install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import jax
import jax.numpy as jnp

devices = jax.devices()
print(f"The available devices are: {devices}")

@jax.jit
def matrix_multiply(a, b):
return jnp.dot(a, b)

# Example usage:
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000, 1000))
y = jax.random.normal(key, (1000, 1000))
z = matrix_multiply(x, y)

# Now the function is JIT compiled and will likely run on GPU (if available)
print(z)

devices = jax.devices()
print(f"The available devices are: {devices}")
Loading