kernel-cute-writing

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

CuTe DSL

CuTe DSL

CuTe DSL is a Python-based domain-specific language for GPU kernel development, part of CUTLASS 4.x. It provides Python abstractions over CUTLASS C++ templates with JIT compilation to optimized CUDA kernels via MLIR and ptxas.
CuTe DSL是一款基于Python的GPU内核开发领域特定语言,属于CUTLASS 4.x的一部分。它通过MLIR和ptxas实现JIT编译,为CUTLASS C++模板提供Python抽象,生成优化的CUDA内核。

When to Use

使用场景

Triggers:
  • Writing CUDA kernels in Python (element-wise, GEMM, custom ops)
  • Optimizing GPU memory access patterns (vectorized loads, TMA, shared memory)
  • Building tensor core (MMA) kernels for Ampere/Hopper/Blackwell
  • Integrating custom GPU kernels with PyTorch or JAX
  • Prototyping high-performance kernels without C++ metaprogramming
Symptoms (wrong tool otherwise):
  • Need shared memory coordination or tensor core MMA → use CuTe DSL (not Triton for complex patterns)
  • Need simple element-wise ops with no shared memory → CuTe DSL or Triton both work
  • Need to call existing CUTLASS C++ kernels → use CUTLASS C++ APIs instead
  • Need reductions, scans, or non-GEMM collective ops → consider CUB/Thrust
Keywords: cute, cutlass, cute.jit, cute.kernel, from_dlpack, zipped_divide, TiledMMA, TiledCopy, TMA, WGMMA, tcgen05, pipeline, mbarrier
触发条件:
  • 使用Python编写CUDA内核(逐元素、GEMM、自定义算子)
  • 优化GPU内存访问模式(向量化加载、TMA、共享内存)
  • 为Ampere/Hopper/Blackwell架构构建张量核心(MMA)内核
  • 将自定义GPU内核与PyTorch或JAX集成
  • 无需C++元编程即可快速原型化高性能内核
不适用场景:
  • 需要共享内存协调或张量核心MMA → 使用CuTe DSL(复杂模式下不推荐Triton)
  • 无需共享内存的简单逐元素操作 → CuTe DSL或Triton均可
  • 需要调用现有CUTLASS C++内核 → 使用CUTLASS C++ API
  • 需要归约、扫描或非GEMM集合操作 → 考虑使用CUB/Thrust
关键词: cute, cutlass, cute.jit, cute.kernel, from_dlpack, zipped_divide, TiledMMA, TiledCopy, TMA, WGMMA, tcgen05, pipeline, mbarrier

Requirements

环境要求

RequirementDetail
PlatformLinux x86_64 only
Python3.10–3.13
GPUNVIDIA Ampere+ (SM80, SM90, SM100)
CUDA Driver≥ 575.51.03 (Toolkit 12.9 compat)
Install
pip install nvidia-cutlass-dsl
Optional
apache-tvm-ffi
,
torch-c-dlpack-ext
要求项详细说明
平台仅支持Linux x86_64
Python版本3.10–3.13
GPUNVIDIA Ampere及以上(SM80, SM90, SM100)
CUDA驱动≥ 575.51.03(兼容Toolkit 12.9)
安装方式
pip install nvidia-cutlass-dsl
可选依赖
apache-tvm-ffi
,
torch-c-dlpack-ext

Workflows

工作流程

Workflow 0: Starting from Examples (Recommended)

流程0:从示例开始(推荐)

