Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions .github/workflows/cache.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,30 @@ on:
schedule:
# Execute cache weekly at 3am on Monday
- cron: '0 3 * * 1'
workflow_dispatch:
jobs:
cache:
runs-on: quantecon-gpu
container:
image: ghcr.io/quantecon/lecture-python-container:cuda-12.6.0-anaconda-2024-10-py312-b
options: --gpus all
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=ubuntu24-gpu-x64/disk=large"
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Setup Anaconda
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
auto-activate-base: true
miniconda-version: 'latest'
python-version: "3.12"
environment-file: environment.yml
activate-environment: quantecon
- name: Install jax (and install checks for GPU)
shell: bash -l {0}
run: |
pip install -U "jax[cuda12]"
python --version
python scripts/test-jax-install.py
nvidia-smi
- name: Build HTML
shell: bash -l {0}
run: |
Expand Down
33 changes: 26 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,41 @@ name: Build Preview [using jupyter-book]
on: [pull_request]
jobs:
preview:
runs-on: quantecon-gpu
container:
image: ghcr.io/quantecon/lecture-python-container:cuda-12.6.0-anaconda-2024-10-py312-b
options: --gpus all
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=ubuntu24-gpu-x64/disk=large"
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Check nvidia drivers
- name: Setup Anaconda
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
auto-activate-base: true
miniconda-version: 'latest'
python-version: "3.12"
environment-file: environment.yml
activate-environment: quantecon
- name: Install jax (and install checks for GPU)
shell: bash -l {0}
run: |
pip install -U "jax[cuda12]"
python --version
python scripts/test-jax-install.py
nvidia-smi
- name: Check python version
- name: Install latex dependencies
shell: bash -l {0}
run: |
python --version
sudo apt-get -qq update
sudo apt-get install -y \
texlive-latex-recommended \
texlive-latex-extra \
texlive-fonts-recommended \
texlive-fonts-extra \
texlive-xetex \
latexmk \
xindy \
dvipng \
cm-super
- name: Display Conda Environment Versions
shell: bash -l {0}
run: conda list
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/collab.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Build Project on Google Collab (Execution)
on: [pull_request]
jobs:
execution-checks:
runs-on: quantecon-gpu
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=ubuntu24-gpu-x64/disk=large"
container:
image: docker://us-docker.pkg.dev/colab-images/public/runtime
options: --gpus all
Expand Down
34 changes: 28 additions & 6 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,40 @@ on:
jobs:
publish:
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
runs-on: quantecon-gpu
container:
image: ghcr.io/quantecon/lecture-python-container:cuda-12.6.0-anaconda-2024-10-py312-b
options: --gpus all
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=ubuntu24-gpu-x64/disk=large"
steps:
- name: Checkout
uses: actions/checkout@v4
# Check nvidia-smi
- name: Check nvidia drivers
- name: Setup Anaconda
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
auto-activate-base: true
miniconda-version: 'latest'
python-version: "3.12"
environment-file: environment.yml
activate-environment: quantecon
- name: Install jax (and install checks for GPU)
shell: bash -l {0}
run: |
pip install -U "jax[cuda12]"
python --version
python scripts/test-jax-install.py
nvidia-smi
- name: Install latex dependencies
shell: bash -l {0}
run: |
sudo apt-get -qq update
sudo apt-get install -y \
texlive-latex-recommended \
texlive-latex-extra \
texlive-fonts-recommended \
texlive-fonts-extra \
texlive-xetex \
latexmk \
xindy \
dvipng \
cm-super
- name: Display Conda Environment Versions
shell: bash -l {0}
run: conda list
Expand Down
21 changes: 21 additions & 0 deletions scripts/test-jax-install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import jax
import jax.numpy as jnp

devices = jax.devices()
print(f"The available devices are: {devices}")

@jax.jit
def matrix_multiply(a, b):
return jnp.dot(a, b)

# Example usage:
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000, 1000))
y = jax.random.normal(key, (1000, 1000))
z = matrix_multiply(x, y)

# Now the function is JIT compiled and will likely run on GPU (if available)
print(z)

devices = jax.devices()
print(f"The available devices are: {devices}")
Loading