ml-pipeline-automation

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

ML Pipeline Automation

机器学习流水线自动化

Orchestrate end-to-end machine learning workflows from data ingestion to production deployment with production-tested Airflow, Kubeflow, and MLflow patterns.
借助经过生产验证的Airflow、Kubeflow和MLflow实践,编排从数据采集到生产部署的端到端机器学习工作流。

When to Use This Skill

何时使用该技能

Load this skill when:
  • Building ML Pipelines: Orchestrating data → train → deploy workflows
  • Scheduling Retraining: Setting up automated model retraining schedules
  • Experiment Tracking: Tracking experiments, parameters, metrics across runs
  • MLOps Implementation: Building reproducible, monitored ML infrastructure
  • Workflow Orchestration: Managing complex multi-step ML workflows
  • Model Registry: Managing model versions and deployment lifecycle
在以下场景中使用本技能:
  • 构建ML流水线:编排数据→训练→部署的完整工作流
  • 调度重训练任务:设置模型自动重训练的调度计划
  • 实验追踪:跨运行周期追踪实验、参数与指标
  • 落地MLOps:构建可复现、可监控的机器学习基础设施
  • 工作流编排:管理复杂的多步骤ML工作流
  • 模型注册表:管理模型版本与部署生命周期

Quick Start: ML Pipeline in 5 Steps

快速上手:5步搭建ML流水线

bash
undefined
bash
undefined

1. Install Airflow and MLflow (check for latest versions at time of use)

1. 安装Airflow和MLflow(使用时请确认最新版本)

pip install apache-airflow==3.1.5 mlflow==3.7.0
pip install apache-airflow==3.1.5 mlflow==3.7.0

Note: These versions are current as of December 2025

注意:以上版本为2025年12月的当前稳定版

Check PyPI for latest stable releases: https://pypi.org/project/apache-airflow/

可在PyPI查看最新稳定版本:https://pypi.org/project/apache-airflow/

2. Initialize Airflow database

2. 初始化Airflow数据库

airflow db init
airflow db init

3. Create DAG file: dags/ml_training_pipeline.py

3. 创建DAG文件:dags/ml_training_pipeline.py

cat > dags/ml_training_pipeline.py << 'EOF' from airflow import DAG from airflow.operators.python import PythonOperator from datetime import datetime, timedelta
default_args = { 'owner': 'ml-team', 'retries': 2, 'retry_delay': timedelta(minutes=5) }
dag = DAG( 'ml_training_pipeline', default_args=default_args, schedule_interval='@daily', start_date=datetime(2025, 1, 1) )
def train_model(**context): import mlflow import mlflow.sklearn from sklearn.ensemble import RandomForestClassifier from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment('iris-training')

with mlflow.start_run():
    model = RandomForestClassifier(n_estimators=100)
    model.fit(X_train, y_train)

    accuracy = model.score(X_test, y_test)
    mlflow.log_metric('accuracy', accuracy)
    mlflow.sklearn.log_model(model, 'model')
train = PythonOperator( task_id='train_model', python_callable=train_model, dag=dag ) EOF
cat > dags/ml_training_pipeline.py << 'EOF' from airflow import DAG from airflow.operators.python import PythonOperator from datetime import datetime, timedelta
default_args = { 'owner': 'ml-team', 'retries': 2, 'retry_delay': timedelta(minutes=5) }
dag = DAG( 'ml_training_pipeline', default_args=default_args, schedule_interval='@daily', start_date=datetime(2025, 1, 1) )
def train_model(**context): import mlflow import mlflow.sklearn from sklearn.ensemble import RandomForestClassifier from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment('iris-training')

with mlflow.start_run():
    model = RandomForestClassifier(n_estimators=100)
    model.fit(X_train, y_train)

    accuracy = model.score(X_test, y_test)
    mlflow.log_metric('accuracy', accuracy)
    mlflow.sklearn.log_model(model, 'model')
