add-uint-support

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Add Unsigned Integer (uint) Support to Operators

为算子添加无符号整数(uint)支持

This skill helps add support for unsigned integer types (uint16, uint32, uint64) to PyTorch operators by updating their AT_DISPATCH macros.
本技能通过更新AT_DISPATCH宏,帮助为PyTorch算子添加无符号整数类型(uint16、uint32、uint64)支持。

When to use this skill

适用场景

Use this skill when:
  • Adding uint16, uint32, or uint64 support to an operator
  • User mentions "unsigned types", "uint support", "barebones unsigned types"
  • Enabling support for kUInt16, kUInt32, kUInt64 in kernels
  • Working with operator implementations that need expanded type coverage
在以下场景中使用本技能:
  • 为算子添加uint16、uint32或uint64支持
  • 用户提及“无符号类型”、“uint支持”、“基础无符号类型”
  • 在内核中启用kUInt16、kUInt32、kUInt64支持
  • 处理需要扩展类型覆盖范围的算子实现

Quick reference

快速参考

Add unsigned types to existing dispatch:
cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));

// After (method 1: add unsigned types explicitly)
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));

// After (method 2: use V2 integral types if AT_INTEGRAL_TYPES present)
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));
为现有调度添加无符号类型:
cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));

// After (method 1: add unsigned types explicitly)
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));

// After (method 2: use V2 integral types if AT_INTEGRAL_TYPES present)
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));

Type group reference

类型组参考

Unsigned type groups:
  • AT_BAREBONES_UNSIGNED_TYPES
    : kUInt16, kUInt32, kUInt64
  • AT_INTEGRAL_TYPES_V2
    : AT_INTEGRAL_TYPES + AT_BAREBONES_UNSIGNED_TYPES
Relationship:
cpp
AT_INTEGRAL_TYPES          // kByte, kChar, kInt, kLong, kShort
AT_BAREBONES_UNSIGNED_TYPES  // kUInt16, kUInt32, kUInt64
AT_INTEGRAL_TYPES_V2       // INTEGRAL_TYPES + BAREBONES_UNSIGNED_TYPES
无符号类型组:
  • AT_BAREBONES_UNSIGNED_TYPES
    :kUInt16、kUInt32、kUInt64
  • AT_INTEGRAL_TYPES_V2
    :AT_INTEGRAL_TYPES + AT_BAREBONES_UNSIGNED_TYPES
关系:
cpp
AT_INTEGRAL_TYPES          // kByte, kChar, kInt, kLong, kShort
AT_BAREBONES_UNSIGNED_TYPES  // kUInt16, kUInt32, kUInt64
AT_INTEGRAL_TYPES_V2       // INTEGRAL_TYPES + BAREBONES_UNSIGNED_TYPES

Instructions

操作步骤

Step 1: Determine if conversion to V2 is needed

步骤1:判断是否需要转换为V2版本

Check if the file uses AT_DISPATCH_V2:
If using old AT_DISPATCH:
  • First convert to AT_DISPATCH_V2 using the at-dispatch-v2 skill
  • Then proceed with adding uint support
If already using AT_DISPATCH_V2:
  • Proceed directly to Step 2
检查文件是否使用AT_DISPATCH_V2:
如果使用旧版AT_DISPATCH:
  • 首先使用at-dispatch-v2技能转换为AT_DISPATCH_V2
  • 然后再添加uint支持
如果已使用AT_DISPATCH_V2:
  • 直接进入步骤2

Step 2: Analyze the current dispatch macro

步骤2:分析当前调度宏

Identify what type groups are currently in use:
cpp
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  // body
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
    ^^^^^^^^^^^^^^^^^^^^^^^^^
    Current type coverage
Common patterns:
  • AT_EXPAND(AT_ALL_TYPES)
    → includes AT_INTEGRAL_TYPES + AT_FLOATING_TYPES
  • AT_EXPAND(AT_INTEGRAL_TYPES)
    → signed integers only
  • AT_EXPAND(AT_FLOATING_TYPES)
    → floating point types
