Workshop on Quax 🦆

Workshop on Quax 🦆#

Author: Nathaniel Starkman (MIT, starkman@mit.edu)

This example is adapted from https://docs.kidger.site/quax/

In this example, we’ll see how to create a custom array-ish Quax type.

We’re going to implement a “dimensional” type, which annotates each array with a dimension like “length” or “time”. It will keep track of the dimensions as they propagate through the computation, and disallow things like adding a length-array to a time-array. (Which isn’t a thing you can do in physics!)

[1]:
from dataclasses import replace
from typing import Union
from functools import partial

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike, Shaped, Array

import quax
from quax import quaxify

As a first step for this example (unrelated to Quax), let’s define a toy dimension system. (In this simple system we only have “length” etc., but no notion of units).

[2]:
class Dimension:
    def __init__(self, name):
        self.name = name

    def __repr__(self):
        return self.name

    def __lt__(self, other):
        return False


mass = Dimension("mass")
length = Dimension("length")
time = Dimension("time")

Now let’s define our custom Quax type. It’ll wrap together an array and a set of dimensions.

[3]:
class Dimensional(quax.ArrayValue):
    array: ArrayLike
    dimensions: dict[Dimension, int] = eqx.field(static=True, converter=lambda x: {x: 1} if isinstance(x, Dimension) else x)

    def aval(self):
        return jax.core.ShapedArray(jnp.shape(self.array), jnp.result_type(self.array))

    def materialise(self):
        raise ValueError("Refusing to materialize Dimensional array.")

Example usage for this is Unitful(array, meters) to indicate that the array has units of meters, or Unitful(array, {meters: 1, seconds: -1}) to indicate the array has units of meters-per-second.

Now let’s define a few rules for how unitful arrays interact with each other.

[4]:
@quax.register(jax.lax.add_p)
def _(x: Dimensional, y: Dimensional):
    if x.dimensions != y.dimensions:
        raise ValueError(f"Cannot add two arrays with dimensions {x.dimensions} and {y.dimensions}.")
    return Dimensional(x.array + y.array, x.dimensions)


@quax.register(jax.lax.mul_p)
def _(x: Dimensional, y: Dimensional):
    dimensions = x.dimensions.copy()
    for k, v in y.dimensions.items():
        if k in dimensions:
            dimensions[k] += v
        else:
            dimensions[k] = v
    return Dimensional(x.array * y.array, dimensions)


@quax.register(jax.lax.mul_p)
def _(x: ArrayLike, y: Dimensional):
    return Dimensional(x * y.array, y.dimensions)


@quax.register(jax.lax.mul_p)
def _(x: Dimensional, y: ArrayLike):
    return Dimensional(x.array * y, x.dimensions)


@quax.register(jax.lax.integer_pow_p)
def _(x: Dimensional, *, y: int):
    dimensions = {k: v * y for k, v in x.dimensions.items()}
    return Dimensional(x.array, dimensions)

And now let’s go ahead and use these in practice!

As our example, we’ll consider computing the energy of a ball moving in Earth’s gravity.

[5]:
def kinetic_energy(mass, velocity):
    """Kinetic energy of a ball with `mass` moving with `velocity`."""
    return 0.5 * mass * velocity**2


def gravitational_potential_energy(mass, height, g):
    """Gravitional potential energy of a ball with `mass` at a distance `height` above
    the Earth's surface.
    """
    return g * mass * height


def compute_energy(mass, velocity, height, g):
    return kinetic_energy(mass, velocity) + gravitational_potential_energy(
        mass, height, g
    )
[6]:
m = Dimensional(jnp.array(3.0), mass)
v = Dimensional(jnp.array(2.2), {length: 1, time: -1})
h = Dimensional(jnp.array(1.0), length)
g = Dimensional(jnp.array(9.81), {length: 1, time: -2})

E = quaxify(compute_energy)(m, v, h, g)

print(f"The amount of energy is {E.array.item()} with units {E.dimensions}.")
The amount of energy is 32.72999954223633 with units {mass: 1, length: 2, time: -2}.

