cute-dsl-ref
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseCuTe Python DSL Reference
CuTe Python DSL 参考文档
Execution Model
执行模型
CuTe Python DSL is the Python surface for NVIDIA's CuTe layout algebra. Unlike cuTile's block-level abstraction, CuTe DSL exposes explicit thread/warp/warpgroup control, TMA pipelines, barrier choreography, and shared memory management.
CuTe Python DSL 是 NVIDIA CuTe 布局代数的Python接口。与cuTile的块级抽象不同,CuTe DSL 暴露了显式的thread/warp/warpgroup控制、TMA管线、barrier编排和共享内存管理。
Two-Level Host/Device Pattern
两级主机/设备模式
Every CuTe DSL kernel has two functions:
- host function — runs on CPU, sets up TMA descriptors, computes grid, allocates shared memory, launches the kernel
@cute.jit - device function — runs on GPU, contains the actual computation
@cute.kernel
python
import cutlass
import cutlass.cute as cute
@cute.kernel
def my_kernel(tiled_mma: cute.TiledMma, ...):
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, _ = cute.arch.block_idx()
# ... GPU code
@cute.jit
def host_fn(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
# Setup TMA descriptors, compute grid, allocate SMEM
my_kernel(...).launch(grid=grid_shape, block=block_shape, smem=smem_bytes)每个CuTe DSL内核包含两个函数:
- 主机函数 — 在CPU上运行,设置TMA描述符、计算网格、分配共享内存、启动内核
@cute.jit - 设备函数 — 在GPU上运行,包含实际计算逻辑
@cute.kernel
python
import cutlass
import cutlass.cute as cute
@cute.kernel
def my_kernel(tiled_mma: cute.TiledMma, ...):
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, _ = cute.arch.block_idx()
# ... GPU代码
@cute.jit
def host_fn(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
# 设置TMA描述符、计算网格、分配SMEM
my_kernel(...).launch(grid=grid_shape, block=block_shape, smem=smem_bytes)Compilation Pipeline
编译流程
Three-stage JIT compilation:
- Pre-Staging: Python AST rewriting
- Meta-Stage: Python interpreter executes meta-programs (compile-time constants resolved)
- Object-Stage: Compiler backend generates PTX → SASS
JIT artifacts are cached automatically. Control with and .
CUTE_DSL_CACHE_DIRCUTE_DSL_DISABLE_FILE_CACHING三段式JIT编译:
- 预阶段:Python AST重写
- 元阶段:Python解释器执行元程序(解析编译时常量)
- 目标阶段:编译器后端生成PTX → SASS
JIT产物会自动缓存,可通过和控制缓存行为。
CUTE_DSL_CACHE_DIRCUTE_DSL_DISABLE_FILE_CACHINGKey Difference from cuTile
与cuTile的核心差异
| Aspect | cuTile | CuTe DSL |
|---|---|---|
| Abstraction level | Block-level (no thread identity) | Thread/warp/warpgroup level |
| TMA | Implicit ( | Explicit descriptor setup in host |
| Shared memory | Compiler-managed | Explicit allocation and layout |
| Synchronization | None exposed | Barriers, arrive/wait, named barriers |
| Pipeline | None exposed | Explicit multi-stage async pipelines |
| MMA | | |
| Compilation | cuTile bytecode → MLIR | CuTe DSL JIT → PTX → SASS |
| 维度 | cuTile | CuTe DSL |
|---|---|---|
| 抽象层级 | 块级(无线程标识) | Thread/warp/warpgroup级 |
| TMA | 隐式( | 在主机端显式设置描述符 |
| 共享内存 | 编译器管理 | 显式分配与布局 |
| 同步机制 | 无暴露接口 | Barrier、arrive/wait、命名barrier |
| 管线 | 无暴露接口 | 显式多阶段异步管线 |
| MMA | 在tile上执行 | 在分片片段上执行 |
| 编译流程 | cuTile字节码 → MLIR | CuTe DSL JIT → PTX → SASS |
Core API Table
核心API表
Decorators and Compilation
装饰器与编译
| API | Purpose | Key params |
|---|---|---|
| Device kernel decorator | function name must match module filename |
| Host JIT function decorator | — |
| Launch kernel | |
| Ahead-of-time compile | |
| API | 用途 | 关键参数 |
|---|---|---|
| 设备内核装饰器 | 函数名必须与模块文件名匹配 |
| 主机JIT函数装饰器 | — |
| 启动内核 | |
| 提前编译 | |
Thread/Block/Grid Indexing
线程/块/网格索引
| API | Purpose |
|---|---|
| Returns |
| Returns |
| Block dimensions |
| Grid dimensions |
| Warp index within block |
| Block index within cluster (Hopper+) |
| API | 用途 |
|---|---|
| 返回 |
| 返回 |
| 块维度 |
| 网格维度 |
| 块内warp索引 |
| 集群内块索引(Hopper+) |
Layout and Tensor Construction
布局与张量构造
| API | Purpose | Notes |
|---|---|---|
| Create layout | Core abstraction: maps coordinates → indices |
| Create with stride order | e.g., |
| Identity layout | For predication |
| Create tensor view | Pairs pointer + layout |
| Allocate register tensor | For accumulators |
| Register tensor matching source | — |
| Tagged pointer | |
| Convert PyTorch/JAX tensor | Returns |
| API | 用途 | 说明 |
|---|---|---|
| 创建布局 | 核心抽象:映射坐标→索引 |
| 创建指定步长顺序的布局 | 例如 |
| 创建恒等布局 | 用于谓词操作 |
| 创建张量视图 | 关联指针+布局 |
| 分配寄存器张量 | 用于累加器 |
| 创建与源匹配的寄存器张量 | — |
| 创建标记指针 | |
| 转换PyTorch/JAX张量 | 返回带 |
Layout Algebra
布局代数
| API | Purpose |
|---|---|
| Total size or modal size |
| Get shape tuple |
| Get stride tuple |
| Number of modes |
| Extract block's tile |
| Divide into tile + rest |
| Compose layouts |
| Complement layout |
| Simplify layout |
| Flatten tensor/layout |
| Group modes together |
| Tile to target shape |
| Ceiling division |
| API | 用途 |
|---|---|
| 获取总大小或指定维度大小 |
| 获取形状元组 |
| 获取步长元组 |
| 获取维度数量 |
| 提取块的tile |
| 拆分为tile+剩余部分 |
| 组合布局 |
| 补布局 |
| 简化布局 |
| 扁平化张量/布局 |
| 合并连续维度 |
| 调整tile至目标形状 |
| 向上取整除法 |
MMA Operations
MMA操作
| API | Purpose | Notes |
|---|---|---|
| Create tiled MMA from operation | e.g., |
| Get thread's partition | Returns |
| Partition A operand | Thread-level view |
| Partition B operand | Thread-level view |
| Partition C operand | Thread-level view |
| Shape of C partition | For fragment allocation |
| Make A fragment | Register allocation |
| Make B fragment | Register allocation |
| Make C fragment | Register allocation |
| Execute MMA: d = a @ b + c | Dispatches to hardware MMA |
| API | 用途 | 说明 |
|---|---|---|
| 基于操作创建分片MMA | 例如 |
| 获取线程的分片 | 返回 |
| 拆分A操作数 | 线程级视图 |
| 拆分B操作数 | 线程级视图 |
| 拆分C操作数 | 线程级视图 |
| C分片的形状 | 用于片段分配 |
| 创建A片段 | 寄存器分配 |
| 创建B片段 | 寄存器分配 |
| 创建C片段 | 寄存器分配 |
| 执行MMA:d = a @ b + c | 调度至硬件MMA |
Copy and TMA Operations
复制与TMA操作
| API | Purpose | Notes |
|---|---|---|
| Create tiled copy | For bulk data movement |
| Execute copy | |
| Element-wise copy | No atom needed |
| Prefetch TMA descriptor | — |
| Create TMA atom for A | Host-side only |
| Create TMA atom for B | Host-side only |
| API | 用途 | 说明 |
|---|---|---|
| 创建分片复制 | 用于批量数据移动 |
| 执行复制 | |
| 逐元素复制 | 无需atom |
| 预取TMA描述符 | — |
| 创建A的TMA atom | 仅主机端可用 |
| 创建B的TMA atom | 仅主机端可用 |
Shared Memory
共享内存
| API | Purpose |
|---|---|
| Create SMEM allocator |
| Allocate tensor in SMEM |
| Allocate struct in SMEM |
| API | 用途 |
|---|---|
| 创建SMEM分配器 |
| 在SMEM中分配张量 |
| 在SMEM中分配结构体 |
Tensor Memory (SM100 only)
张量内存(仅SM100)
| API | Purpose |
|---|---|
| Create TMEM allocator |
| Allocate columns |
| Wait for allocation |
| Get pointer |
| Free memory |
| API | 用途 |
|---|---|
| 创建TMEM分配器 |
| 分配列 |
| 等待分配完成 |
| 获取指针 |
| 释放内存 |
Synchronization
同步机制
| API | Purpose |
|---|---|
| Initialize barrier |
| Arrive at barrier |
| Wait at barrier |
| Arrive with expected TX bytes |
| Non-blocking wait |
| Pipeline classes (see references) | Multi-stage async pipelines |
| API | 用途 |
|---|---|
| 初始化barrier |
| 到达barrier |
| 在barrier处等待 |
| 到达并指定预期传输字节数 |
| 非阻塞等待 |
| 管线类(参见参考文档) | 多阶段异步管线 |
Math (Element-wise on TensorSSA)
数学运算(TensorSSA上的逐元素操作)
| API | Purpose | API | Purpose |
|---|---|---|---|
| e^x | | 2^x |
| ln(x) | | log2(x) |
| sqrt | | 1/sqrt(x) |
| sin | | cos |
| tanh | | erf |
All math functions accept for approximate hardware intrinsics.
fastmath=True| API | 用途 | API | 用途 |
|---|---|---|---|
| e^x | | 2^x |
| ln(x) | | log2(x) |
| 平方根 | | 1/平方根 |
| 正弦 | | 余弦 |
| 双曲正切 | | 误差函数 |
所有数学函数支持参数,使用近似硬件内联函数。
fastmath=TrueTensor Operations
张量操作
| API | Purpose |
|---|---|
| Fill with value |
| Zero tensor matching shape |
| Conditional select |
| Logical any |
| Logical all |
| API | 用途 |
|---|---|
| 填充指定值 |
| 创建与输入形状匹配的零张量 |
| 条件选择 |
| 逻辑或 |
| 逻辑与 |
Debugging
调试
| API | Purpose |
|---|---|
| Device-side printf |
| Print tensor contents |
| Compile-time print (meta-stage only) |
| API | 用途 |
|---|---|
| 设备端printf |
| 打印张量内容 |
| 编译时打印(仅元阶段可用) |
Key Constraints
关键约束
-
function name must match module filename — same rule as cuTile
@cute.kernel -
TMA descriptors are host-side only — create in, pass to
@cute.jit@cute.kernel -
Register budget: 255 max/thread — validate withspecs
docs/devices/ -
SMEM limits vary by device and SM version — SM100 (B200/B300): 228 KB/SM, 227 KB/block opt-in; SM120 (RTX 5090): 96 KB
-
Pipeline stages consume SMEM — each stage needs its own buffer; validate total
-
Barrier sync errors are silent — cause incorrect results, not crashes; always test with
--check -
Cluster dimensions must be compatible with grid — cluster shape must evenly divide grid
-
is compile-time only — use
print()for device-side outputcute.printf() -
No early-exit breaks in loops — use predication instead
-
32-bit layout algebra — shapes/strides limited to 32-bit integers
-
Architecture support: Ampere (SM80) and above; SM100 (B200/B300) for full features includingand TMEM
tcgen05 -
Architecture-specific kernels: Some kernels target a specific SM version (e.g.,MMA ops require SM100, not just "Blackwell"). The "Blackwell" marketing name spans multiple SM versions with different instruction sets:
tcgen05- SM100 (B200, B300) — datacenter GPUs, supports MMA, TMEM, full cluster features
tcgen05 - SM120 (GeForce RTX 5090, RTX 5080) — consumer Blackwell, uses , does not support
sm_120aopstcgen05
Always check the device before running, compiling, or profiling:. Then match against the kernel's target SM version — the GPU name alone is not sufficient; you must know which SM version it maps to. If the SM version does not match, skip the run/profiling step rather than attempting execution that will fail or produce misleading results.nvidia-smi --query-gpu=name --format=csv,noheader - SM100 (B200, B300) — datacenter GPUs, supports
-
函数名必须与模块文件名匹配 — 与cuTile规则一致
@cute.kernel -
TMA描述符仅在主机端可用 — 在中创建,传递给
@cute.jit@cute.kernel -
寄存器预算:单线程最多255个 — 参考中的规格验证
docs/devices/ -
SMEM限制因设备和SM版本而异 — SM100(B200/B300):228 KB/SM,可选227 KB/块;SM120(RTX 5090):96 KB
-
管线阶段会占用SMEM — 每个阶段需要独立缓冲区,需验证总占用量
-
Barrier同步错误无提示 — 会导致结果错误而非崩溃;务必使用测试
--check -
集群维度必须与网格兼容 — 集群形状必须能整除网格
-
仅在编译时可用 — 设备端输出请使用
print()cute.printf() -
循环中不允许提前退出 — 使用谓词替代
-
32位布局代数 — 形状/步长限制为32位整数
-
架构支持:Ampere(SM80)及以上;SM100(B200/B300)支持完整功能,包括和TMEM
tcgen05 -
架构专属内核:部分内核针对特定SM版本(例如MMA操作需要SM100,而非仅"Blackwell")。"Blackwell"营销名称涵盖多个SM版本,指令集不同:
tcgen05- SM100(B200、B300) — 数据中心GPU,支持MMA、TMEM、完整集群功能
tcgen05 - SM120(GeForce RTX 5090、RTX 5080) — 消费级Blackwell,使用,不支持
sm_120a操作tcgen05
运行、编译或性能分析前务必检查设备:。然后匹配内核的目标SM版本 — 仅GPU名称不足够,必须明确其对应的SM版本。若SM版本不匹配,跳过运行/性能分析步骤,避免执行失败或产生误导性结果。nvidia-smi --query-gpu=name --format=csv,noheader - SM100(B200、B300) — 数据中心GPU,支持
Control Flow
控制流
| Construct | Behavior | Notes |
|---|---|---|
| Unrolled if N is static | Standard Python range |
| Runtime loop | Generates MLIR loop |
| Compile-time unroll | N must be static |
| Compile-time branch | Eliminated at compile time |
| Runtime branch | Both branches must type-check |
| 结构 | 行为 | 说明 |
|---|---|---|
| 若N为静态值则展开 | 标准Python range |
| 运行时循环 | 生成MLIR循环 |
| 编译时展开 | N必须为静态值 |
| 编译时分支 | 编译阶段会消除无效分支 |
| 运行时分支 | 两个分支必须通过类型检查 |
Blackwell Datacenter MMA Operations (tcgen05) — SM100 only
Blackwell数据中心MMA操作(tcgen05)—— 仅SM100
SM100 required.ops are available on B200/B300 (SM100) only. They are not available on consumer Blackwell GPUs like RTX 5090 (SM120/tcgen05).sm_120a
python
from cutlass.cute.nvgpu import tcgen05需SM100支持。操作仅在B200/B300(SM100)上可用。在RTX 5090(SM120/tcgen05)等消费级Blackwell GPU上不可用。sm_120a
python
from cutlass.cute.nvgpu import tcgen05Create MMA operation for SM100 (B200/B300) tensor cores
为SM100(B200/B300)张量核心创建MMA操作
op = tcgen05.MmaF16BF16Op(
dtype=cutlass.Float16, # Input type
acc_dtype=cutlass.Float32, # Accumulator type
shape=(128, 128, 64), # M x N x K tile shape
cta_group=tcgen05.CtaGroup.ONE, # ONE or TWO CTAs
)
tiled_mma = cute.make_tiled_mma(op)
| Operation | Input types | Shapes |
|-----------|-------------|--------|
| `MmaF16BF16Op` | FP16/BF16 → FP32 | Various M×N×K |
| `MmaF8F6F4Op` | FP8/FP6/FP4 → FP32 | Narrow precision |
| Block-scaled variants | With scale factors | See examples |
---op = tcgen05.MmaF16BF16Op(
dtype=cutlass.Float16, # 输入类型
acc_dtype=cutlass.Float32, # 累加器类型
shape=(128, 128, 64), # M x N x K tile形状
cta_group=tcgen05.CtaGroup.ONE, # ONE或TWO CTAs
)
tiled_mma = cute.make_tiled_mma(op)
| 操作 | 输入类型 | 形状 |
|-----------|-------------|--------|
| `MmaF16BF16Op` | FP16/BF16 → FP32 | 多种M×N×K |
| `MmaF8F6F4Op` | FP8/FP6/FP4 → FP32 | 窄精度 |
| 块缩放变体 | 带缩放因子 | 参见示例 |
---Common Patterns
常见模式
GEMM with TMA Pipeline (Simplified)
带TMA管线的GEMM(简化版)
python
@cute.kernel
def gemm_kernel(tiled_mma, tma_a, tma_b, smem_a, smem_b, gmem_c, ...):
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, _ = cute.arch.block_idx()
# Get thread's MMA partition
thr_mma = tiled_mma.get_slice(tidx)
# Allocate accumulator in registers
acc = tiled_mma.partition_shape_C(tile_shape)
tCrC = cute.make_fragment(acc, cutlass.Float32)
# K-loop with TMA pipeline
for k_tile in range(num_k_tiles):
# TMA copy: global → shared (async)
cute.copy(tma_a, gA_tile, sA_tile, tma_bar_ptr=barrier)
cute.copy(tma_b, gB_tile, sB_tile, tma_bar_ptr=barrier)
# Wait for TMA to complete
cute.arch.mbar.wait(barrier, phase)
# Partition shared memory for this thread
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sB)
# MMA: accumulate into registers
cute.gemm(tiled_mma, tCrC, tCsA, tCsB, tCrC)
# Epilogue: write back to global memory
cute.basic_copy(tCrC, gC_partition)python
@cute.kernel
def gemm_kernel(tiled_mma, tma_a, tma_b, smem_a, smem_b, gmem_c, ...):
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, _ = cute.arch.block_idx()
# 获取线程的MMA分片
thr_mma = tiled_mma.get_slice(tidx)
# 在寄存器中分配累加器
acc = tiled_mma.partition_shape_C(tile_shape)
tCrC = cute.make_fragment(acc, cutlass.Float32)
# 带TMA管线的K循环
for k_tile in range(num_k_tiles):
# TMA复制:全局→共享(异步)
cute.copy(tma_a, gA_tile, sA_tile, tma_bar_ptr=barrier)
cute.copy(tma_b, gB_tile, sB_tile, tma_bar_ptr=barrier)
# 等待TMA完成
cute.arch.mbar.wait(barrier, phase)
# 为当前线程拆分共享内存
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sB)
# MMA:累加到寄存器
cute.gemm(tiled_mma, tCrC, tCsA, tCsB, tCrC)
# 收尾:写回全局内存
cute.basic_copy(tCrC, gC_partition)Warpgroup Specialization (Producer/Consumer)
Warpgroup特化(生产者/消费者)
python
warp_idx = cute.arch.warp_idx()
is_producer = warp_idx < num_producer_warps
if is_producer:
# Issue TMA copies for upcoming pipeline stages
for stage in range(num_stages):
cute.copy(tma_atom, src, dst, tma_bar_ptr=barriers[stage])
cute.arch.mbar.arrive_and_expect_tx(barriers[stage], bytes_per_stage)
else:
# Consume data from shared memory, execute MMA
for stage in range(num_stages):
cute.arch.mbar.wait(barriers[stage], phase)
cute.gemm(tiled_mma, acc, sA_stage, sB_stage, acc)python
warp_idx = cute.arch.warp_idx()
is_producer = warp_idx < num_producer_warps
if is_producer:
# 为后续管线阶段发起TMA复制
for stage in range(num_stages):
cute.copy(tma_atom, src, dst, tma_bar_ptr=barriers[stage])
cute.arch.mbar.arrive_and_expect_tx(barriers[stage], bytes_per_stage)
else:
# 从共享内存消费数据,执行MMA
for stage in range(num_stages):
cute.arch.mbar.wait(barriers[stage], phase)
cute.gemm(tiled_mma, acc, sA_stage, sB_stage, acc)Persistent Kernel Loop
持久化内核循环
python
@cute.kernel
def persistent_kernel(tiled_mma, num_tiles, ...):
bidx, _, _ = cute.arch.block_idx()
num_blocks = cute.arch.grid_dim()[0]
tile_id = bidx
while tile_id < num_tiles:
# Process tile
# ... TMA copy, MMA, epilogue ...
tile_id += num_blockspython
@cute.kernel
def persistent_kernel(tiled_mma, num_tiles, ...):
bidx, _, _ = cute.arch.block_idx()
num_blocks = cute.arch.grid_dim()[0]
tile_id = bidx
while tile_id < num_tiles:
# 处理tile
# ... TMA复制、MMA、收尾 ...
tile_id += num_blocksDetailed References
详细参考
- Architecture operations (thread indexing, TMA, MMA atoms, barriers, SMEM/TMEM): See references/architecture-ops.md
- Official documentation index: See references/docs-index.md
- Example kernels: See the skill for a categorized example source index
/learn-cute-dsl
- 架构操作(线程索引、TMA、MMA atom、barrier、SMEM/TMEM):参见references/architecture-ops.md
- 官方文档索引:参见references/docs-index.md
- 示例内核:查看技能获取分类化的示例源码索引
/learn-cute-dsl