识别当前使用的类型组:
cpp
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  // body
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
    ^^^^^^^^^^^^^^^^^^^^^^^^^
    当前类型覆盖范围
常见模式:
  • AT_EXPAND(AT_ALL_TYPES)
    → 包含AT_INTEGRAL_TYPES + AT_FLOATING_TYPES
  • AT_EXPAND(AT_INTEGRAL_TYPES)
    → 仅包含有符号整数
  • AT_EXPAND(AT_FLOATING_TYPES)
    → 浮点类型

Step 3: Choose the uint addition method

步骤3:选择uint添加方式

Two approaches:
Method 1: Add AT_BAREBONES_UNSIGNED_TYPES explicitly
  • Use when: You want to be explicit about adding uint support
  • Add
    AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
    to the type list
Method 2: Substitute AT_INTEGRAL_TYPES with AT_INTEGRAL_TYPES_V2
  • Use when: The dispatch already uses
    AT_EXPAND(AT_INTEGRAL_TYPES)
  • More concise: replaces one type group with its superset
  • Only applicable if AT_INTEGRAL_TYPES is present
有两种方法:
方法1:显式添加AT_BAREBONES_UNSIGNED_TYPES
  • 适用场景:希望明确添加uint支持
  • 在类型列表中添加
    AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
方法2:用AT_INTEGRAL_TYPES_V2替换AT_INTEGRAL_TYPES
  • 适用场景:调度已使用
    AT_EXPAND(AT_INTEGRAL_TYPES)
  • 更简洁:用超集替换一个类型组
  • 仅当AT_INTEGRAL_TYPES存在时适用

Step 4: Apply the transformation

步骤4:应用转换

Method 1 example:
cpp
// Before
AT_DISPATCH_V2(
    dtype,
    "min_values_cuda",
    AT_WRAP([&]() {
      kernel_impl<scalar_t>(iter);
    }),
    AT_EXPAND(AT_ALL_TYPES),
    kBFloat16, kHalf, kBool
);

// After (add unsigned types)
AT_DISPATCH_V2(
    dtype,
    "min_values_cuda",
    AT_WRAP([&]() {
      kernel_impl<scalar_t>(iter);
    }),
    AT_EXPAND(AT_ALL_TYPES),
    AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
    kBFloat16, kHalf, kBool
);
Method 2 example:
cpp
// Before
AT_DISPATCH_V2(
    dtype,
    "integral_op",
    AT_WRAP([&]() {
      kernel<scalar_t>();
    }),
    AT_EXPAND(AT_INTEGRAL_TYPES)
);

// After (substitute with V2)
AT_DISPATCH_V2(
    dtype,
    "integral_op",
    AT_WRAP([&]() {
      kernel<scalar_t>();
    }),
    AT_EXPAND(AT_INTEGRAL_TYPES_V2)
);
方法1示例:
cpp
// Before
AT_DISPATCH_V2(
    dtype,
    "min_values_cuda",
    AT_WRAP([&]() {
      kernel_impl<scalar_t>(iter);
    }),
    AT_EXPAND(AT_ALL_TYPES),
    kBFloat16, kHalf, kBool
);

// After (add unsigned types)
AT_DISPATCH_V2(
    dtype,
    "min_values_cuda",
    AT_WRAP([&]() {
      kernel_impl<scalar_t>(iter);
    }),
    AT_EXPAND(AT_ALL_TYPES),
    AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
    kBFloat16, kHalf, kBool
);
方法2示例:
cpp
// Before
AT_DISPATCH_V2(
    dtype,
    "integral_op",
    AT_WRAP([&]() {
      kernel<scalar_t>();
    }),
    AT_EXPAND(AT_INTEGRAL_TYPES)
);

// After (substitute with V2)
AT_DISPATCH_V2(
    dtype,
    "integral_op",
    AT_WRAP([&]() {
      kernel<scalar_t>();
    }),
    AT_EXPAND(AT_INTEGRAL_TYPES_V2)
);

