JAX: The Missing Manual (for Research)
This guide serves as a practical, "cheatsheet-style" reference for using JAX. It focuses on the mental models required to use it effectively, specifically addressing common confusion points like dimensions, vmap, and state management.
1. The Core Mental Model
JAX is not just "NumPy with GPU support." It is a compiler for function transformations.
- Pure Functions: JAX assumes your functions are pure: same input, same output; no side effects (no printing*, no global state mutation).
- Static vs. Traced: JAX traces your code once with "abstract tracers" (placeholders) to build a computation graph, then compiles it (XLA). Python control flow (
if,for) happens during tracing. If it depends on data, it vanishes unless you use JAX primitives (jax.lax.cond,jax.lax.scan).
2. Managing State (The State Pattern)
equinox.Module is a PyTree.
What is a PyTree?
A PyTree is just a container of arrays (lists, tuples, dicts, custom classes). JAX functions like jit, grad, and vmap operate on the leaves (the arrays inside) while preserving the structure.
Your State Object
import jax.numpy as jnp
import equinox as eqx
class State(eqx.Module):
# These are the leaves (arrays)
pos: jnp.ndarray
vel: jnp.ndarray
# This is static metadata (ignored by JAX derivatives/batching)
step_count: int = eqx.field(static=True)
# Initialization
state = State(
pos=jnp.zeros(2),
vel=jnp.zeros(2),
step_count=0
)
The "Functional Update" Rule
In JAX, you never mutate objects in place. You always return a new object.
❌ BAD (Mutating - Won't work in JIT):
✅ GOOD (Functional - Returns new PyTree):
import equinox as eqx
@jax.jit
def update(state):
# 'eqx.tree_at' is a helper to update PyTrees functionally
# It says: "Return a copy of 'state' where 'pos' is replaced by 'new_pos'"
new_pos = state.pos + 0.1
state = eqx.tree_at(lambda s: s.pos, state, new_pos)
return state
3. vmap and in_axes
jax.vmap (Vectorizing Map) transforms a function meant for a single sample into a function that handles a batch of samples.
The Golden Rule: vmap pushes a dimension down. If your input has shape (10, 3) and you vmap over it, inside the function, JAX sees shape (3).
Understanding in_axes
in_axes tells JAX which dimension is the batch dimension for each argument.
Imagine a function:
We want to run this on a batch of 100 states.
Scenario A: Batch of States, Shared dt
We have 100 states, but dt is the same for all of them.
# Batch dimension is 0 for 'state', but 'dt' is NOT batched (None).
batched_step = jax.vmap(step, in_axes=(0, None))
# Input shapes:
# batch_states: State(pos=(100, 2), vel=(100, 2))
# dt: scalar
result = batched_step(batch_states, 0.1)
# Result shape: (100, 2)
Scenario B: Batch of States, Batch of dts
Each state has its own unique time delta.
# Batch dimension is 0 for 'state', AND 0 for 'dt'.
batched_step = jax.vmap(step, in_axes=(0, 0))
# Input shapes:
# batch_states: State(pos=(100, 2), vel=(100, 2))
# dts: (100,)
result = batched_step(batch_states, dts)
Scenario C: Complex PyTree Structures
If your State is a complex nested tree, in_axes=0 automatically means "the 0-th dimension of every array leaf inside this object".
If you have a tuple (x, y) and you want x batched but y shared:
# Function: f(pair) -> where pair is (x, y)
# in_axes must match the structure of the arguments.
# in_axes = (0, None) <-- WRONG, acts on the tuple itself.
# in_axes = ((0, None),) <-- CORRECT, acts on the contents of the first arg.
Note: Usually, simpler signatures like f(x, y) are easier to manage than f((x, y)).
Visualizing in_axes vs out_axes
| Parameter | Value | Meaning |
|---|---|---|
in_axes |
0 |
"Slice this input along the 0-th dimension. Pass one slice at a time." |
in_axes |
1 |
"Slice along the 1st dimension (e.g., columns)." |
in_axes |
None |
"Don't slice. Pass this entire object to every instance." |
out_axes |
0 |
"Stack the results along the 0-th dimension." |
4. Randomness (The "Key" Pattern)
JAX PRNG (Pseudo-Random Number Generation) is explicit. There is no global seed.
The Pattern: Always Split, Then Use.
import jax.random as jr
key = jr.PRNGKey(42)
def step(state, key):
# 1. SPLIT the key into as many as you need, plus one for the next step.
step_key, noise_key, sub_key = jr.split(key, 3)
# 2. USE the sub-keys
noise = jr.normal(noise_key, shape=state.pos.shape)
# 3. RETURN the 'step_key' for the next iteration
return state, step_key
Common Pitfall: Reusing the same key results in identical "random" numbers.
5. Control Flow: scan and cond
Loops: scan vs Python Loops
If you write a Python for loop, JAX unrolls it during compilation.
- Short loop (10 steps): Fine. Fast.
- Long loop (10,000 steps): Compile time explosion. Huge binary.
Use jax.lax.scan for long loops. It compiles into a single "loop" instruction in XLA.
Template for scan (The "Carry" Pattern):
def one_step(carry, x):
# 'carry' is your State
# 'x' is input for this specific time step (optional)
new_state = update(carry)
output = compute_metric(new_state)
return new_state, output
# Run for 1000 steps
final_state, history = jax.lax.scan(one_step, init_state, None, length=1000)
Conditionals: cond vs Python if
- Python
if: Evaluated at trace time. The condition must be known before JAX sees any data (e.g., a hyperparameter flag). jax.lax.cond: Evaluated at runtime (on the GPU/TPU). Both branches are traced.
Template for cond:
# jax.lax.cond(pred, true_fun, false_fun, *operands)
def do_reset(state):
return State.reset()
def keep_going(state):
return state
# If 'done' is true (at runtime), call do_reset(state), else keep_going(state)
next_state = jax.lax.cond(done, do_reset, keep_going, state)
Note: Both branches must return PyTrees with the exact same structure and dtype.
6. JIT Compilation (@jax.jit)
JIT compiles your Python function into optimized machine code (XLA).
- Static Arguments: If a function argument changes the shape of arrays or the structure of the graph (e.g., a boolean flag for "is_training"), it must be marked static.
from functools import partial
@partial(jax.jit, static_argnames=("is_training",))
def forward(params, x, is_training: bool):
if is_training: # This Python 'if' runs once during tracing
return dropout(x)
return x
- Don't JIT everything: Usually, you only JIT the top-level update step. JIT-ing tiny functions inside a loop is useless (and often harmful if the loop itself isn't JIT-ed).
7. Cheatsheet: Common Dimension Operations
Assume x.shape = (B, H, W, C) (Batch, Height, Width, Channel).
| Operation | Code | Result Shape | Note |
|---|---|---|---|
| Add Dim | x[None, ...] |
(1, B, H, W, C) |
Prepend dim |
| Add Dim | x[..., None] |
(B, H, W, C, 1) |
Append dim |
| Flatten | x.reshape(B, -1) |
(B, H*W*C) |
Flatten all but batch |
| Transpose | jnp.transpose(x, (0, 3, 1, 2)) |
(B, C, H, W) |
NHWC -> NCHW |
| Einsum | jnp.einsum('bij,bjk->bik', A, B) |
(B, I, K) |
Batch MatMul |
8. Debugging JAX
Since you can't just print() inside a JIT-ed function (it only prints once during tracing), use:
jax.debug.print:jax.disable_jit(): Temporarily turns off JIT so you can use standard Python debuggers (pdb).
9. Useful APIs / Cheatsheet
Creation & Initialization
| Function | Description | Example |
|---|---|---|
jnp.zeros, jnp.ones |
Create arrays of 0s or 1s | jnp.zeros((2, 3)) |
jnp.full |
Create array with constant value | jnp.full((2,), 3.14) |
jnp.eye |
Identity matrix | jnp.eye(3) |
jnp.arange |
Range of numbers | jnp.arange(0, 10, 2) |
jnp.linspace |
Evenly spaced numbers | jnp.linspace(0, 1, 11) |
jr.normal |
Random normal (Gaussian) | jr.normal(key, (2, 2)) |
jr.uniform |
Random uniform [0, 1) | jr.uniform(key, (5,)) |
jr.randint |
Random integers | jr.randint(key, (5,), 0, 10) |
Array Manipulation
| Function | Description | Example |
|---|---|---|
jnp.stack |
Join arrays along new axis | jnp.stack([x, y]) |
jnp.concatenate |
Join along existing axis | jnp.concatenate([x, y], axis=0) |
jnp.where |
Conditional selection | jnp.where(x > 0, x, 0) |
jnp.squeeze |
Remove dims of size 1 | x.squeeze() |
jnp.expand_dims |
Add dim of size 1 | jnp.expand_dims(x, axis=0) |
jnp.ravel |
Flatten to 1D | x.ravel() |
Math & Logic
| Function | Description | Example |
|---|---|---|
jnp.dot / @ |
Matrix multiplication | A @ B |
jnp.sum, jnp.mean |
Sum/Mean (supports axis) | x.mean(axis=-1) |
jnp.max, jnp.argmax |
Max value / Index of max | x.argmax() |
jnp.clip |
Clip values to range | jnp.clip(x, 0, 1) |
jnp.allclose |
Check if arrays are close | jnp.allclose(x, y) |
jax.nn.softmax |
Softmax function | jax.nn.softmax(logits) |
jax.nn.relu |
ReLU activation | jax.nn.relu(x) |
Functional Transformations
| Function | Description |
|---|---|
jax.jit |
Just-In-Time compilation (speed). |
jax.grad |
Computes gradient (single scalar output). |
jax.value_and_grad |
Returns (value, gradient). |
jax.vmap |
Vectorizes a function (adds batch dim). |
jax.lax.scan |
Efficient loop with carry. |
jax.lax.cond |
Efficient conditional branching. |
10. Einops (Tensor Magic)
Einops is not built-in, but it is the idiomatic way to handle tensor reshaping in JAX research code. It is clearer and safer than reshape/transpose.
APIs: rearrange, reduce, repeat
from einops import rearrange, reduce, repeat
# Assume images: (Batch, Height, Width, Channel)
x = jnp.zeros((32, 64, 64, 3))
1. Rearrange (Permute, Flatten, Split)
Replaces transpose, reshape, and squeeze.
# HWC to CHW (PyTorch style)
x_chw = rearrange(x, 'b h w c -> b c h w')
# Flatten images to vectors
x_flat = rearrange(x, 'b h w c -> b (h w c)')
# Shape: (32, 12288)
# Split heads (for Attention)
# Assume x is (Batch, SeqLen, Dim) -> (Batch, SeqLen, Heads, HeadDim)
q = rearrange(q, 'b s (h d) -> b s h d', h=8)
2. Reduce (Pool, Mean, Max)
Replaces mean, sum, max with explicit axes.
# Global Average Pooling (spatial mean)
# "Reduce over h and w, keep b and c"
y = reduce(x, 'b h w c -> b c', 'mean')
# Shape: (32, 3)
# Max over channels
y = reduce(x, 'b h w c -> b h w', 'max')
3. Repeat (Broadcast)
Replaces tile or expand_dims.
# Repeat a class token (1, D) for every item in batch
cls_token = jnp.zeros((1, 512))
cls_tokens = repeat(cls_token, '1 d -> b d', b=32)
# Shape: (32, 512)
11. PyTrees & JAX Tree Functions
A PyTree is a container of leaf elements and/or more PyTrees. Containers include lists, tuples, dicts, None, and custom classes registered with JAX. Leaves are typically JAX arrays, but can be anything JAX treats as an opaque value.
Key Concepts
- Leaf: An element in a PyTree that is not a PyTree itself (usually a JAX array or scalar).
- Structure (Treedef): The shape/container-type of the PyTree, disregarding the actual leaf values.
Common jax.tree Functions with Examples
JAX provides powerful utilities for working with PyTrees under jax.tree (formerly jax.tree_util).
1. jax.tree.map (Mapping over leaves)
Applies a function to every leaf in the PyTree, preserving the structure.
import jax
import jax.numpy as jnp
tree = {'a': jnp.array([1, 2]), 'b': (jnp.array([3]), jnp.array([4, 5]))}
# Double every element in the tree
doubled_tree = jax.tree.map(lambda x: x * 2, tree)
# Result: {'a': Array([2, 4]), 'b': (Array([6]), Array([8, 10]))}
# Add two trees of the exact same structure
tree2 = {'a': jnp.array([10, 20]), 'b': (jnp.array([30]), jnp.array([40, 50]))}
sum_tree = jax.tree.map(lambda x, y: x + y, tree, tree2)
2. jax.tree.leaves (Extracting leaves)
Extracts all leaves from a PyTree into a flat list. Useful for operations like computing a global metric across all parameters.
tree = {'a': jnp.array([1.0, 2.0]), 'b': jnp.array([3.0])}
leaves = jax.tree.leaves(tree)
# leaves: [Array([1., 2.]), Array([3.])]
# Compute the L2 norm of all parameters in a model
global_norm = jnp.sqrt(sum(jnp.sum(x ** 2) for x in jax.tree.leaves(tree)))
3. jax.tree.reduce (Reducing leaves)
Reduces all leaves to a single value.
tree = {'a': jnp.array([1, 2]), 'b': jnp.array([3, 4])}
# Sum of all elements across all arrays
total_sum = jax.tree.reduce(lambda acc, x: acc + jnp.sum(x), tree, 0.0)
4. jax.tree.flatten and jax.tree.unflatten
Useful when you need to interface with APIs that only accept flat lists of arrays (e.g., scipy.optimize).
tree = {'weights': jnp.ones((2, 2)), 'bias': jnp.zeros(2)}
# Flatten
flat_leaves, treedef = jax.tree.flatten(tree)
# ... do some flat operations ...
new_flat_leaves = [leaf * 2 for leaf in flat_leaves]
# Unflatten back to the original structure
new_tree = jax.tree.unflatten(treedef, new_flat_leaves)
Custom PyTrees
If you don't use equinox.Module or flax.struct.dataclass, you must register your custom classes with JAX so it knows how to flatten and unflatten them. Otherwise, JAX treats the object itself as a single opaque leaf.
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class MyTree:
def __init__(self, a, b):
self.a = a # Leaf
self.b = b # Leaf
def tree_flatten(self):
# Return a tuple of (children, auxiliary_data)
children = (self.a, self.b)
aux_data = None
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
# Reconstruct the object
return cls(*children)
# Now jax.tree.map will go inside MyTree to modify 'a' and 'b'
obj = MyTree(jnp.array([1]), jnp.array([2]))
new_obj = jax.tree.map(lambda x: x * 10, obj)
12. Equinox Tree Operations
Equinox provides additional PyTree utilities that are essential for deep learning, specifically for managing state and model parameters.
1. eqx.tree_at (Functional Update)
The most common way to "mutate" an immutable PyTree. JAX arrays and Equinox modules are immutable; tree_at returns a new copy with the specified leaf modified.
import equinox as eqx
# Replace a leaf
# Syntax: eqx.tree_at(where_fn, pytree, replace_with)
new_state = eqx.tree_at(lambda s: s.weights, state, new_weights)
# Apply a function to a leaf
new_state = eqx.tree_at(lambda s: s.step, state, lambda x: x + 1)
2. eqx.partition and eqx.combine
Used to separate a PyTree into parts based on a filter function. This is critical for functions like jax.jit or jax.grad which require all inputs to be JAX arrays (unless marked as static).
# Partition into JAX-compatible arrays and "static" metadata (strings, bools, etc.)
arrays, static = eqx.partition(model, eqx.is_array)
# Now 'arrays' can be safely passed into JAX transformations
# 'static' is usually handled by closures or marked as static in JIT
# Reconstruct the original model structure
model = eqx.combine(arrays, static)
3. eqx.tree_serialise_leaves and eqx.tree_deserialise_leaves
The standard way to save and load Equinox models. It uses safetensors under the hood if available, or numpy.
# Save model to disk
eqx.tree_serialise_leaves("model.eqx", model)
# Load from disk
# Requires a 'template' model that has the exact same structure as the saved one.
model = eqx.tree_deserialise_leaves("model.eqx", model)
4. Filtering Utilities
Equinox provides predicates to help with partitioning and mapping.
| Predicate | Description |
|---|---|
eqx.is_array |
Matches any JAX array or scalar. |
eqx.is_inexact_array |
Matches floating point JAX arrays (useful for trainable params). |
eqx.is_array_like |
Matches arrays or things that can be converted to arrays. |