train = PythonOperator( task_id='train_model', python_callable=train_model, dag=dag ) EOF

4. Start Airflow scheduler and webserver

4. 启动Airflow调度器与Web服务

airflow scheduler & airflow webserver --port 8080 &
airflow scheduler & airflow webserver --port 8080 &

5. Trigger pipeline

5. 触发流水线

airflow dags trigger ml_training_pipeline
airflow dags trigger ml_training_pipeline

访问UI界面:http://localhost:8080


**Result**: Working ML pipeline with experiment tracking in under 5 minutes.

**结果**:在5分钟内搭建起带有实验追踪功能的可用ML流水线。

Core Concepts

核心概念

Pipeline Stages

流水线阶段

  1. Data Collection → Fetch raw data from sources
  2. Data Validation → Check schema, quality, distributions
  3. Feature Engineering → Transform raw data to features
  4. Model Training → Train with hyperparameter tuning
  5. Model Evaluation → Validate performance on test set
  6. Model Deployment → Push to production if metrics pass
  7. Monitoring → Track drift, performance in production
  1. 数据采集 → 从数据源获取原始数据
  2. 数据验证 → 检查数据 schema、质量与分布
  3. 特征工程 → 将原始数据转换为特征
  4. 模型训练 → 结合超参数调优进行模型训练
  5. 模型评估 → 在测试集上验证模型性能
  6. 模型部署 → 若指标达标则推送至生产环境
  7. 监控 → 追踪生产环境中的数据漂移与模型性能

Orchestration Tools Comparison

编排工具对比

ToolBest ForStrengths
AirflowGeneral ML workflowsMature, flexible, Python-native
KubeflowKubernetes-native MLContainer-based, scalable
MLflowExperiment trackingModel registry, versioning
PrefectModern Python workflowsDynamic DAGs, native caching
DagsterAsset-oriented pipelinesData-aware, testable
工具适用场景优势
Airflow通用ML工作流成熟稳定、灵活、原生支持Python
KubeflowKubernetes原生ML工作流基于容器、可扩展性强
MLflow实验追踪提供模型注册表与版本管理
Prefect现代Python工作流动态DAG、原生缓存机制
Dagster面向资产的流水线数据感知、可测试性强

Basic Airflow DAG

基础Airflow DAG示例

python
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
import logging

logger = logging.getLogger(__name__)

default_args = {
    'owner': 'ml-team',
    'depends_on_past': False,
    'email': ['alerts@example.com'],
    'email_on_failure': True,
    'retries': 2,
    'retry_delay': timedelta(minutes=5)
}

dag = DAG(
    'ml_training_pipeline',
    default_args=default_args,
    description='End-to-end ML training pipeline',
    schedule_interval='@daily',
    start_date=datetime(2025, 1, 1),
    catchup=False
)

def validate_data(**context):
    """Validate input data quality."""
    import pandas as pd

    data_path = "/data/raw/latest.csv"
    df = pd.read_csv(data_path)

    # Validation checks
    assert len(df) > 1000, f"Insufficient data: {len(df)} rows"
    assert df.isnull().sum().sum() < len(df) * 0.1, "Too many nulls"

    context['ti'].xcom_push(key='data_path', value=data_path)
    logger.info(f"Data validation passed: {len(df)} rows")

def train_model(**context):
    """Train ML model with MLflow tracking."""
    import mlflow
    import mlflow.sklearn
    from sklearn.ensemble import RandomForestClassifier

    data_path = context['ti'].xcom_pull(key='data_path', task_ids='validate_data')

    mlflow.set_tracking_uri('http://mlflow:5000')
    mlflow.set_experiment('production-training')

    with mlflow.start_run():
        # Training logic here
        model = RandomForestClassifier(n_estimators=100)
        # model.fit(X, y) ...

        mlflow.log_param('n_estimators', 100)
        mlflow.sklearn.log_model(model, 'model')

validate = PythonOperator(
    task_id='validate_data',
    python_callable=validate_data,
    dag=dag
)

