a796c1d2f7
Establish English as the canonical ADR language with Korean translations held in a parallel docs/adr-ko/ tree as derived artifacts (1:1 mirror). Promotion from adr-proposed/ to adr/ now writes English to adr/ and the Korean to adr-ko/; bidirectional sync rule documented in CLAUDE.md. - Migrate 30 ADRs in docs/adr/: 28 Korean-only translated to English, 2 bilingual pairs (ADR-0020, ADR-0023) consolidated (.en.md suffix dropped). ADR-0023 EN regenerated against KO source which had newer HW Realization Notes (D16-D23) section. - docs/adr-history/ left frozen by design (transitional state). - CLAUDE.md (Part 2): update ADR Lifecycle for 4-folder layout, mark docs/adr-ko/ as a Derived Artifact, add ADR Translation Discipline section covering bidirectional sync, conflict resolution (EN wins), and proposed-language freedom. - tools/verify_adr_lang_pairs.py: new verification tool checking pair completeness, filename mirroring, ADR-ID match, Status byte-equality. Pre-commit hook intentionally not added; run on demand or in CI. - tests/test_verify_adr_lang_pairs.py: 11 cases including CRLF/LF normalization, em-dash title separator, underscore-slug edge case. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
956 lines
40 KiB
Markdown
956 lines
40 KiB
Markdown
# ADR-0027: Megatron-style Tensor Parallelism API
|
||
|
||
## Status
|
||
|
||
Accepted
|
||
|
||
## Context
|
||
|
||
### Goal
|
||
|
||
Support inter-SIP tensor parallelism (TP) via a **Megatron-LM style explicit
|
||
parallel layer** API. Declarative abstractions like DTensor are future work
|
||
in a separate ADR (0028).
|
||
|
||
Why Megatron-style was chosen:
|
||
- TP arises at specific layer boundaries of a model. Explicit primitives are
|
||
natural to the mental model.
|
||
- The de-facto industry standard established by NVIDIA Megatron / DeepSpeed.
|
||
- DTensor is declarative, so its design space is larger → phased approach.
|
||
|
||
### TP primitive spec (Megatron-LM reference)
|
||
|
||
- **ColumnParallelLinear**: shards the weight's **column (out_features)** axis
|
||
across TP ranks. Input is full-replicated, output is column-sharded. When a
|
||
RowParallelLinear follows, no forward all-reduce is required.
|
||
- **RowParallelLinear**: shards the weight's **row (in_features)** axis across
|
||
TP ranks. Input is already column-sharded (the output of ColumnParallel).
|
||
Requires an **all-reduce** at the end of forward.
|
||
- **VocabParallelEmbedding**: shards the embedding along the vocab axis.
|
||
all-reduce at the end of forward. (A stub in the initial scope; full
|
||
implementation requires an all-gather kernel as a prerequisite.)
|
||
- **`copy_to_tp_region`**, **`reduce_from_tp_region`**, **`scatter_to_tp_region`**,
|
||
**`gather_from_tp_region`** — basic primitives.
|
||
|
||
### Problems to solve
|
||
|
||
1. **Worker-wait generalization (D0)**: extend the defer/yield/drain pattern of
|
||
`dist.all_reduce` to every `ctx.wait` path. **The biggest architectural
|
||
decision of this ADR.**
|
||
|
||
2. **Launcher API normalization (D1)**: current benches use a hand-rolled
|
||
greenlet loop. Absorb it into `torch.multiprocessing.spawn(fn, args, nprocs)`
|
||
to preserve the real-PyTorch API surface + concentrate D0's scheduler drain
|
||
in a single implementation site.
|
||
|
||
3. **Per-rank weight shard representation**: each worker owns its own slice of
|
||
the weight tensor. Naturally expressed via ADR-0024's `set_device(rank)` +
|
||
ADR-0026's intra-device DPPolicy.
|
||
|
||
4. **Forward-only scope**: KernBench currently has no backward (simulation
|
||
purposes). This ADR prioritizes **forward only**. Training simulation is a
|
||
separate ADR.
|
||
|
||
5. **Collective call site**: RowParallelLinear calls `all_reduce` at the end of
|
||
forward. Naturally works with ADR-0024's multi-greenlet structure + D0
|
||
generalization.
|
||
|
||
6. **TP group concept**: Megatron crosses DP × TP × PP groups. The initial
|
||
scope simplifies to **TP group = all SIPs**. Mixed DP+TP is future work.
|
||
|
||
---
|
||
|
||
## Decision
|
||
|
||
### D0. Worker-wait generalization — `ctx.wait` defers to main when in worker context
|
||
|
||
**Restating the problem.** `kernel_runner.run` captures the `greenlet.getcurrent()`
|
||
at spawn time as the kernel greenlet's `_parent`
|
||
([kernel_runner.py:94](src/kernbench/triton_emu/kernel_runner.py#L94)).
|
||
If `env.run` runs in the main context, parent=main is safe. If `env.run` runs
|
||
in a worker context, parent=worker, and the moment the worker yields/finishes
|
||
the kernel greenlet becomes an orphan → `GreenletExit` → failure of ADR-0024
|
||
Phase B's `ring_default_ws`.
|
||
|
||
**Resolution.** When a worker greenlet calls `ctx.wait(h)`, instead of driving
|
||
`env.run` directly, **yield to the main scheduler**. main drives env.run and,
|
||
once the handle completes, control returns to the worker.
|
||
|
||
#### D0.1 `RuntimeContext` extension
|
||
|
||
```python
|
||
# context.py
|
||
@dataclass
|
||
class RuntimeContext:
|
||
...
|
||
_pending_worker_waits: list[RequestHandle] = field(default_factory=list, init=False)
|
||
```
|
||
|
||
#### D0.2 `ctx.wait` worker fork
|
||
|
||
```python
|
||
def wait(self, handle, *, _meta=None):
|
||
# Fast-path: already completed — skip enqueue + switch (consistent with
|
||
# D0.4-(3) idempotency). Avoids needless worker→main→worker round-trip
|
||
# and prevents redundant _pending_worker_waits growth.
|
||
if handle in self._completed:
|
||
completion, _trace = self.engine.get_completion(handle)
|
||
return completion
|
||
|
||
from greenlet import getcurrent
|
||
g = getcurrent()
|
||
if g.parent is not None and not g.parent.dead:
|
||
# Worker greenlet: defer to main. Push handle, yield to parent.
|
||
# Parent (scheduler loop) drains env.run, then switches back.
|
||
self._pending_worker_waits.append(handle)
|
||
g.parent.switch()
|
||
# On resume: handle must have completed (main drained the list).
|
||
# Fall through to the status-quo completion/trace assembly.
|
||
|
||
# Main context (or single-driver): drive engine directly.
|
||
wait_fn = getattr(self.engine, "wait", None)
|
||
if wait_fn is not None:
|
||
wait_fn(handle)
|
||
completion, trace = self.engine.get_completion(handle)
|
||
self._completed.add(handle)
|
||
if _meta is not None and trace is not None:
|
||
entry = dict(trace) if isinstance(trace, dict) else {"raw": trace}
|
||
entry.update(_meta)
|
||
self._traces.append(entry)
|
||
return completion
|
||
```
|
||
|
||
#### D0.3 `ctx.wait` worker-context semantic contract (normative)
|
||
|
||
This ADR **explicitly changes** the semantics of `ctx.wait` in worker context.
|
||
|
||
- **Submit-vs-complete separation**: when called from a worker, `ctx.wait(h)`
|
||
no longer guarantees "immediate completion" but instead guarantees
|
||
"completion **after the next scheduler drain**". The point at which the
|
||
worker returns from `wait()` = the point at which main has finished
|
||
`engine.wait` for that handle. Main-context calls remain immediate-synchronous
|
||
as before (status quo).
|
||
- **Resume invariant (normative)**: at the point a worker resumes from a
|
||
worker-deferred `ctx.wait(h)` (when `g.parent.switch()` returns), **`h in
|
||
ctx._completed` must be True**. If this invariant breaks, the worker
|
||
proceeds in a stale state, so whichever of `_drain_pending` / the scheduler
|
||
loop / `ctx.wait` is modified, this invariant must be preserved. T3.b
|
||
directly asserts this invariant.
|
||
- **Observable change**: the pattern `h = ctx.submit(msg); ctx.wait(h);
|
||
read(handle_result)` inside a worker still holds — but the semantic spec
|
||
now includes the fact that a main-drain is automatically inserted between
|
||
`wait()` and `read`.
|
||
- **Direct host-object reads see D0.5**: the contract for calling
|
||
`tensor.numpy()` without `ctx.wait` is specified separately in D0.5.
|
||
|
||
#### D0.4 Main scheduler drain — protocol (normative)
|
||
|
||
(The internal implementation of D1's `multiprocessing.spawn`. Below is the
|
||
semantic definition.)
|
||
|
||
```python
|
||
while alive:
|
||
for g in alive: # (1) round-based worker switch
|
||
g.switch()
|
||
_drain_pending(ctx) # (2) drain in main context
|
||
```
|
||
|
||
(The actual definition of `_drain_pending` is in D0.5 — an outer while-loop
|
||
that drains until both queues are empty.)
|
||
|
||
**Rules**:
|
||
|
||
1. **Round-based cooperative scheduling & yield obligation (worker contract)**.
|
||
`g.switch()` does not return until the worker **voluntarily yields**
|
||
(cooperative greenlet semantics). Therefore:
|
||
- If a worker runs a pure-compute loop like `while True: do_compute()`
|
||
without yielding, `g.switch()` never returns and **the scheduler loop
|
||
itself hard-blocks** (other workers cannot get a switch turn, no drain
|
||
occurs). This is not starvation but **scheduler non-progress (deadlock
|
||
equivalent)**, and this ADR classifies it as **unsupported**.
|
||
- Workers **must** call one of `ctx.wait(h)`, `dist.all_reduce`, or a
|
||
host-read barrier (D0.5) within a finite number of steps. The `forward`
|
||
of a TP layer includes a launch→wait pair at the end of every layer, so
|
||
this condition is naturally met. CCL kernels also yield inside
|
||
`dist.all_reduce`.
|
||
- Implementations need not **detect** this (timeouts/steps-since-yield
|
||
counters, etc.). It is a user contract; the symptom on violation is
|
||
"simulation hang".
|
||
- **Future extension**: if non-collective long compute paths become
|
||
common, an explicit `torch.distributed.cooperative_yield()` primitive
|
||
(no-op yield) could be introduced. Out of scope for this ADR. Not a
|
||
breaking change — can be added if needed.
|
||
- Within a round, every alive worker receives one `switch` turn. Even if
|
||
a single worker calls wait multiple times within one round, the calls
|
||
are enqueued sequentially within that turn and processed in a single
|
||
scheduler drain batch (FIFO).
|
||
|
||
2. **Drain order = submission order (FIFO)**. `_pending_worker_waits` is
|
||
strict FIFO via list append/pop(0). Drain occurs in submission order, not
|
||
completion order, and SimPy's scheduler itself guarantees a causally
|
||
correct completion order, so submission-order drain is safe. Do not
|
||
confuse `completion order` with `drain order`.
|
||
|
||
**Two-queue ordering (worker waits → collectives)**: `_drain_pending`
|
||
drains the worker wait queue first, then the collective queue. Rationale
|
||
for this ordering:
|
||
- **The two queues are different dependency sources**: worker waits are
|
||
handles produced by a worker's own `submit + wait` pair (tensor deploy,
|
||
MmuMap, etc.). The collective queue holds kernel-launch handles that
|
||
`dist.all_reduce` enqueues internally, which the worker never directly
|
||
waits on (see the two-queue drain model in D0.5).
|
||
- **Independent in correctness terms**: from the worker's perspective, a
|
||
collective is "already submitted, then yielded". Its completion timing
|
||
only needs to precede the worker's next action. There is no ordering
|
||
dependency with the worker wait queue.
|
||
- **Both finish within a single drain barrier**: per D0.5's
|
||
loop-until-empty rule, a single barrier invocation drains worker →
|
||
collective → (repeat if new ones appeared) in that order. By the time
|
||
the worker resumes, both sides are drained.
|
||
- **The alternative (collective first) is also valid**: this ADR fixes
|
||
worker-first only for current implementation simplicity; semantically
|
||
they are equivalent. Revisit if a performance-profile difference is
|
||
observed.
|
||
|
||
3. **Duplicate enqueue — correctness via idempotent drain; dedup not
|
||
guaranteed**. `ctx.wait(h)` returns immediately if `h in ctx._completed`.
|
||
`_drain_pending` uses the same guard. Even if the same handle is appended
|
||
to `_pending_worker_waits` multiple times, `engine.wait` is invoked only
|
||
once (idempotent).
|
||
- **Correctness**: relies on idempotent drain → safe.
|
||
- **Memory/performance**: this ADR **does not guarantee dedup** of
|
||
`_pending_worker_waits`. If the same handle is enqueued N times, the
|
||
queue retains N elements and drain performs N pops + in-set guards.
|
||
Unless a single worker abnormally repeats waits on the same handle, N
|
||
stays at the order of 1 to a few.
|
||
- **Implementation freedom**: implementations may optionally dedup (e.g.,
|
||
hold a `set` as a side index, or check `h not in pending_set` before
|
||
append). Classified as an optimization that does not change correctness.
|
||
|
||
4. **Exception propagation + sibling cleanup**.
|
||
When a worker greenlet raises, `g.switch()` propagates the exception to
|
||
main. The scheduler loop stops immediately and performs the following
|
||
cleanup **explicitly**:
|
||
|
||
```python
|
||
try:
|
||
while True:
|
||
alive = [g for g in gs if not g.dead]
|
||
if not alive:
|
||
break
|
||
for g in alive:
|
||
if not g.dead:
|
||
g.switch()
|
||
_drain_pending(ctx)
|
||
except Exception as outer:
|
||
# (a) Force-terminate surviving sibling worker greenlets.
|
||
for other in gs:
|
||
if not other.dead:
|
||
try:
|
||
other.throw(SystemExit)
|
||
except Exception:
|
||
pass # silent — already in exceptional state
|
||
# (b) Reset backend barrier / pending state (in preparation for future epoch barrier).
|
||
backend = getattr(ctx.distributed, "_backend", None)
|
||
if backend is not None and hasattr(backend, "_barrier"):
|
||
backend._barrier.reset()
|
||
backend_pending = getattr(backend, "_pending_collective_handles", None)
|
||
if backend_pending is not None:
|
||
backend_pending.clear()
|
||
ctx._pending_worker_waits.clear()
|
||
# (c) Wrap the originating exception in SpawnException.
|
||
raise SpawnException(errors) from outer
|
||
```
|
||
|
||
Protocol:
|
||
- **Sibling abort guarantee**: when one worker raises, `SystemExit` is
|
||
thrown into all sibling greenlets — greenlets terminate immediately. No
|
||
greenlet leaks.
|
||
- **Explicit pending-queue clear**: both queues (worker-wait +
|
||
collective-pending) are cleared. Prevents contamination on reuse.
|
||
- **`SpawnException(errors)` wrapping**: `errors: dict[int, Exception]`
|
||
contains the original exception per rank. Compatible with the failure
|
||
pattern of real-PyTorch `torch.multiprocessing.spawn`.
|
||
- **Scope restriction**: `errors` includes **only ranks that raised
|
||
from their own code (root cause)**. Ranks terminated via
|
||
`throw(SystemExit)` during sibling cleanup do not appear in `errors`
|
||
(SystemExit is not caught by D1.2's entry wrapper `try/except
|
||
Exception` — intentional design: sibling termination is a cleanup
|
||
signal, not a failure). Made explicit so readers do not expect "all
|
||
failed ranks" to appear.
|
||
- **`ctx._traces` is the partial state up to the moment of exception**.
|
||
Trace completeness is not guaranteed (some launches/all_reduces may
|
||
terminate without leaving an entry).
|
||
- **Allocator / MemoryStore** remain in their pre-exception state — reuse
|
||
is non-goal; creating a fresh `RuntimeContext` is recommended.
|
||
- **`join=False` / retry / partial recovery** are non-goals for this ADR.
|
||
|
||
`SpawnException` is defined in `runtime_api/multiprocessing.py`:
|
||
|
||
```python
|
||
class SpawnException(RuntimeError):
|
||
def __init__(self, errors: dict[int, Exception]):
|
||
self.errors = errors
|
||
first = next(iter(errors.items()), None)
|
||
msg = (f"spawn failed on ranks {sorted(errors.keys())}"
|
||
+ (f": rank {first[0]} raised {first[1]!r}" if first else ""))
|
||
super().__init__(msg)
|
||
```
|
||
|
||
5. **Single-driver compatibility**. In main-only execution where `g.parent is
|
||
None` (legacy single-driver tests), D0.2's worker-fork condition is false
|
||
→ the existing immediate-synchronous path is preserved. `_drain_pending`
|
||
is not invoked.
|
||
|
||
#### D0.5 Host-read barrier — decision (normative)
|
||
|
||
Inside a worker, **host-observable reads** such as `tensor.numpy()`,
|
||
`tensor.__getitem__`, and `tensor.data` are defined as **automatic drain
|
||
barriers**. Immediately before the call:
|
||
|
||
1. If `ctx._pending_worker_waits` or `backend._pending_collective_handles`
|
||
are non-empty, yield to main via `g.parent.switch()` → main runs
|
||
`_drain_pending` → worker resumes after completion.
|
||
2. If both queues are empty, read immediately.
|
||
|
||
**Barrier iteration protocol (normative — re-entrance)**: `_drain_pending`
|
||
drains via a while-loop **until both queues are completely empty**, not in a
|
||
single pass:
|
||
|
||
```python
|
||
def _drain_pending(ctx):
|
||
while ctx._pending_worker_waits or (
|
||
ctx.distributed._backend
|
||
and ctx.distributed._backend._pending_collective_handles
|
||
):
|
||
while ctx._pending_worker_waits:
|
||
h = ctx._pending_worker_waits.pop(0)
|
||
if h not in ctx._completed:
|
||
ctx.engine.wait(h)
|
||
backend = ctx.distributed._backend
|
||
if backend is not None:
|
||
while backend._pending_collective_handles:
|
||
h, _sip_id, meta = backend._pending_collective_handles.pop(0)
|
||
ctx.wait(h, _meta=meta) # main context: safe; ctx.wait will
|
||
# not push back to pending
|
||
```
|
||
|
||
**Main-context ctx.wait non-recursion invariant (normative)**: the
|
||
`ctx.wait(h, _meta=meta)` call inside `_drain_pending` runs in the main
|
||
greenlet context. Because D0.2's worker-fork condition (`g.parent is not
|
||
None and not g.parent.dead`) is False, it enters the immediate-synchronous
|
||
path → **never enqueues to `_pending_worker_waits`**. Thanks to this
|
||
invariant, the drain loop terminates without recursion / queue re-growth.
|
||
When implementing, it is important to maintain `g.parent is None` as the
|
||
single-main-greenlet guarantee.
|
||
|
||
**Why a loop**: `ctx.wait(h, _meta=meta)` is called in main context, so per
|
||
the D0.2 path it **drives the engine directly** (no additional enqueue — the
|
||
invariant above). In theory a single pass would suffice — but the protocol
|
||
is fixed at **loop-until-empty**. Reasons:
|
||
|
||
1. **Future-extension safety**: a future implementation might enqueue new
|
||
pending items mid-drain (e.g., tree-reduce collectives with sub-handles).
|
||
The loop protocol preserves correctness in that case.
|
||
2. **Readability**: the single sentence "the barrier drains until pending
|
||
is empty" closes the semantics. No dependence on the non-trivial
|
||
invariant that `ctx.wait` calls do not produce new enqueues.
|
||
3. **Barrier semantics are "all dependencies needed for this read are
|
||
complete"**: in the current model all pending = all dependencies, so the
|
||
two are identical. The user mental model is the former.
|
||
|
||
**Termination guarantee**: described under two regimes.
|
||
|
||
- **Current implementation**: when called in main context, `ctx.wait`
|
||
drives the engine directly (D0.2) → does not enqueue new pending. Each
|
||
iteration strictly shrinks pending size by `pop(0)` + `engine.wait`. The
|
||
iteration count is bounded by **the initial pending size itself** →
|
||
finite termination.
|
||
- **Future extension (the bound that justifies the loop protocol)**: if an
|
||
implementation enqueues new pending mid-drain (e.g., tree-reduce
|
||
sub-handles) is introduced, the initial-size bound breaks. However,
|
||
SimPy causality guarantees that the dependency DAG of handles is finite,
|
||
so **nested depth is finite**. The loop protocol automatically
|
||
accommodates this case.
|
||
|
||
Both regimes guarantee that infinite loops are impossible. The
|
||
single-pass bound of the current implementation is a reference value for
|
||
aggressive optimization; the protocol is fixed at loop-until-empty.
|
||
|
||
**Why implicit drain at read is correct**:
|
||
|
||
- In the original open question, the choice was between (a) implicit drain
|
||
and (b) explicit barrier. (b) is clear but burdens TP layer users with
|
||
the 3-step pattern `out = fc1.forward(x); ctx.drain(); result =
|
||
out.numpy()` on every read. (a) is a single rule that "guarantees the
|
||
read sees the reflected value" — identical to CUDA's `cudaDeviceSynchronize
|
||
before host copy` pattern, which is not a hidden rule but the **contract
|
||
of a named entry point**.
|
||
- This ADR adopts (a) but **closes the entry-point list explicitly**:
|
||
`Tensor.numpy()`, `Tensor.data` (numpy alias), `Tensor.__getitem__`,
|
||
`Tensor.__repr__` (when data is included), and any other official
|
||
host-read APIs are finalized via codebase search at the time of
|
||
implementing this ADR. Any newly added host-read API must follow this
|
||
contract (regression-guarded by tests).
|
||
- Even when calling `numpy` directly after only `ctx.submit` without
|
||
`wait`, the drain barrier still operates (because the handle is in the
|
||
pending queue). The invariant is restored at read time even if the user
|
||
omits an explicit wait.
|
||
|
||
**`Tensor.copy_(source)` — write barrier specification**:
|
||
|
||
`copy_` is semantically "write to target", but internally it calls
|
||
`source.numpy()` to fetch source data on the host then writes to each
|
||
shard via `target._memory_store.write(...)`. Both directions are
|
||
barrier-handled:
|
||
|
||
1. **Source-side (read barrier)**: `source.numpy()` triggers the D0.5 read
|
||
barrier (when source itself is a deployed tensor with pending).
|
||
2. **Target-side (write barrier — based on global pending)**: on `copy_`
|
||
entry, if `ctx._pending_worker_waits` or
|
||
`backend._pending_collective_handles` are non-empty, drain via
|
||
`g.parent.switch()` before writing. **Not per-tensor / per-shard
|
||
dependency tracking, but based on the global pending queue**.
|
||
- Why global: KernBench's handle representation does not retain the
|
||
reverse-mapping information "this handle writes to which shard of which
|
||
target". A safe conservative rule: "drain if any global pending
|
||
exists". As a result, **pending of an unrelated tensor can also block
|
||
copy_** — drop-in invariant takes priority.
|
||
- **Explicit tradeoff**: this rule can introduce unnecessary
|
||
serialization between independent tensors. However, under the current
|
||
single-queue execution model this cost is acceptable — guaranteeing
|
||
cross-rank correctness and the "read sees latest" invariant via a
|
||
simple rule takes precedence.
|
||
- Practical impact: most pending of a single worker within a layer step
|
||
is its own work — extra context switches from over-barrier often
|
||
coincide with the end-of-round scheduler drain point, so no major
|
||
issue.
|
||
- Future refinement: per-tensor pending tracking could narrow this
|
||
rule, but it is out of scope for this ADR.
|
||
|
||
**Non-barrier**:
|
||
|
||
- `tensor.shape`, `tensor.dtype`, `tensor.name`, and other
|
||
**metadata-only** access does not drain. No data dependency.
|
||
- `tensor.pa`, `tensor.va`, and other raw address accessors also do not
|
||
drain (address only, not content).
|
||
|
||
**Official barrier entry-points (closed set)**:
|
||
|
||
| API | Kind | Rationale |
|
||
|---|---|---|
|
||
| `Tensor.numpy()` | read | host-observable copy |
|
||
| `Tensor.data` | read | `numpy()` alias |
|
||
| `Tensor.__getitem__` | read | shard-aligned read |
|
||
| `Tensor.__repr__` (when data is included) | read | debugging/log |
|
||
| `Tensor.copy_(source)` | read + write | source read + target write |
|
||
|
||
This contract is verified directly in T5/T6.
|
||
|
||
#### D0.6 Why the worker function API is unchanged (informative)
|
||
|
||
- The inside of `torch.zeros(...)` is a `self.submit(msg)` + `self.wait(h)`
|
||
pair. `wait` auto-defers to main per D0.2/D0.3 — appears synchronous from
|
||
the outside but yields once.
|
||
- `tensor.numpy()` follows D0.5's host-read barrier → drain→read when
|
||
pending exists, immediate read otherwise.
|
||
- `dist.all_reduce` continues to use the existing `_defer_wait=True` +
|
||
`_pending_collective_handles` path. D0.4's drain processes both queues
|
||
together.
|
||
|
||
#### D0.7 Invariants
|
||
|
||
- **The kernel greenlet's `_parent` is always main**: because env.run never
|
||
runs in worker context. (Core assertion of T3.)
|
||
- **Cross-rank synchronization point**: drain occurs only after every
|
||
worker has yielded → kernels of all ranks advance together within one
|
||
round (a prerequisite for cross-rank IPCQ exchange).
|
||
- **Single-driver compatibility**: D0.4-(5).
|
||
|
||
### D1. `torch.multiprocessing.spawn(fn, args, nprocs)`
|
||
|
||
Real-PyTorch API parity + a single implementation site for D0's scheduler
|
||
loop.
|
||
|
||
#### D1.0 API parity only — not execution parity (normative)
|
||
|
||
The name `torch.multiprocessing.spawn` is restricted to **API signature
|
||
parity**. The actual execution model is a **cooperative greenlet scheduler**
|
||
(single Python process, single OS thread, round-robin drive per D0.4). The
|
||
following are **properties this ADR does NOT provide** — among the
|
||
guarantees of real-PyTorch `torch.multiprocessing.spawn`, explicitly
|
||
**non-goals**:
|
||
|
||
- Process isolation (independent OS process per rank).
|
||
- Independent address space (each rank with its own Python heap).
|
||
- Failure isolation (a hard crash in one rank not affecting others).
|
||
- OS-level scheduler fairness (preemptive time slicing between ranks).
|
||
- Inter-process primitives such as `mp.Queue`, `mp.Lock`.
|
||
|
||
Actual properties of this implementation:
|
||
|
||
- All ranks are greenlets inside the same Python process. Shared global
|
||
state is visible as-is (intentional simulation convenience).
|
||
- Single-threaded under the GIL → not parallel execution. Only "logical
|
||
concurrency" via SimPy event ordering is reproduced.
|
||
- Unhandled exception in any one worker → entire simulation aborts
|
||
(D0.4-(4)).
|
||
|
||
**Caller's obligation**: when porting real-PyTorch multi-process samples to
|
||
KernBench, logic that relies on process isolation (e.g., `os.getpid`,
|
||
independent temp files, signal handling) must be removed. The namespace name
|
||
is preserved for code portability — semantics differ.
|
||
|
||
#### D1.1 Public surface
|
||
|
||
```python
|
||
# runtime_api/multiprocessing.py (new)
|
||
class _MultiprocessingNamespace:
|
||
def __init__(self, ctx):
|
||
self._ctx = ctx
|
||
|
||
def spawn(self, fn, args: tuple, nprocs: int, join: bool = True) -> None:
|
||
"""Spawn `nprocs` worker greenlets, each calling fn(rank, *args).
|
||
|
||
Mirrors torch.multiprocessing.spawn signature (minus `daemon`).
|
||
Drives the D0 scheduler loop until all workers finish.
|
||
"""
|
||
...
|
||
```
|
||
|
||
#### D1.2 Implementation
|
||
|
||
```python
|
||
def spawn(self, fn, args, nprocs, join=True):
|
||
from greenlet import greenlet
|
||
ctx = self._ctx
|
||
dist = ctx.distributed
|
||
gs: list[greenlet] = []
|
||
errors: dict[int, Exception] = {}
|
||
for rank in range(nprocs):
|
||
def _entry(r=rank):
|
||
try:
|
||
fn(r, *args)
|
||
except Exception as e:
|
||
errors[r] = e
|
||
raise
|
||
g = greenlet(_entry)
|
||
dist._bind_rank(g, rank)
|
||
gs.append(g)
|
||
|
||
try:
|
||
while True:
|
||
alive = [g for g in gs if not g.dead]
|
||
if not alive:
|
||
break
|
||
for g in alive:
|
||
if not g.dead:
|
||
g.switch()
|
||
_drain_pending(ctx) # D0.5
|
||
except Exception as outer:
|
||
# Sibling cleanup per D0.4-(4)
|
||
for other in gs:
|
||
if not other.dead:
|
||
try:
|
||
other.throw(SystemExit)
|
||
except Exception:
|
||
pass
|
||
backend = getattr(dist, "_backend", None)
|
||
if backend is not None:
|
||
if hasattr(backend, "_barrier"):
|
||
backend._barrier.reset()
|
||
if getattr(backend, "_pending_collective_handles", None) is not None:
|
||
backend._pending_collective_handles.clear()
|
||
ctx._pending_worker_waits.clear()
|
||
raise SpawnException(errors) from outer
|
||
# `join=True` semantics: we already wait for all workers.
|
||
```
|
||
|
||
#### D1.3 `torch` namespace attach
|
||
|
||
In `runtime_api/context.py` `__post_init__`:
|
||
```python
|
||
self.multiprocessing = _MultiprocessingNamespace(self)
|
||
```
|
||
|
||
→ in bench code: `torch.multiprocessing.spawn(worker, args=(ws,), nprocs=ws)`.
|
||
|
||
#### D1.4 Migration of existing benches
|
||
|
||
The hand-rolled loop in `benches/ccl_allreduce.py` collapses into a single
|
||
`torch.multiprocessing.spawn` line. Existing matrix regressions are
|
||
preserved. The currently xfail `ring_default_ws` is expected to flip to
|
||
PASS thanks to D0 (workers no longer orphan the kernel greenlet).
|
||
|
||
### D2. New package `kernbench.tp`
|
||
|
||
```
|
||
src/kernbench/tp/
|
||
__init__.py — public API re-exports
|
||
parallel_state.py — TP group management (currently a single global group)
|
||
layers.py — ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
|
||
primitives.py — copy/reduce/scatter/gather_to/from_tp_region
|
||
kernels.py — gemm kernel launched by TP layers (reusable)
|
||
mappings.py — forward identity/all_reduce, backward stub
|
||
```
|
||
|
||
### D3. `parallel_state` — TP group
|
||
|
||
```python
|
||
# parallel_state.py
|
||
_TP_WORLD_SIZE = None
|
||
|
||
def initialize_model_parallel(tensor_model_parallel_size: int) -> None:
|
||
"""Initialize TP group. Must be called after dist.init_process_group."""
|
||
global _TP_WORLD_SIZE
|
||
from kernbench.runtime_api.distributed import get_dist # or torch.distributed
|
||
dist = get_dist()
|
||
total = dist.get_world_size()
|
||
if tensor_model_parallel_size != total:
|
||
raise NotImplementedError(
|
||
"Only TP == world_size supported in initial scope"
|
||
)
|
||
_TP_WORLD_SIZE = tensor_model_parallel_size
|
||
|
||
def get_tensor_model_parallel_world_size() -> int:
|
||
return _TP_WORLD_SIZE
|
||
|
||
def get_tensor_model_parallel_rank() -> int:
|
||
from kernbench.runtime_api.distributed import get_dist
|
||
return get_dist().get_rank() # ADR-0024 greenlet-local rank
|
||
```
|
||
|
||
Initial scope: TP size = world_size = topology SIP count. Pure TP model.
|
||
|
||
### D4-pre. TP shard ownership vs DPPolicy — role separation (normative)
|
||
|
||
In the weight/output representation of TP layers, two concepts are clearly
|
||
separated:
|
||
|
||
| Concept | Decided by | Scope |
|
||
|---|---|---|
|
||
| **TP shard ownership** (which rank owns which slice of the weight) | greenlet-local rank + `torch.ahbm.set_device(rank)` (ADR-0024 D2/D3) | **cross-rank, cross-SIP** |
|
||
| **Intra-rank placement** (how the owned slice is distributed across cube × PE inside the rank) | `DPPolicy(cube=..., pe=...)` (ADR-0026) | **inside one rank (within SIP boundary)** |
|
||
|
||
Thus when `ColumnParallelLinear` creates a weight of shape `(in_features,
|
||
out_features // ws)` and assigns `DPPolicy(cube="column_wise",
|
||
pe="column_wise")`:
|
||
|
||
- The slice owned by **rank r** = column-axis [r * k_local, (r+1) *
|
||
k_local) of the weight — **set_device(r)** determines this (that rank
|
||
resides on SIP r).
|
||
- **Inside that slice**, the cube × PE column-wise distribution — **DPPolicy**
|
||
determines this.
|
||
|
||
The two axes are **independent**. If two ranks build their own slice with
|
||
the same DPPolicy, the slices themselves live on different SIPs but the
|
||
intra-SIP placement pattern is the same. Conversely, changing DPPolicy to
|
||
`cube="replicate", pe="replicate"` preserves TP shard ownership and only
|
||
changes intra-rank placement.
|
||
|
||
**Mistakes that blur this boundary** (forbidden by this ADR):
|
||
|
||
- The "SIP axis" reappearing in DPPolicy (removed in ADR-0026).
|
||
- TP layers expressing cross-rank sharding via `DPPolicy` alone without
|
||
`set_device` → indistinguishable from a vertical split within a single
|
||
rank.
|
||
|
||
The TP layers of this ADR always treat weight/output from the perspective
|
||
of "rank = SIP = owns one slice + DPPolicy intra-SIP distribution" only.
|
||
|
||
### D4. `ColumnParallelLinear`
|
||
|
||
**Important**: no new host-side `torch.matmul` abstraction is introduced.
|
||
The layer's forward calls the existing gemm kernel via `torch.launch("gemm",
|
||
gemm_kernel, ...)` — the pattern already used by KernBench benches
|
||
([benches/gemm_single_pe.py](benches/gemm_single_pe.py),
|
||
[benches/gpt3_qkv.py](benches/gpt3_qkv.py)).
|
||
|
||
```python
|
||
# layers.py
|
||
from kernbench.policy.placement.dp import DPPolicy
|
||
from kernbench.tp.kernels import _gemm_kernel
|
||
from kernbench.tp.parallel_state import (
|
||
get_tensor_model_parallel_rank,
|
||
get_tensor_model_parallel_world_size,
|
||
)
|
||
|
||
class ColumnParallelLinear:
|
||
"""Shards the K(out_features) axis of the weight across TP ranks.
|
||
|
||
forward(x):
|
||
x: (M, N) — full-replicated across ranks
|
||
W_k: (N, K / world_size) — rank-local slice (placed on SIP r via set_device)
|
||
y_k = x @ W_k → (M, K / world_size) — rank-local output
|
||
|
||
Output is column-sharded. The input form expected by RowParallelLinear.
|
||
"""
|
||
|
||
def __init__(self, in_features: int, out_features: int, bias: bool = False,
|
||
dtype: str = "f16", torch=None):
|
||
ws = get_tensor_model_parallel_world_size()
|
||
assert out_features % ws == 0
|
||
self.in_features = in_features
|
||
self.k_local = out_features // ws
|
||
self._torch = torch
|
||
# Each rank owns its own slice — placed on SIP r by set_device(rank).
|
||
self.weight = torch.zeros(
|
||
(in_features, self.k_local), dtype=dtype,
|
||
dp=DPPolicy(cube="column_wise", pe="column_wise"),
|
||
name="col_parallel_w",
|
||
)
|
||
self.bias = None
|
||
if bias:
|
||
self.bias = torch.zeros(
|
||
(self.k_local,), dtype=dtype,
|
||
dp=DPPolicy(cube="replicate", pe="replicate"),
|
||
name="col_parallel_b",
|
||
)
|
||
|
||
def forward(self, x):
|
||
# x is full-replicated (caller-guaranteed). Plain local gemm.
|
||
M = x.shape[0]
|
||
out = self._torch.empty(
|
||
(M, self.k_local), dtype=x.dtype,
|
||
dp=DPPolicy(cube="column_wise", pe="column_wise"),
|
||
name="col_parallel_out",
|
||
)
|
||
self._torch.launch(
|
||
"col_parallel_gemm", _gemm_kernel,
|
||
x, self.weight, out, M, self.in_features, self.k_local,
|
||
)
|
||
# bias add as a separate kernel or as fused bias of a composite gemm.
|
||
# Initial scope verifies bias=False sufficiently.
|
||
return out
|
||
```
|
||
|
||
**Yield-safety contract (normative)**: `ColumnParallelLinear.forward`
|
||
includes one `torch.launch` call containing a kernel launch → internal
|
||
`ctx.wait` pair. This automatically satisfies the "worker yields within a
|
||
finite number of steps" condition of D0.4-(1) — TP layer users do not need
|
||
to insert yield patterns manually.
|
||
|
||
### D5. `RowParallelLinear`
|
||
|
||
```python
|
||
class RowParallelLinear:
|
||
"""Shards the N(in_features) axis of the weight across TP ranks.
|
||
|
||
forward(x):
|
||
x: (M, N / world_size) — rank-local slice (output of ColumnParallel)
|
||
W_k: (N / world_size, K) — rank-local slice
|
||
y_k = x @ W_k → (M, K) — partial sum on each rank
|
||
y = all_reduce(y_k, op="sum") → (M, K) on every rank
|
||
"""
|
||
|
||
def __init__(self, in_features: int, out_features: int, bias: bool = False,
|
||
dtype: str = "f16", torch=None):
|
||
ws = get_tensor_model_parallel_world_size()
|
||
assert in_features % ws == 0
|
||
self.n_local = in_features // ws
|
||
self.out_features = out_features
|
||
self._torch = torch
|
||
self.weight = torch.zeros(
|
||
(self.n_local, out_features), dtype=dtype,
|
||
dp=DPPolicy(cube="column_wise", pe="column_wise"),
|
||
name="row_parallel_w",
|
||
)
|
||
# bias only on rank 0 (Megatron convention). Omitted in initial scope.
|
||
self.bias = None
|
||
|
||
def forward(self, x):
|
||
M = x.shape[0]
|
||
y_partial = self._torch.empty(
|
||
(M, self.out_features), dtype=x.dtype,
|
||
dp=DPPolicy(cube="column_wise", pe="column_wise"),
|
||
name="row_parallel_partial",
|
||
)
|
||
self._torch.launch(
|
||
"row_parallel_gemm", _gemm_kernel,
|
||
x, self.weight, y_partial, M, self.n_local, self.out_features,
|
||
)
|
||
# Cross-rank reduce. ADR-0024's dist.all_reduce works correctly
|
||
# under D0 + mp.spawn (kernel parent = main is preserved).
|
||
self._torch.distributed.all_reduce(y_partial, op="sum")
|
||
return y_partial
|
||
```
|
||
|
||
**Yield-safety contract (normative)**: `RowParallelLinear.forward` includes
|
||
launch → internal wait followed by `all_reduce` (defer + worker-yield
|
||
pattern), so **at least 2 yields per forward** are guaranteed. The
|
||
scheduler-progress condition of D0.4-(1) is automatically satisfied. All TP
|
||
layer forwards in this ADR maintain the invariant "yield-safe by containing
|
||
at least one wait or collective" — any future TP primitives (e.g.,
|
||
VocabParallelEmbedding) must keep the same contract.
|
||
|
||
### D6. Primitive functions
|
||
|
||
```python
|
||
# primitives.py
|
||
def copy_to_tp_region(x):
|
||
"""Forward: identity. Backward: all-reduce. (Implemented when training is added)."""
|
||
return x
|
||
|
||
def reduce_from_tp_region(x, torch):
|
||
"""Forward: all-reduce. Backward: identity."""
|
||
torch.distributed.all_reduce(x, op="sum")
|
||
return x
|
||
|
||
def scatter_to_tp_region(x):
|
||
raise NotImplementedError(
|
||
"Phase 2: replaced by users creating already-sharded tensors"
|
||
)
|
||
|
||
def gather_from_tp_region(x):
|
||
raise NotImplementedError(
|
||
"Phase 2: requires all-gather kernel as a prerequisite (future)"
|
||
)
|
||
```
|
||
|
||
### D7. Sample bench — 2-layer MLP with TP
|
||
|
||
```python
|
||
# benches/tp_mlp.py (new)
|
||
from kernbench.policy.placement.dp import DPPolicy
|
||
import kernbench.tp as tp
|
||
import numpy as np
|
||
|
||
|
||
def worker(rank: int, world_size: int, torch):
|
||
torch.ahbm.set_device(rank)
|
||
tp.initialize_model_parallel(world_size)
|
||
|
||
B, D_in, D_hidden, D_out = 1, 512, 2048, 512
|
||
fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=torch)
|
||
fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=torch)
|
||
|
||
x = torch.zeros(
|
||
(B, D_in), dtype="f16",
|
||
dp=DPPolicy(cube="replicate", pe="replicate"),
|
||
name="x",
|
||
)
|
||
# init x with some pattern (e.g., constant)
|
||
x.copy_(torch.from_numpy(np.full((B, D_in), 0.1, dtype=np.float16)))
|
||
|
||
h = fc1.forward(x) # column-sharded (B, D_hidden / ws)
|
||
y = fc2.forward(h) # all-reduced (B, D_out) on every rank
|
||
|
||
# Only rank 0 prints / verifies the result
|
||
if rank == 0:
|
||
result = y.numpy()
|
||
# With zero-init weights, all values are 0 — within scope "completion itself" is the check
|
||
print(f" tp_mlp: shape={result.shape}, mean={float(result.mean()):.4f}")
|
||
|
||
|
||
def run(torch):
|
||
torch.distributed.init_process_group(backend="ahbm")
|
||
ws = torch.distributed.get_world_size()
|
||
torch.multiprocessing.spawn(worker, args=(ws,), nprocs=ws)
|
||
```
|
||
|
||
### D8. Non-functional — training not supported
|
||
|
||
This ADR is **inference/forward only**. Backward / gradient / optimizer is
|
||
future work. Natural because KernBench is not a training system.
|
||
|
||
### D9. Initial-scope constraints
|
||
|
||
- TP size = world_size (no mixed DP+TP).
|
||
- `scatter_to_tp_region`, `gather_from_tp_region` are unimplemented.
|
||
- **Default weight value is zero**. Proper init schemes (Xavier, Kaiming,
|
||
etc.) are future. Tests inject deterministic non-zero patterns via
|
||
`tensor.copy_` to verify numerical correctness (T2/T6). I.e., operate as
|
||
"production default = zero, verification = deterministic non-zero".
|
||
- Bias is omitted in the initial scope (Megatron's rank-0-only bias policy
|
||
is future).
|
||
- Pipeline parallelism is out of scope.
|
||
- VocabParallelEmbedding requires a prerequisite all-gather → stub only.
|
||
|
||
### D10. Regression: `ring_default_ws` xfail removal — mandatory acceptance
|
||
|
||
Thanks to D0 (worker-wait generalization) + D0.5 (host-read barrier), every
|
||
worker-driven `ctx.wait` and host-read is routed through the main-drain path
|
||
→ the cause of the kernel-greenlet orphan in ADR-0024 Phase B disappears.
|
||
Flipping the existing matrix test's `ring_default_ws` strict-xfail case to
|
||
**PASS** after this ADR's implementation is included as a **mandatory
|
||
regression criterion**. Observable acceptance criteria are specified in
|
||
**T7** (no deadlock, no GreenletExit, numerical tolerance, etc.).
|
||
|
||
---
|
||
|
||
## Dependencies
|
||
|
||
- **ADR-0024** (launcher): rank = SIP, greenlet-local rank,
|
||
`torch.ahbm.set_device(rank)`.
|
||
- **ADR-0026** (DPPolicy intra-device): per-rank slice representation of
|
||
weight tensors.
|
||
- **ADR-0023 / ADR-0025** (IPCQ): foundation of `dist.all_reduce`
|
||
implementation.
|
||
|
||
---
|
||
|
||
## Non-goals
|
||
|
||
- **Backward pass / training**: inference only. Training simulation is a
|
||
separate ADR.
|
||
- **Mixed parallelism (DP + TP + PP)**: pure TP only at the start.
|
||
- **Weight init schemes**: simple zero / debug pattern.
|
||
- **Fused ops**: Megatron's fused matmul+bias+gelu is a kernel-level
|
||
concern.
|
||
- **DTensor integration**: ADR-0028 future.
|
||
- **Host-side `torch.matmul` abstraction**: TP layers call the existing
|
||
gemm kernel via `torch.launch(gemm_kernel, ...)`. No new matmul host-op
|
||
is introduced.
|
||
|
||
---
|
||
|
||
## Open questions
|
||
|
||
- **Location of `initialize_model_parallel`**:
|
||
`kernbench.tp.initialize_model_parallel` (current decision) vs
|
||
real-PyTorch's `torch.distributed.init_device_mesh`. Kept in the TP-only
|
||
module.
|
||
- **Weight init**: the ADR uses zero. A debug pattern (e.g., identity) may
|
||
be needed for valid verification — add at Phase 1 test time if needed.
|
||
- **Bias placement policy**: Megatron places RowParallelLinear bias only on
|
||
rank 0. Avoided in the initial scope via bias=False.
|
||
- **GEMM kernel location**: `kernbench.tp.kernels._gemm_kernel` vs
|
||
importing from existing `benches/gemm_single_pe.py`. TP must not depend
|
||
on benches, so duplicated inside tp. Migration to a shared
|
||
`kernbench.kernels` package is possible later.
|
||
|
||
**Resolved (previously open in earlier revisions)**:
|
||
- ~~Drain timing on `tensor.numpy()` call~~ → **decided in D0.5**: the
|
||
official host-read entry points (`numpy`, `data`, `__getitem__`,
|
||
data-containing `__repr__`) are automatic drain barriers. Metadata-only
|
||
accessors are not barriers.
|
||
|
||
---
|
||
|
||
## Consequences
|
||
|
||
### Positive
|
||
|
||
- **Easy porting of Megatron code**: API matches real training code.
|
||
- **TP benchmarking enabled**: research on scaling,
|
||
communication-compute overlap, and other HW characteristics.
|
||
- **`ring_default_ws` xfail removal**: as a byproduct of D0, the ADR-0024
|
||
Phase B blocker is resolved.
|
||
- **Scheduler-loop unification**: introducing D1 (`mp.spawn`) removes the
|
||
hand-rolled loop. Subsequent collective/TP benches reuse the same
|
||
pattern.
|
||
- **DPPolicy semantics clarified** (synergy with ADR-0026): TP layers as a
|
||
best-practice example of using intra-device DPPolicy only.
|
||
|
||
### Negative
|
||
|
||
- Maintenance cost of a new module (`kernbench.tp`).
|
||
- Initial scope is limited (pure TP only, forward only).
|
||
- D0 generalization changes the semantics of `ctx.wait` — compatibility
|
||
with single-driver tests must be explicitly verified (T7).
|
||
|
||
### Neutral
|
||
|
||
- A pure upper layer added on top of ADR-0024/0026. No impact on the
|
||
hardware-simulation stack (apart from D0).
|