Distributed Computing in JAX

Distributed Computing in JAX#

Author: Philip Mocz (CCA, 2025)

Flatiron Institute

see also:

examples of parallel JAX code:

Sharding#

JAX uses sharding, i.e., the concept of splitting arrays/tensors across GPUs and (in most cases automatically) performing communications to enable distributed computing.

[ ]:
import os
import jax
import jax.numpy as jnp

from jax.experimental import mesh_utils
from jax.experimental.custom_partitioning import custom_partitioning
from jax.sharding import Mesh, PartitionSpec, NamedSharding
[ ]:
# Set up distributed computing in JAX

USE_CPU = True

if USE_CPU:
  # create virtual devices on a CPU for testing/debugging
  flags = os.environ.get("XLA_FLAGS", "")
  flags += " --xla_force_host_platform_device_count=8"  # change to, e.g., 8 for testing sharding virtually
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
  os.environ["XLA_FLAGS"] = flags
  print("Using CPU mode with 8 virtual devices")
else:
  # initialize distributed computing
  jax.distributed.initialize()
  if jax.process_index() == 0:
    print("Using GPU distributed mode")
Using CPU mode with 8 virtual devices
[ ]:
# Print some info about environment/devices, and set up sharding

# Create mesh and sharding for distributed computation
n_devices = jax.device_count()
devices = mesh_utils.create_device_mesh((n_devices,))
mesh = Mesh(devices, axis_names=("gpus",))
sharding = NamedSharding(mesh, PartitionSpec(None, "gpus"))

if jax.process_index() == 0:
  for env_var in [
    "SLURM_JOB_ID",
    "SLURM_NTASKS",
    "SLURM_NODELIST",
    "SLURM_STEP_NODELIST",
    "SLURM_STEP_GPUS",
    "SLURM_GPUS",
  ]:
    print(f"{env_var}: {os.getenv(env_var, '')}")
  print("Total number of processes: ", jax.process_count())
  print("Total number of devices: ", jax.device_count())
  print("List of devices: ", jax.devices())
  print("Number of devices on this process: ", jax.local_device_count())
