jax
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseJAX - Autograd and XLA (Accelerated Linear Algebra)
JAX - Autograd与XLA(加速线性代数)
JAX is a framework that combines a NumPy-like API with a powerful system of composable function transformations: Grad (differentiation), Jit (compilation), Vmap (vectorization), and Pmap (parallelization).
JAX是一个结合了类NumPy API与强大可组合函数转换系统的框架:Grad(微分)、Jit(编译)、Vmap(向量化)和Pmap(并行化)。
When to Use
适用场景
- High-performance scientific simulations requiring GPU/TPU acceleration.
- Custom machine learning research where PyTorch/TF abstractions are too restrictive.
- Calculating higher-order derivatives (Hessians, Jacobians) for optimization.
- Physics-informed machine learning and differentiable simulations.
- Automatic vectorization of functions (no more manual batching).
- Running the same code on CPU, GPU, and TPU without changes.
- 需要GPU/TPU加速的高性能科学模拟。
- PyTorch/TF抽象过于受限的自定义机器学习研究。
- 计算用于优化的高阶导数(Hessian、Jacobian)。
- 物理信息机器学习与可微分模拟。
- 函数的自动向量化(无需手动批处理)。
- 无需修改即可在CPU、GPU和TPU上运行相同代码。
Reference Documentation
参考文档
Official docs: https://jax.readthedocs.io/
GitHub: https://github.com/google/jax
Search patterns:, , , ,
GitHub: https://github.com/google/jax
Search patterns:
jax.numpyjax.jitjax.gradjax.vmapjax.random官方文档:https://jax.readthedocs.io/
GitHub:https://github.com/google/jax
搜索关键词:, , , ,
GitHub:https://github.com/google/jax
搜索关键词:
jax.numpyjax.jitjax.gradjax.vmapjax.randomCore Principles
核心原则
Pure Functions (Immutability)
纯函数(不可变性)
JAX is built on functional programming. All functions must be pure: they should not have side effects (like modifying a global variable) and must return the same output for the same input. JAX arrays are immutable.
JAX基于函数式编程构建。所有函数必须是纯函数:不应有副作用(如修改全局变量),且相同输入必须返回相同输出。JAX数组是不可变的。
XLA (Just-In-Time Compilation)
XLA(即时编译)
JAX uses XLA to compile and optimize Python/NumPy code into efficient machine code for specific hardware.
JAX使用XLA将Python/NumPy代码编译并优化为适用于特定硬件的高效机器码。
Manual PRNG Handling
手动PRNG状态管理
Unlike NumPy, JAX requires explicit management of random state (keys) to ensure reproducibility in parallel environments.
与NumPy不同,JAX要求显式管理随机状态(密钥),以确保并行环境中的可复现性。
Quick Reference
快速参考
Installation
安装
bash
undefinedbash
undefinedCPU
CPU版本
pip install jax jaxlib
pip install jax jaxlib
GPU (Check documentation for specific CUDA versions)
GPU版本(请查看文档确认特定CUDA版本)
pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
undefinedpip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
undefinedStandard Imports
标准导入
python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap, randompython
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap, randomBasic Pattern - Differentiate and JIT
基础模式 - 微分与JIT编译
python
import jax.numpy as jnp
from jax import grad, jitpython
import jax.numpy as jnp
from jax import grad, jit1. Define a pure function
1. 定义纯函数
def f(x):
return jnp.sin(x) + x**2
def f(x):
return jnp.sin(x) + x**2
2. Transform: Create a gradient function
2. 转换:创建梯度函数
df_dx = grad(f)
df_dx = grad(f)
3. Transform: Compile for speed
3. 转换:编译以提升速度
f_fast = jit(f)
f_fast = jit(f)
4. Use
4. 使用
val = f_fast(2.0)
slope = df_dx(2.0)
undefinedval = f_fast(2.0)
slope = df_dx(2.0)
undefinedCritical Rules
关键规则
✅ DO
✅ 建议做法
- Use jax.numpy (jnp) - It mirrors NumPy but supports JAX transformations.
- Write Pure Functions - Ensure functions only depend on inputs and don't modify external state.
- Handle PRNG Keys Manually - Use for every random operation.
key, subkey = random.split(key) - Use vmap for Batching - Write code for a single sample and let JAX handle the batch dimension.
- Set static_argnums in JIT - If a JIT'ed function takes a non-array argument (like a string or integer used in a loop), mark it as static.
- In-place updates via .at - Since arrays are immutable, use .
x = x.at[idx].set(val)
- 使用jax.numpy(jnp) - 它镜像NumPy的API,同时支持JAX的转换功能。
- 编写纯函数 - 确保函数仅依赖输入,不修改外部状态。
- 手动处理PRNG密钥 - 每次随机操作都使用 。
key, subkey = random.split(key) - 使用vmap进行批处理 - 编写单样本代码,由JAX处理批处理维度。
- 在JIT中设置static_argnums - 如果JIT编译的函数接收非数组参数(如循环中使用的字符串或整数),将其标记为静态参数。
- 通过.at进行原地更新 - 由于数组不可变,使用 。
x = x.at[idx].set(val)
❌ DON'T
❌ 禁止做法
- Use in-place updates - will raise an error.
x[idx] = val - Use standard numpy (np) - Standard NumPy arrays don't support JAX transformations.
- Use Side Effects - Don't use or modify global variables inside JIT-compiled functions.
print() - Forget to block_until_ready() - JAX uses asynchronous execution. For accurate timing, use .
result.block_until_ready()
- 使用原地更新 - 会引发错误。
x[idx] = val - 使用标准numpy(np) - 标准NumPy数组不支持JAX的转换功能。
- 使用副作用操作 - 不要在JIT编译的函数内使用或修改全局变量。
print() - 忘记调用block_until_ready() - JAX使用异步执行。如需准确计时,请使用。
result.block_until_ready()
Anti-Patterns (NEVER)
反模式(绝对避免)
python
import jax.numpy as jnp
from jax import jit, randompython
import jax.numpy as jnp
from jax import jit, random❌ BAD: Modifying a global variable inside a function
❌ 错误:在函数内修改全局变量
counter = 0
@jit
def bad_func(x):
global counter
counter += 1 # ❌ Side effect! Will only run once during compilation
return x * 2
counter = 0
@jit
def bad_func(x):
global counter
counter += 1 # ❌ 副作用!仅会在编译时运行一次
return x * 2
❌ BAD: Standard NumPy random (not reproducible/parallel-safe)
❌ 错误:使用标准NumPy随机函数(不可复现/非并行安全)
val = np.random.randn(5)
val = np.random.randn(5)
✅ GOOD: JAX PRNG
✅ 正确:使用JAX PRNG
key = random.key(42)
val = random.normal(key, (5,))
key = random.key(42)
val = random.normal(key, (5,))
❌ BAD: In-place assignment
❌ 错误:原地赋值
x[0] = 1.0
x[0] = 1.0
✅ GOOD: Functional update
✅ 正确:函数式更新
x = jnp.zeros(5)
x = x.at[0].set(1.0)
undefinedx = jnp.zeros(5)
x = x.at[0].set(1.0)
undefinedFunction Transformations
函数转换
Grad (Differentiation)
Grad(微分)
python
def loss(params, x, y):
pred = jnp.dot(x, params)
return jnp.mean((pred - y)**2)python
def loss(params, x, y):
pred = jnp.dot(x, params)
return jnp.mean((pred - y)**2)Gradient of loss with respect to the 1st argument (params)
损失函数对第一个参数(params)的梯度
grads = grad(loss)(params, x, y)
grads = grad(loss)(params, x, y)
Higher-order: Hessian
高阶导数:Hessian矩阵
hessian = jax.hessian(loss)(params, x, y)
undefinedhessian = jax.hessian(loss)(params, x, y)
undefinedJit (Just-In-Time Compilation)
Jit(即时编译)
python
@jit
def complex_math(x):
# This whole block is compiled into one XLA kernel
y = jnp.exp(x)
return jnp.sin(y) / jnp.sqrt(x)python
@jit
def complex_math(x):
# 整个代码块会被编译为一个XLA内核
y = jnp.exp(x)
return jnp.sin(y) / jnp.sqrt(x)First call: Compiles (slow)
第一次调用:编译(速度慢)
Subsequent calls: Super fast
后续调用:极快
undefinedundefinedVmap (Automatic Vectorization)
Vmap(自动向量化)
python
def model(params, x):
return jnp.dot(params, x)python
def model(params, x):
return jnp.dot(params, x)model works on 1D x. How to apply to a 2D batch of X?
model适用于一维x。如何将其应用于二维批量X?
in_axes=(None, 0): don't map params, map the 0th axis of x
in_axes=(None, 0): 不对params进行映射,对x的第0维进行映射
batch_model = vmap(model, in_axes=(None, 0))
batch_preds = batch_model(params, X_batch)
undefinedbatch_model = vmap(model, in_axes=(None, 0))
batch_preds = batch_model(params, X_batch)
undefinedRandom Numbers (jax.random)
随机数(jax.random)
The State Management
状态管理
python
key = random.key(0)python
key = random.key(0)Never reuse the same key!
永远不要重复使用同一个密钥!
key, subkey = random.split(key)
data = random.normal(subkey, (10,))
key, subkey = random.split(key)
noise = random.uniform(subkey, (10,))
undefinedkey, subkey = random.split(key)
data = random.normal(subkey, (10,))
key, subkey = random.split(key)
noise = random.uniform(subkey, (10,))
undefinedWorking with PyTrees
与PyTrees配合使用
Handling complex data structures (Dicts, Lists, Tuples)
处理复杂数据结构(字典、列表、元组)
JAX transformations work on "PyTrees" — nested containers of arrays.
python
params = {'weights': jnp.ones((5,)), 'bias': 0.0}
def predict(p, x):
return jnp.dot(x, p['weights']) + p['bias']JAX的转换功能适用于“PyTrees”——即嵌套的数组容器。
python
params = {'weights': jnp.ones((5,)), 'bias': 0.0}
def predict(p, x):
return jnp.dot(x, p['weights']) + p['bias']grad and jit handle the dictionary automatically
grad和jit会自动处理字典类型
grads = grad(predict)(params, x)
undefinedgrads = grad(predict)(params, x)
undefinedPractical Workflows
实用工作流
1. Differentiable Physics: Solving a Simple ODE
1. 可微分物理:求解简单常微分方程
python
def system_dynamics(state, t):
# Simple harmonic oscillator
x, v = state
dxdt = v
dvdt = -0.5 * x
return jnp.array([dxdt, dvdt])
def loss_fn(initial_state, target_x):
# Simulate for 10 steps using simple Euler
state = initial_state
dt = 0.1
for i in range(10):
state += system_dynamics(state, i*dt) * dt
return (state[0] - target_x)**2python
def system_dynamics(state, t):
# 简谐振荡器
x, v = state
dxdt = v
dvdt = -0.5 * x
return jnp.array([dxdt, dvdt])
def loss_fn(initial_state, target_x):
# 使用简单欧拉法模拟10步
state = initial_state
dt = 0.1
for i in range(10):
state += system_dynamics(state, i*dt) * dt
return (state[0] - target_x)**2We can take the gradient of the simulation with respect to initial state!
我们可以对模拟过程关于初始状态求梯度!
optimize_initial_state = grad(loss_fn)
undefinedoptimize_initial_state = grad(loss_fn)
undefined2. Parameter Sweep with vmap
2. 使用vmap进行参数扫描
python
def simulation(param):
# Some complex computation
return jnp.sum(jnp.linspace(0, param, 100))python
def simulation(param):
# 某些复杂计算
return jnp.sum(jnp.linspace(0, param, 100))Parallelize simulation over a range of parameters
对一系列参数并行执行模拟
params = jnp.linspace(1, 10, 100)
results = vmap(simulation)(params)
undefinedparams = jnp.linspace(1, 10, 100)
results = vmap(simulation)(params)
undefined3. Distributed Training with pmap
3. 使用pmap进行分布式训练
python
undefinedpython
undefinedpmap replicates the function across multiple GPUs
pmap会在多个GPU上复制函数
(assuming 8 GPUs are available)
(假设存在8个GPU)
x = jnp.zeros((8, 1024))
x = jnp.zeros((8, 1024))
results = pmap(jnp.sin)(x)
results = pmap(jnp.sin)(x)
undefinedundefinedPerformance Optimization
性能优化
Static Arguments in JIT
JIT中的静态参数
If your function uses a loop based on an input value, that value must be static.
python
from functools import partial
@partial(jit, static_argnums=(1,))
def power_loop(x, n):
for i in range(n):
x = x * x
return x如果函数使用基于输入值的循环,该值必须标记为静态参数。
python
from functools import partial
@partial(jit, static_argnums=(1,))
def power_loop(x, n):
for i in range(n):
x = x * x
return xAvoid Python Control Flow
避免Python控制流
Prefer JAX control flow (, , ) for better XLA optimization.
condwhile_loopfori_looppython
from jax.lax import cond
def safe_divide(x, y):
return cond(y == 0, lambda _: 0.0, lambda _: x / y, operand=None)优先使用JAX控制流(, , )以获得更好的XLA优化效果。
condwhile_loopfori_looppython
from jax.lax import cond
def safe_divide(x, y):
return cond(y == 0, lambda _: 0.0, lambda _: x / y, operand=None)Common Pitfalls and Solutions
常见陷阱与解决方案
The "Tracer" Error
“Tracer”错误
Inside JIT, JAX doesn't see actual numbers, it sees "Tracers".
python
undefined在JIT编译中,JAX看不到实际数值,只能看到“Tracer”对象。
python
undefined❌ Problem:
❌ 问题:
@jit
@jit
def func(x):
def func(x):
if x > 0: return x # Error! JAX doesn't know x's value during compile
if x > 0: return x # 错误!JAX在编译时不知道x的值
✅ Solution:
✅ 解决方案:
Use jax.lax.cond or mark x as static_argnum
使用jax.lax.cond或将x标记为static_argnum
undefinedundefinedNaN Gradients
NaN梯度
If your function has singularities (like ), gradients will be NaN.
sqrt(0)python
undefined如果函数存在奇点(如),梯度会变为NaN。
sqrt(0)python
undefined✅ Solution: Add a small epsilon
✅ 解决方案:添加小的epsilon值
def safe_sqrt(x):
return jnp.sqrt(x + 1e-8)
undefineddef safe_sqrt(x):
return jnp.sqrt(x + 1e-8)
undefinedMemory Leaks on GPU
GPU内存泄漏
JAX pre-allocates 90% of GPU memory by default.
python
undefinedJAX默认预分配90%的GPU内存。
python
undefined✅ Solution: Set environment variable
✅ 解决方案:设置环境变量
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
undefinedimport os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
undefinedBest Practices
最佳实践
- Always use pure functions - No side effects, deterministic outputs
- Manage PRNG keys explicitly - Split keys for every random operation
- Use JIT for hot loops - Compile functions that are called repeatedly
- Leverage vmap for batching - Write single-sample code, let JAX handle batches
- Mark static arguments - Use for non-array parameters in JIT
static_argnums - Use functional updates - Always use methods for array modifications
.at - Profile before optimizing - Use to find bottlenecks
jax.profiler - Handle device placement - Use to control where arrays live
jax.device_put() - Test on CPU first - Debug on CPU, then scale to GPU/TPU
- Understand compilation costs - First JIT call is slow, subsequent calls are fast
JAX is the ultimate playground for differentiable science. By treating math as a series of functional transformations, it unlocks speeds and complexities that were previously impossible in Python.
- 始终使用纯函数 - 无副作用,输出确定
- 显式管理PRNG密钥 - 每次随机操作都拆分密钥
- 对热循环使用JIT - 编译频繁调用的函数
- 利用vmap进行批处理 - 编写单样本代码,由JAX处理批处理
- 标记静态参数 - 在JIT中为非数组参数使用
static_argnums - 使用函数式更新 - 始终使用方法修改数组
.at - 先分析再优化 - 使用查找性能瓶颈
jax.profiler - 处理设备放置 - 使用控制数组存储位置
jax.device_put() - 先在CPU上测试 - 在CPU上调试,再扩展到GPU/TPU
- 了解编译成本 - 第一次JIT调用速度慢,后续调用速度极快
JAX是可微分科学的终极试验场。通过将数学视为一系列函数转换,它解锁了Python中此前无法实现的速度和复杂度。