Loading...
Loading...
cuTile Python DSL kernel implementation patterns, CtKernel runtime wrapper, suitability gate, and cuTile-specific pitfalls. Use when: (1) creating or modifying a cuTile Python DSL kernel version, (2) implementing an optimization that still fits within cuTile's exposed control surface, (3) deciding whether cuTile is still the right DSL, (4) reviewing cuTile-specific runtime patterns. Always also load /design-kernel for shared naming, versioning, and workflow.
npx skill4agent add pepperu96/hyper-mla design-cutile-dsl-kernel/design-kernelcutile-dslct.bid()ct.num_blocks()num_ctasoccupancylatencyallow_tmaKernelPipelineConcurrentKernelsnum_ctas/design-cute-dsl-kerneldocs/kernels/flash-mla.mdcutile-dslcutilesrc/mla_var3/kernel/cutile/<layer>/<design>/<design>[_vN]/<design>[_vN].pyCtKernelCtKernel(Kernel)grid_fnargs_fnfrom mla_var3.runtime import CtKernel, KernelPlan, Tiling, ConstInt, ConstBool
@ct.kernel
def my_kernel(Tensor, ..., Bm: ConstInt, Bn: ConstInt, EVEN_N: ConstBool):
bid_x = ct.bid(0)
# block-level operations: ct.load, ct.mma, ct.store, ct.reshape, ...
@dataclass
class CtMyKernel(KernelPlan):
b: int = 64; s: int = 1; t: int = 4096
tiling: MyTiling = field(default_factory=MyTiling)
def plan(self, *inputs) -> CtKernel:
Out = torch.empty_like(inputs[0])
def grid_fn(cfg):
return (math.ceil(s / cfg.Bm), math.ceil(h / cfg.Bh), b)
def args_fn(cfg):
return (inputs[0], Out, cfg.Bm, cfg.Bn, (t % cfg.Bn) == 0)
return CtKernel(
input_tensors=inputs,
output_tensors=(Out,),
kernel_fn=my_kernel,
grid_fn=grid_fn,
args_fn=args_fn,
tiling=self.tiling,
autotune_configs=self._autotune_configs(),
algorithmic_flops_bytes_fn=self._algorithmic_flops_bytes,
)CtKernel| Field | Type | Purpose |
|---|---|---|
| | The cuTile kernel function |
| | Maps tiling config to 3D grid dimensions |
| | Maps tiling config to kernel arguments |
| | For autotuning cache key |
| | Returned by |
| | Current tiling config |
| | Search space for autotuning |
| | For roofline analysis |
CtKernelautotune_launch()cuda.tile_experimental_apply_hints()CtKernel.compile()compile_tile().bytecode.mlircuda-tile-translate@dataclass
class MyTiling(Tiling):
Bm: int = 16 # query tile size
Bn: int = 64 # KV tile size
Bh: int = 8 # heads per block
num_ctas: int = None # CGA size (None = auto)
occupancy: int = None # occupancy hint (None = auto)
def validate(self, pd: "CtMyKernel") -> bool:
return self.Bm <= pd.s and self.Bn <= pd.t and self.Bh <= pd.hfrom mla_var3.runtime import ConstInt, ConstBool, ConstFloat, INV_LOG_2ConstInt = ct.Constant[int]ConstBool = ct.Constant[bool]ConstFloat = ct.Constant[float]/cutile-dsl-ref@ct.kerneldocs/devices/src/mla_var3/conf/devices.jsonnum_ctasdocs/knowledge/optimizations/docs/knowledge/anti-patterns/docs/knowledge/languages/cutile-dsl//optimization-catalog-cutile-dsl/design-kernel/design-kerneldocs/kernels/flash-mla.md