train = PythonOperator(
    task_id='train_model',
    python_callable=train_model,
    dag=dag
)

validate >> train
python
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
import logging

logger = logging.getLogger(__name__)

default_args = {
    'owner': 'ml-team',
    'depends_on_past': False,
    'email': ['alerts@example.com'],
    'email_on_failure': True,
    'retries': 2,
    'retry_delay': timedelta(minutes=5)
}

dag = DAG(
    'ml_training_pipeline',
    default_args=default_args,
    description='端到端ML训练流水线',
    schedule_interval='@daily',
    start_date=datetime(2025, 1, 1),
    catchup=False
)

def validate_data(**context):
    """验证输入数据质量。"""
    import pandas as pd

    data_path = "/data/raw/latest.csv"
    df = pd.read_csv(data_path)

    # 验证检查
    assert len(df) > 1000, f"数据量不足:仅{len(df)}行"
    assert df.isnull().sum().sum() < len(df) * 0.1, "空值占比过高"

    context['ti'].xcom_push(key='data_path', value=data_path)
    logger.info(f"数据验证通过:共{len(df)}行")

def train_model(**context):
    """结合MLflow追踪训练ML模型。"""
    import mlflow
    import mlflow.sklearn
    from sklearn.ensemble import RandomForestClassifier

    data_path = context['ti'].xcom_pull(key='data_path', task_ids='validate_data')

    mlflow.set_tracking_uri('http://mlflow:5000')
    mlflow.set_experiment('production-training')

    with mlflow.start_run():
        # 训练逻辑(此处省略)
        model = RandomForestClassifier(n_estimators=100)
        # model.fit(X, y) ...

        mlflow.log_param('n_estimators', 100)
        mlflow.sklearn.log_model(model, 'model')

validate = PythonOperator(
    task_id='validate_data',
    python_callable=validate_data,
    dag=dag
)

train = PythonOperator(
    task_id='train_model',
    python_callable=train_model,
    dag=dag
)

validate >> train

Known Issues Prevention

常见问题预防

1. Task Failures Without Alerts

1. 任务失败无告警

Problem: Pipeline fails silently, no one notices until users complain.
Solution: Configure email/Slack alerts on failure:
python
default_args = {
    'email': ['ml-team@example.com'],
    'email_on_failure': True,
    'email_on_retry': False
}

def on_failure_callback(context):
    """Send Slack alert on failure."""
    from airflow.providers.slack.operators.slack_webhook import SlackWebhookOperator

    slack_msg = f"""
    :red_circle: Task Failed: {context['task_instance'].task_id}
    DAG: {context['task_instance'].dag_id}
    Execution Date: {context['ds']}
    Error: {context.get('exception')}
    """

    SlackWebhookOperator(
        task_id='slack_alert',
        slack_webhook_conn_id='slack_webhook',
        message=slack_msg
    ).execute(context)

task = PythonOperator(
    task_id='critical_task',
    python_callable=my_function,
    on_failure_callback=on_failure_callback,
    dag=dag
)
问题:流水线静默失败,直到用户反馈才被发现。
解决方案:配置失败时的邮件/Slack告警:
python
default_args = {
    'email': ['ml-team@example.com'],
    'email_on_failure': True,
    'email_on_retry': False
}

def on_failure_callback(context):
    """任务失败时发送Slack告警。"""
    from airflow.providers.slack.operators.slack_webhook import SlackWebhookOperator

    slack_msg = f"""
    :red_circle: 任务失败:{context['task_instance'].task_id}
    DAG: {context['task_instance'].dag_id}
    执行日期:{context['ds']}
    错误信息:{context.get('exception')}
    """

    SlackWebhookOperator(
        task_id='slack_alert',
        slack_webhook_conn_id='slack_webhook',
        message=slack_msg
    ).execute(context)

task = PythonOperator(
    task_id='critical_task',
    python_callable=my_function,
    on_failure_callback=on_failure_callback,
    dag=dag
)

