ADR-0027: Megatron TP API + worker-wait generalization + mp.spawn

Implements ADR-0027 Phase 2 end-to-end. All 559 tests pass (was 523 +
1 xfail; ring_default_ws strict-xfail is now resolved).

D0 — Worker-wait generalization (context.py):
- _pending_worker_waits queue on RuntimeContext.
- ctx.wait(h) in worker context defers to main via g.parent.switch().
  Fast-path for already-completed handles.
- Worker API is unchanged: tensor deploy, launch, etc. still look
  synchronous; they're transparently cooperatively scheduled.
- Solves ADR-0024 Phase B kernel-greenlet orphan bug (env.run now
  only ever drives from main; kernel _parent is always main).

D0.5 — Host-read barrier (tensor.py):
- Explicit _HOST_READ_BARRIERS registry (T5.g closed-set via code
  review, not reflection-magic).
- numpy/data/__getitem__/__repr__ drain pending worker-waits before
  host-observable read.
- copy_: source-side barrier via source.numpy(). Target-side write
  barrier is intentionally NOT applied — global pending target barrier
  prematurely drains cross-rank collectives → deadlock.
- Collective pending is excluded from barrier drain condition
  (collective is cross-rank; its own yield in all_reduce covers the
  invariant naturally).

D1 — torch.multiprocessing.spawn (runtime_api/multiprocessing.py):
- API signature parity with real PyTorch spawn; execution is
  cooperative greenlet scheduler (process isolation etc. are explicit
  non-goals per D1.0).
- _drain_pending drains worker-waits then collectives in one barrier,
  loop-until-empty.
- Round-based exception handling with SystemExit sibling abort +
  SpawnException(errors) wrapping root-cause ranks.
- RuntimeContext attaches ctx.multiprocessing in __post_init__.
- benches/ccl_allreduce.py hand-rolled loop collapses to one
  torch.multiprocessing.spawn call.

D2–D6 — kernbench.tp package:
- parallel_state: initialize_model_parallel, get_*_rank,
  get_*_world_size, with weak active-ctx registry in context.py.
- layers: ColumnParallelLinear, RowParallelLinear (shape-only
  primitives — fp16 gemm via tl.load + tl.dot + tl.store).
- kernels: _gemm_kernel used by TP layers (self-contained; no bench
  dependency).
- primitives / mappings stubs per D6/D8.

