aaa1cbfaf6
Replaces global round-robin with deterministic address-derived PC
striping:
pc_shift = log2(burst_bytes)
pc_mask = num_pcs - 1
pc = (flit.address >> pc_shift) & pc_mask
Each Transaction carries base_address (HBM byte offset of the first
chunk); each Flit derives its own address as base + i*flit_bytes.
HBM CTRL routes flits to PCs via this formula, replacing the
arrival-order RR pointer. Also splits the is_last wait into an
asynchronous _finalize_txn process so the worker isn't blocked on
PC commit, exposing true PC parallelism for disjoint addresses.
phyaddr.py documents the canonical bit layout (bits [10:8] for the
default burst=256, num_pcs=8 case). ADR-0033 D6 records the
derivation and the workload scenarios where address-striping
matters (strided streams, offset-disjoint parallel transfers).
Adds tests/test_hbm_address_based_pc.py: canonical bit mapping,
strided 8-way load distribution, same-address PC-0 serialization,
PC-aligned 2KB pair collision, dynamic pc_shift from burst_bytes,
and power-of-2 attr validation. Integration tests inspect
_pc_avail ledger directly: at default config UCIe's 8 ns per-txn
overhead exactly matches chunk_time, masking PC contention at the
makespan level even though the ledger correctly distinguishes the
cases.
Full suite: 631 passed, 1 skipped.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
235 lines
9.8 KiB
Python
235 lines
9.8 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Generator
|
|
from math import ceil
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import simpy
|
|
|
|
from kernbench.components.base import ComponentBase
|
|
from kernbench.sim_engine.transaction import Flit, Transaction
|
|
|
|
if TYPE_CHECKING:
|
|
from kernbench.components.context import ComponentContext
|
|
from kernbench.topology.types import Node
|
|
|
|
|
|
class HbmCtrlComponent(ComponentBase):
|
|
"""HBM controller with per-pseudo-channel (PC) striping (ADR-0019 D1, ADR-0033).
|
|
|
|
Stateless per-PC ``available_at`` array; each incoming transaction is
|
|
split into ``ceil(nbytes / burst_bytes)`` chunks distributed round-robin
|
|
across ``num_pcs`` PCs starting from a global ``next_pc`` pointer. Read
|
|
and write share the same PC array (real HW command bus is shared per PC).
|
|
|
|
Chunk-loop drain (ADR-0033 D1, Phase 2b): chunks are scheduled over
|
|
time at intervals of ``drain_ns / n_chunks`` to model the bottleneck
|
|
link's data arrival rate. Each chunk's PC commit starts at its arrival
|
|
time. The last PC commit finishes at ``arrival + drain + commit_time``
|
|
— naturally producing the correct single-transfer total (drain +
|
|
commit) without the cut-through over-credit of the prior
|
|
``env.now - drain_ns`` subtraction.
|
|
|
|
Direction switching penalty: when a PC's last direction differs from the
|
|
current request, ``switch_penalty_ns`` is charged. Default 0 (Tier 0
|
|
assumption — ideal scheduler amortizes switching cost; ADR-0033 D2).
|
|
"""
|
|
|
|
_FLIT_AWARE = True
|
|
|
|
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
|
super().__init__(node, ctx)
|
|
self._num_pcs: int = 0
|
|
self._pc_bw_gbs: float = 0.0
|
|
self._burst_bytes: int = 256
|
|
self._switch_penalty_ns: float = 0.0
|
|
self._pc_avail: list[float] = []
|
|
self._pc_last_dir: list[str | None] = []
|
|
# Address-based PC selection (ADR-0033 D6):
|
|
# pc = (address >> _pc_shift) & _pc_mask
|
|
self._pc_shift: int = 0
|
|
self._pc_mask: int = 0
|
|
# Per-txn flit accumulation state (ADR-0033 Phase 2c-3).
|
|
self._txn_state: dict[int, dict[str, Any]] = {}
|
|
|
|
def start(self, env: simpy.Environment) -> None:
|
|
attrs = self.node.attrs
|
|
self._num_pcs = int(attrs.get("num_pcs", 8))
|
|
self._pc_bw_gbs = float(attrs.get("pc_bw_gbs", 32.0))
|
|
self._burst_bytes = int(attrs.get("burst_bytes", 256))
|
|
self._switch_penalty_ns = float(attrs.get("switch_penalty_ns", 0.0))
|
|
if self._num_pcs <= 0 or (self._num_pcs & (self._num_pcs - 1)) != 0:
|
|
raise ValueError(f"num_pcs must be a positive power of 2, got {self._num_pcs}")
|
|
if self._burst_bytes <= 0 or (self._burst_bytes & (self._burst_bytes - 1)) != 0:
|
|
raise ValueError(f"burst_bytes must be a positive power of 2, got {self._burst_bytes}")
|
|
self._pc_shift = self._burst_bytes.bit_length() - 1
|
|
self._pc_mask = self._num_pcs - 1
|
|
self._pc_avail = [0.0] * self._num_pcs
|
|
self._pc_last_dir = [None] * self._num_pcs
|
|
super().start(env)
|
|
|
|
def _pc_for_address(self, address: int) -> int:
|
|
return (int(address) >> self._pc_shift) & self._pc_mask
|
|
|
|
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
|
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
|
yield env.timeout(overhead_ns)
|
|
|
|
def _is_write(self, txn: Any) -> bool:
|
|
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
|
|
|
|
req = txn.request
|
|
if isinstance(req, MemoryWriteMsg):
|
|
return True
|
|
if isinstance(req, PeDmaMsg) and req.is_write:
|
|
return True
|
|
return False
|
|
|
|
def _worker(self, env: simpy.Environment) -> Generator:
|
|
while True:
|
|
msg: Any = yield self._inbox.get()
|
|
if isinstance(msg, Flit):
|
|
# ADR-0033 Phase 2c-3: serial flit handling (preserve
|
|
# arrival order, in particular ``is_last`` only after
|
|
# all preceding flits have committed).
|
|
yield from self._handle_flit(env, msg)
|
|
else:
|
|
# Transaction (e.g., zero-byte read command) — keep
|
|
# legacy chunk-loop drain path for PC read time modeling.
|
|
env.process(self._handle_txn(env, msg))
|
|
|
|
def _handle_flit(self, env: simpy.Environment, flit: Flit) -> Generator:
|
|
"""Per-flit PC commit. On first flit of a txn, apply overhead. PC is
|
|
derived from the flit's address (ADR-0033 D6 address-based striping).
|
|
On ``is_last``, wait for last PC commit to finish, then send the
|
|
response."""
|
|
txn = flit.txn
|
|
tid = id(txn)
|
|
chunk_time = (
|
|
self._burst_bytes / self._pc_bw_gbs if self._pc_bw_gbs > 0 else 0.0
|
|
)
|
|
new_dir = "W" if self._is_write(txn) else "R"
|
|
|
|
if tid not in self._txn_state:
|
|
yield from self.run(env, txn.nbytes)
|
|
self._txn_state[tid] = {
|
|
"last_finish": env.now,
|
|
}
|
|
|
|
state = self._txn_state[tid]
|
|
pc = self._pc_for_address(flit.address)
|
|
switch_cost = 0.0
|
|
if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir:
|
|
switch_cost = self._switch_penalty_ns
|
|
start = max(env.now, self._pc_avail[pc]) + switch_cost
|
|
finish = start + chunk_time
|
|
self._pc_avail[pc] = finish
|
|
self._pc_last_dir[pc] = new_dir
|
|
if finish > state["last_finish"]:
|
|
state["last_finish"] = finish
|
|
|
|
if flit.is_last:
|
|
del self._txn_state[tid]
|
|
# Finalize asynchronously so the worker can pick up the next
|
|
# flit while this txn's last PC commit drains. Without this
|
|
# split, the worker's ``yield env.timeout(wait)`` would
|
|
# serialize concurrent single-flit txns at chunk_time even
|
|
# when they hit distinct PCs, hiding address-based PC
|
|
# parallelism (ADR-0033 D6).
|
|
env.process(self._finalize_txn(env, txn, state["last_finish"]))
|
|
|
|
def _finalize_txn(
|
|
self, env: simpy.Environment, txn: Any, last_finish: float,
|
|
) -> Generator:
|
|
wait = last_finish - env.now
|
|
if wait > 0:
|
|
yield env.timeout(wait)
|
|
yield from self._send_response(env, txn)
|
|
|
|
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
|
is_write = self._is_write(txn)
|
|
new_dir = "W" if is_write else "R"
|
|
chunk_time = (
|
|
self._burst_bytes / self._pc_bw_gbs if self._pc_bw_gbs > 0 else 0.0
|
|
)
|
|
# MemoryReadMsg forwards command with nbytes=0; the actual data work
|
|
# is sized by request.nbytes (data returns via reverse-path response).
|
|
work_bytes = txn.nbytes if txn.nbytes > 0 else int(getattr(txn.request, "nbytes", 0) or 0)
|
|
n_chunks = max(1, ceil(work_bytes / self._burst_bytes)) if work_bytes > 0 else 0
|
|
|
|
drain = float(getattr(txn, "drain_ns", 0.0))
|
|
chunk_interval = (drain / n_chunks) if (n_chunks > 0 and drain > 0) else 0.0
|
|
|
|
yield from self.run(env, txn.nbytes)
|
|
|
|
base_addr = int(getattr(txn, "base_address", 0))
|
|
last_finish = env.now
|
|
for i in range(n_chunks):
|
|
if chunk_interval > 0:
|
|
yield env.timeout(chunk_interval)
|
|
pc = self._pc_for_address(base_addr + i * self._burst_bytes)
|
|
switch_cost = 0.0
|
|
if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir:
|
|
switch_cost = self._switch_penalty_ns
|
|
start = max(env.now, self._pc_avail[pc]) + switch_cost
|
|
finish = start + chunk_time
|
|
self._pc_avail[pc] = finish
|
|
self._pc_last_dir[pc] = new_dir
|
|
if finish > last_finish:
|
|
last_finish = finish
|
|
|
|
wait = last_finish - env.now
|
|
if wait > 0:
|
|
yield env.timeout(wait)
|
|
|
|
yield from self._send_response(env, txn)
|
|
|
|
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
|
|
from kernbench.runtime_api.kernel import MemoryReadMsg, PeDmaMsg
|
|
|
|
if isinstance(txn.request, PeDmaMsg):
|
|
reverse_path = list(reversed(txn.path))
|
|
if len(reverse_path) >= 2:
|
|
resp_txn = Transaction(
|
|
request=txn.request, path=reverse_path, step=0,
|
|
nbytes=0, done=txn.done, is_response=True,
|
|
)
|
|
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
|
return
|
|
txn.done.succeed()
|
|
return
|
|
|
|
is_bypass = not any("m_cpu" in n for n in txn.path)
|
|
if is_bypass:
|
|
if isinstance(txn.request, MemoryReadMsg):
|
|
reverse_path = list(reversed(txn.path))
|
|
if len(reverse_path) >= 2:
|
|
resp_txn = Transaction(
|
|
request=txn.request, path=reverse_path, step=0,
|
|
nbytes=txn.request.nbytes, done=txn.done,
|
|
)
|
|
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
|
return
|
|
txn.done.succeed()
|
|
return
|
|
|
|
reverse_path = list(reversed(txn.path))
|
|
if len(reverse_path) >= 2 and self.ctx:
|
|
from kernbench.runtime_api.kernel import ResponseMsg
|
|
|
|
parts = self.node.id.split(".")
|
|
cube_id = int(parts[1].replace("cube", ""))
|
|
pe_id = 0
|
|
resp_msg = ResponseMsg(
|
|
correlation_id=txn.request.correlation_id,
|
|
request_id=txn.request.request_id,
|
|
src_cube=cube_id, src_pe=pe_id, success=True,
|
|
)
|
|
resp_txn = Transaction(
|
|
request=resp_msg, path=reverse_path, step=0,
|
|
nbytes=0, done=env.event(), is_response=True,
|
|
)
|
|
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
|
else:
|
|
txn.done.succeed()
|