metal-kernel
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseMetal Kernel Writing Guide
Metal内核编写指南
This skill guides you through implementing Metal kernels for PyTorch operators on Apple Silicon.
Important: The goal of this skill is to use native Metal capabilities via the infrastructure, NOT MPSGraph. Native Metal kernels provide better control, performance, and maintainability.
c10/metal/本指南将引导你在Apple Silicon上为PyTorch算子实现Metal内核。
重要提示:本指南的目标是通过基础设施使用原生Metal能力,而非MPSGraph。原生Metal内核能提供更好的可控性、性能和可维护性。
c10/metal/Overview
概述
There are two workflows covered by this skill:
- Adding new MPS support - Implementing a new operator from scratch
- Migrating from MPSGraph - Converting existing MPSGraph-based operators to native Metal
Both workflows involve:
- Update dispatch in
aten/src/ATen/native/native_functions.yaml - Write Metal kernel in
aten/src/ATen/native/mps/kernels/ - Implement host-side stub in
aten/src/ATen/native/mps/operations/
本指南涵盖两种工作流:
- 添加新的MPS支持 - 从零开始实现新算子
- 从MPSGraph迁移 - 将现有基于MPSGraph的算子转换为原生Metal实现
两种工作流均包含以下步骤:
- 在中更新调度配置
aten/src/ATen/native/native_functions.yaml - 在中编写Metal内核
aten/src/ATen/native/mps/kernels/ - 在中实现主机端存根
aten/src/ATen/native/mps/operations/
Step 1: Update native_functions.yaml
步骤1:更新native_functions.yaml
Location:
aten/src/ATen/native/native_functions.yaml位置:
aten/src/ATen/native/native_functions.yamlFor New Operators
针对新算子
Find the operator entry and add MPS dispatch:
yaml
undefined找到算子条目并添加MPS调度:
yaml
undefinedSimple MPS-specific implementation
简单的MPS专属实现
- func: my_op(Tensor self) -> Tensor dispatch: CPU: my_op_cpu CUDA: my_op_cuda MPS: my_op_mps
- func: my_op(Tensor self) -> Tensor dispatch: CPU: my_op_cpu CUDA: my_op_cuda MPS: my_op_mps
Shared implementation across devices (preferred for structured kernels)
跨设备共享实现(结构化内核优先选择)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA, MPS: my_op_out
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA, MPS: my_op_out
Structured kernel (preferred for new ops)
结构化内核(新算子优先选择)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: my_op_out
undefined- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: my_op_out
undefinedFor Migrating from MPSGraph
针对从MPSGraph迁移的场景
When migrating an existing operator from MPSGraph to native Metal, consolidate the dispatch entry:
yaml
undefined将现有基于MPSGraph的算子迁移到原生Metal时,合并调度条目:
yaml
undefinedBEFORE (MPSGraph-based, separate dispatch)
迁移前(基于MPSGraph,独立调度)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: atan2_out MPS: atan2_out_mps # Separate MPS implementation
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: atan2_out MPS: atan2_out_mps # 独立的MPS实现
AFTER (native Metal, shared dispatch via stub)
迁移后(原生Metal,通过存根共享调度)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: atan2_out # MPS now uses the same stub mechanism
**Key change:** Replace `MPS: my_op_out_mps` with adding `MPS` to the shared dispatch line (e.g., `CPU, CUDA, MPS: my_op_out`).
**Dispatch naming conventions:**
- `MPS: function_name_mps` - MPS-specific implementation (old MPSGraph pattern)
- `CPU, CUDA, MPS: function_name` - Shared stub implementation (native Metal pattern)- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: atan2_out # MPS现在使用相同的存根机制
**关键变化**:将`MPS: my_op_out_mps`替换为将`MPS`添加到共享调度行中(例如:`CPU, CUDA, MPS: my_op_out`)。
**调度命名约定**:
- `MPS: function_name_mps` - MPS专属实现(旧MPSGraph模式)
- `CPU, CUDA, MPS: function_name` - 共享存根实现(原生Metal模式)Step 2: Implement Metal Kernel
步骤2:实现Metal内核
Location:
aten/src/ATen/native/mps/kernels/位置:
aten/src/ATen/native/mps/kernels/Unary Kernel Pattern
一元内核模式
metal
// MyKernel.metal
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
// Define operation functor
struct my_op_functor {
template <typename T>
inline T operator()(const T x) {
return /* your operation */;
}
};
// Register for supported types
REGISTER_UNARY_OP(my_op, float, float);
REGISTER_UNARY_OP(my_op, half, half);
REGISTER_UNARY_OP(my_op, bfloat, bfloat);metal
// MyKernel.metal
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
// 定义操作函子
struct my_op_functor {
template <typename T>
inline T operator()(const T x) {
return /* 你的操作逻辑 */;
}
};
// 为支持的类型注册
REGISTER_UNARY_OP(my_op, float, float);
REGISTER_UNARY_OP(my_op, half, half);
REGISTER_UNARY_OP(my_op, bfloat, bfloat);Binary Kernel Pattern
二元内核模式
metal
struct my_binary_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return /* your operation */;
}
};
REGISTER_BINARY_OP(my_binary, float, float);
REGISTER_BINARY_OP(my_binary, half, half);metal
struct my_binary_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return /* 你的操作逻辑 */;
}
};
REGISTER_BINARY_OP(my_binary, float, float);
REGISTER_BINARY_OP(my_binary, half, half);Binary Kernel Type Registration Macros
二元内核类型注册宏
For binary operations, use the convenience macros defined in :
BinaryKernel.metalmetal
// Floating-point types only (float, half, bfloat)
REGISTER_FLOAT_BINARY_OP(my_op);
// Integral types with float output (for math ops like atan2, copysign)
// Registers: long->float, int->float, short->float, uchar->float, char->float, bool->float
REGISTER_INT2FLOAT_BINARY_OP(my_op);
// Integral types with same-type output (for bitwise/logical ops)
// Registers: long, int, short, uchar, char, bool
REGISTER_INTEGER_BINARY_OP(my_op);
// Floating-point with opmath precision (for ops needing higher precision)
REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);Common patterns:
- Math functions (atan2, copysign, logaddexp): Use both and
REGISTER_FLOAT_BINARY_OPREGISTER_INT2FLOAT_BINARY_OP - Comparison/logical ops (maximum, minimum): Use both and
REGISTER_FLOAT_BINARY_OPREGISTER_INTEGER_BINARY_OP - Arithmetic ops (add, sub, mul): Use both and
REGISTER_FLOAT_BINARY_OPREGISTER_INTEGER_BINARY_OP
Example for atan2 (supports both float and int inputs):
metal
struct atan2_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return static_cast<T>(precise::atan2(float(a), float(b)));
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return precise::atan2(float(a), float(b));
}
};
REGISTER_FLOAT_BINARY_OP(atan2);
REGISTER_INT2FLOAT_BINARY_OP(atan2);对于二元操作,可使用中定义的便捷宏:
BinaryKernel.metalmetal
// 仅浮点类型(float, half, bfloat)
REGISTER_FLOAT_BINARY_OP(my_op);
// 整数类型转浮点输出(适用于atan2、copysign等数学算子)
// 注册类型:long->float, int->float, short->float, uchar->float, char->float, bool->float
REGISTER_INT2FLOAT_BINARY_OP(my_op);
// 整数类型同类型输出(适用于位运算/逻辑运算)
// 注册类型:long, int, short, uchar, char, bool
REGISTER_INTEGER_BINARY_OP(my_op);
// 带opmath精度的浮点类型(适用于需要更高精度的算子)
REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);常见模式:
- 数学函数(atan2、copysign、logaddexp):同时使用和
REGISTER_FLOAT_BINARY_OPREGISTER_INT2FLOAT_BINARY_OP - 比较/逻辑算子(maximum、minimum):同时使用和
REGISTER_FLOAT_BINARY_OPREGISTER_INTEGER_BINARY_OP - 算术算子(add、sub、mul):同时使用和
REGISTER_FLOAT_BINARY_OPREGISTER_INTEGER_BINARY_OP
atan2示例(支持浮点和整数输入):
metal
struct atan2_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return static_cast<T>(precise::atan2(float(a), float(b)));
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return precise::atan2(float(a), float(b));
}
};
REGISTER_FLOAT_BINARY_OP(atan2);
REGISTER_INT2FLOAT_BINARY_OP(atan2);With Scalar Parameter
带标量参数的内核
metal
struct my_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return a + c10::metal::mul(alpha, b);
}
};
REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);metal
struct my_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return a + c10::metal::mul(alpha, b);
}
};
REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);Type-Specialized Functor
类型特化函子
metal
struct special_functor {
// Floating point types
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T operator()(const T x) {
return precise::exp(x); // Use precise math
}
// Integral types
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
inline float operator()(const T x) {
return precise::exp(float(x));
}
// Complex types (float2 for cfloat, half2 for chalf)
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
inline T operator()(const T x) {
// x.x = real, x.y = imaginary
return T(/* real */, /* imag */);
}
};Note on complex types: Complex numbers in Metal are represented as vector types:
- maps to
c10::complex<float>(x = real, y = imaginary)float2 - maps to
c10::complex<half>half2
Use to specialize for complex types in functors.
is_complex_v<T>metal
struct special_functor {
// 浮点类型
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T operator()(const T x) {
return precise::exp(x); // 使用高精度数学运算
}
// 整数类型
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
inline float operator()(const T x) {
return precise::exp(float(x));
}
// 复数类型(cfloat对应float2,chalf对应half2)
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
inline T operator()(const T x) {
// x.x = 实部, x.y = 虚部
return T(/* 实部 */, /* 虚部 */);
}
};复数类型说明:Metal中的复数以向量类型表示:
- 对应
c10::complex<float>(x = 实部,y = 虚部)float2 - 对应
c10::complex<half>half2
在函子中使用来特化复数类型的实现。
is_complex_v<T>Available c10/metal Utilities
可用的c10/metal工具类
utils.h:
- - Operation math type (half->float)
opmath_t<T> - - Accumulation type for reductions
accum_t<T> - ,
max()with NaN propagationmin()
special_math.h:
- ,
precise::exp(),precise::log()precise::sqrt() - ,
precise::sin(),precise::cos()precise::tan() - ,
erf(),erfc()erfinv()
indexing.h:
REGISTER_UNARY_OP(name, in_type, out_type)REGISTER_BINARY_OP(name, in_type, out_type)REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)
utils.h:
- - 运算数学类型(half->float)
opmath_t<T> - - 归约操作的累加类型
accum_t<T> - 支持NaN传播的、
max()min()
special_math.h:
- 、
precise::exp()、precise::log()precise::sqrt() - 、
precise::sin()、precise::cos()precise::tan() - 、
erf()、erfc()erfinv()
indexing.h:
REGISTER_UNARY_OP(name, in_type, out_type)REGISTER_BINARY_OP(name, in_type, out_type)REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)
Step 3: Implement Host-Side Stub
步骤3:实现主机端存根
Location:
aten/src/ATen/native/mps/operations/Choose or create an appropriate file based on operation type:
- - Single input operations via stub dispatch
UnaryKernel.mm - - Two input operations via stub dispatch
BinaryKernel.mm - /
UnaryOps.mm- Legacy MPSGraph implementations (for reference)BinaryOps.mm - - Reductions (sum, mean, max, etc.)
ReduceOps.mm - Create new file for distinct operation categories
位置:
aten/src/ATen/native/mps/operations/根据算子类型选择或创建合适的文件:
- - 单输入算子的存根调度
UnaryKernel.mm - - 双输入算子的存根调度
BinaryKernel.mm - /
UnaryOps.mm- 旧版MPSGraph实现(仅供参考)BinaryOps.mm - - 归约算子(sum、mean、max等)
ReduceOps.mm - 为不同的算子类别创建新文件
Stub Registration Pattern (Preferred for Native Metal)
存根注册模式(原生Metal优先选择)
For structured kernels that use the TensorIterator pattern:
objc
// In BinaryKernel.mm (or appropriate file)
static void my_op_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "my_op"); // "my_op" matches the functor name in .metal
}
// Register the MPS stub - this connects to the dispatch system
REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)For unary operations:
objc
static void my_unary_mps_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "my_unary");
}
REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)对于使用TensorIterator模式的结构化内核:
objc
// 在BinaryKernel.mm(或合适的文件)中
static void my_op_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "my_op"); // "my_op"与.metal文件中的函子名称匹配
}
// 注册MPS存根 - 连接到调度系统
REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)针对一元操作:
objc
static void my_unary_mps_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "my_unary");
}
REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)Migration: Removing Old MPSGraph Implementation
迁移:移除旧的MPSGraph实现
When migrating from MPSGraph, also remove the old implementation:
-
Remove from BinaryOps.mm (or UnaryOps.mm):
- Delete the implementation
TORCH_IMPL_FUNC(my_op_out_mps) - Remove the corresponding header
#include <ATen/ops/my_op_native.h>
- Delete the
-
Add to BinaryKernel.mm (or UnaryKernel.mm):
- Add the static kernel function
- Add the call
REGISTER_DISPATCH
从MPSGraph迁移时,还需移除旧实现:
-
从BinaryOps.mm(或UnaryOps.mm)中移除:
- 删除实现
TORCH_IMPL_FUNC(my_op_out_mps) - 移除对应的头文件
#include <ATen/ops/my_op_native.h>
- 删除
-
添加到BinaryKernel.mm(或UnaryKernel.mm):
- 添加静态内核函数
- 添加调用
REGISTER_DISPATCH
Step 4: Compile
步骤4:编译
After making changes, compile to verify everything builds correctly:
bash
cd build && ninja torch_cpu修改完成后,编译以验证所有内容构建正确:
bash
cd build && ninja torch_cpuTesting
测试
Basic operator support is already tested by in . After implementing an operator, enable testing by removing expected failures:
test_output_matchtest/test_mps.pytest/test_mps.pytest_output_match1. Remove from common_mps.py
1. 从common_mps.py中移除
Location:
torch/testing/_internal/common_mps.pyFind and remove the operator from skip/xfail lists:
python
undefined位置:
torch/testing/_internal/common_mps.py找到并从跳过/预期失败列表中移除算子:
python
undefinedRemove entries like:
移除如下条目:
MPS_XFAILLIST = {
"my_op": ..., # Remove this line
}
MPS_SKIPLIST = {
"my_op": ..., # Remove this line
}
undefinedMPS_XFAILLIST = {
"my_op": ..., # 删除此行
}
MPS_SKIPLIST = {
"my_op": ..., # 删除此行
}
undefined2. Remove from OpInfo decorators
2. 从OpInfo装饰器中移除
Location: (or related files)
torch/testing/_internal/common_methods_invocations.pyRemove MPS-specific decorators from the OpInfo:
python
OpInfo(
"my_op",
# Remove decorators like:
# decorators=[skipMPS, expectedFailureMPS("reason")],
...
)位置:(或相关文件)
torch/testing/_internal/common_methods_invocations.py从OpInfo中移除MPS专属装饰器:
python
OpInfo(
"my_op",
# 移除如下装饰器:
# decorators=[skipMPS, expectedFailureMPS("reason")],
...
)3. Run tests to verify
3. 运行测试验证
bash
undefinedbash
undefinedRun the specific operator test
运行特定算子测试
python test/test_mps.py -k test_output_match_my_op
python test/test_mps.py -k test_output_match_my_op
Or run full MPS test suite
或运行完整的MPS测试套件
python test/test_mps.py
undefinedpython test/test_mps.py
undefinedChecklist
检查清单
- Added MPS dispatch to
native_functions.yaml - Implemented Metal kernel in
kernels/ - Implemented host-side operator in
operations/ - Handles empty tensors
- Handles non-contiguous tensors
- Supports required dtypes (float32, float16, bfloat16, and often complex types via float2/half2)
- Removed expected failures from
torch/testing/_internal/common_mps.py - Removed skip/xfail decorators from OpInfo (if applicable)
- 已向添加MPS调度
native_functions.yaml - 已在中实现Metal内核
kernels/ - 已在中实现主机端算子
operations/ - 已处理空张量
- 已处理非连续张量
- 已支持所需的数据类型(float32、float16、bfloat16,通常还需通过float2/half2支持复数类型)
- 已从中移除预期失败配置
torch/testing/_internal/common_mps.py - 已从OpInfo中移除跳过/预期失败装饰器(如适用)