from __future__ import annotations from collections.abc import Generator from typing import TYPE_CHECKING, Any import simpy from kernbench.components.base import ComponentBase from kernbench.sim_engine.transaction import Transaction if TYPE_CHECKING: from kernbench.components.context import ComponentContext from kernbench.topology.types import Node class HbmCtrlComponent(ComponentBase): """HBM controller: terminal component that models HBM access latency. Dual-channel model: separate read and write resources (each capacity=1) allowing concurrent read/write like PE_DMA. Multiple reads or multiple writes still serialize within their respective channel. On completion, creates a ResponseMsg and sends it back on the reverse path so that response latency is modeled through the fabric. """ def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) self._read: simpy.Resource | None = None self._write: simpy.Resource | None = None def start(self, env: simpy.Environment) -> None: capacity = int(self.node.attrs.get("capacity", 1)) self._read = simpy.Resource(env, capacity=capacity) self._write = simpy.Resource(env, capacity=capacity) super().start(env) def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) yield env.timeout(overhead_ns) def _select_channel(self, txn: Any) -> simpy.Resource: """Select channel based on request type: write requests → write, else → read.""" from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg assert self._read is not None and self._write is not None req = txn.request if isinstance(req, MemoryWriteMsg): return self._write if isinstance(req, PeDmaMsg) and req.is_write: return self._write return self._read def _worker(self, env: simpy.Environment) -> Generator: """Dispatch each incoming txn to a concurrent process for channel-level parallelism.""" while True: txn: Any = yield self._inbox.get() env.process(self._handle_txn(env, txn)) def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator: """Acquire channel, run, apply drain, send response.""" channel = self._select_channel(txn) with channel.request() as req: yield req yield from self.run(env, txn.nbytes) drain = getattr(txn, "drain_ns", 0.0) if drain > 0: yield env.timeout(drain) yield from self._send_response(env, txn) def _send_response(self, env: simpy.Environment, txn: Any) -> Generator: """Route completion based on path type. - PeDmaMsg: succeed done directly (probe). - Bypass path (no m_cpu): MemoryWrite succeeds done; MemoryRead sends data back on reverse path with original done event. - M_CPU DMA path: send ResponseMsg for m_cpu/io_cpu aggregation. """ from kernbench.runtime_api.kernel import MemoryReadMsg, PeDmaMsg if isinstance(txn.request, PeDmaMsg): reverse_path = list(reversed(txn.path)) if len(reverse_path) >= 2: resp_txn = Transaction( request=txn.request, path=reverse_path, step=0, nbytes=0, done=txn.done, is_response=True, ) yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) return txn.done.succeed() return # Bypass path: no m_cpu in the transaction path is_bypass = not any("m_cpu" in n for n in txn.path) if is_bypass: if isinstance(txn.request, MemoryReadMsg): # D2H: send data back on reverse path to pcie_ep reverse_path = list(reversed(txn.path)) if len(reverse_path) >= 2: resp_txn = Transaction( request=txn.request, path=reverse_path, step=0, nbytes=txn.request.nbytes, done=txn.done, ) yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) return # MemoryWrite bypass or short path: done txn.done.succeed() return # M_CPU DMA path: send ResponseMsg for aggregation reverse_path = list(reversed(txn.path)) if len(reverse_path) >= 2 and self.ctx: from kernbench.runtime_api.kernel import ResponseMsg parts = self.node.id.split(".") cube_id = int(parts[1].replace("cube", "")) pe_id = int(parts[3].replace("slice", "")) resp_msg = ResponseMsg( correlation_id=txn.request.correlation_id, request_id=txn.request.request_id, src_cube=cube_id, src_pe=pe_id, success=True, ) resp_txn = Transaction( request=resp_msg, path=reverse_path, step=0, nbytes=0, done=env.event(), is_response=True, ) yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) else: txn.done.succeed()