Files
kernbench2/docs/adr/ADR-0027-par-megatron-tp.md
T
ywkang a796c1d2f7 ADR: bilingual structure — EN canonical in adr/, KO mirror in adr-ko/
Establish English as the canonical ADR language with Korean translations
held in a parallel docs/adr-ko/ tree as derived artifacts (1:1 mirror).
Promotion from adr-proposed/ to adr/ now writes English to adr/ and the
Korean to adr-ko/; bidirectional sync rule documented in CLAUDE.md.

- Migrate 30 ADRs in docs/adr/: 28 Korean-only translated to English,
  2 bilingual pairs (ADR-0020, ADR-0023) consolidated (.en.md suffix
  dropped). ADR-0023 EN regenerated against KO source which had newer
  HW Realization Notes (D16-D23) section.
- docs/adr-history/ left frozen by design (transitional state).
- CLAUDE.md (Part 2): update ADR Lifecycle for 4-folder layout, mark
  docs/adr-ko/ as a Derived Artifact, add ADR Translation Discipline
  section covering bidirectional sync, conflict resolution (EN wins),
  and proposed-language freedom.
- tools/verify_adr_lang_pairs.py: new verification tool checking pair
  completeness, filename mirroring, ADR-ID match, Status byte-equality.
  Pre-commit hook intentionally not added; run on demand or in CI.
- tests/test_verify_adr_lang_pairs.py: 11 cases including CRLF/LF
  normalization, em-dash title separator, underscore-slug edge case.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 01:38:44 -07:00

40 KiB
Raw Blame History

ADR-0027: Megatron-style Tensor Parallelism API

Status

Accepted

Context

Goal

Support inter-SIP tensor parallelism (TP) via a Megatron-LM style explicit parallel layer API. Declarative abstractions like DTensor are future work in a separate ADR (0028).

Why Megatron-style was chosen:

  • TP arises at specific layer boundaries of a model. Explicit primitives are natural to the mental model.
  • The de-facto industry standard established by NVIDIA Megatron / DeepSpeed.
  • DTensor is declarative, so its design space is larger → phased approach.

TP primitive spec (Megatron-LM reference)

  • ColumnParallelLinear: shards the weight's column (out_features) axis across TP ranks. Input is full-replicated, output is column-sharded. When a RowParallelLinear follows, no forward all-reduce is required.
  • RowParallelLinear: shards the weight's row (in_features) axis across TP ranks. Input is already column-sharded (the output of ColumnParallel). Requires an all-reduce at the end of forward.
  • VocabParallelEmbedding: shards the embedding along the vocab axis. all-reduce at the end of forward. (A stub in the initial scope; full implementation requires an all-gather kernel as a prerequisite.)
  • copy_to_tp_region, reduce_from_tp_region, scatter_to_tp_region, gather_from_tp_region — basic primitives.

Problems to solve

  1. Worker-wait generalization (D0): extend the defer/yield/drain pattern of dist.all_reduce to every ctx.wait path. The biggest architectural decision of this ADR.

  2. Launcher API normalization (D1): current benches use a hand-rolled greenlet loop. Absorb it into torch.multiprocessing.spawn(fn, args, nprocs) to preserve the real-PyTorch API surface + concentrate D0's scheduler drain in a single implementation site.

  3. Per-rank weight shard representation: each worker owns its own slice of the weight tensor. Naturally expressed via ADR-0024's set_device(rank) + ADR-0026's intra-device DPPolicy.

  4. Forward-only scope: KernBench currently has no backward (simulation purposes). This ADR prioritizes forward only. Training simulation is a separate ADR.

  5. Collective call site: RowParallelLinear calls all_reduce at the end of forward. Naturally works with ADR-0024's multi-greenlet structure + D0 generalization.

  6. TP group concept: Megatron crosses DP × TP × PP groups. The initial scope simplifies to TP group = all SIPs. Mixed DP+TP is future work.


Decision

D0. Worker-wait generalization — ctx.wait defers to main when in worker context

Restating the problem. kernel_runner.run captures the greenlet.getcurrent() at spawn time as the kernel greenlet's _parent (kernel_runner.py:94). If env.run runs in the main context, parent=main is safe. If env.run runs in a worker context, parent=worker, and the moment the worker yields/finishes the kernel greenlet becomes an orphan → GreenletExit → failure of ADR-0024 Phase B's ring_default_ws.

Resolution. When a worker greenlet calls ctx.wait(h), instead of driving env.run directly, yield to the main scheduler. main drives env.run and, once the handle completes, control returns to the worker.

D0.1 RuntimeContext extension

# context.py
@dataclass
class RuntimeContext:
    ...
    _pending_worker_waits: list[RequestHandle] = field(default_factory=list, init=False)

D0.2 ctx.wait worker fork

def wait(self, handle, *, _meta=None):
    # Fast-path: already completed — skip enqueue + switch (consistent with
    # D0.4-(3) idempotency). Avoids needless worker→main→worker round-trip
    # and prevents redundant _pending_worker_waits growth.
    if handle in self._completed:
        completion, _trace = self.engine.get_completion(handle)
        return completion

    from greenlet import getcurrent
    g = getcurrent()
    if g.parent is not None and not g.parent.dead:
        # Worker greenlet: defer to main. Push handle, yield to parent.
        # Parent (scheduler loop) drains env.run, then switches back.
        self._pending_worker_waits.append(handle)
        g.parent.switch()
        # On resume: handle must have completed (main drained the list).
        # Fall through to the status-quo completion/trace assembly.

    # Main context (or single-driver): drive engine directly.
    wait_fn = getattr(self.engine, "wait", None)
    if wait_fn is not None:
        wait_fn(handle)
    completion, trace = self.engine.get_completion(handle)
    self._completed.add(handle)
    if _meta is not None and trace is not None:
        entry = dict(trace) if isinstance(trace, dict) else {"raw": trace}
        entry.update(_meta)
        self._traces.append(entry)
    return completion

D0.3 ctx.wait worker-context semantic contract (normative)

This ADR explicitly changes the semantics of ctx.wait in worker context.

  • Submit-vs-complete separation: when called from a worker, ctx.wait(h) no longer guarantees "immediate completion" but instead guarantees "completion after the next scheduler drain". The point at which the worker returns from wait() = the point at which main has finished engine.wait for that handle. Main-context calls remain immediate-synchronous as before (status quo).
  • Resume invariant (normative): at the point a worker resumes from a worker-deferred ctx.wait(h) (when g.parent.switch() returns), h in ctx._completed must be True. If this invariant breaks, the worker proceeds in a stale state, so whichever of _drain_pending / the scheduler loop / ctx.wait is modified, this invariant must be preserved. T3.b directly asserts this invariant.
  • Observable change: the pattern h = ctx.submit(msg); ctx.wait(h); read(handle_result) inside a worker still holds — but the semantic spec now includes the fact that a main-drain is automatically inserted between wait() and read.
  • Direct host-object reads see D0.5: the contract for calling tensor.numpy() without ctx.wait is specified separately in D0.5.

D0.4 Main scheduler drain — protocol (normative)

(The internal implementation of D1's multiprocessing.spawn. Below is the semantic definition.)

while alive:
    for g in alive:              # (1) round-based worker switch
        g.switch()
    _drain_pending(ctx)           # (2) drain in main context

(The actual definition of _drain_pending is in D0.5 — an outer while-loop that drains until both queues are empty.)

Rules:

  1. Round-based cooperative scheduling & yield obligation (worker contract). g.switch() does not return until the worker voluntarily yields (cooperative greenlet semantics). Therefore:

    • If a worker runs a pure-compute loop like while True: do_compute() without yielding, g.switch() never returns and the scheduler loop itself hard-blocks (other workers cannot get a switch turn, no drain occurs). This is not starvation but scheduler non-progress (deadlock equivalent), and this ADR classifies it as unsupported.
    • Workers must call one of ctx.wait(h), dist.all_reduce, or a host-read barrier (D0.5) within a finite number of steps. The forward of a TP layer includes a launch→wait pair at the end of every layer, so this condition is naturally met. CCL kernels also yield inside dist.all_reduce.
    • Implementations need not detect this (timeouts/steps-since-yield counters, etc.). It is a user contract; the symptom on violation is "simulation hang".
    • Future extension: if non-collective long compute paths become common, an explicit torch.distributed.cooperative_yield() primitive (no-op yield) could be introduced. Out of scope for this ADR. Not a breaking change — can be added if needed.
    • Within a round, every alive worker receives one switch turn. Even if a single worker calls wait multiple times within one round, the calls are enqueued sequentially within that turn and processed in a single scheduler drain batch (FIFO).
  2. Drain order = submission order (FIFO). _pending_worker_waits is strict FIFO via list append/pop(0). Drain occurs in submission order, not completion order, and SimPy's scheduler itself guarantees a causally correct completion order, so submission-order drain is safe. Do not confuse completion order with drain order.

    Two-queue ordering (worker waits → collectives): _drain_pending drains the worker wait queue first, then the collective queue. Rationale for this ordering:

    • The two queues are different dependency sources: worker waits are handles produced by a worker's own submit + wait pair (tensor deploy, MmuMap, etc.). The collective queue holds kernel-launch handles that dist.all_reduce enqueues internally, which the worker never directly waits on (see the two-queue drain model in D0.5).
    • Independent in correctness terms: from the worker's perspective, a collective is "already submitted, then yielded". Its completion timing only needs to precede the worker's next action. There is no ordering dependency with the worker wait queue.
    • Both finish within a single drain barrier: per D0.5's loop-until-empty rule, a single barrier invocation drains worker → collective → (repeat if new ones appeared) in that order. By the time the worker resumes, both sides are drained.
    • The alternative (collective first) is also valid: this ADR fixes worker-first only for current implementation simplicity; semantically they are equivalent. Revisit if a performance-profile difference is observed.
  3. Duplicate enqueue — correctness via idempotent drain; dedup not guaranteed. ctx.wait(h) returns immediately if h in ctx._completed. _drain_pending uses the same guard. Even if the same handle is appended to _pending_worker_waits multiple times, engine.wait is invoked only once (idempotent).

    • Correctness: relies on idempotent drain → safe.
    • Memory/performance: this ADR does not guarantee dedup of _pending_worker_waits. If the same handle is enqueued N times, the queue retains N elements and drain performs N pops + in-set guards. Unless a single worker abnormally repeats waits on the same handle, N stays at the order of 1 to a few.
    • Implementation freedom: implementations may optionally dedup (e.g., hold a set as a side index, or check h not in pending_set before append). Classified as an optimization that does not change correctness.
  4. Exception propagation + sibling cleanup. When a worker greenlet raises, g.switch() propagates the exception to main. The scheduler loop stops immediately and performs the following cleanup explicitly:

    try:
        while True:
            alive = [g for g in gs if not g.dead]
            if not alive:
                break
            for g in alive:
                if not g.dead:
                    g.switch()
            _drain_pending(ctx)
    except Exception as outer:
        # (a) Force-terminate surviving sibling worker greenlets.
        for other in gs:
            if not other.dead:
                try:
                    other.throw(SystemExit)
                except Exception:
                    pass          # silent — already in exceptional state
        # (b) Reset backend barrier / pending state (in preparation for future epoch barrier).
        backend = getattr(ctx.distributed, "_backend", None)
        if backend is not None and hasattr(backend, "_barrier"):
            backend._barrier.reset()
        backend_pending = getattr(backend, "_pending_collective_handles", None)
        if backend_pending is not None:
            backend_pending.clear()
        ctx._pending_worker_waits.clear()
        # (c) Wrap the originating exception in SpawnException.
        raise SpawnException(errors) from outer
    

    Protocol:

    • Sibling abort guarantee: when one worker raises, SystemExit is thrown into all sibling greenlets — greenlets terminate immediately. No greenlet leaks.
    • Explicit pending-queue clear: both queues (worker-wait + collective-pending) are cleared. Prevents contamination on reuse.
    • SpawnException(errors) wrapping: errors: dict[int, Exception] contains the original exception per rank. Compatible with the failure pattern of real-PyTorch torch.multiprocessing.spawn.
      • Scope restriction: errors includes only ranks that raised from their own code (root cause). Ranks terminated via throw(SystemExit) during sibling cleanup do not appear in errors (SystemExit is not caught by D1.2's entry wrapper try/except Exception — intentional design: sibling termination is a cleanup signal, not a failure). Made explicit so readers do not expect "all failed ranks" to appear.
    • ctx._traces is the partial state up to the moment of exception. Trace completeness is not guaranteed (some launches/all_reduces may terminate without leaving an entry).
    • Allocator / MemoryStore remain in their pre-exception state — reuse is non-goal; creating a fresh RuntimeContext is recommended.
    • join=False / retry / partial recovery are non-goals for this ADR.

    SpawnException is defined in runtime_api/multiprocessing.py:

    class SpawnException(RuntimeError):
        def __init__(self, errors: dict[int, Exception]):
            self.errors = errors
            first = next(iter(errors.items()), None)
            msg = (f"spawn failed on ranks {sorted(errors.keys())}"
                   + (f": rank {first[0]} raised {first[1]!r}" if first else ""))
            super().__init__(msg)
    
  5. Single-driver compatibility. In main-only execution where g.parent is None (legacy single-driver tests), D0.2's worker-fork condition is false → the existing immediate-synchronous path is preserved. _drain_pending is not invoked.

D0.5 Host-read barrier — decision (normative)

Inside a worker, host-observable reads such as tensor.numpy(), tensor.__getitem__, and tensor.data are defined as automatic drain barriers. Immediately before the call:

  1. If ctx._pending_worker_waits or backend._pending_collective_handles are non-empty, yield to main via g.parent.switch() → main runs _drain_pending → worker resumes after completion.
  2. If both queues are empty, read immediately.

Barrier iteration protocol (normative — re-entrance): _drain_pending drains via a while-loop until both queues are completely empty, not in a single pass:

def _drain_pending(ctx):
    while ctx._pending_worker_waits or (
        ctx.distributed._backend
        and ctx.distributed._backend._pending_collective_handles
    ):
        while ctx._pending_worker_waits:
            h = ctx._pending_worker_waits.pop(0)
            if h not in ctx._completed:
                ctx.engine.wait(h)
        backend = ctx.distributed._backend
        if backend is not None:
            while backend._pending_collective_handles:
                h, _sip_id, meta = backend._pending_collective_handles.pop(0)
                ctx.wait(h, _meta=meta)  # main context: safe; ctx.wait will
                                          # not push back to pending

Main-context ctx.wait non-recursion invariant (normative): the ctx.wait(h, _meta=meta) call inside _drain_pending runs in the main greenlet context. Because D0.2's worker-fork condition (g.parent is not None and not g.parent.dead) is False, it enters the immediate-synchronous path → never enqueues to _pending_worker_waits. Thanks to this invariant, the drain loop terminates without recursion / queue re-growth. When implementing, it is important to maintain g.parent is None as the single-main-greenlet guarantee.

Why a loop: ctx.wait(h, _meta=meta) is called in main context, so per the D0.2 path it drives the engine directly (no additional enqueue — the invariant above). In theory a single pass would suffice — but the protocol is fixed at loop-until-empty. Reasons:

  1. Future-extension safety: a future implementation might enqueue new pending items mid-drain (e.g., tree-reduce collectives with sub-handles). The loop protocol preserves correctness in that case.
  2. Readability: the single sentence "the barrier drains until pending is empty" closes the semantics. No dependence on the non-trivial invariant that ctx.wait calls do not produce new enqueues.
  3. Barrier semantics are "all dependencies needed for this read are complete": in the current model all pending = all dependencies, so the two are identical. The user mental model is the former.

Termination guarantee: described under two regimes.

  • Current implementation: when called in main context, ctx.wait drives the engine directly (D0.2) → does not enqueue new pending. Each iteration strictly shrinks pending size by pop(0) + engine.wait. The iteration count is bounded by the initial pending size itself → finite termination.
  • Future extension (the bound that justifies the loop protocol): if an implementation enqueues new pending mid-drain (e.g., tree-reduce sub-handles) is introduced, the initial-size bound breaks. However, SimPy causality guarantees that the dependency DAG of handles is finite, so nested depth is finite. The loop protocol automatically accommodates this case.

Both regimes guarantee that infinite loops are impossible. The single-pass bound of the current implementation is a reference value for aggressive optimization; the protocol is fixed at loop-until-empty.

Why implicit drain at read is correct:

  • In the original open question, the choice was between (a) implicit drain and (b) explicit barrier. (b) is clear but burdens TP layer users with the 3-step pattern out = fc1.forward(x); ctx.drain(); result = out.numpy() on every read. (a) is a single rule that "guarantees the read sees the reflected value" — identical to CUDA's cudaDeviceSynchronize before host copy pattern, which is not a hidden rule but the contract of a named entry point.
  • This ADR adopts (a) but closes the entry-point list explicitly: Tensor.numpy(), Tensor.data (numpy alias), Tensor.__getitem__, Tensor.__repr__ (when data is included), and any other official host-read APIs are finalized via codebase search at the time of implementing this ADR. Any newly added host-read API must follow this contract (regression-guarded by tests).
  • Even when calling numpy directly after only ctx.submit without wait, the drain barrier still operates (because the handle is in the pending queue). The invariant is restored at read time even if the user omits an explicit wait.

Tensor.copy_(source) — write barrier specification:

copy_ is semantically "write to target", but internally it calls source.numpy() to fetch source data on the host then writes to each shard via target._memory_store.write(...). Both directions are barrier-handled:

  1. Source-side (read barrier): source.numpy() triggers the D0.5 read barrier (when source itself is a deployed tensor with pending).
  2. Target-side (write barrier — based on global pending): on copy_ entry, if ctx._pending_worker_waits or backend._pending_collective_handles are non-empty, drain via g.parent.switch() before writing. Not per-tensor / per-shard dependency tracking, but based on the global pending queue.
    • Why global: KernBench's handle representation does not retain the reverse-mapping information "this handle writes to which shard of which target". A safe conservative rule: "drain if any global pending exists". As a result, pending of an unrelated tensor can also block copy_ — drop-in invariant takes priority.
    • Explicit tradeoff: this rule can introduce unnecessary serialization between independent tensors. However, under the current single-queue execution model this cost is acceptable — guaranteeing cross-rank correctness and the "read sees latest" invariant via a simple rule takes precedence.
    • Practical impact: most pending of a single worker within a layer step is its own work — extra context switches from over-barrier often coincide with the end-of-round scheduler drain point, so no major issue.
    • Future refinement: per-tensor pending tracking could narrow this rule, but it is out of scope for this ADR.

Non-barrier:

  • tensor.shape, tensor.dtype, tensor.name, and other metadata-only access does not drain. No data dependency.
  • tensor.pa, tensor.va, and other raw address accessors also do not drain (address only, not content).

Official barrier entry-points (closed set):

API Kind Rationale
Tensor.numpy() read host-observable copy
Tensor.data read numpy() alias
Tensor.__getitem__ read shard-aligned read
Tensor.__repr__ (when data is included) read debugging/log
Tensor.copy_(source) read + write source read + target write

This contract is verified directly in T5/T6.

D0.6 Why the worker function API is unchanged (informative)

  • The inside of torch.zeros(...) is a self.submit(msg) + self.wait(h) pair. wait auto-defers to main per D0.2/D0.3 — appears synchronous from the outside but yields once.
  • tensor.numpy() follows D0.5's host-read barrier → drain→read when pending exists, immediate read otherwise.
  • dist.all_reduce continues to use the existing _defer_wait=True + _pending_collective_handles path. D0.4's drain processes both queues together.

D0.7 Invariants

  • The kernel greenlet's _parent is always main: because env.run never runs in worker context. (Core assertion of T3.)
  • Cross-rank synchronization point: drain occurs only after every worker has yielded → kernels of all ranks advance together within one round (a prerequisite for cross-rank IPCQ exchange).
  • Single-driver compatibility: D0.4-(5).

D1. torch.multiprocessing.spawn(fn, args, nprocs)

Real-PyTorch API parity + a single implementation site for D0's scheduler loop.

D1.0 API parity only — not execution parity (normative)

The name torch.multiprocessing.spawn is restricted to API signature parity. The actual execution model is a cooperative greenlet scheduler (single Python process, single OS thread, round-robin drive per D0.4). The following are properties this ADR does NOT provide — among the guarantees of real-PyTorch torch.multiprocessing.spawn, explicitly non-goals:

  • Process isolation (independent OS process per rank).
  • Independent address space (each rank with its own Python heap).
  • Failure isolation (a hard crash in one rank not affecting others).
  • OS-level scheduler fairness (preemptive time slicing between ranks).
  • Inter-process primitives such as mp.Queue, mp.Lock.

Actual properties of this implementation:

  • All ranks are greenlets inside the same Python process. Shared global state is visible as-is (intentional simulation convenience).
  • Single-threaded under the GIL → not parallel execution. Only "logical concurrency" via SimPy event ordering is reproduced.
  • Unhandled exception in any one worker → entire simulation aborts (D0.4-(4)).

Caller's obligation: when porting real-PyTorch multi-process samples to KernBench, logic that relies on process isolation (e.g., os.getpid, independent temp files, signal handling) must be removed. The namespace name is preserved for code portability — semantics differ.

D1.1 Public surface

# runtime_api/multiprocessing.py (new)
class _MultiprocessingNamespace:
    def __init__(self, ctx):
        self._ctx = ctx

    def spawn(self, fn, args: tuple, nprocs: int, join: bool = True) -> None:
        """Spawn `nprocs` worker greenlets, each calling fn(rank, *args).

        Mirrors torch.multiprocessing.spawn signature (minus `daemon`).
        Drives the D0 scheduler loop until all workers finish.
        """
        ...

D1.2 Implementation

def spawn(self, fn, args, nprocs, join=True):
    from greenlet import greenlet
    ctx = self._ctx
    dist = ctx.distributed
    gs: list[greenlet] = []
    errors: dict[int, Exception] = {}
    for rank in range(nprocs):
        def _entry(r=rank):
            try:
                fn(r, *args)
            except Exception as e:
                errors[r] = e
                raise
        g = greenlet(_entry)
        dist._bind_rank(g, rank)
        gs.append(g)

    try:
        while True:
            alive = [g for g in gs if not g.dead]
            if not alive:
                break
            for g in alive:
                if not g.dead:
                    g.switch()
            _drain_pending(ctx)       # D0.5
    except Exception as outer:
        # Sibling cleanup per D0.4-(4)
        for other in gs:
            if not other.dead:
                try:
                    other.throw(SystemExit)
                except Exception:
                    pass
        backend = getattr(dist, "_backend", None)
        if backend is not None:
            if hasattr(backend, "_barrier"):
                backend._barrier.reset()
            if getattr(backend, "_pending_collective_handles", None) is not None:
                backend._pending_collective_handles.clear()
        ctx._pending_worker_waits.clear()
        raise SpawnException(errors) from outer
    # `join=True` semantics: we already wait for all workers.

D1.3 torch namespace attach

In runtime_api/context.py __post_init__:

self.multiprocessing = _MultiprocessingNamespace(self)

→ in bench code: torch.multiprocessing.spawn(worker, args=(ws,), nprocs=ws).

D1.4 Migration of existing benches

The hand-rolled loop in benches/ccl_allreduce.py collapses into a single torch.multiprocessing.spawn line. Existing matrix regressions are preserved. The currently xfail ring_default_ws is expected to flip to PASS thanks to D0 (workers no longer orphan the kernel greenlet).

D2. New package kernbench.tp

src/kernbench/tp/
    __init__.py          — public API re-exports
    parallel_state.py    — TP group management (currently a single global group)
    layers.py            — ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
    primitives.py        — copy/reduce/scatter/gather_to/from_tp_region
    kernels.py           — gemm kernel launched by TP layers (reusable)
    mappings.py          — forward identity/all_reduce, backward stub

D3. parallel_state — TP group

# parallel_state.py
_TP_WORLD_SIZE = None

def initialize_model_parallel(tensor_model_parallel_size: int) -> None:
    """Initialize TP group. Must be called after dist.init_process_group."""
    global _TP_WORLD_SIZE
    from kernbench.runtime_api.distributed import get_dist  # or torch.distributed
    dist = get_dist()
    total = dist.get_world_size()
    if tensor_model_parallel_size != total:
        raise NotImplementedError(
            "Only TP == world_size supported in initial scope"
        )
    _TP_WORLD_SIZE = tensor_model_parallel_size

def get_tensor_model_parallel_world_size() -> int:
    return _TP_WORLD_SIZE

def get_tensor_model_parallel_rank() -> int:
    from kernbench.runtime_api.distributed import get_dist
    return get_dist().get_rank()         # ADR-0024 greenlet-local rank

Initial scope: TP size = world_size = topology SIP count. Pure TP model.

D4-pre. TP shard ownership vs DPPolicy — role separation (normative)

In the weight/output representation of TP layers, two concepts are clearly separated:

Concept Decided by Scope
TP shard ownership (which rank owns which slice of the weight) greenlet-local rank + torch.ahbm.set_device(rank) (ADR-0024 D2/D3) cross-rank, cross-SIP
Intra-rank placement (how the owned slice is distributed across cube × PE inside the rank) DPPolicy(cube=..., pe=...) (ADR-0026) inside one rank (within SIP boundary)

Thus when ColumnParallelLinear creates a weight of shape (in_features, out_features // ws) and assigns DPPolicy(cube="column_wise", pe="column_wise"):

  • The slice owned by rank r = column-axis [r * k_local, (r+1) * k_local) of the weight — set_device(r) determines this (that rank resides on SIP r).
  • Inside that slice, the cube × PE column-wise distribution — DPPolicy determines this.

The two axes are independent. If two ranks build their own slice with the same DPPolicy, the slices themselves live on different SIPs but the intra-SIP placement pattern is the same. Conversely, changing DPPolicy to cube="replicate", pe="replicate" preserves TP shard ownership and only changes intra-rank placement.

Mistakes that blur this boundary (forbidden by this ADR):

  • The "SIP axis" reappearing in DPPolicy (removed in ADR-0026).
  • TP layers expressing cross-rank sharding via DPPolicy alone without set_device → indistinguishable from a vertical split within a single rank.

The TP layers of this ADR always treat weight/output from the perspective of "rank = SIP = owns one slice + DPPolicy intra-SIP distribution" only.

D4. ColumnParallelLinear

Important: no new host-side torch.matmul abstraction is introduced. The layer's forward calls the existing gemm kernel via torch.launch("gemm", gemm_kernel, ...) — the pattern already used by KernBench benches (benches/gemm_single_pe.py, benches/gpt3_qkv.py).

# layers.py
from kernbench.policy.placement.dp import DPPolicy
from kernbench.tp.kernels import _gemm_kernel
from kernbench.tp.parallel_state import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)

class ColumnParallelLinear:
    """Shards the K(out_features) axis of the weight across TP ranks.

    forward(x):
        x: (M, N) — full-replicated across ranks
        W_k: (N, K / world_size) — rank-local slice (placed on SIP r via set_device)
        y_k = x @ W_k → (M, K / world_size) — rank-local output

    Output is column-sharded. The input form expected by RowParallelLinear.
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = False,
                 dtype: str = "f16", torch=None):
        ws = get_tensor_model_parallel_world_size()
        assert out_features % ws == 0
        self.in_features = in_features
        self.k_local = out_features // ws
        self._torch = torch
        # Each rank owns its own slice — placed on SIP r by set_device(rank).
        self.weight = torch.zeros(
            (in_features, self.k_local), dtype=dtype,
            dp=DPPolicy(cube="column_wise", pe="column_wise"),
            name="col_parallel_w",
        )
        self.bias = None
        if bias:
            self.bias = torch.zeros(
                (self.k_local,), dtype=dtype,
                dp=DPPolicy(cube="replicate", pe="replicate"),
                name="col_parallel_b",
            )

    def forward(self, x):
        # x is full-replicated (caller-guaranteed). Plain local gemm.
        M = x.shape[0]
        out = self._torch.empty(
            (M, self.k_local), dtype=x.dtype,
            dp=DPPolicy(cube="column_wise", pe="column_wise"),
            name="col_parallel_out",
        )
        self._torch.launch(
            "col_parallel_gemm", _gemm_kernel,
            x, self.weight, out, M, self.in_features, self.k_local,
        )
        # bias add as a separate kernel or as fused bias of a composite gemm.
        # Initial scope verifies bias=False sufficiently.
        return out

Yield-safety contract (normative): ColumnParallelLinear.forward includes one torch.launch call containing a kernel launch → internal ctx.wait pair. This automatically satisfies the "worker yields within a finite number of steps" condition of D0.4-(1) — TP layer users do not need to insert yield patterns manually.

D5. RowParallelLinear

class RowParallelLinear:
    """Shards the N(in_features) axis of the weight across TP ranks.

    forward(x):
        x: (M, N / world_size) — rank-local slice (output of ColumnParallel)
        W_k: (N / world_size, K) — rank-local slice
        y_k = x @ W_k → (M, K) — partial sum on each rank
        y = all_reduce(y_k, op="sum") → (M, K) on every rank
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = False,
                 dtype: str = "f16", torch=None):
        ws = get_tensor_model_parallel_world_size()
        assert in_features % ws == 0
        self.n_local = in_features // ws
        self.out_features = out_features
        self._torch = torch
        self.weight = torch.zeros(
            (self.n_local, out_features), dtype=dtype,
            dp=DPPolicy(cube="column_wise", pe="column_wise"),
            name="row_parallel_w",
        )
        # bias only on rank 0 (Megatron convention). Omitted in initial scope.
        self.bias = None

    def forward(self, x):
        M = x.shape[0]
        y_partial = self._torch.empty(
            (M, self.out_features), dtype=x.dtype,
            dp=DPPolicy(cube="column_wise", pe="column_wise"),
            name="row_parallel_partial",
        )
        self._torch.launch(
            "row_parallel_gemm", _gemm_kernel,
            x, self.weight, y_partial, M, self.n_local, self.out_features,
        )
        # Cross-rank reduce. ADR-0024's dist.all_reduce works correctly
        # under D0 + mp.spawn (kernel parent = main is preserved).
        self._torch.distributed.all_reduce(y_partial, op="sum")
        return y_partial

Yield-safety contract (normative): RowParallelLinear.forward includes launch → internal wait followed by all_reduce (defer + worker-yield pattern), so at least 2 yields per forward are guaranteed. The scheduler-progress condition of D0.4-(1) is automatically satisfied. All TP layer forwards in this ADR maintain the invariant "yield-safe by containing at least one wait or collective" — any future TP primitives (e.g., VocabParallelEmbedding) must keep the same contract.

D6. Primitive functions

# primitives.py
def copy_to_tp_region(x):
    """Forward: identity. Backward: all-reduce. (Implemented when training is added)."""
    return x

def reduce_from_tp_region(x, torch):
    """Forward: all-reduce. Backward: identity."""
    torch.distributed.all_reduce(x, op="sum")
    return x

def scatter_to_tp_region(x):
    raise NotImplementedError(
        "Phase 2: replaced by users creating already-sharded tensors"
    )

def gather_from_tp_region(x):
    raise NotImplementedError(
        "Phase 2: requires all-gather kernel as a prerequisite (future)"
    )

D7. Sample bench — 2-layer MLP with TP

# benches/tp_mlp.py (new)
from kernbench.policy.placement.dp import DPPolicy
import kernbench.tp as tp
import numpy as np


def worker(rank: int, world_size: int, torch):
    torch.ahbm.set_device(rank)
    tp.initialize_model_parallel(world_size)

    B, D_in, D_hidden, D_out = 1, 512, 2048, 512
    fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=torch)
    fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=torch)

    x = torch.zeros(
        (B, D_in), dtype="f16",
        dp=DPPolicy(cube="replicate", pe="replicate"),
        name="x",
    )
    # init x with some pattern (e.g., constant)
    x.copy_(torch.from_numpy(np.full((B, D_in), 0.1, dtype=np.float16)))

    h = fc1.forward(x)      # column-sharded (B, D_hidden / ws)
    y = fc2.forward(h)      # all-reduced (B, D_out) on every rank

    # Only rank 0 prints / verifies the result
    if rank == 0:
        result = y.numpy()
        # With zero-init weights, all values are 0 — within scope "completion itself" is the check
        print(f"  tp_mlp: shape={result.shape}, mean={float(result.mean()):.4f}")


def run(torch):
    torch.distributed.init_process_group(backend="ahbm")
    ws = torch.distributed.get_world_size()
    torch.multiprocessing.spawn(worker, args=(ws,), nprocs=ws)

D8. Non-functional — training not supported

This ADR is inference/forward only. Backward / gradient / optimizer is future work. Natural because KernBench is not a training system.

D9. Initial-scope constraints

  • TP size = world_size (no mixed DP+TP).
  • scatter_to_tp_region, gather_from_tp_region are unimplemented.
  • Default weight value is zero. Proper init schemes (Xavier, Kaiming, etc.) are future. Tests inject deterministic non-zero patterns via tensor.copy_ to verify numerical correctness (T2/T6). I.e., operate as "production default = zero, verification = deterministic non-zero".
  • Bias is omitted in the initial scope (Megatron's rank-0-only bias policy is future).
  • Pipeline parallelism is out of scope.
  • VocabParallelEmbedding requires a prerequisite all-gather → stub only.

D10. Regression: ring_default_ws xfail removal — mandatory acceptance

Thanks to D0 (worker-wait generalization) + D0.5 (host-read barrier), every worker-driven ctx.wait and host-read is routed through the main-drain path → the cause of the kernel-greenlet orphan in ADR-0024 Phase B disappears. Flipping the existing matrix test's ring_default_ws strict-xfail case to PASS after this ADR's implementation is included as a mandatory regression criterion. Observable acceptance criteria are specified in T7 (no deadlock, no GreenletExit, numerical tolerance, etc.).


Dependencies

  • ADR-0024 (launcher): rank = SIP, greenlet-local rank, torch.ahbm.set_device(rank).
  • ADR-0026 (DPPolicy intra-device): per-rank slice representation of weight tensors.
  • ADR-0023 / ADR-0025 (IPCQ): foundation of dist.all_reduce implementation.

Non-goals

  • Backward pass / training: inference only. Training simulation is a separate ADR.
  • Mixed parallelism (DP + TP + PP): pure TP only at the start.
  • Weight init schemes: simple zero / debug pattern.
  • Fused ops: Megatron's fused matmul+bias+gelu is a kernel-level concern.
  • DTensor integration: ADR-0028 future.
  • Host-side torch.matmul abstraction: TP layers call the existing gemm kernel via torch.launch(gemm_kernel, ...). No new matmul host-op is introduced.

Open questions

  • Location of initialize_model_parallel: kernbench.tp.initialize_model_parallel (current decision) vs real-PyTorch's torch.distributed.init_device_mesh. Kept in the TP-only module.
  • Weight init: the ADR uses zero. A debug pattern (e.g., identity) may be needed for valid verification — add at Phase 1 test time if needed.
  • Bias placement policy: Megatron places RowParallelLinear bias only on rank 0. Avoided in the initial scope via bias=False.
  • GEMM kernel location: kernbench.tp.kernels._gemm_kernel vs importing from existing benches/gemm_single_pe.py. TP must not depend on benches, so duplicated inside tp. Migration to a shared kernbench.kernels package is possible later.

Resolved (previously open in earlier revisions):

  • Drain timing on tensor.numpy() calldecided in D0.5: the official host-read entry points (numpy, data, __getitem__, data-containing __repr__) are automatic drain barriers. Metadata-only accessors are not barriers.

Consequences

Positive

  • Easy porting of Megatron code: API matches real training code.
  • TP benchmarking enabled: research on scaling, communication-compute overlap, and other HW characteristics.
  • ring_default_ws xfail removal: as a byproduct of D0, the ADR-0024 Phase B blocker is resolved.
  • Scheduler-loop unification: introducing D1 (mp.spawn) removes the hand-rolled loop. Subsequent collective/TP benches reuse the same pattern.
  • DPPolicy semantics clarified (synergy with ADR-0026): TP layers as a best-practice example of using intra-device DPPolicy only.

Negative

  • Maintenance cost of a new module (kernbench.tp).
  • Initial scope is limited (pure TP only, forward only).
  • D0 generalization changes the semantics of ctx.wait — compatibility with single-driver tests must be explicitly verified (T7).

Neutral

  • A pure upper layer added on top of ADR-0024/0026. No impact on the hardware-simulation stack (apart from D0).