pyhealth

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

PyHealth: Healthcare AI Toolkit

PyHealth:医疗AI工具包

Overview

概述

PyHealth is a comprehensive Python library for healthcare AI that provides specialized tools, models, and datasets for clinical machine learning. Use this skill when developing healthcare prediction models, processing clinical data, working with medical coding systems, or deploying AI solutions in healthcare settings.
PyHealth是一款面向医疗AI的综合性Python库,为临床机器学习提供专用工具、模型和数据集。当你开发医疗预测模型、处理临床数据、使用医疗编码系统,或在医疗场景中部署AI解决方案时,可使用该工具。

When to Use This Skill

适用场景

Invoke this skill when:
  • Working with healthcare datasets: MIMIC-III, MIMIC-IV, eICU, OMOP, sleep EEG data, medical images
  • Clinical prediction tasks: Mortality prediction, hospital readmission, length of stay, drug recommendation
  • Medical coding: Translating between ICD-9/10, NDC, RxNorm, ATC coding systems
  • Processing clinical data: Sequential events, physiological signals, clinical text, medical images
  • Implementing healthcare models: RETAIN, SafeDrug, GAMENet, StageNet, Transformer for EHR
  • Evaluating clinical models: Fairness metrics, calibration, interpretability, uncertainty quantification
在以下场景中调用该工具:
  • 处理医疗数据集:MIMIC-III、MIMIC-IV、eICU、OMOP、睡眠EEG数据、医学影像
  • 临床预测任务:死亡率预测、医院再入院预测、住院时长预测、药物推荐
  • 医疗编码:在ICD-9/10、NDC、RxNorm、ATC编码系统间进行转换
  • 临床数据处理:序列事件、生理信号、临床文本、医学影像
  • 部署医疗模型:RETAIN、SafeDrug、GAMENet、StageNet、面向EHR的Transformer
  • 临床模型评估:公平性指标、校准、可解释性、不确定性量化

Core Capabilities

核心功能

PyHealth operates through a modular 5-stage pipeline optimized for healthcare AI:
  1. Data Loading: Access 10+ healthcare datasets with standardized interfaces
  2. Task Definition: Apply 20+ predefined clinical prediction tasks or create custom tasks
  3. Model Selection: Choose from 33+ models (baselines, deep learning, healthcare-specific)
  4. Training: Train with automatic checkpointing, monitoring, and evaluation
  5. Deployment: Calibrate, interpret, and validate for clinical use
Performance: 3x faster than pandas for healthcare data processing
PyHealth通过针对医疗AI优化的模块化5阶段流程运行:
  1. 数据加载:通过标准化接口访问10+个医疗数据集
  2. 任务定义:应用20+个预定义临床预测任务或创建自定义任务
  3. 模型选择:从33+个模型(基线模型、深度学习模型、医疗专用模型)中选择
  4. 模型训练:支持自动 checkpoint、监控和评估的训练流程
  5. 部署上线:针对临床使用进行校准、解释和验证
性能表现:医疗数据处理速度比pandas快3倍

Quick Start Workflow

快速开始流程

python
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import mortality_prediction_mimic4_fn
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import Transformer
from pyhealth.trainer import Trainer
python
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import mortality_prediction_mimic4_fn
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import Transformer
from pyhealth.trainer import Trainer

1. Load dataset and set task

1. 加载数据集并设置任务

dataset = MIMIC4Dataset(root="/path/to/data") sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
dataset = MIMIC4Dataset(root="/path/to/data") sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)

2. Split data

2. 划分数据

train, val, test = split_by_patient(sample_dataset, [0.7, 0.1, 0.2])
train, val, test = split_by_patient(sample_dataset, [0.7, 0.1, 0.2])

3. Create data loaders

3. 创建数据加载器

train_loader = get_dataloader(train, batch_size=64, shuffle=True) val_loader = get_dataloader(val, batch_size=64, shuffle=False) test_loader = get_dataloader(test, batch_size=64, shuffle=False)
train_loader = get_dataloader(train, batch_size=64, shuffle=True) val_loader = get_dataloader(val, batch_size=64, shuffle=False) test_loader = get_dataloader(test, batch_size=64, shuffle=False)

