jax

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

JAX - Autograd and XLA (Accelerated Linear Algebra)

JAX - Autograd与XLA(加速线性代数)

JAX is a framework that combines a NumPy-like API with a powerful system of composable function transformations: Grad (differentiation), Jit (compilation), Vmap (vectorization), and Pmap (parallelization).
JAX是一个结合了类NumPy API与强大可组合函数转换系统的框架:Grad(微分)、Jit(编译)、Vmap(向量化)和Pmap(并行化)。

When to Use

适用场景

  • High-performance scientific simulations requiring GPU/TPU acceleration.
  • Custom machine learning research where PyTorch/TF abstractions are too restrictive.
  • Calculating higher-order derivatives (Hessians, Jacobians) for optimization.
  • Physics-informed machine learning and differentiable simulations.
  • Automatic vectorization of functions (no more manual batching).
  • Running the same code on CPU, GPU, and TPU without changes.
  • 需要GPU/TPU加速的高性能科学模拟。
  • PyTorch/TF抽象过于受限的自定义机器学习研究。
  • 计算用于优化的高阶导数(Hessian、Jacobian)。
  • 物理信息机器学习与可微分模拟。
  • 函数的自动向量化(无需手动批处理)。
  • 无需修改即可在CPU、GPU和TPU上运行相同代码。

Reference Documentation

参考文档

Official docs: https://jax.readthedocs.io/
GitHub: https://github.com/google/jax
Search patterns:
jax.numpy
,
jax.jit
,
jax.grad
,
jax.vmap
,
jax.random
官方文档https://jax.readthedocs.io/
GitHubhttps://github.com/google/jax
搜索关键词
jax.numpy
,
jax.jit
,
jax.grad
,
jax.vmap
,
jax.random

Core Principles

核心原则

Pure Functions (Immutability)

纯函数(不可变性)

JAX is built on functional programming. All functions must be pure: they should not have side effects (like modifying a global variable) and must return the same output for the same input. JAX arrays are immutable.
JAX基于函数式编程构建。所有函数必须是纯函数:不应有副作用(如修改全局变量),且相同输入必须返回相同输出。JAX数组是不可变的。

XLA (Just-In-Time Compilation)

XLA(即时编译)

JAX uses XLA to compile and optimize Python/NumPy code into efficient machine code for specific hardware.
JAX使用XLA将Python/NumPy代码编译并优化为适用于特定硬件的高效机器码。

Manual PRNG Handling

手动PRNG状态管理

Unlike NumPy, JAX requires explicit management of random state (keys) to ensure reproducibility in parallel environments.
与NumPy不同,JAX要求显式管理随机状态(密钥),以确保并行环境中的可复现性。

Quick Reference

快速参考

Installation

安装

bash
undefined
bash
undefined

CPU

CPU版本

pip install jax jaxlib
pip install jax jaxlib

GPU (Check documentation for specific CUDA versions)

GPU版本(请查看文档确认特定CUDA版本)

Standard Imports

标准导入

python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap, random
python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap, random

Basic Pattern - Differentiate and JIT

基础模式 - 微分与JIT编译

python
import jax.numpy as jnp
from jax import grad, jit
python
import jax.numpy as jnp
from jax import grad, jit

1. Define a pure function

1. 定义纯函数

def f(x): return jnp.sin(x) + x**2
def f(x): return jnp.sin(x) + x**2

2. Transform: Create a gradient function

2. 转换:创建梯度函数

df_dx = grad(f)
df_dx = grad(f)

3. Transform: Compile for speed

3. 转换:编译以提升速度

f_fast = jit(f)
f_fast = jit(f)

4. Use

4. 使用

val = f_fast(2.0) slope = df_dx(2.0)
undefined
val = f_fast(2.0) slope = df_dx(2.0)
undefined

Critical Rules

关键规则

✅ DO

