Fix cross-SIP PE_TCM access by scoping deploy to target_device SIP
RuntimeContext._ensure_allocators() now limits SIP range to target_device (single SIP or all). Prevents cross-SIP tensor deployment that caused PE_TCM routing errors. Also accept 'sip0' format (without colon) in DeviceSelector. 331 passed, 8 skipped Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -173,7 +173,7 @@ class RuntimeContext:
|
||||
pe_comps = pe_template.get("components", {})
|
||||
tcm_cfg = pe_comps.get("pe_tcm", {}).get("attrs", {})
|
||||
|
||||
sip_count = system.get("sips", {}).get("count", 1)
|
||||
total_sip_count = system.get("sips", {}).get("count", 1)
|
||||
cubes_per_sip = system.get("sips", {}).get("cubes_per_sip", 16)
|
||||
pes_per_cube = (
|
||||
cube.get("pe_layout", {}).get("pe_per_corner", 2)
|
||||
@@ -183,6 +183,17 @@ class RuntimeContext:
|
||||
hbm_slices = mm.get("hbm_slices_per_cube", 8)
|
||||
tcm_mb = tcm_cfg.get("size_mb", 16)
|
||||
|
||||
# Scope to target_device: single SIP or all SIPs
|
||||
from kernbench.runtime_api.types import DeviceSelector, resolve_device
|
||||
td = self.target_device if isinstance(self.target_device, DeviceSelector) else resolve_device(str(self.target_device))
|
||||
if td.is_all:
|
||||
sip_range = range(total_sip_count)
|
||||
sip_count = total_sip_count
|
||||
else:
|
||||
sip_idx = td.sip_index
|
||||
sip_range = range(sip_idx, sip_idx + 1)
|
||||
sip_count = 1
|
||||
|
||||
cfg = AddressConfig(
|
||||
sip_count=sip_count,
|
||||
cubes_per_sip=cubes_per_sip,
|
||||
@@ -193,13 +204,13 @@ class RuntimeContext:
|
||||
tcm_scheduler_reserved_bytes=4 * (1 << 20),
|
||||
sram_bytes_per_cube=32 * (1 << 20),
|
||||
)
|
||||
# Create allocators for all SIPs × cubes × PEs
|
||||
# Create allocators scoped to target SIP(s) only
|
||||
# Flat index: sip_id * cubes_per_sip * pes_per_cube + cube_id * pes_per_cube + pe_id
|
||||
self._pes_per_cube = pes_per_cube
|
||||
self._num_cubes = cubes_per_sip
|
||||
self._num_sips = sip_count
|
||||
cubes_x_pes = cubes_per_sip * pes_per_cube
|
||||
for sip_id in range(sip_count):
|
||||
for sip_id in sip_range:
|
||||
for cube_id in range(cubes_per_sip):
|
||||
for pe_id in range(pes_per_cube):
|
||||
flat_idx = sip_id * cubes_x_pes + cube_id * pes_per_cube + pe_id
|
||||
|
||||
@@ -41,7 +41,7 @@ class DeviceSelector:
|
||||
def sip_index(self) -> int:
|
||||
if self.is_all:
|
||||
raise ValueError("DeviceSelector is 'all'; no single sip_index.")
|
||||
m = re.fullmatch(r"sip:(\d+)", self.raw)
|
||||
m = re.fullmatch(r"sip:?(\d+)", self.raw)
|
||||
if not m:
|
||||
raise ValueError(
|
||||
f"Invalid device '{self.raw}'. Expected 'all' or 'sip:<N>' (e.g., sip:0)."
|
||||
@@ -64,8 +64,9 @@ def resolve_device(raw: str | None) -> DeviceSelector:
|
||||
if raw == "all":
|
||||
return DeviceSelector(raw="all")
|
||||
|
||||
m = re.fullmatch(r"sip:(\d+)", raw)
|
||||
m = re.fullmatch(r"sip:?(\d+)", raw)
|
||||
if not m:
|
||||
raise ValueError(f"Invalid device '{raw}'. Expected 'all' or 'sip:<N>' (e.g., sip:0).")
|
||||
raw = f"sip:{m.group(1)}" # normalize to sip:N format
|
||||
|
||||
return DeviceSelector(raw=raw)
|
||||
|
||||
Reference in New Issue
Block a user