"""ADR-0027 T6: End-to-end 2-layer MLP with TP. Phase 1: fails at imports. Phase 2 lands the TP package + D7 bench pattern and these pass with numerical-correctness checks. """ from __future__ import annotations import numpy as np import pytest def _make_ctx(topology): from kernbench.runtime_api.context import RuntimeContext from kernbench.runtime_api.types import DeviceSelector from kernbench.sim_engine.engine import GraphEngine engine = GraphEngine(topology.topology_obj, enable_data=True) return RuntimeContext( engine=engine, target_device=DeviceSelector("all"), correlation_id="test_t6", spec=topology.topology_obj.spec, ) def _replicate_dp(): from kernbench.policy.placement.dp import DPPolicy return DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1) # ── T6.a: zero-weight smoke ────────────────────────────────────────── def test_mlp_zero_weight_produces_zero_output(topology): """T6.a: zero-init weight → output ≈ 0 for every rank.""" import kernbench.tp as tp with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() tp.initialize_model_parallel(ws) B, D_in, D_hidden, D_out = 1, 32, 32 * ws, 32 outputs: dict[int, np.ndarray] = {} def _worker(rank: int): ctx.ahbm.set_device(rank) fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=ctx) fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=ctx) x = ctx.zeros((B, D_in), dtype="f16", dp=_replicate_dp(), name=f"t6a_x_r{rank}") from kernbench.runtime_api.tensor import Tensor hx = Tensor(shape=(B, D_in), dtype="f16", name="host_x") hx._host_buffer = np.full((B, D_in), 0.1, dtype=np.float16) x.copy_(hx) h = fc1.forward(x) y = fc2.forward(h) outputs[rank] = y.numpy() ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws) for r, out in outputs.items(): assert np.allclose(out, 0.0, atol=1e-2), ( f"rank {r}: zero-weight output should be ~0; got mean={out.mean()}" ) # ── T6.b: deterministic weight + numerical check ───────────────────── def test_mlp_deterministic_weight_matches_reference(topology): """T6.b: non-zero deterministic weights → output matches numpy reference.""" import kernbench.tp as tp from kernbench.runtime_api.tensor import Tensor with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() tp.initialize_model_parallel(ws) B, D_in, D_hidden, D_out = 1, 16, 16 * ws, 16 # W1 (D_in, D_hidden) — column-sharded; per rank: (D_in, D_hidden/ws) # W2 (D_hidden, D_out) — row-sharded; per rank: (D_hidden/ws, D_out) # Constant values: W1 = 0.02, W2 = 0.03, x = 0.1 (all fp16). X_VAL, W1_VAL, W2_VAL = 0.1, 0.02, 0.03 outputs: dict[int, np.ndarray] = {} def _worker(rank: int): ctx.ahbm.set_device(rank) fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=ctx) fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=ctx) # W1 slice (per rank column slice) k_local_1 = D_hidden // ws w1_np = np.full((D_in, k_local_1), W1_VAL, dtype=np.float16) src1 = Tensor(shape=w1_np.shape, dtype="f16", name="host_w1") src1._host_buffer = w1_np fc1.weight.copy_(src1) # W2 slice (per rank row slice) n_local_2 = D_hidden // ws w2_np = np.full((n_local_2, D_out), W2_VAL, dtype=np.float16) src2 = Tensor(shape=w2_np.shape, dtype="f16", name="host_w2") src2._host_buffer = w2_np fc2.weight.copy_(src2) # Input x (full-replicated constant) x = ctx.zeros((B, D_in), dtype="f16", dp=_replicate_dp(), name=f"t6b_x_r{rank}") hx = Tensor(shape=(B, D_in), dtype="f16", name="host_x") hx._host_buffer = np.full((B, D_in), X_VAL, dtype=np.float16) x.copy_(hx) h = fc1.forward(x) y = fc2.forward(h) outputs[rank] = y.numpy().astype(np.float32) ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws) # Host reference: y = x @ W1_full @ W2_full w1_full = np.full((D_in, D_hidden), W1_VAL, dtype=np.float32) w2_full = np.full((D_hidden, D_out), W2_VAL, dtype=np.float32) x_full = np.full((B, D_in), X_VAL, dtype=np.float32) expected = x_full @ w1_full @ w2_full for r, out in outputs.items(): assert out.shape == (B, D_out) assert np.allclose(out, expected, rtol=1e-2, atol=1e-2), ( f"rank {r}: MLP output != reference " f"(got mean={out.mean():.4f}, expected={expected.mean():.4f})" ) # ── T6.c: rank-consistency after final all_reduce ──────────────────── def test_mlp_rank_consistency_after_all_reduce(topology): """T6.c: all ranks see elementwise-identical final output.""" import kernbench.tp as tp from kernbench.runtime_api.tensor import Tensor with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() tp.initialize_model_parallel(ws) B, D_in, D_hidden, D_out = 1, 16, 16 * ws, 16 outputs: dict[int, np.ndarray] = {} def _worker(rank: int): ctx.ahbm.set_device(rank) fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=ctx) fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=ctx) # Zero weights OK for this check — just need all_reduce to run. x = ctx.zeros((B, D_in), dtype="f16", dp=_replicate_dp(), name=f"t6c_x_r{rank}") hx = Tensor(shape=(B, D_in), dtype="f16", name="host_x") hx._host_buffer = np.full((B, D_in), 0.1, dtype=np.float16) x.copy_(hx) h = fc1.forward(x) y = fc2.forward(h) outputs[rank] = y.numpy() ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws) ref = outputs[0] for r, out in outputs.items(): assert np.array_equal(out, ref), ( f"rank {r} output differs from rank 0 — all-reduce should " f"make every rank see the same final tensor" ) # ── T6.d: shape contract ───────────────────────────────────────────── def test_mlp_shape_contract(topology): """T6.d: ColumnParallel → (B, D_hidden/ws); RowParallel → (B, D_out).""" import kernbench.tp as tp with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() tp.initialize_model_parallel(ws) B, D_in, D_hidden, D_out = 1, 16, 16 * ws, 16 def _worker(rank: int): ctx.ahbm.set_device(rank) fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=ctx) fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=ctx) x = ctx.zeros((B, D_in), dtype="f16", dp=_replicate_dp(), name=f"t6d_x_r{rank}") h = fc1.forward(x) assert h.shape == (B, D_hidden // ws), ( f"ColumnParallel output shape: {h.shape} != (B, D_hidden/ws)" ) y = fc2.forward(h) assert y.shape == (B, D_out), ( f"RowParallel output shape: {y.shape} != (B, D_out)" ) ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws) # ── liveness: deadlock 없음 (pytest timeout 간접 검증) ─────────────── def test_mlp_completes_without_deadlock(topology): """Structural: full E2E spawn returns within a reasonable wall-clock. Relies on the test suite's overall timeout harness. If this hangs beyond ~60s it would surface as a pytest timeout — a deadlock regression in the scheduler loop would manifest here.""" import kernbench.tp as tp with _make_ctx(topology) as ctx: ctx.distributed.init_process_group(backend="ahbm") ws = ctx.distributed.get_world_size() tp.initialize_model_parallel(ws) def _worker(rank: int): ctx.ahbm.set_device(rank) fc1 = tp.ColumnParallelLinear(16, 16 * ws, torch=ctx) fc2 = tp.RowParallelLinear(16 * ws, 16, torch=ctx) x = ctx.zeros((1, 16), dtype="f16", dp=_replicate_dp(), name=f"t6live_r{rank}") h = fc1.forward(x) y = fc2.forward(h) _ = y.numpy() ctx.multiprocessing.spawn(_worker, args=(), nprocs=ws)