✅ 建议做法

  • Use jax.numpy (jnp) - It mirrors NumPy but supports JAX transformations.
  • Write Pure Functions - Ensure functions only depend on inputs and don't modify external state.
  • Handle PRNG Keys Manually - Use
    key, subkey = random.split(key)
    for every random operation.
  • Use vmap for Batching - Write code for a single sample and let JAX handle the batch dimension.
  • Set static_argnums in JIT - If a JIT'ed function takes a non-array argument (like a string or integer used in a loop), mark it as static.
  • In-place updates via .at - Since arrays are immutable, use
    x = x.at[idx].set(val)
    .
  • 使用jax.numpy(jnp) - 它镜像NumPy的API,同时支持JAX的转换功能。
  • 编写纯函数 - 确保函数仅依赖输入,不修改外部状态。
  • 手动处理PRNG密钥 - 每次随机操作都使用
    key, subkey = random.split(key)
  • 使用vmap进行批处理 - 编写单样本代码,由JAX处理批处理维度。
  • 在JIT中设置static_argnums - 如果JIT编译的函数接收非数组参数(如循环中使用的字符串或整数),将其标记为静态参数。
  • 通过.at进行原地更新 - 由于数组不可变,使用
    x = x.at[idx].set(val)

❌ DON'T

❌ 禁止做法

  • Use in-place updates -
    x[idx] = val
    will raise an error.
  • Use standard numpy (np) - Standard NumPy arrays don't support JAX transformations.
  • Use Side Effects - Don't use
    print()
    or modify global variables inside JIT-compiled functions.
  • Forget to block_until_ready() - JAX uses asynchronous execution. For accurate timing, use
    result.block_until_ready()
    .
  • 使用原地更新 -
    x[idx] = val
    会引发错误。
  • 使用标准numpy(np) - 标准NumPy数组不支持JAX的转换功能。
  • 使用副作用操作 - 不要在JIT编译的函数内使用
    print()
    或修改全局变量。
  • 忘记调用block_until_ready() - JAX使用异步执行。如需准确计时,请使用
    result.block_until_ready()

Anti-Patterns (NEVER)

反模式(绝对避免)

python
import jax.numpy as jnp
from jax import jit, random
python
import jax.numpy as jnp
from jax import jit, random

❌ BAD: Modifying a global variable inside a function

❌ 错误:在函数内修改全局变量

counter = 0 @jit def bad_func(x): global counter counter += 1 # ❌ Side effect! Will only run once during compilation return x * 2
counter = 0 @jit def bad_func(x): global counter counter += 1 # ❌ 副作用!仅会在编译时运行一次 return x * 2

❌ BAD: Standard NumPy random (not reproducible/parallel-safe)

❌ 错误:使用标准NumPy随机函数(不可复现/非并行安全)

val = np.random.randn(5)

val = np.random.randn(5)

✅ GOOD: JAX PRNG

✅ 正确:使用JAX PRNG

key = random.key(42) val = random.normal(key, (5,))
key = random.key(42) val = random.normal(key, (5,))

❌ BAD: In-place assignment

❌ 错误:原地赋值

x[0] = 1.0

x[0] = 1.0

✅ GOOD: Functional update

✅ 正确:函数式更新

x = jnp.zeros(5) x = x.at[0].set(1.0)
undefined
x = jnp.zeros(5) x = x.at[0].set(1.0)
undefined

Function Transformations

函数转换

Grad (Differentiation)

Grad(微分)

python
def loss(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y)**2)
python
def loss(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y)**2)

Gradient of loss with respect to the 1st argument (params)

损失函数对第一个参数(params)的梯度

grads = grad(loss)(params, x, y)
grads = grad(loss)(params, x, y)

Higher-order: Hessian

高阶导数:Hessian矩阵

hessian = jax.hessian(loss)(params, x, y)
undefined
hessian = jax.hessian(loss)(params, x, y)
undefined

Jit (Just-In-Time Compilation)

Jit(即时编译)

python
@jit
def complex_math(x):
    # This whole block is compiled into one XLA kernel
    y = jnp.exp(x)
    return jnp.sin(y) / jnp.sqrt(x)
python
@jit
def complex_math(x):
    # 整个代码块会被编译为一个XLA内核
    y = jnp.exp(x)
    return jnp.sin(y) / jnp.sqrt(x)

First call: Compiles (slow)

第一次调用:编译(速度慢)

Subsequent calls: Super fast

后续调用:极快

undefined
undefined

Vmap (Automatic Vectorization)

Vmap(自动向量化)

python
def model(params, x):
    return jnp.dot(params, x)
python
def model(params, x):
    return jnp.dot(params, x)

