Loading...
Loading...
Add unsigned integer (uint) type support to PyTorch operators by updating AT_DISPATCH macros. Use when adding support for uint16, uint32, uint64 types to operators, kernels, or when user mentions enabling unsigned types, barebones unsigned types, or uint support.
npx skill4agent add pytorch/pytorch add-uint-support// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));
// After (method 1: add unsigned types explicitly)
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
// After (method 2: use V2 integral types if AT_INTEGRAL_TYPES present)
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));AT_BAREBONES_UNSIGNED_TYPESAT_INTEGRAL_TYPES_V2AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + BAREBONES_UNSIGNED_TYPESAT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
// body
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
^^^^^^^^^^^^^^^^^^^^^^^^^
Current type coverageAT_EXPAND(AT_ALL_TYPES)AT_EXPAND(AT_INTEGRAL_TYPES)AT_EXPAND(AT_FLOATING_TYPES)AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)AT_EXPAND(AT_INTEGRAL_TYPES)// Before
AT_DISPATCH_V2(
dtype,
"min_values_cuda",
AT_WRAP([&]() {
kernel_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
kBFloat16, kHalf, kBool
);
// After (add unsigned types)
AT_DISPATCH_V2(
dtype,
"min_values_cuda",
AT_WRAP([&]() {
kernel_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
kBFloat16, kHalf, kBool
);// Before
AT_DISPATCH_V2(
dtype,
"integral_op",
AT_WRAP([&]() {
kernel<scalar_t>();
}),
AT_EXPAND(AT_INTEGRAL_TYPES)
);
// After (substitute with V2)
AT_DISPATCH_V2(
dtype,
"integral_op",
AT_WRAP([&]() {
kernel<scalar_t>();
}),
AT_EXPAND(AT_INTEGRAL_TYPES_V2)
);AT_EXPAND(AT_ALL_TYPES)AT_ALL_TYPESAT_INTEGRAL_TYPESAT_FLOATING_TYPESAT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)// Before
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
// After (Method 2 preferred)
AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)AT_EXPAND()// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));// Before (needs v2 conversion first)
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
kernel<scalar_t>();
});
// After v2 conversion
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
// After adding uint support
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);void min_values_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
impl<scalar_t>(iter);
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
// Added uint support
}
void min_launch_kernel(TensorIterator &iter) {
AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t>(iter);
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
// Added uint support here too
}Is the file using AT_DISPATCH_V2?
├─ No → Use at-dispatch-v2 skill first, then continue
└─ Yes
└─ Does it use AT_EXPAND(AT_INTEGRAL_TYPES)?
├─ Yes → Replace with AT_EXPAND(AT_INTEGRAL_TYPES_V2)
└─ No → Add AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) to type list// Leave as-is - floating point only operator
AT_DISPATCH_V2(dtype, "float_op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf);AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
AT_EXPAND(AT_COMPLEX_TYPES),
kHalf, kBFloat16);AT_INTEGRAL_TYPES_V2AT_BAREBONES_UNSIGNED_TYPES