"""PE_IPCQ component (ADR-0023): per-PE IPCQ control plane. Responsibilities: - Hold per-direction queue pair state (my_head, my_tail, peer_head_cache, peer_tail_cache, ring buffer addresses) - Process IpcqInitMsg from backend to install neighbor table - Handle IpcqRequest(IpcqSendCmd) from PE_CPU: compute peer slot address, check backpressure, forward IpcqDmaToken to PE_DMA (vc_comm) - Handle IpcqRequest(IpcqRecvCmd) from PE_CPU: wait for data arrival, return slot address (or copy to dst), send fast-path credit return - Handle IpcqMetaArrival from PE_DMA: update peer_head_cache, wake recv - Handle IpcqCreditMetadata via own credit_inbox: update peer_tail_cache, wake send PE_IPCQ does NOT move data — it forwards IpcqDmaToken to PE_DMA which performs the actual fabric DMA. Credit return uses a fast path: PE_IPCQ creates a SimPy process with a bottleneck-BW based latency, then puts IpcqCreditMetadata directly into the peer's pre-wired credit_store. """ from __future__ import annotations from collections.abc import Generator from typing import TYPE_CHECKING, Any import simpy from kernbench.common.ipcq_types import ( IpcqCreditMetadata, IpcqDmaToken, IpcqInvalidDirection, IpcqMetaArrival, IpcqRecvCmd, IpcqRequest, IpcqSendCmd, ) from kernbench.components.base import ComponentBase if TYPE_CHECKING: from kernbench.components.context import ComponentContext from kernbench.runtime_api.kernel import IpcqInitMsg from kernbench.topology.types import Node _DIR_ORDER: tuple[str, ...] = ("N", "S", "E", "W", "parent", "child_left", "child_right") class PeIpcqComponent(ComponentBase): """PE_IPCQ: ring buffer pointer + neighbor management for CCL. Owned by one PE; talks to PE_DMA via out_ports[] and receives credit return metadata via the public ``credit_inbox`` SimPy Store (wired by backend at IpcqInitMsg installation time). """ def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None: super().__init__(node, ctx) # Strict shape/dtype validation (D14 F2). Off by default. self._strict: bool = bool(node.attrs.get("strict_validation", False)) # direction → list of received tokens (for strict-mode peek of next slot) self._arrived_tokens: dict[str, list] = {} # Parse self (sip, cube, pe) from node id, e.g. "sip0.cube0.pe0.pe_ipcq" self._pe_prefix: str = node.id.rsplit(".", 1)[0] parts = self._pe_prefix.split(".") try: self._self_sip = int(parts[0].replace("sip", "")) except (IndexError, ValueError): self._self_sip = 0 try: self._self_cube = int(parts[1].replace("cube", "")) except (IndexError, ValueError): self._self_cube = 0 try: self._self_pe = int(parts[2].replace("pe", "")) except (IndexError, ValueError): self._self_pe = 0 self._dma_node_id = f"{self._pe_prefix}.pe_dma" # direction → state dict (see _install_neighbors for shape) self._queue_pairs: dict[str, dict[str, Any]] = {} self._installed = False self._buffer_kind: str = "tcm" self._backpressure_mode: str = "sleep" self._credit_size_bytes: int = 16 # waiters for recv (per direction) and any-direction (for round-robin) self._recv_waiters: dict[str, list[simpy.Event]] = {} self._any_recv_waiters: list[simpy.Event] = [] # waiters for send backpressure (per direction) self._send_waiters: dict[str, list[simpy.Event]] = {} # round-robin cursor over installed directions self._rr_dirs: list[str] = [] self._rr_cursor: int = 0 # credit_inbox is created in start() once env is available self._credit_inbox: simpy.Store | None = None # ── Public ── @property def credit_inbox(self) -> simpy.Store: """SimPy Store that backend wires as ``peer_credit_store`` on every remote sender targeting this PE. Used by D9 fast path.""" assert self._credit_inbox is not None, "PE_IPCQ not started yet" return self._credit_inbox @property def queue_pairs(self) -> dict[str, dict[str, Any]]: """Test/debug accessor.""" return self._queue_pairs # ── Lifecycle ── def run(self, env: simpy.Environment, nbytes: int) -> Generator: yield env.timeout(0) def start(self, env: simpy.Environment) -> None: # Create credit_inbox even if there are no in_ports yet if self._credit_inbox is None: self._credit_inbox = simpy.Store(env) # If no in_ports were wired (e.g. unit test), still spin up workers if not self.in_ports: self._inbox = simpy.Store(env) super().start(env) env.process(self._credit_worker(env)) # ── Worker (override of ComponentBase._worker) ── def _worker(self, env: simpy.Environment) -> Generator: from kernbench.runtime_api.kernel import IpcqInitMsg while True: msg: Any = yield self._inbox.get() # IpcqInitMsg may arrive wrapped in a transaction (with .request) # or directly. request_obj = getattr(msg, "request", None) if isinstance(request_obj, IpcqInitMsg): self._install_neighbors(request_obj) done = getattr(msg, "done", None) if done is not None and not done.triggered: done.succeed() continue if isinstance(msg, IpcqInitMsg): self._install_neighbors(msg) continue if isinstance(msg, IpcqMetaArrival): self._handle_meta_arrival(msg) continue if isinstance(msg, IpcqRequest): env.process(self._handle_request(env, msg)) continue # Unknown message — drop or forward via base class fallback env.process(self._forward_txn(env, msg)) # ── Init ── def _install_neighbors(self, msg: IpcqInitMsg) -> None: self._installed = True self._buffer_kind = msg.buffer_kind self._backpressure_mode = msg.backpressure_mode self._credit_size_bytes = msg.credit_size_bytes for entry in msg.entries: self._queue_pairs[entry.direction] = { "peer": entry.peer, "my_rx_base_pa": entry.my_rx_base_pa, "my_rx_base_va": entry.my_rx_base_va, "n_slots": entry.n_slots, "slot_size": entry.slot_size, "peer_credit_store": entry.peer_credit_store, "my_head": 0, "my_tail": 0, "peer_head_cache": 0, "peer_tail_cache": 0, } self._recv_waiters.setdefault(entry.direction, []) self._send_waiters.setdefault(entry.direction, []) # Reset round-robin order to a stable canonical sequence self._rr_dirs = [d for d in _DIR_ORDER if d in self._queue_pairs] self._rr_cursor = 0 # ── Send ── def _handle_request(self, env: simpy.Environment, req: IpcqRequest) -> Generator: cmd = req.command if isinstance(cmd, IpcqSendCmd): yield from self._handle_send(env, req, cmd) elif isinstance(cmd, IpcqRecvCmd): yield from self._handle_recv(env, req, cmd) def _handle_send( self, env: simpy.Environment, req: IpcqRequest, cmd: IpcqSendCmd, ) -> Generator: if cmd.direction not in self._queue_pairs: raise IpcqInvalidDirection( f"PE {self._pe_prefix}: direction {cmd.direction!r} not installed" ) qp = self._queue_pairs[cmd.direction] peer = qp["peer"] # Backpressure: wait while ring full while (qp["my_head"] - qp["peer_tail_cache"]) >= peer.n_slots: wait_event = env.event() self._send_waiters[cmd.direction].append(wait_event) yield wait_event # Compute peer slot address slot_idx = qp["my_head"] % peer.n_slots dst_pa = peer.rx_base_pa + slot_idx * peer.slot_size token = IpcqDmaToken( src_addr=cmd.src_addr, src_space=cmd.src_space, dst_addr=dst_pa, dst_endpoint=peer, nbytes=cmd.nbytes, handle_id=cmd.handle_id, shape=cmd.shape, dtype=cmd.dtype, # Carry the handle's recv-time data snapshot so the outbound # PE_DMA doesn't need to re-read from MemoryStore (which may # have been overwritten by a later inbound in the meantime). data=getattr(cmd, "data", None), sender_seq=qp["my_head"], src_sip=self._self_sip, src_cube=self._self_cube, src_pe=self._self_pe, src_direction=cmd.direction, ) # Forward to PE_DMA (vc_comm) yield self.out_ports[self._dma_node_id].put(token) qp["my_head"] += 1 # Diagnostics trace (D14) from kernbench.ccl import diagnostics if diagnostics.trace_enabled(): diagnostics.log_send( t_ns=float(env.now), sender=self._pe_prefix, direction=cmd.direction, nbytes=cmd.nbytes, sender_seq=qp["my_head"] - 1, ) if not req.done.triggered: req.done.succeed() # ── Recv ── def _handle_recv( self, env: simpy.Environment, req: IpcqRequest, cmd: IpcqRecvCmd, ) -> Generator: if cmd.direction is None: direction = yield from self._wait_any_direction(env) else: if cmd.direction not in self._queue_pairs: raise IpcqInvalidDirection( f"PE {self._pe_prefix}: direction {cmd.direction!r} not installed" ) direction = cmd.direction qp = self._queue_pairs[direction] while qp["peer_head_cache"] <= qp["my_tail"]: wait_event = env.event() self._recv_waiters[direction].append(wait_event) yield wait_event qp = self._queue_pairs[direction] slot_idx = qp["my_tail"] % qp["n_slots"] slot_addr = qp["my_rx_base_pa"] + slot_idx * qp["slot_size"] # Strict validation (D14 F2): peek the next-arrived token's metadata # against the recv command's expected shape/dtype/nbytes. arrived = self._arrived_tokens.get(direction, []) if arrived: front = arrived.pop(0) if self._strict: expected_nbytes = self._nbytes_for(cmd.shape, cmd.dtype) if front.dtype != cmd.dtype: raise ValueError( f"PE_IPCQ {self._pe_prefix} recv strict: dtype mismatch — " f"sender={front.dtype} recv={cmd.dtype}" ) if front.shape != cmd.shape: raise ValueError( f"PE_IPCQ {self._pe_prefix} recv strict: shape mismatch — " f"sender={front.shape} recv={cmd.shape}" ) if front.nbytes != expected_nbytes: raise ValueError( f"PE_IPCQ {self._pe_prefix} recv strict: nbytes mismatch — " f"sender={front.nbytes} recv={expected_nbytes}" ) req.result_data["src_space"] = self._buffer_kind req.result_data["src_addr"] = slot_addr req.result_data["direction"] = direction req.result_data["dtype"] = cmd.dtype req.result_data["shape"] = cmd.shape req.result_data["nbytes"] = self._nbytes_for(cmd.shape, cmd.dtype) # copy_to_dst mode: rebind the result handle to (dst_space, dst_addr). # When op_log is disabled, we also do the actual data move now; # when op_log is enabled, Phase 2 replays the slot→dst copy from # the op_log entry below so we don't pollute the slot in Phase 1. if cmd.recv_mode == "copy_to_dst" and self.ctx is not None: req.result_data["src_space"] = cmd.dst_space req.result_data["src_addr"] = cmd.dst_addr store = getattr(self.ctx, "memory_store", None) if store is not None and self._op_logger is None: try: data = store.read(self._buffer_kind, slot_addr, shape=cmd.shape, dtype=cmd.dtype) store.write(cmd.dst_space, cmd.dst_addr, data) except Exception: pass if self._op_logger is not None: # Record slot → dst copy for Phase 2 replay (ADR-0023 D9.5). try: self._op_logger.record_copy( t_start=float(env.now), t_end=float(env.now), component_id=self.node.id, src_space=self._buffer_kind, src_addr=slot_addr, dst_space=cmd.dst_space, dst_addr=cmd.dst_addr, shape=cmd.shape, dtype=cmd.dtype, nbytes=self._nbytes_for(cmd.shape, cmd.dtype), ) except Exception: pass qp["my_tail"] += 1 # Diagnostics trace (D14) from kernbench.ccl import diagnostics if diagnostics.trace_enabled(): diagnostics.log_recv( t_ns=float(env.now), receiver=self._pe_prefix, direction=direction, nbytes=req.result_data.get("nbytes", 0), ) # Fast path credit return — bottleneck BW based latency env.process( self._delayed_credit_send(env, direction, qp["peer_credit_store"], qp["my_tail"]) ) if not req.done.triggered: req.done.succeed() def _wait_any_direction(self, env: simpy.Environment) -> Generator: """Round-robin scan over installed directions; wait until at least one has data. Returns the chosen direction (str).""" if not self._rr_dirs: raise IpcqInvalidDirection( f"PE {self._pe_prefix}: no neighbors installed" ) while True: n = len(self._rr_dirs) for i in range(n): idx = (self._rr_cursor + i) % n d = self._rr_dirs[idx] qp = self._queue_pairs[d] if qp["peer_head_cache"] > qp["my_tail"]: self._rr_cursor = (idx + 1) % n return d # Nothing available — wait until any arrival wait_event = env.event() self._any_recv_waiters.append(wait_event) yield wait_event # ── Metadata arrival from PE_DMA (D9) ── def _handle_meta_arrival(self, msg: IpcqMetaArrival) -> None: """Match arrival to the correct direction by dst_addr range (ADR-0025 D2). Each direction has a unique rx buffer address range ([my_rx_base_pa, my_rx_base_pa + n_slots * slot_size)). The token's dst_addr (set by the sender's IPCQ when computing the peer slot address) falls within exactly one such range. Address-based matching is unambiguous even when multiple directions share the same peer (2-rank bidirectional ring). """ token = msg.token dst_addr = token.dst_addr for d, qp in self._queue_pairs.items(): base = qp["my_rx_base_pa"] size = qp["n_slots"] * qp["slot_size"] if base <= dst_addr < base + size: qp["peer_head_cache"] = max(qp["peer_head_cache"], token.sender_seq + 1) # Track arrived token for strict-mode peek self._arrived_tokens.setdefault(d, []).append(token) # Wake any blocked recv on this direction waiters = self._recv_waiters.get(d, []) self._recv_waiters[d] = [] for ev in waiters: if not ev.triggered: ev.succeed() # Wake any-direction waiters any_waiters = self._any_recv_waiters self._any_recv_waiters = [] for ev in any_waiters: if not ev.triggered: ev.succeed() return # Unknown dst_addr — silently drop (could log) # ── Credit return (fast path) ── def _credit_worker(self, env: simpy.Environment) -> Generator: """Process IpcqCreditMetadata from credit_inbox. Matches credit to the correct direction by `credit.dst_rx_base_pa == qp.peer.rx_base_pa` (ADR-0025 D3). This is unambiguous even when multiple directions share the same peer (2-rank bidirectional ring). """ assert self._credit_inbox is not None while True: credit: IpcqCreditMetadata = yield self._credit_inbox.get() for d, qp in self._queue_pairs.items(): if qp["peer"].rx_base_pa == credit.dst_rx_base_pa: qp["peer_tail_cache"] = max(qp["peer_tail_cache"], credit.consumer_seq) # Wake any blocked send on this direction waiters = self._send_waiters.get(d, []) self._send_waiters[d] = [] for ev in waiters: if not ev.triggered: ev.succeed() break def _delayed_credit_send( self, env: simpy.Environment, direction: str, peer_credit_store: simpy.Store, new_tail: int, ) -> Generator: """Wait bottleneck-BW latency, then put IpcqCreditMetadata into peer credit store (D9 fast path). Carries ``dst_rx_base_pa`` = this PE's my_rx_base_pa for the consumed direction. The peer (original sender) matches this against qp.peer.rx_base_pa to identify the correct qp (ADR-0025 D3). """ latency_ns = self._credit_latency_ns(direction) if latency_ns > 0: yield env.timeout(latency_ns) qp = self._queue_pairs[direction] meta = IpcqCreditMetadata( consumer_seq=new_tail, dst_rx_base_pa=qp["my_rx_base_pa"], src_sip=self._self_sip, src_cube=self._self_cube, src_pe=self._self_pe, src_direction=direction, ) yield peer_credit_store.put(meta) def _credit_latency_ns(self, direction: str) -> float: """Compute credit fast path latency = credit_size / bottleneck_bw. Falls back to 0 when ctx/router is unavailable (unit-test mode). """ if self.ctx is None: return 0.0 qp = self._queue_pairs[direction] peer = qp["peer"] peer_pe_prefix = f"sip{peer.sip}.cube{peer.cube}.pe{peer.pe}" try: path = self.ctx.router.find_path(self._pe_prefix, peer_pe_prefix) return self.ctx.compute_drain_ns(path, self._credit_size_bytes) except Exception: return 0.0 # ── Helpers ── @staticmethod def _nbytes_for(shape: tuple[int, ...], dtype: str) -> int: from math import prod bits = {"f16": 16, "bf16": 16, "f32": 32, "i8": 8, "i16": 16, "i32": 32}.get(dtype, 16) return prod(shape) * (bits // 8) if shape else 0