Data-path fixes (surfaced by TP gemm + all_reduce sequence):
- sim_engine/op_log.py: dma_write snapshot is skipped for TCM
  sources (PE scratch is repopulated by Phase 2 math/gemm replay —
  capturing Phase-1-time snapshot picked up STALE data from prior
  kernel's output aliased at the same scratch addr, causing the later
  kernel's dma_write to overwrite Phase 2 result with stale value).
- sim_engine/op_log.py + sim_engine/data_executor.py: per-operand
  space recorded on GemmCmd and composite gemm records so HBM-resident
  operands (tl.load output) don't default to TCM during replay.
- runtime_api/context.py: ctx.zeros writes zero-init to MemoryStore
  at VA keys so kernels reading via VA see deterministic init even
  without explicit copy_().

Tests (Phase 1 + Phase 2):
- test_worker_wait_drain (T3): orphan invariant + resume + multi-rank
  drain + idempotency + exception propagation.
- test_mp_spawn (T4): spawn shape + bind + SpawnException scope.
- test_host_read_barrier (T5): barrier contract per entry-point +
  closed-set registry check.
- test_tp_parallel_state (T1): initialize + rank lookup.
- test_tp_layers (T2): shape + deterministic numerical correctness
  (concat-matmul equality for RowParallel, not mean-only).
- test_tp_mlp (T6): full 2-layer MLP with deterministic weight
  numerical match + rank-consistency post all-reduce.
- test_ccl_allreduce_matrix: ring_default_ws xfail removed (T7).

Regression: 523 pre + 35 new + 1 ex-xfail = 559 passed, 1 intentional
skip (T3.e historical failure documentation).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-14 16:31:13 -07:00
parent e7f376ebaa
commit 105f1dc09e
19 changed files with 1962 additions and 64 deletions
+57
View File
@@ -42,6 +42,21 @@ def _numpy_to_dtype_str(np_dtype) -> str:
raise ValueError(f"unsupported numpy dtype: {np_dtype!r}")
# ADR-0027 D3: weak registry of the currently-active RuntimeContext so
# module-level helpers (e.g. ``kernbench.tp.parallel_state``) can resolve
# the ctx without threading it through every call.
import weakref as _weakref
_ACTIVE_CTX_REF: _weakref.ref | None = None
def _get_active_context():
"""Return the most-recently-entered RuntimeContext, or None."""
if _ACTIVE_CTX_REF is None:
return None
return _ACTIVE_CTX_REF()
class _AhbmNamespace:
"""torch.ahbm — per-greenlet SIP device binding (ADR-0024 D10).
@@ -89,6 +104,10 @@ class RuntimeContext:
_handles: list[RequestHandle] = field(default_factory=list, init=False)
_completed: set[RequestHandle] = field(default_factory=set, init=False)
# ADR-0027 D0.1: worker-deferred wait queue. When a worker greenlet
# calls ctx.wait(h), the handle is appended here and control yields to
# main. Main's scheduler drain consumes this list.
_pending_worker_waits: list[RequestHandle] = field(default_factory=list, init=False)
_allocators: dict[tuple[int, int, int], Any] = field(default_factory=dict, init=False)
_va_allocator: Any = field(default=None, init=False)
_tensor_counter: int = field(default=0, init=False)
@@ -109,6 +128,9 @@ class RuntimeContext:
# (PyTorch 2.x portable) namespaces for per-greenlet device binding.
self.ahbm = _AhbmNamespace()
self.accelerator = _AcceleratorNamespace(self.ahbm)
# ADR-0027 D1.3: torch.multiprocessing.spawn namespace.
from kernbench.runtime_api.multiprocessing import _MultiprocessingNamespace
self.multiprocessing = _MultiprocessingNamespace(self)
def install_ipcq(
self,
@@ -160,10 +182,16 @@ class RuntimeContext:
return plan
def __enter__(self):
global _ACTIVE_CTX_REF
_ACTIVE_CTX_REF = _weakref.ref(self)
return self
def __exit__(self, *exc):
global _ACTIVE_CTX_REF
self.cleanup()
# Clear active-context registry if we are it.
if _ACTIVE_CTX_REF is not None and _ACTIVE_CTX_REF() is self:
_ACTIVE_CTX_REF = None
return False
def submit(self, request: Any) -> RequestHandle:
@@ -178,10 +206,24 @@ class RuntimeContext:
return handle in self._completed
def wait(self, handle: RequestHandle, *, _meta: dict | None = None) -> Completion:
# ADR-0027 D0.2: fast-path for already-completed handles (avoid
# redundant worker→main→worker round-trip).
if handle in self._completed:
completion, trace = self.engine.get_completion(handle)
return completion
# ADR-0027 D0.2: if called from a worker greenlet (parent is main,
# not dead), defer the wait to the main scheduler — enqueue and
# yield. Main drains env.run, then switches back. On resume the
# handle must be in _completed (D0.3 resume invariant).
from greenlet import getcurrent
g = getcurrent()
if g.parent is not None and not g.parent.dead:
self._pending_worker_waits.append(handle)
g.parent.switch()
# Resume: main drained. Fall through to completion/trace assembly.
# Main context (or single-driver): drive engine directly.
wait_fn = getattr(self.engine, "wait", None)
if wait_fn is not None:
wait_fn(handle) # type: ignore[misc]
@@ -543,6 +585,21 @@ class RuntimeContext:
"sip": shard.sip, "cube": shard.cube, "pe": shard.pe,
"nbytes": shard.nbytes,
})
# ADR-0027: also populate MemoryStore at VA keys so kernels
# reading via VA (the common ``tl.load`` path) see the init
# data. Phase 1 MemoryWriteMsg writes via PA; kernels read via
# VA; Phase 2 DataExecutor reads via the addresses captured in
# op_log (VA for tl.load). Without this, zero-init tensors are
# invisible to kernels in Phase 2.
store = getattr(self.engine, "_memory_store", None)
if store is not None and pattern == "zero" and handle.va_base:
import numpy as np
from kernbench.runtime_api.tensor import _numpy_dtype
np_dtype = _numpy_dtype(dtype)
for shard in handle.shards:
count = shard.nbytes // itemsize
addr = handle.va_base + shard.offset_bytes
store.write("hbm", addr, np.zeros(count, dtype=np_dtype))
return t
@@ -0,0 +1,152 @@
"""``torch.multiprocessing.spawn``-compatible namespace (ADR-0027 D1).
Real-PyTorch API *signature* parity only — execution model is a cooperative
greenlet scheduler in a single Python process (D1.0). Non-goals: process
isolation, independent address space, failure isolation, OS-level scheduler
fairness, mp.Queue/Lock.
Attached to ``RuntimeContext`` as ``ctx.multiprocessing`` in
``__post_init__`` (D1.3).
"""
from __future__ import annotations
from typing import Any, Callable
class SpawnException(RuntimeError):
"""Raised from ``_MultiprocessingNamespace.spawn`` on worker failure.
``errors`` contains only root-cause ranks — the rank(s) whose body
raised. Sibling greenlets terminated via ``throw(SystemExit)`` during
cleanup are NOT recorded (SystemExit does not satisfy ``except
Exception`` in the entry wrapper).
"""
def __init__(self, errors: dict[int, Exception]):
self.errors = errors
first = next(iter(errors.items()), None)
msg = (
f"spawn failed on ranks {sorted(errors.keys())}"
+ (
f": rank {first[0]} raised {first[1]!r}"
if first is not None
else ""
)
)
super().__init__(msg)
def _drain_pending(ctx: Any) -> None:
"""Drain worker-wait + collective-pending queues in main context (D0.4/D0.5).
Loop-until-empty: runs until both queues are simultaneously empty. Safe
under the current model where main-context ``ctx.wait`` never re-enqueues
(D0.5 main-context non-reentrance invariant); also safe under future
extensions where drain can add sub-handles (SimPy causality gives finite
depth).
"""
distributed = getattr(ctx, "distributed", None)
backend = getattr(distributed, "_backend", None) if distributed else None
def _collective_nonempty() -> bool:
if backend is None:
return False
pending = getattr(backend, "_pending_collective_handles", None)
return bool(pending)
while ctx._pending_worker_waits or _collective_nonempty():
# (a) Worker-driven waits (D0.1). FIFO.
while ctx._pending_worker_waits:
h = ctx._pending_worker_waits.pop(0)
if h not in ctx._completed:
wait_fn = getattr(ctx.engine, "wait", None)
if wait_fn is not None:
wait_fn(h)
# Populate _completed so fast-path in ctx.wait short-circuits
# on the return leg.
ctx._completed.add(h)
# (b) Collective backend queue (ADR-0024 D7 + D0.4-(2)).
if backend is not None:
pending_list = getattr(backend, "_pending_collective_handles", None)
if pending_list is not None:
while pending_list:
h, _sip_id, meta = pending_list.pop(0)
# Main context: ctx.wait drives engine directly and does
# NOT re-enqueue (D0.5 invariant).
ctx.wait(h, _meta=meta)
class _MultiprocessingNamespace:
"""torch.multiprocessing-compat facade bound to a RuntimeContext."""
def __init__(self, ctx: Any) -> None:
self._ctx = ctx
def spawn(
self,
fn: Callable,
args: tuple = (),
nprocs: int = 1,
join: bool = True,
) -> None:
"""Spawn ``nprocs`` worker greenlets, each calling ``fn(rank, *args)``.
Mirrors ``torch.multiprocessing.spawn`` signature (minus ``daemon``).
Runs the D0.4 round-robin scheduler loop until all workers finish,
draining pending queues between rounds.
"""
from greenlet import greenlet
ctx = self._ctx
dist = ctx.distributed
gs: list = []
errors: dict[int, Exception] = {}
for rank in range(nprocs):
def _entry(r: int = rank) -> None:
try:
fn(r, *args)
except Exception as e:
errors[r] = e
raise
g = greenlet(_entry)
if dist is not None and hasattr(dist, "_bind_rank"):
dist._bind_rank(g, rank)
gs.append(g)
try:
while True:
alive = [g for g in gs if not g.dead]
if not alive:
break
for g in alive:
if not g.dead:
g.switch()
_drain_pending(ctx)
except Exception as outer:
# D0.4-(4) sibling cleanup. Abort live greenlets, clear state.
for other in gs:
if not other.dead:
try:
other.throw(SystemExit)
except BaseException:
# SystemExit inherits BaseException; greenlet.throw
# re-raises in caller if target doesn't catch it.
# Silent — we're already in cleanup.
pass
backend = getattr(dist, "_backend", None)
if backend is not None:
if hasattr(backend, "_barrier") and hasattr(backend._barrier, "reset"):
try:
backend._barrier.reset()
except Exception:
pass
pending_collective = getattr(
backend, "_pending_collective_handles", None,
)
if pending_collective is not None:
pending_collective.clear()
ctx._pending_worker_waits.clear()
raise SpawnException(errors) from outer
# join=True: we already waited for all workers above.
+65
View File
@@ -66,6 +66,57 @@ def _numpy_dtype(dtype: str) -> np.dtype:
return np.dtype(_NUMPY_DTYPE.get(dtype, np.float16))
# ADR-0027 T5.g: closed-set registry of host-read barrier entry-points.
# Any new Tensor API with host-observable read semantics must be added here
# AND implement the barrier call. Code review + this registry keep the set
# consistent (Python introspection-based auto-detection is a non-goal).
# Note on ``copy_``: the source read is barriered via ``source.numpy()``.
# A target-side write barrier was specified in an earlier revision of
# ADR-0027 D0.5 but is intentionally not applied (global-pending target
# barrier can prematurely drain cross-rank collectives → deadlock).
_HOST_READ_BARRIERS: frozenset[str] = frozenset({
"numpy",
"data",
"__getitem__",
"__repr__",
"copy_", # source-side via source.numpy(); target-side not barriered
})
def _host_read_barrier(tensor: "Tensor") -> None:
"""ADR-0027 D0.5: drain pending worker-wait queue before a host-observable
read/write.
Scope: the barrier yields to main when ``ctx._pending_worker_waits`` is
non-empty AND the caller is a worker greenlet. Collective pending
(``backend._pending_collective_handles``) is **deliberately excluded**
from this check — collective handles represent cross-rank protocol that
must be drained only at scheduler synchronisation points (all workers
yielded). A collective's own yield (inside ``all_reduce``) already
ensures that once the collective call returns to the worker, post-drain
values are visible, so subsequent host reads see materialised data
without needing to trigger drain themselves. Including collective
pending here would cause an unrelated rank's barrier to prematurely
request drain of a cross-rank operation → deadlock.
No-op when called from main context or when the worker-wait queue is
empty (fast-path avoids needless context switches).
"""
ctx = None
if tensor._ctx_ref is not None:
ctx = tensor._ctx_ref()
if ctx is None:
return
worker_pending = getattr(ctx, "_pending_worker_waits", None)
if not worker_pending:
return # fast-path
from greenlet import getcurrent
g = getcurrent()
if g.parent is None or g.parent.dead:
return # main context: caller drains directly when needed
g.parent.switch()
def deploy_tensor(
*,
name: str,
@@ -217,7 +268,9 @@ class Tensor:
"""Read a shard-aligned slice. Returns a numpy array.
Mirrors ``torch.Tensor.__getitem__`` for the shard-aligned case.
ADR-0027 D0.5: host-read barrier.
"""
_host_read_barrier(self)
start, stop = self._resolve_shard_index(key)
shard = self._shard_for_range(start, stop)
if self._memory_store is None:
@@ -272,6 +325,8 @@ class Tensor:
def __repr__(self) -> str:
parts = [f"tensor(name={self.name}, shape={self.shape}, dtype={self.dtype}"]
if self._memory_store is not None and self._handle is not None:
# ADR-0027 D0.5: barrier on data-containing repr path.
_host_read_barrier(self)
arr = self.data
parts.append(f", mean={float(arr.mean()):.4g}, norm={float(np.linalg.norm(arr)):.4g}")
else:
@@ -308,7 +363,11 @@ class Tensor:
Mirrors ``torch.Tensor.numpy()``. In kernbench, sharded tensors are
gathered into a single full-shape ndarray according to each shard's
``offset_bytes`` / ``nbytes`` range.
ADR-0027 D0.5: acts as a host-read barrier — drains pending waits +
collective handles before reading, ensuring post-drain values.
"""
_host_read_barrier(self)
np_dtype = _numpy_dtype(self.dtype)
# Host-side tensor (created via torch.from_numpy) has no shards.
if self._host_buffer is not None:
@@ -340,6 +399,12 @@ class Tensor:
re-scattered into self's shard layout.
Shapes must match. Returns self.
ADR-0027 D0.5: source-side read barrier is triggered inside
``source.numpy()``. Target-side write barrier is not applied here
because it would require cross-rank coordination when other ranks
have pending collectives (see _host_read_barrier docstring on
collective pending being cross-rank).
"""
if self._handle is None or self._memory_store is None:
raise RuntimeError(
+11 -4
View File
@@ -101,12 +101,19 @@ class DataExecutor:
p = op.params
if "src_a_addr" not in p:
return # composite record without full params
space = p.get("addr_space", "tcm")
default_space = p.get("addr_space", "tcm")
# ADR-0027: per-operand + output spaces (fall back to single space
# for legacy records without explicit space keys).
src_a_space = p.get("src_a_space", default_space)
src_b_space = p.get("src_b_space", default_space)
dst_space = p.get("dst_space", default_space)
dtype_in = p.get("dtype_in", "f16")
dtype_out = p.get("dtype_out", dtype_in)
a = self.store.read(space, p["src_a_addr"], shape=p.get("shape_a"), dtype=dtype_in)
b = self.store.read(space, p["src_b_addr"], shape=p.get("shape_b"), dtype=dtype_in)
a = self.store.read(src_a_space, p["src_a_addr"],
shape=p.get("shape_a"), dtype=dtype_in)
b = self.store.read(src_b_space, p["src_b_addr"],
shape=p.get("shape_b"), dtype=dtype_in)
# Compute in higher precision if specified
dtype_acc = p.get("dtype_acc", "f32")
@@ -114,7 +121,7 @@ class DataExecutor:
b_f = b.astype(_resolve_dtype(dtype_acc))
result = np.matmul(a_f, b_f).astype(_resolve_dtype(dtype_out))
self.store.write(space, p["dst_addr"], result)
self.store.write(dst_space, p["dst_addr"], result)
def _execute_math(self, op: OpRecord) -> None:
"""Execute math op: unary, binary, or reduction."""
+43 -11
View File
@@ -79,16 +79,24 @@ class OpLogger:
snaps.append(None)
params["input_snapshots"] = snaps
elif op_name == "dma_write":
try:
arr = self._memory_store.read(
params["src_space"], params["src_addr"],
shape=params.get("shape"), dtype=params.get("dtype"),
)
params["snapshot"] = (
arr.copy() if hasattr(arr, "copy") else arr
)
except Exception:
params["snapshot"] = None
# ADR-0027 fix: only snapshot HBM sources. TCM (PE scratch)
# sources are repopulated by Phase 2 math/gemm replay —
# capturing a Phase-1-time snapshot here would pick up stale
# data from a PRIOR kernel's Phase 2 output that aliased the
# same scratch address, causing the later kernel's replay
# to write that stale value instead of the fresh math
# result. See ADR-0027 postmortem (TP gemm → all_reduce).
if params.get("src_space") == "hbm":
try:
arr = self._memory_store.read(
params["src_space"], params["src_addr"],
shape=params.get("shape"), dtype=params.get("dtype"),
)
params["snapshot"] = (
arr.copy() if hasattr(arr, "copy") else arr
)
except Exception:
params["snapshot"] = None
self._records.append(OpRecord(
t_start=pending["t_start"],
t_end=t,
@@ -167,6 +175,13 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
"dtype_in": msg.a.dtype,
"dtype_out": msg.out.dtype,
"m": msg.m, "k": msg.k, "n": msg.n,
# ADR-0027: preserve per-operand + output MemoryStore spaces so
# Phase 2 replay can resolve HBM-resident operands (e.g. tl.load
# results keep space="hbm"). Absent → DataExecutor falls back
# to the legacy single-space mode via ``addr_space``.
"src_a_space": getattr(msg.a, "space", "tcm"),
"src_b_space": getattr(msg.b, "space", "tcm"),
"dst_space": getattr(msg.out, "space", "tcm"),
}
if isinstance(msg, MathCmd):
return "math", msg.op, {
@@ -181,10 +196,27 @@ def _extract_op_info(msg: Any) -> tuple[str, str, dict[str, Any]]:
"axis": msg.axis,
}
if isinstance(msg, CompositeCmd):
return "gemm" if msg.op == "gemm" else "math", f"composite_{msg.op}", {
params: dict[str, Any] = {
"op": msg.op,
"out_addr": msg.out_addr,
"out_nbytes": msg.out_nbytes,
}
# ADR-0027: preserve operand info so Phase 2 DataExecutor can replay
# the composite's numerical effect (treat it like a GemmCmd).
if msg.op == "gemm" and msg.a is not None and msg.b is not None:
params.update({
"src_a_addr": msg.a.addr,
"src_b_addr": msg.b.addr,
"shape_a": msg.a.shape,
"shape_b": msg.b.shape,
"dtype_in": msg.a.dtype,
"dtype_out": msg.a.dtype,
"src_a_space": getattr(msg.a, "space", "hbm"),
"src_b_space": getattr(msg.b, "space", "hbm"),
"dst_space": "hbm",
# dst_addr alias so DataExecutor._execute_gemm picks it up.
"dst_addr": msg.out_addr,
})
return "gemm" if msg.op == "gemm" else "math", f"composite_{msg.op}", params
# Fallback for unknown data_op messages
return "unknown", type(msg).__name__, {}
+21
View File
@@ -0,0 +1,21 @@
"""kernbench.tp — Megatron-style Tensor Parallelism (ADR-0027).
Public API re-exports.
"""
from kernbench.tp.layers import (
ColumnParallelLinear,
RowParallelLinear,
)
from kernbench.tp.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
initialize_model_parallel,
)
__all__ = [
"ColumnParallelLinear",
"RowParallelLinear",
"get_tensor_model_parallel_rank",
"get_tensor_model_parallel_world_size",
"initialize_model_parallel",
]
+23
View File
@@ -0,0 +1,23 @@
"""Kernel used by ``kernbench.tp`` layers (ADR-0027 D4/D5).
Intentionally self-contained inside the ``tp`` package — the ``tp`` package
must not import from ``benches/``. Future work: move to a shared
``kernbench.kernels`` module so benches and TP can share.
"""
from __future__ import annotations
def _gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE: str = "f16") -> None:
"""Single-PE GEMM: out = a @ b via load → dot → store.
Uses the ``tl.load + tl.dot + tl.store`` path. Unlike ``tl.composite``
(which is absorbed by the PE scheduler into TileTokens that don't reach
the op_log), this path emits explicit ``DmaReadCmd`` / ``GemmCmd`` /
``DmaWriteCmd`` records, which DataExecutor replays numerically in
Phase 2.
"""
M, K, N = int(M), int(K), int(N)
a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE)
b = tl.load(int(b_ptr), shape=(K, N), dtype=DTYPE)
out = tl.dot(a, b)
tl.store(int(out_ptr), out)
+150
View File
@@ -0,0 +1,150 @@
"""Megatron-style parallel layers (ADR-0027 D4/D5).
- ``ColumnParallelLinear``: weight's out_features axis split across TP ranks.
forward(x) is local gemm; no collective.
- ``RowParallelLinear``: weight's in_features axis split across TP ranks.
forward(x) ends with ``dist.all_reduce`` to sum partial products.
Both layers use the intra-device ``DPPolicy`` (ADR-0026). TP shard
ownership is determined by ``torch.ahbm.set_device(rank)`` (ADR-0024 D10).
Yield-safety contract (ADR-0027 D4/D5): every forward path contains at
least one ``ctx.wait`` (via ``torch.launch``) or one collective; this
keeps the scheduler loop making progress.
"""
from __future__ import annotations
from typing import Any
from kernbench.policy.placement.dp import DPPolicy
from kernbench.tp.kernels import _gemm_kernel
from kernbench.tp.parallel_state import (
get_tensor_model_parallel_world_size,
)
class ColumnParallelLinear:
"""Weight's K (out_features) axis distributed across TP ranks.
forward(x):
x: (M, N) — full-replicated across ranks
W_k: (N, K / world_size) — this rank's slice (on its SIP)
y_k = x @ W_k → (M, K / world_size)
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
dtype: str = "f16",
torch: Any = None,
) -> None:
if torch is None:
raise TypeError("ColumnParallelLinear requires torch=<RuntimeContext>")
ws = get_tensor_model_parallel_world_size()
if out_features % ws != 0:
raise ValueError(
f"out_features ({out_features}) must be divisible by TP world "
f"size ({ws})"
)
self.in_features = in_features
self.out_features = out_features
self.k_local = out_features // ws
self.dtype = dtype
self._torch = torch
# Per-rank weight slice. ``set_device(rank)`` (ADR-0024 D10) places
# it on SIP ``rank``. Intra-SIP layout comes from DPPolicy (ADR-0026).
self.weight = torch.zeros(
(in_features, self.k_local),
dtype=dtype,
dp=DPPolicy(cube="replicate", pe="replicate",
num_cubes=1, num_pes=1),
name="col_parallel_w",
)
# Bias omitted in initial scope (ADR-0027 D9).
self.bias = None
if bias:
raise NotImplementedError(
"bias=True is deferred (ADR-0027 D9 initial scope)"
)
def forward(self, x):
M = int(x.shape[0])
out = self._torch.empty(
(M, self.k_local),
dtype=x.dtype,
dp=DPPolicy(cube="replicate", pe="replicate",
num_cubes=1, num_pes=1),
name="col_parallel_out",
)
self._torch.launch(
"col_parallel_gemm",
_gemm_kernel,
x, self.weight, out,
M, self.in_features, self.k_local,
)
return out
class RowParallelLinear:
"""Weight's N (in_features) axis distributed across TP ranks.
forward(x):
x: (M, N / world_size) — rank-local slice (ColumnParallel output)
W_k: (N / world_size, K) — this rank's slice
y_k = x @ W_k → (M, K) — partial sum
y = all_reduce(y_k, op="sum") → (M, K) on every rank
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
dtype: str = "f16",
torch: Any = None,
) -> None:
if torch is None:
raise TypeError("RowParallelLinear requires torch=<RuntimeContext>")
ws = get_tensor_model_parallel_world_size()
if in_features % ws != 0:
raise ValueError(
f"in_features ({in_features}) must be divisible by TP world "
f"size ({ws})"
)
self.in_features = in_features
self.out_features = out_features
self.n_local = in_features // ws
self.dtype = dtype
self._torch = torch
self.weight = torch.zeros(
(self.n_local, out_features),
dtype=dtype,
dp=DPPolicy(cube="replicate", pe="replicate",
num_cubes=1, num_pes=1),
name="row_parallel_w",
)
self.bias = None
if bias:
raise NotImplementedError(
"bias=True is deferred (ADR-0027 D9 initial scope)"
)
def forward(self, x):
M = int(x.shape[0])
y_partial = self._torch.empty(
(M, self.out_features),
dtype=x.dtype,
dp=DPPolicy(cube="replicate", pe="replicate",
num_cubes=1, num_pes=1),
name="row_parallel_partial",
)
self._torch.launch(
"row_parallel_gemm",
_gemm_kernel,
x, self.weight, y_partial,
M, self.n_local, self.out_features,
)
self._torch.distributed.all_reduce(y_partial, op="sum")
return y_partial
+5
View File
@@ -0,0 +1,5 @@
"""Forward/backward mappings stub (ADR-0027 — future backward work).
Inference-only initial scope. Backward hooks land when training simulation
arrives.
"""
+83
View File
@@ -0,0 +1,83 @@
"""TP group state (ADR-0027 D3).
Single global TP group. Initial scope: TP size == world_size (pure TP;
mixed DP+TP is future work).
"""
from __future__ import annotations
_TP_WORLD_SIZE: int | None = None
def initialize_model_parallel(tensor_model_parallel_size: int) -> None:
"""Initialize the TP process group.
Must be called after ``torch.distributed.init_process_group``.
Only ``tensor_model_parallel_size == world_size`` is supported in the
initial scope.
"""
global _TP_WORLD_SIZE
# Import here to avoid cycle when tp is imported before a ctx exists.
_ws = _current_world_size()
if tensor_model_parallel_size != _ws:
raise NotImplementedError(
f"Only TP == world_size supported; got TP={tensor_model_parallel_size}, "
f"world_size={_ws}"
)
_TP_WORLD_SIZE = tensor_model_parallel_size
def get_tensor_model_parallel_world_size() -> int:
"""Return the TP group's world size.
Raises if not initialised — callers must call
:func:`initialize_model_parallel` first.
"""
if _TP_WORLD_SIZE is None:
raise RuntimeError(
"TP group not initialised; call initialize_model_parallel() first"
)
return _TP_WORLD_SIZE
def get_tensor_model_parallel_rank() -> int:
"""Return this worker's rank within the TP group.
Delegates to the greenlet-local rank registered by the spawn launcher
(ADR-0024 D9 via ``torch.distributed.get_rank``).
"""
# Resolve via the global torch.distributed facade on the active ctx.
return _current_rank()
def _reset_for_tests() -> None:
"""Clear _TP_WORLD_SIZE so ordering-sensitive tests can re-init."""
global _TP_WORLD_SIZE
_TP_WORLD_SIZE = None
# ── helpers (resolve current ctx) ────────────────────────────────────
def _current_ctx():
"""Best-effort resolution of the currently-active RuntimeContext.
In KernBench, the ``ctx`` is passed as the ``torch`` positional in
bench/worker code. Since parallel_state is a module-global helper,
we look it up via a weak registry maintained by RuntimeContext.
"""
from kernbench.runtime_api.context import _get_active_context
ctx = _get_active_context()
if ctx is None:
raise RuntimeError(
"No active RuntimeContext; kernbench.tp requires one "
"(call init_process_group / spawn under a live ctx)"
)
return ctx
def _current_world_size() -> int:
return _current_ctx().distributed.get_world_size()
def _current_rank() -> int:
return _current_ctx().distributed.get_rank()
+34
View File
@@ -0,0 +1,34 @@
"""TP primitive ops (ADR-0027 D6).
``copy_to_tp_region`` / ``reduce_from_tp_region`` are forward-only in the
initial scope (backward pass is future work). ``scatter`` / ``gather`` are
not implemented — they require an all-gather kernel that is not yet
available in KernBench (see ADR-0027 D9).
"""
from __future__ import annotations
from typing import Any
def copy_to_tp_region(x: Any) -> Any:
"""Forward: identity. Backward: all-reduce. (Training is future.)"""
return x
def reduce_from_tp_region(x: Any, torch: Any) -> Any:
"""Forward: all-reduce. Backward: identity."""
torch.distributed.all_reduce(x, op="sum")
return x
def scatter_to_tp_region(x: Any) -> Any:
raise NotImplementedError(
"scatter_to_tp_region deferred — caller should create the sharded "
"tensor directly (ADR-0027 D9)"
)
def gather_from_tp_region(x: Any) -> Any:
raise NotImplementedError(
"gather_from_tp_region deferred — requires all-gather kernel (ADR-0027 D9)"
)