From ff2c677a9cf3e6ca118a2bbef626f750b6db5f86 Mon Sep 17 00:00:00 2001 From: Yangwook Kang Date: Thu, 9 Apr 2026 16:49:56 -0700 Subject: [PATCH] Add 2D grid program_id semantics (ADR-0022) tl.program_id(axis=0) returns local PE id within cube, tl.program_id(axis=1) returns cube id. Enables cube-aware sharding in benchmark kernels. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/adr/ADR-0022-program-id-2d-grid.md | 90 ++++++++++++++++++++++ src/kernbench/components/builtin/pe_cpu.py | 10 ++- src/kernbench/triton_emu/kernel_runner.py | 4 + src/kernbench/triton_emu/tl_context.py | 20 ++++- tests/test_triton_emu.py | 17 ++++ 5 files changed, 138 insertions(+), 3 deletions(-) create mode 100644 docs/adr/ADR-0022-program-id-2d-grid.md diff --git a/docs/adr/ADR-0022-program-id-2d-grid.md b/docs/adr/ADR-0022-program-id-2d-grid.md new file mode 100644 index 0000000..9bf7966 --- /dev/null +++ b/docs/adr/ADR-0022-program-id-2d-grid.md @@ -0,0 +1,90 @@ +# ADR-0022: 2D Grid program_id Semantics + +- **Status**: Accepted +- **Date**: 2026-04-09 +- **Context**: Triton-style kernel addressing for multi-cube PE topology + +## Problem + +Triton kernels use `tl.program_id(axis)` to identify their position in a launch grid. +Our hardware has a 2-level hierarchy: **cubes** contain **PEs**. +The previous implementation ignored the `axis` parameter and always returned a flat PE index, +making it impossible for kernels to distinguish their cube-local position from their cube identity. + +## Decision + +Map `tl.program_id` and `tl.num_programs` to the 2D hardware grid: + +| Call | Returns | Description | +|------|---------|-------------| +| `tl.program_id(axis=0)` | `local_pe_id` | PE index within cube | +| `tl.program_id(axis=1)` | `cube_id` | Cube index | +| `tl.num_programs(axis=0)` | `num_pes_per_cube` | PEs per cube | +| `tl.num_programs(axis=1)` | `num_cubes` | Total cubes | + +Global PID is derived as: + +```python +global_pid = tl.program_id(axis=1) * tl.num_programs(axis=0) + tl.program_id(axis=0) +``` + +### Axis mapping rationale + +- **axis=0 = PE (innermost)**: PEs within a cube share HBM and communicate via local NOC mesh. This is the fast, tightly-coupled dimension — analogous to threads within a block. +- **axis=1 = Cube (outer)**: Cross-cube communication goes through UCIe with higher latency. This is the coarser scheduling dimension — analogous to blocks in a grid. + +## Implementation + +### TLContext (`triton_emu/tl_context.py`) + +Added `cube_id` and `num_cubes` constructor parameters. `program_id()` and `num_programs()` dispatch on `axis`: + +```python +def program_id(self, axis: int = 0) -> int: + if axis == 1: + return self._cube_id + return self._pe_id + +def num_programs(self, axis: int = 0) -> int: + if axis == 1: + return self._num_cubes + return self._num_programs +``` + +### PE_CPU (`components/builtin/pe_cpu.py`) + +- Extracts `num_cubes` from `ctx.spec["system"]["sips"]["cubes_per_sip"]` +- Passes `cube_id` (already available as `self._cube_idx`) and `num_cubes` to TLContext + +### KernelRunner (`triton_emu/kernel_runner.py`) + +- Receives `num_cubes` from PE_CPU +- Passes `cube_id` and `num_cubes` to TLContext in greenlet mode + +## Backward Compatibility + +- Existing code using `tl.program_id(0)` or `tl.program_id()` is unchanged — returns the same PE index as before. +- `cube_id` and `num_cubes` default to `0` and `1`, so callers that don't provide them (e.g. unit tests) continue to work. + +## Usage Example + +```python +def sharded_gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl): + local_pid = tl.program_id(axis=0) # PE within cube + cube_id = tl.program_id(axis=1) # which cube + global_pid = cube_id * tl.num_programs(axis=0) + local_pid + + # Column-wise sharding across global PID + n_per_pid = N // (tl.num_programs(axis=1) * tl.num_programs(axis=0)) + col_start = global_pid * n_per_pid + + a = tl.load(a_ptr, shape=(M, K), dtype="f16") + b = tl.ref(b_ptr + col_start * K * 2, shape=(K, n_per_pid), dtype="f16") + h = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr + col_start * M * 2) + tl.wait(h) +``` + +## Consequences + +- Benchmarks can now express cube-aware sharding and addressing without hardcoding topology dimensions. +- Future axis=2 (SIP-level) can be added following the same pattern if needed. diff --git a/src/kernbench/components/builtin/pe_cpu.py b/src/kernbench/components/builtin/pe_cpu.py index 4947b9d..455b7de 100644 --- a/src/kernbench/components/builtin/pe_cpu.py +++ b/src/kernbench/components/builtin/pe_cpu.py @@ -42,6 +42,9 @@ class PeCpuComponent(ComponentBase): self._cube_idx = int(parts[1].replace("cube", "")) except (IndexError, ValueError): self._cube_idx = 0 + # num_cubes from spec (for tl.program_id(axis=1)) + spec = ctx.spec if ctx else {} + self._num_cubes = spec.get("system", {}).get("sips", {}).get("cubes_per_sip", 1) def _find_shard(self, shards: tuple) -> Any: """Find shard matching this PE's (sip, cube, pe). Fallback to positional index.""" @@ -139,6 +142,7 @@ class PeCpuComponent(ComponentBase): pe_idx=self._pe_idx, sip_idx=self._sip_idx, cube_idx=self._cube_idx, + num_cubes=self._num_cubes, scheduler_id=scheduler_id, out_ports=self.out_ports, store=store, @@ -155,7 +159,11 @@ class PeCpuComponent(ComponentBase): ) from kernbench.triton_emu.tl_context import TLContext, run_kernel - tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0) + tl = TLContext( + pe_id=self._pe_idx, num_programs=num_programs, + cube_id=self._cube_idx, num_cubes=self._num_cubes, + dispatch_cycles=0, + ) run_kernel(kernel_fn, tl, *kernel_args) commands = tl.commands diff --git a/src/kernbench/triton_emu/kernel_runner.py b/src/kernbench/triton_emu/kernel_runner.py index afc75d3..593733f 100644 --- a/src/kernbench/triton_emu/kernel_runner.py +++ b/src/kernbench/triton_emu/kernel_runner.py @@ -50,11 +50,13 @@ class KernelRunner: scheduler_id: str, out_ports: dict[str, simpy.Store], store: MemoryStore | None = None, + num_cubes: int = 1, ) -> None: self._pe_prefix = pe_prefix self._pe_idx = pe_idx self._sip_idx = sip_idx self._cube_idx = cube_idx + self._num_cubes = num_cubes self._scheduler_id = scheduler_id self._out_ports = out_ports self._store = store @@ -83,6 +85,8 @@ class KernelRunner: tl = TLContext( pe_id=self._pe_idx, num_programs=num_programs, + cube_id=self._cube_idx, + num_cubes=self._num_cubes, dispatch_cycles=0, runner=self, ) diff --git a/src/kernbench/triton_emu/tl_context.py b/src/kernbench/triton_emu/tl_context.py index 4f9732d..3498a84 100644 --- a/src/kernbench/triton_emu/tl_context.py +++ b/src/kernbench/triton_emu/tl_context.py @@ -53,9 +53,13 @@ class TLContext: num_programs: int = 1, dispatch_cycles: int = 1, runner: Any = None, + cube_id: int = 0, + num_cubes: int = 1, ) -> None: self._pe_id = pe_id self._num_programs = num_programs + self._cube_id = cube_id + self._num_cubes = num_cubes self._dispatch_cycles = dispatch_cycles self._commands: list[PeCommand] = [] self._handle_counter = 0 @@ -234,11 +238,23 @@ class TLContext: # ── Index / Scalar (PE_CPU, no engine) ──────────────────────── def program_id(self, axis: int = 0) -> int: - """Return program instance index.""" + """Return program instance index. + + axis=0: local PE id within cube. + axis=1: cube id. + """ + if axis == 1: + return self._cube_id return self._pe_id def num_programs(self, axis: int = 0) -> int: - """Return total number of program instances.""" + """Return total number of program instances. + + axis=0: num PEs per cube. + axis=1: num cubes. + """ + if axis == 1: + return self._num_cubes return self._num_programs def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle: diff --git a/tests/test_triton_emu.py b/tests/test_triton_emu.py index e144c80..77b4568 100644 --- a/tests/test_triton_emu.py +++ b/tests/test_triton_emu.py @@ -196,6 +196,23 @@ def test_tl_program_id(): assert tl.num_programs(0) == 8 +def test_tl_program_id_axis1(): + """axis=1 returns cube_id and num_cubes.""" + tl = TLContext(pe_id=3, num_programs=8, cube_id=7, num_cubes=16) + assert tl.program_id(0) == 3 + assert tl.program_id(1) == 7 + assert tl.num_programs(0) == 8 + assert tl.num_programs(1) == 16 + + +def test_tl_program_id_global(): + """global_pid = cube_id * num_pes_per_cube + local_pe_id.""" + pe_id, cube_id, num_pes = 5, 3, 8 + tl = TLContext(pe_id=pe_id, num_programs=num_pes, cube_id=cube_id, num_cubes=16) + global_pid = tl.program_id(1) * tl.num_programs(0) + tl.program_id(0) + assert global_pid == cube_id * num_pes + pe_id + + # ── 12. tl.arange, tl.zeros, tl.full ─────────────────────────────