Step 5: Handle AT_ALL_TYPES vs individual type groups

步骤5:处理AT_ALL_TYPES与单独类型组的情况

If the dispatch uses
AT_EXPAND(AT_ALL_TYPES)
:
  • AT_ALL_TYPES
    =
    AT_INTEGRAL_TYPES
    +
    AT_FLOATING_TYPES
  • To add uint: add
    AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
    to the list
If the dispatch separately lists INTEGRAL and FLOATING:
cpp
// Before
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)

// After (Method 2 preferred)
AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)
如果调度使用
AT_EXPAND(AT_ALL_TYPES)
  • AT_ALL_TYPES
    =
    AT_INTEGRAL_TYPES
    +
    AT_FLOATING_TYPES
  • 要添加uint:在列表中添加
    AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
如果调度单独列出INTEGRAL和FLOATING:
cpp
// Before
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)

// After (Method 2 preferred)
AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)

Step 6: Verify all dispatch sites

步骤6:验证所有调度位置

Check the file for ALL dispatch macros that need uint support:
  • Some operators have multiple dispatch sites (CPU, CUDA, different functions)
  • Apply the transformation consistently across all sites
  • Ensure each gets the same type coverage updates
检查文件中所有需要uint支持的调度宏:
  • 部分算子有多个调度位置(CPU、CUDA、不同函数)
  • 在所有位置统一应用转换
  • 确保每个位置都获得相同的类型覆盖更新

Step 7: Validate the changes

步骤7:验证修改

Check that:
  • AT_DISPATCH_V2 format is used (not old AT_DISPATCH)
  • Unsigned types are added via one of the two methods
  • All relevant dispatch sites in the file are updated
  • Type groups use
    AT_EXPAND()
  • Arguments are properly formatted and comma-separated
检查以下内容:
  • 使用的是AT_DISPATCH_V2格式(而非旧版AT_DISPATCH)
  • 通过两种方法之一添加了无符号类型
  • 文件中所有相关调度位置都已更新
  • 类型组使用了
    AT_EXPAND()
  • 参数格式正确且用逗号分隔

Common patterns

常见模式

Pattern 1: AT_ALL_TYPES + extras

模式1:AT_ALL_TYPES + 额外类型

cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);

// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);

// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);

Pattern 2: Separate INTEGRAL + FLOATING

模式2:单独的INTEGRAL + FLOATING

cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES));

// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));
cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES));

// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));

Pattern 3: Old dispatch needs conversion first

模式3:旧版调度需要先转换

cpp
// Before (needs v2 conversion first)
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
  kernel<scalar_t>();
});

// After v2 conversion
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);

// After adding uint support
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
cpp
// Before (needs v2 conversion first)
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
  kernel<scalar_t>();
});

// After v2 conversion
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);

// After adding uint support
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);

Multiple dispatch sites example

多调度位置示例

For a file with multiple functions:
cpp
void min_values_kernel_cuda(TensorIterator& iter) {
  AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
    impl<scalar_t>(iter);
  }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
  //                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  //                           Added uint support
}

void min_launch_kernel(TensorIterator &iter) {
  AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
    gpu_reduce_kernel<scalar_t>(iter);
  }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
  //                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  //                           Added uint support here too
}
对于包含多个函数的文件:
cpp
void min_values_kernel_cuda(TensorIterator& iter) {
  AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
    impl<scalar_t>(iter);
  }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
  //                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  //                           添加了uint支持
}

void min_launch_kernel(TensorIterator &iter) {
  AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
    gpu_reduce_kernel<scalar_t>(iter);
  }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
  //                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  //                           此处也添加了uint支持
}

Decision tree

决策树

Use this decision tree to determine the approach:
Is the file using AT_DISPATCH_V2?
├─ No → Use at-dispatch-v2 skill first, then continue
└─ Yes
   └─ Does it use AT_EXPAND(AT_INTEGRAL_TYPES)?
      ├─ Yes → Replace with AT_EXPAND(AT_INTEGRAL_TYPES_V2)
      └─ No → Add AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) to type list
