adding-cutile-kernel
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseAdding a cuTile Kernel to TileGym
向TileGym中添加cuTile内核
End-to-end workflow for adding a new operator (e.g., ) with cuTile backend.
my_op为带有cuTile后端的新算子(例如)提供端到端工作流。
my_opExecution Rules
执行规则
MUST follow these rules strictly:
- Use TodoWrite to create the checklist below BEFORE writing any code
- Execute steps in order — do NOT skip ahead or combine steps
- Mark each todo as after finishing,
completedwhen startingin_progress - If a step is not applicable (e.g., no cuTile impl), mark it with a note, do NOT silently skip
completed - Each step MUST result in a file write or explicit skip decision — no silent omissions
必须严格遵循以下规则:
- 在编写任何代码之前,使用TodoWrite创建下面的检查清单
- 按顺序执行步骤——不要提前跳过或合并步骤
- 完成每个待办项后标记为,开始时标记为
completedin_progress - 如果某个步骤不适用(例如,无需cuTile实现),标记为并添加注释,不要悄悄跳过
completed - 每个步骤必须产生文件写入或明确的跳过决定——不得有任何遗漏
Instructions
说明
MUST copy this checklist to TodoWrite at the start:
- [ ] Step 1: Register dispatch interface in ops.py
- [ ] Step 2: Implement cuTile backend
- [ ] Step 3: Register in __init__.py (cutile)
- [ ] Step 4: Add tests
- [ ] Step 5: Add benchmark to tests/benchmark
- [ ] Step 6: Verify (run pytest + lint)必须在开始时将此检查清单复制到TodoWrite中:
- [ ] Step 1: Register dispatch interface in ops.py
- [ ] Step 2: Implement cuTile backend
- [ ] Step 3: Register in __init__.py (cutile)
- [ ] Step 4: Add tests
- [ ] Step 5: Add benchmark to tests/benchmark
- [ ] Step 6: Verify (run pytest + lint)Step 1: Register dispatch interface
步骤1:注册调度接口
File:
src/tilegym/ops/ops.pyAdd a function — this is the single entry point for all backends.
@dispatchpython
@dispatch(
"my_op",
)
def my_op(
input: torch.Tensor,
out: Optional[torch.Tensor] = None,
**kwargs: Any,
):
"""
Description of my_op.
Args:
input: Input tensor
out: Optional preallocated output tensor
**kwargs: Additional arguments for backend-specific configurations
Returns:
torch.Tensor
"""
raise NotImplementedError(f"my_op is not implemented for {get_current_backend()}")Key rules:
- Function body only raises
NotImplementedError - Include for backend-specific parameters
**kwargs
Reference: See existing ops in (e.g., , )
src/tilegym/ops/ops.pysilu_and_mulsoftmax文件:
src/tilegym/ops/ops.py添加一个函数——这是所有后端的单一入口点。
@dispatchpython
@dispatch(
"my_op",
)
def my_op(
input: torch.Tensor,
out: Optional[torch.Tensor] = None,
**kwargs: Any,
):
"""
Description of my_op.
Args:
input: Input tensor
out: Optional preallocated output tensor
**kwargs: Additional arguments for backend-specific configurations
Returns:
torch.Tensor
"""
raise NotImplementedError(f"my_op is not implemented for {get_current_backend()}")关键规则:
- 函数体仅抛出
NotImplementedError - 包含以支持后端特定参数
**kwargs
参考:查看中的现有算子(例如、)
src/tilegym/ops/ops.pysilu_and_mulsoftmaxStep 2: Implement cuTile backend
步骤2:实现cuTile后端
File:
src/tilegym/ops/cutile/my_op.pyThe file structure follows this template:
python
import torch
import cuda.tile as ct
from tilegym.backend import register_impl
@ct.kernel
def my_op_kernel_ct(x, output, n_elements: ct.Constant[int], BLOCK_SIZE: ct.Constant[int]):
bid = ct.bid(0)
indices = bid * BLOCK_SIZE + ct.arange(0, BLOCK_SIZE)
x_val = ct.gather(x, indices)
# ... compute ...
ct.scatter(output, indices, result)
@register_impl("my_op", backend="cutile")
def my_op(input: torch.Tensor, out: torch.Tensor = None, **kwargs) -> torch.Tensor:
n = input.numel()
if out is None:
out = torch.empty_like(input)
grid = ((n + 1023) // 1024,)
ct.launch(stream, grid, kernel, (some args, ...))
return outReference:
src/tilegym/ops/cutile/silu_and_mul.py文件:
src/tilegym/ops/cutile/my_op.py文件结构遵循以下模板:
python
import torch
import cuda.tile as ct
from tilegym.backend import register_impl
@ct.kernel
def my_op_kernel_ct(x, output, n_elements: ct.Constant[int], BLOCK_SIZE: ct.Constant[int]):
bid = ct.bid(0)
indices = bid * BLOCK_SIZE + ct.arange(0, BLOCK_SIZE)
x_val = ct.gather(x, indices)
# ... compute ...
ct.scatter(output, indices, result)
@register_impl("my_op", backend="cutile")
def my_op(input: torch.Tensor, out: torch.Tensor = None, **kwargs) -> torch.Tensor:
n = input.numel()
if out is None:
out = torch.empty_like(input)
grid = ((n + 1023) // 1024,)
ct.launch(stream, grid, kernel, (some args, ...))
return out参考:
src/tilegym/ops/cutile/silu_and_mul.pyStep 3: Register in __init__.py
(CRITICAL)
__init__.py步骤3:在__init__.py
中注册(至关重要)
__init__.pyMissing this step means the cuTile backend implementation never gets loaded.
File:
src/tilegym/ops/cutile/__init__.pyAdd inside block (alphabetically):
if is_backend_available("cutile"):python
from . import my_opAnd in the function import section:
python
from .my_op import my_opAnd add to .
"my_op"__all__遗漏此步骤会导致cuTile后端实现永远不会被加载。
文件:
src/tilegym/ops/cutile/__init__.py在代码块中按字母顺序添加:
if is_backend_available("cutile"):python
from . import my_op并在函数导入部分添加:
python
from .my_op import my_op同时将添加到中。
"my_op"__all__Step 4: Add tests
步骤4:添加测试
File:
tests/ops/test_my_op.pyCRITICAL: Always import from , NEVER from .
tilegym.opstilegym.ops.cutile.my_oppython
import pytest
import torch
from tilegym.backend import is_backend_available, set_backend
from .. import common
_backends = ["cutile"]
class Test_MY_OP(common.PyTestCase):
@staticmethod
def reference(input):
"""Reference implementation using PyTorch."""
return torch.some_reference(input)
@pytest.mark.parametrize("shape, dtype", [
((1024,), torch.float16),
((1024, 512), torch.float32),
((64, 64, 64), torch.bfloat16),
])
@pytest.mark.parametrize("backend", _backends)
def test_op(self, shape, dtype, backend, arch):
if backend == "cutile" and not is_backend_available("cutile"):
pytest.skip("Cutile backend not available")
try:
set_backend(backend)
except Exception as e:
pytest.skip(f"Backend is not supported: {e}")
self.setUp()
from tilegym.ops import my_op
A = torch.randn(*shape, dtype=dtype, device="cuda")
self.assertCorrectness(
my_op, self.reference, {"input": A},
atol=1e-3, rtol=1e-3,
)Key patterns:
_backends = ["cutile"]- : use
test_opwith try-except, callset_backend(backend)self.setUp()
Reference:
tests/ops/test_silu_and_mul.pyBelow is the common errors.
1. Missing _backends list (inside class)
2. test_op / test_op_xxx — missing @pytest.mark.parametrize("backend", _backends), backend parameter, and tilegym.is_backend_available / tilegym.set_backend pattern文件:
tests/ops/test_my_op.py至关重要:始终从导入,切勿从导入。
tilegym.opstilegym.ops.cutile.my_oppython
import pytest
import torch
from tilegym.backend import is_backend_available, set_backend
from .. import common
_backends = ["cutile"]
class Test_MY_OP(common.PyTestCase):
@staticmethod
def reference(input):
"""Reference implementation using PyTorch."""
return torch.some_reference(input)
@pytest.mark.parametrize("shape, dtype", [
((1024,), torch.float16),
((1024, 512), torch.float32),
((64, 64, 64), torch.bfloat16),
])
@pytest.mark.parametrize("backend", _backends)
def test_op(self, shape, dtype, backend, arch):
if backend == "cutile" and not is_backend_available("cutile"):
pytest.skip("Cutile backend not available")
try:
set_backend(backend)
except Exception as e:
pytest.skip(f"Backend is not supported: {e}")
self.setUp()
from tilegym.ops import my_op
A = torch.randn(*shape, dtype=dtype, device="cuda")
self.assertCorrectness(
my_op, self.reference, {"input": A},
atol=1e-3, rtol=1e-3,
)关键模式:
_backends = ["cutile"]- :使用带有try-except的
test_op,调用set_backend(backend)self.setUp()
参考:
tests/ops/test_silu_and_mul.py以下是常见错误:
1. Missing _backends list (inside class)
2. test_op / test_op_xxx — missing @pytest.mark.parametrize("backend", _backends), backend parameter, and tilegym.is_backend_available / tilegym.set_backend patternStep 5: Add benchmark to tests/benchmark
步骤5:向tests/benchmark添加基准测试
File:
tests/benchmark/bench_my_op.pyKey rules from benchmark_rules.md:
- Call the op via — do not use
tilegym.ops.my_op(a, b, ..., backend=backend).set_backend - Define (include at least
ALL_BACKENDSandcutile), filter withtorch.get_supported_backends() - Implement and register it:
reference_my_op(...).register_impl("my_op", "torch")(reference_my_op) - Use to build
create_benchmark_config()configs (e.g. by shape/dtype).triton.testing.Benchmark - Use on
@triton.testing.perf_report([...]); inside the bench function: correctness check withbench_my_op(...), thentorch.testing.assert_close(fn(), ref(), ...)(orms = triton.testing.do_bench(fn)), compute GB/s or TFLOPS, and return the metric.do_bench_cudagraph - Entry point: .
if __name__ == "__main__": bench_my_op.run(print_data=True)
Template structure:
python
import torch
import triton
import triton.testing
import tilegym
from tilegym.backend import is_backend_available, register_impl
ALL_BACKENDS = [
("cutile", "cuTile", ("orange", "-")) if is_backend_available("cutile") else None,
("torch", "PyTorch", ("green", "-")),
]
def get_supported_backends():
return [p for p in ALL_BACKENDS if p is not None]
def reference_my_op(input: torch.Tensor, out: torch.Tensor = None, **kwargs):
"""Reference implementation using PyTorch."""
...
register_impl("my_op", "torch")(reference_my_op)
def create_benchmark_config(datatype, ...):
available_backends = get_supported_backends()
if not available_backends:
return None
backends, names, styles = zip(*available_backends)
return triton.testing.Benchmark(
x_names=["M"], # or other dimension names
x_vals=[...],
line_arg="backend",
line_vals=list(backends),
line_names=list(names),
styles=list(styles),
ylabel="GB/s", # or TFLOPS
plot_name="my-op-...",
args={"datatype": datatype, ...},
)
@triton.testing.perf_report([
create_benchmark_config(datatype, ...)
for datatype in [torch.float16, torch.float32]
for ... in [...]
])
def bench_my_op(M, backend, datatype, ..., device="cuda"):
x = torch.randn(..., dtype=datatype, device=device)
fn = lambda: tilegym.ops.my_op(x, backend=backend)
ref = lambda: reference_my_op(x)
torch.testing.assert_close(fn(), ref(), rtol=1e-2, atol=1e-2)
ms = triton.testing.do_bench(fn) # or do_bench_cudagraph(fn)
# Compute metric (e.g. GB/s or TFLOPS) from ms and problem size
return metric
if __name__ == "__main__":
bench_my_op.run(print_data=True)Benchmark Plot Names: Must include or suffix
-TFLOPS-GBps- Example:
plot_name=f"persistent-layer-norm-M{num_rows}-{dtype_name}-GBps"
文件:
tests/benchmark/bench_my_op.py来自benchmark_rules.md的关键规则:
- 通过调用算子——不要使用
tilegym.ops.my_op(a, b, ..., backend=backend)。set_backend - 定义(至少包含
ALL_BACKENDS和cutile),使用torch进行过滤。get_supported_backends() - 实现并注册:
reference_my_op(...)。register_impl("my_op", "torch")(reference_my_op) - 使用构建
create_benchmark_config()配置(例如按形状/数据类型)。triton.testing.Benchmark - 在上使用
bench_my_op(...);在基准测试函数内部:使用@triton.testing.perf_report([...])进行正确性检查,然后使用torch.testing.assert_close(fn(), ref(), ...)(或ms = triton.testing.do_bench(fn)),计算GB/s或TFLOPS,并返回该指标。do_bench_cudagraph - 入口点:。
if __name__ == "__main__": bench_my_op.run(print_data=True)
模板结构:
python
import torch
import triton
import triton.testing
import tilegym
from tilegym.backend import is_backend_available, register_impl
ALL_BACKENDS = [
("cutile", "cuTile", ("orange", "-")) if is_backend_available("cutile") else None,
("torch", "PyTorch", ("green", "-")),
]
def get_supported_backends():
return [p for p in ALL_BACKENDS if p is not None]
def reference_my_op(input: torch.Tensor, out: torch.Tensor = None, **kwargs):
"""Reference implementation using PyTorch."""
...
register_impl("my_op", "torch")(reference_my_op)
def create_benchmark_config(datatype, ...):
available_backends = get_supported_backends()
if not available_backends:
return None
backends, names, styles = zip(*available_backends)
return triton.testing.Benchmark(
x_names=["M"], # or other dimension names
x_vals=[...],
line_arg="backend",
line_vals=list(backends),
line_names=list(names),
styles=list(styles),
ylabel="GB/s", # or TFLOPS
plot_name="my-op-...",
args={"datatype": datatype, ...},
)
@triton.testing.perf_report([
create_benchmark_config(datatype, ...)
for datatype in [torch.float16, torch.float32]
for ... in [...]
])
def bench_my_op(M, backend, datatype, ..., device="cuda"):
x = torch.randn(..., dtype=datatype, device=device)
fn = lambda: tilegym.ops.my_op(x, backend=backend)
ref = lambda: reference_my_op(x)
torch.testing.assert_close(fn(), ref(), rtol=1e-2, atol=1e-2)
ms = triton.testing.do_bench(fn) # or do_bench_cudagraph(fn)
# Compute metric (e.g. GB/s or TFLOPS) from ms and problem size
return metric
if __name__ == "__main__":
bench_my_op.run(print_data=True)基准测试图表名称:必须包含或后缀
-TFLOPS-GBps- 示例:
plot_name=f"persistent-layer-norm-M{num_rows}-{dtype_name}-GBps"
Step 6: Verify
步骤6:验证
bash
undefinedbash
undefinedRun tests
Run tests
pytest tests/ops/test_my_op.py -v
pytest tests/ops/test_my_op.py -v
Run benchmark (optional)
Run benchmark (optional)
python tests/benchmark/bench_my_op.py
python tests/benchmark/bench_my_op.py
Lint
Lint
pre-commit run -a
undefinedpre-commit run -a
undefined