"""ADR-0027 T2: TP layer shape + numerical correctness (D4/D5). Phase 1: ``kernbench.tp.layers`` doesn't exist → import failure. Phase 2 lands D4/D5 and T2 passes with deterministic non-zero weight patterns. """ from __future__ import annotations import numpy as np import pytest def _make_ctx(topology): from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.types import DeviceSelector from kernbench.sim_engine.engine import GraphEngine engine = GraphEngine(topology.topology_obj, enable_data=True) return RuntimeContext( engine=engine, target_device=DeviceSelector("all"), correlation_id="test_t2", spec=topology.topology_obj.spec, ) # ── Shape / structural ─────────────────────────────────────────────── def test_column_parallel_weight_shape_per_rank(topology): """ColumnParallelLinear weight per rank is (in_features, out // ws).""" import kernbench.tp as tp from kernbench.runtime_api.tensor import Tensor with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() tp.initialize_model_parallel(ws) def _worker(rank: int): ctx.ahbm.set_device(rank) fc = tp.ColumnParallelLinear( in_features=256, out_features=512, torch=ctx, ) assert fc.weight.shape == (256, 512 // ws) ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws) def test_row_parallel_weight_shape_per_rank(topology): """RowParallelLinear weight per rank is (in_features // ws, out_features).""" import kernbench.tp as tp with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() tp.initialize_model_parallel(ws) def _worker(rank: int): ctx.ahbm.set_device(rank) fc = tp.RowParallelLinear( in_features=512, out_features=256, torch=ctx, ) assert fc.weight.shape == (512 // ws, 256) ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws) # ── T2.a: ColumnParallel deterministic numerical ───────────────────── def test_column_parallel_forward_matches_matmul(topology): """T2.a: ColumnParallelLinear.forward output == x @ W_rank (rtol 1e-2).""" import kernbench.tp as tp from kernbench.runtime_api.tensor import Tensor with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() tp.initialize_model_parallel(ws) M = 4 D_in, D_out = 32, 32 * ws def _worker(rank: int): ctx.ahbm.set_device(rank) fc = tp.ColumnParallelLinear( in_features=D_in, out_features=D_out, torch=ctx, ) # Deterministic non-zero weight: rank-scaled constant. k_local = D_out // ws weight_np = np.full( (D_in, k_local), 0.01 * (rank + 1), dtype=np.float16, ) src = Tensor(shape=(D_in, k_local), dtype="f16", name="host_w") src._host_buffer = weight_np fc.weight.copy_(src) # Input: full-replicated constant. x_np = np.full((M, D_in), 0.5, dtype=np.float16) x = ctx.zeros( (M, D_in), dtype="f16", dp=_replicate_dp(), name=f"t2a_x_r{rank}", ) hx = Tensor(shape=x_np.shape, dtype="f16", name="host_x") hx._host_buffer = x_np x.copy_(hx) y = fc.forward(x) out = y.numpy() expected = x_np.astype(np.float32) @ weight_np.astype(np.float32) assert out.shape == (M, k_local) assert np.allclose(out.astype(np.float32), expected, rtol=1e-2, atol=1e-2), ( f"rank {rank}: output does not match x @ W_local" ) ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws) # ── T2.b: RowParallel observable equality ──────────────────────────── def test_row_parallel_forward_concat_matmul_equality(topology): """T2.b (primary): RowParallel output == concat(x) @ concat(W) (all-reduced).""" import kernbench.tp as tp from kernbench.runtime_api.tensor import Tensor with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() tp.initialize_model_parallel(ws) M = 4 D_in, D_out = 32 * ws, 32 # must divide ws evenly results: dict[int, np.ndarray] = {} def _worker(rank: int): ctx.ahbm.set_device(rank) fc = tp.RowParallelLinear( in_features=D_in, out_features=D_out, torch=ctx, ) # Per-rank W_k = constant 0.01 * (rank + 1) n_local = D_in // ws weight_np = np.full( (n_local, D_out), 0.01 * (rank + 1), dtype=np.float16, ) src = Tensor(shape=weight_np.shape, dtype="f16", name="host_w") src._host_buffer = weight_np fc.weight.copy_(src) # Input x_k = constant 0.1 * (rank + 1) (pretending it was # column-sharded from upstream). x_np = np.full((M, n_local), 0.1 * (rank + 1), dtype=np.float16) x = ctx.zeros( (M, n_local), dtype="f16", dp=_replicate_dp(), name=f"t2b_x_r{rank}", ) hx = Tensor(shape=x_np.shape, dtype="f16", name="host_x") hx._host_buffer = x_np x.copy_(hx) y = fc.forward(x) results[rank] = y.numpy().astype(np.float32) ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws) # Host-side reference: compute sum_r (x_r @ W_r) = y (same on all ranks). expected = np.zeros((M, D_out), dtype=np.float32) n_local = D_in // ws for r in range(ws): x_r = np.full((M, n_local), 0.1 * (r + 1), dtype=np.float32) w_r = np.full((n_local, D_out), 0.01 * (r + 1), dtype=np.float32) expected += x_r @ w_r for r, out in results.items(): assert np.allclose(out, expected, rtol=1e-2, atol=1e-2), ( f"rank {r}: all-reduced output != expected partial sum" ) # ── T2.c: rank-consistency post all-reduce ─────────────────────────── def test_row_parallel_rank_identity_post_all_reduce(topology): """T2.c: after all_reduce, all ranks see elementwise-identical output.""" import kernbench.tp as tp from kernbench.runtime_api.tensor import Tensor with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() tp.initialize_model_parallel(ws) M = 2 D_in, D_out = 16 * ws, 16 results: dict[int, np.ndarray] = {} def _worker(rank: int): ctx.ahbm.set_device(rank) fc = tp.RowParallelLinear( in_features=D_in, out_features=D_out, torch=ctx, ) n_local = D_in // ws weight_np = np.full((n_local, D_out), 0.01, dtype=np.float16) src = Tensor(shape=weight_np.shape, dtype="f16", name="host_w") src._host_buffer = weight_np fc.weight.copy_(src) x_np = np.full((M, n_local), 0.1, dtype=np.float16) x = ctx.zeros( (M, n_local), dtype="f16", dp=_replicate_dp(), name=f"t2c_x_r{rank}", ) hx = Tensor(shape=x_np.shape, dtype="f16", name="host_x") hx._host_buffer = x_np x.copy_(hx) y = fc.forward(x) results[rank] = y.numpy() ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws) ref = results[0] for r, out in results.items(): assert np.allclose(out, ref, rtol=1e-2, atol=1e-2), ( f"rank {r} output differs from rank 0 — all_reduce failed to make " f"outputs elementwise identical" ) def _replicate_dp(): from kernbench.policy.placement.dp import DPPolicy return DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)