Unxt#
With quax
we can do units in JAX!
[1]:
import jax
import jax.numpy as jnp
from beartype import beartype
from jaxtyping import Float
Why doesn’t Astropy work out of the box?
[2]:
import astropy.units as u
[3]:
x = u.Quantity([1, 2, 3], "m")
[4]:
jnp.sqrt(x)
[4]:
Array([1. , 1.4142135, 1.7320508], dtype=float32)
We lost the units!
Astropy arrays are converted to NumPy arrays, which are then converted to JAX arrays.
Let’s try the quaxed
version.
[5]:
import quaxed.numpy as qnp
[6]:
qnp.sqrt(x)
[6]:
Array([1. , 1.4142135, 1.7320508], dtype=float32)
Darn. same problem.
Introducing…
Quantity#
[7]:
from unxt import Quantity
from unxt import uconvert, ustrip # alt functional API
[8]:
x = Quantity([1, 2, 3], 'm')
x
[8]:
[9]:
try:
jnp.sqrt(x)
except TypeError as e:
print(e)
sqrt requires ndarray or scalar arguments, got <class 'unxt._src.quantity.quantity.Quantity[PhysicalType('length')]'> at position 0.
[10]:
from quax import quaxify
quaxify(jnp.sqrt)(x)
[10]:
[11]:
# Alternatively, use the `quaxed` version.
qnp.sqrt(x)
[11]:
Now let’s return to the Astropy Quantity
[12]:
qnp.sqrt(x)
[12]:
It worked?!? What happened?
Well, unxt
hooked into quax
and said it knows how to handle Astropy Quantities. That’s some nice interoperability.
Astropy-like API#
[13]:
x.to(u.km)
[13]:
[14]:
x.to_value(u.km)
[14]:
Array([0.001, 0.002, 0.003], dtype=float32, weak_type=True)
[15]:
x.decompose([u.imperial.furlong, u.Gyr])
[15]:
[ ]:
Parametric#
unxt.Quantity
can be parametrized by its dimensions! This can be run-time inferred or set and used for verification
[16]:
Quantity["frequency"](1, 1/u.s) # works
[16]:
[17]:
try:
Quantity["frequency"](1, "kpc")
except Exception as e:
print(e)
Physical type mismatch.
[18]:
@beartype
def func(x: Quantity["length"], v: Quantity["speed"]) -> Quantity["time"]:
return x / v
[19]:
func(Quantity(1, "kpc"), Quantity(100, "km/s"))
[19]:
[20]:
try:
func(Quantity(1, "kpc"), Quantity(100, "1/s"))
except Exception as e:
print(e)
Function __main__.func() parameter v="Quantity(Array(100, dtype=int32, weak_type=True), unit='1 / s')" violates type hint <class 'unxt._src.quantity.quantity.Quantity[PhysicalType({'speed', 'velocity'})]'>, as <class "unxt._src.quantity.quantity.Quantity[PhysicalType('frequency')]"> "Quantity(Array(100, dtype=int32, weak_type=True), unit='1 / s')" not instance of <class "unxt._src.quantity.quantity.Quantity[PhysicalType({'speed', 'velocity'})]">.
[ ]:
Constructor#
unxt
offers powerful methods to construct a Quantity
, through a generalized constructor class-method.
[21]:
# normal construction
Quantity.from_([1, 2, 3], 'm')
[21]:
[22]:
# From a Quantity
Quantity.from_(x)
[22]:
[23]:
# Also changing the unit
Quantity.from_(x, "km")
[23]:
[24]:
Quantity.from_._f.methods
[24]:
List of 9 method(s): [0] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: typing.Union[ArrayLike, list[typing.Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex]], tuple[typing.Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex], ...]], unit: Any, *, dtype) -> unxt._src.quantity.base.AbstractQuantity <function from_ at 0x116d8f6a0> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:799 [1] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: typing.Union[ArrayLike, list[typing.Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex]], tuple[typing.Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex], ...]], *, unit, dtype) -> unxt._src.quantity.base.AbstractQuantity <function from_ at 0x116d8f9c0> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:838 [2] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], *, value, unit, dtype) -> unxt._src.quantity.base.AbstractQuantity <function from_ at 0x116d8fce0> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:863 [3] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], mapping: collections.abc.Mapping[str, typing.Any]) -> unxt._src.quantity.base.AbstractQuantity <function from_ at 0x116dac220> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:883 [4] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: unxt._src.quantity.base.AbstractQuantity, unit: Any, *, dtype) -> unxt._src.quantity.base.AbstractQuantity <function from_ at 0x116dac360> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:909 [5] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: unxt._src.quantity.base.AbstractQuantity, unit: NoneType, *, dtype) -> unxt._src.quantity.base.AbstractQuantity <function from_ at 0x116dac680> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:935 [6] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: unxt._src.quantity.base.AbstractQuantity, *, unit, dtype) -> unxt._src.quantity.base.AbstractQuantity <function from_ at 0x116dac9a0> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/jaxtyping/_decorator.py:962 [7] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: astropy.units.quantity.Quantity, *, **kwargs) -> unxt._src.quantity.base.AbstractQuantity <function from_ at 0x116ff7a60> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/unxt/_interop/unxt_interop_astropy/quantity.py:53 [8] from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: astropy.units.quantity.Quantity, u: Any, *, **kwargs) -> unxt._src.quantity.base.AbstractQuantity <function from_ at 0x116fc2980> @ ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/unxt/_interop/unxt_interop_astropy/quantity.py:75
[25]:
# Using a dictionary
Quantity.from_({"value": x, "unit": "km"})
[25]:
[26]:
Quantity.from_?
Signature: Quantity.from_(*args: Any, **kwargs: Any) -> 'AbstractQuantity'
Docstring:
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: Union[ArrayLike, list[Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex]], tuple[Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex], ...]], unit: Any, /, *, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity
Construct a `unxt.Quantity` from an array-like value and a unit.
:param value: The array-like value.
:param unit: The unit of the value.
:param dtype: The data type of the array (keyword-only).
Examples
--------
For this example we'll use the `Quantity` class. The same applies to
any subclass of `AbstractQuantity`.
>>> import jax.numpy as jnp
>>> import unxt as u
>>> x = jnp.array([1.0, 2, 3])
>>> u.Quantity.from_(x, "m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
>>> u.Quantity.from_([1.0, 2, 3], "m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
>>> u.Quantity.from_((1.0, 2, 3), "m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: Union[ArrayLike, list[Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex]], tuple[Union[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[ndarray, ''], numpy.bool, numpy.number, bool, int, float, complex], ...]], /, *, unit: Any, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity
Make a `unxt.AbstractQuantity` from an array-like value and a unit kwarg.
Examples
--------
For this example we'll use the `unxt.Quantity` class. The same applies
to any subclass of `unxt.AbstractQuantity`.
>>> import unxt as u
>>> u.Quantity.from_([1.0, 2, 3], unit="m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
from_(cls: type[unxt._src.quantity.base.AbstractQuantity], *, value: Any, unit: Any, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity
Construct a `AbstractQuantity` from value and unit kwargs.
Examples
--------
For this example we'll use the `Quantity` class. The same applies to
any subclass of `AbstractQuantity`.
>>> import unxt as u
>>> u.Quantity.from_(value=[1.0, 2, 3], unit="m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
from_(cls: type[unxt._src.quantity.base.AbstractQuantity], mapping: collections.abc.Mapping[str, typing.Any]) -> unxt._src.quantity.base.AbstractQuantity
Construct a `Quantity` from a Mapping.
Examples
--------
For this example we'll use the `Quantity` class. The same applies to
any subclass of `AbstractQuantity`.
>>> import jax.numpy as jnp
>>> import unxt as u
>>> x = jnp.array([1.0, 2, 3])
>>> q = u.Quantity.from_({"value": x, "unit": "m"})
>>> q
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
>>> u.Quantity.from_({"value": q, "unit": "km"})
Quantity(Array([0.001, 0.002, 0.003], dtype=float32), unit='km')
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: unxt._src.quantity.base.AbstractQuantity, unit: Any, /, *, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity
Construct a `Quantity` from another `Quantity`.
The `value` is converted to the new `unit`.
Examples
--------
>>> import unxt as u
>>> q = u.Quantity(1, "m")
>>> u.Quantity.from_(q, "cm")
Quantity(Array(100., dtype=float32, ...), unit='cm')
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: unxt._src.quantity.base.AbstractQuantity, unit: NoneType, /, *, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity
Construct a `Quantity` from another `Quantity`.
The `value` is converted to the new `unit`.
Examples
--------
>>> import unxt as u
>>> q = u.Quantity(1, "m")
>>> u.Quantity.from_(q, None)
Quantity(Array(1, dtype=int32, ...), unit='m')
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: unxt._src.quantity.base.AbstractQuantity, /, *, unit: typing.Any | None = None, dtype: Any = None) -> unxt._src.quantity.base.AbstractQuantity
Construct a `Quantity` from another `Quantity`, with no unit change.
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: astropy.units.quantity.Quantity, /, **kwargs: Any) -> unxt._src.quantity.base.AbstractQuantity
Construct a `Quantity` from another `Quantity`.
The `value` is converted to the new `unit`.
Examples
--------
>>> import unxt as u
>>> import astropy.units as apyu
>>> u.Quantity.from_(apyu.Quantity(1, "m"))
Quantity(Array(1., dtype=float32), unit='m')
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
from_(cls: type[unxt._src.quantity.base.AbstractQuantity], value: astropy.units.quantity.Quantity, u: Any, /, **kwargs: Any) -> unxt._src.quantity.base.AbstractQuantity
Construct a `Quantity` from another `Quantity`.
The `value` is converted to the new `unit`.
Examples
--------
>>> import unxt as u
>>> import astropy.units as apyu
>>> u.Quantity.from_(apyu.Quantity(1, "m"), "cm")
Quantity(Array(100., dtype=float32), unit='cm')
File: ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/unxt/_src/quantity/base.py
Type: method
Non-Parametric#
Parametric types can be slower (it’s pretty marginal). If in a hot-path it ends up being a limiting factor, unxt
offers a non-parametric Quantity.
[27]:
from unxt.quantity import BareQuantity
[28]:
x = BareQuantity(100, "km")
x
[28]:
[29]:
qnp.cbrt(x)
[29]:
[ ]:
Distance / Parallax#
This is similar to Astropy’s astropy.coordinates.Distance
[30]:
from coordinax import Distance
[31]:
d = Distance([1, 2, 3], 'm')
d
[31]:
We can use it as normal
[32]:
qnp.sqrt(d)
[32]:
Like Astropy, this has methods for converting between distance measures
[33]:
d.distance
[33]:
[34]:
d.parallax
[34]:
[35]:
d.distance_modulus
[35]:
There are convenient constructors
Parallax:
[36]:
Distance.from_(Quantity([1.], "rad"))
[36]:
[ ]:
See the Parallax class
[37]:
from coordinax.angle import Parallax
[38]:
p = Parallax([1., 2, 3], 'mas')
p
[38]:
[39]:
p.distance
[39]:
[40]:
p.parallax
[40]:
[41]:
p.distance_modulus
[41]:
[ ]:
Exercise: Build a DistanceModulus class
Hint, check out coordinax.distance.DistanceModulus
[ ]:
[ ]:
Unit Systems#
If you’ve used gala.units
then you’re familiar with Astropy-compatible unit systems!
[42]:
from unxt.unitsystems import unitsystem, galactic, solarsystem
[43]:
unitsystem?
Signature: unitsystem(usys: unxt._src.unitsystems.base.AbstractUnitSystem, /) -> unxt._src.unitsystems.base.AbstractUnitSystem
Call signature: unitsystem(*args, **kw_args)
Type: Function
String form: <multiple-dispatch function unitsystem (with 9 registered and 0 pending method(s))>
File: ~/Documents/Academia/Conferences & Presentations/2025-09-Thunch and Flatiron/Flatiron Presentation/.venv/lib/python3.12/site-packages/unxt/_src/unitsystems/core.py
Docstring:
Convert a UnitSystem to a UnitSystem.
Examples
--------
>>> from unxt.unitsystems import unitsystem
>>> usys = unitsystem("kpc", "Myr", "Msun", "radian")
>>> usys
unitsystem(kpc, Myr, solMass, rad)
>>> unitsystem(usys) is usys
True
-----------------------------------------------------------------------------------------------------------------------------------------------------------
unitsystem(seq: collections.abc.Sequence[typing.Any], /) -> unxt._src.unitsystems.base.AbstractUnitSystem
Convert a UnitSystem or tuple of arguments to a UnitSystem.
Examples
--------
>>> import unxt as u
>>> u.unitsystem(())
DimensionlessUnitSystem()
>>> u.unitsystem(("kpc", "Myr", "Msun", "radian"))
unitsystem(kpc, Myr, solMass, rad)
>>> u.unitsystem(["kpc", "Myr", "Msun", "radian"])
unitsystem(kpc, Myr, solMass, rad)
-----------------------------------------------------------------------------------------------------------------------------------------------------------
unitsystem(_: NoneType, /) -> unxt._src.unitsystems.builtin.DimensionlessUnitSystem
Dimensionless unit system from None.
Examples
--------
>>> from unxt.unitsystems import unitsystem
>>> unitsystem(None)
DimensionlessUnitSystem()
-----------------------------------------------------------------------------------------------------------------------------------------------------------
unitsystem(*args: Any) -> unxt._src.unitsystems.base.AbstractUnitSystem
Convert a set of arguments to a UnitSystem.
Examples
--------
>>> from unxt.unitsystems import unitsystem
>>> unitsystem("kpc", "Myr", "Msun", "radian")
unitsystem(kpc, Myr, solMass, rad)
-----------------------------------------------------------------------------------------------------------------------------------------------------------
unitsystem(name: str, /) -> unxt._src.unitsystems.base.AbstractUnitSystem
Return unit system from name.
Examples
--------
>>> from unxt.unitsystems import unitsystem
>>> unitsystem("galactic")
unitsystem(kpc, Myr, solMass, rad)
>>> unitsystem("solarsystem")
unitsystem(AU, yr, solMass, rad)
>>> unitsystem("dimensionless")
DimensionlessUnitSystem()
-----------------------------------------------------------------------------------------------------------------------------------------------------------
unitsystem(usys: unxt._src.unitsystems.base.AbstractUnitSystem, *args: Any) -> unxt._src.unitsystems.base.AbstractUnitSystem
Create a unit system from an existing unit system and additional units.
Examples
--------
We can add a new unit definition to an existing unit system:
>>> from unxt.unitsystems import unitsystem
>>> usys = unitsystem("galactic")
>>> unitsystem(usys, "km/s")
LengthTimeMassAngleSpeedUnitSystem(length=Unit("kpc"), time=Unit("Myr"), mass=Unit("solMass"), angle=Unit("rad"), speed=Unit("km / s"))
We can also override the base unit of an existing unit system:
>>> new_usys = unitsystem(usys, "pc")
>>> new_usys
TimeMassAngleLengthUnitSystem(time=Unit("Myr"), mass=Unit("solMass"), angle=Unit("rad"), length=Unit("pc"))
-----------------------------------------------------------------------------------------------------------------------------------------------------------
unitsystem(flag: type[unxt._src.unitsystems.flags.AbstractUSysFlag], *_: Any) -> unxt._src.unitsystems.base.AbstractUnitSystem
Raise an exception since the flag is abstract.
-----------------------------------------------------------------------------------------------------------------------------------------------------------
unitsystem(flag: type[unxt._src.unitsystems.flags.StandardUSysFlag], *args: Any) -> unxt._src.unitsystems.base.AbstractUnitSystem
Create a standard unit system using the inputted units.
Examples
--------
>>> from unxt import unitsystem, unitsystems
>>> unitsystem(unitsystems.StandardUSysFlag, "kpc", "Myr", "Msun")
LengthTimeMassUnitSystem(length=Unit("kpc"), time=Unit("Myr"), mass=Unit("solMass"))
-----------------------------------------------------------------------------------------------------------------------------------------------------------
unitsystem(flag: type[unxt._src.unitsystems.flags.DynamicalSimUSysFlag], *args: Any, G: float | int = 1.0) -> unxt._src.unitsystems.base.AbstractUnitSystem
Make a dynamical unit system.
Examples
--------
>>> from unxt.unitsystems import unitsystem, DynamicalSimUSysFlag
>>> unitsystem(DynamicalSimUSysFlag, "m", "kg")
LengthMassTimeUnitSystem(length=Unit("m"), mass=Unit("kg"), time=Unit("122404 s"))
Class docstring:
A function.
Args:
f (function): Function that is wrapped.
owner (str, optional): Name of the class that owns the function.
warn_redefinition (bool, optional): Throw a warning whenever a method is
redefined. Defaults to `False`.
[44]:
usys = unitsystem("kpc", "Gyr", "solMass", "deg")
usys
[44]:
unitsystem(kpc, Gyr, solMass, deg)
There are also pre-defined unit systems
[45]:
galactic
[45]:
unitsystem(kpc, Myr, solMass, rad)
[46]:
solarsystem
[46]:
unitsystem(AU, yr, solMass, rad)
The unit systems are useful when working with Quantities.
[47]:
print(x)
x.decompose(usys)
BareQuantity(weak_i32[], unit='km')
[47]:
[48]:
x = 1e4 * Quantity([1, 2, 3], 'lyr')
x.uconvert(usys["length"])
[48]:
[49]:
Quantity([1, 2, 3], 'lyr / mas').decompose(usys)
[49]:
[ ]:
Astropy Compatibility#
unxt.Quantity
can be made from an Astropy Quantity
[50]:
x = Quantity.from_(u.Quantity(1, "kpc"))
x
[50]:
And it can be converted to Astropy
[51]:
from plum import convert
convert(x, u.Quantity)
[51]:
[ ]:
Differentiation#
unxt
fully integrates with JAX
differentiation.
[52]:
def polynomial(
x: Float[Quantity["area"], "N"],
a0: Float[Quantity["volume"], ""],
a1: Float[Quantity["length"], ""],
a2: Float[Quantity["1/m"], ""],
) -> Float[Quantity["volume"], "N"]:
return a0 + a1 * x ** 1 + a2 * x ** 2
polynomial
[52]:
<function __main__.polynomial(x: jaxtyping.Float[Quantity[PhysicalType('area')], 'N'], a0: jaxtyping.Float[Quantity[PhysicalType('volume')], ''], a1: jaxtyping.Float[Quantity[PhysicalType('length')], ''], a2: jaxtyping.Float[Quantity[PhysicalType('wavenumber')], '']) -> jaxtyping.Float[Quantity[PhysicalType('volume')], 'N']>
[53]:
x = Quantity(1.0, "km2")
a0, a1, a2 = Quantity(1, "km3"), Quantity(2, "km"), Quantity(3, "1/km")
[54]:
polynomial(x, a0, a1, a2)
[54]:
[55]:
grad_fn = quaxify(jax.grad(polynomial))
grad_fn(x, a0, a1, a2)
[55]:
[56]:
# Alternatively:
from quaxed import grad as qgrad
qgrad(polynomial)(x, a0, a1, a2)
[56]:
[ ]: