Add 2D grid program_id semantics (ADR-0022)

tl.program_id(axis=0) returns local PE id within cube,
tl.program_id(axis=1) returns cube id. Enables cube-aware
sharding in benchmark kernels.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-09 16:49:56 -07:00
parent dc3fb02aed
commit ff2c677a9c
5 changed files with 138 additions and 3 deletions
+90
View File
@@ -0,0 +1,90 @@
# ADR-0022: 2D Grid program_id Semantics
- **Status**: Accepted
- **Date**: 2026-04-09
- **Context**: Triton-style kernel addressing for multi-cube PE topology
## Problem
Triton kernels use `tl.program_id(axis)` to identify their position in a launch grid.
Our hardware has a 2-level hierarchy: **cubes** contain **PEs**.
The previous implementation ignored the `axis` parameter and always returned a flat PE index,
making it impossible for kernels to distinguish their cube-local position from their cube identity.
## Decision
Map `tl.program_id` and `tl.num_programs` to the 2D hardware grid:
| Call | Returns | Description |
|------|---------|-------------|
| `tl.program_id(axis=0)` | `local_pe_id` | PE index within cube |
| `tl.program_id(axis=1)` | `cube_id` | Cube index |
| `tl.num_programs(axis=0)` | `num_pes_per_cube` | PEs per cube |
| `tl.num_programs(axis=1)` | `num_cubes` | Total cubes |
Global PID is derived as:
```python
global_pid = tl.program_id(axis=1) * tl.num_programs(axis=0) + tl.program_id(axis=0)
```
### Axis mapping rationale
- **axis=0 = PE (innermost)**: PEs within a cube share HBM and communicate via local NOC mesh. This is the fast, tightly-coupled dimension — analogous to threads within a block.
- **axis=1 = Cube (outer)**: Cross-cube communication goes through UCIe with higher latency. This is the coarser scheduling dimension — analogous to blocks in a grid.
## Implementation
### TLContext (`triton_emu/tl_context.py`)
Added `cube_id` and `num_cubes` constructor parameters. `program_id()` and `num_programs()` dispatch on `axis`:
```python
def program_id(self, axis: int = 0) -> int:
if axis == 1:
return self._cube_id
return self._pe_id
def num_programs(self, axis: int = 0) -> int:
if axis == 1:
return self._num_cubes
return self._num_programs
```
### PE_CPU (`components/builtin/pe_cpu.py`)
- Extracts `num_cubes` from `ctx.spec["system"]["sips"]["cubes_per_sip"]`
- Passes `cube_id` (already available as `self._cube_idx`) and `num_cubes` to TLContext
### KernelRunner (`triton_emu/kernel_runner.py`)
- Receives `num_cubes` from PE_CPU
- Passes `cube_id` and `num_cubes` to TLContext in greenlet mode
## Backward Compatibility
- Existing code using `tl.program_id(0)` or `tl.program_id()` is unchanged — returns the same PE index as before.
- `cube_id` and `num_cubes` default to `0` and `1`, so callers that don't provide them (e.g. unit tests) continue to work.
## Usage Example
```python
def sharded_gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl):
local_pid = tl.program_id(axis=0) # PE within cube
cube_id = tl.program_id(axis=1) # which cube
global_pid = cube_id * tl.num_programs(axis=0) + local_pid
# Column-wise sharding across global PID
n_per_pid = N // (tl.num_programs(axis=1) * tl.num_programs(axis=0))
col_start = global_pid * n_per_pid
a = tl.load(a_ptr, shape=(M, K), dtype="f16")
b = tl.ref(b_ptr + col_start * K * 2, shape=(K, n_per_pid), dtype="f16")
h = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr + col_start * M * 2)
tl.wait(h)
```
## Consequences
- Benchmarks can now express cube-aware sharding and addressing without hardcoding topology dimensions.
- Future axis=2 (SIP-level) can be added following the same pattern if needed.
+9 -1
View File
@@ -42,6 +42,9 @@ class PeCpuComponent(ComponentBase):
self._cube_idx = int(parts[1].replace("cube", "")) self._cube_idx = int(parts[1].replace("cube", ""))
except (IndexError, ValueError): except (IndexError, ValueError):
self._cube_idx = 0 self._cube_idx = 0
# num_cubes from spec (for tl.program_id(axis=1))
spec = ctx.spec if ctx else {}
self._num_cubes = spec.get("system", {}).get("sips", {}).get("cubes_per_sip", 1)
def _find_shard(self, shards: tuple) -> Any: def _find_shard(self, shards: tuple) -> Any:
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index.""" """Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
@@ -139,6 +142,7 @@ class PeCpuComponent(ComponentBase):
pe_idx=self._pe_idx, pe_idx=self._pe_idx,
sip_idx=self._sip_idx, sip_idx=self._sip_idx,
cube_idx=self._cube_idx, cube_idx=self._cube_idx,
num_cubes=self._num_cubes,
scheduler_id=scheduler_id, scheduler_id=scheduler_id,
out_ports=self.out_ports, out_ports=self.out_ports,
store=store, store=store,
@@ -155,7 +159,11 @@ class PeCpuComponent(ComponentBase):
) )
from kernbench.triton_emu.tl_context import TLContext, run_kernel from kernbench.triton_emu.tl_context import TLContext, run_kernel
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0) tl = TLContext(
pe_id=self._pe_idx, num_programs=num_programs,
cube_id=self._cube_idx, num_cubes=self._num_cubes,
dispatch_cycles=0,
)
run_kernel(kernel_fn, tl, *kernel_args) run_kernel(kernel_fn, tl, *kernel_args)
commands = tl.commands commands = tl.commands
@@ -50,11 +50,13 @@ class KernelRunner:
scheduler_id: str, scheduler_id: str,
out_ports: dict[str, simpy.Store], out_ports: dict[str, simpy.Store],
store: MemoryStore | None = None, store: MemoryStore | None = None,
num_cubes: int = 1,
) -> None: ) -> None:
self._pe_prefix = pe_prefix self._pe_prefix = pe_prefix
self._pe_idx = pe_idx self._pe_idx = pe_idx
self._sip_idx = sip_idx self._sip_idx = sip_idx
self._cube_idx = cube_idx self._cube_idx = cube_idx
self._num_cubes = num_cubes
self._scheduler_id = scheduler_id self._scheduler_id = scheduler_id
self._out_ports = out_ports self._out_ports = out_ports
self._store = store self._store = store
@@ -83,6 +85,8 @@ class KernelRunner:
tl = TLContext( tl = TLContext(
pe_id=self._pe_idx, pe_id=self._pe_idx,
num_programs=num_programs, num_programs=num_programs,
cube_id=self._cube_idx,
num_cubes=self._num_cubes,
dispatch_cycles=0, dispatch_cycles=0,
runner=self, runner=self,
) )
+18 -2
View File
@@ -53,9 +53,13 @@ class TLContext:
num_programs: int = 1, num_programs: int = 1,
dispatch_cycles: int = 1, dispatch_cycles: int = 1,
runner: Any = None, runner: Any = None,
cube_id: int = 0,
num_cubes: int = 1,
) -> None: ) -> None:
self._pe_id = pe_id self._pe_id = pe_id
self._num_programs = num_programs self._num_programs = num_programs
self._cube_id = cube_id
self._num_cubes = num_cubes
self._dispatch_cycles = dispatch_cycles self._dispatch_cycles = dispatch_cycles
self._commands: list[PeCommand] = [] self._commands: list[PeCommand] = []
self._handle_counter = 0 self._handle_counter = 0
@@ -234,11 +238,23 @@ class TLContext:
# ── Index / Scalar (PE_CPU, no engine) ──────────────────────── # ── Index / Scalar (PE_CPU, no engine) ────────────────────────
def program_id(self, axis: int = 0) -> int: def program_id(self, axis: int = 0) -> int:
"""Return program instance index.""" """Return program instance index.
axis=0: local PE id within cube.
axis=1: cube id.
"""
if axis == 1:
return self._cube_id
return self._pe_id return self._pe_id
def num_programs(self, axis: int = 0) -> int: def num_programs(self, axis: int = 0) -> int:
"""Return total number of program instances.""" """Return total number of program instances.
axis=0: num PEs per cube.
axis=1: num cubes.
"""
if axis == 1:
return self._num_cubes
return self._num_programs return self._num_programs
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle: def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
+17
View File
@@ -196,6 +196,23 @@ def test_tl_program_id():
assert tl.num_programs(0) == 8 assert tl.num_programs(0) == 8
def test_tl_program_id_axis1():
"""axis=1 returns cube_id and num_cubes."""
tl = TLContext(pe_id=3, num_programs=8, cube_id=7, num_cubes=16)
assert tl.program_id(0) == 3
assert tl.program_id(1) == 7
assert tl.num_programs(0) == 8
assert tl.num_programs(1) == 16
def test_tl_program_id_global():
"""global_pid = cube_id * num_pes_per_cube + local_pe_id."""
pe_id, cube_id, num_pes = 5, 3, 8
tl = TLContext(pe_id=pe_id, num_programs=num_pes, cube_id=cube_id, num_cubes=16)
global_pid = tl.program_id(1) * tl.num_programs(0) + tl.program_id(0)
assert global_pid == cube_id * num_pes + pe_id
# ── 12. tl.arange, tl.zeros, tl.full ───────────────────────────── # ── 12. tl.arange, tl.zeros, tl.full ─────────────────────────────