dspy-rag-pipeline

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

DSPy RAG Pipeline

DSPy RAG 管道

Goal

目标

Build retrieval-augmented generation pipelines with ColBERTv2 that can be systematically optimized.
构建可系统优化的、基于ColBERTv2的检索增强生成(RAG)管道。

When to Use

适用场景

  • Questions require external knowledge
  • You have a document corpus to search
  • Need grounded, factual responses
  • Want to optimize retrieval + generation jointly
  • 问题需要外部知识支持
  • 您拥有可搜索的文档语料库
  • 需要生成有依据的事实性回答
  • 希望联合优化检索与生成过程

Related Skills

相关Skill

  • Optimize this pipeline: dspy-miprov2-optimizer, dspy-bootstrap-fewshot
  • Evaluate results: dspy-evaluation-suite
  • Design signatures: dspy-signature-designer
  • 优化本管道:dspy-miprov2-optimizer, dspy-bootstrap-fewshot
  • 评估结果:dspy-evaluation-suite
  • 设计签名:dspy-signature-designer

Inputs

输入

InputTypeDescription
question
str
User query
k
int
Number of passages to retrieve
rm
dspy.Retrieve
Retrieval model (ColBERTv2)
输入类型描述
question
str
用户查询
k
int
要检索的段落数量
rm
dspy.Retrieve
检索模型(ColBERTv2)

Outputs

输出

OutputTypeDescription
context
list[str]
Retrieved passages
answer
str
Generated response
输出类型描述
context
list[str]
检索到的段落
answer
str
生成的回答

Workflow

工作流程

Phase 1: Configure Retrieval

阶段1:配置检索

python
import dspy
python
import dspy

Configure LM and retriever

Configure LM and retriever

colbert = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') dspy.configure( lm=dspy.LM("openai/gpt-4o-mini"), rm=colbert )
undefined
colbert = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') dspy.configure( lm=dspy.LM("openai/gpt-4o-mini"), rm=colbert )
undefined

Phase 2: Define Signature

阶段2:定义Signature

python
class GenerateAnswer(dspy.Signature):
    """Answer questions with short factoid answers."""
    context: str = dspy.InputField(desc="May contain relevant facts")
    question: str = dspy.InputField()
    answer: str = dspy.OutputField(desc="Often between 1 and 5 words")
python
class GenerateAnswer(dspy.Signature):
    """Answer questions with short factoid answers."""
    context: str = dspy.InputField(desc="May contain relevant facts")
    question: str = dspy.InputField()
    answer: str = dspy.OutputField(desc="Often between 1 and 5 words")

Phase 3: Build RAG Module

阶段3:构建RAG模块

python
class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question):
        context = self.retrieve(question).passages
        pred = self.generate(context=context, question=question)
        return dspy.Prediction(context=context, answer=pred.answer)
python
class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question):
        context = self.retrieve(question).passages
        pred = self.generate(context=context, question=question)
        return dspy.Prediction(context=context, answer=pred.answer)

Phase 4: Use

阶段4:使用

python
rag = RAG(num_passages=3)
result = rag(question="What is the capital of France?")
print(result.answer)  # Paris
python
rag = RAG(num_passages=3)
result = rag(question="What is the capital of France?")
print(result.answer)  # Paris

Production Example

生产环境示例

python
import dspy
from dspy.teleprompt import BootstrapFewShot
from dspy.evaluate import Evaluate
import logging

logger = logging.getLogger(__name__)

class GenerateAnswer(dspy.Signature):
    """Answer questions using the provided context."""
    context: list[str] = dspy.InputField(desc="Retrieved passages")
    question: str = dspy.InputField()
    answer: str = dspy.OutputField(desc="Concise factual answer")

class ProductionRAG(dspy.Module):
    def __init__(self, num_passages=5):
        super().__init__()
        self.num_passages = num_passages
        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question: str):
        try:
            # Retrieve
            retrieval_result = self.retrieve(question)
            context = retrieval_result.passages
            
            if not context:
                logger.warning(f"No passages retrieved for: {question}")
                return dspy.Prediction(
                    context=[],
                    answer="I couldn't find relevant information."
                )
            
            # Generate
            pred = self.generate(context=context, question=question)
            
            return dspy.Prediction(
                context=context,
                answer=pred.answer,
                reasoning=getattr(pred, 'reasoning', None)
            )
            
        except Exception as e:
            logger.error(f"RAG failed: {e}")
            return dspy.Prediction(
                context=[],
                answer="An error occurred while processing your question."
            )

def validate_answer(example, pred, trace=None):
    """Check if answer is grounded and correct."""
    if not pred.answer or not pred.context:
        return 0.0
    
    # Check correctness
    correct = example.answer.lower() in pred.answer.lower()
    
    # Check grounding (answer should relate to context)
    context_text = " ".join(pred.context).lower()
    grounded = any(word in context_text for word in pred.answer.lower().split())
    
    return float(correct and grounded)

def build_optimized_rag(trainset, devset):
    """Build and optimize a RAG pipeline."""
    
    # Configure
    colbert = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
    dspy.configure(
        lm=dspy.LM("openai/gpt-4o-mini"),
        rm=colbert
    )
    
    # Build
    rag = ProductionRAG(num_passages=5)
    
    # Evaluate baseline
    evaluator = Evaluate(devset=devset, metric=validate_answer, num_threads=8)
    baseline = evaluator(rag)
    logger.info(f"Baseline: {baseline:.2%}")
    
    # Optimize
    optimizer = BootstrapFewShot(
        metric=validate_answer,
        max_bootstrapped_demos=4,
        max_labeled_demos=4
    )
    compiled = optimizer.compile(rag, trainset=trainset)
    
    optimized = evaluator(compiled)
    logger.info(f"Optimized: {optimized:.2%}")
    
    compiled.save("rag_optimized.json")
    return compiled
