metal-kernel

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

Metal Kernel Writing Guide

Metal内核编写指南

This skill guides you through implementing Metal kernels for PyTorch operators on Apple Silicon.
Important: The goal of this skill is to use native Metal capabilities via the
c10/metal/
infrastructure, NOT MPSGraph. Native Metal kernels provide better control, performance, and maintainability.
本指南将引导你在Apple Silicon上为PyTorch算子实现Metal内核。
重要提示:本指南的目标是通过
c10/metal/
基础设施使用原生Metal能力,而非MPSGraph。原生Metal内核能提供更好的可控性、性能和可维护性。

Overview

概述

There are two workflows covered by this skill:
  1. Adding new MPS support - Implementing a new operator from scratch
  2. Migrating from MPSGraph - Converting existing MPSGraph-based operators to native Metal
Both workflows involve:
  1. Update dispatch in
    aten/src/ATen/native/native_functions.yaml
  2. Write Metal kernel in
    aten/src/ATen/native/mps/kernels/
  3. Implement host-side stub in
    aten/src/ATen/native/mps/operations/
本指南涵盖两种工作流:
  1. 添加新的MPS支持 - 从零开始实现新算子
  2. 从MPSGraph迁移 - 将现有基于MPSGraph的算子转换为原生Metal实现
两种工作流均包含以下步骤:
  1. aten/src/ATen/native/native_functions.yaml
    更新调度配置
  2. aten/src/ATen/native/mps/kernels/
    编写Metal内核
  3. aten/src/ATen/native/mps/operations/
    实现主机端存根

Step 1: Update native_functions.yaml

步骤1:更新native_functions.yaml

Location:
aten/src/ATen/native/native_functions.yaml
位置
aten/src/ATen/native/native_functions.yaml

For New Operators

针对新算子

Find the operator entry and add MPS dispatch:
yaml
undefined
找到算子条目并添加MPS调度:
yaml
undefined

Simple MPS-specific implementation

简单的MPS专属实现

  • func: my_op(Tensor self) -> Tensor dispatch: CPU: my_op_cpu CUDA: my_op_cuda MPS: my_op_mps
  • func: my_op(Tensor self) -> Tensor dispatch: CPU: my_op_cpu CUDA: my_op_cuda MPS: my_op_mps

Shared implementation across devices (preferred for structured kernels)

跨设备共享实现(结构化内核优先选择)

  • func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA, MPS: my_op_out
  • func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA, MPS: my_op_out

Structured kernel (preferred for new ops)

结构化内核(新算子优先选择)

  • func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: my_op_out
undefined
  • func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: my_op_out
undefined

For Migrating from MPSGraph

针对从MPSGraph迁移的场景

When migrating an existing operator from MPSGraph to native Metal, consolidate the dispatch entry:
yaml
undefined
将现有基于MPSGraph的算子迁移到原生Metal时,合并调度条目
yaml
undefined

BEFORE (MPSGraph-based, separate dispatch)

迁移前(基于MPSGraph,独立调度)

  • func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: atan2_out MPS: atan2_out_mps # Separate MPS implementation
  • func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: atan2_out MPS: atan2_out_mps # 独立的MPS实现

AFTER (native Metal, shared dispatch via stub)

迁移后(原生Metal,通过存根共享调度)

  • func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: atan2_out # MPS now uses the same stub mechanism

**Key change:** Replace `MPS: my_op_out_mps` with adding `MPS` to the shared dispatch line (e.g., `CPU, CUDA, MPS: my_op_out`).

**Dispatch naming conventions:**
- `MPS: function_name_mps` - MPS-specific implementation (old MPSGraph pattern)
- `CPU, CUDA, MPS: function_name` - Shared stub implementation (native Metal pattern)
  • func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: atan2_out # MPS现在使用相同的存根机制

**关键变化**:将`MPS: my_op_out_mps`替换为将`MPS`添加到共享调度行中(例如:`CPU, CUDA, MPS: my_op_out`)。

**调度命名约定**:
- `MPS: function_name_mps` - MPS专属实现(旧MPSGraph模式)
- `CPU, CUDA, MPS: function_name` - 共享存根实现(原生Metal模式)

Step 2: Implement Metal Kernel

步骤2:实现Metal内核

Location:
aten/src/ATen/native/mps/kernels/
位置
aten/src/ATen/native/mps/kernels/

Unary Kernel Pattern

一元内核模式

metal
// MyKernel.metal
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>

using namespace metal;
using namespace c10::metal;

// Define operation functor
struct my_op_functor {
  template <typename T>
  inline T operator()(const T x) {
    return /* your operation */;
  }
};

