Intro to JAX#

Author: Nathaniel Starkman (MIT, starkman@mit.edu)

What JAX brings to the table:

  1. Python

  2. Familiar (numpy) API

  3. JIT = Speed

  4. GPU support

  5. Sharding

  6. Auto-differentiation

For science we often want to work with float64, not float32. We can set this permanently with an environment variable (see JAX docs) or with a configuration at import.

[1]:
import jax
jax.config.update("jax_enable_x64", True)

We are going to be explicit about dtypes and shapes. For this we will use the very popular jaxtyping library.

[2]:
from jaxtyping import Float, Array, Shaped
[ ]:

NumPy#

(adapted from Quickstart and Sharp Bits)

Mostly you need to replace numpy with

[3]:
import jax.numpy as jnp

Now you can write familiar functions with JAX

[4]:
def sum_of_squares(
    x: Shaped[Array, "N"], /, *, axis: int | None = None,
) -> Shaped[Array, "1"]:
    return jnp.sum(jnp.square(x), axis=axis)

sum_of_squares(jnp.array([1., 2, 3, 4, 5]))
[4]:
Array(55., dtype=float64)

There are well-known ā€œSharp Bitsā€ to JAX. Let’s highlight one here.

[5]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
try:
    jax_array[1, :] = 1.0
except Exception as e:
    print(e)
JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

JAX arrays are immutable!

From many perspectives this is actually really nice (it’s safer for operation tracing and harder to make in-place update mistakes). However 2 points against it are:

  1. As a NumPy user I’m not used to this!

  2. Aren’t out-of-place updates slower?

Yes, the first point is true. The second point / question is more complex. It can be slower the first time the program runs, but if the operation is within a JIT then the out-of-place updates are fused into highly-optimized in-place updates that are actually faster!

So if we build familiarity with JAX’s out-of-place updates, then it’s kind of better in every way.

