Add PE-level IPCQ collective infra + unified ccl_allreduce bench (ADR-0023)
Major changes:
PE-level IPCQ infrastructure:
- New PE_IPCQ component: ring-buffer control plane with 4-direction
neighbor mapping, head/tail pointers, backpressure (poll/sleep).
- PE_DMA extended with vc_comm channel for IPCQ outbound/inbound DMA,
including in-flight data snapshot (D9) and op_log recording at
outbound time for Phase 2 replay correctness.
- IpcqDmaToken piggyback model: data + metadata travel together,
atomic visibility at receiver (invariant I6).
- Credit return fast path: bottleneck-BW latency, no fabric vc_comm.
Phase 2 data execution (ADR-0020 integration):
- op_log extended: DmaWriteCmd now captures src_space/src_addr for
Phase 2 dma_write copy; ipcq_copy ops recorded at outbound time.
- DataExecutor replays dma_write + ipcq_copy in t_start order.
- Engine._flush_data_phase: incremental cursor-based replay after
each engine.wait() so host reads see post-Phase-2 data.
- KernelRunner Phase 1 writes disabled when op_log is active to
prevent stale data from corrupting the MemoryStore snapshot.
TLContext / kernel API:
- tl.send(dir, src=TensorHandle), tl.recv(dir, shape, dtype),
tl.recv_async, tl.wait(RecvFuture), copy_to_dst mode.
- TensorHandle operator overloading (add/sub/mul/div) via thread-local
active TLContext → MathCmd dispatch through PE_MATH.
- PE-local scratch allocator for math output handles.
- tl.load returns space="hbm" handles for correct Phase 2 addressing.
- Additional math functions: maximum, minimum, fma, clamp, softmax, cdiv.
Unified ccl_allreduce bench (PyTorch-compat host code):
- Single benches/ccl_allreduce.py with run() + worker(rank, ws, torch)
split matching real PyTorch DDP worker pattern.
- torch.distributed facade: init_process_group, get_world_size,
get_rank, get_backend, all_reduce, barrier — only real PyTorch names.
- AhbmCCLBackend: eager install_ipcq at init, all_reduce dispatches
kernel via tensor shard metadata (n_elem from shards[0].nbytes).
- world_size derived from topology spec (sips × cubes × pes_per_cube)
with optional algorithm-level override in ccl.yaml.
Tensor API (PyTorch-compat surface):
- Tensor.numpy(): gather-aware (all shards via VA-based addressing).
- Tensor.copy_(source): scatter from host tensor into sharded target.
- RuntimeContext.from_numpy(arr): host-side staging tensor.
- Tensor.data property fixed to use numpy() (was shards[0]-only).
Algorithm modules moved to src/kernbench/ccl/algorithms/:
- ring_allreduce, mesh_allreduce, tree_allreduce, hello_send.
- Each module exports kernel_args(world_size, n_elem) helper.
- ccl.yaml module paths updated to kernbench.ccl.algorithms.*.
Dead code removed:
- 7 per-variant bench files (ccl_allreduce_{tcm,hbm,sram}, etc.).
- _run_ccl_bench greenlet-per-SIP scheduler.
- benches.loader.is_ccl_bench + run_rank detection.
- benches/ccl/ directory.
Tests:
- New test_ccl_allreduce_matrix.py: 7 parametrized cases
(ring×3 buffers, ring 8/16, mesh 4, tree 7).
- New test_runtime_api_tensor.py: copy_/numpy/from_numpy unit tests.
- Existing tests updated for new import paths + world_size_override.
Docs:
- Korean ccl-author-guide.md and ADR-0023 paths updated.
- New English versions: ccl-author-guide.en.md, ADR-0023.en.md.
502 tests pass.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
"""CCL (Collective Communication Library) framework for kernbench (ADR-0023).
|
||||
|
||||
This package provides:
|
||||
- topologies: builtin neighbor topology generators (ring/mesh/tree)
|
||||
- helpers: utilities for algorithm authors (chunked, ring_step, ...)
|
||||
- testing: mock CCL runtime for fast unit tests of algorithm kernels
|
||||
|
||||
See docs/adr/ADR-0023-ipcq-pe-collective.md and docs/ccl-author-guide.md.
|
||||
"""
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Hello-world CCL kernel for the docs/ccl-author-guide.md walkthrough.
|
||||
|
||||
Each PE sends its tile to the E neighbor and receives one tile from W,
|
||||
then stores the received tile back into its own HBM slice. The simplest
|
||||
possible demonstration of ``tl.send`` / ``tl.recv``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend."""
|
||||
return (n_elem,)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, tl):
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
|
||||
# Send our local HBM tile to the E neighbor.
|
||||
src = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
tl.send(dir="E", src=src)
|
||||
|
||||
# Receive a tile from W and store it into our slice (overwrite).
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
tl.store(pe_addr, recv)
|
||||
@@ -0,0 +1,73 @@
|
||||
"""2D-mesh all-reduce kernel (ADR-0023).
|
||||
|
||||
Two-phase reduce on a square mesh of side ``S`` (world_size = S*S):
|
||||
1. Row reduce: ring all-reduce along E/W within each row.
|
||||
2. Column reduce: ring all-reduce along N/S within each column.
|
||||
|
||||
After both phases, every rank holds the global sum.
|
||||
|
||||
Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the
|
||||
data flow so Phase 2 produces correct final HBM contents. Math/recv
|
||||
handles are passed directly to the next send, avoiding store→reload
|
||||
which doesn't propagate correctly with timing-only Phase 1 math.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend.
|
||||
|
||||
Mesh all-reduce requires ``world_size`` to be a perfect square —
|
||||
the mesh side length is ``sqrt(world_size)``.
|
||||
"""
|
||||
side = int(round(math.sqrt(world_size)))
|
||||
if side * side != world_size:
|
||||
raise ValueError(
|
||||
f"mesh_allreduce requires a square world_size; got {world_size}"
|
||||
)
|
||||
return (n_elem, side)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, side, tl):
|
||||
"""All-reduce on a square mesh.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address (column-sharded VA shared across ranks)
|
||||
n_elem: number of f16 elements per tile
|
||||
side: mesh side length (sqrt(world_size))
|
||||
tl: TLContext (ADR-0022).
|
||||
"""
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
current = acc
|
||||
|
||||
# ── Phase 1: row ring (E direction) ──
|
||||
# Ring forwards each received tile (not the cumulative acc) so every
|
||||
# tile passes through every rank exactly once.
|
||||
for _ in range(side - 1):
|
||||
tl.send(dir="E", src=current)
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
|
||||
# Phase 2 column ring starts from the row-phase accumulator. We do NOT
|
||||
# store/reload here — the math handle's scratch addr is the source for
|
||||
# the first column send and Phase 2 ipcq_copy replays from there.
|
||||
current = acc
|
||||
|
||||
# ── Phase 2: column ring (S direction) ──
|
||||
for _ in range(side - 1):
|
||||
tl.send(dir="S", src=current)
|
||||
recv = tl.recv(dir="N", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
current = recv
|
||||
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Ring all-reduce kernel for IPCQ-based PE collective (ADR-0023).
|
||||
|
||||
Algorithm: 1D ring of N PEs, each PE starts with one tile of data.
|
||||
After ``world_size - 1`` rounds, every PE's accumulator holds the sum
|
||||
of all PE tiles.
|
||||
|
||||
Strategy
|
||||
--------
|
||||
Each PE starts with its own tile in HBM. The kernel:
|
||||
1. Loads the local tile into a TensorHandle (the accumulator).
|
||||
2. In each of ``world_size - 1`` rounds:
|
||||
- Sends the current accumulator/recv slot to the E neighbor.
|
||||
- Receives a tile from the W neighbor — the recv handle points
|
||||
into the per-direction TCM slot.
|
||||
- Adds the received tile to the accumulator using the TensorHandle
|
||||
operator overload, which dispatches to ``MathCmd`` (PE_MATH).
|
||||
3. Stores the final accumulator back to HBM via tl.store. The store is
|
||||
recorded in op_log with both src and dst, so Phase 2 will copy the
|
||||
replayed math result from PE-local scratch into HBM.
|
||||
|
||||
ADR-0020 D3 split: Phase 1 simulates timing only — math results are
|
||||
not yet computed, so the accumulator data flowing through Phase 1 may
|
||||
be stale. Phase 2's DataExecutor replays math + IPCQ copies + dma_write
|
||||
in stable t_start order, producing correct final HBM contents.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend.
|
||||
|
||||
Ring all-reduce takes (n_elem, world_size) after the tensor pointer.
|
||||
"""
|
||||
return (n_elem, world_size)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, world_size, tl):
|
||||
"""Ring all-reduce.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address of the column-sharded tensor — all PEs
|
||||
share this base. The per-PE slice lives at
|
||||
``t_ptr + global_rank * n_elem * 2``.
|
||||
n_elem: number of f16 elements per tile.
|
||||
world_size: total number of participating ranks (passed by host).
|
||||
tl: TLContext (auto-injected, ADR-0022). The kernel derives the
|
||||
global rank from ``program_id(axis=0)`` (local PE) and
|
||||
``program_id(axis=1)`` (cube id):
|
||||
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
"""
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2 # f16
|
||||
|
||||
# Each PE reads from its own slice of the shared base address
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
|
||||
# Load the local tile — handle points at HBM[pe_addr].
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
# The ring forwards each received tile to the next neighbor (NOT the
|
||||
# cumulative accumulator), so every rank's tile passes through every
|
||||
# rank exactly once. The accumulator sums the new arrival each round.
|
||||
current = acc
|
||||
|
||||
for _step in range(world_size - 1):
|
||||
tl.send(dir="E", src=current)
|
||||
recv = tl.recv(dir="W", shape=(n_elem,), dtype="f16")
|
||||
# TensorHandle add → MathCmd → PE_MATH (timing in Phase 1, real
|
||||
# numpy in Phase 2 via DataExecutor). The result handle lives at
|
||||
# an auto-allocated PE-local scratch addr.
|
||||
acc = acc + recv
|
||||
current = recv # forward W's tile to E next round
|
||||
|
||||
# Final result back to this PE's HBM slice. Op_log captures the
|
||||
# source (scratch addr) and dst (HBM slice) so Phase 2 copies the
|
||||
# accumulated value into HBM for verification.
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Tree all-reduce kernel for IPCQ-based PE collective (ADR-0023).
|
||||
|
||||
Two-phase binary tree all-reduce:
|
||||
|
||||
Phase 1 (reduce up):
|
||||
- leaf nodes send their value to ``parent``
|
||||
- internal nodes recv from each child, sum, then send to ``parent``
|
||||
- root accumulates child contributions; final acc holds global sum
|
||||
|
||||
Phase 2 (broadcast down):
|
||||
- root sends acc to ``child_left`` and ``child_right`` (if present)
|
||||
- internal nodes recv from ``parent``, then forward to children
|
||||
- all ranks store the final acc to HBM
|
||||
|
||||
Uses TensorHandle math (PE_MATH) for accumulation. Op_log captures the
|
||||
data flow so Phase 2 produces correct final HBM contents. The kernel
|
||||
deliberately avoids the store→reload→send pattern: math/recv handles
|
||||
are passed directly to the next send so PE_DMA snapshots a deterministic
|
||||
source addr that Phase 2 can replay.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def kernel_args(world_size: int, n_elem: int) -> tuple:
|
||||
"""Return the positional kernel arguments for the ahbm backend."""
|
||||
return (n_elem, world_size)
|
||||
|
||||
|
||||
def kernel(t_ptr, n_elem, world_size, tl):
|
||||
"""Tree all-reduce.
|
||||
|
||||
Args:
|
||||
t_ptr: HBM base address.
|
||||
n_elem: number of f16 elements per tile.
|
||||
world_size: total number of participating ranks (passed by host).
|
||||
tl: TLContext (ADR-0022). Global rank from program_id(0/1).
|
||||
"""
|
||||
local_pe = tl.program_id(axis=0)
|
||||
cube_id = tl.program_id(axis=1)
|
||||
pes_per_cube = tl.num_programs(axis=0)
|
||||
rank = cube_id * pes_per_cube + local_pe
|
||||
nbytes = n_elem * 2
|
||||
|
||||
pe_addr = t_ptr + rank * nbytes
|
||||
acc = tl.load(pe_addr, shape=(n_elem,), dtype="f16")
|
||||
|
||||
# Compute children/parent existence (matches tree_binary topology generator)
|
||||
has_parent = rank > 0
|
||||
left = 2 * rank + 1
|
||||
right = 2 * rank + 2
|
||||
has_left = left < world_size
|
||||
has_right = right < world_size
|
||||
|
||||
# ── Phase 1: reduce up ──
|
||||
if has_left:
|
||||
recv = tl.recv(dir="child_left", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
if has_right:
|
||||
recv = tl.recv(dir="child_right", shape=(n_elem,), dtype="f16")
|
||||
acc = acc + recv
|
||||
|
||||
if has_parent:
|
||||
# Send the math/load handle directly — its addr is either the
|
||||
# original HBM tile (leaf) or the PE-local scratch where the
|
||||
# accumulator lives. Phase 2 ipcq_copy replays from the same addr.
|
||||
tl.send(dir="parent", src=acc)
|
||||
|
||||
# ── Phase 2: broadcast down ──
|
||||
if has_parent:
|
||||
# Replace acc with the value broadcast from the parent (the global
|
||||
# sum). The recv handle points at the parent-direction TCM slot.
|
||||
acc = tl.recv(dir="parent", shape=(n_elem,), dtype="f16")
|
||||
|
||||
if has_left:
|
||||
tl.send(dir="child_left", src=acc)
|
||||
if has_right:
|
||||
tl.send(dir="child_right", src=acc)
|
||||
|
||||
# Final store to HBM for the bench's verification path.
|
||||
tl.store(pe_addr, acc)
|
||||
@@ -0,0 +1,127 @@
|
||||
"""CCL diagnostics: trace + pointer dump + deadlock (ADR-0023 D14).
|
||||
|
||||
Trace
|
||||
-----
|
||||
Set ``KERNBENCH_CCL_TRACE=1`` (or any truthy value) to enable per-event
|
||||
logging of CCL send/recv to stdout. Off by default.
|
||||
|
||||
Pointer dump
|
||||
------------
|
||||
``pointer_dump(engine)`` returns a multi-line string showing every PE_IPCQ's
|
||||
ring buffer state (my_head, my_tail, peer_head_cache, peer_tail_cache).
|
||||
Useful for diagnosing hangs.
|
||||
|
||||
Deadlock
|
||||
--------
|
||||
``IpcqDeadlock`` is raised by the engine when SimPy's schedule empties
|
||||
while a request is still pending — typical of unmatched send/recv pairs.
|
||||
The exception message includes the pointer dump.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
class IpcqDeadlock(RuntimeError):
|
||||
"""Raised when the simulation cannot make further progress while a
|
||||
CCL request is still pending (D14 F3)."""
|
||||
|
||||
|
||||
# ── Trace toggle ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
_TRACE_ENABLED: bool = False
|
||||
|
||||
|
||||
def reload_trace_setting() -> None:
|
||||
"""Re-read the ``KERNBENCH_CCL_TRACE`` env var."""
|
||||
global _TRACE_ENABLED
|
||||
val = os.environ.get("KERNBENCH_CCL_TRACE", "")
|
||||
_TRACE_ENABLED = val.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def trace_enabled() -> bool:
|
||||
return _TRACE_ENABLED
|
||||
|
||||
|
||||
# Initialise once at import time
|
||||
reload_trace_setting()
|
||||
|
||||
|
||||
# ── Trace event functions ────────────────────────────────────────────
|
||||
|
||||
|
||||
def log_send(
|
||||
t_ns: float,
|
||||
sender: str,
|
||||
direction: str,
|
||||
nbytes: int,
|
||||
sender_seq: int,
|
||||
) -> None:
|
||||
if not _TRACE_ENABLED:
|
||||
return
|
||||
print(
|
||||
f"[ccl t={t_ns:.1f} send] {sender} dir={direction} nbytes={nbytes} seq={sender_seq}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def log_recv(
|
||||
t_ns: float,
|
||||
receiver: str,
|
||||
direction: str,
|
||||
nbytes: int,
|
||||
) -> None:
|
||||
if not _TRACE_ENABLED:
|
||||
return
|
||||
print(
|
||||
f"[ccl t={t_ns:.1f} recv] {receiver} dir={direction} nbytes={nbytes}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def log_credit_return(
|
||||
t_ns: float,
|
||||
sender: str,
|
||||
direction: str,
|
||||
consumer_seq: int,
|
||||
) -> None:
|
||||
if not _TRACE_ENABLED:
|
||||
return
|
||||
print(
|
||||
f"[ccl t={t_ns:.1f} credit] {sender} dir={direction} seq={consumer_seq}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
# ── Pointer dump ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def pointer_dump(engine: Any) -> str:
|
||||
"""Return a multi-line string of every PE_IPCQ's pointer state."""
|
||||
lines: list[str] = []
|
||||
components = getattr(engine, "_components", {})
|
||||
for node_id in sorted(components):
|
||||
if not node_id.endswith(".pe_ipcq"):
|
||||
continue
|
||||
comp = components[node_id]
|
||||
qps = getattr(comp, "queue_pairs", {})
|
||||
if not qps:
|
||||
continue
|
||||
lines.append(node_id)
|
||||
for d in sorted(qps):
|
||||
qp = qps[d]
|
||||
peer = qp["peer"]
|
||||
lines.append(
|
||||
f" {d}: peer=sip{peer.sip}.cube{peer.cube}.pe{peer.pe} "
|
||||
f"my_head={qp['my_head']} my_tail={qp['my_tail']} "
|
||||
f"peer_head_cache={qp['peer_head_cache']} "
|
||||
f"peer_tail_cache={qp['peer_tail_cache']}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def print_pointer_dump(engine: Any) -> None:
|
||||
"""Convenience: print pointer_dump(engine) to stdout."""
|
||||
print(pointer_dump(engine), flush=True)
|
||||
@@ -0,0 +1,118 @@
|
||||
"""Helpers for CCL algorithm authors (ADR-0023 D15).
|
||||
|
||||
These are pure utility functions usable from any kernel module:
|
||||
|
||||
from kernbench.ccl.helpers import chunked, ring_step, tree_step
|
||||
|
||||
They keep algorithm code short and free of off-by-one bugs.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
_DTYPE_BYTES = {
|
||||
"f16": 2, "fp16": 2, "float16": 2, "bf16": 2,
|
||||
"f32": 4, "fp32": 4, "float32": 4,
|
||||
"i8": 1, "int8": 1,
|
||||
"i16": 2, "int16": 2,
|
||||
"i32": 4, "int32": 4,
|
||||
}
|
||||
|
||||
|
||||
def _itemsize(dtype: str) -> int:
|
||||
if dtype not in _DTYPE_BYTES:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
return _DTYPE_BYTES[dtype]
|
||||
|
||||
|
||||
# ── chunked ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Chunk:
|
||||
"""One chunk of a tensor used by collective algorithms."""
|
||||
|
||||
addr: int
|
||||
n_elem: int
|
||||
nbytes: int
|
||||
|
||||
|
||||
def chunked(
|
||||
base_addr: int,
|
||||
n_chunks: int,
|
||||
n_elem: int,
|
||||
dtype: str = "f16",
|
||||
) -> list[Chunk]:
|
||||
"""Slice a 1D buffer into ``n_chunks`` equal Chunks.
|
||||
|
||||
Args:
|
||||
base_addr: starting address of the buffer.
|
||||
n_chunks: number of equal chunks to produce.
|
||||
n_elem: total number of elements (must be divisible by n_chunks).
|
||||
dtype: element type for byte-size calculation.
|
||||
|
||||
Returns:
|
||||
List of ``Chunk`` objects whose addresses are consecutive.
|
||||
|
||||
Raises:
|
||||
ValueError: if n_elem is not divisible by n_chunks.
|
||||
"""
|
||||
if n_elem % n_chunks != 0:
|
||||
raise ValueError(
|
||||
f"chunked: n_elem ({n_elem}) not divisible by n_chunks ({n_chunks})"
|
||||
)
|
||||
per_chunk_elem = n_elem // n_chunks
|
||||
isize = _itemsize(dtype)
|
||||
per_chunk_bytes = per_chunk_elem * isize
|
||||
return [
|
||||
Chunk(
|
||||
addr=base_addr + i * per_chunk_bytes,
|
||||
n_elem=per_chunk_elem,
|
||||
nbytes=per_chunk_bytes,
|
||||
)
|
||||
for i in range(n_chunks)
|
||||
]
|
||||
|
||||
|
||||
# ── ring_step ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def ring_step(rank: int, step: int, world_size: int) -> tuple[int, int]:
|
||||
"""Return ``(send_chunk_idx, recv_chunk_idx)`` for a ring algorithm step.
|
||||
|
||||
Standard reduce-scatter / all-gather ring schedule:
|
||||
at step s, rank r sends chunk (r - s) and receives chunk (r - s - 1)
|
||||
modulo world_size.
|
||||
|
||||
Used by ring all-reduce kernels:
|
||||
|
||||
for step in range(world_size - 1):
|
||||
send_idx, recv_idx = ring_step(rank, step, world_size)
|
||||
tl.send(dir="E", src=chunks[send_idx])
|
||||
chunks[recv_idx] += tl.recv(dir="W").data
|
||||
"""
|
||||
send_idx = (rank - step) % world_size
|
||||
recv_idx = (rank - step - 1) % world_size
|
||||
return send_idx, recv_idx
|
||||
|
||||
|
||||
# ── tree_step ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def tree_step(rank: int, world_size: int) -> dict[str, Any]:
|
||||
"""Return parent/children for binary tree rooted at rank 0.
|
||||
|
||||
Returns:
|
||||
``{"parent": int|None, "children": list[int]}``
|
||||
"""
|
||||
parent = (rank - 1) // 2 if rank > 0 else None
|
||||
children: list[int] = []
|
||||
left = 2 * rank + 1
|
||||
right = 2 * rank + 2
|
||||
if left < world_size:
|
||||
children.append(left)
|
||||
if right < world_size:
|
||||
children.append(right)
|
||||
return {"parent": parent, "children": children}
|
||||
@@ -0,0 +1,266 @@
|
||||
"""IPCQ install plan for AhbmCCLBackend (ADR-0023 D10/D11/D12).
|
||||
|
||||
Given a ccl.yaml config, the topology, and the engine, this module:
|
||||
|
||||
1. Loads ccl.yaml and resolves the chosen algorithm.
|
||||
2. Maps each rank to a (sip, cube, pe) PE address using a linear scheme.
|
||||
3. Allocates per-rank IPCQ ring buffer base addresses (synthetic but
|
||||
unique-per-PE; see notes below).
|
||||
4. Builds neighbor tables via the algorithm's ``topology`` field plus the
|
||||
optional ``neighbors()`` override hook from the algorithm module.
|
||||
5. Wires bidirectional credit-return SimPy Stores between every (PE, peer)
|
||||
pair.
|
||||
6. Installs each PE_IPCQ component's neighbor table directly via its
|
||||
``_install_neighbors`` sideband call (equivalent to fan-out IpcqInitMsg
|
||||
without going through fabric).
|
||||
|
||||
Address scheme
|
||||
--------------
|
||||
For the first implementation we use a synthetic address scheme that
|
||||
guarantees uniqueness per (sip, cube, pe, direction) without going
|
||||
through ``PEMemAllocator``. The address is encoded as:
|
||||
|
||||
base = IPCQ_BASE | (sip << 40) | (cube << 32) | (pe << 24)
|
||||
rx_base[direction_idx] = base + direction_idx * (n_slots * slot_size)
|
||||
|
||||
The ``buffer_kind`` (tcm/hbm/sram) selects the *MemoryStore space* into
|
||||
which data is written. Within a space, addresses are unique per PE so
|
||||
the existing MemoryStore (``{space: {addr: ndarray}}``) handles them
|
||||
naturally.
|
||||
|
||||
This bypasses the topology's address resolver / PhysAddr encoding and
|
||||
treats IPCQ buffers as a separate, parallel address namespace. Real PA
|
||||
encoding can be plugged in later without changing the rest of the design.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import simpy
|
||||
import yaml
|
||||
|
||||
from kernbench.ccl.topologies import resolve_topology
|
||||
from kernbench.common.ipcq_types import (
|
||||
IpcqEndpoint,
|
||||
IpcqInitEntry,
|
||||
)
|
||||
from kernbench.runtime_api.kernel import IpcqInitMsg
|
||||
|
||||
|
||||
# IPCQ synthetic address space top bit
|
||||
_IPCQ_BASE = 1 << 60
|
||||
|
||||
|
||||
def _ipcq_base_for_pe(sip: int, cube: int, pe: int) -> int:
|
||||
return _IPCQ_BASE | (sip << 40) | (cube << 32) | (pe << 24)
|
||||
|
||||
|
||||
# ── ccl.yaml loading ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def load_ccl_config(path: str | Path | None = None) -> dict:
|
||||
"""Load and validate ccl.yaml. Searches cwd and project root."""
|
||||
if path is None:
|
||||
candidates = [
|
||||
Path.cwd() / "ccl.yaml",
|
||||
Path(__file__).resolve().parents[3] / "ccl.yaml",
|
||||
]
|
||||
for p in candidates:
|
||||
if p.exists():
|
||||
path = p
|
||||
break
|
||||
if path is None:
|
||||
raise FileNotFoundError(
|
||||
"ccl.yaml not found. Place it at project root or cwd."
|
||||
)
|
||||
with open(path) as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
if "defaults" not in cfg:
|
||||
raise ValueError("ccl.yaml missing 'defaults' section")
|
||||
if "algorithms" not in cfg:
|
||||
raise ValueError("ccl.yaml missing 'algorithms' section")
|
||||
return cfg
|
||||
|
||||
|
||||
def resolve_algorithm_config(cfg: dict, name: str | None = None) -> dict:
|
||||
"""Merge defaults with the chosen algorithm's overrides.
|
||||
|
||||
Returns a flat dict with at minimum: module, topology, buffer_kind,
|
||||
backpressure, n_slots, slot_size, ipcq_credit_size_bytes, world_size.
|
||||
"""
|
||||
defaults = dict(cfg.get("defaults", {}))
|
||||
algo_name = name or defaults.get("algorithm")
|
||||
if algo_name is None:
|
||||
raise ValueError("ccl.yaml: defaults.algorithm not set")
|
||||
algos = cfg.get("algorithms", {})
|
||||
if algo_name not in algos:
|
||||
raise ValueError(
|
||||
f"ccl.yaml: algorithm '{algo_name}' not in algorithms section"
|
||||
)
|
||||
merged = defaults.copy()
|
||||
merged.update(algos[algo_name])
|
||||
merged["algorithm"] = algo_name
|
||||
return merged
|
||||
|
||||
|
||||
# ── rank → PE mapping ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def linear_rank_to_pe(rank: int, spec: dict) -> tuple[int, int, int]:
|
||||
"""Map a rank to (sip, cube, pe) using linear topology order."""
|
||||
sips = spec["system"]["sips"]["count"]
|
||||
cubes_per_sip = spec["sip"]["cube_mesh"]["w"] * spec["sip"]["cube_mesh"]["h"]
|
||||
pe_layout = spec["cube"]["pe_layout"]
|
||||
pes_per_cube = pe_layout["pe_per_corner"] * len(pe_layout["corners"])
|
||||
|
||||
pes_per_sip = cubes_per_sip * pes_per_cube
|
||||
if rank >= sips * pes_per_sip:
|
||||
raise ValueError(
|
||||
f"rank {rank} exceeds total PE count {sips * pes_per_sip}"
|
||||
)
|
||||
sip = rank // pes_per_sip
|
||||
rem = rank % pes_per_sip
|
||||
cube = rem // pes_per_cube
|
||||
pe = rem % pes_per_cube
|
||||
return sip, cube, pe
|
||||
|
||||
|
||||
# ── Install plan ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def install_ipcq(
|
||||
engine: Any,
|
||||
spec: dict,
|
||||
cfg: dict,
|
||||
algo_module: Any | None = None,
|
||||
rank_to_pe: list[tuple[int, int, int]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build neighbor tables and install them in every participating PE_IPCQ.
|
||||
|
||||
Args:
|
||||
engine: GraphEngine with ``_components`` dict
|
||||
spec: topology spec dict
|
||||
cfg: merged algorithm config (from ``resolve_algorithm_config``)
|
||||
algo_module: optional algorithm Python module (for neighbors override)
|
||||
rank_to_pe: optional explicit rank → (sip, cube, pe) mapping. If
|
||||
None, the default linear mapping is used.
|
||||
|
||||
Returns:
|
||||
A diagnostics dict with the install plan (rank → PE map, neighbor table).
|
||||
"""
|
||||
if "world_size" in cfg:
|
||||
world_size = int(cfg["world_size"])
|
||||
else:
|
||||
# Topology-derived fallback (mirrors AhbmCCLBackend / RuntimeContext).
|
||||
sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
cm = spec.get("sip", {}).get("cube_mesh", {})
|
||||
cubes_per_sip = int(cm.get("w", 1)) * int(cm.get("h", 1))
|
||||
pl = spec.get("cube", {}).get("pe_layout", {})
|
||||
corners = pl.get("corners", [])
|
||||
pe_per_corner = int(pl.get("pe_per_corner", 1))
|
||||
pes_per_cube = pe_per_corner * max(len(corners), 1)
|
||||
world_size = sips * cubes_per_sip * pes_per_cube
|
||||
buffer_kind = cfg["buffer_kind"]
|
||||
n_slots = int(cfg["n_slots"])
|
||||
slot_size = int(cfg["slot_size"])
|
||||
backpressure = cfg["backpressure"]
|
||||
credit_size_bytes = int(cfg.get("ipcq_credit_size_bytes", 16))
|
||||
|
||||
# Step 1: rank → (sip, cube, pe)
|
||||
if rank_to_pe is not None:
|
||||
if len(rank_to_pe) != world_size:
|
||||
raise ValueError(
|
||||
f"rank_to_pe has {len(rank_to_pe)} entries but world_size={world_size}"
|
||||
)
|
||||
rank_pe = list(rank_to_pe)
|
||||
else:
|
||||
rank_pe: list[tuple[int, int, int]] = [
|
||||
linear_rank_to_pe(r, spec) for r in range(world_size)
|
||||
]
|
||||
pe_to_rank = {(s, c, p): r for r, (s, c, p) in enumerate(rank_pe)}
|
||||
|
||||
# Step 2: resolve topology fn (with optional override)
|
||||
topo_fn = resolve_topology(cfg["topology"], algo_module=algo_module)
|
||||
|
||||
# Build per-rank neighbor map
|
||||
neighbor_table: dict[int, dict[str, int]] = {}
|
||||
for r in range(world_size):
|
||||
neighbor_table[r] = topo_fn(r, world_size)
|
||||
|
||||
# Step 3: pull the live engine reference for each PE_IPCQ
|
||||
components = engine._components
|
||||
pe_ipcq_id = lambda s, c, p: f"sip{s}.cube{c}.pe{p}.pe_ipcq"
|
||||
|
||||
# Step 4: per-PE rx_base address and per-PE credit_inbox
|
||||
direction_keys = sorted({d for nt in neighbor_table.values() for d in nt})
|
||||
direction_idx = {d: i for i, d in enumerate(direction_keys)}
|
||||
bytes_per_direction = n_slots * slot_size
|
||||
|
||||
def rx_base(s: int, c: int, p: int, d: str) -> int:
|
||||
return _ipcq_base_for_pe(s, c, p) + direction_idx[d] * bytes_per_direction
|
||||
|
||||
# Wire bidirectional credit stores: backend creates the SimPy Stores
|
||||
# by reading each rank's PE_IPCQ.credit_inbox property.
|
||||
rank_to_credit_inbox: dict[int, simpy.Store] = {}
|
||||
for r, (s, c, p) in enumerate(rank_pe):
|
||||
comp = components[pe_ipcq_id(s, c, p)]
|
||||
# Trigger lazy creation of credit_inbox if not yet started.
|
||||
# PE_IPCQ.start() creates it; we ensure it exists.
|
||||
if comp._credit_inbox is None:
|
||||
comp._credit_inbox = simpy.Store(engine._env)
|
||||
rank_to_credit_inbox[r] = comp.credit_inbox
|
||||
|
||||
# Step 5: build IpcqInitMsg per rank and call _install_neighbors directly
|
||||
plan: dict[str, Any] = {
|
||||
"world_size": world_size,
|
||||
"rank_to_pe": rank_pe,
|
||||
"buffer_kind": buffer_kind,
|
||||
"neighbor_table": neighbor_table,
|
||||
}
|
||||
|
||||
def reverse_direction(my_rank: int, peer_rank: int) -> str | None:
|
||||
"""Find which direction in peer's neighbor table points back to my_rank."""
|
||||
for d, target in neighbor_table[peer_rank].items():
|
||||
if target == my_rank:
|
||||
return d
|
||||
return None
|
||||
|
||||
for r, (s, c, p) in enumerate(rank_pe):
|
||||
my_pe_ipcq = components[pe_ipcq_id(s, c, p)]
|
||||
nbrs = neighbor_table[r]
|
||||
entries: list[IpcqInitEntry] = []
|
||||
for d, peer_rank in nbrs.items():
|
||||
if peer_rank is None:
|
||||
continue
|
||||
peer_s, peer_c, peer_p = rank_pe[peer_rank]
|
||||
peer_dir = reverse_direction(r, peer_rank)
|
||||
if peer_dir is None:
|
||||
# Peer doesn't have a reverse entry — skip (asymmetric topology)
|
||||
continue
|
||||
peer_endpoint = IpcqEndpoint(
|
||||
sip=peer_s, cube=peer_c, pe=peer_p,
|
||||
buffer_kind=buffer_kind,
|
||||
rx_base_pa=rx_base(peer_s, peer_c, peer_p, peer_dir),
|
||||
rx_base_va=0,
|
||||
n_slots=n_slots, slot_size=slot_size,
|
||||
)
|
||||
entries.append(IpcqInitEntry(
|
||||
direction=d,
|
||||
peer=peer_endpoint,
|
||||
my_rx_base_pa=rx_base(s, c, p, d),
|
||||
my_rx_base_va=0,
|
||||
n_slots=n_slots, slot_size=slot_size,
|
||||
peer_credit_store=rank_to_credit_inbox[peer_rank],
|
||||
))
|
||||
msg = IpcqInitMsg(
|
||||
correlation_id="ccl_init", request_id=f"init_r{r}",
|
||||
target_sips=(s,), target_cubes=(c,), target_pe=p,
|
||||
entries=tuple(entries),
|
||||
backpressure_mode=backpressure,
|
||||
buffer_kind=buffer_kind,
|
||||
credit_size_bytes=credit_size_bytes,
|
||||
)
|
||||
my_pe_ipcq._install_neighbors(msg)
|
||||
|
||||
return plan
|
||||
@@ -0,0 +1,465 @@
|
||||
"""Mock CCL runtime for fast unit tests of algorithm kernels (ADR-0023 D15).
|
||||
|
||||
Runs a kernel function once per rank with a minimal ``tl`` shim — no SimPy,
|
||||
no PE_DMA, no fabric simulation. Just enough to verify *functional*
|
||||
correctness of an IPCQ-based collective algorithm.
|
||||
|
||||
Cross-rank send/recv is implemented with greenlet cooperative scheduling
|
||||
plus per-(rank, direction) FIFO queues. Backpressure is not modeled —
|
||||
queues are unbounded.
|
||||
|
||||
Typical usage in a test::
|
||||
|
||||
from kernbench.ccl.testing import run_kernel_in_mock
|
||||
from kernbench.ccl.algorithms.ring_allreduce import kernel
|
||||
|
||||
inputs = [np.full(16, r + 1, dtype="f16") for r in range(4)]
|
||||
outputs = run_kernel_in_mock(
|
||||
kernel_fn=kernel, world_size=4, topology="ring_1d",
|
||||
inputs=inputs, kernel_args=(16,),
|
||||
)
|
||||
for r in range(4):
|
||||
assert np.allclose(outputs[r], sum(inputs))
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
from greenlet import greenlet
|
||||
|
||||
from kernbench.ccl.topologies import resolve_topology
|
||||
from kernbench.common.ipcq_types import IpcqInvalidDirection
|
||||
from kernbench.common.pe_commands import TensorHandle
|
||||
|
||||
|
||||
# ── Per-rank fake state ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockRankState:
|
||||
"""Per-rank scratch holding HBM/recv slots and tl shim hooks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
neighbors: dict[str, int],
|
||||
input_arr: np.ndarray,
|
||||
) -> None:
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.neighbors = neighbors # direction → peer rank
|
||||
# HBM "memory": addr → ndarray. Per-rank, no cross-rank sharing.
|
||||
self._hbm: dict[int, np.ndarray] = {}
|
||||
self._tcm: dict[int, np.ndarray] = {}
|
||||
# ``t_ptr`` is the address the kernel sees. Real benches use a
|
||||
# column-sharded VA so each rank reads from ``t_ptr + rank*nbytes``.
|
||||
# Mirror that here: each rank's slice lives at the rank-specific addr.
|
||||
nbytes = int(input_arr.nbytes)
|
||||
self.t_ptr = 0 # base; per-rank offset is rank * nbytes
|
||||
self._slice_addr = rank * nbytes
|
||||
self._hbm[self._slice_addr] = input_arr.copy()
|
||||
# Inbound recv FIFOs: direction → deque[ndarray]
|
||||
self.recv_q: dict[str, deque[np.ndarray]] = {d: deque() for d in neighbors}
|
||||
# Output (set when kernel calls tl.store at slice address)
|
||||
self.output: np.ndarray | None = None
|
||||
# Greenlet for this rank — set later
|
||||
self.g: greenlet | None = None
|
||||
|
||||
|
||||
# ── Mock TLContext ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockTL:
|
||||
"""Drop-in tl shim for mock runtime.
|
||||
|
||||
Supports the subset of TLContext API that algorithm authors use:
|
||||
program_id, num_programs, load, store, send, recv, recv_async, wait,
|
||||
plus arithmetic operations on TensorHandle (eager numpy execution,
|
||||
no SimPy involved).
|
||||
"""
|
||||
|
||||
def __init__(self, state: _MockRankState, scheduler: "_MockScheduler") -> None:
|
||||
self._state = state
|
||||
self._scheduler = scheduler
|
||||
self._handle_counter = 0
|
||||
|
||||
def _next_id(self) -> str:
|
||||
self._handle_counter += 1
|
||||
return f"mt{self._handle_counter}"
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
return self._state.rank
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return self._state.world_size
|
||||
|
||||
# axis-aware
|
||||
def program_id(self, axis: int = 0) -> int:
|
||||
return self._state.rank if axis == 0 else 0
|
||||
|
||||
def num_programs(self, axis: int = 0) -> int:
|
||||
return self._state.world_size if axis == 0 else 1
|
||||
|
||||
# ── arithmetic ops (called by TensorHandle.__add__ etc.) ──
|
||||
|
||||
def _binary_math(self, op: str, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
a_data = np.asarray(a.data) if a.data is not None else None
|
||||
b_data = np.asarray(b.data) if b.data is not None else None
|
||||
if a_data is None or b_data is None:
|
||||
result = None
|
||||
elif op == "add":
|
||||
result = a_data + b_data
|
||||
elif op == "sub":
|
||||
result = a_data - b_data
|
||||
elif op == "mul":
|
||||
result = a_data * b_data
|
||||
elif op == "div":
|
||||
result = a_data / b_data
|
||||
elif op == "maximum":
|
||||
result = np.maximum(a_data, b_data)
|
||||
elif op == "minimum":
|
||||
result = np.minimum(a_data, b_data)
|
||||
else:
|
||||
raise NotImplementedError(f"mock _binary_math: op {op!r} not implemented")
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=a.shape, dtype=a.dtype,
|
||||
nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def maximum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
return self._binary_math("maximum", a, b)
|
||||
|
||||
def minimum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
return self._binary_math("minimum", a, b)
|
||||
|
||||
def fma(
|
||||
self, a: TensorHandle, b: TensorHandle, c: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
a_data = np.asarray(a.data) if a.data is not None else None
|
||||
b_data = np.asarray(b.data) if b.data is not None else None
|
||||
c_data = np.asarray(c.data) if c.data is not None else None
|
||||
result = (
|
||||
a_data * b_data + c_data
|
||||
if (a_data is not None and b_data is not None and c_data is not None)
|
||||
else None
|
||||
)
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=a.shape, dtype=a.dtype,
|
||||
nbytes=int(np.prod(a.shape)) * 2 if a.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def clamp(
|
||||
self,
|
||||
x: TensorHandle,
|
||||
min: TensorHandle,
|
||||
max: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
x_data = np.asarray(x.data) if x.data is not None else None
|
||||
lo = np.asarray(min.data) if min.data is not None else None
|
||||
hi = np.asarray(max.data) if max.data is not None else None
|
||||
result = (
|
||||
np.minimum(np.maximum(x_data, lo), hi)
|
||||
if (x_data is not None and lo is not None and hi is not None)
|
||||
else None
|
||||
)
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=x.shape, dtype=x.dtype,
|
||||
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def softmax(self, x: TensorHandle, axis: int = -1) -> TensorHandle:
|
||||
x_data = np.asarray(x.data) if x.data is not None else None
|
||||
if x_data is None:
|
||||
result = None
|
||||
else:
|
||||
x_max = np.max(x_data, axis=axis, keepdims=True)
|
||||
e = np.exp(x_data - x_max)
|
||||
s = np.sum(e, axis=axis, keepdims=True)
|
||||
result = e / s
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=x.shape, dtype=x.dtype,
|
||||
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
return -(-int(a) // int(b))
|
||||
|
||||
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
||||
x_data = np.asarray(x.data) if x.data is not None else None
|
||||
if x_data is None:
|
||||
result = None
|
||||
elif op == "exp":
|
||||
result = np.exp(x_data)
|
||||
elif op == "log":
|
||||
result = np.log(x_data)
|
||||
elif op == "sqrt":
|
||||
result = np.sqrt(x_data)
|
||||
elif op == "abs":
|
||||
result = np.abs(x_data)
|
||||
elif op == "sigmoid":
|
||||
result = 1.0 / (1.0 + np.exp(-x_data))
|
||||
elif op == "cos":
|
||||
result = np.cos(x_data)
|
||||
elif op == "sin":
|
||||
result = np.sin(x_data)
|
||||
else:
|
||||
raise NotImplementedError(f"mock _unary_math: op {op!r} not implemented")
|
||||
return TensorHandle(
|
||||
id=self._next_id(),
|
||||
addr=0, shape=x.shape, dtype=x.dtype,
|
||||
nbytes=int(np.prod(x.shape)) * 2 if x.shape else 0,
|
||||
data=result, space="tcm",
|
||||
)
|
||||
|
||||
def load(self, ptr: int, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
|
||||
data = self._state._hbm.get(ptr)
|
||||
if data is None:
|
||||
data = np.zeros(shape, dtype=np.float16)
|
||||
return TensorHandle(
|
||||
id=f"load_{ptr}", addr=ptr, shape=shape, dtype=dtype,
|
||||
nbytes=int(np.prod(shape)) * 2, data=data, space="hbm",
|
||||
)
|
||||
|
||||
def store(self, ptr: int, handle: TensorHandle) -> None:
|
||||
if handle.data is not None:
|
||||
self._state._hbm[ptr] = np.asarray(handle.data)
|
||||
if ptr == self._state._slice_addr:
|
||||
self._state.output = self._state._hbm[ptr]
|
||||
|
||||
# IPCQ
|
||||
def send(
|
||||
self,
|
||||
dir: str,
|
||||
src: TensorHandle | None = None,
|
||||
*,
|
||||
src_addr: int | None = None,
|
||||
nbytes: int | None = None,
|
||||
shape: tuple[int, ...] | None = None,
|
||||
dtype: str = "f16",
|
||||
space: str = "tcm",
|
||||
) -> None:
|
||||
if dir not in self._state.neighbors:
|
||||
raise IpcqInvalidDirection(
|
||||
f"mock tl.send: direction {dir!r} not in neighbors {list(self._state.neighbors)}"
|
||||
)
|
||||
if src is not None:
|
||||
if src.data is not None:
|
||||
data = np.asarray(src.data)
|
||||
else:
|
||||
# Resolve from this rank's local memory at src.addr
|
||||
space_dict = self._state._hbm if src.space == "hbm" else self._state._tcm
|
||||
stored = space_dict.get(src.addr)
|
||||
if stored is None:
|
||||
raise RuntimeError(
|
||||
f"mock tl.send: no data at {src.space}:0x{src.addr:x}"
|
||||
)
|
||||
data = np.asarray(stored)
|
||||
else:
|
||||
data = None
|
||||
if data is None:
|
||||
raise RuntimeError("mock tl.send: src is None")
|
||||
peer_rank = self._state.neighbors[dir]
|
||||
# Find the reverse direction in peer's neighbors that points back to me
|
||||
peer_state = self._scheduler.states[peer_rank]
|
||||
reverse_dir = None
|
||||
for d, target in peer_state.neighbors.items():
|
||||
if target == self._state.rank:
|
||||
reverse_dir = d
|
||||
break
|
||||
if reverse_dir is None:
|
||||
raise RuntimeError(
|
||||
f"mock tl.send: peer rank {peer_rank} has no reverse direction"
|
||||
)
|
||||
peer_state.recv_q[reverse_dir].append(data.copy())
|
||||
# After delivering, hand control back to scheduler so the receiver
|
||||
# can wake up.
|
||||
self._scheduler.yield_()
|
||||
|
||||
def recv_async(
|
||||
self,
|
||||
dir: str,
|
||||
shape: tuple[int, ...] = (),
|
||||
dtype: str = "f16",
|
||||
) -> dict:
|
||||
"""Non-blocking recv. Returns a future dict to pass to tl.wait."""
|
||||
if dir not in self._state.neighbors:
|
||||
raise IpcqInvalidDirection(
|
||||
f"mock tl.recv_async: direction {dir!r} not in neighbors"
|
||||
)
|
||||
return {"_kind": "recv_future", "dir": dir, "shape": shape, "dtype": dtype}
|
||||
|
||||
def wait(self, future: Any) -> TensorHandle:
|
||||
"""Block until the recv future has data."""
|
||||
if not isinstance(future, dict) or future.get("_kind") != "recv_future":
|
||||
raise TypeError("tl.wait: expected recv future from tl.recv_async")
|
||||
d = future["dir"]
|
||||
while not self._state.recv_q[d]:
|
||||
self._scheduler.yield_()
|
||||
data = self._state.recv_q[d].popleft()
|
||||
return self._make_handle(data, d, future["dtype"])
|
||||
|
||||
def recv(
|
||||
self,
|
||||
dir: str | None = None,
|
||||
shape: tuple[int, ...] = (),
|
||||
dtype: str = "f16",
|
||||
) -> TensorHandle:
|
||||
if dir is not None and dir not in self._state.neighbors:
|
||||
raise IpcqInvalidDirection(
|
||||
f"mock tl.recv: direction {dir!r} not in neighbors {list(self._state.neighbors)}"
|
||||
)
|
||||
# Wait for data
|
||||
while True:
|
||||
if dir is None:
|
||||
# round-robin over directions
|
||||
for d in self._state.neighbors:
|
||||
if self._state.recv_q[d]:
|
||||
data = self._state.recv_q[d].popleft()
|
||||
return self._make_handle(data, d, dtype)
|
||||
else:
|
||||
if self._state.recv_q[dir]:
|
||||
data = self._state.recv_q[dir].popleft()
|
||||
return self._make_handle(data, dir, dtype)
|
||||
# Yield to other ranks
|
||||
self._scheduler.yield_()
|
||||
|
||||
def _make_handle(self, data: np.ndarray, direction: str, dtype: str) -> TensorHandle:
|
||||
return TensorHandle(
|
||||
id=f"recv_{direction}",
|
||||
addr=0, shape=data.shape, dtype=dtype,
|
||||
nbytes=int(data.nbytes), data=data, space="tcm",
|
||||
)
|
||||
|
||||
|
||||
# ── Cooperative scheduler ────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockScheduler:
|
||||
"""Round-robin cooperative scheduler over rank greenlets."""
|
||||
|
||||
def __init__(self, states: list[_MockRankState]) -> None:
|
||||
self.states = states
|
||||
self._parent: greenlet | None = None
|
||||
self._cur_idx = 0
|
||||
|
||||
def yield_(self) -> None:
|
||||
"""Called from inside a rank greenlet to give other ranks a turn."""
|
||||
assert self._parent is not None
|
||||
self._parent.switch()
|
||||
|
||||
def run(self, kernel_fn: Callable, kernel_args: tuple) -> list[np.ndarray]:
|
||||
from kernbench.triton_emu.tl_context import TLContext
|
||||
|
||||
self._parent = greenlet.getcurrent()
|
||||
n = len(self.states)
|
||||
|
||||
# Per-rank tl shim
|
||||
tls: dict[int, _MockTL] = {}
|
||||
|
||||
def _spawn(rank_idx: int) -> greenlet:
|
||||
state = self.states[rank_idx]
|
||||
tl = _MockTL(state, self)
|
||||
tls[rank_idx] = tl
|
||||
|
||||
def _entry():
|
||||
# Activate this rank's tl for TensorHandle operator overloads
|
||||
TLContext._set_active(tl) # type: ignore[attr-defined]
|
||||
try:
|
||||
kernel_fn(state.t_ptr, *kernel_args, tl=tl)
|
||||
finally:
|
||||
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||
|
||||
return greenlet(_entry)
|
||||
|
||||
for state in self.states:
|
||||
state.g = _spawn(state.rank)
|
||||
|
||||
# Drive each rank round-robin until all dead. Detect global deadlock.
|
||||
max_rounds = 10_000
|
||||
round_no = 0
|
||||
while True:
|
||||
alive = [s for s in self.states if s.g is not None and not s.g.dead]
|
||||
if not alive:
|
||||
break
|
||||
progressed = False
|
||||
for s in self.states:
|
||||
if s.g is None or s.g.dead:
|
||||
continue
|
||||
# Multi-rank greenlets share TLContext active state via the
|
||||
# module-level thread-local; restore this rank's tl before
|
||||
# resuming so TensorHandle operator overloads dispatch to
|
||||
# the right _MockTL.
|
||||
TLContext._set_active(tls[s.rank]) # type: ignore[attr-defined]
|
||||
s.g.switch()
|
||||
if s.g.dead:
|
||||
progressed = True
|
||||
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||
# Loose progress check: if no greenlet died and queues didn't grow,
|
||||
# advance round counter; abort after too many idle rounds.
|
||||
round_no += 1
|
||||
if round_no > max_rounds and not progressed:
|
||||
raise RuntimeError(
|
||||
"mock CCL runtime: deadlock detected (no progress for "
|
||||
f"{max_rounds} rounds)"
|
||||
)
|
||||
|
||||
return [
|
||||
s.output if s.output is not None else s._hbm.get(s._slice_addr)
|
||||
for s in self.states
|
||||
]
|
||||
|
||||
|
||||
# ── Public entry ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def run_kernel_in_mock(
|
||||
kernel_fn: Callable,
|
||||
world_size: int,
|
||||
topology: str,
|
||||
inputs: list[np.ndarray],
|
||||
kernel_args: tuple = (),
|
||||
algo_module: Any | None = None,
|
||||
) -> list[np.ndarray]:
|
||||
"""Run a CCL kernel under the mock runtime with no SimPy/fabric.
|
||||
|
||||
Args:
|
||||
kernel_fn: ``kernel(t_ptr, *kernel_args, tl=...)``
|
||||
world_size: number of ranks
|
||||
topology: builtin topology name (e.g. "ring_1d")
|
||||
inputs: per-rank input ndarrays. ``inputs[r]`` becomes rank r's
|
||||
local tile at HBM address 0.
|
||||
kernel_args: extra positional args after t_ptr
|
||||
algo_module: optional module providing ``neighbors()`` override
|
||||
|
||||
Returns:
|
||||
Per-rank output ndarrays — whatever the kernel wrote via tl.store
|
||||
(or the original input if the kernel didn't store).
|
||||
"""
|
||||
if len(inputs) != world_size:
|
||||
raise ValueError(f"len(inputs)={len(inputs)} != world_size={world_size}")
|
||||
|
||||
topo_fn = resolve_topology(topology, algo_module=algo_module)
|
||||
states = [
|
||||
_MockRankState(
|
||||
rank=r, world_size=world_size,
|
||||
neighbors=topo_fn(r, world_size),
|
||||
input_arr=inputs[r],
|
||||
)
|
||||
for r in range(world_size)
|
||||
]
|
||||
|
||||
sched = _MockScheduler(states)
|
||||
return sched.run(kernel_fn, kernel_args)
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Builtin neighbor topology generators for CCL backend (ADR-0023 D11).
|
||||
|
||||
Each generator takes ``(rank, world_size)`` and returns a
|
||||
``dict[direction, peer_rank]`` for that rank. ``direction`` is one of
|
||||
``"N" | "S" | "E" | "W"`` for ring/mesh, or
|
||||
``"parent" | "child_left" | "child_right"`` for tree topologies.
|
||||
|
||||
Algorithm modules may override the generated map by defining a
|
||||
``neighbors(rank, world_size, neighbor_map) -> dict | None`` function in
|
||||
the same module (see D11 / D15). ``resolve_topology`` wires these together.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
NeighborMap = dict[str, int]
|
||||
TopologyFn = Callable[[int, int], NeighborMap]
|
||||
|
||||
|
||||
# ── Builtin generators ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def ring_1d(rank: int, world_size: int) -> NeighborMap:
|
||||
"""1D bidirectional ring (E/W)."""
|
||||
return {
|
||||
"E": (rank + 1) % world_size,
|
||||
"W": (rank - 1) % world_size,
|
||||
}
|
||||
|
||||
|
||||
def ring_1d_unidir(rank: int, world_size: int) -> NeighborMap:
|
||||
"""1D unidirectional ring (E only)."""
|
||||
return {"E": (rank + 1) % world_size}
|
||||
|
||||
|
||||
def mesh_2d(rank: int, world_size: int) -> NeighborMap:
|
||||
"""Square 2D mesh (N/S/E/W).
|
||||
|
||||
Layout: rank = row * side + col, with side = sqrt(world_size).
|
||||
Wrap-around (torus) on all four edges.
|
||||
"""
|
||||
side = int(round(world_size ** 0.5))
|
||||
if side * side != world_size:
|
||||
raise ValueError(
|
||||
f"mesh_2d requires square world_size, got {world_size}"
|
||||
)
|
||||
r, c = divmod(rank, side)
|
||||
return {
|
||||
"N": ((r - 1) % side) * side + c,
|
||||
"S": ((r + 1) % side) * side + c,
|
||||
"W": r * side + (c - 1) % side,
|
||||
"E": r * side + (c + 1) % side,
|
||||
}
|
||||
|
||||
|
||||
def tree_binary(rank: int, world_size: int) -> NeighborMap:
|
||||
"""Binary tree rooted at rank 0.
|
||||
|
||||
Children of rank r are 2r+1 and 2r+2 (if within world_size).
|
||||
Parent of rank r > 0 is (r-1)//2.
|
||||
Returned keys (only those that exist):
|
||||
"parent", "child_left", "child_right"
|
||||
"""
|
||||
n: NeighborMap = {}
|
||||
if rank > 0:
|
||||
n["parent"] = (rank - 1) // 2
|
||||
left = 2 * rank + 1
|
||||
right = 2 * rank + 2
|
||||
if left < world_size:
|
||||
n["child_left"] = left
|
||||
if right < world_size:
|
||||
n["child_right"] = right
|
||||
return n
|
||||
|
||||
|
||||
def none(rank: int, world_size: int) -> NeighborMap:
|
||||
"""Empty map — algorithm's neighbors() must build from scratch."""
|
||||
return {}
|
||||
|
||||
|
||||
_BUILTIN: dict[str, TopologyFn] = {
|
||||
"ring_1d": ring_1d,
|
||||
"ring_1d_unidir": ring_1d_unidir,
|
||||
"mesh_2d": mesh_2d,
|
||||
"tree_binary": tree_binary,
|
||||
"none": none,
|
||||
}
|
||||
|
||||
|
||||
# ── Resolution ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def resolve_topology(
|
||||
name: str, algo_module: Any | None = None,
|
||||
) -> TopologyFn:
|
||||
"""Return a callable ``(rank, world_size) -> NeighborMap``.
|
||||
|
||||
Args:
|
||||
name: builtin topology name from ccl.yaml. Must be one of
|
||||
``ring_1d``, ``ring_1d_unidir``, ``mesh_2d``, ``tree_binary``,
|
||||
or ``none``.
|
||||
algo_module: optional algorithm module. If it defines
|
||||
``neighbors(rank, world_size, neighbor_map)``, that hook is
|
||||
invoked after the builtin to override the result.
|
||||
Returning None from neighbors() leaves the builtin map
|
||||
unchanged; returning a dict replaces it.
|
||||
|
||||
Raises:
|
||||
ValueError: if ``name`` is not a known builtin.
|
||||
"""
|
||||
if name not in _BUILTIN:
|
||||
raise ValueError(
|
||||
f"Unknown topology '{name}'. "
|
||||
f"Available builtins: {list(_BUILTIN)}"
|
||||
)
|
||||
builtin_fn = _BUILTIN[name]
|
||||
override_fn = getattr(algo_module, "neighbors", None) if algo_module else None
|
||||
if override_fn is None or not callable(override_fn):
|
||||
return builtin_fn
|
||||
|
||||
def _wrapped(rank: int, world_size: int) -> NeighborMap:
|
||||
base = builtin_fn(rank, world_size)
|
||||
result = override_fn(rank, world_size, base)
|
||||
if result is None:
|
||||
return base
|
||||
return result
|
||||
|
||||
return _wrapped
|
||||
@@ -0,0 +1,234 @@
|
||||
"""IPCQ schemas and exceptions (ADR-0023 D2.5, D12, D14 F1).
|
||||
|
||||
This module contains the data structures and exceptions used by the
|
||||
PE-level IPCQ collective communication infrastructure. The host-facing
|
||||
sideband fan-out message ``IpcqInitMsg`` lives in
|
||||
``kernbench.runtime_api.kernel`` (alongside other fabric messages),
|
||||
while all internal token / metadata / command schemas are kept here.
|
||||
|
||||
Layering:
|
||||
PE_CPU --IpcqRequest(IpcqSendCmd|IpcqRecvCmd)--> PE_IPCQ
|
||||
PE_IPCQ --IpcqDmaToken--> PE_DMA (vc_comm)
|
||||
PE_DMA --IpcqMetaArrival--> PE_IPCQ (atomic, D9)
|
||||
PE_IPCQ --IpcqCreditMetadata--> peer PE_IPCQ (fast path, D9)
|
||||
|
||||
See ADR-0023 for the full design.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import simpy
|
||||
|
||||
|
||||
# ── D14 F1: invalid direction exception ──────────────────────────────
|
||||
|
||||
|
||||
class IpcqInvalidDirection(ValueError):
|
||||
"""Raised when a kernel calls tl.send/recv with a direction that
|
||||
has no neighbor installed for this PE."""
|
||||
|
||||
|
||||
# ── D2.5: IpcqEndpoint ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IpcqEndpoint:
|
||||
"""송신 측이 peer's rx_buffer 주소를 계산하기 위해 필요한 모든 정보 (D2.5).
|
||||
|
||||
Sender PE_IPCQ uses this to compute the destination PA for its DMA
|
||||
write into the peer's rx ring buffer slot:
|
||||
|
||||
slot_idx = sender.my_head % peer.n_slots
|
||||
dst_pa = peer.rx_base_pa + slot_idx * peer.slot_size
|
||||
"""
|
||||
|
||||
sip: int # destination SIP
|
||||
cube: int # destination cube
|
||||
pe: int # destination PE (cube-local index)
|
||||
buffer_kind: str # "tcm" | "hbm" | "sram"
|
||||
rx_base_pa: int # peer rx_buffer base PA (PhysAddr.encode())
|
||||
rx_base_va: int # peer rx_buffer base VA (optional, MMU)
|
||||
n_slots: int # peer ring depth (wrap-around modulo)
|
||||
slot_size: int # peer slot size (offset multiplier)
|
||||
|
||||
|
||||
# ── D12: IpcqInitEntry (used by IpcqInitMsg in kernel.py) ────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IpcqInitEntry:
|
||||
"""One direction's neighbor entry that backend installs into a PE_IPCQ
|
||||
via IpcqInitMsg (kernbench.runtime_api.kernel.IpcqInitMsg, D12).
|
||||
"""
|
||||
|
||||
direction: str # "N" | "S" | "E" | "W"
|
||||
peer: IpcqEndpoint # see D2.5
|
||||
my_rx_base_pa: int # this PE's own rx_buffer base
|
||||
my_rx_base_va: int # this PE's own rx_buffer base VA (optional)
|
||||
n_slots: int # this PE's ring depth
|
||||
slot_size: int # this PE's slot size
|
||||
# Credit fast path channel (D9).
|
||||
# Contract: must be a simpy.Store instance dedicated to receiving
|
||||
# IpcqCreditMetadata objects only. Backend wires it once at init time
|
||||
# and the receiving PE_IPCQ owns its consumer side; the sender (peer's
|
||||
# PE_IPCQ) puts IpcqCreditMetadata directly into this store via
|
||||
# _delayed_credit_send. Do not put any other object type.
|
||||
peer_credit_store: "simpy.Store"
|
||||
|
||||
|
||||
# ── D12: IpcqSendCmd (PE_CPU → PE_IPCQ) ──────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IpcqSendCmd:
|
||||
"""tl.send command issued by the kernel to PE_IPCQ."""
|
||||
|
||||
direction: str # "N" | "S" | "E" | "W"
|
||||
src_addr: int # source data address (TCM/HBM/SRAM)
|
||||
src_space: str # "tcm" | "hbm" | "sram"
|
||||
nbytes: int
|
||||
shape: tuple[int, ...] # data shape (op_log + MemoryStore use)
|
||||
dtype: str
|
||||
handle_id: str # completion tracking
|
||||
data_op: bool = True # ADR-0020 op_log recording flag
|
||||
|
||||
|
||||
# ── D12: IpcqRecvCmd (PE_CPU → PE_IPCQ) ──────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IpcqRecvCmd:
|
||||
"""tl.recv command issued by the kernel to PE_IPCQ.
|
||||
|
||||
Two modes (recv_mode):
|
||||
"return_slot" — return slot address as-is (default, zero-copy).
|
||||
Kernel uses the slot memory directly.
|
||||
"copy_to_dst" — copy slot data to dst_addr, then return.
|
||||
"""
|
||||
|
||||
direction: str | None # None → round-robin (weak fairness, D4)
|
||||
shape: tuple[int, ...]
|
||||
dtype: str
|
||||
handle_id: str
|
||||
recv_mode: str = "return_slot"
|
||||
dst_addr: int = 0 # used only when recv_mode == "copy_to_dst"
|
||||
dst_space: str = "" # used only when recv_mode == "copy_to_dst"
|
||||
blocking: bool = True
|
||||
data_op: bool = True
|
||||
|
||||
|
||||
# ── D12: IpcqDmaToken (PE_IPCQ → PE_DMA, vc_comm) ───────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class IpcqDmaToken:
|
||||
"""Token sent from PE_IPCQ to PE_DMA (vc_comm channel) carrying both
|
||||
the data move request and the piggyback metadata (ADR-0023 D9).
|
||||
|
||||
Receiving PE_DMA processes this atomically (I6 MUST):
|
||||
1. MemoryStore.write(dst_endpoint.buffer_kind, dst_addr, data)
|
||||
2. Forward IpcqMetaArrival(token=self) to peer PE_IPCQ
|
||||
No yield is allowed between the two steps.
|
||||
|
||||
The ``data`` field is a snapshot taken by the sender's PE_DMA at the
|
||||
moment the send is issued. This preserves "in-flight data" semantics:
|
||||
if the sender mutates its source memory after issuing the send but
|
||||
before arrival, the receiver still gets the snapshot. The snapshot is
|
||||
None for control-only tokens (e.g. credit-only updates).
|
||||
"""
|
||||
|
||||
# ── Data movement (single-hop DMA write) ──
|
||||
src_addr: int
|
||||
src_space: str
|
||||
dst_addr: int # already-computed peer rx slot PA
|
||||
dst_endpoint: IpcqEndpoint # routing target (sip/cube/pe) + buffer_kind
|
||||
nbytes: int
|
||||
handle_id: str # completion notify back to sender PE_IPCQ
|
||||
# Optional shape/dtype carried for op_log + MemoryStore convenience.
|
||||
shape: tuple[int, ...] = ()
|
||||
dtype: str = "f16"
|
||||
# In-flight data snapshot (sender PE_DMA captures this at send time).
|
||||
data: Any = None
|
||||
|
||||
# ── Piggyback metadata (D9) ──
|
||||
sender_seq: int = 0 # monotonic; receiver updates peer_head_cache
|
||||
src_sip: int = 0
|
||||
src_cube: int = 0
|
||||
src_pe: int = 0
|
||||
src_direction: str = "E" # sender-side direction; receiver maps to its own
|
||||
|
||||
data_op: bool = True
|
||||
|
||||
|
||||
# ── D12: IpcqMetaArrival (PE_DMA → PE_IPCQ, intra-PE wire) ──────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class IpcqMetaArrival:
|
||||
"""Posted by receiving PE_DMA into the destination PE's PE_IPCQ inbox
|
||||
in the same SimPy step as the MemoryStore.write (D9, I6 MUST).
|
||||
|
||||
The receiver PE_IPCQ uses ``token.sender_seq`` to update its
|
||||
peer_head_cache for the corresponding direction.
|
||||
"""
|
||||
|
||||
token: IpcqDmaToken
|
||||
|
||||
|
||||
# ── D12: IpcqCreditMetadata (PE_IPCQ → peer PE_IPCQ, fast path) ─────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IpcqCreditMetadata:
|
||||
"""Credit return — recv-side → send-side fast path (D9).
|
||||
|
||||
Sent by ``PeIpcqComponent._delayed_credit_send`` after a
|
||||
bottleneck-BW based latency, putting the metadata directly into
|
||||
the peer's pre-wired credit store (no fabric routing).
|
||||
"""
|
||||
|
||||
consumer_seq: int # my_tail at recv side (new tail value)
|
||||
src_sip: int # which peer is sending the credit
|
||||
src_cube: int
|
||||
src_pe: int
|
||||
src_direction: str # sender-side direction (peer maps to its own)
|
||||
|
||||
|
||||
# ── Request wrapper (PE_CPU → PE_IPCQ) ───────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class IpcqRequest:
|
||||
"""Wrapper carrying an IpcqSendCmd or IpcqRecvCmd plus a SimPy completion
|
||||
event. Posted by PE_CPU into PE_IPCQ's inbox; PE_IPCQ calls
|
||||
``done.succeed()`` when the request is fully processed.
|
||||
|
||||
For recv requests, the result (slot address, direction, dtype, shape)
|
||||
is written into ``result_data`` so the caller can read it after wait.
|
||||
"""
|
||||
|
||||
command: "IpcqSendCmd | IpcqRecvCmd"
|
||||
done: "simpy.Event"
|
||||
result_data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ── RecvFuture (kernel ↔ runner handshake for tl.recv_async / tl.wait) ─
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecvFuture:
|
||||
"""Opaque future returned by ``tl.recv_async``.
|
||||
|
||||
The KernelRunner attaches a SimPy event and the IpcqRequest in the
|
||||
background; ``tl.wait(future)`` switches back to the runner which
|
||||
yields on the event and resolves the result into a TensorHandle.
|
||||
"""
|
||||
|
||||
cmd: "IpcqRecvCmd"
|
||||
request: Any = None # IpcqRequest (set by runner)
|
||||
event: Any = None # simpy.Event (set by runner)
|
||||
resolved: bool = False
|
||||
result: Any = None # cached TensorHandle after wait()
|
||||
@@ -33,6 +33,7 @@ class TensorHandle:
|
||||
dtype: str
|
||||
nbytes: int # total byte size
|
||||
data: object = None # reserved for validate mode
|
||||
space: str = "tcm" # MemoryStore space ("tcm" | "hbm" | "sram")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -42,9 +42,30 @@ class PeCpuComponent(ComponentBase):
|
||||
self._cube_idx = int(parts[1].replace("cube", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._cube_idx = 0
|
||||
# num_cubes from spec (for tl.program_id(axis=1))
|
||||
# num_cubes from spec (for tl.program_id(axis=1) — ADR-0022)
|
||||
spec = ctx.spec if ctx else {}
|
||||
self._num_cubes = spec.get("system", {}).get("sips", {}).get("cubes_per_sip", 1)
|
||||
cube_mesh = spec.get("sip", {}).get("cube_mesh", {})
|
||||
if cube_mesh:
|
||||
self._num_cubes = int(cube_mesh.get("w", 1)) * int(cube_mesh.get("h", 1))
|
||||
else:
|
||||
self._num_cubes = (
|
||||
spec.get("system", {}).get("sips", {}).get("cubes_per_sip", 1)
|
||||
)
|
||||
# PE-local scratch for kernel math output handles (ADR-0020 D3
|
||||
# extension; reserved portion of TCM addressed via a synthetic
|
||||
# MemoryStore key, not the real PA encoder).
|
||||
pe_template = spec.get("cube", {}).get("pe_template", {})
|
||||
tcm_attrs = pe_template.get("components", {}).get("pe_tcm", {}).get("attrs", {})
|
||||
scratch_mb = float(tcm_attrs.get("kernel_scratch_mb", 1))
|
||||
self._tl_scratch_size = int(scratch_mb * (1 << 20))
|
||||
# PE-unique base address — high bit pattern to avoid collision with
|
||||
# IPCQ ring buffers (which use bit 60).
|
||||
self._tl_scratch_base = (
|
||||
(1 << 61)
|
||||
| (self._sip_idx << 40)
|
||||
| (self._cube_idx << 32)
|
||||
| (self._pe_idx << 24)
|
||||
)
|
||||
|
||||
def _find_shard(self, shards: tuple) -> Any:
|
||||
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
|
||||
@@ -146,6 +167,8 @@ class PeCpuComponent(ComponentBase):
|
||||
scheduler_id=scheduler_id,
|
||||
out_ports=self.out_ports,
|
||||
store=store,
|
||||
scratch_base=self._tl_scratch_base,
|
||||
scratch_size=self._tl_scratch_size,
|
||||
)
|
||||
yield from runner.run(env, kernel_fn, kernel_args, num_programs)
|
||||
return getattr(runner, "_composite_results", [])
|
||||
|
||||
@@ -106,18 +106,131 @@ class PeDmaComponent(PeEngineBase):
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Handle TileToken (pipeline), PeInternalTxn (legacy), and Transaction (fabric)."""
|
||||
"""Handle TileToken (pipeline), PeInternalTxn (legacy), IpcqDmaToken,
|
||||
and Transaction (fabric)."""
|
||||
from kernbench.common.ipcq_types import IpcqDmaToken
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.builtin.pe_types import TileToken
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, TileToken):
|
||||
if isinstance(msg, IpcqDmaToken):
|
||||
# Outbound: IPCQ token from local PE_IPCQ → forward via fabric
|
||||
env.process(self._handle_ipcq_outbound(env, msg))
|
||||
elif isinstance(msg, TileToken):
|
||||
env.process(self._pipeline_process(env, msg))
|
||||
elif isinstance(msg, PeInternalTxn):
|
||||
env.process(self._handle_with_hooks(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
# Transaction (or unknown). May carry IpcqDmaToken inbound.
|
||||
req = getattr(msg, "request", None)
|
||||
if isinstance(req, IpcqDmaToken):
|
||||
env.process(self._handle_ipcq_inbound(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
# ── IPCQ outbound (PE_IPCQ → PE_DMA → fabric) ───────────────────
|
||||
|
||||
def _handle_ipcq_outbound(self, env: simpy.Environment, token: Any) -> Generator:
|
||||
"""Forward IpcqDmaToken from local PE_IPCQ through the fabric to peer
|
||||
PE_DMA. ADR-0023 D8 (vc_comm channel)."""
|
||||
if self.ctx is None:
|
||||
return # nothing to do
|
||||
peer = token.dst_endpoint
|
||||
peer_pe_dma = f"sip{peer.sip}.cube{peer.cube}.pe{peer.pe}.pe_dma"
|
||||
|
||||
# Snapshot the source data at send time (D9 in-flight semantics).
|
||||
# Without this, the receiver could read stale or future data if the
|
||||
# sender mutates src_addr between send issue and DMA arrival.
|
||||
store = getattr(self.ctx, "memory_store", None)
|
||||
if store is not None and token.data is None:
|
||||
try:
|
||||
snap = store.read(
|
||||
token.src_space, token.src_addr,
|
||||
shape=token.shape, dtype=token.dtype,
|
||||
)
|
||||
# Copy so later mutations to src_addr don't affect the snapshot.
|
||||
token.data = snap.copy() if hasattr(snap, "copy") else snap
|
||||
except Exception:
|
||||
token.data = None
|
||||
|
||||
# Record the IPCQ copy in op_log at OUTBOUND time. ADR-0020 D6:
|
||||
# Phase 2 replays the copy in t_start order; using outbound time
|
||||
# (rather than inbound) ensures the copy executes before any later
|
||||
# local op at the sender that might overwrite token.src_addr (e.g.
|
||||
# a tl.store after a recv).
|
||||
if self._op_logger is not None:
|
||||
try:
|
||||
self._op_logger.record_copy(
|
||||
t_start=float(env.now), t_end=float(env.now),
|
||||
component_id=self.node.id,
|
||||
src_space=token.src_space, src_addr=token.src_addr,
|
||||
dst_space=peer.buffer_kind,
|
||||
dst_addr=token.dst_addr,
|
||||
shape=token.shape, dtype=token.dtype, nbytes=token.nbytes,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
path = self.ctx.router.find_path(self._pe_prefix, peer_pe_dma)
|
||||
except Exception:
|
||||
return
|
||||
drain_ns = self.ctx.compute_drain_ns(path, token.nbytes)
|
||||
|
||||
sub_done = env.event()
|
||||
sub_txn = Transaction(
|
||||
request=token, path=path, step=0,
|
||||
nbytes=token.nbytes, done=sub_done, drain_ns=drain_ns,
|
||||
)
|
||||
if len(path) > 1:
|
||||
next_hop = path[1]
|
||||
if next_hop in self.out_ports:
|
||||
yield self.out_ports[next_hop].put(sub_txn.advance())
|
||||
else:
|
||||
return
|
||||
# Note: don't wait on sub_done here — fire-and-forget for vc_comm.
|
||||
# IPCQ slot bookkeeping (peer_head) was already updated by PE_IPCQ;
|
||||
# backpressure is via credit return, not via this DMA's completion.
|
||||
|
||||
# ── IPCQ inbound (fabric → PE_DMA → MemoryStore + PE_IPCQ) ──────
|
||||
|
||||
def _handle_ipcq_inbound(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""At destination PE_DMA: atomically write data and forward metadata.
|
||||
|
||||
I6 (MUST): no SimPy yield between MemoryStore.write and the
|
||||
IpcqMetaArrival put into PE_IPCQ.
|
||||
"""
|
||||
from kernbench.common.ipcq_types import IpcqMetaArrival
|
||||
|
||||
token = txn.request
|
||||
|
||||
# ── ATOMIC: do not introduce yield between these two operations ──
|
||||
# 1. Move data via MemoryStore (single-hop DMA write).
|
||||
# Prefer the in-flight snapshot stashed by the sender PE_DMA;
|
||||
# fall back to a fresh read of src_addr if no snapshot is present
|
||||
# (e.g. control-only token).
|
||||
store = getattr(self.ctx, "memory_store", None) if self.ctx else None
|
||||
if store is not None:
|
||||
try:
|
||||
data = token.data
|
||||
if data is None:
|
||||
data = store.read(
|
||||
token.src_space, token.src_addr,
|
||||
shape=token.shape, dtype=token.dtype,
|
||||
)
|
||||
store.write(token.dst_endpoint.buffer_kind, token.dst_addr, data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Forward IpcqMetaArrival to local PE_IPCQ
|
||||
ipcq_id = f"{self._pe_prefix}.pe_ipcq"
|
||||
if ipcq_id in self.out_ports:
|
||||
yield self.out_ports[ipcq_id].put(IpcqMetaArrival(token=token))
|
||||
# ─────────────────────────────────────────────────────────────────
|
||||
|
||||
if not txn.done.triggered:
|
||||
txn.done.succeed()
|
||||
|
||||
def _pipeline_process(self, env: simpy.Environment, token: Any) -> Generator:
|
||||
"""Pipeline mode: DMA read/write via fabric, then self-route."""
|
||||
|
||||
@@ -0,0 +1,455 @@
|
||||
"""PE_IPCQ component (ADR-0023): per-PE IPCQ control plane.
|
||||
|
||||
Responsibilities:
|
||||
- Hold per-direction queue pair state (my_head, my_tail,
|
||||
peer_head_cache, peer_tail_cache, ring buffer addresses)
|
||||
- Process IpcqInitMsg from backend to install neighbor table
|
||||
- Handle IpcqRequest(IpcqSendCmd) from PE_CPU:
|
||||
compute peer slot address, check backpressure, forward
|
||||
IpcqDmaToken to PE_DMA (vc_comm)
|
||||
- Handle IpcqRequest(IpcqRecvCmd) from PE_CPU:
|
||||
wait for data arrival, return slot address (or copy to dst),
|
||||
send fast-path credit return
|
||||
- Handle IpcqMetaArrival from PE_DMA: update peer_head_cache, wake recv
|
||||
- Handle IpcqCreditMetadata via own credit_inbox: update peer_tail_cache,
|
||||
wake send
|
||||
|
||||
PE_IPCQ does NOT move data — it forwards IpcqDmaToken to PE_DMA which
|
||||
performs the actual fabric DMA.
|
||||
|
||||
Credit return uses a fast path: PE_IPCQ creates a SimPy process with a
|
||||
bottleneck-BW based latency, then puts IpcqCreditMetadata directly into
|
||||
the peer's pre-wired credit_store.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.common.ipcq_types import (
|
||||
IpcqCreditMetadata,
|
||||
IpcqDmaToken,
|
||||
IpcqInvalidDirection,
|
||||
IpcqMetaArrival,
|
||||
IpcqRecvCmd,
|
||||
IpcqRequest,
|
||||
IpcqSendCmd,
|
||||
)
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.runtime_api.kernel import IpcqInitMsg
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
_DIR_ORDER: tuple[str, ...] = ("N", "S", "E", "W", "parent", "child_left", "child_right")
|
||||
|
||||
|
||||
class PeIpcqComponent(ComponentBase):
|
||||
"""PE_IPCQ: ring buffer pointer + neighbor management for CCL.
|
||||
|
||||
Owned by one PE; talks to PE_DMA via out_ports[<pe_dma_id>] and
|
||||
receives credit return metadata via the public ``credit_inbox``
|
||||
SimPy Store (wired by backend at IpcqInitMsg installation time).
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
# Strict shape/dtype validation (D14 F2). Off by default.
|
||||
self._strict: bool = bool(node.attrs.get("strict_validation", False))
|
||||
# direction → list of received tokens (for strict-mode peek of next slot)
|
||||
self._arrived_tokens: dict[str, list] = {}
|
||||
# Parse self (sip, cube, pe) from node id, e.g. "sip0.cube0.pe0.pe_ipcq"
|
||||
self._pe_prefix: str = node.id.rsplit(".", 1)[0]
|
||||
parts = self._pe_prefix.split(".")
|
||||
try:
|
||||
self._self_sip = int(parts[0].replace("sip", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._self_sip = 0
|
||||
try:
|
||||
self._self_cube = int(parts[1].replace("cube", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._self_cube = 0
|
||||
try:
|
||||
self._self_pe = int(parts[2].replace("pe", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._self_pe = 0
|
||||
|
||||
self._dma_node_id = f"{self._pe_prefix}.pe_dma"
|
||||
# direction → state dict (see _install_neighbors for shape)
|
||||
self._queue_pairs: dict[str, dict[str, Any]] = {}
|
||||
self._installed = False
|
||||
self._buffer_kind: str = "tcm"
|
||||
self._backpressure_mode: str = "sleep"
|
||||
self._credit_size_bytes: int = 16
|
||||
# waiters for recv (per direction) and any-direction (for round-robin)
|
||||
self._recv_waiters: dict[str, list[simpy.Event]] = {}
|
||||
self._any_recv_waiters: list[simpy.Event] = []
|
||||
# waiters for send backpressure (per direction)
|
||||
self._send_waiters: dict[str, list[simpy.Event]] = {}
|
||||
# round-robin cursor over installed directions
|
||||
self._rr_dirs: list[str] = []
|
||||
self._rr_cursor: int = 0
|
||||
# credit_inbox is created in start() once env is available
|
||||
self._credit_inbox: simpy.Store | None = None
|
||||
|
||||
# ── Public ──
|
||||
|
||||
@property
|
||||
def credit_inbox(self) -> simpy.Store:
|
||||
"""SimPy Store that backend wires as ``peer_credit_store`` on
|
||||
every remote sender targeting this PE. Used by D9 fast path."""
|
||||
assert self._credit_inbox is not None, "PE_IPCQ not started yet"
|
||||
return self._credit_inbox
|
||||
|
||||
@property
|
||||
def queue_pairs(self) -> dict[str, dict[str, Any]]:
|
||||
"""Test/debug accessor."""
|
||||
return self._queue_pairs
|
||||
|
||||
# ── Lifecycle ──
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
yield env.timeout(0)
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
# Create credit_inbox even if there are no in_ports yet
|
||||
if self._credit_inbox is None:
|
||||
self._credit_inbox = simpy.Store(env)
|
||||
# If no in_ports were wired (e.g. unit test), still spin up workers
|
||||
if not self.in_ports:
|
||||
self._inbox = simpy.Store(env)
|
||||
super().start(env)
|
||||
env.process(self._credit_worker(env))
|
||||
|
||||
# ── Worker (override of ComponentBase._worker) ──
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
from kernbench.runtime_api.kernel import IpcqInitMsg
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
|
||||
# IpcqInitMsg may arrive wrapped in a transaction (with .request)
|
||||
# or directly.
|
||||
request_obj = getattr(msg, "request", None)
|
||||
if isinstance(request_obj, IpcqInitMsg):
|
||||
self._install_neighbors(request_obj)
|
||||
done = getattr(msg, "done", None)
|
||||
if done is not None and not done.triggered:
|
||||
done.succeed()
|
||||
continue
|
||||
if isinstance(msg, IpcqInitMsg):
|
||||
self._install_neighbors(msg)
|
||||
continue
|
||||
|
||||
if isinstance(msg, IpcqMetaArrival):
|
||||
self._handle_meta_arrival(msg)
|
||||
continue
|
||||
|
||||
if isinstance(msg, IpcqRequest):
|
||||
env.process(self._handle_request(env, msg))
|
||||
continue
|
||||
|
||||
# Unknown message — drop or forward via base class fallback
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
# ── Init ──
|
||||
|
||||
def _install_neighbors(self, msg: IpcqInitMsg) -> None:
|
||||
self._installed = True
|
||||
self._buffer_kind = msg.buffer_kind
|
||||
self._backpressure_mode = msg.backpressure_mode
|
||||
self._credit_size_bytes = msg.credit_size_bytes
|
||||
for entry in msg.entries:
|
||||
self._queue_pairs[entry.direction] = {
|
||||
"peer": entry.peer,
|
||||
"my_rx_base_pa": entry.my_rx_base_pa,
|
||||
"my_rx_base_va": entry.my_rx_base_va,
|
||||
"n_slots": entry.n_slots,
|
||||
"slot_size": entry.slot_size,
|
||||
"peer_credit_store": entry.peer_credit_store,
|
||||
"my_head": 0,
|
||||
"my_tail": 0,
|
||||
"peer_head_cache": 0,
|
||||
"peer_tail_cache": 0,
|
||||
}
|
||||
self._recv_waiters.setdefault(entry.direction, [])
|
||||
self._send_waiters.setdefault(entry.direction, [])
|
||||
# Reset round-robin order to a stable canonical sequence
|
||||
self._rr_dirs = [d for d in _DIR_ORDER if d in self._queue_pairs]
|
||||
self._rr_cursor = 0
|
||||
|
||||
# ── Send ──
|
||||
|
||||
def _handle_request(self, env: simpy.Environment, req: IpcqRequest) -> Generator:
|
||||
cmd = req.command
|
||||
if isinstance(cmd, IpcqSendCmd):
|
||||
yield from self._handle_send(env, req, cmd)
|
||||
elif isinstance(cmd, IpcqRecvCmd):
|
||||
yield from self._handle_recv(env, req, cmd)
|
||||
|
||||
def _handle_send(
|
||||
self, env: simpy.Environment, req: IpcqRequest, cmd: IpcqSendCmd,
|
||||
) -> Generator:
|
||||
if cmd.direction not in self._queue_pairs:
|
||||
raise IpcqInvalidDirection(
|
||||
f"PE {self._pe_prefix}: direction {cmd.direction!r} not installed"
|
||||
)
|
||||
qp = self._queue_pairs[cmd.direction]
|
||||
peer = qp["peer"]
|
||||
|
||||
# Backpressure: wait while ring full
|
||||
while (qp["my_head"] - qp["peer_tail_cache"]) >= peer.n_slots:
|
||||
wait_event = env.event()
|
||||
self._send_waiters[cmd.direction].append(wait_event)
|
||||
yield wait_event
|
||||
|
||||
# Compute peer slot address
|
||||
slot_idx = qp["my_head"] % peer.n_slots
|
||||
dst_pa = peer.rx_base_pa + slot_idx * peer.slot_size
|
||||
|
||||
token = IpcqDmaToken(
|
||||
src_addr=cmd.src_addr,
|
||||
src_space=cmd.src_space,
|
||||
dst_addr=dst_pa,
|
||||
dst_endpoint=peer,
|
||||
nbytes=cmd.nbytes,
|
||||
handle_id=cmd.handle_id,
|
||||
shape=cmd.shape,
|
||||
dtype=cmd.dtype,
|
||||
sender_seq=qp["my_head"],
|
||||
src_sip=self._self_sip,
|
||||
src_cube=self._self_cube,
|
||||
src_pe=self._self_pe,
|
||||
src_direction=cmd.direction,
|
||||
)
|
||||
|
||||
# Forward to PE_DMA (vc_comm)
|
||||
yield self.out_ports[self._dma_node_id].put(token)
|
||||
qp["my_head"] += 1
|
||||
# Diagnostics trace (D14)
|
||||
from kernbench.ccl import diagnostics
|
||||
if diagnostics.trace_enabled():
|
||||
diagnostics.log_send(
|
||||
t_ns=float(env.now), sender=self._pe_prefix,
|
||||
direction=cmd.direction, nbytes=cmd.nbytes,
|
||||
sender_seq=qp["my_head"] - 1,
|
||||
)
|
||||
if not req.done.triggered:
|
||||
req.done.succeed()
|
||||
|
||||
# ── Recv ──
|
||||
|
||||
def _handle_recv(
|
||||
self, env: simpy.Environment, req: IpcqRequest, cmd: IpcqRecvCmd,
|
||||
) -> Generator:
|
||||
if cmd.direction is None:
|
||||
direction = yield from self._wait_any_direction(env)
|
||||
else:
|
||||
if cmd.direction not in self._queue_pairs:
|
||||
raise IpcqInvalidDirection(
|
||||
f"PE {self._pe_prefix}: direction {cmd.direction!r} not installed"
|
||||
)
|
||||
direction = cmd.direction
|
||||
qp = self._queue_pairs[direction]
|
||||
while qp["peer_head_cache"] <= qp["my_tail"]:
|
||||
wait_event = env.event()
|
||||
self._recv_waiters[direction].append(wait_event)
|
||||
yield wait_event
|
||||
|
||||
qp = self._queue_pairs[direction]
|
||||
slot_idx = qp["my_tail"] % qp["n_slots"]
|
||||
slot_addr = qp["my_rx_base_pa"] + slot_idx * qp["slot_size"]
|
||||
|
||||
# Strict validation (D14 F2): peek the next-arrived token's metadata
|
||||
# against the recv command's expected shape/dtype/nbytes.
|
||||
arrived = self._arrived_tokens.get(direction, [])
|
||||
if arrived:
|
||||
front = arrived.pop(0)
|
||||
if self._strict:
|
||||
expected_nbytes = self._nbytes_for(cmd.shape, cmd.dtype)
|
||||
if front.dtype != cmd.dtype:
|
||||
raise ValueError(
|
||||
f"PE_IPCQ {self._pe_prefix} recv strict: dtype mismatch — "
|
||||
f"sender={front.dtype} recv={cmd.dtype}"
|
||||
)
|
||||
if front.shape != cmd.shape:
|
||||
raise ValueError(
|
||||
f"PE_IPCQ {self._pe_prefix} recv strict: shape mismatch — "
|
||||
f"sender={front.shape} recv={cmd.shape}"
|
||||
)
|
||||
if front.nbytes != expected_nbytes:
|
||||
raise ValueError(
|
||||
f"PE_IPCQ {self._pe_prefix} recv strict: nbytes mismatch — "
|
||||
f"sender={front.nbytes} recv={expected_nbytes}"
|
||||
)
|
||||
|
||||
req.result_data["src_space"] = self._buffer_kind
|
||||
req.result_data["src_addr"] = slot_addr
|
||||
req.result_data["direction"] = direction
|
||||
req.result_data["dtype"] = cmd.dtype
|
||||
req.result_data["shape"] = cmd.shape
|
||||
req.result_data["nbytes"] = self._nbytes_for(cmd.shape, cmd.dtype)
|
||||
|
||||
# copy_to_dst mode: rebind the result handle to (dst_space, dst_addr).
|
||||
# When op_log is disabled, we also do the actual data move now;
|
||||
# when op_log is enabled, Phase 2 replays the slot→dst copy from
|
||||
# the op_log entry below so we don't pollute the slot in Phase 1.
|
||||
if cmd.recv_mode == "copy_to_dst" and self.ctx is not None:
|
||||
req.result_data["src_space"] = cmd.dst_space
|
||||
req.result_data["src_addr"] = cmd.dst_addr
|
||||
store = getattr(self.ctx, "memory_store", None)
|
||||
if store is not None and self._op_logger is None:
|
||||
try:
|
||||
data = store.read(self._buffer_kind, slot_addr, shape=cmd.shape, dtype=cmd.dtype)
|
||||
store.write(cmd.dst_space, cmd.dst_addr, data)
|
||||
except Exception:
|
||||
pass
|
||||
if self._op_logger is not None:
|
||||
# Record slot → dst copy for Phase 2 replay (ADR-0023 D9.5).
|
||||
try:
|
||||
self._op_logger.record_copy(
|
||||
t_start=float(env.now), t_end=float(env.now),
|
||||
component_id=self.node.id,
|
||||
src_space=self._buffer_kind, src_addr=slot_addr,
|
||||
dst_space=cmd.dst_space, dst_addr=cmd.dst_addr,
|
||||
shape=cmd.shape, dtype=cmd.dtype,
|
||||
nbytes=self._nbytes_for(cmd.shape, cmd.dtype),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
qp["my_tail"] += 1
|
||||
|
||||
# Diagnostics trace (D14)
|
||||
from kernbench.ccl import diagnostics
|
||||
if diagnostics.trace_enabled():
|
||||
diagnostics.log_recv(
|
||||
t_ns=float(env.now), receiver=self._pe_prefix,
|
||||
direction=direction,
|
||||
nbytes=req.result_data.get("nbytes", 0),
|
||||
)
|
||||
|
||||
# Fast path credit return — bottleneck BW based latency
|
||||
env.process(
|
||||
self._delayed_credit_send(env, direction, qp["peer_credit_store"], qp["my_tail"])
|
||||
)
|
||||
|
||||
if not req.done.triggered:
|
||||
req.done.succeed()
|
||||
|
||||
def _wait_any_direction(self, env: simpy.Environment) -> Generator:
|
||||
"""Round-robin scan over installed directions; wait until at least one
|
||||
has data. Returns the chosen direction (str)."""
|
||||
if not self._rr_dirs:
|
||||
raise IpcqInvalidDirection(
|
||||
f"PE {self._pe_prefix}: no neighbors installed"
|
||||
)
|
||||
while True:
|
||||
n = len(self._rr_dirs)
|
||||
for i in range(n):
|
||||
idx = (self._rr_cursor + i) % n
|
||||
d = self._rr_dirs[idx]
|
||||
qp = self._queue_pairs[d]
|
||||
if qp["peer_head_cache"] > qp["my_tail"]:
|
||||
self._rr_cursor = (idx + 1) % n
|
||||
return d
|
||||
# Nothing available — wait until any arrival
|
||||
wait_event = env.event()
|
||||
self._any_recv_waiters.append(wait_event)
|
||||
yield wait_event
|
||||
|
||||
# ── Metadata arrival from PE_DMA (D9) ──
|
||||
|
||||
def _handle_meta_arrival(self, msg: IpcqMetaArrival) -> None:
|
||||
token = msg.token
|
||||
sender_key = (token.src_sip, token.src_cube, token.src_pe)
|
||||
for d, qp in self._queue_pairs.items():
|
||||
p = qp["peer"]
|
||||
if (p.sip, p.cube, p.pe) == sender_key:
|
||||
qp["peer_head_cache"] = max(qp["peer_head_cache"], token.sender_seq + 1)
|
||||
# Track arrived token for strict-mode peek
|
||||
self._arrived_tokens.setdefault(d, []).append(token)
|
||||
# Wake any blocked recv on this direction
|
||||
waiters = self._recv_waiters.get(d, [])
|
||||
self._recv_waiters[d] = []
|
||||
for ev in waiters:
|
||||
if not ev.triggered:
|
||||
ev.succeed()
|
||||
# Wake any-direction waiters
|
||||
any_waiters = self._any_recv_waiters
|
||||
self._any_recv_waiters = []
|
||||
for ev in any_waiters:
|
||||
if not ev.triggered:
|
||||
ev.succeed()
|
||||
return
|
||||
# Unknown sender — silently drop (could log)
|
||||
|
||||
# ── Credit return (fast path) ──
|
||||
|
||||
def _credit_worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Process IpcqCreditMetadata from credit_inbox."""
|
||||
assert self._credit_inbox is not None
|
||||
while True:
|
||||
credit: IpcqCreditMetadata = yield self._credit_inbox.get()
|
||||
sender_key = (credit.src_sip, credit.src_cube, credit.src_pe)
|
||||
for d, qp in self._queue_pairs.items():
|
||||
p = qp["peer"]
|
||||
if (p.sip, p.cube, p.pe) == sender_key:
|
||||
qp["peer_tail_cache"] = max(qp["peer_tail_cache"], credit.consumer_seq)
|
||||
# Wake any blocked send on this direction
|
||||
waiters = self._send_waiters.get(d, [])
|
||||
self._send_waiters[d] = []
|
||||
for ev in waiters:
|
||||
if not ev.triggered:
|
||||
ev.succeed()
|
||||
break
|
||||
|
||||
def _delayed_credit_send(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
direction: str,
|
||||
peer_credit_store: simpy.Store,
|
||||
new_tail: int,
|
||||
) -> Generator:
|
||||
"""Wait bottleneck-BW latency, then put IpcqCreditMetadata into peer
|
||||
credit store (D9 fast path)."""
|
||||
latency_ns = self._credit_latency_ns(direction)
|
||||
if latency_ns > 0:
|
||||
yield env.timeout(latency_ns)
|
||||
meta = IpcqCreditMetadata(
|
||||
consumer_seq=new_tail,
|
||||
src_sip=self._self_sip,
|
||||
src_cube=self._self_cube,
|
||||
src_pe=self._self_pe,
|
||||
src_direction=direction,
|
||||
)
|
||||
yield peer_credit_store.put(meta)
|
||||
|
||||
def _credit_latency_ns(self, direction: str) -> float:
|
||||
"""Compute credit fast path latency = credit_size / bottleneck_bw.
|
||||
|
||||
Falls back to 0 when ctx/router is unavailable (unit-test mode).
|
||||
"""
|
||||
if self.ctx is None:
|
||||
return 0.0
|
||||
qp = self._queue_pairs[direction]
|
||||
peer = qp["peer"]
|
||||
peer_pe_prefix = f"sip{peer.sip}.cube{peer.cube}.pe{peer.pe}"
|
||||
try:
|
||||
path = self.ctx.router.find_path(self._pe_prefix, peer_pe_prefix)
|
||||
return self.ctx.compute_drain_ns(path, self._credit_size_bytes)
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
# ── Helpers ──
|
||||
|
||||
@staticmethod
|
||||
def _nbytes_for(shape: tuple[int, ...], dtype: str) -> int:
|
||||
from math import prod
|
||||
bits = {"f16": 16, "bf16": 16, "f32": 32, "i8": 8, "i16": 16, "i32": 32}.get(dtype, 16)
|
||||
return prod(shape) * (bits // 8) if shape else 0
|
||||
@@ -29,11 +29,10 @@ def run_bench(
|
||||
correlation_id: str = "bench0",
|
||||
completion_policy: CompletionPolicy = CompletionPolicy.LAST_SUBMITTED,
|
||||
) -> BenchResult:
|
||||
"""
|
||||
Minimal bench runner.
|
||||
"""Minimal bench runner.
|
||||
|
||||
- topology: compiled topology object (opaque to runtime here)
|
||||
- bench_fn: callable that receives RuntimeContext and submits requests
|
||||
- bench_fn: callable ``run(torch)`` receiving a RuntimeContext
|
||||
- device: DeviceSelector ("all" or "sip:<N>")
|
||||
- engine_factory: builds sim_engine for given topology & device
|
||||
- completion_policy: how to determine overall completion/result
|
||||
@@ -48,7 +47,6 @@ def run_bench(
|
||||
)
|
||||
|
||||
bench_fn(ctx)
|
||||
|
||||
ctx.wait_all()
|
||||
|
||||
collected_traces = ctx._traces or None
|
||||
|
||||
@@ -9,6 +9,39 @@ from kernbench.common.types import Completion, RequestHandle, SimEngine
|
||||
from .types import DeviceSelector
|
||||
|
||||
|
||||
def _world_size_from_spec(spec: dict | None) -> int:
|
||||
"""Derive world_size from topology spec: sips × cubes × pes_per_cube."""
|
||||
spec = spec or {}
|
||||
sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
cm = spec.get("sip", {}).get("cube_mesh", {})
|
||||
cubes_per_sip = int(cm.get("w", 1)) * int(cm.get("h", 1))
|
||||
pl = spec.get("cube", {}).get("pe_layout", {})
|
||||
corners = pl.get("corners", [])
|
||||
pe_per_corner = int(pl.get("pe_per_corner", 1))
|
||||
pes_per_cube = pe_per_corner * max(len(corners), 1)
|
||||
return sips * cubes_per_sip * pes_per_cube
|
||||
|
||||
|
||||
def _numpy_to_dtype_str(np_dtype) -> str:
|
||||
"""Map numpy dtype → kernbench dtype string used by Tensor."""
|
||||
import numpy as np
|
||||
|
||||
kind_map = {
|
||||
np.float16: "f16",
|
||||
np.float32: "f32",
|
||||
np.int8: "i8",
|
||||
np.int16: "i16",
|
||||
np.int32: "i32",
|
||||
np.uint8: "u8",
|
||||
np.uint16: "u16",
|
||||
np.uint32: "u32",
|
||||
}
|
||||
for np_type, s in kind_map.items():
|
||||
if np.dtype(np_dtype) == np.dtype(np_type):
|
||||
return s
|
||||
raise ValueError(f"unsupported numpy dtype: {np_dtype!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeContext:
|
||||
engine: SimEngine
|
||||
@@ -23,6 +56,66 @@ class RuntimeContext:
|
||||
_tensor_counter: int = field(default=0, init=False)
|
||||
_traces: list[dict] = field(default_factory=list, init=False)
|
||||
_tensors: list[Any] = field(default_factory=list, init=False)
|
||||
distributed: Any = field(default=None, init=False) # DistributedContext for CCL benches
|
||||
_ipcq_plan: dict = field(default_factory=dict, init=False) # ADR-0023 install plan
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Eagerly attach a DistributedContext so bench code can do
|
||||
# ``dist = torch.distributed`` + ``dist.init_process_group(...)``
|
||||
# without needing a separate launcher to install it.
|
||||
from kernbench.runtime_api.distributed import DistributedContext
|
||||
dc = DistributedContext()
|
||||
dc._ctx_ref = self # back-reference for AhbmCCLBackend to reach ctx.launch etc.
|
||||
self.distributed = dc
|
||||
|
||||
def install_ipcq(
|
||||
self,
|
||||
algorithm: str | None = None,
|
||||
ccl_yaml: str | None = None,
|
||||
world_size_override: int | None = None,
|
||||
rank_to_pe: list[tuple[int, int, int]] | None = None,
|
||||
) -> dict:
|
||||
"""Install IPCQ neighbor tables on all participating PEs (ADR-0023 D10).
|
||||
|
||||
Loads ``ccl.yaml`` (or the path provided), resolves the chosen
|
||||
algorithm (or ``defaults.algorithm`` if None), and pushes per-PE
|
||||
IpcqInitMsg into every PE_IPCQ component via the engine.
|
||||
|
||||
Args:
|
||||
algorithm: name of the algorithm in ccl.yaml (or use defaults).
|
||||
ccl_yaml: optional path to ccl.yaml.
|
||||
world_size_override: if set, replace the algorithm's world_size.
|
||||
|
||||
Returns the install plan dict (rank → (sip,cube,pe), neighbor table).
|
||||
"""
|
||||
import importlib
|
||||
from kernbench.ccl.install import (
|
||||
install_ipcq as _install,
|
||||
load_ccl_config,
|
||||
resolve_algorithm_config,
|
||||
)
|
||||
|
||||
cfg = load_ccl_config(ccl_yaml)
|
||||
merged = resolve_algorithm_config(cfg, algorithm)
|
||||
if world_size_override is not None:
|
||||
merged["world_size"] = world_size_override
|
||||
elif "world_size" not in merged:
|
||||
# Derive from topology.yaml when neither the algorithm entry
|
||||
# nor ``defaults`` carries ``world_size`` (matches pytorch DDP
|
||||
# where env vars determine ranks, not the ccl config file).
|
||||
merged["world_size"] = _world_size_from_spec(self.spec)
|
||||
algo_module = None
|
||||
try:
|
||||
algo_module = importlib.import_module(merged["module"])
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
plan = _install(
|
||||
self.engine, self.spec, merged,
|
||||
algo_module=algo_module, rank_to_pe=rank_to_pe,
|
||||
)
|
||||
self._ipcq_plan = plan
|
||||
self._ipcq_config = merged
|
||||
return plan
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@@ -258,6 +351,24 @@ class RuntimeContext:
|
||||
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
|
||||
return self._create_tensor(shape, dtype, name, pattern=None, dp=dp)
|
||||
|
||||
def from_numpy(self, arr: Any):
|
||||
"""Create a host-side tensor wrapping a numpy array.
|
||||
|
||||
Mirrors ``torch.from_numpy``. The returned tensor is NOT deployed
|
||||
to any PE — it lives in an in-memory host staging buffer. Use
|
||||
``target.copy_(host_tensor)`` to scatter its contents into a
|
||||
sharded, deployed tensor.
|
||||
"""
|
||||
import numpy as np
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
|
||||
arr_c = np.ascontiguousarray(arr)
|
||||
dtype_str = _numpy_to_dtype_str(arr_c.dtype)
|
||||
t = Tensor(shape=tuple(arr_c.shape), dtype=dtype_str, name="host")
|
||||
t._host_buffer = arr_c
|
||||
t._memory_store = getattr(self.engine, "_memory_store", None)
|
||||
return t
|
||||
|
||||
def _create_tensor(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
@@ -418,13 +529,12 @@ class RuntimeContext:
|
||||
TensorArgShard,
|
||||
)
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
from kernbench.triton_emu.registry import register_kernel
|
||||
from kernbench.triton_emu.registry import _kernels, register_kernel
|
||||
|
||||
# Register kernel (idempotent)
|
||||
try:
|
||||
register_kernel(kernel_name, kernel_fn)
|
||||
except ValueError:
|
||||
pass
|
||||
# Register kernel (idempotent overwrite — last call wins).
|
||||
# Tests can re-register the same kernel_name with a different
|
||||
# function; the user's most recent launch must use the latest fn.
|
||||
_kernels[kernel_name] = kernel_fn
|
||||
|
||||
# Collect tensors and scalars
|
||||
tensor_args: list[Tensor] = []
|
||||
@@ -506,6 +616,7 @@ class RuntimeContext:
|
||||
|
||||
# Per-SIP kernel launch: each SIP gets TensorArgs with local va_base
|
||||
last_handle = None
|
||||
_pending_handles: list[tuple[Any, int]] = []
|
||||
for sip_id in sorted(sip_set):
|
||||
sip_kernel_args: list = []
|
||||
sip_cube_set: set[int] = set()
|
||||
@@ -566,10 +677,17 @@ class RuntimeContext:
|
||||
target_cubes=target_cubes,
|
||||
target_pe=target_pe,
|
||||
))
|
||||
# Defer wait until all SIPs are submitted (multi-SIP CCL needs
|
||||
# all participating PEs to be live concurrently — waiting
|
||||
# per-SIP would deadlock when ranks span SIP boundaries).
|
||||
_pending_handles.append((h, sip_id))
|
||||
last_handle = h
|
||||
|
||||
# Drain pending handles now that every SIP has a launch posted.
|
||||
for h, sip_id in _pending_handles:
|
||||
self.wait(h, _meta={
|
||||
"phase": "kernel", "name": kernel_name,
|
||||
"sip": sip_id, "target_pe": target_pe,
|
||||
})
|
||||
last_handle = h
|
||||
|
||||
return last_handle
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
"""PyTorch-compatible distributed communication shim (ADR-0023 D11).
|
||||
|
||||
Provides a ``torch.distributed``-like API whose public surface matches
|
||||
real PyTorch so that bench code looks identical to a DDP training script.
|
||||
|
||||
Only the ``ahbm`` backend is implemented. It:
|
||||
|
||||
1. Reads ``ccl.yaml`` to decide which collective algorithm to run.
|
||||
2. Derives world_size from the algorithm entry, the defaults section, or
|
||||
from the topology spec (``system.sips.count × sip.cube_mesh × pe_layout``).
|
||||
3. At ``init_process_group`` time, eagerly installs the IPCQ neighbor
|
||||
table once (one-time comm setup — mirrors NCCL communicator creation).
|
||||
4. On each ``all_reduce(tensor)`` call, reads per-shard metadata from the
|
||||
tensor handle and dispatches ``torch.launch`` with the registered
|
||||
kernel. The kernel performs intra-PE ring/tree/mesh CCL via IPCQ,
|
||||
and Phase 2 DataExecutor replays math + copies from op_log so
|
||||
MemoryStore is correct when ``all_reduce`` returns.
|
||||
|
||||
Host bench code uses only real-PyTorch names:
|
||||
dist.init_process_group, dist.is_initialized, dist.get_world_size,
|
||||
dist.get_rank, dist.get_backend, dist.all_reduce, dist.barrier
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
|
||||
class AhbmCCLBackend:
|
||||
"""Ahbm CCL backend — drives kernel-level collectives via IPCQ."""
|
||||
|
||||
def __init__(self, torch_ctx: Any) -> None:
|
||||
from kernbench.ccl.install import (
|
||||
load_ccl_config,
|
||||
resolve_algorithm_config,
|
||||
)
|
||||
|
||||
self.ctx = torch_ctx
|
||||
self._cfg_all = load_ccl_config()
|
||||
self._merged = resolve_algorithm_config(self._cfg_all)
|
||||
self._algo_module = importlib.import_module(self._merged["module"])
|
||||
self._world_size = self._resolve_world_size()
|
||||
|
||||
# Eager IPCQ install — ``init_process_group`` time. Mirrors NCCL
|
||||
# communicator creation: done once, reused across every subsequent
|
||||
# collective call on the same process group.
|
||||
self.ctx.install_ipcq(
|
||||
algorithm=self._merged["algorithm"],
|
||||
world_size_override=self._world_size,
|
||||
)
|
||||
|
||||
def _resolve_world_size(self) -> int:
|
||||
"""Derive world_size (priority: algorithm override > defaults > topology).
|
||||
|
||||
Topology derivation:
|
||||
sips × cubes_per_sip × pes_per_cube
|
||||
"""
|
||||
if "world_size" in self._merged:
|
||||
return int(self._merged["world_size"])
|
||||
defaults = self._cfg_all.get("defaults", {})
|
||||
if "world_size" in defaults:
|
||||
return int(defaults["world_size"])
|
||||
spec = self.ctx.spec or {}
|
||||
sips = int(spec.get("system", {}).get("sips", {}).get("count", 1))
|
||||
cm = spec.get("sip", {}).get("cube_mesh", {})
|
||||
cubes_per_sip = int(cm.get("w", 1)) * int(cm.get("h", 1))
|
||||
pl = spec.get("cube", {}).get("pe_layout", {})
|
||||
corners = pl.get("corners", [])
|
||||
pe_per_corner = int(pl.get("pe_per_corner", 1))
|
||||
pes_per_cube = pe_per_corner * max(len(corners), 1)
|
||||
return sips * cubes_per_sip * pes_per_cube
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return self._world_size
|
||||
|
||||
def all_reduce(self, tensor: Any, op: str = "sum") -> None:
|
||||
"""Dispatch the configured CCL algorithm as a single kernel launch.
|
||||
|
||||
Raises if ``op != "sum"`` (current kernels only implement add
|
||||
reduction) or if the tensor's shard count disagrees with the
|
||||
world_size that was installed into PE_IPCQ.
|
||||
"""
|
||||
if op != "sum":
|
||||
raise NotImplementedError(f"all_reduce op={op!r} not supported")
|
||||
if tensor._handle is None:
|
||||
raise RuntimeError(
|
||||
f"Tensor '{tensor.name}' is not deployed (call torch.zeros "
|
||||
"with a DPPolicy first)"
|
||||
)
|
||||
shards = tensor._handle.shards
|
||||
if len(shards) != self._world_size:
|
||||
raise RuntimeError(
|
||||
f"all_reduce tensor has {len(shards)} shards but the "
|
||||
f"ahbm backend was installed with world_size="
|
||||
f"{self._world_size}; adjust the tensor's DPPolicy or "
|
||||
"restart the process group"
|
||||
)
|
||||
n_elem = shards[0].nbytes // tensor.itemsize
|
||||
kernel_fn = self._algo_module.kernel
|
||||
kernel_args = self._algo_module.kernel_args(self._world_size, n_elem)
|
||||
self.ctx.launch(
|
||||
self._merged["algorithm"], kernel_fn, tensor, *kernel_args,
|
||||
)
|
||||
|
||||
def barrier(self) -> None:
|
||||
# Single-driver model → no cross-process sync needed. Keeping the
|
||||
# method so ``dist.barrier()`` is callable (pytorch-compat surface).
|
||||
return None
|
||||
|
||||
|
||||
class DistributedContext:
|
||||
"""torch.distributed-compat facade.
|
||||
|
||||
Public surface matches real PyTorch so bench code reads identically
|
||||
to a DDP training script. Single-driver semantics: ``get_rank()``
|
||||
always returns 0 because kernbench runs as one Python process;
|
||||
``get_world_size()`` returns the CCL group size (number of PEs
|
||||
participating in the collective).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._backend: AhbmCCLBackend | None = None
|
||||
|
||||
def init_process_group(
|
||||
self,
|
||||
backend: str = "ahbm",
|
||||
world_size: int | None = None,
|
||||
rank: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create the default process group.
|
||||
|
||||
``world_size`` and ``rank`` are accepted for API parity with
|
||||
``torch.distributed.init_process_group`` but ignored — the ahbm
|
||||
backend derives both from ``ccl.yaml`` + topology automatically
|
||||
(like reading ``RANK``/``WORLD_SIZE`` env vars in real DDP).
|
||||
"""
|
||||
if backend != "ahbm":
|
||||
raise ValueError(
|
||||
f"Unsupported backend '{backend}'. Only 'ahbm' is supported."
|
||||
)
|
||||
ctx = getattr(self, "_ctx_ref", None)
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"DistributedContext not bound to a RuntimeContext"
|
||||
)
|
||||
self._backend = AhbmCCLBackend(torch_ctx=ctx)
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
return self._backend is not None
|
||||
|
||||
def get_world_size(self) -> int:
|
||||
self._ensure_initialized()
|
||||
return self._backend.world_size
|
||||
|
||||
def get_rank(self) -> int:
|
||||
# Single-driver kernbench: there is only one host rank.
|
||||
self._ensure_initialized()
|
||||
return 0
|
||||
|
||||
def get_backend(self) -> str:
|
||||
self._ensure_initialized()
|
||||
return "ahbm"
|
||||
|
||||
def all_reduce(self, tensor: Any, op: str = "sum") -> None:
|
||||
self._ensure_initialized()
|
||||
self._backend.all_reduce(tensor, op=op)
|
||||
|
||||
def barrier(self) -> None:
|
||||
self._ensure_initialized()
|
||||
self._backend.barrier()
|
||||
|
||||
def _ensure_initialized(self) -> None:
|
||||
if self._backend is None:
|
||||
raise RuntimeError(
|
||||
"Default process group has not been initialized. "
|
||||
"Call init_process_group(backend='ahbm') first."
|
||||
)
|
||||
@@ -152,3 +152,30 @@ class MmuUnmapMsg:
|
||||
target_cubes: tuple[int, ...] | Literal["all"] = "all"
|
||||
target_pe: int | Literal["all"] = "all"
|
||||
msg_type: Literal["mmu_unmap"] = "mmu_unmap"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IpcqInitMsg:
|
||||
"""IPCQ neighbor table install (sideband fan-out, ADR-0023 D10/D12).
|
||||
|
||||
Backend issues this at ``init_process_group`` time to install per-PE
|
||||
IPCQ neighbor tables. Each entry covers one direction (N/S/E/W) and
|
||||
carries the peer's IpcqEndpoint plus this PE's own rx_buffer base
|
||||
and a pre-wired SimPy Store for credit return fast path (D9).
|
||||
|
||||
Routing is similar to MmuMapMsg.
|
||||
"""
|
||||
|
||||
correlation_id: str
|
||||
request_id: str
|
||||
target_sips: tuple[int, ...] | Literal["all"] = "all"
|
||||
target_cubes: tuple[int, ...] | Literal["all"] = "all"
|
||||
target_pe: int | tuple[int, ...] | Literal["all"] = "all"
|
||||
# entries: tuple[IpcqInitEntry, ...] — kept as tuple of plain objects to
|
||||
# avoid a runtime import cycle (IpcqInitEntry lives in
|
||||
# kernbench.common.ipcq_types).
|
||||
entries: tuple = ()
|
||||
backpressure_mode: str = "sleep" # "poll" | "sleep"
|
||||
buffer_kind: str = "tcm" # "tcm" | "hbm" | "sram"
|
||||
credit_size_bytes: int = 16
|
||||
msg_type: Literal["ipcq_init"] = "ipcq_init"
|
||||
|
||||
@@ -146,6 +146,11 @@ class Tensor:
|
||||
self._handle: TensorHandle | None = None
|
||||
self._ctx_ref: weakref.ref | None = None # set by RuntimeContext
|
||||
self._memory_store = None # set by RuntimeContext when enable_data=True
|
||||
# Host-side staging buffer for torch.from_numpy() results. A tensor
|
||||
# with a non-None _host_buffer is NOT deployed to any PE — it lives
|
||||
# only on the host. Use `target.copy_(host_tensor)` to scatter the
|
||||
# data into a deployed, sharded target tensor.
|
||||
self._host_buffer: np.ndarray | None = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._ctx_ref is None or self._handle is None:
|
||||
@@ -166,15 +171,85 @@ class Tensor:
|
||||
|
||||
@property
|
||||
def data(self) -> np.ndarray:
|
||||
"""Tensor data as numpy array. Returns actual values when enable_data=True,
|
||||
zeros placeholder otherwise (like an uninitialized tensor)."""
|
||||
if self._memory_store is not None and self._handle is not None:
|
||||
shard = self._handle.shards[0]
|
||||
"""Tensor data as numpy array.
|
||||
|
||||
Gathers all shards into a single full-shape array. Returns actual
|
||||
values when enable_data=True, zeros placeholder otherwise (like an
|
||||
uninitialized tensor). Alias of ``numpy()``.
|
||||
"""
|
||||
return self.numpy()
|
||||
|
||||
def _shard_store_addr(self, shard: TensorShard) -> int:
|
||||
"""MemoryStore key for a shard.
|
||||
|
||||
Kernels read tensors via VA (translated to PA by PE_DMA's MMU when
|
||||
a mapping exists, otherwise the addr is treated as a PA-equivalent
|
||||
key). Tensor I/O therefore writes/reads at ``va_base + offset_bytes``
|
||||
when ``va_base`` is set, falling back to ``shard.pa`` for the
|
||||
VA-less mode used by some legacy paths.
|
||||
"""
|
||||
if self._handle and self._handle.va_base:
|
||||
return self._handle.va_base + shard.offset_bytes
|
||||
return shard.pa
|
||||
|
||||
def numpy(self) -> np.ndarray:
|
||||
"""Return a single numpy array gathered from all shards.
|
||||
|
||||
Mirrors ``torch.Tensor.numpy()``. In kernbench, sharded tensors are
|
||||
gathered into a single full-shape ndarray according to each shard's
|
||||
``offset_bytes`` / ``nbytes`` range.
|
||||
"""
|
||||
np_dtype = _numpy_dtype(self.dtype)
|
||||
# Host-side tensor (created via torch.from_numpy) has no shards.
|
||||
if self._host_buffer is not None:
|
||||
return self._host_buffer.copy()
|
||||
if self._handle is None or self._memory_store is None:
|
||||
return np.zeros(self.shape, dtype=np_dtype)
|
||||
flat = np.zeros(math.prod(self.shape), dtype=np_dtype)
|
||||
for shard in self._handle.shards:
|
||||
start = shard.offset_bytes // self.itemsize
|
||||
count = shard.nbytes // self.itemsize
|
||||
try:
|
||||
return self._memory_store.read("hbm", shard.pa, shape=self.shape, dtype=self.dtype)
|
||||
piece = self._memory_store.read(
|
||||
"hbm", self._shard_store_addr(shard),
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
return np.zeros(self.shape, dtype=_numpy_dtype(self.dtype))
|
||||
continue
|
||||
flat[start : start + count] = (
|
||||
np.asarray(piece, dtype=np_dtype).reshape(-1)[:count]
|
||||
)
|
||||
return flat.reshape(self.shape)
|
||||
|
||||
def copy_(self, source: "Tensor") -> "Tensor":
|
||||
"""In-place copy from another tensor into self.
|
||||
|
||||
Mirrors ``torch.Tensor.copy_()``. If ``source`` is a host tensor
|
||||
(from ``torch.from_numpy``), its ndarray is split across self's
|
||||
shards using each shard's byte range. If ``source`` is a deployed
|
||||
(sharded) tensor, its contents are gathered first and then
|
||||
re-scattered into self's shard layout.
|
||||
|
||||
Shapes must match. Returns self.
|
||||
"""
|
||||
if self._handle is None or self._memory_store is None:
|
||||
raise RuntimeError(
|
||||
f"Tensor '{self.name}' must be deployed before copy_()"
|
||||
)
|
||||
if source.shape != self.shape:
|
||||
raise ValueError(
|
||||
f"copy_ shape mismatch: self={self.shape} source={source.shape}"
|
||||
)
|
||||
np_dtype = _numpy_dtype(self.dtype)
|
||||
arr = source.numpy().astype(np_dtype, copy=False)
|
||||
flat = np.ascontiguousarray(arr).reshape(-1)
|
||||
for shard in self._handle.shards:
|
||||
start = shard.offset_bytes // self.itemsize
|
||||
count = shard.nbytes // self.itemsize
|
||||
piece = flat[start : start + count].copy()
|
||||
self._memory_store.write(
|
||||
"hbm", self._shard_store_addr(shard), piece,
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def itemsize(self) -> int:
|
||||
|
||||
@@ -51,7 +51,42 @@ class DataExecutor:
|
||||
self._execute_math(op)
|
||||
|
||||
def _execute_memory(self, op: OpRecord) -> None:
|
||||
"""Memory ops are already handled by Phase 1 MemoryStore. Skip."""
|
||||
"""Replay memory copy ops in Phase 2 (ADR-0020 + ADR-0023).
|
||||
|
||||
- dma_read: no-op (handle already references HBM source).
|
||||
- dma_write: copy (src_space, src_addr) → (dst_space, dst_addr).
|
||||
Required because Phase 2 may have just produced new data at the
|
||||
source addr (e.g. PE_MATH scratch output).
|
||||
- ipcq_copy: copy across PEs — sender's source → receiver's slot.
|
||||
Required because the source may be a Phase 2 math output, and
|
||||
a downstream math op on the receiver reads from the slot.
|
||||
|
||||
Legacy entries without src/dst metadata are silently skipped.
|
||||
"""
|
||||
p = op.params
|
||||
if op.op_name == "dma_write" or op.op_name == "ipcq_copy":
|
||||
src_space = p.get("src_space")
|
||||
src_addr = p.get("src_addr")
|
||||
dst_space = p.get("dst_space")
|
||||
dst_addr = p.get("dst_addr")
|
||||
if (src_space is None or src_addr is None
|
||||
or dst_space is None or dst_addr is None):
|
||||
return
|
||||
# Prefer the Phase-1-time snapshot (captured at record_end /
|
||||
# outbound) so we don't read from a source that has since been
|
||||
# mutated by another op. Fall back to MemoryStore for sources
|
||||
# that had no Phase 1 data (e.g. math scratch outputs that
|
||||
# only get populated by Phase 2's math replay).
|
||||
data = p.get("snapshot")
|
||||
if data is None:
|
||||
try:
|
||||
data = self.store.read(
|
||||
src_space, src_addr,
|
||||
shape=p.get("shape"), dtype=p.get("dtype"),
|
||||
)
|
||||
except KeyError:
|
||||
return
|
||||
self.store.write(dst_space, dst_addr, data)
|
||||
|
||||
def _execute_gemm(self, op: OpRecord) -> None:
|
||||
"""Execute GEMM: out = a @ b."""
|
||||
@@ -77,18 +112,35 @@ class DataExecutor:
|
||||
"""Execute math op: unary, binary, or reduction."""
|
||||
p = op.params
|
||||
math_op = p.get("op", op.op_name)
|
||||
space = p.get("addr_space", "tcm")
|
||||
dtype = p.get("dtype", "f32")
|
||||
input_addrs = p.get("input_addrs", [])
|
||||
input_shapes = p.get("input_shapes", [])
|
||||
# Per-input space/dtype (ADR-0023 CCL accumulation): math ops can
|
||||
# mix inputs from different MemoryStore spaces (e.g. acc in "hbm",
|
||||
# recv slot in "tcm"). Fall back to legacy single-space mode when
|
||||
# the per-input lists are absent.
|
||||
input_spaces = p.get("input_spaces") or [p.get("addr_space", "tcm")] * len(input_addrs)
|
||||
input_dtypes = p.get("input_dtypes") or [dtype] * len(input_addrs)
|
||||
# Per-input data snapshots (ADR-0020 D6): captured at op_log
|
||||
# record time. Phase 1 has correct values for slot/HBM addrs at
|
||||
# that moment, which lets Phase 2 sidestep the slot-wraparound
|
||||
# races where a later round overwrites a slot before this op
|
||||
# runs in t_start order.
|
||||
snapshots = p.get("input_snapshots") or [None] * len(input_addrs)
|
||||
dst_space = p.get("dst_space", p.get("addr_space", "tcm"))
|
||||
|
||||
inputs = []
|
||||
for addr, shape in zip(input_addrs, input_shapes):
|
||||
inputs.append(self.store.read(space, addr, shape=shape, dtype=dtype))
|
||||
for addr, shape, space, idtype, snap in zip(
|
||||
input_addrs, input_shapes, input_spaces, input_dtypes, snapshots
|
||||
):
|
||||
if snap is not None:
|
||||
inputs.append(snap)
|
||||
else:
|
||||
inputs.append(self.store.read(space, addr, shape=shape, dtype=idtype))
|
||||
|
||||
result = _compute_math(math_op, inputs, p.get("axis"))
|
||||
if result is not None:
|
||||
self.store.write(space, p["dst_addr"], result)
|
||||
self.store.write(dst_space, p["dst_addr"], result)
|
||||
|
||||
def verify(self, expected: dict[tuple[str, int], np.ndarray],
|
||||
rtol: float = 1e-3, atol: float = 1e-3) -> dict[str, bool]:
|
||||
@@ -146,6 +198,14 @@ def _compute_math(op: str, inputs: list[np.ndarray], axis: int | None) -> np.nda
|
||||
if op == "min":
|
||||
return np.min(x, axis=axis, keepdims=True)
|
||||
|
||||
# Softmax (numerically stable)
|
||||
if op == "softmax":
|
||||
ax = axis if axis is not None else -1
|
||||
x_max = np.max(x, axis=ax, keepdims=True)
|
||||
e = np.exp(x - x_max)
|
||||
s = np.sum(e, axis=ax, keepdims=True)
|
||||
return e / s
|
||||
|
||||
# Binary
|
||||
if len(inputs) >= 2:
|
||||
y = inputs[1]
|
||||
@@ -157,9 +217,18 @@ def _compute_math(op: str, inputs: list[np.ndarray], axis: int | None) -> np.nda
|
||||
return x * y
|
||||
if op == "div":
|
||||
return x / y
|
||||
if op == "maximum":
|
||||
return np.maximum(x, y)
|
||||
if op == "minimum":
|
||||
return np.minimum(x, y)
|
||||
|
||||
# Ternary
|
||||
if op == "where" and len(inputs) >= 3:
|
||||
return np.where(inputs[0], inputs[1], inputs[2])
|
||||
if len(inputs) >= 3:
|
||||
if op == "where":
|
||||
return np.where(inputs[0], inputs[1], inputs[2])
|
||||
if op == "fma":
|
||||
return inputs[0] * inputs[1] + inputs[2]
|
||||
if op == "clamp":
|
||||
return np.minimum(np.maximum(inputs[0], inputs[1]), inputs[2])
|
||||
|
||||
return None
|
||||
|
||||
@@ -51,8 +51,12 @@ class GraphEngine:
|
||||
if enable_data:
|
||||
from kernbench.sim_engine.memory_store import MemoryStore
|
||||
from kernbench.sim_engine.op_log import OpLogger
|
||||
self._op_logger = OpLogger()
|
||||
self._memory_store = MemoryStore()
|
||||
self._op_logger = OpLogger(memory_store=self._memory_store)
|
||||
# Cursor for incremental Phase 2 replay (ADR-0020 D6).
|
||||
# SimPy env.now is monotonic so newly logged records always sort
|
||||
# to the tail; the cursor remains valid across waits.
|
||||
self._data_cursor = 0
|
||||
|
||||
ctx = ComponentContext(
|
||||
router=self._router,
|
||||
@@ -147,11 +151,60 @@ class GraphEngine:
|
||||
self._env.process(self._process(str(handle), request, event))
|
||||
return handle
|
||||
|
||||
def _flush_data_phase(self) -> None:
|
||||
"""Replay newly recorded op_log entries through DataExecutor.
|
||||
|
||||
ADR-0020 D6 Phase 2: when data tracking is enabled, run DataExecutor
|
||||
on records added since the last flush so that callers reading
|
||||
MemoryStore between launches observe correct (compute-replayed)
|
||||
tensor data.
|
||||
|
||||
Cursor-based incremental replay is necessary because Phase 2 is
|
||||
NOT idempotent across full re-runs: a math op writes a TCM scratch
|
||||
addr, a later dma_write copies that scratch into HBM[X], and an
|
||||
even-later math op may then read HBM[X]. Re-running everything
|
||||
from scratch would let the second pass's first math op read the
|
||||
already-overwritten HBM[X] instead of the original input.
|
||||
"""
|
||||
if self._op_logger is None or self._memory_store is None:
|
||||
return
|
||||
records = self._op_logger.records # sorted by t_start (stable)
|
||||
if self._data_cursor >= len(records):
|
||||
return
|
||||
new_records = records[self._data_cursor:]
|
||||
from kernbench.sim_engine.data_executor import DataExecutor
|
||||
DataExecutor(new_records, self._memory_store).run()
|
||||
self._data_cursor = len(records)
|
||||
|
||||
def wait(self, handle: RequestHandle) -> None:
|
||||
key = str(handle)
|
||||
event = self._events[key]
|
||||
if not event.triggered:
|
||||
self._env.run(until=event)
|
||||
try:
|
||||
self._env.run(until=event)
|
||||
except (simpy.core.EmptySchedule, RuntimeError) as exc:
|
||||
# SimPy raises EmptySchedule directly OR (in newer simpy)
|
||||
# wraps it as a RuntimeError("No scheduled events left ...").
|
||||
# Either case while our event is still pending → IPCQ deadlock.
|
||||
msg = str(exc)
|
||||
is_deadlock = (
|
||||
isinstance(exc, simpy.core.EmptySchedule)
|
||||
or "No scheduled events left" in msg
|
||||
)
|
||||
if not is_deadlock:
|
||||
raise
|
||||
from kernbench.ccl.diagnostics import IpcqDeadlock, pointer_dump
|
||||
dump = pointer_dump(self)
|
||||
if dump.strip():
|
||||
raise IpcqDeadlock(
|
||||
"IPCQ deadlock: simulation schedule empty while "
|
||||
f"request {handle!r} is still pending.\n"
|
||||
f"Pointer state:\n{dump}"
|
||||
) from None
|
||||
raise
|
||||
# ADR-0020: replay newly logged ops so the caller observes
|
||||
# post-Phase-2 tensor state from MemoryStore.
|
||||
self._flush_data_phase()
|
||||
|
||||
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
|
||||
return self._results[str(handle)]
|
||||
|
||||
@@ -29,9 +29,13 @@ class OpLogger:
|
||||
Records are maintained in t_start stable ordering (insertion order).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, memory_store: Any | None = None) -> None:
|
||||
self._records: list[OpRecord] = []
|
||||
self._pending: dict[int, dict[str, Any]] = {} # msg id → partial record
|
||||
# Optional MemoryStore reference. When set, math op records capture
|
||||
# input data snapshots at record_end time so Phase 2 replay does
|
||||
# not depend on slot/scratch addrs surviving until math runs.
|
||||
self._memory_store = memory_store
|
||||
|
||||
@property
|
||||
def records(self) -> list[OpRecord]:
|
||||
@@ -53,6 +57,38 @@ class OpLogger:
|
||||
if pending is None:
|
||||
return
|
||||
op_kind, op_name, params = _extract_op_info(msg)
|
||||
# Snapshot data at record time so Phase 2 replay sidesteps
|
||||
# downstream mutations of source addrs (e.g. a tl.store that
|
||||
# overwrites HBM after a load handle was sent, or a slot that
|
||||
# gets reused on the next ring round).
|
||||
if self._memory_store is not None:
|
||||
if op_kind == "math":
|
||||
snaps: list[Any] = []
|
||||
for addr, shape, space, idtype in zip(
|
||||
params.get("input_addrs", []),
|
||||
params.get("input_shapes", []),
|
||||
params.get("input_spaces", []),
|
||||
params.get("input_dtypes", []),
|
||||
):
|
||||
try:
|
||||
arr = self._memory_store.read(
|
||||
space, addr, shape=shape, dtype=idtype,
|
||||
)
|
||||
snaps.append(arr.copy() if hasattr(arr, "copy") else arr)
|
||||
except Exception:
|
||||
snaps.append(None)
|
||||
params["input_snapshots"] = snaps
|
||||
elif op_name == "dma_write":
|
||||
try:
|
||||
arr = self._memory_store.read(
|
||||
params["src_space"], params["src_addr"],
|
||||
shape=params.get("shape"), dtype=params.get("dtype"),
|
||||
)
|
||||
params["snapshot"] = (
|
||||
arr.copy() if hasattr(arr, "copy") else arr
|
||||
)
|
||||
except Exception:
|
||||
params["snapshot"] = None
|
||||
self._records.append(OpRecord(
|
||||
t_start=pending["t_start"],
|
||||
t_end=t,
|
||||
@@ -62,6 +98,45 @@ class OpLogger:
|
||||
params=params,
|
||||
))
|
||||
|
||||
def record_copy(
|
||||
self, t_start: float, t_end: float, component_id: str,
|
||||
src_space: str, src_addr: int,
|
||||
dst_space: str, dst_addr: int,
|
||||
shape: tuple[int, ...], dtype: str, nbytes: int,
|
||||
) -> None:
|
||||
"""Record a memory copy op for Phase 2 replay (ADR-0023 + ADR-0020).
|
||||
|
||||
Used by PE_DMA at outbound (sender) time: the snapshot captures
|
||||
the source data at the moment the send was issued, so Phase 2
|
||||
replay does not see later mutations of the source addr (e.g. a
|
||||
tl.store that runs after the recv at the sender).
|
||||
|
||||
For sources whose data is not yet materialized in Phase 1 (math
|
||||
scratch outputs), the snapshot is None and Phase 2 falls back to
|
||||
reading from MemoryStore — by which point the corresponding math
|
||||
op has been replayed and the scratch addr is populated.
|
||||
"""
|
||||
snap = None
|
||||
if self._memory_store is not None:
|
||||
try:
|
||||
arr = self._memory_store.read(
|
||||
src_space, src_addr, shape=shape, dtype=dtype,
|
||||
)
|
||||
snap = arr.copy() if hasattr(arr, "copy") else arr
|
||||
except Exception:
|
||||
snap = None
|
||||
self._records.append(OpRecord(
|
||||
t_start=t_start, t_end=t_end,
|
||||
component_id=component_id,
|
||||
op_kind="memory", op_name="ipcq_copy",
|
||||
params={
|
||||
"src_space": src_space, "src_addr": src_addr,
|
||||
"dst_space": dst_space, "dst_addr": dst_addr,
|
||||
"shape": shape, "dtype": dtype, "nbytes": nbytes,
|
||||
"snapshot": snap,
|
||||
},
|
||||
))
|
||||
|
||||
|
||||
def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
|
||||
"""Extract op_kind, op_name, params from a data_op message."""
|
||||
@@ -76,6 +151,11 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
|
||||
}
|
||||
if isinstance(msg, DmaWriteCmd):
|
||||
return "memory", "dma_write", {
|
||||
"src_space": getattr(msg.handle, "space", "tcm"),
|
||||
"src_addr": msg.handle.addr,
|
||||
"shape": msg.handle.shape,
|
||||
"dtype": msg.handle.dtype,
|
||||
"dst_space": "hbm",
|
||||
"dst_addr": msg.dst_addr,
|
||||
"nbytes": msg.nbytes,
|
||||
"handle_id": msg.handle.id,
|
||||
@@ -96,7 +176,10 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
|
||||
return "math", msg.op, {
|
||||
"input_addrs": [h.addr for h in msg.inputs],
|
||||
"input_shapes": [h.shape for h in msg.inputs],
|
||||
"input_spaces": [getattr(h, "space", "tcm") for h in msg.inputs],
|
||||
"input_dtypes": [h.dtype for h in msg.inputs],
|
||||
"dst_addr": msg.out.addr,
|
||||
"dst_space": getattr(msg.out, "space", "tcm"),
|
||||
"shape_out": msg.out.shape,
|
||||
"dtype": msg.out.dtype,
|
||||
"axis": msg.axis,
|
||||
|
||||
@@ -25,6 +25,7 @@ _PE_COMP_OFFSETS = {
|
||||
"pe_math": (0.0, 0.15),
|
||||
"pe_mmu": (0.15, -0.15),
|
||||
"pe_tcm": (0.3, 0.0),
|
||||
"pe_ipcq": (-0.15, 0.15),
|
||||
}
|
||||
|
||||
|
||||
@@ -698,6 +699,20 @@ def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None:
|
||||
kind="pe_internal",
|
||||
))
|
||||
|
||||
# PE_IPCQ edges (ADR-0023 D1, D9 D10)
|
||||
ipcq_edges = [
|
||||
("pe_cpu", "pe_ipcq", "cpu_to_ipcq_mm"), # IpcqRequest
|
||||
("pe_ipcq", "pe_dma", "ipcq_to_dma_mm"), # IpcqDmaToken outbound
|
||||
("pe_dma", "pe_ipcq", "dma_to_ipcq_mm"), # IpcqMetaArrival inbound
|
||||
]
|
||||
for src_c, dst_c, mm_key in ipcq_edges:
|
||||
if mm_key in pe_links:
|
||||
edges.append(Edge(
|
||||
src=f"{pp}.{src_c}", dst=f"{pp}.{dst_c}",
|
||||
distance_mm=pe_links[mm_key],
|
||||
kind="pe_internal",
|
||||
))
|
||||
|
||||
|
||||
# ── Inter-cube / IO / system edges ──────────────────────────────────
|
||||
|
||||
@@ -765,7 +780,13 @@ def _add_io_to_cube_edges(
|
||||
def _add_system_to_io_edges(
|
||||
edges: list[Edge], sp: str, sip_spec: dict, system: dict,
|
||||
) -> None:
|
||||
"""Add fabric switch → IO chiplet PCIe edges."""
|
||||
"""Add bidirectional fabric switch ↔ IO chiplet PCIe edges.
|
||||
|
||||
Both directions are needed:
|
||||
switch → pcie_ep for host→device traffic (memory writes, kernel launch)
|
||||
pcie_ep → switch for device-side outbound traffic (cross-SIP IPCQ
|
||||
send between PE_DMAs through the system switch).
|
||||
"""
|
||||
sw_id = "fabric.switch0"
|
||||
sys_link = system["links"]["io_ep_to_switch"]
|
||||
for inst in sip_spec["iochiplet"]["instances"]:
|
||||
@@ -776,6 +797,12 @@ def _add_system_to_io_edges(
|
||||
bw_gbs=sys_link["bw_gbs_per_ep"],
|
||||
kind="pcie",
|
||||
))
|
||||
edges.append(Edge(
|
||||
src=pcie_ep_id, dst=sw_id,
|
||||
distance_mm=sys_link["distance_mm"],
|
||||
bw_gbs=sys_link["bw_gbs_per_ep"],
|
||||
kind="pcie",
|
||||
))
|
||||
|
||||
|
||||
# ── View builders ────────────────────────────────────────────────────
|
||||
@@ -1113,13 +1140,14 @@ def _build_pe_view(spec: dict) -> ViewGraph:
|
||||
"pe_math": (7.0, 6.5),
|
||||
"pe_mmu": (4.0, 1.5),
|
||||
"pe_tcm": (10.0, 4.0),
|
||||
"pe_ipcq": (4.0, 6.5),
|
||||
}
|
||||
|
||||
nodes: dict[str, Node] = {}
|
||||
view_edges: list[Edge] = []
|
||||
|
||||
for comp_name, comp_spec in pe_tmpl["components"].items():
|
||||
px, py = positions[comp_name]
|
||||
px, py = positions.get(comp_name, (1.0, 1.0))
|
||||
nodes[comp_name] = Node(
|
||||
id=comp_name, kind=comp_spec["kind"], impl=comp_spec["impl"],
|
||||
attrs=comp_spec["attrs"], pos_mm=(px, py),
|
||||
|
||||
@@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any
|
||||
import simpy
|
||||
from greenlet import greenlet
|
||||
|
||||
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqRequest, IpcqSendCmd, RecvFuture
|
||||
from kernbench.common.pe_commands import (
|
||||
CompletionHandle,
|
||||
CompositeCmd,
|
||||
@@ -51,6 +52,9 @@ class KernelRunner:
|
||||
out_ports: dict[str, simpy.Store],
|
||||
store: MemoryStore | None = None,
|
||||
num_cubes: int = 1,
|
||||
ipcq_id: str | None = None,
|
||||
scratch_base: int = 0,
|
||||
scratch_size: int = 1 << 20,
|
||||
) -> None:
|
||||
self._pe_prefix = pe_prefix
|
||||
self._pe_idx = pe_idx
|
||||
@@ -61,6 +65,13 @@ class KernelRunner:
|
||||
self._out_ports = out_ports
|
||||
self._store = store
|
||||
self._parent: greenlet | None = None
|
||||
# Optional IPCQ port (ADR-0023). If None, IPCQ commands raise.
|
||||
self._ipcq_id = ipcq_id or f"{pe_prefix}.pe_ipcq"
|
||||
# PE-local scratch for compute output TensorHandles (ADR-0020 D3
|
||||
# extension). The TLContext allocates from this pool when math/dot
|
||||
# ops produce a result that may later be used as a send/store source.
|
||||
self._scratch_base = scratch_base
|
||||
self._scratch_size = scratch_size
|
||||
|
||||
def run(
|
||||
self,
|
||||
@@ -89,7 +100,10 @@ class KernelRunner:
|
||||
num_cubes=self._num_cubes,
|
||||
dispatch_cycles=0,
|
||||
runner=self,
|
||||
scratch_base=self._scratch_base,
|
||||
scratch_size=self._scratch_size,
|
||||
)
|
||||
self._tl = tl # exposed so switch_to_simpy can re-set on restore
|
||||
|
||||
def _kernel_entry():
|
||||
TLContext._set_active(tl) # type: ignore[attr-defined]
|
||||
@@ -103,13 +117,20 @@ class KernelRunner:
|
||||
pending: dict[str, simpy.Event] = {}
|
||||
composite_results: list[dict] = []
|
||||
|
||||
# Helper: set our tl as active just before resuming the kernel.
|
||||
# Multiple PE kernel runners share the same thread-local; without
|
||||
# this, another runner's kernel may have left a different context.
|
||||
def _switch_kernel(*args):
|
||||
TLContext._set_active(tl) # type: ignore[attr-defined]
|
||||
return g.switch(*args)
|
||||
|
||||
# Start kernel — first switch returns first command (or None if kernel is done)
|
||||
cmd = g.switch()
|
||||
cmd = _switch_kernel()
|
||||
|
||||
while cmd is not None:
|
||||
if isinstance(cmd, PeCpuOverheadCmd):
|
||||
yield env.timeout(cmd.cycles)
|
||||
cmd = g.switch()
|
||||
cmd = _switch_kernel()
|
||||
|
||||
elif isinstance(cmd, WaitCmd):
|
||||
if cmd.handle is not None:
|
||||
@@ -120,7 +141,7 @@ class KernelRunner:
|
||||
for evt in pending.values():
|
||||
yield evt
|
||||
pending.clear()
|
||||
cmd = g.switch()
|
||||
cmd = _switch_kernel()
|
||||
|
||||
elif isinstance(cmd, DmaReadCmd):
|
||||
# Dispatch DMA through SimPy components
|
||||
@@ -141,10 +162,12 @@ class KernelRunner:
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
cmd = g.switch(data)
|
||||
cmd = _switch_kernel(data)
|
||||
|
||||
elif isinstance(cmd, DmaWriteCmd):
|
||||
# Write to MemoryStore first (visibility = issue, ADR-0020 D3)
|
||||
# Write to MemoryStore first (visibility = issue, ADR-0020 D3).
|
||||
# When data is None (e.g. timing-only TensorHandle math result),
|
||||
# this is a no-op; Phase 2 dma_write replay handles those.
|
||||
if self._store is not None and cmd.handle.data is not None:
|
||||
self._store.write("hbm", cmd.dst_addr, cmd.handle.data)
|
||||
|
||||
@@ -154,7 +177,7 @@ class KernelRunner:
|
||||
)
|
||||
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||
yield done_evt
|
||||
cmd = g.switch()
|
||||
cmd = _switch_kernel()
|
||||
|
||||
elif isinstance(cmd, CompositeCmd):
|
||||
# Non-blocking composite
|
||||
@@ -165,7 +188,7 @@ class KernelRunner:
|
||||
composite_results.append(pe_txn.result_data)
|
||||
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||
pending[cmd.completion.id] = done_evt
|
||||
cmd = g.switch()
|
||||
cmd = _switch_kernel()
|
||||
|
||||
elif isinstance(cmd, (GemmCmd, MathCmd)):
|
||||
# Blocking compute command
|
||||
@@ -175,7 +198,90 @@ class KernelRunner:
|
||||
)
|
||||
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||
yield done_evt
|
||||
cmd = g.switch()
|
||||
cmd = _switch_kernel()
|
||||
|
||||
elif isinstance(cmd, IpcqSendCmd):
|
||||
# Forward IpcqRequest to PE_IPCQ, wait for done
|
||||
if self._ipcq_id not in self._out_ports:
|
||||
raise RuntimeError(
|
||||
f"PE_IPCQ port {self._ipcq_id!r} not wired to runner"
|
||||
)
|
||||
done_evt = env.event()
|
||||
req = IpcqRequest(command=cmd, done=done_evt)
|
||||
yield self._out_ports[self._ipcq_id].put(req)
|
||||
yield done_evt
|
||||
cmd = _switch_kernel()
|
||||
|
||||
elif isinstance(cmd, IpcqRecvCmd):
|
||||
if self._ipcq_id not in self._out_ports:
|
||||
raise RuntimeError(
|
||||
f"PE_IPCQ port {self._ipcq_id!r} not wired to runner"
|
||||
)
|
||||
done_evt = env.event()
|
||||
req = IpcqRequest(command=cmd, done=done_evt)
|
||||
yield self._out_ports[self._ipcq_id].put(req)
|
||||
yield done_evt
|
||||
# Read actual data from MemoryStore at the slot address
|
||||
data = None
|
||||
src_space = req.result_data.get("src_space", "tcm")
|
||||
src_addr = req.result_data.get("src_addr", 0)
|
||||
if self._store is not None:
|
||||
try:
|
||||
data = self._store.read(
|
||||
src_space, src_addr,
|
||||
shape=cmd.shape, dtype=cmd.dtype,
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
# Build result dict for tl.recv to wrap in TensorHandle
|
||||
result = {
|
||||
"data": data,
|
||||
"src_space": src_space,
|
||||
"src_addr": src_addr,
|
||||
"direction": req.result_data.get("direction", cmd.direction),
|
||||
"dtype": cmd.dtype,
|
||||
"shape": cmd.shape,
|
||||
"nbytes": req.result_data.get("nbytes", 0),
|
||||
}
|
||||
cmd = _switch_kernel(result)
|
||||
|
||||
elif isinstance(cmd, tuple) and len(cmd) == 2 and cmd[0] == "recv_async":
|
||||
# Non-blocking recv: post the IpcqRequest now, store the
|
||||
# event in the future, return None to kernel.
|
||||
future: RecvFuture = cmd[1]
|
||||
done_evt = env.event()
|
||||
req = IpcqRequest(command=future.cmd, done=done_evt)
|
||||
future.request = req
|
||||
future.event = done_evt
|
||||
yield self._out_ports[self._ipcq_id].put(req)
|
||||
cmd = _switch_kernel(None)
|
||||
|
||||
elif isinstance(cmd, tuple) and len(cmd) == 2 and cmd[0] == "recv_wait":
|
||||
future = cmd[1]
|
||||
if not future.event.triggered:
|
||||
yield future.event
|
||||
req = future.request
|
||||
src_space = req.result_data.get("src_space", "tcm")
|
||||
src_addr = req.result_data.get("src_addr", 0)
|
||||
data = None
|
||||
if self._store is not None:
|
||||
try:
|
||||
data = self._store.read(
|
||||
src_space, src_addr,
|
||||
shape=future.cmd.shape, dtype=future.cmd.dtype,
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
result = {
|
||||
"data": data,
|
||||
"src_space": src_space,
|
||||
"src_addr": src_addr,
|
||||
"direction": req.result_data.get("direction", future.cmd.direction),
|
||||
"dtype": future.cmd.dtype,
|
||||
"shape": future.cmd.shape,
|
||||
"nbytes": req.result_data.get("nbytes", 0),
|
||||
}
|
||||
cmd = _switch_kernel(result)
|
||||
|
||||
else:
|
||||
# Unknown command — pass through as blocking
|
||||
@@ -185,7 +291,7 @@ class KernelRunner:
|
||||
)
|
||||
yield self._out_ports[self._scheduler_id].put(pe_txn)
|
||||
yield done_evt
|
||||
cmd = g.switch()
|
||||
cmd = _switch_kernel()
|
||||
|
||||
# Wait remaining pending composites
|
||||
for evt in pending.values():
|
||||
|
||||
@@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
from kernbench.common.ipcq_types import IpcqRecvCmd, IpcqSendCmd, RecvFuture
|
||||
from kernbench.common.pe_commands import (
|
||||
CompletionHandle,
|
||||
CompositeCmd,
|
||||
@@ -55,6 +56,8 @@ class TLContext:
|
||||
runner: Any = None,
|
||||
cube_id: int = 0,
|
||||
num_cubes: int = 1,
|
||||
scratch_base: int = 0,
|
||||
scratch_size: int = 1 << 20, # 1 MiB per kernel invocation
|
||||
) -> None:
|
||||
self._pe_id = pe_id
|
||||
self._num_programs = num_programs
|
||||
@@ -65,6 +68,33 @@ class TLContext:
|
||||
self._handle_counter = 0
|
||||
self._completion_counter = 0
|
||||
self._runner = runner # KernelRunner for greenlet mode (ADR-0020 D3)
|
||||
# PE-local scratch allocator for math/compute output handles.
|
||||
# Each binary/unary/reduction op auto-allocates a unique addr from
|
||||
# this pool so the resulting TensorHandle can be the source of a
|
||||
# later tl.send / tl.store. Cursor resets on every kernel invocation.
|
||||
self._scratch_base = scratch_base
|
||||
self._scratch_size = scratch_size
|
||||
self._scratch_cursor = 0
|
||||
|
||||
def _scratch_alloc(self, nbytes: int) -> int:
|
||||
"""Allocate a unique scratch address for an output TensorHandle.
|
||||
|
||||
Returns 0 if no scratch base was configured (e.g. command-list mode);
|
||||
in that case the resulting handle has addr=0 and cannot be used as a
|
||||
send/store source. Greenlet/runner mode always supplies a base.
|
||||
"""
|
||||
if self._scratch_base == 0:
|
||||
return 0
|
||||
# 16-byte alignment
|
||||
aligned = (nbytes + 15) & ~15
|
||||
addr = self._scratch_base + self._scratch_cursor
|
||||
self._scratch_cursor += aligned
|
||||
if self._scratch_cursor > self._scratch_size:
|
||||
raise RuntimeError(
|
||||
f"TLContext scratch overflow: requested {nbytes}B, "
|
||||
f"used {self._scratch_cursor}/{self._scratch_size}B"
|
||||
)
|
||||
return addr
|
||||
|
||||
@property
|
||||
def commands(self) -> list[PeCommand]:
|
||||
@@ -93,11 +123,30 @@ class TLContext:
|
||||
|
||||
def _make_handle(
|
||||
self, addr: int, shape: tuple[int, ...], dtype: str,
|
||||
space: str = "tcm",
|
||||
) -> TensorHandle:
|
||||
return TensorHandle(
|
||||
id=self._next_handle_id(),
|
||||
addr=addr, shape=shape, dtype=dtype,
|
||||
nbytes=self._nbytes(shape, dtype),
|
||||
space=space,
|
||||
)
|
||||
|
||||
def _make_compute_out(
|
||||
self, shape: tuple[int, ...], dtype: str,
|
||||
) -> TensorHandle:
|
||||
"""Allocate an output TensorHandle in PE-local scratch (TCM space).
|
||||
|
||||
Used by math/compute ops so the result has a real address that can
|
||||
be the source of a later send/store. The data field stays None in
|
||||
Phase 1 — Phase 2 DataExecutor fills the actual ndarray.
|
||||
"""
|
||||
nbytes = self._nbytes(shape, dtype)
|
||||
addr = self._scratch_alloc(nbytes)
|
||||
return TensorHandle(
|
||||
id=self._next_handle_id(),
|
||||
addr=addr, shape=shape, dtype=dtype,
|
||||
nbytes=nbytes, space="tcm",
|
||||
)
|
||||
|
||||
# ── Reference (no DMA, metadata only) ────────────────────────
|
||||
@@ -124,20 +173,26 @@ class TLContext:
|
||||
def load(
|
||||
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
|
||||
) -> TensorHandle:
|
||||
"""Load tensor from HBM to TCM. Returns TensorHandle.
|
||||
"""Load tensor from HBM. Returns TensorHandle pointing at HBM[ptr].
|
||||
|
||||
In greenlet mode: returns TensorHandle with actual numpy data.
|
||||
In command-list mode: returns TensorHandle with data=None.
|
||||
|
||||
The returned handle's ``space`` is "hbm" so subsequent ops (math,
|
||||
send, store) using this handle as a source resolve via MemoryStore
|
||||
at ``(hbm, ptr)`` — which is where the load's underlying data
|
||||
actually lives in Phase 2 storage.
|
||||
"""
|
||||
self._emit_dispatch_overhead()
|
||||
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype)
|
||||
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype, space="hbm")
|
||||
cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes)
|
||||
data = self._emit(cmd)
|
||||
if data is not None:
|
||||
# Greenlet mode: attach real data to handle
|
||||
# Greenlet mode: attach real data to handle (preserve space)
|
||||
return TensorHandle(
|
||||
id=handle.id, addr=handle.addr, shape=handle.shape,
|
||||
dtype=handle.dtype, nbytes=handle.nbytes, data=data,
|
||||
space=handle.space,
|
||||
)
|
||||
return handle
|
||||
|
||||
@@ -162,7 +217,7 @@ class TLContext:
|
||||
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
|
||||
out_shape = (*a.shape[:-2], m, n)
|
||||
out_dtype = a.dtype
|
||||
out = self._make_handle(addr=0, shape=out_shape, dtype=out_dtype)
|
||||
out = self._make_compute_out(shape=out_shape, dtype=out_dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._emit(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
|
||||
return out
|
||||
@@ -170,7 +225,7 @@ class TLContext:
|
||||
# ── MATH Engine: unary (blocking) ─────────────────────────────
|
||||
|
||||
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
||||
out = self._make_handle(addr=0, shape=x.shape, dtype=x.dtype)
|
||||
out = self._make_compute_out(shape=x.shape, dtype=x.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._emit(MathCmd(op=op, inputs=(x,), out=out))
|
||||
return out
|
||||
@@ -203,7 +258,7 @@ class TLContext:
|
||||
) -> TensorHandle:
|
||||
out_shape = list(x.shape)
|
||||
out_shape[axis] = 1
|
||||
out = self._make_handle(addr=0, shape=tuple(out_shape), dtype=x.dtype)
|
||||
out = self._make_compute_out(shape=tuple(out_shape), dtype=x.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._emit(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
|
||||
return out
|
||||
@@ -222,7 +277,7 @@ class TLContext:
|
||||
def _binary_math(
|
||||
self, op: str, a: TensorHandle, b: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
|
||||
out = self._make_compute_out(shape=a.shape, dtype=a.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._emit(MathCmd(op=op, inputs=(a, b), out=out))
|
||||
return out
|
||||
@@ -230,15 +285,67 @@ class TLContext:
|
||||
def where(
|
||||
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
out = self._make_handle(addr=0, shape=a.shape, dtype=a.dtype)
|
||||
out = self._make_compute_out(shape=a.shape, dtype=a.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._emit(MathCmd(op="where", inputs=(cond, a, b), out=out))
|
||||
return out
|
||||
|
||||
def maximum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
"""Element-wise max of two tensors (real Triton: tl.maximum)."""
|
||||
return self._binary_math("maximum", a, b)
|
||||
|
||||
def minimum(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
"""Element-wise min of two tensors (real Triton: tl.minimum)."""
|
||||
return self._binary_math("minimum", a, b)
|
||||
|
||||
def fma(
|
||||
self, a: TensorHandle, b: TensorHandle, c: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
"""Fused multiply-add: a * b + c (real Triton: tl.fma)."""
|
||||
out = self._make_compute_out(shape=a.shape, dtype=a.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._emit(MathCmd(op="fma", inputs=(a, b, c), out=out))
|
||||
return out
|
||||
|
||||
def clamp(
|
||||
self,
|
||||
x: TensorHandle,
|
||||
min: TensorHandle,
|
||||
max: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
"""Clamp x to [min, max] (real Triton: tl.clamp)."""
|
||||
out = self._make_compute_out(shape=x.shape, dtype=x.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._emit(MathCmd(op="clamp", inputs=(x, min, max), out=out))
|
||||
return out
|
||||
|
||||
def softmax(self, x: TensorHandle, axis: int = -1) -> TensorHandle:
|
||||
"""Numerically-stable softmax along ``axis`` (real Triton: tl.softmax).
|
||||
|
||||
Implemented as a single MathCmd (op="softmax") so timing accounts
|
||||
for one MATH dispatch; Phase 2 DataExecutor expands it to the
|
||||
canonical (x - max) → exp → sum → div sequence.
|
||||
"""
|
||||
out = self._make_compute_out(shape=x.shape, dtype=x.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._emit(MathCmd(op="softmax", inputs=(x,), out=out, axis=axis))
|
||||
return out
|
||||
|
||||
# ── Scalar helpers (real Triton: tl.cdiv etc.) ────────────────
|
||||
|
||||
@staticmethod
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
"""Ceiling division: (a + b - 1) // b (real Triton: tl.cdiv).
|
||||
|
||||
Used by host/kernel grid math; not a tensor op, so no MathCmd
|
||||
is emitted. Mirrors triton.cdiv.
|
||||
"""
|
||||
return -(-int(a) // int(b))
|
||||
|
||||
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
|
||||
|
||||
def program_id(self, axis: int = 0) -> int:
|
||||
"""Return program instance index.
|
||||
"""Return program instance index (ADR-0022).
|
||||
|
||||
axis=0: local PE id within cube.
|
||||
axis=1: cube id.
|
||||
@@ -248,7 +355,7 @@ class TLContext:
|
||||
return self._pe_id
|
||||
|
||||
def num_programs(self, axis: int = 0) -> int:
|
||||
"""Return total number of program instances.
|
||||
"""Return total number of program instances (ADR-0022).
|
||||
|
||||
axis=0: num PEs per cube.
|
||||
axis=1: num cubes.
|
||||
@@ -284,6 +391,119 @@ class TLContext:
|
||||
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
|
||||
)
|
||||
|
||||
# ── IPCQ (CCL) collective primitives (ADR-0023 D4) ────────────
|
||||
|
||||
def send(
|
||||
self,
|
||||
dir: str,
|
||||
src: TensorHandle | None = None,
|
||||
*,
|
||||
src_addr: int | None = None,
|
||||
nbytes: int | None = None,
|
||||
shape: tuple[int, ...] | None = None,
|
||||
dtype: str = "f16",
|
||||
space: str = "tcm",
|
||||
) -> None:
|
||||
"""Send tensor data to the peer in the given direction.
|
||||
|
||||
Two calling forms:
|
||||
tl.send(dir, handle) # use handle's metadata
|
||||
tl.send(dir, src_addr=..., nbytes=..., shape=..., dtype=..., space=...)
|
||||
|
||||
Blocking: returns when PE_IPCQ has accepted the request and
|
||||
forwarded the IpcqDmaToken to PE_DMA. Backpressure may apply.
|
||||
"""
|
||||
if src is not None:
|
||||
src_addr = src.addr
|
||||
nbytes = src.nbytes
|
||||
shape = src.shape
|
||||
dtype = src.dtype
|
||||
space = getattr(src, "space", space)
|
||||
if src_addr is None or nbytes is None or shape is None:
|
||||
raise ValueError("tl.send: provide either a TensorHandle or src_addr/nbytes/shape")
|
||||
self._emit_dispatch_overhead()
|
||||
cmd = IpcqSendCmd(
|
||||
direction=dir,
|
||||
src_addr=src_addr, src_space=space,
|
||||
nbytes=nbytes, shape=shape, dtype=dtype,
|
||||
handle_id=self._next_handle_id(),
|
||||
)
|
||||
self._emit(cmd)
|
||||
|
||||
def recv(
|
||||
self,
|
||||
dir: str | None = None,
|
||||
shape: tuple[int, ...] = (),
|
||||
dtype: str = "f16",
|
||||
space: str = "tcm",
|
||||
dst_addr: int | None = None,
|
||||
dst_space: str | None = None,
|
||||
) -> TensorHandle:
|
||||
"""Receive tensor data from a peer.
|
||||
|
||||
Args:
|
||||
dir: specific direction (e.g. "W"), or None for round-robin.
|
||||
shape, dtype: expected tensor metadata.
|
||||
dst_addr / dst_space: if both are provided, the slot data is
|
||||
copied to (dst_space, dst_addr) before the handle is
|
||||
returned ("copy_to_dst" mode). Otherwise the slot address
|
||||
is returned directly ("return_slot" mode).
|
||||
|
||||
Returns:
|
||||
TensorHandle pointing to the slot (or dst) where the data has
|
||||
arrived. In greenlet/runner mode, ``handle.data`` carries the
|
||||
actual ndarray; in command-list mode the handle is a placeholder.
|
||||
"""
|
||||
self._emit_dispatch_overhead()
|
||||
if dst_addr is not None and dst_space is not None:
|
||||
cmd = IpcqRecvCmd(
|
||||
direction=dir,
|
||||
shape=shape, dtype=dtype,
|
||||
handle_id=self._next_handle_id(),
|
||||
recv_mode="copy_to_dst",
|
||||
dst_addr=dst_addr, dst_space=dst_space,
|
||||
)
|
||||
else:
|
||||
cmd = IpcqRecvCmd(
|
||||
direction=dir,
|
||||
shape=shape, dtype=dtype,
|
||||
handle_id=self._next_handle_id(),
|
||||
)
|
||||
result = self._emit(cmd)
|
||||
if isinstance(result, dict):
|
||||
slot_addr = int(result.get("src_addr", 0))
|
||||
slot_space = str(result.get("src_space", "tcm"))
|
||||
data = result.get("data")
|
||||
return TensorHandle(
|
||||
id=self._next_handle_id(),
|
||||
addr=slot_addr,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
nbytes=self._nbytes(shape, dtype),
|
||||
data=data,
|
||||
space=slot_space,
|
||||
)
|
||||
return self._make_handle(addr=0, shape=shape, dtype=dtype)
|
||||
|
||||
def recv_async(
|
||||
self,
|
||||
dir: str,
|
||||
shape: tuple[int, ...] = (),
|
||||
dtype: str = "f16",
|
||||
) -> "RecvFuture":
|
||||
"""Non-blocking recv. Returns a future to pass into ``tl.wait``."""
|
||||
self._emit_dispatch_overhead()
|
||||
cmd = IpcqRecvCmd(
|
||||
direction=dir,
|
||||
shape=shape, dtype=dtype,
|
||||
handle_id=self._next_handle_id(),
|
||||
blocking=False,
|
||||
)
|
||||
future = RecvFuture(cmd=cmd)
|
||||
if self._runner is not None:
|
||||
self._runner.switch_to_simpy(("recv_async", future))
|
||||
return future
|
||||
|
||||
# ── Composite + Control ───────────────────────────────────────
|
||||
|
||||
def composite(
|
||||
@@ -316,9 +536,40 @@ class TLContext:
|
||||
))
|
||||
return completion
|
||||
|
||||
def wait(self, handle: CompletionHandle | None = None) -> None:
|
||||
"""Wait for a specific composite or all pending composites."""
|
||||
def wait(self, handle: "CompletionHandle | RecvFuture | None" = None) -> Any:
|
||||
"""Wait for a composite, a recv future, or all pending composites.
|
||||
|
||||
- ``CompletionHandle`` (or None): wait for composite completion.
|
||||
- ``RecvFuture``: wait for a non-blocking ``recv_async`` to finish.
|
||||
Returns the resolved ``TensorHandle``.
|
||||
"""
|
||||
if isinstance(handle, RecvFuture):
|
||||
if handle.resolved:
|
||||
return handle.result
|
||||
if self._runner is None:
|
||||
raise RuntimeError(
|
||||
"tl.wait(RecvFuture) requires runner mode (greenlet)"
|
||||
)
|
||||
result_dict = self._runner.switch_to_simpy(("recv_wait", handle))
|
||||
slot_addr = int(result_dict.get("src_addr", 0))
|
||||
slot_space = str(result_dict.get("src_space", "tcm"))
|
||||
data = result_dict.get("data")
|
||||
th = TensorHandle(
|
||||
id=self._next_handle_id(),
|
||||
addr=slot_addr,
|
||||
shape=handle.cmd.shape,
|
||||
dtype=handle.cmd.dtype,
|
||||
nbytes=self._nbytes(handle.cmd.shape, handle.cmd.dtype),
|
||||
data=data,
|
||||
space=slot_space,
|
||||
)
|
||||
handle.resolved = True
|
||||
handle.result = th
|
||||
return th
|
||||
|
||||
# Composite path (existing behaviour)
|
||||
self._emit(WaitCmd(handle=handle))
|
||||
return None
|
||||
|
||||
def cycles(self, n: int) -> None:
|
||||
"""Declare PE_CPU scalar execution overhead (cycles)."""
|
||||
|
||||
Reference in New Issue
Block a user