Files
kernbench2/docs/adr/ADR-0027-par-megatron-tp.md
ywkang a796c1d2f7 ADR: bilingual structure — EN canonical in adr/, KO mirror in adr-ko/
Establish English as the canonical ADR language with Korean translations
held in a parallel docs/adr-ko/ tree as derived artifacts (1:1 mirror).
Promotion from adr-proposed/ to adr/ now writes English to adr/ and the
Korean to adr-ko/; bidirectional sync rule documented in CLAUDE.md.

- Migrate 30 ADRs in docs/adr/: 28 Korean-only translated to English,
  2 bilingual pairs (ADR-0020, ADR-0023) consolidated (.en.md suffix
  dropped). ADR-0023 EN regenerated against KO source which had newer
  HW Realization Notes (D16-D23) section.
- docs/adr-history/ left frozen by design (transitional state).
- CLAUDE.md (Part 2): update ADR Lifecycle for 4-folder layout, mark
  docs/adr-ko/ as a Derived Artifact, add ADR Translation Discipline
  section covering bidirectional sync, conflict resolution (EN wins),
  and proposed-language freedom.
- tools/verify_adr_lang_pairs.py: new verification tool checking pair
  completeness, filename mirroring, ADR-ID match, Status byte-equality.
  Pre-commit hook intentionally not added; run on demand or in CI.
- tests/test_verify_adr_lang_pairs.py: 11 cases including CRLF/LF
  normalization, em-dash title separator, underscore-slug edge case.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 01:38:44 -07:00