Wonderful! That went perfectly.

The key take-aways from this example are:

  • The basic usage of defining a custom type with its aval and materialise

  • How to define a rule that binds your custom type against itself, e.g. python     @quax.register(jax.lax.mul_p)     def _(x: Unitful, y: Unitful): ...

  • How to define a rule that binds your custom type against a normal JAX arraylike type, e.g. python     @quax.register(jax.lax.mul_p)     def _(x: ArrayLike, y: Unitful): ... (An ArrayLike is all the things JAX is normally willing to have interact with arrays: bool/int/float/complex/NumPy scalars/NumPy arrays/JAX arrays. You can think of the purpose of Quax as being a way to extend what it means for an object to be arraylike.)

Differentiation#

In true JAX fashion we can take derivatives

[7]:
quaxify(kinetic_energy)(m, v)
[7]:
Dimensional(array=f32[], dimensions={mass: 1, length: 2, time: -2})
[8]:
quaxify(jax.grad(kinetic_energy, argnums=0))(m, v)
[8]:
Dimensional(array=f32[], dimensions={length: 2, time: -2})
[9]:
quaxify(jax.grad(kinetic_energy, argnums=1))(m, v)
[9]:
Dimensional(array=f32[], dimensions={mass: 1, length: 1, time: -1})
[ ]:


Exercise: Use Gradient Descent to Find the Minimum of a Function

[10]:
# A more fully-featured version of the classes

from dimensional import Dimensional, Dimension
mass = Dimension("mass")
length = Dimension("length")
time = Dimension("time")
[11]:
def the_function(
    x: Shaped[Dimensional, "*shape"],
    *,
    scale: Shaped[Dimensional, ""]
) -> Shaped[Array, "*shape"]:
    y = x / scale
    return scale * (y**4 - 3*y**3 + 2)

Solution. No peeking!

# Compute the derivative of the potential energy function
the_function_grad = jax.jit(quax.quaxify(jax.grad(the_function)))

# trigger JIT compile
the_function_grad(Dimensional(0.0, length), scale=Dimensional(1.0, length))

# Initial guess for the position
x_init = Dimensional(0.0, length)

# Scale for the position
scale = Dimensional(1.0, length)

# Optimization step: use gradient descent to find the minimum
learning_rate = Dimensional(0.01, length)
num_steps = 1000

x = x_init
for _ in range(num_steps):
    x = x - learning_rate * the_function_grad(x, scale=scale)

print(f"Position of minimum: {x}")
print(f"Minimum: {the_function(x, scale=scale)}")
[ ]:

[ ]:

Quaxed#

quax is impressive!

However one of the annoyances is having to wrap a lot of operations with quaxify. Thankfully in a complex calculation it can be just the outer-most step that must be wrapped. But for scripting, or if an outer-quaxed is non-optimal, there’s a convenience library…

quaxed offers pre-quaxified jax.

NumPy#

[12]:
import quaxed.numpy as jnp

quaxed.numpy acts the same as jax.numpy

[13]:
x = jnp.array([1, 2, 3])
isinstance(x, jax.Array)  # True
[13]:
True
[14]:
jnp.square(x)
[14]:
Array([1, 4, 9], dtype=int32)

But it also supports all the power of quax objects.

[15]:
from dimensional import Dimensional, Dimension
[16]:
length = Dimension('length')
y = Dimensional([1.0, 2, 3], length)
[17]:
jnp.square(y)
[17]:
Dimensional(array=[1. 4. 9.], dimensions={length: 2})
[ ]:

Operator#

quax also supports a functional approach to operators (see https://docs.python.org/3/library/operator.html)

[18]:
import quaxed.operator as op
[19]:
op.mul(x, y)
[19]:
Dimensional(array=[1. 4. 9.], dimensions={length: 1})
[20]:
op.pow(y, 2)
[20]:
Dimensional(array=[1. 2. 3.], dimensions={length: 2})
[ ]:

Recap#

  1. quax

  2. quaxed

[ ]: