pytorch

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Using PyTorch

使用PyTorch

PyTorch is a deep learning framework with dynamic computation graphs, strong GPU acceleration, and Pythonic design. This skill covers practical patterns for building production-quality neural networks.
PyTorch是一款具备动态计算图、强大GPU加速能力和Python式设计的深度学习框架。本内容涵盖了构建生产级质量神经网络的实用模式。

Table of Contents

目录

Core Concepts

核心概念

Tensors

张量

python
import torch
python
import torch

Create tensors

Create tensors

x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) x = torch.zeros(3, 4) x = torch.randn(3, 4) # Normal distribution
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) x = torch.zeros(3, 4) x = torch.randn(3, 4) # Normal distribution

Device management

Device management

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") x = x.to(device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") x = x.to(device)

Operations (all return new tensors)

Operations (all return new tensors)

y = x + 1 y = x @ x.T # Matrix multiplication y = x.view(2, 6) # Reshape
undefined
y = x + 1 y = x @ x.T # Matrix multiplication y = x.view(2, 6) # Reshape
undefined

Autograd

自动微分

python
undefined
python
undefined

Enable gradient tracking

Enable gradient tracking

x = torch.randn(3, requires_grad=True) y = x ** 2 loss = y.sum()
x = torch.randn(3, requires_grad=True) y = x ** 2 loss = y.sum()

Compute gradients

Compute gradients

loss.backward() print(x.grad) # dy/dx
loss.backward() print(x.grad) # dy/dx

Disable gradients for inference

Disable gradients for inference

with torch.no_grad(): pred = model(x)
with torch.no_grad(): pred = model(x)

Or use inference mode (more efficient)

Or use inference mode (more efficient)

with torch.inference_mode(): pred = model(x)
undefined
with torch.inference_mode(): pred = model(x)
undefined

Model Architecture

模型架构

nn.Module Pattern

nn.Module 模式

python
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)
python
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

Common Layers

常见层

python
undefined
python
undefined

Convolution

Convolution

nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

Normalization

Normalization

nn.BatchNorm2d(num_features) nn.LayerNorm(normalized_shape)
nn.BatchNorm2d(num_features) nn.LayerNorm(normalized_shape)

Attention

Attention

nn.MultiheadAttention(embed_dim, num_heads)
nn.MultiheadAttention(embed_dim, num_heads)

Recurrent

Recurrent

nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
undefined
nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
undefined

Weight Initialization

权重初始化

python
def init_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=0.02)

model.apply(init_weights)
python
def init_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=0.02)

model.apply(init_weights)

Training Loop

训练循环

Standard Pattern

标准模式

python
model = Model(input_dim, hidden_dim, output_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        # Optional: gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

    scheduler.step()

    # Validation
    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            # ... validation logic
python
model = Model(input_dim, hidden_dim, output_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        # Optional: gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

    scheduler.step()

    # Validation
    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            # ... validation logic

Mixed Precision Training

混合精度训练

python
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in train_loader:
    inputs, targets = batch
    inputs, targets = inputs.to(device), targets.to(device)

    optimizer.zero_grad()

    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
python
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in train_loader:
    inputs, targets = batch
    inputs, targets = inputs.to(device), targets.to(device)

    optimizer.zero_grad()

    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Gradient Accumulation

梯度累积

python
undefined
python
undefined

Requires setup from Mixed Precision Training above:

Requires setup from Mixed Precision Training above:

scaler = GradScaler(), model, criterion, optimizer, device

scaler = GradScaler(), model, criterion, optimizer, device

accumulation_steps = 4
for i, batch in enumerate(train_loader): inputs, targets = batch inputs, targets = inputs.to(device), targets.to(device)
with autocast():
    outputs = model(inputs)
    loss = criterion(outputs, targets) / accumulation_steps

scaler.scale(loss).backward()

if (i + 1) % accumulation_steps == 0:
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
undefined
accumulation_steps = 4
for i, batch in enumerate(train_loader): inputs, targets = batch inputs, targets = inputs.to(device), targets.to(device)
with autocast():
    outputs = model(inputs)
    loss = criterion(outputs, targets) / accumulation_steps

scaler.scale(loss).backward()

if (i + 1) % accumulation_steps == 0:
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
undefined

Data Loading

数据加载

Dataset and DataLoader

Dataset 与 DataLoader

python
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        if self.transform:
            x = self.transform(x)
        return x, self.labels[idx]

train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,  # Faster GPU transfer
    drop_last=True,   # Consistent batch sizes
)
python
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        if self.transform:
            x = self.transform(x)
        return x, self.labels[idx]

train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,  # Faster GPU transfer
    drop_last=True,   # Consistent batch sizes
)

Collate Functions

自定义批处理函数

python
def collate_fn(batch):
    """Custom batching for variable-length sequences."""
    inputs, targets = zip(*batch)
    inputs = nn.utils.rnn.pad_sequence(inputs, batch_first=True)
    targets = torch.stack(targets)
    return inputs, targets

loader = DataLoader(dataset, collate_fn=collate_fn)
python
def collate_fn(batch):
    """Custom batching for variable-length sequences."""
    inputs, targets = zip(*batch)
    inputs = nn.utils.rnn.pad_sequence(inputs, batch_first=True)
    targets = torch.stack(targets)
    return inputs, targets

loader = DataLoader(dataset, collate_fn=collate_fn)

Performance Optimization

性能优化

torch.compile (PyTorch 2.0+)

torch.compile(PyTorch 2.0+)

python
undefined
python
undefined

Basic compilation

Basic compilation

model = torch.compile(model)
model = torch.compile(model)

With options

With options

model = torch.compile( model, mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune fullgraph=True, # Enforce no graph breaks )
model = torch.compile( model, mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune fullgraph=True, # Enforce no graph breaks )

Compile specific functions

Compile specific functions

@torch.compile def train_step(model, inputs, targets): outputs = model(inputs) return criterion(outputs, targets)

**Compilation modes:**
- `default`: Good balance of compile time and speedup
- `reduce-overhead`: Minimizes framework overhead, good for small models
- `max-autotune`: Maximum performance, longer compile time
@torch.compile def train_step(model, inputs, targets): outputs = model(inputs) return criterion(outputs, targets)

**编译模式说明:**
- `default`:编译时间与加速效果的平衡方案
- `reduce-overhead`:最小化框架开销,适合小型模型
- `max-autotune`:最大化性能,但编译时间更长

Memory Optimization

内存优化

python
undefined
python
undefined

Activation checkpointing (trade compute for memory)

Activation checkpointing (trade compute for memory)

from torch.utils.checkpoint import checkpoint
class Model(nn.Module): def forward(self, x): # Recompute activations during backward x = checkpoint(self.expensive_layer, x, use_reentrant=False) return self.output_layer(x)
from torch.utils.checkpoint import checkpoint
class Model(nn.Module): def forward(self, x): # Recompute activations during backward x = checkpoint(self.expensive_layer, x, use_reentrant=False) return self.output_layer(x)

Clear cache

Clear cache

torch.cuda.empty_cache()
torch.cuda.empty_cache()

Monitor memory

Monitor memory

print(torch.cuda.memory_allocated() / 1e9, "GB") print(torch.cuda.max_memory_allocated() / 1e9, "GB")
undefined
print(torch.cuda.memory_allocated() / 1e9, "GB") print(torch.cuda.max_memory_allocated() / 1e9, "GB")
undefined

Distributed Training

分布式训练

DistributedDataParallel (DDP)

DistributedDataParallel(DDP)

python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)

    model = Model().to(rank)
    model = DDP(model, device_ids=[rank])

    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, sampler=sampler)

    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)  # Important for shuffling
        # ... training loop

    cleanup()
python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)

    model = Model().to(rank)
    model = DDP(model, device_ids=[rank])

    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, sampler=sampler)

    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)  # Important for shuffling
        # ... training loop

    cleanup()

Launch with: torchrun --nproc_per_node=4 train.py

Launch with: torchrun --nproc_per_node=4 train.py

undefined
undefined

FullyShardedDataParallel (FSDP)

FullyShardedDataParallel(FSDP)

python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision

mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

model = FSDP(
    model,
    mixed_precision=mp_policy,
    use_orig_params=True,  # Required for torch.compile compatibility
)
python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision

mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

model = FSDP(
    model,
    mixed_precision=mp_policy,
    use_orig_params=True,  # Required for torch.compile compatibility
)

Saving and Loading

保存与加载

Checkpoints

检查点

python
undefined
python
undefined

Save

Save

torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, }, "checkpoint.pt")
torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, }, "checkpoint.pt")

Load

Load

checkpoint = torch.load("checkpoint.pt", map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
undefined
checkpoint = torch.load("checkpoint.pt", map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
undefined

Export for Deployment

导出用于部署

python
undefined
python
undefined

TorchScript

TorchScript

scripted = torch.jit.script(model) scripted.save("model.pt")
scripted = torch.jit.script(model) scripted.save("model.pt")

ONNX

ONNX

torch.onnx.export( model, dummy_input, "model.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}, )
undefined
torch.onnx.export( model, dummy_input, "model.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}, )
undefined

Best Practices

最佳实践

  1. Always set model mode: Use
    model.train()
    and
    model.eval()
    appropriately
  2. Use inference_mode over no_grad: More efficient for inference
  3. Pin memory for GPU training: Set
    pin_memory=True
    in DataLoader
  4. Profile before optimizing: Use
    torch.profiler
    to find bottlenecks
  5. Prefer bfloat16 over float16: Better numerical stability on modern GPUs
  6. Use torch.compile: Significant speedups with minimal code changes
  7. Set deterministic mode for reproducibility:
    python
    torch.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
  1. 始终设置模型模式:合理使用
    model.train()
    model.eval()
  2. 优先使用inference_mode而非no_grad:推理阶段更高效
  3. GPU训练时启用固定内存:在DataLoader中设置
    pin_memory=True
  4. 优化前先分析性能:使用
    torch.profiler
    定位瓶颈
  5. 优先选择bfloat16而非float16:在现代GPU上具备更好的数值稳定性
  6. 使用torch.compile:只需少量代码修改即可获得显著加速
  7. 设置确定性模式以保证可复现性
    python
    torch.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

References

参考资料

See
reference/
for detailed documentation:
  • training-patterns.md
    - Advanced training techniques
  • debugging.md
    - Debugging and profiling tools
请查看
reference/
目录下的详细文档:
  • training-patterns.md
    - 高级训练技术
  • debugging.md
    - 调试与性能分析工具