python
import dspy
from dspy.teleprompt import BootstrapFewShot
from dspy.evaluate import Evaluate
import logging

logger = logging.getLogger(__name__)

class GenerateAnswer(dspy.Signature):
    """Answer questions using the provided context."""
    context: list[str] = dspy.InputField(desc="Retrieved passages")
    question: str = dspy.InputField()
    answer: str = dspy.OutputField(desc="Concise factual answer")

class ProductionRAG(dspy.Module):
    def __init__(self, num_passages=5):
        super().__init__()
        self.num_passages = num_passages
        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question: str):
        try:
            # Retrieve
            retrieval_result = self.retrieve(question)
            context = retrieval_result.passages
            
            if not context:
                logger.warning(f"No passages retrieved for: {question}")
                return dspy.Prediction(
                    context=[],
                    answer="I couldn't find relevant information."
                )
            
            # Generate
            pred = self.generate(context=context, question=question)
            
            return dspy.Prediction(
                context=context,
                answer=pred.answer,
                reasoning=getattr(pred, 'reasoning', None)
            )
            
        except Exception as e:
            logger.error(f"RAG failed: {e}")
            return dspy.Prediction(
                context=[],
                answer="An error occurred while processing your question."
            )

def validate_answer(example, pred, trace=None):
    """Check if answer is grounded and correct."""
    if not pred.answer or not pred.context:
        return 0.0
    
    # Check correctness
    correct = example.answer.lower() in pred.answer.lower()
    
    # Check grounding (answer should relate to context)
    context_text = " ".join(pred.context).lower()
    grounded = any(word in context_text for word in pred.answer.lower().split())
    
    return float(correct and grounded)

def build_optimized_rag(trainset, devset):
    """Build and optimize a RAG pipeline."""
    
    # Configure
    colbert = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
    dspy.configure(
        lm=dspy.LM("openai/gpt-4o-mini"),
        rm=colbert
    )
    
    # Build
    rag = ProductionRAG(num_passages=5)
    
    # Evaluate baseline
    evaluator = Evaluate(devset=devset, metric=validate_answer, num_threads=8)
    baseline = evaluator(rag)
    logger.info(f"Baseline: {baseline:.2%}")
    
    # Optimize
    optimizer = BootstrapFewShot(
        metric=validate_answer,
        max_bootstrapped_demos=4,
        max_labeled_demos=4
    )
    compiled = optimizer.compile(rag, trainset=trainset)
    
    optimized = evaluator(compiled)
    logger.info(f"Optimized: {optimized:.2%}")
    
    compiled.save("rag_optimized.json")
    return compiled

Multi-Hop RAG

多跳RAG

python
class MultiHopRAG(dspy.Module):
    """RAG with iterative retrieval for complex questions."""
    
    def __init__(self, num_hops=2, passages_per_hop=3):
        super().__init__()
        self.num_hops = num_hops
        self.retrieve = dspy.Retrieve(k=passages_per_hop)
        self.generate_query = dspy.ChainOfThought("context, question -> search_query")
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question):
        context = []
        
        for hop in range(self.num_hops):
            # First hop: use original question
            # Later hops: generate refined query
            if hop == 0:
                query = question
            else:
                query = self.generate_query(
                    context=context,
                    question=question
                ).search_query
            
            # Retrieve and accumulate
            new_passages = self.retrieve(query).passages
            context.extend(new_passages)
        
        # Generate final answer
        pred = self.generate_answer(context=context, question=question)
        return dspy.Prediction(context=context, answer=pred.answer)
python
class MultiHopRAG(dspy.Module):
    """RAG with iterative retrieval for complex questions."""
    
    def __init__(self, num_hops=2, passages_per_hop=3):
        super().__init__()
        self.num_hops = num_hops
        self.retrieve = dspy.Retrieve(k=passages_per_hop)
        self.generate_query = dspy.ChainOfThought("context, question -> search_query")
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question):
        context = []
        
        for hop in range(self.num_hops):
            # First hop: use original question
            # Later hops: generate refined query
            if hop == 0:
                query = question
            else:
                query = self.generate_query(
                    context=context,
                    question=question
                ).search_query
            
            # Retrieve and accumulate
            new_passages = self.retrieve(query).passages
            context.extend(new_passages)
        
        # Generate final answer
        pred = self.generate_answer(context=context, question=question)
        return dspy.Prediction(context=context, answer=pred.answer)

Best Practices

最佳实践

  1. Tune k carefully - More passages = more context but also noise
  2. Signature descriptions matter - Guide the model with field descriptions
  3. Validate grounding - Ensure answers come from retrieved context
  4. Consider multi-hop - Complex questions may need iterative retrieval
  1. 谨慎调整k值 - 段落越多,上下文越丰富,但也会引入更多噪声
  2. Signature描述很重要 - 用字段描述引导模型
  3. 验证依据性 - 确保回答来自检索到的上下文
  4. 考虑多跳检索 - 复杂问题可能需要迭代检索

Limitations

局限性

  • Retrieval quality bounds generation quality
  • ColBERTv2 requires hosted index
  • Context length limits affect passage count
  • Latency increases with more passages
  • 检索质量决定了生成质量的上限
  • ColBERTv2需要托管的索引
  • 上下文长度限制会影响可使用的段落数量
  • 段落数量越多,延迟越高

Official Documentation

官方文档