cute-dsl-ref

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

CuTe Python DSL Reference

CuTe Python DSL 参考文档

Execution Model

执行模型

CuTe Python DSL is the Python surface for NVIDIA's CuTe layout algebra. Unlike cuTile's block-level abstraction, CuTe DSL exposes explicit thread/warp/warpgroup control, TMA pipelines, barrier choreography, and shared memory management.
CuTe Python DSL 是 NVIDIA CuTe 布局代数的Python接口。与cuTile的块级抽象不同,CuTe DSL 暴露了显式的thread/warp/warpgroup控制、TMA管线、barrier编排和共享内存管理。

Two-Level Host/Device Pattern

两级主机/设备模式

Every CuTe DSL kernel has two functions:
  1. @cute.jit
    host function
    — runs on CPU, sets up TMA descriptors, computes grid, allocates shared memory, launches the kernel
  2. @cute.kernel
    device function
    — runs on GPU, contains the actual computation
python
import cutlass
import cutlass.cute as cute

@cute.kernel
def my_kernel(tiled_mma: cute.TiledMma, ...):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, bidy, _ = cute.arch.block_idx()
    # ... GPU code

@cute.jit
def host_fn(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
    # Setup TMA descriptors, compute grid, allocate SMEM
    my_kernel(...).launch(grid=grid_shape, block=block_shape, smem=smem_bytes)
每个CuTe DSL内核包含两个函数:
  1. @cute.jit
    主机函数
    — 在CPU上运行,设置TMA描述符、计算网格、分配共享内存、启动内核
  2. @cute.kernel
    设备函数
    — 在GPU上运行,包含实际计算逻辑
python
import cutlass
import cutlass.cute as cute

@cute.kernel
def my_kernel(tiled_mma: cute.TiledMma, ...):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, bidy, _ = cute.arch.block_idx()
    # ... GPU代码

@cute.jit
def host_fn(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
    # 设置TMA描述符、计算网格、分配SMEM
    my_kernel(...).launch(grid=grid_shape, block=block_shape, smem=smem_bytes)

Compilation Pipeline

编译流程

Three-stage JIT compilation:
  1. Pre-Staging: Python AST rewriting
  2. Meta-Stage: Python interpreter executes meta-programs (compile-time constants resolved)
  3. Object-Stage: Compiler backend generates PTX → SASS
JIT artifacts are cached automatically. Control with
CUTE_DSL_CACHE_DIR
and
CUTE_DSL_DISABLE_FILE_CACHING
.
三段式JIT编译:
  1. 预阶段:Python AST重写
  2. 元阶段:Python解释器执行元程序(解析编译时常量)
  3. 目标阶段:编译器后端生成PTX → SASS
JIT产物会自动缓存,可通过
CUTE_DSL_CACHE_DIR
CUTE_DSL_DISABLE_FILE_CACHING
控制缓存行为。

Key Difference from cuTile

与cuTile的核心差异

AspectcuTileCuTe DSL
Abstraction levelBlock-level (no thread identity)Thread/warp/warpgroup level
TMAImplicit (
allow_tma=True
)
Explicit descriptor setup in host
Shared memoryCompiler-managedExplicit allocation and layout
SynchronizationNone exposedBarriers, arrive/wait, named barriers
PipelineNone exposedExplicit multi-stage async pipelines
MMA
ct.mma(A, B, C)
on tiles
cute.gemm(tiled_mma, d, a, b, c)
on partitioned fragments
CompilationcuTile bytecode → MLIRCuTe DSL JIT → PTX → SASS

维度cuTileCuTe DSL
抽象层级块级(无线程标识)Thread/warp/warpgroup级
TMA隐式(
allow_tma=True
在主机端显式设置描述符
共享内存编译器管理显式分配与布局
同步机制无暴露接口Barrier、arrive/wait、命名barrier
管线无暴露接口显式多阶段异步管线
MMA在tile上执行
ct.mma(A, B, C)
在分片片段上执行
cute.gemm(tiled_mma, d, a, b, c)
编译流程cuTile字节码 → MLIRCuTe DSL JIT → PTX → SASS

Core API Table

核心API表

Decorators and Compilation

装饰器与编译

APIPurposeKey params
@cute.kernel
Device kernel decoratorfunction name must match module filename
@cute.jit
Host JIT function decorator
kernel(...).launch(grid, block, smem, cluster)
Launch kernel
grid=(x,y,z)
,
block=(threads,1,1)
,
smem=bytes
,
cluster=(cx,cy,cz)
cute.compile(fn, *args, options="...")
Ahead-of-time compile
"opt-level=3"
,
"keep-ptx"
, etc.
API用途关键参数
@cute.kernel
设备内核装饰器函数名必须与模块文件名匹配
@cute.jit
主机JIT函数装饰器
kernel(...).launch(grid, block, smem, cluster)
启动内核
grid=(x,y,z)
block=(threads,1,1)
smem=bytes
cluster=(cx,cy,cz)
cute.compile(fn, *args, options="...")
提前编译
"opt-level=3"
"keep-ptx"

Thread/Block/Grid Indexing

线程/块/网格索引

APIPurpose
cute.arch.thread_idx()
Returns
(tidx, tidy, tidz)
cute.arch.block_idx()
Returns
(bidx, bidy, bidz)
cute.arch.block_dim()
Block dimensions
cute.arch.grid_dim()
Grid dimensions
cute.arch.warp_idx()
Warp index within block
cute.arch.block_idx_in_cluster()
Block index within cluster (Hopper+)
API用途
cute.arch.thread_idx()
返回
(tidx, tidy, tidz)
cute.arch.block_idx()
返回
(bidx, bidy, bidz)
cute.arch.block_dim()
块维度
cute.arch.grid_dim()
网格维度
cute.arch.warp_idx()
块内warp索引
cute.arch.block_idx_in_cluster()
集群内块索引(Hopper+)

Layout and Tensor Construction

布局与张量构造

APIPurposeNotes
cute.make_layout(shape, stride)
Create layoutCore abstraction: maps coordinates → indices
cute.make_ordered_layout(shape, order)
Create with stride ordere.g.,
order=(1,0)
for column-major
cute.make_identity_layout(shape)
Identity layoutFor predication
cute.make_tensor(ptr, layout)
Create tensor viewPairs pointer + layout
cute.make_fragment(layout, dtype)
Allocate register tensorFor accumulators
cute.make_fragment_like(src, dtype)
Register tensor matching source
cute.make_ptr(dtype, addr, memspace)
Tagged pointer
cute.AddressSpace.smem/gmem
cute.from_dlpack(tensor)
Convert PyTorch/JAX tensorReturns
cute.Tensor
with
.mark_layout_dynamic()
API用途说明
cute.make_layout(shape, stride)
创建布局核心抽象:映射坐标→索引
cute.make_ordered_layout(shape, order)
创建指定步长顺序的布局例如
order=(1,0)
表示列优先
cute.make_identity_layout(shape)
创建恒等布局用于谓词操作
cute.make_tensor(ptr, layout)
创建张量视图关联指针+布局
cute.make_fragment(layout, dtype)
分配寄存器张量用于累加器
cute.make_fragment_like(src, dtype)
创建与源匹配的寄存器张量
cute.make_ptr(dtype, addr, memspace)
创建标记指针
cute.AddressSpace.smem/gmem
cute.from_dlpack(tensor)
转换PyTorch/JAX张量返回带
.mark_layout_dynamic()
cute.Tensor

Layout Algebra

布局代数

APIPurpose
cute.size(t, mode=None)
Total size or modal size
cute.shape(t)
Get shape tuple
cute.stride(t)
Get stride tuple
cute.rank(t)
Number of modes
cute.local_tile(tensor, tiler, coord, proj)
Extract block's tile
cute.logical_divide(tensor, divisor)
Divide into tile + rest
cute.composition(layout1, layout2)
Compose layouts
cute.complement(layout)
Complement layout
cute.coalesce(layout)
Simplify layout
cute.flatten(t)
Flatten tensor/layout
cute.group_modes(layout, start, end)
Group modes together
cute.tile_to_shape(tile, shape)
Tile to target shape
cute.ceil_div(a, b)
Ceiling division
API用途
cute.size(t, mode=None)
获取总大小或指定维度大小
cute.shape(t)
获取形状元组
cute.stride(t)
获取步长元组
cute.rank(t)
获取维度数量
cute.local_tile(tensor, tiler, coord, proj)
提取块的tile
cute.logical_divide(tensor, divisor)
拆分为tile+剩余部分
cute.composition(layout1, layout2)
组合布局
cute.complement(layout)
补布局
cute.coalesce(layout)
简化布局
cute.flatten(t)
扁平化张量/布局
cute.group_modes(layout, start, end)
合并连续维度
cute.tile_to_shape(tile, shape)
调整tile至目标形状
cute.ceil_div(a, b)
向上取整除法

MMA Operations

MMA操作

APIPurposeNotes
cute.make_tiled_mma(op)
Create tiled MMA from operatione.g.,
tcgen05.MmaF16BF16Op(...)
tiled_mma.get_slice(thread_idx)
Get thread's partitionReturns
ThrMma
thr_mma.partition_A(tensor)
Partition A operandThread-level view
thr_mma.partition_B(tensor)
Partition B operandThread-level view
thr_mma.partition_C(tensor)
Partition C operandThread-level view
tiled_mma.partition_shape_C(shape)
Shape of C partitionFor fragment allocation
tiled_mma.make_fragment_A(tensor)
Make A fragmentRegister allocation
tiled_mma.make_fragment_B(tensor)
Make B fragmentRegister allocation
tiled_mma.make_fragment_C(shape)
Make C fragmentRegister allocation
cute.gemm(tiled_mma, d, a, b, c)
Execute MMA: d = a @ b + cDispatches to hardware MMA
API用途说明
cute.make_tiled_mma(op)
基于操作创建分片MMA例如
tcgen05.MmaF16BF16Op(...)
tiled_mma.get_slice(thread_idx)
获取线程的分片返回
ThrMma
thr_mma.partition_A(tensor)
拆分A操作数线程级视图
thr_mma.partition_B(tensor)
拆分B操作数线程级视图
thr_mma.partition_C(tensor)
拆分C操作数线程级视图
tiled_mma.partition_shape_C(shape)
C分片的形状用于片段分配
tiled_mma.make_fragment_A(tensor)
创建A片段寄存器分配
tiled_mma.make_fragment_B(tensor)
创建B片段寄存器分配
tiled_mma.make_fragment_C(shape)
创建C片段寄存器分配
cute.gemm(tiled_mma, d, a, b, c)
执行MMA:d = a @ b + c调度至硬件MMA

Copy and TMA Operations

复制与TMA操作

APIPurposeNotes
cute.make_tiled_copy(atom)
Create tiled copyFor bulk data movement
cute.copy(atom, src, dst, **kw)
Execute copy
tma_bar_ptr=
,
mcast_mask=
cute.basic_copy(src, dst)
Element-wise copyNo atom needed
cute.prefetch(atom, src)
Prefetch TMA descriptor
cute.nvgpu.make_tiled_tma_atom_A(op, tensor, smem_layout, tile, mma)
Create TMA atom for AHost-side only
cute.nvgpu.make_tiled_tma_atom_B(op, tensor, smem_layout, tile, mma)
Create TMA atom for BHost-side only
API用途说明
cute.make_tiled_copy(atom)
创建分片复制用于批量数据移动
cute.copy(atom, src, dst, **kw)
执行复制
tma_bar_ptr=
mcast_mask=
cute.basic_copy(src, dst)
逐元素复制无需atom
cute.prefetch(atom, src)
预取TMA描述符
cute.nvgpu.make_tiled_tma_atom_A(op, tensor, smem_layout, tile, mma)
创建A的TMA atom仅主机端可用
cute.nvgpu.make_tiled_tma_atom_B(op, tensor, smem_layout, tile, mma)
创建B的TMA atom仅主机端可用

Shared Memory

共享内存

APIPurpose
cutlass.utils.SmemAllocator()
Create SMEM allocator
smem.allocate_tensor(dtype, layout, align, swizzle)
Allocate tensor in SMEM
smem.allocate(struct_type)
Allocate struct in SMEM
API用途
cutlass.utils.SmemAllocator()
创建SMEM分配器
smem.allocate_tensor(dtype, layout, align, swizzle)
在SMEM中分配张量
smem.allocate(struct_type)
在SMEM中分配结构体

Tensor Memory (SM100 only)

张量内存(仅SM100)

APIPurpose
cutlass.utils.TmemAllocator(...)
Create TMEM allocator
tmem.allocate(num_cols)
Allocate columns
tmem.wait_for_alloc()
Wait for allocation
tmem.retrieve_ptr(dtype)
Get pointer
tmem.free(ptr)
Free memory
API用途
cutlass.utils.TmemAllocator(...)
创建TMEM分配器
tmem.allocate(num_cols)
分配列
tmem.wait_for_alloc()
等待分配完成
tmem.retrieve_ptr(dtype)
获取指针
tmem.free(ptr)
释放内存

Synchronization

同步机制

APIPurpose
cute.arch.mbar.init(barrier_ptr, count)
Initialize barrier
cute.arch.mbar.arrive(barrier_ptr)
Arrive at barrier
cute.arch.mbar.wait(barrier_ptr, phase)
Wait at barrier
cute.arch.mbar.arrive_and_expect_tx(barrier_ptr, bytes)
Arrive with expected TX bytes
cute.arch.mbar.try_wait(barrier_ptr, phase)
Non-blocking wait
Pipeline classes (see references)Multi-stage async pipelines
API用途
cute.arch.mbar.init(barrier_ptr, count)
初始化barrier
cute.arch.mbar.arrive(barrier_ptr)
到达barrier
cute.arch.mbar.wait(barrier_ptr, phase)
在barrier处等待
cute.arch.mbar.arrive_and_expect_tx(barrier_ptr, bytes)
到达并指定预期传输字节数
cute.arch.mbar.try_wait(barrier_ptr, phase)
非阻塞等待
管线类(参见参考文档)多阶段异步管线

Math (Element-wise on TensorSSA)

数学运算(TensorSSA上的逐元素操作)

APIPurposeAPIPurpose
cute.exp(x)
e^x
cute.exp2(x)
2^x
cute.log(x)
ln(x)
cute.log2(x)
log2(x)
cute.sqrt(x)
sqrt
cute.rsqrt(x)
1/sqrt(x)
cute.sin(x)
sin
cute.cos(x)
cos
cute.tanh(x)
tanh
cute.erf(x)
erf
All math functions accept
fastmath=True
for approximate hardware intrinsics.
API用途API用途
cute.exp(x)
e^x
cute.exp2(x)
2^x
cute.log(x)
ln(x)
cute.log2(x)
log2(x)
cute.sqrt(x)
平方根
cute.rsqrt(x)
1/平方根
cute.sin(x)
正弦
cute.cos(x)
余弦
cute.tanh(x)
双曲正切
cute.erf(x)
误差函数
所有数学函数支持
fastmath=True
参数,使用近似硬件内联函数。

Tensor Operations

张量操作

APIPurpose
cute.full(shape, val, dtype)
Fill with value
cute.zeros_like(tensor)
Zero tensor matching shape
cute.where(cond, x, y)
Conditional select
cute.any_(tensor)
Logical any
cute.all_(tensor)
Logical all
API用途
cute.full(shape, val, dtype)
填充指定值
cute.zeros_like(tensor)
创建与输入形状匹配的零张量
cute.where(cond, x, y)
条件选择
cute.any_(tensor)
逻辑或
cute.all_(tensor)
逻辑与

Debugging

调试

APIPurpose
cute.printf(fmt, ...)
Device-side printf
cute.print_tensor(tensor)
Print tensor contents
print(...)
Compile-time print (meta-stage only)

API用途
cute.printf(fmt, ...)
设备端printf
cute.print_tensor(tensor)
打印张量内容
print(...)
编译时打印(仅元阶段可用)

Key Constraints

关键约束

  1. @cute.kernel
    function name must match module filename
    — same rule as cuTile
  2. TMA descriptors are host-side only — create in
    @cute.jit
    , pass to
    @cute.kernel
  3. Register budget: 255 max/thread — validate with
    docs/devices/
    specs
  4. SMEM limits vary by device and SM version — SM100 (B200/B300): 228 KB/SM, 227 KB/block opt-in; SM120 (RTX 5090): 96 KB
  5. Pipeline stages consume SMEM — each stage needs its own buffer; validate total
  6. Barrier sync errors are silent — cause incorrect results, not crashes; always test with
    --check
  7. Cluster dimensions must be compatible with grid — cluster shape must evenly divide grid
  8. print()
    is compile-time only
    — use
    cute.printf()
    for device-side output
  9. No early-exit breaks in loops — use predication instead
  10. 32-bit layout algebra — shapes/strides limited to 32-bit integers
  11. Architecture support: Ampere (SM80) and above; SM100 (B200/B300) for full features including
    tcgen05
    and TMEM
  12. Architecture-specific kernels: Some kernels target a specific SM version (e.g.,
    tcgen05
    MMA ops require SM100, not just "Blackwell"). The "Blackwell" marketing name spans multiple SM versions with different instruction sets:
    • SM100 (B200, B300) — datacenter GPUs, supports
      tcgen05
      MMA, TMEM, full cluster features
    • SM120 (GeForce RTX 5090, RTX 5080) — consumer Blackwell, uses
      sm_120a
      , does not support
      tcgen05
      ops
    Always check the device before running, compiling, or profiling:
    nvidia-smi --query-gpu=name --format=csv,noheader
    . Then match against the kernel's target SM version — the GPU name alone is not sufficient; you must know which SM version it maps to. If the SM version does not match, skip the run/profiling step rather than attempting execution that will fail or produce misleading results.
  1. @cute.kernel
    函数名必须与模块文件名匹配
    — 与cuTile规则一致
  2. TMA描述符仅在主机端可用 — 在
    @cute.jit
    中创建,传递给
    @cute.kernel
  3. 寄存器预算:单线程最多255个 — 参考
    docs/devices/
    中的规格验证
  4. SMEM限制因设备和SM版本而异 — SM100(B200/B300):228 KB/SM,可选227 KB/块;SM120(RTX 5090):96 KB
  5. 管线阶段会占用SMEM — 每个阶段需要独立缓冲区,需验证总占用量
  6. Barrier同步错误无提示 — 会导致结果错误而非崩溃;务必使用
    --check
    测试
  7. 集群维度必须与网格兼容 — 集群形状必须能整除网格
  8. print()
    仅在编译时可用
    — 设备端输出请使用
    cute.printf()
  9. 循环中不允许提前退出 — 使用谓词替代
  10. 32位布局代数 — 形状/步长限制为32位整数
  11. 架构支持:Ampere(SM80)及以上;SM100(B200/B300)支持完整功能,包括
    tcgen05
    和TMEM
  12. 架构专属内核:部分内核针对特定SM版本(例如
    tcgen05
    MMA操作需要SM100,而非仅"Blackwell")。"Blackwell"营销名称涵盖多个SM版本,指令集不同:
    • SM100(B200、B300) — 数据中心GPU,支持
      tcgen05
      MMA、TMEM、完整集群功能
    • SM120(GeForce RTX 5090、RTX 5080) — 消费级Blackwell,使用
      sm_120a
      不支持
      tcgen05
      操作
    运行、编译或性能分析前务必检查设备:
    nvidia-smi --query-gpu=name --format=csv,noheader
    。然后匹配内核的目标SM版本 — 仅GPU名称不足够,必须明确其对应的SM版本。若SM版本不匹配,跳过运行/性能分析步骤,避免执行失败或产生误导性结果。

Control Flow

控制流

ConstructBehaviorNotes
for i in range(N)
Unrolled if N is staticStandard Python range
for i in cutlass.range(N)
Runtime loopGenerates MLIR loop
for i in cutlass.range_constexpr(N)
Compile-time unrollN must be static
if cutlass.const_expr(cond)
Compile-time branchEliminated at compile time
if cond
Runtime branchBoth branches must type-check

结构行为说明
for i in range(N)
若N为静态值则展开标准Python range
for i in cutlass.range(N)
运行时循环生成MLIR循环
for i in cutlass.range_constexpr(N)
编译时展开N必须为静态值
if cutlass.const_expr(cond)
编译时分支编译阶段会消除无效分支
if cond
运行时分支两个分支必须通过类型检查

Blackwell Datacenter MMA Operations (tcgen05) — SM100 only

Blackwell数据中心MMA操作(tcgen05)—— 仅SM100

SM100 required.
tcgen05
ops are available on B200/B300 (SM100) only. They are not available on consumer Blackwell GPUs like RTX 5090 (SM120/
sm_120a
).
python
from cutlass.cute.nvgpu import tcgen05
需SM100支持。
tcgen05
操作仅在B200/B300(SM100)上可用。在RTX 5090(SM120/
sm_120a
)等消费级Blackwell GPU上不可用
python
from cutlass.cute.nvgpu import tcgen05

Create MMA operation for SM100 (B200/B300) tensor cores

为SM100(B200/B300)张量核心创建MMA操作

op = tcgen05.MmaF16BF16Op( dtype=cutlass.Float16, # Input type acc_dtype=cutlass.Float32, # Accumulator type shape=(128, 128, 64), # M x N x K tile shape cta_group=tcgen05.CtaGroup.ONE, # ONE or TWO CTAs )
tiled_mma = cute.make_tiled_mma(op)

| Operation | Input types | Shapes |
|-----------|-------------|--------|
| `MmaF16BF16Op` | FP16/BF16 → FP32 | Various M×N×K |
| `MmaF8F6F4Op` | FP8/FP6/FP4 → FP32 | Narrow precision |
| Block-scaled variants | With scale factors | See examples |

---
op = tcgen05.MmaF16BF16Op( dtype=cutlass.Float16, # 输入类型 acc_dtype=cutlass.Float32, # 累加器类型 shape=(128, 128, 64), # M x N x K tile形状 cta_group=tcgen05.CtaGroup.ONE, # ONE或TWO CTAs )
tiled_mma = cute.make_tiled_mma(op)

| 操作 | 输入类型 | 形状 |
|-----------|-------------|--------|
| `MmaF16BF16Op` | FP16/BF16 → FP32 | 多种M×N×K |
| `MmaF8F6F4Op` | FP8/FP6/FP4 → FP32 | 窄精度 |
| 块缩放变体 | 带缩放因子 | 参见示例 |

---

Common Patterns

常见模式

GEMM with TMA Pipeline (Simplified)

带TMA管线的GEMM(简化版)

python
@cute.kernel
def gemm_kernel(tiled_mma, tma_a, tma_b, smem_a, smem_b, gmem_c, ...):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, bidy, _ = cute.arch.block_idx()

    # Get thread's MMA partition
    thr_mma = tiled_mma.get_slice(tidx)

    # Allocate accumulator in registers
    acc = tiled_mma.partition_shape_C(tile_shape)
    tCrC = cute.make_fragment(acc, cutlass.Float32)

    # K-loop with TMA pipeline
    for k_tile in range(num_k_tiles):
        # TMA copy: global → shared (async)
        cute.copy(tma_a, gA_tile, sA_tile, tma_bar_ptr=barrier)
        cute.copy(tma_b, gB_tile, sB_tile, tma_bar_ptr=barrier)

        # Wait for TMA to complete
        cute.arch.mbar.wait(barrier, phase)

        # Partition shared memory for this thread
        tCsA = thr_mma.partition_A(sA)
        tCsB = thr_mma.partition_B(sB)

        # MMA: accumulate into registers
        cute.gemm(tiled_mma, tCrC, tCsA, tCsB, tCrC)

    # Epilogue: write back to global memory
    cute.basic_copy(tCrC, gC_partition)
python
@cute.kernel
def gemm_kernel(tiled_mma, tma_a, tma_b, smem_a, smem_b, gmem_c, ...):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, bidy, _ = cute.arch.block_idx()

    # 获取线程的MMA分片
    thr_mma = tiled_mma.get_slice(tidx)

    # 在寄存器中分配累加器
    acc = tiled_mma.partition_shape_C(tile_shape)
    tCrC = cute.make_fragment(acc, cutlass.Float32)

    # 带TMA管线的K循环
    for k_tile in range(num_k_tiles):
        # TMA复制:全局→共享(异步)
        cute.copy(tma_a, gA_tile, sA_tile, tma_bar_ptr=barrier)
        cute.copy(tma_b, gB_tile, sB_tile, tma_bar_ptr=barrier)

        # 等待TMA完成
        cute.arch.mbar.wait(barrier, phase)

        # 为当前线程拆分共享内存
        tCsA = thr_mma.partition_A(sA)
        tCsB = thr_mma.partition_B(sB)

        # MMA:累加到寄存器
        cute.gemm(tiled_mma, tCrC, tCsA, tCsB, tCrC)

    # 收尾:写回全局内存
    cute.basic_copy(tCrC, gC_partition)

Warpgroup Specialization (Producer/Consumer)

Warpgroup特化(生产者/消费者)

python
warp_idx = cute.arch.warp_idx()
is_producer = warp_idx < num_producer_warps

if is_producer:
    # Issue TMA copies for upcoming pipeline stages
    for stage in range(num_stages):
        cute.copy(tma_atom, src, dst, tma_bar_ptr=barriers[stage])
        cute.arch.mbar.arrive_and_expect_tx(barriers[stage], bytes_per_stage)
else:
    # Consume data from shared memory, execute MMA
    for stage in range(num_stages):
        cute.arch.mbar.wait(barriers[stage], phase)
        cute.gemm(tiled_mma, acc, sA_stage, sB_stage, acc)
python
warp_idx = cute.arch.warp_idx()
is_producer = warp_idx < num_producer_warps

if is_producer:
    # 为后续管线阶段发起TMA复制
    for stage in range(num_stages):
        cute.copy(tma_atom, src, dst, tma_bar_ptr=barriers[stage])
        cute.arch.mbar.arrive_and_expect_tx(barriers[stage], bytes_per_stage)
else:
    # 从共享内存消费数据,执行MMA
    for stage in range(num_stages):
        cute.arch.mbar.wait(barriers[stage], phase)
        cute.gemm(tiled_mma, acc, sA_stage, sB_stage, acc)

Persistent Kernel Loop

持久化内核循环

python
@cute.kernel
def persistent_kernel(tiled_mma, num_tiles, ...):
    bidx, _, _ = cute.arch.block_idx()
    num_blocks = cute.arch.grid_dim()[0]

    tile_id = bidx
    while tile_id < num_tiles:
        # Process tile
        # ... TMA copy, MMA, epilogue ...
        tile_id += num_blocks

python
@cute.kernel
def persistent_kernel(tiled_mma, num_tiles, ...):
    bidx, _, _ = cute.arch.block_idx()
    num_blocks = cute.arch.grid_dim()[0]

    tile_id = bidx
    while tile_id < num_tiles:
        # 处理tile
        # ... TMA复制、MMA、收尾 ...
        tile_id += num_blocks

Detailed References

详细参考

  • Architecture operations (thread indexing, TMA, MMA atoms, barriers, SMEM/TMEM): See references/architecture-ops.md
  • Official documentation index: See references/docs-index.md
  • Example kernels: See the
    /learn-cute-dsl
    skill for a categorized example source index
  • 架构操作(线程索引、TMA、MMA atom、barrier、SMEM/TMEM):参见references/architecture-ops.md
  • 官方文档索引:参见references/docs-index.md
  • 示例内核:查看
    /learn-cute-dsl
    技能获取分类化的示例源码索引