4. Initialize and train model

4. 初始化并训练模型

model = Transformer( dataset=sample_dataset, feature_keys=["diagnoses", "medications"], mode="binary", embedding_dim=128 )
trainer = Trainer(model=model, device="cuda") trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, epochs=50, monitor="pr_auc_score" )
model = Transformer( dataset=sample_dataset, feature_keys=["diagnoses", "medications"], mode="binary", embedding_dim=128 )
trainer = Trainer(model=model, device="cuda") trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, epochs=50, monitor="pr_auc_score" )

5. Evaluate

5. 评估模型

results = trainer.evaluate(test_loader)
undefined
results = trainer.evaluate(test_loader)
undefined

Detailed Documentation

详细文档

This skill includes comprehensive reference documentation organized by functionality. Read specific reference files as needed:
该工具包含按功能分类的完整参考文档,可根据需求阅读特定参考文件:

1. Datasets and Data Structures

1. 数据集与数据结构

File:
references/datasets.md
Read when:
  • Loading healthcare datasets (MIMIC, eICU, OMOP, sleep EEG, etc.)
  • Understanding Event, Patient, Visit data structures
  • Processing different data types (EHR, signals, images, text)
  • Splitting data for training/validation/testing
  • Working with SampleDataset for task-specific formatting
Key Topics:
  • Core data structures (Event, Patient, Visit)
  • 10+ available datasets (EHR, physiological signals, imaging, text)
  • Data loading and iteration
  • Train/val/test splitting strategies
  • Performance optimization for large datasets
文件
references/datasets.md
阅读场景
  • 加载医疗数据集(MIMIC、eICU、OMOP、睡眠EEG等)
  • 理解Event、Patient、Visit数据结构
  • 处理不同类型数据(EHR、信号、影像、文本)
  • 划分训练/验证/测试数据集
  • 使用SampleDataset进行任务特定格式处理
核心主题
  • 核心数据结构(Event、Patient、Visit)
  • 10+个可用数据集(EHR、生理信号、影像、文本)
  • 数据加载与迭代
  • 训练/验证/测试划分策略
  • 大型数据集的性能优化

2. Medical Coding Translation

2. 医疗编码转换

File:
references/medical_coding.md
Read when:
  • Translating between medical coding systems
  • Working with diagnosis codes (ICD-9-CM, ICD-10-CM, CCS)
  • Processing medication codes (NDC, RxNorm, ATC)
  • Standardizing procedure codes (ICD-9-PROC, ICD-10-PROC)
  • Grouping codes into clinical categories
  • Handling hierarchical drug classifications
Key Topics:
  • InnerMap for within-system lookups
  • CrossMap for cross-system translation
  • Supported coding systems (ICD, NDC, ATC, CCS, RxNorm)
  • Code standardization and hierarchy traversal
  • Medication classification by therapeutic class
  • Integration with datasets
文件
references/medical_coding.md
阅读场景
  • 在不同医疗编码系统间进行转换
  • 处理诊断编码(ICD-9-CM、ICD-10-CM、CCS)
  • 处理药物编码(NDC、RxNorm、ATC)
  • 标准化手术编码(ICD-9-PROC、ICD-10-PROC)
  • 将编码分组为临床类别
  • 处理层级药物分类
核心主题
  • 系统内查询工具InnerMap
  • 跨系统转换工具CrossMap
  • 支持的编码系统(ICD、NDC、ATC、CCS、RxNorm)
  • 编码标准化与层级遍历
  • 按治疗类别进行药物分类
  • 与数据集的集成

3. Clinical Prediction Tasks

3. 临床预测任务

File:
references/tasks.md
Read when:
  • Defining clinical prediction objectives
  • Using predefined tasks (mortality, readmission, drug recommendation)
  • Working with EHR, signal, imaging, or text-based tasks
  • Creating custom prediction tasks
  • Setting up input/output schemas for models
  • Applying task-specific filtering logic
