Add PE-level IPCQ collective infra + unified ccl_allreduce bench (ADR-0023)

Major changes:

PE-level IPCQ infrastructure:
- New PE_IPCQ component: ring-buffer control plane with 4-direction
  neighbor mapping, head/tail pointers, backpressure (poll/sleep).
- PE_DMA extended with vc_comm channel for IPCQ outbound/inbound DMA,
  including in-flight data snapshot (D9) and op_log recording at
  outbound time for Phase 2 replay correctness.
- IpcqDmaToken piggyback model: data + metadata travel together,
  atomic visibility at receiver (invariant I6).
- Credit return fast path: bottleneck-BW latency, no fabric vc_comm.

Phase 2 data execution (ADR-0020 integration):
- op_log extended: DmaWriteCmd now captures src_space/src_addr for
  Phase 2 dma_write copy; ipcq_copy ops recorded at outbound time.
- DataExecutor replays dma_write + ipcq_copy in t_start order.
- Engine._flush_data_phase: incremental cursor-based replay after
  each engine.wait() so host reads see post-Phase-2 data.
- KernelRunner Phase 1 writes disabled when op_log is active to
  prevent stale data from corrupting the MemoryStore snapshot.

TLContext / kernel API:
- tl.send(dir, src=TensorHandle), tl.recv(dir, shape, dtype),
  tl.recv_async, tl.wait(RecvFuture), copy_to_dst mode.
- TensorHandle operator overloading (add/sub/mul/div) via thread-local
  active TLContext → MathCmd dispatch through PE_MATH.
- PE-local scratch allocator for math output handles.
- tl.load returns space="hbm" handles for correct Phase 2 addressing.
- Additional math functions: maximum, minimum, fma, clamp, softmax, cdiv.

Unified ccl_allreduce bench (PyTorch-compat host code):
- Single benches/ccl_allreduce.py with run() + worker(rank, ws, torch)
  split matching real PyTorch DDP worker pattern.
- torch.distributed facade: init_process_group, get_world_size,
  get_rank, get_backend, all_reduce, barrier — only real PyTorch names.
- AhbmCCLBackend: eager install_ipcq at init, all_reduce dispatches
  kernel via tensor shard metadata (n_elem from shards[0].nbytes).
- world_size derived from topology spec (sips × cubes × pes_per_cube)
  with optional algorithm-level override in ccl.yaml.

Tensor API (PyTorch-compat surface):
- Tensor.numpy(): gather-aware (all shards via VA-based addressing).
- Tensor.copy_(source): scatter from host tensor into sharded target.
- RuntimeContext.from_numpy(arr): host-side staging tensor.
- Tensor.data property fixed to use numpy() (was shards[0]-only).

Algorithm modules moved to src/kernbench/ccl/algorithms/:
- ring_allreduce, mesh_allreduce, tree_allreduce, hello_send.
- Each module exports kernel_args(world_size, n_elem) helper.
- ccl.yaml module paths updated to kernbench.ccl.algorithms.*.

Dead code removed:
- 7 per-variant bench files (ccl_allreduce_{tcm,hbm,sram}, etc.).
- _run_ccl_bench greenlet-per-SIP scheduler.
- benches.loader.is_ccl_bench + run_rank detection.
- benches/ccl/ directory.

Tests:
- New test_ccl_allreduce_matrix.py: 7 parametrized cases
  (ring×3 buffers, ring 8/16, mesh 4, tree 7).
- New test_runtime_api_tensor.py: copy_/numpy/from_numpy unit tests.
- Existing tests updated for new import paths + world_size_override.

Docs:
- Korean ccl-author-guide.md and ADR-0023 paths updated.
- New English versions: ccl-author-guide.en.md, ADR-0023.en.md.

