at-dispatch-v2

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

AT_DISPATCH to AT_DISPATCH_V2 Converter

AT_DISPATCH 转 AT_DISPATCH_V2 转换器

This skill helps convert PyTorch's legacy AT_DISPATCH macros to the new AT_DISPATCH_V2 format, as defined in
aten/src/ATen/Dispatch_v2.h
.
此技能可帮助将PyTorch的旧版AT_DISPATCH宏转换为新的AT_DISPATCH_V2格式,该格式定义于
aten/src/ATen/Dispatch_v2.h
中。

When to use this skill

何时使用此技能

Use this skill when:
  • Converting AT_DISPATCH_* macros to AT_DISPATCH_V2
  • Porting ATen kernels to use the new dispatch API
  • Working with files in
    aten/src/ATen/native/
    that use dispatch macros
  • User mentions "AT_DISPATCH", "dispatch v2", "Dispatch_v2.h", or macro conversion
在以下场景使用此技能:
  • 将AT_DISPATCH_*宏转换为AT_DISPATCH_V2
  • 移植ATen内核以使用新的分发API
  • 处理
    aten/src/ATen/native/
    目录中使用分发宏的文件
  • 用户提及"AT_DISPATCH"、"dispatch v2"、"Dispatch_v2.h"或宏转换

Quick reference

快速参考

Old format:
cpp
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() {
  // lambda body
});
New format:
cpp
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() {
  // lambda body
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);
旧格式:
cpp
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() {
  // lambda body
});
新格式:
cpp
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() {
  // lambda body
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);

Key transformations

核心转换要点

  1. Reorder arguments:
    scalar_type
    and
    name
    come first, then lambda, then types
  2. Wrap the lambda: Use
    AT_WRAP(lambda)
    to handle internal commas
  3. Expand type groups: Use
    AT_EXPAND(AT_ALL_TYPES)
    instead of implicit expansion
  4. List individual types: Add extra types (kHalf, kBFloat16, etc.) after expanded groups
  5. Add include:
    #include <ATen/Dispatch_v2.h>
    near other Dispatch includes
  1. 调整参数顺序
    scalar_type
    name
    放在最前面,然后是lambda表达式,最后是类型
  2. 包装lambda表达式:使用
    AT_WRAP(lambda)
    处理内部逗号
  3. 展开类型组:用
    AT_EXPAND(AT_ALL_TYPES)
    替代隐式展开
  4. 列出独立类型:在展开的类型组后添加额外类型(如kHalf、kBFloat16等)
  5. 添加头文件引用:在其他Dispatch头文件附近添加
    #include <ATen/Dispatch_v2.h>

Instructions

操作步骤

Step 1: Add the Dispatch_v2.h include

步骤1:添加Dispatch_v2.h头文件引用

Add the v2 header near the existing
#include <ATen/Dispatch.h>
:
cpp
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
Keep the old Dispatch.h include for now (other code may still need it).
在现有的
#include <ATen/Dispatch.h>
附近添加v2版本的头文件:
cpp
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
暂时保留旧的Dispatch.h引用(其他代码可能仍需使用)。

Step 2: Identify the old dispatch pattern

步骤2:识别旧的分发模式

Common patterns to convert:
  • AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)
  • AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)
  • AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)
  • AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)
需要转换的常见模式:
  • AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)
  • AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)
  • AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)
  • AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)

Step 3: Map the old macro to type groups

步骤3:将旧宏映射到类型组

Identify which type group macro corresponds to the base types:
Old macro baseAT_DISPATCH_V2 type group
ALL_TYPES
AT_EXPAND(AT_ALL_TYPES)
FLOATING_TYPES
AT_EXPAND(AT_FLOATING_TYPES)
INTEGRAL_TYPES
AT_EXPAND(AT_INTEGRAL_TYPES)
COMPLEX_TYPES
AT_EXPAND(AT_COMPLEX_TYPES)
ALL_TYPES_AND_COMPLEX
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX)
For combined patterns, use multiple
AT_EXPAND()
entries:
cpp
// Old: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...)
// New: AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2
确定哪个类型组宏对应基础类型:
旧宏基础AT_DISPATCH_V2类型组
ALL_TYPES
AT_EXPAND(AT_ALL_TYPES)
FLOATING_TYPES
AT_EXPAND(AT_FLOATING_TYPES)
INTEGRAL_TYPES
AT_EXPAND(AT_INTEGRAL_TYPES)
COMPLEX_TYPES
AT_EXPAND(AT_COMPLEX_TYPES)
ALL_TYPES_AND_COMPLEX
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX)
对于组合模式,使用多个
AT_EXPAND()
条目:
cpp
// 旧版:AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...)
// 新版:AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2