Key Topics:
  • 20+ predefined clinical tasks
  • EHR tasks (mortality, readmission, length of stay, drug recommendation)
  • Signal tasks (sleep staging, EEG analysis, seizure detection)
  • Imaging tasks (COVID-19 chest X-ray classification)
  • Text tasks (medical coding, specialty classification)
  • Custom task creation patterns
文件
references/tasks.md
阅读场景
  • 定义临床预测目标
  • 使用预定义任务(死亡率、再入院率、药物推荐)
  • 处理基于EHR、信号、影像或文本的任务
  • 创建自定义预测任务
  • 为模型设置输入/输出模式
  • 应用任务特定的过滤逻辑
核心主题
  • 20+个预定义临床任务
  • EHR任务(死亡率、再入院率、住院时长、药物推荐)
  • 信号任务(睡眠分期、EEG分析、癫痫检测)
  • 影像任务(COVID-19胸片分类)
  • 文本任务(医疗编码、专科分类)
  • 自定义任务创建模式

4. Models and Architectures

4. 模型与架构

File:
references/models.md
Read when:
  • Selecting models for clinical prediction
  • Understanding model architectures and capabilities
  • Choosing between general-purpose and healthcare-specific models
  • Implementing interpretable models (RETAIN, AdaCare)
  • Working with medication recommendation (SafeDrug, GAMENet)
  • Using graph neural networks for healthcare
  • Configuring model hyperparameters
Key Topics:
  • 33+ available models
  • General-purpose: Logistic Regression, MLP, CNN, RNN, Transformer, GNN
  • Healthcare-specific: RETAIN, SafeDrug, GAMENet, StageNet, AdaCare
  • Model selection by task type and data type
  • Interpretability considerations
  • Computational requirements
  • Hyperparameter tuning guidelines
文件
references/models.md
阅读场景
  • 为临床预测选择模型
  • 理解模型架构与功能
  • 在通用模型与医疗专用模型间选择
  • 部署可解释模型(RETAIN、AdaCare)
  • 处理药物推荐任务(SafeDrug、GAMENet)
  • 使用图神经网络(GNN)解决医疗问题
  • 配置模型超参数
核心主题
  • 33+个可用模型
  • 通用模型:Logistic Regression、MLP、CNN、RNN、Transformer、GNN
  • 医疗专用模型:RETAIN、SafeDrug、GAMENet、StageNet、AdaCare
  • 按任务类型和数据类型选择模型
  • 可解释性考量
  • 计算资源需求
  • 超参数调优指南

5. Data Preprocessing

5. 数据预处理

File:
references/preprocessing.md
Read when:
  • Preprocessing clinical data for models
  • Handling sequential events and time-series data
  • Processing physiological signals (EEG, ECG)
  • Normalizing lab values and vital signs
  • Preparing labels for different task types
  • Building feature vocabularies
  • Managing missing data and outliers
Key Topics:
  • 15+ processor types
  • Sequence processing (padding, truncation)
  • Signal processing (filtering, segmentation)
  • Feature extraction and encoding
  • Label processors (binary, multi-class, multi-label, regression)
  • Text and image preprocessing
  • Common preprocessing workflows
文件
references/preprocessing.md
阅读场景
  • 为模型预处理临床数据
  • 处理序列事件与时间序列数据
  • 处理生理信号(EEG、ECG)
  • 标准化实验室指标与生命体征
  • 为不同任务类型准备标签
  • 构建特征词汇表
  • 处理缺失数据与异常值
核心主题
  • 15+种处理器类型
  • 序列处理(填充、截断)
  • 信号处理(过滤、分割)
  • 特征提取与编码
  • 标签处理器(二分类、多分类、多标签、回归)
  • 文本与影像预处理
  • 常见预处理流程

6. Training and Evaluation

6. 训练与评估

File:
references/training_evaluation.md
Read when:
  • Training models with the Trainer class
  • Evaluating model performance
  • Computing clinical metrics
  • Assessing model fairness across demographics
  • Calibrating predictions for reliability
  • Quantifying prediction uncertainty
  • Interpreting model predictions
  • Preparing models for clinical deployment
