from __future__ import annotations from collections.abc import Generator from typing import TYPE_CHECKING, Any import simpy from kernbench.components.base import ComponentBase from kernbench.sim_engine.transaction import Transaction if TYPE_CHECKING: from kernbench.components.context import ComponentContext from kernbench.topology.types import Node class IoCpuComponent(ComponentBase): """IO_CPU component: multi-cube fan-out with response aggregation. Forward path: 1. Applies overhead_ns processing overhead. 2. Resolves target cube(s) from request.target_cubes. 3. Fans out sub-Transactions to each target cube's M_CPU. Response path: Collects ResponseMsg from each M_CPU. When all cube responses are received, succeeds the parent txn.done. """ def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) # Pending fan-out tracking: request_id → (expected, received, parent_txn_done) self._pending: dict[str, tuple[int, int, simpy.Event]] = {} def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) yield env.timeout(overhead_ns) def _worker(self, env: simpy.Environment) -> Generator: while True: txn: Any = yield self._inbox.get() if getattr(txn, "is_response", False): self._collect_response(txn) else: yield from self.run(env, txn.nbytes) env.process(self._dispatch_to_m_cpus(env, txn)) def _collect_response(self, resp_txn: Any) -> None: """Receive a cube response and increment the aggregation counter.""" key = resp_txn.request.request_id if key not in self._pending: return expected, received, parent_done = self._pending[key] received += 1 if received >= expected: parent_done.succeed() del self._pending[key] else: self._pending[key] = (expected, received, parent_done) def _dispatch_to_m_cpus(self, env: simpy.Environment, txn: Any) -> Generator: """Fan out sub-Transactions to target cube M_CPUs, wait for responses. ADR-0009 D5 (extended): for KernelLaunchMsg, stamp a single global target_start_ns = env.now + max(IO_CPU → any target PE_CPU path latency across all target cubes). M_CPU passes this value through unchanged; every PE in every cube yields until the same sim-time before beginning kernel execution. Without this, cross-cube launches would have each cube's M_CPU compute its own per-cube barrier relative to its local env.now, leaving PEs on different cubes out of sync (the "h3/h4 dispatch-offset artifact"). """ import dataclasses from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg request = txn.request try: cube_targets = self._resolve_cube_targets(request) except Exception: txn.done.succeed() return if not cube_targets: txn.done.succeed() return # For KernelLaunchMsg, compute the global barrier once here so # every downstream PE_CPU uses the same target_start_ns. if isinstance(request, KernelLaunchMsg): io_overhead = self.ctx.node_overhead_ns.get(self.node.id, 0.0) global_max_latency = 0.0 pe_ids = self._resolve_pe_ids( getattr(request, "target_pe", "all") ) for sip, cube in cube_targets: try: m_cpu_id = self.ctx.resolver.find_m_cpu(sip, cube) io_to_m_path = self.ctx.router.find_node_path( self.node.id, m_cpu_id, ) except Exception: continue if len(io_to_m_path) < 2: continue leg1 = self.ctx.compute_path_latency_ns( io_to_m_path, nbytes=0, ) m_overhead = self.ctx.node_overhead_ns.get(m_cpu_id, 0.0) for pe_id in pe_ids: pe_cpu_id = ( f"sip{sip}.cube{cube}.pe{pe_id}.pe_cpu" ) try: m_to_pe_path = self.ctx.router.find_node_path( m_cpu_id, pe_cpu_id, ) except Exception: continue if len(m_to_pe_path) < 2: continue leg2 = self.ctx.compute_path_latency_ns( m_to_pe_path, nbytes=0, ) latency = leg1 + leg2 - io_overhead - m_overhead if latency > global_max_latency: global_max_latency = latency request = dataclasses.replace( request, target_start_ns=float(env.now) + global_max_latency, ) # Setup aggregation self._pending[request.request_id] = (len(cube_targets), 0, txn.done) # Fan out to each target cube's M_CPU. Kernel-launch fanout # carries control metadata only; nbytes is forced to 0 for # KernelLaunchMsg so the launch sub-txns do not occupy data-fabric # BW (would otherwise serialize 16 cubes worth of fanout on the # shared first hop and break ADR-0009 D5's barrier prediction). is_kernel_launch = isinstance(request, KernelLaunchMsg) for sip, cube in cube_targets: try: m_cpu_id = self.ctx.resolver.find_m_cpu(sip, cube) path = self.ctx.router.find_node_path(self.node.id, m_cpu_id) except Exception: continue if len(path) < 2: continue sub_txn = Transaction( request=request, path=path, step=0, nbytes=0 if is_kernel_launch else txn.nbytes, done=env.event(), result_data=txn.result_data, ) yield self.out_ports[path[1]].put(sub_txn.advance()) def _resolve_pe_ids(self, target_pe: Any) -> list[int]: """Resolve target_pe → list of PE indices (mirrors M_CPU logic).""" if isinstance(target_pe, int): return [target_pe] if isinstance(target_pe, tuple): return list(target_pe) # "all": all PEs in a cube n_slices = 8 if self.ctx and self.ctx.spec: mm = self.ctx.spec.get("cube", {}).get("memory_map", {}) n_slices = mm.get("hbm_slices_per_cube", 8) return list(range(n_slices)) def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]: """Return list of (sip, cube) pairs to fan out to.""" from kernbench.runtime_api.kernel import ( KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, MmuMapMsg, MmuUnmapMsg, ) target_cubes = getattr(request, "target_cubes", "all") if isinstance(request, MemoryWriteMsg): sip = request.dst_sip if target_cubes == "all": cube = self._cube_from_pa(request.dst_pa, fallback=request.dst_cube) return [(sip, cube)] return [(sip, c) for c in target_cubes] if isinstance(request, MemoryReadMsg): sip = request.src_sip if target_cubes == "all": cube = self._cube_from_pa(request.src_pa, fallback=request.src_cube) return [(sip, cube)] return [(sip, c) for c in target_cubes] if isinstance(request, KernelLaunchMsg): my_sip = self._my_sip() if target_cubes != "all": return [(my_sip, c) for c in target_cubes] # "all": derive from tensor shards, filtered to this SIP seen: set[tuple[int, int]] = set() targets: list[tuple[int, int]] = [] for arg in request.args: if arg.arg_kind != "tensor": continue for shard in arg.shards: if shard.sip != my_sip: continue key = (shard.sip, shard.cube) if key not in seen: seen.add(key) targets.append(key) return targets if isinstance(request, (MmuMapMsg, MmuUnmapMsg)): my_sip = self._my_sip() if target_cubes == "all": n_cubes = 16 if self.ctx and self.ctx.spec: sips = self.ctx.spec.get("system", {}).get("sips", {}) n_cubes = sips.get("cubes_per_sip", 16) return [(my_sip, c) for c in range(n_cubes)] return [(my_sip, c) for c in target_cubes] return [] def _cube_from_pa(self, pa_val: int, fallback: int) -> int: """Extract die_id from a physical address, with fallback.""" from kernbench.policy.address.phyaddr import PhysAddr try: return PhysAddr.decode(pa_val).die_id except Exception: return fallback def _my_sip(self) -> int: """Extract this IO_CPU's SIP ID from its node ID (e.g. 'sip0.io0.io_cpu' → 0).""" return int(self.node.id.split(".")[0].replace("sip", ""))