Intro to JAX#
Author: Nathaniel Starkman (MIT, starkman@mit.edu)
What JAX brings to the table:
Python
Familiar (numpy) API
JIT = Speed
GPU support
Sharding
Auto-differentiation
For science we often want to work with float64, not float32. We can set this permanently with an environment variable (see JAX docs) or with a configuration at import.
[1]:
import jax
jax.config.update("jax_enable_x64", True)
We are going to be explicit about dtypes and shapes. For this we will use the very popular jaxtyping
library.
[2]:
from jaxtyping import Float, Array, Shaped
[ ]:
NumPy#
(adapted from Quickstart and Sharp Bits)
Mostly you need to replace numpy
with
[3]:
import jax.numpy as jnp
Now you can write familiar functions with JAX
[4]:
def sum_of_squares(
x: Shaped[Array, "N"], /, *, axis: int | None = None,
) -> Shaped[Array, "1"]:
return jnp.sum(jnp.square(x), axis=axis)
sum_of_squares(jnp.array([1., 2, 3, 4, 5]))
[4]:
Array(55., dtype=float64)
There are well-known āSharp Bitsā to JAX. Letās highlight one here.
[5]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
try:
jax_array[1, :] = 1.0
except Exception as e:
print(e)
JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html
JAX arrays are immutable!
From many perspectives this is actually really nice (itās safer for operation tracing and harder to make in-place update mistakes). However 2 points against it are:
As a NumPy user Iām not used to this!
Arenāt out-of-place updates slower?
Yes, the first point is true. The second point / question is more complex. It can be slower the first time the program runs, but if the operation is within a JIT then the out-of-place updates are fused into highly-optimized in-place updates that are actually faster!
So if we build familiarity with JAXās out-of-place updates, then itās kind of better in every way.
[6]:
new_array = jax_array.at[1, :].set(1.0)
print("new array:\n", new_array)
new array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
[ ]:
Vectorization#
Adapted from automatic-vectorization
[7]:
def convolve(x: Float[Array, "N"], w: Float[Array, "3"]) -> Float[Array, "N-2"]:
return jnp.array([jnp.dot(x[i-1:i+2], w)
for i in range(1, len(x)-1)])
[8]:
x = jnp.arange(6, dtype=float)
w = jnp.array([2., 3., 4.])
[9]:
convolve(x, w)
[9]:
Array([11., 20., 29., 38.], dtype=float64)
[10]:
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])
xs.shape
[10]:
(2, 6)
[11]:
convolve(xs, ws) # Not what we want!
[11]:
Array([], shape=(0,), dtype=float64)
How do we vectorize operations in JAX?
Some operations are already vectorized
[12]:
jnp.cos(xs)
[12]:
Array([[ 1. , 0.54030231, -0.41614684, -0.9899925 , -0.65364362,
0.28366219],
[ 1. , 0.54030231, -0.41614684, -0.9899925 , -0.65364362,
0.28366219]], dtype=float64)
But for those that arenāt, thereāsā¦
[13]:
vmap_convolve = jax.vmap(convolve)
[14]:
vmap_convolve(xs, ws)
[14]:
Array([[11., 20., 29., 38.],
[11., 20., 29., 38.]], dtype=float64)
The only problem is that vmap
only vectorizes over pre-specified axes.
[15]:
vmap_convolve(xs[None], ws[None])
[15]:
Array([], shape=(1, 0), dtype=float64)
Thatās why thereās jax.numpy.vectorize
. This is like vmap
but requires less a priori knowledge of array shapes.
JAX doesnāt really advertise jax.numpy.vectorize
as a vmap
alternative, but itās very convenient!
[16]:
vec_convolve = jnp.vectorize(convolve, signature="(n),(w)->(m)")
[17]:
vec_convolve(xs[None], ws[None])
[17]:
Array([[[11., 20., 29., 38.],
[11., 20., 29., 38.]]], dtype=float64)
[ ]:
JIT#
[18]:
%timeit convolve(x, w)
237 μs ± 1.93 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
[19]:
jit_convolve = jax.jit(convolve)
[20]:
jit_convolve(x, w) # trigger jit
%timeit jit_convolve(x, w).block_until_ready()
4.47 μs ± 55.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
70x speedup! š±
And thatās on a CPU. GPU is even better.
What about the vectorized version?
[21]:
%timeit vec_convolve(x, w)
322 μs ± 8.86 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
[22]:
jit_vec_convolve = jax.jit(vec_convolve)
[23]:
jit_vec_convolve(x, w) # trigger jit
%timeit jit_vec_convolve(x, w).block_until_ready()
4.32 μs ± 33.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
š„³ we get highly general vectorization basically for free.
[ ]:
Differentiation#
[24]:
jax.grad, jax.jacfwd, jax.jacrev
[24]:
(<function jax._src.api.grad(fun: 'Callable', argnums: 'int | Sequence[int]' = 0, has_aux: 'bool' = False, holomorphic: 'bool' = False, allow_int: 'bool' = False, reduce_axes: 'Sequence[AxisName]' = ()) -> 'Callable'>,
<function jax._src.api.jacfwd(fun: 'Callable', argnums: 'int | Sequence[int]' = 0, has_aux: 'bool' = False, holomorphic: 'bool' = False) -> 'Callable'>,
<function jax._src.api.jacrev(fun: 'Callable', argnums: 'int | Sequence[int]' = 0, has_aux: 'bool' = False, holomorphic: 'bool' = False, allow_int: 'bool' = False) -> 'Callable'>)
Jacobians are easy. The output is a matrix of the derivatives of the outputs versus the inputs.
[25]:
jac_fn = jax.jacobian(convolve)
jac_fn(x, w)
[25]:
Array([[2., 3., 4., 0., 0., 0.],
[0., 2., 3., 4., 0., 0.],
[0., 0., 2., 3., 4., 0.],
[0., 0., 0., 2., 3., 4.]], dtype=float64)
Gradients can be a little more tricky since they must be scalar-valued.
[26]:
try:
jax.grad(convolve)(x, w)
except Exception as e:
print(e)
Gradient only defined for scalar-output functions. Output had shape: (4,).
[27]:
# this is one not-great way
func = lambda x, w: jnp.asarray([jax.grad(
lambda x, w: convolve(x, w)[0] # scalar output
)(x[i:i+3], w) for i in range(len(x) - 2)
])
func(x, w)
[27]:
Array([[2., 3., 4.],
[2., 3., 4.],
[2., 3., 4.],
[2., 3., 4.]], dtype=float64)
Note that this hack to produce scalar output basically just produced the diagonal of the jacobian
calculation. grad
is intended for scalar functions, which this example was not.
[ ]:
PyTrees#
Adapted from https://docs.kidger.site/equinox/all-of-equinox/
PyTrees
are what JAX
calls nested collections. JAX
has built-in support for tuples, lists, and dicts, but can also support any custom type, if properly registered.
PyTrees
can be built out of pretty much anything: JAX/NumPy arrays, floats, functions, etc.
Many JAX operations will accept either:
arbitrary PyTrees;
PyTrees with just JAX/NumPy arrays as the leaves;
PyTrees without any JAX/NumPy arrays as the leaves.
Functions in jax.numpy
just need to be tree_map
ed over.
[28]:
# Example pytree
pytree = {
'a': jnp.array([0.0, jnp.pi / 2, jnp.pi]),
'b': [jnp.array([jnp.pi / 4, jnp.pi / 3]), jnp.array([jnp.pi / 6])],
'c': (jnp.array([2 * jnp.pi]), jnp.array([3 * jnp.pi]))
}
# Apply jnp.cos to each element in the pytree
cos_pytree = jax.tree.map(jnp.cos, pytree)
cos_pytree
[28]:
{'a': Array([ 1.000000e+00, 6.123234e-17, -1.000000e+00], dtype=float64),
'b': [Array([0.70710678, 0.5 ], dtype=float64),
Array([0.8660254], dtype=float64)],
'c': (Array([1.], dtype=float64), Array([-1.], dtype=float64))}
[ ]:
JaxTyping + BearType#
This is not strictly JAX, but I think itās worth showing. jaxtyping
offers integrations with run-time type checkers. Normally runtime type checking is slow, but with JIT, it can be fast!
We are going to do this very explicitly, but there are faster tools. See `install_import_hook
<https://docs.kidger.site/jaxtyping/api/runtime-type-checking/#jaxtyping.install_import_hook>`__ to apply type-checking to an entire module, or the IPython extension to do this in notebooks.
[29]:
from jaxtyping import jaxtyped # explicit
from beartype import beartype as typechecker
[30]:
@jax.jit
@jaxtyped(typechecker=typechecker) # explicit
def checked_convolve(x: Float[Array, "N"], w: Float[Array, "3"]) -> Float[Array, "N-2"]:
return jnp.array([jnp.dot(x[i-1:i+2], w)
for i in range(1, len(x)-1)])
[31]:
try:
checked_convolve(jnp.array([1, 2, 3, 4, 5]), jnp.array([2., 3, 4]))
except Exception as e:
print(e)
Type-check error whilst checking the parameters of __main__.checked_convolve.
The problem arose whilst typechecking parameter 'x'.
Actual value: i64[5]
Expected type: <class 'Float[Array, 'N']'>.
----------------------
Called with parameters: {'x': i64[5], 'w': f64[3]}
Parameter annotations: (x: Float[Array, 'N'], w: Float[Array, '3']) -> Any.
[32]:
checked_convolve(jnp.array([1., 2, 3, 4, 5]), jnp.array([2., 3, 4]))
[32]:
Array([20., 29., 38.], dtype=float64)
[34]:
%timeit checked_convolve(x, w).block_until_ready()
4.47 μs ± 26 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Runtime type checking with (almost) no performance penalty? š±
Warning: this can incur a large pre-compilation penalty. Check speeds. Runtime type checking is easy to turn off.
[ ]:
Recap#
jax.numpy
vectorization
jit
differentiation
PyTrees
Runtime type-checking