at-dispatch-v2
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseAT_DISPATCH to AT_DISPATCH_V2 Converter
AT_DISPATCH 转 AT_DISPATCH_V2 转换器
This skill helps convert PyTorch's legacy AT_DISPATCH macros to the new AT_DISPATCH_V2 format, as defined in .
aten/src/ATen/Dispatch_v2.h此技能可帮助将PyTorch的旧版AT_DISPATCH宏转换为新的AT_DISPATCH_V2格式,该格式定义于中。
aten/src/ATen/Dispatch_v2.hWhen to use this skill
何时使用此技能
Use this skill when:
- Converting AT_DISPATCH_* macros to AT_DISPATCH_V2
- Porting ATen kernels to use the new dispatch API
- Working with files in that use dispatch macros
aten/src/ATen/native/ - User mentions "AT_DISPATCH", "dispatch v2", "Dispatch_v2.h", or macro conversion
在以下场景使用此技能:
- 将AT_DISPATCH_*宏转换为AT_DISPATCH_V2
- 移植ATen内核以使用新的分发API
- 处理目录中使用分发宏的文件
aten/src/ATen/native/ - 用户提及"AT_DISPATCH"、"dispatch v2"、"Dispatch_v2.h"或宏转换
Quick reference
快速参考
Old format:
cpp
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() {
// lambda body
});New format:
cpp
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() {
// lambda body
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);旧格式:
cpp
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() {
// lambda body
});新格式:
cpp
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() {
// lambda body
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);Key transformations
核心转换要点
- Reorder arguments: and
scalar_typecome first, then lambda, then typesname - Wrap the lambda: Use to handle internal commas
AT_WRAP(lambda) - Expand type groups: Use instead of implicit expansion
AT_EXPAND(AT_ALL_TYPES) - List individual types: Add extra types (kHalf, kBFloat16, etc.) after expanded groups
- Add include: near other Dispatch includes
#include <ATen/Dispatch_v2.h>
- 调整参数顺序:和
scalar_type放在最前面,然后是lambda表达式,最后是类型name - 包装lambda表达式:使用处理内部逗号
AT_WRAP(lambda) - 展开类型组:用替代隐式展开
AT_EXPAND(AT_ALL_TYPES) - 列出独立类型:在展开的类型组后添加额外类型(如kHalf、kBFloat16等)
- 添加头文件引用:在其他Dispatch头文件附近添加
#include <ATen/Dispatch_v2.h>
Instructions
操作步骤
Step 1: Add the Dispatch_v2.h include
步骤1:添加Dispatch_v2.h头文件引用
Add the v2 header near the existing :
#include <ATen/Dispatch.h>cpp
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>Keep the old Dispatch.h include for now (other code may still need it).
在现有的附近添加v2版本的头文件:
#include <ATen/Dispatch.h>cpp
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>暂时保留旧的Dispatch.h引用(其他代码可能仍需使用)。
Step 2: Identify the old dispatch pattern
步骤2:识别旧的分发模式
Common patterns to convert:
AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)
需要转换的常见模式:
AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)
Step 3: Map the old macro to type groups
步骤3:将旧宏映射到类型组
Identify which type group macro corresponds to the base types:
| Old macro base | AT_DISPATCH_V2 type group |
|---|---|
| |
| |
| |
| |
| |
For combined patterns, use multiple entries:
AT_EXPAND()cpp
// Old: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...)
// New: AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2确定哪个类型组宏对应基础类型:
| 旧宏基础 | AT_DISPATCH_V2类型组 |
|---|---|
| |
| |
| |
| |
| |
对于组合模式,使用多个条目:
AT_EXPAND()cpp
// 旧版:AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...)
// 新版:AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2Step 4: Extract the individual types
步骤4:提取独立类型
From or , extract the individual types (type1, type2, etc.).
AT_DISPATCH_*_AND2(type1, type2, ...)AT_DISPATCH_*_AND3(type1, type2, type3, ...)These become the trailing arguments after the type group:
cpp
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool)
^^^^^^^^^^^^^^^^^^^^^^^^
Individual types from AND3从或中提取独立类型(type1、type2等)。
AT_DISPATCH_*_AND2(type1, type2, ...)AT_DISPATCH_*_AND3(type1, type2, type3, ...)这些类型将作为类型组之后的尾随参数:
cpp
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool)
^^^^^^^^^^^^^^^^^^^^^^^^
来自AND3的独立类型Step 5: Transform to AT_DISPATCH_V2
步骤5:转换为AT_DISPATCH_V2格式
Apply the transformation:
Pattern:
cpp
AT_DISPATCH_V2(
scalar_type, // 1st: The dtype expression
"name", // 2nd: The debug string
AT_WRAP(lambda), // 3rd: The lambda wrapped in AT_WRAP
type_groups, // 4th+: Type groups with AT_EXPAND()
individual_types // Last: Individual types
)Example transformation:
cpp
// BEFORE
AT_DISPATCH_ALL_TYPES_AND3(
kBFloat16, kHalf, kBool,
iter.dtype(),
"min_values_cuda",
[&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
}
);
// AFTER
AT_DISPATCH_V2(
iter.dtype(),
"min_values_cuda",
AT_WRAP([&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
kBFloat16, kHalf, kBool
);应用转换规则:
模式:
cpp
AT_DISPATCH_V2(
scalar_type, // 第1个参数:dtype表达式
"name", // 第2个参数:调试字符串
AT_WRAP(lambda), // 第3个参数:用AT_WRAP包装的lambda表达式
type_groups, // 第4及以后参数:使用AT_EXPAND()的类型组
individual_types // 最后:独立类型
)转换示例:
cpp
// 转换前
AT_DISPATCH_ALL_TYPES_AND3(
kBFloat16, kHalf, kBool,
iter.dtype(),
"min_values_cuda",
[&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
}
);
// 转换后
AT_DISPATCH_V2(
iter.dtype(),
"min_values_cuda",
AT_WRAP([&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
kBFloat16, kHalf, kBool
);Step 6: Handle multi-line lambdas
步骤6:处理多行lambda表达式
For lambdas with internal commas or complex expressions, AT_WRAP is essential:
cpp
AT_DISPATCH_V2(
dtype,
"complex_kernel",
AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
MinOps<scalar_t>{},
thrust::pair<scalar_t, int64_t>(upper_bound(), 0) // Commas inside!
);
}),
AT_EXPAND(AT_ALL_TYPES)
);对于包含内部逗号或复杂表达式的lambda,AT_WRAP是必不可少的:
cpp
AT_DISPATCH_V2(
dtype,
"complex_kernel",
AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
MinOps<scalar_t>{},
thrust::pair<scalar_t, int64_t>(upper_bound(), 0) // 内部包含逗号!
);
}),
AT_EXPAND(AT_ALL_TYPES)
);Step 7: Verify the conversion
步骤7:验证转换结果
Check that:
- wraps the entire lambda
AT_WRAP() - Type groups use
AT_EXPAND() - Individual types don't have (just
AT_EXPAND(), notkBFloat16)AT_EXPAND(kBFloat16) - Argument order is: scalar_type, name, lambda, types
- Include added:
#include <ATen/Dispatch_v2.h>
检查以下内容:
- 包裹了整个lambda表达式
AT_WRAP() - 类型组使用了
AT_EXPAND() - 独立类型没有使用(仅使用
AT_EXPAND(),而非kBFloat16)AT_EXPAND(kBFloat16) - 参数顺序为:scalar_type、name、lambda、类型
- 添加了头文件引用:
#include <ATen/Dispatch_v2.h>
Type group reference
类型组参考
Available type group macros (use with ):
AT_EXPAND()cpp
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
AT_FLOATING_TYPES // kDouble, kFloat
AT_COMPLEX_TYPES // kComplexDouble, kComplexFloat
AT_QINT_TYPES // kQInt8, kQUInt8, kQInt32
AT_ALL_TYPES // INTEGRAL_TYPES + FLOATING_TYPES
AT_ALL_TYPES_AND_COMPLEX // ALL_TYPES + COMPLEX_TYPES
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + unsigned types
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
AT_FLOAT8_TYPES // Float8 variants可用的类型组宏(需配合使用):
AT_EXPAND()cpp
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
AT_FLOATING_TYPES // kDouble, kFloat
AT_COMPLEX_TYPES // kComplexDouble, kComplexFloat
AT_QINT_TYPES // kQInt8, kQUInt8, kQInt32
AT_ALL_TYPES // INTEGRAL_TYPES + FLOATING_TYPES
AT_ALL_TYPES_AND_COMPLEX // ALL_TYPES + COMPLEX_TYPES
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + 无符号类型
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
AT_FLOAT8_TYPES // Float8变体Common patterns
常见模式
Pattern: AT_DISPATCH_ALL_TYPES_AND2
模式:AT_DISPATCH_ALL_TYPES_AND2
cpp
// Before
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
kernel<scalar_t>(data);
});
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>(data);
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);cpp
// 转换前
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
kernel<scalar_t>(data);
});
// 转换后
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>(data);
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);Pattern: AT_DISPATCH_FLOATING_TYPES_AND3
模式:AT_DISPATCH_FLOATING_TYPES_AND3
cpp
// Before
AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn,
tensor.scalar_type(), "float_op", [&] {
process<scalar_t>(tensor);
});
// After
AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] {
process<scalar_t>(tensor);
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);cpp
// 转换前
AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn,
tensor.scalar_type(), "float_op", [&] {
process<scalar_t>(tensor);
});
// 转换后
AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] {
process<scalar_t>(tensor);
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);Pattern: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2
模式:AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2
cpp
// Before
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kComplexHalf, kHalf,
self.scalar_type(),
"complex_op",
[&] {
result = compute<scalar_t>(self);
}
);
// After
AT_DISPATCH_V2(
self.scalar_type(),
"complex_op",
AT_WRAP([&] {
result = compute<scalar_t>(self);
}),
AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_COMPLEX_TYPES),
kComplexHalf,
kHalf
);cpp
// 转换前
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kComplexHalf, kHalf,
self.scalar_type(),
"complex_op",
[&] {
result = compute<scalar_t>(self);
}
);
// 转换后
AT_DISPATCH_V2(
self.scalar_type(),
"complex_op",
AT_WRAP([&] {
result = compute<scalar_t>(self);
}),
AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_COMPLEX_TYPES),
kComplexHalf,
kHalf
);Edge cases
边缘情况
Case 1: No extra types (rare)
情况1:无额外类型(罕见)
cpp
// Before
AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); });
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));cpp
// 转换前
AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); });
// 转换后
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));Case 2: Many individual types (AND4, AND5, etc.)
情况2:多个独立类型(AND4、AND5等)
cpp
// Before
AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2,
dtype, "float8_op", [&]() { kernel<scalar_t>(); });
// After
AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);cpp
// 转换前
AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2,
dtype, "float8_op", [&]() { kernel<scalar_t>(); });
// 转换后
AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);Case 3: Lambda with no captures
情况3:无捕获的lambda表达式
cpp
// Before
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() {
static_kernel<scalar_t>();
});
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() {
static_kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);cpp
// 转换前
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() {
static_kernel<scalar_t>();
});
// 转换后
AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() {
static_kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);Benefits of AT_DISPATCH_V2
AT_DISPATCH_V2的优势
- No arity in macro name: Don't need different macros for AND2, AND3, AND4
- Composable type sets: Mix and match type groups with
AT_EXPAND() - Extensible: Easy to add more types without hitting macro limits
- Clearer: Type groups are explicit, not implicit in macro name
- 宏名称无参数限制:无需为AND2、AND3、AND4使用不同的宏
- 类型集可组合:使用混合搭配类型组
AT_EXPAND() - 可扩展性强:无需受宏限制即可轻松添加更多类型
- 更清晰直观:类型组显式可见,而非隐含在宏名称中
Important notes
重要注意事项
- Keep - other code may need it
#include <ATen/Dispatch.h> - The is mandatory - prevents comma parsing issues in the lambda
AT_WRAP() - Type groups need , individual types don't
AT_EXPAND() - The v2 API is in - refer to it for full docs
aten/src/ATen/Dispatch_v2.h - See the header file for the Python script to regenerate the macro implementation
- 保留- 其他代码可能仍需使用
#include <ATen/Dispatch.h> - 是必填项 - 避免lambda表达式中的逗号解析问题
AT_WRAP() - 类型组需要使用,独立类型则不需要
AT_EXPAND() - v2 API位于中 - 可参考该头文件获取完整文档
aten/src/ATen/Dispatch_v2.h - 可查看头文件中的Python脚本以重新生成宏实现
Workflow
工作流程
When asked to convert AT_DISPATCH macros:
- Read the file to identify all AT_DISPATCH uses
- Add if not present
#include <ATen/Dispatch_v2.h> - For each dispatch macro:
- Identify the pattern and extract components
- Map the base type group
- Extract individual types
- Construct the AT_DISPATCH_V2 call
- Apply with Edit tool
- Show the user the complete converted file
- Explain what was changed
Do NOT compile or test the code - focus on accurate conversion only.
当被要求转换AT_DISPATCH宏时:
- 读取文件以识别所有AT_DISPATCH的使用位置
- 如果尚未添加,添加
#include <ATen/Dispatch_v2.h> - 对每个分发宏:
- 识别模式并提取组件
- 映射基础类型组
- 提取独立类型
- 构造AT_DISPATCH_V2调用
- 使用编辑工具应用转换
- 向用户展示完整的转换后文件
- 解释所做的更改
请勿编译或测试代码 - 仅专注于准确转换。