For any non-trivial kernel (GEMM, attention, pipelined, fused ops), start by finding the most similar existing example to use as a starting point — study its structure, then rework it for your use case. Do not copy examples verbatim; they target specific dtypes, architectures, and problem shapes that likely differ.
  1. Pick the closest example from the index below. Prefer examples matching the target GPU architecture (check with
    torch.cuda.get_device_capability()
    ) when the operation is similar.
    Fetch via
    web_fetch
    with base URL
    https://raw.githubusercontent.com/NVIDIA/cutlass/main/examples/python/CuTeDSL
    OperationArchExample path (append to base URL)
    Element-wise addSM80
    ampere/elementwise_add.py
    Element-wise + autotuneSM80
    ampere/elementwise_add_autotune.py
    Element-wise applySM80
    ampere/elementwise_apply.py
    SGEMM (scalar)SM80
    ampere/sgemm.py
    Tensor-core GEMMSM80
    ampere/tensorop_gemm.py
    Flash Attention v2SM80
    ampere/flash_attention_v2.py
    HSTU AttentionSM80
    ampere/hstu_attention.py
    Shared memory allocatorSM80
    ampere/smem_allocator.py
    CTA norm (LayerNorm)SM90
    hopper/cta_norm.py
    Dense GEMMSM90
    hopper/dense_gemm.py
    Dense GEMM persistentSM90
    hopper/dense_gemm_persistent.py
    Flash MHASM90
    hopper/fmha.py
    Dense GEMMSM100
    blackwell/dense_gemm.py
    Dense GEMM persistentSM100
    blackwell/dense_gemm_persistent.py
    Dense GEMM + alpha/betaSM100
    blackwell/dense_gemm_alpha_beta_persistent.py
    RMSNormSM100
    blackwell/rmsnorm.py
    ReduceSM100
    blackwell/reduce.py
    Flash MHASM100
    blackwell/fmha.py
    Grouped GEMMSM100
    blackwell/grouped_gemm.py
    Mamba2 SSDSM100
    blackwell/mamba2_ssd/
    GEMM tutorial (notebook)SM100
    notebooks/tour_to_sol_gemm.ipynb
    Example: To fetch the Hopper dense GEMM:
    bash
    web_fetch https://raw.githubusercontent.com/NVIDIA/cutlass/main/examples/python/CuTeDSL/hopper/dense_gemm.py
  2. Read reference materials first — before diving into example code, read the relevant
    references/
    docs to understand the patterns and APIs:
    • For GEMM:
      references/patterns-gemm.md
      (3-level tiling, epilogue fusion,
      cute.compile
      with
      mark_layout_dynamic
      , shared memory layouts)
    • For reductions:
      references/patterns-reduction.md
      (warp reductions,
      cute.compile
      cache pattern)
    • For element-wise:
      references/patterns-elementwise.md
      (variations A–E)
    • Always:
      references/api-arch.md
      (available APIs, arch-specific caveats)
    This gives you the conceptual foundation so you can rework the example intelligently rather than trying to copy-paste complex pipelines.
  3. Fetch and study the example source — read for structure, not to copy:
    • Identify: decorators, tiling strategy, shared memory usage, mainloop flow
    • Note which dtype/arch it targets (many examples are fp16/bf16-specific)
    • Check if it uses APIs tied to a specific arch (TMA → SM90+, tcgen05 → SM100)
  4. Rework for the user's workload (do not copy-paste):
    • Change shapes, data types, tile sizes to match requirements
    • Replace compute logic (epilogue, activation fusion) as needed
    • If dtype differs (e.g., example is fp16, need fp32), expect vectorization and layout changes — the scalar-loop patterns in
      references/
      may be a better starting point than adapting a vectorized example
    • Runtime wrapper must be lightweight:
      kernel_fn()
      should only call
      from_dlpack()
      + the compiled kernel. Never allocate intermediate tensors, copy data, or re-compile per call — these belong in one-time setup
    • Apply optimizations from this skill's reference docs
    ⛔ Blackwell/Hopper GEMM + extra tensors — STOP: If the target GPU is SM90+ (Hopper/Blackwell) and the GEMM requires extra tensors beyond A, B, C in the epilogue (e.g., bias vector, activation inputs), do not attempt it. These examples use TMA descriptors for all data movement — adding tensors requires modifying TMA descriptor setup, which is prohibitively complex. Instead, tell the user this limitation and suggest a two-kernel approach: run the GEMM kernel as-is, then apply bias + activation in a separate element-wise kernel (Workflow 1). Plain GEMM (just A×B→C with scalar alpha/beta) on Hopper/Blackwell is fine.
  5. Validate and benchmark using companion scripts:
    bash
    python scripts/verify_kernel.py kernel.py --rtol 1e-3 --atol 1e-3
    python scripts/benchmark_kernel.py kernel.py
    The kernel file must export
    kernel_fn
    ,
    reference_fn
    , and
    get_inputs()
    .
