ml-failfast-validation
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseML Fail-Fast Validation
ML快速失败验证
POC validation patterns to catch issues before committing to long-running ML experiments.
在投入长时间运行的ML实验前,用于发现问题的POC验证模式。
When to Use This Skill
何时使用该技能
Use this skill when:
- Starting a new ML experiment that will run for hours
- Validating model architecture before full training
- Checking gradient flow and data pipeline integrity
- Implementing POC validation checklists
- Debugging prediction collapse or gradient explosion issues
在以下场景使用该技能:
- 启动将运行数小时的新ML实验时
- 在完整训练前验证模型架构时
- 检查梯度流动和数据管道完整性时
- 实施POC验证清单时
- 调试预测崩溃或梯度爆炸问题时
1. Why Fail-Fast?
1. 为什么要快速失败?
| Without Fail-Fast | With Fail-Fast |
|---|---|
| Discover crash 4 hours in | Catch in 30 seconds |
| Debug from cryptic error | Clear error message |
| Lose GPU time | Validate before commit |
| Silent data issues | Explicit schema checks |
Principle: Validate everything that can go wrong BEFORE the expensive computation.
| 未使用快速失败机制 | 使用快速失败机制 |
|---|---|
| 运行4小时后才发现崩溃 | 30秒内发现问题 |
| 从模糊错误中调试 | 清晰的错误提示 |
| 浪费GPU时间 | 投入前先验证 |
| 隐性数据问题 | 明确的Schema检查 |
原则:在进行昂贵的计算前,先验证所有可能出错的环节。
2. POC Validation Checklist
2. POC验证清单
Minimum Viable POC (5 Checks)
最小可行POC(5项检查)
python
def run_poc_validation():
"""Fast validation before full experiment."""
print("=" * 60)
print("FAIL-FAST POC VALIDATION")
print("=" * 60)
# [1/5] Model instantiation
print("\n[1/5] Model instantiation...")
model = create_model(architecture, input_size=n_features)
x = torch.randn(32, seq_len, n_features).to(device)
out = model(x)
assert out.shape == (32, 1), f"Output shape wrong: {out.shape}"
print(f" Input: (32, {seq_len}, {n_features}) -> Output: {out.shape}")
print(" Status: PASS")
# [2/5] Gradient flow
print("\n[2/5] Gradient flow...")
y = torch.randn(32, 1).to(device)
loss = F.mse_loss(out, y)
loss.backward()
grad_norms = [p.grad.norm().item() for p in model.parameters() if p.grad is not None]
assert len(grad_norms) > 0, "No gradients!"
assert all(np.isfinite(g) for g in grad_norms), "NaN/Inf gradients!"
print(f" Max grad norm: {max(grad_norms):.4f}")
print(" Status: PASS")
# [3/5] NDJSON artifact validation
print("\n[3/5] NDJSON artifact validation...")
log_path = output_dir / "experiment.jsonl"
with open(log_path, "a") as f:
f.write(json.dumps({"phase": "poc_start", "timestamp": datetime.now().isoformat()}) + "\n")
assert log_path.exists(), "Log file not created"
print(f" Log file: {log_path}")
print(" Status: PASS")
# [4/5] Epoch selector variation
print("\n[4/5] Epoch selector variation...")
epochs = []
for seed in [1, 2, 3]:
selector = create_selector()
# Simulate different validation results
for e in range(10, 201, 10):
selector.record(epoch=e, sortino=np.random.randn() * 0.1, sparsity=np.random.rand())
epochs.append(selector.select())
print(f" Selected epochs: {epochs}")
assert len(set(epochs)) > 1 or all(e == epochs[0] for e in epochs), "Selector not varying"
print(" Status: PASS")
# [5/5] Mini training (10 epochs)
print("\n[5/5] Mini training (10 epochs)...")
model = create_model(architecture, input_size=n_features).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
initial_loss = None
for epoch in range(10):
loss = train_one_epoch(model, train_loader, optimizer)
if initial_loss is None:
initial_loss = loss
print(f" Initial loss: {initial_loss:.4f}")
print(f" Final loss: {loss:.4f}")
print(" Status: PASS")
print("\n" + "=" * 60)
print("POC RESULT: ALL 5 CHECKS PASSED")
print("=" * 60)python
def run_poc_validation():
"""Fast validation before full experiment."""
print("=" * 60)
print("FAIL-FAST POC VALIDATION")
print("=" * 60)
# [1/5] Model instantiation
print("\
[1/5] Model instantiation...")
model = create_model(architecture, input_size=n_features)
x = torch.randn(32, seq_len, n_features).to(device)
out = model(x)
assert out.shape == (32, 1), f"Output shape wrong: {out.shape}"
print(f" Input: (32, {seq_len}, {n_features}) -> Output: {out.shape}")
print(" Status: PASS")
# [2/5] Gradient flow
print("\
[2/5] Gradient flow...")
y = torch.randn(32, 1).to(device)
loss = F.mse_loss(out, y)
loss.backward()
grad_norms = [p.grad.norm().item() for p in model.parameters() if p.grad is not None]
assert len(grad_norms) > 0, "No gradients!"
assert all(np.isfinite(g) for g in grad_norms), "NaN/Inf gradients!"
print(f" Max grad norm: {max(grad_norms):.4f}")
print(" Status: PASS")
# [3/5] NDJSON artifact validation
print("\
[3/5] NDJSON artifact validation...")
log_path = output_dir / "experiment.jsonl"
with open(log_path, "a") as f:
f.write(json.dumps({"phase": "poc_start", "timestamp": datetime.now().isoformat()}) + "\
")
assert log_path.exists(), "Log file not created"
print(f" Log file: {log_path}")
print(" Status: PASS")
# [4/5] Epoch selector variation
print("\
[4/5] Epoch selector variation...")
epochs = []
for seed in [1, 2, 3]:
selector = create_selector()
# Simulate different validation results
for e in range(10, 201, 10):
selector.record(epoch=e, sortino=np.random.randn() * 0.1, sparsity=np.random.rand())
epochs.append(selector.select())
print(f" Selected epochs: {epochs}")
assert len(set(epochs)) > 1 or all(e == epochs[0] for e in epochs), "Selector not varying"
print(" Status: PASS")
# [5/5] Mini training (10 epochs)
print("\
[5/5] Mini training (10 epochs)...")
model = create_model(architecture, input_size=n_features).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
initial_loss = None
for epoch in range(10):
loss = train_one_epoch(model, train_loader, optimizer)
if initial_loss is None:
initial_loss = loss
print(f" Initial loss: {initial_loss:.4f}")
print(f" Final loss: {loss:.4f}")
print(" Status: PASS")
print("\
" + "=" * 60)
print("POC RESULT: ALL 5 CHECKS PASSED")
print("=" * 60)Extended POC (10 Checks)
扩展POC(10项检查)
Add these for comprehensive validation:
python
undefined添加以下检查以实现全面验证:
python
undefined[6/10] Data loading
[6/10] Data loading
print("\n[6/10] Data loading...")
df = fetch_data(symbol, threshold)
assert len(df) > min_required_bars, f"Insufficient data: {len(df)} bars"
print(f" Loaded: {len(df):,} bars")
print(" Status: PASS")
print("
[6/10] Data loading...") df = fetch_data(symbol, threshold) assert len(df) > min_required_bars, f"Insufficient data: {len(df)} bars" print(f" Loaded: {len(df):,} bars") print(" Status: PASS")
[6/10] Data loading...") df = fetch_data(symbol, threshold) assert len(df) > min_required_bars, f"Insufficient data: {len(df)} bars" print(f" Loaded: {len(df):,} bars") print(" Status: PASS")
[7/10] Schema validation
[7/10] Schema validation
print("\n[7/10] Schema validation...")
validate_schema(df, required_columns, "raw_data")
print(" Status: PASS")
print("
[7/10] Schema validation...") validate_schema(df, required_columns, "raw_data") print(" Status: PASS")
[7/10] Schema validation...") validate_schema(df, required_columns, "raw_data") print(" Status: PASS")
[8/10] Feature computation
[8/10] Feature computation
print("\n[8/10] Feature computation...")
df = compute_features(df)
validate_schema(df, feature_columns, "features")
print(f" Features: {len(feature_columns)}")
print(" Status: PASS")
print("
[8/10] Feature computation...") df = compute_features(df) validate_schema(df, feature_columns, "features") print(f" Features: {len(feature_columns)}") print(" Status: PASS")
[8/10] Feature computation...") df = compute_features(df) validate_schema(df, feature_columns, "features") print(f" Features: {len(feature_columns)}") print(" Status: PASS")
[9/10] Prediction sanity
[9/10] Prediction sanity
print("\n[9/10] Prediction sanity...")
preds = model(X_test).detach().cpu().numpy()
pred_std = preds.std()
target_std = y_test.std()
pred_ratio = pred_std / target_std
assert pred_ratio > 0.005, f"Predictions collapsed: ratio={pred_ratio:.4f}"
print(f" Pred std ratio: {pred_ratio:.2%}")
print(" Status: PASS")
print("
[9/10] Prediction sanity...") preds = model(X_test).detach().cpu().numpy() pred_std = preds.std() target_std = y_test.std() pred_ratio = pred_std / target_std assert pred_ratio > 0.005, f"Predictions collapsed: ratio={pred_ratio:.4f}" print(f" Pred std ratio: {pred_ratio:.2%}") print(" Status: PASS")
[9/10] Prediction sanity...") preds = model(X_test).detach().cpu().numpy() pred_std = preds.std() target_std = y_test.std() pred_ratio = pred_std / target_std assert pred_ratio > 0.005, f"Predictions collapsed: ratio={pred_ratio:.4f}" print(f" Pred std ratio: {pred_ratio:.2%}") print(" Status: PASS")
[10/10] Checkpoint save/load
[10/10] Checkpoint save/load
print("\n[10/10] Checkpoint save/load...")
torch.save(model.state_dict(), checkpoint_path)
model2 = create_model(architecture, input_size=n_features)
model2.load_state_dict(torch.load(checkpoint_path))
print(" Status: PASS")
---print("
[10/10] Checkpoint save/load...") torch.save(model.state_dict(), checkpoint_path) model2 = create_model(architecture, input_size=n_features) model2.load_state_dict(torch.load(checkpoint_path)) print(" Status: PASS")
[10/10] Checkpoint save/load...") torch.save(model.state_dict(), checkpoint_path) model2 = create_model(architecture, input_size=n_features) model2.load_state_dict(torch.load(checkpoint_path)) print(" Status: PASS")
---3. Schema Validation Pattern
3. Schema验证模式
The Problem
问题所在
python
undefinedpython
undefinedBAD: Cryptic error 2 hours into experiment
BAD: Cryptic error 2 hours into experiment
KeyError: 'returns_vs' # Which file? Which function? What columns exist?
undefinedKeyError: 'returns_vs' # Which file? Which function? What columns exist?
undefinedThe Solution
解决方案
python
def validate_schema(df, required: list[str], stage: str) -> None:
"""Fail-fast schema validation with actionable error messages."""
# Handle both DataFrame columns and DatetimeIndex
available = list(df.columns)
if hasattr(df.index, 'name') and df.index.name:
available.append(df.index.name)
missing = [c for c in required if c not in available]
if missing:
raise ValueError(
f"[{stage}] Missing columns: {missing}\n"
f"Available: {sorted(available)}\n"
f"DataFrame shape: {df.shape}"
)
print(f" Schema validation PASSED ({stage}): {len(required)} columns", flush=True)python
def validate_schema(df, required: list[str], stage: str) -> None:
"""Fail-fast schema validation with actionable error messages."""
# Handle both DataFrame columns and DatetimeIndex
available = list(df.columns)
if hasattr(df.index, 'name') and df.index.name:
available.append(df.index.name)
missing = [c for c in required if c not in available]
if missing:
raise ValueError(
f"[{stage}] Missing columns: {missing}\
"
f"Available: {sorted(available)}\
"
f"DataFrame shape: {df.shape}"
)
print(f" Schema validation PASSED ({stage}): {len(required)} columns", flush=True)Usage at pipeline boundaries
Usage at pipeline boundaries
REQUIRED_RAW = ["open", "high", "low", "close", "volume"]
REQUIRED_FEATURES = ["returns_vs", "momentum_z", "atr_pct", "volume_z",
"rsi_14", "bb_pct_b", "vol_regime", "return_accel", "pv_divergence"]
df = fetch_data(symbol)
validate_schema(df, REQUIRED_RAW, "raw_data")
df = compute_features(df)
validate_schema(df, REQUIRED_FEATURES, "features")
---REQUIRED_RAW = ["open", "high", "low", "close", "volume"]
REQUIRED_FEATURES = ["returns_vs", "momentum_z", "atr_pct", "volume_z",
"rsi_14", "bb_pct_b", "vol_regime", "return_accel", "pv_divergence"]
df = fetch_data(symbol)
validate_schema(df, REQUIRED_RAW, "raw_data")
df = compute_features(df)
validate_schema(df, REQUIRED_FEATURES, "features")
---4. Gradient Health Checks
4. 梯度健康检查
Basic Gradient Check
基础梯度检查
python
def check_gradient_health(model: nn.Module, sample_input: torch.Tensor) -> dict:
"""Verify gradients flow correctly through model."""
model.train()
out = model(sample_input)
loss = out.sum()
loss.backward()
stats = {"total_params": 0, "params_with_grad": 0, "grad_norms": []}
for name, param in model.named_parameters():
stats["total_params"] += 1
if param.grad is not None:
stats["params_with_grad"] += 1
norm = param.grad.norm().item()
stats["grad_norms"].append(norm)
# Check for issues
if not np.isfinite(norm):
raise ValueError(f"Non-finite gradient in {name}: {norm}")
if norm > 100:
print(f" WARNING: Large gradient in {name}: {norm:.2f}")
stats["max_grad"] = max(stats["grad_norms"]) if stats["grad_norms"] else 0
stats["mean_grad"] = np.mean(stats["grad_norms"]) if stats["grad_norms"] else 0
return statspython
def check_gradient_health(model: nn.Module, sample_input: torch.Tensor) -> dict:
"""Verify gradients flow correctly through model."""
model.train()
out = model(sample_input)
loss = out.sum()
loss.backward()
stats = {"total_params": 0, "params_with_grad": 0, "grad_norms": []}
for name, param in model.named_parameters():
stats["total_params"] += 1
if param.grad is not None:
stats["params_with_grad"] += 1
norm = param.grad.norm().item()
stats["grad_norms"].append(norm)
# Check for issues
if not np.isfinite(norm):
raise ValueError(f"Non-finite gradient in {name}: {norm}")
if norm > 100:
print(f" WARNING: Large gradient in {name}: {norm:.2f}")
stats["max_grad"] = max(stats["grad_norms"]) if stats["grad_norms"] else 0
stats["mean_grad"] = np.mean(stats["grad_norms"]) if stats["grad_norms"] else 0
return statsArchitecture-Specific Checks
架构特定检查
python
def check_lstm_gradients(model: nn.Module) -> dict:
"""Check LSTM-specific gradient patterns."""
stats = {}
for name, param in model.named_parameters():
if param.grad is None:
continue
# Check forget gate bias (should not be too negative)
if "bias_hh" in name or "bias_ih" in name:
# LSTM bias: [i, f, g, o] gates
hidden_size = param.shape[0] // 4
forget_bias = param.grad[hidden_size:2*hidden_size]
stats["forget_bias_grad_mean"] = forget_bias.mean().item()
# Check hidden-to-hidden weights
if "weight_hh" in name:
stats["hh_weight_grad_norm"] = param.grad.norm().item()
return statspython
def check_lstm_gradients(model: nn.Module) -> dict:
"""Check LSTM-specific gradient patterns."""
stats = {}
for name, param in model.named_parameters():
if param.grad is None:
continue
# Check forget gate bias (should not be too negative)
if "bias_hh" in name or "bias_ih" in name:
# LSTM bias: [i, f, g, o] gates
hidden_size = param.shape[0] // 4
forget_bias = param.grad[hidden_size:2*hidden_size]
stats["forget_bias_grad_mean"] = forget_bias.mean().item()
# Check hidden-to-hidden weights
if "weight_hh" in name:
stats["hh_weight_grad_norm"] = param.grad.norm().item()
return stats5. Prediction Sanity Checks
5. 预测健全性检查
Collapse Detection
崩溃检测
python
def check_prediction_sanity(preds: np.ndarray, targets: np.ndarray) -> dict:
"""Detect prediction collapse or explosion."""
stats = {
"pred_mean": preds.mean(),
"pred_std": preds.std(),
"pred_min": preds.min(),
"pred_max": preds.max(),
"target_std": targets.std(),
}
# Relative threshold (not absolute!)
stats["pred_std_ratio"] = stats["pred_std"] / stats["target_std"]
# Collapse detection
if stats["pred_std_ratio"] < 0.005: # < 0.5% of target variance
raise ValueError(
f"Predictions collapsed!\n"
f" pred_std: {stats['pred_std']:.6f}\n"
f" target_std: {stats['target_std']:.6f}\n"
f" ratio: {stats['pred_std_ratio']:.4%}"
)
# Explosion detection
if stats["pred_std_ratio"] > 100: # > 100x target variance
raise ValueError(
f"Predictions exploded!\n"
f" pred_std: {stats['pred_std']:.2f}\n"
f" target_std: {stats['target_std']:.6f}\n"
f" ratio: {stats['pred_std_ratio']:.1f}x"
)
# Unique value check
stats["unique_values"] = len(np.unique(np.round(preds, 6)))
if stats["unique_values"] < 10:
print(f" WARNING: Only {stats['unique_values']} unique prediction values")
return statspython
def check_prediction_sanity(preds: np.ndarray, targets: np.ndarray) -> dict:
"""Detect prediction collapse or explosion."""
stats = {
"pred_mean": preds.mean(),
"pred_std": preds.std(),
"pred_min": preds.min(),
"pred_max": preds.max(),
"target_std": targets.std(),
}
# Relative threshold (not absolute!)
stats["pred_std_ratio"] = stats["pred_std"] / stats["target_std"]
# Collapse detection
if stats["pred_std_ratio"] < 0.005: # < 0.5% of target variance
raise ValueError(
f"Predictions collapsed!\
"
f" pred_std: {stats['pred_std']:.6f}\
"
f" target_std: {stats['target_std']:.6f}\
"
f" ratio: {stats['pred_std_ratio']:.4%}"
)
# Explosion detection
if stats["pred_std_ratio"] > 100: # > 100x target variance
raise ValueError(
f"Predictions exploded!\
"
f" pred_std: {stats['pred_std']:.2f}\
"
f" target_std: {stats['target_std']:.6f}\
"
f" ratio: {stats['pred_std_ratio']:.1f}x"
)
# Unique value check
stats["unique_values"] = len(np.unique(np.round(preds, 6)))
if stats["unique_values"] < 10:
print(f" WARNING: Only {stats['unique_values']} unique prediction values")
return statsCorrelation Check
相关性检查
python
def check_prediction_correlation(preds: np.ndarray, targets: np.ndarray) -> float:
"""Check if predictions have any correlation with targets."""
corr = np.corrcoef(preds.flatten(), targets.flatten())[0, 1]
if not np.isfinite(corr):
print(" WARNING: Correlation is NaN (likely collapsed predictions)")
return 0.0
# Note: negative correlation may still be useful (short signal)
print(f" Prediction-target correlation: {corr:.4f}")
return corrpython
def check_prediction_correlation(preds: np.ndarray, targets: np.ndarray) -> float:
"""Check if predictions have any correlation with targets."""
corr = np.corrcoef(preds.flatten(), targets.flatten())[0, 1]
if not np.isfinite(corr):
print(" WARNING: Correlation is NaN (likely collapsed predictions)")
return 0.0
# Note: negative correlation may still be useful (short signal)
print(f" Prediction-target correlation: {corr:.4f}")
return corr6. NDJSON Logging Validation
6. NDJSON日志验证
Required Event Types
必备事件类型
python
REQUIRED_EVENTS = {
"experiment_start": ["architecture", "features", "config"],
"fold_start": ["fold_id", "train_size", "val_size", "test_size"],
"epoch_complete": ["epoch", "train_loss", "val_loss"],
"fold_complete": ["fold_id", "test_sharpe", "test_sortino"],
"experiment_complete": ["total_folds", "mean_sharpe", "elapsed_seconds"],
}
def validate_ndjson_schema(log_path: Path) -> None:
"""Validate NDJSON log has all required events and fields."""
events = {}
with open(log_path) as f:
for line in f:
event = json.loads(line)
phase = event.get("phase", "unknown")
if phase not in events:
events[phase] = []
events[phase].append(event)
for phase, required_fields in REQUIRED_EVENTS.items():
if phase not in events:
raise ValueError(f"Missing event type: {phase}")
sample = events[phase][0]
missing = [f for f in required_fields if f not in sample]
if missing:
raise ValueError(f"Event '{phase}' missing fields: {missing}")
print(f" NDJSON schema valid: {len(events)} event types")python
REQUIRED_EVENTS = {
"experiment_start": ["architecture", "features", "config"],
"fold_start": ["fold_id", "train_size", "val_size", "test_size"],
"epoch_complete": ["epoch", "train_loss", "val_loss"],
"fold_complete": ["fold_id", "test_sharpe", "test_sortino"],
"experiment_complete": ["total_folds", "mean_sharpe", "elapsed_seconds"],
}
def validate_ndjson_schema(log_path: Path) -> None:
"""Validate NDJSON log has all required events and fields."""
events = {}
with open(log_path) as f:
for line in f:
event = json.loads(line)
phase = event.get("phase", "unknown")
if phase not in events:
events[phase] = []
events[phase].append(event)
for phase, required_fields in REQUIRED_EVENTS.items():
if phase not in events:
raise ValueError(f"Missing event type: {phase}")
sample = events[phase][0]
missing = [f for f in required_fields if f not in sample]
if missing:
raise ValueError(f"Event '{phase}' missing fields: {missing}")
print(f" NDJSON schema valid: {len(events)} event types")7. POC Timing Guide
7. POC时间指南
| Check | Typical Time | Max Time | Action if Exceeded |
|---|---|---|---|
| Model instantiation | < 1s | 5s | Check device, reduce model size |
| Gradient flow | < 2s | 10s | Check batch size |
| Schema validation | < 0.1s | 1s | Check data loading |
| Mini training (10 epochs) | < 30s | 2min | Reduce batch, check data loader |
| Full POC (10 checks) | < 2min | 5min | Something is wrong |
| 检查项 | 典型耗时 | 最大耗时 | 超时后的操作建议 |
|---|---|---|---|
| 模型实例化 | < 1秒 | 5秒 | 检查设备,减小模型尺寸 |
| 梯度流动检查 | < 2秒 | 10秒 | 检查批次大小 |
| Schema验证 | < 0.1秒 | 1秒 | 检查数据加载流程 |
| 迷你训练(10轮) | < 30秒 | 2分钟 | 减小批次,检查数据加载器 |
| 完整POC(10项检查) | < 2分钟 | 5分钟 | 系统存在问题,需排查 |
8. Failure Response Guide
8. 故障响应指南
| Failure | Likely Cause | Fix |
|---|---|---|
| Shape mismatch | Wrong input_size or seq_len | Check feature count |
| NaN gradients | LR too high, bad init | Reduce LR, check init |
| Zero gradients | Dead layers, missing params | Check model architecture |
| Predictions collapsed | Normalizer issue, bad loss | Check sLSTM normalizer |
| Predictions exploded | Gradient explosion | Add/tighten gradient clipping |
| Schema missing columns | Wrong data source | Check fetch function |
| Checkpoint load fails | State dict key mismatch | Check model architecture match |
| 故障类型 | 可能原因 | 修复方案 |
|---|---|---|
| 形状不匹配 | 输入尺寸或序列长度错误 | 检查特征数量 |
| NaN梯度 | 学习率过高、初始化不当 | 降低学习率,检查初始化逻辑 |
| 零梯度 | 失效层、参数缺失 | 检查模型架构 |
| 预测崩溃 | 归一器问题、损失函数错误 | 检查sLSTM归一器 |
| 预测爆炸 | 梯度爆炸 | 添加/收紧梯度裁剪 |
| Schema缺失列 | 数据源错误 | 检查数据获取函数 |
| 检查点加载失败 | 状态字典键不匹配 | 检查模型架构是否一致 |
9. Integration Example
9. 集成示例
python
def main():
# Parse args, setup output dir...
# PHASE 1: Fail-fast POC
print("=" * 60)
print("FAIL-FAST POC VALIDATION")
print("=" * 60)
try:
run_poc_validation()
except Exception as e:
print(f"\n{'=' * 60}")
print(f"POC FAILED: {type(e).__name__}")
print(f"{'=' * 60}")
print(f"Error: {e}")
print("\nFix the issue before running full experiment.")
sys.exit(1)
# PHASE 2: Full experiment (only if POC passes)
print("\n" + "=" * 60)
print("STARTING FULL EXPERIMENT")
print("=" * 60)
run_full_experiment()python
def main():
# Parse args, setup output dir...
# PHASE 1: Fail-fast POC
print("=" * 60)
print("FAIL-FAST POC VALIDATION")
print("=" * 60)
try:
run_poc_validation()
except Exception as e:
print(f"\
{'=' * 60}")
print(f"POC FAILED: {type(e).__name__}")
print(f"{'=' * 60}")
print(f"Error: {e}")
print("\
Fix the issue before running full experiment.")
sys.exit(1)
# PHASE 2: Full experiment (only if POC passes)
print("\
" + "=" * 60)
print("STARTING FULL EXPERIMENT")
print("=" * 60)
run_full_experiment()10. Anti-Patterns to Avoid
10. 需避免的反模式
DON'T: Skip validation to "save time"
不要:为了“节省时间”跳过验证
python
undefinedpython
undefinedBAD: "I'll just run it and see"
BAD: "I'll just run it and see"
run_full_experiment() # 4 hours later: crash
undefinedrun_full_experiment() # 4 hours later: crash
undefinedDON'T: Use absolute thresholds for relative quantities
不要:对相对量使用绝对阈值
python
undefinedpython
undefinedBAD: Absolute threshold
BAD: Absolute threshold
assert pred_std > 1e-4 # Meaningless for returns ~0.001
assert pred_std > 1e-4 # Meaningless for returns ~0.001
GOOD: Relative threshold
GOOD: Relative threshold
assert pred_std / target_std > 0.005 # 0.5% of target variance
undefinedassert pred_std / target_std > 0.005 # 0.5% of target variance
undefinedDON'T: Catch all exceptions silently
不要:静默捕获所有异常
python
undefinedpython
undefinedBAD: Hides real issues
BAD: Hides real issues
try:
result = risky_operation()
except Exception:
result = default_value # What went wrong?
try:
result = risky_operation()
except Exception:
result = default_value # What went wrong?
GOOD: Catch specific exceptions
GOOD: Catch specific exceptions
try:
result = risky_operation()
except (ValueError, RuntimeError) as e:
logger.error(f"Operation failed: {e}")
raise
undefinedtry:
result = risky_operation()
except (ValueError, RuntimeError) as e:
logger.error(f"Operation failed: {e}")
raise
undefinedDON'T: Print without flush
不要:打印时不刷新缓冲区
python
undefinedpython
undefinedBAD: Output buffered, can't see progress
BAD: Output buffered, can't see progress
print(f"Processing fold {i}...")
print(f"Processing fold {i}...")
GOOD: See output immediately
GOOD: See output immediately
print(f"Processing fold {i}...", flush=True)
---print(f"Processing fold {i}...", flush=True)
---References
参考资料
Troubleshooting
故障排查
| Issue | Cause | Solution |
|---|---|---|
| NaN gradients in POC | Learning rate too high | Reduce LR by 10x, check weight initialization |
| Zero gradients | Dead layers or missing params | Check model architecture, verify requires_grad=True |
| Predictions collapsed | Normalizer issue or bad loss | Check target normalization, verify loss function |
| Predictions exploded | Gradient explosion | Add gradient clipping, reduce learning rate |
| Schema missing columns | Wrong data source or transform | Verify fetch function returns expected columns |
| Checkpoint load fails | State dict key mismatch | Ensure model architecture matches saved checkpoint |
| POC timeout (>5 min) | Data loading or model too large | Reduce batch size, check DataLoader num_workers |
| Mini training no progress | Learning rate too low or frozen | Increase LR, verify optimizer updates all parameters |
| NDJSON validation fails | Missing required event types | Check all phases emit expected fields |
| Shape mismatch error | Wrong input_size or seq_len | Verify feature count matches model input dimension |
| 问题 | 原因 | 解决方案 |
|---|---|---|
| POC中出现NaN梯度 | 学习率过高 | 将学习率降低10倍,检查权重初始化 |
| 零梯度 | 失效层或参数缺失 | 检查模型架构,验证requires_grad=True |
| 预测崩溃 | 归一器问题或损失函数错误 | 检查目标归一化,验证损失函数 |
| 预测爆炸 | 梯度爆炸 | 添加梯度裁剪,降低学习率 |
| Schema缺失列 | 数据源或转换逻辑错误 | 验证数据获取函数返回的列是否符合预期 |
| 检查点加载失败 | 状态字典键不匹配 | 确保模型架构与保存的检查点一致 |
| POC超时(>5分钟) | 数据加载过慢或模型过大 | 减小批次大小,检查DataLoader的num_workers参数 |
| 迷你训练无进展 | 学习率过低或参数冻结 | 提高学习率,验证优化器是否更新所有参数 |
| NDJSON验证失败 | 缺少必备事件类型 | 检查所有阶段是否输出了预期字段 |
| 形状不匹配错误 | 输入尺寸或序列长度错误 | 验证特征数量与模型输入维度是否匹配 |
| ", |