Add PE-level IPCQ collective infra + unified ccl_allreduce bench (ADR-0023)
Major changes:
PE-level IPCQ infrastructure:
- New PE_IPCQ component: ring-buffer control plane with 4-direction
neighbor mapping, head/tail pointers, backpressure (poll/sleep).
- PE_DMA extended with vc_comm channel for IPCQ outbound/inbound DMA,
including in-flight data snapshot (D9) and op_log recording at
outbound time for Phase 2 replay correctness.
- IpcqDmaToken piggyback model: data + metadata travel together,
atomic visibility at receiver (invariant I6).
- Credit return fast path: bottleneck-BW latency, no fabric vc_comm.
Phase 2 data execution (ADR-0020 integration):
- op_log extended: DmaWriteCmd now captures src_space/src_addr for
Phase 2 dma_write copy; ipcq_copy ops recorded at outbound time.
- DataExecutor replays dma_write + ipcq_copy in t_start order.
- Engine._flush_data_phase: incremental cursor-based replay after
each engine.wait() so host reads see post-Phase-2 data.
- KernelRunner Phase 1 writes disabled when op_log is active to
prevent stale data from corrupting the MemoryStore snapshot.
TLContext / kernel API:
- tl.send(dir, src=TensorHandle), tl.recv(dir, shape, dtype),
tl.recv_async, tl.wait(RecvFuture), copy_to_dst mode.
- TensorHandle operator overloading (add/sub/mul/div) via thread-local
active TLContext → MathCmd dispatch through PE_MATH.
- PE-local scratch allocator for math output handles.
- tl.load returns space="hbm" handles for correct Phase 2 addressing.
- Additional math functions: maximum, minimum, fma, clamp, softmax, cdiv.
Unified ccl_allreduce bench (PyTorch-compat host code):
- Single benches/ccl_allreduce.py with run() + worker(rank, ws, torch)
split matching real PyTorch DDP worker pattern.
- torch.distributed facade: init_process_group, get_world_size,
get_rank, get_backend, all_reduce, barrier — only real PyTorch names.
- AhbmCCLBackend: eager install_ipcq at init, all_reduce dispatches
kernel via tensor shard metadata (n_elem from shards[0].nbytes).
- world_size derived from topology spec (sips × cubes × pes_per_cube)
with optional algorithm-level override in ccl.yaml.
Tensor API (PyTorch-compat surface):
- Tensor.numpy(): gather-aware (all shards via VA-based addressing).
- Tensor.copy_(source): scatter from host tensor into sharded target.
- RuntimeContext.from_numpy(arr): host-side staging tensor.
- Tensor.data property fixed to use numpy() (was shards[0]-only).
Algorithm modules moved to src/kernbench/ccl/algorithms/:
- ring_allreduce, mesh_allreduce, tree_allreduce, hello_send.
- Each module exports kernel_args(world_size, n_elem) helper.
- ccl.yaml module paths updated to kernbench.ccl.algorithms.*.
Dead code removed:
- 7 per-variant bench files (ccl_allreduce_{tcm,hbm,sram}, etc.).
- _run_ccl_bench greenlet-per-SIP scheduler.
- benches.loader.is_ccl_bench + run_rank detection.
- benches/ccl/ directory.
Tests:
- New test_ccl_allreduce_matrix.py: 7 parametrized cases
(ring×3 buffers, ring 8/16, mesh 4, tree 7).
- New test_runtime_api_tensor.py: copy_/numpy/from_numpy unit tests.
- Existing tests updated for new import paths + world_size_override.
Docs:
- Korean ccl-author-guide.md and ADR-0023 paths updated.
- New English versions: ccl-author-guide.en.md, ADR-0023.en.md.
502 tests pass.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -42,9 +42,30 @@ class PeCpuComponent(ComponentBase):
|
||||
self._cube_idx = int(parts[1].replace("cube", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._cube_idx = 0
|
||||
# num_cubes from spec (for tl.program_id(axis=1))
|
||||
# num_cubes from spec (for tl.program_id(axis=1) — ADR-0022)
|
||||
spec = ctx.spec if ctx else {}
|
||||
self._num_cubes = spec.get("system", {}).get("sips", {}).get("cubes_per_sip", 1)
|
||||
cube_mesh = spec.get("sip", {}).get("cube_mesh", {})
|
||||
if cube_mesh:
|
||||
self._num_cubes = int(cube_mesh.get("w", 1)) * int(cube_mesh.get("h", 1))
|
||||
else:
|
||||
self._num_cubes = (
|
||||
spec.get("system", {}).get("sips", {}).get("cubes_per_sip", 1)
|
||||
)
|
||||
# PE-local scratch for kernel math output handles (ADR-0020 D3
|
||||
# extension; reserved portion of TCM addressed via a synthetic
|
||||
# MemoryStore key, not the real PA encoder).
|
||||
pe_template = spec.get("cube", {}).get("pe_template", {})
|
||||
tcm_attrs = pe_template.get("components", {}).get("pe_tcm", {}).get("attrs", {})
|
||||
scratch_mb = float(tcm_attrs.get("kernel_scratch_mb", 1))
|
||||
self._tl_scratch_size = int(scratch_mb * (1 << 20))
|
||||
# PE-unique base address — high bit pattern to avoid collision with
|
||||
# IPCQ ring buffers (which use bit 60).
|
||||
self._tl_scratch_base = (
|
||||
(1 << 61)
|
||||
| (self._sip_idx << 40)
|
||||
| (self._cube_idx << 32)
|
||||
| (self._pe_idx << 24)
|
||||
)
|
||||
|
||||
def _find_shard(self, shards: tuple) -> Any:
|
||||
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
|
||||
@@ -146,6 +167,8 @@ class PeCpuComponent(ComponentBase):
|
||||
scheduler_id=scheduler_id,
|
||||
out_ports=self.out_ports,
|
||||
store=store,
|
||||
scratch_base=self._tl_scratch_base,
|
||||
scratch_size=self._tl_scratch_size,
|
||||
)
|
||||
yield from runner.run(env, kernel_fn, kernel_args, num_programs)
|
||||
return getattr(runner, "_composite_results", [])
|
||||
|
||||
@@ -106,18 +106,131 @@ class PeDmaComponent(PeEngineBase):
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Handle TileToken (pipeline), PeInternalTxn (legacy), and Transaction (fabric)."""
|
||||
"""Handle TileToken (pipeline), PeInternalTxn (legacy), IpcqDmaToken,
|
||||
and Transaction (fabric)."""
|
||||
from kernbench.common.ipcq_types import IpcqDmaToken
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.builtin.pe_types import TileToken
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, TileToken):
|
||||
if isinstance(msg, IpcqDmaToken):
|
||||
# Outbound: IPCQ token from local PE_IPCQ → forward via fabric
|
||||
env.process(self._handle_ipcq_outbound(env, msg))
|
||||
elif isinstance(msg, TileToken):
|
||||
env.process(self._pipeline_process(env, msg))
|
||||
elif isinstance(msg, PeInternalTxn):
|
||||
env.process(self._handle_with_hooks(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
# Transaction (or unknown). May carry IpcqDmaToken inbound.
|
||||
req = getattr(msg, "request", None)
|
||||
if isinstance(req, IpcqDmaToken):
|
||||
env.process(self._handle_ipcq_inbound(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
# ── IPCQ outbound (PE_IPCQ → PE_DMA → fabric) ───────────────────
|
||||
|
||||
def _handle_ipcq_outbound(self, env: simpy.Environment, token: Any) -> Generator:
|
||||
"""Forward IpcqDmaToken from local PE_IPCQ through the fabric to peer
|
||||
PE_DMA. ADR-0023 D8 (vc_comm channel)."""
|
||||
if self.ctx is None:
|
||||
return # nothing to do
|
||||
peer = token.dst_endpoint
|
||||
peer_pe_dma = f"sip{peer.sip}.cube{peer.cube}.pe{peer.pe}.pe_dma"
|
||||
|
||||
# Snapshot the source data at send time (D9 in-flight semantics).
|
||||
# Without this, the receiver could read stale or future data if the
|
||||
# sender mutates src_addr between send issue and DMA arrival.
|
||||
store = getattr(self.ctx, "memory_store", None)
|
||||
if store is not None and token.data is None:
|
||||
try:
|
||||
snap = store.read(
|
||||
token.src_space, token.src_addr,
|
||||
shape=token.shape, dtype=token.dtype,
|
||||
)
|
||||
# Copy so later mutations to src_addr don't affect the snapshot.
|
||||
token.data = snap.copy() if hasattr(snap, "copy") else snap
|
||||
except Exception:
|
||||
token.data = None
|
||||
|
||||
# Record the IPCQ copy in op_log at OUTBOUND time. ADR-0020 D6:
|
||||
# Phase 2 replays the copy in t_start order; using outbound time
|
||||
# (rather than inbound) ensures the copy executes before any later
|
||||
# local op at the sender that might overwrite token.src_addr (e.g.
|
||||
# a tl.store after a recv).
|
||||
if self._op_logger is not None:
|
||||
try:
|
||||
self._op_logger.record_copy(
|
||||
t_start=float(env.now), t_end=float(env.now),
|
||||
component_id=self.node.id,
|
||||
src_space=token.src_space, src_addr=token.src_addr,
|
||||
dst_space=peer.buffer_kind,
|
||||
dst_addr=token.dst_addr,
|
||||
shape=token.shape, dtype=token.dtype, nbytes=token.nbytes,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
path = self.ctx.router.find_path(self._pe_prefix, peer_pe_dma)
|
||||
except Exception:
|
||||
return
|
||||
drain_ns = self.ctx.compute_drain_ns(path, token.nbytes)
|
||||
|
||||
sub_done = env.event()
|
||||
sub_txn = Transaction(
|
||||
request=token, path=path, step=0,
|
||||
nbytes=token.nbytes, done=sub_done, drain_ns=drain_ns,
|
||||
)
|
||||
if len(path) > 1:
|
||||
next_hop = path[1]
|
||||
if next_hop in self.out_ports:
|
||||
yield self.out_ports[next_hop].put(sub_txn.advance())
|
||||
else:
|
||||
return
|
||||
# Note: don't wait on sub_done here — fire-and-forget for vc_comm.
|
||||
# IPCQ slot bookkeeping (peer_head) was already updated by PE_IPCQ;
|
||||
# backpressure is via credit return, not via this DMA's completion.
|
||||
|
||||
# ── IPCQ inbound (fabric → PE_DMA → MemoryStore + PE_IPCQ) ──────
|
||||
|
||||
def _handle_ipcq_inbound(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""At destination PE_DMA: atomically write data and forward metadata.
|
||||
|
||||
I6 (MUST): no SimPy yield between MemoryStore.write and the
|
||||
IpcqMetaArrival put into PE_IPCQ.
|
||||
"""
|
||||
from kernbench.common.ipcq_types import IpcqMetaArrival
|
||||
|
||||
token = txn.request
|
||||
|
||||
# ── ATOMIC: do not introduce yield between these two operations ──
|
||||
# 1. Move data via MemoryStore (single-hop DMA write).
|
||||
# Prefer the in-flight snapshot stashed by the sender PE_DMA;
|
||||
# fall back to a fresh read of src_addr if no snapshot is present
|
||||
# (e.g. control-only token).
|
||||
store = getattr(self.ctx, "memory_store", None) if self.ctx else None
|
||||
if store is not None:
|
||||
try:
|
||||
data = token.data
|
||||
if data is None:
|
||||
data = store.read(
|
||||
token.src_space, token.src_addr,
|
||||
shape=token.shape, dtype=token.dtype,
|
||||
)
|
||||
store.write(token.dst_endpoint.buffer_kind, token.dst_addr, data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Forward IpcqMetaArrival to local PE_IPCQ
|
||||
ipcq_id = f"{self._pe_prefix}.pe_ipcq"
|
||||
if ipcq_id in self.out_ports:
|
||||
yield self.out_ports[ipcq_id].put(IpcqMetaArrival(token=token))
|
||||
# ─────────────────────────────────────────────────────────────────
|
||||
|
||||
if not txn.done.triggered:
|
||||
txn.done.succeed()
|
||||
|
||||
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
|
||||
"""Pipeline mode: DMA read/write via fabric, then self-route."""
|
||||
|
||||
@@ -0,0 +1,455 @@
|
||||
"""PE_IPCQ component (ADR-0023): per-PE IPCQ control plane.
|
||||
|
||||
Responsibilities:
|
||||
- Hold per-direction queue pair state (my_head, my_tail,
|
||||
peer_head_cache, peer_tail_cache, ring buffer addresses)
|
||||
- Process IpcqInitMsg from backend to install neighbor table
|
||||
- Handle IpcqRequest(IpcqSendCmd) from PE_CPU:
|
||||
compute peer slot address, check backpressure, forward
|
||||
IpcqDmaToken to PE_DMA (vc_comm)
|
||||
- Handle IpcqRequest(IpcqRecvCmd) from PE_CPU:
|
||||
wait for data arrival, return slot address (or copy to dst),
|
||||
send fast-path credit return
|
||||
- Handle IpcqMetaArrival from PE_DMA: update peer_head_cache, wake recv
|
||||
- Handle IpcqCreditMetadata via own credit_inbox: update peer_tail_cache,
|
||||
wake send
|
||||
|
||||
PE_IPCQ does NOT move data — it forwards IpcqDmaToken to PE_DMA which
|
||||
performs the actual fabric DMA.
|
||||
|
||||
Credit return uses a fast path: PE_IPCQ creates a SimPy process with a
|
||||
bottleneck-BW based latency, then puts IpcqCreditMetadata directly into
|
||||
the peer's pre-wired credit_store.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.common.ipcq_types import (
|
||||
IpcqCreditMetadata,
|
||||
IpcqDmaToken,
|
||||
IpcqInvalidDirection,
|
||||
IpcqMetaArrival,
|
||||
IpcqRecvCmd,
|
||||
IpcqRequest,
|
||||
IpcqSendCmd,
|
||||
)
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.runtime_api.kernel import IpcqInitMsg
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
_DIR_ORDER: tuple[str, ...] = ("N", "S", "E", "W", "parent", "child_left", "child_right")
|
||||
|
||||
|
||||
class PeIpcqComponent(ComponentBase):
|
||||
"""PE_IPCQ: ring buffer pointer + neighbor management for CCL.
|
||||
|
||||
Owned by one PE; talks to PE_DMA via out_ports[<pe_dma_id>] and
|
||||
receives credit return metadata via the public ``credit_inbox``
|
||||
SimPy Store (wired by backend at IpcqInitMsg installation time).
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
# Strict shape/dtype validation (D14 F2). Off by default.
|
||||
self._strict: bool = bool(node.attrs.get("strict_validation", False))
|
||||
# direction → list of received tokens (for strict-mode peek of next slot)
|
||||
self._arrived_tokens: dict[str, list] = {}
|
||||
# Parse self (sip, cube, pe) from node id, e.g. "sip0.cube0.pe0.pe_ipcq"
|
||||
self._pe_prefix: str = node.id.rsplit(".", 1)[0]
|
||||
parts = self._pe_prefix.split(".")
|
||||
try:
|
||||
self._self_sip = int(parts[0].replace("sip", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._self_sip = 0
|
||||
try:
|
||||
self._self_cube = int(parts[1].replace("cube", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._self_cube = 0
|
||||
try:
|
||||
self._self_pe = int(parts[2].replace("pe", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._self_pe = 0
|
||||
|
||||
self._dma_node_id = f"{self._pe_prefix}.pe_dma"
|
||||
# direction → state dict (see _install_neighbors for shape)
|
||||
self._queue_pairs: dict[str, dict[str, Any]] = {}
|
||||
self._installed = False
|
||||
self._buffer_kind: str = "tcm"
|
||||
self._backpressure_mode: str = "sleep"
|
||||
self._credit_size_bytes: int = 16
|
||||
# waiters for recv (per direction) and any-direction (for round-robin)
|
||||
self._recv_waiters: dict[str, list[simpy.Event]] = {}
|
||||
self._any_recv_waiters: list[simpy.Event] = []
|
||||
# waiters for send backpressure (per direction)
|
||||
self._send_waiters: dict[str, list[simpy.Event]] = {}
|
||||
# round-robin cursor over installed directions
|
||||
self._rr_dirs: list[str] = []
|
||||
self._rr_cursor: int = 0
|
||||
# credit_inbox is created in start() once env is available
|
||||
self._credit_inbox: simpy.Store | None = None
|
||||
|
||||
# ── Public ──
|
||||
|
||||
@property
|
||||
def credit_inbox(self) -> simpy.Store:
|
||||
"""SimPy Store that backend wires as ``peer_credit_store`` on
|
||||
every remote sender targeting this PE. Used by D9 fast path."""
|
||||
assert self._credit_inbox is not None, "PE_IPCQ not started yet"
|
||||
return self._credit_inbox
|
||||
|
||||
@property
|
||||
def queue_pairs(self) -> dict[str, dict[str, Any]]:
|
||||
"""Test/debug accessor."""
|
||||
return self._queue_pairs
|
||||
|
||||
# ── Lifecycle ──
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
yield env.timeout(0)
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
# Create credit_inbox even if there are no in_ports yet
|
||||
if self._credit_inbox is None:
|
||||
self._credit_inbox = simpy.Store(env)
|
||||
# If no in_ports were wired (e.g. unit test), still spin up workers
|
||||
if not self.in_ports:
|
||||
self._inbox = simpy.Store(env)
|
||||
super().start(env)
|
||||
env.process(self._credit_worker(env))
|
||||
|
||||
# ── Worker (override of ComponentBase._worker) ──
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
from kernbench.runtime_api.kernel import IpcqInitMsg
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
|
||||
# IpcqInitMsg may arrive wrapped in a transaction (with .request)
|
||||
# or directly.
|
||||
request_obj = getattr(msg, "request", None)
|
||||
if isinstance(request_obj, IpcqInitMsg):
|
||||
self._install_neighbors(request_obj)
|
||||
done = getattr(msg, "done", None)
|
||||
if done is not None and not done.triggered:
|
||||
done.succeed()
|
||||
continue
|
||||
if isinstance(msg, IpcqInitMsg):
|
||||
self._install_neighbors(msg)
|
||||
continue
|
||||
|
||||
if isinstance(msg, IpcqMetaArrival):
|
||||
self._handle_meta_arrival(msg)
|
||||
continue
|
||||
|
||||
if isinstance(msg, IpcqRequest):
|
||||
env.process(self._handle_request(env, msg))
|
||||
continue
|
||||
|
||||
# Unknown message — drop or forward via base class fallback
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
# ── Init ──
|
||||
|
||||
def _install_neighbors(self, msg: IpcqInitMsg) -> None:
|
||||
self._installed = True
|
||||
self._buffer_kind = msg.buffer_kind
|
||||
self._backpressure_mode = msg.backpressure_mode
|
||||
self._credit_size_bytes = msg.credit_size_bytes
|
||||
for entry in msg.entries:
|
||||
self._queue_pairs[entry.direction] = {
|
||||
"peer": entry.peer,
|
||||
"my_rx_base_pa": entry.my_rx_base_pa,
|
||||
"my_rx_base_va": entry.my_rx_base_va,
|
||||
"n_slots": entry.n_slots,
|
||||
"slot_size": entry.slot_size,
|
||||
"peer_credit_store": entry.peer_credit_store,
|
||||
"my_head": 0,
|
||||
"my_tail": 0,
|
||||
"peer_head_cache": 0,
|
||||
"peer_tail_cache": 0,
|
||||
}
|
||||
self._recv_waiters.setdefault(entry.direction, [])
|
||||
self._send_waiters.setdefault(entry.direction, [])
|
||||
# Reset round-robin order to a stable canonical sequence
|
||||
self._rr_dirs = [d for d in _DIR_ORDER if d in self._queue_pairs]
|
||||
self._rr_cursor = 0
|
||||
|
||||
# ── Send ──
|
||||
|
||||
def _handle_request(self, env: simpy.Environment, req: IpcqRequest) -> Generator:
|
||||
cmd = req.command
|
||||
if isinstance(cmd, IpcqSendCmd):
|
||||
yield from self._handle_send(env, req, cmd)
|
||||
elif isinstance(cmd, IpcqRecvCmd):
|
||||
yield from self._handle_recv(env, req, cmd)
|
||||
|
||||
def _handle_send(
|
||||
self, env: simpy.Environment, req: IpcqRequest, cmd: IpcqSendCmd,
|
||||
) -> Generator:
|
||||
if cmd.direction not in self._queue_pairs:
|
||||
raise IpcqInvalidDirection(
|
||||
f"PE {self._pe_prefix}: direction {cmd.direction!r} not installed"
|
||||
)
|
||||
qp = self._queue_pairs[cmd.direction]
|
||||
peer = qp["peer"]
|
||||
|
||||
# Backpressure: wait while ring full
|
||||
while (qp["my_head"] - qp["peer_tail_cache"]) >= peer.n_slots:
|
||||
wait_event = env.event()
|
||||
self._send_waiters[cmd.direction].append(wait_event)
|
||||
yield wait_event
|
||||
|
||||
# Compute peer slot address
|
||||
slot_idx = qp["my_head"] % peer.n_slots
|
||||
dst_pa = peer.rx_base_pa + slot_idx * peer.slot_size
|
||||
|
||||
token = IpcqDmaToken(
|
||||
src_addr=cmd.src_addr,
|
||||
src_space=cmd.src_space,
|
||||
dst_addr=dst_pa,
|
||||
dst_endpoint=peer,
|
||||
nbytes=cmd.nbytes,
|
||||
handle_id=cmd.handle_id,
|
||||
shape=cmd.shape,
|
||||
dtype=cmd.dtype,
|
||||
sender_seq=qp["my_head"],
|
||||
src_sip=self._self_sip,
|
||||
src_cube=self._self_cube,
|
||||
src_pe=self._self_pe,
|
||||
src_direction=cmd.direction,
|
||||
)
|
||||
|
||||
# Forward to PE_DMA (vc_comm)
|
||||
yield self.out_ports[self._dma_node_id].put(token)
|
||||
qp["my_head"] += 1
|
||||
# Diagnostics trace (D14)
|
||||
from kernbench.ccl import diagnostics
|
||||
if diagnostics.trace_enabled():
|
||||
diagnostics.log_send(
|
||||
t_ns=float(env.now), sender=self._pe_prefix,
|
||||
direction=cmd.direction, nbytes=cmd.nbytes,
|
||||
sender_seq=qp["my_head"] - 1,
|
||||
)
|
||||
if not req.done.triggered:
|
||||
req.done.succeed()
|
||||
|
||||
# ── Recv ──
|
||||
|
||||
def _handle_recv(
|
||||
self, env: simpy.Environment, req: IpcqRequest, cmd: IpcqRecvCmd,
|
||||
) -> Generator:
|
||||
if cmd.direction is None:
|
||||
direction = yield from self._wait_any_direction(env)
|
||||
else:
|
||||
if cmd.direction not in self._queue_pairs:
|
||||
raise IpcqInvalidDirection(
|
||||
f"PE {self._pe_prefix}: direction {cmd.direction!r} not installed"
|
||||
)
|
||||
direction = cmd.direction
|
||||
qp = self._queue_pairs[direction]
|
||||
while qp["peer_head_cache"] <= qp["my_tail"]:
|
||||
wait_event = env.event()
|
||||
self._recv_waiters[direction].append(wait_event)
|
||||
yield wait_event
|
||||
|
||||
qp = self._queue_pairs[direction]
|
||||
slot_idx = qp["my_tail"] % qp["n_slots"]
|
||||
slot_addr = qp["my_rx_base_pa"] + slot_idx * qp["slot_size"]
|
||||
|
||||
# Strict validation (D14 F2): peek the next-arrived token's metadata
|
||||
# against the recv command's expected shape/dtype/nbytes.
|
||||
arrived = self._arrived_tokens.get(direction, [])
|
||||
if arrived:
|
||||
front = arrived.pop(0)
|
||||
if self._strict:
|
||||
expected_nbytes = self._nbytes_for(cmd.shape, cmd.dtype)
|
||||
if front.dtype != cmd.dtype:
|
||||
raise ValueError(
|
||||
f"PE_IPCQ {self._pe_prefix} recv strict: dtype mismatch — "
|
||||
f"sender={front.dtype} recv={cmd.dtype}"
|
||||
)
|
||||
if front.shape != cmd.shape:
|
||||
raise ValueError(
|
||||
f"PE_IPCQ {self._pe_prefix} recv strict: shape mismatch — "
|
||||
f"sender={front.shape} recv={cmd.shape}"
|
||||
)
|
||||
if front.nbytes != expected_nbytes:
|
||||
raise ValueError(
|
||||
f"PE_IPCQ {self._pe_prefix} recv strict: nbytes mismatch — "
|
||||
f"sender={front.nbytes} recv={expected_nbytes}"
|
||||
)
|
||||
|
||||
req.result_data["src_space"] = self._buffer_kind
|
||||
req.result_data["src_addr"] = slot_addr
|
||||
req.result_data["direction"] = direction
|
||||
req.result_data["dtype"] = cmd.dtype
|
||||
req.result_data["shape"] = cmd.shape
|
||||
req.result_data["nbytes"] = self._nbytes_for(cmd.shape, cmd.dtype)
|
||||
|
||||
# copy_to_dst mode: rebind the result handle to (dst_space, dst_addr).
|
||||
# When op_log is disabled, we also do the actual data move now;
|
||||
# when op_log is enabled, Phase 2 replays the slot→dst copy from
|
||||
# the op_log entry below so we don't pollute the slot in Phase 1.
|
||||
if cmd.recv_mode == "copy_to_dst" and self.ctx is not None:
|
||||
req.result_data["src_space"] = cmd.dst_space
|
||||
req.result_data["src_addr"] = cmd.dst_addr
|
||||
store = getattr(self.ctx, "memory_store", None)
|
||||
if store is not None and self._op_logger is None:
|
||||
try:
|
||||
data = store.read(self._buffer_kind, slot_addr, shape=cmd.shape, dtype=cmd.dtype)
|
||||
store.write(cmd.dst_space, cmd.dst_addr, data)
|
||||
except Exception:
|
||||
pass
|
||||
if self._op_logger is not None:
|
||||
# Record slot → dst copy for Phase 2 replay (ADR-0023 D9.5).
|
||||
try:
|
||||
self._op_logger.record_copy(
|
||||
t_start=float(env.now), t_end=float(env.now),
|
||||
component_id=self.node.id,
|
||||
src_space=self._buffer_kind, src_addr=slot_addr,
|
||||
dst_space=cmd.dst_space, dst_addr=cmd.dst_addr,
|
||||
shape=cmd.shape, dtype=cmd.dtype,
|
||||
nbytes=self._nbytes_for(cmd.shape, cmd.dtype),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
qp["my_tail"] += 1
|
||||
|
||||
# Diagnostics trace (D14)
|
||||
from kernbench.ccl import diagnostics
|
||||
if diagnostics.trace_enabled():
|
||||
diagnostics.log_recv(
|
||||
t_ns=float(env.now), receiver=self._pe_prefix,
|
||||
direction=direction,
|
||||
nbytes=req.result_data.get("nbytes", 0),
|
||||
)
|
||||
|
||||
# Fast path credit return — bottleneck BW based latency
|
||||
env.process(
|
||||
self._delayed_credit_send(env, direction, qp["peer_credit_store"], qp["my_tail"])
|
||||
)
|
||||
|
||||
if not req.done.triggered:
|
||||
req.done.succeed()
|
||||
|
||||
def _wait_any_direction(self, env: simpy.Environment) -> Generator:
|
||||
"""Round-robin scan over installed directions; wait until at least one
|
||||
has data. Returns the chosen direction (str)."""
|
||||
if not self._rr_dirs:
|
||||
raise IpcqInvalidDirection(
|
||||
f"PE {self._pe_prefix}: no neighbors installed"
|
||||
)
|
||||
while True:
|
||||
n = len(self._rr_dirs)
|
||||
for i in range(n):
|
||||
idx = (self._rr_cursor + i) % n
|
||||
d = self._rr_dirs[idx]
|
||||
qp = self._queue_pairs[d]
|
||||
if qp["peer_head_cache"] > qp["my_tail"]:
|
||||
self._rr_cursor = (idx + 1) % n
|
||||
return d
|
||||
# Nothing available — wait until any arrival
|
||||
wait_event = env.event()
|
||||
self._any_recv_waiters.append(wait_event)
|
||||
yield wait_event
|
||||
|
||||
# ── Metadata arrival from PE_DMA (D9) ──
|
||||
|
||||
def _handle_meta_arrival(self, msg: IpcqMetaArrival) -> None:
|
||||
token = msg.token
|
||||
sender_key = (token.src_sip, token.src_cube, token.src_pe)
|
||||
for d, qp in self._queue_pairs.items():
|
||||
p = qp["peer"]
|
||||
if (p.sip, p.cube, p.pe) == sender_key:
|
||||
qp["peer_head_cache"] = max(qp["peer_head_cache"], token.sender_seq + 1)
|
||||
# Track arrived token for strict-mode peek
|
||||
self._arrived_tokens.setdefault(d, []).append(token)
|
||||
# Wake any blocked recv on this direction
|
||||
waiters = self._recv_waiters.get(d, [])
|
||||
self._recv_waiters[d] = []
|
||||
for ev in waiters:
|
||||
if not ev.triggered:
|
||||
ev.succeed()
|
||||
# Wake any-direction waiters
|
||||
any_waiters = self._any_recv_waiters
|
||||
self._any_recv_waiters = []
|
||||
for ev in any_waiters:
|
||||
if not ev.triggered:
|
||||
ev.succeed()
|
||||
return
|
||||
# Unknown sender — silently drop (could log)
|
||||
|
||||
# ── Credit return (fast path) ──
|
||||
|
||||
def _credit_worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Process IpcqCreditMetadata from credit_inbox."""
|
||||
assert self._credit_inbox is not None
|
||||
while True:
|
||||
credit: IpcqCreditMetadata = yield self._credit_inbox.get()
|
||||
sender_key = (credit.src_sip, credit.src_cube, credit.src_pe)
|
||||
for d, qp in self._queue_pairs.items():
|
||||
p = qp["peer"]
|
||||
if (p.sip, p.cube, p.pe) == sender_key:
|
||||
qp["peer_tail_cache"] = max(qp["peer_tail_cache"], credit.consumer_seq)
|
||||
# Wake any blocked send on this direction
|
||||
waiters = self._send_waiters.get(d, [])
|
||||
self._send_waiters[d] = []
|
||||
for ev in waiters:
|
||||
if not ev.triggered:
|
||||
ev.succeed()
|
||||
break
|
||||
|
||||
def _delayed_credit_send(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
direction: str,
|
||||
peer_credit_store: simpy.Store,
|
||||
new_tail: int,
|
||||
) -> Generator:
|
||||
"""Wait bottleneck-BW latency, then put IpcqCreditMetadata into peer
|
||||
credit store (D9 fast path)."""
|
||||
latency_ns = self._credit_latency_ns(direction)
|
||||
if latency_ns > 0:
|
||||
yield env.timeout(latency_ns)
|
||||
meta = IpcqCreditMetadata(
|
||||
consumer_seq=new_tail,
|
||||
src_sip=self._self_sip,
|
||||
src_cube=self._self_cube,
|
||||
src_pe=self._self_pe,
|
||||
src_direction=direction,
|
||||
)
|
||||
yield peer_credit_store.put(meta)
|
||||
|
||||
def _credit_latency_ns(self, direction: str) -> float:
|
||||
"""Compute credit fast path latency = credit_size / bottleneck_bw.
|
||||
|
||||
Falls back to 0 when ctx/router is unavailable (unit-test mode).
|
||||
"""
|
||||
if self.ctx is None:
|
||||
return 0.0
|
||||
qp = self._queue_pairs[direction]
|
||||
peer = qp["peer"]
|
||||
peer_pe_prefix = f"sip{peer.sip}.cube{peer.cube}.pe{peer.pe}"
|
||||
try:
|
||||
path = self.ctx.router.find_path(self._pe_prefix, peer_pe_prefix)
|
||||
return self.ctx.compute_drain_ns(path, self._credit_size_bytes)
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
# ── Helpers ──
|
||||
|
||||
@staticmethod
|
||||
def _nbytes_for(shape: tuple[int, ...], dtype: str) -> int:
|
||||
from math import prod
|
||||
bits = {"f16": 16, "bf16": 16, "f32": 32, "i8": 8, "i16": 16, "i32": 32}.get(dtype, 16)
|
||||
return prod(shape) * (bits // 8) if shape else 0
|
||||
Reference in New Issue
Block a user