torch-tensor-parallelism

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Tensor Parallelism Implementation Guide

PyTorch张量并行实现指南

This skill provides guidance for implementing tensor parallelism patterns in PyTorch, specifically for ColumnParallelLinear and RowParallelLinear layers that distribute computation across multiple devices.
本指南提供了在PyTorch中实现张量并行模式的方法,特别是针对ColumnParallelLinear和RowParallelLinear层,这些层可将计算任务分发到多个设备上。

Core Concepts

核心概念

Tensor Parallelism Overview

张量并行概述

Tensor parallelism splits individual layers across multiple devices to parallelize computation within a single forward/backward pass. The two primary patterns are:
  1. ColumnParallelLinear: Shards weights along the output dimension (columns). Each device computes a portion of the output features, then results are concatenated via all-gather.
  2. RowParallelLinear: Shards weights along the input dimension (rows). Each device computes partial outputs using its shard of the input, then results are summed via all-reduce.
张量并行将单个层拆分到多个设备上,以在单次前向/反向传播中实现计算并行化。两种主要模式如下:
  1. ColumnParallelLinear:沿输出维度(列)切分权重。每个设备计算部分输出特征,随后通过all-gather操作将结果拼接。
  2. RowParallelLinear:沿输入维度(行)切分权重。每个设备使用自身的输入分片计算部分输出,随后通过all-reduce操作将结果求和。

Critical Implementation Requirement

关键实现要求

When implementing tensor parallelism (especially in simulation or testing contexts), the forward pass must actually perform the collective operations, not just compute local shards:
  • ColumnParallelLinear: Must concatenate outputs from all ranks (all-gather semantics)
  • RowParallelLinear: Must sum outputs from all ranks (all-reduce semantics)
A common mistake is returning only the local shard and expecting an external framework to handle collective operations. Unless explicitly specified otherwise, the implementation should produce the final, complete output.
实现张量并行时(尤其是在模拟或测试场景下),前向传播必须实际执行集合操作,而不仅仅是计算本地分片:
  • ColumnParallelLinear:必须拼接所有rank的输出(符合all-gather语义)
  • RowParallelLinear:必须对所有rank的输出求和(符合all-reduce语义)
一个常见错误是仅返回本地分片,期望外部框架处理集合操作。除非有明确说明,否则实现应生成最终的完整输出。

Implementation Approach

实现步骤

Step 1: Understand the Parallelism Pattern

步骤1:理解并行模式

Before implementing, clearly identify:
  1. Which dimension is being sharded (input features vs output features)
  2. What collective operation combines the results (all-gather vs all-reduce)
  3. Whether the implementation should simulate distributed execution or prepare for actual distributed execution
  4. How bias should be handled in the parallel context
在实现前,需明确以下几点:
  1. 要切分的维度(输入特征 vs 输出特征)
  2. 用于合并结果的集合操作(all-gather vs all-reduce)
  3. 实现是要模拟分布式执行还是为实际分布式执行做准备
  4. 在并行环境下如何处理偏置(bias)

Step 2: Weight Sharding

步骤2:权重切分