Key Topics:
  • Trainer class (train, evaluate, inference)
  • Metrics for binary, multi-class, multi-label, regression tasks
  • Fairness metrics for bias assessment
  • Calibration methods (Platt scaling, temperature scaling)
  • Uncertainty quantification (conformal prediction, MC dropout)
  • Interpretability tools (attention visualization, SHAP, ChEFER)
  • Complete training pipeline example
文件
references/training_evaluation.md
阅读场景
  • 使用Trainer类训练模型
  • 评估模型性能
  • 计算临床指标
  • 评估不同人群的模型公平性
  • 校准预测结果以提升可靠性
  • 量化预测不确定性
  • 解释模型预测结果
  • 为临床部署准备模型
核心主题
  • Trainer类(训练、评估、推理)
  • 二分类、多分类、多标签、回归任务的指标
  • 用于偏差评估的公平性指标
  • 校准方法(Platt缩放、温度缩放)
  • 不确定性量化( conformal prediction、MC dropout)
  • 可解释性工具(注意力可视化、SHAP、ChEFER)
  • 完整训练流程示例

Installation

安装方式

bash
uv pip install pyhealth
Requirements:
  • Python ≥ 3.7
  • PyTorch ≥ 1.8
  • NumPy, pandas, scikit-learn
bash
uv pip install pyhealth
系统要求
  • Python ≥ 3.7
  • PyTorch ≥ 1.8
  • NumPy、pandas、scikit-learn

Common Use Cases

常见使用案例

Use Case 1: ICU Mortality Prediction

案例1:ICU死亡率预测

Objective: Predict patient mortality in intensive care unit
Approach:
  1. Load MIMIC-IV dataset → Read
    references/datasets.md
  2. Apply mortality prediction task → Read
    references/tasks.md
  3. Select interpretable model (RETAIN) → Read
    references/models.md
  4. Train and evaluate → Read
    references/training_evaluation.md
  5. Interpret predictions for clinical use → Read
    references/training_evaluation.md
目标:预测重症监护病房患者的死亡率
实现步骤
  1. 加载MIMIC-IV数据集 → 阅读
    references/datasets.md
  2. 应用死亡率预测任务 → 阅读
    references/tasks.md
  3. 选择可解释模型(RETAIN) → 阅读
    references/models.md
  4. 训练并评估模型 → 阅读
    references/training_evaluation.md
  5. 解释预测结果以用于临床 → 阅读
    references/training_evaluation.md

Use Case 2: Safe Medication Recommendation

案例2:安全药物推荐

Objective: Recommend medications while avoiding drug-drug interactions
Approach:
  1. Load EHR dataset (MIMIC-IV or OMOP) → Read
    references/datasets.md
  2. Apply drug recommendation task → Read
    references/tasks.md
  3. Use SafeDrug model with DDI constraints → Read
    references/models.md
  4. Preprocess medication codes → Read
    references/medical_coding.md
  5. Evaluate with multi-label metrics → Read
    references/training_evaluation.md
目标:在避免药物相互作用的前提下推荐药物
实现步骤
  1. 加载EHR数据集(MIMIC-IV或OMOP) → 阅读
    references/datasets.md
  2. 应用药物推荐任务 → 阅读
    references/tasks.md
  3. 使用带DDI约束的SafeDrug模型 → 阅读
    references/models.md
  4. 预处理药物编码 → 阅读
    references/medical_coding.md
  5. 使用多标签指标评估 → 阅读
    references/training_evaluation.md

Use Case 3: Hospital Readmission Prediction

案例3:医院再入院预测

Objective: Identify patients at risk of 30-day readmission
Approach:
  1. Load multi-site EHR data (eICU or OMOP) → Read
    references/datasets.md
  2. Apply readmission prediction task → Read
    references/tasks.md
  3. Handle class imbalance in preprocessing → Read
    references/preprocessing.md
  4. Train Transformer model → Read
    references/models.md
  5. Calibrate predictions and assess fairness → Read
    references/training_evaluation.md