2. Missing XCom Data Between Tasks

2. 任务间XCom数据丢失

Problem: Task expects XCom value from previous task, gets None, crashes.
Solution: Always validate XCom pulls:
python
def process_data(**context):
    data_path = context['ti'].xcom_pull(
        key='data_path',
        task_ids='upstream_task'
    )

    if data_path is None:
        raise ValueError("No data_path from upstream_task - check XCom push")

    # Process data...
问题:任务期望从上游获取XCom值,但得到None导致崩溃。
解决方案:始终验证XCom拉取结果:
python
def process_data(**context):
    data_path = context['ti'].xcom_pull(
        key='data_path',
        task_ids='upstream_task'
    )

    if data_path is None:
        raise ValueError("未从upstream_task获取到data_path - 请检查XCom推送逻辑")

    # 数据处理...

3. DAG Not Appearing in UI

3. DAG未在UI中显示

Problem: DAG file exists in
dags/
but doesn't show in Airflow UI.
Solution: Check DAG parsing errors:
bash
undefined
问题:DAG文件存在于
dags/
目录,但未在Airflow UI中展示。
解决方案:检查DAG解析错误:
bash
undefined

Check for syntax errors

检查语法错误

python dags/my_dag.py
python dags/my_dag.py

View DAG import errors in UI

在UI中查看DAG导入错误

Navigate to: Browse → DAG Import Errors

路径:Browse → DAG Import Errors

Common fixes:

常见修复方案:

1. Ensure DAG object is defined in file

1. 确保文件中定义了DAG对象

2. Check for circular imports

2. 检查是否存在循环导入

3. Verify all dependencies installed

3. 验证所有依赖已安装

4. Fix syntax errors

4. 修复语法错误

undefined
undefined

4. Hardcoded Paths Break in Production

4. 硬编码路径在生产环境失效

Problem: Paths like
/Users/myname/data/
work locally, fail in production.
Solution: Use Airflow Variables or environment variables:
python
from airflow.models import Variable

def load_data(**context):
    # ❌ Bad: Hardcoded path
    # data_path = "/Users/myname/data/train.csv"

    # ✅ Good: Use Airflow Variable
    data_dir = Variable.get("data_directory", "/data")
    data_path = f"{data_dir}/train.csv"

    # Or use environment variable
    import os
    data_path = os.getenv("DATA_PATH", "/data/train.csv")
问题:本地路径如
/Users/myname/data/
可正常工作,但在生产环境失败。
解决方案:使用Airflow变量或环境变量:
python
from airflow.models import Variable

def load_data(**context):
    # ❌ 错误方式:硬编码路径
    # data_path = "/Users/myname/data/train.csv"

    # ✅ 正确方式:使用Airflow变量
    data_dir = Variable.get("data_directory", "/data")
    data_path = f"{data_dir}/train.csv"

    # 或使用环境变量
    import os
    data_path = os.getenv("DATA_PATH", "/data/train.csv")

5. Stuck Tasks Consume Resources

5. 停滞任务占用资源

Problem: Task hangs indefinitely, blocks worker slot, wastes resources.
Solution: Set execution_timeout on tasks:
python
from datetime import timedelta

task = PythonOperator(
    task_id='long_running_task',
    python_callable=my_function,
    execution_timeout=timedelta(hours=2),  # Kill after 2 hours
    dag=dag
)
问题:任务无限期挂起,占用工作节点资源。
解决方案:为任务设置执行超时时间:
python
from datetime import timedelta

task = PythonOperator(
    task_id='long_running_task',
    python_callable=my_function,
    execution_timeout=timedelta(hours=2),  # 2小时后终止任务
    dag=dag
)

6. No Data Validation = Bad Model Training

6. 无数据验证导致模型训练效果差

