from __future__ import annotations from collections.abc import Generator from typing import TYPE_CHECKING, Any import simpy from kernbench.components.base import PeEngineBase from kernbench.sim_engine.transaction import Transaction if TYPE_CHECKING: from kernbench.common.pe_commands import PeInternalTxn from kernbench.components.context import ComponentContext from kernbench.topology.types import Node class PeDmaComponent(PeEngineBase): """PE_DMA: dual-channel DMA engine with READ and WRITE resources. Compute channels (vc_compute) have capacity=1 each (ADR-0014 D4): - DMA_READ and DMA_WRITE may execute concurrently. - Multiple READs cannot overlap; multiple WRITEs cannot overlap. The orthogonal vc_comm channel for IPCQ traffic is defined in ADR-0023 D8. Handles two message types: - Transaction: external fabric messages (PeDmaMsg probes, M_CPU DMA) - PeInternalTxn: PE-internal commands from PE_SCHEDULER (DmaReadCmd → HBM read, DmaWriteCmd → HBM write) """ # Defer op_log record_start until AFTER the DMA channel is acquired so # t_start reflects the serve-start moment (post queueing) rather than # the queue-enter moment. ComponentBase._handle_with_hooks consults this # flag. _DEFER_RECORD_START = True def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) self._dma_read: simpy.Resource | None = None self._dma_write: simpy.Resource | None = None self._mmu = None # PeMMU instance, set by engine wiring def init_resources(self, env: simpy.Environment) -> None: self._dma_read = simpy.Resource(env, capacity=1) self._dma_write = simpy.Resource(env, capacity=1) def run(self, env: simpy.Environment, nbytes: int) -> Generator: yield env.timeout(0) def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: """Handle PE-internal DMA command: resolve PA → HBM path → transfer.""" from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd from kernbench.policy.address.phyaddr import PhysAddr from kernbench.runtime_api.kernel import PeDmaMsg cmd = pe_txn.command assert self._dma_read is not None and self._dma_write is not None # Determine direction and target address (VA → PA via MMU) if isinstance(cmd, DmaReadCmd): dma_res = self._dma_read raw_addr = cmd.src_addr is_write = False elif isinstance(cmd, DmaWriteCmd): dma_res = self._dma_write raw_addr = cmd.dst_addr is_write = True else: pe_txn.done.succeed() return # Translate VA → PA via MMU (if available), then resolve HBM node # If MMU has no mapping for this address (PageFault), treat as PA directly # (backward-compatible with PA-only mode) if self._mmu is not None: from kernbench.policy.address.pe_mmu import PageFault try: target_pa = self._mmu.translate(raw_addr) if self._mmu.overhead_ns > 0: yield env.timeout(self._mmu.overhead_ns) except PageFault: target_pa = raw_addr else: target_pa = raw_addr # fallback: treat as PA directly pa = PhysAddr.decode(target_pa) dst_node = self.ctx.resolver.resolve(pa) path = self.ctx.router.find_path(self._pe_prefix, dst_node) drain_ns = self.ctx.compute_drain_ns(path, cmd.nbytes) # Acquire DMA channel — held through the entire round-trip so the # channel models "one DMA in flight per PE per direction" rather # than just issue-time serialization. This is what makes Option B # meaningful: t_start = serve-start covers the actual transfer. with dma_res.request() as req: yield req # Option B: record_start fires AFTER channel acquired, so t_start # = serve-start (excludes queue wait). _DEFER_RECORD_START=True # suppresses the auto-start in ComponentBase._handle_with_hooks. self._on_process_start(env, cmd) # Create sub-Transaction with PeDmaMsg (HbmCtrl handles it directly) sub_done = env.event() sub_request = PeDmaMsg( correlation_id="pe_internal", request_id=f"dma_{id(pe_txn)}", src_sip=0, src_cube=0, src_pe=0, dst_pa=target_pa, nbytes=cmd.nbytes, is_write=is_write, ) sub_txn = Transaction( request=sub_request, path=path, step=0, nbytes=cmd.nbytes, done=sub_done, drain_ns=drain_ns, ) # Send to next hop (path[0] is pe_dma itself, path[1] is router) if len(path) > 1: yield self.out_ports[path[1]].put(sub_txn.advance()) # Wait for HBM transfer completion BEFORE releasing the channel. yield sub_done pe_txn.done.succeed() def _worker(self, env: simpy.Environment) -> Generator: """Handle TileToken (pipeline), PeInternalTxn (legacy), IpcqDmaToken, and Transaction (fabric).""" from kernbench.common.ipcq_types import IpcqDmaToken from kernbench.common.pe_commands import PeInternalTxn from kernbench.components.builtin.pe_types import TileToken while True: msg: Any = yield self._inbox.get() if isinstance(msg, IpcqDmaToken): # Outbound: IPCQ token from local PE_IPCQ → forward via fabric env.process(self._handle_ipcq_outbound(env, msg)) elif isinstance(msg, TileToken): env.process(self._pipeline_process(env, msg)) elif isinstance(msg, PeInternalTxn): env.process(self._handle_with_hooks(env, msg)) else: # Transaction (or unknown). May carry IpcqDmaToken inbound. req = getattr(msg, "request", None) if isinstance(req, IpcqDmaToken): env.process(self._handle_ipcq_inbound(env, msg)) else: env.process(self._forward_txn(env, msg)) # ── IPCQ outbound (PE_IPCQ → PE_DMA → fabric) ─────────────────── def _handle_ipcq_outbound(self, env: simpy.Environment, token: Any) -> Generator: """Forward IpcqDmaToken from local PE_IPCQ through the fabric to peer PE_DMA. ADR-0023 D8 (vc_comm channel).""" if self.ctx is None: return # nothing to do peer = token.dst_endpoint peer_pe_dma = f"sip{peer.sip}.cube{peer.cube}.pe{peer.pe}.pe_dma" # Snapshot the source data at send time (D9 in-flight semantics). # Without this, the receiver could read stale or future data if the # sender mutates src_addr between send issue and DMA arrival. store = getattr(self.ctx, "memory_store", None) if store is not None and token.data is None: try: snap = store.read( token.src_space, token.src_addr, shape=token.shape, dtype=token.dtype, ) # Copy so later mutations to src_addr don't affect the snapshot. token.data = snap.copy() if hasattr(snap, "copy") else snap except Exception: token.data = None # Note: ipcq_copy is recorded at INBOUND time (in _handle_ipcq_inbound), # not here. Outbound time is too early — it precedes fabric propagation, # so in Phase 2 a later round's copy can sort before the receiver's # math for an earlier round, causing slot data corruption. # The secondary sort in DataExecutor (memory ops before math at the # same t_start) ensures the inbound copy runs before the local math # that reads the slot. try: path = self.ctx.router.find_path(self._pe_prefix, peer_pe_dma) except Exception: return drain_ns = self.ctx.compute_drain_ns(path, token.nbytes) sub_done = env.event() sub_txn = Transaction( request=token, path=path, step=0, nbytes=token.nbytes, done=sub_done, drain_ns=drain_ns, ) if len(path) > 1: next_hop = path[1] if next_hop in self.out_ports: yield self.out_ports[next_hop].put(sub_txn.advance()) else: return # Note: don't wait on sub_done here — fire-and-forget for vc_comm. # IPCQ slot bookkeeping (peer_head) was already updated by PE_IPCQ; # backpressure is via credit return, not via this DMA's completion. # ── IPCQ inbound (fabric → PE_DMA → MemoryStore + PE_IPCQ) ────── def _handle_ipcq_inbound(self, env: simpy.Environment, txn: Any) -> Generator: """At destination PE_DMA: pay terminal drain, then atomically write data and forward metadata. ADR-0023 D9 (drain at inbound terminal): the Transaction carries ``drain_ns = nbytes / bottleneck_bw_on_path`` stamped by the sender PE_DMA. Like every other Transaction terminal in the simulator (see ``ComponentBase._forward_txn``), this drain must be paid when the Transaction reaches its destination. SRC-side ``tl.send`` is fire-and-forget — it never yields on ``sub_done`` — so paying the drain here does NOT delay the sender. What it DOES delay is the IpcqMetaArrival forwarded below: that delay is the only signal ``tl.recv`` on DST blocks on, which is exactly the desired semantics — "send dispatches and returns; recv waits until the bytes have actually landed in its inbox". The drain MUST be paid before the atomic block — inserting a yield inside would break invariant I6. I6 (MUST): no SimPy yield between MemoryStore.write and the IpcqMetaArrival put into PE_IPCQ. """ from kernbench.common.ipcq_types import IpcqMetaArrival # Pay terminal BW drain before the atomic write/metadata forward. # Without this, IPCQ effectively got fabric bandwidth for free at # the terminal (only intermediate-hop overhead_ns was charged), # making IPCQ lower than raw DMA at large sizes in benchmarks. drain = getattr(txn, "drain_ns", 0.0) if drain > 0: yield env.timeout(drain) token = txn.request # ADR-0023 D9.7: charge IPCQ slot-WRITE latency against the # backing-memory tier (tcm/sram/hbm) before the atomic block. # Must come BEFORE the atomic write→IpcqMetaArrival pair (I6). # SRAM/HBM also pay a PE_DMA→bank fabric drain (slot lives on # the cube NoC); TCM is per-PE local and skips this hop. from kernbench.common.ipcq_types import slot_io_latency_ns buffer_kind = token.dst_endpoint.buffer_kind if buffer_kind in ("sram", "hbm") and self.ctx is not None: cube_prefix = self._pe_prefix.rsplit(".", 1)[0] bank_node = ( f"{cube_prefix}.sram" if buffer_kind == "sram" else f"{cube_prefix}.hbm_ctrl" ) try: path = self.ctx.router.find_path(self._pe_prefix, bank_node) bank_drain_ns = self.ctx.compute_drain_ns(path, token.nbytes) if bank_drain_ns > 0: yield env.timeout(bank_drain_ns) except Exception: pass slot_write_ns = slot_io_latency_ns(buffer_kind, token.nbytes) if slot_write_ns > 0: yield env.timeout(slot_write_ns) # ── ATOMIC: do not introduce yield between these two operations ── # 1. Move data via MemoryStore (single-hop DMA write). # Prefer the in-flight snapshot stashed by the sender PE_DMA; # fall back to a fresh read of src_addr if no snapshot is present # (e.g. control-only token). store = getattr(self.ctx, "memory_store", None) if self.ctx else None if store is not None: try: data = token.data if data is None: data = store.read( token.src_space, token.src_addr, shape=token.shape, dtype=token.dtype, ) store.write(token.dst_endpoint.buffer_kind, token.dst_addr, data) except Exception: pass # Record the IPCQ copy at INBOUND time with embedded data snapshot. # The snapshot (token.data) was captured by the sender's outbound # PE_DMA at send time. Phase 2 writes the snapshot directly to # dst — it does NOT re-read from MemoryStore[src_addr], which may # have been mutated by a different PE's Phase 2 ops by that point. # DataExecutor's secondary sort (memory before math at same # t_start) ensures the write completes before the local math # that reads the slot. if self._op_logger is not None: try: self._op_logger.record_copy( t_start=float(env.now), t_end=float(env.now), component_id=self.node.id, src_space=token.src_space, src_addr=token.src_addr, dst_space=token.dst_endpoint.buffer_kind, dst_addr=token.dst_addr, shape=token.shape, dtype=token.dtype, nbytes=token.nbytes, snapshot=token.data, ) except Exception: pass # 2. Forward IpcqMetaArrival to local PE_IPCQ ipcq_id = f"{self._pe_prefix}.pe_ipcq" if ipcq_id in self.out_ports: yield self.out_ports[ipcq_id].put(IpcqMetaArrival(token=token)) # ───────────────────────────────────────────────────────────────── if not txn.done.triggered: txn.done.succeed() def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator: """Pipeline mode: DMA read/write via fabric, then self-route. Option B: record_start is fired *inside* _do_pipeline_dma, after the DMA channel is acquired — record_end stays here. """ yield from self._do_pipeline_dma(env, token) self._on_process_end(env, token) # Self-routing (handle same-component consecutive stages) next_stage = token.advance() while next_stage is not None and next_stage.component == self.node.id: yield from self._do_pipeline_dma(env, token) self._on_process_end(env, token) next_stage = token.advance() if next_stage is not None: yield self.out_ports[next_stage.component].put(token) else: token.pipeline_ctx.complete_tile() def _do_pipeline_dma(self, env, token): """Core DMA logic for pipeline mode.""" from kernbench.policy.address.phyaddr import PhysAddr from kernbench.runtime_api.kernel import PeDmaMsg params = token.params from kernbench.components.builtin.pe_types import StageType is_write = token.current_stage.stage_type == StageType.DMA_WRITE addr = params.get("dst_addr" if is_write else "src_addr", 0) nbytes = params.get("nbytes", 0) if nbytes > 0 and self.ctx: dma_res = self._dma_write if is_write else self._dma_read assert dma_res is not None # Translate VA → PA via MMU (same logic as non-pipeline path) target_pa = addr if self._mmu is not None: from kernbench.policy.address.pe_mmu import PageFault try: target_pa = self._mmu.translate(addr) except PageFault: target_pa = addr # fallback: treat as PA directly pa = PhysAddr.decode(target_pa) dst_node = self.ctx.resolver.resolve(pa) path = self.ctx.router.find_path(self._pe_prefix, dst_node) drain_ns = self.ctx.compute_drain_ns(path, nbytes) # Hold dma_res through the full round-trip — one DMA in flight # per PE per direction — so Option B's t_start (post-acquire) # bounds the actual transfer interval. with dma_res.request() as req: yield req # Option B: t_start = post-acquire moment. self._on_process_start(env, token) sub_done = env.event() sub_request = PeDmaMsg( correlation_id="pipeline", request_id=f"tile_{token.tile_id}", src_sip=0, src_cube=0, src_pe=0, dst_pa=target_pa, nbytes=nbytes, is_write=is_write, ) sub_txn = Transaction( request=sub_request, path=path, step=0, nbytes=nbytes, done=sub_done, drain_ns=drain_ns, ) if len(path) > 1: yield self.out_ports[path[1]].put(sub_txn.advance()) yield sub_done else: # No-op (nbytes==0 or no ctx): no channel wait, but still record # so _on_process_end has a matching pending entry to finalise. self._on_process_start(env, token) def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator: """Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition.""" # Response transactions bypass DMA channel (no outbound resource needed) if getattr(txn, "is_response", False): next_hop = txn.next_hop if next_hop: yield self.out_ports[next_hop].put(txn.advance()) else: txn.done.succeed() return dma_res = self._select_channel(txn) with dma_res.request() as req: yield req next_hop = txn.next_hop if next_hop: yield self.out_ports[next_hop].put(txn.advance()) else: drain = getattr(txn, "drain_ns", 0.0) if drain > 0: yield env.timeout(drain) txn.done.succeed() def _select_channel(self, txn: Any) -> simpy.Resource: """Select DMA channel based on request type.""" from kernbench.runtime_api.kernel import MemoryWriteMsg assert self._dma_read is not None and self._dma_write is not None if isinstance(txn.request, MemoryWriteMsg): return self._dma_write return self._dma_read