Distributed Computing in JAX#
Author: Philip Mocz (CCA, 2025)
Flatiron Institute
see also:
examples of parallel JAX code:
Sharding#
JAX uses sharding, i.e., the concept of splitting arrays/tensors across GPUs and (in most cases automatically) performing communications to enable distributed computing.
[ ]:
import os
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.experimental.custom_partitioning import custom_partitioning
from jax.sharding import Mesh, PartitionSpec, NamedSharding
[ ]:
# Set up distributed computing in JAX
USE_CPU = True
if USE_CPU:
# create virtual devices on a CPU for testing/debugging
flags = os.environ.get("XLA_FLAGS", "")
flags += " --xla_force_host_platform_device_count=8" # change to, e.g., 8 for testing sharding virtually
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["XLA_FLAGS"] = flags
print("Using CPU mode with 8 virtual devices")
else:
# initialize distributed computing
jax.distributed.initialize()
if jax.process_index() == 0:
print("Using GPU distributed mode")
Using CPU mode with 8 virtual devices
[ ]:
# Print some info about environment/devices, and set up sharding
# Create mesh and sharding for distributed computation
n_devices = jax.device_count()
devices = mesh_utils.create_device_mesh((n_devices,))
mesh = Mesh(devices, axis_names=("gpus",))
sharding = NamedSharding(mesh, PartitionSpec(None, "gpus"))
if jax.process_index() == 0:
for env_var in [
"SLURM_JOB_ID",
"SLURM_NTASKS",
"SLURM_NODELIST",
"SLURM_STEP_NODELIST",
"SLURM_STEP_GPUS",
"SLURM_GPUS",
]:
print(f"{env_var}: {os.getenv(env_var, '')}")
print("Total number of processes: ", jax.process_count())
print("Total number of devices: ", jax.device_count())
print("List of devices: ", jax.devices())
print("Number of devices on this process: ", jax.local_device_count())
SLURM_JOB_ID:
SLURM_NTASKS:
SLURM_NODELIST:
SLURM_STEP_NODELIST:
SLURM_STEP_GPUS:
SLURM_GPUS:
Total number of processes: 1
Total number of devices: 8
List of devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
Number of devices on this process: 8
[ ]:
# Let's shard an array:
n = 8 # array size
xlin = jnp.linspace(0, n-1, n)
xx, yy = jnp.meshgrid(xlin, xlin, indexing="ij")
print("regular array:")
jax.debug.visualize_array_sharding(xx)
print(xx)
print("")
xx_sharded = jax.lax.with_sharding_constraint(xx, sharding) # (you can also use jax.device_put())
print("sharded array:")
jax.debug.visualize_array_sharding(xx_sharded)
print(xx_sharded)
regular array:
CPU 0
[[0. 0. 0. 0. 0. 0. 0. 0.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[2. 2. 2. 2. 2. 2. 2. 2.]
[3. 3. 3. 3. 3. 3. 3. 3.]
[4. 4. 4. 4. 4. 4. 4. 4.]
[5. 5. 5. 5. 5. 5. 5. 5.]
[6. 6. 6. 6. 6. 6. 6. 6.]
[7. 7. 7. 7. 7. 7. 7. 7.]]
sharded array:
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
[[0. 0. 0. 0. 0. 0. 0. 0.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[2. 2. 2. 2. 2. 2. 2. 2.]
[3. 3. 3. 3. 3. 3. 3. 3.]
[4. 4. 4. 4. 4. 4. 4. 4.]
[5. 5. 5. 5. 5. 5. 5. 5.]
[6. 6. 6. 6. 6. 6. 6. 6.]
[7. 7. 7. 7. 7. 7. 7. 7.]]
[ ]:
# What about compositions?
# Jax will automatically determine sharding
xx_sharded_sq = xx_sharded**2
print("xx_sharded**2:")
print(xx_sharded_sq.shape)
jax.debug.visualize_array_sharding(xx_sharded_sq)
print("")
xx_sum = xx + xx_sharded
print("xx + xx_sharded:")
print(xx_sum.shape)
jax.debug.visualize_array_sharding(xx_sum)
print("")
xx_mean_0 = jnp.mean(xx, axis=0)
print("jnp.mean(xx_sharded,axis=0)")
print(xx_mean_0.shape)
jax.debug.visualize_array_sharding(xx_mean_0)
print("")
xx_mean_1 = jnp.mean(xx, axis=1)
print("jnp.mean(xx_sharded,axis=1)")
print(xx_mean_1.shape)
jax.debug.visualize_array_sharding(xx_mean_1)
xx_sharded**2:
(8, 8)
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
xx + xx_sharded:
(8, 8)
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
jnp.mean(xx_sharded,axis=0)
(8,)
CPU 0
jnp.mean(xx_sharded,axis=1)
(8,)
CPU 0
[ ]:
xx_roll = jnp.roll(xx_sharded, (0,1))
print("jnp.roll(xx_sharded, (0,1))")
print(xx_roll.shape)
jax.debug.visualize_array_sharding(xx_roll)
jnp.roll(xx_sharded, (0,1))
(8, 8)
CPU 0,1,2,3,4,5,6,7
[ ]:
# OK, but if I build a global array and then shard, can't I run out of memory?
# Yes! Here is a solution to avoid that:
# Make a distributed meshgrid function
def xmeshgrid(xlin):
xx, yy = jnp.meshgrid(xlin, xlin, indexing="ij")
return xx, yy
xmeshgrid_jit = jax.jit(xmeshgrid, in_shardings=None, out_shardings=sharding)
xx, yy = xmeshgrid_jit(xlin)
print("sharded array:")
jax.debug.visualize_array_sharding(xx)
sharded array:
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
[ ]:
# Let's make the result of roll that we saw earlier sharded
def xroll(xx):
return jnp.roll(xx_sharded, (0,1))
xroll_jit = jax.jit(xroll, in_shardings=sharding, out_shardings=sharding)
xx_roll = xroll_jit(xx_sharded)
print("xroll_jit(xx_sharded)")
print(xx_roll.shape)
jax.debug.visualize_array_sharding(xx_roll)
xroll_jit(xx_sharded)
(8, 8)
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
Sharded FFT example#
[ ]:
import jax.numpy.fft as jfft
from typing import Callable
# What if you want to calculate a 3D fft in a distributed manner? Can we simply apply the above wrapping to jfft.fftn?
# We could, but Jax is not smart enough to optimize communication and at some intermediate step will end up creating a global array on each device and OOM.
# So we do the following instead
# decompose a 3D FFT into a 1D fft in Z and a 2D fft in XY
def fft_partitioner(
fft_func: Callable[[jax.Array], jax.Array],
partition_spec: PartitionSpec,
):
@custom_partitioning
def func(x):
return fft_func(x)
def supported_sharding(sharding, shape):
return NamedSharding(sharding.mesh, partition_spec)
def partition(mesh, arg_shapes, result_shape):
# result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
return (
mesh,
fft_func,
supported_sharding(arg_shardings[0], arg_shapes[0]),
(supported_sharding(arg_shardings[0], arg_shapes[0]),),
)
def infer_sharding_from_operands(mesh, arg_shapes, shape):
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
return supported_sharding(arg_shardings[0], arg_shapes[0])
func.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition,
sharding_rule="i j k -> i j k",
)
return func
def _fft_XY(x):
return jfft.fftn(x, axes=[0, 1])
def _fft_Z(x):
return jfft.fft(x, axis=2)
def _ifft_XY(x):
return jfft.ifftn(x, axes=[0, 1])
def _ifft_Z(x):
return jfft.ifft(x, axis=2)
# fft_XY/ifft_XY: operate on 2D slices (axes [0,1])
# fft_Z/ifft_Z: operate on 1D slices (axis 2)
fft_XY = fft_partitioner(_fft_XY, PartitionSpec(None, None, "gpus"))
fft_Z = fft_partitioner(_fft_Z, PartitionSpec(None, "gpus"))
ifft_XY = fft_partitioner(_ifft_XY, PartitionSpec(None, None, "gpus"))
ifft_Z = fft_partitioner(_ifft_Z, PartitionSpec(None, "gpus"))
def xfft3d(x):
x = fft_Z(x)
x = fft_XY(x)
return x
def ixfft3d(x):
x = ifft_XY(x)
x = ifft_Z(x)
return x
# set up xfft (distributed version of jfft)
xfft3d_jit = jax.jit(
xfft3d,
in_shardings=sharding,
out_shardings=sharding,
)
ixfft3d_jit = jax.jit(
ixfft3d,
in_shardings=sharding,
out_shardings=sharding,
)
[ ]:
# even more optimized is the jaxdecomp library.
# (avoids a transpose & transpose back in kspace)
# implements both slab- and pencil-based fft decomposition
!pip install jaxdecomp
import jaxdecomp as jd
my_fftn = jd.fft.pfft3d
my_ifftn = jd.fft.pifft3d
Collecting jaxdecomp
Downloading jaxdecomp-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10.0 kB)
Collecting jaxtyping>=0.2.0 (from jaxdecomp)
Downloading jaxtyping-0.3.3-py3-none-any.whl.metadata (7.8 kB)
Requirement already satisfied: jax>=0.4.35 in /usr/local/lib/python3.12/dist-packages (from jaxdecomp) (0.5.3)
Requirement already satisfied: jaxlib<=0.5.3,>=0.5.3 in /usr/local/lib/python3.12/dist-packages (from jax>=0.4.35->jaxdecomp) (0.5.3)
Requirement already satisfied: ml_dtypes>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from jax>=0.4.35->jaxdecomp) (0.5.3)
Requirement already satisfied: numpy>=1.25 in /usr/local/lib/python3.12/dist-packages (from jax>=0.4.35->jaxdecomp) (2.0.2)
Requirement already satisfied: opt_einsum in /usr/local/lib/python3.12/dist-packages (from jax>=0.4.35->jaxdecomp) (3.4.0)
Requirement already satisfied: scipy>=1.11.1 in /usr/local/lib/python3.12/dist-packages (from jax>=0.4.35->jaxdecomp) (1.16.2)
Collecting wadler-lindig>=0.1.3 (from jaxtyping>=0.2.0->jaxdecomp)
Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Downloading jaxdecomp-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (170 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 170.4/170.4 kB 2.8 MB/s eta 0:00:00
Downloading jaxtyping-0.3.3-py3-none-any.whl (55 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 55.9/55.9 kB 3.6 MB/s eta 0:00:00
Downloading wadler_lindig-0.1.7-py3-none-any.whl (20 kB)
Installing collected packages: wadler-lindig, jaxtyping, jaxdecomp
Successfully installed jaxdecomp-0.2.8 jaxtyping-0.3.3 wadler-lindig-0.1.7
Notes:#
concepts like
pjitandpmapare gone as of JAX 0.8.0: https://docs.jax.dev/en/latest/migrate_pmap.html