Files
kernbench2/docs/adr/ADR-0027-megatron-tp.md
T
ywkang e1084800ab docs: add ADRs 0024–0031 for SIP-TP launcher stack
ADR-0024 (SIP-level TP launcher): rank = SIP abstraction, engine-routed
  install, mp.spawn parity, epoch barrier, ShardSpec structural coords.
ADR-0025 (IPCQ direction addressing): address-based matching for meta
  arrival and credit return; fixes 2-rank bidirectional ring deadlock.
ADR-0026 (DPPolicy intra-device only): remove sip/num_sips fields;
  ShardSpec uses structural (sip, cube, pe); pe_index property removed.
ADR-0027 (Megatron-style TP API): ColumnParallelLinear / RowParallelLinear
  on top of ADR-0024 launcher. Backlog until 0024/0025/0026 land.
ADR-0028 (DTensor support): stub / future work.
ADR-0029 (Hierarchical all-reduce): 3-level reduce using all_pes mapper
  and multi_pe_sip_local validator from ADR-0024. Backlog.
ADR-0030 (IPCQ PhysAddr integration): blocked on ADR-0031.
ADR-0031 (PhysAddr PE-resource extension): stub; local_offset range-based
  partition approach; specific ranges TBD.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 00:38:27 -07:00

12 KiB
Raw Blame History

ADR-0027: Megatron-style Tensor Parallelism API

Status

Proposed

Context

목표

SIP 간 tensor parallelism(TP)을 Megatron-LM 스타일의 명시적 parallel layer API로 지원한다. DTensor 같은 선언적 추상화는 별도 ADR(0028) future work.

Megatron-style을 선택한 이유 (사용자 지시):

  • TP는 일반적으로 model의 특정 layer 경계에서 발생. 명시적 primitive가 mental model에 자연스러움.
  • NVIDIA Megatron / DeepSpeed가 확립한 인더스트리 표준 패턴.
  • DTensor는 선언적이라 디자인 공간이 더 크다 → 단계적으로 접근.

현재 상태

  • KernBench는 TP가 없음. 기존 DPPolicy.sip="column_wise" 경로가 "SIP 간 column sharding"을 흉내 냈으나 DP와 TP가 섞인 상태 (ADR-0026에서 정리).
  • ADR-0024가 launcher 인프라 (rank = SIP, set_device, greenlet-local) 제공.
  • 이 인프라 위에 TP primitive를 얹는다.

TP primitive 스펙 (Megatron-LM 참조)

  • ColumnParallelLinear: weight의 column 축을 TP ranks에 분산. 입력 full-replicated, 출력 column-sharded. 후속에 row-parallel이 올 때 all-reduce 없음.
  • RowParallelLinear: weight의 row 축을 TP ranks에 분산. 입력이 이미 column-sharded (ColumnParallel의 출력). forward 끝에 all-reduce 필요.
  • VocabParallelEmbedding: embedding을 vocab 축에 분산. forward 끝에 all-reduce.
  • copy_to_tp_region, reduce_from_tp_region, scatter_to_tp_region, gather_from_tp_region — 기본 primitive (identity forward, all-reduce backward 등).

풀어야 할 문제

  1. Per-rank weight 분산 표현: 각 worker(rank)가 weight tensor의 자기 slice를 소유. ADR-0024의 set_device(rank) + ADR-0026의 intra-device DPPolicy로 자연스러운 표현.

  2. Forward / backward activation 흐름: 현재 KernBench는 backward가 없음 (simulation 목적). 본 ADR은 forward만 우선 지원. Training simulation이 추가되면 확장.

  3. Collective 호출 지점: RowParallelLinear가 forward 끝에 all_reduce를 호출. ADR-0024의 multi-greenlet 구조에서 자연스럽게 동작 (각 rank가 동시에 호출).

  4. TP group 개념: Megatron은 일반적으로 data_parallel × tensor_parallel × pipeline_parallel group을 교차 사용. 초기 scope는 TP group = 전체 SIP로 단순화. Mixed DP+TP는 future.


Decision

D1. 새 패키지 kernbench.tp

src/kernbench/tp/
    __init__.py          — public API re-exports
    parallel_state.py    — TP group 관리 (현재 단일 global group)
    layers.py            — ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
    primitives.py        — copy/reduce/scatter/gather_to/from_tp_region
    mappings.py          — identity/all_reduce forward, all_reduce/identity backward (stub)

D2. parallel_state — TP group

# parallel_state.py
_TP_WORLD_SIZE = None
_TP_RANK = None   # greenlet-local via dist.get_rank()