SLURM_JOB_ID:
SLURM_NTASKS:
SLURM_NODELIST:
SLURM_STEP_NODELIST:
SLURM_STEP_GPUS:
SLURM_GPUS:
Total number of processes:  1
Total number of devices:  8
List of devices:  [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
Number of devices on this process:  8
[ ]:
# Let's shard an array:

n = 8  # array size
xlin = jnp.linspace(0, n-1, n)

xx, yy = jnp.meshgrid(xlin, xlin, indexing="ij")
print("regular array:")
jax.debug.visualize_array_sharding(xx)
print(xx)
print("")

xx_sharded = jax.lax.with_sharding_constraint(xx, sharding)  # (you can also use jax.device_put())
print("sharded array:")
jax.debug.visualize_array_sharding(xx_sharded)
print(xx_sharded)
regular array:
                         
                         
                         
                         
                         
          CPU 0          
                         
                         
                         
                         
                         
[[0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 1.]
 [2. 2. 2. 2. 2. 2. 2. 2.]
 [3. 3. 3. 3. 3. 3. 3. 3.]
 [4. 4. 4. 4. 4. 4. 4. 4.]
 [5. 5. 5. 5. 5. 5. 5. 5.]
 [6. 6. 6. 6. 6. 6. 6. 6.]
 [7. 7. 7. 7. 7. 7. 7. 7.]]

sharded array:
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        
  CPU 0    CPU 1    CPU 2    CPU 3    CPU 4    CPU 5    CPU 6    CPU 7  
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        
[[0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 1.]
 [2. 2. 2. 2. 2. 2. 2. 2.]
 [3. 3. 3. 3. 3. 3. 3. 3.]
 [4. 4. 4. 4. 4. 4. 4. 4.]
 [5. 5. 5. 5. 5. 5. 5. 5.]
 [6. 6. 6. 6. 6. 6. 6. 6.]
 [7. 7. 7. 7. 7. 7. 7. 7.]]
[ ]:
# What about compositions?
# Jax will automatically determine sharding

xx_sharded_sq = xx_sharded**2
print("xx_sharded**2:")
print(xx_sharded_sq.shape)
jax.debug.visualize_array_sharding(xx_sharded_sq)
print("")

xx_sum = xx + xx_sharded
print("xx + xx_sharded:")
print(xx_sum.shape)
jax.debug.visualize_array_sharding(xx_sum)
print("")

xx_mean_0 = jnp.mean(xx, axis=0)
print("jnp.mean(xx_sharded,axis=0)")
print(xx_mean_0.shape)
jax.debug.visualize_array_sharding(xx_mean_0)
print("")

xx_mean_1 = jnp.mean(xx, axis=1)
print("jnp.mean(xx_sharded,axis=1)")
print(xx_mean_1.shape)
jax.debug.visualize_array_sharding(xx_mean_1)

xx_sharded**2:
(8, 8)
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        
  CPU 0    CPU 1    CPU 2    CPU 3    CPU 4    CPU 5    CPU 6    CPU 7  
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        

xx + xx_sharded:
(8, 8)
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        
  CPU 0    CPU 1    CPU 2    CPU 3    CPU 4    CPU 5    CPU 6    CPU 7  
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        

jnp.mean(xx_sharded,axis=0)
(8,)
  CPU 0  
         

jnp.mean(xx_sharded,axis=1)
(8,)
  CPU 0  
         
[ ]:
xx_roll = jnp.roll(xx_sharded, (0,1))
print("jnp.roll(xx_sharded, (0,1))")
print(xx_roll.shape)
jax.debug.visualize_array_sharding(xx_roll)
jnp.roll(xx_sharded, (0,1))
(8, 8)
                         
                         
                         
                         
                         
   CPU 0,1,2,3,4,5,6,7   
                         
                         
                         
                         
                         
[ ]:
# OK, but if I build a global array and then shard, can't I run out of memory?

# Yes! Here is a solution to avoid that:

# Make a distributed meshgrid function
def xmeshgrid(xlin):
    xx, yy = jnp.meshgrid(xlin, xlin, indexing="ij")
    return xx, yy


xmeshgrid_jit = jax.jit(xmeshgrid, in_shardings=None, out_shardings=sharding)

xx, yy = xmeshgrid_jit(xlin)

print("sharded array:")
jax.debug.visualize_array_sharding(xx)

sharded array:
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        
  CPU 0    CPU 1    CPU 2    CPU 3    CPU 4    CPU 5    CPU 6    CPU 7  
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        
[ ]:

# Let's make the result of roll that we saw earlier sharded def xroll(xx): return jnp.roll(xx_sharded, (0,1)) xroll_jit = jax.jit(xroll, in_shardings=sharding, out_shardings=sharding) xx_roll = xroll_jit(xx_sharded) print("xroll_jit(xx_sharded)") print(xx_roll.shape) jax.debug.visualize_array_sharding(xx_roll)
xroll_jit(xx_sharded)
(8, 8)
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        
  CPU 0    CPU 1    CPU 2    CPU 3    CPU 4    CPU 5    CPU 6    CPU 7  
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        

Sharded FFT example#

[ ]:
import jax.numpy.fft as jfft
from typing import Callable

# What if you want to calculate a 3D fft in a distributed manner? Can we simply apply the above wrapping to jfft.fftn?
# We could, but Jax is not smart enough to optimize communication and at some intermediate step will end up creating a global array on each device and OOM.
# So we do the following instead

# decompose a 3D FFT into a 1D fft in Z and a 2D fft in XY

def fft_partitioner(
  fft_func: Callable[[jax.Array], jax.Array],
  partition_spec: PartitionSpec,
):
  @custom_partitioning
  def func(x):
    return fft_func(x)

  def supported_sharding(sharding, shape):
    return NamedSharding(sharding.mesh, partition_spec)

  def partition(mesh, arg_shapes, result_shape):
    # result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
    arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
    return (
      mesh,
      fft_func,
      supported_sharding(arg_shardings[0], arg_shapes[0]),
      (supported_sharding(arg_shardings[0], arg_shapes[0]),),
    )

  def infer_sharding_from_operands(mesh, arg_shapes, shape):
    arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
    return supported_sharding(arg_shardings[0], arg_shapes[0])

  func.def_partition(
    infer_sharding_from_operands=infer_sharding_from_operands,
    partition=partition,
    sharding_rule="i j k -> i j k",
  )
  return func


def _fft_XY(x):
  return jfft.fftn(x, axes=[0, 1])


def _fft_Z(x):
    return jfft.fft(x, axis=2)


def _ifft_XY(x):
    return jfft.ifftn(x, axes=[0, 1])


def _ifft_Z(x):
    return jfft.ifft(x, axis=2)


# fft_XY/ifft_XY: operate on 2D slices (axes [0,1])
# fft_Z/ifft_Z: operate on 1D slices (axis 2)
fft_XY = fft_partitioner(_fft_XY, PartitionSpec(None, None, "gpus"))
fft_Z = fft_partitioner(_fft_Z, PartitionSpec(None, "gpus"))
ifft_XY = fft_partitioner(_ifft_XY, PartitionSpec(None, None, "gpus"))
ifft_Z = fft_partitioner(_ifft_Z, PartitionSpec(None, "gpus"))


def xfft3d(x):
  x = fft_Z(x)
  x = fft_XY(x)
  return x


def ixfft3d(x):
  x = ifft_XY(x)
  x = ifft_Z(x)
  return x


# set up xfft (distributed version of jfft)
xfft3d_jit = jax.jit(
  xfft3d,
  in_shardings=sharding,
  out_shardings=sharding,
)

ixfft3d_jit = jax.jit(
  ixfft3d,
  in_shardings=sharding,
  out_shardings=sharding,
 )
[ ]:
# even more optimized is the jaxdecomp library.
# (avoids a transpose & transpose back in kspace)
# implements both slab- and pencil-based fft decomposition

!pip install jaxdecomp

import jaxdecomp as jd

my_fftn = jd.fft.pfft3d
my_ifftn = jd.fft.pifft3d



Collecting jaxdecomp
  Downloading jaxdecomp-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10.0 kB)
