from __future__ import annotations from collections.abc import Generator from typing import TYPE_CHECKING, Any import simpy from kernbench.components.base import ComponentBase from kernbench.sim_engine.transaction import Transaction if TYPE_CHECKING: from kernbench.components.context import ComponentContext from kernbench.topology.types import Node class MCpuComponent(ComponentBase): """M_CPU component: multi-PE DMA fan-out with response aggregation. Forward path (ADR-0015 D5): When a forward Transaction arrives at m_cpu (terminal hop), M_CPU fans out DMA sub-Transactions to target PEs' HBM slices. target_pe on the request controls fan-out: int → single PE, "all" → all PEs in the cube. Response path: ResponseMsg from each hbm_ctrl arrives back at m_cpu. Once all PE responses are collected, m_cpu sends an aggregate ResponseMsg on the reverse command path back to io_cpu. Transit: When m_cpu is NOT the terminal hop (transit or response relay), the Transaction is forwarded normally to the next hop. """ def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) # Pending fan-out tracking: request_id → (expected, received, all_done_event) self._pending: dict[str, tuple[int, int, simpy.Event]] = {} # Store parent txn for response sending: request_id → parent_txn self._parent_txns: dict[str, Any] = {} # DMA engine resources (ADR-0015 D5, ADR-0014 D4): capacity=1 each self._dma_write: simpy.Resource | None = None self._dma_read: simpy.Resource | None = None def start(self, env: simpy.Environment) -> None: self._dma_write = simpy.Resource(env, capacity=1) self._dma_read = simpy.Resource(env, capacity=1) super().start(env) def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) yield env.timeout(overhead_ns) def _worker(self, env: simpy.Environment) -> Generator: """Dispatch forward txns, collect response txns.""" from kernbench.runtime_api.kernel import KernelLaunchMsg, MmuMapMsg, MmuUnmapMsg while True: txn: Any = yield self._inbox.get() if getattr(txn, "is_response", False): self._collect_response(txn) else: yield from self.run(env, txn.nbytes) next_hop = txn.next_hop if next_hop: yield self.out_ports[next_hop].put(txn.advance()) elif self.ctx is not None and txn.request is not None: if isinstance(txn.request, KernelLaunchMsg): env.process(self._kernel_launch_fanout(env, txn)) elif isinstance(txn.request, (MmuMapMsg, MmuUnmapMsg)): env.process(self._mmu_msg_fanout(env, txn)) else: env.process(self._dma_fanout(env, txn)) else: txn.done.succeed() def _collect_response(self, resp_txn: Any) -> None: """Receive a PE response and increment the aggregation counter.""" key = resp_txn.request.request_id if key not in self._pending: return expected, received, all_done = self._pending[key] received += 1 if received >= expected: all_done.succeed() del self._pending[key] else: self._pending[key] = (expected, received, all_done) def _dma_fanout(self, env: simpy.Environment, txn: Any) -> Generator: """Fan out DMA sub-Transactions to target PE(s), wait for responses, then send aggregate response on reverse command path. Each DMA transfer acquires the DMA resource (capacity=1 per ADR-0014 D4), so multi-PE fan-out is serialized through the DMA engine. """ from kernbench.runtime_api.kernel import MemoryWriteMsg request = txn.request target_pe = getattr(request, "target_pe", "all") dst_nodes = self._resolve_dma_destinations(request, target_pe) if not dst_nodes: txn.done.succeed() return # Setup aggregation all_done = env.event() self._pending[request.request_id] = (len(dst_nodes), 0, all_done) self._parent_txns[request.request_id] = txn # Select DMA resource based on operation type dma_res = self._dma_write if isinstance(request, MemoryWriteMsg) else self._dma_read # Fan out DMA sub-txns (serialized through DMA resource) max_drain_ns = 0.0 for dst_node in dst_nodes: try: dma_path = self.ctx.router.find_mcpu_dma_path(self.node.id, dst_node) except Exception: continue if len(dma_path) < 2: continue drain_ns = self.ctx.compute_drain_ns(dma_path, txn.nbytes) max_drain_ns = max(max_drain_ns, drain_ns) sub_txn = Transaction( request=request, path=dma_path, step=0, nbytes=txn.nbytes, done=env.event(), drain_ns=drain_ns, ) with dma_res.request() as req: yield req yield self.out_ports[dma_path[1]].put(sub_txn.advance()) # Wait for all PE responses yield all_done txn.result_data["xfer_ns"] = max_drain_ns del self._parent_txns[request.request_id] # Send aggregate response on reverse command path reverse_path = list(reversed(txn.path)) if len(reverse_path) >= 2: from kernbench.runtime_api.kernel import ResponseMsg parts = self.node.id.split(".") cube_id = int(parts[1].replace("cube", "")) resp_msg = ResponseMsg( correlation_id=request.correlation_id, request_id=request.request_id, src_cube=cube_id, src_pe=-1, success=True, ) resp_txn = Transaction( request=resp_msg, path=reverse_path, step=0, nbytes=0, done=env.event(), is_response=True, ) yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) else: txn.done.succeed() def _kernel_launch_fanout(self, env: simpy.Environment, txn: Any) -> Generator: """Fan out KernelLaunchMsg to target PE_CPU(s) via NOC (ADR-0009 D3). Routes through find_node_path (M_CPU → NOC → PE_CPU command edges). PE_CPU sends ResponseMsg back via NOC → M_CPU on completion. Then sends aggregate ResponseMsg back to IO_CPU on the reverse path. """ request = txn.request target_pe = getattr(request, "target_pe", "all") cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0" pe_ids = self._resolve_pe_ids(target_pe) if not pe_ids: txn.done.succeed() return # Fan out to each PE_CPU, using response-based aggregation sub_txns: list[Transaction] = [] n_dispatched = 0 for pe_id in pe_ids: pe_cpu_id = f"{cube_prefix}.pe{pe_id}.pe_cpu" try: path = self.ctx.router.find_node_path(self.node.id, pe_cpu_id) except Exception: continue if len(path) < 2: continue sub_txn = Transaction( request=request, path=path, step=0, nbytes=0, done=env.event(), ) yield self.out_ports[path[1]].put(sub_txn.advance()) sub_txns.append(sub_txn) n_dispatched += 1 if n_dispatched == 0: txn.done.succeed() return # Setup response aggregation (PE_CPU ResponseMsg arrives via _collect_response) all_done = env.event() self._pending[request.request_id] = (n_dispatched, 0, all_done) self._parent_txns[request.request_id] = txn # Wait for all PE_CPU responses via NOC yield all_done del self._parent_txns[request.request_id] # Aggregate PE-internal metrics (max across PEs) pe_exec_values = [st.result_data.get("pe_exec_ns", 0.0) for st in sub_txns] if pe_exec_values: txn.result_data["pe_exec_ns"] = max(pe_exec_values) dma_values = [st.result_data.get("dma_ns", 0.0) for st in sub_txns] if dma_values: txn.result_data["dma_ns"] = max(dma_values) compute_values = [st.result_data.get("compute_ns", 0.0) for st in sub_txns] if compute_values: txn.result_data["compute_ns"] = max(compute_values) # Send aggregate response on reverse command path back to IO_CPU reverse_path = list(reversed(txn.path)) if len(reverse_path) >= 2: from kernbench.runtime_api.kernel import ResponseMsg parts = self.node.id.split(".") cube_id = int(parts[1].replace("cube", "")) resp_msg = ResponseMsg( correlation_id=request.correlation_id, request_id=request.request_id, src_cube=cube_id, src_pe=-1, success=True, ) resp_txn = Transaction( request=resp_msg, path=reverse_path, step=0, nbytes=0, done=env.event(), is_response=True, ) yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) else: txn.done.succeed() def _resolve_dma_destinations(self, request: Any, target_pe: int | str) -> list[str]: """Return list of HBM destination node_ids for DMA fan-out. Uses PA-based resolution to determine the actual target cube and slice, enabling cross-cube DMA routing when the PA points to a remote cube. """ cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0" if isinstance(target_pe, int): return [f"{cube_prefix}.hbm_ctrl.slice{target_pe}"] # PA-based resolution: extract actual target from physical address pa_val = getattr(request, "dst_pa", None) or getattr(request, "src_pa", None) if pa_val is not None: from kernbench.policy.address.phyaddr import PhysAddr try: pa = PhysAddr.decode(pa_val) return [self.ctx.resolver.resolve(pa)] except Exception: pass # "all" without PA (KernelLaunch): all slices in local cube n_slices = 8 if self.ctx and self.ctx.spec: mm = self.ctx.spec.get("cube", {}).get("memory_map", {}) n_slices = mm.get("hbm_slices_per_cube", 8) return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)] def _mmu_msg_fanout(self, env: simpy.Environment, txn: Any) -> Generator: """Fan out MmuMapMsg/MmuUnmapMsg to target PE_MMU(s) via NOC. Routes through find_node_path (M_CPU → NOC → PE_MMU command edges). PE_MMU is a terminal node — completes the transaction directly. """ request = txn.request target_pe = getattr(request, "target_pe", "all") cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0" pe_ids = self._resolve_pe_ids(target_pe) if not pe_ids: txn.done.succeed() return # Fan out to each PE_MMU sub_dones: list[simpy.Event] = [] for pe_id in pe_ids: pe_mmu_id = f"{cube_prefix}.pe{pe_id}.pe_mmu" try: path = self.ctx.router.find_node_path(self.node.id, pe_mmu_id) except Exception: continue if len(path) < 2: continue sub_done = env.event() sub_txn = Transaction( request=request, path=path, step=0, nbytes=0, done=sub_done, ) yield self.out_ports[path[1]].put(sub_txn.advance()) sub_dones.append(sub_done) # Wait for all PE_MMUs to complete for sd in sub_dones: yield sd # Send aggregate response on reverse path reverse_path = list(reversed(txn.path)) if len(reverse_path) >= 2: from kernbench.runtime_api.kernel import ResponseMsg parts = self.node.id.split(".") cube_id = int(parts[1].replace("cube", "")) resp_msg = ResponseMsg( correlation_id=request.correlation_id, request_id=request.request_id, src_cube=cube_id, src_pe=-1, success=True, ) resp_txn = Transaction( request=resp_msg, path=reverse_path, step=0, nbytes=0, done=env.event(), is_response=True, ) yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) else: txn.done.succeed() def _resolve_pe_ids(self, target_pe: int | tuple | str) -> list[int]: """Return list of PE IDs to fan out to (used by kernel launch fan-out).""" if isinstance(target_pe, int): return [target_pe] if isinstance(target_pe, tuple): return list(target_pe) # "all": all PEs in local cube n_slices = 8 if self.ctx and self.ctx.spec: mm = self.ctx.spec.get("cube", {}).get("memory_map", {}) n_slices = mm.get("hbm_slices_per_cube", 8) return list(range(n_slices))