from __future__ import annotations from collections.abc import Generator from typing import TYPE_CHECKING, Any import simpy from kernbench.components.base import PeEngineBase from kernbench.sim_engine.transaction import Transaction if TYPE_CHECKING: from kernbench.common.pe_commands import PeInternalTxn from kernbench.components.context import ComponentContext from kernbench.topology.types import Node class PeDmaComponent(PeEngineBase): """PE_DMA: dual-channel DMA engine with READ and WRITE resources. Each channel has capacity=1 (ADR-0014 D4): - DMA_READ and DMA_WRITE may execute concurrently. - Multiple READs cannot overlap; multiple WRITEs cannot overlap. Handles two message types: - Transaction: external fabric messages (PeDmaMsg probes, M_CPU DMA) - PeInternalTxn: PE-internal commands from PE_SCHEDULER (DmaReadCmd → HBM read, DmaWriteCmd → HBM write) """ def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) self._dma_read: simpy.Resource | None = None self._dma_write: simpy.Resource | None = None self._mmu = None # PeMMU instance, set by engine wiring def init_resources(self, env: simpy.Environment) -> None: self._dma_read = simpy.Resource(env, capacity=1) self._dma_write = simpy.Resource(env, capacity=1) def run(self, env: simpy.Environment, nbytes: int) -> Generator: yield env.timeout(0) def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator: """Handle PE-internal DMA command: resolve PA → HBM path → transfer.""" from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd from kernbench.policy.address.phyaddr import PhysAddr from kernbench.runtime_api.kernel import PeDmaMsg cmd = pe_txn.command assert self._dma_read is not None and self._dma_write is not None # Determine direction and target address (VA → PA via MMU) if isinstance(cmd, DmaReadCmd): dma_res = self._dma_read raw_addr = cmd.src_addr is_write = False elif isinstance(cmd, DmaWriteCmd): dma_res = self._dma_write raw_addr = cmd.dst_addr is_write = True else: pe_txn.done.succeed() return # Translate VA → PA via MMU (if available), then resolve HBM node # If MMU has no mapping for this address (PageFault), treat as PA directly # (backward-compatible with PA-only mode) if self._mmu is not None: from kernbench.policy.address.pe_mmu import PageFault try: target_pa = self._mmu.translate(raw_addr) if self._mmu.overhead_ns > 0: yield env.timeout(self._mmu.overhead_ns) except PageFault: target_pa = raw_addr else: target_pa = raw_addr # fallback: treat as PA directly pa = PhysAddr.decode(target_pa) dst_node = self.ctx.resolver.resolve(pa) path = self.ctx.router.find_path(self._pe_prefix, dst_node) drain_ns = self.ctx.compute_drain_ns(path, cmd.nbytes) # Acquire DMA channel (command issue serialization) with dma_res.request() as req: yield req # Create sub-Transaction with PeDmaMsg (HbmCtrl handles it directly) sub_done = env.event() sub_request = PeDmaMsg( correlation_id="pe_internal", request_id=f"dma_{id(pe_txn)}", src_sip=0, src_cube=0, src_pe=0, dst_pa=target_pa, nbytes=cmd.nbytes, is_write=is_write, ) sub_txn = Transaction( request=sub_request, path=path, step=0, nbytes=cmd.nbytes, done=sub_done, drain_ns=drain_ns, ) # Send to next hop (path[0] is pe_dma itself, path[1] is router) if len(path) > 1: yield self.out_ports[path[1]].put(sub_txn.advance()) # DMA channel released after issue # Wait for HBM transfer completion yield sub_done pe_txn.done.succeed() def _worker(self, env: simpy.Environment) -> Generator: """Handle TileToken (pipeline), PeInternalTxn (legacy), IpcqDmaToken, and Transaction (fabric).""" from kernbench.common.ipcq_types import IpcqDmaToken from kernbench.common.pe_commands import PeInternalTxn from kernbench.components.builtin.pe_types import TileToken while True: msg: Any = yield self._inbox.get() if isinstance(msg, IpcqDmaToken): # Outbound: IPCQ token from local PE_IPCQ → forward via fabric env.process(self._handle_ipcq_outbound(env, msg)) elif isinstance(msg, TileToken): env.process(self._pipeline_process(env, msg)) elif isinstance(msg, PeInternalTxn): env.process(self._handle_with_hooks(env, msg)) else: # Transaction (or unknown). May carry IpcqDmaToken inbound. req = getattr(msg, "request", None) if isinstance(req, IpcqDmaToken): env.process(self._handle_ipcq_inbound(env, msg)) else: env.process(self._forward_txn(env, msg)) # ── IPCQ outbound (PE_IPCQ → PE_DMA → fabric) ─────────────────── def _handle_ipcq_outbound(self, env: simpy.Environment, token: Any) -> Generator: """Forward IpcqDmaToken from local PE_IPCQ through the fabric to peer PE_DMA. ADR-0023 D8 (vc_comm channel).""" if self.ctx is None: return # nothing to do peer = token.dst_endpoint peer_pe_dma = f"sip{peer.sip}.cube{peer.cube}.pe{peer.pe}.pe_dma" # Snapshot the source data at send time (D9 in-flight semantics). # Without this, the receiver could read stale or future data if the # sender mutates src_addr between send issue and DMA arrival. store = getattr(self.ctx, "memory_store", None) if store is not None and token.data is None: try: snap = store.read( token.src_space, token.src_addr, shape=token.shape, dtype=token.dtype, ) # Copy so later mutations to src_addr don't affect the snapshot. token.data = snap.copy() if hasattr(snap, "copy") else snap except Exception: token.data = None # Note: ipcq_copy is recorded at INBOUND time (in _handle_ipcq_inbound), # not here. Outbound time is too early — it precedes fabric propagation, # so in Phase 2 a later round's copy can sort before the receiver's # math for an earlier round, causing slot data corruption. # The secondary sort in DataExecutor (memory ops before math at the # same t_start) ensures the inbound copy runs before the local math # that reads the slot. try: path = self.ctx.router.find_path(self._pe_prefix, peer_pe_dma) except Exception: return drain_ns = self.ctx.compute_drain_ns(path, token.nbytes) sub_done = env.event() sub_txn = Transaction( request=token, path=path, step=0, nbytes=token.nbytes, done=sub_done, drain_ns=drain_ns, ) if len(path) > 1: next_hop = path[1] if next_hop in self.out_ports: yield self.out_ports[next_hop].put(sub_txn.advance()) else: return # Note: don't wait on sub_done here — fire-and-forget for vc_comm. # IPCQ slot bookkeeping (peer_head) was already updated by PE_IPCQ; # backpressure is via credit return, not via this DMA's completion. # ── IPCQ inbound (fabric → PE_DMA → MemoryStore + PE_IPCQ) ────── def _handle_ipcq_inbound(self, env: simpy.Environment, txn: Any) -> Generator: """At destination PE_DMA: pay terminal drain, then atomically write data and forward metadata. ADR-0023 D9 (drain at inbound terminal): the Transaction carries ``drain_ns = nbytes / bottleneck_bw_on_path`` stamped by the sender PE_DMA. Like every other Transaction terminal in the simulator (see ``ComponentBase._forward_txn``), this drain must be paid when the Transaction reaches its destination. SRC-side ``tl.send`` is fire-and-forget — it never yields on ``sub_done`` — so paying the drain here does NOT delay the sender. What it DOES delay is the IpcqMetaArrival forwarded below: that delay is the only signal ``tl.recv`` on DST blocks on, which is exactly the desired semantics — "send dispatches and returns; recv waits until the bytes have actually landed in its inbox". The drain MUST be paid before the atomic block — inserting a yield inside would break invariant I6. I6 (MUST): no SimPy yield between MemoryStore.write and the IpcqMetaArrival put into PE_IPCQ. """ from kernbench.common.ipcq_types import IpcqMetaArrival # Pay terminal BW drain before the atomic write/metadata forward. # Without this, IPCQ effectively got fabric bandwidth for free at # the terminal (only intermediate-hop overhead_ns was charged), # making IPCQ lower than raw DMA at large sizes in benchmarks. drain = getattr(txn, "drain_ns", 0.0) if drain > 0: yield env.timeout(drain) token = txn.request # ── ATOMIC: do not introduce yield between these two operations ── # 1. Move data via MemoryStore (single-hop DMA write). # Prefer the in-flight snapshot stashed by the sender PE_DMA; # fall back to a fresh read of src_addr if no snapshot is present # (e.g. control-only token). store = getattr(self.ctx, "memory_store", None) if self.ctx else None if store is not None: try: data = token.data if data is None: data = store.read( token.src_space, token.src_addr, shape=token.shape, dtype=token.dtype, ) store.write(token.dst_endpoint.buffer_kind, token.dst_addr, data) except Exception: pass # Record the IPCQ copy at INBOUND time with embedded data snapshot. # The snapshot (token.data) was captured by the sender's outbound # PE_DMA at send time. Phase 2 writes the snapshot directly to # dst — it does NOT re-read from MemoryStore[src_addr], which may # have been mutated by a different PE's Phase 2 ops by that point. # DataExecutor's secondary sort (memory before math at same # t_start) ensures the write completes before the local math # that reads the slot. if self._op_logger is not None: try: self._op_logger.record_copy( t_start=float(env.now), t_end=float(env.now), component_id=self.node.id, src_space=token.src_space, src_addr=token.src_addr, dst_space=token.dst_endpoint.buffer_kind, dst_addr=token.dst_addr, shape=token.shape, dtype=token.dtype, nbytes=token.nbytes, snapshot=token.data, ) except Exception: pass # 2. Forward IpcqMetaArrival to local PE_IPCQ ipcq_id = f"{self._pe_prefix}.pe_ipcq" if ipcq_id in self.out_ports: yield self.out_ports[ipcq_id].put(IpcqMetaArrival(token=token)) # ───────────────────────────────────────────────────────────────── if not txn.done.triggered: txn.done.succeed() def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator: """Pipeline mode: DMA read/write via fabric, then self-route.""" self._on_process_start(env, token) yield from self._do_pipeline_dma(env, token) self._on_process_end(env, token) # Self-routing (handle same-component consecutive stages) next_stage = token.advance() while next_stage is not None and next_stage.component == self.node.id: self._on_process_start(env, token) yield from self._do_pipeline_dma(env, token) self._on_process_end(env, token) next_stage = token.advance() if next_stage is not None: yield self.out_ports[next_stage.component].put(token) else: token.pipeline_ctx.complete_tile() def _do_pipeline_dma(self, env, token): """Core DMA logic for pipeline mode.""" from kernbench.policy.address.phyaddr import PhysAddr from kernbench.runtime_api.kernel import PeDmaMsg params = token.params from kernbench.components.builtin.pe_types import StageType is_write = token.current_stage.stage_type == StageType.DMA_WRITE addr = params.get("dst_addr" if is_write else "src_addr", 0) nbytes = params.get("nbytes", 0) if nbytes > 0 and self.ctx: dma_res = self._dma_write if is_write else self._dma_read assert dma_res is not None pa = PhysAddr.decode(addr) dst_node = self.ctx.resolver.resolve(pa) path = self.ctx.router.find_path(self._pe_prefix, dst_node) drain_ns = self.ctx.compute_drain_ns(path, nbytes) with dma_res.request() as req: yield req sub_done = env.event() sub_request = PeDmaMsg( correlation_id="pipeline", request_id=f"tile_{token.tile_id}", src_sip=0, src_cube=0, src_pe=0, dst_pa=addr, nbytes=nbytes, is_write=is_write, ) sub_txn = Transaction( request=sub_request, path=path, step=0, nbytes=nbytes, done=sub_done, drain_ns=drain_ns, ) if len(path) > 1: yield self.out_ports[path[1]].put(sub_txn.advance()) yield sub_done def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator: """Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition.""" # Response transactions bypass DMA channel (no outbound resource needed) if getattr(txn, "is_response", False): next_hop = txn.next_hop if next_hop: yield self.out_ports[next_hop].put(txn.advance()) else: txn.done.succeed() return dma_res = self._select_channel(txn) with dma_res.request() as req: yield req next_hop = txn.next_hop if next_hop: yield self.out_ports[next_hop].put(txn.advance()) else: drain = getattr(txn, "drain_ns", 0.0) if drain > 0: yield env.timeout(drain) txn.done.succeed() def _select_channel(self, txn: Any) -> simpy.Resource: """Select DMA channel based on request type.""" from kernbench.runtime_api.kernel import MemoryWriteMsg assert self._dma_read is not None and self._dma_write is not None if isinstance(txn.request, MemoryWriteMsg): return self._dma_write return self._dma_read