When to skip examples: Pure element-wise operations (Workflow 1) have complete patterns in
references/patterns-elementwise.md
— no need to fetch external examples.
Reduction kernels (softmax, layernorm, RMSNorm): Use
references/patterns-reduction.md
which provides complete, proven patterns for float32 reductions using scalar loops + butterfly shuffle + shared memory.
对于任何非 trivial 的内核(GEMM、注意力机制、流水线、融合算子),先找到最相似的现有示例作为起点——研究其结构,再根据需求修改。不要直接复制示例,因为它们针对的特定数据类型、架构和问题形状可能与你的需求不同。
  1. 选择最接近的示例,从下方索引中挑选。 当操作类型相似时,优先选择与目标GPU架构匹配的示例(可通过
    torch.cuda.get_device_capability()
    查看)。
    通过
    web_fetch
    工具获取,基础URL为
    https://raw.githubusercontent.com/NVIDIA/cutlass/main/examples/python/CuTeDSL
    操作类型架构示例路径(追加到基础URL后)
    逐元素加法SM80
    ampere/elementwise_add.py
    逐元素+自动调优SM80
    ampere/elementwise_add_autotune.py
    逐元素应用SM80
    ampere/elementwise_apply.py
    SGEMM(标量)SM80
    ampere/sgemm.py
    张量核心GEMMSM80
    ampere/tensorop_gemm.py
    Flash Attention v2SM80
    ampere/flash_attention_v2.py
    HSTU注意力SM80
    ampere/hstu_attention.py
    共享内存分配器SM80
    ampere/smem_allocator.py
    CTA归一化(LayerNorm)SM90
    hopper/cta_norm.py
    稠密GEMMSM90
    hopper/dense_gemm.py
    持久化稠密GEMMSM90
    hopper/dense_gemm_persistent.py
    Flash MHASM90
    hopper/fmha.py
    稠密GEMMSM100
    blackwell/dense_gemm.py
    持久化稠密GEMMSM100
    blackwell/dense_gemm_persistent.py
    带alpha/beta的稠密GEMMSM100
    blackwell/dense_gemm_alpha_beta_persistent.py
    RMSNormSM100
    blackwell/rmsnorm.py
    归约操作SM100
    blackwell/reduce.py
    Flash MHASM100
    blackwell/fmha.py
    分组GEMMSM100
    blackwell/grouped_gemm.py
    Mamba2 SSDSM100
    blackwell/mamba2_ssd/
    GEMM教程(笔记本)SM100
    notebooks/tour_to_sol_gemm.ipynb
    示例: 获取Hopper架构的稠密GEMM示例:
    bash
    web_fetch https://raw.githubusercontent.com/NVIDIA/cutlass/main/examples/python/CuTeDSL/hopper/dense_gemm.py
  2. 先阅读参考资料——在深入示例代码之前,阅读相关
    references/
    文档以理解模式和API:
    • 针对GEMM:
      references/patterns-gemm.md
      (三级分块、尾端融合、
      cute.compile
      配合
      mark_layout_dynamic
      、共享内存布局)
    • 针对归约操作:
      references/patterns-reduction.md
      (归约 warp、
      cute.compile
      缓存模式)
    • 针对逐元素操作:
      references/patterns-elementwise.md
      (变体A–E)
    • 必看:
      references/api-arch.md
      (可用API、架构特定注意事项)
    这能为你打下概念基础,让你可以智能地修改示例,而非尝试复制粘贴复杂流水线。
  3. 获取并研究示例源码——关注结构而非复制:
    • 识别:装饰器、分块策略、共享内存使用、主循环流程
    • 注意它针对的数据类型/架构(许多示例是fp16/bf16专用)
    • 检查是否使用了特定架构的API(TMA → SM90+,tcgen05 → SM100)
  4. 根据用户工作负载修改(不要复制粘贴):
    • 修改形状、数据类型、分块大小以匹配需求
    • 根据需要替换计算逻辑(尾端、激活融合)
    • 如果数据类型不同(例如示例是fp16,需要fp32),预期会有向量化和布局变化——
      references/
      中的标量循环模式可能比修改向量化示例更适合作为起点
    • 运行时包装器必须轻量化
      kernel_fn()
      应仅调用
      from_dlpack()
      + 编译后的内核。绝不要在每次调用时分配中间张量、复制数据或重新编译——这些操作应放在一次性初始化中
    • 应用本技能参考文档中的优化方法
    ⛔ Blackwell/Hopper GEMM + 额外张量——注意: 如果目标GPU是SM90+(Hopper/Blackwell)GEMM需要在尾端使用A、B、C之外的额外张量(例如偏置向量、激活输入),请勿尝试实现。这些示例使用TMA描述符处理所有数据移动——添加张量需要修改TMA描述符设置,这极其复杂。请告知用户此限制,并建议采用双内核方案:按原样运行GEMM内核,然后在单独的逐元素内核中应用偏置+激活(流程1)。Hopper/Blackwell上的纯GEMM(仅A×B→C带标量alpha/beta)是可行的。
  5. 使用配套脚本验证和基准测试
    bash
    python scripts/verify_kernel.py kernel.py --rtol 1e-3 --atol 1e-3
    python scripts/benchmark_kernel.py kernel.py
    内核文件必须导出
    kernel_fn
    reference_fn
    get_inputs()
