Latency model: HBM PC striping + chunk-loop drain (ADR-0033)
Previous model double-counted slow-upstream paths (e.g., 64KB via UCIe 128 GB/s was ~2x pessimistic). HBM CTRL now distributes bursts across 8 pseudo-channels via global round-robin, with per-chunk commit timing that pipelines correctly against the bottleneck link's data arrival. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from math import ceil
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
@@ -14,68 +15,106 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class HbmCtrlComponent(ComponentBase):
|
||||
"""HBM controller: terminal component that models HBM access latency.
|
||||
"""HBM controller with per-pseudo-channel (PC) striping (ADR-0019 D1, ADR-0033).
|
||||
|
||||
Dual-channel model: separate read and write resources (each capacity=1)
|
||||
allowing concurrent read/write like PE_DMA. Multiple reads or multiple
|
||||
writes still serialize within their respective channel.
|
||||
Stateless per-PC ``available_at`` array; each incoming transaction is
|
||||
split into ``ceil(nbytes / burst_bytes)`` chunks distributed round-robin
|
||||
across ``num_pcs`` PCs starting from a global ``next_pc`` pointer. Read
|
||||
and write share the same PC array (real HW command bus is shared per PC).
|
||||
|
||||
On completion, creates a ResponseMsg and sends it back on the reverse path
|
||||
so that response latency is modeled through the fabric.
|
||||
Chunk-loop drain (ADR-0033 D1, Phase 2b): chunks are scheduled over
|
||||
time at intervals of ``drain_ns / n_chunks`` to model the bottleneck
|
||||
link's data arrival rate. Each chunk's PC commit starts at its arrival
|
||||
time. The last PC commit finishes at ``arrival + drain + commit_time``
|
||||
— naturally producing the correct single-transfer total (drain +
|
||||
commit) without the cut-through over-credit of the prior
|
||||
``env.now - drain_ns`` subtraction.
|
||||
|
||||
Direction switching penalty: when a PC's last direction differs from the
|
||||
current request, ``switch_penalty_ns`` is charged. Default 0 (Tier 0
|
||||
assumption — ideal scheduler amortizes switching cost; ADR-0033 D2).
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._read: simpy.Resource | None = None
|
||||
self._write: simpy.Resource | None = None
|
||||
self._num_pcs: int = 0
|
||||
self._pc_bw_gbs: float = 0.0
|
||||
self._burst_bytes: int = 256
|
||||
self._switch_penalty_ns: float = 0.0
|
||||
self._pc_avail: list[float] = []
|
||||
self._pc_last_dir: list[str | None] = []
|
||||
self._next_pc: int = 0
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
capacity = int(self.node.attrs.get("capacity", 1))
|
||||
self._read = simpy.Resource(env, capacity=capacity)
|
||||
self._write = simpy.Resource(env, capacity=capacity)
|
||||
attrs = self.node.attrs
|
||||
self._num_pcs = int(attrs.get("num_pcs", 8))
|
||||
self._pc_bw_gbs = float(attrs.get("pc_bw_gbs", 32.0))
|
||||
self._burst_bytes = int(attrs.get("burst_bytes", 256))
|
||||
self._switch_penalty_ns = float(attrs.get("switch_penalty_ns", 0.0))
|
||||
self._pc_avail = [0.0] * self._num_pcs
|
||||
self._pc_last_dir = [None] * self._num_pcs
|
||||
self._next_pc = 0
|
||||
super().start(env)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _select_channel(self, txn: Any) -> simpy.Resource:
|
||||
"""Select channel based on request type: write requests → write, else → read."""
|
||||
def _is_write(self, txn: Any) -> bool:
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
|
||||
|
||||
assert self._read is not None and self._write is not None
|
||||
req = txn.request
|
||||
if isinstance(req, MemoryWriteMsg):
|
||||
return self._write
|
||||
return True
|
||||
if isinstance(req, PeDmaMsg) and req.is_write:
|
||||
return self._write
|
||||
return self._read
|
||||
return True
|
||||
return False
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Dispatch each incoming txn to a concurrent process for channel-level parallelism."""
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._handle_txn(env, txn))
|
||||
|
||||
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Acquire channel, run, apply drain, send response."""
|
||||
channel = self._select_channel(txn)
|
||||
with channel.request() as req:
|
||||
yield req
|
||||
yield from self.run(env, txn.nbytes)
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
is_write = self._is_write(txn)
|
||||
new_dir = "W" if is_write else "R"
|
||||
chunk_time = (
|
||||
self._burst_bytes / self._pc_bw_gbs if self._pc_bw_gbs > 0 else 0.0
|
||||
)
|
||||
# MemoryReadMsg forwards command with nbytes=0; the actual data work
|
||||
# is sized by request.nbytes (data returns via reverse-path response).
|
||||
work_bytes = txn.nbytes if txn.nbytes > 0 else int(getattr(txn.request, "nbytes", 0) or 0)
|
||||
n_chunks = max(1, ceil(work_bytes / self._burst_bytes)) if work_bytes > 0 else 0
|
||||
|
||||
drain = float(getattr(txn, "drain_ns", 0.0))
|
||||
chunk_interval = (drain / n_chunks) if (n_chunks > 0 and drain > 0) else 0.0
|
||||
|
||||
yield from self.run(env, txn.nbytes)
|
||||
|
||||
last_finish = env.now
|
||||
for i in range(n_chunks):
|
||||
if chunk_interval > 0:
|
||||
yield env.timeout(chunk_interval)
|
||||
pc = (self._next_pc + i) % self._num_pcs
|
||||
switch_cost = 0.0
|
||||
if self._pc_last_dir[pc] is not None and self._pc_last_dir[pc] != new_dir:
|
||||
switch_cost = self._switch_penalty_ns
|
||||
start = max(env.now, self._pc_avail[pc]) + switch_cost
|
||||
finish = start + chunk_time
|
||||
self._pc_avail[pc] = finish
|
||||
self._pc_last_dir[pc] = new_dir
|
||||
if finish > last_finish:
|
||||
last_finish = finish
|
||||
if n_chunks > 0:
|
||||
self._next_pc = (self._next_pc + n_chunks) % self._num_pcs
|
||||
|
||||
wait = last_finish - env.now
|
||||
if wait > 0:
|
||||
yield env.timeout(wait)
|
||||
|
||||
yield from self._send_response(env, txn)
|
||||
|
||||
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Route completion based on path type.
|
||||
|
||||
- PeDmaMsg: succeed done directly (probe).
|
||||
- Bypass path (no m_cpu): MemoryWrite succeeds done; MemoryRead sends
|
||||
data back on reverse path with original done event.
|
||||
- M_CPU DMA path: send ResponseMsg for m_cpu/io_cpu aggregation.
|
||||
"""
|
||||
from kernbench.runtime_api.kernel import MemoryReadMsg, PeDmaMsg
|
||||
|
||||
if isinstance(txn.request, PeDmaMsg):
|
||||
@@ -90,11 +129,9 @@ class HbmCtrlComponent(ComponentBase):
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Bypass path: no m_cpu in the transaction path
|
||||
is_bypass = not any("m_cpu" in n for n in txn.path)
|
||||
if is_bypass:
|
||||
if isinstance(txn.request, MemoryReadMsg):
|
||||
# D2H: send data back on reverse path to pcie_ep
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
resp_txn = Transaction(
|
||||
@@ -103,18 +140,16 @@ class HbmCtrlComponent(ComponentBase):
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
return
|
||||
# MemoryWrite bypass or short path: done
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# M_CPU DMA path: send ResponseMsg for aggregation
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2 and self.ctx:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
pe_id = 0 # single hbm_ctrl, PE info from request
|
||||
pe_id = 0
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=txn.request.correlation_id,
|
||||
request_id=txn.request.request_id,
|
||||
|
||||
Reference in New Issue
Block a user