Problem: Train on corrupted/incomplete data, model performs poorly in production.
Solution: Add data quality validation tasks:
python
def validate_data_quality(**context):
    """Comprehensive data validation."""
    import pandas as pd

    df = pd.read_csv(data_path)

    # Schema validation
    required_cols = ['user_id', 'timestamp', 'feature_a', 'target']
    missing_cols = set(required_cols) - set(df.columns)
    if missing_cols:
        raise ValueError(f"Missing columns: {missing_cols}")

    # Statistical validation
    if df['target'].isnull().sum() > 0:
        raise ValueError("Target column contains nulls")

    if len(df) < 1000:
        raise ValueError(f"Insufficient data: {len(df)} rows")

    logger.info("✅ Data quality validation passed")
问题:使用损坏或不完整的数据训练模型,导致生产环境性能不佳。
解决方案:添加数据质量验证任务:
python
def validate_data_quality(**context):
    """全面的数据质量验证。"""
    import pandas as pd

    df = pd.read_csv(data_path)

    # Schema验证
    required_cols = ['user_id', 'timestamp', 'feature_a', 'target']
    missing_cols = set(required_cols) - set(df.columns)
    if missing_cols:
        raise ValueError(f"缺失列:{missing_cols}")

    # 统计验证
    if df['target'].isnull().sum() > 0:
        raise ValueError("目标列包含空值")

    if len(df) < 1000:
        raise ValueError(f"数据量不足:仅{len(df)}行")

    logger.info("✅ 数据质量验证通过")

7. Untracked Experiments = Lost Knowledge

7. 未追踪实验导致知识丢失

Problem: Can't reproduce results, don't know which hyperparameters worked.
Solution: Use MLflow for all experiments:
python
import mlflow

mlflow.set_tracking_uri('http://mlflow:5000')
mlflow.set_experiment('model-experiments')

with mlflow.start_run(run_name='rf_v1'):
    # Log ALL hyperparameters
    mlflow.log_params({
        'model_type': 'random_forest',
        'n_estimators': 100,
        'max_depth': 10,
        'random_state': 42
    })

    # Log ALL metrics
    mlflow.log_metrics({
        'train_accuracy': 0.95,
        'test_accuracy': 0.87,
        'f1_score': 0.89
    })

    # Log model
    mlflow.sklearn.log_model(model, 'model')
问题:无法复现实验结果,不清楚哪些超参数有效。
解决方案:使用MLflow追踪所有实验:
python
import mlflow

mlflow.set_tracking_uri('http://mlflow:5000')
mlflow.set_experiment('model-experiments')

with mlflow.start_run(run_name='rf_v1'):
    # 记录所有超参数
    mlflow.log_params({
        'model_type': 'random_forest',
        'n_estimators': 100,
        'max_depth': 10,
        'random_state': 42
    })

    # 记录所有指标
    mlflow.log_metrics({
        'train_accuracy': 0.95,
        'test_accuracy': 0.87,
        'f1_score': 0.89
    })

    # 记录模型
    mlflow.sklearn.log_model(model, 'model')

When to Load References

何时加载参考文档

Load reference files for detailed production implementations:
  • Airflow DAG Patterns: Load
    references/airflow-patterns.md
    when building complex DAGs with error handling, dynamic generation, sensors, task groups, or retry logic. Contains complete production DAG examples.
  • Kubeflow & MLflow Integration: Load
    references/kubeflow-mlflow.md
    when using Kubeflow Pipelines for container-native orchestration, integrating MLflow tracking, building KFP components, or managing model registry.
  • Pipeline Monitoring: Load
    references/pipeline-monitoring.md
    when implementing data quality checks, drift detection, alert configuration, or pipeline health monitoring with Prometheus.
在以下场景加载参考文件以获取详细的生产级实现方案:
  • Airflow DAG实践:当构建带有错误处理、动态生成、传感器、任务组或重试逻辑的复杂DAG时,加载
    references/airflow-patterns.md
    。该文档包含完整的生产级DAG示例。
  • Kubeflow与MLflow集成:当使用Kubeflow Pipelines进行容器原生编排、集成MLflow追踪、构建KFP组件或管理模型注册表时,加载
    references/kubeflow-mlflow.md
  • 流水线监控:当实现数据质量检查、漂移检测、告警配置或使用Prometheus进行流水线健康监控时,加载
    references/pipeline-monitoring.md

Best Practices

最佳实践

  1. Idempotent Tasks: Tasks should produce same result when re-run
  2. Atomic Operations: Each task does one thing well
  3. Version Everything: Data, code, models, dependencies
  4. Comprehensive Logging: Log all important events with context
  5. Error Handling: Fail fast with clear error messages
  6. Monitoring: Track pipeline health, data quality, model drift
  7. Testing: Test tasks independently before integrating
  8. Documentation: Document DAG purpose, task dependencies
  1. 幂等任务:任务重新运行时应产生相同结果
  2. 原子操作:每个任务专注完成一件事
  3. 版本化所有内容:数据、代码、模型、依赖均需版本化
  4. 全面日志:记录所有重要事件及上下文
  5. 错误处理:快速失败并提供清晰的错误信息
  6. 监控:追踪流水线健康状态、数据质量与模型漂移
  7. 测试:集成前独立测试每个任务
  8. 文档:记录DAG用途与任务依赖关系

Common Patterns

常见模式

Conditional Execution

条件执行

python
from airflow.operators.python import BranchPythonOperator

def choose_branch(**context):
    accuracy = context['ti'].xcom_pull(key='accuracy', task_ids='evaluate')

    if accuracy > 0.9:
        return 'deploy_to_production'
    else:
        return 'retrain_with_more_data'

branch = BranchPythonOperator(
    task_id='check_accuracy',
    python_callable=choose_branch,
    dag=dag
)

train >> evaluate >> branch >> [deploy, retrain]
python
from airflow.operators.python import BranchPythonOperator

def choose_branch(**context):
    accuracy = context['ti'].xcom_pull(key='accuracy', task_ids='evaluate')

    if accuracy > 0.9:
        return 'deploy_to_production'
    else:
        return 'retrain_with_more_data'

branch = BranchPythonOperator(
    task_id='check_accuracy',
    python_callable=choose_branch,
    dag=dag
)

train >> evaluate >> branch >> [deploy, retrain]

Parallel Training

并行训练

python
from airflow.utils.task_group import TaskGroup

with TaskGroup('train_models', dag=dag) as train_group:
    train_rf = PythonOperator(task_id='train_rf', ...)
    train_lr = PythonOperator(task_id='train_lr', ...)
    train_xgb = PythonOperator(task_id='train_xgb', ...)
python
from airflow.utils.task_group import TaskGroup

with TaskGroup('train_models', dag=dag) as train_group:
    train_rf = PythonOperator(task_id='train_rf', ...)
    train_lr = PythonOperator(task_id='train_lr', ...)
    train_xgb = PythonOperator(task_id='train_xgb', ...)

All models train in parallel

所有模型并行训练

preprocess >> train_group >> select_best
undefined
preprocess >> train_group >> select_best
undefined

Waiting for Data

等待数据就绪

python
from airflow.sensors.filesystem import FileSensor

wait_for_data = FileSensor(
    task_id='wait_for_data',
    filepath='/data/input/{{ ds }}.csv',
    poke_interval=60,  # Check every 60 seconds
    timeout=3600,  # Timeout after 1 hour
    mode='reschedule',  # Don't block worker
    dag=dag
)

wait_for_data >> process_data
python
from airflow.sensors.filesystem import FileSensor

wait_for_data = FileSensor(
    task_id='wait_for_data',
    filepath='/data/input/{{ ds }}.csv',
    poke_interval=60,  # 每60秒检查一次
    timeout=3600,  # 1小时后超时
    mode='reschedule',  # 不占用工作节点
    dag=dag
)

wait_for_data >> process_data