kernel-triton-writing

Original🇺🇸 English
Translated
3 scripts

ONLY for OpenAI Triton (@triton.jit) kernel development. NEVER use for CUDA C++ kernels, TileIR, or profiling tools (ncu, nsys). The user's request must involve Triton explicitly. Covers Triton-specific patterns: fused elementwise, reductions (softmax, LayerNorm, RMSNorm), tiled GEMM with triton.autotune, and flash attention. Workflow: design, write, verify (with fast-path for explicit requests).

4installs
Added on

NPX Install

npx skill4agent add nvidia/skills kernel-triton-writing

Triton Kernel Writing

Principles

Correctness First

  1. Never benchmark before verification passes.
  2. Always mask loads and stores for non-divisible shapes.
  3. Include
    kernel_fn
    ,
    reference_fn
    , and
    get_inputs()
    exports for companion scripts.
  4. Always run
    scripts/verify_kernel.py
    to validate against the reference.

FP16/BF16 Precision Rules (LOW FREEDOM -- follow exactly)

Transcendental functions (
tl.exp
,
tl.log
,
tl.math.erf
,
tl.math.tanh
) require fp32 inputs.
python
# WRONG -- compilation error or wrong results with fp16/bf16:
result = tl.exp(x)

# CORRECT -- cast to fp32, compute, cast back:
x_fp32 = x.to(tl.float32)
result = tl.exp(x_fp32).to(x.dtype)
Rule: any math function beyond basic arithmetic (+, -, *, /) requires fp32 cast in, original dtype cast out.
Additional precision constraints:
  • tl.sigmoid()
    is unavailable in some Triton versions. Use
    1.0 / (1.0 + tl.exp(-x_fp32))
    .
  • Always cast back to
    x.dtype
    before
    tl.store
    -- mismatches cause "Type mismatch, store Float32 to Float16".
  • Unlike PyTorch, Triton does NOT auto-promote fp16/bf16 to fp32 for accumulation. Always use
    tl.float32
    accumulators for
    tl.dot
    .
  • TF32 for matmul: On Ampere+/Hopper,
    tl.dot
    uses TF32 by default for fp32 inputs (same as
    torch.mm
    ). Do NOT add
    input_precision="ieee"
    — it is 3-8x slower. TF32 is the correct default. If verification fails due to TF32 precision (~0.01-0.1 abs diff), ensure
    reference_fn
    also uses TF32 (plain
    torch.mm
    , no
    allow_tf32=False
    ).

CPU-GPU Sync Avoidance (LOW FREEDOM)

Never call
.item()
in kernel wrappers. It forces a CPU-GPU sync (~50-100us per call).
PitfallFix
tensor.item()
for seed
x.data_ptr() % (2**31)
torch.randint(...).item()
Use tensor metadata for pseudo-random seed
Allocating output every callAccept pre-allocated output as parameter
Python loops calling kernelBatch operations

C Integer Division Semantics (CRITICAL)

Triton uses C semantics (round toward zero) for
//
and
%
, NOT Python semantics (round toward negative infinity). This only matters when operands can be negative.
ExpressionPythonTriton/C
-7 // 2
-4
-3
-7 % 2
1
-1
Safe pattern: Ensure all index/offset values are non-negative. If negative values are possible, use
(idx % BLOCK + BLOCK) % BLOCK
.
See references/concepts-semantics.md for full rules and scalar-only exception.

Kernel Design Mental Model

  • Parallelization axis: Element-wise kernels parallelize over flattened elements. Row-wise kernels (LayerNorm, softmax) parallelize over rows. Matmul kernels tile in 2D (M, N).
  • Block size: Power-of-2 only (256, 512, 1024, 2048). Start with 1024 for H100, 512 for V100.
  • Memory coalescing: Adjacent threads must access adjacent memory addresses. The compiler handles this automatically from block-level pointer arithmetic.
  • Grid: Use
    triton.cdiv(n_elements, BLOCK_SIZE)
    . With autotune, grid must be a lambda:
    lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),)
    .
  • Decorator order:
    @triton.autotune
    (outermost) ->
    @triton.heuristics
    ->
    @triton.jit
    (innermost).
  • reset_to_zero
    :
    Required for autotune on kernels that accumulate (e.g., matmul output). Without it, later configs see leftover values from earlier trials.

Workflow

Fast path: If the user explicitly requests a Triton kernel (e.g., "Write a Triton kernel for X", "Implement softmax in Triton"), start at Phase 2. Only use Phase 0-1 when the request is ambiguous about whether Triton is appropriate.

Phase 0: Route the Operator (only for ambiguous requests)

Skip this phase if the user explicitly asks for a Triton kernel. Only use when the request is ambiguous (e.g., "optimize this operation").
Triton wins when 2+ operations can share registers instead of writing/reading global memory. Quick rules:
PatternDecision
Single element-wise op (
relu
,
sigmoid
)
SKIP — PyTorch already optimal
Standalone matmulSKIP — cuBLAS is optimized
Standard attentionSKIP — Use FlashAttention
Element-wise chain (2+ ops), reduction, matmul + epilogueUSE TRITON
If SKIP, recommend the alternative and STOP. See references/operator-routing.md for edge cases.

Phase 1: Analyze the Operator (only for ambiguous requests)

From the user's request, identify: (1) operation type, (2) parallelization strategy, (3) input shapes and dtypes.

Phase 2: Design the Kernel

Pick the skeleton below that matches your operation. These skeletons are sufficient for element-wise, reduction, matmul, and fusion kernels — do NOT read reference files for these common patterns. Only consult
references/
when implementing uncommon patterns (grouped GEMM, TMA, extern functions) or debugging issues.
Element-wise skeleton (GELU, dropout, fused ops on flat tensors):
python
@triton.jit
def kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # ... compute ...
    tl.store(out_ptr + offsets, result, mask=mask)
Row-wise skeleton (softmax, LayerNorm, RMSNorm — one program per row):
python
@triton.jit
def kernel(x_ptr, out_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    x = tl.load(x_ptr + row_idx * n_cols + col_offsets, mask=mask, other=0.0)
    # ... reduce / normalize ...
    tl.store(out_ptr + row_idx * n_cols + col_offsets, result, mask=mask)
Tiled matmul skeleton (GEMM with 2D tiling, grouped ordering, and autotune):
python
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr, M, N, K,
    stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
):
    pid = tl.program_id(0)
    num_m_blocks = tl.cdiv(M, BLOCK_M)
    num_n_blocks = tl.cdiv(N, BLOCK_N)
    # Grouped ordering for L2 cache locality
    num_pid_in_group = GROUP_SIZE_M * num_n_blocks
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_m_blocks - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
        b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
        a = tl.load(a_ptrs, mask=a_mask, other=0.0)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)
        acc += tl.dot(a, b)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
        offs_k += BLOCK_K

    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)

Phase 3: Write the Kernel

Create an output directory, then write the kernel file to
{output_dir}/kernel.py
.
The kernel file MUST include:
  • @triton.jit
    decorated kernel function
  • @triton.autotune
    for production kernels (see references/api-core.md)
  • Python wrapper function (descriptive name for external import)
  • Fixed contract exports (companion scripts rely on these exact names):
    • kernel_fn
      — alias to the wrapper function
    • reference_fn(*args)
      — PyTorch reference with identical signature
    • get_inputs()
      — returns
      list
      of fresh CUDA tensors for testing/benchmarking
