Loading...
Loading...
Composable transformations of Python+NumPy programs. Differentiate, vectorize, JIT-compile to GPU/TPU. Built for high-performance machine learning research and complex scientific simulations. Use for automatic differentiation, GPU/TPU acceleration, higher-order derivatives, physics-informed machine learning, differentiable simulations, and automatic vectorization.
npx skill4agent add tondevrel/scientific-agent-skills jaxjax.numpyjax.jitjax.gradjax.vmapjax.random# CPU
pip install jax jaxlib
# GPU (Check documentation for specific CUDA versions)
pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.htmlimport jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap, randomimport jax.numpy as jnp
from jax import grad, jit
# 1. Define a pure function
def f(x):
return jnp.sin(x) + x**2
# 2. Transform: Create a gradient function
df_dx = grad(f)
# 3. Transform: Compile for speed
f_fast = jit(f)
# 4. Use
val = f_fast(2.0)
slope = df_dx(2.0)key, subkey = random.split(key)x = x.at[idx].set(val)x[idx] = valprint()result.block_until_ready()import jax.numpy as jnp
from jax import jit, random
# ❌ BAD: Modifying a global variable inside a function
counter = 0
@jit
def bad_func(x):
global counter
counter += 1 # ❌ Side effect! Will only run once during compilation
return x * 2
# ❌ BAD: Standard NumPy random (not reproducible/parallel-safe)
# val = np.random.randn(5)
# ✅ GOOD: JAX PRNG
key = random.key(42)
val = random.normal(key, (5,))
# ❌ BAD: In-place assignment
# x[0] = 1.0
# ✅ GOOD: Functional update
x = jnp.zeros(5)
x = x.at[0].set(1.0)def loss(params, x, y):
pred = jnp.dot(x, params)
return jnp.mean((pred - y)**2)
# Gradient of loss with respect to the 1st argument (params)
grads = grad(loss)(params, x, y)
# Higher-order: Hessian
hessian = jax.hessian(loss)(params, x, y)@jit
def complex_math(x):
# This whole block is compiled into one XLA kernel
y = jnp.exp(x)
return jnp.sin(y) / jnp.sqrt(x)
# First call: Compiles (slow)
# Subsequent calls: Super fastdef model(params, x):
return jnp.dot(params, x)
# model works on 1D x. How to apply to a 2D batch of X?
# in_axes=(None, 0): don't map params, map the 0th axis of x
batch_model = vmap(model, in_axes=(None, 0))
batch_preds = batch_model(params, X_batch)key = random.key(0)
# Never reuse the same key!
key, subkey = random.split(key)
data = random.normal(subkey, (10,))
key, subkey = random.split(key)
noise = random.uniform(subkey, (10,))params = {'weights': jnp.ones((5,)), 'bias': 0.0}
def predict(p, x):
return jnp.dot(x, p['weights']) + p['bias']
# grad and jit handle the dictionary automatically
grads = grad(predict)(params, x)def system_dynamics(state, t):
# Simple harmonic oscillator
x, v = state
dxdt = v
dvdt = -0.5 * x
return jnp.array([dxdt, dvdt])
def loss_fn(initial_state, target_x):
# Simulate for 10 steps using simple Euler
state = initial_state
dt = 0.1
for i in range(10):
state += system_dynamics(state, i*dt) * dt
return (state[0] - target_x)**2
# We can take the gradient of the simulation with respect to initial state!
optimize_initial_state = grad(loss_fn)def simulation(param):
# Some complex computation
return jnp.sum(jnp.linspace(0, param, 100))
# Parallelize simulation over a range of parameters
params = jnp.linspace(1, 10, 100)
results = vmap(simulation)(params)# pmap replicates the function across multiple GPUs
# (assuming 8 GPUs are available)
# x = jnp.zeros((8, 1024))
# results = pmap(jnp.sin)(x)from functools import partial
@partial(jit, static_argnums=(1,))
def power_loop(x, n):
for i in range(n):
x = x * x
return xcondwhile_loopfori_loopfrom jax.lax import cond
def safe_divide(x, y):
return cond(y == 0, lambda _: 0.0, lambda _: x / y, operand=None)# ❌ Problem:
# @jit
# def func(x):
# if x > 0: return x # Error! JAX doesn't know x's value during compile
# ✅ Solution:
# Use jax.lax.cond or mark x as static_argnumsqrt(0)# ✅ Solution: Add a small epsilon
def safe_sqrt(x):
return jnp.sqrt(x + 1e-8)# ✅ Solution: Set environment variable
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'static_argnums.atjax.profilerjax.device_put()