# Unxt

With `quax` we can do units in JAX!

In [1]:
import jax
import jax.numpy as jnp
from beartype import beartype
from jaxtyping import Float

Why doesn't Astropy work out of the box?

In [2]:
import astropy.units as u

In [3]:
x = u.Quantity([1, 2, 3], "m")

In [4]:
jnp.sqrt(x)

Array([1.       , 1.4142135, 1.7320508], dtype=float32)

We lost the units!

Astropy arrays are converted to NumPy arrays, which are then converted to JAX arrays.

Let's try the `quaxed` version.

In [5]:
import quaxed.numpy as qnp

In [6]:
qnp.sqrt(x)

Array([1.       , 1.4142135, 1.7320508], dtype=float32)

Darn. same problem.

Introducing...

## Quantity

In [7]:
from unxt import Quantity
from unxt import uconvert, ustrip  # alt functional API

In [8]:
x = Quantity([1, 2, 3], 'm')
x

In [9]:
try:
    jnp.sqrt(x)
except TypeError as e:
    print(e)

sqrt requires ndarray or scalar arguments, got <class 'unxt._src.quantity.quantity.Quantity[PhysicalType('length')]'> at position 0.


In [10]:
from quax import quaxify

quaxify(jnp.sqrt)(x)

In [11]:
# Alternatively, use the `quaxed` version.
qnp.sqrt(x)

Now let's return to the Astropy Quantity

In [12]:
qnp.sqrt(x)

It worked?!? What happened?

Well, `unxt` hooked into `quax` and said it knows how to handle Astropy Quantities.
That's some nice interoperability.

### Astropy-like API

In [13]:
x.to(u.km)

In [14]:
x.to_value(u.km)

Array([0.001, 0.002, 0.003], dtype=float32, weak_type=True)

In [15]:
x.decompose([u.imperial.furlong, u.Gyr])

### Parametric

`unxt.Quantity` can be parametrized by its dimensions!
This can be run-time inferred or set and used for verification

In [16]:
Quantity["frequency"](1, 1/u.s)  # works

In [17]:
try:
    Quantity["frequency"](1, "kpc")
except Exception as e:
    print(e)

Physical type mismatch.


In [18]:
@beartype
def func(x: Quantity["length"], v: Quantity["speed"]) -> Quantity["time"]:
    return x / v


In [19]:
func(Quantity(1, "kpc"), Quantity(100, "km/s"))

In [20]:
try:
    func(Quantity(1, "kpc"), Quantity(100, "1/s"))
except Exception as e:
    print(e)

Function __main__.func() parameter v="Quantity(Array(100, dtype=int32, weak_type=True), unit='1 / s')" violates type hint <class 'unxt._src.quantity.quantity.Quantity[PhysicalType({'speed', 'velocity'})]'>, as <class "unxt._src.quantity.quantity.Quantity[PhysicalType('frequency')]"> "Quantity(Array(100, dtype=int32, weak_type=True), unit='1 / s')" not instance of <class "unxt._src.quantity.quantity.Quantity[PhysicalType({'speed', 'velocity'})]">.


### Constructor

`unxt` offers powerful methods to construct a `Quantity`, through a generalized constructor class-method.

In [21]:
# normal construction
Quantity.from_([1, 2, 3], 'm')

In [22]:
# From a Quantity
Quantity.from_(x)

In [23]:
# Also changing the unit
Quantity.from_(x, "km")

In [24]:
Quantity.from_._f.methods

In [25]:
# Using a dictionary
Quantity.from_({"value": x, "unit": "km"})

In [26]:
Quantity.from_?

