Files
kernbench2/tests/test_phase_a_components.py
T
ywkang 63669f82cb Add SIP-level tensor parallelism, component registry YAML, VA offset verification
- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise)
- PE_CPU: auto num_programs from cube shard count
- context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape
- deploy_tensor: removed mmus param, MMU mapping is context-only responsibility
- ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename
- VA offset bench + tests: 2D/1D, standard Triton kernel pattern

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 01:13:17 -07:00

271 lines
8.7 KiB
Python

"""Phase A component infrastructure tests (ADR-0015).
Verifies:
- TransitComponent, IoCpuComponent apply overhead_ns via run()
- HbmCtrlComponent and SramComponent act as terminal nodes (succeed done)
- MCpuComponent forwards when not terminal; completes when terminal + no ctx
- ComponentRegistry resolves impl strings to correct concrete classes
- GraphEngine passes ComponentContext to every component
- ComponentContext.router and .resolver are correctly populated
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import pytest
import simpy
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext
from kernbench.components.builtin import (
HbmCtrlComponent,
IoCpuComponent,
MCpuComponent,
PcieEpComponent,
PositionAwareXbarComponent,
SramComponent,
TransitComponent,
)
from kernbench.sim_engine.engine import GraphEngine
from kernbench.sim_engine.transaction import Transaction
from kernbench.topology.builder import load_topology
from kernbench.topology.types import Node
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
def _node(impl: str, attrs: dict | None = None) -> Node:
return Node(id="test.node", kind="test", impl=impl, attrs=attrs or {}, pos_mm=None)
def _run_worker(comp: ComponentBase, env: simpy.Environment, txn: Transaction) -> None:
"""Wire one in_port, start the component, inject txn, run env until done."""
in_store: simpy.Store = simpy.Store(env)
comp.in_ports["src"] = in_store
comp.start(env)
env.process(_inject(in_store, txn))
env.run(until=txn.done)
def _inject(store: simpy.Store, txn: Transaction):
yield store.put(txn)
# ── 1. run() latency: TransitComponent ───────────────────────────────
def test_transit_component_run_overhead_ns():
"""TransitComponent.run() yields exactly overhead_ns."""
node = _node("forwarding_v1", {"overhead_ns": 7.5})
comp = TransitComponent(node)
env = simpy.Environment()
def proc():
yield from comp.run(env, nbytes=1024)
env.process(proc())
env.run()
assert env.now == pytest.approx(7.5)
def test_transit_component_run_zero_overhead_ns():
"""TransitComponent.run() with overhead_ns=0 completes immediately."""
node = _node("noc_v1", {"overhead_ns": 0.0})
comp = TransitComponent(node)
env = simpy.Environment()
done = []
def proc():
yield from comp.run(env, nbytes=512)
done.append(True)
env.process(proc())
env.run()
assert done == [True]
assert env.now == pytest.approx(0.0)
# ── 2. run() latency: IoCpuComponent ────────────────────────────────
def test_io_cpu_component_run_overhead_ns():
"""IoCpuComponent.run() yields exactly overhead_ns."""
node = _node("io_cpu_v1", {"overhead_ns": 10.0})
comp = IoCpuComponent(node)
env = simpy.Environment()
def proc():
yield from comp.run(env, nbytes=2048)
env.process(proc())
env.run()
assert env.now == pytest.approx(10.0)
# ── 3. Terminal: HbmCtrlComponent succeeds done ──────────────────────
def test_hbm_ctrl_terminal_succeeds_done():
"""HbmCtrlComponent is a terminal node: succeeds txn.done after run()."""
node = _node("hbm_ctrl_v1", {"overhead_ns": 0.0, "capacity": 1})
comp = HbmCtrlComponent(node)
env = simpy.Environment()
done_event = env.event()
txn = Transaction(request=None, path=["test.node"], step=0, nbytes=256, done=done_event)
_run_worker(comp, env, txn)
assert done_event.triggered
def test_hbm_ctrl_resource_serializes_requests():
"""HbmCtrlComponent with capacity=1 serializes concurrent requests."""
node = _node("hbm_ctrl_v1", {"overhead_ns": 5.0, "capacity": 1})
comp = HbmCtrlComponent(node)
env = simpy.Environment()
in_store: simpy.Store = simpy.Store(env)
comp.in_ports["src"] = in_store
comp.start(env)
done1 = env.event()
done2 = env.event()
txn1 = Transaction(request=None, path=["test.node"], step=0, nbytes=0, done=done1)
txn2 = Transaction(request=None, path=["test.node"], step=0, nbytes=0, done=done2)
def inject():
yield in_store.put(txn1)
yield in_store.put(txn2)
env.process(inject())
env.run(until=done2)
# Both must be done; with serialization: t=5 + t=10
assert done1.triggered
assert done2.triggered
assert env.now == pytest.approx(10.0)
# ── 4. Terminal: SramComponent succeeds done ─────────────────────────
def test_sram_terminal_succeeds_done():
"""SramComponent is a terminal node: succeeds txn.done after run()."""
node = _node("sram_v1", {"overhead_ns": 2.0})
comp = SramComponent(node)
env = simpy.Environment()
done_event = env.event()
txn = Transaction(request=None, path=["test.node"], step=0, nbytes=512, done=done_event)
_run_worker(comp, env, txn)
assert done_event.triggered
assert env.now == pytest.approx(2.0)
# ── 5. MCpuComponent: forward when not terminal ──────────────────────
def test_m_cpu_forwards_when_not_terminal():
"""MCpuComponent forwards Transaction to next hop when not terminal."""
node = _node("m_cpu_v1", {"overhead_ns": 5.0})
comp = MCpuComponent(node)
env = simpy.Environment()
# Wire in_port and out_port for a two-hop path [src, test.node, next]
in_store: simpy.Store = simpy.Store(env)
out_store: simpy.Store = simpy.Store(env)
comp.in_ports["src"] = in_store
comp.out_ports["next"] = out_store
comp.start(env)
done_event = env.event()
txn = Transaction(
request=None,
path=["src", "test.node", "next"],
step=1, # currently at test.node; next_hop = "next"
nbytes=128,
done=done_event,
)
forwarded: list[Any] = []
def receiver():
msg = yield out_store.get()
forwarded.append(msg)
msg.done.succeed()
env.process(receiver())
def inject():
yield in_store.put(txn)
env.process(inject())
env.run(until=done_event)
assert len(forwarded) == 1
assert forwarded[0].step == 2 # advanced
assert env.now == pytest.approx(5.0)
# ── 6. MCpuComponent: terminal with no ctx just completes ────────────
def test_m_cpu_terminal_no_ctx_completes():
"""MCpuComponent without ctx completes txn.done when it is the terminal hop."""
node = _node("m_cpu_v1", {"overhead_ns": 0.0})
comp = MCpuComponent(node, ctx=None)
env = simpy.Environment()
done_event = env.event()
txn = Transaction(request=None, path=["test.node"], step=0, nbytes=64, done=done_event)
_run_worker(comp, env, txn)
assert done_event.triggered
# ── 7. ComponentRegistry resolves impl strings ───────────────────────
@pytest.mark.parametrize("impl,expected_cls", [
("forwarding_v1", TransitComponent),
("noc_v1", TransitComponent),
("ucie_v1", TransitComponent),
("xbar_v1", PositionAwareXbarComponent),
("pcie_ep_v1", PcieEpComponent),
("io_cpu_v1", IoCpuComponent),
("m_cpu_v1", MCpuComponent),
("hbm_ctrl_v1", HbmCtrlComponent),
("sram_v1", SramComponent),
])
def test_registry_resolves_impl(impl, expected_cls):
"""ComponentRegistry.create() returns the correct concrete class for each impl."""
node = _node(impl, {"overhead_ns": 0.0})
comp = ComponentRegistry.create(node)
assert isinstance(comp, expected_cls)
# ── 8. GraphEngine passes ComponentContext to components ─────────────
def test_engine_passes_ctx_to_components():
"""GraphEngine injects a non-None ComponentContext into every component."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
for node_id, comp in engine._components.items():
assert comp.ctx is not None, f"{node_id}: ctx is None"
assert isinstance(comp.ctx, ComponentContext), f"{node_id}: ctx wrong type"
def test_engine_ctx_router_and_resolver_populated():
"""ComponentContext.router and .resolver are PathRouter / AddressResolver instances."""
from kernbench.policy.routing.router import AddressResolver, PathRouter
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
# Spot-check one component
first_comp = next(iter(engine._components.values()))
assert isinstance(first_comp.ctx.router, PathRouter)
assert isinstance(first_comp.ctx.resolver, AddressResolver)