Step 4: Extract the individual types

步骤4:提取独立类型

From
AT_DISPATCH_*_AND2(type1, type2, ...)
or
AT_DISPATCH_*_AND3(type1, type2, type3, ...)
, extract the individual types (type1, type2, etc.).
These become the trailing arguments after the type group:
cpp
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool)
                                             ^^^^^^^^^^^^^^^^^^^^^^^^
                                             Individual types from AND3
AT_DISPATCH_*_AND2(type1, type2, ...)
AT_DISPATCH_*_AND3(type1, type2, type3, ...)
中提取独立类型(type1、type2等)。
这些类型将作为类型组之后的尾随参数:
cpp
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool)
                                             ^^^^^^^^^^^^^^^^^^^^^^^^
                                             来自AND3的独立类型

Step 5: Transform to AT_DISPATCH_V2

步骤5:转换为AT_DISPATCH_V2格式

Apply the transformation:
Pattern:
cpp
AT_DISPATCH_V2(
  scalar_type,           // 1st: The dtype expression
  "name",                // 2nd: The debug string
  AT_WRAP(lambda),       // 3rd: The lambda wrapped in AT_WRAP
  type_groups,           // 4th+: Type groups with AT_EXPAND()
  individual_types       // Last: Individual types
)
Example transformation:
cpp
// BEFORE
AT_DISPATCH_ALL_TYPES_AND3(
    kBFloat16, kHalf, kBool,
    iter.dtype(),
    "min_values_cuda",
    [&]() {
      min_values_kernel_cuda_impl<scalar_t>(iter);
    }
);

// AFTER
AT_DISPATCH_V2(
    iter.dtype(),
    "min_values_cuda",
    AT_WRAP([&]() {
      min_values_kernel_cuda_impl<scalar_t>(iter);
    }),
    AT_EXPAND(AT_ALL_TYPES),
    kBFloat16, kHalf, kBool
);
应用转换规则:
模式:
cpp
AT_DISPATCH_V2(
  scalar_type,           // 第1个参数:dtype表达式
  "name",                // 第2个参数:调试字符串
  AT_WRAP(lambda),       // 第3个参数:用AT_WRAP包装的lambda表达式
  type_groups,           // 第4及以后参数:使用AT_EXPAND()的类型组
  individual_types       // 最后:独立类型
)
转换示例:
cpp
// 转换前
AT_DISPATCH_ALL_TYPES_AND3(
    kBFloat16, kHalf, kBool,
    iter.dtype(),
    "min_values_cuda",
    [&]() {
      min_values_kernel_cuda_impl<scalar_t>(iter);
    }
);

// 转换后
AT_DISPATCH_V2(
    iter.dtype(),
    "min_values_cuda",
    AT_WRAP([&]() {
      min_values_kernel_cuda_impl<scalar_t>(iter);
    }),
    AT_EXPAND(AT_ALL_TYPES),
    kBFloat16, kHalf, kBool
);

Step 6: Handle multi-line lambdas

步骤6:处理多行lambda表达式

For lambdas with internal commas or complex expressions, AT_WRAP is essential:
cpp
AT_DISPATCH_V2(
    dtype,
    "complex_kernel",
    AT_WRAP([&]() {
      gpu_reduce_kernel<scalar_t, scalar_t>(
        iter,
        MinOps<scalar_t>{},
        thrust::pair<scalar_t, int64_t>(upper_bound(), 0)  // Commas inside!
      );
    }),
    AT_EXPAND(AT_ALL_TYPES)
);
对于包含内部逗号或复杂表达式的lambda,AT_WRAP是必不可少的:
cpp
AT_DISPATCH_V2(
    dtype,
    "complex_kernel",
    AT_WRAP([&]() {
      gpu_reduce_kernel<scalar_t, scalar_t>(
        iter,
        MinOps<scalar_t>{},
        thrust::pair<scalar_t, int64_t>(upper_bound(), 0)  // 内部包含逗号!
      );
    }),
    AT_EXPAND(AT_ALL_TYPES)
);

