ywkang 105f1dc09e ADR-0027: Megatron TP API + worker-wait generalization + mp.spawn
Implements ADR-0027 Phase 2 end-to-end. All 559 tests pass (was 523 +
1 xfail; ring_default_ws strict-xfail is now resolved).

D0 — Worker-wait generalization (context.py):
- _pending_worker_waits queue on RuntimeContext.
- ctx.wait(h) in worker context defers to main via g.parent.switch().
  Fast-path for already-completed handles.
- Worker API is unchanged: tensor deploy, launch, etc. still look
  synchronous; they're transparently cooperatively scheduled.
- Solves ADR-0024 Phase B kernel-greenlet orphan bug (env.run now
  only ever drives from main; kernel _parent is always main).

D0.5 — Host-read barrier (tensor.py):
- Explicit _HOST_READ_BARRIERS registry (T5.g closed-set via code
  review, not reflection-magic).
- numpy/data/__getitem__/__repr__ drain pending worker-waits before
  host-observable read.
- copy_: source-side barrier via source.numpy(). Target-side write
  barrier is intentionally NOT applied — global pending target barrier
  prematurely drains cross-rank collectives → deadlock.
- Collective pending is excluded from barrier drain condition
  (collective is cross-rank; its own yield in all_reduce covers the
  invariant naturally).

D1 — torch.multiprocessing.spawn (runtime_api/multiprocessing.py):
- API signature parity with real PyTorch spawn; execution is
  cooperative greenlet scheduler (process isolation etc. are explicit
  non-goals per D1.0).
- _drain_pending drains worker-waits then collectives in one barrier,
  loop-until-empty.
- Round-based exception handling with SystemExit sibling abort +
  SpawnException(errors) wrapping root-cause ranks.
- RuntimeContext attaches ctx.multiprocessing in __post_init__.
- benches/ccl_allreduce.py hand-rolled loop collapses to one
  torch.multiprocessing.spawn call.

D2–D6 — kernbench.tp package:
- parallel_state: initialize_model_parallel, get_*_rank,
  get_*_world_size, with weak active-ctx registry in context.py.
- layers: ColumnParallelLinear, RowParallelLinear (shape-only
  primitives — fp16 gemm via tl.load + tl.dot + tl.store).
- kernels: _gemm_kernel used by TP layers (self-contained; no bench
  dependency).
- primitives / mappings stubs per D6/D8.

