Files
kernbench2/docs/adr/ADR-0022-program-id-2d-grid.md
T
ywkang ff2c677a9c Add 2D grid program_id semantics (ADR-0022)
tl.program_id(axis=0) returns local PE id within cube,
tl.program_id(axis=1) returns cube id. Enables cube-aware
sharding in benchmark kernels.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-09 16:49:56 -07:00

3.2 KiB

ADR-0022: 2D Grid program_id Semantics

  • Status: Accepted
  • Date: 2026-04-09
  • Context: Triton-style kernel addressing for multi-cube PE topology

Problem

Triton kernels use tl.program_id(axis) to identify their position in a launch grid. Our hardware has a 2-level hierarchy: cubes contain PEs. The previous implementation ignored the axis parameter and always returned a flat PE index, making it impossible for kernels to distinguish their cube-local position from their cube identity.

Decision

Map tl.program_id and tl.num_programs to the 2D hardware grid:

Call Returns Description
tl.program_id(axis=0) local_pe_id PE index within cube
tl.program_id(axis=1) cube_id Cube index
tl.num_programs(axis=0) num_pes_per_cube PEs per cube
tl.num_programs(axis=1) num_cubes Total cubes

Global PID is derived as:

global_pid = tl.program_id(axis=1) * tl.num_programs(axis=0) + tl.program_id(axis=0)

Axis mapping rationale

  • axis=0 = PE (innermost): PEs within a cube share HBM and communicate via local NOC mesh. This is the fast, tightly-coupled dimension — analogous to threads within a block.
  • axis=1 = Cube (outer): Cross-cube communication goes through UCIe with higher latency. This is the coarser scheduling dimension — analogous to blocks in a grid.

Implementation

TLContext (triton_emu/tl_context.py)

Added cube_id and num_cubes constructor parameters. program_id() and num_programs() dispatch on axis:

def program_id(self, axis: int = 0) -> int:
    if axis == 1:
        return self._cube_id
    return self._pe_id

def num_programs(self, axis: int = 0) -> int:
    if axis == 1:
        return self._num_cubes
    return self._num_programs

PE_CPU (components/builtin/pe_cpu.py)

  • Extracts num_cubes from ctx.spec["system"]["sips"]["cubes_per_sip"]
  • Passes cube_id (already available as self._cube_idx) and num_cubes to TLContext

KernelRunner (triton_emu/kernel_runner.py)

  • Receives num_cubes from PE_CPU
  • Passes cube_id and num_cubes to TLContext in greenlet mode

Backward Compatibility

  • Existing code using tl.program_id(0) or tl.program_id() is unchanged — returns the same PE index as before.
  • cube_id and num_cubes default to 0 and 1, so callers that don't provide them (e.g. unit tests) continue to work.

Usage Example

def sharded_gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl):
    local_pid = tl.program_id(axis=0)      # PE within cube
    cube_id   = tl.program_id(axis=1)      # which cube
    global_pid = cube_id * tl.num_programs(axis=0) + local_pid

    # Column-wise sharding across global PID
    n_per_pid = N // (tl.num_programs(axis=1) * tl.num_programs(axis=0))
    col_start = global_pid * n_per_pid

    a = tl.load(a_ptr, shape=(M, K), dtype="f16")
    b = tl.ref(b_ptr + col_start * K * 2, shape=(K, n_per_pid), dtype="f16")
    h = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr + col_start * M * 2)
    tl.wait(h)

Consequences

  • Benchmarks can now express cube-aware sharding and addressing without hardcoding topology dimensions.
  • Future axis=2 (SIP-level) can be added following the same pattern if needed.