commit - release 1

This commit is contained in:
2026-03-18 11:47:48 -07:00
commit 6f43807900
109 changed files with 14909 additions and 0 deletions
+22
View File
@@ -0,0 +1,22 @@
import kernbench.cli.main as cli_main
def test_cli_main_arg_parsing(monkeypatch):
def fake_cmd_run(args) -> int:
assert args.cmd == "run"
assert args.topology == "topology.yaml"
assert args.bench == "qkv_gemm"
assert args.device == None
return 0
# monkey patch the handler to test arg parsing without running the actual bench
monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run)
rc = cli_main.main(["run", "--topology", "topology.yaml", "--bench", "qkv_gemm"])
assert rc == 0
def test_cli_main():
rc = cli_main.main(["run", "--topology", "topology.yaml", "--bench", "qkv_gemm"])
assert rc == 0
+187
View File
@@ -0,0 +1,187 @@
"""Tests for the SimPy component model and DI registry (ADR-0007 D3).
Phase 1 verification: all tests FAIL until Phase 2 implements production code.
Latency invariant after refactor:
total_ns = Σ(wire propagation) + Σ(component.run() overhead_ns) + nbytes / bottleneck_bw
This is identical to the current formula for Phase 0 (no contention).
"""
import pytest
import simpy
from pathlib import Path
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.impls.forwarding import TransitComponent
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import MemoryReadMsg
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import load_topology
from kernbench.topology.types import Node
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
def _graph():
return load_topology(TOPOLOGY_PATH)
def _hbm_pa(pe_id: int = 0) -> int:
slice_bytes = 48 * (1 << 30) // 8
pa = PhysAddr.pe_hbm_addr(
rack_id=0, sip_id=0, cube_id=0, pe_id=pe_id,
pe_local_hbm_offset=0x1000, slice_size_bytes=slice_bytes,
)
return pa.encode()
def _node(impl: str, overhead_ns: float = 0.0) -> Node:
return Node(id="test", kind="xbar", impl=impl, attrs={"overhead_ns": overhead_ns}, pos_mm=None)
# ── 1. unknown impl → error ──────────────────────────────────────────
def test_registry_unknown_impl_raises_error():
"""Unregistered impl raises ValueError (no fallback)."""
node = _node("totally_unknown_v99", overhead_ns=5.0)
with pytest.raises(ValueError, match="No component registered"):
ComponentRegistry.create(node)
# ── 2. TransitComponent yields exactly overhead_ns via simpy timeout ──
def test_transit_component_yields_overhead_ns():
"""TransitComponent.run() yields exactly node.attrs['overhead_ns'] ns."""
node = _node("xbar_v1", overhead_ns=3.0)
comp = TransitComponent(node)
env = simpy.Environment()
def proc():
yield from comp.run(env, nbytes=4096)
env.process(proc())
env.run()
assert env.now == pytest.approx(3.0)
def test_transit_component_zero_overhead_ns():
"""TransitComponent with overhead_ns=0 still yields (no infinite loop)."""
node = _node("noc_v1", overhead_ns=0.0)
comp = TransitComponent(node)
env = simpy.Environment()
done = []
def proc():
yield from comp.run(env, nbytes=1024)
done.append(True)
env.process(proc())
env.run()
assert done == [True]
assert env.now == pytest.approx(0.0)
# ── 3. DI override: custom component is invoked by engine ────────────
def test_engine_component_override_is_called():
"""Custom component injected via component_overrides is invoked during simulation."""
class SpyXbar(ComponentBase):
calls = 0
def run(self, env, nbytes):
SpyXbar.calls += 1
yield env.timeout(0)
SpyXbar.calls = 0
graph = _graph()
engine = GraphEngine(graph, component_overrides={"xbar_v1": SpyXbar})
msg = MemoryReadMsg(
correlation_id="c", request_id="r",
src_sip=0, src_cube=0, src_pe=0,
src_pa=_hbm_pa(pe_id=0), nbytes=4096,
)
h = engine.submit(msg)
engine.wait(h)
# PE0→slice0 path passes through xbar.pe0 (impl=xbar_v1)
assert SpyXbar.calls > 0
# ── 4. behavior unchanged: total_ns matches existing formula ─────────
def test_engine_component_model_same_latency_as_before():
"""Phase B component model total_ns for PE0→slice0 local HBM (4096B).
Cut-through (wormhole) wire model: wires apply propagation only.
Serialization (drain) is computed per-path and applied once at the terminal.
Forward path:
Path 1: pcie_ep(5.0) + wire(1.0mm=0.01) + io_cpu(10.0)
Path 2: wire(3.5mm=0.035) + ucie-N(1.0)
+ 2DMeshNOC(ucie-N→m_cpu: Manhattan 10.9mm=0.109) + m_cpu(5.0)
Path 3 DMA (m_cpu→noc→xbar.pe0→hbm_ctrl.slice0):
+ 2DMeshNOC(m_cpu→xbar.pe0: Manhattan 15.0mm=0.15)
+ xbar.pe0(2.0) + wire(2.5mm=0.025) + hbm_ctrl(0.0)
+ drain_ns(4096/128 = 32.0, bottleneck = noc_to_xbar 128 GB/s)
Response path (reverse, nbytes=0, drain=0):
DMA response: hbm_ctrl→xbar.pe0→noc→m_cpu (propagation + xbar overhead_ns)
Command response: m_cpu→noc→ucie-N→io_cpu (propagation + ucie overhead_ns)
Total: ~58.648 ns
"""
graph = _graph()
engine = GraphEngine(graph)
msg = MemoryReadMsg(
correlation_id="c", request_id="r",
src_sip=0, src_cube=0, src_pe=0,
src_pa=_hbm_pa(pe_id=0), nbytes=4096,
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
assert trace["total_ns"] == pytest.approx(58.648, rel=1e-4)
# ── 5. override is scoped: only targeted impl is replaced ────────────
def test_engine_override_is_scoped_to_impl():
"""xbar_v1 override (ZeroXbar, no overhead_ns) reduces total_ns by exactly 4.0 ns.
xbar.pe0 has overhead_ns=2.0. It is traversed on both the forward DMA path
and the reverse response path, so replacing it with a zero-latency impl
removes 2.0 ns × 2 = 4.0 ns; all other components are unchanged.
"""
class ZeroXbar(ComponentBase):
def run(self, env, nbytes):
yield env.timeout(0)
graph = _graph()
engine_default = GraphEngine(graph)
engine_override = GraphEngine(graph, component_overrides={"xbar_v1": ZeroXbar})
msg = MemoryReadMsg(
correlation_id="c", request_id="r",
src_sip=0, src_cube=0, src_pe=0,
src_pa=_hbm_pa(pe_id=0), nbytes=4096,
)
h_d = engine_default.submit(msg)
engine_default.wait(h_d)
_, t_default = engine_default.get_completion(h_d)
h_o = engine_override.submit(msg)
engine_override.wait(h_o)
_, t_override = engine_override.get_completion(h_o)
# ZeroXbar removes overhead_ns=2.0 from xbar.pe0 on forward + response = 4.0 ns faster
assert t_override["total_ns"] < t_default["total_ns"]
assert t_default["total_ns"] - t_override["total_ns"] == pytest.approx(4.0, rel=1e-6)
+405
View File
@@ -0,0 +1,405 @@
import pytest
from pathlib import Path
from kernbench.common.types import Completion, RequestHandle
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import (
KernelLaunchMsg,
KernelRef,
MemoryReadMsg,
MemoryWriteMsg,
ScalarArg,
TensorArg,
TensorArgShard,
)
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import load_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
def _engine():
graph = load_topology(TOPOLOGY_PATH)
return GraphEngine(graph)
def _hbm_pa(sip: int = 0, cube: int = 0, pe_id: int = 0) -> int:
"""Create an HBM physical address targeting a specific PE's HBM slice."""
# 48 GB / 8 slices = 6 GB per slice
slice_bytes = 48 * (1 << 30) // 8
pa = PhysAddr.pe_hbm_addr(
rack_id=0, sip_id=sip, cube_id=cube, pe_id=pe_id,
pe_local_hbm_offset=0x1000, slice_size_bytes=slice_bytes,
)
return pa.encode()
def _sram_pa(sip: int = 0, cube: int = 0) -> int:
"""Create an SRAM physical address."""
pa = PhysAddr.cube_sram_addr(rack_id=0, sip_id=sip, cube_id=cube, sram_offset=0x800)
return pa.encode()
# ── 1. submit returns handle ────────────────────────────────────────
def test_engine_submit_returns_handle():
"""submit() must return a RequestHandle (non-empty string)."""
engine = _engine()
msg = MemoryWriteMsg(
correlation_id="c0", request_id="r0",
dst_sip=0, dst_cube=0, dst_pe=0,
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
)
handle = engine.submit(msg)
assert isinstance(handle, str)
assert len(handle) > 0
# ── 2. memory write completion ──────────────────────────────────────
def test_engine_memory_write_completion():
"""MemoryWrite must complete with ok=True."""
engine = _engine()
msg = MemoryWriteMsg(
correlation_id="c0", request_id="r1",
dst_sip=0, dst_cube=0, dst_pe=0,
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
)
h = engine.submit(msg)
engine.wait(h)
comp, trace = engine.get_completion(h)
assert comp.ok is True
# ── 3. memory read completion ───────────────────────────────────────
def test_engine_memory_read_completion():
"""MemoryRead must complete with ok=True."""
engine = _engine()
msg = MemoryReadMsg(
correlation_id="c0", request_id="r2",
src_sip=0, src_cube=0, src_pe=0,
src_pa=_hbm_pa(), nbytes=4096,
)
h = engine.submit(msg)
engine.wait(h)
comp, trace = engine.get_completion(h)
assert comp.ok is True
# ── 4. latency positive ────────────────────────────────────────────
def test_engine_latency_positive():
"""Trace total_ns must be > 0 (ADR-0002 D4)."""
engine = _engine()
msg = MemoryWriteMsg(
correlation_id="c0", request_id="r3",
dst_sip=0, dst_cube=0, dst_pe=0,
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
assert trace["total_ns"] > 0
# ── 5. trace has total_ns and nbytes ───────────────────────────────
def test_engine_trace_has_total_ns_and_nbytes():
"""Trace must contain 'total_ns' and 'nbytes'."""
engine = _engine()
msg = MemoryWriteMsg(
correlation_id="c0", request_id="r4",
dst_sip=0, dst_cube=0, dst_pe=0,
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
assert "total_ns" in trace
assert "nbytes" in trace
assert trace["nbytes"] == 4096
# ── 6. latency includes node overhead_ns ────────────────────────────
def test_engine_latency_includes_node_overhead_ns():
"""Path traverses components with overhead_ns > 0, so total >= some minimum."""
engine = _engine()
msg = MemoryWriteMsg(
correlation_id="c0", request_id="r7",
dst_sip=0, dst_cube=0, dst_pe=0,
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
# pcie_ep (5.0) + io_cpu (10.0) + m_cpu (5.0) = at least 20 ns
assert trace["total_ns"] >= 20.0
# ── 7. concurrent requests ─────────────────────────────────────────
def test_engine_concurrent_requests():
"""Two requests submitted before wait must both complete with traces."""
engine = _engine()
msg1 = MemoryWriteMsg(
correlation_id="c0", request_id="r9a",
dst_sip=0, dst_cube=0, dst_pe=0,
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
)
msg2 = MemoryWriteMsg(
correlation_id="c0", request_id="r9b",
dst_sip=0, dst_cube=0, dst_pe=1,
dst_pa=_hbm_pa(pe_id=1), nbytes=4096, pattern="zero",
)
h1 = engine.submit(msg1)
h2 = engine.submit(msg2)
engine.wait(h1)
engine.wait(h2)
comp1, trace1 = engine.get_completion(h1)
comp2, trace2 = engine.get_completion(h2)
assert comp1.ok is True
assert comp2.ok is True
assert trace1["total_ns"] > 0
assert trace2["total_ns"] > 0
# ── 8. kernel launch ───────────────────────────────────────────────
def test_engine_kernel_launch_simplified():
"""KernelLaunch returns latency > 0."""
from kernbench.triton_emu.registry import clear_registry, register_kernel
clear_registry()
hbm_pa = _hbm_pa(pe_id=0)
def gemm_kernel(a_ptr, tl):
a = tl.load(a_ptr, shape=(4, 4), dtype="f16")
tl.store(a_ptr, a)
register_kernel("gemm", gemm_kernel)
engine = _engine()
shard0 = TensorArgShard(
sip=0, cube=0, pe=0,
pa=_hbm_pa(pe_id=0), nbytes=4096, offset_bytes=0,
)
shard1 = TensorArgShard(
sip=0, cube=0, pe=1,
pa=_hbm_pa(pe_id=1), nbytes=4096, offset_bytes=4096,
)
msg = KernelLaunchMsg(
correlation_id="c0", request_id="r10",
kernel_ref=KernelRef(name="gemm", kind="builtin"),
args=(TensorArg(shards=(shard0, shard1)),),
)
h = engine.submit(msg)
engine.wait(h)
comp, trace = engine.get_completion(h)
assert comp.ok is True
assert trace["total_ns"] > 0
clear_registry()
# ── 9. deterministic ───────────────────────────────────────────────
def test_engine_deterministic():
"""Same request on two engines must produce identical latency."""
msg = MemoryWriteMsg(
correlation_id="c0", request_id="r11",
dst_sip=0, dst_cube=0, dst_pe=0,
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
)
e1 = _engine()
h1 = e1.submit(msg)
e1.wait(h1)
_, t1 = e1.get_completion(h1)
e2 = _engine()
h2 = e2.submit(msg)
e2.wait(h2)
_, t2 = e2.get_completion(h2)
assert t1["total_ns"] == t2["total_ns"]
# ── 10. remote cube access succeeds with higher latency ────────────
def test_dma_capacity_serializes_concurrent():
"""Two concurrent DMA writes to the same cube must contend at DMA capacity=1.
When two MemoryWrite requests target the same cube's M_CPU simultaneously,
the DMA engine (capacity=1) serializes them. The slower request must take
longer than a single isolated request (ADR-0014 D4, ADR-0015 D5).
"""
# Single isolated write baseline
engine_single = _engine()
msg_single = MemoryWriteMsg(
correlation_id="c0", request_id="single",
dst_sip=0, dst_cube=0, dst_pe=0,
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
pattern="zero", target_pe=0,
)
h1 = engine_single.submit(msg_single)
engine_single.wait(h1)
_, t1 = engine_single.get_completion(h1)
single_ns = t1["total_ns"]
# Two concurrent writes to same cube (different PEs) → DMA contention
engine_conc = _engine()
msg_a = MemoryWriteMsg(
correlation_id="c0", request_id="conc-a",
dst_sip=0, dst_cube=0, dst_pe=0,
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
pattern="zero", target_pe=0,
)
msg_b = MemoryWriteMsg(
correlation_id="c0", request_id="conc-b",
dst_sip=0, dst_cube=0, dst_pe=1,
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=1), nbytes=4096,
pattern="zero", target_pe=1,
)
ha = engine_conc.submit(msg_a)
hb = engine_conc.submit(msg_b)
engine_conc.wait(ha)
engine_conc.wait(hb)
_, ta = engine_conc.get_completion(ha)
_, tb = engine_conc.get_completion(hb)
# At least one must be delayed by DMA contention
max_ns = max(ta["total_ns"], tb["total_ns"])
assert max_ns > single_ns, (
f"concurrent max ({max_ns:.2f}ns) must > single ({single_ns:.2f}ns) "
f"due to DMA capacity=1 contention"
)
# ── 11. formula latency lower bound ──────────────────────────────
def test_formula_latency_lower_bound():
"""_formula_latency must be <= actual latency (ADR-0015 D7).
Uses PE DMA path which is fully known at engine level.
"""
from kernbench.policy.address.phyaddr import PhysAddr as PA
from kernbench.policy.routing.router import AddressResolver, PathRouter
from kernbench.topology.builder import load_topology as lt
graph = lt(TOPOLOGY_PATH)
engine = GraphEngine(graph)
resolver = AddressResolver(graph)
router = PathRouter(graph)
pa = _hbm_pa(sip=0, cube=0, pe_id=1)
pa_obj = PA.decode(pa)
dst_node = resolver.resolve(pa_obj)
pe_ref = "sip0.cube0.pe0"
path = router.find_path(pe_ref, dst_node)
formula = engine._formula_latency(path, 4096)
# Run actual simulation
msg = MemoryReadMsg(
correlation_id="c0", request_id="formula-lb",
src_sip=0, src_cube=0, src_pe=0,
src_pa=pa, nbytes=4096, target_pe=1,
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
actual = trace["total_ns"]
assert formula <= actual, (
f"formula ({formula:.2f}) must <= actual ({actual:.2f})"
)
assert formula > 0, "formula must be > 0"
def test_formula_latency_exact_no_contention():
"""With no contention, formula should approximate actual for PE DMA.
PE DMA is single-request with no fan-out or aggregation,
so formula ≈ actual (within small tolerance for SimPy scheduling).
"""
from kernbench.runtime_api.kernel import PeDmaMsg
from kernbench.policy.address.phyaddr import PhysAddr as PA
from kernbench.policy.routing.router import AddressResolver, PathRouter
from kernbench.topology.builder import load_topology as lt
graph = lt(TOPOLOGY_PATH)
engine = GraphEngine(graph)
resolver = AddressResolver(graph)
router = PathRouter(graph)
pa = _hbm_pa(sip=0, cube=0, pe_id=0)
pa_obj = PA.decode(pa)
dst_node = resolver.resolve(pa_obj)
pe_ref = "sip0.cube0.pe0"
path = router.find_path(pe_ref, dst_node)
formula = engine._formula_latency(path, 4096)
msg = PeDmaMsg(
correlation_id="c0", request_id="formula-exact",
src_sip=0, src_cube=0, src_pe=0,
dst_pa=pa, nbytes=4096,
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
actual = trace["total_ns"]
# No contention: formula should equal actual
assert abs(formula - actual) < 0.01, (
f"formula ({formula:.4f}) ≈ actual ({actual:.4f}) expected with no contention"
)
# ── 10. remote cube access succeeds with higher latency ────────────
def test_engine_remote_cube_latency_higher():
"""Accessing a distant cube's HBM must have strictly higher latency than local.
Uses separate engines to avoid contention effects.
cube15 (far corner of 4x4 mesh) requires multiple UCIe + NOC hops
from IO chiplet compared to cube0 (directly connected).
"""
engine_local = _engine()
engine_remote = _engine()
msg_local = MemoryReadMsg(
correlation_id="c0", request_id="r14a",
src_sip=0, src_cube=0, src_pe=0,
src_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
)
msg_remote = MemoryReadMsg(
correlation_id="c0", request_id="r14b",
src_sip=0, src_cube=0, src_pe=0,
src_pa=_hbm_pa(sip=0, cube=15, pe_id=0), nbytes=4096,
)
h_local = engine_local.submit(msg_local)
engine_local.wait(h_local)
_, t_local = engine_local.get_completion(h_local)
h_remote = engine_remote.submit(msg_remote)
engine_remote.wait(h_remote)
comp_remote, t_remote = engine_remote.get_completion(h_remote)
assert comp_remote.ok is True
assert t_remote is not None and t_local is not None
assert t_remote["total_ns"] > t_local["total_ns"], (
f"remote cube {t_remote['total_ns']:.2f} must > local {t_local['total_ns']:.2f}"
)
File diff suppressed because it is too large Load Diff
+269
View File
@@ -0,0 +1,269 @@
"""Phase A component infrastructure tests (ADR-0015).
Verifies:
- TransitComponent, IoCpuComponent apply overhead_ns via run()
- HbmCtrlComponent and SramComponent act as terminal nodes (succeed done)
- MCpuComponent forwards when not terminal; completes when terminal + no ctx
- ComponentRegistry resolves impl strings to correct concrete classes
- GraphEngine passes ComponentContext to every component
- ComponentContext.router and .resolver are correctly populated
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import pytest
import simpy
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext
from kernbench.components.impls import (
HbmCtrlComponent,
IoCpuComponent,
MCpuComponent,
PcieEpComponent,
SramComponent,
TransitComponent,
)
from kernbench.sim_engine.engine import GraphEngine
from kernbench.sim_engine.transaction import Transaction
from kernbench.topology.builder import load_topology
from kernbench.topology.types import Node
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
def _node(impl: str, attrs: dict | None = None) -> Node:
return Node(id="test.node", kind="test", impl=impl, attrs=attrs or {}, pos_mm=None)
def _run_worker(comp: ComponentBase, env: simpy.Environment, txn: Transaction) -> None:
"""Wire one in_port, start the component, inject txn, run env until done."""
in_store: simpy.Store = simpy.Store(env)
comp.in_ports["src"] = in_store
comp.start(env)
env.process(_inject(in_store, txn))
env.run(until=txn.done)
def _inject(store: simpy.Store, txn: Transaction):
yield store.put(txn)
# ── 1. run() latency: TransitComponent ───────────────────────────────
def test_transit_component_run_overhead_ns():
"""TransitComponent.run() yields exactly overhead_ns."""
node = _node("forwarding_v1", {"overhead_ns": 7.5})
comp = TransitComponent(node)
env = simpy.Environment()
def proc():
yield from comp.run(env, nbytes=1024)
env.process(proc())
env.run()
assert env.now == pytest.approx(7.5)
def test_transit_component_run_zero_overhead_ns():
"""TransitComponent.run() with overhead_ns=0 completes immediately."""
node = _node("noc_v1", {"overhead_ns": 0.0})
comp = TransitComponent(node)
env = simpy.Environment()
done = []
def proc():
yield from comp.run(env, nbytes=512)
done.append(True)
env.process(proc())
env.run()
assert done == [True]
assert env.now == pytest.approx(0.0)
# ── 2. run() latency: IoCpuComponent ────────────────────────────────
def test_io_cpu_component_run_overhead_ns():
"""IoCpuComponent.run() yields exactly overhead_ns."""
node = _node("io_cpu_v1", {"overhead_ns": 10.0})
comp = IoCpuComponent(node)
env = simpy.Environment()
def proc():
yield from comp.run(env, nbytes=2048)
env.process(proc())
env.run()
assert env.now == pytest.approx(10.0)
# ── 3. Terminal: HbmCtrlComponent succeeds done ──────────────────────
def test_hbm_ctrl_terminal_succeeds_done():
"""HbmCtrlComponent is a terminal node: succeeds txn.done after run()."""
node = _node("hbm_ctrl_v1", {"overhead_ns": 0.0, "capacity": 1})
comp = HbmCtrlComponent(node)
env = simpy.Environment()
done_event = env.event()
txn = Transaction(request=None, path=["test.node"], step=0, nbytes=256, done=done_event)
_run_worker(comp, env, txn)
assert done_event.triggered
def test_hbm_ctrl_resource_serializes_requests():
"""HbmCtrlComponent with capacity=1 serializes concurrent requests."""
node = _node("hbm_ctrl_v1", {"overhead_ns": 5.0, "capacity": 1})
comp = HbmCtrlComponent(node)
env = simpy.Environment()
in_store: simpy.Store = simpy.Store(env)
comp.in_ports["src"] = in_store
comp.start(env)
done1 = env.event()
done2 = env.event()
txn1 = Transaction(request=None, path=["test.node"], step=0, nbytes=0, done=done1)
txn2 = Transaction(request=None, path=["test.node"], step=0, nbytes=0, done=done2)
def inject():
yield in_store.put(txn1)
yield in_store.put(txn2)
env.process(inject())
env.run(until=done2)
# Both must be done; with serialization: t=5 + t=10
assert done1.triggered
assert done2.triggered
assert env.now == pytest.approx(10.0)
# ── 4. Terminal: SramComponent succeeds done ─────────────────────────
def test_sram_terminal_succeeds_done():
"""SramComponent is a terminal node: succeeds txn.done after run()."""
node = _node("sram_v1", {"overhead_ns": 2.0})
comp = SramComponent(node)
env = simpy.Environment()
done_event = env.event()
txn = Transaction(request=None, path=["test.node"], step=0, nbytes=512, done=done_event)
_run_worker(comp, env, txn)
assert done_event.triggered
assert env.now == pytest.approx(2.0)
# ── 5. MCpuComponent: forward when not terminal ──────────────────────
def test_m_cpu_forwards_when_not_terminal():
"""MCpuComponent forwards Transaction to next hop when not terminal."""
node = _node("m_cpu_v1", {"overhead_ns": 5.0})
comp = MCpuComponent(node)
env = simpy.Environment()
# Wire in_port and out_port for a two-hop path [src, test.node, next]
in_store: simpy.Store = simpy.Store(env)
out_store: simpy.Store = simpy.Store(env)
comp.in_ports["src"] = in_store
comp.out_ports["next"] = out_store
comp.start(env)
done_event = env.event()
txn = Transaction(
request=None,
path=["src", "test.node", "next"],
step=1, # currently at test.node; next_hop = "next"
nbytes=128,
done=done_event,
)
forwarded: list[Any] = []
def receiver():
msg = yield out_store.get()
forwarded.append(msg)
msg.done.succeed()
env.process(receiver())
def inject():
yield in_store.put(txn)
env.process(inject())
env.run(until=done_event)
assert len(forwarded) == 1
assert forwarded[0].step == 2 # advanced
assert env.now == pytest.approx(5.0)
# ── 6. MCpuComponent: terminal with no ctx just completes ────────────
def test_m_cpu_terminal_no_ctx_completes():
"""MCpuComponent without ctx completes txn.done when it is the terminal hop."""
node = _node("m_cpu_v1", {"overhead_ns": 0.0})
comp = MCpuComponent(node, ctx=None)
env = simpy.Environment()
done_event = env.event()
txn = Transaction(request=None, path=["test.node"], step=0, nbytes=64, done=done_event)
_run_worker(comp, env, txn)
assert done_event.triggered
# ── 7. ComponentRegistry resolves impl strings ───────────────────────
@pytest.mark.parametrize("impl,expected_cls", [
("forwarding_v1", TransitComponent),
("noc_v1", TransitComponent),
("ucie_v1", TransitComponent),
("xbar_v1", TransitComponent),
("pcie_ep_v1", PcieEpComponent),
("io_cpu_v1", IoCpuComponent),
("m_cpu_v1", MCpuComponent),
("hbm_ctrl_v1", HbmCtrlComponent),
("sram_v1", SramComponent),
])
def test_registry_resolves_impl(impl, expected_cls):
"""ComponentRegistry.create() returns the correct concrete class for each impl."""
node = _node(impl, {"overhead_ns": 0.0})
comp = ComponentRegistry.create(node)
assert isinstance(comp, expected_cls)
# ── 8. GraphEngine passes ComponentContext to components ─────────────
def test_engine_passes_ctx_to_components():
"""GraphEngine injects a non-None ComponentContext into every component."""
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
for node_id, comp in engine._components.items():
assert comp.ctx is not None, f"{node_id}: ctx is None"
assert isinstance(comp.ctx, ComponentContext), f"{node_id}: ctx wrong type"
def test_engine_ctx_router_and_resolver_populated():
"""ComponentContext.router and .resolver are PathRouter / AddressResolver instances."""
from kernbench.policy.routing.router import AddressResolver, PathRouter
graph = load_topology(TOPOLOGY_PATH)
engine = GraphEngine(graph)
# Spot-check one component
first_comp = next(iter(engine._components.values()))
assert isinstance(first_comp.ctx.router, PathRouter)
assert isinstance(first_comp.ctx.resolver, AddressResolver)
+268
View File
@@ -0,0 +1,268 @@
import pytest
from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator
from kernbench.policy.address.phyaddr import PhysAddr, PhysAddrError, UnitType
_MB = 1 << 20
_GB = 1 << 30
# Topology-matching config: 48GB HBM / 8 slices / 16MB TCM / 4MB reserved / 32MB SRAM
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=16,
pes_per_cube=8,
hbm_bytes_per_cube=48 * _GB,
hbm_slices_per_cube=8,
tcm_bytes_per_pe=16 * _MB,
tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
# ── Immutability & value semantics ──────────────────────────────────
def test_physaddr_immutable():
pa = PhysAddr.hbm_addr(rack_id=0, sip_id=0, cube_id=0, hbm_offset=0)
with pytest.raises(AttributeError):
pa.rack_id = 1 # type: ignore[misc]
# hashable
{pa}
# comparable
pa2 = PhysAddr.hbm_addr(rack_id=0, sip_id=0, cube_id=0, hbm_offset=0)
assert pa == pa2
# ── HBM encode/decode roundtrip ────────────────────────────────────
def test_hbm_encode_decode_roundtrip():
pa = PhysAddr.hbm_addr(rack_id=2, sip_id=3, cube_id=5, hbm_offset=0x1000)
raw = pa.encode()
dec = PhysAddr.decode(raw)
assert dec.rack_id == 2
assert dec.sip_id == 3
assert dec.cube_id == 5
assert dec.kind == "hbm"
assert dec.hbm_offset == 0x1000
# ── PE resource encode/decode roundtrip ─────────────────────────────
def test_pe_resource_encode_decode_roundtrip():
pa = PhysAddr(
rack_id=1, sip_id=2, sip_seg=7, local_offset=0,
kind="pe_resource", cube_id=7,
unit_type=UnitType.PE, pe_id=3, ext=1, sub_offset=0xFF,
)
# manually build local_offset matching bit layout
local_offset = (UnitType.PE << 34) | (3 << 30) | (1 << 29) | 0xFF
pa2 = PhysAddr(
rack_id=1, sip_id=2, sip_seg=7, local_offset=local_offset,
kind="pe_resource", cube_id=7,
unit_type=UnitType.PE, pe_id=3, ext=1, sub_offset=0xFF,
)
raw = pa2.encode()
dec = PhysAddr.decode(raw)
assert dec.kind == "pe_resource"
assert dec.unit_type == UnitType.PE
assert dec.pe_id == 3
assert dec.ext == 1
assert dec.sub_offset == 0xFF
# ── pe_hbm_addr factory ────────────────────────────────────────────
def test_pe_hbm_addr_factory():
SLICE = 6 * (1 << 30) # 6 GB per PE slice
pa = PhysAddr.pe_hbm_addr(
rack_id=0, sip_id=0, cube_id=0,
pe_id=2, pe_local_hbm_offset=1024, slice_size_bytes=SLICE,
)
assert pa.kind == "hbm"
assert pa.cube_id == 0
assert pa.hbm_offset == 2 * SLICE + 1024
def test_pe_hbm_addr_overflow():
SLICE = 6 * (1 << 30)
with pytest.raises(PhysAddrError, match="pe_local_hbm_offset"):
PhysAddr.pe_hbm_addr(
rack_id=0, sip_id=0, cube_id=0,
pe_id=0, pe_local_hbm_offset=SLICE, slice_size_bytes=SLICE,
)
# ── Invalid unit_type decode (fix #1) ──────────────────────────────
def test_invalid_unit_type_raises():
# Craft a PE-resource address with unit_type=7 (invalid)
local_offset = (7 << 34) | (0 << 30) | 0
pa_raw = PhysAddr(
rack_id=0, sip_id=0, sip_seg=0, local_offset=local_offset,
)
raw = pa_raw.encode()
with pytest.raises(PhysAddrError, match="unit_type"):
PhysAddr.decode(raw)
# ── hbm_pe_id utility (fix #3) ─────────────────────────────────────
def test_hbm_pe_id_utility():
SLICE = 6 * (1 << 30) # 6 GB
pa = PhysAddr.pe_hbm_addr(
rack_id=0, sip_id=0, cube_id=0,
pe_id=5, pe_local_hbm_offset=256, slice_size_bytes=SLICE,
)
assert PhysAddr.hbm_pe_id(pa.hbm_offset, SLICE) == 5
# ── UnitType.SRAM exists (fix #5) ──────────────────────────────────
def test_sram_unit_type_exists():
assert UnitType.SRAM == 2
# ── cube_sram_addr factory + roundtrip ──────────────────────────────
def test_cube_sram_addr_roundtrip():
pa = PhysAddr.cube_sram_addr(
rack_id=0, sip_id=1, cube_id=3, sram_offset=0x800,
)
assert pa.kind == "pe_resource"
assert pa.unit_type == UnitType.SRAM
assert pa.cube_id == 3
assert pa.sub_offset == 0x800
# encode → decode roundtrip
dec = PhysAddr.decode(pa.encode())
assert dec.unit_type == UnitType.SRAM
assert dec.cube_id == 3
assert dec.sub_offset == 0x800
def test_cube_sram_addr_range_check():
with pytest.raises(PhysAddrError):
PhysAddr.cube_sram_addr(
rack_id=0, sip_id=0, cube_id=0,
sram_offset=(1 << 29), # exceeds 29-bit sub_offset
)
# ── pe_tcm_addr factory + roundtrip ────────────────────────────────
def test_pe_tcm_addr_roundtrip():
pa = PhysAddr.pe_tcm_addr(
rack_id=0, sip_id=0, cube_id=2, pe_id=7, tcm_offset=0x400,
)
assert pa.kind == "pe_resource"
assert pa.unit_type == UnitType.PE
assert pa.pe_id == 7
assert pa.cube_id == 2
assert pa.sub_offset == 0x400
# encode → decode roundtrip
dec = PhysAddr.decode(pa.encode())
assert dec.unit_type == UnitType.PE
assert dec.pe_id == 7
assert dec.sub_offset == 0x400
def test_pe_tcm_addr_range_check():
with pytest.raises(PhysAddrError):
PhysAddr.pe_tcm_addr(
rack_id=0, sip_id=0, cube_id=0, pe_id=0,
tcm_offset=(1 << 29), # exceeds 29-bit sub_offset
)
# ── AddressConfig ───────────────────────────────────────────────────
def test_address_config_derived_sizes():
assert _CFG.hbm_slice_bytes == 6 * _GB
assert _CFG.tcm_allocatable_bytes == 12 * _MB
# ── PEMemAllocator: HBM ────────────────────────────────────────────
def _make_alloc(pe_id: int = 0) -> PEMemAllocator:
return PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=pe_id, cfg=_CFG)
def test_allocator_hbm_basic():
a = _make_alloc(pe_id=3)
pa = a.alloc_hbm(4096)
assert pa.kind == "hbm"
assert pa.sip_id == 0
assert pa.cube_id == 0
# hbm_offset should be pe3's slice start
assert pa.hbm_offset == 3 * 6 * _GB
def test_allocator_hbm_sequential():
a = _make_alloc()
pa1 = a.alloc_hbm(1024)
pa2 = a.alloc_hbm(2048)
assert pa1.hbm_offset == 0 # pe0 slice start + 0
assert pa2.hbm_offset == 1024 # pe0 slice start + 1024
def test_allocator_hbm_overflow():
a = _make_alloc()
a.alloc_hbm(6 * _GB - 256)
with pytest.raises(AllocationError, match="HBM"):
a.alloc_hbm(512)
# ── PEMemAllocator: TCM ────────────────────────────────────────────
def test_allocator_tcm_basic():
a = _make_alloc(pe_id=5)
pa = a.alloc_tcm(256)
assert pa.kind == "pe_resource"
assert pa.unit_type == UnitType.PE
assert pa.pe_id == 5
assert pa.sub_offset == 0
def test_allocator_tcm_respects_reserved():
a = _make_alloc()
# allocatable = 12 MB, should succeed
a.alloc_tcm(12 * _MB)
assert a.tcm_used == 12 * _MB
assert a.tcm_total == 12 * _MB
def test_allocator_tcm_overflow():
a = _make_alloc()
a.alloc_tcm(12 * _MB)
with pytest.raises(AllocationError, match="TCM"):
a.alloc_tcm(1)
# ── PEMemAllocator: stats & determinism ─────────────────────────────
def test_allocator_stats():
a = _make_alloc()
a.alloc_hbm(1000)
a.alloc_tcm(500)
assert a.hbm_used == 1000
assert a.hbm_total == 6 * _GB
assert a.tcm_used == 500
assert a.tcm_total == 12 * _MB
def test_allocator_deterministic():
a1 = _make_alloc(pe_id=2)
a2 = _make_alloc(pe_id=2)
assert a1.alloc_hbm(4096) == a2.alloc_hbm(4096)
assert a1.alloc_tcm(256) == a2.alloc_tcm(256)
+221
View File
@@ -0,0 +1,221 @@
"""Tests for H2D writes and PE DMA probe latency invariants.
H2D tests use MemoryWriteMsg (pcie_ep → io_cpu → m_cpu → hbm_ctrl → response).
PE DMA tests use PeDmaMsg (direct pe_dma → xbar → hbm_ctrl injection).
"""
from pathlib import Path
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.policy.routing.router import AddressResolver, PathRouter
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import load_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
def _engine():
return GraphEngine(load_topology(TOPOLOGY_PATH))
def _hbm_pa(sip: int = 0, cube: int = 0, pe_id: int = 0) -> int:
slice_bytes = 48 * (1 << 30) // 8
pa = PhysAddr.pe_hbm_addr(
rack_id=0, sip_id=sip, cube_id=cube, pe_id=pe_id,
pe_local_hbm_offset=0x1000, slice_size_bytes=slice_bytes,
)
return pa.encode()
def _h2d_latency(dst_cube: int, dst_pe: int = 0) -> float:
engine = _engine()
msg = MemoryWriteMsg(
correlation_id="probe", request_id=f"h2d-c{dst_cube}-p{dst_pe}",
dst_sip=0, dst_cube=dst_cube, dst_pe=dst_pe,
dst_pa=_hbm_pa(sip=0, cube=dst_cube, pe_id=dst_pe), nbytes=4096,
pattern="zero", target_pe=dst_pe,
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
return trace["total_ns"]
# ── 1. Single-PE write completes ──────────────────────────────────
def test_single_pe_write_completes():
"""MemoryWriteMsg(target_pe=0) must complete with ok=True, latency > 0."""
engine = _engine()
msg = MemoryWriteMsg(
correlation_id="probe", request_id="pe-local",
dst_sip=0, dst_cube=0, dst_pe=0,
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
pattern="zero", target_pe=0,
)
h = engine.submit(msg)
engine.wait(h)
comp, trace = engine.get_completion(h)
assert comp.ok is True
assert trace["total_ns"] > 0
# ── 2. Cross-cube write positive latency ─────────────────────────
def test_cross_cube_write_positive():
"""Cross-cube MemoryWriteMsg(target_pe=0) must complete with latency > 0."""
lat = _h2d_latency(dst_cube=1, dst_pe=0)
assert lat > 0
# ── 3. H2D latency monotonicity ──────────────────────────────────
def test_h2d_latency_monotonic():
"""1hop < 2hop < 3hop < 4hop."""
cubes = [0, 4, 8, 12]
latencies: list[tuple[int, float]] = []
for cube in cubes:
lat = _h2d_latency(dst_cube=cube, dst_pe=0)
latencies.append((cube, lat))
for i in range(len(latencies) - 1):
assert latencies[i][1] < latencies[i + 1][1], (
f"cube{latencies[i][0]}({latencies[i][1]:.2f}) "
f"must < cube{latencies[i + 1][0]}({latencies[i + 1][1]:.2f})"
)
# ── 4. Single-PE write deterministic ─────────────────────────────
def test_single_pe_write_deterministic():
"""Same MemoryWriteMsg on two engines must produce identical latency."""
msg = MemoryWriteMsg(
correlation_id="probe", request_id="det",
dst_sip=0, dst_cube=0, dst_pe=0,
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
pattern="zero", target_pe=0,
)
e1 = _engine()
h1 = e1.submit(msg)
e1.wait(h1)
_, t1 = e1.get_completion(h1)
e2 = _engine()
h2 = e2.submit(msg)
e2.wait(h2)
_, t2 = e2.get_completion(h2)
assert t1["total_ns"] == t2["total_ns"]
# ── 5. Cut-through (wormhole) wire model invariants ──────────────
def test_h2d_local_cube_cut_through():
"""H2D to local cube with cut-through should be < 50ns for 4096B.
Full command path: pcie_ep → io_cpu → ucie → noc → m_cpu
DMA: m_cpu → noc → xbar → hbm_ctrl (drain once at terminal)
Plus response path back.
With store-and-forward each hop would serialize; cut-through keeps it low.
"""
lat = _h2d_latency(dst_cube=0, dst_pe=0)
assert lat < 65.0, f"Local H2D {lat:.2f}ns; cut-through expects < 65ns"
def test_h2d_remote_cube_cut_through():
"""H2D to 1-hop remote cube: cut-through drain dominates, not per-hop serialization.
With store-and-forward, each hop would serialize 4096B, total >> 100ns.
With cut-through, drain happens once at bottleneck.
"""
lat = _h2d_latency(dst_cube=4, dst_pe=0)
assert lat < 80.0, f"Remote H2D {lat:.2f}ns; cut-through expects < 80ns"
# ── 6. PE DMA: direct injection tests ─────────────────────────
def _graph():
return load_topology(TOPOLOGY_PATH)
def _pe_dma_latency(src_cube: int, src_pe: int, dst_pe: int) -> float:
engine = _engine()
msg = PeDmaMsg(
correlation_id="probe", request_id=f"dma-c{src_cube}-p{src_pe}-s{dst_pe}",
src_sip=0, src_cube=src_cube, src_pe=src_pe,
dst_pa=_hbm_pa(sip=0, cube=src_cube, pe_id=dst_pe), nbytes=4096,
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
return trace["total_ns"]
def _pe_dma_bottleneck(src_cube: int, src_pe: int, dst_pe: int) -> float | None:
graph = _graph()
edge_map = {(e.src, e.dst): e for e in graph.edges}
resolver = AddressResolver(graph)
router = PathRouter(graph)
pa = _hbm_pa(sip=0, cube=src_cube, pe_id=dst_pe)
pa_obj = PhysAddr.decode(pa)
dst_node = resolver.resolve(pa_obj)
pe_ref = f"sip0.cube{src_cube}.pe{src_pe}"
path = router.find_path(pe_ref, dst_node)
bws: list[float] = []
for i in range(len(path) - 1):
e = edge_map.get((path[i], path[i + 1]))
if e and e.bw_gbs:
bws.append(e.bw_gbs)
return min(bws) if bws else None
def test_pe_dma_local_completes():
"""PeDmaMsg to local slice0 must complete with ok=True, latency > 0."""
engine = _engine()
msg = PeDmaMsg(
correlation_id="probe", request_id="dma-local",
src_sip=0, src_cube=0, src_pe=0,
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
)
h = engine.submit(msg)
engine.wait(h)
comp, trace = engine.get_completion(h)
assert comp.ok is True
assert trace["total_ns"] > 0
def test_pe_dma_local_bottleneck_256():
"""PE DMA pe0→slice0 (local): bottleneck = 256 GB/s (direct xbar→hbm)."""
bn = _pe_dma_bottleneck(src_cube=0, src_pe=0, dst_pe=0)
assert bn == 256.0, f"Local PE DMA bottleneck {bn}, expected 256.0"
def test_pe_dma_chain_bottleneck_128():
"""PE DMA pe0→slice1 (xbar chain): bottleneck = 128 GB/s."""
bn = _pe_dma_bottleneck(src_cube=0, src_pe=0, dst_pe=1)
assert bn == 128.0, f"Chain PE DMA bottleneck {bn}, expected 128.0"
def test_pe_dma_deterministic():
"""Same PeDmaMsg on two engines must produce identical latency."""
msg = PeDmaMsg(
correlation_id="probe", request_id="det",
src_sip=0, src_cube=0, src_pe=0,
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
)
e1 = _engine()
h1 = e1.submit(msg)
e1.wait(h1)
_, t1 = e1.get_completion(h1)
e2 = _engine()
h2 = e2.submit(msg)
e2.wait(h2)
_, t2 = e2.get_completion(h2)
assert t1["total_ns"] == t2["total_ns"]
+226
View File
@@ -0,0 +1,226 @@
import pytest
from pathlib import Path
from kernbench.policy.address.phyaddr import PhysAddr, UnitType
from kernbench.policy.routing.router import AddressResolver, PathRouter, RoutingError
from kernbench.topology.builder import load_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
def _graph():
return load_topology(TOPOLOGY_PATH)
# ── AddressResolver ──────────────────────────────────────────────────
def test_resolve_hbm_addr():
"""HBM address -> sip{S}.cube{C}.hbm_ctrl.slice{P}"""
g = _graph()
resolver = AddressResolver(g)
# hbm_offset=0x1000, slice_size=6GB -> slice 0
pa = PhysAddr.hbm_addr(rack_id=0, sip_id=0, cube_id=3, hbm_offset=0x1000)
assert resolver.resolve(pa) == "sip0.cube3.hbm_ctrl.slice0"
def test_resolve_hbm_addr_slice4():
"""HBM address in PE4's slice range -> slice4."""
g = _graph()
resolver = AddressResolver(g)
# slice_size = 6GB; PE4 offset starts at 4*6GB = 24GB = 0x600000000
pa = PhysAddr.hbm_addr(rack_id=0, sip_id=0, cube_id=0, hbm_offset=0x600000000)
assert resolver.resolve(pa) == "sip0.cube0.hbm_ctrl.slice4"
def test_resolve_pe_tcm_addr():
"""PE TCM address → sip{S}.cube{C}.pe{P}.pe_tcm"""
g = _graph()
resolver = AddressResolver(g)
pa = PhysAddr.pe_tcm_addr(rack_id=0, sip_id=1, cube_id=5, pe_id=7, tcm_offset=0x400)
assert resolver.resolve(pa) == "sip1.cube5.pe7.pe_tcm"
def test_resolve_sram_addr():
"""SRAM address → sip{S}.cube{C}.sram"""
g = _graph()
resolver = AddressResolver(g)
pa = PhysAddr.cube_sram_addr(rack_id=0, sip_id=0, cube_id=10, sram_offset=0x800)
assert resolver.resolve(pa) == "sip0.cube10.sram"
def test_resolve_mcpu_addr():
"""MCPU pe_resource address → sip{S}.cube{C}.m_cpu"""
g = _graph()
resolver = AddressResolver(g)
pa = PhysAddr(
rack_id=0, sip_id=0, sip_seg=2, local_offset=(UnitType.MCPU << 34),
kind="pe_resource", cube_id=2, unit_type=UnitType.MCPU,
)
assert resolver.resolve(pa) == "sip0.cube2.m_cpu"
def test_resolve_nonexistent_node():
"""Address pointing to a node outside the compiled topology raises RoutingError."""
g = _graph()
resolver = AddressResolver(g)
# sip_id=15 doesn't exist in the 2-SIP topology
pa = PhysAddr.hbm_addr(rack_id=0, sip_id=15, cube_id=0, hbm_offset=0)
with pytest.raises(RoutingError):
resolver.resolve(pa)
# ── PathRouter: local HBM (same xbar half) ──────────────────────────
def test_path_local_hbm_same_half():
"""PE0 -> slice0 (local): pe_dma -> xbar.pe0 -> hbm_ctrl.slice0 (no chain hops)."""
g = _graph()
router = PathRouter(g)
path = router.find_path("sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice0")
assert path[0] == "sip0.cube0.pe0.pe_dma"
assert "sip0.cube0.xbar.pe0" in path
assert path[-1] == "sip0.cube0.hbm_ctrl.slice0"
# local access: no bridge and no chain traversal (shortest path = 3 nodes)
assert not any("bridge" in n for n in path)
assert len(path) == 3 # pe_dma → xbar.pe0 → slice0
# ── PathRouter: same-half remote HBM ────────────────────────────────
def test_path_same_half_remote_hbm():
"""PE0 -> slice1: same-half chain traversal pe0→pe1, no bridge."""
g = _graph()
router = PathRouter(g)
path = router.find_path("sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice1")
assert path[0] == "sip0.cube0.pe0.pe_dma"
assert "sip0.cube0.xbar.pe0" in path # enter at pe0
assert "sip0.cube0.xbar.pe1" in path # chain hop to pe1
assert path[-1] == "sip0.cube0.hbm_ctrl.slice1"
assert not any("bridge" in n for n in path)
assert len(path) == 4 # pe_dma → xbar.pe0 → xbar.pe1 → slice1
# ── PathRouter: cross-half HBM ──────────────────────────────────────
def test_path_cross_half_hbm():
"""PE0 -> slice4 (cross-half): pe_dma → xbar.pe0 → bridge.left → xbar.pe4 → slice4."""
g = _graph()
router = PathRouter(g)
path = router.find_path("sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice4")
assert path[0] == "sip0.cube0.pe0.pe_dma"
assert "sip0.cube0.xbar.pe0" in path
assert any("bridge" in n for n in path), "cross-half HBM must traverse bridge"
assert "sip0.cube0.xbar.pe4" in path
assert path[-1] == "sip0.cube0.hbm_ctrl.slice4"
# Shortest cross-half path: pe_dma → xbar.pe0 → bridge.left → xbar.pe4 → slice4
assert len(path) == 5
def test_path_cross_half_requires_bridge():
"""PE4 (bottom) -> slice2 (top) requires bridge traversal."""
g = _graph()
router = PathRouter(g)
path = router.find_path("sip0.cube0.pe4", "sip0.cube0.hbm_ctrl.slice2")
assert any("bridge" in n for n in path), "cross-half HBM must traverse bridge"
assert any("xbar.pe" in n for n in path)
assert path[-1] == "sip0.cube0.hbm_ctrl.slice2"
def test_cross_half_distance_greater():
"""Cross-half HBM access must have greater distance than local-half."""
g = _graph()
router = PathRouter(g)
_, dist_local = router.find_path_with_distance(
"sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice0")
_, dist_cross = router.find_path_with_distance(
"sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice4")
assert dist_cross > dist_local
def test_path_same_half_remote_longer():
"""Same-half remote HBM (PE0->slice3) has greater distance than local (PE0->slice0)."""
g = _graph()
router = PathRouter(g)
_, dist_local = router.find_path_with_distance(
"sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice0")
_, dist_remote = router.find_path_with_distance(
"sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice3")
assert dist_remote > dist_local, (
f"same-half remote ({dist_remote:.2f}mm) must > local ({dist_local:.2f}mm)"
)
def test_path_remote_cube_hbm():
"""PE0 in cube0 can reach HBM in cube1 via UCIe (ADR-0004 D4)."""
g = _graph()
router = PathRouter(g)
path = router.find_path("sip0.cube0.pe0", "sip0.cube1.hbm_ctrl.slice0")
assert path[0] == "sip0.cube0.pe0.pe_dma"
assert path[-1] == "sip0.cube1.hbm_ctrl.slice0"
# inter-cube path must cross a UCIe link
assert any("ucie" in n for n in path), "remote cube path must traverse UCIe"
# must not be trivially short (needs noc + ucie + remote noc + xbar)
assert len(path) >= 5
# ── PathRouter: SRAM via NOC ────────────────────────────────────────
def test_path_sram_via_noc():
"""PE → SRAM must go through NOC (non-HBM data path)."""
g = _graph()
router = PathRouter(g)
path = router.find_path("sip0.cube0.pe0", "sip0.cube0.sram")
assert path[0] == "sip0.cube0.pe0.pe_dma"
assert "sip0.cube0.noc" in path
assert path[-1] == "sip0.cube0.sram"
# should NOT go through xbar (SRAM is non-HBM path)
assert not any("xbar" in n for n in path)
# ── PathRouter: PE TCM (local) ──────────────────────────────────────
def test_path_local_tcm():
"""PE0 → own TCM is PE-internal, not via xbar or noc."""
g = _graph()
router = PathRouter(g)
path = router.find_path("sip0.cube0.pe0", "sip0.cube0.pe0.pe_tcm")
assert path[0] == "sip0.cube0.pe0.pe_dma"
assert path[-1] == "sip0.cube0.pe0.pe_tcm"
# PE-internal path, no fabric
assert not any("xbar" in n or "noc" in n for n in path)
# ── PathRouter: distance monotonic ──────────────────────────────────
def test_path_distance_positive():
"""All routed paths must have accumulated distance > 0 (ADR-0002 D4)."""
g = _graph()
router = PathRouter(g)
_, dist = router.find_path_with_distance("sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice0")
assert dist > 0
def test_path_deterministic():
"""Same (src, dst) must always produce the same path."""
g = _graph()
r1 = PathRouter(g)
r2 = PathRouter(g)
p1 = r1.find_path("sip0.cube0.pe3", "sip0.cube0.hbm_ctrl.slice3")
p2 = r2.find_path("sip0.cube0.pe3", "sip0.cube0.hbm_ctrl.slice3")
assert p1 == p2
def test_remote_cube_path_no_routing_error():
"""Routing to remote cube HBM must not raise RoutingError (ADR-0004 D4)."""
g = _graph()
router = PathRouter(g)
# cube0.PE0 -> cube1.slice0 (adjacent cube, E direction)
path = router.find_path("sip0.cube0.pe0", "sip0.cube1.hbm_ctrl.slice0")
assert len(path) >= 1 # succeeds without exception
+282
View File
@@ -0,0 +1,282 @@
import pytest
from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator
from kernbench.policy.placement.dp import (
ShardSpec,
column_wise,
tiled_column_major,
replicate,
row_wise,
tiled_row_major,
)
from kernbench.runtime_api.kernel import (
KernelLaunchMsg,
KernelRef,
MemoryReadMsg,
MemoryWriteMsg,
ScalarArg,
TensorArg,
TensorArgShard,
)
from kernbench.runtime_api.tensor import (
TensorHandle,
TensorShard,
deploy_tensor,
dtype_itemsize,
)
_MB = 1 << 20
_GB = 1 << 30
_CFG = AddressConfig(
sip_count=2,
cubes_per_sip=16,
pes_per_cube=8,
hbm_bytes_per_cube=48 * _GB,
hbm_slices_per_cube=8,
tcm_bytes_per_pe=16 * _MB,
tcm_scheduler_reserved_bytes=4 * _MB,
sram_bytes_per_cube=32 * _MB,
)
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]:
return {
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
for i in range(num_pe)
}
# ── Tensor types ─────────────────────────────────────────────────────
def test_tensor_shard_immutable():
ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0)
with pytest.raises(AttributeError):
ts.pa = 0x2000 # type: ignore[misc]
# hashable
{ts}
def test_tensor_handle_nbytes():
th = TensorHandle(
name="A",
shape=(1024, 512),
dtype="fp16",
itemsize=2,
shards=(),
)
assert th.nbytes == 1024 * 512 * 2 # 1 MB
# ── Message types (ADR-0012) ─────────────────────────────────────────
def test_memory_write_msg_fields():
msg = MemoryWriteMsg(
correlation_id="c0",
request_id="r0",
dst_sip=0,
dst_cube=3,
dst_pe=5,
dst_pa=0xDEAD,
nbytes=4096,
pattern="zero",
)
assert msg.msg_type == "memory_write"
assert msg.src_kind == "pattern"
assert msg.dst_pa == 0xDEAD
assert msg.pattern == "zero"
with pytest.raises(AttributeError):
msg.nbytes = 0 # type: ignore[misc]
def test_memory_read_msg_fields():
msg = MemoryReadMsg(
correlation_id="c0",
request_id="r1",
src_sip=1,
src_cube=2,
src_pe=7,
src_pa=0xBEEF,
nbytes=2048,
)
assert msg.msg_type == "memory_read"
assert msg.src_pa == 0xBEEF
assert msg.nbytes == 2048
def test_kernel_launch_msg_fields():
shard = TensorArgShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=1024, offset_bytes=0)
targ = TensorArg(shards=(shard,))
sarg = ScalarArg(dtype="fp32", value=1.0)
kref = KernelRef(name="gemm", kind="builtin")
msg = KernelLaunchMsg(
correlation_id="c0",
request_id="r2",
kernel_ref=kref,
args=(targ, sarg),
)
assert msg.msg_type == "kernel_launch"
assert msg.kernel_ref.name == "gemm"
assert len(msg.args) == 2
assert msg.args[0].arg_kind == "tensor"
assert msg.args[1].arg_kind == "scalar"
# ── Placement: column_wise ───────────────────────────────────────────
def test_column_wise_placement():
"""(1024, 512) fp16 across 8 PEs → K axis split → 8 shards, each (1024, 64) = 128KB"""
shards = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
assert len(shards) == 8
expected_nbytes = 1024 * 64 * 2 # 128 KB
for i, s in enumerate(shards):
assert s.pe_index == i
assert s.nbytes == expected_nbytes
# offsets are contiguous
assert shards[0].offset_bytes == 0
assert shards[1].offset_bytes == expected_nbytes
# total coverage
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
# ── Placement: row_wise ──────────────────────────────────────────────
def test_row_wise_placement():
"""(1024, 512) fp16 across 8 PEs → M axis split → 8 shards, each (128, 512) = 128KB"""
shards = row_wise(shape=(1024, 512), itemsize=2, num_pe=8)
assert len(shards) == 8
expected_nbytes = 128 * 512 * 2 # 128 KB
for i, s in enumerate(shards):
assert s.pe_index == i
assert s.nbytes == expected_nbytes
assert shards[0].offset_bytes == 0
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
# ── Placement: replicate ─────────────────────────────────────────────
def test_replicate_placement():
"""(1024, 512) fp16 across 8 PEs → each PE gets full copy = 1MB"""
shards = replicate(shape=(1024, 512), itemsize=2, num_pe=8)
assert len(shards) == 8
full_nbytes = 1024 * 512 * 2 # 1 MB
for i, s in enumerate(shards):
assert s.pe_index == i
assert s.nbytes == full_nbytes
assert s.offset_bytes == 0 # each is a full copy
# ── Placement: tiled_column_major ─────────────────────────────────────
def test_tiled_column_major():
"""(1024, 512) tile=(256, 128) → 4×4=16 tiles, column-major → round-robin 8 PEs"""
shards = tiled_column_major(
shape=(1024, 512), itemsize=2, num_pe=8, tile_m=256, tile_k=128,
)
# 4 tiles along M, 4 tiles along K → 16 tiles total
assert len(shards) == 16
tile_bytes = 256 * 128 * 2 # 64 KB per tile
for s in shards:
assert s.nbytes == tile_bytes
# column-major: iterate K first, then M
# tile (m=0,k=0) → PE0, tile (m=0,k=1) → PE1, ..., (m=0,k=3) → PE3
# tile (m=1,k=0) → PE4, tile (m=1,k=1) → PE5, ..., (m=1,k=3) → PE7
# tile (m=2,k=0) → PE0, ...
assert shards[0].pe_index == 0
assert shards[1].pe_index == 1
assert shards[7].pe_index == 7
assert shards[8].pe_index == 0 # wraps around
# total coverage
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
# ── Placement: tiled_row_major ────────────────────────────────────────
def test_tiled_row_major():
"""(1024, 512) tile=(256, 128) → 4×4=16 tiles, row-major → round-robin 8 PEs"""
shards = tiled_row_major(
shape=(1024, 512), itemsize=2, num_pe=8, tile_m=256, tile_k=128,
)
assert len(shards) == 16
tile_bytes = 256 * 128 * 2
for s in shards:
assert s.nbytes == tile_bytes
# row-major: iterate M first, then K
# tile (m=0,k=0) → PE0, tile (m=1,k=0) → PE1, ..., (m=3,k=0) → PE3
# tile (m=0,k=1) → PE4, tile (m=1,k=1) → PE5, ..., (m=3,k=1) → PE7
# tile (m=0,k=2) → PE0, ...
assert shards[0].pe_index == 0
assert shards[1].pe_index == 1
assert shards[7].pe_index == 7
assert shards[8].pe_index == 0 # wraps around
# total coverage
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
# ── deploy_tensor ────────────────────────────────────────────────────
def test_deploy_tensor_hbm():
"""Deploy with column_wise placement → TensorHandle with valid PA shards."""
allocs = _make_allocators()
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
th = deploy_tensor(
name="W",
shape=(1024, 512),
dtype="fp16",
placement=placement,
allocators=allocs,
mem_kind="hbm",
)
assert th.name == "W"
assert th.shape == (1024, 512)
assert th.dtype == "fp16"
assert th.itemsize == 2
assert len(th.shards) == 8
# each shard has a distinct PA
pas = [s.pa for s in th.shards]
assert len(set(pas)) == 8
# each shard placed on correct PE
for i, s in enumerate(th.shards):
assert s.pe == i
assert s.sip == 0
assert s.cube == 0
def test_deploy_tensor_tcm():
"""Deploy with TCM → uses pe_tcm_addr allocation."""
allocs = _make_allocators()
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=256)]
th = deploy_tensor(
name="small",
shape=(128,),
dtype="fp16",
placement=placement,
allocators=allocs,
mem_kind="tcm",
)
assert len(th.shards) == 1
assert th.shards[0].pe == 0
assert th.shards[0].nbytes == 256
def test_deploy_tensor_overflow():
"""Allocation exceeding PE HBM capacity raises AllocationError."""
allocs = _make_allocators()
# 6 GB per PE slice, try to allocate 7 GB
big_shard = ShardSpec(pe_index=0, offset_bytes=0, nbytes=7 * _GB)
with pytest.raises(AllocationError):
deploy_tensor(
name="toobig",
shape=(1,),
dtype="int8",
placement=[big_shard],
allocators=allocs,
)
+409
View File
@@ -0,0 +1,409 @@
from pathlib import Path
from kernbench.topology.builder import load_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
def _graph():
return load_topology(TOPOLOGY_PATH)
# ── Full graph: node counts ──────────────────────────────────────────
def test_full_graph_node_count():
g = _graph()
# 1 switch
# + 2 SIPs × (1 IO × 2 comps + 16 cubes × (cube_comps + 8 PEs × 6 pe_comps))
# cube_comps: 9 (noc, m_cpu, sram, 2 bridge, 4 ucie)
# + 8 xbar.pe{0..7} [replaced xbar.top/xbar.bottom]
# + 8 hbm_slices = 25
# = 1 + 2*(2 + 16*(25+48)) = 1 + 2*(2+1168) = 1 + 2340 = 2341
assert len(g.nodes) == 2341
def test_full_graph_edge_count():
g = _graph()
# Per cube: 144 (88 cube-fabric + 56 PE-internal)
# cube-fabric: 8 pe→xbar.pe + 8 pe→noc + 8 noc→pe_cpu
# + 8 xbar.pe→slice + 8 slice→xbar.pe (bidirectional for response)
# + 12 xbar chain (3 pairs × 2 dir × 2 halves)
# + 8 xbar.pe↔bridge (pe0↔bL, pe4↔bL, pe3↔bR, pe7↔bR, ×2 dir each)
# + 4 noc→ucie + 4 ucie→noc (bidirectional)
# + 8 noc→xbar.pe + 8 xbar.pe→noc (bidirectional for response)
# + 1 m_cpu→noc + 1 noc→m_cpu + 1 noc→sram + 1 sram→noc = 88
# Per SIP: 16*144 + 48 inter-cube(bidirectional) + 8 io↔cube(bidirectional)
# + 1 io_internal + 1 switch→io = 2362
# Total: 2 * 2362 = 4724
assert len(g.edges) == 4724
# ── Full graph: specific nodes exist ─────────────────────────────────
def test_system_switch_exists():
g = _graph()
assert "fabric.switch0" in g.nodes
assert g.nodes["fabric.switch0"].kind == "switch"
assert g.nodes["fabric.switch0"].pos_mm is None # abstract
def test_io_chiplet_nodes_exist():
g = _graph()
for s in range(2):
assert f"sip{s}.io0.pcie_ep" in g.nodes
assert f"sip{s}.io0.io_cpu" in g.nodes
def test_cube_component_nodes_exist():
g = _graph()
cp = "sip0.cube0"
for name in ("noc", "m_cpu",
"bridge.left", "bridge.right",
"ucie-N", "ucie-S", "ucie-E", "ucie-W",
"sram"):
assert f"{cp}.{name}" in g.nodes
# xbar.top/xbar.bottom replaced by per-PE xbar entry nodes
assert "sip0.cube0.xbar.top" not in g.nodes
assert "sip0.cube0.xbar.bottom" not in g.nodes
for pe in range(8):
node_id = f"{cp}.xbar.pe{pe}"
assert node_id in g.nodes, f"{node_id} missing"
assert g.nodes[node_id].kind == "xbar"
# HBM slices (one per PE)
for s in range(8):
assert f"{cp}.hbm_ctrl.slice{s}" in g.nodes
assert g.nodes[f"{cp}.hbm_ctrl.slice{s}"].kind == "hbm_ctrl"
def test_pe_component_nodes_exist():
g = _graph()
for comp in ("pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"):
assert f"sip0.cube0.pe0.{comp}" in g.nodes
assert f"sip1.cube15.pe7.{comp}" in g.nodes
# ── Full graph: positions ────────────────────────────────────────────
def test_hbm_ctrl_slices_at_cube_center():
g = _graph()
# cube0 origin = (0, 0), cx=8.5, cy=7.0, hbm_ctrl at (cx-2, cy)
# all slices share the same physical position
for s in range(8):
node = g.nodes[f"sip0.cube0.hbm_ctrl.slice{s}"]
assert node.pos_mm == (6.5, 7.0)
def test_hbm_ctrl_slices_cube5_position():
g = _graph()
# cube5 = col=1, row=1 -> origin = (1*18, 1*15) = (18, 15)
# hbm_ctrl = (18 + 6.5, 15 + 7.0) = (24.5, 22.0)
node = g.nodes["sip0.cube5.hbm_ctrl.slice0"]
assert node.pos_mm == (24.5, 22.0)
def test_ucie_ports_at_cube_edges():
g = _graph()
# cube0 origin = (0, 0), cube_w=17, cube_h=14
# UCIe nodes inset by half-size so edges touch boundary
assert g.nodes["sip0.cube0.ucie-N"].pos_mm == (8.5, 0.6)
assert g.nodes["sip0.cube0.ucie-S"].pos_mm == (8.5, 13.4)
assert g.nodes["sip0.cube0.ucie-W"].pos_mm == (1.0, 7.0)
assert g.nodes["sip0.cube0.ucie-E"].pos_mm == (16.0, 7.0)
# ── Full graph: edges ────────────────────────────────────────────────
def _edge_set(g):
return {(e.src, e.dst) for e in g.edges}
def test_inter_cube_ucie_edges():
es = _edge_set(_graph())
# cube0 (0,0) E → cube1 (1,0) W
assert ("sip0.cube0.ucie-E", "sip0.cube1.ucie-W") in es
# cube0 (0,0) S → cube4 (0,1) N
assert ("sip0.cube0.ucie-S", "sip0.cube4.ucie-N") in es
def test_io_to_cube_edges():
es = _edge_set(_graph())
# io0 connects to cubes (0,0)..(3,0) on N side
assert ("sip0.io0.io_cpu", "sip0.cube0.ucie-N") in es
assert ("sip0.io0.io_cpu", "sip0.cube3.ucie-N") in es
def test_switch_to_io_edges():
es = _edge_set(_graph())
assert ("fabric.switch0", "sip0.io0.pcie_ep") in es
assert ("fabric.switch0", "sip1.io0.pcie_ep") in es
def test_pe_to_xbar_edges():
es = _edge_set(_graph())
cp = "sip0.cube0"
# Each PE connects to its own xbar entry (per-PE chain model)
for pe in range(8):
assert (f"{cp}.pe{pe}.pe_dma", f"{cp}.xbar.pe{pe}") in es
# Old shared xbar.top/bottom edges must NOT exist
assert (f"{cp}.pe0.pe_dma", f"{cp}.xbar.top") not in es
assert (f"{cp}.pe4.pe_dma", f"{cp}.xbar.bottom") not in es
def test_command_path_m_cpu_noc_pe_cpu():
es = _edge_set(_graph())
cp = "sip0.cube0"
# m_cpu ↔ noc (bidirectional)
assert (f"{cp}.m_cpu", f"{cp}.noc") in es
assert (f"{cp}.noc", f"{cp}.m_cpu") in es
# noc → pe_cpu for each PE
assert (f"{cp}.noc", f"{cp}.pe0.pe_cpu") in es
assert (f"{cp}.noc", f"{cp}.pe7.pe_cpu") in es
def test_pe_internal_edges():
es = _edge_set(_graph())
pp = "sip0.cube0.pe0"
assert (f"{pp}.pe_cpu", f"{pp}.pe_scheduler") in es
assert (f"{pp}.pe_scheduler", f"{pp}.pe_dma") in es
assert (f"{pp}.pe_scheduler", f"{pp}.pe_gemm") in es
assert (f"{pp}.pe_scheduler", f"{pp}.pe_math") in es
assert (f"{pp}.pe_dma", f"{pp}.pe_tcm") in es
assert (f"{pp}.pe_gemm", f"{pp}.pe_tcm") in es
assert (f"{pp}.pe_math", f"{pp}.pe_tcm") in es
def test_xbar_to_hbm_slice_edges():
"""Each xbar.pe{i} connects only to its own (local) HBM slice."""
es = _edge_set(_graph())
cp = "sip0.cube0"
# xbar.pe_i -> slice_i only (local Y-direction access)
for pe in range(8):
assert (f"{cp}.xbar.pe{pe}", f"{cp}.hbm_ctrl.slice{pe}") in es
# Negative: xbar.pe_i must NOT directly connect to a different slice
assert (f"{cp}.xbar.pe0", f"{cp}.hbm_ctrl.slice1") not in es
assert (f"{cp}.xbar.pe0", f"{cp}.hbm_ctrl.slice4") not in es
assert (f"{cp}.xbar.pe4", f"{cp}.hbm_ctrl.slice0") not in es
# ── Views: system ────────────────────────────────────────────────────
def test_system_view_nodes():
v = _graph().system_view
assert "fabric.switch0" in v.nodes
assert "sip0" in v.nodes
assert "sip1" in v.nodes
assert "sip0.io0" in v.nodes
assert "sip1.io0" in v.nodes
# ── Views: SIP ───────────────────────────────────────────────────────
def test_sip_view_cube_count():
v = _graph().sip_view
cube_nodes = [n for n in v.nodes if n.startswith("cube")]
assert len(cube_nodes) == 16
def test_sip_view_io_chiplets():
v = _graph().sip_view
assert "io0" in v.nodes
def test_sip_view_cube_positions():
v = _graph().sip_view
# cube0 (0,0): center = (8.5, 6+7.0) = (8.5, 13.0) [io_margin=6]
x, y = v.nodes["cube0"].pos_mm
assert x == 8.5
assert y == 13.0
# cube1 (1,0): center = (18+8.5, 13.0) = (26.5, 13.0)
x1, y1 = v.nodes["cube1"].pos_mm
assert x1 == 26.5
assert y1 == 13.0
# ── Views: cube ──────────────────────────────────────────────────────
def test_cube_view_has_all_components():
v = _graph().cube_view
expected = {"ucie-N", "ucie-S", "ucie-W", "ucie-E",
"m_cpu", "hbm_ctrl",
"bridge.left", "bridge.right", "noc", "sram",
"xbar.pe0", "xbar.pe1", "xbar.pe2", "xbar.pe3",
"xbar.pe4", "xbar.pe5", "xbar.pe6", "xbar.pe7",
"pe0", "pe1", "pe2", "pe3", "pe4", "pe5", "pe6", "pe7"}
assert set(v.nodes.keys()) == expected
def test_cube_view_hbm_at_center():
v = _graph().cube_view
assert v.nodes["hbm_ctrl"].pos_mm == (6.5, 7.0)
assert v.nodes["noc"].pos_mm == (10.5, 7.0)
assert v.width_mm == 17.0
assert v.height_mm == 14.0
def test_cube_view_pe_corner_mapping():
v = _graph().cube_view
ves = {(e.src, e.dst) for e in v.edges}
# Each PE connects to its own xbar entry (chain model)
for i in range(8):
assert (f"pe{i}", f"xbar.pe{i}") in ves
# Old shared xbar.top/bottom mapping must not exist
assert ("pe0", "xbar.top") not in ves
assert ("pe4", "xbar.bottom") not in ves
# ── Views: PE ────────────────────────────────────────────────────────
def test_pe_view_has_all_components():
v = _graph().pe_view
assert set(v.nodes.keys()) == {
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"
}
def test_pe_view_edges():
v = _graph().pe_view
ves = {(e.src, e.dst) for e in v.edges}
assert ("pe_cpu", "pe_scheduler") in ves
assert ("pe_scheduler", "pe_dma") in ves
assert ("pe_scheduler", "pe_gemm") in ves
assert ("pe_scheduler", "pe_math") in ves
assert ("pe_dma", "pe_tcm") in ves
assert ("pe_gemm", "pe_tcm") in ves
assert ("pe_math", "pe_tcm") in ves
# ── SRAM ────────────────────────────────────────────────────────────
def test_sram_node_exists():
g = _graph()
assert "sip0.cube0.sram" in g.nodes
assert g.nodes["sip0.cube0.sram"].kind == "sram"
def test_noc_to_sram_edges():
es = _edge_set(_graph())
cp = "sip0.cube0"
assert (f"{cp}.noc", f"{cp}.sram") in es
assert (f"{cp}.sram", f"{cp}.noc") in es
# ── PE_DMA → NOC (non-HBM data path) ───────────────────────────────
def test_pe_dma_to_noc_edges():
es = _edge_set(_graph())
cp = "sip0.cube0"
for i in range(8):
assert (f"{cp}.pe{i}.pe_dma", f"{cp}.noc") in es
# ── Bridge connects XBAR halves (not NOC) ──────────────────────────
def test_bridge_connects_xbar_halves():
"""bridge.left connects leftmost PE nodes (pe0 top, pe4 bottom).
bridge.right connects rightmost PE nodes (pe3 top, pe7 bottom)."""
es = _edge_set(_graph())
cp = "sip0.cube0"
# bridge.left ↔ pe0 (top-left) and pe4 (bottom-left)
assert (f"{cp}.xbar.pe0", f"{cp}.bridge.left") in es
assert (f"{cp}.bridge.left", f"{cp}.xbar.pe0") in es
assert (f"{cp}.xbar.pe4", f"{cp}.bridge.left") in es
assert (f"{cp}.bridge.left", f"{cp}.xbar.pe4") in es
# bridge.right ↔ pe3 (top-right) and pe7 (bottom-right)
assert (f"{cp}.xbar.pe3", f"{cp}.bridge.right") in es
assert (f"{cp}.bridge.right", f"{cp}.xbar.pe3") in es
assert (f"{cp}.xbar.pe7", f"{cp}.bridge.right") in es
assert (f"{cp}.bridge.right", f"{cp}.xbar.pe7") in es
# Old xbar.top/bottom ↔ bridge edges must NOT exist
assert (f"{cp}.xbar.top", f"{cp}.bridge.left") not in es
assert (f"{cp}.xbar.bottom", f"{cp}.bridge.left") not in es
def test_no_bridge_to_noc_edges():
es = _edge_set(_graph())
cp = "sip0.cube0"
assert (f"{cp}.bridge.left", f"{cp}.noc") not in es
assert (f"{cp}.bridge.right", f"{cp}.noc") not in es
# ── Cube view: new edges ────────────────────────────────────────────
def test_cube_view_pe_to_noc():
v = _graph().cube_view
ves = {(e.src, e.dst) for e in v.edges}
for i in range(8):
assert (f"pe{i}", "noc") in ves
def test_cube_view_sram():
v = _graph().cube_view
assert "sram" in v.nodes
ves = {(e.src, e.dst) for e in v.edges}
assert ("noc", "sram") in ves
assert ("sram", "noc") in ves
def test_cube_view_bridge_xbar():
v = _graph().cube_view
ves = {(e.src, e.dst) for e in v.edges}
# bridge.left connects pe0 (top-left) ↔ pe4 (bottom-left)
assert ("xbar.pe0", "bridge.left") in ves
assert ("bridge.left", "xbar.pe0") in ves
assert ("xbar.pe4", "bridge.left") in ves
assert ("bridge.left", "xbar.pe4") in ves
# bridge.right connects pe3 (top-right) ↔ pe7 (bottom-right)
assert ("xbar.pe3", "bridge.right") in ves
assert ("bridge.right", "xbar.pe3") in ves
assert ("xbar.pe7", "bridge.right") in ves
assert ("bridge.right", "xbar.pe7") in ves
# ── Chain xbar: new topology edges ──────────────────────────────────
def test_xbar_chain_edges():
"""Adjacent xbar.pe nodes within each half are bidirectionally connected."""
es = _edge_set(_graph())
cp = "sip0.cube0"
# Top chain: pe0 ↔ pe1 ↔ pe2 ↔ pe3 (NW→NE direction)
for a, b in [(0, 1), (1, 2), (2, 3)]:
assert (f"{cp}.xbar.pe{a}", f"{cp}.xbar.pe{b}") in es, f"missing pe{a}→pe{b}"
assert (f"{cp}.xbar.pe{b}", f"{cp}.xbar.pe{a}") in es, f"missing pe{b}→pe{a}"
# Bottom chain: pe4 ↔ pe5 ↔ pe6 ↔ pe7
for a, b in [(4, 5), (5, 6), (6, 7)]:
assert (f"{cp}.xbar.pe{a}", f"{cp}.xbar.pe{b}") in es, f"missing pe{a}→pe{b}"
assert (f"{cp}.xbar.pe{b}", f"{cp}.xbar.pe{a}") in es, f"missing pe{b}→pe{a}"
# Negative: no cross-chain direct edges
assert (f"{cp}.xbar.pe0", f"{cp}.xbar.pe2") not in es
assert (f"{cp}.xbar.pe0", f"{cp}.xbar.pe4") not in es
def test_ucie_noc_reverse_edges():
"""UCIe ports must have reverse edges back to NOC (bidirectional)."""
es = _edge_set(_graph())
cp = "sip0.cube1" # non-edge cube to avoid io-cube edges
for port in ("N", "S", "E", "W"):
assert (f"{cp}.ucie-{port}", f"{cp}.noc") in es, \
f"missing ucie-{port}->noc reverse edge"
def test_noc_to_xbar_pe_edges():
"""NOC connects to all xbar.pe nodes (for remote cube HBM access)."""
es = _edge_set(_graph())
cp = "sip0.cube0"
for pe in range(8):
assert (f"{cp}.noc", f"{cp}.xbar.pe{pe}") in es, \
f"missing noc->xbar.pe{pe}"
+60
View File
@@ -0,0 +1,60 @@
from pathlib import Path
from kernbench.topology.builder import _read_spec, resolve_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
def test_topology_yaml_loads_without_error():
# _compile_graph is still stubbed (returns None); load must not raise
resolve_topology(str(TOPOLOGY_PATH))
def test_pe_layout_structure():
spec = _read_spec(TOPOLOGY_PATH)
pe_layout = spec["cube"]["pe_layout"]
assert set(pe_layout["corners"]) == {"NW", "NE", "SW", "SE"}
assert pe_layout["pe_per_corner"] == 2
# derived total must equal original pe_per_cube: 8
assert pe_layout["pe_per_corner"] * len(pe_layout["corners"]) == 8
def test_pe_template_components():
spec = _read_spec(TOPOLOGY_PATH)
comps = spec["cube"]["pe_template"]["components"]
assert set(comps.keys()) == {
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"
}
def test_pe_template_links_present():
spec = _read_spec(TOPOLOGY_PATH)
links = spec["cube"]["pe_template"]["links"]
required = {
"pe_cpu_to_scheduler_mm",
"scheduler_to_dma_mm",
"scheduler_to_gemm_mm",
"scheduler_to_math_mm",
"dma_to_tcm_bw_gbs", "dma_to_tcm_mm",
"gemm_to_tcm_bw_gbs", "gemm_to_tcm_mm",
"math_to_tcm_bw_gbs", "math_to_tcm_mm",
}
assert required.issubset(set(links.keys()))
def test_pe_dma_not_in_cube_components():
spec = _read_spec(TOPOLOGY_PATH)
assert "pe_dma" not in spec["cube"]["components"]
def test_pe_per_cube_removed():
spec = _read_spec(TOPOLOGY_PATH)
assert "pe_per_cube" not in spec["cube"].get("device", {})
def test_shared_resource_accel_slot():
# ADR-0014 D4: PE_GEMM and PE_MATH share PE_ACCEL capacity = 1
spec = _read_spec(TOPOLOGY_PATH)
comps = spec["cube"]["pe_template"]["components"]
assert comps["pe_gemm"]["attrs"]["shared_resource"] == "accel_slot"
assert comps["pe_math"]["attrs"]["shared_resource"] == "accel_slot"
+81
View File
@@ -0,0 +1,81 @@
from pathlib import Path
from kernbench.topology.builder import load_topology
from kernbench.topology.visualizer import emit_diagrams
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
VIEW_FILES = ["system_view.svg", "sip_view.svg", "cube_view.svg", "pe_view.svg"]
def _emit(tmp_path: Path) -> list[Path]:
graph = load_topology(TOPOLOGY_PATH)
return emit_diagrams(graph, tmp_path)
def test_emit_creates_all_svg_files(tmp_path):
created = _emit(tmp_path)
assert len(created) == 4
for name in VIEW_FILES:
assert (tmp_path / name).exists()
assert (tmp_path / name).stat().st_size > 0
def test_svg_output_is_deterministic(tmp_path):
graph = load_topology(TOPOLOGY_PATH)
emit_diagrams(graph, tmp_path)
first = {name: (tmp_path / name).read_text() for name in VIEW_FILES}
emit_diagrams(graph, tmp_path)
second = {name: (tmp_path / name).read_text() for name in VIEW_FILES}
for name in VIEW_FILES:
assert first[name] == second[name], f"{name} is not deterministic"
def test_cube_svg_contains_hbm_ctrl(tmp_path):
_emit(tmp_path)
svg = (tmp_path / "cube_view.svg").read_text()
assert "HBM CTRL" in svg
def test_cube_svg_contains_ucie_ports(tmp_path):
_emit(tmp_path)
svg = (tmp_path / "cube_view.svg").read_text()
for port in ("UCIe-N", "UCIe-S", "UCIe-W", "UCIe-E"):
assert port in svg
def test_cube_svg_contains_pe_nodes(tmp_path):
_emit(tmp_path)
svg = (tmp_path / "cube_view.svg").read_text()
for i in range(8):
assert f"PE{i}" in svg
def test_pe_svg_contains_all_components(tmp_path):
_emit(tmp_path)
svg = (tmp_path / "pe_view.svg").read_text()
for comp in ("PE CPU", "PE SCHEDULER", "PE DMA", "PE GEMM", "PE MATH", "PE TCM"):
assert comp in svg
def test_sip_svg_contains_cubes(tmp_path):
_emit(tmp_path)
svg = (tmp_path / "sip_view.svg").read_text()
assert "CUBE (0,0)" in svg
assert "CUBE (3,3)" in svg
def test_system_svg_contains_switch_and_sips(tmp_path):
_emit(tmp_path)
svg = (tmp_path / "system_view.svg").read_text()
assert "Fabric Switch" in svg
assert "SIP 0" in svg
assert "SIP 1" in svg
def test_svg_is_valid_xml(tmp_path):
_emit(tmp_path)
for name in VIEW_FILES:
svg = (tmp_path / name).read_text()
assert svg.startswith("<svg")
assert svg.strip().endswith("</svg>")
+349
View File
@@ -0,0 +1,349 @@
"""Tests for Triton emulator: TLContext, command generation, kernel registry."""
from kernbench.common.pe_commands import (
CompletionHandle,
CompositeCmd,
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
MathCmd,
PeCpuOverheadCmd,
TensorHandle,
WaitCmd,
)
from kernbench.triton_emu.registry import clear_registry, get_kernel, register_kernel
from kernbench.triton_emu.tl_context import TLContext, run_kernel
def _ctx(**kwargs) -> TLContext:
return TLContext(dispatch_cycles=0, **kwargs)
def _ctx_with_overhead(**kwargs) -> TLContext:
return TLContext(dispatch_cycles=1, **kwargs)
# ── 1. tl.load → DmaReadCmd ──────────────────────────────────────
def test_tl_load_generates_dma_read():
tl = _ctx()
h = tl.load(0x1000, shape=(32, 64), dtype="f16")
assert isinstance(h, TensorHandle)
assert h.shape == (32, 64)
assert h.nbytes == 32 * 64 * 2
cmds = tl.commands
assert len(cmds) == 1
assert isinstance(cmds[0], DmaReadCmd)
assert cmds[0].src_pa == 0x1000
assert cmds[0].nbytes == 32 * 64 * 2
# ── 2. tl.store → DmaWriteCmd ────────────────────────────────────
def test_tl_store_generates_dma_write():
tl = _ctx()
h = tl.load(0x1000, shape=(16, 16), dtype="f32")
tl.store(0x2000, h)
cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
assert len(cmds) == 1
assert cmds[0].dst_pa == 0x2000
assert cmds[0].nbytes == 16 * 16 * 4
# ── 3. tl.dot → GemmCmd ──────────────────────────────────────────
def test_tl_dot_generates_gemm_cmd():
tl = _ctx()
a = tl.load(0x1000, shape=(32, 64), dtype="f16")
b = tl.load(0x2000, shape=(64, 16), dtype="f16")
out = tl.dot(a, b)
assert out.shape == (32, 16)
cmds = [c for c in tl.commands if isinstance(c, GemmCmd)]
assert len(cmds) == 1
assert cmds[0].m == 32
assert cmds[0].k == 64
assert cmds[0].n == 16
# ── 4. tl.exp, tl.sqrt etc. → MathCmd ────────────────────────────
def test_tl_math_unary_ops():
tl = _ctx()
x = tl.load(0x1000, shape=(8, 8), dtype="f16")
for op_name, op_fn in [
("exp", tl.exp), ("log", tl.log), ("sqrt", tl.sqrt),
("abs", tl.abs), ("sigmoid", tl.sigmoid),
("cos", tl.cos), ("sin", tl.sin),
]:
result = op_fn(x)
assert isinstance(result, TensorHandle)
assert result.shape == x.shape
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
ops = [c.op for c in math_cmds]
assert ops == ["exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin"]
# ── 5. a + b, a * b → MathCmd ────────────────────────────────────
def test_tl_math_binary_ops():
tl = _ctx()
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
b = tl.load(0x2000, shape=(4, 4), dtype="f16")
r1 = run_kernel(lambda tl: None, tl) # activate context for operators
# Need active context for operators
tl2 = _ctx()
a2 = tl2.load(0x1000, shape=(4, 4), dtype="f16")
b2 = tl2.load(0x2000, shape=(4, 4), dtype="f16")
def kernel(tl):
pass
# Use run_kernel to activate context, then test operators
tl3 = _ctx()
def binary_kernel(tl):
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
b = tl.load(0x2000, shape=(4, 4), dtype="f16")
_ = a + b
_ = a - b
_ = a * b
_ = a / b
run_kernel(binary_kernel, tl3)
math_cmds = [c for c in tl3.commands if isinstance(c, MathCmd)]
ops = [c.op for c in math_cmds]
assert ops == ["add", "sub", "mul", "div"]
# ── 6. tl.sum, tl.max → MathCmd with axis ────────────────────────
def test_tl_reduction_ops():
tl = _ctx()
x = tl.load(0x1000, shape=(32, 64), dtype="f16")
s = tl.sum(x, axis=1)
m = tl.max(x, axis=0)
assert s.shape == (32, 1)
assert m.shape == (1, 64)
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
assert math_cmds[0].op == "sum" and math_cmds[0].axis == 1
assert math_cmds[1].op == "max" and math_cmds[1].axis == 0
# ── 7. tl.composite → CompositeCmd + CompletionHandle ────────────
def test_tl_composite_nonblocking():
tl = _ctx()
a = tl.load(0x1000, shape=(32, 64), dtype="f16")
b = tl.load(0x2000, shape=(64, 32), dtype="f16")
h = tl.composite(op="gemm", a=a, b=b, out_ptr=0x3000)
assert isinstance(h, CompletionHandle)
comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)]
assert len(comp_cmds) == 1
assert comp_cmds[0].op == "gemm"
assert comp_cmds[0].out_pa == 0x3000
assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes
# ── 8. tl.wait(handle) → WaitCmd ─────────────────────────────────
def test_tl_wait_specific():
tl = _ctx()
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
h = tl.composite(op="gemm", a=a, b=a, out_ptr=0x2000)
tl.wait(h)
wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)]
assert len(wait_cmds) == 1
assert wait_cmds[0].handle == h
# ── 9. tl.wait() → WaitCmd(handle=None) ──────────────────────────
def test_tl_wait_all():
tl = _ctx()
tl.wait()
wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)]
assert len(wait_cmds) == 1
assert wait_cmds[0].handle is None
# ── 10. tl.cycles → PeCpuOverheadCmd ─────────────────────────────
def test_tl_cycles():
tl = _ctx()
tl.cycles(10)
assert len(tl.commands) == 1
assert isinstance(tl.commands[0], PeCpuOverheadCmd)
assert tl.commands[0].cycles == 10
# ── 11. tl.program_id ────────────────────────────────────────────
def test_tl_program_id():
tl = TLContext(pe_id=5, num_programs=8)
assert tl.program_id(0) == 5
assert tl.num_programs(0) == 8
# ── 12. tl.arange, tl.zeros, tl.full ─────────────────────────────
def test_tl_arange_zeros_full():
tl = _ctx()
r = tl.arange(0, 16, dtype="i32")
assert r.shape == (16,)
assert r.dtype == "i32"
z = tl.zeros((4, 8), dtype="f16")
assert z.shape == (4, 8)
assert z.nbytes == 4 * 8 * 2
f = tl.full((2, 3), value=1.0, dtype="f32")
assert f.shape == (2, 3)
assert f.nbytes == 2 * 3 * 4
# ── 13. tl.trans → shape change, no command ───────────────────────
def test_tl_trans_shape():
tl = _ctx()
h = tl.load(0x1000, shape=(32, 64), dtype="f16")
t = tl.trans(h)
assert t.shape == (64, 32)
assert t.id == h.id # same underlying data
# Only DmaReadCmd from load, no command from trans
assert len(tl.commands) == 1
assert isinstance(tl.commands[0], DmaReadCmd)
# ── 14. Kernel registry ──────────────────────────────────────────
def test_kernel_registry():
clear_registry()
def my_kernel(tl):
pass
register_kernel("test_kern", my_kernel)
assert get_kernel("test_kern") is my_kernel
clear_registry()
def test_kernel_registry_missing():
clear_registry()
import pytest
with pytest.raises(KeyError):
get_kernel("nonexistent")
def test_kernel_registry_duplicate():
clear_registry()
register_kernel("dup", lambda tl: None)
import pytest
with pytest.raises(ValueError):
register_kernel("dup", lambda tl: None)
clear_registry()
# ── 15. GEMM kernel → correct command sequence ───────────────────
def test_gemm_kernel_command_sequence():
"""32×64 × 64×32 GEMM kernel produces [DmaRead, DmaRead, Composite]."""
def gemm_kernel(a_ptr, b_ptr, out_ptr, tl):
pid = tl.program_id(0)
a = tl.load(a_ptr, shape=(32, 64), dtype="f16")
b = tl.load(b_ptr + pid * 64 * 32 * 2, shape=(64, 32), dtype="f16")
tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr + pid * 32 * 32 * 2)
tl = _ctx(pe_id=3)
run_kernel(gemm_kernel, tl, a_ptr=0x1000, b_ptr=0x2000, out_ptr=0x3000)
types = [type(c).__name__ for c in tl.commands]
assert types == ["DmaReadCmd", "DmaReadCmd", "CompositeCmd"]
# ── 16. Attention kernel → correct command sequence ───────────────
def test_attention_kernel_command_sequence():
"""Attention kernel: load→dot→math ops→dot→store."""
def attention_kernel(q_ptr, k_ptr, v_ptr, out_ptr, tl,
seq_len=16, head_dim=8):
pid = tl.program_id(0)
q = tl.load(q_ptr, shape=(seq_len, head_dim), dtype="f16")
k = tl.load(k_ptr, shape=(head_dim, seq_len), dtype="f16")
scores = tl.dot(q, k)
row_max = tl.max(scores, axis=1)
scores = scores - row_max
scores = tl.exp(scores)
row_sum = tl.sum(scores, axis=1)
scores = scores / row_sum
v = tl.load(v_ptr, shape=(seq_len, head_dim), dtype="f16")
out = tl.dot(scores, v)
tl.store(out_ptr, out)
tl = _ctx(pe_id=0)
run_kernel(
attention_kernel, tl,
q_ptr=0x1000, k_ptr=0x2000, v_ptr=0x3000, out_ptr=0x4000,
)
types = [type(c).__name__ for c in tl.commands]
# load, load, dot, max, sub, exp, sum, div, load, dot, store
assert types == [
"DmaReadCmd", "DmaReadCmd", # load Q, K
"GemmCmd", # Q @ K
"MathCmd", "MathCmd", "MathCmd", # max, sub, exp
"MathCmd", "MathCmd", # sum, div
"DmaReadCmd", # load V
"GemmCmd", # scores @ V
"DmaWriteCmd", # store output
]
# Verify math ops
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
math_ops = [c.op for c in math_cmds]
assert math_ops == ["max", "sub", "exp", "sum", "div"]
# ── 17. Dispatch overhead auto-inserted ───────────────────────────
def test_dispatch_overhead_inserted():
"""Each tl API call auto-inserts PeCpuOverheadCmd when dispatch_cycles > 0."""
tl = _ctx_with_overhead()
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
tl.store(0x2000, a)
types = [type(c).__name__ for c in tl.commands]
# overhead, load, overhead, store
assert types == [
"PeCpuOverheadCmd", "DmaReadCmd",
"PeCpuOverheadCmd", "DmaWriteCmd",
]
# ── 18. where operation ──────────────────────────────────────────
def test_tl_where():
tl = _ctx()
cond = tl.load(0x1000, shape=(4, 4), dtype="i32")
a = tl.load(0x2000, shape=(4, 4), dtype="f16")
b = tl.load(0x3000, shape=(4, 4), dtype="f16")
out = tl.where(cond, a, b)
assert isinstance(out, TensorHandle)
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
assert len(math_cmds) == 1
assert math_cmds[0].op == "where"
assert len(math_cmds[0].inputs) == 3