956 lines
40 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ADR-0027: Megatron-style Tensor Parallelism API
## Status
Accepted
## Context
### Goal
Support inter-SIP tensor parallelism (TP) via a **Megatron-LM style explicit
parallel layer** API. Declarative abstractions like DTensor are future work
in a separate ADR (0028).
Why Megatron-style was chosen:
- TP arises at specific layer boundaries of a model. Explicit primitives are
natural to the mental model.
- The de-facto industry standard established by NVIDIA Megatron / DeepSpeed.
- DTensor is declarative, so its design space is larger → phased approach.
### TP primitive spec (Megatron-LM reference)
- **ColumnParallelLinear**: shards the weight's **column (out_features)** axis
across TP ranks. Input is full-replicated, output is column-sharded. When a
RowParallelLinear follows, no forward all-reduce is required.
- **RowParallelLinear**: shards the weight's **row (in_features)** axis across
TP ranks. Input is already column-sharded (the output of ColumnParallel).
Requires an **all-reduce** at the end of forward.
- **VocabParallelEmbedding**: shards the embedding along the vocab axis.
all-reduce at the end of forward. (A stub in the initial scope; full
implementation requires an all-gather kernel as a prerequisite.)
- **`copy_to_tp_region`**, **`reduce_from_tp_region`**, **`scatter_to_tp_region`**,
**`gather_from_tp_region`** — basic primitives.
### Problems to solve
1. **Worker-wait generalization (D0)**: extend the defer/yield/drain pattern of
`dist.all_reduce` to every `ctx.wait` path. **The biggest architectural
decision of this ADR.**
2. **Launcher API normalization (D1)**: current benches use a hand-rolled
greenlet loop. Absorb it into `torch.multiprocessing.spawn(fn, args, nprocs)`
to preserve the real-PyTorch API surface + concentrate D0's scheduler drain
in a single implementation site.
3. **Per-rank weight shard representation**: each worker owns its own slice of
the weight tensor. Naturally expressed via ADR-0024's `set_device(rank)` +
ADR-0026's intra-device DPPolicy.
4. **Forward-only scope**: KernBench currently has no backward (simulation
purposes). This ADR prioritizes **forward only**. Training simulation is a
separate ADR.
5. **Collective call site**: RowParallelLinear calls `all_reduce` at the end of
forward. Naturally works with ADR-0024's multi-greenlet structure + D0
generalization.
6. **TP group concept**: Megatron crosses DP × TP × PP groups. The initial
scope simplifies to **TP group = all SIPs**. Mixed DP+TP is future work.
---
## Decision
### D0. Worker-wait generalization — `ctx.wait` defers to main when in worker context
**Restating the problem.** `kernel_runner.run` captures the `greenlet.getcurrent()`
at spawn time as the kernel greenlet's `_parent`
([kernel_runner.py:94](src/kernbench/triton_emu/kernel_runner.py#L94)).
If `env.run` runs in the main context, parent=main is safe. If `env.run` runs
in a worker context, parent=worker, and the moment the worker yields/finishes
the kernel greenlet becomes an orphan → `GreenletExit` → failure of ADR-0024
Phase B's `ring_default_ws`.
**Resolution.** When a worker greenlet calls `ctx.wait(h)`, instead of driving
`env.run` directly, **yield to the main scheduler**. main drives env.run and,
once the handle completes, control returns to the worker.
#### D0.1 `RuntimeContext` extension
```python
# context.py
@dataclass
class RuntimeContext:
...
_pending_worker_waits: list[RequestHandle] = field(default_factory=list, init=False)
```
#### D0.2 `ctx.wait` worker fork
```python
def wait(self, handle, *, _meta=None):
# Fast-path: already completed — skip enqueue + switch (consistent with
# D0.4-(3) idempotency). Avoids needless worker→main→worker round-trip
# and prevents redundant _pending_worker_waits growth.
if handle in self._completed:
completion, _trace = self.engine.get_completion(handle)
return completion
from greenlet import getcurrent
g = getcurrent()
if g.parent is not None and not g.parent.dead:
# Worker greenlet: defer to main. Push handle, yield to parent.
# Parent (scheduler loop) drains env.run, then switches back.
self._pending_worker_waits.append(handle)
g.parent.switch()
# On resume: handle must have completed (main drained the list).
# Fall through to the status-quo completion/trace assembly.
# Main context (or single-driver): drive engine directly.
wait_fn = getattr(self.engine, "wait", None)
if wait_fn is not None:
wait_fn(handle)
completion, trace = self.engine.get_completion(handle)
self._completed.add(handle)
if _meta is not None and trace is not None:
entry = dict(trace) if isinstance(trace, dict) else {"raw": trace}
entry.update(_meta)
self._traces.append(entry)
return completion
```
#### D0.3 `ctx.wait` worker-context semantic contract (normative)
This ADR **explicitly changes** the semantics of `ctx.wait` in worker context.
- **Submit-vs-complete separation**: when called from a worker, `ctx.wait(h)`
no longer guarantees "immediate completion" but instead guarantees
"completion **after the next scheduler drain**". The point at which the
worker returns from `wait()` = the point at which main has finished
`engine.wait` for that handle. Main-context calls remain immediate-synchronous
as before (status quo).
- **Resume invariant (normative)**: at the point a worker resumes from a
worker-deferred `ctx.wait(h)` (when `g.parent.switch()` returns), **`h in
ctx._completed` must be True**. If this invariant breaks, the worker
proceeds in a stale state, so whichever of `_drain_pending` / the scheduler
loop / `ctx.wait` is modified, this invariant must be preserved. T3.b
directly asserts this invariant.
- **Observable change**: the pattern `h = ctx.submit(msg); ctx.wait(h);
read(handle_result)` inside a worker still holds — but the semantic spec
now includes the fact that a main-drain is automatically inserted between
`wait()` and `read`.
- **Direct host-object reads see D0.5**: the contract for calling
`tensor.numpy()` without `ctx.wait` is specified separately in D0.5.
#### D0.4 Main scheduler drain — protocol (normative)
(The internal implementation of D1's `multiprocessing.spawn`. Below is the
semantic definition.)
```python
while alive:
for g in alive: # (1) round-based worker switch
g.switch()
_drain_pending(ctx) # (2) drain in main context
```
(The actual definition of `_drain_pending` is in D0.5 — an outer while-loop
that drains until both queues are empty.)
**Rules**:
1. **Round-based cooperative scheduling & yield obligation (worker contract)**.
`g.switch()` does not return until the worker **voluntarily yields**
(cooperative greenlet semantics). Therefore:
- If a worker runs a pure-compute loop like `while True: do_compute()`
without yielding, `g.switch()` never returns and **the scheduler loop
itself hard-blocks** (other workers cannot get a switch turn, no drain
occurs). This is not starvation but **scheduler non-progress (deadlock
equivalent)**, and this ADR classifies it as **unsupported**.
- Workers **must** call one of `ctx.wait(h)`, `dist.all_reduce`, or a
host-read barrier (D0.5) within a finite number of steps. The `forward`
of a TP layer includes a launch→wait pair at the end of every layer, so
this condition is naturally met. CCL kernels also yield inside
`dist.all_reduce`.
- Implementations need not **detect** this (timeouts/steps-since-yield
counters, etc.). It is a user contract; the symptom on violation is
"simulation hang".
- **Future extension**: if non-collective long compute paths become
common, an explicit `torch.distributed.cooperative_yield()` primitive
(no-op yield) could be introduced. Out of scope for this ADR. Not a
breaking change — can be added if needed.
- Within a round, every alive worker receives one `switch` turn. Even if
a single worker calls wait multiple times within one round, the calls
are enqueued sequentially within that turn and processed in a single
scheduler drain batch (FIFO).
2. **Drain order = submission order (FIFO)**. `_pending_worker_waits` is
strict FIFO via list append/pop(0). Drain occurs in submission order, not
completion order, and SimPy's scheduler itself guarantees a causally
correct completion order, so submission-order drain is safe. Do not
confuse `completion order` with `drain order`.
**Two-queue ordering (worker waits → collectives)**: `_drain_pending`
drains the worker wait queue first, then the collective queue. Rationale
for this ordering:
- **The two queues are different dependency sources**: worker waits are
handles produced by a worker's own `submit + wait` pair (tensor deploy,
MmuMap, etc.). The collective queue holds kernel-launch handles that
`dist.all_reduce` enqueues internally, which the worker never directly
waits on (see the two-queue drain model in D0.5).
- **Independent in correctness terms**: from the worker's perspective, a
collective is "already submitted, then yielded". Its completion timing
only needs to precede the worker's next action. There is no ordering
dependency with the worker wait queue.
- **Both finish within a single drain barrier**: per D0.5's
loop-until-empty rule, a single barrier invocation drains worker →
collective → (repeat if new ones appeared) in that order. By the time
the worker resumes, both sides are drained.
- **The alternative (collective first) is also valid**: this ADR fixes
worker-first only for current implementation simplicity; semantically
they are equivalent. Revisit if a performance-profile difference is
observed.
3. **Duplicate enqueue — correctness via idempotent drain; dedup not
guaranteed**. `ctx.wait(h)` returns immediately if `h in ctx._completed`.
`_drain_pending` uses the same guard. Even if the same handle is appended
to `_pending_worker_waits` multiple times, `engine.wait` is invoked only
once (idempotent).
- **Correctness**: relies on idempotent drain → safe.
- **Memory/performance**: this ADR **does not guarantee dedup** of
`_pending_worker_waits`. If the same handle is enqueued N times, the
queue retains N elements and drain performs N pops + in-set guards.
Unless a single worker abnormally repeats waits on the same handle, N
stays at the order of 1 to a few.
- **Implementation freedom**: implementations may optionally dedup (e.g.,
hold a `set` as a side index, or check `h not in pending_set` before
append). Classified as an optimization that does not change correctness.
4. **Exception propagation + sibling cleanup**.
When a worker greenlet raises, `g.switch()` propagates the exception to
main. The scheduler loop stops immediately and performs the following
cleanup **explicitly**:
```python
try:
while True:
alive = [g for g in gs if not g.dead]
if not alive:
break
for g in alive:
if not g.dead:
g.switch()
_drain_pending(ctx)
except Exception as outer:
# (a) Force-terminate surviving sibling worker greenlets.
for other in gs:
if not other.dead:
try:
other.throw(SystemExit)
except Exception:
pass # silent — already in exceptional state
# (b) Reset backend barrier / pending state (in preparation for future epoch barrier).
backend = getattr(ctx.distributed, "_backend", None)
if backend is not None and hasattr(backend, "_barrier"):
backend._barrier.reset()
backend_pending = getattr(backend, "_pending_collective_handles", None)
if backend_pending is not None:
backend_pending.clear()
ctx._pending_worker_waits.clear()
# (c) Wrap the originating exception in SpawnException.
raise SpawnException(errors) from outer
```
Protocol:
- **Sibling abort guarantee**: when one worker raises, `SystemExit` is
thrown into all sibling greenlets — greenlets terminate immediately. No
greenlet leaks.
- **Explicit pending-queue clear**: both queues (worker-wait +
collective-pending) are cleared. Prevents contamination on reuse.
- **`SpawnException(errors)` wrapping**: `errors: dict[int, Exception]`
contains the original exception per rank. Compatible with the failure
pattern of real-PyTorch `torch.multiprocessing.spawn`.
- **Scope restriction**: `errors` includes **only ranks that raised
from their own code (root cause)**. Ranks terminated via
`throw(SystemExit)` during sibling cleanup do not appear in `errors`
(SystemExit is not caught by D1.2's entry wrapper `try/except
Exception` — intentional design: sibling termination is a cleanup
signal, not a failure). Made explicit so readers do not expect "all
failed ranks" to appear.
- **`ctx._traces` is the partial state up to the moment of exception**.
Trace completeness is not guaranteed (some launches/all_reduces may
terminate without leaving an entry).
- **Allocator / MemoryStore** remain in their pre-exception state — reuse
is non-goal; creating a fresh `RuntimeContext` is recommended.
- **`join=False` / retry / partial recovery** are non-goals for this ADR.
`SpawnException` is defined in `runtime_api/multiprocessing.py`:
```python
class SpawnException(RuntimeError):
def __init__(self, errors: dict[int, Exception]):
self.errors = errors
first = next(iter(errors.items()), None)
msg = (f"spawn failed on ranks {sorted(errors.keys())}"
+ (f": rank {first[0]} raised {first[1]!r}" if first else ""))
super().__init__(msg)
```
5. **Single-driver compatibility**. In main-only execution where `g.parent is
None` (legacy single-driver tests), D0.2's worker-fork condition is false
→ the existing immediate-synchronous path is preserved. `_drain_pending`
is not invoked.
#### D0.5 Host-read barrier — decision (normative)
Inside a worker, **host-observable reads** such as `tensor.numpy()`,
`tensor.__getitem__`, and `tensor.data` are defined as **automatic drain
barriers**. Immediately before the call:
1. If `ctx._pending_worker_waits` or `backend._pending_collective_handles`
are non-empty, yield to main via `g.parent.switch()` → main runs
`_drain_pending` → worker resumes after completion.
2. If both queues are empty, read immediately.
**Barrier iteration protocol (normative — re-entrance)**: `_drain_pending`
drains via a while-loop **until both queues are completely empty**, not in a
single pass:
```python
def _drain_pending(ctx):
while ctx._pending_worker_waits or (
ctx.distributed._backend
and ctx.distributed._backend._pending_collective_handles
):
while ctx._pending_worker_waits:
h = ctx._pending_worker_waits.pop(0)
if h not in ctx._completed:
ctx.engine.wait(h)
backend = ctx.distributed._backend
if backend is not None:
while backend._pending_collective_handles:
h, _sip_id, meta = backend._pending_collective_handles.pop(0)
ctx.wait(h, _meta=meta) # main context: safe; ctx.wait will
# not push back to pending
```
**Main-context ctx.wait non-recursion invariant (normative)**: the
`ctx.wait(h, _meta=meta)` call inside `_drain_pending` runs in the main
greenlet context. Because D0.2's worker-fork condition (`g.parent is not
None and not g.parent.dead`) is False, it enters the immediate-synchronous
path → **never enqueues to `_pending_worker_waits`**. Thanks to this
invariant, the drain loop terminates without recursion / queue re-growth.
When implementing, it is important to maintain `g.parent is None` as the
single-main-greenlet guarantee.
**Why a loop**: `ctx.wait(h, _meta=meta)` is called in main context, so per
the D0.2 path it **drives the engine directly** (no additional enqueue — the
invariant above). In theory a single pass would suffice — but the protocol
is fixed at **loop-until-empty**. Reasons:
1. **Future-extension safety**: a future implementation might enqueue new
pending items mid-drain (e.g., tree-reduce collectives with sub-handles).
The loop protocol preserves correctness in that case.
2. **Readability**: the single sentence "the barrier drains until pending
is empty" closes the semantics. No dependence on the non-trivial
invariant that `ctx.wait` calls do not produce new enqueues.
3. **Barrier semantics are "all dependencies needed for this read are
complete"**: in the current model all pending = all dependencies, so the
two are identical. The user mental model is the former.
**Termination guarantee**: described under two regimes.
- **Current implementation**: when called in main context, `ctx.wait`
drives the engine directly (D0.2) → does not enqueue new pending. Each
iteration strictly shrinks pending size by `pop(0)` + `engine.wait`. The
iteration count is bounded by **the initial pending size itself** →
finite termination.
- **Future extension (the bound that justifies the loop protocol)**: if an
implementation enqueues new pending mid-drain (e.g., tree-reduce
sub-handles) is introduced, the initial-size bound breaks. However,
SimPy causality guarantees that the dependency DAG of handles is finite,
so **nested depth is finite**. The loop protocol automatically
accommodates this case.
Both regimes guarantee that infinite loops are impossible. The
single-pass bound of the current implementation is a reference value for
aggressive optimization; the protocol is fixed at loop-until-empty.
**Why implicit drain at read is correct**:
- In the original open question, the choice was between (a) implicit drain
and (b) explicit barrier. (b) is clear but burdens TP layer users with
the 3-step pattern `out = fc1.forward(x); ctx.drain(); result =
out.numpy()` on every read. (a) is a single rule that "guarantees the
read sees the reflected value" — identical to CUDA's `cudaDeviceSynchronize
before host copy` pattern, which is not a hidden rule but the **contract
of a named entry point**.
- This ADR adopts (a) but **closes the entry-point list explicitly**:
`Tensor.numpy()`, `Tensor.data` (numpy alias), `Tensor.__getitem__`,
`Tensor.__repr__` (when data is included), and any other official
host-read APIs are finalized via codebase search at the time of
implementing this ADR. Any newly added host-read API must follow this
contract (regression-guarded by tests).
- Even when calling `numpy` directly after only `ctx.submit` without
`wait`, the drain barrier still operates (because the handle is in the
pending queue). The invariant is restored at read time even if the user
omits an explicit wait.
**`Tensor.copy_(source)` — write barrier specification**:
`copy_` is semantically "write to target", but internally it calls
`source.numpy()` to fetch source data on the host then writes to each
shard via `target._memory_store.write(...)`. Both directions are
barrier-handled:
1. **Source-side (read barrier)**: `source.numpy()` triggers the D0.5 read
barrier (when source itself is a deployed tensor with pending).
2. **Target-side (write barrier — based on global pending)**: on `copy_`
entry, if `ctx._pending_worker_waits` or
`backend._pending_collective_handles` are non-empty, drain via
`g.parent.switch()` before writing. **Not per-tensor / per-shard
dependency tracking, but based on the global pending queue**.
- Why global: KernBench's handle representation does not retain the
reverse-mapping information "this handle writes to which shard of which
target". A safe conservative rule: "drain if any global pending
exists". As a result, **pending of an unrelated tensor can also block
copy_** — drop-in invariant takes priority.
- **Explicit tradeoff**: this rule can introduce unnecessary
serialization between independent tensors. However, under the current
single-queue execution model this cost is acceptable — guaranteeing
cross-rank correctness and the "read sees latest" invariant via a
simple rule takes precedence.
- Practical impact: most pending of a single worker within a layer step
is its own work — extra context switches from over-barrier often
coincide with the end-of-round scheduler drain point, so no major
issue.
- Future refinement: per-tensor pending tracking could narrow this
rule, but it is out of scope for this ADR.
**Non-barrier**:
- `tensor.shape`, `tensor.dtype`, `tensor.name`, and other
**metadata-only** access does not drain. No data dependency.
- `tensor.pa`, `tensor.va`, and other raw address accessors also do not
drain (address only, not content).
**Official barrier entry-points (closed set)**:
| API | Kind | Rationale |
|---|---|---|
| `Tensor.numpy()` | read | host-observable copy |
| `Tensor.data` | read | `numpy()` alias |
| `Tensor.__getitem__` | read | shard-aligned read |
| `Tensor.__repr__` (when data is included) | read | debugging/log |
| `Tensor.copy_(source)` | read + write | source read + target write |
This contract is verified directly in T5/T6.
#### D0.6 Why the worker function API is unchanged (informative)
- The inside of `torch.zeros(...)` is a `self.submit(msg)` + `self.wait(h)`
pair. `wait` auto-defers to main per D0.2/D0.3 — appears synchronous from
the outside but yields once.
- `tensor.numpy()` follows D0.5's host-read barrier → drain→read when
pending exists, immediate read otherwise.
- `dist.all_reduce` continues to use the existing `_defer_wait=True` +
`_pending_collective_handles` path. D0.4's drain processes both queues
together.
#### D0.7 Invariants
- **The kernel greenlet's `_parent` is always main**: because env.run never
runs in worker context. (Core assertion of T3.)
- **Cross-rank synchronization point**: drain occurs only after every
worker has yielded → kernels of all ranks advance together within one
round (a prerequisite for cross-rank IPCQ exchange).
- **Single-driver compatibility**: D0.4-(5).
### D1. `torch.multiprocessing.spawn(fn, args, nprocs)`
Real-PyTorch API parity + a single implementation site for D0's scheduler
loop.
#### D1.0 API parity only — not execution parity (normative)
The name `torch.multiprocessing.spawn` is restricted to **API signature
parity**. The actual execution model is a **cooperative greenlet scheduler**
(single Python process, single OS thread, round-robin drive per D0.4). The
following are **properties this ADR does NOT provide** — among the
guarantees of real-PyTorch `torch.multiprocessing.spawn`, explicitly
**non-goals**:
- Process isolation (independent OS process per rank).
- Independent address space (each rank with its own Python heap).
- Failure isolation (a hard crash in one rank not affecting others).
- OS-level scheduler fairness (preemptive time slicing between ranks).
- Inter-process primitives such as `mp.Queue`, `mp.Lock`.
Actual properties of this implementation:
- All ranks are greenlets inside the same Python process. Shared global
state is visible as-is (intentional simulation convenience).
- Single-threaded under the GIL → not parallel execution. Only "logical
concurrency" via SimPy event ordering is reproduced.
- Unhandled exception in any one worker → entire simulation aborts
(D0.4-(4)).
**Caller's obligation**: when porting real-PyTorch multi-process samples to
KernBench, logic that relies on process isolation (e.g., `os.getpid`,
independent temp files, signal handling) must be removed. The namespace name
is preserved for code portability — semantics differ.
#### D1.1 Public surface
```python
# runtime_api/multiprocessing.py (new)
class _MultiprocessingNamespace:
def __init__(self, ctx):
self._ctx = ctx
def spawn(self, fn, args: tuple, nprocs: int, join: bool = True) -> None:
"""Spawn `nprocs` worker greenlets, each calling fn(rank, *args).
Mirrors torch.multiprocessing.spawn signature (minus `daemon`).
Drives the D0 scheduler loop until all workers finish.
"""
...
```
#### D1.2 Implementation
```python
def spawn(self, fn, args, nprocs, join=True):
from greenlet import greenlet
ctx = self._ctx
dist = ctx.distributed
gs: list[greenlet] = []
errors: dict[int, Exception] = {}
for rank in range(nprocs):
def _entry(r=rank):
try:
fn(r, *args)
except Exception as e:
errors[r] = e
raise
g = greenlet(_entry)
dist._bind_rank(g, rank)
gs.append(g)
try:
while True:
alive = [g for g in gs if not g.dead]
if not alive:
break
for g in alive:
if not g.dead:
g.switch()
_drain_pending(ctx) # D0.5
except Exception as outer:
# Sibling cleanup per D0.4-(4)
for other in gs:
if not other.dead:
try:
other.throw(SystemExit)
except Exception:
pass
backend = getattr(dist, "_backend", None)
if backend is not None:
if hasattr(backend, "_barrier"):
backend._barrier.reset()
if getattr(backend, "_pending_collective_handles", None) is not None:
backend._pending_collective_handles.clear()
ctx._pending_worker_waits.clear()
raise SpawnException(errors) from outer
# `join=True` semantics: we already wait for all workers.
```
#### D1.3 `torch` namespace attach
In `runtime_api/context.py` `__post_init__`:
```python
self.multiprocessing = _MultiprocessingNamespace(self)
```
→ in bench code: `torch.multiprocessing.spawn(worker, args=(ws,), nprocs=ws)`.
#### D1.4 Migration of existing benches
The hand-rolled loop in `benches/ccl_allreduce.py` collapses into a single
`torch.multiprocessing.spawn` line. Existing matrix regressions are
preserved. The currently xfail `ring_default_ws` is expected to flip to
PASS thanks to D0 (workers no longer orphan the kernel greenlet).
### D2. New package `kernbench.tp`
```
src/kernbench/tp/
__init__.py — public API re-exports
parallel_state.py — TP group management (currently a single global group)
layers.py — ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
primitives.py — copy/reduce/scatter/gather_to/from_tp_region
kernels.py — gemm kernel launched by TP layers (reusable)
mappings.py — forward identity/all_reduce, backward stub
```
### D3. `parallel_state` — TP group
```python
# parallel_state.py
_TP_WORLD_SIZE = None
def initialize_model_parallel(tensor_model_parallel_size: int) -> None:
"""Initialize TP group. Must be called after dist.init_process_group."""
global _TP_WORLD_SIZE
from kernbench.runtime_api.distributed import get_dist # or torch.distributed
dist = get_dist()
total = dist.get_world_size()
if tensor_model_parallel_size != total:
raise NotImplementedError(
"Only TP == world_size supported in initial scope"
)
_TP_WORLD_SIZE = tensor_model_parallel_size
def get_tensor_model_parallel_world_size() -> int:
return _TP_WORLD_SIZE
def get_tensor_model_parallel_rank() -> int:
from kernbench.runtime_api.distributed import get_dist
return get_dist().get_rank() # ADR-0024 greenlet-local rank
```
Initial scope: TP size = world_size = topology SIP count. Pure TP model.
### D4-pre. TP shard ownership vs DPPolicy — role separation (normative)
In the weight/output representation of TP layers, two concepts are clearly
separated:
| Concept | Decided by | Scope |
|---|---|---|
| **TP shard ownership** (which rank owns which slice of the weight) | greenlet-local rank + `torch.ahbm.set_device(rank)` (ADR-0024 D2/D3) | **cross-rank, cross-SIP** |
| **Intra-rank placement** (how the owned slice is distributed across cube × PE inside the rank) | `DPPolicy(cube=..., pe=...)` (ADR-0026) | **inside one rank (within SIP boundary)** |
Thus when `ColumnParallelLinear` creates a weight of shape `(in_features,
out_features // ws)` and assigns `DPPolicy(cube="column_wise",
pe="column_wise")`:
- The slice owned by **rank r** = column-axis [r * k_local, (r+1) *
k_local) of the weight — **set_device(r)** determines this (that rank
resides on SIP r).
- **Inside that slice**, the cube × PE column-wise distribution — **DPPolicy**
determines this.
The two axes are **independent**. If two ranks build their own slice with
the same DPPolicy, the slices themselves live on different SIPs but the
intra-SIP placement pattern is the same. Conversely, changing DPPolicy to
`cube="replicate", pe="replicate"` preserves TP shard ownership and only
changes intra-rank placement.
**Mistakes that blur this boundary** (forbidden by this ADR):
- The "SIP axis" reappearing in DPPolicy (removed in ADR-0026).
- TP layers expressing cross-rank sharding via `DPPolicy` alone without
`set_device` → indistinguishable from a vertical split within a single
rank.
The TP layers of this ADR always treat weight/output from the perspective
of "rank = SIP = owns one slice + DPPolicy intra-SIP distribution" only.
### D4. `ColumnParallelLinear`
**Important**: no new host-side `torch.matmul` abstraction is introduced.
The layer's forward calls the existing gemm kernel via `torch.launch("gemm",
gemm_kernel, ...)` — the pattern already used by KernBench benches
([benches/gemm_single_pe.py](benches/gemm_single_pe.py),
[benches/gpt3_qkv.py](benches/gpt3_qkv.py)).
```python
# layers.py
from kernbench.policy.placement.dp import DPPolicy
from kernbench.tp.kernels import _gemm_kernel
from kernbench.tp.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
class ColumnParallelLinear:
"""Shards the K(out_features) axis of the weight across TP ranks.
forward(x):
x: (M, N) — full-replicated across ranks
W_k: (N, K / world_size) — rank-local slice (placed on SIP r via set_device)
y_k = x @ W_k → (M, K / world_size) — rank-local output
Output is column-sharded. The input form expected by RowParallelLinear.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False,
dtype: str = "f16", torch=None):
ws = get_tensor_model_parallel_world_size()
assert out_features % ws == 0
self.in_features = in_features
self.k_local = out_features // ws
self._torch = torch
# Each rank owns its own slice — placed on SIP r by set_device(rank).
self.weight = torch.zeros(
(in_features, self.k_local), dtype=dtype,
dp=DPPolicy(cube="column_wise", pe="column_wise"),
name="col_parallel_w",
)
self.bias = None
if bias:
self.bias = torch.zeros(
(self.k_local,), dtype=dtype,
dp=DPPolicy(cube="replicate", pe="replicate"),
name="col_parallel_b",
)
def forward(self, x):
# x is full-replicated (caller-guaranteed). Plain local gemm.
M = x.shape[0]
out = self._torch.empty(
(M, self.k_local), dtype=x.dtype,
dp=DPPolicy(cube="column_wise", pe="column_wise"),
name="col_parallel_out",
)
self._torch.launch(
"col_parallel_gemm", _gemm_kernel,
x, self.weight, out, M, self.in_features, self.k_local,
)
# bias add as a separate kernel or as fused bias of a composite gemm.
# Initial scope verifies bias=False sufficiently.
return out
```
**Yield-safety contract (normative)**: `ColumnParallelLinear.forward`
includes one `torch.launch` call containing a kernel launch → internal
`ctx.wait` pair. This automatically satisfies the "worker yields within a
finite number of steps" condition of D0.4-(1) — TP layer users do not need
to insert yield patterns manually.
### D5. `RowParallelLinear`
```python
class RowParallelLinear:
"""Shards the N(in_features) axis of the weight across TP ranks.
forward(x):
x: (M, N / world_size) — rank-local slice (output of ColumnParallel)
W_k: (N / world_size, K) — rank-local slice
y_k = x @ W_k → (M, K) — partial sum on each rank
y = all_reduce(y_k, op="sum") → (M, K) on every rank
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False,
dtype: str = "f16", torch=None):
ws = get_tensor_model_parallel_world_size()
assert in_features % ws == 0
self.n_local = in_features // ws
self.out_features = out_features
self._torch = torch
self.weight = torch.zeros(
(self.n_local, out_features), dtype=dtype,
dp=DPPolicy(cube="column_wise", pe="column_wise"),
name="row_parallel_w",
)
# bias only on rank 0 (Megatron convention). Omitted in initial scope.
self.bias = None
def forward(self, x):
M = x.shape[0]
y_partial = self._torch.empty(
(M, self.out_features), dtype=x.dtype,
dp=DPPolicy(cube="column_wise", pe="column_wise"),
name="row_parallel_partial",
)
self._torch.launch(
"row_parallel_gemm", _gemm_kernel,
x, self.weight, y_partial, M, self.n_local, self.out_features,
)
# Cross-rank reduce. ADR-0024's dist.all_reduce works correctly
# under D0 + mp.spawn (kernel parent = main is preserved).
self._torch.distributed.all_reduce(y_partial, op="sum")
return y_partial
```
**Yield-safety contract (normative)**: `RowParallelLinear.forward` includes
launch → internal wait followed by `all_reduce` (defer + worker-yield
pattern), so **at least 2 yields per forward** are guaranteed. The
scheduler-progress condition of D0.4-(1) is automatically satisfied. All TP
layer forwards in this ADR maintain the invariant "yield-safe by containing
at least one wait or collective" — any future TP primitives (e.g.,
VocabParallelEmbedding) must keep the same contract.
### D6. Primitive functions
```python
# primitives.py
def copy_to_tp_region(x):
"""Forward: identity. Backward: all-reduce. (Implemented when training is added)."""
return x
def reduce_from_tp_region(x, torch):
"""Forward: all-reduce. Backward: identity."""
torch.distributed.all_reduce(x, op="sum")
return x
def scatter_to_tp_region(x):
raise NotImplementedError(
"Phase 2: replaced by users creating already-sharded tensors"
)
def gather_from_tp_region(x):
raise NotImplementedError(
"Phase 2: requires all-gather kernel as a prerequisite (future)"
)
```
### D7. Sample bench — 2-layer MLP with TP
```python
# benches/tp_mlp.py (new)
from kernbench.policy.placement.dp import DPPolicy
import kernbench.tp as tp
import numpy as np
def worker(rank: int, world_size: int, torch):
torch.ahbm.set_device(rank)
tp.initialize_model_parallel(world_size)
B, D_in, D_hidden, D_out = 1, 512, 2048, 512
fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=torch)
fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=torch)
x = torch.zeros(
(B, D_in), dtype="f16",
dp=DPPolicy(cube="replicate", pe="replicate"),
name="x",
)
# init x with some pattern (e.g., constant)
x.copy_(torch.from_numpy(np.full((B, D_in), 0.1, dtype=np.float16)))
h = fc1.forward(x) # column-sharded (B, D_hidden / ws)
y = fc2.forward(h) # all-reduced (B, D_out) on every rank
# Only rank 0 prints / verifies the result
if rank == 0:
result = y.numpy()
# With zero-init weights, all values are 0 — within scope "completion itself" is the check
print(f" tp_mlp: shape={result.shape}, mean={float(result.mean()):.4f}")
def run(torch):
torch.distributed.init_process_group(backend="ahbm")
ws = torch.distributed.get_world_size()
torch.multiprocessing.spawn(worker, args=(ws,), nprocs=ws)
```
### D8. Non-functional — training not supported
This ADR is **inference/forward only**. Backward / gradient / optimizer is
future work. Natural because KernBench is not a training system.
### D9. Initial-scope constraints
- TP size = world_size (no mixed DP+TP).
- `scatter_to_tp_region`, `gather_from_tp_region` are unimplemented.
- **Default weight value is zero**. Proper init schemes (Xavier, Kaiming,
etc.) are future. Tests inject deterministic non-zero patterns via
`tensor.copy_` to verify numerical correctness (T2/T6). I.e., operate as
"production default = zero, verification = deterministic non-zero".
- Bias is omitted in the initial scope (Megatron's rank-0-only bias policy
is future).
- Pipeline parallelism is out of scope.
- VocabParallelEmbedding requires a prerequisite all-gather → stub only.
### D10. Regression: `ring_default_ws` xfail removal — mandatory acceptance
Thanks to D0 (worker-wait generalization) + D0.5 (host-read barrier), every
worker-driven `ctx.wait` and host-read is routed through the main-drain path
→ the cause of the kernel-greenlet orphan in ADR-0024 Phase B disappears.
Flipping the existing matrix test's `ring_default_ws` strict-xfail case to
**PASS** after this ADR's implementation is included as a **mandatory
regression criterion**. Observable acceptance criteria are specified in
**T7** (no deadlock, no GreenletExit, numerical tolerance, etc.).
---
## Dependencies
- **ADR-0024** (launcher): rank = SIP, greenlet-local rank,
`torch.ahbm.set_device(rank)`.
- **ADR-0026** (DPPolicy intra-device): per-rank slice representation of
weight tensors.
- **ADR-0023 / ADR-0025** (IPCQ): foundation of `dist.all_reduce`
implementation.
---
## Non-goals
- **Backward pass / training**: inference only. Training simulation is a
separate ADR.
- **Mixed parallelism (DP + TP + PP)**: pure TP only at the start.
- **Weight init schemes**: simple zero / debug pattern.
- **Fused ops**: Megatron's fused matmul+bias+gelu is a kernel-level
concern.
- **DTensor integration**: ADR-0028 future.
- **Host-side `torch.matmul` abstraction**: TP layers call the existing
gemm kernel via `torch.launch(gemm_kernel, ...)`. No new matmul host-op
is introduced.
---
## Open questions
- **Location of `initialize_model_parallel`**:
`kernbench.tp.initialize_model_parallel` (current decision) vs
real-PyTorch's `torch.distributed.init_device_mesh`. Kept in the TP-only
module.
- **Weight init**: the ADR uses zero. A debug pattern (e.g., identity) may
be needed for valid verification — add at Phase 1 test time if needed.
- **Bias placement policy**: Megatron places RowParallelLinear bias only on
rank 0. Avoided in the initial scope via bias=False.
- **GEMM kernel location**: `kernbench.tp.kernels._gemm_kernel` vs
importing from existing `benches/gemm_single_pe.py`. TP must not depend
on benches, so duplicated inside tp. Migration to a shared
`kernbench.kernels` package is possible later.
**Resolved (previously open in earlier revisions)**:
- ~~Drain timing on `tensor.numpy()` call~~ → **decided in D0.5**: the
official host-read entry points (`numpy`, `data`, `__getitem__`,
data-containing `__repr__`) are automatic drain barriers. Metadata-only
accessors are not barriers.
---
## Consequences
### Positive
- **Easy porting of Megatron code**: API matches real training code.
- **TP benchmarking enabled**: research on scaling,
communication-compute overlap, and other HW characteristics.
- **`ring_default_ws` xfail removal**: as a byproduct of D0, the ADR-0024
Phase B blocker is resolved.
- **Scheduler-loop unification**: introducing D1 (`mp.spawn`) removes the
hand-rolled loop. Subsequent collective/TP benches reuse the same
pattern.
- **DPPolicy semantics clarified** (synergy with ADR-0026): TP layers as a
best-practice example of using intra-device DPPolicy only.
### Negative
- Maintenance cost of a new module (`kernbench.tp`).
- Initial scope is limited (pure TP only, forward only).
- D0 generalization changes the semantics of `ctx.wait` — compatibility
with single-driver tests must be explicitly verified (T7).
### Neutral
- A pure upper layer added on top of ADR-0024/0026. No impact on the
hardware-simulation stack (apart from D0).