Latency model: HBM PC striping + chunk-loop drain (ADR-0033)

Previous model double-counted slow-upstream paths (e.g., 64KB via UCIe
128 GB/s was ~2x pessimistic). HBM CTRL now distributes bursts across
8 pseudo-channels via global round-robin, with per-chunk commit timing
that pipelines correctly against the bottleneck link's data arrival.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-14 21:59:07 -07:00
parent f6d262e359
commit 5fdb6f8797
11 changed files with 1192 additions and 52 deletions
+73 -38
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
from collections.abc import Generator
from math import ceil
from typing import TYPE_CHECKING, Any
import simpy
@@ -14,68 +15,106 @@ if TYPE_CHECKING:
class HbmCtrlComponent(ComponentBase):
"""HBM controller: terminal component that models HBM access latency.
"""HBM controller with per-pseudo-channel (PC) striping (ADR-0019 D1, ADR-0033).
Dual-channel model: separate read and write resources (each capacity=1)
allowing concurrent read/write like PE_DMA. Multiple reads or multiple
writes still serialize within their respective channel.
Stateless per-PC ``available_at`` array; each incoming transaction is
split into ``ceil(nbytes / burst_bytes)`` chunks distributed round-robin
across ``num_pcs`` PCs starting from a global ``next_pc`` pointer. Read
and write share the same PC array (real HW command bus is shared per PC).
On completion, creates a ResponseMsg and sends it back on the reverse path
so that response latency is modeled through the fabric.
Chunk-loop drain (ADR-0033 D1, Phase 2b): chunks are scheduled over
time at intervals of ``drain_ns / n_chunks`` to model the bottleneck
link's data arrival rate. Each chunk's PC commit starts at its arrival
time. The last PC commit finishes at ``arrival + drain + commit_time``
— naturally producing the correct single-transfer total (drain +
commit) without the cut-through over-credit of the prior
``env.now - drain_ns`` subtraction.
Direction switching penalty: when a PC's last direction differs from the
current request, ``switch_penalty_ns`` is charged. Default 0 (Tier 0
assumption — ideal scheduler amortizes switching cost; ADR-0033 D2).
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._read: simpy.Resource | None = None
self._write: simpy.Resource | None = None
self._num_pcs: int = 0
self._pc_bw_gbs: float = 0.0
self._burst_bytes: int = 256
self._switch_penalty_ns: float = 0.0
self._pc_avail: list[float] = []
self._pc_last_dir: list[str | None] = []
self._next_pc: int = 0
def start(self, env: simpy.Environment) -> None:
capacity = int(self.node.attrs.get("capacity", 1))
self._read = simpy.Resource(env, capacity=capacity)
self._write = simpy.Resource(env, capacity=capacity)
attrs = self.node.attrs
self._num_pcs = int(attrs.get("num_pcs", 8))
self._pc_bw_gbs = float(attrs.get("pc_bw_gbs", 32.0))
self._burst_bytes = int(attrs.get("burst_bytes", 256))
self._switch_penalty_ns = float(attrs.get("switch_penalty_ns", 0.0))
self._pc_avail = [0.0] * self._num_pcs
self._pc_last_dir = [None] * self._num_pcs
self._next_pc = 0
super().start(env)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _select_channel(self, txn: Any) -> simpy.Resource:
"""Select channel based on request type: write requests → write, else → read."""
def _is_write(self, txn: Any) -> bool:
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
assert self._read is not None and self._write is not None
req = txn.request
if isinstance(req, MemoryWriteMsg):
return self._write
return True
if isinstance(req, PeDmaMsg) and req.is_write:
return self._write
return self._read
return True
return False
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch each incoming txn to a concurrent process for channel-level parallelism."""
while True:
txn: Any = yield self._inbox.get()
env.process(self._handle_txn(env, txn))
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Acquire channel, run, apply drain, send response."""
channel = self._select_channel(txn)
with channel.request() as req:
yield req
yield from self.run(env, txn.nbytes)
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
is_write = self._is_write(txn)
new_dir = "W" if is_write else "R"
chunk_time = (
self._burst_bytes / self._pc_bw_gbs if self._pc_bw_gbs > 0 else 0.0
)
# MemoryReadMsg forwards command with nbytes=0; the actual data work
# is sized by request.nbytes (data returns via reverse-path response).
work_bytes = txn.nbytes if txn.nbytes > 0 else int(getattr(txn.request, "nbytes", 0) or 0)
n_chunks = max(1, ceil(work_bytes / self._burst_bytes)) if work_bytes > 0 else 0
drain = float(getattr(txn, "drain_ns", 0.0))
chunk_interval = (drain / n_chunks) if (n_chunks > 0 and drain > 0) else 0.0
yield from self.run(env, txn.nbytes)
last_finish = env.now
for i in range(n_chunks):
if chunk_interval > 0:
yield env.timeout(chunk_interval)
pc = (self._next_pc + i) % self._num_pcs
switch_cost = 0.0
if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir:
switch_cost = self._switch_penalty_ns
start = max(env.now, self._pc_avail[pc]) + switch_cost
finish = start + chunk_time
self._pc_avail[pc] = finish
self._pc_last_dir[pc] = new_dir
if finish > last_finish:
last_finish = finish
if n_chunks > 0:
self._next_pc = (self._next_pc + n_chunks) % self._num_pcs
wait = last_finish - env.now
if wait > 0:
yield env.timeout(wait)
yield from self._send_response(env, txn)
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
"""Route completion based on path type.
- PeDmaMsg: succeed done directly (probe).
- Bypass path (no m_cpu): MemoryWrite succeeds done; MemoryRead sends
data back on reverse path with original done event.
- M_CPU DMA path: send ResponseMsg for m_cpu/io_cpu aggregation.
"""
from kernbench.runtime_api.kernel import MemoryReadMsg, PeDmaMsg
if isinstance(txn.request, PeDmaMsg):
@@ -90,11 +129,9 @@ class HbmCtrlComponent(ComponentBase):
txn.done.succeed()
return
# Bypass path: no m_cpu in the transaction path
is_bypass = not any("m_cpu" in n for n in txn.path)
if is_bypass:
if isinstance(txn.request, MemoryReadMsg):
# D2H: send data back on reverse path to pcie_ep
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
resp_txn = Transaction(
@@ -103,18 +140,16 @@ class HbmCtrlComponent(ComponentBase):
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
return
# MemoryWrite bypass or short path: done
txn.done.succeed()
return
# M_CPU DMA path: send ResponseMsg for aggregation
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2 and self.ctx:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
pe_id = 0 # single hbm_ctrl, PE info from request
pe_id = 0
resp_msg = ResponseMsg(
correlation_id=txn.request.correlation_id,
request_id=txn.request.request_id,
+44
View File
@@ -1,5 +1,6 @@
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Any
@@ -47,3 +48,46 @@ class Transaction:
is_response=self.is_response,
result_data=self.result_data,
)
def into_flits(self, flit_bytes: int) -> Iterator[Flit]:
"""Decompose this Transaction's payload into Flits (ADR-0033 D1).
Yields one Flit per ``flit_bytes`` of payload. The final flit may
carry fewer bytes when ``nbytes`` is not a multiple of ``flit_bytes``;
that flit has ``is_last=True``. Transactions with ``nbytes <= 0``
yield no flits.
All yielded Flits share a reference to this Transaction.
"""
if self.nbytes <= 0 or flit_bytes <= 0:
return
n_full = self.nbytes // flit_bytes
remainder = self.nbytes % flit_bytes
n_total = n_full + (1 if remainder else 0)
for i in range(n_total):
size = flit_bytes if i < n_full else remainder
yield Flit(
txn=self,
flit_index=i,
flit_nbytes=size,
is_last=(i == n_total - 1),
)
@dataclass
class Flit:
"""Atomic wire transport unit (ADR-0033 D1).
Carries a slice of a parent Transaction's payload. The wire
(``engine._wire``) decomposes Transactions into Flits on first
transport; downstream wires pass Flits through with their own
``bw_gbs`` delay.
Phase 2 constraint: ``flit_bytes`` MUST be a multiple of HBM
``burst_bytes`` (default they are equal). See ADR-0033 D1.
"""
txn: Transaction # parent transaction reference
flit_index: int # 0..n_flits-1
flit_nbytes: int # bytes carried (usually flit_bytes; last may be smaller)
is_last: bool # True for the terminating flit
+7 -2
View File
@@ -404,13 +404,18 @@ def _instantiate_cube(
label=name.upper().replace("_", " "),
)
# ── HBM controller (single node, ADR-0019 D1) ──
# ── HBM controller (single node, ADR-0019 D1, ADR-0033) ──
hbm_spec = cube["components"]["hbm_ctrl"]
hbm_lx, hbm_ly = local_pos["hbm_ctrl"]
hbm_id = f"{cp}.hbm_ctrl"
hbm_attrs = dict(hbm_spec["attrs"])
_hbm_total_bw = float(cube["links"].get("hbm_to_router_bw_gbs", 256.0))
_num_pcs = int(hbm_attrs.get("num_pcs", 8))
hbm_attrs["num_pcs"] = _num_pcs
hbm_attrs["pc_bw_gbs"] = _hbm_total_bw / _num_pcs
nodes[hbm_id] = Node(
id=hbm_id, kind=hbm_spec["kind"], impl=hbm_spec["impl"],
attrs=hbm_spec["attrs"], pos_mm=(ox + hbm_lx, oy + hbm_ly),
attrs=hbm_attrs, pos_mm=(ox + hbm_lx, oy + hbm_ly),
label="HBM CTRL",
)