Step 7: Verify the conversion

步骤7:验证转换结果

Check that:
  • AT_WRAP()
    wraps the entire lambda
  • Type groups use
    AT_EXPAND()
  • Individual types don't have
    AT_EXPAND()
    (just
    kBFloat16
    , not
    AT_EXPAND(kBFloat16)
    )
  • Argument order is: scalar_type, name, lambda, types
  • Include added:
    #include <ATen/Dispatch_v2.h>
检查以下内容:
  • AT_WRAP()
    包裹了整个lambda表达式
  • 类型组使用了
    AT_EXPAND()
  • 独立类型没有使用
    AT_EXPAND()
    (仅使用
    kBFloat16
    ,而非
    AT_EXPAND(kBFloat16)
  • 参数顺序为:scalar_type、name、lambda、类型
  • 添加了头文件引用:
    #include <ATen/Dispatch_v2.h>

Type group reference

类型组参考

Available type group macros (use with
AT_EXPAND()
):
cpp
AT_INTEGRAL_TYPES      // kByte, kChar, kInt, kLong, kShort
AT_FLOATING_TYPES      // kDouble, kFloat
AT_COMPLEX_TYPES       // kComplexDouble, kComplexFloat
AT_QINT_TYPES         // kQInt8, kQUInt8, kQInt32
AT_ALL_TYPES          // INTEGRAL_TYPES + FLOATING_TYPES
AT_ALL_TYPES_AND_COMPLEX  // ALL_TYPES + COMPLEX_TYPES
AT_INTEGRAL_TYPES_V2  // INTEGRAL_TYPES + unsigned types
AT_BAREBONES_UNSIGNED_TYPES  // kUInt16, kUInt32, kUInt64
AT_FLOAT8_TYPES       // Float8 variants
可用的类型组宏(需配合
AT_EXPAND()
使用):
cpp
AT_INTEGRAL_TYPES      // kByte, kChar, kInt, kLong, kShort
AT_FLOATING_TYPES      // kDouble, kFloat
AT_COMPLEX_TYPES       // kComplexDouble, kComplexFloat
AT_QINT_TYPES         // kQInt8, kQUInt8, kQInt32
AT_ALL_TYPES          // INTEGRAL_TYPES + FLOATING_TYPES
AT_ALL_TYPES_AND_COMPLEX  // ALL_TYPES + COMPLEX_TYPES
AT_INTEGRAL_TYPES_V2  // INTEGRAL_TYPES + 无符号类型
AT_BAREBONES_UNSIGNED_TYPES  // kUInt16, kUInt32, kUInt64
AT_FLOAT8_TYPES       // Float8变体

Common patterns

常见模式

Pattern: AT_DISPATCH_ALL_TYPES_AND2

模式:AT_DISPATCH_ALL_TYPES_AND2

cpp
// Before
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
  kernel<scalar_t>(data);
});

// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>(data);
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
cpp
// 转换前
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
  kernel<scalar_t>(data);
});

// 转换后
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>(data);
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);

Pattern: AT_DISPATCH_FLOATING_TYPES_AND3

模式:AT_DISPATCH_FLOATING_TYPES_AND3

cpp
// Before
AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn,
    tensor.scalar_type(), "float_op", [&] {
  process<scalar_t>(tensor);
});

// After
AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] {
  process<scalar_t>(tensor);
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);
cpp
// 转换前
AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn,
    tensor.scalar_type(), "float_op", [&] {
  process<scalar_t>(tensor);
});

// 转换后
AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] {
  process<scalar_t>(tensor);
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);

Pattern: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2

模式:AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2

cpp
// Before
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
    kComplexHalf, kHalf,
    self.scalar_type(),
    "complex_op",
    [&] {
      result = compute<scalar_t>(self);
    }
);

// After
AT_DISPATCH_V2(
    self.scalar_type(),
    "complex_op",
    AT_WRAP([&] {
      result = compute<scalar_t>(self);
    }),
    AT_EXPAND(AT_ALL_TYPES),
    AT_EXPAND(AT_COMPLEX_TYPES),
    kComplexHalf,
    kHalf
);
cpp
// 转换前
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
    kComplexHalf, kHalf,
    self.scalar_type(),
    "complex_op",
    [&] {
      result = compute<scalar_t>(self);
    }
);

// 转换后
AT_DISPATCH_V2(
    self.scalar_type(),
    "complex_op",
    AT_WRAP([&] {
      result = compute<scalar_t>(self);
    }),
    AT_EXPAND(AT_ALL_TYPES),
    AT_EXPAND(AT_COMPLEX_TYPES),
    kComplexHalf,
    kHalf
);

Edge cases

边缘情况

Case 1: No extra types (rare)

情况1:无额外类型(罕见)

cpp
// Before
AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); });

// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));
cpp
// 转换前
AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); });

// 转换后
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));

Case 2: Many individual types (AND4, AND5, etc.)

情况2:多个独立类型(AND4、AND5等)

cpp
// Before
AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2,
    dtype, "float8_op", [&]() { kernel<scalar_t>(); });

// After
AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);
cpp
// 转换前
AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2,
    dtype, "float8_op", [&]() { kernel<scalar_t>(); });

// 转换后
AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);

Case 3: Lambda with no captures

情况3:无捕获的lambda表达式

cpp
// Before
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() {
  static_kernel<scalar_t>();
});

// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() {
  static_kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);
cpp
// 转换前
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() {
  static_kernel<scalar_t>();
});

// 转换后
AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() {
  static_kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);

Benefits of AT_DISPATCH_V2

AT_DISPATCH_V2的优势

  1. No arity in macro name: Don't need different macros for AND2, AND3, AND4
  2. Composable type sets: Mix and match type groups with
    AT_EXPAND()
  3. Extensible: Easy to add more types without hitting macro limits
  4. Clearer: Type groups are explicit, not implicit in macro name
  1. 宏名称无参数限制:无需为AND2、AND3、AND4使用不同的宏
  2. 类型集可组合:使用
    AT_EXPAND()
    混合搭配类型组
  3. 可扩展性强:无需受宏限制即可轻松添加更多类型
  4. 更清晰直观:类型组显式可见,而非隐含在宏名称中

Important notes

重要注意事项

  • Keep
    #include <ATen/Dispatch.h>
    - other code may need it
  • The
    AT_WRAP()
    is mandatory - prevents comma parsing issues in the lambda
  • Type groups need
    AT_EXPAND()
    , individual types don't
  • The v2 API is in
    aten/src/ATen/Dispatch_v2.h
    - refer to it for full docs
  • See the header file for the Python script to regenerate the macro implementation
  • 保留
    #include <ATen/Dispatch.h>
    - 其他代码可能仍需使用
  • AT_WRAP()
    是必填项 - 避免lambda表达式中的逗号解析问题
  • 类型组需要使用
    AT_EXPAND()
    ,独立类型则不需要
  • v2 API位于
    aten/src/ATen/Dispatch_v2.h
    中 - 可参考该头文件获取完整文档
  • 可查看头文件中的Python脚本以重新生成宏实现

Workflow

工作流程

When asked to convert AT_DISPATCH macros:
  1. Read the file to identify all AT_DISPATCH uses
  2. Add
    #include <ATen/Dispatch_v2.h>
    if not present
  3. For each dispatch macro:
    • Identify the pattern and extract components
    • Map the base type group
    • Extract individual types
    • Construct the AT_DISPATCH_V2 call
    • Apply with Edit tool
  4. Show the user the complete converted file
  5. Explain what was changed
Do NOT compile or test the code - focus on accurate conversion only.
当被要求转换AT_DISPATCH宏时:
  1. 读取文件以识别所有AT_DISPATCH的使用位置
  2. 如果尚未添加,添加
    #include <ATen/Dispatch_v2.h>
  3. 对每个分发宏:
    • 识别模式并提取组件
    • 映射基础类型组
    • 提取独立类型
    • 构造AT_DISPATCH_V2调用
    • 使用编辑工具应用转换
  4. 向用户展示完整的转换后文件
  5. 解释所做的更改
请勿编译或测试代码 - 仅专注于准确转换。