Concise example (fused GELU + dropout):
python
import triton
import triton.language as tl
import torch

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 1024}, num_warps=4),
        triton.Config({'BLOCK_SIZE': 2048}, num_warps=8),
    ],
    key=['n_elements'],
)
@triton.jit
def fused_gelu_dropout_kernel(
    x_ptr, out_ptr, n_elements, p, seed,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    x_fp32 = x.to(tl.float32)
    x = (0.5 * x_fp32 * (1.0 + tl.math.erf(x_fp32 * 0.7071067811865476))).to(x.dtype)

    random = tl.rand(seed, offsets)
    x = tl.where(random > p, x / (1.0 - p), 0.0)

    tl.store(out_ptr + offsets, x, mask=mask)


def fused_gelu_dropout_triton(x: torch.Tensor, p: float = 0.1) -> torch.Tensor:
    n_elements = x.numel()
    out = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    seed = (x.data_ptr() % (2**31)) ^ n_elements  # sync-free seed
    fused_gelu_dropout_kernel[grid](x, out, n_elements, p, seed)
    return out


# --- Fixed contract (companion scripts rely on these names) ---
kernel_fn = fused_gelu_dropout_triton

def reference_fn(x, p=0.1):
    torch.manual_seed((x.data_ptr() % (2**31)) ^ x.numel())
    return torch.nn.functional.dropout(
        torch.nn.functional.gelu(x), p, training=True
    )

def get_inputs():
    return [torch.randn(128 * 1024 * 1024, device="cuda")]
For more patterns (SiLU+mul, RMSNorm, linear+GELU, add+LayerNorm), see references/patterns-fusion.md. For GEMM patterns, see references/patterns-gemm.md.

Phase 4: Verify Correctness

Run the companion verification script:
bash
python scripts/verify_kernel.py {output_dir}/kernel.py --rtol 1e-3 --atol 1e-3
Output:
json
{"correct": true, "max_abs_diff": 1.2e-7, "max_rel_diff": 3.4e-6, "details": "..."}
Stop if
correct: false
.
Fix the kernel before benchmarking.
Tolerance guide:
DtypertolatolNotes
float161e-31e-3
bfloat161e-21e-2
float321e-51e-5Element-wise ops
float32 (matmul)1e-21e-1TF32 accumulation order differs between Triton tiles and cuBLAS

Phase 5: Benchmark Performance (optional)

Only benchmark if the user explicitly requests performance numbers. Skip this phase for correctness-focused requests.
bash
python scripts/benchmark_kernel.py {output_dir}/kernel.py
Output:
json
{"kernel_time_ms": 0.45, "reference_time_ms": 1.23, "speedup": 2.73, "warmup_iters": 10, "benchmark_iters": 40}

References (consult only when stuck)

The skeletons and principles above cover element-wise, reduction, matmul, and fusion kernels. Do NOT read reference files for these common patterns.
Only consult
references/
when:
  • Implementing uncommon patterns (grouped GEMM, TMA, persistent matmul, extern functions)
  • Debugging a compile error or incorrect result not covered by the error table below
  • Needing API details for an unfamiliar
    tl.*
    operation
How to search: Grep for your keyword across
references/
. Read only the file Grep points to.
FileWhen to use
references/api-core.md
Unfamiliar
triton.autotune
/
triton.Config
options
references/api-language.md
Unfamiliar
tl.*
operations
references/patterns-gemm.md
Grouped GEMM, persistent matmul, TMA, MX formats
references/patterns-advanced.md
Flash attention details, backward passes, libdevice
references/troubleshooting.md
Debug ops, interpreter mode, env vars

Error Handling and Troubleshooting

Common Errors

Error / SymptomCauseFix
"Type mismatch, store Float32 to Float16"Missing
.to(x.dtype)
before store
Cast fp32 result back
BLOCK_SIZE is not a constexpr
Block size passed as runtime valueAdd
: tl.constexpr
annotation
shape mismatch
in binary op
Tensor shapes don't broadcastCheck with
tl.static_print
; use
[:, None]
/
[None, :]
Large diffs everywhereWrong dtype in
tl.load
Check load dtype matches input
Matmul 3-8x slower than expected
input_precision="ieee"
on
tl.dot
Remove it; use TF32 default. Ensure
reference_fn
also uses TF32
Matmul ~0.01-0.1 abs diff vs referenceTF32 vs IEEE mismatchUse same precision in both kernel and reference (TF32 for both)
Diffs at boundariesMissing maskAdd mask to all load/store ops
Random diffsRace conditionCheck atomics and ordering
NaN/InfDivision by zero or fp16 overflowGuard with epsilon; use
tl.float32
accumulator
grid must be a tuple
Grid lambda returns int, not tupleReturn
(value,)
with trailing comma
expected constexpr
in
tl.arange
Non-constexpr argumentBoth args of
tl.arange(start, end)
must be constexpr
triton.OutOfResources
Register/shared memory pressureReduce BLOCK_SIZE or
num_stages
Kernel not updating after editStale compilation cache
rm -rf ~/.triton/cache/
Mismatched results vs PyTorchC integer division semanticsTriton uses truncation; see
references/concepts-semantics.md
For extended error table, interpreter mode issues, and environment variables, see references/troubleshooting.md.

When to Abort

Stop and report failure if:
  1. Not a good fit -- Pure matmul or complex control flow (Phase 0 should catch this).
  2. Verification fails after 3 attempts -- Numerical issues too severe to fix.
  3. No speedup -- Reference is already well-optimized (cuBLAS, cuDNN).
  4. Hardware mismatch -- Target GPU not available for testing.