"""Megatron-style parallel layers (ADR-0027 D4/D5). - ``ColumnParallelLinear``: weight's out_features axis split across TP ranks. forward(x) is local gemm; no collective. - ``RowParallelLinear``: weight's in_features axis split across TP ranks. forward(x) ends with ``dist.all_reduce`` to sum partial products. Both layers use the intra-device ``DPPolicy`` (ADR-0026). TP shard ownership is determined by ``torch.ahbm.set_device(rank)`` (ADR-0024 D10). Yield-safety contract (ADR-0027 D4/D5): every forward path contains at least one ``ctx.wait`` (via ``torch.launch``) or one collective; this keeps the scheduler loop making progress. """ from __future__ import annotations from typing import Any from kernbench.policy.placement.dp import DPPolicy from kernbench.tp.kernels import _gemm_kernel from kernbench.tp.parallel_state import ( get_tensor_model_parallel_world_size, ) class ColumnParallelLinear: """Weight's K (out_features) axis distributed across TP ranks. forward(x): x: (M, N) — full-replicated across ranks W_k: (N, K / world_size) — this rank's slice (on its SIP) y_k = x @ W_k → (M, K / world_size) """ def __init__( self, in_features: int, out_features: int, bias: bool = False, dtype: str = "f16", torch: Any = None, ) -> None: if torch is None: raise TypeError("ColumnParallelLinear requires torch=") ws = get_tensor_model_parallel_world_size() if out_features % ws != 0: raise ValueError( f"out_features ({out_features}) must be divisible by TP world " f"size ({ws})" ) self.in_features = in_features self.out_features = out_features self.k_local = out_features // ws self.dtype = dtype self._torch = torch # Per-rank weight slice. ``set_device(rank)`` (ADR-0024 D10) places # it on SIP ``rank``. Intra-SIP layout comes from DPPolicy (ADR-0026). self.weight = torch.zeros( (in_features, self.k_local), dtype=dtype, dp=DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1), name="col_parallel_w", ) # Bias omitted in initial scope (ADR-0027 D9). self.bias = None if bias: raise NotImplementedError( "bias=True is deferred (ADR-0027 D9 initial scope)" ) def forward(self, x): M = int(x.shape[0]) out = self._torch.empty( (M, self.k_local), dtype=x.dtype, dp=DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1), name="col_parallel_out", ) self._torch.launch( "col_parallel_gemm", _gemm_kernel, x, self.weight, out, M, self.in_features, self.k_local, ) return out class RowParallelLinear: """Weight's N (in_features) axis distributed across TP ranks. forward(x): x: (M, N / world_size) — rank-local slice (ColumnParallel output) W_k: (N / world_size, K) — this rank's slice y_k = x @ W_k → (M, K) — partial sum y = all_reduce(y_k, op="sum") → (M, K) on every rank """ def __init__( self, in_features: int, out_features: int, bias: bool = False, dtype: str = "f16", torch: Any = None, ) -> None: if torch is None: raise TypeError("RowParallelLinear requires torch=") ws = get_tensor_model_parallel_world_size() if in_features % ws != 0: raise ValueError( f"in_features ({in_features}) must be divisible by TP world " f"size ({ws})" ) self.in_features = in_features self.out_features = out_features self.n_local = in_features // ws self.dtype = dtype self._torch = torch self.weight = torch.zeros( (self.n_local, out_features), dtype=dtype, dp=DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1), name="row_parallel_w", ) self.bias = None if bias: raise NotImplementedError( "bias=True is deferred (ADR-0027 D9 initial scope)" ) def forward(self, x): M = int(x.shape[0]) y_partial = self._torch.empty( (M, self.out_features), dtype=x.dtype, dp=DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1), name="row_parallel_partial", ) self._torch.launch( "row_parallel_gemm", _gemm_kernel, x, self.weight, y_partial, M, self.n_local, self.out_features, ) self._torch.distributed.all_reduce(y_partial, op="sum") return y_partial