何时跳过示例: 纯逐元素操作(流程1)在
references/patterns-elementwise.md
中有完整模式——无需获取外部示例。
归约内核(softmax、layernorm、RMSNorm):使用
references/patterns-reduction.md
,其中提供了使用标量循环+蝶形洗牌+共享内存的完整、经过验证的float32归约模式。

Workflow 1: Element-wise Kernel

流程1:逐元素内核

For unary/binary/in-place operations that map inputs to outputs 1:1.
  1. Determine kernel structure: inputs/outputs count, tensor rank, target arch
  2. Select pattern from
    references/patterns-elementwise.md
    (Variations A–E)
  3. Write kernel applying all four invariant principles:
    • P1:
      from_dlpack(tensor, assumed_align=16)
      for vector loads
    • P2: Derive
      vec_size
      from
      element_type.width
    • P3:
      cute.zipped_divide(mA, tiler)
      for coalesced access
    • P4:
      cutlass.dynamic_expr(thread_idx < total)
      for bounds
  4. Critical rules: No early return, no
    a * 2
    (use
    a + a
    ), no
    cute.math.sigmoid
  5. Pre-compile with
    cute.compile()
    : Always pre-compile the kernel once using
    cute.compile()
    so that
    kernel_fn
    calls the compiled object, not
    @cute.jit
    directly. Without pre-compilation, every call recompiles (~20-50ms overhead). Use
    .mark_layout_dynamic()
    so a single compiled kernel handles arbitrary input shapes without recompilation:
    python
    # Compile once with dynamic layouts — works for any shape
    fake_x = from_dlpack(torch.empty(1, 1, dtype=torch.float16, device="cuda"),
                          assumed_align=16).mark_layout_dynamic()
    fake_out = from_dlpack(torch.empty(1, 1, dtype=torch.float16, device="cuda"),
                            assumed_align=16).mark_layout_dynamic()
    compiled_kernel = cute.compile(host_fn, fake_x, fake_out)
    
    def kernel_fn(x):
        out = torch.empty_like(x)
        compiled_kernel(from_dlpack(x, assumed_align=16).mark_layout_dynamic(),
                        from_dlpack(out, assumed_align=16).mark_layout_dynamic())
        return out
  6. Verify correctness using companion script:
    bash
    python scripts/verify_kernel.py kernel.py --rtol 1e-3 --atol 1e-3
    The kernel file must export
    kernel_fn
    ,
    reference_fn
    , and
    get_inputs()
    .
  7. Benchmark using companion script:
    bash
    python scripts/benchmark_kernel.py kernel.py