model works on 1D x. How to apply to a 2D batch of X?

model适用于一维x。如何将其应用于二维批量X?

in_axes=(None, 0): don't map params, map the 0th axis of x

in_axes=(None, 0): 不对params进行映射,对x的第0维进行映射

batch_model = vmap(model, in_axes=(None, 0))
batch_preds = batch_model(params, X_batch)
undefined
batch_model = vmap(model, in_axes=(None, 0))
batch_preds = batch_model(params, X_batch)
undefined

Random Numbers (jax.random)

随机数(jax.random)

The State Management

状态管理

python
key = random.key(0)
python
key = random.key(0)

Never reuse the same key!

永远不要重复使用同一个密钥!

key, subkey = random.split(key) data = random.normal(subkey, (10,))
key, subkey = random.split(key) noise = random.uniform(subkey, (10,))
undefined
key, subkey = random.split(key) data = random.normal(subkey, (10,))
key, subkey = random.split(key) noise = random.uniform(subkey, (10,))
undefined

Working with PyTrees

与PyTrees配合使用

Handling complex data structures (Dicts, Lists, Tuples)

处理复杂数据结构(字典、列表、元组)

JAX transformations work on "PyTrees" — nested containers of arrays.
python
params = {'weights': jnp.ones((5,)), 'bias': 0.0}

def predict(p, x):
    return jnp.dot(x, p['weights']) + p['bias']
JAX的转换功能适用于“PyTrees”——即嵌套的数组容器。
python
params = {'weights': jnp.ones((5,)), 'bias': 0.0}

def predict(p, x):
    return jnp.dot(x, p['weights']) + p['bias']

grad and jit handle the dictionary automatically

grad和jit会自动处理字典类型

grads = grad(predict)(params, x)
undefined
grads = grad(predict)(params, x)
undefined

Practical Workflows

实用工作流

1. Differentiable Physics: Solving a Simple ODE

1. 可微分物理:求解简单常微分方程

python
def system_dynamics(state, t):
    # Simple harmonic oscillator
    x, v = state
    dxdt = v
    dvdt = -0.5 * x
    return jnp.array([dxdt, dvdt])

def loss_fn(initial_state, target_x):
    # Simulate for 10 steps using simple Euler
    state = initial_state
    dt = 0.1
    for i in range(10):
        state += system_dynamics(state, i*dt) * dt
    return (state[0] - target_x)**2
python
def system_dynamics(state, t):
    # 简谐振荡器
    x, v = state
    dxdt = v
    dvdt = -0.5 * x
    return jnp.array([dxdt, dvdt])

def loss_fn(initial_state, target_x):
    # 使用简单欧拉法模拟10步
    state = initial_state
    dt = 0.1
    for i in range(10):
        state += system_dynamics(state, i*dt) * dt
    return (state[0] - target_x)**2

We can take the gradient of the simulation with respect to initial state!

我们可以对模拟过程关于初始状态求梯度!

optimize_initial_state = grad(loss_fn)
undefined
optimize_initial_state = grad(loss_fn)
undefined

2. Parameter Sweep with vmap

2. 使用vmap进行参数扫描

python
def simulation(param):
    # Some complex computation
    return jnp.sum(jnp.linspace(0, param, 100))
python
def simulation(param):
    # 某些复杂计算
    return jnp.sum(jnp.linspace(0, param, 100))

Parallelize simulation over a range of parameters

对一系列参数并行执行模拟

params = jnp.linspace(1, 10, 100) results = vmap(simulation)(params)
undefined
params = jnp.linspace(1, 10, 100) results = vmap(simulation)(params)
undefined

3. Distributed Training with pmap

3. 使用pmap进行分布式训练

python
undefined
python
undefined

pmap replicates the function across multiple GPUs

pmap会在多个GPU上复制函数

(assuming 8 GPUs are available)

(假设存在8个GPU)

x = jnp.zeros((8, 1024))

x = jnp.zeros((8, 1024))

results = pmap(jnp.sin)(x)

results = pmap(jnp.sin)(x)

undefined
undefined

Performance Optimization

性能优化

Static Arguments in JIT

JIT中的静态参数

If your function uses a loop based on an input value, that value must be static.
python
from functools import partial

