Loading...
Loading...
CuTe Python DSL API reference and implementation patterns for NVIDIA GPU kernel programming. Provides execution model, core API table, key constraints, common patterns, and documentation index. Use when: (1) writing or modifying CuTe DSL kernel code, (2) looking up CuTe DSL API syntax, (3) implementing attention/GEMM/MLA patterns in CuTe DSL, (4) understanding CuTe DSL execution model and compilation pipeline, (5) checking what CuTe DSL can and cannot do.
npx skill4agent add pepperu96/hyper-mla cute-dsl-ref@cute.jit@cute.kernelimport cutlass
import cutlass.cute as cute
@cute.kernel
def my_kernel(tiled_mma: cute.TiledMma, ...):
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, _ = cute.arch.block_idx()
# ... GPU code
@cute.jit
def host_fn(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
# Setup TMA descriptors, compute grid, allocate SMEM
my_kernel(...).launch(grid=grid_shape, block=block_shape, smem=smem_bytes)CUTE_DSL_CACHE_DIRCUTE_DSL_DISABLE_FILE_CACHING| Aspect | cuTile | CuTe DSL |
|---|---|---|
| Abstraction level | Block-level (no thread identity) | Thread/warp/warpgroup level |
| TMA | Implicit ( | Explicit descriptor setup in host |
| Shared memory | Compiler-managed | Explicit allocation and layout |
| Synchronization | None exposed | Barriers, arrive/wait, named barriers |
| Pipeline | None exposed | Explicit multi-stage async pipelines |
| MMA | | |
| Compilation | cuTile bytecode → MLIR | CuTe DSL JIT → PTX → SASS |
| API | Purpose | Key params |
|---|---|---|
| Device kernel decorator | function name must match module filename |
| Host JIT function decorator | — |
| Launch kernel | |
| Ahead-of-time compile | |
| API | Purpose |
|---|---|
| Returns |
| Returns |
| Block dimensions |
| Grid dimensions |
| Warp index within block |
| Block index within cluster (Hopper+) |
| API | Purpose | Notes |
|---|---|---|
| Create layout | Core abstraction: maps coordinates → indices |
| Create with stride order | e.g., |
| Identity layout | For predication |
| Create tensor view | Pairs pointer + layout |
| Allocate register tensor | For accumulators |
| Register tensor matching source | — |
| Tagged pointer | |
| Convert PyTorch/JAX tensor | Returns |
| API | Purpose |
|---|---|
| Total size or modal size |
| Get shape tuple |
| Get stride tuple |
| Number of modes |
| Extract block's tile |
| Divide into tile + rest |
| Compose layouts |
| Complement layout |
| Simplify layout |
| Flatten tensor/layout |
| Group modes together |
| Tile to target shape |
| Ceiling division |
| API | Purpose | Notes |
|---|---|---|
| Create tiled MMA from operation | e.g., |
| Get thread's partition | Returns |
| Partition A operand | Thread-level view |
| Partition B operand | Thread-level view |
| Partition C operand | Thread-level view |
| Shape of C partition | For fragment allocation |
| Make A fragment | Register allocation |
| Make B fragment | Register allocation |
| Make C fragment | Register allocation |
| Execute MMA: d = a @ b + c | Dispatches to hardware MMA |
| API | Purpose | Notes |
|---|---|---|
| Create tiled copy | For bulk data movement |
| Execute copy | |
| Element-wise copy | No atom needed |
| Prefetch TMA descriptor | — |
| Create TMA atom for A | Host-side only |
| Create TMA atom for B | Host-side only |
| API | Purpose |
|---|---|
| Create SMEM allocator |
| Allocate tensor in SMEM |
| Allocate struct in SMEM |
| API | Purpose |
|---|---|
| Create TMEM allocator |
| Allocate columns |
| Wait for allocation |
| Get pointer |
| Free memory |
| API | Purpose |
|---|---|
| Initialize barrier |
| Arrive at barrier |
| Wait at barrier |
| Arrive with expected TX bytes |
| Non-blocking wait |
| Pipeline classes (see references) | Multi-stage async pipelines |
| API | Purpose | API | Purpose |
|---|---|---|---|
| e^x | | 2^x |
| ln(x) | | log2(x) |
| sqrt | | 1/sqrt(x) |
| sin | | cos |
| tanh | | erf |
fastmath=True| API | Purpose |
|---|---|
| Fill with value |
| Zero tensor matching shape |
| Conditional select |
| Logical any |
| Logical all |
| API | Purpose |
|---|---|
| Device-side printf |
| Print tensor contents |
| Compile-time print (meta-stage only) |
@cute.kernel@cute.jit@cute.kerneldocs/devices/--checkprint()cute.printf()tcgen05tcgen05tcgen05sm_120atcgen05nvidia-smi --query-gpu=name --format=csv,noheader| Construct | Behavior | Notes |
|---|---|---|
| Unrolled if N is static | Standard Python range |
| Runtime loop | Generates MLIR loop |
| Compile-time unroll | N must be static |
| Compile-time branch | Eliminated at compile time |
| Runtime branch | Both branches must type-check |
SM100 required.ops are available on B200/B300 (SM100) only. They are not available on consumer Blackwell GPUs like RTX 5090 (SM120/tcgen05).sm_120a
from cutlass.cute.nvgpu import tcgen05
# Create MMA operation for SM100 (B200/B300) tensor cores
op = tcgen05.MmaF16BF16Op(
dtype=cutlass.Float16, # Input type
acc_dtype=cutlass.Float32, # Accumulator type
shape=(128, 128, 64), # M x N x K tile shape
cta_group=tcgen05.CtaGroup.ONE, # ONE or TWO CTAs
)
tiled_mma = cute.make_tiled_mma(op)| Operation | Input types | Shapes |
|---|---|---|
| FP16/BF16 → FP32 | Various M×N×K |
| FP8/FP6/FP4 → FP32 | Narrow precision |
| Block-scaled variants | With scale factors | See examples |
@cute.kernel
def gemm_kernel(tiled_mma, tma_a, tma_b, smem_a, smem_b, gmem_c, ...):
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, _ = cute.arch.block_idx()
# Get thread's MMA partition
thr_mma = tiled_mma.get_slice(tidx)
# Allocate accumulator in registers
acc = tiled_mma.partition_shape_C(tile_shape)
tCrC = cute.make_fragment(acc, cutlass.Float32)
# K-loop with TMA pipeline
for k_tile in range(num_k_tiles):
# TMA copy: global → shared (async)
cute.copy(tma_a, gA_tile, sA_tile, tma_bar_ptr=barrier)
cute.copy(tma_b, gB_tile, sB_tile, tma_bar_ptr=barrier)
# Wait for TMA to complete
cute.arch.mbar.wait(barrier, phase)
# Partition shared memory for this thread
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sB)
# MMA: accumulate into registers
cute.gemm(tiled_mma, tCrC, tCsA, tCsB, tCrC)
# Epilogue: write back to global memory
cute.basic_copy(tCrC, gC_partition)warp_idx = cute.arch.warp_idx()
is_producer = warp_idx < num_producer_warps
if is_producer:
# Issue TMA copies for upcoming pipeline stages
for stage in range(num_stages):
cute.copy(tma_atom, src, dst, tma_bar_ptr=barriers[stage])
cute.arch.mbar.arrive_and_expect_tx(barriers[stage], bytes_per_stage)
else:
# Consume data from shared memory, execute MMA
for stage in range(num_stages):
cute.arch.mbar.wait(barriers[stage], phase)
cute.gemm(tiled_mma, acc, sA_stage, sB_stage, acc)@cute.kernel
def persistent_kernel(tiled_mma, num_tiles, ...):
bidx, _, _ = cute.arch.block_idx()
num_blocks = cute.arch.grid_dim()[0]
tile_id = bidx
while tile_id < num_tiles:
# Process tile
# ... TMA copy, MMA, epilogue ...
tile_id += num_blocks/learn-cute-dsl