Unxt#

With quax we can do units in JAX!

[1]:
import jax
import jax.numpy as jnp
from beartype import beartype
from jaxtyping import Float

Why doesn’t Astropy work out of the box?

[2]:
import astropy.units as u
[3]:
x = u.Quantity([1, 2, 3], "m")
[4]:
jnp.sqrt(x)
[4]:
Array([1.       , 1.4142135, 1.7320508], dtype=float32)

We lost the units!

Astropy arrays are converted to NumPy arrays, which are then converted to JAX arrays.

Let’s try the quaxed version.

[5]:
import quaxed.numpy as qnp
[6]:
qnp.sqrt(x)
[6]:
Array([1.       , 1.4142135, 1.7320508], dtype=float32)

Darn. same problem.

Introducing…

Quantity#

[7]:
from unxt import Quantity
from unxt import uconvert, ustrip  # alt functional API
[8]:
x = Quantity([1, 2, 3], 'm')
x
[8]:
[1, 2, 3] * Unit("m")
[9]:
try:
    jnp.sqrt(x)
except TypeError as e:
    print(e)
sqrt requires ndarray or scalar arguments, got <class 'unxt._src.quantity.quantity.Quantity[PhysicalType('length')]'> at position 0.
[10]:
from quax import quaxify

quaxify(jnp.sqrt)(x)
[10]:
[1. , 1.41421354, 1.73205078] * Unit("m(1/2)")
[11]:
# Alternatively, use the `quaxed` version.
qnp.sqrt(x)
[11]:
[1. , 1.41421354, 1.73205078] * Unit("m(1/2)")

Now let’s return to the Astropy Quantity

[12]:
qnp.sqrt(x)
[12]:
[1. , 1.41421354, 1.73205078] * Unit("m(1/2)")

It worked?!? What happened?

Well, unxt hooked into quax and said it knows how to handle Astropy Quantities. That’s some nice interoperability.

Astropy-like API#

[13]:
x.to(u.km)
[13]:
[0.001, 0.002, 0.003] * Unit("km")
[14]:
x.to_value(u.km)
[14]:
Array([0.001, 0.002, 0.003], dtype=float32, weak_type=True)
[15]:
x.decompose([u.imperial.furlong, u.Gyr])
[15]:
[0.00497097, 0.00994194, 0.01491291] * Unit("fur")
[ ]:

Parametric#

unxt.Quantity can be parametrized by its dimensions! This can be run-time inferred or set and used for verification

[16]:
Quantity["frequency"](1, 1/u.s)  # works
[16]:
1 * Unit("1 / s")
[17]:
try:
    Quantity["frequency"](1, "kpc")
except Exception as e:
    print(e)
Physical type mismatch.
[18]:
@beartype
def func(x: Quantity["length"], v: Quantity["speed"]) -> Quantity["time"]:
    return x / v

[19]:
func(Quantity(1, "kpc"), Quantity(100, "km/s"))
[19]:
0.01 * Unit("kpc s / km")
[20]:
try:
    func(Quantity(1, "kpc"), Quantity(100, "1/s"))
except Exception as e:
    print(e)
Function __main__.func() parameter v="Quantity(Array(100, dtype=int32, weak_type=True), unit='1 / s')" violates type hint <class 'unxt._src.quantity.quantity.Quantity[PhysicalType({'speed', 'velocity'})]'>, as <class "unxt._src.quantity.quantity.Quantity[PhysicalType('frequency')]"> "Quantity(Array(100, dtype=int32, weak_type=True), unit='1 / s')" not instance of <class "unxt._src.quantity.quantity.Quantity[PhysicalType({'speed', 'velocity'})]">.
[ ]:

Constructor#

unxt offers powerful methods to construct a Quantity, through a generalized constructor class-method.