def initialize_model_parallel(tensor_model_parallel_size: int) -> None:
    """Initialize TP group. Must be called after dist.init_process_group."""
    global _TP_WORLD_SIZE
    from kernbench.runtime_api.distributed import get_dist
    dist = get_dist()
    total = dist.get_world_size()
    if tensor_model_parallel_size != total:
        raise NotImplementedError(
            "Only TP == world_size supported in initial scope"
        )
    _TP_WORLD_SIZE = tensor_model_parallel_size

def get_tensor_model_parallel_world_size() -> int:
    return _TP_WORLD_SIZE

def get_tensor_model_parallel_rank() -> int:
    from kernbench.runtime_api.distributed import get_dist
    return get_dist().get_rank()  # ADR-0024의 greenlet-local rank

초기 scope: TP 사이즈 = world_size = topology SIP 수. Pure TP 모델.

D3. ColumnParallelLinear

# layers.py
class ColumnParallelLinear:
    """Weight의 K(out_features) 축을 TP rank에 분산.

    forward(x):
        x: (M, N) — full-replicated across ranks
        W_k: (N, K / world_size) — rank-local slice
        y_k = x @ W_k → (M, K / world_size) — rank-local output

    출력은 sharded. 후속 RowParallelLinear가 기대하는 입력 형태.
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = False,
                 dtype: str = "f16", torch=None):
        ws = get_tensor_model_parallel_world_size()
        assert out_features % ws == 0
        k_local = out_features // ws
        # 각 rank가 자기 slice 소유 (ADR-0024 set_device + ADR-0026 DPPolicy)
        self.weight = torch.zeros(
            (in_features, k_local), dtype=dtype,
            dp=DPPolicy(cube="column_wise", pe="column_wise"),
            name="col_parallel_w",
        )
        # init with something sensible — TODO
        if bias:
            self.bias = torch.zeros((k_local,), ...)
        else:
            self.bias = None

    def forward(self, x):
        # x는 full-replicated (caller 보장). 단순 local matmul.
        y = torch.matmul(x, self.weight)
        if self.bias is not None:
            y = y + self.bias
        return y

D4. RowParallelLinear

class RowParallelLinear:
    """Weight의 N(in_features) 축을 TP rank에 분산.

    forward(x):
        x: (M, N / world_size) — rank-local slice (ColumnParallel의 출력)
        W_k: (N / world_size, K) — rank-local slice
        y_k = x @ W_k → (M, K) — partial sum on each rank
        y = all_reduce(y_k, op="sum") → (M, K) on every rank
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = False,
                 dtype: str = "f16", torch=None):
        ws = get_tensor_model_parallel_world_size()
        assert in_features % ws == 0
        n_local = in_features // ws
        self.weight = torch.zeros(
            (n_local, out_features), dtype=dtype,
            dp=DPPolicy(cube="column_wise", pe="column_wise"),
            name="row_parallel_w",
        )
        # bias는 rank 0에만 (Megatron convention)
        self.bias = torch.zeros(...) if bias else None
        self._torch = torch

    def forward(self, x):
        y_partial = torch.matmul(x, self.weight)
        # Final all-reduce sums partial products across ranks
        self._torch.distributed.all_reduce(y_partial, op="sum")
        if self.bias is not None:
            # bias는 reduce 이후에만 추가 (rank 0 보유)
            rank = get_tensor_model_parallel_rank()
            if rank == 0:
                y_partial = y_partial + self.bias
        return y_partial

D5. Primitive 함수

# primitives.py
def copy_to_tp_region(x):
    """Forward: identity. Backward: all-reduce. (Training 추가 시 구현)."""
    return x

def reduce_from_tp_region(x):
    """Forward: all-reduce. Backward: identity."""
    torch.distributed.all_reduce(x, op="sum")
    return x

def scatter_to_tp_region(x):
    """x를 K 축으로 scatter. Forward: split. Backward: all-gather."""
    # 초기 scope에서는 사용자가 이미 sharded tensor를 만들었다고 가정 →
    # no-op 또는 metadata 추가
    raise NotImplementedError("Phase 2 feature")

def gather_from_tp_region(x):
    """x를 K 축으로 all-gather. Forward: all-gather. Backward: split."""
    raise NotImplementedError("all-gather kernel이 먼저 필요 (future)")

D6. 샘플 bench — 2-layer MLP with TP

# benches/tp_mlp.py (새 파일)
def worker(rank, world_size, torch):
    torch.cuda.set_device(rank)
    tp.initialize_model_parallel(world_size)

    B, D_in, D_hidden, D_out = 1, 512, 2048, 512
    fc1 = tp.ColumnParallelLinear(D_in, D_hidden, torch=torch)
    fc2 = tp.RowParallelLinear(D_hidden, D_out, torch=torch)

    x = torch.zeros((B, D_in), dtype="f16",
                    dp=DPPolicy(cube="column_wise", pe="column_wise"),
                    name="x")
    # ... init x ...

    h = fc1.forward(x)      # column-sharded output
    y = fc2.forward(h)      # all-reduced output, full on every rank

    # verify...