[6]:
new_array = jax_array.at[1, :].set(1.0)
print("new array:\n", new_array)
new array:
 [[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]
[ ]:

Vectorization#

Adapted from automatic-vectorization

[7]:
def convolve(x: Float[Array, "N"], w: Float[Array, "3"]) -> Float[Array, "N-2"]:
  return jnp.array([jnp.dot(x[i-1:i+2], w)
                    for i in range(1, len(x)-1)])
[8]:
x = jnp.arange(6, dtype=float)
w = jnp.array([2., 3., 4.])
[9]:
convolve(x, w)
[9]:
Array([11., 20., 29., 38.], dtype=float64)
[10]:
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

xs.shape
[10]:
(2, 6)
[11]:
convolve(xs, ws)  # Not what we want!
[11]:
Array([], shape=(0,), dtype=float64)

How do we vectorize operations in JAX?

Some operations are already vectorized

[12]:
jnp.cos(xs)
[12]:
Array([[ 1.        ,  0.54030231, -0.41614684, -0.9899925 , -0.65364362,
         0.28366219],
       [ 1.        ,  0.54030231, -0.41614684, -0.9899925 , -0.65364362,
         0.28366219]], dtype=float64)

But for those that aren’t, there’s…

[13]:
vmap_convolve = jax.vmap(convolve)
[14]:
vmap_convolve(xs, ws)
[14]:
Array([[11., 20., 29., 38.],
       [11., 20., 29., 38.]], dtype=float64)

The only problem is that vmap only vectorizes over pre-specified axes.

[15]:
vmap_convolve(xs[None], ws[None])
[15]:
Array([], shape=(1, 0), dtype=float64)

That’s why there’s jax.numpy.vectorize. This is like vmap but requires less a priori knowledge of array shapes.

JAX doesn’t really advertise jax.numpy.vectorize as a vmap alternative, but it’s very convenient!

[16]:
vec_convolve = jnp.vectorize(convolve, signature="(n),(w)->(m)")
[17]:
vec_convolve(xs[None], ws[None])
[17]:
Array([[[11., 20., 29., 38.],
        [11., 20., 29., 38.]]], dtype=float64)
[ ]:

JIT#

[18]:
%timeit convolve(x, w)
237 μs ± 1.93 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
[19]:
jit_convolve = jax.jit(convolve)
[20]:
jit_convolve(x, w)  # trigger jit

%timeit jit_convolve(x, w).block_until_ready()
4.47 μs ± 55.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

70x speedup! 😱

And that’s on a CPU. GPU is even better.

What about the vectorized version?

[21]:
%timeit vec_convolve(x, w)
322 μs ± 8.86 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
[22]:
jit_vec_convolve = jax.jit(vec_convolve)
[23]:
jit_vec_convolve(x, w)  # trigger jit

%timeit jit_vec_convolve(x, w).block_until_ready()
4.32 μs ± 33.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

🄳 we get highly general vectorization basically for free.

[ ]:

Differentiation#

[24]:
jax.grad, jax.jacfwd, jax.jacrev
[24]:
(<function jax._src.api.grad(fun: 'Callable', argnums: 'int | Sequence[int]' = 0, has_aux: 'bool' = False, holomorphic: 'bool' = False, allow_int: 'bool' = False, reduce_axes: 'Sequence[AxisName]' = ()) -> 'Callable'>,
 <function jax._src.api.jacfwd(fun: 'Callable', argnums: 'int | Sequence[int]' = 0, has_aux: 'bool' = False, holomorphic: 'bool' = False) -> 'Callable'>,
 <function jax._src.api.jacrev(fun: 'Callable', argnums: 'int | Sequence[int]' = 0, has_aux: 'bool' = False, holomorphic: 'bool' = False, allow_int: 'bool' = False) -> 'Callable'>)

Jacobians are easy. The output is a matrix of the derivatives of the outputs versus the inputs.

[25]:
jac_fn = jax.jacobian(convolve)
jac_fn(x, w)
[25]:
Array([[2., 3., 4., 0., 0., 0.],
       [0., 2., 3., 4., 0., 0.],
       [0., 0., 2., 3., 4., 0.],
       [0., 0., 0., 2., 3., 4.]], dtype=float64)

Gradients can be a little more tricky since they must be scalar-valued.

[26]:
try:
    jax.grad(convolve)(x, w)
except Exception as e:
    print(e)
Gradient only defined for scalar-output functions. Output had shape: (4,).
[27]:
# this is one not-great way
func = lambda x, w: jnp.asarray([jax.grad(
    lambda x, w: convolve(x, w)[0]  # scalar output
)(x[i:i+3], w) for i in range(len(x) - 2)
])
func(x, w)
[27]:
Array([[2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.],
       [2., 3., 4.]], dtype=float64)

Note that this hack to produce scalar output basically just produced the diagonal of the jacobian calculation. grad is intended for scalar functions, which this example was not.

[ ]:

PyTrees#

Adapted from https://docs.kidger.site/equinox/all-of-equinox/

PyTrees are what JAX calls nested collections. JAX has built-in support for tuples, lists, and dicts, but can also support any custom type, if properly registered.

PyTrees can be built out of pretty much anything: JAX/NumPy arrays, floats, functions, etc.

Many JAX operations will accept either:

  • arbitrary PyTrees;

  • PyTrees with just JAX/NumPy arrays as the leaves;

  • PyTrees without any JAX/NumPy arrays as the leaves.

Functions in jax.numpy just need to be tree_maped over.

[28]:
# Example pytree
pytree = {
    'a': jnp.array([0.0, jnp.pi / 2, jnp.pi]),
    'b': [jnp.array([jnp.pi / 4, jnp.pi / 3]), jnp.array([jnp.pi / 6])],
    'c': (jnp.array([2 * jnp.pi]), jnp.array([3 * jnp.pi]))
}

# Apply jnp.cos to each element in the pytree
cos_pytree = jax.tree.map(jnp.cos, pytree)
cos_pytree
[28]:
{'a': Array([ 1.000000e+00,  6.123234e-17, -1.000000e+00], dtype=float64),
 'b': [Array([0.70710678, 0.5       ], dtype=float64),
  Array([0.8660254], dtype=float64)],
 'c': (Array([1.], dtype=float64), Array([-1.], dtype=float64))}
[ ]:

JaxTyping + BearType#

This is not strictly JAX, but I think it’s worth showing. jaxtyping offers integrations with run-time type checkers. Normally runtime type checking is slow, but with JIT, it can be fast!

We are going to do this very explicitly, but there are faster tools. See `install_import_hook <https://docs.kidger.site/jaxtyping/api/runtime-type-checking/#jaxtyping.install_import_hook>`__ to apply type-checking to an entire module, or the IPython extension to do this in notebooks.

[29]:
from jaxtyping import jaxtyped  # explicit
from beartype import beartype as typechecker
[30]:
@jax.jit
@jaxtyped(typechecker=typechecker)  # explicit
def checked_convolve(x: Float[Array, "N"], w: Float[Array, "3"]) -> Float[Array, "N-2"]:
  return jnp.array([jnp.dot(x[i-1:i+2], w)
                    for i in range(1, len(x)-1)])
[31]:
try:
    checked_convolve(jnp.array([1, 2, 3, 4, 5]), jnp.array([2., 3, 4]))
except Exception as e:
    print(e)
Type-check error whilst checking the parameters of __main__.checked_convolve.
The problem arose whilst typechecking parameter 'x'.
Actual value: i64[5]
Expected type: <class 'Float[Array, 'N']'>.
----------------------
Called with parameters: {'x': i64[5], 'w': f64[3]}
Parameter annotations: (x: Float[Array, 'N'], w: Float[Array, '3']) -> Any.

[32]:
checked_convolve(jnp.array([1., 2, 3, 4, 5]), jnp.array([2., 3, 4]))
[32]:
Array([20., 29., 38.], dtype=float64)
[34]:
%timeit checked_convolve(x, w).block_until_ready()
4.47 μs ± 26 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Runtime type checking with (almost) no performance penalty? 😱

Warning: this can incur a large pre-compilation penalty. Check speeds. Runtime type checking is easy to turn off.

[ ]:

Recap#

  1. jax.numpy

  2. vectorization

  3. jit

  4. differentiation

  5. PyTrees

  6. Runtime type-checking