Add 2D grid program_id semantics (ADR-0022)

tl.program_id(axis=0) returns local PE id within cube,
tl.program_id(axis=1) returns cube id. Enables cube-aware
sharding in benchmark kernels.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-09 16:49:56 -07:00
parent dc3fb02aed
commit ff2c677a9c
5 changed files with 138 additions and 3 deletions
+9 -1
View File
@@ -42,6 +42,9 @@ class PeCpuComponent(ComponentBase):
self._cube_idx = int(parts[1].replace("cube", ""))
except (IndexError, ValueError):
self._cube_idx = 0
# num_cubes from spec (for tl.program_id(axis=1))
spec = ctx.spec if ctx else {}
self._num_cubes = spec.get("system", {}).get("sips", {}).get("cubes_per_sip", 1)
def _find_shard(self, shards: tuple) -> Any:
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
@@ -139,6 +142,7 @@ class PeCpuComponent(ComponentBase):
pe_idx=self._pe_idx,
sip_idx=self._sip_idx,
cube_idx=self._cube_idx,
num_cubes=self._num_cubes,
scheduler_id=scheduler_id,
out_ports=self.out_ports,
store=store,
@@ -155,7 +159,11 @@ class PeCpuComponent(ComponentBase):
)
from kernbench.triton_emu.tl_context import TLContext, run_kernel
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
tl = TLContext(
pe_id=self._pe_idx, num_programs=num_programs,
cube_id=self._cube_idx, num_cubes=self._num_cubes,
dispatch_cycles=0,
)
run_kernel(kernel_fn, tl, *kernel_args)
commands = tl.commands
@@ -50,11 +50,13 @@ class KernelRunner:
scheduler_id: str,
out_ports: dict[str, simpy.Store],
store: MemoryStore | None = None,
num_cubes: int = 1,
) -> None:
self._pe_prefix = pe_prefix
self._pe_idx = pe_idx
self._sip_idx = sip_idx
self._cube_idx = cube_idx
self._num_cubes = num_cubes
self._scheduler_id = scheduler_id
self._out_ports = out_ports
self._store = store
@@ -83,6 +85,8 @@ class KernelRunner:
tl = TLContext(
pe_id=self._pe_idx,
num_programs=num_programs,
cube_id=self._cube_idx,
num_cubes=self._num_cubes,
dispatch_cycles=0,
runner=self,
)
+18 -2
View File
@@ -53,9 +53,13 @@ class TLContext:
num_programs: int = 1,
dispatch_cycles: int = 1,
runner: Any = None,
cube_id: int = 0,
num_cubes: int = 1,
) -> None:
self._pe_id = pe_id
self._num_programs = num_programs
self._cube_id = cube_id
self._num_cubes = num_cubes
self._dispatch_cycles = dispatch_cycles
self._commands: list[PeCommand] = []
self._handle_counter = 0
@@ -234,11 +238,23 @@ class TLContext:
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
def program_id(self, axis: int = 0) -> int:
"""Return program instance index."""
"""Return program instance index.
axis=0: local PE id within cube.
axis=1: cube id.
"""
if axis == 1:
return self._cube_id
return self._pe_id
def num_programs(self, axis: int = 0) -> int:
"""Return total number of program instances."""
"""Return total number of program instances.
axis=0: num PEs per cube.
axis=1: num cubes.
"""
if axis == 1:
return self._num_cubes
return self._num_programs
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle: