dspy-custom-module-design

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

DSPy Custom Module Design

DSPy自定义模块设计

Goal

目标

Design production-quality custom DSPy modules with proper architecture, state management, serialization, and testing patterns.
设计具备合适架构、状态管理、序列化和测试模式的生产级自定义DSPy模块。

When to Use

使用场景

  • Building reusable DSPy components
  • Complex logic beyond built-in modules
  • Need custom state management
  • Sharing modules across projects
  • Production deployment requirements
  • 构建可复用DSPy组件
  • 实现内置模块之外的复杂逻辑
  • 需要自定义状态管理
  • 在项目间共享模块
  • 满足生产部署要求

Related Skills

相关技能

  • Module composition: dspy-advanced-module-composition
  • Signature design: dspy-signature-designer
  • Optimization: dspy-miprov2-optimizer
  • 模块组合:dspy-advanced-module-composition
  • 签名设计:dspy-signature-designer
  • 优化:dspy-miprov2-optimizer

Inputs

输入项

InputTypeDescription
task_description
str
What the module should do
components
list
Sub-modules or predictors
state
dict
Stateful attributes
输入项类型描述
task_description
str
模块需要实现的功能
components
list
子模块或预测器
state
dict
有状态属性

Outputs

输出项

OutputTypeDescription
custom_module
dspy.Module
Production-ready module
输出项类型描述
custom_module
dspy.Module
生产就绪的模块

Workflow

工作流程

Phase 1: Basic Module Structure

阶段1:基础模块结构

All custom modules inherit from
dspy.Module
:
python
import dspy

class BasicQA(dspy.Module):
    """Simple question answering module."""

    def __init__(self):
        super().__init__()
        self.predictor = dspy.Predict("question -> answer")

    def forward(self, question):
        """Entry point for module execution."""
        return self.predictor(question=question)
所有自定义模块都继承自
dspy.Module
python
import dspy

class BasicQA(dspy.Module):
    """Simple question answering module."""

    def __init__(self):
        super().__init__()
        self.predictor = dspy.Predict("question -> answer")

    def forward(self, question):
        """Entry point for module execution."""
        return self.predictor(question=question)

Usage

Usage

dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) qa = BasicQA() result = qa(question="What is Python?") print(result.answer)
undefined
dspy.configure(lm=dspy.LM("openai/gpt-4o-mini")) qa = BasicQA() result = qa(question="What is Python?") print(result.answer)
undefined

Phase 2: Stateful Modules

阶段2:有状态模块

Modules can maintain state across calls:
python
import dspy
import logging

logger = logging.getLogger(__name__)

class StatefulRAG(dspy.Module):
    """RAG with query caching."""

    def __init__(self, cache_size=100):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=3)
        self.generate = dspy.ChainOfThought("context, question -> answer")
        self.cache = {}
        self.cache_size = cache_size

    def forward(self, question):
        # Check cache
        if question in self.cache:
            return self.cache[question]

        # Retrieve and generate
        passages = self.retrieve(question).passages
        result = self.generate(context=passages, question=question)

        # Update cache with size limit
        if len(self.cache) >= self.cache_size:
            self.cache.pop(next(iter(self.cache)))
        self.cache[question] = result

        return result
模块可以在多次调用间维持状态:
python
import dspy
import logging

logger = logging.getLogger(__name__)

class StatefulRAG(dspy.Module):
    """RAG with query caching."""

    def __init__(self, cache_size=100):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=3)
        self.generate = dspy.ChainOfThought("context, question -> answer")
        self.cache = {}
        self.cache_size = cache_size

    def forward(self, question):
        # Check cache
        if question in self.cache:
            return self.cache[question]

        # Retrieve and generate
        passages = self.retrieve(question).passages
        result = self.generate(context=passages, question=question)

        # Update cache with size limit
        if len(self.cache) >= self.cache_size:
            self.cache.pop(next(iter(self.cache)))
        self.cache[question] = result

        return result

Phase 3: Error Handling and Validation

阶段3:错误处理与验证

Production modules need robust error handling:
python
import dspy
from typing import Optional
import logging

logger = logging.getLogger(__name__)

class RobustClassifier(dspy.Module):
    """Classifier with validation."""

    def __init__(self, valid_labels: list[str]):
        super().__init__()
        self.valid_labels = set(valid_labels)
        self.classify = dspy.Predict("text -> label: str, confidence: float")

    def forward(self, text: str) -> dspy.Prediction:
        if not text or not text.strip():
            return dspy.Prediction(label="unknown", confidence=0.0, error="Empty input")

        try:
            result = self.classify(text=text)

            # Validate label
            if result.label not in self.valid_labels:
                result.label = "unknown"
                result.confidence = 0.0

            return result

        except Exception as e:
            logger.error(f"Classification failed: {e}")
            return dspy.Prediction(label="unknown", confidence=0.0, error=str(e))
生产级模块需要健壮的错误处理:
python
import dspy
from typing import Optional
import logging

logger = logging.getLogger(__name__)

class RobustClassifier(dspy.Module):
    """Classifier with validation."""

    def __init__(self, valid_labels: list[str]):
        super().__init__()
        self.valid_labels = set(valid_labels)
        self.classify = dspy.Predict("text -> label: str, confidence: float")

    def forward(self, text: str) -> dspy.Prediction:
        if not text or not text.strip():
            return dspy.Prediction(label="unknown", confidence=0.0, error="Empty input")

        try:
            result = self.classify(text=text)

            # Validate label
            if result.label not in self.valid_labels:
                result.label = "unknown"
                result.confidence = 0.0

            return result

        except Exception as e:
            logger.error(f"Classification failed: {e}")
            return dspy.Prediction(label="unknown", confidence=0.0, error=str(e))

Phase 4: Serialization

阶段4:序列化

Modules support save/load:
python
import dspy
模块支持保存/加载:
python
import dspy

Save module state

Save module state

module = MyCustomModule() module.save("my_module.json")
module = MyCustomModule() module.save("my_module.json")

Load requires creating instance first, then loading state

Load requires creating instance first, then loading state

loaded = MyCustomModule() loaded.load("my_module.json")
loaded = MyCustomModule() loaded.load("my_module.json")

For loading entire programs (dspy>=2.6.0)

For loading entire programs (dspy>=2.6.0)

module.save("./my_module/", save_program=True) loaded = dspy.load("./my_module/")
undefined
module.save("./my_module/", save_program=True) loaded = dspy.load("./my_module/")
undefined

Production Example

生产级示例

python
import dspy
from typing import List, Optional
import logging

logger = logging.getLogger(__name__)

class ProductionRAG(dspy.Module):
    """Production-ready RAG with all best practices."""

    def __init__(
        self,
        retriever_k: int = 5,
        cache_enabled: bool = True,
        cache_size: int = 1000
    ):
        super().__init__()

        # Configuration
        self.retriever_k = retriever_k
        self.cache_enabled = cache_enabled
        self.cache_size = cache_size

        # Components
        self.retrieve = dspy.Retrieve(k=retriever_k)
        self.generate = dspy.ChainOfThought("context, question -> answer")

        # State
        self.cache = {} if cache_enabled else None
        self.call_count = 0

    def forward(self, question: str) -> dspy.Prediction:
        """Execute RAG pipeline with caching."""
        self.call_count += 1

        # Validation
        if not question or not question.strip():
            return dspy.Prediction(
                answer="Please provide a valid question.",
                error="Invalid input"
            )

        # Cache check
        if self.cache_enabled and question in self.cache:
            logger.info(f"Cache hit (call #{self.call_count})")
            return self.cache[question]

        # Execute pipeline
        try:
            passages = self.retrieve(question).passages

            if not passages:
                logger.warning("No passages retrieved")
                return dspy.Prediction(
                    answer="No relevant information found.",
                    passages=[]
                )

            result = self.generate(context=passages, question=question)
            result.passages = passages

            # Update cache
            if self.cache_enabled:
                self._update_cache(question, result)

            return result

        except Exception as e:
            logger.error(f"RAG execution failed: {e}")
            return dspy.Prediction(
                answer="An error occurred while processing your question.",
                error=str(e)
            )

    def _update_cache(self, key: str, value: dspy.Prediction):
        """Manage cache with size limit."""
        if len(self.cache) >= self.cache_size:
            self.cache.pop(next(iter(self.cache)))
        self.cache[key] = value

    def clear_cache(self):
        """Clear cache."""
        if self.cache_enabled:
            self.cache.clear()
python
import dspy
from typing import List, Optional
import logging

logger = logging.getLogger(__name__)

class ProductionRAG(dspy.Module):
    """Production-ready RAG with all best practices."""

    def __init__(
        self,
        retriever_k: int = 5,
        cache_enabled: bool = True,
        cache_size: int = 1000
    ):
        super().__init__()

        # Configuration
        self.retriever_k = retriever_k
        self.cache_enabled = cache_enabled
        self.cache_size = cache_size

        # Components
        self.retrieve = dspy.Retrieve(k=retriever_k)
        self.generate = dspy.ChainOfThought("context, question -> answer")

        # State
        self.cache = {} if cache_enabled else None
        self.call_count = 0

    def forward(self, question: str) -> dspy.Prediction:
        """Execute RAG pipeline with caching."""
        self.call_count += 1

        # Validation
        if not question or not question.strip():
            return dspy.Prediction(
                answer="Please provide a valid question.",
                error="Invalid input"
            )

        # Cache check
        if self.cache_enabled and question in self.cache:
            logger.info(f"Cache hit (call #{self.call_count})")
            return self.cache[question]

        # Execute pipeline
        try:
            passages = self.retrieve(question).passages

            if not passages:
                logger.warning("No passages retrieved")
                return dspy.Prediction(
                    answer="No relevant information found.",
                    passages=[]
                )

            result = self.generate(context=passages, question=question)
            result.passages = passages

            # Update cache
            if self.cache_enabled:
                self._update_cache(question, result)

            return result

        except Exception as e:
            logger.error(f"RAG execution failed: {e}")
            return dspy.Prediction(
                answer="An error occurred while processing your question.",
                error=str(e)
            )

    def _update_cache(self, key: str, value: dspy.Prediction):
        """Manage cache with size limit."""
        if len(self.cache) >= self.cache_size:
            self.cache.pop(next(iter(self.cache)))
        self.cache[key] = value

    def clear_cache(self):
        """Clear cache."""
        if self.cache_enabled:
            self.cache.clear()

Best Practices

最佳实践

  1. Single responsibility - Each module does one thing well
  2. Validate inputs - Check for None, empty strings, invalid types
  3. Handle errors - Return Predictions with error fields, never raise
  4. Log important events - Cache hits, errors, validation failures
  5. Test independently - Unit test modules before composition
  1. 单一职责 - 每个模块专注完成一件事
  2. 输入验证 - 检查空值、空字符串、无效类型
  3. 错误处理 - 返回包含错误字段的Prediction对象,绝不抛出异常
  4. 日志记录重要事件 - 缓存命中、错误、验证失败等
  5. 独立测试 - 在组合前对模块进行单元测试

Limitations

局限性

  • State increases memory usage (careful with large caches)
  • Serialization doesn't automatically save custom state
  • Module testing requires mocking LM calls
  • Deep module hierarchies can be hard to debug
  • Performance overhead from validation in hot paths
  • 状态会增加内存占用(使用大型缓存时需注意)
  • 序列化不会自动保存自定义状态
  • 模块测试需要模拟LM调用
  • 深层模块层级可能难以调试
  • 热路径中的验证会带来性能开销

Official Documentation

官方文档