[31mSignature:[39m Quantity.from_(*args: Any, **kwargs: Any) -> [33m'AbstractQuantity'[39m
[31mDocstring:[39m
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: Union[ArrayLike, list[Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex]], tuple[Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex], ...]], unit: Any, /, *, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity

Construct a `unxt.Quantity` from an ar

## Non-Parametric

Parametric types can be slower (it's pretty marginal).
If in a hot-path it ends up being a limiting factor,
`unxt` offers a non-parametric Quantity.

In [27]:
from unxt.quantity import BareQuantity

In [28]:
x = BareQuantity(100, "km")
x

In [29]:
qnp.cbrt(x)

## Distance / Parallax

This is similar to Astropy's 
``astropy.coordinates.Distance``

In [30]:
from coordinax import Distance

In [31]:
d = Distance([1, 2, 3], 'm')
d

We can use it as normal

In [32]:
qnp.sqrt(d)

Like Astropy, this has methods for converting between distance measures

In [33]:
d.distance

In [34]:
d.parallax

In [35]:
d.distance_modulus

There are convenient constructors

Parallax:

In [36]:
Distance.from_(Quantity([1.], "rad"))

See the Parallax class

In [37]:
from coordinax.angle import Parallax

In [38]:
p = Parallax([1., 2, 3], 'mas')
p

In [39]:
p.distance

In [40]:
p.parallax

In [41]:
p.distance_modulus

<div class="alert alert-block alert-success">
<b>Exercise:</b> Build a DistanceModulus class
</div>

Hint, check out `coordinax.distance.DistanceModulus`

## Unit Systems

If you've used `gala.units` then you're familiar with Astropy-compatible unit systems!

In [42]:
from unxt.unitsystems import unitsystem, galactic, solarsystem

In [43]:
unitsystem?

[31mSignature:[39m       unitsystem(usys: unxt._src.unitsystems.base.AbstractUnitSystem, /) -> unxt._src.unitsystems.base.AbstractUnitSystem
[31mCall signature:[39m  unitsystem(*args, **kw_args)
[31mType:[39m            Function
[31mString form:[39m     <multiple-dispatch function unitsystem (with 9 registered and 0 pending method(s))>
[31mFile:[39m            ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/unxt/_src/unitsystems/core.py
[31mDocstring:[39m      
Convert a UnitSystem to a UnitSystem.

Examples
--------
>>> from unxt.unitsystems import unitsystem
>>> usys = unitsystem("kpc", "Myr", "Msun", "radian")
>>> usys
unitsystem(kpc, Myr, solMass, rad)

>>> unitsystem(usys) is usys
True

-----------------------------------------------------------------------------------------------------------------------------------------------------------

unitsystem(seq: collections.abc.Sequence[typ

In [44]:
usys = unitsystem("kpc", "Gyr", "solMass", "deg")
usys

unitsystem(kpc, Gyr, solMass, deg)

There are also pre-defined unit systems

In [45]:
galactic

unitsystem(kpc, Myr, solMass, rad)

In [46]:
solarsystem

unitsystem(AU, yr, solMass, rad)

The unit systems are useful when working with Quantities.

In [47]:
print(x)
x.decompose(usys)

BareQuantity(weak_i32[], unit='km')


In [48]:
x = 1e4 * Quantity([1, 2, 3], 'lyr')
x.uconvert(usys["length"])

In [49]:
Quantity([1, 2, 3], 'lyr / mas').decompose(usys)

## Astropy Compatibility

`unxt.Quantity` can be made from an Astropy Quantity

In [50]:
x = Quantity.from_(u.Quantity(1, "kpc"))
x

And it can be converted to Astropy

In [51]:
from plum import convert

convert(x, u.Quantity)

<Quantity 1. kpc>

## Differentiation

`unxt` fully integrates with `JAX` differentiation.

In [52]:
def polynomial(
    x: Float[Quantity["area"], "N"],
    a0: Float[Quantity["volume"], ""],
    a1: Float[Quantity["length"], ""],
    a2: Float[Quantity["1/m"], ""],
) -> Float[Quantity["volume"], "N"]:
    return a0 + a1 * x ** 1 + a2 * x ** 2

polynomial

<function __main__.polynomial(x: jaxtyping.Float[Quantity[PhysicalType('area')], 'N'], a0: jaxtyping.Float[Quantity[PhysicalType('volume')], ''], a1: jaxtyping.Float[Quantity[PhysicalType('length')], ''], a2: jaxtyping.Float[Quantity[PhysicalType('wavenumber')], '']) -> jaxtyping.Float[Quantity[PhysicalType('volume')], 'N']>

In [53]:
x = Quantity(1.0, "km2")
a0, a1, a2 = Quantity(1, "km3"), Quantity(2, "km"), Quantity(3, "1/km")

In [54]:
polynomial(x, a0, a1, a2)

In [55]:
grad_fn = quaxify(jax.grad(polynomial))
grad_fn(x, a0, a1, a2)

In [56]:
# Alternatively:

from quaxed import grad as qgrad

qgrad(polynomial)(x, a0, a1, a2)