502 tests pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-12 19:36:59 -07:00
parent ff2c677a9c
commit 998cc85762
60 changed files with 9196 additions and 80 deletions
+9
View File
@@ -0,0 +1,9 @@
"""CCL (Collective Communication Library) framework for kernbench (ADR-0023).
This package provides:
- topologies: builtin neighbor topology generators (ring/mesh/tree)
- helpers: utilities for algorithm authors (chunked, ring_step, ...)
- testing: mock CCL runtime for fast unit tests of algorithm kernels
See docs/adr/ADR-0023-ipcq-pe-collective.md and docs/ccl-author-guide.md.
"""
@@ -0,0 +1,29 @@
"""Hello-world CCL kernel for the docs/ccl-author-guide.md walkthrough.
Each PE sends its tile to the E neighbor and receives one tile from W,
then stores the received tile back into its own HBM slice. The simplest
possible demonstration of ``tl.send`` / ``tl.recv``.
"""
from __future__ import annotations
def kernel_args(world_size: int, n_elem: int) -> tuple:
"""Return the positional kernel arguments for the ahbm backend."""
return (n_elem,)
def kernel(t_ptr, n_elem, tl):
local_pe = tl.program_id(axis=0)
cube_id = tl.program_id(axis=1)
pes_per_cube = tl.num_programs(axis=0)
rank = cube_id * pes_per_cube + local_pe
nbytes = n_elem * 2
pe_addr = t_ptr + rank * nbytes
# Send our local HBM tile to the E neighbor.
src = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
tl.send(dir="E", src=src)
# Receive a tile from W and store it into our slice (overwrite).
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
tl.store(pe_addr, recv)
@@ -0,0 +1,73 @@
"""2D-mesh all-reduce kernel (ADR-0023).
Two-phase reduce on a square mesh of side ``S`` (world_size = S*S):
1. Row reduce: ring all-reduce along E/W within each row.
2. Column reduce: ring all-reduce along N/S within each column.
After both phases, every rank holds the global sum.
Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the
data flow so Phase 2 produces correct final HBM contents. Math/recv
handles are passed directly to the next send, avoiding store→reload
which doesn't propagate correctly with timing-only Phase 1 math.
"""
from __future__ import annotations
import math
def kernel_args(world_size: int, n_elem: int) -> tuple:
"""Return the positional kernel arguments for the ahbm backend.
Mesh all-reduce requires ``world_size`` to be a perfect square —
the mesh side length is ``sqrt(world_size)``.
"""
side = int(round(math.sqrt(world_size)))
if side * side != world_size:
raise ValueError(
f"mesh_allreduce requires a square world_size; got {world_size}"
)
return (n_elem, side)
def kernel(t_ptr, n_elem, side, tl):
"""All-reduce on a square mesh.
Args:
t_ptr: HBM base address (column-sharded VA shared across ranks)
n_elem: number of f16 elements per tile
side: mesh side length (sqrt(world_size))
tl: TLContext (ADR-0022).
"""
local_pe = tl.program_id(axis=0)
cube_id = tl.program_id(axis=1)
pes_per_cube = tl.num_programs(axis=0)
rank = cube_id * pes_per_cube + local_pe
nbytes = n_elem * 2
pe_addr = t_ptr + rank * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
current = acc
# ── Phase 1: row ring (E direction) ──
# Ring forwards each received tile (not the cumulative acc) so every
# tile passes through every rank exactly once.
for _ in range(side - 1):
tl.send(dir="E", src=current)
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
# Phase 2 column ring starts from the row-phase accumulator. We do NOT
# store/reload here — the math handle's scratch addr is the source for
# the first column send and Phase 2 ipcq_copy replays from there.
current = acc
# ── Phase 2: column ring (S direction) ──
for _ in range(side - 1):
tl.send(dir="S", src=current)
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
acc = acc + recv
current = recv
tl.store(pe_addr, acc)
@@ -0,0 +1,80 @@
"""Ring all-reduce kernel for IPCQ-based PE collective (ADR-0023).
Algorithm: 1D ring of N PEs, each PE starts with one tile of data.
After ``world_size - 1`` rounds, every PE's accumulator holds the sum
of all PE tiles.
Strategy
--------
Each PE starts with its own tile in HBM. The kernel:
1. Loads the local tile into a TensorHandle (the accumulator).
2. In each of ``world_size - 1`` rounds:
- Sends the current accumulator/recv slot to the E neighbor.
- Receives a tile from the W neighbor — the recv handle points
into the per-direction TCM slot.
- Adds the received tile to the accumulator using the TensorHandle
operator overload, which dispatches to ``MathCmd`` (PE_MATH).
3. Stores the final accumulator back to HBM via tl.store. The store is
recorded in op_log with both src and dst, so Phase 2 will copy the
replayed math result from PE-local scratch into HBM.
ADR-0020 D3 split: Phase 1 simulates timing only — math results are
not yet computed, so the accumulator data flowing through Phase 1 may
be stale. Phase 2's DataExecutor replays math + IPCQ copies + dma_write
in stable t_start order, producing correct final HBM contents.
"""
from __future__ import annotations
def kernel_args(world_size: int, n_elem: int) -> tuple:
"""Return the positional kernel arguments for the ahbm backend.
Ring all-reduce takes (n_elem, world_size) after the tensor pointer.
"""
return (n_elem, world_size)
def kernel(t_ptr, n_elem, world_size, tl):
"""Ring all-reduce.
Args:
t_ptr: HBM base address of the column-sharded tensor — all PEs
share this base. The per-PE slice lives at
``t_ptr + global_rank * n_elem * 2``.
n_elem: number of f16 elements per tile.
world_size: total number of participating ranks (passed by host).
tl: TLContext (auto-injected, ADR-0022). The kernel derives the
global rank from ``program_id(axis=0)`` (local PE) and
``program_id(axis=1)`` (cube id):
rank = cube_id * pes_per_cube + local_pe
"""
local_pe = tl.program_id(axis=0)
cube_id = tl.program_id(axis=1)
pes_per_cube = tl.num_programs(axis=0)
rank = cube_id * pes_per_cube + local_pe
nbytes = n_elem * 2 # f16
# Each PE reads from its own slice of the shared base address
pe_addr = t_ptr + rank * nbytes
# Load the local tile — handle points at HBM[pe_addr].
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
# The ring forwards each received tile to the next neighbor (NOT the
# cumulative accumulator), so every rank's tile passes through every
# rank exactly once. The accumulator sums the new arrival each round.
current = acc
for _step in range(world_size - 1):
tl.send(dir="E", src=current)
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
# TensorHandle add → MathCmd → PE_MATH (timing in Phase 1, real
# numpy in Phase 2 via DataExecutor). The result handle lives at
# an auto-allocated PE-local scratch addr.
acc = acc + recv
current = recv # forward W's tile to E next round
# Final result back to this PE's HBM slice. Op_log captures the
# source (scratch addr) and dst (HBM slice) so Phase 2 copies the
# accumulated value into HBM for verification.
tl.store(pe_addr, acc)
@@ -0,0 +1,80 @@
"""Tree all-reduce kernel for IPCQ-based PE collective (ADR-0023).
Two-phase binary tree all-reduce:
Phase 1 (reduce up):
- leaf nodes send their value to ``parent``
- internal nodes recv from each child, sum, then send to ``parent``
- root accumulates child contributions; final acc holds global sum
Phase 2 (broadcast down):
- root sends acc to ``child_left`` and ``child_right`` (if present)
- internal nodes recv from ``parent``, then forward to children
- all ranks store the final acc to HBM
Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the
data flow so Phase 2 produces correct final HBM contents. The kernel
deliberately avoids the store→reload→send pattern: math/recv handles
are passed directly to the next send so PE_DMA snapshots a deterministic
source addr that Phase 2 can replay.
"""
from __future__ import annotations
def kernel_args(world_size: int, n_elem: int) -> tuple:
"""Return the positional kernel arguments for the ahbm backend."""
return (n_elem, world_size)
def kernel(t_ptr, n_elem, world_size, tl):
"""Tree all-reduce.
Args:
t_ptr: HBM base address.
n_elem: number of f16 elements per tile.
world_size: total number of participating ranks (passed by host).
tl: TLContext (ADR-0022). Global rank from program_id(0/1).
"""
local_pe = tl.program_id(axis=0)
cube_id = tl.program_id(axis=1)
pes_per_cube = tl.num_programs(axis=0)
rank = cube_id * pes_per_cube + local_pe
nbytes = n_elem * 2
pe_addr = t_ptr + rank * nbytes
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
# Compute children/parent existence (matches tree_binary topology generator)
has_parent = rank > 0
left = 2 * rank + 1
right = 2 * rank + 2
has_left = left < world_size
has_right = right < world_size
# ── Phase 1: reduce up ──
if has_left:
recv = tl.recv(dir="child_left", shape=(n_elem,), dtype="f16")
acc = acc + recv
if has_right:
recv = tl.recv(dir="child_right", shape=(n_elem,), dtype="f16")
acc = acc + recv
if has_parent:
# Send the math/load handle directly — its addr is either the
# original HBM tile (leaf) or the PE-local scratch where the
# accumulator lives. Phase 2 ipcq_copy replays from the same addr.
tl.send(dir="parent", src=acc)
# ── Phase 2: broadcast down ──
if has_parent:
# Replace acc with the value broadcast from the parent (the global
# sum). The recv handle points at the parent-direction TCM slot.
acc = tl.recv(dir="parent", shape=(n_elem,), dtype="f16")
if has_left:
tl.send(dir="child_left", src=acc)
if has_right:
tl.send(dir="child_right", src=acc)
# Final store to HBM for the bench's verification path.
tl.store(pe_addr, acc)
+127
View File
@@ -0,0 +1,127 @@
"""CCL diagnostics: trace + pointer dump + deadlock (ADR-0023 D14).
Trace
-----
Set ``KERNBENCH_CCL_TRACE=1`` (or any truthy value) to enable per-event
logging of CCL send/recv to stdout. Off by default.
Pointer dump
------------
``pointer_dump(engine)`` returns a multi-line string showing every PE_IPCQ's
ring buffer state (my_head, my_tail, peer_head_cache, peer_tail_cache).
Useful for diagnosing hangs.
Deadlock
--------
``IpcqDeadlock`` is raised by the engine when SimPy's schedule empties
while a request is still pending — typical of unmatched send/recv pairs.
The exception message includes the pointer dump.
"""
from __future__ import annotations
import os
from typing import Any
class IpcqDeadlock(RuntimeError):
"""Raised when the simulation cannot make further progress while a
CCL request is still pending (D14 F3)."""
# ── Trace toggle ─────────────────────────────────────────────────────
_TRACE_ENABLED: bool = False
def reload_trace_setting() -> None:
"""Re-read the ``KERNBENCH_CCL_TRACE`` env var."""
global _TRACE_ENABLED
val = os.environ.get("KERNBENCH_CCL_TRACE", "")
_TRACE_ENABLED = val.strip().lower() in {"1", "true", "yes", "on"}
def trace_enabled() -> bool:
return _TRACE_ENABLED
# Initialise once at import time
reload_trace_setting()
# ── Trace event functions ────────────────────────────────────────────
def log_send(
t_ns: float,
sender: str,
direction: str,
nbytes: int,
sender_seq: int,
) -> None:
if not _TRACE_ENABLED:
return
print(
f"[ccl t={t_ns:.1f} send] {sender} dir={direction} nbytes={nbytes} seq={sender_seq}",
flush=True,
)
def log_recv(
t_ns: float,
receiver: str,
direction: str,
nbytes: int,
) -> None:
if not _TRACE_ENABLED:
return
print(
f"[ccl t={t_ns:.1f} recv] {receiver} dir={direction} nbytes={nbytes}",
flush=True,
)
def log_credit_return(
t_ns: float,
sender: str,
direction: str,
consumer_seq: int,
) -> None:
if not _TRACE_ENABLED:
return
print(
f"[ccl t={t_ns:.1f} credit] {sender} dir={direction} seq={consumer_seq}",
flush=True,
)
# ── Pointer dump ─────────────────────────────────────────────────────
def pointer_dump(engine: Any) -> str:
"""Return a multi-line string of every PE_IPCQ's pointer state."""
lines: list[str] = []
components = getattr(engine, "_components", {})
for node_id in sorted(components):
if not node_id.endswith(".pe_ipcq"):
continue
comp = components[node_id]
qps = getattr(comp, "queue_pairs", {})
if not qps:
continue
lines.append(node_id)
for d in sorted(qps):
qp = qps[d]
peer = qp["peer"]
lines.append(
f" {d}: peer=sip{peer.sip}.cube{peer.cube}.pe{peer.pe} "
f"my_head={qp['my_head']} my_tail={qp['my_tail']} "
f"peer_head_cache={qp['peer_head_cache']} "
f"peer_tail_cache={qp['peer_tail_cache']}"
)
return "\n".join(lines)
def print_pointer_dump(engine: Any) -> None:
"""Convenience: print pointer_dump(engine) to stdout."""
print(pointer_dump(engine), flush=True)
+118
View File
@@ -0,0 +1,118 @@
"""Helpers for CCL algorithm authors (ADR-0023 D15).
These are pure utility functions usable from any kernel module:
from kernbench.ccl.helpers import chunked, ring_step, tree_step
They keep algorithm code short and free of off-by-one bugs.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
_DTYPE_BYTES = {
"f16": 2, "fp16": 2, "float16": 2, "bf16": 2,
"f32": 4, "fp32": 4, "float32": 4,
"i8": 1, "int8": 1,
"i16": 2, "int16": 2,
"i32": 4, "int32": 4,
}
def _itemsize(dtype: str) -> int:
if dtype not in _DTYPE_BYTES:
raise ValueError(f"Unsupported dtype: {dtype}")
return _DTYPE_BYTES[dtype]
# ── chunked ──────────────────────────────────────────────────────────
@dataclass(frozen=True)
class Chunk:
"""One chunk of a tensor used by collective algorithms."""
addr: int
n_elem: int
nbytes: int
def chunked(
base_addr: int,
n_chunks: int,
n_elem: int,
dtype: str = "f16",
) -> list[Chunk]:
"""Slice a 1D buffer into ``n_chunks`` equal Chunks.
Args:
base_addr: starting address of the buffer.
n_chunks: number of equal chunks to produce.
n_elem: total number of elements (must be divisible by n_chunks).
dtype: element type for byte-size calculation.
Returns:
List of ``Chunk`` objects whose addresses are consecutive.
Raises:
ValueError: if n_elem is not divisible by n_chunks.
"""
if n_elem % n_chunks != 0:
raise ValueError(
f"chunked: n_elem ({n_elem}) not divisible by n_chunks ({n_chunks})"
)
per_chunk_elem = n_elem // n_chunks
isize = _itemsize(dtype)
per_chunk_bytes = per_chunk_elem * isize
return [
Chunk(
addr=base_addr + i * per_chunk_bytes,
n_elem=per_chunk_elem,
nbytes=per_chunk_bytes,
)
for i in range(n_chunks)
]
# ── ring_step ────────────────────────────────────────────────────────
def ring_step(rank: int, step: int, world_size: int) -> tuple[int, int]:
"""Return ``(send_chunk_idx, recv_chunk_idx)`` for a ring algorithm step.
Standard reduce-scatter / all-gather ring schedule:
at step s, rank r sends chunk (r - s) and receives chunk (r - s - 1)
modulo world_size.
Used by ring all-reduce kernels:
for step in range(world_size - 1):
send_idx, recv_idx = ring_step(rank, step, world_size)
tl.send(dir="E", src=chunks[send_idx])
chunks[recv_idx] += tl.recv(dir="W").data
"""
send_idx = (rank - step) % world_size
recv_idx = (rank - step - 1) % world_size
return send_idx, recv_idx
# ── tree_step ────────────────────────────────────────────────────────
def tree_step(rank: int, world_size: int) -> dict[str, Any]:
"""Return parent/children for binary tree rooted at rank 0.
Returns:
``{"parent": int|None, "children": list[int]}``
"""
parent = (rank - 1) // 2 if rank > 0 else None
children: list[int] = []
left = 2 * rank + 1
right = 2 * rank + 2
if left < world_size:
children.append(left)
if right < world_size:
children.append(right)
return {"parent": parent, "children": children}
+266
View File
@@ -0,0 +1,266 @@
"""IPCQ install plan for AhbmCCLBackend (ADR-0023 D10/D11/D12).
Given a ccl.yaml config, the topology, and the engine, this module:
1. Loads ccl.yaml and resolves the chosen algorithm.
2. Maps each rank to a (sip, cube, pe) PE address using a linear scheme.
3. Allocates per-rank IPCQ ring buffer base addresses (synthetic but
unique-per-PE; see notes below).
4. Builds neighbor tables via the algorithm's ``topology`` field plus the
optional ``neighbors()`` override hook from the algorithm module.
5. Wires bidirectional credit-return SimPy Stores between every (PE, peer)
pair.
6. Installs each PE_IPCQ component's neighbor table directly via its
``_install_neighbors`` sideband call (equivalent to fan-out IpcqInitMsg
without going through fabric).
Address scheme
--------------
For the first implementation we use a synthetic address scheme that
guarantees uniqueness per (sip, cube, pe, direction) without going
through ``PEMemAllocator``. The address is encoded as:
base = IPCQ_BASE | (sip << 40) | (cube << 32) | (pe << 24)
rx_base[direction_idx] = base + direction_idx * (n_slots * slot_size)
The ``buffer_kind`` (tcm/hbm/sram) selects the *MemoryStore space* into
which data is written. Within a space, addresses are unique per PE so
the existing MemoryStore (``{space: {addr: ndarray}}``) handles them
naturally.
This bypasses the topology's address resolver / PhysAddr encoding and
treats IPCQ buffers as a separate, parallel address namespace. Real PA
encoding can be plugged in later without changing the rest of the design.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import simpy
import yaml
from kernbench.ccl.topologies import resolve_topology
from kernbench.common.ipcq_types import (
IpcqEndpoint,
IpcqInitEntry,
)
from kernbench.runtime_api.kernel import IpcqInitMsg
# IPCQ synthetic address space top bit
_IPCQ_BASE = 1 << 60
def _ipcq_base_for_pe(sip: int, cube: int, pe: int) -> int:
return _IPCQ_BASE | (sip << 40) | (cube << 32) | (pe << 24)
# ── ccl.yaml loading ─────────────────────────────────────────────────
def load_ccl_config(path: str | Path | None = None) -> dict:
"""Load and validate ccl.yaml. Searches cwd and project root."""
if path is None:
candidates = [
Path.cwd() / "ccl.yaml",
Path(__file__).resolve().parents[3] / "ccl.yaml",
]
for p in candidates:
if p.exists():
path = p
break
if path is None:
raise FileNotFoundError(
"ccl.yaml not found. Place it at project root or cwd."
)
with open(path) as f:
cfg = yaml.safe_load(f)
if "defaults" not in cfg:
raise ValueError("ccl.yaml missing 'defaults' section")
if "algorithms" not in cfg:
raise ValueError("ccl.yaml missing 'algorithms' section")
return cfg
def resolve_algorithm_config(cfg: dict, name: str | None = None) -> dict:
"""Merge defaults with the chosen algorithm's overrides.
Returns a flat dict with at minimum: module, topology, buffer_kind,
backpressure, n_slots, slot_size, ipcq_credit_size_bytes, world_size.
"""
defaults = dict(cfg.get("defaults", {}))
algo_name = name or defaults.get("algorithm")
if algo_name is None:
raise ValueError("ccl.yaml: defaults.algorithm not set")
algos = cfg.get("algorithms", {})
if algo_name not in algos:
raise ValueError(
f"ccl.yaml: algorithm '{algo_name}' not in algorithms section"
)
merged = defaults.copy()
merged.update(algos[algo_name])
merged["algorithm"] = algo_name
return merged
# ── rank → PE mapping ────────────────────────────────────────────────
def linear_rank_to_pe(rank: int, spec: dict) -> tuple[int, int, int]:
"""Map a rank to (sip, cube, pe) using linear topology order."""
sips = spec["system"]["sips"]["count"]
cubes_per_sip = spec["sip"]["cube_mesh"]["w"] * spec["sip"]["cube_mesh"]["h"]
pe_layout = spec["cube"]["pe_layout"]
pes_per_cube = pe_layout["pe_per_corner"] * len(pe_layout["corners"])
pes_per_sip = cubes_per_sip * pes_per_cube
if rank >= sips * pes_per_sip:
raise ValueError(
f"rank {rank} exceeds total PE count {sips * pes_per_sip}"
)
sip = rank // pes_per_sip
rem = rank % pes_per_sip
cube = rem // pes_per_cube
pe = rem % pes_per_cube
return sip, cube, pe
# ── Install plan ─────────────────────────────────────────────────────
def install_ipcq(
engine: Any,
spec: dict,
cfg: dict,
algo_module: Any | None = None,
rank_to_pe: list[tuple[int, int, int]] | None = None,
) -> dict[str, Any]:
"""Build neighbor tables and install them in every participating PE_IPCQ.
Args:
engine: GraphEngine with ``_components`` dict
spec: topology spec dict
cfg: merged algorithm config (from ``resolve_algorithm_config``)
algo_module: optional algorithm Python module (for neighbors override)
rank_to_pe: optional explicit rank → (sip, cube, pe) mapping. If
None, the default linear mapping is used.
Returns:
A diagnostics dict with the install plan (rank → PE map, neighbor table).
"""
if "world_size" in cfg:
world_size = int(cfg["world_size"])
else:
# Topology-derived fallback (mirrors AhbmCCLBackend / RuntimeContext).
sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
cm = spec.get("sip", {}).get("cube_mesh", {})
cubes_per_sip = int(cm.get("w", 1)) * int(cm.get("h", 1))
pl = spec.get("cube", {}).get("pe_layout", {})
corners = pl.get("corners", [])
pe_per_corner = int(pl.get("pe_per_corner", 1))
pes_per_cube = pe_per_corner * max(len(corners), 1)
world_size = sips * cubes_per_sip * pes_per_cube
buffer_kind = cfg["buffer_kind"]
n_slots = int(cfg["n_slots"])
slot_size = int(cfg["slot_size"])
backpressure = cfg["backpressure"]
credit_size_bytes = int(cfg.get("ipcq_credit_size_bytes", 16))
# Step 1: rank → (sip, cube, pe)
if rank_to_pe is not None:
if len(rank_to_pe) != world_size:
raise ValueError(
f"rank_to_pe has {len(rank_to_pe)} entries but world_size={world_size}"
)
rank_pe = list(rank_to_pe)
else:
rank_pe: list[tuple[int, int, int]] = [
linear_rank_to_pe(r, spec) for r in range(world_size)
]
pe_to_rank = {(s, c, p): r for r, (s, c, p) in enumerate(rank_pe)}
# Step 2: resolve topology fn (with optional override)
topo_fn = resolve_topology(cfg["topology"], algo_module=algo_module)
# Build per-rank neighbor map
neighbor_table: dict[int, dict[str, int]] = {}
for r in range(world_size):
neighbor_table[r] = topo_fn(r, world_size)
# Step 3: pull the live engine reference for each PE_IPCQ
components = engine._components
pe_ipcq_id = lambda s, c, p: f"sip{s}.cube{c}.pe{p}.pe_ipcq"
# Step 4: per-PE rx_base address and per-PE credit_inbox
direction_keys = sorted({d for nt in neighbor_table.values() for d in nt})
direction_idx = {d: i for i, d in enumerate(direction_keys)}
bytes_per_direction = n_slots * slot_size
def rx_base(s: int, c: int, p: int, d: str) -> int:
return _ipcq_base_for_pe(s, c, p) + direction_idx[d] * bytes_per_direction
# Wire bidirectional credit stores: backend creates the SimPy Stores
# by reading each rank's PE_IPCQ.credit_inbox property.
rank_to_credit_inbox: dict[int, simpy.Store] = {}
for r, (s, c, p) in enumerate(rank_pe):
comp = components[pe_ipcq_id(s, c, p)]
# Trigger lazy creation of credit_inbox if not yet started.
# PE_IPCQ.start() creates it; we ensure it exists.
if comp._credit_inbox is None:
comp._credit_inbox = simpy.Store(engine._env)
rank_to_credit_inbox[r] = comp.credit_inbox
# Step 5: build IpcqInitMsg per rank and call _install_neighbors directly
plan: dict[str, Any] = {
"world_size": world_size,
"rank_to_pe": rank_pe,
"buffer_kind": buffer_kind,
"neighbor_table": neighbor_table,
}
def reverse_direction(my_rank: int, peer_rank: int) -> str | None:
"""Find which direction in peer's neighbor table points back to my_rank."""
for d, target in neighbor_table[peer_rank].items():
if target == my_rank:
return d
return None
for r, (s, c, p) in enumerate(rank_pe):
my_pe_ipcq = components[pe_ipcq_id(s, c, p)]
nbrs = neighbor_table[r]
entries: list[IpcqInitEntry] = []
for d, peer_rank in nbrs.items():
if peer_rank is None:
continue
peer_s, peer_c, peer_p = rank_pe[peer_rank]
peer_dir = reverse_direction(r, peer_rank)
if peer_dir is None:
# Peer doesn't have a reverse entry — skip (asymmetric topology)
continue
peer_endpoint = IpcqEndpoint(
sip=peer_s, cube=peer_c, pe=peer_p,
buffer_kind=buffer_kind,
rx_base_pa=rx_base(peer_s, peer_c, peer_p, peer_dir),
rx_base_va=0,
n_slots=n_slots, slot_size=slot_size,
)
entries.append(IpcqInitEntry(
direction=d,
peer=peer_endpoint,
my_rx_base_pa=rx_base(s, c, p, d),
my_rx_base_va=0,
n_slots=n_slots, slot_size=slot_size,
peer_credit_store=rank_to_credit_inbox[peer_rank],
))
msg = IpcqInitMsg(
correlation_id="ccl_init", request_id=f"init_r{r}",
target_sips=(s,), target_cubes=(c,), target_pe=p,
entries=tuple(entries),
backpressure_mode=backpressure,
buffer_kind=buffer_kind,
credit_size_bytes=credit_size_bytes,
)
my_pe_ipcq._install_neighbors(msg)
return plan
+465
View File
@@ -0,0 +1,465 @@
"""Mock CCL runtime for fast unit tests of algorithm kernels (ADR-0023 D15).
Runs a kernel function once per rank with a minimal ``tl`` shim — no SimPy,
no PE_DMA, no fabric simulation. Just enough to verify *functional*
correctness of an IPCQ-based collective algorithm.
Cross-rank send/recv is implemented with greenlet cooperative scheduling
plus per-(rank, direction) FIFO queues. Backpressure is not modeled —
queues are unbounded.
Typical usage in a test::
from kernbench.ccl.testing import run_kernel_in_mock
from kernbench.ccl.algorithms.ring_allreduce import kernel
inputs = [np.full(16, r + 1, dtype="f16") for r in range(4)]
outputs = run_kernel_in_mock(
kernel_fn=kernel, world_size=4, topology="ring_1d",
inputs=inputs, kernel_args=(16,),
)
for r in range(4):
assert np.allclose(outputs[r], sum(inputs))
"""
from __future__ import annotations
from collections import deque
from typing import Any, Callable
import numpy as np
from greenlet import greenlet
from kernbench.ccl.topologies import resolve_topology
from kernbench.common.ipcq_types import IpcqInvalidDirection
from kernbench.common.pe_commands import TensorHandle
# ── Per-rank fake state ──────────────────────────────────────────────
class _MockRankState:
"""Per-rank scratch holding HBM/recv slots and tl shim hooks."""
def __init__(
self,
rank: int,
world_size: int,
neighbors: dict[str, int],
input_arr: np.ndarray,
) -> None:
self.rank = rank
self.world_size = world_size
self.neighbors = neighbors # direction → peer rank
# HBM "memory": addr → ndarray. Per-rank, no cross-rank sharing.
self._hbm: dict[int, np.ndarray] = {}
self._tcm: dict[int, np.ndarray] = {}
# ``t_ptr`` is the address the kernel sees. Real benches use a
# column-sharded VA so each rank reads from ``t_ptr + rank*nbytes``.
# Mirror that here: each rank's slice lives at the rank-specific addr.
nbytes = int(input_arr.nbytes)
self.t_ptr = 0 # base; per-rank offset is rank * nbytes
self._slice_addr = rank * nbytes
self._hbm[self._slice_addr] = input_arr.copy()
# Inbound recv FIFOs: direction → deque[ndarray]
self.recv_q: dict[str, deque[np.ndarray]] = {d: deque() for d in neighbors}
# Output (set when kernel calls tl.store at slice address)
self.output: np.ndarray | None = None
# Greenlet for this rank — set later
self.g: greenlet | None = None
# ── Mock TLContext ───────────────────────────────────────────────────
class _MockTL:
"""Drop-in tl shim for mock runtime.
Supports the subset of TLContext API that algorithm authors use:
program_id, num_programs, load, store, send, recv, recv_async, wait,
plus arithmetic operations on TensorHandle (eager numpy execution,
no SimPy involved).
"""
def __init__(self, state: _MockRankState, scheduler: "_MockScheduler") -> None:
self._state = state
self._scheduler = scheduler
self._handle_counter = 0
def _next_id(self) -> str:
self._handle_counter += 1
return f"mt{self._handle_counter}"
@property
def rank(self) -> int:
return self._state.rank
@property
def world_size(self) -> int:
return self._state.world_size
# axis-aware
def program_id(self, axis: int = 0) -> int:
return self._state.rank if axis == 0 else 0
def num_programs(self, axis: int = 0) -> int:
return self._state.world_size if axis == 0 else 1
# ── arithmetic ops (called by TensorHandle.__add__ etc.) ──
def _binary_math(self, op: str, a: TensorHandle, b: TensorHandle) -> TensorHandle:
a_data = np.asarray(a.data) if a.data is not None else None
b_data = np.asarray(b.data) if b.data is not None else None
if a_data is None or b_data is None:
result = None
elif op == "add":
result = a_data + b_data
elif op == "sub":
result = a_data - b_data
elif op == "mul":
result = a_data * b_data
elif op == "div":
result = a_data / b_data
elif op == "maximum":
result = np.maximum(a_data, b_data)
elif op == "minimum":
result = np.minimum(a_data, b_data)
else:
raise NotImplementedError(f"mock _binary_math: op {op!r} not implemented")
return TensorHandle(
id=self._next_id(),
addr=0, shape=a.shape, dtype=a.dtype,
nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0,
data=result, space="tcm",
)
def maximum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
return self._binary_math("maximum", a, b)
def minimum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
return self._binary_math("minimum", a, b)
def fma(
self, a: TensorHandle, b: TensorHandle, c: TensorHandle,
) -> TensorHandle:
a_data = np.asarray(a.data) if a.data is not None else None
b_data = np.asarray(b.data) if b.data is not None else None
c_data = np.asarray(c.data) if c.data is not None else None
result = (
a_data * b_data + c_data
if (a_data is not None and b_data is not None and c_data is not None)
else None
)
return TensorHandle(
id=self._next_id(),
addr=0, shape=a.shape, dtype=a.dtype,
nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0,
data=result, space="tcm",
)
def clamp(
self,
x: TensorHandle,
min: TensorHandle,
max: TensorHandle,
) -> TensorHandle:
x_data = np.asarray(x.data) if x.data is not None else None
lo = np.asarray(min.data) if min.data is not None else None
hi = np.asarray(max.data) if max.data is not None else None
result = (
np.minimum(np.maximum(x_data, lo), hi)
if (x_data is not None and lo is not None and hi is not None)
else None
)
return TensorHandle(
id=self._next_id(),
addr=0, shape=x.shape, dtype=x.dtype,
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
data=result, space="tcm",
)
def softmax(self, x: TensorHandle, axis: int = -1) -> TensorHandle:
x_data = np.asarray(x.data) if x.data is not None else None
if x_data is None:
result = None
else:
x_max = np.max(x_data, axis=axis, keepdims=True)
e = np.exp(x_data - x_max)
s = np.sum(e, axis=axis, keepdims=True)
result = e / s
return TensorHandle(
id=self._next_id(),
addr=0, shape=x.shape, dtype=x.dtype,
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
data=result, space="tcm",
)
@staticmethod
def cdiv(a: int, b: int) -> int:
return -(-int(a) // int(b))
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
x_data = np.asarray(x.data) if x.data is not None else None
if x_data is None:
result = None
elif op == "exp":
result = np.exp(x_data)
elif op == "log":
result = np.log(x_data)
elif op == "sqrt":
result = np.sqrt(x_data)
elif op == "abs":
result = np.abs(x_data)
elif op == "sigmoid":
result = 1.0 / (1.0 + np.exp(-x_data))
elif op == "cos":
result = np.cos(x_data)
elif op == "sin":
result = np.sin(x_data)
else:
raise NotImplementedError(f"mock _unary_math: op {op!r} not implemented")
return TensorHandle(
id=self._next_id(),
addr=0, shape=x.shape, dtype=x.dtype,
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
data=result, space="tcm",
)
def load(self, ptr: int, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
data = self._state._hbm.get(ptr)
if data is None:
data = np.zeros(shape, dtype=np.float16)
return TensorHandle(
id=f"load_{ptr}", addr=ptr, shape=shape, dtype=dtype,
nbytes=int(np.prod(shape)) * 2, data=data, space="hbm",
)
def store(self, ptr: int, handle: TensorHandle) -> None:
if handle.data is not None:
self._state._hbm[ptr] = np.asarray(handle.data)
if ptr == self._state._slice_addr:
self._state.output = self._state._hbm[ptr]
# IPCQ
def send(
self,
dir: str,
src: TensorHandle | None = None,
*,
src_addr: int | None = None,
nbytes: int | None = None,
shape: tuple[int, ...] | None = None,
dtype: str = "f16",
space: str = "tcm",
) -> None:
if dir not in self._state.neighbors:
raise IpcqInvalidDirection(
f"mock tl.send: direction {dir!r} not in neighbors {list(self._state.neighbors)}"
)
if src is not None:
if src.data is not None:
data = np.asarray(src.data)
else:
# Resolve from this rank's local memory at src.addr
space_dict = self._state._hbm if src.space == "hbm" else self._state._tcm
stored = space_dict.get(src.addr)
if stored is None:
raise RuntimeError(
f"mock tl.send: no data at {src.space}:0x{src.addr:x}"
)
data = np.asarray(stored)
else:
data = None
if data is None:
raise RuntimeError("mock tl.send: src is None")
peer_rank = self._state.neighbors[dir]
# Find the reverse direction in peer's neighbors that points back to me
peer_state = self._scheduler.states[peer_rank]
reverse_dir = None
for d, target in peer_state.neighbors.items():
if target == self._state.rank:
reverse_dir = d
break
if reverse_dir is None:
raise RuntimeError(
f"mock tl.send: peer rank {peer_rank} has no reverse direction"
)
peer_state.recv_q[reverse_dir].append(data.copy())
# After delivering, hand control back to scheduler so the receiver
# can wake up.
self._scheduler.yield_()
def recv_async(
self,
dir: str,
shape: tuple[int, ...] = (),
dtype: str = "f16",
) -> dict:
"""Non-blocking recv. Returns a future dict to pass to tl.wait."""
if dir not in self._state.neighbors:
raise IpcqInvalidDirection(
f"mock tl.recv_async: direction {dir!r} not in neighbors"
)
return {"_kind": "recv_future", "dir": dir, "shape": shape, "dtype": dtype}
def wait(self, future: Any) -> TensorHandle:
"""Block until the recv future has data."""
if not isinstance(future, dict) or future.get("_kind") != "recv_future":
raise TypeError("tl.wait: expected recv future from tl.recv_async")
d = future["dir"]
while not self._state.recv_q[d]:
self._scheduler.yield_()
data = self._state.recv_q[d].popleft()
return self._make_handle(data, d, future["dtype"])
def recv(
self,
dir: str | None = None,
shape: tuple[int, ...] = (),
dtype: str = "f16",
) -> TensorHandle:
if dir is not None and dir not in self._state.neighbors:
raise IpcqInvalidDirection(
f"mock tl.recv: direction {dir!r} not in neighbors {list(self._state.neighbors)}"
)
# Wait for data
while True:
if dir is None:
# round-robin over directions
for d in self._state.neighbors:
if self._state.recv_q[d]:
data = self._state.recv_q[d].popleft()
return self._make_handle(data, d, dtype)
else:
if self._state.recv_q[dir]:
data = self._state.recv_q[dir].popleft()
return self._make_handle(data, dir, dtype)
# Yield to other ranks
self._scheduler.yield_()
def _make_handle(self, data: np.ndarray, direction: str, dtype: str) -> TensorHandle:
return TensorHandle(
id=f"recv_{direction}",
addr=0, shape=data.shape, dtype=dtype,
nbytes=int(data.nbytes), data=data, space="tcm",
)
# ── Cooperative scheduler ────────────────────────────────────────────
class _MockScheduler:
"""Round-robin cooperative scheduler over rank greenlets."""
def __init__(self, states: list[_MockRankState]) -> None:
self.states = states
self._parent: greenlet | None = None
self._cur_idx = 0
def yield_(self) -> None:
"""Called from inside a rank greenlet to give other ranks a turn."""
assert self._parent is not None
self._parent.switch()
def run(self, kernel_fn: Callable, kernel_args: tuple) -> list[np.ndarray]:
from kernbench.triton_emu.tl_context import TLContext
self._parent = greenlet.getcurrent()
n = len(self.states)
# Per-rank tl shim
tls: dict[int, _MockTL] = {}
def _spawn(rank_idx: int) -> greenlet:
state = self.states[rank_idx]
tl = _MockTL(state, self)
tls[rank_idx] = tl
def _entry():
# Activate this rank's tl for TensorHandle operator overloads
TLContext._set_active(tl) # type: ignore[attr-defined]
try:
kernel_fn(state.t_ptr, *kernel_args, tl=tl)
finally:
TLContext._set_active(None) # type: ignore[attr-defined]
return greenlet(_entry)
for state in self.states:
state.g = _spawn(state.rank)
# Drive each rank round-robin until all dead. Detect global deadlock.
max_rounds = 10_000
round_no = 0
while True:
alive = [s for s in self.states if s.g is not None and not s.g.dead]
if not alive:
break
progressed = False
for s in self.states:
if s.g is None or s.g.dead:
continue
# Multi-rank greenlets share TLContext active state via the
# module-level thread-local; restore this rank's tl before
# resuming so TensorHandle operator overloads dispatch to
# the right _MockTL.
TLContext._set_active(tls[s.rank]) # type: ignore[attr-defined]
s.g.switch()
if s.g.dead:
progressed = True
TLContext._set_active(None) # type: ignore[attr-defined]
# Loose progress check: if no greenlet died and queues didn't grow,
# advance round counter; abort after too many idle rounds.
round_no += 1
if round_no > max_rounds and not progressed:
raise RuntimeError(
"mock CCL runtime: deadlock detected (no progress for "
f"{max_rounds} rounds)"
)
return [
s.output if s.output is not None else s._hbm.get(s._slice_addr)
for s in self.states
]
# ── Public entry ────────────────────────────────────────────────────
def run_kernel_in_mock(
kernel_fn: Callable,
world_size: int,
topology: str,
inputs: list[np.ndarray],
kernel_args: tuple = (),
algo_module: Any | None = None,
) -> list[np.ndarray]:
"""Run a CCL kernel under the mock runtime with no SimPy/fabric.
Args:
kernel_fn: ``kernel(t_ptr, *kernel_args, tl=...)``
world_size: number of ranks
topology: builtin topology name (e.g. "ring_1d")
inputs: per-rank input ndarrays. ``inputs[r]`` becomes rank r's
local tile at HBM address 0.
kernel_args: extra positional args after t_ptr
algo_module: optional module providing ``neighbors()`` override
Returns:
Per-rank output ndarrays — whatever the kernel wrote via tl.store
(or the original input if the kernel didn't store).
"""
if len(inputs) != world_size:
raise ValueError(f"len(inputs)={len(inputs)} != world_size={world_size}")
topo_fn = resolve_topology(topology, algo_module=algo_module)
states = [
_MockRankState(
rank=r, world_size=world_size,
neighbors=topo_fn(r, world_size),
input_arr=inputs[r],
)
for r in range(world_size)
]
sched = _MockScheduler(states)
return sched.run(kernel_fn, kernel_args)
+128
View File
@@ -0,0 +1,128 @@
"""Builtin neighbor topology generators for CCL backend (ADR-0023 D11).
Each generator takes ``(rank, world_size)`` and returns a
``dict[direction, peer_rank]`` for that rank. ``direction`` is one of
``"N" | "S" | "E" | "W"`` for ring/mesh, or
``"parent" | "child_left" | "child_right"`` for tree topologies.
Algorithm modules may override the generated map by defining a
``neighbors(rank, world_size, neighbor_map) -> dict | None`` function in
the same module (see D11 / D15). ``resolve_topology`` wires these together.
"""
from __future__ import annotations
from typing import Any, Callable
NeighborMap = dict[str, int]
TopologyFn = Callable[[int, int], NeighborMap]
# ── Builtin generators ───────────────────────────────────────────────
def ring_1d(rank: int, world_size: int) -> NeighborMap:
"""1D bidirectional ring (E/W)."""
return {
"E": (rank + 1) % world_size,
"W": (rank - 1) % world_size,
}
def ring_1d_unidir(rank: int, world_size: int) -> NeighborMap:
"""1D unidirectional ring (E only)."""
return {"E": (rank + 1) % world_size}
def mesh_2d(rank: int, world_size: int) -> NeighborMap:
"""Square 2D mesh (N/S/E/W).
Layout: rank = row * side + col, with side = sqrt(world_size).
Wrap-around (torus) on all four edges.
"""
side = int(round(world_size ** 0.5))
if side * side != world_size:
raise ValueError(
f"mesh_2d requires square world_size, got {world_size}"
)
r, c = divmod(rank, side)
return {
"N": ((r - 1) % side) * side + c,
"S": ((r + 1) % side) * side + c,
"W": r * side + (c - 1) % side,
"E": r * side + (c + 1) % side,
}
def tree_binary(rank: int, world_size: int) -> NeighborMap:
"""Binary tree rooted at rank 0.
Children of rank r are 2r+1 and 2r+2 (if within world_size).
Parent of rank r > 0 is (r-1)//2.
Returned keys (only those that exist):
"parent", "child_left", "child_right"
"""
n: NeighborMap = {}
if rank > 0:
n["parent"] = (rank - 1) // 2
left = 2 * rank + 1
right = 2 * rank + 2
if left < world_size:
n["child_left"] = left
if right < world_size:
n["child_right"] = right
return n
def none(rank: int, world_size: int) -> NeighborMap:
"""Empty map — algorithm's neighbors() must build from scratch."""
return {}
_BUILTIN: dict[str, TopologyFn] = {
"ring_1d": ring_1d,
"ring_1d_unidir": ring_1d_unidir,
"mesh_2d": mesh_2d,
"tree_binary": tree_binary,
"none": none,
}
# ── Resolution ───────────────────────────────────────────────────────
def resolve_topology(
name: str, algo_module: Any | None = None,
) -> TopologyFn:
"""Return a callable ``(rank, world_size) -> NeighborMap``.
Args:
name: builtin topology name from ccl.yaml. Must be one of
``ring_1d``, ``ring_1d_unidir``, ``mesh_2d``, ``tree_binary``,
or ``none``.
algo_module: optional algorithm module. If it defines
``neighbors(rank, world_size, neighbor_map)``, that hook is
invoked after the builtin to override the result.
Returning None from neighbors() leaves the builtin map
unchanged; returning a dict replaces it.
Raises:
ValueError: if ``name`` is not a known builtin.
"""
if name not in _BUILTIN:
raise ValueError(
f"Unknown topology '{name}'. "
f"Available builtins: {list(_BUILTIN)}"
)
builtin_fn = _BUILTIN[name]
override_fn = getattr(algo_module, "neighbors", None) if algo_module else None
if override_fn is None or not callable(override_fn):
return builtin_fn
def _wrapped(rank: int, world_size: int) -> NeighborMap:
base = builtin_fn(rank, world_size)
result = override_fn(rank, world_size, base)
if result is None:
return base
return result
return _wrapped
+234
View File
@@ -0,0 +1,234 @@
"""IPCQ schemas and exceptions (ADR-0023 D2.5, D12, D14 F1).
This module contains the data structures and exceptions used by the
PE-level IPCQ collective communication infrastructure. The host-facing
sideband fan-out message ``IpcqInitMsg`` lives in
``kernbench.runtime_api.kernel`` (alongside other fabric messages),
while all internal token / metadata / command schemas are kept here.
Layering:
PE_CPU --IpcqRequest(IpcqSendCmd|IpcqRecvCmd)--> PE_IPCQ
PE_IPCQ --IpcqDmaToken--> PE_DMA (vc_comm)
PE_DMA --IpcqMetaArrival--> PE_IPCQ (atomic, D9)
PE_IPCQ --IpcqCreditMetadata--> peer PE_IPCQ (fast path, D9)
See ADR-0023 for the full design.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Union
if TYPE_CHECKING:
import simpy
# ── D14 F1: invalid direction exception ──────────────────────────────
class IpcqInvalidDirection(ValueError):
"""Raised when a kernel calls tl.send/recv with a direction that
has no neighbor installed for this PE."""
# ── D2.5: IpcqEndpoint ───────────────────────────────────────────────
@dataclass(frozen=True)
class IpcqEndpoint:
"""송신 측이 peer's rx_buffer 주소를 계산하기 위해 필요한 모든 정보 (D2.5).
Sender PE_IPCQ uses this to compute the destination PA for its DMA
write into the peer's rx ring buffer slot:
slot_idx = sender.my_head % peer.n_slots
dst_pa = peer.rx_base_pa + slot_idx * peer.slot_size
"""
sip: int # destination SIP
cube: int # destination cube
pe: int # destination PE (cube-local index)
buffer_kind: str # "tcm" | "hbm" | "sram"
rx_base_pa: int # peer rx_buffer base PA (PhysAddr.encode())
rx_base_va: int # peer rx_buffer base VA (optional, MMU)
n_slots: int # peer ring depth (wrap-around modulo)
slot_size: int # peer slot size (offset multiplier)
# ── D12: IpcqInitEntry (used by IpcqInitMsg in kernel.py) ────────────
@dataclass(frozen=True)
class IpcqInitEntry:
"""One direction's neighbor entry that backend installs into a PE_IPCQ
via IpcqInitMsg (kernbench.runtime_api.kernel.IpcqInitMsg, D12).
"""
direction: str # "N" | "S" | "E" | "W"
peer: IpcqEndpoint # see D2.5
my_rx_base_pa: int # this PE's own rx_buffer base
my_rx_base_va: int # this PE's own rx_buffer base VA (optional)
n_slots: int # this PE's ring depth
slot_size: int # this PE's slot size
# Credit fast path channel (D9).
# Contract: must be a simpy.Store instance dedicated to receiving
# IpcqCreditMetadata objects only. Backend wires it once at init time
# and the receiving PE_IPCQ owns its consumer side; the sender (peer's
# PE_IPCQ) puts IpcqCreditMetadata directly into this store via
# _delayed_credit_send. Do not put any other object type.
peer_credit_store: "simpy.Store"
# ── D12: IpcqSendCmd (PE_CPU → PE_IPCQ) ──────────────────────────────
@dataclass(frozen=True)
class IpcqSendCmd:
"""tl.send command issued by the kernel to PE_IPCQ."""
direction: str # "N" | "S" | "E" | "W"
src_addr: int # source data address (TCM/HBM/SRAM)
src_space: str # "tcm" | "hbm" | "sram"
nbytes: int
shape: tuple[int, ...] # data shape (op_log + MemoryStore use)
dtype: str
handle_id: str # completion tracking
data_op: bool = True # ADR-0020 op_log recording flag
# ── D12: IpcqRecvCmd (PE_CPU → PE_IPCQ) ──────────────────────────────
@dataclass(frozen=True)
class IpcqRecvCmd:
"""tl.recv command issued by the kernel to PE_IPCQ.
Two modes (recv_mode):
"return_slot" — return slot address as-is (default, zero-copy).
Kernel uses the slot memory directly.
"copy_to_dst" — copy slot data to dst_addr, then return.
"""
direction: str | None # None → round-robin (weak fairness, D4)
shape: tuple[int, ...]
dtype: str
handle_id: str
recv_mode: str = "return_slot"
dst_addr: int = 0 # used only when recv_mode == "copy_to_dst"
dst_space: str = "" # used only when recv_mode == "copy_to_dst"
blocking: bool = True
data_op: bool = True
# ── D12: IpcqDmaToken (PE_IPCQ → PE_DMA, vc_comm) ───────────────────
@dataclass
class IpcqDmaToken:
"""Token sent from PE_IPCQ to PE_DMA (vc_comm channel) carrying both
the data move request and the piggyback metadata (ADR-0023 D9).
Receiving PE_DMA processes this atomically (I6 MUST):
1. MemoryStore.write(dst_endpoint.buffer_kind, dst_addr, data)
2. Forward IpcqMetaArrival(token=self) to peer PE_IPCQ
No yield is allowed between the two steps.
The ``data`` field is a snapshot taken by the sender's PE_DMA at the
moment the send is issued. This preserves "in-flight data" semantics:
if the sender mutates its source memory after issuing the send but
before arrival, the receiver still gets the snapshot. The snapshot is
None for control-only tokens (e.g. credit-only updates).
"""
# ── Data movement (single-hop DMA write) ──
src_addr: int
src_space: str
dst_addr: int # already-computed peer rx slot PA
dst_endpoint: IpcqEndpoint # routing target (sip/cube/pe) + buffer_kind
nbytes: int
handle_id: str # completion notify back to sender PE_IPCQ
# Optional shape/dtype carried for op_log + MemoryStore convenience.
shape: tuple[int, ...] = ()
dtype: str = "f16"
# In-flight data snapshot (sender PE_DMA captures this at send time).
data: Any = None
# ── Piggyback metadata (D9) ──
sender_seq: int = 0 # monotonic; receiver updates peer_head_cache
src_sip: int = 0
src_cube: int = 0
src_pe: int = 0
src_direction: str = "E" # sender-side direction; receiver maps to its own
data_op: bool = True
# ── D12: IpcqMetaArrival (PE_DMA → PE_IPCQ, intra-PE wire) ──────────
@dataclass
class IpcqMetaArrival:
"""Posted by receiving PE_DMA into the destination PE's PE_IPCQ inbox
in the same SimPy step as the MemoryStore.write (D9, I6 MUST).
The receiver PE_IPCQ uses ``token.sender_seq`` to update its
peer_head_cache for the corresponding direction.
"""
token: IpcqDmaToken
# ── D12: IpcqCreditMetadata (PE_IPCQ → peer PE_IPCQ, fast path) ─────
@dataclass(frozen=True)
class IpcqCreditMetadata:
"""Credit return — recv-side → send-side fast path (D9).
Sent by ``PeIpcqComponent._delayed_credit_send`` after a
bottleneck-BW based latency, putting the metadata directly into
the peer's pre-wired credit store (no fabric routing).
"""
consumer_seq: int # my_tail at recv side (new tail value)
src_sip: int # which peer is sending the credit
src_cube: int
src_pe: int
src_direction: str # sender-side direction (peer maps to its own)
# ── Request wrapper (PE_CPU → PE_IPCQ) ───────────────────────────────
@dataclass
class IpcqRequest:
"""Wrapper carrying an IpcqSendCmd or IpcqRecvCmd plus a SimPy completion
event. Posted by PE_CPU into PE_IPCQ's inbox; PE_IPCQ calls
``done.succeed()`` when the request is fully processed.
For recv requests, the result (slot address, direction, dtype, shape)
is written into ``result_data`` so the caller can read it after wait.
"""
command: "IpcqSendCmd | IpcqRecvCmd"
done: "simpy.Event"
result_data: dict[str, Any] = field(default_factory=dict)
# ── RecvFuture (kernel ↔ runner handshake for tl.recv_async / tl.wait) ─
@dataclass
class RecvFuture:
"""Opaque future returned by ``tl.recv_async``.
The KernelRunner attaches a SimPy event and the IpcqRequest in the
background; ``tl.wait(future)`` switches back to the runner which
yields on the event and resolves the result into a TensorHandle.
"""
cmd: "IpcqRecvCmd"
request: Any = None # IpcqRequest (set by runner)
event: Any = None # simpy.Event (set by runner)
resolved: bool = False
result: Any = None # cached TensorHandle after wait()
+1
View File
@@ -33,6 +33,7 @@ class TensorHandle:
dtype: str
nbytes: int # total byte size
data: object = None # reserved for validate mode
space: str = "tcm" # MemoryStore space ("tcm" | "hbm" | "sram")
@dataclass(frozen=True)
+25 -2
View File
@@ -42,9 +42,30 @@ class PeCpuComponent(ComponentBase):
self._cube_idx = int(parts[1].replace("cube", ""))
except (IndexError, ValueError):
self._cube_idx = 0
# num_cubes from spec (for tl.program_id(axis=1))
# num_cubes from spec (for tl.program_id(axis=1) — ADR-0022)
spec = ctx.spec if ctx else {}
self._num_cubes = spec.get("system", {}).get("sips", {}).get("cubes_per_sip", 1)
cube_mesh = spec.get("sip", {}).get("cube_mesh", {})
if cube_mesh:
self._num_cubes = int(cube_mesh.get("w", 1)) * int(cube_mesh.get("h", 1))
else:
self._num_cubes = (
spec.get("system", {}).get("sips", {}).get("cubes_per_sip", 1)
)
# PE-local scratch for kernel math output handles (ADR-0020 D3
# extension; reserved portion of TCM addressed via a synthetic
# MemoryStore key, not the real PA encoder).
pe_template = spec.get("cube", {}).get("pe_template", {})
tcm_attrs = pe_template.get("components", {}).get("pe_tcm", {}).get("attrs", {})
scratch_mb = float(tcm_attrs.get("kernel_scratch_mb", 1))
self._tl_scratch_size = int(scratch_mb * (1 << 20))
# PE-unique base address — high bit pattern to avoid collision with
# IPCQ ring buffers (which use bit 60).
self._tl_scratch_base = (
(1 << 61)
| (self._sip_idx << 40)
| (self._cube_idx << 32)
| (self._pe_idx << 24)
)
def _find_shard(self, shards: tuple) -> Any:
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
@@ -146,6 +167,8 @@ class PeCpuComponent(ComponentBase):
scheduler_id=scheduler_id,
out_ports=self.out_ports,
store=store,
scratch_base=self._tl_scratch_base,
scratch_size=self._tl_scratch_size,
)
yield from runner.run(env, kernel_fn, kernel_args, num_programs)
return getattr(runner, "_composite_results", [])
+116 -3
View File
@@ -106,18 +106,131 @@ class PeDmaComponent(PeEngineBase):
pe_txn.done.succeed()
def _worker(self, env: simpy.Environment) -> Generator:
"""Handle TileToken (pipeline), PeInternalTxn (legacy), and Transaction (fabric)."""
"""Handle TileToken (pipeline), PeInternalTxn (legacy), IpcqDmaToken,
and Transaction (fabric)."""
from kernbench.common.ipcq_types import IpcqDmaToken
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.builtin.pe_types import TileToken
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, TileToken):
if isinstance(msg, IpcqDmaToken):
# Outbound: IPCQ token from local PE_IPCQ → forward via fabric
env.process(self._handle_ipcq_outbound(env, msg))
elif isinstance(msg, TileToken):
env.process(self._pipeline_process(env, msg))
elif isinstance(msg, PeInternalTxn):
env.process(self._handle_with_hooks(env, msg))
else:
env.process(self._forward_txn(env, msg))
# Transaction (or unknown). May carry IpcqDmaToken inbound.
req = getattr(msg, "request", None)
if isinstance(req, IpcqDmaToken):
env.process(self._handle_ipcq_inbound(env, msg))
else:
env.process(self._forward_txn(env, msg))
# ── IPCQ outbound (PE_IPCQ → PE_DMA → fabric) ───────────────────
def _handle_ipcq_outbound(self, env: simpy.Environment, token: Any) -> Generator:
"""Forward IpcqDmaToken from local PE_IPCQ through the fabric to peer
PE_DMA. ADR-0023 D8 (vc_comm channel)."""
if self.ctx is None:
return # nothing to do
peer = token.dst_endpoint
peer_pe_dma = f"sip{peer.sip}.cube{peer.cube}.pe{peer.pe}.pe_dma"
# Snapshot the source data at send time (D9 in-flight semantics).
# Without this, the receiver could read stale or future data if the
# sender mutates src_addr between send issue and DMA arrival.
store = getattr(self.ctx, "memory_store", None)
if store is not None and token.data is None:
try:
snap = store.read(
token.src_space, token.src_addr,
shape=token.shape, dtype=token.dtype,
)
# Copy so later mutations to src_addr don't affect the snapshot.
token.data = snap.copy() if hasattr(snap, "copy") else snap
except Exception:
token.data = None
# Record the IPCQ copy in op_log at OUTBOUND time. ADR-0020 D6:
# Phase 2 replays the copy in t_start order; using outbound time
# (rather than inbound) ensures the copy executes before any later
# local op at the sender that might overwrite token.src_addr (e.g.
# a tl.store after a recv).
if self._op_logger is not None:
try:
self._op_logger.record_copy(
t_start=float(env.now), t_end=float(env.now),
component_id=self.node.id,
src_space=token.src_space, src_addr=token.src_addr,
dst_space=peer.buffer_kind,
dst_addr=token.dst_addr,
shape=token.shape, dtype=token.dtype, nbytes=token.nbytes,
)
except Exception:
pass
try:
path = self.ctx.router.find_path(self._pe_prefix, peer_pe_dma)
except Exception:
return
drain_ns = self.ctx.compute_drain_ns(path, token.nbytes)
sub_done = env.event()
sub_txn = Transaction(
request=token, path=path, step=0,
nbytes=token.nbytes, done=sub_done, drain_ns=drain_ns,
)
if len(path) > 1:
next_hop = path[1]
if next_hop in self.out_ports:
yield self.out_ports[next_hop].put(sub_txn.advance())
else:
return
# Note: don't wait on sub_done here — fire-and-forget for vc_comm.
# IPCQ slot bookkeeping (peer_head) was already updated by PE_IPCQ;
# backpressure is via credit return, not via this DMA's completion.
# ── IPCQ inbound (fabric → PE_DMA → MemoryStore + PE_IPCQ) ──────
def _handle_ipcq_inbound(self, env: simpy.Environment, txn: Any) -> Generator:
"""At destination PE_DMA: atomically write data and forward metadata.
I6 (MUST): no SimPy yield between MemoryStore.write and the
IpcqMetaArrival put into PE_IPCQ.
"""
from kernbench.common.ipcq_types import IpcqMetaArrival
token = txn.request
# ── ATOMIC: do not introduce yield between these two operations ──
# 1. Move data via MemoryStore (single-hop DMA write).
# Prefer the in-flight snapshot stashed by the sender PE_DMA;
# fall back to a fresh read of src_addr if no snapshot is present
# (e.g. control-only token).
store = getattr(self.ctx, "memory_store", None) if self.ctx else None
if store is not None:
try:
data = token.data
if data is None:
data = store.read(
token.src_space, token.src_addr,
shape=token.shape, dtype=token.dtype,
)
store.write(token.dst_endpoint.buffer_kind, token.dst_addr, data)
except Exception:
pass
# 2. Forward IpcqMetaArrival to local PE_IPCQ
ipcq_id = f"{self._pe_prefix}.pe_ipcq"
if ipcq_id in self.out_ports:
yield self.out_ports[ipcq_id].put(IpcqMetaArrival(token=token))
# ─────────────────────────────────────────────────────────────────
if not txn.done.triggered:
txn.done.succeed()
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
"""Pipeline mode: DMA read/write via fabric, then self-route."""
+455
View File
@@ -0,0 +1,455 @@
"""PE_IPCQ component (ADR-0023): per-PE IPCQ control plane.
Responsibilities:
- Hold per-direction queue pair state (my_head, my_tail,
peer_head_cache, peer_tail_cache, ring buffer addresses)
- Process IpcqInitMsg from backend to install neighbor table
- Handle IpcqRequest(IpcqSendCmd) from PE_CPU:
compute peer slot address, check backpressure, forward
IpcqDmaToken to PE_DMA (vc_comm)
- Handle IpcqRequest(IpcqRecvCmd) from PE_CPU:
wait for data arrival, return slot address (or copy to dst),
send fast-path credit return
- Handle IpcqMetaArrival from PE_DMA: update peer_head_cache, wake recv
- Handle IpcqCreditMetadata via own credit_inbox: update peer_tail_cache,
wake send
PE_IPCQ does NOT move data — it forwards IpcqDmaToken to PE_DMA which
performs the actual fabric DMA.
Credit return uses a fast path: PE_IPCQ creates a SimPy process with a
bottleneck-BW based latency, then puts IpcqCreditMetadata directly into
the peer's pre-wired credit_store.
"""
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.common.ipcq_types import (
IpcqCreditMetadata,
IpcqDmaToken,
IpcqInvalidDirection,
IpcqMetaArrival,
IpcqRecvCmd,
IpcqRequest,
IpcqSendCmd,
)
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.runtime_api.kernel import IpcqInitMsg
from kernbench.topology.types import Node
_DIR_ORDER: tuple[str, ...] = ("N", "S", "E", "W", "parent", "child_left", "child_right")
class PeIpcqComponent(ComponentBase):
"""PE_IPCQ: ring buffer pointer + neighbor management for CCL.
Owned by one PE; talks to PE_DMA via out_ports[<pe_dma_id>] and
receives credit return metadata via the public ``credit_inbox``
SimPy Store (wired by backend at IpcqInitMsg installation time).
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
# Strict shape/dtype validation (D14 F2). Off by default.
self._strict: bool = bool(node.attrs.get("strict_validation", False))
# direction → list of received tokens (for strict-mode peek of next slot)
self._arrived_tokens: dict[str, list] = {}
# Parse self (sip, cube, pe) from node id, e.g. "sip0.cube0.pe0.pe_ipcq"
self._pe_prefix: str = node.id.rsplit(".", 1)[0]
parts = self._pe_prefix.split(".")
try:
self._self_sip = int(parts[0].replace("sip", ""))
except (IndexError, ValueError):
self._self_sip = 0
try:
self._self_cube = int(parts[1].replace("cube", ""))
except (IndexError, ValueError):
self._self_cube = 0
try:
self._self_pe = int(parts[2].replace("pe", ""))
except (IndexError, ValueError):
self._self_pe = 0
self._dma_node_id = f"{self._pe_prefix}.pe_dma"
# direction → state dict (see _install_neighbors for shape)
self._queue_pairs: dict[str, dict[str, Any]] = {}
self._installed = False
self._buffer_kind: str = "tcm"
self._backpressure_mode: str = "sleep"
self._credit_size_bytes: int = 16
# waiters for recv (per direction) and any-direction (for round-robin)
self._recv_waiters: dict[str, list[simpy.Event]] = {}
self._any_recv_waiters: list[simpy.Event] = []
# waiters for send backpressure (per direction)
self._send_waiters: dict[str, list[simpy.Event]] = {}
# round-robin cursor over installed directions
self._rr_dirs: list[str] = []
self._rr_cursor: int = 0
# credit_inbox is created in start() once env is available
self._credit_inbox: simpy.Store | None = None
# ── Public ──
@property
def credit_inbox(self) -> simpy.Store:
"""SimPy Store that backend wires as ``peer_credit_store`` on
every remote sender targeting this PE. Used by D9 fast path."""
assert self._credit_inbox is not None, "PE_IPCQ not started yet"
return self._credit_inbox
@property
def queue_pairs(self) -> dict[str, dict[str, Any]]:
"""Test/debug accessor."""
return self._queue_pairs
# ── Lifecycle ──
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
yield env.timeout(0)
def start(self, env: simpy.Environment) -> None:
# Create credit_inbox even if there are no in_ports yet
if self._credit_inbox is None:
self._credit_inbox = simpy.Store(env)
# If no in_ports were wired (e.g. unit test), still spin up workers
if not self.in_ports:
self._inbox = simpy.Store(env)
super().start(env)
env.process(self._credit_worker(env))
# ── Worker (override of ComponentBase._worker) ──
def _worker(self, env: simpy.Environment) -> Generator:
from kernbench.runtime_api.kernel import IpcqInitMsg
while True:
msg: Any = yield self._inbox.get()
# IpcqInitMsg may arrive wrapped in a transaction (with .request)
# or directly.
request_obj = getattr(msg, "request", None)
if isinstance(request_obj, IpcqInitMsg):
self._install_neighbors(request_obj)
done = getattr(msg, "done", None)
if done is not None and not done.triggered:
done.succeed()
continue
if isinstance(msg, IpcqInitMsg):
self._install_neighbors(msg)
continue
if isinstance(msg, IpcqMetaArrival):
self._handle_meta_arrival(msg)
continue
if isinstance(msg, IpcqRequest):
env.process(self._handle_request(env, msg))
continue
# Unknown message — drop or forward via base class fallback
env.process(self._forward_txn(env, msg))
# ── Init ──
def _install_neighbors(self, msg: IpcqInitMsg) -> None:
self._installed = True
self._buffer_kind = msg.buffer_kind
self._backpressure_mode = msg.backpressure_mode
self._credit_size_bytes = msg.credit_size_bytes
for entry in msg.entries:
self._queue_pairs[entry.direction] = {
"peer": entry.peer,
"my_rx_base_pa": entry.my_rx_base_pa,
"my_rx_base_va": entry.my_rx_base_va,
"n_slots": entry.n_slots,
"slot_size": entry.slot_size,
"peer_credit_store": entry.peer_credit_store,
"my_head": 0,
"my_tail": 0,
"peer_head_cache": 0,
"peer_tail_cache": 0,
}
self._recv_waiters.setdefault(entry.direction, [])
self._send_waiters.setdefault(entry.direction, [])
# Reset round-robin order to a stable canonical sequence
self._rr_dirs = [d for d in _DIR_ORDER if d in self._queue_pairs]
self._rr_cursor = 0
# ── Send ──
def _handle_request(self, env: simpy.Environment, req: IpcqRequest) -> Generator:
cmd = req.command
if isinstance(cmd, IpcqSendCmd):
yield from self._handle_send(env, req, cmd)
elif isinstance(cmd, IpcqRecvCmd):
yield from self._handle_recv(env, req, cmd)
def _handle_send(
self, env: simpy.Environment, req: IpcqRequest, cmd: IpcqSendCmd,
) -> Generator:
if cmd.direction not in self._queue_pairs:
raise IpcqInvalidDirection(
f"PE {self._pe_prefix}: direction {cmd.direction!r} not installed"
)
qp = self._queue_pairs[cmd.direction]
peer = qp["peer"]
# Backpressure: wait while ring full
while (qp["my_head"] - qp["peer_tail_cache"]) >= peer.n_slots:
wait_event = env.event()
self._send_waiters[cmd.direction].append(wait_event)
yield wait_event
# Compute peer slot address
slot_idx = qp["my_head"] % peer.n_slots
dst_pa = peer.rx_base_pa + slot_idx * peer.slot_size
token = IpcqDmaToken(
src_addr=cmd.src_addr,
src_space=cmd.src_space,
dst_addr=dst_pa,
dst_endpoint=peer,
nbytes=cmd.nbytes,
handle_id=cmd.handle_id,
shape=cmd.shape,
dtype=cmd.dtype,
sender_seq=qp["my_head"],
src_sip=self._self_sip,
src_cube=self._self_cube,
src_pe=self._self_pe,
src_direction=cmd.direction,
)
# Forward to PE_DMA (vc_comm)
yield self.out_ports[self._dma_node_id].put(token)
qp["my_head"] += 1
# Diagnostics trace (D14)
from kernbench.ccl import diagnostics
if diagnostics.trace_enabled():
diagnostics.log_send(
t_ns=float(env.now), sender=self._pe_prefix,
direction=cmd.direction, nbytes=cmd.nbytes,
sender_seq=qp["my_head"] - 1,
)
if not req.done.triggered:
req.done.succeed()
# ── Recv ──
def _handle_recv(
self, env: simpy.Environment, req: IpcqRequest, cmd: IpcqRecvCmd,
) -> Generator:
if cmd.direction is None:
direction = yield from self._wait_any_direction(env)
else:
if cmd.direction not in self._queue_pairs:
raise IpcqInvalidDirection(
f"PE {self._pe_prefix}: direction {cmd.direction!r} not installed"
)
direction = cmd.direction
qp = self._queue_pairs[direction]
while qp["peer_head_cache"] <= qp["my_tail"]:
wait_event = env.event()
self._recv_waiters[direction].append(wait_event)
yield wait_event
qp = self._queue_pairs[direction]
slot_idx = qp["my_tail"] % qp["n_slots"]
slot_addr = qp["my_rx_base_pa"] + slot_idx * qp["slot_size"]
# Strict validation (D14 F2): peek the next-arrived token's metadata
# against the recv command's expected shape/dtype/nbytes.
arrived = self._arrived_tokens.get(direction, [])
if arrived:
front = arrived.pop(0)
if self._strict:
expected_nbytes = self._nbytes_for(cmd.shape, cmd.dtype)
if front.dtype != cmd.dtype:
raise ValueError(
f"PE_IPCQ {self._pe_prefix} recv strict: dtype mismatch — "
f"sender={front.dtype} recv={cmd.dtype}"
)
if front.shape != cmd.shape:
raise ValueError(
f"PE_IPCQ {self._pe_prefix} recv strict: shape mismatch — "
f"sender={front.shape} recv={cmd.shape}"
)
if front.nbytes != expected_nbytes:
raise ValueError(
f"PE_IPCQ {self._pe_prefix} recv strict: nbytes mismatch — "
f"sender={front.nbytes} recv={expected_nbytes}"
)
req.result_data["src_space"] = self._buffer_kind
req.result_data["src_addr"] = slot_addr
req.result_data["direction"] = direction
req.result_data["dtype"] = cmd.dtype
req.result_data["shape"] = cmd.shape
req.result_data["nbytes"] = self._nbytes_for(cmd.shape, cmd.dtype)
# copy_to_dst mode: rebind the result handle to (dst_space, dst_addr).
# When op_log is disabled, we also do the actual data move now;
# when op_log is enabled, Phase 2 replays the slot→dst copy from
# the op_log entry below so we don't pollute the slot in Phase 1.
if cmd.recv_mode == "copy_to_dst" and self.ctx is not None:
req.result_data["src_space"] = cmd.dst_space
req.result_data["src_addr"] = cmd.dst_addr
store = getattr(self.ctx, "memory_store", None)
if store is not None and self._op_logger is None:
try:
data = store.read(self._buffer_kind, slot_addr, shape=cmd.shape, dtype=cmd.dtype)
store.write(cmd.dst_space, cmd.dst_addr, data)
except Exception:
pass
if self._op_logger is not None:
# Record slot → dst copy for Phase 2 replay (ADR-0023 D9.5).
try:
self._op_logger.record_copy(
t_start=float(env.now), t_end=float(env.now),
component_id=self.node.id,
src_space=self._buffer_kind, src_addr=slot_addr,
dst_space=cmd.dst_space, dst_addr=cmd.dst_addr,
shape=cmd.shape, dtype=cmd.dtype,
nbytes=self._nbytes_for(cmd.shape, cmd.dtype),
)
except Exception:
pass
qp["my_tail"] += 1
# Diagnostics trace (D14)
from kernbench.ccl import diagnostics
if diagnostics.trace_enabled():
diagnostics.log_recv(
t_ns=float(env.now), receiver=self._pe_prefix,
direction=direction,
nbytes=req.result_data.get("nbytes", 0),
)
# Fast path credit return — bottleneck BW based latency
env.process(
self._delayed_credit_send(env, direction, qp["peer_credit_store"], qp["my_tail"])
)
if not req.done.triggered:
req.done.succeed()
def _wait_any_direction(self, env: simpy.Environment) -> Generator:
"""Round-robin scan over installed directions; wait until at least one
has data. Returns the chosen direction (str)."""
if not self._rr_dirs:
raise IpcqInvalidDirection(
f"PE {self._pe_prefix}: no neighbors installed"
)
while True:
n = len(self._rr_dirs)
for i in range(n):
idx = (self._rr_cursor + i) % n
d = self._rr_dirs[idx]
qp = self._queue_pairs[d]
if qp["peer_head_cache"] > qp["my_tail"]:
self._rr_cursor = (idx + 1) % n
return d
# Nothing available — wait until any arrival
wait_event = env.event()
self._any_recv_waiters.append(wait_event)
yield wait_event
# ── Metadata arrival from PE_DMA (D9) ──
def _handle_meta_arrival(self, msg: IpcqMetaArrival) -> None:
token = msg.token
sender_key = (token.src_sip, token.src_cube, token.src_pe)
for d, qp in self._queue_pairs.items():
p = qp["peer"]
if (p.sip, p.cube, p.pe) == sender_key:
qp["peer_head_cache"] = max(qp["peer_head_cache"], token.sender_seq + 1)
# Track arrived token for strict-mode peek
self._arrived_tokens.setdefault(d, []).append(token)
# Wake any blocked recv on this direction
waiters = self._recv_waiters.get(d, [])
self._recv_waiters[d] = []
for ev in waiters:
if not ev.triggered:
ev.succeed()
# Wake any-direction waiters
any_waiters = self._any_recv_waiters
self._any_recv_waiters = []
for ev in any_waiters:
if not ev.triggered:
ev.succeed()
return
# Unknown sender — silently drop (could log)
# ── Credit return (fast path) ──
def _credit_worker(self, env: simpy.Environment) -> Generator:
"""Process IpcqCreditMetadata from credit_inbox."""
assert self._credit_inbox is not None
while True:
credit: IpcqCreditMetadata = yield self._credit_inbox.get()
sender_key = (credit.src_sip, credit.src_cube, credit.src_pe)
for d, qp in self._queue_pairs.items():
p = qp["peer"]
if (p.sip, p.cube, p.pe) == sender_key:
qp["peer_tail_cache"] = max(qp["peer_tail_cache"], credit.consumer_seq)
# Wake any blocked send on this direction
waiters = self._send_waiters.get(d, [])
self._send_waiters[d] = []
for ev in waiters:
if not ev.triggered:
ev.succeed()
break
def _delayed_credit_send(
self,
env: simpy.Environment,
direction: str,
peer_credit_store: simpy.Store,
new_tail: int,
) -> Generator:
"""Wait bottleneck-BW latency, then put IpcqCreditMetadata into peer
credit store (D9 fast path)."""
latency_ns = self._credit_latency_ns(direction)
if latency_ns > 0:
yield env.timeout(latency_ns)
meta = IpcqCreditMetadata(
consumer_seq=new_tail,
src_sip=self._self_sip,
src_cube=self._self_cube,
src_pe=self._self_pe,
src_direction=direction,
)
yield peer_credit_store.put(meta)
def _credit_latency_ns(self, direction: str) -> float:
"""Compute credit fast path latency = credit_size / bottleneck_bw.
Falls back to 0 when ctx/router is unavailable (unit-test mode).
"""
if self.ctx is None:
return 0.0
qp = self._queue_pairs[direction]
peer = qp["peer"]
peer_pe_prefix = f"sip{peer.sip}.cube{peer.cube}.pe{peer.pe}"
try:
path = self.ctx.router.find_path(self._pe_prefix, peer_pe_prefix)
return self.ctx.compute_drain_ns(path, self._credit_size_bytes)
except Exception:
return 0.0
# ── Helpers ──
@staticmethod
def _nbytes_for(shape: tuple[int, ...], dtype: str) -> int:
from math import prod
bits = {"f16": 16, "bf16": 16, "f32": 32, "i8": 8, "i16": 16, "i32": 32}.get(dtype, 16)
return prod(shape) * (bits // 8) if shape else 0
+2 -4
View File
@@ -29,11 +29,10 @@ def run_bench(
correlation_id: str = "bench0",
completion_policy: CompletionPolicy = CompletionPolicy.LAST_SUBMITTED,
) -> BenchResult:
"""
Minimal bench runner.
"""Minimal bench runner.
- topology: compiled topology object (opaque to runtime here)
- bench_fn: callable that receives RuntimeContext and submits requests
- bench_fn: callable ``run(torch)`` receiving a RuntimeContext
- device: DeviceSelector ("all" or "sip:<N>")
- engine_factory: builds sim_engine for given topology & device
- completion_policy: how to determine overall completion/result
@@ -48,7 +47,6 @@ def run_bench(
)
bench_fn(ctx)
ctx.wait_all()
collected_traces = ctx._traces or None
+125 -7
View File
@@ -9,6 +9,39 @@ from kernbench.common.types import Completion, RequestHandle, SimEngine
from .types import DeviceSelector
def _world_size_from_spec(spec: dict | None) -> int:
"""Derive world_size from topology spec: sips × cubes × pes_per_cube."""
spec = spec or {}
sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
cm = spec.get("sip", {}).get("cube_mesh", {})
cubes_per_sip = int(cm.get("w", 1)) * int(cm.get("h", 1))
pl = spec.get("cube", {}).get("pe_layout", {})
corners = pl.get("corners", [])
pe_per_corner = int(pl.get("pe_per_corner", 1))
pes_per_cube = pe_per_corner * max(len(corners), 1)
return sips * cubes_per_sip * pes_per_cube
def _numpy_to_dtype_str(np_dtype) -> str:
"""Map numpy dtype → kernbench dtype string used by Tensor."""
import numpy as np
kind_map = {
np.float16: "f16",
np.float32: "f32",
np.int8: "i8",
np.int16: "i16",
np.int32: "i32",
np.uint8: "u8",
np.uint16: "u16",
np.uint32: "u32",
}
for np_type, s in kind_map.items():
if np.dtype(np_dtype) == np.dtype(np_type):
return s
raise ValueError(f"unsupported numpy dtype: {np_dtype!r}")
@dataclass
class RuntimeContext:
engine: SimEngine
@@ -23,6 +56,66 @@ class RuntimeContext:
_tensor_counter: int = field(default=0, init=False)
_traces: list[dict] = field(default_factory=list, init=False)
_tensors: list[Any] = field(default_factory=list, init=False)
distributed: Any = field(default=None, init=False) # DistributedContext for CCL benches
_ipcq_plan: dict = field(default_factory=dict, init=False) # ADR-0023 install plan
def __post_init__(self) -> None:
# Eagerly attach a DistributedContext so bench code can do
# ``dist = torch.distributed`` + ``dist.init_process_group(...)``
# without needing a separate launcher to install it.
from kernbench.runtime_api.distributed import DistributedContext
dc = DistributedContext()
dc._ctx_ref = self # back-reference for AhbmCCLBackend to reach ctx.launch etc.
self.distributed = dc
def install_ipcq(
self,
algorithm: str | None = None,
ccl_yaml: str | None = None,
world_size_override: int | None = None,
rank_to_pe: list[tuple[int, int, int]] | None = None,
) -> dict:
"""Install IPCQ neighbor tables on all participating PEs (ADR-0023 D10).
Loads ``ccl.yaml`` (or the path provided), resolves the chosen
algorithm (or ``defaults.algorithm`` if None), and pushes per-PE
IpcqInitMsg into every PE_IPCQ component via the engine.
Args:
algorithm: name of the algorithm in ccl.yaml (or use defaults).
ccl_yaml: optional path to ccl.yaml.
world_size_override: if set, replace the algorithm's world_size.
Returns the install plan dict (rank → (sip,cube,pe), neighbor table).
"""
import importlib
from kernbench.ccl.install import (
install_ipcq as _install,
load_ccl_config,
resolve_algorithm_config,
)
cfg = load_ccl_config(ccl_yaml)
merged = resolve_algorithm_config(cfg, algorithm)
if world_size_override is not None:
merged["world_size"] = world_size_override
elif "world_size" not in merged:
# Derive from topology.yaml when neither the algorithm entry
# nor ``defaults`` carries ``world_size`` (matches pytorch DDP
# where env vars determine ranks, not the ccl config file).
merged["world_size"] = _world_size_from_spec(self.spec)
algo_module = None
try:
algo_module = importlib.import_module(merged["module"])
except ModuleNotFoundError:
pass
plan = _install(
self.engine, self.spec, merged,
algo_module=algo_module, rank_to_pe=rank_to_pe,
)
self._ipcq_plan = plan
self._ipcq_config = merged
return plan
def __enter__(self):
return self
@@ -258,6 +351,24 @@ class RuntimeContext:
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
return self._create_tensor(shape, dtype, name, pattern=None, dp=dp)
def from_numpy(self, arr: Any):
"""Create a host-side tensor wrapping a numpy array.
Mirrors ``torch.from_numpy``. The returned tensor is NOT deployed
to any PE — it lives in an in-memory host staging buffer. Use
``target.copy_(host_tensor)`` to scatter its contents into a
sharded, deployed tensor.
"""
import numpy as np
from kernbench.runtime_api.tensor import Tensor
arr_c = np.ascontiguousarray(arr)
dtype_str = _numpy_to_dtype_str(arr_c.dtype)
t = Tensor(shape=tuple(arr_c.shape), dtype=dtype_str, name="host")
t._host_buffer = arr_c
t._memory_store = getattr(self.engine, "_memory_store", None)
return t
def _create_tensor(
self,
shape: tuple[int, ...],
@@ -418,13 +529,12 @@ class RuntimeContext:
TensorArgShard,
)
from kernbench.runtime_api.tensor import Tensor
from kernbench.triton_emu.registry import register_kernel
from kernbench.triton_emu.registry import _kernels, register_kernel
# Register kernel (idempotent)
try:
register_kernel(kernel_name, kernel_fn)
except ValueError:
pass
# Register kernel (idempotent overwrite — last call wins).
# Tests can re-register the same kernel_name with a different
# function; the user's most recent launch must use the latest fn.
_kernels[kernel_name] = kernel_fn
# Collect tensors and scalars
tensor_args: list[Tensor] = []
@@ -506,6 +616,7 @@ class RuntimeContext:
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
last_handle = None
_pending_handles: list[tuple[Any, int]] = []
for sip_id in sorted(sip_set):
sip_kernel_args: list = []
sip_cube_set: set[int] = set()
@@ -566,10 +677,17 @@ class RuntimeContext:
target_cubes=target_cubes,
target_pe=target_pe,
))
# Defer wait until all SIPs are submitted (multi-SIP CCL needs
# all participating PEs to be live concurrently — waiting
# per-SIP would deadlock when ranks span SIP boundaries).
_pending_handles.append((h, sip_id))
last_handle = h
# Drain pending handles now that every SIP has a launch posted.
for h, sip_id in _pending_handles:
self.wait(h, _meta={
"phase": "kernel", "name": kernel_name,
"sip": sip_id, "target_pe": target_pe,
})
last_handle = h
return last_handle
+179
View File
@@ -0,0 +1,179 @@
"""PyTorch-compatible distributed communication shim (ADR-0023 D11).
Provides a ``torch.distributed``-like API whose public surface matches
real PyTorch so that bench code looks identical to a DDP training script.
Only the ``ahbm`` backend is implemented. It:
1. Reads ``ccl.yaml`` to decide which collective algorithm to run.
2. Derives world_size from the algorithm entry, the defaults section, or
from the topology spec (``system.sips.count × sip.cube_mesh × pe_layout``).
3. At ``init_process_group`` time, eagerly installs the IPCQ neighbor
table once (one-time comm setup — mirrors NCCL communicator creation).
4. On each ``all_reduce(tensor)`` call, reads per-shard metadata from the
tensor handle and dispatches ``torch.launch`` with the registered
kernel. The kernel performs intra-PE ring/tree/mesh CCL via IPCQ,
and Phase 2 DataExecutor replays math + copies from op_log so
MemoryStore is correct when ``all_reduce`` returns.
Host bench code uses only real-PyTorch names:
dist.init_process_group, dist.is_initialized, dist.get_world_size,
dist.get_rank, dist.get_backend, dist.all_reduce, dist.barrier
"""
from __future__ import annotations
import importlib
from typing import Any
class AhbmCCLBackend:
"""Ahbm CCL backend — drives kernel-level collectives via IPCQ."""
def __init__(self, torch_ctx: Any) -> None:
from kernbench.ccl.install import (
load_ccl_config,
resolve_algorithm_config,
)
self.ctx = torch_ctx
self._cfg_all = load_ccl_config()
self._merged = resolve_algorithm_config(self._cfg_all)
self._algo_module = importlib.import_module(self._merged["module"])
self._world_size = self._resolve_world_size()
# Eager IPCQ install — ``init_process_group`` time. Mirrors NCCL
# communicator creation: done once, reused across every subsequent
# collective call on the same process group.
self.ctx.install_ipcq(
algorithm=self._merged["algorithm"],
world_size_override=self._world_size,
)
def _resolve_world_size(self) -> int:
"""Derive world_size (priority: algorithm override > defaults > topology).
Topology derivation:
sips × cubes_per_sip × pes_per_cube
"""
if "world_size" in self._merged:
return int(self._merged["world_size"])
defaults = self._cfg_all.get("defaults", {})
if "world_size" in defaults:
return int(defaults["world_size"])
spec = self.ctx.spec or {}
sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
cm = spec.get("sip", {}).get("cube_mesh", {})
cubes_per_sip = int(cm.get("w", 1)) * int(cm.get("h", 1))
pl = spec.get("cube", {}).get("pe_layout", {})
corners = pl.get("corners", [])
pe_per_corner = int(pl.get("pe_per_corner", 1))
pes_per_cube = pe_per_corner * max(len(corners), 1)
return sips * cubes_per_sip * pes_per_cube
@property
def world_size(self) -> int:
return self._world_size
def all_reduce(self, tensor: Any, op: str = "sum") -> None:
"""Dispatch the configured CCL algorithm as a single kernel launch.
Raises if ``op != "sum"`` (current kernels only implement add
reduction) or if the tensor's shard count disagrees with the
world_size that was installed into PE_IPCQ.
"""
if op != "sum":
raise NotImplementedError(f"all_reduce op={op!r} not supported")
if tensor._handle is None:
raise RuntimeError(
f"Tensor '{tensor.name}' is not deployed (call torch.zeros "
"with a DPPolicy first)"
)
shards = tensor._handle.shards
if len(shards) != self._world_size:
raise RuntimeError(
f"all_reduce tensor has {len(shards)} shards but the "
f"ahbm backend was installed with world_size="
f"{self._world_size}; adjust the tensor's DPPolicy or "
"restart the process group"
)
n_elem = shards[0].nbytes // tensor.itemsize
kernel_fn = self._algo_module.kernel
kernel_args = self._algo_module.kernel_args(self._world_size, n_elem)
self.ctx.launch(
self._merged["algorithm"], kernel_fn, tensor, *kernel_args,
)
def barrier(self) -> None:
# Single-driver model → no cross-process sync needed. Keeping the
# method so ``dist.barrier()`` is callable (pytorch-compat surface).
return None
class DistributedContext:
"""torch.distributed-compat facade.
Public surface matches real PyTorch so bench code reads identically
to a DDP training script. Single-driver semantics: ``get_rank()``
always returns 0 because kernbench runs as one Python process;
``get_world_size()`` returns the CCL group size (number of PEs
participating in the collective).
"""
def __init__(self) -> None:
self._backend: AhbmCCLBackend | None = None
def init_process_group(
self,
backend: str = "ahbm",
world_size: int | None = None,
rank: int | None = None,
**kwargs: Any,
) -> None:
"""Create the default process group.
``world_size`` and ``rank`` are accepted for API parity with
``torch.distributed.init_process_group`` but ignored — the ahbm
backend derives both from ``ccl.yaml`` + topology automatically
(like reading ``RANK``/``WORLD_SIZE`` env vars in real DDP).
"""
if backend != "ahbm":
raise ValueError(
f"Unsupported backend '{backend}'. Only 'ahbm' is supported."
)
ctx = getattr(self, "_ctx_ref", None)
if ctx is None:
raise RuntimeError(
"DistributedContext not bound to a RuntimeContext"
)
self._backend = AhbmCCLBackend(torch_ctx=ctx)
def is_initialized(self) -> bool:
return self._backend is not None
def get_world_size(self) -> int:
self._ensure_initialized()
return self._backend.world_size
def get_rank(self) -> int:
# Single-driver kernbench: there is only one host rank.
self._ensure_initialized()
return 0
def get_backend(self) -> str:
self._ensure_initialized()
return "ahbm"
def all_reduce(self, tensor: Any, op: str = "sum") -> None:
self._ensure_initialized()
self._backend.all_reduce(tensor, op=op)
def barrier(self) -> None:
self._ensure_initialized()
self._backend.barrier()
def _ensure_initialized(self) -> None:
if self._backend is None:
raise RuntimeError(
"Default process group has not been initialized. "
"Call init_process_group(backend='ahbm') first."
)
+27
View File
@@ -152,3 +152,30 @@ class MmuUnmapMsg:
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["mmu_unmap"] = "mmu_unmap"
@dataclass(frozen=True)
class IpcqInitMsg:
"""IPCQ neighbor table install (sideband fan-out, ADR-0023 D10/D12).
Backend issues this at ``init_process_group`` time to install per-PE
IPCQ neighbor tables. Each entry covers one direction (N/S/E/W) and
carries the peer's IpcqEndpoint plus this PE's own rx_buffer base
and a pre-wired SimPy Store for credit return fast path (D9).
Routing is similar to MmuMapMsg.
"""
correlation_id: str
request_id: str
target_sips: tuple[int, ...] | Literal["all"] = "all"
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | tuple[int, ...] | Literal["all"] = "all"
# entries: tuple[IpcqInitEntry, ...] — kept as tuple of plain objects to
# avoid a runtime import cycle (IpcqInitEntry lives in
# kernbench.common.ipcq_types).
entries: tuple = ()
backpressure_mode: str = "sleep" # "poll" | "sleep"
buffer_kind: str = "tcm" # "tcm" | "hbm" | "sram"
credit_size_bytes: int = 16
msg_type: Literal["ipcq_init"] = "ipcq_init"
+82 -7
View File
@@ -146,6 +146,11 @@ class Tensor:
self._handle: TensorHandle | None = None
self._ctx_ref: weakref.ref | None = None # set by RuntimeContext
self._memory_store = None # set by RuntimeContext when enable_data=True
# Host-side staging buffer for torch.from_numpy() results. A tensor
# with a non-None _host_buffer is NOT deployed to any PE — it lives
# only on the host. Use `target.copy_(host_tensor)` to scatter the
# data into a deployed, sharded target tensor.
self._host_buffer: np.ndarray | None = None
def __del__(self) -> None:
if self._ctx_ref is None or self._handle is None:
@@ -166,15 +171,85 @@ class Tensor:
@property
def data(self) -> np.ndarray:
"""Tensor data as numpy array. Returns actual values when enable_data=True,
zeros placeholder otherwise (like an uninitialized tensor)."""
if self._memory_store is not None and self._handle is not None:
shard = self._handle.shards[0]
"""Tensor data as numpy array.
Gathers all shards into a single full-shape array. Returns actual
values when enable_data=True, zeros placeholder otherwise (like an
uninitialized tensor). Alias of ``numpy()``.
"""
return self.numpy()
def _shard_store_addr(self, shard: TensorShard) -> int:
"""MemoryStore key for a shard.
Kernels read tensors via VA (translated to PA by PE_DMA's MMU when
a mapping exists, otherwise the addr is treated as a PA-equivalent
key). Tensor I/O therefore writes/reads at ``va_base + offset_bytes``
when ``va_base`` is set, falling back to ``shard.pa`` for the
VA-less mode used by some legacy paths.
"""
if self._handle and self._handle.va_base:
return self._handle.va_base + shard.offset_bytes
return shard.pa
def numpy(self) -> np.ndarray:
"""Return a single numpy array gathered from all shards.
Mirrors ``torch.Tensor.numpy()``. In kernbench, sharded tensors are
gathered into a single full-shape ndarray according to each shard's
``offset_bytes`` / ``nbytes`` range.
"""
np_dtype = _numpy_dtype(self.dtype)
# Host-side tensor (created via torch.from_numpy) has no shards.
if self._host_buffer is not None:
return self._host_buffer.copy()
if self._handle is None or self._memory_store is None:
return np.zeros(self.shape, dtype=np_dtype)
flat = np.zeros(math.prod(self.shape), dtype=np_dtype)
for shard in self._handle.shards:
start = shard.offset_bytes // self.itemsize
count = shard.nbytes // self.itemsize
try:
return self._memory_store.read("hbm", shard.pa, shape=self.shape, dtype=self.dtype)
piece = self._memory_store.read(
"hbm", self._shard_store_addr(shard),
)
except KeyError:
pass
return np.zeros(self.shape, dtype=_numpy_dtype(self.dtype))
continue
flat[start : start + count] = (
np.asarray(piece, dtype=np_dtype).reshape(-1)[:count]
)
return flat.reshape(self.shape)
def copy_(self, source: "Tensor") -> "Tensor":
"""In-place copy from another tensor into self.
Mirrors ``torch.Tensor.copy_()``. If ``source`` is a host tensor
(from ``torch.from_numpy``), its ndarray is split across self's
shards using each shard's byte range. If ``source`` is a deployed
(sharded) tensor, its contents are gathered first and then
re-scattered into self's shard layout.
Shapes must match. Returns self.
"""
if self._handle is None or self._memory_store is None:
raise RuntimeError(
f"Tensor '{self.name}' must be deployed before copy_()"
)
if source.shape != self.shape:
raise ValueError(
f"copy_ shape mismatch: self={self.shape} source={source.shape}"
)
np_dtype = _numpy_dtype(self.dtype)
arr = source.numpy().astype(np_dtype, copy=False)
flat = np.ascontiguousarray(arr).reshape(-1)
for shard in self._handle.shards:
start = shard.offset_bytes // self.itemsize
count = shard.nbytes // self.itemsize
piece = flat[start : start + count].copy()
self._memory_store.write(
"hbm", self._shard_store_addr(shard), piece,
)
return self
@property
def itemsize(self) -> int:
+76 -7
View File
@@ -51,7 +51,42 @@ class DataExecutor:
self._execute_math(op)
def _execute_memory(self, op: OpRecord) -> None:
"""Memory ops are already handled by Phase 1 MemoryStore. Skip."""
"""Replay memory copy ops in Phase 2 (ADR-0020 + ADR-0023).
- dma_read: no-op (handle already references HBM source).
- dma_write: copy (src_space, src_addr) → (dst_space, dst_addr).
Required because Phase 2 may have just produced new data at the
source addr (e.g. PE_MATH scratch output).
- ipcq_copy: copy across PEs — sender's source → receiver's slot.
Required because the source may be a Phase 2 math output, and
a downstream math op on the receiver reads from the slot.
Legacy entries without src/dst metadata are silently skipped.
"""
p = op.params
if op.op_name == "dma_write" or op.op_name == "ipcq_copy":
src_space = p.get("src_space")
src_addr = p.get("src_addr")
dst_space = p.get("dst_space")
dst_addr = p.get("dst_addr")
if (src_space is None or src_addr is None
or dst_space is None or dst_addr is None):
return
# Prefer the Phase-1-time snapshot (captured at record_end /
# outbound) so we don't read from a source that has since been
# mutated by another op. Fall back to MemoryStore for sources
# that had no Phase 1 data (e.g. math scratch outputs that
# only get populated by Phase 2's math replay).
data = p.get("snapshot")
if data is None:
try:
data = self.store.read(
src_space, src_addr,
shape=p.get("shape"), dtype=p.get("dtype"),
)
except KeyError:
return
self.store.write(dst_space, dst_addr, data)
def _execute_gemm(self, op: OpRecord) -> None:
"""Execute GEMM: out = a @ b."""
@@ -77,18 +112,35 @@ class DataExecutor:
"""Execute math op: unary, binary, or reduction."""
p = op.params
math_op = p.get("op", op.op_name)
space = p.get("addr_space", "tcm")
dtype = p.get("dtype", "f32")
input_addrs = p.get("input_addrs", [])
input_shapes = p.get("input_shapes", [])
# Per-input space/dtype (ADR-0023 CCL accumulation): math ops can
# mix inputs from different MemoryStore spaces (e.g. acc in "hbm",
# recv slot in "tcm"). Fall back to legacy single-space mode when
# the per-input lists are absent.
input_spaces = p.get("input_spaces") or [p.get("addr_space", "tcm")] * len(input_addrs)
input_dtypes = p.get("input_dtypes") or [dtype] * len(input_addrs)
# Per-input data snapshots (ADR-0020 D6): captured at op_log
# record time. Phase 1 has correct values for slot/HBM addrs at
# that moment, which lets Phase 2 sidestep the slot-wraparound
# races where a later round overwrites a slot before this op
# runs in t_start order.
snapshots = p.get("input_snapshots") or [None] * len(input_addrs)
dst_space = p.get("dst_space", p.get("addr_space", "tcm"))
inputs = []
for addr, shape in zip(input_addrs, input_shapes):
inputs.append(self.store.read(space, addr, shape=shape, dtype=dtype))
for addr, shape, space, idtype, snap in zip(
input_addrs, input_shapes, input_spaces, input_dtypes, snapshots
):
if snap is not None:
inputs.append(snap)
else:
inputs.append(self.store.read(space, addr, shape=shape, dtype=idtype))
result = _compute_math(math_op, inputs, p.get("axis"))
if result is not None:
self.store.write(space, p["dst_addr"], result)
self.store.write(dst_space, p["dst_addr"], result)
def verify(self, expected: dict[tuple[str, int], np.ndarray],
rtol: float = 1e-3, atol: float = 1e-3) -> dict[str, bool]:
@@ -146,6 +198,14 @@ def _compute_math(op: str, inputs: list[np.ndarray], axis: int | None) -> np.nda
if op == "min":
return np.min(x, axis=axis, keepdims=True)
# Softmax (numerically stable)
if op == "softmax":
ax = axis if axis is not None else -1
x_max = np.max(x, axis=ax, keepdims=True)
e = np.exp(x - x_max)
s = np.sum(e, axis=ax, keepdims=True)
return e / s
# Binary
if len(inputs) >= 2:
y = inputs[1]
@@ -157,9 +217,18 @@ def _compute_math(op: str, inputs: list[np.ndarray], axis: int | None) -> np.nda
return x * y
if op == "div":
return x / y
if op == "maximum":
return np.maximum(x, y)
if op == "minimum":
return np.minimum(x, y)
# Ternary
if op == "where" and len(inputs) >= 3:
return np.where(inputs[0], inputs[1], inputs[2])
if len(inputs) >= 3:
if op == "where":
return np.where(inputs[0], inputs[1], inputs[2])
if op == "fma":
return inputs[0] * inputs[1] + inputs[2]
if op == "clamp":
return np.minimum(np.maximum(inputs[0], inputs[1]), inputs[2])
return None
+55 -2
View File
@@ -51,8 +51,12 @@ class GraphEngine:
if enable_data:
from kernbench.sim_engine.memory_store import MemoryStore
from kernbench.sim_engine.op_log import OpLogger
self._op_logger = OpLogger()
self._memory_store = MemoryStore()
self._op_logger = OpLogger(memory_store=self._memory_store)
# Cursor for incremental Phase 2 replay (ADR-0020 D6).
# SimPy env.now is monotonic so newly logged records always sort
# to the tail; the cursor remains valid across waits.
self._data_cursor = 0
ctx = ComponentContext(
router=self._router,
@@ -147,11 +151,60 @@ class GraphEngine:
self._env.process(self._process(str(handle), request, event))
return handle
def _flush_data_phase(self) -> None:
"""Replay newly recorded op_log entries through DataExecutor.
ADR-0020 D6 Phase 2: when data tracking is enabled, run DataExecutor
on records added since the last flush so that callers reading
MemoryStore between launches observe correct (compute-replayed)
tensor data.
Cursor-based incremental replay is necessary because Phase 2 is
NOT idempotent across full re-runs: a math op writes a TCM scratch
addr, a later dma_write copies that scratch into HBM[X], and an
even-later math op may then read HBM[X]. Re-running everything
from scratch would let the second pass's first math op read the
already-overwritten HBM[X] instead of the original input.
"""
if self._op_logger is None or self._memory_store is None:
return
records = self._op_logger.records # sorted by t_start (stable)
if self._data_cursor >= len(records):
return
new_records = records[self._data_cursor:]
from kernbench.sim_engine.data_executor import DataExecutor
DataExecutor(new_records, self._memory_store).run()
self._data_cursor = len(records)
def wait(self, handle: RequestHandle) -> None:
key = str(handle)
event = self._events[key]
if not event.triggered:
self._env.run(until=event)
try:
self._env.run(until=event)
except (simpy.core.EmptySchedule, RuntimeError) as exc:
# SimPy raises EmptySchedule directly OR (in newer simpy)
# wraps it as a RuntimeError("No scheduled events left ...").
# Either case while our event is still pending → IPCQ deadlock.
msg = str(exc)
is_deadlock = (
isinstance(exc, simpy.core.EmptySchedule)
or "No scheduled events left" in msg
)
if not is_deadlock:
raise
from kernbench.ccl.diagnostics import IpcqDeadlock, pointer_dump
dump = pointer_dump(self)
if dump.strip():
raise IpcqDeadlock(
"IPCQ deadlock: simulation schedule empty while "
f"request {handle!r} is still pending.\n"
f"Pointer state:\n{dump}"
) from None
raise
# ADR-0020: replay newly logged ops so the caller observes
# post-Phase-2 tensor state from MemoryStore.
self._flush_data_phase()
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
return self._results[str(handle)]
+84 -1
View File
@@ -29,9 +29,13 @@ class OpLogger:
Records are maintained in t_start stable ordering (insertion order).
"""
def __init__(self) -> None:
def __init__(self, memory_store: Any | None = None) -> None:
self._records: list[OpRecord] = []
self._pending: dict[int, dict[str, Any]] = {} # msg id → partial record
# Optional MemoryStore reference. When set, math op records capture
# input data snapshots at record_end time so Phase 2 replay does
# not depend on slot/scratch addrs surviving until math runs.
self._memory_store = memory_store
@property
def records(self) -> list[OpRecord]:
@@ -53,6 +57,38 @@ class OpLogger:
if pending is None:
return
op_kind, op_name, params = _extract_op_info(msg)
# Snapshot data at record time so Phase 2 replay sidesteps
# downstream mutations of source addrs (e.g. a tl.store that
# overwrites HBM after a load handle was sent, or a slot that
# gets reused on the next ring round).
if self._memory_store is not None:
if op_kind == "math":
snaps: list[Any] = []
for addr, shape, space, idtype in zip(
params.get("input_addrs", []),
params.get("input_shapes", []),
params.get("input_spaces", []),
params.get("input_dtypes", []),
):
try:
arr = self._memory_store.read(
space, addr, shape=shape, dtype=idtype,
)
snaps.append(arr.copy() if hasattr(arr, "copy") else arr)
except Exception:
snaps.append(None)
params["input_snapshots"] = snaps
elif op_name == "dma_write":
try:
arr = self._memory_store.read(
params["src_space"], params["src_addr"],
shape=params.get("shape"), dtype=params.get("dtype"),
)
params["snapshot"] = (
arr.copy() if hasattr(arr, "copy") else arr
)
except Exception:
params["snapshot"] = None
self._records.append(OpRecord(
t_start=pending["t_start"],
t_end=t,
@@ -62,6 +98,45 @@ class OpLogger:
params=params,
))
def record_copy(
self, t_start: float, t_end: float, component_id: str,
src_space: str, src_addr: int,
dst_space: str, dst_addr: int,
shape: tuple[int, ...], dtype: str, nbytes: int,
) -> None:
"""Record a memory copy op for Phase 2 replay (ADR-0023 + ADR-0020).
Used by PE_DMA at outbound (sender) time: the snapshot captures
the source data at the moment the send was issued, so Phase 2
replay does not see later mutations of the source addr (e.g. a
tl.store that runs after the recv at the sender).
For sources whose data is not yet materialized in Phase 1 (math
scratch outputs), the snapshot is None and Phase 2 falls back to
reading from MemoryStore — by which point the corresponding math
op has been replayed and the scratch addr is populated.
"""
snap = None
if self._memory_store is not None:
try:
arr = self._memory_store.read(
src_space, src_addr, shape=shape, dtype=dtype,
)
snap = arr.copy() if hasattr(arr, "copy") else arr
except Exception:
snap = None
self._records.append(OpRecord(
t_start=t_start, t_end=t_end,
component_id=component_id,
op_kind="memory", op_name="ipcq_copy",
params={
"src_space": src_space, "src_addr": src_addr,
"dst_space": dst_space, "dst_addr": dst_addr,
"shape": shape, "dtype": dtype, "nbytes": nbytes,
"snapshot": snap,
},
))
def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
"""Extract op_kind, op_name, params from a data_op message."""
@@ -76,6 +151,11 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
}
if isinstance(msg, DmaWriteCmd):
return "memory", "dma_write", {
"src_space": getattr(msg.handle, "space", "tcm"),
"src_addr": msg.handle.addr,
"shape": msg.handle.shape,
"dtype": msg.handle.dtype,
"dst_space": "hbm",
"dst_addr": msg.dst_addr,
"nbytes": msg.nbytes,
"handle_id": msg.handle.id,
@@ -96,7 +176,10 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
return "math", msg.op, {
"input_addrs": [h.addr for h in msg.inputs],
"input_shapes": [h.shape for h in msg.inputs],
"input_spaces": [getattr(h, "space", "tcm") for h in msg.inputs],
"input_dtypes": [h.dtype for h in msg.inputs],
"dst_addr": msg.out.addr,
"dst_space": getattr(msg.out, "space", "tcm"),
"shape_out": msg.out.shape,
"dtype": msg.out.dtype,
"axis": msg.axis,
+30 -2
View File
@@ -25,6 +25,7 @@ _PE_COMP_OFFSETS = {
"pe_math": (0.0, 0.15),
"pe_mmu": (0.15, -0.15),
"pe_tcm": (0.3, 0.0),
"pe_ipcq": (-0.15, 0.15),
}
@@ -698,6 +699,20 @@ def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None:
kind="pe_internal",
))
# PE_IPCQ edges (ADR-0023 D1, D9 D10)
ipcq_edges = [
("pe_cpu", "pe_ipcq", "cpu_to_ipcq_mm"), # IpcqRequest
("pe_ipcq", "pe_dma", "ipcq_to_dma_mm"), # IpcqDmaToken outbound
("pe_dma", "pe_ipcq", "dma_to_ipcq_mm"), # IpcqMetaArrival inbound
]
for src_c, dst_c, mm_key in ipcq_edges:
if mm_key in pe_links:
edges.append(Edge(
src=f"{pp}.{src_c}", dst=f"{pp}.{dst_c}",
distance_mm=pe_links[mm_key],
kind="pe_internal",
))
# ── Inter-cube / IO / system edges ──────────────────────────────────
@@ -765,7 +780,13 @@ def _add_io_to_cube_edges(
def _add_system_to_io_edges(
edges: list[Edge], sp: str, sip_spec: dict, system: dict,
) -> None:
"""Add fabric switch IO chiplet PCIe edges."""
"""Add bidirectional fabric switch IO chiplet PCIe edges.
Both directions are needed:
switch → pcie_ep for host→device traffic (memory writes, kernel launch)
pcie_ep → switch for device-side outbound traffic (cross-SIP IPCQ
send between PE_DMAs through the system switch).
"""
sw_id = "fabric.switch0"
sys_link = system["links"]["io_ep_to_switch"]
for inst in sip_spec["iochiplet"]["instances"]:
@@ -776,6 +797,12 @@ def _add_system_to_io_edges(
bw_gbs=sys_link["bw_gbs_per_ep"],
kind="pcie",
))
edges.append(Edge(
src=pcie_ep_id, dst=sw_id,
distance_mm=sys_link["distance_mm"],
bw_gbs=sys_link["bw_gbs_per_ep"],
kind="pcie",
))
# ── View builders ────────────────────────────────────────────────────
@@ -1113,13 +1140,14 @@ def _build_pe_view(spec: dict) -> ViewGraph:
"pe_math": (7.0, 6.5),
"pe_mmu": (4.0, 1.5),
"pe_tcm": (10.0, 4.0),
"pe_ipcq": (4.0, 6.5),
}
nodes: dict[str, Node] = {}
view_edges: list[Edge] = []
for comp_name, comp_spec in pe_tmpl["components"].items():
px, py = positions[comp_name]
px, py = positions.get(comp_name, (1.0, 1.0))
nodes[comp_name] = Node(
id=comp_name, kind=comp_spec["kind"], impl=comp_spec["impl"],
attrs=comp_spec["attrs"], pos_mm=(px, py),
+115 -9
View File
@@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any
import simpy
from greenlet import greenlet
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqRequest, IpcqSendCmd, RecvFuture
from kernbench.common.pe_commands import (
CompletionHandle,
CompositeCmd,
@@ -51,6 +52,9 @@ class KernelRunner:
out_ports: dict[str, simpy.Store],
store: MemoryStore | None = None,
num_cubes: int = 1,
ipcq_id: str | None = None,
scratch_base: int = 0,
scratch_size: int = 1 << 20,
) -> None:
self._pe_prefix = pe_prefix
self._pe_idx = pe_idx
@@ -61,6 +65,13 @@ class KernelRunner:
self._out_ports = out_ports
self._store = store
self._parent: greenlet | None = None
# Optional IPCQ port (ADR-0023). If None, IPCQ commands raise.
self._ipcq_id = ipcq_id or f"{pe_prefix}.pe_ipcq"
# PE-local scratch for compute output TensorHandles (ADR-0020 D3
# extension). The TLContext allocates from this pool when math/dot
# ops produce a result that may later be used as a send/store source.
self._scratch_base = scratch_base
self._scratch_size = scratch_size
def run(
self,
@@ -89,7 +100,10 @@ class KernelRunner:
num_cubes=self._num_cubes,
dispatch_cycles=0,
runner=self,
scratch_base=self._scratch_base,
scratch_size=self._scratch_size,
)
self._tl = tl # exposed so switch_to_simpy can re-set on restore
def _kernel_entry():
TLContext._set_active(tl) # type: ignore[attr-defined]
@@ -103,13 +117,20 @@ class KernelRunner:
pending: dict[str, simpy.Event] = {}
composite_results: list[dict] = []
# Helper: set our tl as active just before resuming the kernel.
# Multiple PE kernel runners share the same thread-local; without
# this, another runner's kernel may have left a different context.
def _switch_kernel(*args):
TLContext._set_active(tl) # type: ignore[attr-defined]
return g.switch(*args)
# Start kernel — first switch returns first command (or None if kernel is done)
cmd = g.switch()
cmd = _switch_kernel()
while cmd is not None:
if isinstance(cmd, PeCpuOverheadCmd):
yield env.timeout(cmd.cycles)
cmd = g.switch()
cmd = _switch_kernel()
elif isinstance(cmd, WaitCmd):
if cmd.handle is not None:
@@ -120,7 +141,7 @@ class KernelRunner:
for evt in pending.values():
yield evt
pending.clear()
cmd = g.switch()
cmd = _switch_kernel()
elif isinstance(cmd, DmaReadCmd):
# Dispatch DMA through SimPy components
@@ -141,10 +162,12 @@ class KernelRunner:
)
except KeyError:
pass
cmd = g.switch(data)
cmd = _switch_kernel(data)
elif isinstance(cmd, DmaWriteCmd):
# Write to MemoryStore first (visibility = issue, ADR-0020 D3)
# Write to MemoryStore first (visibility = issue, ADR-0020 D3).
# When data is None (e.g. timing-only TensorHandle math result),
# this is a no-op; Phase 2 dma_write replay handles those.
if self._store is not None and cmd.handle.data is not None:
self._store.write("hbm", cmd.dst_addr, cmd.handle.data)
@@ -154,7 +177,7 @@ class KernelRunner:
)
yield self._out_ports[self._scheduler_id].put(pe_txn)
yield done_evt
cmd = g.switch()
cmd = _switch_kernel()
elif isinstance(cmd, CompositeCmd):
# Non-blocking composite
@@ -165,7 +188,7 @@ class KernelRunner:
composite_results.append(pe_txn.result_data)
yield self._out_ports[self._scheduler_id].put(pe_txn)
pending[cmd.completion.id] = done_evt
cmd = g.switch()
cmd = _switch_kernel()
elif isinstance(cmd, (GemmCmd, MathCmd)):
# Blocking compute command
@@ -175,7 +198,90 @@ class KernelRunner:
)
yield self._out_ports[self._scheduler_id].put(pe_txn)
yield done_evt
cmd = g.switch()
cmd = _switch_kernel()
elif isinstance(cmd, IpcqSendCmd):
# Forward IpcqRequest to PE_IPCQ, wait for done
if self._ipcq_id not in self._out_ports:
raise RuntimeError(
f"PE_IPCQ port {self._ipcq_id!r} not wired to runner"
)
done_evt = env.event()
req = IpcqRequest(command=cmd, done=done_evt)
yield self._out_ports[self._ipcq_id].put(req)
yield done_evt
cmd = _switch_kernel()
elif isinstance(cmd, IpcqRecvCmd):
if self._ipcq_id not in self._out_ports:
raise RuntimeError(
f"PE_IPCQ port {self._ipcq_id!r} not wired to runner"
)
done_evt = env.event()
req = IpcqRequest(command=cmd, done=done_evt)
yield self._out_ports[self._ipcq_id].put(req)
yield done_evt
# Read actual data from MemoryStore at the slot address
data = None
src_space = req.result_data.get("src_space", "tcm")
src_addr = req.result_data.get("src_addr", 0)
if self._store is not None:
try:
data = self._store.read(
src_space, src_addr,
shape=cmd.shape, dtype=cmd.dtype,
)
except KeyError:
pass
# Build result dict for tl.recv to wrap in TensorHandle
result = {
"data": data,
"src_space": src_space,
"src_addr": src_addr,
"direction": req.result_data.get("direction", cmd.direction),
"dtype": cmd.dtype,
"shape": cmd.shape,
"nbytes": req.result_data.get("nbytes", 0),
}
cmd = _switch_kernel(result)
elif isinstance(cmd, tuple) and len(cmd) == 2 and cmd[0] == "recv_async":
# Non-blocking recv: post the IpcqRequest now, store the
# event in the future, return None to kernel.
future: RecvFuture = cmd[1]
done_evt = env.event()
req = IpcqRequest(command=future.cmd, done=done_evt)
future.request = req
future.event = done_evt
yield self._out_ports[self._ipcq_id].put(req)
cmd = _switch_kernel(None)
elif isinstance(cmd, tuple) and len(cmd) == 2 and cmd[0] == "recv_wait":
future = cmd[1]
if not future.event.triggered:
yield future.event
req = future.request
src_space = req.result_data.get("src_space", "tcm")
src_addr = req.result_data.get("src_addr", 0)
data = None
if self._store is not None:
try:
data = self._store.read(
src_space, src_addr,
shape=future.cmd.shape, dtype=future.cmd.dtype,
)
except KeyError:
pass
result = {
"data": data,
"src_space": src_space,
"src_addr": src_addr,
"direction": req.result_data.get("direction", future.cmd.direction),
"dtype": future.cmd.dtype,
"shape": future.cmd.shape,
"nbytes": req.result_data.get("nbytes", 0),
}
cmd = _switch_kernel(result)
else:
# Unknown command — pass through as blocking
@@ -185,7 +291,7 @@ class KernelRunner:
)
yield self._out_ports[self._scheduler_id].put(pe_txn)
yield done_evt
cmd = g.switch()
cmd = _switch_kernel()
# Wait remaining pending composites
for evt in pending.values():
+263 -12
View File
@@ -17,6 +17,7 @@ from __future__ import annotations
import math
from typing import Literal
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd, RecvFuture
from kernbench.common.pe_commands import (
CompletionHandle,
CompositeCmd,
@@ -55,6 +56,8 @@ class TLContext:
runner: Any = None,
cube_id: int = 0,
num_cubes: int = 1,
scratch_base: int = 0,
scratch_size: int = 1 << 20, # 1 MiB per kernel invocation
) -> None:
self._pe_id = pe_id
self._num_programs = num_programs
@@ -65,6 +68,33 @@ class TLContext:
self._handle_counter = 0
self._completion_counter = 0
self._runner = runner # KernelRunner for greenlet mode (ADR-0020 D3)
# PE-local scratch allocator for math/compute output handles.
# Each binary/unary/reduction op auto-allocates a unique addr from
# this pool so the resulting TensorHandle can be the source of a
# later tl.send / tl.store. Cursor resets on every kernel invocation.
self._scratch_base = scratch_base
self._scratch_size = scratch_size
self._scratch_cursor = 0
def _scratch_alloc(self, nbytes: int) -> int:
"""Allocate a unique scratch address for an output TensorHandle.
Returns 0 if no scratch base was configured (e.g. command-list mode);
in that case the resulting handle has addr=0 and cannot be used as a
send/store source. Greenlet/runner mode always supplies a base.
"""
if self._scratch_base == 0:
return 0
# 16-byte alignment
aligned = (nbytes + 15) & ~15
addr = self._scratch_base + self._scratch_cursor
self._scratch_cursor += aligned
if self._scratch_cursor > self._scratch_size:
raise RuntimeError(
f"TLContext scratch overflow: requested {nbytes}B, "
f"used {self._scratch_cursor}/{self._scratch_size}B"
)
return addr
@property
def commands(self) -> list[PeCommand]:
@@ -93,11 +123,30 @@ class TLContext:
def _make_handle(
self, addr: int, shape: tuple[int, ...], dtype: str,
space: str = "tcm",
) -> TensorHandle:
return TensorHandle(
id=self._next_handle_id(),
addr=addr, shape=shape, dtype=dtype,
nbytes=self._nbytes(shape, dtype),
space=space,
)
def _make_compute_out(
self, shape: tuple[int, ...], dtype: str,
) -> TensorHandle:
"""Allocate an output TensorHandle in PE-local scratch (TCM space).
Used by math/compute ops so the result has a real address that can
be the source of a later send/store. The data field stays None in
Phase 1 — Phase 2 DataExecutor fills the actual ndarray.
"""
nbytes = self._nbytes(shape, dtype)
addr = self._scratch_alloc(nbytes)
return TensorHandle(
id=self._next_handle_id(),
addr=addr, shape=shape, dtype=dtype,
nbytes=nbytes, space="tcm",
)
# ── Reference (no DMA, metadata only) ────────────────────────
@@ -124,20 +173,26 @@ class TLContext:
def load(
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
) -> TensorHandle:
"""Load tensor from HBM to TCM. Returns TensorHandle.
"""Load tensor from HBM. Returns TensorHandle pointing at HBM[ptr].
In greenlet mode: returns TensorHandle with actual numpy data.
In command-list mode: returns TensorHandle with data=None.
The returned handle's ``space`` is "hbm" so subsequent ops (math,
send, store) using this handle as a source resolve via MemoryStore
at ``(hbm, ptr)`` — which is where the load's underlying data
actually lives in Phase 2 storage.
"""
self._emit_dispatch_overhead()
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype)
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype, space="hbm")
cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes)
data = self._emit(cmd)
if data is not None:
# Greenlet mode: attach real data to handle
# Greenlet mode: attach real data to handle (preserve space)
return TensorHandle(
id=handle.id, addr=handle.addr, shape=handle.shape,
dtype=handle.dtype, nbytes=handle.nbytes, data=data,
space=handle.space,
)
return handle
@@ -162,7 +217,7 @@ class TLContext:
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
out_shape = (*a.shape[:-2], m, n)
out_dtype = a.dtype
out = self._make_handle(addr=0, shape=out_shape, dtype=out_dtype)
out = self._make_compute_out(shape=out_shape, dtype=out_dtype)
self._emit_dispatch_overhead()
self._emit(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
return out
@@ -170,7 +225,7 @@ class TLContext:
# ── MATH Engine: unary (blocking) ─────────────────────────────
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
out = self._make_handle(addr=0, shape=x.shape, dtype=x.dtype)
out = self._make_compute_out(shape=x.shape, dtype=x.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op=op, inputs=(x,), out=out))
return out
@@ -203,7 +258,7 @@ class TLContext:
) -> TensorHandle:
out_shape = list(x.shape)
out_shape[axis] = 1
out = self._make_handle(addr=0, shape=tuple(out_shape), dtype=x.dtype)
out = self._make_compute_out(shape=tuple(out_shape), dtype=x.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
return out
@@ -222,7 +277,7 @@ class TLContext:
def _binary_math(
self, op: str, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
out = self._make_compute_out(shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op=op, inputs=(a, b), out=out))
return out
@@ -230,15 +285,67 @@ class TLContext:
def where(
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
out = self._make_compute_out(shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op="where", inputs=(cond, a, b), out=out))
return out
def maximum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
"""Element-wise max of two tensors (real Triton: tl.maximum)."""
return self._binary_math("maximum", a, b)
def minimum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
"""Element-wise min of two tensors (real Triton: tl.minimum)."""
return self._binary_math("minimum", a, b)
def fma(
self, a: TensorHandle, b: TensorHandle, c: TensorHandle,
) -> TensorHandle:
"""Fused multiply-add: a * b + c (real Triton: tl.fma)."""
out = self._make_compute_out(shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op="fma", inputs=(a, b, c), out=out))
return out
def clamp(
self,
x: TensorHandle,
min: TensorHandle,
max: TensorHandle,
) -> TensorHandle:
"""Clamp x to [min, max] (real Triton: tl.clamp)."""
out = self._make_compute_out(shape=x.shape, dtype=x.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op="clamp", inputs=(x, min, max), out=out))
return out
def softmax(self, x: TensorHandle, axis: int = -1) -> TensorHandle:
"""Numerically-stable softmax along ``axis`` (real Triton: tl.softmax).
Implemented as a single MathCmd (op="softmax") so timing accounts
for one MATH dispatch; Phase 2 DataExecutor expands it to the
canonical (x - max) → exp → sum → div sequence.
"""
out = self._make_compute_out(shape=x.shape, dtype=x.dtype)
self._emit_dispatch_overhead()
self._emit(MathCmd(op="softmax", inputs=(x,), out=out, axis=axis))
return out
# ── Scalar helpers (real Triton: tl.cdiv etc.) ────────────────
@staticmethod
def cdiv(a: int, b: int) -> int:
"""Ceiling division: (a + b - 1) // b (real Triton: tl.cdiv).
Used by host/kernel grid math; not a tensor op, so no MathCmd
is emitted. Mirrors triton.cdiv.
"""
return -(-int(a) // int(b))
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
def program_id(self, axis: int = 0) -> int:
"""Return program instance index.
"""Return program instance index (ADR-0022).
axis=0: local PE id within cube.
axis=1: cube id.
@@ -248,7 +355,7 @@ class TLContext:
return self._pe_id
def num_programs(self, axis: int = 0) -> int:
"""Return total number of program instances.
"""Return total number of program instances (ADR-0022).
axis=0: num PEs per cube.
axis=1: num cubes.
@@ -284,6 +391,119 @@ class TLContext:
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
)
# ── IPCQ (CCL) collective primitives (ADR-0023 D4) ────────────
def send(
self,
dir: str,
src: TensorHandle | None = None,
*,
src_addr: int | None = None,
nbytes: int | None = None,
shape: tuple[int, ...] | None = None,
dtype: str = "f16",
space: str = "tcm",
) -> None:
"""Send tensor data to the peer in the given direction.
Two calling forms:
tl.send(dir, handle) # use handle's metadata
tl.send(dir, src_addr=..., nbytes=..., shape=..., dtype=..., space=...)
Blocking: returns when PE_IPCQ has accepted the request and
forwarded the IpcqDmaToken to PE_DMA. Backpressure may apply.
"""
if src is not None:
src_addr = src.addr
nbytes = src.nbytes
shape = src.shape
dtype = src.dtype
space = getattr(src, "space", space)
if src_addr is None or nbytes is None or shape is None:
raise ValueError("tl.send: provide either a TensorHandle or src_addr/nbytes/shape")
self._emit_dispatch_overhead()
cmd = IpcqSendCmd(
direction=dir,
src_addr=src_addr, src_space=space,
nbytes=nbytes, shape=shape, dtype=dtype,
handle_id=self._next_handle_id(),
)
self._emit(cmd)
def recv(
self,
dir: str | None = None,
shape: tuple[int, ...] = (),
dtype: str = "f16",
space: str = "tcm",
dst_addr: int | None = None,
dst_space: str | None = None,
) -> TensorHandle:
"""Receive tensor data from a peer.
Args:
dir: specific direction (e.g. "W"), or None for round-robin.
shape, dtype: expected tensor metadata.
dst_addr / dst_space: if both are provided, the slot data is
copied to (dst_space, dst_addr) before the handle is
returned ("copy_to_dst" mode). Otherwise the slot address
is returned directly ("return_slot" mode).
Returns:
TensorHandle pointing to the slot (or dst) where the data has
arrived. In greenlet/runner mode, ``handle.data`` carries the
actual ndarray; in command-list mode the handle is a placeholder.
"""
self._emit_dispatch_overhead()
if dst_addr is not None and dst_space is not None:
cmd = IpcqRecvCmd(
direction=dir,
shape=shape, dtype=dtype,
handle_id=self._next_handle_id(),
recv_mode="copy_to_dst",
dst_addr=dst_addr, dst_space=dst_space,
)
else:
cmd = IpcqRecvCmd(
direction=dir,
shape=shape, dtype=dtype,
handle_id=self._next_handle_id(),
)
result = self._emit(cmd)
if isinstance(result, dict):
slot_addr = int(result.get("src_addr", 0))
slot_space = str(result.get("src_space", "tcm"))
data = result.get("data")
return TensorHandle(
id=self._next_handle_id(),
addr=slot_addr,
shape=shape,
dtype=dtype,
nbytes=self._nbytes(shape, dtype),
data=data,
space=slot_space,
)
return self._make_handle(addr=0, shape=shape, dtype=dtype)
def recv_async(
self,
dir: str,
shape: tuple[int, ...] = (),
dtype: str = "f16",
) -> "RecvFuture":
"""Non-blocking recv. Returns a future to pass into ``tl.wait``."""
self._emit_dispatch_overhead()
cmd = IpcqRecvCmd(
direction=dir,
shape=shape, dtype=dtype,
handle_id=self._next_handle_id(),
blocking=False,
)
future = RecvFuture(cmd=cmd)
if self._runner is not None:
self._runner.switch_to_simpy(("recv_async", future))
return future
# ── Composite + Control ───────────────────────────────────────
def composite(
@@ -316,9 +536,40 @@ class TLContext:
))
return completion
def wait(self, handle: CompletionHandle | None = None) -> None:
"""Wait for a specific composite or all pending composites."""
def wait(self, handle: "CompletionHandle | RecvFuture | None" = None) -> Any:
"""Wait for a composite, a recv future, or all pending composites.
- ``CompletionHandle`` (or None): wait for composite completion.
- ``RecvFuture``: wait for a non-blocking ``recv_async`` to finish.
Returns the resolved ``TensorHandle``.
"""
if isinstance(handle, RecvFuture):
if handle.resolved:
return handle.result
if self._runner is None:
raise RuntimeError(
"tl.wait(RecvFuture) requires runner mode (greenlet)"
)
result_dict = self._runner.switch_to_simpy(("recv_wait", handle))
slot_addr = int(result_dict.get("src_addr", 0))
slot_space = str(result_dict.get("src_space", "tcm"))
data = result_dict.get("data")
th = TensorHandle(
id=self._next_handle_id(),
addr=slot_addr,
shape=handle.cmd.shape,
dtype=handle.cmd.dtype,
nbytes=self._nbytes(handle.cmd.shape, handle.cmd.dtype),
data=data,
space=slot_space,
)
handle.resolved = True
handle.result = th
return th
# Composite path (existing behaviour)
self._emit(WaitCmd(handle=handle))
return None
def cycles(self, n: int) -> None:
"""Declare PE_CPU scalar execution overhead (cycles)."""