用于输入与输出1:1映射的一元/二元/原地操作。
  1. 确定内核结构:输入/输出数量、张量秩、目标架构
  2. references/patterns-elementwise.md
    中选择模式
    (变体A–E)
  3. 编写内核并应用四项不变原则
    • P1:使用
      from_dlpack(tensor, assumed_align=16)
      进行向量加载
    • P2:从
      element_type.width
      推导
      vec_size
    • P3:使用
      cute.zipped_divide(mA, tiler)
      实现合并访问
    • P4:使用
      cutlass.dynamic_expr(thread_idx < total)
      处理边界
  4. 关键规则:禁止提前返回,禁止使用
    a * 2
    (改用
    a + a
    ),禁止使用
    cute.math.sigmoid
  5. 使用
    cute.compile()
    预编译
    :始终使用
    cute.compile()
    预编译内核一次,使
    kernel_fn
    调用编译后的对象,而非直接调用
    @cute.jit
    。如果不预编译,每次调用都会重新编译(约20-50ms开销)。使用
    .mark_layout_dynamic()
    使单个编译后的内核可处理任意输入形状而无需重新编译:
    python
    # 编译一次并启用动态布局——适用于任意形状
    fake_x = from_dlpack(torch.empty(1, 1, dtype=torch.float16, device="cuda"),
                          assumed_align=16).mark_layout_dynamic()
    fake_out = from_dlpack(torch.empty(1, 1, dtype=torch.float16, device="cuda"),
                            assumed_align=16).mark_layout_dynamic()
    compiled_kernel = cute.compile(host_fn, fake_x, fake_out)
    
    def kernel_fn(x):
        out = torch.empty_like(x)
        compiled_kernel(from_dlpack(x, assumed_align=16).mark_layout_dynamic(),
                        from_dlpack(out, assumed_align=16).mark_layout_dynamic())
        return out
  6. 使用配套脚本验证正确性
    bash
    python scripts/verify_kernel.py kernel.py --rtol 1e-3 --atol 1e-3
    内核文件必须导出
    kernel_fn
    reference_fn
    get_inputs()
  7. 使用配套脚本进行基准测试
    bash
    python scripts/benchmark_kernel.py kernel.py

Workflow 2: GEMM Kernel

流程2:GEMM内核

For matrix multiplication with tiling, shared memory, and tensor cores.
  1. Define problem: shapes (M, N, K), data types, target architecture
  2. Choose tiling: CTA tile (bM, bN, bK), pipeline stages, cluster shape
  3. Three-level partitioning (see
    references/patterns-gemm.md
    ):
    • Level 1: CTA tiling with
      local_tile()
    • Level 2: Copy partitioning (global → shared) with
      TiledCopy
    • Level 3: Compute partitioning (shared → register) with
      TiledMMA
  4. Shared memory: Use swizzled layouts (
    make_smem_layout_atom
    ) to avoid bank conflicts
  5. Mainloop: K-tile loop with copy → sync → MMA → sync
  6. Pipeline: Use
    PipelineTmaAsync
    (Hopper) or
    PipelineTmaUmma
    (Blackwell). ⚠️ TMA-based pipelines manage data movement via TMA descriptors — adding extra tensors (bias, activation inputs) to the epilogue requires modifying descriptor setup, which is prohibitively complex. See the stop condition in Workflow 0 step 4.
  7. Epilogue: Predicated store with alpha/beta scaling
  8. Pre-compile with
    cute.compile()
    : Always pre-compile the GEMM kernel so
    kernel_fn
    calls the compiled object, not
    @cute.jit
    directly. Without pre-compilation, every call recompiles (~20-50ms overhead).
  9. Autotune: Search over tile sizes, cluster shapes, pipeline depths
用于带分块、共享内存和张量核心的矩阵乘法。
  1. 定义问题:形状(M, N, K)、数据类型、目标架构
  2. 选择分块策略:CTA分块(bM, bN, bK)、流水线阶段、集群形状
  3. 三级分区(见
    references/patterns-gemm.md
    ):
    • 第一级:使用
      local_tile()
      进行CTA分块
    • 第二级:使用
      TiledCopy
      进行复制分区(全局→共享)
    • 第三级:使用
      TiledMMA
      进行计算分区(共享→寄存器)
  4. 共享内存:使用混洗布局(
    make_smem_layout_atom
    )避免 bank 冲突
  5. 主循环:K分块循环,流程为复制→同步→MMA→同步
  6. 流水线:使用
    PipelineTmaAsync
    (Hopper)或
    PipelineTmaUmma
    (Blackwell)。 ⚠️ 基于TMA的流水线通过TMA描述符管理数据移动——在尾端添加额外张量(偏置、激活输入)需要修改描述符设置,这极其复杂。请查看流程0步骤4中的注意事项。
  7. 尾端:带alpha/beta缩放的谓词存储
  8. 使用
    cute.compile()
    预编译
    :始终预编译GEMM内核,使
    kernel_fn
    调用编译后的对象,而非直接调用
    @cute.jit
    。如果不预编译,每次调用都会重新编译(约20-50ms开销)。
  9. 自动调优:搜索分块大小、集群形状、流水线深度