// Register for supported types
REGISTER_UNARY_OP(my_op, float, float);
REGISTER_UNARY_OP(my_op, half, half);
REGISTER_UNARY_OP(my_op, bfloat, bfloat);
metal
// MyKernel.metal
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>

using namespace metal;
using namespace c10::metal;

// 定义操作函子
struct my_op_functor {
  template <typename T>
  inline T operator()(const T x) {
    return /* 你的操作逻辑 */;
  }
};

// 为支持的类型注册
REGISTER_UNARY_OP(my_op, float, float);
REGISTER_UNARY_OP(my_op, half, half);
REGISTER_UNARY_OP(my_op, bfloat, bfloat);

Binary Kernel Pattern

二元内核模式

metal
struct my_binary_functor {
  template <typename T>
  inline T operator()(const T a, const T b) {
    return /* your operation */;
  }
};

REGISTER_BINARY_OP(my_binary, float, float);
REGISTER_BINARY_OP(my_binary, half, half);
metal
struct my_binary_functor {
  template <typename T>
  inline T operator()(const T a, const T b) {
    return /* 你的操作逻辑 */;
  }
};

REGISTER_BINARY_OP(my_binary, float, float);
REGISTER_BINARY_OP(my_binary, half, half);

Binary Kernel Type Registration Macros

二元内核类型注册宏

For binary operations, use the convenience macros defined in
BinaryKernel.metal
:
metal
// Floating-point types only (float, half, bfloat)
REGISTER_FLOAT_BINARY_OP(my_op);

// Integral types with float output (for math ops like atan2, copysign)
// Registers: long->float, int->float, short->float, uchar->float, char->float, bool->float
REGISTER_INT2FLOAT_BINARY_OP(my_op);

// Integral types with same-type output (for bitwise/logical ops)
// Registers: long, int, short, uchar, char, bool
REGISTER_INTEGER_BINARY_OP(my_op);

// Floating-point with opmath precision (for ops needing higher precision)
REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);
Common patterns:
  • Math functions (atan2, copysign, logaddexp): Use both
    REGISTER_FLOAT_BINARY_OP
    and
    REGISTER_INT2FLOAT_BINARY_OP
  • Comparison/logical ops (maximum, minimum): Use both
    REGISTER_FLOAT_BINARY_OP
    and
    REGISTER_INTEGER_BINARY_OP
  • Arithmetic ops (add, sub, mul): Use both
    REGISTER_FLOAT_BINARY_OP
    and
    REGISTER_INTEGER_BINARY_OP
Example for atan2 (supports both float and int inputs):
metal
struct atan2_functor {
  template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
  inline T operator()(const T a, const T b) {
    return static_cast<T>(precise::atan2(float(a), float(b)));
  }
  template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
  inline float operator()(const T a, const T b) {
    return precise::atan2(float(a), float(b));
  }
};

REGISTER_FLOAT_BINARY_OP(atan2);
REGISTER_INT2FLOAT_BINARY_OP(atan2);
对于二元操作,可使用
BinaryKernel.metal
中定义的便捷宏:
metal
// 仅浮点类型(float, half, bfloat)
REGISTER_FLOAT_BINARY_OP(my_op);

// 整数类型转浮点输出(适用于atan2、copysign等数学算子)
// 注册类型:long->float, int->float, short->float, uchar->float, char->float, bool->float
REGISTER_INT2FLOAT_BINARY_OP(my_op);

// 整数类型同类型输出(适用于位运算/逻辑运算)
// 注册类型:long, int, short, uchar, char, bool
REGISTER_INTEGER_BINARY_OP(my_op);

// 带opmath精度的浮点类型(适用于需要更高精度的算子)
REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);
常见模式
  • 数学函数(atan2、copysign、logaddexp):同时使用
    REGISTER_FLOAT_BINARY_OP
    REGISTER_INT2FLOAT_BINARY_OP
  • 比较/逻辑算子(maximum、minimum):同时使用
    REGISTER_FLOAT_BINARY_OP
    REGISTER_INTEGER_BINARY_OP
  • 算术算子(add、sub、mul):同时使用
    REGISTER_FLOAT_BINARY_OP
    REGISTER_INTEGER_BINARY_OP
atan2示例(支持浮点和整数输入):
metal
struct atan2_functor {
  template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
  inline T operator()(const T a, const T b) {
    return static_cast<T>(precise::atan2(float(a), float(b)));
  }
  template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
  inline float operator()(const T a, const T b) {
    return precise::atan2(float(a), float(b));
  }
};

REGISTER_FLOAT_BINARY_OP(atan2);
REGISTER_INT2FLOAT_BINARY_OP(atan2);

With Scalar Parameter

带标量参数的内核

metal
struct my_alpha_functor {
  template <typename T>
  inline T operator()(const T a, const T b, const T alpha) {
    return a + c10::metal::mul(alpha, b);
  }
};

REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);
metal
struct my_alpha_functor {
  template <typename T>
  inline T operator()(const T a, const T b, const T alpha) {
    return a + c10::metal::mul(alpha, b);
  }
};

REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);

Type-Specialized Functor

类型特化函子

metal
struct special_functor {
  // Floating point types
  template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
  inline T operator()(const T x) {
    return precise::exp(x);  // Use precise math
  }

  // Integral types
  template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
  inline float operator()(const T x) {
    return precise::exp(float(x));
  }

  // Complex types (float2 for cfloat, half2 for chalf)
  template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
  inline T operator()(const T x) {
    // x.x = real, x.y = imaginary
    return T(/* real */, /* imag */);
  }
};
Note on complex types: Complex numbers in Metal are represented as vector types:
  • c10::complex<float>
    maps to
    float2
    (x = real, y = imaginary)
  • c10::complex<half>
    maps to
    half2
Use
is_complex_v<T>
to specialize for complex types in functors.
metal
struct special_functor {
  // 浮点类型
  template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
  inline T operator()(const T x) {
    return precise::exp(x);  // 使用高精度数学运算
  }

  // 整数类型
  template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
  inline float operator()(const T x) {
    return precise::exp(float(x));
  }

  // 复数类型(cfloat对应float2,chalf对应half2)
  template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
  inline T operator()(const T x) {
    // x.x = 实部, x.y = 虚部
    return T(/* 实部 */, /* 虚部 */);
  }
};
复数类型说明:Metal中的复数以向量类型表示:
  • c10::complex<float>
    对应
    float2
    (x = 实部,y = 虚部)
  • c10::complex<half>
    对应
    half2
在函子中使用
is_complex_v<T>
来特化复数类型的实现。

Available c10/metal Utilities

可用的c10/metal工具类

utils.h:
  • opmath_t<T>
    - Operation math type (half->float)
  • accum_t<T>
    - Accumulation type for reductions
  • max()
    ,
    min()
    with NaN propagation
special_math.h:
  • precise::exp()
    ,
    precise::log()
    ,
    precise::sqrt()
  • precise::sin()
    ,
    precise::cos()
    ,
    precise::tan()
  • erf()
    ,
    erfc()
    ,
    erfinv()
indexing.h:
  • REGISTER_UNARY_OP(name, in_type, out_type)
  • REGISTER_BINARY_OP(name, in_type, out_type)
  • REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)
utils.h:
  • opmath_t<T>
    - 运算数学类型(half->float)
  • accum_t<T>
    - 归约操作的累加类型
  • 支持NaN传播的
    max()
    min()
special_math.h:
  • precise::exp()
    precise::log()
    precise::sqrt()
  • precise::sin()
    precise::cos()
    precise::tan()
  • erf()
    erfc()
    erfinv()
indexing.h:
  • REGISTER_UNARY_OP(name, in_type, out_type)
  • REGISTER_BINARY_OP(name, in_type, out_type)
  • REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)

Step 3: Implement Host-Side Stub

步骤3:实现主机端存根

Location:
aten/src/ATen/native/mps/operations/
Choose or create an appropriate file based on operation type:
  • UnaryKernel.mm
    - Single input operations via stub dispatch
  • BinaryKernel.mm
    - Two input operations via stub dispatch
  • UnaryOps.mm
    /
    BinaryOps.mm
    - Legacy MPSGraph implementations (for reference)
  • ReduceOps.mm
    - Reductions (sum, mean, max, etc.)
  • Create new file for distinct operation categories
位置
aten/src/ATen/native/mps/operations/
根据算子类型选择或创建合适的文件:
  • UnaryKernel.mm
    - 单输入算子的存根调度
  • BinaryKernel.mm
    - 双输入算子的存根调度
  • UnaryOps.mm
    /
    BinaryOps.mm
    - 旧版MPSGraph实现(仅供参考)
  • ReduceOps.mm
    - 归约算子(sum、mean、max等)
  • 为不同的算子类别创建新文件

Stub Registration Pattern (Preferred for Native Metal)

存根注册模式(原生Metal优先选择)

For structured kernels that use the TensorIterator pattern:
objc
// In BinaryKernel.mm (or appropriate file)

static void my_op_mps_kernel(TensorIteratorBase& iter) {
  lib.exec_binary_kernel(iter, "my_op");  // "my_op" matches the functor name in .metal
}

// Register the MPS stub - this connects to the dispatch system
REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)
For unary operations:
objc
static void my_unary_mps_kernel(TensorIteratorBase& iter) {
  lib.exec_unary_kernel(iter, "my_unary");
}

REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)
对于使用TensorIterator模式的结构化内核:
objc
// 在BinaryKernel.mm(或合适的文件)中

static void my_op_mps_kernel(TensorIteratorBase& iter) {
  lib.exec_binary_kernel(iter, "my_op");  // "my_op"与.metal文件中的函子名称匹配
}

// 注册MPS存根 - 连接到调度系统
REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)
针对一元操作:
objc
static void my_unary_mps_kernel(TensorIteratorBase& iter) {
  lib.exec_unary_kernel(iter, "my_unary");
}

REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)

Migration: Removing Old MPSGraph Implementation

迁移:移除旧的MPSGraph实现

When migrating from MPSGraph, also remove the old implementation:
  1. Remove from BinaryOps.mm (or UnaryOps.mm):
    • Delete the
      TORCH_IMPL_FUNC(my_op_out_mps)
      implementation
    • Remove the corresponding
      #include <ATen/ops/my_op_native.h>
      header
  2. Add to BinaryKernel.mm (or UnaryKernel.mm):
    • Add the static kernel function
    • Add the
      REGISTER_DISPATCH
      call
从MPSGraph迁移时,还需移除旧实现:
  1. 从BinaryOps.mm(或UnaryOps.mm)中移除:
    • 删除
      TORCH_IMPL_FUNC(my_op_out_mps)
      实现
    • 移除对应的
      #include <ATen/ops/my_op_native.h>
      头文件
  2. 添加到BinaryKernel.mm(或UnaryKernel.mm):
    • 添加静态内核函数
    • 添加
      REGISTER_DISPATCH
      调用

Step 4: Compile

步骤4:编译

After making changes, compile to verify everything builds correctly:
bash
cd build && ninja torch_cpu
修改完成后,编译以验证所有内容构建正确:
bash
cd build && ninja torch_cpu

Testing

测试

Basic operator support is already tested by
test_output_match
in
test/test_mps.py
. After implementing an operator, enable testing by removing expected failures:
test/test_mps.py
中的
test_output_match
已覆盖基础算子支持。实现算子后,通过移除预期失败配置来启用测试:

1. Remove from common_mps.py

1. 从common_mps.py中移除

Location:
torch/testing/_internal/common_mps.py
Find and remove the operator from skip/xfail lists:
python
undefined
位置
torch/testing/_internal/common_mps.py
找到并从跳过/预期失败列表中移除算子:
python
undefined

Remove entries like:

移除如下条目:

MPS_XFAILLIST = { "my_op": ..., # Remove this line }
MPS_SKIPLIST = { "my_op": ..., # Remove this line }
undefined
MPS_XFAILLIST = { "my_op": ..., # 删除此行 }
MPS_SKIPLIST = { "my_op": ..., # 删除此行 }
undefined

2. Remove from OpInfo decorators

2. 从OpInfo装饰器中移除

Location:
torch/testing/_internal/common_methods_invocations.py
(or related files)
Remove MPS-specific decorators from the OpInfo:
python
OpInfo(
    "my_op",
    # Remove decorators like:
    # decorators=[skipMPS, expectedFailureMPS("reason")],
    ...
)
位置
torch/testing/_internal/common_methods_invocations.py
(或相关文件)
从OpInfo中移除MPS专属装饰器:
python
OpInfo(
    "my_op",
    # 移除如下装饰器:
    # decorators=[skipMPS, expectedFailureMPS("reason")],
    ...
)

3. Run tests to verify

3. 运行测试验证

bash
undefined
bash
undefined

Run the specific operator test

运行特定算子测试

python test/test_mps.py -k test_output_match_my_op
python test/test_mps.py -k test_output_match_my_op

Or run full MPS test suite

或运行完整的MPS测试套件

python test/test_mps.py
undefined
python test/test_mps.py
undefined

Checklist

检查清单

  • Added MPS dispatch to
    native_functions.yaml
  • Implemented Metal kernel in
    kernels/
  • Implemented host-side operator in
    operations/
  • Handles empty tensors
  • Handles non-contiguous tensors
  • Supports required dtypes (float32, float16, bfloat16, and often complex types via float2/half2)
  • Removed expected failures from
    torch/testing/_internal/common_mps.py
  • Removed skip/xfail decorators from OpInfo (if applicable)
  • 已向
    native_functions.yaml
    添加MPS调度
  • 已在
    kernels/
    中实现Metal内核
  • 已在
    operations/
    中实现主机端算子
  • 已处理空张量
  • 已处理非连续张量
  • 已支持所需的数据类型(float32、float16、bfloat16,通常还需通过float2/half2支持复数类型)
  • 已从
    torch/testing/_internal/common_mps.py
    中移除预期失败配置
  • 已从OpInfo中移除跳过/预期失败装饰器(如适用)