Collecting jaxtyping>=0.2.0 (from jaxdecomp)
  Downloading jaxtyping-0.3.3-py3-none-any.whl.metadata (7.8 kB)
Requirement already satisfied: jax>=0.4.35 in /usr/local/lib/python3.12/dist-packages (from jaxdecomp) (0.5.3)
Requirement already satisfied: jaxlib<=0.5.3,>=0.5.3 in /usr/local/lib/python3.12/dist-packages (from jax>=0.4.35->jaxdecomp) (0.5.3)
Requirement already satisfied: ml_dtypes>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from jax>=0.4.35->jaxdecomp) (0.5.3)
Requirement already satisfied: numpy>=1.25 in /usr/local/lib/python3.12/dist-packages (from jax>=0.4.35->jaxdecomp) (2.0.2)
Requirement already satisfied: opt_einsum in /usr/local/lib/python3.12/dist-packages (from jax>=0.4.35->jaxdecomp) (3.4.0)
Requirement already satisfied: scipy>=1.11.1 in /usr/local/lib/python3.12/dist-packages (from jax>=0.4.35->jaxdecomp) (1.16.2)
Collecting wadler-lindig>=0.1.3 (from jaxtyping>=0.2.0->jaxdecomp)
  Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Downloading jaxdecomp-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (170 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 170.4/170.4 kB 2.8 MB/s eta 0:00:00
Downloading jaxtyping-0.3.3-py3-none-any.whl (55 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 55.9/55.9 kB 3.6 MB/s eta 0:00:00
Downloading wadler_lindig-0.1.7-py3-none-any.whl (20 kB)
Installing collected packages: wadler-lindig, jaxtyping, jaxdecomp
Successfully installed jaxdecomp-0.2.8 jaxtyping-0.3.3 wadler-lindig-0.1.7

Notes:#

shard_map#

[ ]:
#  single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data

# See: https://docs.jax.dev/en/latest/notebooks/shard_map.html