Files
kernbench2/src/kernbench/policy/routing/router.py
T
2026-03-18 11:47:48 -07:00

185 lines
7.5 KiB
Python

from __future__ import annotations
import heapq
from collections import defaultdict
from kernbench.policy.address.phyaddr import PhysAddr, UnitType
from kernbench.topology.types import TopologyGraph
class RoutingError(Exception):
pass
class AddressResolver:
"""Resolve a PhysAddr to the destination node_id in the compiled graph.
Also provides named node lookups (find_m_cpu, find_pcie_ep, …) so that
component implementations never construct node_id strings directly.
Centralising the naming convention here means a single change propagates
everywhere (ADR-0015 D4).
"""
def __init__(self, graph: TopologyGraph) -> None:
self._node_ids = set(graph.nodes)
mm = graph.spec["cube"]["memory_map"]
self._slice_size_bytes = mm["hbm_total_gb_per_cube"] * (1 << 30) // mm["hbm_slices_per_cube"]
# ── Physical-address resolution ──────────────────────────────────
def resolve(self, addr: PhysAddr) -> str:
s = addr.sip_id
c = addr.cube_id
if addr.kind == "hbm":
pe_slice = PhysAddr.hbm_pe_id(addr.hbm_offset, self._slice_size_bytes)
node_id = f"sip{s}.cube{c}.hbm_ctrl.slice{pe_slice}"
elif addr.kind == "pe_resource":
if addr.unit_type == UnitType.PE:
node_id = f"sip{s}.cube{c}.pe{addr.pe_id}.pe_tcm"
elif addr.unit_type == UnitType.SRAM:
node_id = f"sip{s}.cube{c}.sram"
elif addr.unit_type == UnitType.MCPU:
node_id = f"sip{s}.cube{c}.m_cpu"
else:
raise RoutingError(f"unsupported unit_type: {addr.unit_type}")
else:
raise RoutingError(f"unsupported address kind: {addr.kind}")
if node_id not in self._node_ids:
raise RoutingError(f"node {node_id} not found in topology")
return node_id
# ── Named node lookups ───────────────────────────────────────────
def find_m_cpu(self, sip: int, cube: int) -> str:
node_id = f"sip{sip}.cube{cube}.m_cpu"
if node_id not in self._node_ids:
raise RoutingError(f"M_CPU not found: {node_id}")
return node_id
def find_pcie_ep(self, sip: int, io_id: str = "io0") -> str:
node_id = f"sip{sip}.{io_id}.pcie_ep"
if node_id not in self._node_ids:
raise RoutingError(f"PCIE_EP not found: {node_id}")
return node_id
def find_io_cpu(self, sip: int, io_id: str = "io0") -> str:
node_id = f"sip{sip}.{io_id}.io_cpu"
if node_id not in self._node_ids:
raise RoutingError(f"IO_CPU not found: {node_id}")
return node_id
def find_all_pcie_eps(self) -> list[str]:
"""Return all PCIE_EP node ids across all SIPs, sorted."""
return sorted(nid for nid in self._node_ids if nid.endswith(".pcie_ep"))
class PathRouter:
"""Find data-path from a source PE (or arbitrary node) to a destination node.
Two adjacency graphs are maintained:
_adj — excludes command edges (used by PE DMA routing, find_path)
_adj_all — includes all edges (used by component-to-component routing,
find_node_path; required because M_CPU↔NOC links are "command")
"""
# Edge kinds excluded from M_CPU DMA adjacency: prevents routing through
# PE-internal pipeline nodes when computing DMA paths.
_MCPU_DMA_EXCLUDE = {"pe_internal", "pe_to_xbar"}
def __init__(self, graph: TopologyGraph) -> None:
self._adj: dict[str, list[tuple[str, float]]] = defaultdict(list)
self._adj_all: dict[str, list[tuple[str, float]]] = defaultdict(list)
self._adj_mcpu_dma: dict[str, list[tuple[str, float]]] = defaultdict(list)
for e in graph.edges:
w = e.routing_weight_mm if e.routing_weight_mm is not None else e.distance_mm
self._adj_all[e.src].append((e.dst, w))
if e.kind != "command":
self._adj[e.src].append((e.dst, w))
if e.kind not in self._MCPU_DMA_EXCLUDE:
self._adj_mcpu_dma[e.src].append((e.dst, w))
def find_path(self, src_pe: str, dst_node: str) -> list[str]:
"""PE DMA routing: prepends .pe_dma, excludes command edges."""
start = f"{src_pe}.pe_dma"
return self._run_dijkstra(self._adj, start, dst_node)
def find_path_with_distance(self, src_pe: str, dst_node: str) -> tuple[list[str], float]:
start = f"{src_pe}.pe_dma"
return self._run_dijkstra_with_dist(self._adj, start, dst_node)
def find_mcpu_dma_path(self, m_cpu_id: str, dst_hbm_slice_id: str) -> list[str]:
"""M_CPU DMA path: never routes through PE-internal nodes (ADR-0015 D5).
Same-cube: deterministic [m_cpu, noc, xbar.pe_i, hbm_ctrl.slice_i].
Cross-cube: Dijkstra via _adj_mcpu_dma (pe_internal/pe_to_xbar excluded)
→ routes through NOC → UCIe → target cube NOC → xbar → HBM.
"""
m_cube = ".".join(m_cpu_id.split(".")[:2])
d_cube = ".".join(dst_hbm_slice_id.split(".")[:2])
if m_cube == d_cube:
slice_idx = int(dst_hbm_slice_id.rsplit("slice", 1)[1])
return [
m_cpu_id,
f"{m_cube}.noc",
f"{m_cube}.xbar.pe{slice_idx}",
dst_hbm_slice_id,
]
return self._run_dijkstra(self._adj_mcpu_dma, m_cpu_id, dst_hbm_slice_id)
def find_node_path(self, src: str, dst: str) -> list[str]:
"""General routing between arbitrary nodes, including command edges.
Used by components (IoCpuComponent, MCpuComponent) that route through
M_CPU↔NOC command-kind links.
"""
return self._run_dijkstra(self._adj_all, src, dst)
def _run_dijkstra(
self,
adj: dict[str, list[tuple[str, float]]],
start: str,
goal: str,
) -> list[str]:
path, _ = self._run_dijkstra_with_dist(adj, start, goal)
return path
def _run_dijkstra_with_dist(
self,
adj: dict[str, list[tuple[str, float]]],
start: str,
goal: str,
) -> tuple[list[str], float]:
if start == goal:
return [start], 0.0
best: dict[str, float] = {start: 0.0}
prev: dict[str, str] = {}
heap: list[tuple[float, str]] = [(0.0, start)]
while heap:
d, node = heapq.heappop(heap)
if node == goal:
path: list[str] = []
cur = goal
while cur != start:
path.append(cur)
cur = prev[cur]
path.append(start)
path.reverse()
return path, d
if d > best.get(node, float("inf")):
continue
for neighbor, edge_dist in adj[node]:
new_d = d + edge_dist
if new_d < best.get(neighbor, float("inf")):
best[neighbor] = new_d
prev[neighbor] = node
heapq.heappush(heap, (new_d, neighbor))
raise RoutingError(f"no path from {start} to {goal}")
# ── backward-compat shims (used by existing tests) ───────────────
def _dijkstra(self, start: str, goal: str) -> list[str]:
return self._run_dijkstra(self._adj, start, goal)
def _dijkstra_with_dist(self, start: str, goal: str) -> tuple[list[str], float]:
return self._run_dijkstra_with_dist(self._adj, start, goal)