目标:识别30天内有再入院风险的患者
实现步骤
  1. 加载多站点EHR数据(eICU或OMOP) → 阅读
    references/datasets.md
  2. 应用再入院预测任务 → 阅读
    references/tasks.md
  3. 在预处理中处理类别不平衡问题 → 阅读
    references/preprocessing.md
  4. 训练Transformer模型 → 阅读
    references/models.md
  5. 校准预测结果并评估公平性 → 阅读
    references/training_evaluation.md

Use Case 4: Sleep Disorder Diagnosis

案例4:睡眠障碍诊断

Objective: Classify sleep stages from EEG signals
Approach:
  1. Load sleep EEG dataset (SleepEDF, SHHS) → Read
    references/datasets.md
  2. Apply sleep staging task → Read
    references/tasks.md
  3. Preprocess EEG signals (filtering, segmentation) → Read
    references/preprocessing.md
  4. Train CNN or RNN model → Read
    references/models.md
  5. Evaluate per-stage performance → Read
    references/training_evaluation.md
目标:通过EEG信号分类睡眠阶段
实现步骤
  1. 加载睡眠EEG数据集(SleepEDF、SHHS) → 阅读
    references/datasets.md
  2. 应用睡眠分期任务 → 阅读
    references/tasks.md
  3. 预处理EEG信号(过滤、分割) → 阅读
    references/preprocessing.md
  4. 训练CNN或RNN模型 → 阅读
    references/models.md
  5. 评估各阶段性能 → 阅读
    references/training_evaluation.md

Use Case 5: Medical Code Translation

案例5:医疗编码转换

Objective: Standardize diagnoses across different coding systems
Approach:
  1. Read
    references/medical_coding.md
    for comprehensive guidance
  2. Use CrossMap to translate between ICD-9, ICD-10, CCS
  3. Group codes into clinically meaningful categories
  4. Integrate with dataset processing
目标:在不同编码系统间标准化诊断信息
实现步骤
  1. 阅读
    references/medical_coding.md
    获取全面指导
  2. 使用CrossMap在ICD-9、ICD-10、CCS间进行转换
  3. 将编码分组为具有临床意义的类别
  4. 与数据集处理流程集成

Use Case 6: Clinical Text to ICD Coding

案例6:临床文本转ICD编码

Objective: Automatically assign ICD codes from clinical notes
Approach:
  1. Load MIMIC-III with clinical text → Read
    references/datasets.md
  2. Apply ICD coding task → Read
    references/tasks.md
  3. Preprocess clinical text → Read
    references/preprocessing.md
  4. Use TransformersModel (ClinicalBERT) → Read
    references/models.md
  5. Evaluate with multi-label metrics → Read
    references/training_evaluation.md
目标:自动从临床笔记中分配ICD编码
实现步骤
  1. 加载带临床文本的MIMIC-III数据集 → 阅读
    references/datasets.md
  2. 应用ICD编码任务 → 阅读
    references/tasks.md
  3. 预处理临床文本 → 阅读
    references/preprocessing.md
  4. 使用TransformersModel(ClinicalBERT) → 阅读
    references/models.md
  5. 使用多标签指标评估 → 阅读
    references/training_evaluation.md

Best Practices

最佳实践

Data Handling

数据处理

  1. Always split by patient: Prevent data leakage by ensuring no patient appears in multiple splits
    python
    from pyhealth.datasets import split_by_patient
    train, val, test = split_by_patient(dataset, [0.7, 0.1, 0.2])
  2. Check dataset statistics: Understand your data before modeling
    python
    print(dataset.stats())  # Patients, visits, events, code distributions
  3. Use appropriate preprocessing: Match processors to data types (see
    references/preprocessing.md
    )
  1. 始终按患者划分数据集:确保同一患者不会出现在多个数据集中,防止数据泄露
    python
    from pyhealth.datasets import split_by_patient
    train, val, test = split_by_patient(dataset, [0.7, 0.1, 0.2])
  2. 检查数据集统计信息:在建模前理解数据情况
    python
    print(dataset.stats())  # 患者数、就诊数、事件数、编码分布
  3. 使用合适的预处理方法:根据数据类型匹配处理器(详见
    references/preprocessing.md

Model Development

模型开发

  1. Start with baselines: Establish baseline performance with simple models
    • Logistic Regression for binary/multi-class tasks
    • MLP for initial deep learning baseline
  2. Choose task-appropriate models:
    • Interpretability needed → RETAIN, AdaCare
    • Drug recommendation → SafeDrug, GAMENet
    • Long sequences → Transformer
    • Graph relationships → GNN
  3. Monitor validation metrics: Use appropriate metrics for task and handle class imbalance
    • Binary classification: AUROC, AUPRC (especially for rare events)
    • Multi-class: macro-F1 (for imbalanced), weighted-F1
    • Multi-label: Jaccard, example-F1
    • Regression: MAE, RMSE
  1. 从基线模型开始:使用简单模型建立基线性能
    • 二分类/多分类任务:Logistic Regression
    • 初始深度学习基线:MLP
  2. 选择适合任务的模型
    • 需要可解释性 → RETAIN、AdaCare
    • 药物推荐 → SafeDrug、GAMENet
    • 长序列数据 → Transformer
    • 图关系数据 → GNN
  3. 监控验证指标:为任务选择合适的指标并处理类别不平衡问题
    • 二分类:AUROC、AUPRC(尤其针对稀有事件)
    • 多分类:macro-F1(针对不平衡数据)、weighted-F1
    • 多标签:Jaccard、example-F1
    • 回归:MAE、RMSE

Clinical Deployment

临床部署

  1. Calibrate predictions: Ensure probabilities are reliable (see
    references/training_evaluation.md
    )
  2. Assess fairness: Evaluate across demographic groups to detect bias
  3. Quantify uncertainty: Provide confidence estimates for predictions
  4. Interpret predictions: Use attention weights, SHAP, or ChEFER for clinical trust
  5. Validate thoroughly: Use held-out test sets from different time periods or sites
  1. 校准预测结果:确保概率结果可靠(详见
    references/training_evaluation.md
  2. 评估公平性:在不同人群中评估以检测偏差
  3. 量化不确定性:为预测结果提供置信度估计
  4. 解释预测结果:使用注意力权重、SHAP或ChEFER提升临床信任度
  5. 充分验证:使用来自不同时间段或站点的独立测试集

Limitations and Considerations

局限性与注意事项

Data Requirements

数据要求

  • Large datasets: Deep learning models require sufficient data (thousands of patients)
  • Data quality: Missing data and coding errors impact performance
  • Temporal consistency: Ensure train/test split respects temporal ordering when needed
  • 大型数据集:深度学习模型需要充足的数据(数千名患者)
  • 数据质量:缺失数据和编码错误会影响性能
  • 时间一致性:必要时确保训练/测试划分遵循时间顺序

Clinical Validation

临床验证

  • External validation: Test on data from different hospitals/systems
  • Prospective evaluation: Validate in real clinical settings before deployment
  • Clinical review: Have clinicians review predictions and interpretations
  • Ethical considerations: Address privacy (HIPAA/GDPR), fairness, and safety
  • 外部验证:在来自不同医院/系统的数据上测试
  • 前瞻性评估:在实际临床环境中验证后再部署
  • 临床审核:由临床医生审核预测结果与解释
  • 伦理考量:解决隐私(HIPAA/GDPR)、公平性和安全性问题

Computational Resources

计算资源

  • GPU recommended: For training deep learning models efficiently
  • Memory requirements: Large datasets may require 16GB+ RAM
  • Storage: Healthcare datasets can be 10s-100s of GB
  • 推荐使用GPU:高效训练深度学习模型
  • 内存需求:大型数据集可能需要16GB以上内存
  • 存储需求:医疗数据集大小可达数十至数百GB

Troubleshooting

故障排除

Common Issues

常见问题

ImportError for dataset:
  • Ensure dataset files are downloaded and path is correct
  • Check PyHealth version compatibility
Out of memory:
  • Reduce batch size
  • Reduce sequence length (
    max_seq_length
    )
  • Use gradient accumulation
  • Process data in chunks
Poor performance:
  • Check class imbalance and use appropriate metrics (AUPRC vs AUROC)
  • Verify preprocessing (normalization, missing data handling)
  • Increase model capacity or training epochs
  • Check for data leakage in train/test split
Slow training:
  • Use GPU (
    device="cuda"
    )
  • Increase batch size (if memory allows)
  • Reduce sequence length
  • Use more efficient model (CNN vs Transformer)
数据集导入错误
  • 确保数据集文件已下载且路径正确
  • 检查PyHealth版本兼容性
内存不足
  • 减小批量大小
  • 减小序列长度(
    max_seq_length
  • 使用梯度累积
  • 分块处理数据
性能不佳
  • 检查类别不平衡问题并使用合适的指标(AUPRC vs AUROC)
  • 验证预处理步骤(标准化、缺失数据处理)
  • 增加模型容量或训练轮数
  • 检查训练/测试划分中的数据泄露问题
训练速度慢
  • 使用GPU(
    device="cuda"
  • 增加批量大小(如果内存允许)
  • 减小序列长度
  • 使用更高效的模型(如CNN替代Transformer)

Getting Help

获取帮助

Example: Complete Workflow

示例:完整工作流程

python
undefined
python
undefined

Complete mortality prediction pipeline

完整的死亡率预测流程

from pyhealth.datasets import MIMIC4Dataset from pyhealth.tasks import mortality_prediction_mimic4_fn from pyhealth.datasets import split_by_patient, get_dataloader from pyhealth.models import RETAIN from pyhealth.trainer import Trainer
from pyhealth.datasets import MIMIC4Dataset from pyhealth.tasks import mortality_prediction_mimic4_fn from pyhealth.datasets import split_by_patient, get_dataloader from pyhealth.models import RETAIN from pyhealth.trainer import Trainer

1. Load dataset

1. 加载数据集

print("Loading MIMIC-IV dataset...") dataset = MIMIC4Dataset(root="/data/mimic4") print(dataset.stats())
print("Loading MIMIC-IV dataset...") dataset = MIMIC4Dataset(root="/data/mimic4") print(dataset.stats())

2. Define task

2. 定义任务

print("Setting mortality prediction task...") sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn) print(f"Generated {len(sample_dataset)} samples")
print("Setting mortality prediction task...") sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn) print(f"Generated {len(sample_dataset)} samples")

3. Split data (by patient to prevent leakage)

3. 划分数据(按患者划分以防止数据泄露)

print("Splitting data...") train_ds, val_ds, test_ds = split_by_patient( sample_dataset, ratios=[0.7, 0.1, 0.2], seed=42 )
print("Splitting data...") train_ds, val_ds, test_ds = split_by_patient( sample_dataset, ratios=[0.7, 0.1, 0.2], seed=42 )

4. Create data loaders

4. 创建数据加载器

train_loader = get_dataloader(train_ds, batch_size=64, shuffle=True) val_loader = get_dataloader(val_ds, batch_size=64) test_loader = get_dataloader(test_ds, batch_size=64)
train_loader = get_dataloader(train_ds, batch_size=64, shuffle=True) val_loader = get_dataloader(val_ds, batch_size=64) test_loader = get_dataloader(test_ds, batch_size=64)

5. Initialize interpretable model

5. 初始化可解释模型

print("Initializing RETAIN model...") model = RETAIN( dataset=sample_dataset, feature_keys=["diagnoses", "procedures", "medications"], mode="binary", embedding_dim=128, hidden_dim=128 )
print("Initializing RETAIN model...") model = RETAIN( dataset=sample_dataset, feature_keys=["diagnoses", "procedures", "medications"], mode="binary", embedding_dim=128, hidden_dim=128 )

6. Train model

6. 训练模型

print("Training model...") trainer = Trainer(model=model, device="cuda") trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, epochs=50, optimizer="Adam", learning_rate=1e-3, weight_decay=1e-5, monitor="pr_auc_score", # Use AUPRC for imbalanced data monitor_criterion="max", save_path="./checkpoints/mortality_retain" )
print("Training model...") trainer = Trainer(model=model, device="cuda") trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, epochs=50, optimizer="Adam", learning_rate=1e-3, weight_decay=1e-5, monitor="pr_auc_score", # 针对不平衡数据使用AUPRC monitor_criterion="max", save_path="./checkpoints/mortality_retain" )

7. Evaluate on test set

7. 在测试集上评估

print("Evaluating on test set...") test_results = trainer.evaluate( test_loader, metrics=["accuracy", "precision", "recall", "f1_score", "roc_auc_score", "pr_auc_score"] )
print("\nTest Results:") for metric, value in test_results.items(): print(f" {metric}: {value:.4f}")
print("Evaluating on test set...") test_results = trainer.evaluate( test_loader, metrics=["accuracy", "precision", "recall", "f1_score", "roc_auc_score", "pr_auc_score"] )
print("\nTest Results:") for metric, value in test_results.items(): print(f" {metric}: {value:.4f}")

8. Get predictions with attention for interpretation

8. 获取带注意力权重的预测结果用于解释

predictions = trainer.inference( test_loader, additional_outputs=["visit_attention", "feature_attention"], return_patient_ids=True )
predictions = trainer.inference( test_loader, additional_outputs=["visit_attention", "feature_attention"], return_patient_ids=True )

9. Analyze a high-risk patient

9. 分析高风险患者

high_risk_idx = predictions["y_pred"].argmax() patient_id = predictions["patient_ids"][high_risk_idx] visit_attn = predictions["visit_attention"][high_risk_idx] feature_attn = predictions["feature_attention"][high_risk_idx]
print(f"\nHigh-risk patient: {patient_id}") print(f"Risk score: {predictions['y_pred'][high_risk_idx]:.3f}") print(f"Most influential visit: {visit_attn.argmax()}") print(f"Most important features: {feature_attn[visit_attn.argmax()].argsort()[-5:]}")
high_risk_idx = predictions["y_pred"].argmax() patient_id = predictions["patient_ids"][high_risk_idx] visit_attn = predictions["visit_attention"][high_risk_idx] feature_attn = predictions["feature_attention"][high_risk_idx]
print(f"\nHigh-risk patient: {patient_id}") print(f"Risk score: {predictions['y_pred'][high_risk_idx]:.3f}") print(f"Most influential visit: {visit_attn.argmax()}") print(f"Most important features: {feature_attn[visit_attn.argmax()].argsort()[-5:]}")

10. Save model for deployment

10. 保存模型用于部署

trainer.save("./models/mortality_retain_final.pt") print("\nModel saved successfully!")
undefined
trainer.save("./models/mortality_retain_final.pt") print("\nModel saved successfully!")
undefined

Resources

资源

For detailed information on each component, refer to the comprehensive reference files in the
references/
directory:
  • datasets.md: Data structures, loading, and splitting (4,500 words)
  • medical_coding.md: Code translation and standardization (3,800 words)
  • tasks.md: Clinical prediction tasks and custom task creation (4,200 words)
  • models.md: Model architectures and selection guidelines (5,100 words)
  • preprocessing.md: Data processors and preprocessing workflows (4,600 words)
  • training_evaluation.md: Training, metrics, calibration, interpretability (5,900 words)
Total comprehensive documentation: ~28,000 words across modular reference files.
如需了解各组件的详细信息,请参考
references/
目录下的完整参考文件:
  • datasets.md:数据结构、加载与划分(4500词)
  • medical_coding.md:编码转换与标准化(3800词)
  • tasks.md:临床预测任务与自定义任务创建(4200词)
  • models.md:模型架构与选择指南(5100词)
  • preprocessing.md:数据处理器与预处理流程(4600词)
  • training_evaluation.md:训练、指标、校准与可解释性(5900词)
总文档量:模块化参考文件总计约28000词。