Skip to content

Commit 2fd5aa7

Browse files
committed
minor update
1 parent cb78ad4 commit 2fd5aa7

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

lectures/jax_intro.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@ a, a_new
172172
The designers of JAX chose to make arrays immutable because JAX uses a
173173
functional programming style. More on this below.
174174

175-
However, JAX provides a functionally pure equivalent of in-place array modifications.
175+
However, JAX provides a functionally pure equivalent of in-place array modification
176+
using the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html).
176177

177-
To assign a new value to an element of a JAX array, we can use the `at` method
178178

179179
```{code-cell} ipython3
180180
a = jnp.linspace(0, 1, 3)
@@ -185,16 +185,15 @@ id(a)
185185
a
186186
```
187187

188-
We can see that the array `a` is changed by using the
189-
[`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html).
190-
191-
It returns a new copy of `a` with the specified element changed.
188+
Applying `at[0].set(1)`, we can see that a new copy of `a` with the first element
189+
set to 1 is returned
192190

193191
```{code-cell} ipython3
194192
a = a.at[0].set(1)
193+
a
195194
```
196195

197-
Inspecting the identifier of `a` shows that it has changed
196+
Inspecting the identifier of `a` shows that it has been reassigned
198197

199198
```{code-cell} ipython3
200199
id(a)

0 commit comments

Comments
 (0)