Data-path fixes (surfaced by TP gemm + all_reduce sequence):
- sim_engine/op_log.py: dma_write snapshot is skipped for TCM
  sources (PE scratch is repopulated by Phase 2 math/gemm replay —
  capturing Phase-1-time snapshot picked up STALE data from prior
  kernel's output aliased at the same scratch addr, causing the later
  kernel's dma_write to overwrite Phase 2 result with stale value).
- sim_engine/op_log.py + sim_engine/data_executor.py: per-operand
  space recorded on GemmCmd and composite gemm records so HBM-resident
  operands (tl.load output) don't default to TCM during replay.
- runtime_api/context.py: ctx.zeros writes zero-init to MemoryStore
  at VA keys so kernels reading via VA see deterministic init even
  without explicit copy_().

Tests (Phase 1 + Phase 2):
- test_worker_wait_drain (T3): orphan invariant + resume + multi-rank
  drain + idempotency + exception propagation.
- test_mp_spawn (T4): spawn shape + bind + SpawnException scope.
- test_host_read_barrier (T5): barrier contract per entry-point +
  closed-set registry check.
- test_tp_parallel_state (T1): initialize + rank lookup.
- test_tp_layers (T2): shape + deterministic numerical correctness
  (concat-matmul equality for RowParallel, not mean-only).
- test_tp_mlp (T6): full 2-layer MLP with deterministic weight
  numerical match + rank-consistency post all-reduce.
- test_ccl_allreduce_matrix: ring_default_ws xfail removed (T7).

Regression: 523 pre + 35 new + 1 ex-xfail = 559 passed, 1 intentional
skip (T3.e historical failure documentation).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-14 16:31:13 -07:00
2026-03-18 11:47:48 -07:00
2026-03-18 11:47:48 -07:00
2026-03-18 11:47:48 -07:00
2026-03-18 11:47:48 -07:00
2026-03-18 11:47:48 -07:00

kernbench

A discrete-event simulator for AI accelerator hardware, built on SimPy. It models the full data path — from host PCIe injection through IO chiplet, NOC mesh, crossbar, and HBM — to measure end-to-end latency with contention and queueing.

Architecture

Host (CLI)
  |
  +-- kernbench run     -> run a benchmark (QKV GEMM, AllReduce, ...)
  +-- kernbench probe   -> latency/BW analysis for predefined traffic patterns
  |
  v
+---------------------------------------------------+
|  Runtime API          (runtime_api/)              |
|  MemoryWriteMsg, MemoryReadMsg, PeDmaMsg,         |
|  KernelLaunchMsg                                  |
+---------------------------------------------------+
|  Simulation Engine    (sim_engine/)               |
|  SimPy processes, wire model, BW occupancy        |
+---------------------------------------------------+
|  Components           (components/)               |
|  pcie_ep, io_cpu, m_cpu, noc, xbar, hbm_ctrl,    |
|  pe_cpu, pe_dma, pe_gemm, pe_math, pe_tcm, ...   |
+---------------------------------------------------+
|  Topology             (topology/)                 |
|  YAML-driven graph: 4x4 cube mesh, UCIe links,   |
|  IO chiplet with NOC, HBM slices                  |
+---------------------------------------------------+

Prerequisites

  • Python 3.10+
  • Dependencies: simpy, pyyaml, pytest

Installation

# Create virtual environment
python -m venv .venv

# Activate (Windows)
.venv\Scripts\activate

# Activate (Linux/macOS)
source .venv/bin/activate

# Install in editable mode
pip install -e ".[dev]"

Usage

Probe — Latency and Bandwidth Analysis

The probe command runs predefined traffic patterns (H2D write, D2H read, PE DMA) and reports latency breakdown, bottleneck bandwidth, and utilization.

# Run all probe cases
kernbench probe --topology topology.yaml

# Run a specific case
kernbench probe --topology topology.yaml --case pe-local-hbm

Output includes:

  • Summary tables — actual latency, overhead/drain/wire breakdown, effective BW, utilization
  • BW saturation sweep — utilization at 4KB through 1MB to show saturation threshold
  • Per-hop route traces — cumulative timestamps at every node along the path

Run — Execute a Benchmark

# Run a benchmark on all devices
kernbench run --topology topology.yaml --bench qkv_gemm

# Run on a specific device
kernbench run --topology topology.yaml --bench qkv_gemm --device sip:0

Available benchmarks (in benches/):

  • qkv_gemm — single-PE QKV GEMM
  • qkv_gemm_multi_pe — multi-PE QKV GEMM
  • ipcq_allreduce — IPCQ AllReduce

Tests

# Run all tests (278 tests)
pytest

# Run a specific test file
pytest tests/test_probe.py -v

# Run a single test
pytest tests/test_probe.py::test_h2d_latency_monotonic -v

# Run with output shown
pytest -s tests/test_probe.py

Key test files:

File Coverage
test_probe.py Probe latency invariants, monotonicity, determinism, BW sweep
test_engine.py SimPy engine: submit/wait/complete, routing, multi-SIP
test_bw_occupancy.py Wire BW contention, HOL blocking, back-to-back serialization
test_iochiplet_noc_d2h.py IO chiplet NOC topology, H2D/D2H data paths
test_noc_mesh.py 2D mesh NOC routing, Manhattan distance
test_pe_components.py PE-internal components: cpu, scheduler, dma, gemm
test_routing.py XY routing, address resolution, path finding
test_topology_compile.py YAML topology compilation, node/edge validation

Topology Configuration

The system is configured via topology.yaml. Key parameters:

Parameter Default Description
ns_per_mm 0.01 Wire propagation delay (10 ps/mm)
cube_mesh 4x4 Cube grid dimensions per SIP
ucie.overhead_ns 8.0 UCIe protocol overhead per port (16ns per crossing)
hbm_ctrl.efficiency 0.8 HBM effective BW factor (256 to 204.8 GB/s)
xbar.overhead_ns 2.0 Crossbar arbitration delay
xbar_to_hbm_bw_gbs 256.0 Raw HBM bandwidth per slice

Project Structure

kernbench/
+-- src/kernbench/
|   +-- cli/            # CLI entry points (main, probe, report)
|   +-- common/         # Shared types (Completion, RequestHandle, Trace)
|   +-- components/     # Hardware component models (SimPy processes)
|   +-- di/             # Dependency injection
|   +-- policy/         # Routing (XY), address decoding (PhysAddr)
|   +-- runtime_api/    # Host-facing API (messages, bench runner)
|   +-- sim_engine/     # Discrete-event engine, transaction, wire model
|   +-- topology/       # YAML builder, mesh generator, graph types
|   +-- triton_emu/     # Triton kernel emulation
+-- benches/            # Benchmark implementations
+-- tests/              # pytest test suite (278 tests)
+-- docs/               # ADRs, latency model docs, diagrams
+-- topology.yaml       # System topology configuration
+-- CHANGES.md          # Changelog

Documentation

S
Description
No description provided
Readme 13 MiB
Languages
Python 96%
HTML 4%