Workflow 3: Framework Integration

流程3:框架集成

For wrapping CuTe DSL kernels as PyTorch/JAX custom operators.
  1. Write kernel using Workflow 1 or 2
  2. Create wrapper: Accept
    torch.Tensor
    , convert via
    from_dlpack
    , call host fn
  3. For production: Compile with TVM FFI for zero-overhead tensor passing:
    python
    compiled = cute.compile(host_fn, *fake_tensors, options="--enable-tvm-ffi")
    compiled(torch_a, torch_b)  # Direct torch.Tensor, no from_dlpack
  4. For deployment: Use AOT compilation → export to
    .o
    → load at runtime
用于将CuTe DSL内核包装为PyTorch/JAX自定义算子。
  1. 使用流程1或2编写内核
  2. 创建包装器:接收
    torch.Tensor
    ,通过
    from_dlpack
    转换,调用宿主函数
  3. 生产环境:使用TVM FFI编译以实现零开销张量传递:
    python
    compiled = cute.compile(host_fn, *fake_tensors, options="--enable-tvm-ffi")
    compiled(torch_a, torch_b)  # 直接传入torch.Tensor,无需from_dlpack
  4. 部署:使用AOT编译→导出为
    .o
    文件→运行时加载

Workflow 4: Debugging & Profiling