使用以下决策树确定处理方式:
文件是否使用AT_DISPATCH_V2?
├─ 否 → 先使用at-dispatch-v2技能,然后继续
└─ 是
   └─ 是否使用了AT_EXPAND(AT_INTEGRAL_TYPES)?
      ├─ 是 → 替换为AT_EXPAND(AT_INTEGRAL_TYPES_V2)
      └─ 否 → 在类型列表中添加AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)

Edge cases

边缘情况

Case 1: Dispatch with only floating types

情况1:仅支持浮点类型的调度

If the operator only supports floating point types, don't add uint support:
cpp
// Leave as-is - floating point only operator
AT_DISPATCH_V2(dtype, "float_op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf);
如果算子仅支持浮点类型,则无需添加uint支持:
cpp
// 保持原样 - 仅支持浮点类型的算子
AT_DISPATCH_V2(dtype, "float_op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf);

Case 2: Complex types present

情况2:存在复数类型

Unsigned types work alongside complex types:
cpp
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES),
    AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
    AT_EXPAND(AT_COMPLEX_TYPES),
    kHalf, kBFloat16);
无符号类型可与复数类型共存:
cpp
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
  kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES),
    AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
    AT_EXPAND(AT_COMPLEX_TYPES),
    kHalf, kBFloat16);

Case 3: Already has uint support

情况3:已具备uint支持

Check if uint types are already present:
  • If
    AT_INTEGRAL_TYPES_V2
    is used → already has uint support
  • If
    AT_BAREBONES_UNSIGNED_TYPES
    is already in list → already has uint support
  • Skip the file if uint support is already present
检查是否已存在uint类型:
  • 如果使用了
    AT_INTEGRAL_TYPES_V2
    → 已具备uint支持
  • 如果
    AT_BAREBONES_UNSIGNED_TYPES
    已在列表中 → 已具备uint支持
  • 若已具备uint支持,跳过该文件

Workflow

工作流程

When asked to add uint support:
  1. Read the target file
  2. Check if using AT_DISPATCH_V2:
    • If not → use at-dispatch-v2 skill first
  3. Identify all dispatch macro sites
  4. For each dispatch:
    • Analyze current type groups
    • Choose method (add BAREBONES_UNSIGNED or upgrade to V2)
    • Apply transformation with Edit tool
  5. Show the user the changes
  6. Explain what was modified
当被要求添加uint支持时:
  1. 读取目标文件
  2. 检查是否使用AT_DISPATCH_V2:
    • 若未使用 → 先使用at-dispatch-v2技能
  3. 识别所有调度宏位置
  4. 针对每个调度:
    • 分析当前类型组
    • 选择处理方式(添加BAREBONES_UNSIGNED或升级到V2)
    • 使用编辑工具应用转换
  5. 向用户展示修改内容
  6. 解释修改的内容

Important notes

重要说明

  • Always check if v2 conversion is needed first
  • Apply changes consistently across all dispatch sites in the file
  • Method 2 (AT_INTEGRAL_TYPES_V2) is cleaner when applicable
  • Method 1 (explicit AT_BAREBONES_UNSIGNED_TYPES) is more explicit
  • Unsigned types are: kUInt16, kUInt32, kUInt64 (not kByte which is uint8)
  • Some operators may not semantically support unsigned types - use judgment
  • 始终先检查是否需要转换为V2版本
  • 在文件的所有调度位置统一应用修改
  • 方法2(AT_INTEGRAL_TYPES_V2)适用时更简洁
  • 方法1(显式添加AT_BAREBONES_UNSIGNED_TYPES)更明确
  • 无符号类型包括:kUInt16、kUInt32、kUInt64(不包括kByte,它是uint8)
  • 部分算子在语义上可能不支持无符号类型,请根据实际情况判断

Testing

测试

After adding uint support, the operator should accept uint16, uint32, and uint64 tensors. The user is responsible for functional testing.
添加uint支持后,算子应能接受uint16、uint32和uint64张量。功能测试由用户负责。