Add 2D grid program_id semantics (ADR-0022)
tl.program_id(axis=0) returns local PE id within cube, tl.program_id(axis=1) returns cube id. Enables cube-aware sharding in benchmark kernels. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -50,11 +50,13 @@ class KernelRunner:
|
||||
scheduler_id: str,
|
||||
out_ports: dict[str, simpy.Store],
|
||||
store: MemoryStore | None = None,
|
||||
num_cubes: int = 1,
|
||||
) -> None:
|
||||
self._pe_prefix = pe_prefix
|
||||
self._pe_idx = pe_idx
|
||||
self._sip_idx = sip_idx
|
||||
self._cube_idx = cube_idx
|
||||
self._num_cubes = num_cubes
|
||||
self._scheduler_id = scheduler_id
|
||||
self._out_ports = out_ports
|
||||
self._store = store
|
||||
@@ -83,6 +85,8 @@ class KernelRunner:
|
||||
tl = TLContext(
|
||||
pe_id=self._pe_idx,
|
||||
num_programs=num_programs,
|
||||
cube_id=self._cube_idx,
|
||||
num_cubes=self._num_cubes,
|
||||
dispatch_cycles=0,
|
||||
runner=self,
|
||||
)
|
||||
|
||||
@@ -53,9 +53,13 @@ class TLContext:
|
||||
num_programs: int = 1,
|
||||
dispatch_cycles: int = 1,
|
||||
runner: Any = None,
|
||||
cube_id: int = 0,
|
||||
num_cubes: int = 1,
|
||||
) -> None:
|
||||
self._pe_id = pe_id
|
||||
self._num_programs = num_programs
|
||||
self._cube_id = cube_id
|
||||
self._num_cubes = num_cubes
|
||||
self._dispatch_cycles = dispatch_cycles
|
||||
self._commands: list[PeCommand] = []
|
||||
self._handle_counter = 0
|
||||
@@ -234,11 +238,23 @@ class TLContext:
|
||||
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
|
||||
|
||||
def program_id(self, axis: int = 0) -> int:
|
||||
"""Return program instance index."""
|
||||
"""Return program instance index.
|
||||
|
||||
axis=0: local PE id within cube.
|
||||
axis=1: cube id.
|
||||
"""
|
||||
if axis == 1:
|
||||
return self._cube_id
|
||||
return self._pe_id
|
||||
|
||||
def num_programs(self, axis: int = 0) -> int:
|
||||
"""Return total number of program instances."""
|
||||
"""Return total number of program instances.
|
||||
|
||||
axis=0: num PEs per cube.
|
||||
axis=1: num cubes.
|
||||
"""
|
||||
if axis == 1:
|
||||
return self._num_cubes
|
||||
return self._num_programs
|
||||
|
||||
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
|
||||
|
||||
Reference in New Issue
Block a user