For weight matrix W of shape (out_features, in_features):
ColumnParallelLinear:
  • Shard W along dim=0 (output features)
  • Each rank gets W_shard of shape (out_features // world_size, in_features)
  • Output shape per rank: (batch, out_features // world_size)
RowParallelLinear:
  • Shard W along dim=1 (input features)
  • Each rank gets W_shard of shape (out_features, in_features // world_size)
  • Input to each rank should be corresponding shard of input
  • Output shape per rank: (batch, out_features) - partial sum
对于形状为(out_features, in_features)的权重矩阵W:
ColumnParallelLinear:
  • 沿dim=0(输出特征)切分W
  • 每个rank获得形状为(out_features // world_size, in_features)的W_shard
  • 每个rank的输出形状:(batch, out_features // world_size)
RowParallelLinear:
  • 沿dim=1(输入特征)切分W
  • 每个rank获得形状为(out_features, in_features // world_size)的W_shard
  • 每个rank的输入应为对应分片的输入
  • 每个rank的输出形状:(batch, out_features) - 部分求和结果

Step 3: Forward Pass Implementation

步骤3:前向传播实现

ColumnParallelLinear Forward:
1. Compute local output: y_local = x @ W_shard.T + bias_shard (if bias per shard)
2. All-gather to concatenate: y = concat([y_0, y_1, ..., y_n], dim=-1)
3. Return complete output of shape (batch, out_features)
RowParallelLinear Forward:
1. Get input shard: x_shard = x[..., start:end] for this rank
2. Compute partial output: y_partial = x_shard @ W_shard.T
3. All-reduce to sum: y = sum([y_0, y_1, ..., y_n])
4. Add bias (only once, not per-rank): y = y + bias
5. Return complete output of shape (batch, out_features)
ColumnParallelLinear前向传播:
1. 计算本地输出:y_local = x @ W_shard.T + bias_shard(如果每个分片有偏置)
2. 执行all-gather拼接:y = concat([y_0, y_1, ..., y_n], dim=-1)
3. 返回形状为(batch, out_features)的完整输出
RowParallelLinear前向传播:
1. 获取输入分片:x_shard = x[..., start:end](当前rank对应的分片)
2. 计算部分输出:y_partial = x_shard @ W_shard.T
3. 执行all-reduce求和:y = sum([y_0, y_1, ..., y_n])
4. 添加偏置(仅执行一次,而非每个rank都添加):y = y + bias
5. 返回形状为(batch, out_features)的完整输出

Step 4: Bias Handling

步骤4:偏置处理

ColumnParallelLinear:
  • Bias can be sharded along with output features
  • Each rank adds its bias shard to its output shard
  • After all-gather, the full bias has been applied
RowParallelLinear:
  • Bias must NOT be sharded or added per-rank (would cause N-fold bias)
  • Add bias only once after the all-reduce operation
  • Typically only rank 0 adds bias, OR add bias after the sum
ColumnParallelLinear:
  • 偏置可随输出特征一起切分
  • 每个rank将自身的偏置分片添加到输出分片
  • 执行all-gather后,完整的偏置已被应用
RowParallelLinear:
  • 偏置不能被切分或由每个rank单独添加(会导致结果被放大N倍,N为world_size)
  • 仅在all-reduce操作后添加一次偏置
  • 通常仅由rank 0添加偏置,或在求和完成后添加

Verification Strategies

验证策略

Mathematical Verification

数学验证

When local testing is unavailable, verify implementation correctness through mathematical analysis:
  1. Simple example: Use a 2x4 weight matrix with world_size=2
  2. Trace computation: Manually compute what each rank produces
  3. Verify combination: Confirm all-gather/all-reduce produces correct final output
  4. Compare to baseline: Verify parallel output matches non-parallel computation
当无法进行本地测试时,可通过数学分析验证实现正确性:
  1. 简单示例:使用world_size=2和一个2x4的权重矩阵
  2. 追踪计算过程:手动计算每个rank的输出
  3. 验证合并结果:确认all-gather/all-reduce操作能生成正确的最终输出
  4. 与基线对比:验证并行实现的输出与非并行实现的输出一致

Shape Verification Checklist

形状验证清单

  • Input shape matches expected (batch, in_features)
  • Weight shard shape matches expected partitioning
  • Local output shape is correct for the parallelism type
  • Final output shape matches (batch, out_features) - NOT the sharded dimension
  • 输入形状符合预期(batch, in_features)
  • 权重分片形状符合预期的切分方式
  • 本地输出形状对应正确的并行类型
  • 最终输出形状为(batch, out_features),而非切分后的维度

Test Cases to Consider

需考虑的测试用例

  1. world_size=1: Trivial case, should match non-parallel implementation exactly
  2. world_size=2,4,8: Common parallel configurations
  3. Non-divisible dimensions: What happens when out_features % world_size != 0?
  4. Different batch sizes: Verify batch dimension is handled correctly
  5. With and without bias: Test both configurations
  1. world_size=1:简单场景,应与非并行实现完全一致
  2. world_size=2,4,8:常见的并行配置
  3. 非可整除维度:当out_features % world_size != 0时如何处理?
  4. 不同批量大小:验证批量维度是否被正确处理
  5. 带偏置和不带偏置:测试两种配置

Common Pitfalls

常见陷阱

Pitfall 1: Returning Local Shards Only

陷阱1:仅返回本地分片

Symptom: Output tensor size is (out_features / world_size) instead of (out_features)
Cause: Implementation computes local shard but doesn't perform all-gather
Fix: Implement the collective operation to combine results from all ranks
症状:输出张量大小为(out_features / world_size)而非(out_features)
原因:实现仅计算了本地分片,但未执行all-gather操作
解决方法:实现集合操作以合并所有rank的结果

Pitfall 2: Incorrect Bias Handling in RowParallelLinear

陷阱2:RowParallelLinear中偏置处理错误

Symptom: Output values are N times larger than expected (where N is world_size)
Cause: Each rank adds the full bias, then values are summed
Fix: Add bias only once after all-reduce, not per-rank
症状:输出值比预期大N倍(N为world_size)
原因:每个rank都添加了完整的偏置,随后结果被求和
解决方法:仅在all-reduce操作后添加一次偏置,而非每个rank单独添加

Pitfall 3: Misinterpreting "Simulation" Requirements

陷阱3:误解“模拟”需求

Symptom: Implementation works for world_size=1 but fails for larger world sizes
Cause: Assuming external framework handles collective operations
Fix: Read requirements carefully - "as if using all_gather" means implement the operation
症状:实现在world_size=1时正常,但在更大的world_size下失败
原因:假设外部框架会处理集合操作
解决方法:仔细阅读需求——“如同使用all_gather”意味着需要自行实现该操作

Pitfall 4: Truncated File Writes

陷阱4:文件写入被截断

Symptom: Implementation has syntax errors or missing code
Cause: File write operation was truncated
Fix: Always read back the complete file after writing to verify integrity
症状:实现存在语法错误或代码缺失
原因:文件写入操作被截断
解决方法:写入文件后务必回读完整内容以验证完整性

Pitfall 5: Wrong Dimension for Sharding

陷阱5:切分维度错误

Symptom: Shape mismatch errors during matrix multiplication
Cause: Sharding along wrong dimension (rows vs columns confusion)
Fix: ColumnParallel shards output features (dim=0 of weight), RowParallel shards input features (dim=1 of weight)
症状:矩阵乘法时出现形状不匹配错误
原因:沿错误维度切分(混淆了行和列)
解决方法:ColumnParallel切分输出特征(权重的dim=0),RowParallel切分输入特征(权重的dim=1)

Pre-Implementation Checklist

实现前检查清单

Before writing code, confirm understanding of:
  • Which collective operation is needed (all-gather vs all-reduce)
  • What the final output shape should be
  • Whether simulation should actually perform collective ops or defer them
  • How bias should be handled for this parallelism type
  • What happens for edge cases (world_size=1, non-divisible dimensions)
编写代码前,确认已理解以下内容:
  • 需要使用哪种集合操作(all-gather vs all-reduce)
  • 最终输出形状应为
  • 模拟实现是否需要实际执行集合操作还是延迟执行
  • 针对该并行类型应如何处理偏置
  • 边缘情况(world_size=1、非可整除维度)的处理方式

Post-Implementation Checklist

实现后检查清单

After writing code:
  • Read back the complete implementation file to verify no truncation
  • Verify output shapes match expected dimensions for all world sizes
  • Trace through a simple example manually to verify correctness
  • Test trivial case (world_size=1) matches non-parallel baseline
  • Test at least one non-trivial case (world_size=2 or 4)
编写代码后:
  • 回读完整的实现文件以验证无截断
  • 验证所有world_size配置下的输出形状符合预期
  • 通过手动追踪简单示例验证正确性
  • 测试简单场景(world_size=1)与非并行基线一致
  • 至少测试一个非简单场景(world_size=2或4)