ff2c677a9c
tl.program_id(axis=0) returns local PE id within cube, tl.program_id(axis=1) returns cube id. Enables cube-aware sharding in benchmark kernels. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
3.2 KiB
3.2 KiB
ADR-0022: 2D Grid program_id Semantics
- Status: Accepted
- Date: 2026-04-09
- Context: Triton-style kernel addressing for multi-cube PE topology
Problem
Triton kernels use tl.program_id(axis) to identify their position in a launch grid.
Our hardware has a 2-level hierarchy: cubes contain PEs.
The previous implementation ignored the axis parameter and always returned a flat PE index,
making it impossible for kernels to distinguish their cube-local position from their cube identity.
Decision
Map tl.program_id and tl.num_programs to the 2D hardware grid:
| Call | Returns | Description |
|---|---|---|
tl.program_id(axis=0) |
local_pe_id |
PE index within cube |
tl.program_id(axis=1) |
cube_id |
Cube index |
tl.num_programs(axis=0) |
num_pes_per_cube |
PEs per cube |
tl.num_programs(axis=1) |
num_cubes |
Total cubes |
Global PID is derived as:
global_pid = tl.program_id(axis=1) * tl.num_programs(axis=0) + tl.program_id(axis=0)
Axis mapping rationale
- axis=0 = PE (innermost): PEs within a cube share HBM and communicate via local NOC mesh. This is the fast, tightly-coupled dimension — analogous to threads within a block.
- axis=1 = Cube (outer): Cross-cube communication goes through UCIe with higher latency. This is the coarser scheduling dimension — analogous to blocks in a grid.
Implementation
TLContext (triton_emu/tl_context.py)
Added cube_id and num_cubes constructor parameters. program_id() and num_programs() dispatch on axis:
def program_id(self, axis: int = 0) -> int:
if axis == 1:
return self._cube_id
return self._pe_id
def num_programs(self, axis: int = 0) -> int:
if axis == 1:
return self._num_cubes
return self._num_programs
PE_CPU (components/builtin/pe_cpu.py)
- Extracts
num_cubesfromctx.spec["system"]["sips"]["cubes_per_sip"] - Passes
cube_id(already available asself._cube_idx) andnum_cubesto TLContext
KernelRunner (triton_emu/kernel_runner.py)
- Receives
num_cubesfrom PE_CPU - Passes
cube_idandnum_cubesto TLContext in greenlet mode
Backward Compatibility
- Existing code using
tl.program_id(0)ortl.program_id()is unchanged — returns the same PE index as before. cube_idandnum_cubesdefault to0and1, so callers that don't provide them (e.g. unit tests) continue to work.
Usage Example
def sharded_gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl):
local_pid = tl.program_id(axis=0) # PE within cube
cube_id = tl.program_id(axis=1) # which cube
global_pid = cube_id * tl.num_programs(axis=0) + local_pid
# Column-wise sharding across global PID
n_per_pid = N // (tl.num_programs(axis=1) * tl.num_programs(axis=0))
col_start = global_pid * n_per_pid
a = tl.load(a_ptr, shape=(M, K), dtype="f16")
b = tl.ref(b_ptr + col_start * K * 2, shape=(K, n_per_pid), dtype="f16")
h = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr + col_start * M * 2)
tl.wait(h)
Consequences
- Benchmarks can now express cube-aware sharding and addressing without hardcoding topology dimensions.
- Future axis=2 (SIP-level) can be added following the same pattern if needed.