Loading...
Loading...
Break a failing complex AI task into reliable subtasks. Use when your AI works on simple inputs but fails on complex ones, extraction misses items in long documents, accuracy degrades as input grows, AI conflates multiple things at once, results are inconsistent across input types, you need to chunk long text for processing, or you want to split one unreliable AI step into multiple reliable ones.
npx skill4agent add lebsral/dspy-programming-not-prompting-lms-skills ai-decomposing-tasks| Failure mode | What you see | Root cause |
|---|---|---|
| Missed items | Extracts 3 of 7 line items | Input overwhelms the context — too much to track at once |
| Conflated fields | Mixes up sender/recipient addresses | Multiple similar things extracted simultaneously |
| Inconsistent results | Works on invoice A, fails on invoice B | Different input formats need different handling |
| Degraded accuracy | 95% on short text, 60% on long text | Input length exceeds what a single pass can reliably process |
/ai-improving-accuracyWhat's going wrong?
|
+- Input is too long, AI loses focus
| → Chunk-then-process (Step 3)
|
+- AI conflates multiple similar things
| → Sequential extraction (Step 4)
|
+- AI misses items in variable-length lists
| → Identify-then-process (Step 5)
|
+- Different input types need different handling
| → Classify-then-specialize (see /ai-building-pipelines)import dspy
from pydantic import BaseModel, Field
class ExtractedItem(BaseModel):
name: str
value: str
source_text: str = Field(description="The exact text this was extracted from")
class ExtractFromChunk(dspy.Signature):
"""Extract all relevant items from this section of the document."""
chunk: str = dspy.InputField(desc="A section of the document")
items: list[ExtractedItem] = dspy.OutputField(desc="All items found in this section")
class ChunkAndExtract(dspy.Module):
def __init__(self, chunk_size=2000, overlap=200):
self.chunk_size = chunk_size
self.overlap = overlap
self.extract = dspy.ChainOfThought(ExtractFromChunk)
def _chunk_text(self, text: str) -> list[str]:
"""Split text into overlapping chunks at paragraph boundaries."""
words = text.split()
chunks = []
start = 0
while start < len(words):
end = start + self.chunk_size
chunk = " ".join(words[start:end])
chunks.append(chunk)
start = end - self.overlap
return chunks
def _deduplicate(self, all_items: list[ExtractedItem]) -> list[ExtractedItem]:
"""Remove duplicate extractions from overlapping chunks."""
seen = set()
unique = []
for item in all_items:
key = (item.name.lower().strip(), item.value.lower().strip())
if key not in seen:
seen.add(key)
unique.append(item)
return unique
def forward(self, document: str):
chunks = self._chunk_text(document)
all_items = []
for chunk in chunks:
result = self.extract(chunk=chunk)
all_items.extend(result.items)
unique_items = self._deduplicate(all_items)
return dspy.Prediction(items=unique_items)\n\nsource_textclass IdentifyPanels(dspy.Signature):
"""Identify all lab test panels in the medical report."""
report: str = dspy.InputField(desc="Medical lab report")
panel_names: list[str] = dspy.OutputField(desc="Names of all test panels found")
class LabResult(BaseModel):
test_name: str
value: str
unit: str
reference_range: str
flag: str = Field(description="'normal', 'high', or 'low'")
class ExtractPanelResults(dspy.Signature):
"""Extract all test results for a specific panel from the report."""
report: str = dspy.InputField(desc="Medical lab report")
panel_name: str = dspy.InputField(desc="The specific panel to extract results for")
results: list[LabResult] = dspy.OutputField(desc="All test results for this panel")
class SequentialExtractor(dspy.Module):
def __init__(self):
self.identify = dspy.ChainOfThought(IdentifyPanels)
self.extract = dspy.ChainOfThought(ExtractPanelResults)
def forward(self, report: str):
# Step 1: Identify what's in the report
panels = self.identify(report=report)
dspy.Assert(
len(panels.panel_names) > 0,
"No test panels found — is this a valid lab report?"
)
# Step 2: Extract results per panel
all_results = {}
for panel_name in panels.panel_names:
result = self.extract(report=report, panel_name=panel_name)
all_results[panel_name] = result.results
return dspy.Prediction(
panels=panels.panel_names,
results=all_results,
)class IdentifyLineItems(dspy.Signature):
"""Identify all line items in the invoice. List every item, even small ones."""
invoice_text: str = dspy.InputField(desc="Raw invoice text")
item_descriptions: list[str] = dspy.OutputField(
desc="Brief description of each line item, in order they appear"
)
class LineItemDetail(BaseModel):
description: str
quantity: int
unit_price: float
total: float
class ExtractLineItem(dspy.Signature):
"""Extract the details for a specific line item from the invoice."""
invoice_text: str = dspy.InputField(desc="Raw invoice text")
item_description: str = dspy.InputField(desc="The specific item to extract details for")
details: LineItemDetail = dspy.OutputField()
class IdentifyThenExtract(dspy.Module):
def __init__(self):
self.identify = dspy.ChainOfThought(IdentifyLineItems)
self.extract_item = dspy.ChainOfThought(ExtractLineItem)
def forward(self, invoice_text: str):
# Step 1: Identify all items (just names — low cognitive load)
items = self.identify(invoice_text=invoice_text)
dspy.Assert(
len(items.item_descriptions) > 0,
"No line items found in invoice"
)
# Step 2: Extract details per item
line_items = []
for desc in items.item_descriptions:
result = self.extract_item(
invoice_text=invoice_text,
item_description=desc,
)
line_items.append(result.details)
return dspy.Prediction(line_items=line_items)from dspy.evaluate import Evaluate
# Build both versions
single_step = dspy.ChainOfThought(ExtractAllItems) # Original single-step
decomposed = IdentifyThenExtract() # Decomposed version
def extraction_metric(example, prediction, trace=None):
"""Measure recall — what fraction of gold items were extracted."""
gold_items = set(item.lower() for item in example.item_names)
pred_items = set(item.description.lower() for item in prediction.line_items)
if not gold_items:
return 1.0
return len(gold_items & pred_items) / len(gold_items)
evaluator = Evaluate(devset=devset, metric=extraction_metric, num_threads=4, display_table=5)
# Compare
single_score = evaluator(single_step)
decomposed_score = evaluator(decomposed)
print(f"Single-step: {single_score:.1f}%")
print(f"Decomposed: {decomposed_score:.1f}%")simple_devset = [ex for ex in devset if len(ex.item_names) <= 3]
complex_devset = [ex for ex in devset if len(ex.item_names) > 3]
simple_evaluator = Evaluate(devset=simple_devset, metric=extraction_metric)
complex_evaluator = Evaluate(devset=complex_devset, metric=extraction_metric)
print("Simple inputs:")
print(f" Single-step: {simple_evaluator(single_step):.1f}%")
print(f" Decomposed: {simple_evaluator(decomposed):.1f}%")
print("Complex inputs:")
print(f" Single-step: {complex_evaluator(single_step):.1f}%")
print(f" Decomposed: {complex_evaluator(decomposed):.1f}%")optimizer = dspy.MIPROv2(metric=extraction_metric, auto="medium")
optimized = optimizer.compile(decomposed, trainset=trainset)
# Verify improvement
optimized_score = evaluator(optimized)
print(f"Decomposed (unoptimized): {decomposed_score:.1f}%")
print(f"Decomposed (optimized): {optimized_score:.1f}%")cheap_lm = dspy.LM("openai/gpt-4o-mini")
quality_lm = dspy.LM("openai/gpt-4o")
decomposed.identify.set_lm(cheap_lm) # Cheap for listing
decomposed.extract_item.set_lm(quality_lm) # Quality for extraction/ai-cutting-costs/ai-building-pipelines/ai-improving-accuracy/ai-parsing-data/ai-improving-accuracy