# ADR-0027: Megatron-style Tensor Parallelism API ## Status Proposed ## Context ### 목표 SIP 간 tensor parallelism(TP)을 **Megatron-LM 스타일의 명시적 parallel layer** API로 지원한다. DTensor 같은 선언적 추상화는 별도 ADR(0028) future work. Megatron-style을 선택한 이유 (사용자 지시): - TP는 일반적으로 **model의 특정 layer 경계**에서 발생. 명시적 primitive가 mental model에 자연스러움. - NVIDIA Megatron / DeepSpeed가 확립한 인더스트리 표준 패턴. - DTensor는 선언적이라 **디자인 공간이 더 크다** → 단계적으로 접근. ### 현재 상태 - KernBench는 TP가 없음. 기존 `DPPolicy.sip="column_wise"` 경로가 "SIP 간 column sharding"을 흉내 냈으나 DP와 TP가 섞인 상태 (ADR-0026에서 정리). - ADR-0024가 launcher 인프라 (rank = SIP, `set_device`, greenlet-local) 제공. - 이 인프라 위에 TP primitive를 얹는다. ### TP primitive 스펙 (Megatron-LM 참조) - **ColumnParallelLinear**: weight의 **column** 축을 TP ranks에 분산. 입력 full-replicated, 출력 column-sharded. 후속에 row-parallel이 올 때 all-reduce 없음. - **RowParallelLinear**: weight의 **row** 축을 TP ranks에 분산. 입력이 이미 column-sharded (ColumnParallel의 출력). forward 끝에 **all-reduce** 필요. - **VocabParallelEmbedding**: embedding을 vocab 축에 분산. forward 끝에 all-reduce. - **`copy_to_tp_region`**, **`reduce_from_tp_region`**, **`scatter_to_tp_region`**, **`gather_from_tp_region`** — 기본 primitive (`identity` forward, `all-reduce` backward 등). ### 풀어야 할 문제 1. **Per-rank weight 분산 표현**: 각 worker(rank)가 weight tensor의 자기 slice를 소유. ADR-0024의 `set_device(rank)` + ADR-0026의 intra-device DPPolicy로 자연스러운 표현. 2. **Forward / backward activation 흐름**: 현재 KernBench는 backward가 없음 (simulation 목적). 본 ADR은 **forward만** 우선 지원. Training simulation이 추가되면 확장. 3. **Collective 호출 지점**: RowParallelLinear가 forward 끝에 `all_reduce`를 호출. ADR-0024의 multi-greenlet 구조에서 자연스럽게 동작 (각 rank가 동시에 호출). 4. **TP group 개념**: Megatron은 일반적으로 data_parallel × tensor_parallel × pipeline_parallel group을 교차 사용. 초기 scope는 **TP group = 전체 SIP**로 단순화. Mixed DP+TP는 future. --- ## Decision ### D1. 새 패키지 `kernbench.tp` ``` src/kernbench/tp/ __init__.py — public API re-exports parallel_state.py — TP group 관리 (현재 단일 global group) layers.py — ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding primitives.py — copy/reduce/scatter/gather_to/from_tp_region mappings.py — identity/all_reduce forward, all_reduce/identity backward (stub) ``` ### D2. `parallel_state` — TP group ```python # parallel_state.py _TP_WORLD_SIZE = None _TP_RANK = None # greenlet-local via dist.get_rank() def initialize_model_parallel(tensor_model_parallel_size: int) -> None: """Initialize TP group. Must be called after dist.init_process_group.""" global _TP_WORLD_SIZE from kernbench.runtime_api.distributed import get_dist dist = get_dist() total = dist.get_world_size() if tensor_model_parallel_size != total: raise NotImplementedError( "Only TP == world_size supported in initial scope" ) _TP_WORLD_SIZE = tensor_model_parallel_size def get_tensor_model_parallel_world_size() -> int: return _TP_WORLD_SIZE def get_tensor_model_parallel_rank() -> int: from kernbench.runtime_api.distributed import get_dist return get_dist().get_rank() # ADR-0024의 greenlet-local rank ``` 초기 scope: **TP 사이즈 = world_size = topology SIP 수**. Pure TP 모델. ### D3. `ColumnParallelLinear` ```python # layers.py class ColumnParallelLinear: """Weight의 K(out_features) 축을 TP rank에 분산. forward(x): x: (M, N) — full-replicated across ranks W_k: (N, K / world_size) — rank-local slice y_k = x @ W_k → (M, K / world_size) — rank-local output 출력은 sharded. 후속 RowParallelLinear가 기대하는 입력 형태. """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype: str = "f16", torch=None): ws = get_tensor_model_parallel_world_size() assert out_features % ws == 0 k_local = out_features // ws # 각 rank가 자기 slice 소유 (ADR-0024 set_device + ADR-0026 DPPolicy) self.weight = torch.zeros( (in_features, k_local), dtype=dtype, dp=DPPolicy(cube="column_wise", pe="column_wise"), name="col_parallel_w", ) # init with something sensible — TODO if bias: self.bias = torch.zeros((k_local,), ...) else: self.bias = None def forward(self, x): # x는 full-replicated (caller 보장). 단순 local matmul. y = torch.matmul(x, self.weight) if self.bias is not None: y = y + self.bias return y ``` ### D4. `RowParallelLinear` ```python class RowParallelLinear: """Weight의 N(in_features) 축을 TP rank에 분산. forward(x): x: (M, N / world_size) — rank-local slice (ColumnParallel의 출력) W_k: (N / world_size, K) — rank-local slice y_k = x @ W_k → (M, K) — partial sum on each rank y = all_reduce(y_k, op="sum") → (M, K) on every rank """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype: str = "f16", torch=None): ws = get_tensor_model_parallel_world_size() assert in_features % ws == 0 n_local = in_features // ws self.weight = torch.zeros( (n_local, out_features), dtype=dtype, dp=DPPolicy(cube="column_wise", pe="column_wise"), name="row_parallel_w", ) # bias는 rank 0에만 (Megatron convention) self.bias = torch.zeros(...) if bias else None self._torch = torch def forward(self, x): y_partial = torch.matmul(x, self.weight) # Final all-reduce sums partial products across ranks self._torch.distributed.all_reduce(y_partial, op="sum") if self.bias is not None: # bias는 reduce 이후에만 추가 (rank 0 보유) rank = get_tensor_model_parallel_rank() if rank == 0: y_partial = y_partial + self.bias return y_partial ``` ### D5. Primitive 함수 ```python # primitives.py def copy_to_tp_region(x): """Forward: identity. Backward: all-reduce. (Training 추가 시 구현).""" return x def reduce_from_tp_region(x): """Forward: all-reduce. Backward: identity.""" torch.distributed.all_reduce(x, op="sum") return x def scatter_to_tp_region(x): """x를 K 축으로 scatter. Forward: split. Backward: all-gather.""" # 초기 scope에서는 사용자가 이미 sharded tensor를 만들었다고 가정 → # no-op 또는 metadata 추가 raise NotImplementedError("Phase 2 feature") def gather_from_tp_region(x): """x를 K 축으로 all-gather. Forward: all-gather. Backward: split.""" raise NotImplementedError("all-gather kernel이 먼저 필요 (future)") ``` ### D6. 샘플 bench — 2-layer MLP with TP ```python # benches/tp_mlp.py (새 파일) def worker(rank, world_size, torch): torch.cuda.set_device(rank) tp.initialize_model_parallel(world_size) B, D_in, D_hidden, D_out = 1, 512, 2048, 512 fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=torch) fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=torch) x = torch.zeros((B, D_in), dtype="f16", dp=DPPolicy(cube="column_wise", pe="column_wise"), name="x") # ... init x ... h = fc1.forward(x) # column-sharded output y = fc2.forward(h) # all-reduced output, full on every rank # verify... def run(torch): torch.distributed.init_process_group(backend="ahbm") torch.distributed.spawn(worker, nprocs=torch.distributed.get_world_size(), args=(...,)) ``` ### D7. Non-functional — training 미지원 본 ADR은 **inference/forward only**. Backward / gradient / optimizer는 future. 기존 KernBench가 training이 아니므로 자연스러움. ### D8. 초기 scope 제약 - TP 사이즈 = world_size (mixed DP+TP 없음) - `scatter_to_tp_region`, `gather_from_tp_region`은 unimplemented (별도 kernel 필요) - Weight init은 단순 zero (적절한 init은 future) - Pipeline parallelism은 scope 밖 ### D9. `distributed.all_reduce` 기반 RowParallelLinear의 모든 collective는 ADR-0024의 `dist.all_reduce`를 사용. 별도 TP-전용 collective 엔진 불필요. --- ## Dependencies - **ADR-0024** (launcher): rank = SIP, greenlet-local rank, `dist.all_reduce`, `torch.cuda.set_device(rank)`, `spawn_workers` 제공. - **ADR-0026** (DPPolicy intra-device): weight tensor의 per-rank slice 표현. - **ADR-0023 / ADR-0025** (IPCQ): `dist.all_reduce` 구현의 기반. --- ## Non-goals - **Backward pass / training**: inference only. Training simulation은 별도 ADR. - **Mixed parallelism (DP + TP + PP)**: 초기엔 pure TP만. - **Weight init schemes**: 단순 zero / debug pattern. 실제 training init는 future. - **Fused ops**: Megatron의 fused matmul+bias+gelu 등은 KernBench kernel 수준 문제. 본 ADR은 host-side API만. - **DTensor 통합**: ADR-0028 future. --- ## Open questions - **`initialize_model_parallel` 위치**: `kernbench.tp.initialize_model_parallel` vs `torch.distributed.init_tp(...)` 확장. real PyTorch는 `torch.distributed. init_device_mesh` 등을 권장. 우리는 당분간 TP-전용 모듈. - **Weight의 DP 전략**: 본 ADR은 `DPPolicy(cube="column_wise", pe="column_wise")` 를 가정. Intra-SIP DP를 다르게 주면? 성능 벤치마크로 결정. - **Bias 배치 정책**: Megatron은 bias를 split하지 않음. RowParallelLinear는 rank 0에만. 이게 항상 맞는가? 대안: replicate across ranks. - **`VocabParallelEmbedding`**: 처음 몇 벤치엔 불필요할 수도. 샘플 구현은 넣되 scope에서 제외할 수도. --- ## Test strategy ### T1. Unit — `tests/test_tp_layers.py` (신규) - `ColumnParallelLinear` forward: rank별 weight slice, 출력이 `(M, K / ws)`. - `RowParallelLinear` forward: 입력이 sharded, all_reduce 후 `(M, K)` 일치. - `VocabParallelEmbedding` forward (if implemented). - `parallel_state` 초기화 / rank 조회. ### T2. E2E — `tests/test_tp_mlp.py` (신규) - 2-layer MLP (ColumnParallel → RowParallel) forward가 single-driver reference 와 일치 (numerical check, rtol/atol). - ws = SIP count (current: 2). ### T3. 회귀 - ADR-0024의 `test_ccl_allreduce_matrix` 그대로 통과 (TP가 호출하는 `dist.all_reduce`의 기반). --- ## Consequences ### Positive - **Megatron 코드 이식 용이**: real training code와 API 일치. - **TP 벤치마크 가능**: scaling, communication-compute overlap 등 HW 특성 연구. - **DPPolicy 의미 명확화** (ADR-0026과 시너지). ### Negative - 새 모듈 (`kernbench.tp`) 유지보수 비용. - 초기 scope가 제한적 (pure TP only). ### Neutral - ADR-0024/0026 기반 위에 순수한 상위 레이어 추가. Hardware simulation stack에 영향 없음. --- ## Affected files | File | Change | |------|--------| | `src/kernbench/tp/__init__.py` | 신규: public API re-export | | `src/kernbench/tp/parallel_state.py` | 신규: D2 | | `src/kernbench/tp/layers.py` | 신규: D3/D4 | | `src/kernbench/tp/primitives.py` | 신규: D5 | | `src/kernbench/tp/mappings.py` | 신규 stub (backward TODO) | | `benches/tp_mlp.py` | 신규: D6 샘플 | | `tests/test_tp_layers.py` | 신규: T1 | | `tests/test_tp_mlp.py` | 신규: T2 | | `docs/tp-author-guide.md` | 신규 (선택): 사용자 가이드 |