流程4:调试与性能分析

  1. Set environment:
    CUTE_DSL_PRINT_IR=1
    ,
    CUTE_DSL_KEEP_PTX=1
  2. Use
    cute.printf()
    for runtime values (not Python
    print
    )
  3. Inspect generated code:
    compiled.__ptx__
    ,
    compiled.__mlir__
  4. Profile: Enable
    CUTE_DSL_LINEINFO=1
    , use Nsight Compute/Systems
  5. Debug memory: Run with
    compute-sanitizer python script.py
  1. 设置环境变量
    CUTE_DSL_PRINT_IR=1
    ,
    CUTE_DSL_KEEP_PTX=1
  2. **使用
    cute.printf()
    **输出运行时值(不要使用Python的
    print
  3. 检查生成的代码
    compiled.__ptx__
    ,
    compiled.__mlir__
  4. 性能分析:启用
    CUTE_DSL_LINEINFO=1
    ,使用Nsight Compute/Systems
  5. 内存调试:运行
    compute-sanitizer python script.py

Output Formats

输出格式

A typical CuTe DSL kernel project:
kernel_dir/
  kernel.py          # @cute.kernel + @cute.jit functions
  test_kernel.py     # Correctness test vs PyTorch reference
  bench_kernel.py    # Benchmark with cute.compile() setup
Success indicators:
  • Correctness test passes (
    torch.testing.assert_close
    )
  • Nsight shows vector loads (LDG.128/LDG.256), not scalar loads
  • For GEMM: tensor core utilization > 80% in Nsight Compute
典型的CuTe DSL内核项目结构:
kernel_dir/
  kernel.py          # @cute.kernel + @cute.jit函数
  test_kernel.py     # 与PyTorch参考实现对比的正确性测试
  bench_kernel.py    # 基于cute.compile()的基准测试设置
成功指标:
  • 正确性测试通过(
    torch.testing.assert_close
  • Nsight显示向量加载(LDG.128/LDG.256),而非标量加载
  • 对于GEMM:Nsight Compute中张量核心利用率>80%

Companion Script Contract

配套脚本约定

Kernel files used with
scripts/verify_kernel.py
and
scripts/benchmark_kernel.py
must export three names:
  • kernel_fn(*inputs)
    — the CuTe DSL kernel wrapper (calls
    cute.compile
    + runs kernel)
  • reference_fn(*inputs)
    — PyTorch reference implementation (same signature)
  • get_inputs()
    — returns a list of CUDA tensors for testing
python
undefined
scripts/verify_kernel.py
scripts/benchmark_kernel.py
配合使用的内核文件必须导出三个名称:
  • kernel_fn(*inputs)
    — CuTe DSL内核包装器(调用
    cute.compile
    并运行内核)
  • reference_fn(*inputs)
    — PyTorch参考实现(签名一致)
  • get_inputs()
    — 返回用于测试的CUDA张量列表
python
undefined

Example kernel.py contract

示例kernel.py约定

import torch import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack
def kernel_fn(x): out = torch.empty_like(x) # ... call compiled cute kernel ... return out
def reference_fn(x): return torch.nn.functional.gelu(x)
def get_inputs(): return [torch.randn(1024, 512, dtype=torch.float16, device="cuda")]
undefined
import torch import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack
def kernel_fn(x): out = torch.empty_like(x) # ... 调用编译后的cute内核 ... return out
def reference_fn(x): return torch.nn.functional.gelu(x)
def get_inputs(): return [torch.randn(1024, 512, dtype=torch.float16, device="cuda")]
undefined

Examples

示例

Example: 2D Unary Element-wise (ReLU)

示例:2D一元逐元素(ReLU)

python
import torch, cutlass, cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.kernel
def relu_kernel(gA: cute.Tensor, gC: cute.Tensor):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    bdim, _, _ = cute.arch.block_dim()
    idx = bidx * bdim + tidx
    m, n = gA.shape[1]
    total = m * n
    if cutlass.dynamic_expr(idx < total):
        a = gA[(None, (idx // n, idx % n))].load()
        gC[(None, (idx // n, idx % n))] = cute.where(a > 0, a, 0)

@cute.jit
def relu_host(mA: cute.Tensor, mC: cute.Tensor):
    vec = 16 // (mA.element_type.width // 8)
    gA = cute.zipped_divide(mA, (1, vec))
    gC = cute.zipped_divide(mC, (1, vec))
    T = 256
    N = cute.size(gA.shape[1])
    relu_kernel(gA, gC).launch(grid=((N+T-1)//T,1,1), block=(T,1,1))

x = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
out = torch.empty_like(x)
relu_host(from_dlpack(x, assumed_align=16), from_dlpack(out, assumed_align=16))
python
import torch, cutlass, cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.kernel
def relu_kernel(gA: cute.Tensor, gC: cute.Tensor):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    bdim, _, _ = cute.arch.block_dim()
    idx = bidx * bdim + tidx
    m, n = gA.shape[1]
    total = m * n
    if cutlass.dynamic_expr(idx < total):
        a = gA[(None, (idx // n, idx % n))].load()
        gC[(None, (idx // n, idx % n))] = cute.where(a > 0, a, 0)

@cute.jit
def relu_host(mA: cute.Tensor, mC: cute.Tensor):
    vec = 16 // (mA.element_type.width // 8)
    gA = cute.zipped_divide(mA, (1, vec))
    gC = cute.zipped_divide(mC, (1, vec))
    T = 256
    N = cute.size(gA.shape[1])
    relu_kernel(gA, gC).launch(grid=((N+T-1)//T,1,1), block=(T,1,1))

x = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
out = torch.empty_like(x)
relu_host(from_dlpack(x, assumed_align=16), from_dlpack(out, assumed_align=16))

Error Handling

错误处理

ErrorCauseFix
MLIR function requires a Context
Called @kernel from PythonLaunch via @cute.jit host function
DSLAstPreprocessorError
on return
Early return in @kernelUse
if cutlass.dynamic_expr(cond):
Type mismatch on store
a * 2
promotes FP16→FP32
Use
a + a
or
.to(cutlass.Float16)
could not get source code
Kernel in
exec()
context
Write to file and import
Scalar loads in NsightMissing alignment hintAdd
assumed_align=16
to
from_dlpack
Missing required argument
Not all @jit params passedPass ALL declared parameters
AttributeError: sigmoid
No
cute.math.sigmoid
Use
1.0/(1.0+cute.math.exp(-x))
See
references/troubleshooting.md
for the full error table and limitations.
Debugging rule: Never delete kernel.py during debugging. Use
backup_file
to save a checkpoint, then
edit_file
to iterate. If stuck,
revert_file
to restore the backup. A partially-working kernel is always better than no kernel.
错误信息原因修复方法
MLIR function requires a Context
从Python直接调用@kernel通过@cute.jit宿主函数启动
DSLAstPreprocessorError
on return
@kernel中存在提前返回使用
if cutlass.dynamic_expr(cond):
存储时类型不匹配
a * 2
会将FP16提升为FP32
使用
a + a
.to(cutlass.Float16)
could not get source code
内核位于
exec()
上下文
写入文件后再导入
Nsight显示标量加载缺少对齐提示
from_dlpack
中添加
assumed_align=16
Missing required argument
未传递所有@jit参数传递所有声明的参数
AttributeError: sigmoid
不存在
cute.math.sigmoid
使用
1.0/(1.0+cute.math.exp(-x))
完整的错误表和限制请查看
references/troubleshooting.md
调试规则: 调试期间绝不要删除kernel.py。使用
backup_file
保存检查点,然后使用
edit_file
迭代修改。如果遇到瓶颈,使用
revert_file
恢复备份。一个部分可用的内核总比没有内核好。

Finding More Information

获取更多信息

Tier 1: This File (SKILL.md)

一级:本文件(SKILL.md)

Workflows above cover element-wise kernels, GEMM, framework integration, and debugging. Search this file first for procedural questions.
上述工作流程涵盖了逐元素内核、GEMM、框架集成和调试。程序性问题请首先搜索本文件。

Tier 2: references/ Directory

二级:references/目录

Grep for keywords across
references/
. Headers are grep-friendly.
FileContent
concepts-architecture.md
Core abstractions, terminology, compilation pipeline
concepts-layouts.md
Layout algebra: composition, complement, divide, swizzle
concepts-tensors.md
Tensor types, partitioning, tiling, predication
concepts-mma.md
MMA atoms, TiledMMA, per-architecture tensor core ops
patterns-getting-started.md
Installation, decorators, first kernel walkthrough
patterns-elementwise.md
Invariant principles, pattern variations, reference impl
patterns-gemm.md
3-level tiling, shared memory, pipelining, autotuning
patterns-memory.md
from_dlpack, TMA, cp.async, TMEM, copy atoms
patterns-compilation.md
Control flow, JIT caching, TVM FFI, AOT compilation
patterns-pipeline.md
Producer-consumer, pipeline classes, barriers, warp specialization
api-core.md
cute module: layouts, tensors, math, copy, gemm, printing
api-arch.md
cute.arch: thread indexing, sync, atomics, memory ops
api-nvgpu.md
cute.nvgpu: warp/warpgroup/cpasync/tcgen05 MMA and copy
api-runtime-utils.md
Runtime: from_dlpack, fake tensors, utils, schedulers
troubleshooting.md
Debugging, env vars, common errors, limitations, FAQ
How to search: Grep for your keyword across
references/
. Read only the file and section that Grep points to.
references/
目录中搜索关键词。文件标题便于搜索。
文件内容
concepts-architecture.md
核心抽象、术语、编译流水线
concepts-layouts.md
布局代数:组合、补集、划分、混洗
concepts-tensors.md
张量类型、分区、分块、谓词
concepts-mma.md
MMA原子、TiledMMA、各架构张量核心操作
patterns-getting-started.md
安装、装饰器、第一个内核入门指南
patterns-elementwise.md
不变原则、模式变体、参考实现
patterns-gemm.md
三级分块、共享内存、流水线、自动调优
patterns-memory.md
from_dlpack、TMA、cp.async、TMEM、复制原子
patterns-compilation.md
控制流、JIT缓存、TVM FFI、AOT编译
patterns-pipeline.md
生产者-消费者、流水线类、屏障、Warp特化
api-core.md
cute模块:布局、张量、数学、复制、gemm、打印
api-arch.md
cute.arch:线程索引、同步、原子操作、内存操作
api-nvgpu.md
cute.nvgpu:warp/warpgroup/cpasync/tcgen05 MMA和复制
api-runtime-utils.md
运行时:from_dlpack、伪张量、工具、调度器
troubleshooting.md
调试、环境变量、常见错误、限制、常见问题
搜索方法:
references/
目录中搜索关键词。仅阅读Grep指向的文件和章节。

Tier 3: Original Documentation

三级:官方文档

If Tiers 1–2 don't answer, consult the source:
如果一级和二级无法解答问题,请查阅官方资料: