# ADR-0027: Megatron-style Tensor Parallelism API ## Status Accepted ## Context ### Goal Support inter-SIP tensor parallelism (TP) via a **Megatron-LM style explicit parallel layer** API. Declarative abstractions like DTensor are future work in a separate ADR (0028). Why Megatron-style was chosen: - TP arises at specific layer boundaries of a model. Explicit primitives are natural to the mental model. - The de-facto industry standard established by NVIDIA Megatron / DeepSpeed. - DTensor is declarative, so its design space is larger → phased approach. ### TP primitive spec (Megatron-LM reference) - **ColumnParallelLinear**: shards the weight's **column (out_features)** axis across TP ranks. Input is full-replicated, output is column-sharded. When a RowParallelLinear follows, no forward all-reduce is required. - **RowParallelLinear**: shards the weight's **row (in_features)** axis across TP ranks. Input is already column-sharded (the output of ColumnParallel). Requires an **all-reduce** at the end of forward. - **VocabParallelEmbedding**: shards the embedding along the vocab axis. all-reduce at the end of forward. (A stub in the initial scope; full implementation requires an all-gather kernel as a prerequisite.) - **`copy_to_tp_region`**, **`reduce_from_tp_region`**, **`scatter_to_tp_region`**, **`gather_from_tp_region`** — basic primitives. ### Problems to solve 1. **Worker-wait generalization (D0)**: extend the defer/yield/drain pattern of `dist.all_reduce` to every `ctx.wait` path. **The biggest architectural decision of this ADR.** 2. **Launcher API normalization (D1)**: current benches use a hand-rolled greenlet loop. Absorb it into `torch.multiprocessing.spawn(fn, args, nprocs)` to preserve the real-PyTorch API surface + concentrate D0's scheduler drain in a single implementation site. 3. **Per-rank weight shard representation**: each worker owns its own slice of the weight tensor. Naturally expressed via ADR-0024's `set_device(rank)` + ADR-0026's intra-device DPPolicy. 4. **Forward-only scope**: KernBench currently has no backward (simulation purposes). This ADR prioritizes **forward only**. Training simulation is a separate ADR. 5. **Collective call site**: RowParallelLinear calls `all_reduce` at the end of forward. Naturally works with ADR-0024's multi-greenlet structure + D0 generalization. 6. **TP group concept**: Megatron crosses DP × TP × PP groups. The initial scope simplifies to **TP group = all SIPs**. Mixed DP+TP is future work. --- ## Decision ### D0. Worker-wait generalization — `ctx.wait` defers to main when in worker context **Restating the problem.** `kernel_runner.run` captures the `greenlet.getcurrent()` at spawn time as the kernel greenlet's `_parent` ([kernel_runner.py:94](src/kernbench/triton_emu/kernel_runner.py#L94)). If `env.run` runs in the main context, parent=main is safe. If `env.run` runs in a worker context, parent=worker, and the moment the worker yields/finishes the kernel greenlet becomes an orphan → `GreenletExit` → failure of ADR-0024 Phase B's `ring_default_ws`. **Resolution.** When a worker greenlet calls `ctx.wait(h)`, instead of driving `env.run` directly, **yield to the main scheduler**. main drives env.run and, once the handle completes, control returns to the worker. #### D0.1 `RuntimeContext` extension ```python # context.py @dataclass class RuntimeContext: ... _pending_worker_waits: list[RequestHandle] = field(default_factory=list, init=False) ``` #### D0.2 `ctx.wait` worker fork ```python def wait(self, handle, *, _meta=None): # Fast-path: already completed — skip enqueue + switch (consistent with # D0.4-(3) idempotency). Avoids needless worker→main→worker round-trip # and prevents redundant _pending_worker_waits growth. if handle in self._completed: completion, _trace = self.engine.get_completion(handle) return completion from greenlet import getcurrent g = getcurrent() if g.parent is not None and not g.parent.dead: # Worker greenlet: defer to main. Push handle, yield to parent. # Parent (scheduler loop) drains env.run, then switches back. self._pending_worker_waits.append(handle) g.parent.switch() # On resume: handle must have completed (main drained the list). # Fall through to the status-quo completion/trace assembly. # Main context (or single-driver): drive engine directly. wait_fn = getattr(self.engine, "wait", None) if wait_fn is not None: wait_fn(handle) completion, trace = self.engine.get_completion(handle) self._completed.add(handle) if _meta is not None and trace is not None: entry = dict(trace) if isinstance(trace, dict) else {"raw": trace} entry.update(_meta) self._traces.append(entry) return completion ``` #### D0.3 `ctx.wait` worker-context semantic contract (normative) This ADR **explicitly changes** the semantics of `ctx.wait` in worker context. - **Submit-vs-complete separation**: when called from a worker, `ctx.wait(h)` no longer guarantees "immediate completion" but instead guarantees "completion **after the next scheduler drain**". The point at which the worker returns from `wait()` = the point at which main has finished `engine.wait` for that handle. Main-context calls remain immediate-synchronous as before (status quo). - **Resume invariant (normative)**: at the point a worker resumes from a worker-deferred `ctx.wait(h)` (when `g.parent.switch()` returns), **`h in ctx._completed` must be True**. If this invariant breaks, the worker proceeds in a stale state, so whichever of `_drain_pending` / the scheduler loop / `ctx.wait` is modified, this invariant must be preserved. T3.b directly asserts this invariant. - **Observable change**: the pattern `h = ctx.submit(msg); ctx.wait(h); read(handle_result)` inside a worker still holds — but the semantic spec now includes the fact that a main-drain is automatically inserted between `wait()` and `read`. - **Direct host-object reads see D0.5**: the contract for calling `tensor.numpy()` without `ctx.wait` is specified separately in D0.5. #### D0.4 Main scheduler drain — protocol (normative) (The internal implementation of D1's `multiprocessing.spawn`. Below is the semantic definition.) ```python while alive: for g in alive: # (1) round-based worker switch g.switch() _drain_pending(ctx) # (2) drain in main context ``` (The actual definition of `_drain_pending` is in D0.5 — an outer while-loop that drains until both queues are empty.) **Rules**: 1. **Round-based cooperative scheduling & yield obligation (worker contract)**. `g.switch()` does not return until the worker **voluntarily yields** (cooperative greenlet semantics). Therefore: - If a worker runs a pure-compute loop like `while True: do_compute()` without yielding, `g.switch()` never returns and **the scheduler loop itself hard-blocks** (other workers cannot get a switch turn, no drain occurs). This is not starvation but **scheduler non-progress (deadlock equivalent)**, and this ADR classifies it as **unsupported**. - Workers **must** call one of `ctx.wait(h)`, `dist.all_reduce`, or a host-read barrier (D0.5) within a finite number of steps. The `forward` of a TP layer includes a launch→wait pair at the end of every layer, so this condition is naturally met. CCL kernels also yield inside `dist.all_reduce`. - Implementations need not **detect** this (timeouts/steps-since-yield counters, etc.). It is a user contract; the symptom on violation is "simulation hang". - **Future extension**: if non-collective long compute paths become common, an explicit `torch.distributed.cooperative_yield()` primitive (no-op yield) could be introduced. Out of scope for this ADR. Not a breaking change — can be added if needed. - Within a round, every alive worker receives one `switch` turn. Even if a single worker calls wait multiple times within one round, the calls are enqueued sequentially within that turn and processed in a single scheduler drain batch (FIFO). 2. **Drain order = submission order (FIFO)**. `_pending_worker_waits` is strict FIFO via list append/pop(0). Drain occurs in submission order, not completion order, and SimPy's scheduler itself guarantees a causally correct completion order, so submission-order drain is safe. Do not confuse `completion order` with `drain order`. **Two-queue ordering (worker waits → collectives)**: `_drain_pending` drains the worker wait queue first, then the collective queue. Rationale for this ordering: - **The two queues are different dependency sources**: worker waits are handles produced by a worker's own `submit + wait` pair (tensor deploy, MmuMap, etc.). The collective queue holds kernel-launch handles that `dist.all_reduce` enqueues internally, which the worker never directly waits on (see the two-queue drain model in D0.5). - **Independent in correctness terms**: from the worker's perspective, a collective is "already submitted, then yielded". Its completion timing only needs to precede the worker's next action. There is no ordering dependency with the worker wait queue. - **Both finish within a single drain barrier**: per D0.5's loop-until-empty rule, a single barrier invocation drains worker → collective → (repeat if new ones appeared) in that order. By the time the worker resumes, both sides are drained. - **The alternative (collective first) is also valid**: this ADR fixes worker-first only for current implementation simplicity; semantically they are equivalent. Revisit if a performance-profile difference is observed. 3. **Duplicate enqueue — correctness via idempotent drain; dedup not guaranteed**. `ctx.wait(h)` returns immediately if `h in ctx._completed`. `_drain_pending` uses the same guard. Even if the same handle is appended to `_pending_worker_waits` multiple times, `engine.wait` is invoked only once (idempotent). - **Correctness**: relies on idempotent drain → safe. - **Memory/performance**: this ADR **does not guarantee dedup** of `_pending_worker_waits`. If the same handle is enqueued N times, the queue retains N elements and drain performs N pops + in-set guards. Unless a single worker abnormally repeats waits on the same handle, N stays at the order of 1 to a few. - **Implementation freedom**: implementations may optionally dedup (e.g., hold a `set` as a side index, or check `h not in pending_set` before append). Classified as an optimization that does not change correctness. 4. **Exception propagation + sibling cleanup**. When a worker greenlet raises, `g.switch()` propagates the exception to main. The scheduler loop stops immediately and performs the following cleanup **explicitly**: ```python try: while True: alive = [g for g in gs if not g.dead] if not alive: break for g in alive: if not g.dead: g.switch() _drain_pending(ctx) except Exception as outer: # (a) Force-terminate surviving sibling worker greenlets. for other in gs: if not other.dead: try: other.throw(SystemExit) except Exception: pass # silent — already in exceptional state # (b) Reset backend barrier / pending state (in preparation for future epoch barrier). backend = getattr(ctx.distributed, "_backend", None) if backend is not None and hasattr(backend, "_barrier"): backend._barrier.reset() backend_pending = getattr(backend, "_pending_collective_handles", None) if backend_pending is not None: backend_pending.clear() ctx._pending_worker_waits.clear() # (c) Wrap the originating exception in SpawnException. raise SpawnException(errors) from outer ``` Protocol: - **Sibling abort guarantee**: when one worker raises, `SystemExit` is thrown into all sibling greenlets — greenlets terminate immediately. No greenlet leaks. - **Explicit pending-queue clear**: both queues (worker-wait + collective-pending) are cleared. Prevents contamination on reuse. - **`SpawnException(errors)` wrapping**: `errors: dict[int, Exception]` contains the original exception per rank. Compatible with the failure pattern of real-PyTorch `torch.multiprocessing.spawn`. - **Scope restriction**: `errors` includes **only ranks that raised from their own code (root cause)**. Ranks terminated via `throw(SystemExit)` during sibling cleanup do not appear in `errors` (SystemExit is not caught by D1.2's entry wrapper `try/except Exception` — intentional design: sibling termination is a cleanup signal, not a failure). Made explicit so readers do not expect "all failed ranks" to appear. - **`ctx._traces` is the partial state up to the moment of exception**. Trace completeness is not guaranteed (some launches/all_reduces may terminate without leaving an entry). - **Allocator / MemoryStore** remain in their pre-exception state — reuse is non-goal; creating a fresh `RuntimeContext` is recommended. - **`join=False` / retry / partial recovery** are non-goals for this ADR. `SpawnException` is defined in `runtime_api/multiprocessing.py`: ```python class SpawnException(RuntimeError): def __init__(self, errors: dict[int, Exception]): self.errors = errors first = next(iter(errors.items()), None) msg = (f"spawn failed on ranks {sorted(errors.keys())}" + (f": rank {first[0]} raised {first[1]!r}" if first else "")) super().__init__(msg) ``` 5. **Single-driver compatibility**. In main-only execution where `g.parent is None` (legacy single-driver tests), D0.2's worker-fork condition is false → the existing immediate-synchronous path is preserved. `_drain_pending` is not invoked. #### D0.5 Host-read barrier — decision (normative) Inside a worker, **host-observable reads** such as `tensor.numpy()`, `tensor.__getitem__`, and `tensor.data` are defined as **automatic drain barriers**. Immediately before the call: 1. If `ctx._pending_worker_waits` or `backend._pending_collective_handles` are non-empty, yield to main via `g.parent.switch()` → main runs `_drain_pending` → worker resumes after completion. 2. If both queues are empty, read immediately. **Barrier iteration protocol (normative — re-entrance)**: `_drain_pending` drains via a while-loop **until both queues are completely empty**, not in a single pass: ```python def _drain_pending(ctx): while ctx._pending_worker_waits or ( ctx.distributed._backend and ctx.distributed._backend._pending_collective_handles ): while ctx._pending_worker_waits: h = ctx._pending_worker_waits.pop(0) if h not in ctx._completed: ctx.engine.wait(h) backend = ctx.distributed._backend if backend is not None: while backend._pending_collective_handles: h, _sip_id, meta = backend._pending_collective_handles.pop(0) ctx.wait(h, _meta=meta) # main context: safe; ctx.wait will # not push back to pending ``` **Main-context ctx.wait non-recursion invariant (normative)**: the `ctx.wait(h, _meta=meta)` call inside `_drain_pending` runs in the main greenlet context. Because D0.2's worker-fork condition (`g.parent is not None and not g.parent.dead`) is False, it enters the immediate-synchronous path → **never enqueues to `_pending_worker_waits`**. Thanks to this invariant, the drain loop terminates without recursion / queue re-growth. When implementing, it is important to maintain `g.parent is None` as the single-main-greenlet guarantee. **Why a loop**: `ctx.wait(h, _meta=meta)` is called in main context, so per the D0.2 path it **drives the engine directly** (no additional enqueue — the invariant above). In theory a single pass would suffice — but the protocol is fixed at **loop-until-empty**. Reasons: 1. **Future-extension safety**: a future implementation might enqueue new pending items mid-drain (e.g., tree-reduce collectives with sub-handles). The loop protocol preserves correctness in that case. 2. **Readability**: the single sentence "the barrier drains until pending is empty" closes the semantics. No dependence on the non-trivial invariant that `ctx.wait` calls do not produce new enqueues. 3. **Barrier semantics are "all dependencies needed for this read are complete"**: in the current model all pending = all dependencies, so the two are identical. The user mental model is the former. **Termination guarantee**: described under two regimes. - **Current implementation**: when called in main context, `ctx.wait` drives the engine directly (D0.2) → does not enqueue new pending. Each iteration strictly shrinks pending size by `pop(0)` + `engine.wait`. The iteration count is bounded by **the initial pending size itself** → finite termination. - **Future extension (the bound that justifies the loop protocol)**: if an implementation enqueues new pending mid-drain (e.g., tree-reduce sub-handles) is introduced, the initial-size bound breaks. However, SimPy causality guarantees that the dependency DAG of handles is finite, so **nested depth is finite**. The loop protocol automatically accommodates this case. Both regimes guarantee that infinite loops are impossible. The single-pass bound of the current implementation is a reference value for aggressive optimization; the protocol is fixed at loop-until-empty. **Why implicit drain at read is correct**: - In the original open question, the choice was between (a) implicit drain and (b) explicit barrier. (b) is clear but burdens TP layer users with the 3-step pattern `out = fc1.forward(x); ctx.drain(); result = out.numpy()` on every read. (a) is a single rule that "guarantees the read sees the reflected value" — identical to CUDA's `cudaDeviceSynchronize before host copy` pattern, which is not a hidden rule but the **contract of a named entry point**. - This ADR adopts (a) but **closes the entry-point list explicitly**: `Tensor.numpy()`, `Tensor.data` (numpy alias), `Tensor.__getitem__`, `Tensor.__repr__` (when data is included), and any other official host-read APIs are finalized via codebase search at the time of implementing this ADR. Any newly added host-read API must follow this contract (regression-guarded by tests). - Even when calling `numpy` directly after only `ctx.submit` without `wait`, the drain barrier still operates (because the handle is in the pending queue). The invariant is restored at read time even if the user omits an explicit wait. **`Tensor.copy_(source)` — write barrier specification**: `copy_` is semantically "write to target", but internally it calls `source.numpy()` to fetch source data on the host then writes to each shard via `target._memory_store.write(...)`. Both directions are barrier-handled: 1. **Source-side (read barrier)**: `source.numpy()` triggers the D0.5 read barrier (when source itself is a deployed tensor with pending). 2. **Target-side (write barrier — based on global pending)**: on `copy_` entry, if `ctx._pending_worker_waits` or `backend._pending_collective_handles` are non-empty, drain via `g.parent.switch()` before writing. **Not per-tensor / per-shard dependency tracking, but based on the global pending queue**. - Why global: KernBench's handle representation does not retain the reverse-mapping information "this handle writes to which shard of which target". A safe conservative rule: "drain if any global pending exists". As a result, **pending of an unrelated tensor can also block copy_** — drop-in invariant takes priority. - **Explicit tradeoff**: this rule can introduce unnecessary serialization between independent tensors. However, under the current single-queue execution model this cost is acceptable — guaranteeing cross-rank correctness and the "read sees latest" invariant via a simple rule takes precedence. - Practical impact: most pending of a single worker within a layer step is its own work — extra context switches from over-barrier often coincide with the end-of-round scheduler drain point, so no major issue. - Future refinement: per-tensor pending tracking could narrow this rule, but it is out of scope for this ADR. **Non-barrier**: - `tensor.shape`, `tensor.dtype`, `tensor.name`, and other **metadata-only** access does not drain. No data dependency. - `tensor.pa`, `tensor.va`, and other raw address accessors also do not drain (address only, not content). **Official barrier entry-points (closed set)**: | API | Kind | Rationale | |---|---|---| | `Tensor.numpy()` | read | host-observable copy | | `Tensor.data` | read | `numpy()` alias | | `Tensor.__getitem__` | read | shard-aligned read | | `Tensor.__repr__` (when data is included) | read | debugging/log | | `Tensor.copy_(source)` | read + write | source read + target write | This contract is verified directly in T5/T6. #### D0.6 Why the worker function API is unchanged (informative) - The inside of `torch.zeros(...)` is a `self.submit(msg)` + `self.wait(h)` pair. `wait` auto-defers to main per D0.2/D0.3 — appears synchronous from the outside but yields once. - `tensor.numpy()` follows D0.5's host-read barrier → drain→read when pending exists, immediate read otherwise. - `dist.all_reduce` continues to use the existing `_defer_wait=True` + `_pending_collective_handles` path. D0.4's drain processes both queues together. #### D0.7 Invariants - **The kernel greenlet's `_parent` is always main**: because env.run never runs in worker context. (Core assertion of T3.) - **Cross-rank synchronization point**: drain occurs only after every worker has yielded → kernels of all ranks advance together within one round (a prerequisite for cross-rank IPCQ exchange). - **Single-driver compatibility**: D0.4-(5). ### D1. `torch.multiprocessing.spawn(fn, args, nprocs)` Real-PyTorch API parity + a single implementation site for D0's scheduler loop. #### D1.0 API parity only — not execution parity (normative) The name `torch.multiprocessing.spawn` is restricted to **API signature parity**. The actual execution model is a **cooperative greenlet scheduler** (single Python process, single OS thread, round-robin drive per D0.4). The following are **properties this ADR does NOT provide** — among the guarantees of real-PyTorch `torch.multiprocessing.spawn`, explicitly **non-goals**: - Process isolation (independent OS process per rank). - Independent address space (each rank with its own Python heap). - Failure isolation (a hard crash in one rank not affecting others). - OS-level scheduler fairness (preemptive time slicing between ranks). - Inter-process primitives such as `mp.Queue`, `mp.Lock`. Actual properties of this implementation: - All ranks are greenlets inside the same Python process. Shared global state is visible as-is (intentional simulation convenience). - Single-threaded under the GIL → not parallel execution. Only "logical concurrency" via SimPy event ordering is reproduced. - Unhandled exception in any one worker → entire simulation aborts (D0.4-(4)). **Caller's obligation**: when porting real-PyTorch multi-process samples to KernBench, logic that relies on process isolation (e.g., `os.getpid`, independent temp files, signal handling) must be removed. The namespace name is preserved for code portability — semantics differ. #### D1.1 Public surface ```python # runtime_api/multiprocessing.py (new) class _MultiprocessingNamespace: def __init__(self, ctx): self._ctx = ctx def spawn(self, fn, args: tuple, nprocs: int, join: bool = True) -> None: """Spawn `nprocs` worker greenlets, each calling fn(rank, *args). Mirrors torch.multiprocessing.spawn signature (minus `daemon`). Drives the D0 scheduler loop until all workers finish. """ ... ``` #### D1.2 Implementation ```python def spawn(self, fn, args, nprocs, join=True): from greenlet import greenlet ctx = self._ctx dist = ctx.distributed gs: list[greenlet] = [] errors: dict[int, Exception] = {} for rank in range(nprocs): def _entry(r=rank): try: fn(r, *args) except Exception as e: errors[r] = e raise g = greenlet(_entry) dist._bind_rank(g, rank) gs.append(g) try: while True: alive = [g for g in gs if not g.dead] if not alive: break for g in alive: if not g.dead: g.switch() _drain_pending(ctx) # D0.5 except Exception as outer: # Sibling cleanup per D0.4-(4) for other in gs: if not other.dead: try: other.throw(SystemExit) except Exception: pass backend = getattr(dist, "_backend", None) if backend is not None: if hasattr(backend, "_barrier"): backend._barrier.reset() if getattr(backend, "_pending_collective_handles", None) is not None: backend._pending_collective_handles.clear() ctx._pending_worker_waits.clear() raise SpawnException(errors) from outer # `join=True` semantics: we already wait for all workers. ``` #### D1.3 `torch` namespace attach In `runtime_api/context.py` `__post_init__`: ```python self.multiprocessing = _MultiprocessingNamespace(self) ``` → in bench code: `torch.multiprocessing.spawn(worker, args=(ws,), nprocs=ws)`. #### D1.4 Migration of existing benches The hand-rolled loop in `benches/ccl_allreduce.py` collapses into a single `torch.multiprocessing.spawn` line. Existing matrix regressions are preserved. The currently xfail `ring_default_ws` is expected to flip to PASS thanks to D0 (workers no longer orphan the kernel greenlet). ### D2. New package `kernbench.tp` ``` src/kernbench/tp/ __init__.py — public API re-exports parallel_state.py — TP group management (currently a single global group) layers.py — ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding primitives.py — copy/reduce/scatter/gather_to/from_tp_region kernels.py — gemm kernel launched by TP layers (reusable) mappings.py — forward identity/all_reduce, backward stub ``` ### D3. `parallel_state` — TP group ```python # parallel_state.py _TP_WORLD_SIZE = None def initialize_model_parallel(tensor_model_parallel_size: int) -> None: """Initialize TP group. Must be called after dist.init_process_group.""" global _TP_WORLD_SIZE from kernbench.runtime_api.distributed import get_dist # or torch.distributed dist = get_dist() total = dist.get_world_size() if tensor_model_parallel_size != total: raise NotImplementedError( "Only TP == world_size supported in initial scope" ) _TP_WORLD_SIZE = tensor_model_parallel_size def get_tensor_model_parallel_world_size() -> int: return _TP_WORLD_SIZE def get_tensor_model_parallel_rank() -> int: from kernbench.runtime_api.distributed import get_dist return get_dist().get_rank() # ADR-0024 greenlet-local rank ``` Initial scope: TP size = world_size = topology SIP count. Pure TP model. ### D4-pre. TP shard ownership vs DPPolicy — role separation (normative) In the weight/output representation of TP layers, two concepts are clearly separated: | Concept | Decided by | Scope | |---|---|---| | **TP shard ownership** (which rank owns which slice of the weight) | greenlet-local rank + `torch.ahbm.set_device(rank)` (ADR-0024 D2/D3) | **cross-rank, cross-SIP** | | **Intra-rank placement** (how the owned slice is distributed across cube × PE inside the rank) | `DPPolicy(cube=..., pe=...)` (ADR-0026) | **inside one rank (within SIP boundary)** | Thus when `ColumnParallelLinear` creates a weight of shape `(in_features, out_features // ws)` and assigns `DPPolicy(cube="column_wise", pe="column_wise")`: - The slice owned by **rank r** = column-axis [r * k_local, (r+1) * k_local) of the weight — **set_device(r)** determines this (that rank resides on SIP r). - **Inside that slice**, the cube × PE column-wise distribution — **DPPolicy** determines this. The two axes are **independent**. If two ranks build their own slice with the same DPPolicy, the slices themselves live on different SIPs but the intra-SIP placement pattern is the same. Conversely, changing DPPolicy to `cube="replicate", pe="replicate"` preserves TP shard ownership and only changes intra-rank placement. **Mistakes that blur this boundary** (forbidden by this ADR): - The "SIP axis" reappearing in DPPolicy (removed in ADR-0026). - TP layers expressing cross-rank sharding via `DPPolicy` alone without `set_device` → indistinguishable from a vertical split within a single rank. The TP layers of this ADR always treat weight/output from the perspective of "rank = SIP = owns one slice + DPPolicy intra-SIP distribution" only. ### D4. `ColumnParallelLinear` **Important**: no new host-side `torch.matmul` abstraction is introduced. The layer's forward calls the existing gemm kernel via `torch.launch("gemm", gemm_kernel, ...)` — the pattern already used by KernBench benches ([benches/gemm_single_pe.py](benches/gemm_single_pe.py), [benches/gpt3_qkv.py](benches/gpt3_qkv.py)). ```python # layers.py from kernbench.policy.placement.dp import DPPolicy from kernbench.tp.kernels import _gemm_kernel from kernbench.tp.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) class ColumnParallelLinear: """Shards the K(out_features) axis of the weight across TP ranks. forward(x): x: (M, N) — full-replicated across ranks W_k: (N, K / world_size) — rank-local slice (placed on SIP r via set_device) y_k = x @ W_k → (M, K / world_size) — rank-local output Output is column-sharded. The input form expected by RowParallelLinear. """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype: str = "f16", torch=None): ws = get_tensor_model_parallel_world_size() assert out_features % ws == 0 self.in_features = in_features self.k_local = out_features // ws self._torch = torch # Each rank owns its own slice — placed on SIP r by set_device(rank). self.weight = torch.zeros( (in_features, self.k_local), dtype=dtype, dp=DPPolicy(cube="column_wise", pe="column_wise"), name="col_parallel_w", ) self.bias = None if bias: self.bias = torch.zeros( (self.k_local,), dtype=dtype, dp=DPPolicy(cube="replicate", pe="replicate"), name="col_parallel_b", ) def forward(self, x): # x is full-replicated (caller-guaranteed). Plain local gemm. M = x.shape[0] out = self._torch.empty( (M, self.k_local), dtype=x.dtype, dp=DPPolicy(cube="column_wise", pe="column_wise"), name="col_parallel_out", ) self._torch.launch( "col_parallel_gemm", _gemm_kernel, x, self.weight, out, M, self.in_features, self.k_local, ) # bias add as a separate kernel or as fused bias of a composite gemm. # Initial scope verifies bias=False sufficiently. return out ``` **Yield-safety contract (normative)**: `ColumnParallelLinear.forward` includes one `torch.launch` call containing a kernel launch → internal `ctx.wait` pair. This automatically satisfies the "worker yields within a finite number of steps" condition of D0.4-(1) — TP layer users do not need to insert yield patterns manually. ### D5. `RowParallelLinear` ```python class RowParallelLinear: """Shards the N(in_features) axis of the weight across TP ranks. forward(x): x: (M, N / world_size) — rank-local slice (output of ColumnParallel) W_k: (N / world_size, K) — rank-local slice y_k = x @ W_k → (M, K) — partial sum on each rank y = all_reduce(y_k, op="sum") → (M, K) on every rank """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype: str = "f16", torch=None): ws = get_tensor_model_parallel_world_size() assert in_features % ws == 0 self.n_local = in_features // ws self.out_features = out_features self._torch = torch self.weight = torch.zeros( (self.n_local, out_features), dtype=dtype, dp=DPPolicy(cube="column_wise", pe="column_wise"), name="row_parallel_w", ) # bias only on rank 0 (Megatron convention). Omitted in initial scope. self.bias = None def forward(self, x): M = x.shape[0] y_partial = self._torch.empty( (M, self.out_features), dtype=x.dtype, dp=DPPolicy(cube="column_wise", pe="column_wise"), name="row_parallel_partial", ) self._torch.launch( "row_parallel_gemm", _gemm_kernel, x, self.weight, y_partial, M, self.n_local, self.out_features, ) # Cross-rank reduce. ADR-0024's dist.all_reduce works correctly # under D0 + mp.spawn (kernel parent = main is preserved). self._torch.distributed.all_reduce(y_partial, op="sum") return y_partial ``` **Yield-safety contract (normative)**: `RowParallelLinear.forward` includes launch → internal wait followed by `all_reduce` (defer + worker-yield pattern), so **at least 2 yields per forward** are guaranteed. The scheduler-progress condition of D0.4-(1) is automatically satisfied. All TP layer forwards in this ADR maintain the invariant "yield-safe by containing at least one wait or collective" — any future TP primitives (e.g., VocabParallelEmbedding) must keep the same contract. ### D6. Primitive functions ```python # primitives.py def copy_to_tp_region(x): """Forward: identity. Backward: all-reduce. (Implemented when training is added).""" return x def reduce_from_tp_region(x, torch): """Forward: all-reduce. Backward: identity.""" torch.distributed.all_reduce(x, op="sum") return x def scatter_to_tp_region(x): raise NotImplementedError( "Phase 2: replaced by users creating already-sharded tensors" ) def gather_from_tp_region(x): raise NotImplementedError( "Phase 2: requires all-gather kernel as a prerequisite (future)" ) ``` ### D7. Sample bench — 2-layer MLP with TP ```python # benches/tp_mlp.py (new) from kernbench.policy.placement.dp import DPPolicy import kernbench.tp as tp import numpy as np def worker(rank: int, world_size: int, torch): torch.ahbm.set_device(rank) tp.initialize_model_parallel(world_size) B, D_in, D_hidden, D_out = 1, 512, 2048, 512 fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=torch) fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=torch) x = torch.zeros( (B, D_in), dtype="f16", dp=DPPolicy(cube="replicate", pe="replicate"), name="x", ) # init x with some pattern (e.g., constant) x.copy_(torch.from_numpy(np.full((B, D_in), 0.1, dtype=np.float16))) h = fc1.forward(x) # column-sharded (B, D_hidden / ws) y = fc2.forward(h) # all-reduced (B, D_out) on every rank # Only rank 0 prints / verifies the result if rank == 0: result = y.numpy() # With zero-init weights, all values are 0 — within scope "completion itself" is the check print(f" tp_mlp: shape={result.shape}, mean={float(result.mean()):.4f}") def run(torch): torch.distributed.init_process_group(backend="ahbm") ws = torch.distributed.get_world_size() torch.multiprocessing.spawn(worker, args=(ws,), nprocs=ws) ``` ### D8. Non-functional — training not supported This ADR is **inference/forward only**. Backward / gradient / optimizer is future work. Natural because KernBench is not a training system. ### D9. Initial-scope constraints - TP size = world_size (no mixed DP+TP). - `scatter_to_tp_region`, `gather_from_tp_region` are unimplemented. - **Default weight value is zero**. Proper init schemes (Xavier, Kaiming, etc.) are future. Tests inject deterministic non-zero patterns via `tensor.copy_` to verify numerical correctness (T2/T6). I.e., operate as "production default = zero, verification = deterministic non-zero". - Bias is omitted in the initial scope (Megatron's rank-0-only bias policy is future). - Pipeline parallelism is out of scope. - VocabParallelEmbedding requires a prerequisite all-gather → stub only. ### D10. Regression: `ring_default_ws` xfail removal — mandatory acceptance Thanks to D0 (worker-wait generalization) + D0.5 (host-read barrier), every worker-driven `ctx.wait` and host-read is routed through the main-drain path → the cause of the kernel-greenlet orphan in ADR-0024 Phase B disappears. Flipping the existing matrix test's `ring_default_ws` strict-xfail case to **PASS** after this ADR's implementation is included as a **mandatory regression criterion**. Observable acceptance criteria are specified in **T7** (no deadlock, no GreenletExit, numerical tolerance, etc.). --- ## Dependencies - **ADR-0024** (launcher): rank = SIP, greenlet-local rank, `torch.ahbm.set_device(rank)`. - **ADR-0026** (DPPolicy intra-device): per-rank slice representation of weight tensors. - **ADR-0023 / ADR-0025** (IPCQ): foundation of `dist.all_reduce` implementation. --- ## Non-goals - **Backward pass / training**: inference only. Training simulation is a separate ADR. - **Mixed parallelism (DP + TP + PP)**: pure TP only at the start. - **Weight init schemes**: simple zero / debug pattern. - **Fused ops**: Megatron's fused matmul+bias+gelu is a kernel-level concern. - **DTensor integration**: ADR-0028 future. - **Host-side `torch.matmul` abstraction**: TP layers call the existing gemm kernel via `torch.launch(gemm_kernel, ...)`. No new matmul host-op is introduced. --- ## Open questions - **Location of `initialize_model_parallel`**: `kernbench.tp.initialize_model_parallel` (current decision) vs real-PyTorch's `torch.distributed.init_device_mesh`. Kept in the TP-only module. - **Weight init**: the ADR uses zero. A debug pattern (e.g., identity) may be needed for valid verification — add at Phase 1 test time if needed. - **Bias placement policy**: Megatron places RowParallelLinear bias only on rank 0. Avoided in the initial scope via bias=False. - **GEMM kernel location**: `kernbench.tp.kernels._gemm_kernel` vs importing from existing `benches/gemm_single_pe.py`. TP must not depend on benches, so duplicated inside tp. Migration to a shared `kernbench.kernels` package is possible later. **Resolved (previously open in earlier revisions)**: - ~~Drain timing on `tensor.numpy()` call~~ → **decided in D0.5**: the official host-read entry points (`numpy`, `data`, `__getitem__`, data-containing `__repr__`) are automatic drain barriers. Metadata-only accessors are not barriers. --- ## Consequences ### Positive - **Easy porting of Megatron code**: API matches real training code. - **TP benchmarking enabled**: research on scaling, communication-compute overlap, and other HW characteristics. - **`ring_default_ws` xfail removal**: as a byproduct of D0, the ADR-0024 Phase B blocker is resolved. - **Scheduler-loop unification**: introducing D1 (`mp.spawn`) removes the hand-rolled loop. Subsequent collective/TP benches reuse the same pattern. - **DPPolicy semantics clarified** (synergy with ADR-0026): TP layers as a best-practice example of using intra-device DPPolicy only. ### Negative - Maintenance cost of a new module (`kernbench.tp`). - Initial scope is limited (pure TP only, forward only). - D0 generalization changes the semantics of `ctx.wait` — compatibility with single-driver tests must be explicitly verified (T7). ### Neutral - A pure upper layer added on top of ADR-0024/0026. No impact on the hardware-simulation stack (apart from D0).