@partial(jit, static_argnums=(1,))
def power_loop(x, n):
    for i in range(n):
        x = x * x
    return x
如果函数使用基于输入值的循环,该值必须标记为静态参数。
python
from functools import partial

@partial(jit, static_argnums=(1,))
def power_loop(x, n):
    for i in range(n):
        x = x * x
    return x

Avoid Python Control Flow

避免Python控制流

Prefer JAX control flow (
cond
,
while_loop
,
fori_loop
) for better XLA optimization.
python
from jax.lax import cond

def safe_divide(x, y):
    return cond(y == 0, lambda _: 0.0, lambda _: x / y, operand=None)
优先使用JAX控制流(
cond
,
while_loop
,
fori_loop
)以获得更好的XLA优化效果。
python
from jax.lax import cond

def safe_divide(x, y):
    return cond(y == 0, lambda _: 0.0, lambda _: x / y, operand=None)

Common Pitfalls and Solutions

常见陷阱与解决方案

The "Tracer" Error

“Tracer”错误

Inside JIT, JAX doesn't see actual numbers, it sees "Tracers".
python
undefined
在JIT编译中,JAX看不到实际数值,只能看到“Tracer”对象。
python
undefined

❌ Problem:

❌ 问题:

@jit

@jit

def func(x):

def func(x):

if x > 0: return x # Error! JAX doesn't know x's value during compile

if x > 0: return x # 错误!JAX在编译时不知道x的值

✅ Solution:

✅ 解决方案:

Use jax.lax.cond or mark x as static_argnum

使用jax.lax.cond或将x标记为static_argnum

undefined
undefined

NaN Gradients

NaN梯度

If your function has singularities (like
sqrt(0)
), gradients will be NaN.
python
undefined
如果函数存在奇点(如
sqrt(0)
),梯度会变为NaN。
python
undefined

✅ Solution: Add a small epsilon

✅ 解决方案:添加小的epsilon值

def safe_sqrt(x): return jnp.sqrt(x + 1e-8)
undefined
def safe_sqrt(x): return jnp.sqrt(x + 1e-8)
undefined

Memory Leaks on GPU

GPU内存泄漏

JAX pre-allocates 90% of GPU memory by default.
python
undefined
JAX默认预分配90%的GPU内存。
python
undefined

✅ Solution: Set environment variable

✅ 解决方案:设置环境变量

import os os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
undefined
import os os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
undefined

Best Practices

最佳实践

  1. Always use pure functions - No side effects, deterministic outputs
  2. Manage PRNG keys explicitly - Split keys for every random operation
  3. Use JIT for hot loops - Compile functions that are called repeatedly
  4. Leverage vmap for batching - Write single-sample code, let JAX handle batches
  5. Mark static arguments - Use
    static_argnums
    for non-array parameters in JIT
  6. Use functional updates - Always use
    .at
    methods for array modifications
  7. Profile before optimizing - Use
    jax.profiler
    to find bottlenecks
  8. Handle device placement - Use
    jax.device_put()
    to control where arrays live
  9. Test on CPU first - Debug on CPU, then scale to GPU/TPU
  10. Understand compilation costs - First JIT call is slow, subsequent calls are fast
JAX is the ultimate playground for differentiable science. By treating math as a series of functional transformations, it unlocks speeds and complexities that were previously impossible in Python.
  1. 始终使用纯函数 - 无副作用,输出确定
  2. 显式管理PRNG密钥 - 每次随机操作都拆分密钥
  3. 对热循环使用JIT - 编译频繁调用的函数
  4. 利用vmap进行批处理 - 编写单样本代码,由JAX处理批处理
  5. 标记静态参数 - 在JIT中为非数组参数使用
    static_argnums
  6. 使用函数式更新 - 始终使用
    .at
    方法修改数组
  7. 先分析再优化 - 使用
    jax.profiler
    查找性能瓶颈
  8. 处理设备放置 - 使用
    jax.device_put()
    控制数组存储位置
  9. 先在CPU上测试 - 在CPU上调试,再扩展到GPU/TPU
  10. 了解编译成本 - 第一次JIT调用速度慢,后续调用速度极快
JAX是可微分科学的终极试验场。通过将数学视为一系列函数转换,它解锁了Python中此前无法实现的速度和复杂度。