[21]:
# normal construction
Quantity.from_([1, 2, 3], 'm')
[21]:
[1, 2, 3] * Unit("m")
[22]:
# From a Quantity
Quantity.from_(x)
[22]:
[1, 2, 3] * Unit("m")
[23]:
# Also changing the unit
Quantity.from_(x, "km")
[23]:
[0.001, 0.002, 0.003] * Unit("km")
[24]:
Quantity.from_._f.methods
[24]:
List of 9 method(s):
    [0] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: typing.Union[ArrayLike, 
    list[typing.Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, 
    bool, int, float, complex]], tuple[typing.Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''],
    numpy.bool, numpy.number, bool, int, float, complex], ...]], unit: Any, *, dtype) ->
    unxt._src.quantity.base.AbstractQuantity
        <function from_ at 0x116d8f6a0> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and 
    Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:799
    [1] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: typing.Union[ArrayLike, 
    list[typing.Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, 
    bool, int, float, complex]], tuple[typing.Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''],
    numpy.bool, numpy.number, bool, int, float, complex], ...]], *, unit, dtype) ->
    unxt._src.quantity.base.AbstractQuantity
        <function from_ at 0x116d8f9c0> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and 
    Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:838
    [2] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], *, value, unit, dtype) ->
    unxt._src.quantity.base.AbstractQuantity
        <function from_ at 0x116d8fce0> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and 
    Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:863
    [3] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], mapping: collections.abc.Mapping[str, 
    typing.Any]) -> unxt._src.quantity.base.AbstractQuantity
        <function from_ at 0x116dac220> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and 
    Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:883
    [4] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value:
    unxt._src.quantity.base.AbstractQuantity, unit: Any, *, dtype) -> unxt._src.quantity.base.AbstractQuantity
        <function from_ at 0x116dac360> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and 
    Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:909
    [5] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value:
    unxt._src.quantity.base.AbstractQuantity, unit: NoneType, *, dtype) ->
    unxt._src.quantity.base.AbstractQuantity
        <function from_ at 0x116dac680> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and 
    Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:935
    [6] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value:
    unxt._src.quantity.base.AbstractQuantity, *, unit, dtype) -> unxt._src.quantity.base.AbstractQuantity
        <function from_ at 0x116dac9a0> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and 
    Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:962
    [7] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: astropy.units.quantity.Quantity, *,
    **kwargs) -> unxt._src.quantity.base.AbstractQuantity
        <function from_ at 0x116ff7a60> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and 
    Flatiron/Flatiron 
    Presentation/.venv/lib/python3.12/site-packages/unxt/_interop/unxt_interop_astropy/quantity.py:53
    [8] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: astropy.units.quantity.Quantity, u:
    Any, *, **kwargs) -> unxt._src.quantity.base.AbstractQuantity
        <function from_ at 0x116fc2980> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and 
    Flatiron/Flatiron 
    Presentation/.venv/lib/python3.12/site-packages/unxt/_interop/unxt_interop_astropy/quantity.py:75
[25]:
# Using a dictionary
Quantity.from_({"value": x, "unit": "km"})
[25]:
[0.001, 0.002, 0.003] * Unit("km")
[26]:
Quantity.from_?
Signature: Quantity.from_(*args: Any, **kwargs: Any) -> 'AbstractQuantity'
Docstring:
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: Union[ArrayLike, list[Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex]], tuple[Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex], ...]], unit: Any, /, *, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity

Construct a `unxt.Quantity` from an array-like value and a unit.

:param value: The array-like value.
:param unit: The unit of the value.
:param dtype: The data type of the array (keyword-only).

Examples
--------
For this example we'll use the `Quantity` class. The same applies to
any subclass of `AbstractQuantity`.

>>> import jax.numpy as jnp
>>> import unxt as u

