Fix cross-SIP PE_TCM access by scoping deploy to target_device SIP

RuntimeContext._ensure_allocators() now limits SIP range to
target_device (single SIP or all). Prevents cross-SIP tensor
deployment that caused PE_TCM routing errors.
Also accept 'sip0' format (without colon) in DeviceSelector.

331 passed, 8 skipped

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-04 18:03:11 -07:00
parent 624161f52f
commit 08256c1326
5 changed files with 19 additions and 11 deletions
+14 -3
View File
@@ -173,7 +173,7 @@ class RuntimeContext:
pe_comps = pe_template.get("components", {})
tcm_cfg = pe_comps.get("pe_tcm", {}).get("attrs", {})
sip_count = system.get("sips", {}).get("count", 1)
total_sip_count = system.get("sips", {}).get("count", 1)
cubes_per_sip = system.get("sips", {}).get("cubes_per_sip", 16)
pes_per_cube = (
cube.get("pe_layout", {}).get("pe_per_corner", 2)
@@ -183,6 +183,17 @@ class RuntimeContext:
hbm_slices = mm.get("hbm_slices_per_cube", 8)
tcm_mb = tcm_cfg.get("size_mb", 16)
# Scope to target_device: single SIP or all SIPs
from kernbench.runtime_api.types import DeviceSelector, resolve_device
td = self.target_device if isinstance(self.target_device, DeviceSelector) else resolve_device(str(self.target_device))
if td.is_all:
sip_range = range(total_sip_count)
sip_count = total_sip_count
else:
sip_idx = td.sip_index
sip_range = range(sip_idx, sip_idx + 1)
sip_count = 1
cfg = AddressConfig(
sip_count=sip_count,
cubes_per_sip=cubes_per_sip,
@@ -193,13 +204,13 @@ class RuntimeContext:
tcm_scheduler_reserved_bytes=4 * (1 << 20),
sram_bytes_per_cube=32 * (1 << 20),
)
# Create allocators for all SIPs × cubes × PEs
# Create allocators scoped to target SIP(s) only
# Flat index: sip_id * cubes_per_sip * pes_per_cube + cube_id * pes_per_cube + pe_id
self._pes_per_cube = pes_per_cube
self._num_cubes = cubes_per_sip
self._num_sips = sip_count
cubes_x_pes = cubes_per_sip * pes_per_cube
for sip_id in range(sip_count):
for sip_id in sip_range:
for cube_id in range(cubes_per_sip):
for pe_id in range(pes_per_cube):
flat_idx = sip_id * cubes_x_pes + cube_id * pes_per_cube + pe_id
+3 -2
View File
@@ -41,7 +41,7 @@ class DeviceSelector:
def sip_index(self) -> int:
if self.is_all:
raise ValueError("DeviceSelector is 'all'; no single sip_index.")
m = re.fullmatch(r"sip:(\d+)", self.raw)
m = re.fullmatch(r"sip:?(\d+)", self.raw)
if not m:
raise ValueError(
f"Invalid device '{self.raw}'. Expected 'all' or 'sip:<N>' (e.g., sip:0)."
@@ -64,8 +64,9 @@ def resolve_device(raw: str | None) -> DeviceSelector:
if raw == "all":
return DeviceSelector(raw="all")
m = re.fullmatch(r"sip:(\d+)", raw)
m = re.fullmatch(r"sip:?(\d+)", raw)
if not m:
raise ValueError(f"Invalid device '{raw}'. Expected 'all' or 'sip:<N>' (e.g., sip:0).")
raw = f"sip:{m.group(1)}" # normalize to sip:N format
return DeviceSelector(raw=raw)