add-uint-support
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseAdd Unsigned Integer (uint) Support to Operators
为算子添加无符号整数(uint)支持
This skill helps add support for unsigned integer types (uint16, uint32, uint64) to PyTorch operators by updating their AT_DISPATCH macros.
本技能通过更新AT_DISPATCH宏,帮助为PyTorch算子添加无符号整数类型(uint16、uint32、uint64)支持。
When to use this skill
适用场景
Use this skill when:
- Adding uint16, uint32, or uint64 support to an operator
- User mentions "unsigned types", "uint support", "barebones unsigned types"
- Enabling support for kUInt16, kUInt32, kUInt64 in kernels
- Working with operator implementations that need expanded type coverage
在以下场景中使用本技能:
- 为算子添加uint16、uint32或uint64支持
- 用户提及“无符号类型”、“uint支持”、“基础无符号类型”
- 在内核中启用kUInt16、kUInt32、kUInt64支持
- 处理需要扩展类型覆盖范围的算子实现
Quick reference
快速参考
Add unsigned types to existing dispatch:
cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));
// After (method 1: add unsigned types explicitly)
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
// After (method 2: use V2 integral types if AT_INTEGRAL_TYPES present)
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));为现有调度添加无符号类型:
cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));
// After (method 1: add unsigned types explicitly)
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
// After (method 2: use V2 integral types if AT_INTEGRAL_TYPES present)
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));Type group reference
类型组参考
Unsigned type groups:
- : kUInt16, kUInt32, kUInt64
AT_BAREBONES_UNSIGNED_TYPES - : AT_INTEGRAL_TYPES + AT_BAREBONES_UNSIGNED_TYPES
AT_INTEGRAL_TYPES_V2
Relationship:
cpp
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + BAREBONES_UNSIGNED_TYPES无符号类型组:
- :kUInt16、kUInt32、kUInt64
AT_BAREBONES_UNSIGNED_TYPES - :AT_INTEGRAL_TYPES + AT_BAREBONES_UNSIGNED_TYPES
AT_INTEGRAL_TYPES_V2
关系:
cpp
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + BAREBONES_UNSIGNED_TYPESInstructions
操作步骤
Step 1: Determine if conversion to V2 is needed
步骤1:判断是否需要转换为V2版本
Check if the file uses AT_DISPATCH_V2:
If using old AT_DISPATCH:
- First convert to AT_DISPATCH_V2 using the at-dispatch-v2 skill
- Then proceed with adding uint support
If already using AT_DISPATCH_V2:
- Proceed directly to Step 2
检查文件是否使用AT_DISPATCH_V2:
如果使用旧版AT_DISPATCH:
- 首先使用at-dispatch-v2技能转换为AT_DISPATCH_V2
- 然后再添加uint支持
如果已使用AT_DISPATCH_V2:
- 直接进入步骤2
Step 2: Analyze the current dispatch macro
步骤2:分析当前调度宏
Identify what type groups are currently in use:
cpp
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
// body
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
^^^^^^^^^^^^^^^^^^^^^^^^^
Current type coverageCommon patterns:
- → includes AT_INTEGRAL_TYPES + AT_FLOATING_TYPES
AT_EXPAND(AT_ALL_TYPES) - → signed integers only
AT_EXPAND(AT_INTEGRAL_TYPES) - → floating point types
AT_EXPAND(AT_FLOATING_TYPES)
识别当前使用的类型组:
cpp
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
// body
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
^^^^^^^^^^^^^^^^^^^^^^^^^
当前类型覆盖范围常见模式:
- → 包含AT_INTEGRAL_TYPES + AT_FLOATING_TYPES
AT_EXPAND(AT_ALL_TYPES) - → 仅包含有符号整数
AT_EXPAND(AT_INTEGRAL_TYPES) - → 浮点类型
AT_EXPAND(AT_FLOATING_TYPES)
Step 3: Choose the uint addition method
步骤3:选择uint添加方式
Two approaches:
Method 1: Add AT_BAREBONES_UNSIGNED_TYPES explicitly
- Use when: You want to be explicit about adding uint support
- Add to the type list
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
Method 2: Substitute AT_INTEGRAL_TYPES with AT_INTEGRAL_TYPES_V2
- Use when: The dispatch already uses
AT_EXPAND(AT_INTEGRAL_TYPES) - More concise: replaces one type group with its superset
- Only applicable if AT_INTEGRAL_TYPES is present
有两种方法:
方法1:显式添加AT_BAREBONES_UNSIGNED_TYPES
- 适用场景:希望明确添加uint支持
- 在类型列表中添加
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
方法2:用AT_INTEGRAL_TYPES_V2替换AT_INTEGRAL_TYPES
- 适用场景:调度已使用
AT_EXPAND(AT_INTEGRAL_TYPES) - 更简洁:用超集替换一个类型组
- 仅当AT_INTEGRAL_TYPES存在时适用
Step 4: Apply the transformation
步骤4:应用转换
Method 1 example:
cpp
// Before
AT_DISPATCH_V2(
dtype,
"min_values_cuda",
AT_WRAP([&]() {
kernel_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
kBFloat16, kHalf, kBool
);
// After (add unsigned types)
AT_DISPATCH_V2(
dtype,
"min_values_cuda",
AT_WRAP([&]() {
kernel_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
kBFloat16, kHalf, kBool
);Method 2 example:
cpp
// Before
AT_DISPATCH_V2(
dtype,
"integral_op",
AT_WRAP([&]() {
kernel<scalar_t>();
}),
AT_EXPAND(AT_INTEGRAL_TYPES)
);
// After (substitute with V2)
AT_DISPATCH_V2(
dtype,
"integral_op",
AT_WRAP([&]() {
kernel<scalar_t>();
}),
AT_EXPAND(AT_INTEGRAL_TYPES_V2)
);方法1示例:
cpp
// Before
AT_DISPATCH_V2(
dtype,
"min_values_cuda",
AT_WRAP([&]() {
kernel_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
kBFloat16, kHalf, kBool
);
// After (add unsigned types)
AT_DISPATCH_V2(
dtype,
"min_values_cuda",
AT_WRAP([&]() {
kernel_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
kBFloat16, kHalf, kBool
);方法2示例:
cpp
// Before
AT_DISPATCH_V2(
dtype,
"integral_op",
AT_WRAP([&]() {
kernel<scalar_t>();
}),
AT_EXPAND(AT_INTEGRAL_TYPES)
);
// After (substitute with V2)
AT_DISPATCH_V2(
dtype,
"integral_op",
AT_WRAP([&]() {
kernel<scalar_t>();
}),
AT_EXPAND(AT_INTEGRAL_TYPES_V2)
);Step 5: Handle AT_ALL_TYPES vs individual type groups
步骤5:处理AT_ALL_TYPES与单独类型组的情况
If the dispatch uses :
AT_EXPAND(AT_ALL_TYPES)- =
AT_ALL_TYPES+AT_INTEGRAL_TYPESAT_FLOATING_TYPES - To add uint: add to the list
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
If the dispatch separately lists INTEGRAL and FLOATING:
cpp
// Before
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
// After (Method 2 preferred)
AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)如果调度使用:
AT_EXPAND(AT_ALL_TYPES)- =
AT_ALL_TYPES+AT_INTEGRAL_TYPESAT_FLOATING_TYPES - 要添加uint:在列表中添加
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
如果调度单独列出INTEGRAL和FLOATING:
cpp
// Before
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
// After (Method 2 preferred)
AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)Step 6: Verify all dispatch sites
步骤6:验证所有调度位置
Check the file for ALL dispatch macros that need uint support:
- Some operators have multiple dispatch sites (CPU, CUDA, different functions)
- Apply the transformation consistently across all sites
- Ensure each gets the same type coverage updates
检查文件中所有需要uint支持的调度宏:
- 部分算子有多个调度位置(CPU、CUDA、不同函数)
- 在所有位置统一应用转换
- 确保每个位置都获得相同的类型覆盖更新
Step 7: Validate the changes
步骤7:验证修改
Check that:
- AT_DISPATCH_V2 format is used (not old AT_DISPATCH)
- Unsigned types are added via one of the two methods
- All relevant dispatch sites in the file are updated
- Type groups use
AT_EXPAND() - Arguments are properly formatted and comma-separated
检查以下内容:
- 使用的是AT_DISPATCH_V2格式(而非旧版AT_DISPATCH)
- 通过两种方法之一添加了无符号类型
- 文件中所有相关调度位置都已更新
- 类型组使用了
AT_EXPAND() - 参数格式正确且用逗号分隔
Common patterns
常见模式
Pattern 1: AT_ALL_TYPES + extras
模式1:AT_ALL_TYPES + 额外类型
cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);Pattern 2: Separate INTEGRAL + FLOATING
模式2:单独的INTEGRAL + FLOATING
cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));Pattern 3: Old dispatch needs conversion first
模式3:旧版调度需要先转换
cpp
// Before (needs v2 conversion first)
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
kernel<scalar_t>();
});
// After v2 conversion
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
// After adding uint support
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);cpp
// Before (needs v2 conversion first)
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
kernel<scalar_t>();
});
// After v2 conversion
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
// After adding uint support
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);Multiple dispatch sites example
多调度位置示例
For a file with multiple functions:
cpp
void min_values_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
impl<scalar_t>(iter);
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
// Added uint support
}
void min_launch_kernel(TensorIterator &iter) {
AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t>(iter);
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
// Added uint support here too
}对于包含多个函数的文件:
cpp
void min_values_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
impl<scalar_t>(iter);
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
// 添加了uint支持
}
void min_launch_kernel(TensorIterator &iter) {
AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t>(iter);
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
// 此处也添加了uint支持
}Decision tree
决策树
Use this decision tree to determine the approach:
Is the file using AT_DISPATCH_V2?
├─ No → Use at-dispatch-v2 skill first, then continue
└─ Yes
└─ Does it use AT_EXPAND(AT_INTEGRAL_TYPES)?
├─ Yes → Replace with AT_EXPAND(AT_INTEGRAL_TYPES_V2)
└─ No → Add AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) to type list使用以下决策树确定处理方式:
文件是否使用AT_DISPATCH_V2?
├─ 否 → 先使用at-dispatch-v2技能,然后继续
└─ 是
└─ 是否使用了AT_EXPAND(AT_INTEGRAL_TYPES)?
├─ 是 → 替换为AT_EXPAND(AT_INTEGRAL_TYPES_V2)
└─ 否 → 在类型列表中添加AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)Edge cases
边缘情况
Case 1: Dispatch with only floating types
情况1:仅支持浮点类型的调度
If the operator only supports floating point types, don't add uint support:
cpp
// Leave as-is - floating point only operator
AT_DISPATCH_V2(dtype, "float_op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf);如果算子仅支持浮点类型,则无需添加uint支持:
cpp
// 保持原样 - 仅支持浮点类型的算子
AT_DISPATCH_V2(dtype, "float_op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf);Case 2: Complex types present
情况2:存在复数类型
Unsigned types work alongside complex types:
cpp
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
AT_EXPAND(AT_COMPLEX_TYPES),
kHalf, kBFloat16);无符号类型可与复数类型共存:
cpp
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
AT_EXPAND(AT_COMPLEX_TYPES),
kHalf, kBFloat16);Case 3: Already has uint support
情况3:已具备uint支持
Check if uint types are already present:
- If is used → already has uint support
AT_INTEGRAL_TYPES_V2 - If is already in list → already has uint support
AT_BAREBONES_UNSIGNED_TYPES - Skip the file if uint support is already present
检查是否已存在uint类型:
- 如果使用了→ 已具备uint支持
AT_INTEGRAL_TYPES_V2 - 如果已在列表中 → 已具备uint支持
AT_BAREBONES_UNSIGNED_TYPES - 若已具备uint支持,跳过该文件
Workflow
工作流程
When asked to add uint support:
- Read the target file
- Check if using AT_DISPATCH_V2:
- If not → use at-dispatch-v2 skill first
- Identify all dispatch macro sites
- For each dispatch:
- Analyze current type groups
- Choose method (add BAREBONES_UNSIGNED or upgrade to V2)
- Apply transformation with Edit tool
- Show the user the changes
- Explain what was modified
当被要求添加uint支持时:
- 读取目标文件
- 检查是否使用AT_DISPATCH_V2:
- 若未使用 → 先使用at-dispatch-v2技能
- 识别所有调度宏位置
- 针对每个调度:
- 分析当前类型组
- 选择处理方式(添加BAREBONES_UNSIGNED或升级到V2)
- 使用编辑工具应用转换
- 向用户展示修改内容
- 解释修改的内容
Important notes
重要说明
- Always check if v2 conversion is needed first
- Apply changes consistently across all dispatch sites in the file
- Method 2 (AT_INTEGRAL_TYPES_V2) is cleaner when applicable
- Method 1 (explicit AT_BAREBONES_UNSIGNED_TYPES) is more explicit
- Unsigned types are: kUInt16, kUInt32, kUInt64 (not kByte which is uint8)
- Some operators may not semantically support unsigned types - use judgment
- 始终先检查是否需要转换为V2版本
- 在文件的所有调度位置统一应用修改
- 方法2(AT_INTEGRAL_TYPES_V2)适用时更简洁
- 方法1(显式添加AT_BAREBONES_UNSIGNED_TYPES)更明确
- 无符号类型包括:kUInt16、kUInt32、kUInt64(不包括kByte,它是uint8)
- 部分算子在语义上可能不支持无符号类型,请根据实际情况判断
Testing
测试
After adding uint support, the operator should accept uint16, uint32, and uint64 tensors. The user is responsible for functional testing.
添加uint支持后,算子应能接受uint16、uint32和uint64张量。功能测试由用户负责。