>>> x = jnp.array([1.0, 2, 3])
>>> u.Quantity.from_(x, "m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')

>>> u.Quantity.from_([1.0, 2, 3], "m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')

>>> u.Quantity.from_((1.0, 2, 3), "m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: Union[ArrayLike, list[Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex]], tuple[Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex], ...]], /, *, unit: Any, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity

Make a `unxt.AbstractQuantity` from an array-like value and a unit kwarg.

Examples
--------
For this example we'll use the `unxt.Quantity` class. The same applies
to any subclass of `unxt.AbstractQuantity`.

>>> import unxt as u
>>> u.Quantity.from_([1.0, 2, 3], unit="m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

from_(cls: type[unxt._src.quantity.base.AbstractQuantity], *, value: Any, unit: Any, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity

Construct a `AbstractQuantity` from value and unit kwargs.

Examples
--------
For this example we'll use the `Quantity` class. The same applies to
any subclass of `AbstractQuantity`.

>>> import unxt as u
>>> u.Quantity.from_(value=[1.0, 2, 3], unit="m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

from_(cls: type[unxt._src.quantity.base.AbstractQuantity], mapping: collections.abc.Mapping[str, typing.Any]) -> unxt._src.quantity.base.AbstractQuantity

Construct a `Quantity` from a Mapping.

Examples
--------
For this example we'll use the `Quantity` class. The same applies to
any subclass of `AbstractQuantity`.

>>> import jax.numpy as jnp
>>> import unxt as u

>>> x = jnp.array([1.0, 2, 3])
>>> q = u.Quantity.from_({"value": x, "unit": "m"})
>>> q
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')

>>> u.Quantity.from_({"value": q, "unit": "km"})
Quantity(Array([0.001, 0.002, 0.003], dtype=float32), unit='km')

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: unxt._src.quantity.base.AbstractQuantity, unit: Any, /, *, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity

Construct a `Quantity` from another `Quantity`.

The `value` is converted to the new `unit`.

Examples
--------
>>> import unxt as u

>>> q = u.Quantity(1, "m")
>>> u.Quantity.from_(q, "cm")
Quantity(Array(100., dtype=float32, ...), unit='cm')

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: unxt._src.quantity.base.AbstractQuantity, unit: NoneType, /, *, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity

Construct a `Quantity` from another `Quantity`.

The `value` is converted to the new `unit`.

Examples
--------
>>> import unxt as u

>>> q = u.Quantity(1, "m")
>>> u.Quantity.from_(q, None)
Quantity(Array(1, dtype=int32, ...), unit='m')

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: unxt._src.quantity.base.AbstractQuantity, /, *, unit: typing.Any | None = None, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity

Construct a `Quantity` from another `Quantity`, with no unit change.

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: astropy.units.quantity.Quantity, /, **kwargs: Any) -> unxt._src.quantity.base.AbstractQuantity

Construct a `Quantity` from another `Quantity`.

The `value` is converted to the new `unit`.

Examples
--------
>>> import unxt as u
>>> import astropy.units as apyu

>>> u.Quantity.from_(apyu.Quantity(1, "m"))
Quantity(Array(1., dtype=float32), unit='m')

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: astropy.units.quantity.Quantity, u: Any, /, **kwargs: Any) -> unxt._src.quantity.base.AbstractQuantity

Construct a `Quantity` from another `Quantity`.

The `value` is converted to the new `unit`.

Examples
--------
>>> import unxt as u
>>> import astropy.units as apyu

>>> u.Quantity.from_(apyu.Quantity(1, "m"), "cm")
Quantity(Array(100., dtype=float32), unit='cm')
File:      ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/unxt/_src/quantity/base.py
Type:      method

Non-Parametric#

Parametric types can be slower (it’s pretty marginal). If in a hot-path it ends up being a limiting factor, unxt offers a non-parametric Quantity.

[27]:
from unxt.quantity import BareQuantity
[28]:
x = BareQuantity(100, "km")
x
[28]:
100 * Unit("km")
[29]:
qnp.cbrt(x)
[29]:
4.641589 * Unit("km(1/3)")
[ ]:

Distance / Parallax#

This is similar to Astropy’s astropy.coordinates.Distance

[30]:
from coordinax import Distance
[31]:
d = Distance([1, 2, 3], 'm')
d
[31]:
[1, 2, 3] * Unit("m")

We can use it as normal

[32]:
qnp.sqrt(d)
[32]:
[1. , 1.41421354, 1.73205078] * Unit("m(1/2)")

Like Astropy, this has methods for converting between distance measures

[33]:
d.distance
[33]:
[1, 2, 3] * Unit("m")
[34]:
d.parallax
[34]:
[1.57079637, 1.57079637, 1.57079637] * Unit("rad")
[35]:
d.distance_modulus
[35]:
[-87.44674683, -85.94160461, -85.0611496] * Unit("mag")

There are convenient constructors

Parallax:

[36]:
Distance.from_(Quantity([1.], "rad"))
[36]:
[0.64209259] * Unit("AU")
[ ]:

See the Parallax class

[37]:
from coordinax.angle import Parallax
[38]:
p = Parallax([1., 2, 3], 'mas')
p
[38]:
[1., 2., 3.] * Unit("mas")
[39]:
p.distance
[39]:
[2.0626480e+08, 1.0313240e+08, 6.8754936e+07] * Unit("AU")
[40]:
p.parallax
[40]:
[1., 2., 3.] * Unit("mas")
[41]:
p.distance_modulus
[41]:
[10. , 8.49485016, 7.61439323] * Unit("mag")
[ ]:

Exercise: Build a DistanceModulus class

Hint, check out coordinax.distance.DistanceModulus

[ ]:

[ ]:

Unit Systems#

If you’ve used gala.units then you’re familiar with Astropy-compatible unit systems!

[42]:
from unxt.unitsystems import unitsystem, galactic, solarsystem
[43]:
unitsystem?
Signature:       unitsystem(usys: unxt._src.unitsystems.base.AbstractUnitSystem, /) -> unxt._src.unitsystems.base.AbstractUnitSystem
Call signature:  unitsystem(*args, **kw_args)
Type:            Function
String form:     <multiple-dispatch function unitsystem (with 9 registered and 0 pending method(s))>
File:            ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/unxt/_src/unitsystems/core.py
Docstring:
Convert a UnitSystem to a UnitSystem.

Examples
--------
>>> from unxt.unitsystems import unitsystem
>>> usys = unitsystem("kpc", "Myr", "Msun", "radian")
>>> usys
unitsystem(kpc, Myr, solMass, rad)

>>> unitsystem(usys) is usys
True

-----------------------------------------------------------------------------------------------------------------------------------------------------------

unitsystem(seq: collections.abc.Sequence[typing.Any], /) -> unxt._src.unitsystems.base.AbstractUnitSystem

Convert a UnitSystem or tuple of arguments to a UnitSystem.

Examples
--------
>>> import unxt as u

>>> u.unitsystem(())
DimensionlessUnitSystem()

>>> u.unitsystem(("kpc", "Myr", "Msun", "radian"))
unitsystem(kpc, Myr, solMass, rad)

>>> u.unitsystem(["kpc", "Myr", "Msun", "radian"])
unitsystem(kpc, Myr, solMass, rad)

-----------------------------------------------------------------------------------------------------------------------------------------------------------

unitsystem(_: NoneType, /) -> unxt._src.unitsystems.builtin.DimensionlessUnitSystem

Dimensionless unit system from None.

Examples
--------
>>> from unxt.unitsystems import unitsystem
>>> unitsystem(None)
DimensionlessUnitSystem()

-----------------------------------------------------------------------------------------------------------------------------------------------------------

unitsystem(*args: Any) -> unxt._src.unitsystems.base.AbstractUnitSystem

Convert a set of arguments to a UnitSystem.

Examples
--------
>>> from unxt.unitsystems import unitsystem

>>> unitsystem("kpc", "Myr", "Msun", "radian")
unitsystem(kpc, Myr, solMass, rad)

-----------------------------------------------------------------------------------------------------------------------------------------------------------

unitsystem(name: str, /) -> unxt._src.unitsystems.base.AbstractUnitSystem

Return unit system from name.

Examples
--------
>>> from unxt.unitsystems import unitsystem
>>> unitsystem("galactic")
unitsystem(kpc, Myr, solMass, rad)

>>> unitsystem("solarsystem")
unitsystem(AU, yr, solMass, rad)

>>> unitsystem("dimensionless")
DimensionlessUnitSystem()

-----------------------------------------------------------------------------------------------------------------------------------------------------------

unitsystem(usys: unxt._src.unitsystems.base.AbstractUnitSystem, *args: Any) -> unxt._src.unitsystems.base.AbstractUnitSystem

Create a unit system from an existing unit system and additional units.

Examples
--------
We can add a new unit definition to an existing unit system:

>>> from unxt.unitsystems import unitsystem
>>> usys = unitsystem("galactic")
>>> unitsystem(usys, "km/s")
LengthTimeMassAngleSpeedUnitSystem(length=Unit("kpc"), time=Unit("Myr"), mass=Unit("solMass"), angle=Unit("rad"), speed=Unit("km / s"))

We can also override the base unit of an existing unit system:

>>> new_usys = unitsystem(usys, "pc")
>>> new_usys
TimeMassAngleLengthUnitSystem(time=Unit("Myr"), mass=Unit("solMass"), angle=Unit("rad"), length=Unit("pc"))

-----------------------------------------------------------------------------------------------------------------------------------------------------------

unitsystem(flag: type[unxt._src.unitsystems.flags.AbstractUSysFlag], *_: Any) -> unxt._src.unitsystems.base.AbstractUnitSystem

Raise an exception since the flag is abstract.

-----------------------------------------------------------------------------------------------------------------------------------------------------------

unitsystem(flag: type[unxt._src.unitsystems.flags.StandardUSysFlag], *args: Any) -> unxt._src.unitsystems.base.AbstractUnitSystem

Create a standard unit system using the inputted units.

Examples
--------
>>> from unxt import unitsystem, unitsystems
>>> unitsystem(unitsystems.StandardUSysFlag, "kpc", "Myr", "Msun")
LengthTimeMassUnitSystem(length=Unit("kpc"), time=Unit("Myr"), mass=Unit("solMass"))

-----------------------------------------------------------------------------------------------------------------------------------------------------------

unitsystem(flag: type[unxt._src.unitsystems.flags.DynamicalSimUSysFlag], *args: Any, G: float | int = 1.0) -> unxt._src.unitsystems.base.AbstractUnitSystem

Make a dynamical unit system.

Examples
--------
>>> from unxt.unitsystems import unitsystem, DynamicalSimUSysFlag

>>> unitsystem(DynamicalSimUSysFlag, "m", "kg")
LengthMassTimeUnitSystem(length=Unit("m"), mass=Unit("kg"), time=Unit("122404 s"))
Class docstring:
A function.

Args:
    f (function): Function that is wrapped.
    owner (str, optional): Name of the class that owns the function.
    warn_redefinition (bool, optional): Throw a warning whenever a method is
        redefined. Defaults to `False`.
[44]:
usys = unitsystem("kpc", "Gyr", "solMass", "deg")
usys
[44]:
unitsystem(kpc, Gyr, solMass, deg)

There are also pre-defined unit systems

[45]:
galactic
[45]:
unitsystem(kpc, Myr, solMass, rad)
[46]:
solarsystem
[46]:
unitsystem(AU, yr, solMass, rad)

The unit systems are useful when working with Quantities.

[47]:
print(x)
x.decompose(usys)
BareQuantity(weak_i32[], unit='km')
[47]:
3.2407792e-15 * Unit("kpc")
[48]:
x = 1e4 * Quantity([1, 2, 3], 'lyr')
x.uconvert(usys["length"])
[48]:
[3.06601405, 6.1320281, 9.19804192] * Unit("kpc")
[49]:
Quantity([1, 2, 3], 'lyr / mas').decompose(usys)
[49]:
[1103.76501465, 2207.5300293, 3311.29492188] * Unit("kpc / deg")
[ ]:

Astropy Compatibility#

unxt.Quantity can be made from an Astropy Quantity

[50]:
x = Quantity.from_(u.Quantity(1, "kpc"))
x
[50]:
1. * Unit("kpc")

And it can be converted to Astropy

[51]:
from plum import convert

convert(x, u.Quantity)
[51]:
$1 \; \mathrm{kpc}$
[ ]:

Differentiation#

unxt fully integrates with JAX differentiation.

[52]:
def polynomial(
    x: Float[Quantity["area"], "N"],
    a0: Float[Quantity["volume"], ""],
    a1: Float[Quantity["length"], ""],
    a2: Float[Quantity["1/m"], ""],
) -> Float[Quantity["volume"], "N"]:
    return a0 + a1 * x ** 1 + a2 * x ** 2

polynomial
[52]:
<function __main__.polynomial(x: jaxtyping.Float[Quantity[PhysicalType('area')], 'N'], a0: jaxtyping.Float[Quantity[PhysicalType('volume')], ''], a1: jaxtyping.Float[Quantity[PhysicalType('length')], ''], a2: jaxtyping.Float[Quantity[PhysicalType('wavenumber')], '']) -> jaxtyping.Float[Quantity[PhysicalType('volume')], 'N']>
[53]:
x = Quantity(1.0, "km2")
a0, a1, a2 = Quantity(1, "km3"), Quantity(2, "km"), Quantity(3, "1/km")
[54]:
polynomial(x, a0, a1, a2)
[54]:
6. * Unit("km3")
[55]:
grad_fn = quaxify(jax.grad(polynomial))
grad_fn(x, a0, a1, a2)
[55]:
8. * Unit("km")
[56]:
# Alternatively:

from quaxed import grad as qgrad

qgrad(polynomial)(x, a0, a1, a2)
[56]:
8. * Unit("km")
[ ]: