dspy-custom-module-design
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseDSPy Custom Module Design
DSPy自定义模块设计
Goal
目标
Design production-quality custom DSPy modules with proper architecture, state management, serialization, and testing patterns.
设计具备合适架构、状态管理、序列化和测试模式的生产级自定义DSPy模块。
When to Use
使用场景
- Building reusable DSPy components
- Complex logic beyond built-in modules
- Need custom state management
- Sharing modules across projects
- Production deployment requirements
- 构建可复用DSPy组件
- 实现内置模块之外的复杂逻辑
- 需要自定义状态管理
- 在项目间共享模块
- 满足生产部署要求
Related Skills
相关技能
- Module composition: dspy-advanced-module-composition
- Signature design: dspy-signature-designer
- Optimization: dspy-miprov2-optimizer
- 模块组合:dspy-advanced-module-composition
- 签名设计:dspy-signature-designer
- 优化:dspy-miprov2-optimizer
Inputs
输入项
| Input | Type | Description |
|---|---|---|
| | What the module should do |
| | Sub-modules or predictors |
| | Stateful attributes |
| 输入项 | 类型 | 描述 |
|---|---|---|
| | 模块需要实现的功能 |
| | 子模块或预测器 |
| | 有状态属性 |
Outputs
输出项
| Output | Type | Description |
|---|---|---|
| | Production-ready module |
| 输出项 | 类型 | 描述 |
|---|---|---|
| | 生产就绪的模块 |
Workflow
工作流程
Phase 1: Basic Module Structure
阶段1:基础模块结构
All custom modules inherit from :
dspy.Modulepython
import dspy
class BasicQA(dspy.Module):
"""Simple question answering module."""
def __init__(self):
super().__init__()
self.predictor = dspy.Predict("question -> answer")
def forward(self, question):
"""Entry point for module execution."""
return self.predictor(question=question)所有自定义模块都继承自:
dspy.Modulepython
import dspy
class BasicQA(dspy.Module):
"""Simple question answering module."""
def __init__(self):
super().__init__()
self.predictor = dspy.Predict("question -> answer")
def forward(self, question):
"""Entry point for module execution."""
return self.predictor(question=question)Usage
Usage
dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"))
qa = BasicQA()
result = qa(question="What is Python?")
print(result.answer)
undefineddspy.configure(lm=dspy.LM("openai/gpt-4o-mini"))
qa = BasicQA()
result = qa(question="What is Python?")
print(result.answer)
undefinedPhase 2: Stateful Modules
阶段2:有状态模块
Modules can maintain state across calls:
python
import dspy
import logging
logger = logging.getLogger(__name__)
class StatefulRAG(dspy.Module):
"""RAG with query caching."""
def __init__(self, cache_size=100):
super().__init__()
self.retrieve = dspy.Retrieve(k=3)
self.generate = dspy.ChainOfThought("context, question -> answer")
self.cache = {}
self.cache_size = cache_size
def forward(self, question):
# Check cache
if question in self.cache:
return self.cache[question]
# Retrieve and generate
passages = self.retrieve(question).passages
result = self.generate(context=passages, question=question)
# Update cache with size limit
if len(self.cache) >= self.cache_size:
self.cache.pop(next(iter(self.cache)))
self.cache[question] = result
return result模块可以在多次调用间维持状态:
python
import dspy
import logging
logger = logging.getLogger(__name__)
class StatefulRAG(dspy.Module):
"""RAG with query caching."""
def __init__(self, cache_size=100):
super().__init__()
self.retrieve = dspy.Retrieve(k=3)
self.generate = dspy.ChainOfThought("context, question -> answer")
self.cache = {}
self.cache_size = cache_size
def forward(self, question):
# Check cache
if question in self.cache:
return self.cache[question]
# Retrieve and generate
passages = self.retrieve(question).passages
result = self.generate(context=passages, question=question)
# Update cache with size limit
if len(self.cache) >= self.cache_size:
self.cache.pop(next(iter(self.cache)))
self.cache[question] = result
return resultPhase 3: Error Handling and Validation
阶段3:错误处理与验证
Production modules need robust error handling:
python
import dspy
from typing import Optional
import logging
logger = logging.getLogger(__name__)
class RobustClassifier(dspy.Module):
"""Classifier with validation."""
def __init__(self, valid_labels: list[str]):
super().__init__()
self.valid_labels = set(valid_labels)
self.classify = dspy.Predict("text -> label: str, confidence: float")
def forward(self, text: str) -> dspy.Prediction:
if not text or not text.strip():
return dspy.Prediction(label="unknown", confidence=0.0, error="Empty input")
try:
result = self.classify(text=text)
# Validate label
if result.label not in self.valid_labels:
result.label = "unknown"
result.confidence = 0.0
return result
except Exception as e:
logger.error(f"Classification failed: {e}")
return dspy.Prediction(label="unknown", confidence=0.0, error=str(e))生产级模块需要健壮的错误处理:
python
import dspy
from typing import Optional
import logging
logger = logging.getLogger(__name__)
class RobustClassifier(dspy.Module):
"""Classifier with validation."""
def __init__(self, valid_labels: list[str]):
super().__init__()
self.valid_labels = set(valid_labels)
self.classify = dspy.Predict("text -> label: str, confidence: float")
def forward(self, text: str) -> dspy.Prediction:
if not text or not text.strip():
return dspy.Prediction(label="unknown", confidence=0.0, error="Empty input")
try:
result = self.classify(text=text)
# Validate label
if result.label not in self.valid_labels:
result.label = "unknown"
result.confidence = 0.0
return result
except Exception as e:
logger.error(f"Classification failed: {e}")
return dspy.Prediction(label="unknown", confidence=0.0, error=str(e))Phase 4: Serialization
阶段4:序列化
Modules support save/load:
python
import dspy模块支持保存/加载:
python
import dspySave module state
Save module state
module = MyCustomModule()
module.save("my_module.json")
module = MyCustomModule()
module.save("my_module.json")
Load requires creating instance first, then loading state
Load requires creating instance first, then loading state
loaded = MyCustomModule()
loaded.load("my_module.json")
loaded = MyCustomModule()
loaded.load("my_module.json")
For loading entire programs (dspy>=2.6.0)
For loading entire programs (dspy>=2.6.0)
module.save("./my_module/", save_program=True)
loaded = dspy.load("./my_module/")
undefinedmodule.save("./my_module/", save_program=True)
loaded = dspy.load("./my_module/")
undefinedProduction Example
生产级示例
python
import dspy
from typing import List, Optional
import logging
logger = logging.getLogger(__name__)
class ProductionRAG(dspy.Module):
"""Production-ready RAG with all best practices."""
def __init__(
self,
retriever_k: int = 5,
cache_enabled: bool = True,
cache_size: int = 1000
):
super().__init__()
# Configuration
self.retriever_k = retriever_k
self.cache_enabled = cache_enabled
self.cache_size = cache_size
# Components
self.retrieve = dspy.Retrieve(k=retriever_k)
self.generate = dspy.ChainOfThought("context, question -> answer")
# State
self.cache = {} if cache_enabled else None
self.call_count = 0
def forward(self, question: str) -> dspy.Prediction:
"""Execute RAG pipeline with caching."""
self.call_count += 1
# Validation
if not question or not question.strip():
return dspy.Prediction(
answer="Please provide a valid question.",
error="Invalid input"
)
# Cache check
if self.cache_enabled and question in self.cache:
logger.info(f"Cache hit (call #{self.call_count})")
return self.cache[question]
# Execute pipeline
try:
passages = self.retrieve(question).passages
if not passages:
logger.warning("No passages retrieved")
return dspy.Prediction(
answer="No relevant information found.",
passages=[]
)
result = self.generate(context=passages, question=question)
result.passages = passages
# Update cache
if self.cache_enabled:
self._update_cache(question, result)
return result
except Exception as e:
logger.error(f"RAG execution failed: {e}")
return dspy.Prediction(
answer="An error occurred while processing your question.",
error=str(e)
)
def _update_cache(self, key: str, value: dspy.Prediction):
"""Manage cache with size limit."""
if len(self.cache) >= self.cache_size:
self.cache.pop(next(iter(self.cache)))
self.cache[key] = value
def clear_cache(self):
"""Clear cache."""
if self.cache_enabled:
self.cache.clear()python
import dspy
from typing import List, Optional
import logging
logger = logging.getLogger(__name__)
class ProductionRAG(dspy.Module):
"""Production-ready RAG with all best practices."""
def __init__(
self,
retriever_k: int = 5,
cache_enabled: bool = True,
cache_size: int = 1000
):
super().__init__()
# Configuration
self.retriever_k = retriever_k
self.cache_enabled = cache_enabled
self.cache_size = cache_size
# Components
self.retrieve = dspy.Retrieve(k=retriever_k)
self.generate = dspy.ChainOfThought("context, question -> answer")
# State
self.cache = {} if cache_enabled else None
self.call_count = 0
def forward(self, question: str) -> dspy.Prediction:
"""Execute RAG pipeline with caching."""
self.call_count += 1
# Validation
if not question or not question.strip():
return dspy.Prediction(
answer="Please provide a valid question.",
error="Invalid input"
)
# Cache check
if self.cache_enabled and question in self.cache:
logger.info(f"Cache hit (call #{self.call_count})")
return self.cache[question]
# Execute pipeline
try:
passages = self.retrieve(question).passages
if not passages:
logger.warning("No passages retrieved")
return dspy.Prediction(
answer="No relevant information found.",
passages=[]
)
result = self.generate(context=passages, question=question)
result.passages = passages
# Update cache
if self.cache_enabled:
self._update_cache(question, result)
return result
except Exception as e:
logger.error(f"RAG execution failed: {e}")
return dspy.Prediction(
answer="An error occurred while processing your question.",
error=str(e)
)
def _update_cache(self, key: str, value: dspy.Prediction):
"""Manage cache with size limit."""
if len(self.cache) >= self.cache_size:
self.cache.pop(next(iter(self.cache)))
self.cache[key] = value
def clear_cache(self):
"""Clear cache."""
if self.cache_enabled:
self.cache.clear()Best Practices
最佳实践
- Single responsibility - Each module does one thing well
- Validate inputs - Check for None, empty strings, invalid types
- Handle errors - Return Predictions with error fields, never raise
- Log important events - Cache hits, errors, validation failures
- Test independently - Unit test modules before composition
- 单一职责 - 每个模块专注完成一件事
- 输入验证 - 检查空值、空字符串、无效类型
- 错误处理 - 返回包含错误字段的Prediction对象,绝不抛出异常
- 日志记录重要事件 - 缓存命中、错误、验证失败等
- 独立测试 - 在组合前对模块进行单元测试
Limitations
局限性
- State increases memory usage (careful with large caches)
- Serialization doesn't automatically save custom state
- Module testing requires mocking LM calls
- Deep module hierarchies can be hard to debug
- Performance overhead from validation in hot paths
- 状态会增加内存占用(使用大型缓存时需注意)
- 序列化不会自动保存自定义状态
- 模块测试需要模拟LM调用
- 深层模块层级可能难以调试
- 热路径中的验证会带来性能开销
Official Documentation
官方文档
- DSPy Documentation: https://dspy.ai/
- DSPy GitHub: https://github.com/stanfordnlp/dspy
- Custom Modules Guide: https://dspy.ai/tutorials/custom_module/
- Module API: https://dspy.ai/api/modules/
- DSPy官方文档: https://dspy.ai/
- DSPy GitHub: https://github.com/stanfordnlp/dspy
- 自定义模块指南: https://dspy.ai/tutorials/custom_module/
- 模块API: https://dspy.ai/api/modules/