# ADR-0022: 2D 그리드 program_id 시맨틱 ## Status Accepted ## Context Triton 커널은 `tl.program_id(axis)`를 사용해 launch 그리드 내 자신의 위치를 식별한다. 본 하드웨어는 2단계 계층을 갖는다: **큐브**가 **PE**를 포함한다. 이전 구현은 `axis` 파라미터를 무시하고 항상 평탄화된 PE 인덱스를 반환했기 때문에, 커널이 큐브 내부 위치와 큐브 식별자를 구분할 수 없었다. ## Decision `tl.program_id`와 `tl.num_programs`를 2D 하드웨어 그리드에 매핑한다: | Call | Returns | Description | |------|---------|-------------| | `tl.program_id(axis=0)` | `local_pe_id` | 큐브 내 PE 인덱스 | | `tl.program_id(axis=1)` | `cube_id` | 큐브 인덱스 | | `tl.num_programs(axis=0)` | `num_pes_per_cube` | 큐브당 PE 개수 | | `tl.num_programs(axis=1)` | `num_cubes` | 전체 큐브 개수 | 전역 PID는 다음과 같이 도출된다: ```python global_pid = tl.program_id(axis=1) * tl.num_programs(axis=0) + tl.program_id(axis=0) ``` ### 축 매핑 근거 - **axis=0 = PE (최내부)**: 큐브 내부 PE들은 HBM을 공유하고 로컬 NoC 메시를 통해 통신한다. 빠르고 강하게 결합된 차원이다 — 블록 내부의 스레드와 유사하다. - **axis=1 = 큐브 (외부)**: 큐브 간 통신은 더 높은 레이턴시의 UCIe를 통한다. 더 거친 스케줄링 차원이다 — 그리드 내의 블록과 유사하다. ## Implementation ### TLContext (`triton_emu/tl_context.py`) `cube_id`와 `num_cubes` 생성자 파라미터를 추가했다. `program_id()`와 `num_programs()`가 `axis`로 디스패치한다: ```python def program_id(self, axis: int = 0) -> int: if axis == 1: return self._cube_id return self._pe_id def num_programs(self, axis: int = 0) -> int: if axis == 1: return self._num_cubes return self._num_programs ``` ### PE_CPU (`components/builtin/pe_cpu.py`) - `ctx.spec["system"]["sips"]["cubes_per_sip"]`에서 `num_cubes`를 추출한다. - `cube_id`(이미 `self._cube_idx`로 사용 가능)와 `num_cubes`를 TLContext에 전달한다. ### KernelRunner (`triton_emu/kernel_runner.py`) - PE_CPU로부터 `num_cubes`를 수신한다. - greenlet 모드에서 `cube_id`와 `num_cubes`를 TLContext에 전달한다. ## Backward Compatibility - `tl.program_id(0)` 또는 `tl.program_id()`를 사용하는 기존 코드는 변경되지 않는다 — 이전과 동일한 PE 인덱스를 반환한다. - `cube_id`와 `num_cubes`는 기본값이 `0`과 `1`이므로, 이를 제공하지 않는 호출자(예: 유닛 테스트)도 계속 동작한다. ## Usage Example ```python def sharded_gemm_kernel(a_ptr, b_ptr, out_ptr, M, K, N, tl): local_pid = tl.program_id(axis=0) # PE within cube cube_id = tl.program_id(axis=1) # which cube global_pid = cube_id * tl.num_programs(axis=0) + local_pid # 전역 PID에 걸친 column-wise 샤딩 n_per_pid = N // (tl.num_programs(axis=1) * tl.num_programs(axis=0)) col_start = global_pid * n_per_pid a = tl.load(a_ptr, shape=(M, K), dtype="f16") b = tl.ref(b_ptr + col_start * K * 2, shape=(K, n_per_pid), dtype="f16") h = tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr + col_start * M * 2) tl.wait(h) ``` ## Consequences - 벤치마크가 토폴로지 차원을 하드코딩하지 않고 큐브 인식 샤딩과 주소 지정을 표현할 수 있다. - 필요 시 axis=2(SIP 레벨)를 동일한 패턴을 따라 향후 추가할 수 있다.