Establish English as the canonical ADR language with Korean translations held in a parallel docs/adr-ko/ tree as derived artifacts (1:1 mirror). Promotion from adr-proposed/ to adr/ now writes English to adr/ and the Korean to adr-ko/; bidirectional sync rule documented in CLAUDE.md. - Migrate 30 ADRs in docs/adr/: 28 Korean-only translated to English, 2 bilingual pairs (ADR-0020, ADR-0023) consolidated (.en.md suffix dropped). ADR-0023 EN regenerated against KO source which had newer HW Realization Notes (D16-D23) section. - docs/adr-history/ left frozen by design (transitional state). - CLAUDE.md (Part 2): update ADR Lifecycle for 4-folder layout, mark docs/adr-ko/ as a Derived Artifact, add ADR Translation Discipline section covering bidirectional sync, conflict resolution (EN wins), and proposed-language freedom. - tools/verify_adr_lang_pairs.py: new verification tool checking pair completeness, filename mirroring, ADR-ID match, Status byte-equality. Pre-commit hook intentionally not added; run on demand or in CI. - tests/test_verify_adr_lang_pairs.py: 11 cases including CRLF/LF normalization, em-dash title separator, underscore-slug edge case. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
40 KiB
ADR-0027: Megatron-style Tensor Parallelism API
Status
Accepted
Context
Goal
Support inter-SIP tensor parallelism (TP) via a Megatron-LM style explicit parallel layer API. Declarative abstractions like DTensor are future work in a separate ADR (0028).
Why Megatron-style was chosen:
- TP arises at specific layer boundaries of a model. Explicit primitives are natural to the mental model.
- The de-facto industry standard established by NVIDIA Megatron / DeepSpeed.
- DTensor is declarative, so its design space is larger → phased approach.
TP primitive spec (Megatron-LM reference)
- ColumnParallelLinear: shards the weight's column (out_features) axis across TP ranks. Input is full-replicated, output is column-sharded. When a RowParallelLinear follows, no forward all-reduce is required.
- RowParallelLinear: shards the weight's row (in_features) axis across TP ranks. Input is already column-sharded (the output of ColumnParallel). Requires an all-reduce at the end of forward.
- VocabParallelEmbedding: shards the embedding along the vocab axis. all-reduce at the end of forward. (A stub in the initial scope; full implementation requires an all-gather kernel as a prerequisite.)
copy_to_tp_region,reduce_from_tp_region,scatter_to_tp_region,gather_from_tp_region— basic primitives.
Problems to solve
-
Worker-wait generalization (D0): extend the defer/yield/drain pattern of
dist.all_reduceto everyctx.waitpath. The biggest architectural decision of this ADR. -
Launcher API normalization (D1): current benches use a hand-rolled greenlet loop. Absorb it into
torch.multiprocessing.spawn(fn, args, nprocs)to preserve the real-PyTorch API surface + concentrate D0's scheduler drain in a single implementation site. -
Per-rank weight shard representation: each worker owns its own slice of the weight tensor. Naturally expressed via ADR-0024's
set_device(rank)+ ADR-0026's intra-device DPPolicy. -
Forward-only scope: KernBench currently has no backward (simulation purposes). This ADR prioritizes forward only. Training simulation is a separate ADR.
-
Collective call site: RowParallelLinear calls
all_reduceat the end of forward. Naturally works with ADR-0024's multi-greenlet structure + D0 generalization. -
TP group concept: Megatron crosses DP × TP × PP groups. The initial scope simplifies to TP group = all SIPs. Mixed DP+TP is future work.
Decision
D0. Worker-wait generalization — ctx.wait defers to main when in worker context
Restating the problem. kernel_runner.run captures the greenlet.getcurrent()
at spawn time as the kernel greenlet's _parent
(kernel_runner.py:94).
If env.run runs in the main context, parent=main is safe. If env.run runs
in a worker context, parent=worker, and the moment the worker yields/finishes
the kernel greenlet becomes an orphan → GreenletExit → failure of ADR-0024
Phase B's ring_default_ws.
Resolution. When a worker greenlet calls ctx.wait(h), instead of driving
env.run directly, yield to the main scheduler. main drives env.run and,
once the handle completes, control returns to the worker.
D0.1 RuntimeContext extension
# context.py
@dataclass
class RuntimeContext:
...
_pending_worker_waits: list[RequestHandle] = field(default_factory=list, init=False)
D0.2 ctx.wait worker fork
def wait(self, handle, *, _meta=None):
# Fast-path: already completed — skip enqueue + switch (consistent with
# D0.4-(3) idempotency). Avoids needless worker→main→worker round-trip
# and prevents redundant _pending_worker_waits growth.
if handle in self._completed:
completion, _trace = self.engine.get_completion(handle)
return completion
from greenlet import getcurrent
g = getcurrent()
if g.parent is not None and not g.parent.dead:
# Worker greenlet: defer to main. Push handle, yield to parent.
# Parent (scheduler loop) drains env.run, then switches back.
self._pending_worker_waits.append(handle)
g.parent.switch()
# On resume: handle must have completed (main drained the list).
# Fall through to the status-quo completion/trace assembly.
# Main context (or single-driver): drive engine directly.
wait_fn = getattr(self.engine, "wait", None)
if wait_fn is not None:
wait_fn(handle)
completion, trace = self.engine.get_completion(handle)
self._completed.add(handle)
if _meta is not None and trace is not None:
entry = dict(trace) if isinstance(trace, dict) else {"raw": trace}
entry.update(_meta)
self._traces.append(entry)
return completion
D0.3 ctx.wait worker-context semantic contract (normative)
This ADR explicitly changes the semantics of ctx.wait in worker context.
- Submit-vs-complete separation: when called from a worker,
ctx.wait(h)no longer guarantees "immediate completion" but instead guarantees "completion after the next scheduler drain". The point at which the worker returns fromwait()= the point at which main has finishedengine.waitfor that handle. Main-context calls remain immediate-synchronous as before (status quo). - Resume invariant (normative): at the point a worker resumes from a
worker-deferred
ctx.wait(h)(wheng.parent.switch()returns),h in ctx._completedmust be True. If this invariant breaks, the worker proceeds in a stale state, so whichever of_drain_pending/ the scheduler loop /ctx.waitis modified, this invariant must be preserved. T3.b directly asserts this invariant. - Observable change: the pattern
h = ctx.submit(msg); ctx.wait(h); read(handle_result)inside a worker still holds — but the semantic spec now includes the fact that a main-drain is automatically inserted betweenwait()andread. - Direct host-object reads see D0.5: the contract for calling
tensor.numpy()withoutctx.waitis specified separately in D0.5.
D0.4 Main scheduler drain — protocol (normative)
(The internal implementation of D1's multiprocessing.spawn. Below is the
semantic definition.)
while alive:
for g in alive: # (1) round-based worker switch
g.switch()
_drain_pending(ctx) # (2) drain in main context
(The actual definition of _drain_pending is in D0.5 — an outer while-loop
that drains until both queues are empty.)
Rules:
-
Round-based cooperative scheduling & yield obligation (worker contract).
g.switch()does not return until the worker voluntarily yields (cooperative greenlet semantics). Therefore:- If a worker runs a pure-compute loop like
while True: do_compute()without yielding,g.switch()never returns and the scheduler loop itself hard-blocks (other workers cannot get a switch turn, no drain occurs). This is not starvation but scheduler non-progress (deadlock equivalent), and this ADR classifies it as unsupported. - Workers must call one of
ctx.wait(h),dist.all_reduce, or a host-read barrier (D0.5) within a finite number of steps. Theforwardof a TP layer includes a launch→wait pair at the end of every layer, so this condition is naturally met. CCL kernels also yield insidedist.all_reduce. - Implementations need not detect this (timeouts/steps-since-yield counters, etc.). It is a user contract; the symptom on violation is "simulation hang".
- Future extension: if non-collective long compute paths become
common, an explicit
torch.distributed.cooperative_yield()primitive (no-op yield) could be introduced. Out of scope for this ADR. Not a breaking change — can be added if needed. - Within a round, every alive worker receives one
switchturn. Even if a single worker calls wait multiple times within one round, the calls are enqueued sequentially within that turn and processed in a single scheduler drain batch (FIFO).
- If a worker runs a pure-compute loop like
-
Drain order = submission order (FIFO).
_pending_worker_waitsis strict FIFO via list append/pop(0). Drain occurs in submission order, not completion order, and SimPy's scheduler itself guarantees a causally correct completion order, so submission-order drain is safe. Do not confusecompletion orderwithdrain order.Two-queue ordering (worker waits → collectives):
_drain_pendingdrains the worker wait queue first, then the collective queue. Rationale for this ordering:- The two queues are different dependency sources: worker waits are
handles produced by a worker's own
submit + waitpair (tensor deploy, MmuMap, etc.). The collective queue holds kernel-launch handles thatdist.all_reduceenqueues internally, which the worker never directly waits on (see the two-queue drain model in D0.5). - Independent in correctness terms: from the worker's perspective, a collective is "already submitted, then yielded". Its completion timing only needs to precede the worker's next action. There is no ordering dependency with the worker wait queue.
- Both finish within a single drain barrier: per D0.5's loop-until-empty rule, a single barrier invocation drains worker → collective → (repeat if new ones appeared) in that order. By the time the worker resumes, both sides are drained.
- The alternative (collective first) is also valid: this ADR fixes worker-first only for current implementation simplicity; semantically they are equivalent. Revisit if a performance-profile difference is observed.
- The two queues are different dependency sources: worker waits are
handles produced by a worker's own
-
Duplicate enqueue — correctness via idempotent drain; dedup not guaranteed.
ctx.wait(h)returns immediately ifh in ctx._completed._drain_pendinguses the same guard. Even if the same handle is appended to_pending_worker_waitsmultiple times,engine.waitis invoked only once (idempotent).- Correctness: relies on idempotent drain → safe.
- Memory/performance: this ADR does not guarantee dedup of
_pending_worker_waits. If the same handle is enqueued N times, the queue retains N elements and drain performs N pops + in-set guards. Unless a single worker abnormally repeats waits on the same handle, N stays at the order of 1 to a few. - Implementation freedom: implementations may optionally dedup (e.g.,
hold a
setas a side index, or checkh not in pending_setbefore append). Classified as an optimization that does not change correctness.
-
Exception propagation + sibling cleanup. When a worker greenlet raises,
g.switch()propagates the exception to main. The scheduler loop stops immediately and performs the following cleanup explicitly:try: while True: alive = [g for g in gs if not g.dead] if not alive: break for g in alive: if not g.dead: g.switch() _drain_pending(ctx) except Exception as outer: # (a) Force-terminate surviving sibling worker greenlets. for other in gs: if not other.dead: try: other.throw(SystemExit) except Exception: pass # silent — already in exceptional state # (b) Reset backend barrier / pending state (in preparation for future epoch barrier). backend = getattr(ctx.distributed, "_backend", None) if backend is not None and hasattr(backend, "_barrier"): backend._barrier.reset() backend_pending = getattr(backend, "_pending_collective_handles", None) if backend_pending is not None: backend_pending.clear() ctx._pending_worker_waits.clear() # (c) Wrap the originating exception in SpawnException. raise SpawnException(errors) from outerProtocol:
- Sibling abort guarantee: when one worker raises,
SystemExitis thrown into all sibling greenlets — greenlets terminate immediately. No greenlet leaks. - Explicit pending-queue clear: both queues (worker-wait + collective-pending) are cleared. Prevents contamination on reuse.
SpawnException(errors)wrapping:errors: dict[int, Exception]contains the original exception per rank. Compatible with the failure pattern of real-PyTorchtorch.multiprocessing.spawn.- Scope restriction:
errorsincludes only ranks that raised from their own code (root cause). Ranks terminated viathrow(SystemExit)during sibling cleanup do not appear inerrors(SystemExit is not caught by D1.2's entry wrappertry/except Exception— intentional design: sibling termination is a cleanup signal, not a failure). Made explicit so readers do not expect "all failed ranks" to appear.
- Scope restriction:
ctx._tracesis the partial state up to the moment of exception. Trace completeness is not guaranteed (some launches/all_reduces may terminate without leaving an entry).- Allocator / MemoryStore remain in their pre-exception state — reuse
is non-goal; creating a fresh
RuntimeContextis recommended. join=False/ retry / partial recovery are non-goals for this ADR.
SpawnExceptionis defined inruntime_api/multiprocessing.py:class SpawnException(RuntimeError): def __init__(self, errors: dict[int, Exception]): self.errors = errors first = next(iter(errors.items()), None) msg = (f"spawn failed on ranks {sorted(errors.keys())}" + (f": rank {first[0]} raised {first[1]!r}" if first else "")) super().__init__(msg) - Sibling abort guarantee: when one worker raises,
-
Single-driver compatibility. In main-only execution where
g.parent is None(legacy single-driver tests), D0.2's worker-fork condition is false → the existing immediate-synchronous path is preserved._drain_pendingis not invoked.
D0.5 Host-read barrier — decision (normative)
Inside a worker, host-observable reads such as tensor.numpy(),
tensor.__getitem__, and tensor.data are defined as automatic drain
barriers. Immediately before the call:
- If
ctx._pending_worker_waitsorbackend._pending_collective_handlesare non-empty, yield to main viag.parent.switch()→ main runs_drain_pending→ worker resumes after completion. - If both queues are empty, read immediately.
Barrier iteration protocol (normative — re-entrance): _drain_pending
drains via a while-loop until both queues are completely empty, not in a
single pass:
def _drain_pending(ctx):
while ctx._pending_worker_waits or (
ctx.distributed._backend
and ctx.distributed._backend._pending_collective_handles
):
while ctx._pending_worker_waits:
h = ctx._pending_worker_waits.pop(0)
if h not in ctx._completed:
ctx.engine.wait(h)
backend = ctx.distributed._backend
if backend is not None:
while backend._pending_collective_handles:
h, _sip_id, meta = backend._pending_collective_handles.pop(0)
ctx.wait(h, _meta=meta) # main context: safe; ctx.wait will
# not push back to pending
Main-context ctx.wait non-recursion invariant (normative): the
ctx.wait(h, _meta=meta) call inside _drain_pending runs in the main
greenlet context. Because D0.2's worker-fork condition (g.parent is not None and not g.parent.dead) is False, it enters the immediate-synchronous
path → never enqueues to _pending_worker_waits. Thanks to this
invariant, the drain loop terminates without recursion / queue re-growth.
When implementing, it is important to maintain g.parent is None as the
single-main-greenlet guarantee.
Why a loop: ctx.wait(h, _meta=meta) is called in main context, so per
the D0.2 path it drives the engine directly (no additional enqueue — the
invariant above). In theory a single pass would suffice — but the protocol
is fixed at loop-until-empty. Reasons:
- Future-extension safety: a future implementation might enqueue new pending items mid-drain (e.g., tree-reduce collectives with sub-handles). The loop protocol preserves correctness in that case.
- Readability: the single sentence "the barrier drains until pending
is empty" closes the semantics. No dependence on the non-trivial
invariant that
ctx.waitcalls do not produce new enqueues. - Barrier semantics are "all dependencies needed for this read are complete": in the current model all pending = all dependencies, so the two are identical. The user mental model is the former.
Termination guarantee: described under two regimes.
- Current implementation: when called in main context,
ctx.waitdrives the engine directly (D0.2) → does not enqueue new pending. Each iteration strictly shrinks pending size bypop(0)+engine.wait. The iteration count is bounded by the initial pending size itself → finite termination. - Future extension (the bound that justifies the loop protocol): if an implementation enqueues new pending mid-drain (e.g., tree-reduce sub-handles) is introduced, the initial-size bound breaks. However, SimPy causality guarantees that the dependency DAG of handles is finite, so nested depth is finite. The loop protocol automatically accommodates this case.
Both regimes guarantee that infinite loops are impossible. The single-pass bound of the current implementation is a reference value for aggressive optimization; the protocol is fixed at loop-until-empty.
Why implicit drain at read is correct:
- In the original open question, the choice was between (a) implicit drain
and (b) explicit barrier. (b) is clear but burdens TP layer users with
the 3-step pattern
out = fc1.forward(x); ctx.drain(); result = out.numpy()on every read. (a) is a single rule that "guarantees the read sees the reflected value" — identical to CUDA'scudaDeviceSynchronize before host copypattern, which is not a hidden rule but the contract of a named entry point. - This ADR adopts (a) but closes the entry-point list explicitly:
Tensor.numpy(),Tensor.data(numpy alias),Tensor.__getitem__,Tensor.__repr__(when data is included), and any other official host-read APIs are finalized via codebase search at the time of implementing this ADR. Any newly added host-read API must follow this contract (regression-guarded by tests). - Even when calling
numpydirectly after onlyctx.submitwithoutwait, the drain barrier still operates (because the handle is in the pending queue). The invariant is restored at read time even if the user omits an explicit wait.
Tensor.copy_(source) — write barrier specification:
copy_ is semantically "write to target", but internally it calls
source.numpy() to fetch source data on the host then writes to each
shard via target._memory_store.write(...). Both directions are
barrier-handled:
- Source-side (read barrier):
source.numpy()triggers the D0.5 read barrier (when source itself is a deployed tensor with pending). - Target-side (write barrier — based on global pending): on
copy_entry, ifctx._pending_worker_waitsorbackend._pending_collective_handlesare non-empty, drain viag.parent.switch()before writing. Not per-tensor / per-shard dependency tracking, but based on the global pending queue.- Why global: KernBench's handle representation does not retain the reverse-mapping information "this handle writes to which shard of which target". A safe conservative rule: "drain if any global pending exists". As a result, pending of an unrelated tensor can also block copy_ — drop-in invariant takes priority.
- Explicit tradeoff: this rule can introduce unnecessary serialization between independent tensors. However, under the current single-queue execution model this cost is acceptable — guaranteeing cross-rank correctness and the "read sees latest" invariant via a simple rule takes precedence.
- Practical impact: most pending of a single worker within a layer step is its own work — extra context switches from over-barrier often coincide with the end-of-round scheduler drain point, so no major issue.
- Future refinement: per-tensor pending tracking could narrow this rule, but it is out of scope for this ADR.
Non-barrier:
tensor.shape,tensor.dtype,tensor.name, and other metadata-only access does not drain. No data dependency.tensor.pa,tensor.va, and other raw address accessors also do not drain (address only, not content).
Official barrier entry-points (closed set):
| API | Kind | Rationale |
|---|---|---|
Tensor.numpy() |
read | host-observable copy |
Tensor.data |
read | numpy() alias |
Tensor.__getitem__ |
read | shard-aligned read |
Tensor.__repr__ (when data is included) |
read | debugging/log |
Tensor.copy_(source) |
read + write | source read + target write |
This contract is verified directly in T5/T6.
D0.6 Why the worker function API is unchanged (informative)
- The inside of
torch.zeros(...)is aself.submit(msg)+self.wait(h)pair.waitauto-defers to main per D0.2/D0.3 — appears synchronous from the outside but yields once. tensor.numpy()follows D0.5's host-read barrier → drain→read when pending exists, immediate read otherwise.dist.all_reducecontinues to use the existing_defer_wait=True+_pending_collective_handlespath. D0.4's drain processes both queues together.
D0.7 Invariants
- The kernel greenlet's
_parentis always main: because env.run never runs in worker context. (Core assertion of T3.) - Cross-rank synchronization point: drain occurs only after every worker has yielded → kernels of all ranks advance together within one round (a prerequisite for cross-rank IPCQ exchange).
- Single-driver compatibility: D0.4-(5).
D1. torch.multiprocessing.spawn(fn, args, nprocs)
Real-PyTorch API parity + a single implementation site for D0's scheduler loop.
D1.0 API parity only — not execution parity (normative)
The name torch.multiprocessing.spawn is restricted to API signature
parity. The actual execution model is a cooperative greenlet scheduler
(single Python process, single OS thread, round-robin drive per D0.4). The
following are properties this ADR does NOT provide — among the
guarantees of real-PyTorch torch.multiprocessing.spawn, explicitly
non-goals:
- Process isolation (independent OS process per rank).
- Independent address space (each rank with its own Python heap).
- Failure isolation (a hard crash in one rank not affecting others).
- OS-level scheduler fairness (preemptive time slicing between ranks).
- Inter-process primitives such as
mp.Queue,mp.Lock.
Actual properties of this implementation:
- All ranks are greenlets inside the same Python process. Shared global state is visible as-is (intentional simulation convenience).
- Single-threaded under the GIL → not parallel execution. Only "logical concurrency" via SimPy event ordering is reproduced.
- Unhandled exception in any one worker → entire simulation aborts (D0.4-(4)).
Caller's obligation: when porting real-PyTorch multi-process samples to
KernBench, logic that relies on process isolation (e.g., os.getpid,
independent temp files, signal handling) must be removed. The namespace name
is preserved for code portability — semantics differ.
D1.1 Public surface
# runtime_api/multiprocessing.py (new)
class _MultiprocessingNamespace:
def __init__(self, ctx):
self._ctx = ctx
def spawn(self, fn, args: tuple, nprocs: int, join: bool = True) -> None:
"""Spawn `nprocs` worker greenlets, each calling fn(rank, *args).
Mirrors torch.multiprocessing.spawn signature (minus `daemon`).
Drives the D0 scheduler loop until all workers finish.
"""
...
D1.2 Implementation
def spawn(self, fn, args, nprocs, join=True):
from greenlet import greenlet
ctx = self._ctx
dist = ctx.distributed
gs: list[greenlet] = []
errors: dict[int, Exception] = {}
for rank in range(nprocs):
def _entry(r=rank):
try:
fn(r, *args)
except Exception as e:
errors[r] = e
raise
g = greenlet(_entry)
dist._bind_rank(g, rank)
gs.append(g)
try:
while True:
alive = [g for g in gs if not g.dead]
if not alive:
break
for g in alive:
if not g.dead:
g.switch()
_drain_pending(ctx) # D0.5
except Exception as outer:
# Sibling cleanup per D0.4-(4)
for other in gs:
if not other.dead:
try:
other.throw(SystemExit)
except Exception:
pass
backend = getattr(dist, "_backend", None)
if backend is not None:
if hasattr(backend, "_barrier"):
backend._barrier.reset()
if getattr(backend, "_pending_collective_handles", None) is not None:
backend._pending_collective_handles.clear()
ctx._pending_worker_waits.clear()
raise SpawnException(errors) from outer
# `join=True` semantics: we already wait for all workers.
D1.3 torch namespace attach
In runtime_api/context.py __post_init__:
self.multiprocessing = _MultiprocessingNamespace(self)
→ in bench code: torch.multiprocessing.spawn(worker, args=(ws,), nprocs=ws).
D1.4 Migration of existing benches
The hand-rolled loop in benches/ccl_allreduce.py collapses into a single
torch.multiprocessing.spawn line. Existing matrix regressions are
preserved. The currently xfail ring_default_ws is expected to flip to
PASS thanks to D0 (workers no longer orphan the kernel greenlet).
D2. New package kernbench.tp
src/kernbench/tp/
__init__.py — public API re-exports
parallel_state.py — TP group management (currently a single global group)
layers.py — ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
primitives.py — copy/reduce/scatter/gather_to/from_tp_region
kernels.py — gemm kernel launched by TP layers (reusable)
mappings.py — forward identity/all_reduce, backward stub
D3. parallel_state — TP group
# parallel_state.py
_TP_WORLD_SIZE = None
def initialize_model_parallel(tensor_model_parallel_size: int) -> None:
"""Initialize TP group. Must be called after dist.init_process_group."""
global _TP_WORLD_SIZE
from kernbench.runtime_api.distributed import get_dist # or torch.distributed
dist = get_dist()
total = dist.get_world_size()
if tensor_model_parallel_size != total:
raise NotImplementedError(
"Only TP == world_size supported in initial scope"
)
_TP_WORLD_SIZE = tensor_model_parallel_size
def get_tensor_model_parallel_world_size() -> int:
return _TP_WORLD_SIZE
def get_tensor_model_parallel_rank() -> int:
from kernbench.runtime_api.distributed import get_dist
return get_dist().get_rank() # ADR-0024 greenlet-local rank
Initial scope: TP size = world_size = topology SIP count. Pure TP model.
D4-pre. TP shard ownership vs DPPolicy — role separation (normative)
In the weight/output representation of TP layers, two concepts are clearly separated:
| Concept | Decided by | Scope |
|---|---|---|
| TP shard ownership (which rank owns which slice of the weight) | greenlet-local rank + torch.ahbm.set_device(rank) (ADR-0024 D2/D3) |
cross-rank, cross-SIP |
| Intra-rank placement (how the owned slice is distributed across cube × PE inside the rank) | DPPolicy(cube=..., pe=...) (ADR-0026) |
inside one rank (within SIP boundary) |
Thus when ColumnParallelLinear creates a weight of shape (in_features, out_features // ws) and assigns DPPolicy(cube="column_wise", pe="column_wise"):
- The slice owned by rank r = column-axis [r * k_local, (r+1) * k_local) of the weight — set_device(r) determines this (that rank resides on SIP r).
- Inside that slice, the cube × PE column-wise distribution — DPPolicy determines this.
The two axes are independent. If two ranks build their own slice with
the same DPPolicy, the slices themselves live on different SIPs but the
intra-SIP placement pattern is the same. Conversely, changing DPPolicy to
cube="replicate", pe="replicate" preserves TP shard ownership and only
changes intra-rank placement.
Mistakes that blur this boundary (forbidden by this ADR):
- The "SIP axis" reappearing in DPPolicy (removed in ADR-0026).
- TP layers expressing cross-rank sharding via
DPPolicyalone withoutset_device→ indistinguishable from a vertical split within a single rank.
The TP layers of this ADR always treat weight/output from the perspective of "rank = SIP = owns one slice + DPPolicy intra-SIP distribution" only.
D4. ColumnParallelLinear
Important: no new host-side torch.matmul abstraction is introduced.
The layer's forward calls the existing gemm kernel via torch.launch("gemm", gemm_kernel, ...) — the pattern already used by KernBench benches
(benches/gemm_single_pe.py,
benches/gpt3_qkv.py).
# layers.py
from kernbench.policy.placement.dp import DPPolicy
from kernbench.tp.kernels import _gemm_kernel
from kernbench.tp.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
class ColumnParallelLinear:
"""Shards the K(out_features) axis of the weight across TP ranks.
forward(x):
x: (M, N) — full-replicated across ranks
W_k: (N, K / world_size) — rank-local slice (placed on SIP r via set_device)
y_k = x @ W_k → (M, K / world_size) — rank-local output
Output is column-sharded. The input form expected by RowParallelLinear.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False,
dtype: str = "f16", torch=None):
ws = get_tensor_model_parallel_world_size()
assert out_features % ws == 0
self.in_features = in_features
self.k_local = out_features // ws
self._torch = torch
# Each rank owns its own slice — placed on SIP r by set_device(rank).
self.weight = torch.zeros(
(in_features, self.k_local), dtype=dtype,
dp=DPPolicy(cube="column_wise", pe="column_wise"),
name="col_parallel_w",
)
self.bias = None
if bias:
self.bias = torch.zeros(
(self.k_local,), dtype=dtype,
dp=DPPolicy(cube="replicate", pe="replicate"),
name="col_parallel_b",
)
def forward(self, x):
# x is full-replicated (caller-guaranteed). Plain local gemm.
M = x.shape[0]
out = self._torch.empty(
(M, self.k_local), dtype=x.dtype,
dp=DPPolicy(cube="column_wise", pe="column_wise"),
name="col_parallel_out",
)
self._torch.launch(
"col_parallel_gemm", _gemm_kernel,
x, self.weight, out, M, self.in_features, self.k_local,
)
# bias add as a separate kernel or as fused bias of a composite gemm.
# Initial scope verifies bias=False sufficiently.
return out
Yield-safety contract (normative): ColumnParallelLinear.forward
includes one torch.launch call containing a kernel launch → internal
ctx.wait pair. This automatically satisfies the "worker yields within a
finite number of steps" condition of D0.4-(1) — TP layer users do not need
to insert yield patterns manually.
D5. RowParallelLinear
class RowParallelLinear:
"""Shards the N(in_features) axis of the weight across TP ranks.
forward(x):
x: (M, N / world_size) — rank-local slice (output of ColumnParallel)
W_k: (N / world_size, K) — rank-local slice
y_k = x @ W_k → (M, K) — partial sum on each rank
y = all_reduce(y_k, op="sum") → (M, K) on every rank
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False,
dtype: str = "f16", torch=None):
ws = get_tensor_model_parallel_world_size()
assert in_features % ws == 0
self.n_local = in_features // ws
self.out_features = out_features
self._torch = torch
self.weight = torch.zeros(
(self.n_local, out_features), dtype=dtype,
dp=DPPolicy(cube="column_wise", pe="column_wise"),
name="row_parallel_w",
)
# bias only on rank 0 (Megatron convention). Omitted in initial scope.
self.bias = None
def forward(self, x):
M = x.shape[0]
y_partial = self._torch.empty(
(M, self.out_features), dtype=x.dtype,
dp=DPPolicy(cube="column_wise", pe="column_wise"),
name="row_parallel_partial",
)
self._torch.launch(
"row_parallel_gemm", _gemm_kernel,
x, self.weight, y_partial, M, self.n_local, self.out_features,
)
# Cross-rank reduce. ADR-0024's dist.all_reduce works correctly
# under D0 + mp.spawn (kernel parent = main is preserved).
self._torch.distributed.all_reduce(y_partial, op="sum")
return y_partial
Yield-safety contract (normative): RowParallelLinear.forward includes
launch → internal wait followed by all_reduce (defer + worker-yield
pattern), so at least 2 yields per forward are guaranteed. The
scheduler-progress condition of D0.4-(1) is automatically satisfied. All TP
layer forwards in this ADR maintain the invariant "yield-safe by containing
at least one wait or collective" — any future TP primitives (e.g.,
VocabParallelEmbedding) must keep the same contract.
D6. Primitive functions
# primitives.py
def copy_to_tp_region(x):
"""Forward: identity. Backward: all-reduce. (Implemented when training is added)."""
return x
def reduce_from_tp_region(x, torch):
"""Forward: all-reduce. Backward: identity."""
torch.distributed.all_reduce(x, op="sum")
return x
def scatter_to_tp_region(x):
raise NotImplementedError(
"Phase 2: replaced by users creating already-sharded tensors"
)
def gather_from_tp_region(x):
raise NotImplementedError(
"Phase 2: requires all-gather kernel as a prerequisite (future)"
)
D7. Sample bench — 2-layer MLP with TP
# benches/tp_mlp.py (new)
from kernbench.policy.placement.dp import DPPolicy
import kernbench.tp as tp
import numpy as np
def worker(rank: int, world_size: int, torch):
torch.ahbm.set_device(rank)
tp.initialize_model_parallel(world_size)
B, D_in, D_hidden, D_out = 1, 512, 2048, 512
fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=torch)
fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=torch)
x = torch.zeros(
(B, D_in), dtype="f16",
dp=DPPolicy(cube="replicate", pe="replicate"),
name="x",
)
# init x with some pattern (e.g., constant)
x.copy_(torch.from_numpy(np.full((B, D_in), 0.1, dtype=np.float16)))
h = fc1.forward(x) # column-sharded (B, D_hidden / ws)
y = fc2.forward(h) # all-reduced (B, D_out) on every rank
# Only rank 0 prints / verifies the result
if rank == 0:
result = y.numpy()
# With zero-init weights, all values are 0 — within scope "completion itself" is the check
print(f" tp_mlp: shape={result.shape}, mean={float(result.mean()):.4f}")
def run(torch):
torch.distributed.init_process_group(backend="ahbm")
ws = torch.distributed.get_world_size()
torch.multiprocessing.spawn(worker, args=(ws,), nprocs=ws)
D8. Non-functional — training not supported
This ADR is inference/forward only. Backward / gradient / optimizer is future work. Natural because KernBench is not a training system.
D9. Initial-scope constraints
- TP size = world_size (no mixed DP+TP).
scatter_to_tp_region,gather_from_tp_regionare unimplemented.- Default weight value is zero. Proper init schemes (Xavier, Kaiming,
etc.) are future. Tests inject deterministic non-zero patterns via
tensor.copy_to verify numerical correctness (T2/T6). I.e., operate as "production default = zero, verification = deterministic non-zero". - Bias is omitted in the initial scope (Megatron's rank-0-only bias policy is future).
- Pipeline parallelism is out of scope.
- VocabParallelEmbedding requires a prerequisite all-gather → stub only.
D10. Regression: ring_default_ws xfail removal — mandatory acceptance
Thanks to D0 (worker-wait generalization) + D0.5 (host-read barrier), every
worker-driven ctx.wait and host-read is routed through the main-drain path
→ the cause of the kernel-greenlet orphan in ADR-0024 Phase B disappears.
Flipping the existing matrix test's ring_default_ws strict-xfail case to
PASS after this ADR's implementation is included as a mandatory
regression criterion. Observable acceptance criteria are specified in
T7 (no deadlock, no GreenletExit, numerical tolerance, etc.).
Dependencies
- ADR-0024 (launcher): rank = SIP, greenlet-local rank,
torch.ahbm.set_device(rank). - ADR-0026 (DPPolicy intra-device): per-rank slice representation of weight tensors.
- ADR-0023 / ADR-0025 (IPCQ): foundation of
dist.all_reduceimplementation.
Non-goals
- Backward pass / training: inference only. Training simulation is a separate ADR.
- Mixed parallelism (DP + TP + PP): pure TP only at the start.
- Weight init schemes: simple zero / debug pattern.
- Fused ops: Megatron's fused matmul+bias+gelu is a kernel-level concern.
- DTensor integration: ADR-0028 future.
- Host-side
torch.matmulabstraction: TP layers call the existing gemm kernel viatorch.launch(gemm_kernel, ...). No new matmul host-op is introduced.
Open questions
- Location of
initialize_model_parallel:kernbench.tp.initialize_model_parallel(current decision) vs real-PyTorch'storch.distributed.init_device_mesh. Kept in the TP-only module. - Weight init: the ADR uses zero. A debug pattern (e.g., identity) may be needed for valid verification — add at Phase 1 test time if needed.
- Bias placement policy: Megatron places RowParallelLinear bias only on rank 0. Avoided in the initial scope via bias=False.
- GEMM kernel location:
kernbench.tp.kernels._gemm_kernelvs importing from existingbenches/gemm_single_pe.py. TP must not depend on benches, so duplicated inside tp. Migration to a sharedkernbench.kernelspackage is possible later.
Resolved (previously open in earlier revisions):
Drain timing on→ decided in D0.5: the official host-read entry points (tensor.numpy()callnumpy,data,__getitem__, data-containing__repr__) are automatic drain barriers. Metadata-only accessors are not barriers.
Consequences
Positive
- Easy porting of Megatron code: API matches real training code.
- TP benchmarking enabled: research on scaling, communication-compute overlap, and other HW characteristics.
ring_default_wsxfail removal: as a byproduct of D0, the ADR-0024 Phase B blocker is resolved.- Scheduler-loop unification: introducing D1 (
mp.spawn) removes the hand-rolled loop. Subsequent collective/TP benches reuse the same pattern. - DPPolicy semantics clarified (synergy with ADR-0026): TP layers as a best-practice example of using intra-device DPPolicy only.
Negative
- Maintenance cost of a new module (
kernbench.tp). - Initial scope is limited (pure TP only, forward only).
- D0 generalization changes the semantics of
ctx.wait— compatibility with single-driver tests must be explicitly verified (T7).
Neutral
- A pure upper layer added on top of ADR-0024/0026. No impact on the hardware-simulation stack (apart from D0).