def run(torch):
    torch.distributed.init_process_group(backend="ahbm")
    torch.distributed.spawn(worker, nprocs=torch.distributed.get_world_size(),
                            args=(...,))

D7. Non-functional — training 미지원

본 ADR은 inference/forward only. Backward / gradient / optimizer는 future. 기존 KernBench가 training이 아니므로 자연스러움.

D8. 초기 scope 제약

  • TP 사이즈 = world_size (mixed DP+TP 없음)
  • scatter_to_tp_region, gather_from_tp_region은 unimplemented (별도 kernel 필요)
  • Weight init은 단순 zero (적절한 init은 future)
  • Pipeline parallelism은 scope 밖

D9. distributed.all_reduce 기반

RowParallelLinear의 모든 collective는 ADR-0024의 dist.all_reduce를 사용. 별도 TP-전용 collective 엔진 불필요.


Dependencies

  • ADR-0024 (launcher): rank = SIP, greenlet-local rank, dist.all_reduce, torch.cuda.set_device(rank), spawn_workers 제공.
  • ADR-0026 (DPPolicy intra-device): weight tensor의 per-rank slice 표현.
  • ADR-0023 / ADR-0025 (IPCQ): dist.all_reduce 구현의 기반.

Non-goals

  • Backward pass / training: inference only. Training simulation은 별도 ADR.
  • Mixed parallelism (DP + TP + PP): 초기엔 pure TP만.
  • Weight init schemes: 단순 zero / debug pattern. 실제 training init는 future.
  • Fused ops: Megatron의 fused matmul+bias+gelu 등은 KernBench kernel 수준 문제. 본 ADR은 host-side API만.
  • DTensor 통합: ADR-0028 future.

Open questions

  • initialize_model_parallel 위치: kernbench.tp.initialize_model_parallel vs torch.distributed.init_tp(...) 확장. real PyTorch는 torch.distributed. init_device_mesh 등을 권장. 우리는 당분간 TP-전용 모듈.
  • Weight의 DP 전략: 본 ADR은 DPPolicy(cube="column_wise", pe="column_wise") 를 가정. Intra-SIP DP를 다르게 주면? 성능 벤치마크로 결정.
  • Bias 배치 정책: Megatron은 bias를 split하지 않음. RowParallelLinear는 rank 0에만. 이게 항상 맞는가? 대안: replicate across ranks.
  • VocabParallelEmbedding: 처음 몇 벤치엔 불필요할 수도. 샘플 구현은 넣되 scope에서 제외할 수도.

Test strategy

T1. Unit — tests/test_tp_layers.py (신규)

  • ColumnParallelLinear forward: rank별 weight slice, 출력이 (M, K / ws).
  • RowParallelLinear forward: 입력이 sharded, all_reduce 후 (M, K) 일치.
  • VocabParallelEmbedding forward (if implemented).
  • parallel_state 초기화 / rank 조회.

T2. E2E — tests/test_tp_mlp.py (신규)

  • 2-layer MLP (ColumnParallel → RowParallel) forward가 single-driver reference 와 일치 (numerical check, rtol/atol).
  • ws = SIP count (current: 2).

T3. 회귀

  • ADR-0024의 test_ccl_allreduce_matrix 그대로 통과 (TP가 호출하는 dist.all_reduce의 기반).

Consequences

Positive

  • Megatron 코드 이식 용이: real training code와 API 일치.
  • TP 벤치마크 가능: scaling, communication-compute overlap 등 HW 특성 연구.
  • DPPolicy 의미 명확화 (ADR-0026과 시너지).

Negative

  • 새 모듈 (kernbench.tp) 유지보수 비용.
  • 초기 scope가 제한적 (pure TP only).

Neutral

  • ADR-0024/0026 기반 위에 순수한 상위 레이어 추가. Hardware simulation stack에 영향 없음.

Affected files

File Change
src/kernbench/tp/__init__.py 신규: public API re-export
src/kernbench/tp/parallel_state.py 신규: D2
src/kernbench/tp/layers.py 신규: D3/D4
src/kernbench/tp/primitives.py 신규: D5
src/kernbench/tp/mappings.py 신규 stub (backward TODO)
benches/tp_mlp.py 신규: D6 샘플
tests/test_tp_layers.py 신규: T1
tests/test_tp_mlp.py 신규: T2
docs/tp-author-guide.md 신규 (선택): 사용자 가이드