Fix cross-SIP PE_TCM access by scoping deploy to target_device SIP
RuntimeContext._ensure_allocators() now limits SIP range to target_device (single SIP or all). Prevents cross-SIP tensor deployment that caused PE_TCM routing errors. Also accept 'sip0' format (without colon) in DeviceSelector. 331 passed, 8 skipped Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -173,7 +173,7 @@ class RuntimeContext:
|
||||
pe_comps = pe_template.get("components", {})
|
||||
tcm_cfg = pe_comps.get("pe_tcm", {}).get("attrs", {})
|
||||
|
||||
sip_count = system.get("sips", {}).get("count", 1)
|
||||
total_sip_count = system.get("sips", {}).get("count", 1)
|
||||
cubes_per_sip = system.get("sips", {}).get("cubes_per_sip", 16)
|
||||
pes_per_cube = (
|
||||
cube.get("pe_layout", {}).get("pe_per_corner", 2)
|
||||
@@ -183,6 +183,17 @@ class RuntimeContext:
|
||||
hbm_slices = mm.get("hbm_slices_per_cube", 8)
|
||||
tcm_mb = tcm_cfg.get("size_mb", 16)
|
||||
|
||||
# Scope to target_device: single SIP or all SIPs
|
||||
from kernbench.runtime_api.types import DeviceSelector, resolve_device
|
||||
td = self.target_device if isinstance(self.target_device, DeviceSelector) else resolve_device(str(self.target_device))
|
||||
if td.is_all:
|
||||
sip_range = range(total_sip_count)
|
||||
sip_count = total_sip_count
|
||||
else:
|
||||
sip_idx = td.sip_index
|
||||
sip_range = range(sip_idx, sip_idx + 1)
|
||||
sip_count = 1
|
||||
|
||||
cfg = AddressConfig(
|
||||
sip_count=sip_count,
|
||||
cubes_per_sip=cubes_per_sip,
|
||||
@@ -193,13 +204,13 @@ class RuntimeContext:
|
||||
tcm_scheduler_reserved_bytes=4 * (1 << 20),
|
||||
sram_bytes_per_cube=32 * (1 << 20),
|
||||
)
|
||||
# Create allocators for all SIPs × cubes × PEs
|
||||
# Create allocators scoped to target SIP(s) only
|
||||
# Flat index: sip_id * cubes_per_sip * pes_per_cube + cube_id * pes_per_cube + pe_id
|
||||
self._pes_per_cube = pes_per_cube
|
||||
self._num_cubes = cubes_per_sip
|
||||
self._num_sips = sip_count
|
||||
cubes_x_pes = cubes_per_sip * pes_per_cube
|
||||
for sip_id in range(sip_count):
|
||||
for sip_id in sip_range:
|
||||
for cube_id in range(cubes_per_sip):
|
||||
for pe_id in range(pes_per_cube):
|
||||
flat_idx = sip_id * cubes_x_pes + cube_id * pes_per_cube + pe_id
|
||||
|
||||
@@ -41,7 +41,7 @@ class DeviceSelector:
|
||||
def sip_index(self) -> int:
|
||||
if self.is_all:
|
||||
raise ValueError("DeviceSelector is 'all'; no single sip_index.")
|
||||
m = re.fullmatch(r"sip:(\d+)", self.raw)
|
||||
m = re.fullmatch(r"sip:?(\d+)", self.raw)
|
||||
if not m:
|
||||
raise ValueError(
|
||||
f"Invalid device '{self.raw}'. Expected 'all' or 'sip:<N>' (e.g., sip:0)."
|
||||
@@ -64,8 +64,9 @@ def resolve_device(raw: str | None) -> DeviceSelector:
|
||||
if raw == "all":
|
||||
return DeviceSelector(raw="all")
|
||||
|
||||
m = re.fullmatch(r"sip:(\d+)", raw)
|
||||
m = re.fullmatch(r"sip:?(\d+)", raw)
|
||||
if not m:
|
||||
raise ValueError(f"Invalid device '{raw}'. Expected 'all' or 'sip:<N>' (e.g., sip:0).")
|
||||
raw = f"sip:{m.group(1)}" # normalize to sip:N format
|
||||
|
||||
return DeviceSelector(raw=raw)
|
||||
|
||||
+2
-2
@@ -18,5 +18,5 @@ def test_cli_main_arg_parsing(monkeypatch):
|
||||
|
||||
def test_cli_main():
|
||||
"""CLI bench run on single SIP device."""
|
||||
import pytest
|
||||
pytest.skip("Cross-SIP PE_TCM access not supported with router mesh topology")
|
||||
rc = cli_main.main(["run", "--topology", "topology.yaml", "--bench", "qkv_gemm", "--device", "sip:0"])
|
||||
assert rc == 0
|
||||
|
||||
@@ -861,7 +861,6 @@ def test_mcpu_kernel_launch_composite():
|
||||
# ── 19. Stage 5: QKV GEMM benchmark completion ────────────────────
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Cross-SIP PE_TCM access not supported with router mesh topology")
|
||||
def test_qkv_gemm_bench_completes():
|
||||
"""The qkv_gemm benchmark runs to completion without error."""
|
||||
clear_registry()
|
||||
@@ -956,7 +955,6 @@ def test_mcpu_multi_pe_kernel_launch():
|
||||
# ── 21. Stage 5: QKV GEMM multi-PE benchmark completion ──────────
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Cross-SIP PE_TCM access not supported with router mesh topology")
|
||||
def test_qkv_gemm_bench_multi_pe_completes():
|
||||
"""The qkv_gemm_multi_pe benchmark runs to completion without error."""
|
||||
clear_registry()
|
||||
|
||||
@@ -131,7 +131,6 @@ def test_2d_va_translates_to_local_hbm():
|
||||
# ── VO3. 2D: End-to-end bench completes ──────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Cross-SIP PE_TCM access not supported with router mesh topology")
|
||||
def test_2d_bench_completes():
|
||||
"""2D: full TP bench with standard Triton kernel pattern."""
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
@@ -199,7 +198,6 @@ def test_1d_va_translates_to_local_hbm():
|
||||
# ── VO6. 1D: End-to-end ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Cross-SIP PE_TCM access not supported with router mesh topology")
|
||||
def test_1d_e2e_completes():
|
||||
"""1D: full engine run with column_wise TP sharding."""
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
|
||||
Reference in New Issue
Block a user