from __future__ import annotations from collections.abc import Generator from math import ceil from typing import TYPE_CHECKING, Any import simpy from kernbench.components.base import ComponentBase from kernbench.sim_engine.transaction import Flit, Transaction if TYPE_CHECKING: from kernbench.components.context import ComponentContext from kernbench.topology.types import Node class HbmCtrlComponent(ComponentBase): """HBM controller with per-pseudo-channel (PC) striping (ADR-0019 D1, ADR-0033). Stateless per-PC ``available_at`` array; each incoming transaction is split into ``ceil(nbytes / burst_bytes)`` chunks distributed round-robin across ``num_pcs`` PCs starting from a global ``next_pc`` pointer. Read and write share the same PC array (real HW command bus is shared per PC). Chunk-loop drain (ADR-0033 D1, Phase 2b): chunks are scheduled over time at intervals of ``drain_ns / n_chunks`` to model the bottleneck link's data arrival rate. Each chunk's PC commit starts at its arrival time. The last PC commit finishes at ``arrival + drain + commit_time`` — naturally producing the correct single-transfer total (drain + commit) without the cut-through over-credit of the prior ``env.now - drain_ns`` subtraction. Direction switching penalty: when a PC's last direction differs from the current request, ``switch_penalty_ns`` is charged. Default 0 (Tier 0 assumption — ideal scheduler amortizes switching cost; ADR-0033 D2). """ _FLIT_AWARE = True def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) self._num_pcs: int = 0 self._pc_bw_gbs: float = 0.0 self._burst_bytes: int = 256 self._switch_penalty_ns: float = 0.0 self._pc_avail: list[float] = [] self._pc_last_dir: list[str | None] = [] # Address-based PC selection (ADR-0033 D6): # pc = (address >> _pc_shift) & _pc_mask self._pc_shift: int = 0 self._pc_mask: int = 0 # Per-txn flit accumulation state (ADR-0033 Phase 2c-3). self._txn_state: dict[int, dict[str, Any]] = {} def start(self, env: simpy.Environment) -> None: attrs = self.node.attrs self._num_pcs = int(attrs.get("num_pcs", 8)) self._pc_bw_gbs = float(attrs.get("pc_bw_gbs", 32.0)) self._burst_bytes = int(attrs.get("burst_bytes", 256)) self._switch_penalty_ns = float(attrs.get("switch_penalty_ns", 0.0)) if self._num_pcs <= 0 or (self._num_pcs & (self._num_pcs - 1)) != 0: raise ValueError(f"num_pcs must be a positive power of 2, got {self._num_pcs}") if self._burst_bytes <= 0 or (self._burst_bytes & (self._burst_bytes - 1)) != 0: raise ValueError(f"burst_bytes must be a positive power of 2, got {self._burst_bytes}") self._pc_shift = self._burst_bytes.bit_length() - 1 self._pc_mask = self._num_pcs - 1 self._pc_avail = [0.0] * self._num_pcs self._pc_last_dir = [None] * self._num_pcs super().start(env) def _pc_for_address(self, address: int) -> int: return (int(address) >> self._pc_shift) & self._pc_mask def run(self, env: simpy.Environment, nbytes: int) -> Generator: overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0)) yield env.timeout(overhead_ns) def _is_write(self, txn: Any) -> bool: from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg req = txn.request if isinstance(req, MemoryWriteMsg): return True if isinstance(req, PeDmaMsg) and req.is_write: return True return False def _worker(self, env: simpy.Environment) -> Generator: while True: msg: Any = yield self._inbox.get() if isinstance(msg, Flit): # ADR-0033 Phase 2c-3: serial flit handling (preserve # arrival order, in particular ``is_last`` only after # all preceding flits have committed). yield from self._handle_flit(env, msg) else: # Transaction (e.g., zero-byte read command) — keep # legacy chunk-loop drain path for PC read time modeling. env.process(self._handle_txn(env, msg)) def _handle_flit(self, env: simpy.Environment, flit: Flit) -> Generator: """Per-flit PC commit. On first flit of a txn, apply overhead. PC is derived from the flit's address (ADR-0033 D6 address-based striping). On ``is_last``, wait for last PC commit to finish, then send the response.""" txn = flit.txn tid = id(txn) chunk_time = ( self._burst_bytes / self._pc_bw_gbs if self._pc_bw_gbs > 0 else 0.0 ) new_dir = "W" if self._is_write(txn) else "R" if tid not in self._txn_state: yield from self.run(env, txn.nbytes) self._txn_state[tid] = { "last_finish": env.now, } state = self._txn_state[tid] pc = self._pc_for_address(flit.address) switch_cost = 0.0 if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir: switch_cost = self._switch_penalty_ns start = max(env.now, self._pc_avail[pc]) + switch_cost finish = start + chunk_time self._pc_avail[pc] = finish self._pc_last_dir[pc] = new_dir if finish > state["last_finish"]: state["last_finish"] = finish if flit.is_last: del self._txn_state[tid] # Finalize asynchronously so the worker can pick up the next # flit while this txn's last PC commit drains. Without this # split, the worker's ``yield env.timeout(wait)`` would # serialize concurrent single-flit txns at chunk_time even # when they hit distinct PCs, hiding address-based PC # parallelism (ADR-0033 D6). env.process(self._finalize_txn(env, txn, state["last_finish"])) def _finalize_txn( self, env: simpy.Environment, txn: Any, last_finish: float, ) -> Generator: wait = last_finish - env.now if wait > 0: yield env.timeout(wait) yield from self._send_response(env, txn) def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator: is_write = self._is_write(txn) new_dir = "W" if is_write else "R" chunk_time = ( self._burst_bytes / self._pc_bw_gbs if self._pc_bw_gbs > 0 else 0.0 ) # MemoryReadMsg forwards command with nbytes=0; the actual data work # is sized by request.nbytes (data returns via reverse-path response). work_bytes = txn.nbytes if txn.nbytes > 0 else int(getattr(txn.request, "nbytes", 0) or 0) n_chunks = max(1, ceil(work_bytes / self._burst_bytes)) if work_bytes > 0 else 0 drain = float(getattr(txn, "drain_ns", 0.0)) chunk_interval = (drain / n_chunks) if (n_chunks > 0 and drain > 0) else 0.0 yield from self.run(env, txn.nbytes) base_addr = int(getattr(txn, "base_address", 0)) last_finish = env.now for i in range(n_chunks): if chunk_interval > 0: yield env.timeout(chunk_interval) pc = self._pc_for_address(base_addr + i * self._burst_bytes) switch_cost = 0.0 if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir: switch_cost = self._switch_penalty_ns start = max(env.now, self._pc_avail[pc]) + switch_cost finish = start + chunk_time self._pc_avail[pc] = finish self._pc_last_dir[pc] = new_dir if finish > last_finish: last_finish = finish wait = last_finish - env.now if wait > 0: yield env.timeout(wait) yield from self._send_response(env, txn) def _send_response(self, env: simpy.Environment, txn: Any) -> Generator: from kernbench.runtime_api.kernel import MemoryReadMsg, PeDmaMsg if isinstance(txn.request, PeDmaMsg): reverse_path = list(reversed(txn.path)) if len(reverse_path) >= 2: resp_txn = Transaction( request=txn.request, path=reverse_path, step=0, nbytes=0, done=txn.done, is_response=True, ) yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) return txn.done.succeed() return is_bypass = not any("m_cpu" in n for n in txn.path) if is_bypass: if isinstance(txn.request, MemoryReadMsg): reverse_path = list(reversed(txn.path)) if len(reverse_path) >= 2: resp_txn = Transaction( request=txn.request, path=reverse_path, step=0, nbytes=txn.request.nbytes, done=txn.done, ) yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) return txn.done.succeed() return reverse_path = list(reversed(txn.path)) if len(reverse_path) >= 2 and self.ctx: from kernbench.runtime_api.kernel import ResponseMsg parts = self.node.id.split(".") cube_id = int(parts[1].replace("cube", "")) pe_id = 0 resp_msg = ResponseMsg( correlation_id=txn.request.correlation_id, request_id=txn.request.request_id, src_cube=cube_id, src_pe=pe_id, success=True, ) resp_txn = Transaction( request=resp_msg, path=reverse_path, step=0, nbytes=0, done=env.event(), is_response=True, ) yield self.out_ports[reverse_path[1]].put(resp_txn.advance()) else: txn.done.succeed()