long-context
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseLong Context: Extending Transformer Context Windows
长上下文:扩展Transformer模型的上下文窗口
When to Use This Skill
何时使用该技术
Use Long Context techniques when you need to:
- Process long documents (32k, 64k, 128k+ tokens) with transformer models
- Extend context windows of pre-trained models (LLaMA, Mistral, etc.)
- Implement efficient positional encodings (RoPE, ALiBi)
- Train models with length extrapolation capabilities
- Deploy models that handle variable-length inputs efficiently
- Fine-tune existing models for longer contexts with minimal compute
Key Techniques: RoPE (Rotary Position Embeddings), YaRN, ALiBi (Attention with Linear Biases), Position Interpolation
Papers: RoFormer (arXiv 2104.09864), YaRN (arXiv 2309.00071), ALiBi (arXiv 2108.12409), Position Interpolation (arXiv 2306.15595)
在以下场景中使用长上下文技术:
- 处理长文档(32k、64k、128k+ tokens)时使用Transformer模型
- 扩展预训练模型(LLaMA、Mistral等)的上下文窗口
- 实现高效位置编码(RoPE、ALiBi)
- 训练具备长度外推能力的模型
- 部署可高效处理变长输入的模型
- 以最小计算量微调现有模型以适配更长上下文
核心技术:RoPE(旋转位置嵌入)、YaRN、ALiBi(带线性偏置的注意力)、位置插值
相关论文:RoFormer(arXiv 2104.09864)、YaRN(arXiv 2309.00071)、ALiBi(arXiv 2108.12409)、位置插值(arXiv 2306.15595)
Installation
安装
bash
undefinedbash
undefinedHuggingFace Transformers (includes RoPE, YaRN support)
HuggingFace Transformers(包含RoPE、YaRN支持)
pip install transformers torch
pip install transformers torch
For custom implementations
用于自定义实现
pip install einops # Tensor operations
pip install rotary-embedding-torch # Standalone RoPE
pip install einops # 张量操作
pip install rotary-embedding-torch # 独立RoPE库
Optional: FlashAttention for efficiency
可选:FlashAttention提升效率
pip install flash-attn --no-build-isolation
undefinedpip install flash-attn --no-build-isolation
undefinedQuick Start
快速开始
RoPE (Rotary Position Embeddings)
RoPE(旋转位置嵌入)
python
import torch
import torch.nn as nn
class RotaryEmbedding(nn.Module):
"""Rotary Position Embeddings (RoPE)."""
def __init__(self, dim, max_seq_len=8192, base=10000):
super().__init__()
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len = max_seq_len
def forward(self, seq_len, device):
# Position indices
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
# Compute frequencies
freqs = torch.outer(t, self.inv_freq) # (seq_len, dim/2)
# Compute sin and cos
emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, dim)
return emb.cos(), emb.sin()
def rotate_half(x):
"""Rotate half the hidden dimensions."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""Apply rotary embeddings to queries and keys."""
# q, k shape: (batch, heads, seq_len, dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embedpython
import torch
import torch.nn as nn
class RotaryEmbedding(nn.Module):
"""Rotary Position Embeddings (RoPE)."""
def __init__(self, dim, max_seq_len=8192, base=10000):
super().__init__()
# 计算逆频率
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len = max_seq_len
def forward(self, seq_len, device):
# 位置索引
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
# 计算频率
freqs = torch.outer(t, self.inv_freq) # (seq_len, dim/2)
# 计算正弦和余弦
emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, dim)
return emb.cos(), emb.sin()
def rotate_half(x):
"""旋转一半隐藏维度。"""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""将旋转嵌入应用于查询和键。"""
# q, k 形状: (batch, heads, seq_len, dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embedUsage
使用示例
rope = RotaryEmbedding(dim=64, max_seq_len=8192)
cos, sin = rope(seq_len=2048, device='cuda')
rope = RotaryEmbedding(dim=64, max_seq_len=8192)
cos, sin = rope(seq_len=2048, device='cuda')
In attention layer
在注意力层中使用
q_rotated, k_rotated = apply_rotary_pos_emb(query, key, cos, sin)
undefinedq_rotated, k_rotated = apply_rotary_pos_emb(query, key, cos, sin)
undefinedALiBi (Attention with Linear Biases)
ALiBi(带线性偏置的注意力)
python
def get_alibi_slopes(num_heads):
"""Get ALiBi slope values for each attention head."""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(num_heads).is_integer():
return get_slopes_power_of_2(num_heads)
else:
# Closest power of 2
closest_power = 2 ** math.floor(math.log2(num_heads))
slopes = get_slopes_power_of_2(closest_power)
# Add extra slopes
extra = get_slopes_power_of_2(2 * closest_power)
slopes.extend(extra[0::2][:num_heads - closest_power])
return slopes
def create_alibi_bias(seq_len, num_heads):
"""Create ALiBi attention bias."""
# Distance matrix
context_position = torch.arange(seq_len)
memory_position = torch.arange(seq_len)
relative_position = memory_position[None, :] - context_position[:, None]
# Get slopes
slopes = torch.tensor(get_alibi_slopes(num_heads))
# Apply slopes to distances
alibi = slopes[:, None, None] * relative_position[None, :, :]
return alibi # (num_heads, seq_len, seq_len)python
def get_alibi_slopes(num_heads):
"""为每个注意力头获取ALiBi斜率值。"""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(num_heads).is_integer():
return get_slopes_power_of_2(num_heads)
else:
# 最接近的2的幂
closest_power = 2 ** math.floor(math.log2(num_heads))
slopes = get_slopes_power_of_2(closest_power)
# 添加额外的斜率
extra = get_slopes_power_of_2(2 * closest_power)
slopes.extend(extra[0::2][:num_heads - closest_power])
return slopes
def create_alibi_bias(seq_len, num_heads):
"""创建ALiBi注意力偏置。"""
# 距离矩阵
context_position = torch.arange(seq_len)
memory_position = torch.arange(seq_len)
relative_position = memory_position[None, :] - context_position[:, None]
# 获取斜率
slopes = torch.tensor(get_alibi_slopes(num_heads))
# 将斜率应用于距离
alibi = slopes[:, None, None] * relative_position[None, :, :]
return alibi # (num_heads, seq_len, seq_len)Usage in attention
在注意力中使用
num_heads = 8
seq_len = 2048
alibi_bias = create_alibi_bias(seq_len, num_heads).to('cuda')
num_heads = 8
seq_len = 2048
alibi_bias = create_alibi_bias(seq_len, num_heads).to('cuda')
Add bias to attention scores
将偏置添加到注意力分数
attn_scores shape: (batch, num_heads, seq_len, seq_len)
attn_scores 形状: (batch, num_heads, seq_len, seq_len)
attn_scores = attn_scores + alibi_bias
attn_weights = torch.softmax(attn_scores, dim=-1)
undefinedattn_scores = attn_scores + alibi_bias
attn_weights = torch.softmax(attn_scores, dim=-1)
undefinedPosition Interpolation for LLaMA
针对LLaMA的位置插值
python
from transformers import LlamaForCausalLM, LlamaTokenizerpython
from transformers import LlamaForCausalLM, LlamaTokenizerOriginal context: 2048 tokens
原始上下文:2048 tokens
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
Extend to 32k with position interpolation
使用位置插值扩展到32k
Modify RoPE base frequency
修改RoPE基础频率
model.config.rope_scaling = {
"type": "linear",
"factor": 16.0 # 2048 * 16 = 32768
}
model.config.rope_scaling = {
"type": "linear",
"factor": 16.0 # 2048 * 16 = 32768
}
Or use dynamic scaling
或使用动态缩放
model.config.rope_scaling = {
"type": "dynamic",
"factor": 16.0
}
model.config.rope_scaling = {
"type": "dynamic",
"factor": 16.0
}
Fine-tune with long documents (minimal steps needed)
使用长文档微调(仅需少量步骤)
Position interpolation works out-of-the-box after this config change
修改配置后,位置插值即可直接生效
undefinedundefinedCore Concepts
核心概念
1. RoPE (Rotary Position Embeddings)
1. RoPE(旋转位置嵌入)
How it works:
- Encodes absolute position via rotation matrix
- Provides relative position dependency in attention
- Enables length extrapolation
Mathematical formulation:
q_m = (W_q * x_m) * e^(imθ)
k_n = (W_k * x_n) * e^(inθ)
where θ_j = base^(-2j/d) for j ∈ [0, d/2)Advantages:
- Decaying inter-token dependency with distance
- Compatible with linear attention
- Better extrapolation than absolute position encodings
工作原理:
- 通过旋转矩阵编码绝对位置
- 在注意力中提供相对位置依赖
- 支持长度外推
数学公式:
q_m = (W_q * x_m) * e^(imθ)
k_n = (W_k * x_n) * e^(inθ)
其中 θ_j = base^(-2j/d),j ∈ [0, d/2)优势:
- 随距离衰减的 token 间依赖
- 兼容线性注意力
- 比绝对位置编码的外推效果更好
2. YaRN (Yet another RoPE extensioN)
2. YaRN(Yet another RoPE extensioN)
Key innovation:
- NTK-aware interpolation (Neural Tangent Kernel)
- Attention temperature scaling
- Efficient context extension (10× less tokens vs baselines)
Parameters:
python
undefined核心创新:
- NTK感知插值(神经切线核)
- 注意力温度缩放
- 高效的上下文扩展(比基线少用10倍tokens)
参数:
python
undefinedYaRN configuration
YaRN配置
yarn_config = {
"scale": 16, # Extension factor
"original_max_position": 2048, # Base context
"extrapolation_factor": 1.0, # NTK parameter
"attn_factor": 1.0, # Attention scaling
"beta_fast": 32, # High-frequency scale
"beta_slow": 1, # Low-frequency scale
}
**Performance:**
- Extends LLaMA to 128k tokens
- 2.5× less training steps than baselines
- State-of-the-art context window extensionyarn_config = {
"scale": 16, # 扩展因子
"original_max_position": 2048, # 基础上下文长度
"extrapolation_factor": 1.0, # NTK参数
"attn_factor": 1.0, # 注意力缩放因子
"beta_fast": 32, # 高频缩放
"beta_slow": 1, # 低频缩放
}
**性能**:
- 将LLaMA扩展到128k tokens
- 比基线少用2.5倍训练步骤
- 最先进的上下文窗口扩展技术3. ALiBi (Attention with Linear Biases)
3. ALiBi(带线性偏置的注意力)
Core idea:
- No positional embeddings added to tokens
- Apply distance penalty directly to attention scores
- Bias proportional to key-query distance
Formula:
attention_bias[i, j] = -m * |i - j|
where m = slope for each attention headAdvantages:
- 11% faster training vs sinusoidal embeddings
- 11% less memory usage
- Strong length extrapolation (train 1k, test 2k+)
- Inductive bias towards recency
核心思想:
- 不为token添加位置嵌入
- 直接对注意力分数应用距离惩罚
- 偏置与键-查询距离成正比
公式:
attention_bias[i, j] = -m * |i - j|
其中 m = 每个注意力头的斜率优势:
- 比正弦嵌入训练速度快11%
- 内存使用减少11%
- 强大的长度外推能力(训练1k,测试2k+)
- 具备偏向近期内容的归纳偏置
4. Position Interpolation
4. 位置插值
Technique:
- Linearly down-scale position indices
- Interpolate within trained range (vs extrapolate beyond)
- Minimal fine-tuning required
Formula:
undefined技术原理:
- 线性缩小位置索引
- 在训练范围内插值(而非外推到训练范围之外)
- 仅需少量微调
公式:
undefinedOriginal: position indices [0, 1, 2, ..., L]
原始:位置索引 [0, 1, 2, ..., L]
Extended: position indices [0, 0.5, 1.0, ..., L/2]
扩展后:位置索引 [0, 0.5, 1.0, ..., L/2]
(for 2× extension)
(针对2倍扩展)
scaled_position[i] = i / extension_factor
**Results:**
- LLaMA 7B-65B extended to 32k tokens
- 1000 fine-tuning steps sufficient
- 600× better stability than extrapolationscaled_position[i] = i / extension_factor
**效果**:
- 将LLaMA 7B-65B扩展到32k tokens
- 仅需1000步微调即可
- 稳定性比外推好600倍Method Comparison
方法对比
| Method | Max Context | Training Needed | Memory | Extrapolation | Best For |
|---|---|---|---|---|---|
| RoPE | 8k-32k | Full pre-training | Moderate | Good | New models |
| YaRN | 32k-128k | Minimal (10× efficient) | Moderate | Excellent | Extending existing models |
| ALiBi | Unlimited | Full pre-training | Low (-11%) | Excellent | Training from scratch |
| Position Interpolation | 32k+ | Minimal (1k steps) | Moderate | Poor (by design) | Quick extension |
| 方法 | 最大上下文 | 是否需要训练 | 内存占用 | 外推能力 | 最佳适用场景 |
|---|---|---|---|---|---|
| RoPE | 8k-32k | 完整预训练 | 中等 | 良好 | 新模型开发 |
| YaRN | 32k-128k | 少量训练(效率提升10倍) | 中等 | 优秀 | 扩展现有模型 |
| ALiBi | 无限制 | 完整预训练 | 低(减少11%) | 优秀 | 从零开始训练 |
| 位置插值 | 32k+ | 少量训练(1k步) | 中等 | 较差(设计如此) | 快速扩展 |
Implementation Patterns
实现模式
HuggingFace Transformers Integration
HuggingFace Transformers集成
python
from transformers import AutoModelForCausalLM, AutoConfigpython
from transformers import AutoModelForCausalLM, AutoConfigRoPE with YaRN scaling
带YaRN缩放的RoPE
config = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1")
config.rope_scaling = {
"type": "yarn",
"factor": 8.0,
"original_max_position_embeddings": 8192,
"attention_factor": 1.0
}
model = AutoModelForCausalLM.from_config(config)
config = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1")
config.rope_scaling = {
"type": "yarn",
"factor": 8.0,
"original_max_position_embeddings": 8192,
"attention_factor": 1.0
}
model = AutoModelForCausalLM.from_config(config)
Position interpolation (simpler)
位置插值(更简单)
config.rope_scaling = {
"type": "linear",
"factor": 4.0
}
config.rope_scaling = {
"type": "linear",
"factor": 4.0
}
Dynamic scaling (adjusts based on input length)
动态缩放(根据输入长度调整)
config.rope_scaling = {
"type": "dynamic",
"factor": 8.0
}
undefinedconfig.rope_scaling = {
"type": "dynamic",
"factor": 8.0
}
undefinedCustom RoPE Implementation
自定义RoPE实现
python
class LongContextAttention(nn.Module):
"""Multi-head attention with RoPE."""
def __init__(self, hidden_size, num_heads, max_seq_len=32768):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
# Q, K, V projections
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.o_proj = nn.Linear(hidden_size, hidden_size)
# RoPE
self.rotary_emb = RotaryEmbedding(
dim=self.head_dim,
max_seq_len=max_seq_len
)
def forward(self, hidden_states):
batch_size, seq_len, _ = hidden_states.shape
# Project to Q, K, V
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Reshape for multi-head
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Apply RoPE
cos, sin = self.rotary_emb(seq_len, device=hidden_states.device)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# Standard attention
attn_output = F.scaled_dot_product_attention(q, k, v)
# Reshape and project
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, -1)
output = self.o_proj(attn_output)
return outputpython
class LongContextAttention(nn.Module):
"""带RoPE的多头注意力。"""
def __init__(self, hidden_size, num_heads, max_seq_len=32768):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
# Q、K、V投影层
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.o_proj = nn.Linear(hidden_size, hidden_size)
# RoPE模块
self.rotary_emb = RotaryEmbedding(
dim=self.head_dim,
max_seq_len=max_seq_len
)
def forward(self, hidden_states):
batch_size, seq_len, _ = hidden_states.shape
# 投影到Q、K、V
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# 重塑为多头格式
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 应用RoPE
cos, sin = self.rotary_emb(seq_len, device=hidden_states.device)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# 标准注意力计算
attn_output = F.scaled_dot_product_attention(q, k, v)
# 重塑并投影
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, -1)
output = self.o_proj(attn_output)
return outputFine-tuning for Long Context
长上下文微调
Minimal Fine-tuning (Position Interpolation)
少量微调(位置插值)
python
from transformers import Trainer, TrainingArgumentspython
from transformers import Trainer, TrainingArgumentsExtend model config
扩展模型配置
model.config.max_position_embeddings = 32768
model.config.rope_scaling = {"type": "linear", "factor": 16.0}
model.config.max_position_embeddings = 32768
model.config.rope_scaling = {"type": "linear", "factor": 16.0}
Training args (minimal steps needed)
训练参数(仅需少量步骤)
training_args = TrainingArguments(
output_dir="./llama-32k",
num_train_epochs=1,
max_steps=1000, # Only 1000 steps!
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
learning_rate=2e-5,
warmup_steps=100,
logging_steps=10,
save_steps=500,
)
training_args = TrainingArguments(
output_dir="./llama-32k",
num_train_epochs=1,
max_steps=1000, # 仅需1000步!
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
learning_rate=2e-5,
warmup_steps=100,
logging_steps=10,
save_steps=500,
)
Train on long documents
使用长文档训练
trainer = Trainer(
model=model,
args=training_args,
train_dataset=long_document_dataset, # 32k token sequences
)
trainer.train()
undefinedtrainer = Trainer(
model=model,
args=training_args,
train_dataset=long_document_dataset, # 32k token序列
)
trainer.train()
undefinedYaRN Fine-tuning
YaRN微调
bash
undefinedbash
undefinedClone YaRN implementation
克隆YaRN实现仓库
git clone https://github.com/jquesnelle/yarn
cd yarn
git clone https://github.com/jquesnelle/yarn
cd yarn
Fine-tune LLaMA with YaRN
使用YaRN微调LLaMA
python scripts/train.py
--model meta-llama/Llama-2-7b-hf
--scale 16
--rope_theta 10000
--max_length 32768
--batch_size 1
--gradient_accumulation 16
--steps 400
--learning_rate 2e-5
--model meta-llama/Llama-2-7b-hf
--scale 16
--rope_theta 10000
--max_length 32768
--batch_size 1
--gradient_accumulation 16
--steps 400
--learning_rate 2e-5
undefinedpython scripts/train.py
--model meta-llama/Llama-2-7b-hf
--scale 16
--rope_theta 10000
--max_length 32768
--batch_size 1
--gradient_accumulation 16
--steps 400
--learning_rate 2e-5
--model meta-llama/Llama-2-7b-hf
--scale 16
--rope_theta 10000
--max_length 32768
--batch_size 1
--gradient_accumulation 16
--steps 400
--learning_rate 2e-5
undefinedBest Practices
最佳实践
1. Choose the Right Method
1. 选择合适的方法
python
undefinedpython
undefinedFor NEW models (training from scratch)
针对新模型(从零开始训练)
use_method = "ALiBi" # Best extrapolation, lowest memory
use_method = "ALiBi" # 外推能力最佳,内存占用最低
For EXTENDING existing RoPE models
针对扩展现有RoPE模型
use_method = "YaRN" # Most efficient extension (10× less data)
use_method = "YaRN" # 扩展效率最高(数据用量减少10倍)
For QUICK extension with minimal compute
针对快速扩展且计算量少的场景
use_method = "Position Interpolation" # 1000 steps
use_method = "Position Interpolation" # 仅需1000步
For MODERATE extension with good efficiency
针对中等扩展且效率良好的场景
use_method = "Linear RoPE Scaling" # Built-in, simple
undefineduse_method = "Linear RoPE Scaling" # 内置支持,实现简单
undefined2. Scaling Factor Selection
2. 选择缩放因子
python
undefinedpython
undefinedConservative (safer, better quality)
保守策略(更安全,质量更好)
scaling_factor = 2.0 # 8k → 16k
scaling_factor = 2.0 # 8k → 16k
Moderate (good balance)
中等策略(平衡效果与成本)
scaling_factor = 4.0 # 8k → 32k
scaling_factor = 4.0 # 8k → 32k
Aggressive (requires more fine-tuning)
激进策略(需要更多微调)
scaling_factor = 8.0 # 8k → 64k
scaling_factor = 16.0 # 8k → 128k
scaling_factor = 8.0 # 8k → 64k
scaling_factor = 16.0 # 8k → 128k
Rule: Larger factors need more fine-tuning steps
规则:缩放因子越大,需要的微调步骤越多
steps_needed = 100 * scaling_factor # Rough estimate
undefinedsteps_needed = 100 * scaling_factor # 粗略估计
undefined3. Fine-tuning Data
3. 微调数据
python
undefinedpython
undefined✅ Good: Long documents matching target length
✅ 推荐:与目标长度匹配的长文档
train_data = [
{"text": long_doc_32k_tokens}, # Full 32k
{"text": long_doc_24k_tokens}, # Varied lengths
{"text": long_doc_16k_tokens},
]
train_data = [
{"text": long_doc_32k_tokens}, # 完整32k长度
{"text": long_doc_24k_tokens}, # 长度多样化
{"text": long_doc_16k_tokens},
]
❌ Bad: Short documents (won't learn long context)
❌ 不推荐:短文档(无法学习长上下文)
train_data = [
{"text": short_doc_2k_tokens},
]
train_data = [
{"text": short_doc_2k_tokens},
]
Use datasets like:
可使用的数据集:
- PG-19 (books, long texts)
- PG-19(书籍,长文本)
- arXiv papers
- arXiv论文
- Long-form conversations
- 长对话
- GitHub repositories (concatenated files)
- GitHub仓库(拼接文件)
undefinedundefined4. Avoid Common Pitfalls
4. 避免常见误区
python
undefinedpython
undefined❌ Bad: Applying position interpolation without fine-tuning
❌ 错误:仅缩放不微调
model.config.rope_scaling = {"type": "linear", "factor": 16.0}
model.config.rope_scaling = {"type": "linear", "factor": 16.0}
Model will perform poorly without fine-tuning!
不进行微调的话,模型性能会很差!
✅ Good: Fine-tune after scaling
✅ 正确:缩放后进行微调
model.config.rope_scaling = {"type": "linear", "factor": 16.0}
fine_tune(model, long_documents, steps=1000)
model.config.rope_scaling = {"type": "linear", "factor": 16.0}
fine_tune(model, long_documents, steps=1000)
❌ Bad: Too aggressive scaling without data
❌ 错误:无数据支撑的激进缩放
scale_to_1M_tokens() # Won't work without massive fine-tuning
scale_to_1M_tokens() # 没有大量微调的话无法工作
✅ Good: Incremental scaling
✅ 正确:增量缩放
8k → 16k → 32k → 64k (fine-tune at each step)
8k → 16k → 32k → 64k(每一步都进行微调)
undefinedundefinedProduction Deployment
生产部署
Inference with Long Context
长上下文推理
python
from transformers import AutoModelForCausalLM, AutoTokenizerpython
from transformers import AutoModelForCausalLM, AutoTokenizerLoad long-context model
加载长上下文模型
model = AutoModelForCausalLM.from_pretrained(
"togethercomputer/LLaMA-2-7B-32K", # 32k context
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("togethercomputer/LLaMA-2-7B-32K")
model = AutoModelForCausalLM.from_pretrained(
"togethercomputer/LLaMA-2-7B-32K", # 32k上下文
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("togethercomputer/LLaMA-2-7B-32K")
Process long document
处理长文档
long_text = "..." * 30000 # 30k tokens
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to('cuda')
long_text = "..." * 30000 # 30k tokens
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to('cuda')
Generate
生成内容
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
undefinedoutputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
undefinedMemory Optimization
内存优化
python
undefinedpython
undefinedUse gradient checkpointing for fine-tuning
微调时使用梯度检查点
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable()
Use Flash Attention 2
使用Flash Attention 2
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2", # 2-3× faster
torch_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2", # 速度提升2-3倍
torch_dtype=torch.float16
)
Use paged attention (vLLM)
使用分页注意力(vLLM)
from vllm import LLM
llm = LLM(
model="togethercomputer/LLaMA-2-7B-32K",
max_model_len=32768, # 32k context
gpu_memory_utilization=0.9
)
undefinedfrom vllm import LLM
llm = LLM(
model="togethercomputer/LLaMA-2-7B-32K",
max_model_len=32768, # 32k上下文
gpu_memory_utilization=0.9
)
undefinedResources
参考资源
- RoPE Paper: https://arxiv.org/abs/2104.09864 (RoFormer)
- YaRN Paper: https://arxiv.org/abs/2309.00071
- ALiBi Paper: https://arxiv.org/abs/2108.12409 (Train Short, Test Long)
- Position Interpolation: https://arxiv.org/abs/2306.15595
- HuggingFace RoPE Utils: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py
- YaRN Implementation: https://github.com/jquesnelle/yarn
- Together AI Blog: https://www.together.ai/blog/llama-2-7b-32k
- RoPE论文:https://arxiv.org/abs/2104.09864(RoFormer)
- YaRN论文:https://arxiv.org/abs/2309.00071
- ALiBi论文:https://arxiv.org/abs/2108.12409(Train Short, Test Long)
- 位置插值论文:https://arxiv.org/abs/2306.15595
- HuggingFace RoPE工具:https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py
- YaRN实现:https://github.com/jquesnelle/yarn
- Together AI博客:https://www.together.ai/blog/llama-2-7b-32k
See Also
更多参考
- - Detailed RoPE implementation and theory
references/rope.md - - YaRN, ALiBi, Position Interpolation comparisons
references/extension_methods.md - - Complete fine-tuning guide for context extension
references/fine_tuning.md
- - RoPE详细实现与理论
references/rope.md - - YaRN、ALiBi、位置插值对比
references/extension_methods.md - - 上下文扩展完整微调指南
references/fine_tuning.md