Skip to content

Commit f7c6eb5

Browse files
authored
[FIX] Fix descriptions on .at method (#218)
* update jax intro lecture * minor update * minor update * address comment
1 parent 5ae1a02 commit f7c6eb5

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

lectures/jax_intro.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ a, a_new
172172
The designers of JAX chose to make arrays immutable because JAX uses a
173173
functional programming style. More on this below.
174174

175-
Note that, while mutation is discouraged, it is in fact possible with `at`, as in
175+
However, JAX provides a functionally pure equivalent of in-place array modification
176+
using the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html).
177+
176178

177179
```{code-cell} ipython3
178180
a = jnp.linspace(0, 1, 3)
@@ -183,11 +185,14 @@ id(a)
183185
a
184186
```
185187

188+
Applying `at[0].set(1)` returns a new copy of `a` with the first element set to 1
189+
186190
```{code-cell} ipython3
187-
a.at[0].set(1)
191+
a = a.at[0].set(1)
192+
a
188193
```
189194

190-
We can check that the array is mutated by verifying its identity is unchanged:
195+
Inspecting the identifier of `a` shows that it has been reassigned
191196

192197
```{code-cell} ipython3
193198
id(a)

0 commit comments

Comments
 (0)