Loading...
Loading...
Bayesian statistical modeling with PyMC v5+. Use when building probabilistic models, specifying priors, running MCMC inference, diagnosing convergence, or comparing models. Covers PyMC, ArviZ, pymc-bart, pymc-extras, nutpie, and JAX/NumPyro backends. Triggers on tasks involving: Bayesian inference, posterior sampling, hierarchical/multilevel models, GLMs, time series, Gaussian processes, BART, mixture models, prior/posterior predictive checks, MCMC diagnostics, LOO-CV, WAIC, model comparison, or causal inference with do/observe.
npx skill4agent add pymc-labs/python-analytics-skills pymc-modelingimport pymc as pm
import arviz as az
with pm.Model(coords=coords) as model:
# Data containers (for out-of-sample prediction)
x = pm.Data("x", x_obs, dims="obs")
# Priors
beta = pm.Normal("beta", mu=0, sigma=1, dims="features")
sigma = pm.HalfNormal("sigma", sigma=1)
# Likelihood
mu = pm.math.dot(x, beta)
y = pm.Normal("y", mu=mu, sigma=sigma, observed=y_obs, dims="obs")
# Inference
idata = pm.sample(nuts_sampler="nutpie", random_seed=42)coords = {
"obs": np.arange(n_obs),
"features": ["intercept", "age", "income"],
"group": group_labels,
}# Non-centered (better for divergences)
offset = pm.Normal("offset", 0, 1, dims="group")
alpha = mu_alpha + sigma_alpha * offset
# Centered (better with strong data)
alpha = pm.Normal("alpha", mu_alpha, sigma_alpha, dims="group")with model:
idata = pm.sample(
draws=1000, tune=1000, chains=4,
nuts_sampler="nutpie",
random_seed=42,
)
idata.to_netcdf("results.nc") # Save immediately after samplingidata_kwargs={"log_likelihood": True}pm.compute_log_likelihood(idata, model=model)orderedOrderedLogisticOrderedProbitnuts_sampler="nutpie"idata = pm.sample(draws=1000, tune=1000, chains=4, random_seed=42)pip install nutpienuts_sampler="numpyro"pm.sample_prior_predictiveidata.to_netcdf(...)pm.sample_posterior_predictive# 1. Check for divergences (must be 0 or near 0)
n_div = idata.sample_stats["diverging"].sum().item()
print(f"Divergences: {n_div}")
# 2. Summary with convergence diagnostics
summary = az.summary(idata, var_names=["~offset"]) # exclude auxiliary
print(summary[["mean", "sd", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]])
# 3. Visual convergence check
az.plot_trace(idata, compact=True)
az.plot_rank(idata, var_names=["beta", "sigma"])r_hat < 1.01ess_bulk > 400ess_tail > 400# ESS evolution (should grow linearly)
az.plot_ess(idata, kind="evolution")
# Energy diagnostic (HMC health)
az.plot_energy(idata)
# Autocorrelation (should decay rapidly)
az.plot_autocorr(idata, var_names=["beta"])# Generate posterior predictive
with model:
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
# Does the model capture the data?
az.plot_ppc(idata, kind="cumulative")
# Calibration check
az.plot_loo_pit(idata, y="y")# Posterior summaries
az.plot_posterior(idata, var_names=["beta"], ref_val=0)
# Forest plots for hierarchical parameters
az.plot_forest(idata, var_names=["alpha"], combined=True)
# Parameter correlations (identify non-identifiability)
az.plot_pair(idata, var_names=["alpha", "beta", "sigma"])with model:
prior_pred = pm.sample_prior_predictive(draws=500)
az.plot_ppc(prior_pred, group="prior", kind="cumulative")
prior_y = prior_pred.prior_predictive["y"].values.flatten()
print(f"Prior predictive range: [{prior_y.min():.1f}, {prior_y.max():.1f}]")pm.sample()with model:
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
az.plot_ppc(idata, kind="cumulative")
az.plot_loo_pit(idata, y="y")model.debug()model.point_logps()print(model)pm.model_to_graphviz(model)| Symptom | Likely Cause | Fix |
|---|---|---|
| Parameter vs observation dimensions | Use index vectors: |
| Data outside distribution support | Check bounds; use |
| Unscaled predictors or flat priors | Standardize features; use weakly informative priors |
| High divergence count | Funnel geometry | Non-centered parameterization |
| Invalid parameter combinations | Check parameter constraints, add bounds |
| Observations outside likelihood support | Verify data matches distribution domain |
| Slow discrete sampling | NUTS incompatible with discrete | Marginalize discrete variables |
az.plot_pair(idata, divergences=True)# Compute LOO with pointwise diagnostics
loo = az.loo(idata, pointwise=True)
print(f"ELPD: {loo.elpd_loo:.1f} ± {loo.se:.1f}")
# Check Pareto k values (must be < 0.7 for reliable LOO)
print(f"Bad k (>0.7): {(loo.pareto_k > 0.7).sum().item()}")
az.plot_khat(idata)# If using nutpie, compute log-likelihood first (nutpie doesn't store it automatically)
pm.compute_log_likelihood(idata_a, model=model_a)
pm.compute_log_likelihood(idata_b, model=model_b)
comparison = az.compare({
"model_a": idata_a,
"model_b": idata_b,
}, ic="loo")
print(comparison[["rank", "elpd_loo", "elpd_diff", "weight"]])
az.plot_compare(comparison)# Save to NetCDF (recommended format)
idata.to_netcdf("results/model_v1.nc")
# Load
idata = az.from_netcdf("results/model_v1.nc")with model:
idata = pm.sample(nuts_sampler="nutpie")
idata.to_netcdf("results.nc") # Save before any post-processing!
with model:
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
idata.to_netcdf("results.nc") # Update with posterior predictivewith pm.Model(coords={"group": groups, "obs": obs_idx}) as hierarchical:
# Hyperpriors
mu_alpha = pm.Normal("mu_alpha", 0, 1)
sigma_alpha = pm.HalfNormal("sigma_alpha", 1)
# Group-level (non-centered)
alpha_offset = pm.Normal("alpha_offset", 0, 1, dims="group")
alpha = pm.Deterministic("alpha", mu_alpha + sigma_alpha * alpha_offset, dims="group")
# Likelihood
y = pm.Normal("y", alpha[group_idx], sigma, observed=y_obs, dims="obs")# Logistic regression
with pm.Model() as logistic:
alpha = pm.Normal("alpha", 0, 2.5)
beta = pm.Normal("beta", 0, 2.5, dims="features")
p = pm.math.sigmoid(alpha + pm.math.dot(X, beta))
y = pm.Bernoulli("y", p=p, observed=y_obs)
# Poisson regression
with pm.Model() as poisson:
beta = pm.Normal("beta", 0, 1, dims="features")
y = pm.Poisson("y", mu=pm.math.exp(pm.math.dot(X, beta)), observed=y_obs)pm.gp.Marginalwith pm.Model() as gp_model:
# Hyperparameters
ell = pm.InverseGamma("ell", alpha=5, beta=5)
eta = pm.HalfNormal("eta", sigma=2)
sigma = pm.HalfNormal("sigma", sigma=0.5)
# Covariance function (Matern52 recommended)
cov = eta**2 * pm.gp.cov.Matern52(1, ls=ell)
# HSGP approximation
gp = pm.gp.HSGP(m=[20], c=1.5, cov_func=cov)
f = gp.prior("f", X=X[:, None]) # X must be 2D
# Likelihood
y = pm.Normal("y", mu=f, sigma=sigma, observed=y_obs)pm.gp.HSGPPeriodicpm.gp.Marginalpm.gp.Latentwith pm.Model(coords={"time": range(T)}) as ar_model:
rho = pm.Uniform("rho", -1, 1)
sigma = pm.HalfNormal("sigma", sigma=1)
y = pm.AR("y", rho=[rho], sigma=sigma, constant=True,
observed=y_obs, dims="time")import pymc_bart as pmb
with pm.Model() as bart_model:
mu = pmb.BART("mu", X=X, Y=y, m=50)
sigma = pm.HalfNormal("sigma", 1)
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)import numpy as np
coords = {"component": range(K)}
with pm.Model(coords=coords) as gmm:
# Mixture weights
w = pm.Dirichlet("w", a=np.ones(K), dims="component")
# Component parameters (with ordering to avoid label switching)
mu = pm.Normal("mu", mu=0, sigma=10, dims="component",
transform=pm.distributions.transforms.ordered,
initval=np.linspace(y_obs.min(), y_obs.max(), K))
sigma = pm.HalfNormal("sigma", sigma=2, dims="component")
# Mixture likelihood
y = pm.NormalMixture("y", w=w, mu=mu, sigma=sigma, observed=y_obs)target_accept=0.9initvaltarget_accept=0.95# Zero-Inflated Poisson (excess zeros)
with pm.Model() as zip_model:
psi = pm.Beta("psi", alpha=2, beta=2) # P(structural zero)
mu = pm.Exponential("mu", lam=1)
y = pm.ZeroInflatedPoisson("y", psi=psi, mu=mu, observed=y_obs)
# Censored data (e.g., right-censored survival)
with pm.Model() as censored_model:
mu = pm.Normal("mu", mu=0, sigma=10)
sigma = pm.HalfNormal("sigma", sigma=5)
y = pm.Censored("y", dist=pm.Normal.dist(mu=mu, sigma=sigma),
lower=None, upper=censoring_time, observed=y_obs)
# Ordinal regression
with pm.Model() as ordinal:
beta = pm.Normal("beta", mu=0, sigma=2, dims="features")
cutpoints = pm.Normal("cutpoints", mu=0, sigma=2,
transform=pm.distributions.transforms.ordered,
shape=n_categories - 1)
y = pm.OrderedLogistic("y", eta=pm.math.dot(X, beta),
cutpoints=cutpoints, observed=y_obs)"cutpoints""cutpoints"# pm.do — intervene (breaks incoming edges)
with pm.do(causal_model, {"x": 2}) as intervention_model:
idata = pm.sample_prior_predictive() # P(y, z | do(x=2))
# pm.observe — condition (preserves causal structure)
with pm.observe(causal_model, {"y": 1}) as conditioned_model:
idata = pm.sample(nuts_sampler="nutpie") # P(x, z | y=1)
# Combine: P(y | do(x=2), z=0)
with pm.do(causal_model, {"x": 2}) as m1:
with pm.observe(m1, {"z": 0}) as m2:
idata = pm.sample(nuts_sampler="nutpie")import pymc_extras as pmxpmx.marginalize(model, ["discrete_var"])pmx.R2D2M2CP(...)pmx.fit_laplace(model)# Soft constraints via Potential
import pytensor.tensor as pt
pm.Potential("sum_to_zero", -100 * pt.sqr(alpha.sum()))pm.DensityDistpm.Potentialpm.Simulatorpm.CustomDist