commit - release 1
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
import kernbench.cli.main as cli_main
|
||||
|
||||
|
||||
def test_cli_main_arg_parsing(monkeypatch):
|
||||
|
||||
def fake_cmd_run(args) -> int:
|
||||
assert args.cmd == "run"
|
||||
assert args.topology == "topology.yaml"
|
||||
assert args.bench == "qkv_gemm"
|
||||
assert args.device == None
|
||||
return 0
|
||||
|
||||
# monkey patch the handler to test arg parsing without running the actual bench
|
||||
monkeypatch.setattr(cli_main, "cmd_run", fake_cmd_run)
|
||||
rc = cli_main.main(["run", "--topology", "topology.yaml", "--bench", "qkv_gemm"])
|
||||
assert rc == 0
|
||||
|
||||
|
||||
def test_cli_main():
|
||||
|
||||
rc = cli_main.main(["run", "--topology", "topology.yaml", "--bench", "qkv_gemm"])
|
||||
assert rc == 0
|
||||
@@ -0,0 +1,187 @@
|
||||
"""Tests for the SimPy component model and DI registry (ADR-0007 D3).
|
||||
|
||||
Phase 1 verification: all tests FAIL until Phase 2 implements production code.
|
||||
|
||||
Latency invariant after refactor:
|
||||
total_ns = Σ(wire propagation) + Σ(component.run() overhead_ns) + nbytes / bottleneck_bw
|
||||
This is identical to the current formula for Phase 0 (no contention).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import simpy
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.components.base import ComponentBase, ComponentRegistry
|
||||
from kernbench.components.impls.forwarding import TransitComponent
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.runtime_api.kernel import MemoryReadMsg
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import load_topology
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
|
||||
def _graph():
|
||||
return load_topology(TOPOLOGY_PATH)
|
||||
|
||||
|
||||
def _hbm_pa(pe_id: int = 0) -> int:
|
||||
slice_bytes = 48 * (1 << 30) // 8
|
||||
pa = PhysAddr.pe_hbm_addr(
|
||||
rack_id=0, sip_id=0, cube_id=0, pe_id=pe_id,
|
||||
pe_local_hbm_offset=0x1000, slice_size_bytes=slice_bytes,
|
||||
)
|
||||
return pa.encode()
|
||||
|
||||
|
||||
def _node(impl: str, overhead_ns: float = 0.0) -> Node:
|
||||
return Node(id="test", kind="xbar", impl=impl, attrs={"overhead_ns": overhead_ns}, pos_mm=None)
|
||||
|
||||
|
||||
# ── 1. unknown impl → error ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_registry_unknown_impl_raises_error():
|
||||
"""Unregistered impl raises ValueError (no fallback)."""
|
||||
node = _node("totally_unknown_v99", overhead_ns=5.0)
|
||||
with pytest.raises(ValueError, match="No component registered"):
|
||||
ComponentRegistry.create(node)
|
||||
|
||||
|
||||
# ── 2. TransitComponent yields exactly overhead_ns via simpy timeout ──
|
||||
|
||||
|
||||
def test_transit_component_yields_overhead_ns():
|
||||
"""TransitComponent.run() yields exactly node.attrs['overhead_ns'] ns."""
|
||||
node = _node("xbar_v1", overhead_ns=3.0)
|
||||
comp = TransitComponent(node)
|
||||
env = simpy.Environment()
|
||||
|
||||
def proc():
|
||||
yield from comp.run(env, nbytes=4096)
|
||||
|
||||
env.process(proc())
|
||||
env.run()
|
||||
assert env.now == pytest.approx(3.0)
|
||||
|
||||
|
||||
def test_transit_component_zero_overhead_ns():
|
||||
"""TransitComponent with overhead_ns=0 still yields (no infinite loop)."""
|
||||
node = _node("noc_v1", overhead_ns=0.0)
|
||||
comp = TransitComponent(node)
|
||||
env = simpy.Environment()
|
||||
|
||||
done = []
|
||||
|
||||
def proc():
|
||||
yield from comp.run(env, nbytes=1024)
|
||||
done.append(True)
|
||||
|
||||
env.process(proc())
|
||||
env.run()
|
||||
assert done == [True]
|
||||
assert env.now == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ── 3. DI override: custom component is invoked by engine ────────────
|
||||
|
||||
|
||||
def test_engine_component_override_is_called():
|
||||
"""Custom component injected via component_overrides is invoked during simulation."""
|
||||
|
||||
class SpyXbar(ComponentBase):
|
||||
calls = 0
|
||||
|
||||
def run(self, env, nbytes):
|
||||
SpyXbar.calls += 1
|
||||
yield env.timeout(0)
|
||||
|
||||
SpyXbar.calls = 0
|
||||
graph = _graph()
|
||||
engine = GraphEngine(graph, component_overrides={"xbar_v1": SpyXbar})
|
||||
msg = MemoryReadMsg(
|
||||
correlation_id="c", request_id="r",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
src_pa=_hbm_pa(pe_id=0), nbytes=4096,
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
# PE0→slice0 path passes through xbar.pe0 (impl=xbar_v1)
|
||||
assert SpyXbar.calls > 0
|
||||
|
||||
|
||||
# ── 4. behavior unchanged: total_ns matches existing formula ─────────
|
||||
|
||||
|
||||
def test_engine_component_model_same_latency_as_before():
|
||||
"""Phase B component model total_ns for PE0→slice0 local HBM (4096B).
|
||||
|
||||
Cut-through (wormhole) wire model: wires apply propagation only.
|
||||
Serialization (drain) is computed per-path and applied once at the terminal.
|
||||
|
||||
Forward path:
|
||||
Path 1: pcie_ep(5.0) + wire(1.0mm=0.01) + io_cpu(10.0)
|
||||
Path 2: wire(3.5mm=0.035) + ucie-N(1.0)
|
||||
+ 2DMeshNOC(ucie-N→m_cpu: Manhattan 10.9mm=0.109) + m_cpu(5.0)
|
||||
Path 3 DMA (m_cpu→noc→xbar.pe0→hbm_ctrl.slice0):
|
||||
+ 2DMeshNOC(m_cpu→xbar.pe0: Manhattan 15.0mm=0.15)
|
||||
+ xbar.pe0(2.0) + wire(2.5mm=0.025) + hbm_ctrl(0.0)
|
||||
+ drain_ns(4096/128 = 32.0, bottleneck = noc_to_xbar 128 GB/s)
|
||||
|
||||
Response path (reverse, nbytes=0, drain=0):
|
||||
DMA response: hbm_ctrl→xbar.pe0→noc→m_cpu (propagation + xbar overhead_ns)
|
||||
Command response: m_cpu→noc→ucie-N→io_cpu (propagation + ucie overhead_ns)
|
||||
|
||||
Total: ~58.648 ns
|
||||
"""
|
||||
graph = _graph()
|
||||
engine = GraphEngine(graph)
|
||||
msg = MemoryReadMsg(
|
||||
correlation_id="c", request_id="r",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
src_pa=_hbm_pa(pe_id=0), nbytes=4096,
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
_, trace = engine.get_completion(h)
|
||||
assert trace["total_ns"] == pytest.approx(58.648, rel=1e-4)
|
||||
|
||||
|
||||
# ── 5. override is scoped: only targeted impl is replaced ────────────
|
||||
|
||||
|
||||
def test_engine_override_is_scoped_to_impl():
|
||||
"""xbar_v1 override (ZeroXbar, no overhead_ns) reduces total_ns by exactly 4.0 ns.
|
||||
|
||||
xbar.pe0 has overhead_ns=2.0. It is traversed on both the forward DMA path
|
||||
and the reverse response path, so replacing it with a zero-latency impl
|
||||
removes 2.0 ns × 2 = 4.0 ns; all other components are unchanged.
|
||||
"""
|
||||
|
||||
class ZeroXbar(ComponentBase):
|
||||
def run(self, env, nbytes):
|
||||
yield env.timeout(0)
|
||||
|
||||
graph = _graph()
|
||||
engine_default = GraphEngine(graph)
|
||||
engine_override = GraphEngine(graph, component_overrides={"xbar_v1": ZeroXbar})
|
||||
|
||||
msg = MemoryReadMsg(
|
||||
correlation_id="c", request_id="r",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
src_pa=_hbm_pa(pe_id=0), nbytes=4096,
|
||||
)
|
||||
|
||||
h_d = engine_default.submit(msg)
|
||||
engine_default.wait(h_d)
|
||||
_, t_default = engine_default.get_completion(h_d)
|
||||
|
||||
h_o = engine_override.submit(msg)
|
||||
engine_override.wait(h_o)
|
||||
_, t_override = engine_override.get_completion(h_o)
|
||||
|
||||
# ZeroXbar removes overhead_ns=2.0 from xbar.pe0 on forward + response = 4.0 ns faster
|
||||
assert t_override["total_ns"] < t_default["total_ns"]
|
||||
assert t_default["total_ns"] - t_override["total_ns"] == pytest.approx(4.0, rel=1e-6)
|
||||
@@ -0,0 +1,405 @@
|
||||
import pytest
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.common.types import Completion, RequestHandle
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.runtime_api.kernel import (
|
||||
KernelLaunchMsg,
|
||||
KernelRef,
|
||||
MemoryReadMsg,
|
||||
MemoryWriteMsg,
|
||||
ScalarArg,
|
||||
TensorArg,
|
||||
TensorArgShard,
|
||||
)
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import load_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
|
||||
def _engine():
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
return GraphEngine(graph)
|
||||
|
||||
|
||||
def _hbm_pa(sip: int = 0, cube: int = 0, pe_id: int = 0) -> int:
|
||||
"""Create an HBM physical address targeting a specific PE's HBM slice."""
|
||||
# 48 GB / 8 slices = 6 GB per slice
|
||||
slice_bytes = 48 * (1 << 30) // 8
|
||||
pa = PhysAddr.pe_hbm_addr(
|
||||
rack_id=0, sip_id=sip, cube_id=cube, pe_id=pe_id,
|
||||
pe_local_hbm_offset=0x1000, slice_size_bytes=slice_bytes,
|
||||
)
|
||||
return pa.encode()
|
||||
|
||||
|
||||
def _sram_pa(sip: int = 0, cube: int = 0) -> int:
|
||||
"""Create an SRAM physical address."""
|
||||
pa = PhysAddr.cube_sram_addr(rack_id=0, sip_id=sip, cube_id=cube, sram_offset=0x800)
|
||||
return pa.encode()
|
||||
|
||||
|
||||
# ── 1. submit returns handle ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_engine_submit_returns_handle():
|
||||
"""submit() must return a RequestHandle (non-empty string)."""
|
||||
engine = _engine()
|
||||
msg = MemoryWriteMsg(
|
||||
correlation_id="c0", request_id="r0",
|
||||
dst_sip=0, dst_cube=0, dst_pe=0,
|
||||
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
|
||||
)
|
||||
handle = engine.submit(msg)
|
||||
assert isinstance(handle, str)
|
||||
assert len(handle) > 0
|
||||
|
||||
|
||||
# ── 2. memory write completion ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_engine_memory_write_completion():
|
||||
"""MemoryWrite must complete with ok=True."""
|
||||
engine = _engine()
|
||||
msg = MemoryWriteMsg(
|
||||
correlation_id="c0", request_id="r1",
|
||||
dst_sip=0, dst_cube=0, dst_pe=0,
|
||||
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
comp, trace = engine.get_completion(h)
|
||||
assert comp.ok is True
|
||||
|
||||
|
||||
# ── 3. memory read completion ───────────────────────────────────────
|
||||
|
||||
|
||||
def test_engine_memory_read_completion():
|
||||
"""MemoryRead must complete with ok=True."""
|
||||
engine = _engine()
|
||||
msg = MemoryReadMsg(
|
||||
correlation_id="c0", request_id="r2",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
src_pa=_hbm_pa(), nbytes=4096,
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
comp, trace = engine.get_completion(h)
|
||||
assert comp.ok is True
|
||||
|
||||
|
||||
# ── 4. latency positive ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_engine_latency_positive():
|
||||
"""Trace total_ns must be > 0 (ADR-0002 D4)."""
|
||||
engine = _engine()
|
||||
msg = MemoryWriteMsg(
|
||||
correlation_id="c0", request_id="r3",
|
||||
dst_sip=0, dst_cube=0, dst_pe=0,
|
||||
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
_, trace = engine.get_completion(h)
|
||||
assert trace["total_ns"] > 0
|
||||
|
||||
|
||||
# ── 5. trace has total_ns and nbytes ───────────────────────────────
|
||||
|
||||
|
||||
def test_engine_trace_has_total_ns_and_nbytes():
|
||||
"""Trace must contain 'total_ns' and 'nbytes'."""
|
||||
engine = _engine()
|
||||
msg = MemoryWriteMsg(
|
||||
correlation_id="c0", request_id="r4",
|
||||
dst_sip=0, dst_cube=0, dst_pe=0,
|
||||
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
_, trace = engine.get_completion(h)
|
||||
assert "total_ns" in trace
|
||||
assert "nbytes" in trace
|
||||
assert trace["nbytes"] == 4096
|
||||
|
||||
|
||||
# ── 6. latency includes node overhead_ns ────────────────────────────
|
||||
|
||||
|
||||
def test_engine_latency_includes_node_overhead_ns():
|
||||
"""Path traverses components with overhead_ns > 0, so total >= some minimum."""
|
||||
engine = _engine()
|
||||
msg = MemoryWriteMsg(
|
||||
correlation_id="c0", request_id="r7",
|
||||
dst_sip=0, dst_cube=0, dst_pe=0,
|
||||
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
_, trace = engine.get_completion(h)
|
||||
# pcie_ep (5.0) + io_cpu (10.0) + m_cpu (5.0) = at least 20 ns
|
||||
assert trace["total_ns"] >= 20.0
|
||||
|
||||
|
||||
# ── 7. concurrent requests ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_engine_concurrent_requests():
|
||||
"""Two requests submitted before wait must both complete with traces."""
|
||||
engine = _engine()
|
||||
msg1 = MemoryWriteMsg(
|
||||
correlation_id="c0", request_id="r9a",
|
||||
dst_sip=0, dst_cube=0, dst_pe=0,
|
||||
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
|
||||
)
|
||||
msg2 = MemoryWriteMsg(
|
||||
correlation_id="c0", request_id="r9b",
|
||||
dst_sip=0, dst_cube=0, dst_pe=1,
|
||||
dst_pa=_hbm_pa(pe_id=1), nbytes=4096, pattern="zero",
|
||||
)
|
||||
h1 = engine.submit(msg1)
|
||||
h2 = engine.submit(msg2)
|
||||
engine.wait(h1)
|
||||
engine.wait(h2)
|
||||
comp1, trace1 = engine.get_completion(h1)
|
||||
comp2, trace2 = engine.get_completion(h2)
|
||||
assert comp1.ok is True
|
||||
assert comp2.ok is True
|
||||
assert trace1["total_ns"] > 0
|
||||
assert trace2["total_ns"] > 0
|
||||
|
||||
|
||||
# ── 8. kernel launch ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_engine_kernel_launch_simplified():
|
||||
"""KernelLaunch returns latency > 0."""
|
||||
from kernbench.triton_emu.registry import clear_registry, register_kernel
|
||||
|
||||
clear_registry()
|
||||
hbm_pa = _hbm_pa(pe_id=0)
|
||||
|
||||
def gemm_kernel(a_ptr, tl):
|
||||
a = tl.load(a_ptr, shape=(4, 4), dtype="f16")
|
||||
tl.store(a_ptr, a)
|
||||
|
||||
register_kernel("gemm", gemm_kernel)
|
||||
|
||||
engine = _engine()
|
||||
shard0 = TensorArgShard(
|
||||
sip=0, cube=0, pe=0,
|
||||
pa=_hbm_pa(pe_id=0), nbytes=4096, offset_bytes=0,
|
||||
)
|
||||
shard1 = TensorArgShard(
|
||||
sip=0, cube=0, pe=1,
|
||||
pa=_hbm_pa(pe_id=1), nbytes=4096, offset_bytes=4096,
|
||||
)
|
||||
msg = KernelLaunchMsg(
|
||||
correlation_id="c0", request_id="r10",
|
||||
kernel_ref=KernelRef(name="gemm", kind="builtin"),
|
||||
args=(TensorArg(shards=(shard0, shard1)),),
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
comp, trace = engine.get_completion(h)
|
||||
assert comp.ok is True
|
||||
assert trace["total_ns"] > 0
|
||||
clear_registry()
|
||||
|
||||
|
||||
# ── 9. deterministic ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_engine_deterministic():
|
||||
"""Same request on two engines must produce identical latency."""
|
||||
msg = MemoryWriteMsg(
|
||||
correlation_id="c0", request_id="r11",
|
||||
dst_sip=0, dst_cube=0, dst_pe=0,
|
||||
dst_pa=_hbm_pa(), nbytes=4096, pattern="zero",
|
||||
)
|
||||
e1 = _engine()
|
||||
h1 = e1.submit(msg)
|
||||
e1.wait(h1)
|
||||
_, t1 = e1.get_completion(h1)
|
||||
|
||||
e2 = _engine()
|
||||
h2 = e2.submit(msg)
|
||||
e2.wait(h2)
|
||||
_, t2 = e2.get_completion(h2)
|
||||
|
||||
assert t1["total_ns"] == t2["total_ns"]
|
||||
|
||||
|
||||
# ── 10. remote cube access succeeds with higher latency ────────────
|
||||
|
||||
|
||||
def test_dma_capacity_serializes_concurrent():
|
||||
"""Two concurrent DMA writes to the same cube must contend at DMA capacity=1.
|
||||
|
||||
When two MemoryWrite requests target the same cube's M_CPU simultaneously,
|
||||
the DMA engine (capacity=1) serializes them. The slower request must take
|
||||
longer than a single isolated request (ADR-0014 D4, ADR-0015 D5).
|
||||
"""
|
||||
# Single isolated write baseline
|
||||
engine_single = _engine()
|
||||
msg_single = MemoryWriteMsg(
|
||||
correlation_id="c0", request_id="single",
|
||||
dst_sip=0, dst_cube=0, dst_pe=0,
|
||||
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
|
||||
pattern="zero", target_pe=0,
|
||||
)
|
||||
h1 = engine_single.submit(msg_single)
|
||||
engine_single.wait(h1)
|
||||
_, t1 = engine_single.get_completion(h1)
|
||||
single_ns = t1["total_ns"]
|
||||
|
||||
# Two concurrent writes to same cube (different PEs) → DMA contention
|
||||
engine_conc = _engine()
|
||||
msg_a = MemoryWriteMsg(
|
||||
correlation_id="c0", request_id="conc-a",
|
||||
dst_sip=0, dst_cube=0, dst_pe=0,
|
||||
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
|
||||
pattern="zero", target_pe=0,
|
||||
)
|
||||
msg_b = MemoryWriteMsg(
|
||||
correlation_id="c0", request_id="conc-b",
|
||||
dst_sip=0, dst_cube=0, dst_pe=1,
|
||||
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=1), nbytes=4096,
|
||||
pattern="zero", target_pe=1,
|
||||
)
|
||||
ha = engine_conc.submit(msg_a)
|
||||
hb = engine_conc.submit(msg_b)
|
||||
engine_conc.wait(ha)
|
||||
engine_conc.wait(hb)
|
||||
_, ta = engine_conc.get_completion(ha)
|
||||
_, tb = engine_conc.get_completion(hb)
|
||||
|
||||
# At least one must be delayed by DMA contention
|
||||
max_ns = max(ta["total_ns"], tb["total_ns"])
|
||||
assert max_ns > single_ns, (
|
||||
f"concurrent max ({max_ns:.2f}ns) must > single ({single_ns:.2f}ns) "
|
||||
f"due to DMA capacity=1 contention"
|
||||
)
|
||||
|
||||
|
||||
# ── 11. formula latency lower bound ──────────────────────────────
|
||||
|
||||
|
||||
def test_formula_latency_lower_bound():
|
||||
"""_formula_latency must be <= actual latency (ADR-0015 D7).
|
||||
|
||||
Uses PE DMA path which is fully known at engine level.
|
||||
"""
|
||||
from kernbench.policy.address.phyaddr import PhysAddr as PA
|
||||
from kernbench.policy.routing.router import AddressResolver, PathRouter
|
||||
from kernbench.topology.builder import load_topology as lt
|
||||
|
||||
graph = lt(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
resolver = AddressResolver(graph)
|
||||
router = PathRouter(graph)
|
||||
|
||||
pa = _hbm_pa(sip=0, cube=0, pe_id=1)
|
||||
pa_obj = PA.decode(pa)
|
||||
dst_node = resolver.resolve(pa_obj)
|
||||
pe_ref = "sip0.cube0.pe0"
|
||||
path = router.find_path(pe_ref, dst_node)
|
||||
formula = engine._formula_latency(path, 4096)
|
||||
|
||||
# Run actual simulation
|
||||
msg = MemoryReadMsg(
|
||||
correlation_id="c0", request_id="formula-lb",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
src_pa=pa, nbytes=4096, target_pe=1,
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
_, trace = engine.get_completion(h)
|
||||
actual = trace["total_ns"]
|
||||
|
||||
assert formula <= actual, (
|
||||
f"formula ({formula:.2f}) must <= actual ({actual:.2f})"
|
||||
)
|
||||
assert formula > 0, "formula must be > 0"
|
||||
|
||||
|
||||
def test_formula_latency_exact_no_contention():
|
||||
"""With no contention, formula should approximate actual for PE DMA.
|
||||
|
||||
PE DMA is single-request with no fan-out or aggregation,
|
||||
so formula ≈ actual (within small tolerance for SimPy scheduling).
|
||||
"""
|
||||
from kernbench.runtime_api.kernel import PeDmaMsg
|
||||
from kernbench.policy.address.phyaddr import PhysAddr as PA
|
||||
from kernbench.policy.routing.router import AddressResolver, PathRouter
|
||||
from kernbench.topology.builder import load_topology as lt
|
||||
|
||||
graph = lt(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
resolver = AddressResolver(graph)
|
||||
router = PathRouter(graph)
|
||||
|
||||
pa = _hbm_pa(sip=0, cube=0, pe_id=0)
|
||||
pa_obj = PA.decode(pa)
|
||||
dst_node = resolver.resolve(pa_obj)
|
||||
pe_ref = "sip0.cube0.pe0"
|
||||
path = router.find_path(pe_ref, dst_node)
|
||||
formula = engine._formula_latency(path, 4096)
|
||||
|
||||
msg = PeDmaMsg(
|
||||
correlation_id="c0", request_id="formula-exact",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
dst_pa=pa, nbytes=4096,
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
_, trace = engine.get_completion(h)
|
||||
actual = trace["total_ns"]
|
||||
|
||||
# No contention: formula should equal actual
|
||||
assert abs(formula - actual) < 0.01, (
|
||||
f"formula ({formula:.4f}) ≈ actual ({actual:.4f}) expected with no contention"
|
||||
)
|
||||
|
||||
|
||||
# ── 10. remote cube access succeeds with higher latency ────────────
|
||||
|
||||
|
||||
def test_engine_remote_cube_latency_higher():
|
||||
"""Accessing a distant cube's HBM must have strictly higher latency than local.
|
||||
|
||||
Uses separate engines to avoid contention effects.
|
||||
cube15 (far corner of 4x4 mesh) requires multiple UCIe + NOC hops
|
||||
from IO chiplet compared to cube0 (directly connected).
|
||||
"""
|
||||
engine_local = _engine()
|
||||
engine_remote = _engine()
|
||||
msg_local = MemoryReadMsg(
|
||||
correlation_id="c0", request_id="r14a",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
src_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
|
||||
)
|
||||
msg_remote = MemoryReadMsg(
|
||||
correlation_id="c0", request_id="r14b",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
src_pa=_hbm_pa(sip=0, cube=15, pe_id=0), nbytes=4096,
|
||||
)
|
||||
h_local = engine_local.submit(msg_local)
|
||||
engine_local.wait(h_local)
|
||||
_, t_local = engine_local.get_completion(h_local)
|
||||
|
||||
h_remote = engine_remote.submit(msg_remote)
|
||||
engine_remote.wait(h_remote)
|
||||
comp_remote, t_remote = engine_remote.get_completion(h_remote)
|
||||
|
||||
assert comp_remote.ok is True
|
||||
assert t_remote is not None and t_local is not None
|
||||
assert t_remote["total_ns"] > t_local["total_ns"], (
|
||||
f"remote cube {t_remote['total_ns']:.2f} must > local {t_local['total_ns']:.2f}"
|
||||
)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,269 @@
|
||||
"""Phase A component infrastructure tests (ADR-0015).
|
||||
|
||||
Verifies:
|
||||
- TransitComponent, IoCpuComponent apply overhead_ns via run()
|
||||
- HbmCtrlComponent and SramComponent act as terminal nodes (succeed done)
|
||||
- MCpuComponent forwards when not terminal; completes when terminal + no ctx
|
||||
- ComponentRegistry resolves impl strings to correct concrete classes
|
||||
- GraphEngine passes ComponentContext to every component
|
||||
- ComponentContext.router and .resolver are correctly populated
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase, ComponentRegistry
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.components.impls import (
|
||||
HbmCtrlComponent,
|
||||
IoCpuComponent,
|
||||
MCpuComponent,
|
||||
PcieEpComponent,
|
||||
SramComponent,
|
||||
TransitComponent,
|
||||
)
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
from kernbench.topology.builder import load_topology
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
|
||||
def _node(impl: str, attrs: dict | None = None) -> Node:
|
||||
return Node(id="test.node", kind="test", impl=impl, attrs=attrs or {}, pos_mm=None)
|
||||
|
||||
|
||||
def _run_worker(comp: ComponentBase, env: simpy.Environment, txn: Transaction) -> None:
|
||||
"""Wire one in_port, start the component, inject txn, run env until done."""
|
||||
in_store: simpy.Store = simpy.Store(env)
|
||||
comp.in_ports["src"] = in_store
|
||||
comp.start(env)
|
||||
env.process(_inject(in_store, txn))
|
||||
env.run(until=txn.done)
|
||||
|
||||
|
||||
def _inject(store: simpy.Store, txn: Transaction):
|
||||
yield store.put(txn)
|
||||
|
||||
|
||||
# ── 1. run() latency: TransitComponent ───────────────────────────────
|
||||
|
||||
|
||||
def test_transit_component_run_overhead_ns():
|
||||
"""TransitComponent.run() yields exactly overhead_ns."""
|
||||
node = _node("forwarding_v1", {"overhead_ns": 7.5})
|
||||
comp = TransitComponent(node)
|
||||
env = simpy.Environment()
|
||||
|
||||
def proc():
|
||||
yield from comp.run(env, nbytes=1024)
|
||||
|
||||
env.process(proc())
|
||||
env.run()
|
||||
assert env.now == pytest.approx(7.5)
|
||||
|
||||
|
||||
def test_transit_component_run_zero_overhead_ns():
|
||||
"""TransitComponent.run() with overhead_ns=0 completes immediately."""
|
||||
node = _node("noc_v1", {"overhead_ns": 0.0})
|
||||
comp = TransitComponent(node)
|
||||
env = simpy.Environment()
|
||||
done = []
|
||||
|
||||
def proc():
|
||||
yield from comp.run(env, nbytes=512)
|
||||
done.append(True)
|
||||
|
||||
env.process(proc())
|
||||
env.run()
|
||||
assert done == [True]
|
||||
assert env.now == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ── 2. run() latency: IoCpuComponent ────────────────────────────────
|
||||
|
||||
|
||||
def test_io_cpu_component_run_overhead_ns():
|
||||
"""IoCpuComponent.run() yields exactly overhead_ns."""
|
||||
node = _node("io_cpu_v1", {"overhead_ns": 10.0})
|
||||
comp = IoCpuComponent(node)
|
||||
env = simpy.Environment()
|
||||
|
||||
def proc():
|
||||
yield from comp.run(env, nbytes=2048)
|
||||
|
||||
env.process(proc())
|
||||
env.run()
|
||||
assert env.now == pytest.approx(10.0)
|
||||
|
||||
|
||||
# ── 3. Terminal: HbmCtrlComponent succeeds done ──────────────────────
|
||||
|
||||
|
||||
def test_hbm_ctrl_terminal_succeeds_done():
|
||||
"""HbmCtrlComponent is a terminal node: succeeds txn.done after run()."""
|
||||
node = _node("hbm_ctrl_v1", {"overhead_ns": 0.0, "capacity": 1})
|
||||
comp = HbmCtrlComponent(node)
|
||||
env = simpy.Environment()
|
||||
done_event = env.event()
|
||||
txn = Transaction(request=None, path=["test.node"], step=0, nbytes=256, done=done_event)
|
||||
|
||||
_run_worker(comp, env, txn)
|
||||
|
||||
assert done_event.triggered
|
||||
|
||||
|
||||
def test_hbm_ctrl_resource_serializes_requests():
|
||||
"""HbmCtrlComponent with capacity=1 serializes concurrent requests."""
|
||||
node = _node("hbm_ctrl_v1", {"overhead_ns": 5.0, "capacity": 1})
|
||||
comp = HbmCtrlComponent(node)
|
||||
env = simpy.Environment()
|
||||
in_store: simpy.Store = simpy.Store(env)
|
||||
comp.in_ports["src"] = in_store
|
||||
comp.start(env)
|
||||
|
||||
done1 = env.event()
|
||||
done2 = env.event()
|
||||
txn1 = Transaction(request=None, path=["test.node"], step=0, nbytes=0, done=done1)
|
||||
txn2 = Transaction(request=None, path=["test.node"], step=0, nbytes=0, done=done2)
|
||||
|
||||
def inject():
|
||||
yield in_store.put(txn1)
|
||||
yield in_store.put(txn2)
|
||||
|
||||
env.process(inject())
|
||||
env.run(until=done2)
|
||||
|
||||
# Both must be done; with serialization: t=5 + t=10
|
||||
assert done1.triggered
|
||||
assert done2.triggered
|
||||
assert env.now == pytest.approx(10.0)
|
||||
|
||||
|
||||
# ── 4. Terminal: SramComponent succeeds done ─────────────────────────
|
||||
|
||||
|
||||
def test_sram_terminal_succeeds_done():
|
||||
"""SramComponent is a terminal node: succeeds txn.done after run()."""
|
||||
node = _node("sram_v1", {"overhead_ns": 2.0})
|
||||
comp = SramComponent(node)
|
||||
env = simpy.Environment()
|
||||
done_event = env.event()
|
||||
txn = Transaction(request=None, path=["test.node"], step=0, nbytes=512, done=done_event)
|
||||
|
||||
_run_worker(comp, env, txn)
|
||||
|
||||
assert done_event.triggered
|
||||
assert env.now == pytest.approx(2.0)
|
||||
|
||||
|
||||
# ── 5. MCpuComponent: forward when not terminal ──────────────────────
|
||||
|
||||
|
||||
def test_m_cpu_forwards_when_not_terminal():
|
||||
"""MCpuComponent forwards Transaction to next hop when not terminal."""
|
||||
node = _node("m_cpu_v1", {"overhead_ns": 5.0})
|
||||
comp = MCpuComponent(node)
|
||||
env = simpy.Environment()
|
||||
|
||||
# Wire in_port and out_port for a two-hop path [src, test.node, next]
|
||||
in_store: simpy.Store = simpy.Store(env)
|
||||
out_store: simpy.Store = simpy.Store(env)
|
||||
comp.in_ports["src"] = in_store
|
||||
comp.out_ports["next"] = out_store
|
||||
comp.start(env)
|
||||
|
||||
done_event = env.event()
|
||||
txn = Transaction(
|
||||
request=None,
|
||||
path=["src", "test.node", "next"],
|
||||
step=1, # currently at test.node; next_hop = "next"
|
||||
nbytes=128,
|
||||
done=done_event,
|
||||
)
|
||||
|
||||
forwarded: list[Any] = []
|
||||
|
||||
def receiver():
|
||||
msg = yield out_store.get()
|
||||
forwarded.append(msg)
|
||||
msg.done.succeed()
|
||||
|
||||
env.process(receiver())
|
||||
|
||||
def inject():
|
||||
yield in_store.put(txn)
|
||||
|
||||
env.process(inject())
|
||||
env.run(until=done_event)
|
||||
|
||||
assert len(forwarded) == 1
|
||||
assert forwarded[0].step == 2 # advanced
|
||||
assert env.now == pytest.approx(5.0)
|
||||
|
||||
|
||||
# ── 6. MCpuComponent: terminal with no ctx just completes ────────────
|
||||
|
||||
|
||||
def test_m_cpu_terminal_no_ctx_completes():
|
||||
"""MCpuComponent without ctx completes txn.done when it is the terminal hop."""
|
||||
node = _node("m_cpu_v1", {"overhead_ns": 0.0})
|
||||
comp = MCpuComponent(node, ctx=None)
|
||||
env = simpy.Environment()
|
||||
done_event = env.event()
|
||||
txn = Transaction(request=None, path=["test.node"], step=0, nbytes=64, done=done_event)
|
||||
|
||||
_run_worker(comp, env, txn)
|
||||
|
||||
assert done_event.triggered
|
||||
|
||||
|
||||
# ── 7. ComponentRegistry resolves impl strings ───────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize("impl,expected_cls", [
|
||||
("forwarding_v1", TransitComponent),
|
||||
("noc_v1", TransitComponent),
|
||||
("ucie_v1", TransitComponent),
|
||||
("xbar_v1", TransitComponent),
|
||||
("pcie_ep_v1", PcieEpComponent),
|
||||
("io_cpu_v1", IoCpuComponent),
|
||||
("m_cpu_v1", MCpuComponent),
|
||||
("hbm_ctrl_v1", HbmCtrlComponent),
|
||||
("sram_v1", SramComponent),
|
||||
])
|
||||
def test_registry_resolves_impl(impl, expected_cls):
|
||||
"""ComponentRegistry.create() returns the correct concrete class for each impl."""
|
||||
node = _node(impl, {"overhead_ns": 0.0})
|
||||
comp = ComponentRegistry.create(node)
|
||||
assert isinstance(comp, expected_cls)
|
||||
|
||||
|
||||
# ── 8. GraphEngine passes ComponentContext to components ─────────────
|
||||
|
||||
|
||||
def test_engine_passes_ctx_to_components():
|
||||
"""GraphEngine injects a non-None ComponentContext into every component."""
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
for node_id, comp in engine._components.items():
|
||||
assert comp.ctx is not None, f"{node_id}: ctx is None"
|
||||
assert isinstance(comp.ctx, ComponentContext), f"{node_id}: ctx wrong type"
|
||||
|
||||
|
||||
def test_engine_ctx_router_and_resolver_populated():
|
||||
"""ComponentContext.router and .resolver are PathRouter / AddressResolver instances."""
|
||||
from kernbench.policy.routing.router import AddressResolver, PathRouter
|
||||
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
engine = GraphEngine(graph)
|
||||
# Spot-check one component
|
||||
first_comp = next(iter(engine._components.values()))
|
||||
assert isinstance(first_comp.ctx.router, PathRouter)
|
||||
assert isinstance(first_comp.ctx.resolver, AddressResolver)
|
||||
@@ -0,0 +1,268 @@
|
||||
import pytest
|
||||
|
||||
from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator
|
||||
from kernbench.policy.address.phyaddr import PhysAddr, PhysAddrError, UnitType
|
||||
|
||||
_MB = 1 << 20
|
||||
_GB = 1 << 30
|
||||
|
||||
# Topology-matching config: 48GB HBM / 8 slices / 16MB TCM / 4MB reserved / 32MB SRAM
|
||||
_CFG = AddressConfig(
|
||||
sip_count=2,
|
||||
cubes_per_sip=16,
|
||||
pes_per_cube=8,
|
||||
hbm_bytes_per_cube=48 * _GB,
|
||||
hbm_slices_per_cube=8,
|
||||
tcm_bytes_per_pe=16 * _MB,
|
||||
tcm_scheduler_reserved_bytes=4 * _MB,
|
||||
sram_bytes_per_cube=32 * _MB,
|
||||
)
|
||||
|
||||
|
||||
# ── Immutability & value semantics ──────────────────────────────────
|
||||
|
||||
|
||||
def test_physaddr_immutable():
|
||||
pa = PhysAddr.hbm_addr(rack_id=0, sip_id=0, cube_id=0, hbm_offset=0)
|
||||
with pytest.raises(AttributeError):
|
||||
pa.rack_id = 1 # type: ignore[misc]
|
||||
# hashable
|
||||
{pa}
|
||||
# comparable
|
||||
pa2 = PhysAddr.hbm_addr(rack_id=0, sip_id=0, cube_id=0, hbm_offset=0)
|
||||
assert pa == pa2
|
||||
|
||||
|
||||
# ── HBM encode/decode roundtrip ────────────────────────────────────
|
||||
|
||||
|
||||
def test_hbm_encode_decode_roundtrip():
|
||||
pa = PhysAddr.hbm_addr(rack_id=2, sip_id=3, cube_id=5, hbm_offset=0x1000)
|
||||
raw = pa.encode()
|
||||
dec = PhysAddr.decode(raw)
|
||||
assert dec.rack_id == 2
|
||||
assert dec.sip_id == 3
|
||||
assert dec.cube_id == 5
|
||||
assert dec.kind == "hbm"
|
||||
assert dec.hbm_offset == 0x1000
|
||||
|
||||
|
||||
# ── PE resource encode/decode roundtrip ─────────────────────────────
|
||||
|
||||
|
||||
def test_pe_resource_encode_decode_roundtrip():
|
||||
pa = PhysAddr(
|
||||
rack_id=1, sip_id=2, sip_seg=7, local_offset=0,
|
||||
kind="pe_resource", cube_id=7,
|
||||
unit_type=UnitType.PE, pe_id=3, ext=1, sub_offset=0xFF,
|
||||
)
|
||||
# manually build local_offset matching bit layout
|
||||
local_offset = (UnitType.PE << 34) | (3 << 30) | (1 << 29) | 0xFF
|
||||
pa2 = PhysAddr(
|
||||
rack_id=1, sip_id=2, sip_seg=7, local_offset=local_offset,
|
||||
kind="pe_resource", cube_id=7,
|
||||
unit_type=UnitType.PE, pe_id=3, ext=1, sub_offset=0xFF,
|
||||
)
|
||||
raw = pa2.encode()
|
||||
dec = PhysAddr.decode(raw)
|
||||
assert dec.kind == "pe_resource"
|
||||
assert dec.unit_type == UnitType.PE
|
||||
assert dec.pe_id == 3
|
||||
assert dec.ext == 1
|
||||
assert dec.sub_offset == 0xFF
|
||||
|
||||
|
||||
# ── pe_hbm_addr factory ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_pe_hbm_addr_factory():
|
||||
SLICE = 6 * (1 << 30) # 6 GB per PE slice
|
||||
pa = PhysAddr.pe_hbm_addr(
|
||||
rack_id=0, sip_id=0, cube_id=0,
|
||||
pe_id=2, pe_local_hbm_offset=1024, slice_size_bytes=SLICE,
|
||||
)
|
||||
assert pa.kind == "hbm"
|
||||
assert pa.cube_id == 0
|
||||
assert pa.hbm_offset == 2 * SLICE + 1024
|
||||
|
||||
|
||||
def test_pe_hbm_addr_overflow():
|
||||
SLICE = 6 * (1 << 30)
|
||||
with pytest.raises(PhysAddrError, match="pe_local_hbm_offset"):
|
||||
PhysAddr.pe_hbm_addr(
|
||||
rack_id=0, sip_id=0, cube_id=0,
|
||||
pe_id=0, pe_local_hbm_offset=SLICE, slice_size_bytes=SLICE,
|
||||
)
|
||||
|
||||
|
||||
# ── Invalid unit_type decode (fix #1) ──────────────────────────────
|
||||
|
||||
|
||||
def test_invalid_unit_type_raises():
|
||||
# Craft a PE-resource address with unit_type=7 (invalid)
|
||||
local_offset = (7 << 34) | (0 << 30) | 0
|
||||
pa_raw = PhysAddr(
|
||||
rack_id=0, sip_id=0, sip_seg=0, local_offset=local_offset,
|
||||
)
|
||||
raw = pa_raw.encode()
|
||||
with pytest.raises(PhysAddrError, match="unit_type"):
|
||||
PhysAddr.decode(raw)
|
||||
|
||||
|
||||
# ── hbm_pe_id utility (fix #3) ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_hbm_pe_id_utility():
|
||||
SLICE = 6 * (1 << 30) # 6 GB
|
||||
pa = PhysAddr.pe_hbm_addr(
|
||||
rack_id=0, sip_id=0, cube_id=0,
|
||||
pe_id=5, pe_local_hbm_offset=256, slice_size_bytes=SLICE,
|
||||
)
|
||||
assert PhysAddr.hbm_pe_id(pa.hbm_offset, SLICE) == 5
|
||||
|
||||
|
||||
# ── UnitType.SRAM exists (fix #5) ──────────────────────────────────
|
||||
|
||||
|
||||
def test_sram_unit_type_exists():
|
||||
assert UnitType.SRAM == 2
|
||||
|
||||
|
||||
# ── cube_sram_addr factory + roundtrip ──────────────────────────────
|
||||
|
||||
|
||||
def test_cube_sram_addr_roundtrip():
|
||||
pa = PhysAddr.cube_sram_addr(
|
||||
rack_id=0, sip_id=1, cube_id=3, sram_offset=0x800,
|
||||
)
|
||||
assert pa.kind == "pe_resource"
|
||||
assert pa.unit_type == UnitType.SRAM
|
||||
assert pa.cube_id == 3
|
||||
assert pa.sub_offset == 0x800
|
||||
# encode → decode roundtrip
|
||||
dec = PhysAddr.decode(pa.encode())
|
||||
assert dec.unit_type == UnitType.SRAM
|
||||
assert dec.cube_id == 3
|
||||
assert dec.sub_offset == 0x800
|
||||
|
||||
|
||||
def test_cube_sram_addr_range_check():
|
||||
with pytest.raises(PhysAddrError):
|
||||
PhysAddr.cube_sram_addr(
|
||||
rack_id=0, sip_id=0, cube_id=0,
|
||||
sram_offset=(1 << 29), # exceeds 29-bit sub_offset
|
||||
)
|
||||
|
||||
|
||||
# ── pe_tcm_addr factory + roundtrip ────────────────────────────────
|
||||
|
||||
|
||||
def test_pe_tcm_addr_roundtrip():
|
||||
pa = PhysAddr.pe_tcm_addr(
|
||||
rack_id=0, sip_id=0, cube_id=2, pe_id=7, tcm_offset=0x400,
|
||||
)
|
||||
assert pa.kind == "pe_resource"
|
||||
assert pa.unit_type == UnitType.PE
|
||||
assert pa.pe_id == 7
|
||||
assert pa.cube_id == 2
|
||||
assert pa.sub_offset == 0x400
|
||||
# encode → decode roundtrip
|
||||
dec = PhysAddr.decode(pa.encode())
|
||||
assert dec.unit_type == UnitType.PE
|
||||
assert dec.pe_id == 7
|
||||
assert dec.sub_offset == 0x400
|
||||
|
||||
|
||||
def test_pe_tcm_addr_range_check():
|
||||
with pytest.raises(PhysAddrError):
|
||||
PhysAddr.pe_tcm_addr(
|
||||
rack_id=0, sip_id=0, cube_id=0, pe_id=0,
|
||||
tcm_offset=(1 << 29), # exceeds 29-bit sub_offset
|
||||
)
|
||||
|
||||
|
||||
# ── AddressConfig ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_address_config_derived_sizes():
|
||||
assert _CFG.hbm_slice_bytes == 6 * _GB
|
||||
assert _CFG.tcm_allocatable_bytes == 12 * _MB
|
||||
|
||||
|
||||
# ── PEMemAllocator: HBM ────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_alloc(pe_id: int = 0) -> PEMemAllocator:
|
||||
return PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=pe_id, cfg=_CFG)
|
||||
|
||||
|
||||
def test_allocator_hbm_basic():
|
||||
a = _make_alloc(pe_id=3)
|
||||
pa = a.alloc_hbm(4096)
|
||||
assert pa.kind == "hbm"
|
||||
assert pa.sip_id == 0
|
||||
assert pa.cube_id == 0
|
||||
# hbm_offset should be pe3's slice start
|
||||
assert pa.hbm_offset == 3 * 6 * _GB
|
||||
|
||||
|
||||
def test_allocator_hbm_sequential():
|
||||
a = _make_alloc()
|
||||
pa1 = a.alloc_hbm(1024)
|
||||
pa2 = a.alloc_hbm(2048)
|
||||
assert pa1.hbm_offset == 0 # pe0 slice start + 0
|
||||
assert pa2.hbm_offset == 1024 # pe0 slice start + 1024
|
||||
|
||||
|
||||
def test_allocator_hbm_overflow():
|
||||
a = _make_alloc()
|
||||
a.alloc_hbm(6 * _GB - 256)
|
||||
with pytest.raises(AllocationError, match="HBM"):
|
||||
a.alloc_hbm(512)
|
||||
|
||||
|
||||
# ── PEMemAllocator: TCM ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_allocator_tcm_basic():
|
||||
a = _make_alloc(pe_id=5)
|
||||
pa = a.alloc_tcm(256)
|
||||
assert pa.kind == "pe_resource"
|
||||
assert pa.unit_type == UnitType.PE
|
||||
assert pa.pe_id == 5
|
||||
assert pa.sub_offset == 0
|
||||
|
||||
|
||||
def test_allocator_tcm_respects_reserved():
|
||||
a = _make_alloc()
|
||||
# allocatable = 12 MB, should succeed
|
||||
a.alloc_tcm(12 * _MB)
|
||||
assert a.tcm_used == 12 * _MB
|
||||
assert a.tcm_total == 12 * _MB
|
||||
|
||||
|
||||
def test_allocator_tcm_overflow():
|
||||
a = _make_alloc()
|
||||
a.alloc_tcm(12 * _MB)
|
||||
with pytest.raises(AllocationError, match="TCM"):
|
||||
a.alloc_tcm(1)
|
||||
|
||||
|
||||
# ── PEMemAllocator: stats & determinism ─────────────────────────────
|
||||
|
||||
|
||||
def test_allocator_stats():
|
||||
a = _make_alloc()
|
||||
a.alloc_hbm(1000)
|
||||
a.alloc_tcm(500)
|
||||
assert a.hbm_used == 1000
|
||||
assert a.hbm_total == 6 * _GB
|
||||
assert a.tcm_used == 500
|
||||
assert a.tcm_total == 12 * _MB
|
||||
|
||||
|
||||
def test_allocator_deterministic():
|
||||
a1 = _make_alloc(pe_id=2)
|
||||
a2 = _make_alloc(pe_id=2)
|
||||
assert a1.alloc_hbm(4096) == a2.alloc_hbm(4096)
|
||||
assert a1.alloc_tcm(256) == a2.alloc_tcm(256)
|
||||
@@ -0,0 +1,221 @@
|
||||
"""Tests for H2D writes and PE DMA probe latency invariants.
|
||||
|
||||
H2D tests use MemoryWriteMsg (pcie_ep → io_cpu → m_cpu → hbm_ctrl → response).
|
||||
PE DMA tests use PeDmaMsg (direct pe_dma → xbar → hbm_ctrl injection).
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.policy.routing.router import AddressResolver, PathRouter
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import load_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
|
||||
def _engine():
|
||||
return GraphEngine(load_topology(TOPOLOGY_PATH))
|
||||
|
||||
|
||||
def _hbm_pa(sip: int = 0, cube: int = 0, pe_id: int = 0) -> int:
|
||||
slice_bytes = 48 * (1 << 30) // 8
|
||||
pa = PhysAddr.pe_hbm_addr(
|
||||
rack_id=0, sip_id=sip, cube_id=cube, pe_id=pe_id,
|
||||
pe_local_hbm_offset=0x1000, slice_size_bytes=slice_bytes,
|
||||
)
|
||||
return pa.encode()
|
||||
|
||||
|
||||
def _h2d_latency(dst_cube: int, dst_pe: int = 0) -> float:
|
||||
engine = _engine()
|
||||
msg = MemoryWriteMsg(
|
||||
correlation_id="probe", request_id=f"h2d-c{dst_cube}-p{dst_pe}",
|
||||
dst_sip=0, dst_cube=dst_cube, dst_pe=dst_pe,
|
||||
dst_pa=_hbm_pa(sip=0, cube=dst_cube, pe_id=dst_pe), nbytes=4096,
|
||||
pattern="zero", target_pe=dst_pe,
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
_, trace = engine.get_completion(h)
|
||||
return trace["total_ns"]
|
||||
|
||||
|
||||
# ── 1. Single-PE write completes ──────────────────────────────────
|
||||
|
||||
|
||||
def test_single_pe_write_completes():
|
||||
"""MemoryWriteMsg(target_pe=0) must complete with ok=True, latency > 0."""
|
||||
engine = _engine()
|
||||
msg = MemoryWriteMsg(
|
||||
correlation_id="probe", request_id="pe-local",
|
||||
dst_sip=0, dst_cube=0, dst_pe=0,
|
||||
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
|
||||
pattern="zero", target_pe=0,
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
comp, trace = engine.get_completion(h)
|
||||
assert comp.ok is True
|
||||
assert trace["total_ns"] > 0
|
||||
|
||||
|
||||
# ── 2. Cross-cube write positive latency ─────────────────────────
|
||||
|
||||
|
||||
def test_cross_cube_write_positive():
|
||||
"""Cross-cube MemoryWriteMsg(target_pe=0) must complete with latency > 0."""
|
||||
lat = _h2d_latency(dst_cube=1, dst_pe=0)
|
||||
assert lat > 0
|
||||
|
||||
|
||||
# ── 3. H2D latency monotonicity ──────────────────────────────────
|
||||
|
||||
|
||||
def test_h2d_latency_monotonic():
|
||||
"""1hop < 2hop < 3hop < 4hop."""
|
||||
cubes = [0, 4, 8, 12]
|
||||
latencies: list[tuple[int, float]] = []
|
||||
for cube in cubes:
|
||||
lat = _h2d_latency(dst_cube=cube, dst_pe=0)
|
||||
latencies.append((cube, lat))
|
||||
|
||||
for i in range(len(latencies) - 1):
|
||||
assert latencies[i][1] < latencies[i + 1][1], (
|
||||
f"cube{latencies[i][0]}({latencies[i][1]:.2f}) "
|
||||
f"must < cube{latencies[i + 1][0]}({latencies[i + 1][1]:.2f})"
|
||||
)
|
||||
|
||||
|
||||
# ── 4. Single-PE write deterministic ─────────────────────────────
|
||||
|
||||
|
||||
def test_single_pe_write_deterministic():
|
||||
"""Same MemoryWriteMsg on two engines must produce identical latency."""
|
||||
msg = MemoryWriteMsg(
|
||||
correlation_id="probe", request_id="det",
|
||||
dst_sip=0, dst_cube=0, dst_pe=0,
|
||||
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
|
||||
pattern="zero", target_pe=0,
|
||||
)
|
||||
e1 = _engine()
|
||||
h1 = e1.submit(msg)
|
||||
e1.wait(h1)
|
||||
_, t1 = e1.get_completion(h1)
|
||||
|
||||
e2 = _engine()
|
||||
h2 = e2.submit(msg)
|
||||
e2.wait(h2)
|
||||
_, t2 = e2.get_completion(h2)
|
||||
|
||||
assert t1["total_ns"] == t2["total_ns"]
|
||||
|
||||
|
||||
# ── 5. Cut-through (wormhole) wire model invariants ──────────────
|
||||
|
||||
|
||||
def test_h2d_local_cube_cut_through():
|
||||
"""H2D to local cube with cut-through should be < 50ns for 4096B.
|
||||
|
||||
Full command path: pcie_ep → io_cpu → ucie → noc → m_cpu
|
||||
DMA: m_cpu → noc → xbar → hbm_ctrl (drain once at terminal)
|
||||
Plus response path back.
|
||||
With store-and-forward each hop would serialize; cut-through keeps it low.
|
||||
"""
|
||||
lat = _h2d_latency(dst_cube=0, dst_pe=0)
|
||||
assert lat < 65.0, f"Local H2D {lat:.2f}ns; cut-through expects < 65ns"
|
||||
|
||||
|
||||
def test_h2d_remote_cube_cut_through():
|
||||
"""H2D to 1-hop remote cube: cut-through drain dominates, not per-hop serialization.
|
||||
|
||||
With store-and-forward, each hop would serialize 4096B, total >> 100ns.
|
||||
With cut-through, drain happens once at bottleneck.
|
||||
"""
|
||||
lat = _h2d_latency(dst_cube=4, dst_pe=0)
|
||||
assert lat < 80.0, f"Remote H2D {lat:.2f}ns; cut-through expects < 80ns"
|
||||
|
||||
|
||||
# ── 6. PE DMA: direct injection tests ─────────────────────────
|
||||
|
||||
|
||||
def _graph():
|
||||
return load_topology(TOPOLOGY_PATH)
|
||||
|
||||
|
||||
def _pe_dma_latency(src_cube: int, src_pe: int, dst_pe: int) -> float:
|
||||
engine = _engine()
|
||||
msg = PeDmaMsg(
|
||||
correlation_id="probe", request_id=f"dma-c{src_cube}-p{src_pe}-s{dst_pe}",
|
||||
src_sip=0, src_cube=src_cube, src_pe=src_pe,
|
||||
dst_pa=_hbm_pa(sip=0, cube=src_cube, pe_id=dst_pe), nbytes=4096,
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
_, trace = engine.get_completion(h)
|
||||
return trace["total_ns"]
|
||||
|
||||
|
||||
def _pe_dma_bottleneck(src_cube: int, src_pe: int, dst_pe: int) -> float | None:
|
||||
graph = _graph()
|
||||
edge_map = {(e.src, e.dst): e for e in graph.edges}
|
||||
resolver = AddressResolver(graph)
|
||||
router = PathRouter(graph)
|
||||
pa = _hbm_pa(sip=0, cube=src_cube, pe_id=dst_pe)
|
||||
pa_obj = PhysAddr.decode(pa)
|
||||
dst_node = resolver.resolve(pa_obj)
|
||||
pe_ref = f"sip0.cube{src_cube}.pe{src_pe}"
|
||||
path = router.find_path(pe_ref, dst_node)
|
||||
bws: list[float] = []
|
||||
for i in range(len(path) - 1):
|
||||
e = edge_map.get((path[i], path[i + 1]))
|
||||
if e and e.bw_gbs:
|
||||
bws.append(e.bw_gbs)
|
||||
return min(bws) if bws else None
|
||||
|
||||
|
||||
def test_pe_dma_local_completes():
|
||||
"""PeDmaMsg to local slice0 must complete with ok=True, latency > 0."""
|
||||
engine = _engine()
|
||||
msg = PeDmaMsg(
|
||||
correlation_id="probe", request_id="dma-local",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
comp, trace = engine.get_completion(h)
|
||||
assert comp.ok is True
|
||||
assert trace["total_ns"] > 0
|
||||
|
||||
|
||||
def test_pe_dma_local_bottleneck_256():
|
||||
"""PE DMA pe0→slice0 (local): bottleneck = 256 GB/s (direct xbar→hbm)."""
|
||||
bn = _pe_dma_bottleneck(src_cube=0, src_pe=0, dst_pe=0)
|
||||
assert bn == 256.0, f"Local PE DMA bottleneck {bn}, expected 256.0"
|
||||
|
||||
|
||||
def test_pe_dma_chain_bottleneck_128():
|
||||
"""PE DMA pe0→slice1 (xbar chain): bottleneck = 128 GB/s."""
|
||||
bn = _pe_dma_bottleneck(src_cube=0, src_pe=0, dst_pe=1)
|
||||
assert bn == 128.0, f"Chain PE DMA bottleneck {bn}, expected 128.0"
|
||||
|
||||
|
||||
def test_pe_dma_deterministic():
|
||||
"""Same PeDmaMsg on two engines must produce identical latency."""
|
||||
msg = PeDmaMsg(
|
||||
correlation_id="probe", request_id="det",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
dst_pa=_hbm_pa(sip=0, cube=0, pe_id=0), nbytes=4096,
|
||||
)
|
||||
e1 = _engine()
|
||||
h1 = e1.submit(msg)
|
||||
e1.wait(h1)
|
||||
_, t1 = e1.get_completion(h1)
|
||||
|
||||
e2 = _engine()
|
||||
h2 = e2.submit(msg)
|
||||
e2.wait(h2)
|
||||
_, t2 = e2.get_completion(h2)
|
||||
|
||||
assert t1["total_ns"] == t2["total_ns"]
|
||||
@@ -0,0 +1,226 @@
|
||||
import pytest
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.policy.address.phyaddr import PhysAddr, UnitType
|
||||
from kernbench.policy.routing.router import AddressResolver, PathRouter, RoutingError
|
||||
from kernbench.topology.builder import load_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
|
||||
def _graph():
|
||||
return load_topology(TOPOLOGY_PATH)
|
||||
|
||||
|
||||
# ── AddressResolver ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_resolve_hbm_addr():
|
||||
"""HBM address -> sip{S}.cube{C}.hbm_ctrl.slice{P}"""
|
||||
g = _graph()
|
||||
resolver = AddressResolver(g)
|
||||
# hbm_offset=0x1000, slice_size=6GB -> slice 0
|
||||
pa = PhysAddr.hbm_addr(rack_id=0, sip_id=0, cube_id=3, hbm_offset=0x1000)
|
||||
assert resolver.resolve(pa) == "sip0.cube3.hbm_ctrl.slice0"
|
||||
|
||||
|
||||
def test_resolve_hbm_addr_slice4():
|
||||
"""HBM address in PE4's slice range -> slice4."""
|
||||
g = _graph()
|
||||
resolver = AddressResolver(g)
|
||||
# slice_size = 6GB; PE4 offset starts at 4*6GB = 24GB = 0x600000000
|
||||
pa = PhysAddr.hbm_addr(rack_id=0, sip_id=0, cube_id=0, hbm_offset=0x600000000)
|
||||
assert resolver.resolve(pa) == "sip0.cube0.hbm_ctrl.slice4"
|
||||
|
||||
|
||||
def test_resolve_pe_tcm_addr():
|
||||
"""PE TCM address → sip{S}.cube{C}.pe{P}.pe_tcm"""
|
||||
g = _graph()
|
||||
resolver = AddressResolver(g)
|
||||
pa = PhysAddr.pe_tcm_addr(rack_id=0, sip_id=1, cube_id=5, pe_id=7, tcm_offset=0x400)
|
||||
assert resolver.resolve(pa) == "sip1.cube5.pe7.pe_tcm"
|
||||
|
||||
|
||||
def test_resolve_sram_addr():
|
||||
"""SRAM address → sip{S}.cube{C}.sram"""
|
||||
g = _graph()
|
||||
resolver = AddressResolver(g)
|
||||
pa = PhysAddr.cube_sram_addr(rack_id=0, sip_id=0, cube_id=10, sram_offset=0x800)
|
||||
assert resolver.resolve(pa) == "sip0.cube10.sram"
|
||||
|
||||
|
||||
def test_resolve_mcpu_addr():
|
||||
"""MCPU pe_resource address → sip{S}.cube{C}.m_cpu"""
|
||||
g = _graph()
|
||||
resolver = AddressResolver(g)
|
||||
pa = PhysAddr(
|
||||
rack_id=0, sip_id=0, sip_seg=2, local_offset=(UnitType.MCPU << 34),
|
||||
kind="pe_resource", cube_id=2, unit_type=UnitType.MCPU,
|
||||
)
|
||||
assert resolver.resolve(pa) == "sip0.cube2.m_cpu"
|
||||
|
||||
|
||||
def test_resolve_nonexistent_node():
|
||||
"""Address pointing to a node outside the compiled topology raises RoutingError."""
|
||||
g = _graph()
|
||||
resolver = AddressResolver(g)
|
||||
# sip_id=15 doesn't exist in the 2-SIP topology
|
||||
pa = PhysAddr.hbm_addr(rack_id=0, sip_id=15, cube_id=0, hbm_offset=0)
|
||||
with pytest.raises(RoutingError):
|
||||
resolver.resolve(pa)
|
||||
|
||||
|
||||
# ── PathRouter: local HBM (same xbar half) ──────────────────────────
|
||||
|
||||
|
||||
def test_path_local_hbm_same_half():
|
||||
"""PE0 -> slice0 (local): pe_dma -> xbar.pe0 -> hbm_ctrl.slice0 (no chain hops)."""
|
||||
g = _graph()
|
||||
router = PathRouter(g)
|
||||
path = router.find_path("sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice0")
|
||||
assert path[0] == "sip0.cube0.pe0.pe_dma"
|
||||
assert "sip0.cube0.xbar.pe0" in path
|
||||
assert path[-1] == "sip0.cube0.hbm_ctrl.slice0"
|
||||
# local access: no bridge and no chain traversal (shortest path = 3 nodes)
|
||||
assert not any("bridge" in n for n in path)
|
||||
assert len(path) == 3 # pe_dma → xbar.pe0 → slice0
|
||||
|
||||
|
||||
# ── PathRouter: same-half remote HBM ────────────────────────────────
|
||||
|
||||
|
||||
def test_path_same_half_remote_hbm():
|
||||
"""PE0 -> slice1: same-half chain traversal pe0→pe1, no bridge."""
|
||||
g = _graph()
|
||||
router = PathRouter(g)
|
||||
path = router.find_path("sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice1")
|
||||
assert path[0] == "sip0.cube0.pe0.pe_dma"
|
||||
assert "sip0.cube0.xbar.pe0" in path # enter at pe0
|
||||
assert "sip0.cube0.xbar.pe1" in path # chain hop to pe1
|
||||
assert path[-1] == "sip0.cube0.hbm_ctrl.slice1"
|
||||
assert not any("bridge" in n for n in path)
|
||||
assert len(path) == 4 # pe_dma → xbar.pe0 → xbar.pe1 → slice1
|
||||
|
||||
|
||||
# ── PathRouter: cross-half HBM ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_path_cross_half_hbm():
|
||||
"""PE0 -> slice4 (cross-half): pe_dma → xbar.pe0 → bridge.left → xbar.pe4 → slice4."""
|
||||
g = _graph()
|
||||
router = PathRouter(g)
|
||||
path = router.find_path("sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice4")
|
||||
assert path[0] == "sip0.cube0.pe0.pe_dma"
|
||||
assert "sip0.cube0.xbar.pe0" in path
|
||||
assert any("bridge" in n for n in path), "cross-half HBM must traverse bridge"
|
||||
assert "sip0.cube0.xbar.pe4" in path
|
||||
assert path[-1] == "sip0.cube0.hbm_ctrl.slice4"
|
||||
# Shortest cross-half path: pe_dma → xbar.pe0 → bridge.left → xbar.pe4 → slice4
|
||||
assert len(path) == 5
|
||||
|
||||
|
||||
def test_path_cross_half_requires_bridge():
|
||||
"""PE4 (bottom) -> slice2 (top) requires bridge traversal."""
|
||||
g = _graph()
|
||||
router = PathRouter(g)
|
||||
path = router.find_path("sip0.cube0.pe4", "sip0.cube0.hbm_ctrl.slice2")
|
||||
assert any("bridge" in n for n in path), "cross-half HBM must traverse bridge"
|
||||
assert any("xbar.pe" in n for n in path)
|
||||
assert path[-1] == "sip0.cube0.hbm_ctrl.slice2"
|
||||
|
||||
|
||||
def test_cross_half_distance_greater():
|
||||
"""Cross-half HBM access must have greater distance than local-half."""
|
||||
g = _graph()
|
||||
router = PathRouter(g)
|
||||
_, dist_local = router.find_path_with_distance(
|
||||
"sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice0")
|
||||
_, dist_cross = router.find_path_with_distance(
|
||||
"sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice4")
|
||||
assert dist_cross > dist_local
|
||||
|
||||
|
||||
def test_path_same_half_remote_longer():
|
||||
"""Same-half remote HBM (PE0->slice3) has greater distance than local (PE0->slice0)."""
|
||||
g = _graph()
|
||||
router = PathRouter(g)
|
||||
_, dist_local = router.find_path_with_distance(
|
||||
"sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice0")
|
||||
_, dist_remote = router.find_path_with_distance(
|
||||
"sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice3")
|
||||
assert dist_remote > dist_local, (
|
||||
f"same-half remote ({dist_remote:.2f}mm) must > local ({dist_local:.2f}mm)"
|
||||
)
|
||||
|
||||
|
||||
def test_path_remote_cube_hbm():
|
||||
"""PE0 in cube0 can reach HBM in cube1 via UCIe (ADR-0004 D4)."""
|
||||
g = _graph()
|
||||
router = PathRouter(g)
|
||||
path = router.find_path("sip0.cube0.pe0", "sip0.cube1.hbm_ctrl.slice0")
|
||||
assert path[0] == "sip0.cube0.pe0.pe_dma"
|
||||
assert path[-1] == "sip0.cube1.hbm_ctrl.slice0"
|
||||
# inter-cube path must cross a UCIe link
|
||||
assert any("ucie" in n for n in path), "remote cube path must traverse UCIe"
|
||||
# must not be trivially short (needs noc + ucie + remote noc + xbar)
|
||||
assert len(path) >= 5
|
||||
|
||||
|
||||
# ── PathRouter: SRAM via NOC ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_path_sram_via_noc():
|
||||
"""PE → SRAM must go through NOC (non-HBM data path)."""
|
||||
g = _graph()
|
||||
router = PathRouter(g)
|
||||
path = router.find_path("sip0.cube0.pe0", "sip0.cube0.sram")
|
||||
assert path[0] == "sip0.cube0.pe0.pe_dma"
|
||||
assert "sip0.cube0.noc" in path
|
||||
assert path[-1] == "sip0.cube0.sram"
|
||||
# should NOT go through xbar (SRAM is non-HBM path)
|
||||
assert not any("xbar" in n for n in path)
|
||||
|
||||
|
||||
# ── PathRouter: PE TCM (local) ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_path_local_tcm():
|
||||
"""PE0 → own TCM is PE-internal, not via xbar or noc."""
|
||||
g = _graph()
|
||||
router = PathRouter(g)
|
||||
path = router.find_path("sip0.cube0.pe0", "sip0.cube0.pe0.pe_tcm")
|
||||
assert path[0] == "sip0.cube0.pe0.pe_dma"
|
||||
assert path[-1] == "sip0.cube0.pe0.pe_tcm"
|
||||
# PE-internal path, no fabric
|
||||
assert not any("xbar" in n or "noc" in n for n in path)
|
||||
|
||||
|
||||
# ── PathRouter: distance monotonic ──────────────────────────────────
|
||||
|
||||
|
||||
def test_path_distance_positive():
|
||||
"""All routed paths must have accumulated distance > 0 (ADR-0002 D4)."""
|
||||
g = _graph()
|
||||
router = PathRouter(g)
|
||||
_, dist = router.find_path_with_distance("sip0.cube0.pe0", "sip0.cube0.hbm_ctrl.slice0")
|
||||
assert dist > 0
|
||||
|
||||
|
||||
def test_path_deterministic():
|
||||
"""Same (src, dst) must always produce the same path."""
|
||||
g = _graph()
|
||||
r1 = PathRouter(g)
|
||||
r2 = PathRouter(g)
|
||||
p1 = r1.find_path("sip0.cube0.pe3", "sip0.cube0.hbm_ctrl.slice3")
|
||||
p2 = r2.find_path("sip0.cube0.pe3", "sip0.cube0.hbm_ctrl.slice3")
|
||||
assert p1 == p2
|
||||
|
||||
|
||||
def test_remote_cube_path_no_routing_error():
|
||||
"""Routing to remote cube HBM must not raise RoutingError (ADR-0004 D4)."""
|
||||
g = _graph()
|
||||
router = PathRouter(g)
|
||||
# cube0.PE0 -> cube1.slice0 (adjacent cube, E direction)
|
||||
path = router.find_path("sip0.cube0.pe0", "sip0.cube1.hbm_ctrl.slice0")
|
||||
assert len(path) >= 1 # succeeds without exception
|
||||
@@ -0,0 +1,282 @@
|
||||
import pytest
|
||||
|
||||
from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator
|
||||
from kernbench.policy.placement.dp import (
|
||||
ShardSpec,
|
||||
column_wise,
|
||||
tiled_column_major,
|
||||
replicate,
|
||||
row_wise,
|
||||
tiled_row_major,
|
||||
)
|
||||
from kernbench.runtime_api.kernel import (
|
||||
KernelLaunchMsg,
|
||||
KernelRef,
|
||||
MemoryReadMsg,
|
||||
MemoryWriteMsg,
|
||||
ScalarArg,
|
||||
TensorArg,
|
||||
TensorArgShard,
|
||||
)
|
||||
from kernbench.runtime_api.tensor import (
|
||||
TensorHandle,
|
||||
TensorShard,
|
||||
deploy_tensor,
|
||||
dtype_itemsize,
|
||||
)
|
||||
|
||||
_MB = 1 << 20
|
||||
_GB = 1 << 30
|
||||
|
||||
_CFG = AddressConfig(
|
||||
sip_count=2,
|
||||
cubes_per_sip=16,
|
||||
pes_per_cube=8,
|
||||
hbm_bytes_per_cube=48 * _GB,
|
||||
hbm_slices_per_cube=8,
|
||||
tcm_bytes_per_pe=16 * _MB,
|
||||
tcm_scheduler_reserved_bytes=4 * _MB,
|
||||
sram_bytes_per_cube=32 * _MB,
|
||||
)
|
||||
|
||||
|
||||
def _make_allocators(num_pe: int = 8) -> dict[int, PEMemAllocator]:
|
||||
return {
|
||||
i: PEMemAllocator(rack_id=0, sip_id=0, cube_id=0, pe_id=i, cfg=_CFG)
|
||||
for i in range(num_pe)
|
||||
}
|
||||
|
||||
|
||||
# ── Tensor types ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tensor_shard_immutable():
|
||||
ts = TensorShard(sip=0, cube=0, pe=0, pa=0x1000, nbytes=4096, offset_bytes=0)
|
||||
with pytest.raises(AttributeError):
|
||||
ts.pa = 0x2000 # type: ignore[misc]
|
||||
# hashable
|
||||
{ts}
|
||||
|
||||
|
||||
def test_tensor_handle_nbytes():
|
||||
th = TensorHandle(
|
||||
name="A",
|
||||
shape=(1024, 512),
|
||||
dtype="fp16",
|
||||
itemsize=2,
|
||||
shards=(),
|
||||
)
|
||||
assert th.nbytes == 1024 * 512 * 2 # 1 MB
|
||||
|
||||
|
||||
# ── Message types (ADR-0012) ─────────────────────────────────────────
|
||||
|
||||
|
||||
def test_memory_write_msg_fields():
|
||||
msg = MemoryWriteMsg(
|
||||
correlation_id="c0",
|
||||
request_id="r0",
|
||||
dst_sip=0,
|
||||
dst_cube=3,
|
||||
dst_pe=5,
|
||||
dst_pa=0xDEAD,
|
||||
nbytes=4096,
|
||||
pattern="zero",
|
||||
)
|
||||
assert msg.msg_type == "memory_write"
|
||||
assert msg.src_kind == "pattern"
|
||||
assert msg.dst_pa == 0xDEAD
|
||||
assert msg.pattern == "zero"
|
||||
with pytest.raises(AttributeError):
|
||||
msg.nbytes = 0 # type: ignore[misc]
|
||||
|
||||
|
||||
def test_memory_read_msg_fields():
|
||||
msg = MemoryReadMsg(
|
||||
correlation_id="c0",
|
||||
request_id="r1",
|
||||
src_sip=1,
|
||||
src_cube=2,
|
||||
src_pe=7,
|
||||
src_pa=0xBEEF,
|
||||
nbytes=2048,
|
||||
)
|
||||
assert msg.msg_type == "memory_read"
|
||||
assert msg.src_pa == 0xBEEF
|
||||
assert msg.nbytes == 2048
|
||||
|
||||
|
||||
def test_kernel_launch_msg_fields():
|
||||
shard = TensorArgShard(sip=0, cube=0, pe=0, pa=0x100, nbytes=1024, offset_bytes=0)
|
||||
targ = TensorArg(shards=(shard,))
|
||||
sarg = ScalarArg(dtype="fp32", value=1.0)
|
||||
kref = KernelRef(name="gemm", kind="builtin")
|
||||
msg = KernelLaunchMsg(
|
||||
correlation_id="c0",
|
||||
request_id="r2",
|
||||
kernel_ref=kref,
|
||||
args=(targ, sarg),
|
||||
)
|
||||
assert msg.msg_type == "kernel_launch"
|
||||
assert msg.kernel_ref.name == "gemm"
|
||||
assert len(msg.args) == 2
|
||||
assert msg.args[0].arg_kind == "tensor"
|
||||
assert msg.args[1].arg_kind == "scalar"
|
||||
|
||||
|
||||
# ── Placement: column_wise ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_column_wise_placement():
|
||||
"""(1024, 512) fp16 across 8 PEs → K axis split → 8 shards, each (1024, 64) = 128KB"""
|
||||
shards = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
assert len(shards) == 8
|
||||
expected_nbytes = 1024 * 64 * 2 # 128 KB
|
||||
for i, s in enumerate(shards):
|
||||
assert s.pe_index == i
|
||||
assert s.nbytes == expected_nbytes
|
||||
# offsets are contiguous
|
||||
assert shards[0].offset_bytes == 0
|
||||
assert shards[1].offset_bytes == expected_nbytes
|
||||
# total coverage
|
||||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||||
|
||||
|
||||
# ── Placement: row_wise ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_row_wise_placement():
|
||||
"""(1024, 512) fp16 across 8 PEs → M axis split → 8 shards, each (128, 512) = 128KB"""
|
||||
shards = row_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
assert len(shards) == 8
|
||||
expected_nbytes = 128 * 512 * 2 # 128 KB
|
||||
for i, s in enumerate(shards):
|
||||
assert s.pe_index == i
|
||||
assert s.nbytes == expected_nbytes
|
||||
assert shards[0].offset_bytes == 0
|
||||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||||
|
||||
|
||||
# ── Placement: replicate ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_replicate_placement():
|
||||
"""(1024, 512) fp16 across 8 PEs → each PE gets full copy = 1MB"""
|
||||
shards = replicate(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
assert len(shards) == 8
|
||||
full_nbytes = 1024 * 512 * 2 # 1 MB
|
||||
for i, s in enumerate(shards):
|
||||
assert s.pe_index == i
|
||||
assert s.nbytes == full_nbytes
|
||||
assert s.offset_bytes == 0 # each is a full copy
|
||||
|
||||
|
||||
# ── Placement: tiled_column_major ─────────────────────────────────────
|
||||
|
||||
|
||||
def test_tiled_column_major():
|
||||
"""(1024, 512) tile=(256, 128) → 4×4=16 tiles, column-major → round-robin 8 PEs"""
|
||||
shards = tiled_column_major(
|
||||
shape=(1024, 512), itemsize=2, num_pe=8, tile_m=256, tile_k=128,
|
||||
)
|
||||
# 4 tiles along M, 4 tiles along K → 16 tiles total
|
||||
assert len(shards) == 16
|
||||
tile_bytes = 256 * 128 * 2 # 64 KB per tile
|
||||
for s in shards:
|
||||
assert s.nbytes == tile_bytes
|
||||
# column-major: iterate K first, then M
|
||||
# tile (m=0,k=0) → PE0, tile (m=0,k=1) → PE1, ..., (m=0,k=3) → PE3
|
||||
# tile (m=1,k=0) → PE4, tile (m=1,k=1) → PE5, ..., (m=1,k=3) → PE7
|
||||
# tile (m=2,k=0) → PE0, ...
|
||||
assert shards[0].pe_index == 0
|
||||
assert shards[1].pe_index == 1
|
||||
assert shards[7].pe_index == 7
|
||||
assert shards[8].pe_index == 0 # wraps around
|
||||
# total coverage
|
||||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||||
|
||||
|
||||
# ── Placement: tiled_row_major ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tiled_row_major():
|
||||
"""(1024, 512) tile=(256, 128) → 4×4=16 tiles, row-major → round-robin 8 PEs"""
|
||||
shards = tiled_row_major(
|
||||
shape=(1024, 512), itemsize=2, num_pe=8, tile_m=256, tile_k=128,
|
||||
)
|
||||
assert len(shards) == 16
|
||||
tile_bytes = 256 * 128 * 2
|
||||
for s in shards:
|
||||
assert s.nbytes == tile_bytes
|
||||
# row-major: iterate M first, then K
|
||||
# tile (m=0,k=0) → PE0, tile (m=1,k=0) → PE1, ..., (m=3,k=0) → PE3
|
||||
# tile (m=0,k=1) → PE4, tile (m=1,k=1) → PE5, ..., (m=3,k=1) → PE7
|
||||
# tile (m=0,k=2) → PE0, ...
|
||||
assert shards[0].pe_index == 0
|
||||
assert shards[1].pe_index == 1
|
||||
assert shards[7].pe_index == 7
|
||||
assert shards[8].pe_index == 0 # wraps around
|
||||
# total coverage
|
||||
assert sum(s.nbytes for s in shards) == 1024 * 512 * 2
|
||||
|
||||
|
||||
# ── deploy_tensor ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_deploy_tensor_hbm():
|
||||
"""Deploy with column_wise placement → TensorHandle with valid PA shards."""
|
||||
allocs = _make_allocators()
|
||||
placement = column_wise(shape=(1024, 512), itemsize=2, num_pe=8)
|
||||
th = deploy_tensor(
|
||||
name="W",
|
||||
shape=(1024, 512),
|
||||
dtype="fp16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
mem_kind="hbm",
|
||||
)
|
||||
assert th.name == "W"
|
||||
assert th.shape == (1024, 512)
|
||||
assert th.dtype == "fp16"
|
||||
assert th.itemsize == 2
|
||||
assert len(th.shards) == 8
|
||||
# each shard has a distinct PA
|
||||
pas = [s.pa for s in th.shards]
|
||||
assert len(set(pas)) == 8
|
||||
# each shard placed on correct PE
|
||||
for i, s in enumerate(th.shards):
|
||||
assert s.pe == i
|
||||
assert s.sip == 0
|
||||
assert s.cube == 0
|
||||
|
||||
|
||||
def test_deploy_tensor_tcm():
|
||||
"""Deploy with TCM → uses pe_tcm_addr allocation."""
|
||||
allocs = _make_allocators()
|
||||
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=256)]
|
||||
th = deploy_tensor(
|
||||
name="small",
|
||||
shape=(128,),
|
||||
dtype="fp16",
|
||||
placement=placement,
|
||||
allocators=allocs,
|
||||
mem_kind="tcm",
|
||||
)
|
||||
assert len(th.shards) == 1
|
||||
assert th.shards[0].pe == 0
|
||||
assert th.shards[0].nbytes == 256
|
||||
|
||||
|
||||
def test_deploy_tensor_overflow():
|
||||
"""Allocation exceeding PE HBM capacity raises AllocationError."""
|
||||
allocs = _make_allocators()
|
||||
# 6 GB per PE slice, try to allocate 7 GB
|
||||
big_shard = ShardSpec(pe_index=0, offset_bytes=0, nbytes=7 * _GB)
|
||||
with pytest.raises(AllocationError):
|
||||
deploy_tensor(
|
||||
name="toobig",
|
||||
shape=(1,),
|
||||
dtype="int8",
|
||||
placement=[big_shard],
|
||||
allocators=allocs,
|
||||
)
|
||||
@@ -0,0 +1,409 @@
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.topology.builder import load_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
|
||||
def _graph():
|
||||
return load_topology(TOPOLOGY_PATH)
|
||||
|
||||
|
||||
# ── Full graph: node counts ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_full_graph_node_count():
|
||||
g = _graph()
|
||||
# 1 switch
|
||||
# + 2 SIPs × (1 IO × 2 comps + 16 cubes × (cube_comps + 8 PEs × 6 pe_comps))
|
||||
# cube_comps: 9 (noc, m_cpu, sram, 2 bridge, 4 ucie)
|
||||
# + 8 xbar.pe{0..7} [replaced xbar.top/xbar.bottom]
|
||||
# + 8 hbm_slices = 25
|
||||
# = 1 + 2*(2 + 16*(25+48)) = 1 + 2*(2+1168) = 1 + 2340 = 2341
|
||||
assert len(g.nodes) == 2341
|
||||
|
||||
|
||||
def test_full_graph_edge_count():
|
||||
g = _graph()
|
||||
# Per cube: 144 (88 cube-fabric + 56 PE-internal)
|
||||
# cube-fabric: 8 pe→xbar.pe + 8 pe→noc + 8 noc→pe_cpu
|
||||
# + 8 xbar.pe→slice + 8 slice→xbar.pe (bidirectional for response)
|
||||
# + 12 xbar chain (3 pairs × 2 dir × 2 halves)
|
||||
# + 8 xbar.pe↔bridge (pe0↔bL, pe4↔bL, pe3↔bR, pe7↔bR, ×2 dir each)
|
||||
# + 4 noc→ucie + 4 ucie→noc (bidirectional)
|
||||
# + 8 noc→xbar.pe + 8 xbar.pe→noc (bidirectional for response)
|
||||
# + 1 m_cpu→noc + 1 noc→m_cpu + 1 noc→sram + 1 sram→noc = 88
|
||||
# Per SIP: 16*144 + 48 inter-cube(bidirectional) + 8 io↔cube(bidirectional)
|
||||
# + 1 io_internal + 1 switch→io = 2362
|
||||
# Total: 2 * 2362 = 4724
|
||||
assert len(g.edges) == 4724
|
||||
|
||||
|
||||
# ── Full graph: specific nodes exist ─────────────────────────────────
|
||||
|
||||
|
||||
def test_system_switch_exists():
|
||||
g = _graph()
|
||||
assert "fabric.switch0" in g.nodes
|
||||
assert g.nodes["fabric.switch0"].kind == "switch"
|
||||
assert g.nodes["fabric.switch0"].pos_mm is None # abstract
|
||||
|
||||
|
||||
def test_io_chiplet_nodes_exist():
|
||||
g = _graph()
|
||||
for s in range(2):
|
||||
assert f"sip{s}.io0.pcie_ep" in g.nodes
|
||||
assert f"sip{s}.io0.io_cpu" in g.nodes
|
||||
|
||||
|
||||
def test_cube_component_nodes_exist():
|
||||
g = _graph()
|
||||
cp = "sip0.cube0"
|
||||
for name in ("noc", "m_cpu",
|
||||
"bridge.left", "bridge.right",
|
||||
"ucie-N", "ucie-S", "ucie-E", "ucie-W",
|
||||
"sram"):
|
||||
assert f"{cp}.{name}" in g.nodes
|
||||
# xbar.top/xbar.bottom replaced by per-PE xbar entry nodes
|
||||
assert "sip0.cube0.xbar.top" not in g.nodes
|
||||
assert "sip0.cube0.xbar.bottom" not in g.nodes
|
||||
for pe in range(8):
|
||||
node_id = f"{cp}.xbar.pe{pe}"
|
||||
assert node_id in g.nodes, f"{node_id} missing"
|
||||
assert g.nodes[node_id].kind == "xbar"
|
||||
# HBM slices (one per PE)
|
||||
for s in range(8):
|
||||
assert f"{cp}.hbm_ctrl.slice{s}" in g.nodes
|
||||
assert g.nodes[f"{cp}.hbm_ctrl.slice{s}"].kind == "hbm_ctrl"
|
||||
|
||||
|
||||
def test_pe_component_nodes_exist():
|
||||
g = _graph()
|
||||
for comp in ("pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"):
|
||||
assert f"sip0.cube0.pe0.{comp}" in g.nodes
|
||||
assert f"sip1.cube15.pe7.{comp}" in g.nodes
|
||||
|
||||
|
||||
# ── Full graph: positions ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_hbm_ctrl_slices_at_cube_center():
|
||||
g = _graph()
|
||||
# cube0 origin = (0, 0), cx=8.5, cy=7.0, hbm_ctrl at (cx-2, cy)
|
||||
# all slices share the same physical position
|
||||
for s in range(8):
|
||||
node = g.nodes[f"sip0.cube0.hbm_ctrl.slice{s}"]
|
||||
assert node.pos_mm == (6.5, 7.0)
|
||||
|
||||
|
||||
def test_hbm_ctrl_slices_cube5_position():
|
||||
g = _graph()
|
||||
# cube5 = col=1, row=1 -> origin = (1*18, 1*15) = (18, 15)
|
||||
# hbm_ctrl = (18 + 6.5, 15 + 7.0) = (24.5, 22.0)
|
||||
node = g.nodes["sip0.cube5.hbm_ctrl.slice0"]
|
||||
assert node.pos_mm == (24.5, 22.0)
|
||||
|
||||
|
||||
def test_ucie_ports_at_cube_edges():
|
||||
g = _graph()
|
||||
# cube0 origin = (0, 0), cube_w=17, cube_h=14
|
||||
# UCIe nodes inset by half-size so edges touch boundary
|
||||
assert g.nodes["sip0.cube0.ucie-N"].pos_mm == (8.5, 0.6)
|
||||
assert g.nodes["sip0.cube0.ucie-S"].pos_mm == (8.5, 13.4)
|
||||
assert g.nodes["sip0.cube0.ucie-W"].pos_mm == (1.0, 7.0)
|
||||
assert g.nodes["sip0.cube0.ucie-E"].pos_mm == (16.0, 7.0)
|
||||
|
||||
|
||||
# ── Full graph: edges ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _edge_set(g):
|
||||
return {(e.src, e.dst) for e in g.edges}
|
||||
|
||||
|
||||
def test_inter_cube_ucie_edges():
|
||||
es = _edge_set(_graph())
|
||||
# cube0 (0,0) E → cube1 (1,0) W
|
||||
assert ("sip0.cube0.ucie-E", "sip0.cube1.ucie-W") in es
|
||||
# cube0 (0,0) S → cube4 (0,1) N
|
||||
assert ("sip0.cube0.ucie-S", "sip0.cube4.ucie-N") in es
|
||||
|
||||
|
||||
def test_io_to_cube_edges():
|
||||
es = _edge_set(_graph())
|
||||
# io0 connects to cubes (0,0)..(3,0) on N side
|
||||
assert ("sip0.io0.io_cpu", "sip0.cube0.ucie-N") in es
|
||||
assert ("sip0.io0.io_cpu", "sip0.cube3.ucie-N") in es
|
||||
|
||||
|
||||
def test_switch_to_io_edges():
|
||||
es = _edge_set(_graph())
|
||||
assert ("fabric.switch0", "sip0.io0.pcie_ep") in es
|
||||
assert ("fabric.switch0", "sip1.io0.pcie_ep") in es
|
||||
|
||||
|
||||
def test_pe_to_xbar_edges():
|
||||
es = _edge_set(_graph())
|
||||
cp = "sip0.cube0"
|
||||
# Each PE connects to its own xbar entry (per-PE chain model)
|
||||
for pe in range(8):
|
||||
assert (f"{cp}.pe{pe}.pe_dma", f"{cp}.xbar.pe{pe}") in es
|
||||
# Old shared xbar.top/bottom edges must NOT exist
|
||||
assert (f"{cp}.pe0.pe_dma", f"{cp}.xbar.top") not in es
|
||||
assert (f"{cp}.pe4.pe_dma", f"{cp}.xbar.bottom") not in es
|
||||
|
||||
|
||||
def test_command_path_m_cpu_noc_pe_cpu():
|
||||
es = _edge_set(_graph())
|
||||
cp = "sip0.cube0"
|
||||
# m_cpu ↔ noc (bidirectional)
|
||||
assert (f"{cp}.m_cpu", f"{cp}.noc") in es
|
||||
assert (f"{cp}.noc", f"{cp}.m_cpu") in es
|
||||
# noc → pe_cpu for each PE
|
||||
assert (f"{cp}.noc", f"{cp}.pe0.pe_cpu") in es
|
||||
assert (f"{cp}.noc", f"{cp}.pe7.pe_cpu") in es
|
||||
|
||||
|
||||
def test_pe_internal_edges():
|
||||
es = _edge_set(_graph())
|
||||
pp = "sip0.cube0.pe0"
|
||||
assert (f"{pp}.pe_cpu", f"{pp}.pe_scheduler") in es
|
||||
assert (f"{pp}.pe_scheduler", f"{pp}.pe_dma") in es
|
||||
assert (f"{pp}.pe_scheduler", f"{pp}.pe_gemm") in es
|
||||
assert (f"{pp}.pe_scheduler", f"{pp}.pe_math") in es
|
||||
assert (f"{pp}.pe_dma", f"{pp}.pe_tcm") in es
|
||||
assert (f"{pp}.pe_gemm", f"{pp}.pe_tcm") in es
|
||||
assert (f"{pp}.pe_math", f"{pp}.pe_tcm") in es
|
||||
|
||||
|
||||
def test_xbar_to_hbm_slice_edges():
|
||||
"""Each xbar.pe{i} connects only to its own (local) HBM slice."""
|
||||
es = _edge_set(_graph())
|
||||
cp = "sip0.cube0"
|
||||
# xbar.pe_i -> slice_i only (local Y-direction access)
|
||||
for pe in range(8):
|
||||
assert (f"{cp}.xbar.pe{pe}", f"{cp}.hbm_ctrl.slice{pe}") in es
|
||||
# Negative: xbar.pe_i must NOT directly connect to a different slice
|
||||
assert (f"{cp}.xbar.pe0", f"{cp}.hbm_ctrl.slice1") not in es
|
||||
assert (f"{cp}.xbar.pe0", f"{cp}.hbm_ctrl.slice4") not in es
|
||||
assert (f"{cp}.xbar.pe4", f"{cp}.hbm_ctrl.slice0") not in es
|
||||
|
||||
|
||||
# ── Views: system ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_system_view_nodes():
|
||||
v = _graph().system_view
|
||||
assert "fabric.switch0" in v.nodes
|
||||
assert "sip0" in v.nodes
|
||||
assert "sip1" in v.nodes
|
||||
assert "sip0.io0" in v.nodes
|
||||
assert "sip1.io0" in v.nodes
|
||||
|
||||
|
||||
# ── Views: SIP ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_sip_view_cube_count():
|
||||
v = _graph().sip_view
|
||||
cube_nodes = [n for n in v.nodes if n.startswith("cube")]
|
||||
assert len(cube_nodes) == 16
|
||||
|
||||
|
||||
def test_sip_view_io_chiplets():
|
||||
v = _graph().sip_view
|
||||
assert "io0" in v.nodes
|
||||
|
||||
|
||||
def test_sip_view_cube_positions():
|
||||
v = _graph().sip_view
|
||||
# cube0 (0,0): center = (8.5, 6+7.0) = (8.5, 13.0) [io_margin=6]
|
||||
x, y = v.nodes["cube0"].pos_mm
|
||||
assert x == 8.5
|
||||
assert y == 13.0
|
||||
# cube1 (1,0): center = (18+8.5, 13.0) = (26.5, 13.0)
|
||||
x1, y1 = v.nodes["cube1"].pos_mm
|
||||
assert x1 == 26.5
|
||||
assert y1 == 13.0
|
||||
|
||||
|
||||
# ── Views: cube ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_cube_view_has_all_components():
|
||||
v = _graph().cube_view
|
||||
expected = {"ucie-N", "ucie-S", "ucie-W", "ucie-E",
|
||||
"m_cpu", "hbm_ctrl",
|
||||
"bridge.left", "bridge.right", "noc", "sram",
|
||||
"xbar.pe0", "xbar.pe1", "xbar.pe2", "xbar.pe3",
|
||||
"xbar.pe4", "xbar.pe5", "xbar.pe6", "xbar.pe7",
|
||||
"pe0", "pe1", "pe2", "pe3", "pe4", "pe5", "pe6", "pe7"}
|
||||
assert set(v.nodes.keys()) == expected
|
||||
|
||||
|
||||
def test_cube_view_hbm_at_center():
|
||||
v = _graph().cube_view
|
||||
assert v.nodes["hbm_ctrl"].pos_mm == (6.5, 7.0)
|
||||
assert v.nodes["noc"].pos_mm == (10.5, 7.0)
|
||||
assert v.width_mm == 17.0
|
||||
assert v.height_mm == 14.0
|
||||
|
||||
|
||||
def test_cube_view_pe_corner_mapping():
|
||||
v = _graph().cube_view
|
||||
ves = {(e.src, e.dst) for e in v.edges}
|
||||
# Each PE connects to its own xbar entry (chain model)
|
||||
for i in range(8):
|
||||
assert (f"pe{i}", f"xbar.pe{i}") in ves
|
||||
# Old shared xbar.top/bottom mapping must not exist
|
||||
assert ("pe0", "xbar.top") not in ves
|
||||
assert ("pe4", "xbar.bottom") not in ves
|
||||
|
||||
|
||||
# ── Views: PE ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_pe_view_has_all_components():
|
||||
v = _graph().pe_view
|
||||
assert set(v.nodes.keys()) == {
|
||||
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"
|
||||
}
|
||||
|
||||
|
||||
def test_pe_view_edges():
|
||||
v = _graph().pe_view
|
||||
ves = {(e.src, e.dst) for e in v.edges}
|
||||
assert ("pe_cpu", "pe_scheduler") in ves
|
||||
assert ("pe_scheduler", "pe_dma") in ves
|
||||
assert ("pe_scheduler", "pe_gemm") in ves
|
||||
assert ("pe_scheduler", "pe_math") in ves
|
||||
assert ("pe_dma", "pe_tcm") in ves
|
||||
assert ("pe_gemm", "pe_tcm") in ves
|
||||
assert ("pe_math", "pe_tcm") in ves
|
||||
|
||||
|
||||
# ── SRAM ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_sram_node_exists():
|
||||
g = _graph()
|
||||
assert "sip0.cube0.sram" in g.nodes
|
||||
assert g.nodes["sip0.cube0.sram"].kind == "sram"
|
||||
|
||||
|
||||
def test_noc_to_sram_edges():
|
||||
es = _edge_set(_graph())
|
||||
cp = "sip0.cube0"
|
||||
assert (f"{cp}.noc", f"{cp}.sram") in es
|
||||
assert (f"{cp}.sram", f"{cp}.noc") in es
|
||||
|
||||
|
||||
# ── PE_DMA → NOC (non-HBM data path) ───────────────────────────────
|
||||
|
||||
|
||||
def test_pe_dma_to_noc_edges():
|
||||
es = _edge_set(_graph())
|
||||
cp = "sip0.cube0"
|
||||
for i in range(8):
|
||||
assert (f"{cp}.pe{i}.pe_dma", f"{cp}.noc") in es
|
||||
|
||||
|
||||
# ── Bridge connects XBAR halves (not NOC) ──────────────────────────
|
||||
|
||||
|
||||
def test_bridge_connects_xbar_halves():
|
||||
"""bridge.left connects leftmost PE nodes (pe0 top, pe4 bottom).
|
||||
bridge.right connects rightmost PE nodes (pe3 top, pe7 bottom)."""
|
||||
es = _edge_set(_graph())
|
||||
cp = "sip0.cube0"
|
||||
# bridge.left ↔ pe0 (top-left) and pe4 (bottom-left)
|
||||
assert (f"{cp}.xbar.pe0", f"{cp}.bridge.left") in es
|
||||
assert (f"{cp}.bridge.left", f"{cp}.xbar.pe0") in es
|
||||
assert (f"{cp}.xbar.pe4", f"{cp}.bridge.left") in es
|
||||
assert (f"{cp}.bridge.left", f"{cp}.xbar.pe4") in es
|
||||
# bridge.right ↔ pe3 (top-right) and pe7 (bottom-right)
|
||||
assert (f"{cp}.xbar.pe3", f"{cp}.bridge.right") in es
|
||||
assert (f"{cp}.bridge.right", f"{cp}.xbar.pe3") in es
|
||||
assert (f"{cp}.xbar.pe7", f"{cp}.bridge.right") in es
|
||||
assert (f"{cp}.bridge.right", f"{cp}.xbar.pe7") in es
|
||||
# Old xbar.top/bottom ↔ bridge edges must NOT exist
|
||||
assert (f"{cp}.xbar.top", f"{cp}.bridge.left") not in es
|
||||
assert (f"{cp}.xbar.bottom", f"{cp}.bridge.left") not in es
|
||||
|
||||
|
||||
def test_no_bridge_to_noc_edges():
|
||||
es = _edge_set(_graph())
|
||||
cp = "sip0.cube0"
|
||||
assert (f"{cp}.bridge.left", f"{cp}.noc") not in es
|
||||
assert (f"{cp}.bridge.right", f"{cp}.noc") not in es
|
||||
|
||||
|
||||
# ── Cube view: new edges ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_cube_view_pe_to_noc():
|
||||
v = _graph().cube_view
|
||||
ves = {(e.src, e.dst) for e in v.edges}
|
||||
for i in range(8):
|
||||
assert (f"pe{i}", "noc") in ves
|
||||
|
||||
|
||||
def test_cube_view_sram():
|
||||
v = _graph().cube_view
|
||||
assert "sram" in v.nodes
|
||||
ves = {(e.src, e.dst) for e in v.edges}
|
||||
assert ("noc", "sram") in ves
|
||||
assert ("sram", "noc") in ves
|
||||
|
||||
|
||||
def test_cube_view_bridge_xbar():
|
||||
v = _graph().cube_view
|
||||
ves = {(e.src, e.dst) for e in v.edges}
|
||||
# bridge.left connects pe0 (top-left) ↔ pe4 (bottom-left)
|
||||
assert ("xbar.pe0", "bridge.left") in ves
|
||||
assert ("bridge.left", "xbar.pe0") in ves
|
||||
assert ("xbar.pe4", "bridge.left") in ves
|
||||
assert ("bridge.left", "xbar.pe4") in ves
|
||||
# bridge.right connects pe3 (top-right) ↔ pe7 (bottom-right)
|
||||
assert ("xbar.pe3", "bridge.right") in ves
|
||||
assert ("bridge.right", "xbar.pe3") in ves
|
||||
assert ("xbar.pe7", "bridge.right") in ves
|
||||
assert ("bridge.right", "xbar.pe7") in ves
|
||||
|
||||
|
||||
# ── Chain xbar: new topology edges ──────────────────────────────────
|
||||
|
||||
|
||||
def test_xbar_chain_edges():
|
||||
"""Adjacent xbar.pe nodes within each half are bidirectionally connected."""
|
||||
es = _edge_set(_graph())
|
||||
cp = "sip0.cube0"
|
||||
# Top chain: pe0 ↔ pe1 ↔ pe2 ↔ pe3 (NW→NE direction)
|
||||
for a, b in [(0, 1), (1, 2), (2, 3)]:
|
||||
assert (f"{cp}.xbar.pe{a}", f"{cp}.xbar.pe{b}") in es, f"missing pe{a}→pe{b}"
|
||||
assert (f"{cp}.xbar.pe{b}", f"{cp}.xbar.pe{a}") in es, f"missing pe{b}→pe{a}"
|
||||
# Bottom chain: pe4 ↔ pe5 ↔ pe6 ↔ pe7
|
||||
for a, b in [(4, 5), (5, 6), (6, 7)]:
|
||||
assert (f"{cp}.xbar.pe{a}", f"{cp}.xbar.pe{b}") in es, f"missing pe{a}→pe{b}"
|
||||
assert (f"{cp}.xbar.pe{b}", f"{cp}.xbar.pe{a}") in es, f"missing pe{b}→pe{a}"
|
||||
# Negative: no cross-chain direct edges
|
||||
assert (f"{cp}.xbar.pe0", f"{cp}.xbar.pe2") not in es
|
||||
assert (f"{cp}.xbar.pe0", f"{cp}.xbar.pe4") not in es
|
||||
|
||||
|
||||
def test_ucie_noc_reverse_edges():
|
||||
"""UCIe ports must have reverse edges back to NOC (bidirectional)."""
|
||||
es = _edge_set(_graph())
|
||||
cp = "sip0.cube1" # non-edge cube to avoid io-cube edges
|
||||
for port in ("N", "S", "E", "W"):
|
||||
assert (f"{cp}.ucie-{port}", f"{cp}.noc") in es, \
|
||||
f"missing ucie-{port}->noc reverse edge"
|
||||
|
||||
|
||||
def test_noc_to_xbar_pe_edges():
|
||||
"""NOC connects to all xbar.pe nodes (for remote cube HBM access)."""
|
||||
es = _edge_set(_graph())
|
||||
cp = "sip0.cube0"
|
||||
for pe in range(8):
|
||||
assert (f"{cp}.noc", f"{cp}.xbar.pe{pe}") in es, \
|
||||
f"missing noc->xbar.pe{pe}"
|
||||
@@ -0,0 +1,60 @@
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.topology.builder import _read_spec, resolve_topology
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
|
||||
def test_topology_yaml_loads_without_error():
|
||||
# _compile_graph is still stubbed (returns None); load must not raise
|
||||
resolve_topology(str(TOPOLOGY_PATH))
|
||||
|
||||
|
||||
def test_pe_layout_structure():
|
||||
spec = _read_spec(TOPOLOGY_PATH)
|
||||
pe_layout = spec["cube"]["pe_layout"]
|
||||
assert set(pe_layout["corners"]) == {"NW", "NE", "SW", "SE"}
|
||||
assert pe_layout["pe_per_corner"] == 2
|
||||
# derived total must equal original pe_per_cube: 8
|
||||
assert pe_layout["pe_per_corner"] * len(pe_layout["corners"]) == 8
|
||||
|
||||
|
||||
def test_pe_template_components():
|
||||
spec = _read_spec(TOPOLOGY_PATH)
|
||||
comps = spec["cube"]["pe_template"]["components"]
|
||||
assert set(comps.keys()) == {
|
||||
"pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"
|
||||
}
|
||||
|
||||
|
||||
def test_pe_template_links_present():
|
||||
spec = _read_spec(TOPOLOGY_PATH)
|
||||
links = spec["cube"]["pe_template"]["links"]
|
||||
required = {
|
||||
"pe_cpu_to_scheduler_mm",
|
||||
"scheduler_to_dma_mm",
|
||||
"scheduler_to_gemm_mm",
|
||||
"scheduler_to_math_mm",
|
||||
"dma_to_tcm_bw_gbs", "dma_to_tcm_mm",
|
||||
"gemm_to_tcm_bw_gbs", "gemm_to_tcm_mm",
|
||||
"math_to_tcm_bw_gbs", "math_to_tcm_mm",
|
||||
}
|
||||
assert required.issubset(set(links.keys()))
|
||||
|
||||
|
||||
def test_pe_dma_not_in_cube_components():
|
||||
spec = _read_spec(TOPOLOGY_PATH)
|
||||
assert "pe_dma" not in spec["cube"]["components"]
|
||||
|
||||
|
||||
def test_pe_per_cube_removed():
|
||||
spec = _read_spec(TOPOLOGY_PATH)
|
||||
assert "pe_per_cube" not in spec["cube"].get("device", {})
|
||||
|
||||
|
||||
def test_shared_resource_accel_slot():
|
||||
# ADR-0014 D4: PE_GEMM and PE_MATH share PE_ACCEL capacity = 1
|
||||
spec = _read_spec(TOPOLOGY_PATH)
|
||||
comps = spec["cube"]["pe_template"]["components"]
|
||||
assert comps["pe_gemm"]["attrs"]["shared_resource"] == "accel_slot"
|
||||
assert comps["pe_math"]["attrs"]["shared_resource"] == "accel_slot"
|
||||
@@ -0,0 +1,81 @@
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.topology.builder import load_topology
|
||||
from kernbench.topology.visualizer import emit_diagrams
|
||||
|
||||
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
|
||||
|
||||
VIEW_FILES = ["system_view.svg", "sip_view.svg", "cube_view.svg", "pe_view.svg"]
|
||||
|
||||
|
||||
def _emit(tmp_path: Path) -> list[Path]:
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
return emit_diagrams(graph, tmp_path)
|
||||
|
||||
|
||||
def test_emit_creates_all_svg_files(tmp_path):
|
||||
created = _emit(tmp_path)
|
||||
assert len(created) == 4
|
||||
for name in VIEW_FILES:
|
||||
assert (tmp_path / name).exists()
|
||||
assert (tmp_path / name).stat().st_size > 0
|
||||
|
||||
|
||||
def test_svg_output_is_deterministic(tmp_path):
|
||||
graph = load_topology(TOPOLOGY_PATH)
|
||||
emit_diagrams(graph, tmp_path)
|
||||
first = {name: (tmp_path / name).read_text() for name in VIEW_FILES}
|
||||
emit_diagrams(graph, tmp_path)
|
||||
second = {name: (tmp_path / name).read_text() for name in VIEW_FILES}
|
||||
for name in VIEW_FILES:
|
||||
assert first[name] == second[name], f"{name} is not deterministic"
|
||||
|
||||
|
||||
def test_cube_svg_contains_hbm_ctrl(tmp_path):
|
||||
_emit(tmp_path)
|
||||
svg = (tmp_path / "cube_view.svg").read_text()
|
||||
assert "HBM CTRL" in svg
|
||||
|
||||
|
||||
def test_cube_svg_contains_ucie_ports(tmp_path):
|
||||
_emit(tmp_path)
|
||||
svg = (tmp_path / "cube_view.svg").read_text()
|
||||
for port in ("UCIe-N", "UCIe-S", "UCIe-W", "UCIe-E"):
|
||||
assert port in svg
|
||||
|
||||
|
||||
def test_cube_svg_contains_pe_nodes(tmp_path):
|
||||
_emit(tmp_path)
|
||||
svg = (tmp_path / "cube_view.svg").read_text()
|
||||
for i in range(8):
|
||||
assert f"PE{i}" in svg
|
||||
|
||||
|
||||
def test_pe_svg_contains_all_components(tmp_path):
|
||||
_emit(tmp_path)
|
||||
svg = (tmp_path / "pe_view.svg").read_text()
|
||||
for comp in ("PE CPU", "PE SCHEDULER", "PE DMA", "PE GEMM", "PE MATH", "PE TCM"):
|
||||
assert comp in svg
|
||||
|
||||
|
||||
def test_sip_svg_contains_cubes(tmp_path):
|
||||
_emit(tmp_path)
|
||||
svg = (tmp_path / "sip_view.svg").read_text()
|
||||
assert "CUBE (0,0)" in svg
|
||||
assert "CUBE (3,3)" in svg
|
||||
|
||||
|
||||
def test_system_svg_contains_switch_and_sips(tmp_path):
|
||||
_emit(tmp_path)
|
||||
svg = (tmp_path / "system_view.svg").read_text()
|
||||
assert "Fabric Switch" in svg
|
||||
assert "SIP 0" in svg
|
||||
assert "SIP 1" in svg
|
||||
|
||||
|
||||
def test_svg_is_valid_xml(tmp_path):
|
||||
_emit(tmp_path)
|
||||
for name in VIEW_FILES:
|
||||
svg = (tmp_path / name).read_text()
|
||||
assert svg.startswith("<svg")
|
||||
assert svg.strip().endswith("</svg>")
|
||||
@@ -0,0 +1,349 @@
|
||||
"""Tests for Triton emulator: TLContext, command generation, kernel registry."""
|
||||
from kernbench.common.pe_commands import (
|
||||
CompletionHandle,
|
||||
CompositeCmd,
|
||||
DmaReadCmd,
|
||||
DmaWriteCmd,
|
||||
GemmCmd,
|
||||
MathCmd,
|
||||
PeCpuOverheadCmd,
|
||||
TensorHandle,
|
||||
WaitCmd,
|
||||
)
|
||||
from kernbench.triton_emu.registry import clear_registry, get_kernel, register_kernel
|
||||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||
|
||||
|
||||
def _ctx(**kwargs) -> TLContext:
|
||||
return TLContext(dispatch_cycles=0, **kwargs)
|
||||
|
||||
|
||||
def _ctx_with_overhead(**kwargs) -> TLContext:
|
||||
return TLContext(dispatch_cycles=1, **kwargs)
|
||||
|
||||
|
||||
# ── 1. tl.load → DmaReadCmd ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_load_generates_dma_read():
|
||||
tl = _ctx()
|
||||
h = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||||
assert isinstance(h, TensorHandle)
|
||||
assert h.shape == (32, 64)
|
||||
assert h.nbytes == 32 * 64 * 2
|
||||
cmds = tl.commands
|
||||
assert len(cmds) == 1
|
||||
assert isinstance(cmds[0], DmaReadCmd)
|
||||
assert cmds[0].src_pa == 0x1000
|
||||
assert cmds[0].nbytes == 32 * 64 * 2
|
||||
|
||||
|
||||
# ── 2. tl.store → DmaWriteCmd ────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_store_generates_dma_write():
|
||||
tl = _ctx()
|
||||
h = tl.load(0x1000, shape=(16, 16), dtype="f32")
|
||||
tl.store(0x2000, h)
|
||||
cmds = [c for c in tl.commands if isinstance(c, DmaWriteCmd)]
|
||||
assert len(cmds) == 1
|
||||
assert cmds[0].dst_pa == 0x2000
|
||||
assert cmds[0].nbytes == 16 * 16 * 4
|
||||
|
||||
|
||||
# ── 3. tl.dot → GemmCmd ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_dot_generates_gemm_cmd():
|
||||
tl = _ctx()
|
||||
a = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||||
b = tl.load(0x2000, shape=(64, 16), dtype="f16")
|
||||
out = tl.dot(a, b)
|
||||
assert out.shape == (32, 16)
|
||||
cmds = [c for c in tl.commands if isinstance(c, GemmCmd)]
|
||||
assert len(cmds) == 1
|
||||
assert cmds[0].m == 32
|
||||
assert cmds[0].k == 64
|
||||
assert cmds[0].n == 16
|
||||
|
||||
|
||||
# ── 4. tl.exp, tl.sqrt etc. → MathCmd ────────────────────────────
|
||||
|
||||
|
||||
def test_tl_math_unary_ops():
|
||||
tl = _ctx()
|
||||
x = tl.load(0x1000, shape=(8, 8), dtype="f16")
|
||||
for op_name, op_fn in [
|
||||
("exp", tl.exp), ("log", tl.log), ("sqrt", tl.sqrt),
|
||||
("abs", tl.abs), ("sigmoid", tl.sigmoid),
|
||||
("cos", tl.cos), ("sin", tl.sin),
|
||||
]:
|
||||
result = op_fn(x)
|
||||
assert isinstance(result, TensorHandle)
|
||||
assert result.shape == x.shape
|
||||
|
||||
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
|
||||
ops = [c.op for c in math_cmds]
|
||||
assert ops == ["exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin"]
|
||||
|
||||
|
||||
# ── 5. a + b, a * b → MathCmd ────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_math_binary_ops():
|
||||
tl = _ctx()
|
||||
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
|
||||
b = tl.load(0x2000, shape=(4, 4), dtype="f16")
|
||||
r1 = run_kernel(lambda tl: None, tl) # activate context for operators
|
||||
|
||||
# Need active context for operators
|
||||
tl2 = _ctx()
|
||||
a2 = tl2.load(0x1000, shape=(4, 4), dtype="f16")
|
||||
b2 = tl2.load(0x2000, shape=(4, 4), dtype="f16")
|
||||
|
||||
def kernel(tl):
|
||||
pass
|
||||
|
||||
# Use run_kernel to activate context, then test operators
|
||||
tl3 = _ctx()
|
||||
|
||||
def binary_kernel(tl):
|
||||
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
|
||||
b = tl.load(0x2000, shape=(4, 4), dtype="f16")
|
||||
_ = a + b
|
||||
_ = a - b
|
||||
_ = a * b
|
||||
_ = a / b
|
||||
|
||||
run_kernel(binary_kernel, tl3)
|
||||
math_cmds = [c for c in tl3.commands if isinstance(c, MathCmd)]
|
||||
ops = [c.op for c in math_cmds]
|
||||
assert ops == ["add", "sub", "mul", "div"]
|
||||
|
||||
|
||||
# ── 6. tl.sum, tl.max → MathCmd with axis ────────────────────────
|
||||
|
||||
|
||||
def test_tl_reduction_ops():
|
||||
tl = _ctx()
|
||||
x = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||||
s = tl.sum(x, axis=1)
|
||||
m = tl.max(x, axis=0)
|
||||
assert s.shape == (32, 1)
|
||||
assert m.shape == (1, 64)
|
||||
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
|
||||
assert math_cmds[0].op == "sum" and math_cmds[0].axis == 1
|
||||
assert math_cmds[1].op == "max" and math_cmds[1].axis == 0
|
||||
|
||||
|
||||
# ── 7. tl.composite → CompositeCmd + CompletionHandle ────────────
|
||||
|
||||
|
||||
def test_tl_composite_nonblocking():
|
||||
tl = _ctx()
|
||||
a = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||||
b = tl.load(0x2000, shape=(64, 32), dtype="f16")
|
||||
h = tl.composite(op="gemm", a=a, b=b, out_ptr=0x3000)
|
||||
assert isinstance(h, CompletionHandle)
|
||||
comp_cmds = [c for c in tl.commands if isinstance(c, CompositeCmd)]
|
||||
assert len(comp_cmds) == 1
|
||||
assert comp_cmds[0].op == "gemm"
|
||||
assert comp_cmds[0].out_pa == 0x3000
|
||||
assert comp_cmds[0].out_nbytes == 32 * 32 * 2 # M×N×dtype_bytes
|
||||
|
||||
|
||||
# ── 8. tl.wait(handle) → WaitCmd ─────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_wait_specific():
|
||||
tl = _ctx()
|
||||
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
|
||||
h = tl.composite(op="gemm", a=a, b=a, out_ptr=0x2000)
|
||||
tl.wait(h)
|
||||
wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)]
|
||||
assert len(wait_cmds) == 1
|
||||
assert wait_cmds[0].handle == h
|
||||
|
||||
|
||||
# ── 9. tl.wait() → WaitCmd(handle=None) ──────────────────────────
|
||||
|
||||
|
||||
def test_tl_wait_all():
|
||||
tl = _ctx()
|
||||
tl.wait()
|
||||
wait_cmds = [c for c in tl.commands if isinstance(c, WaitCmd)]
|
||||
assert len(wait_cmds) == 1
|
||||
assert wait_cmds[0].handle is None
|
||||
|
||||
|
||||
# ── 10. tl.cycles → PeCpuOverheadCmd ─────────────────────────────
|
||||
|
||||
|
||||
def test_tl_cycles():
|
||||
tl = _ctx()
|
||||
tl.cycles(10)
|
||||
assert len(tl.commands) == 1
|
||||
assert isinstance(tl.commands[0], PeCpuOverheadCmd)
|
||||
assert tl.commands[0].cycles == 10
|
||||
|
||||
|
||||
# ── 11. tl.program_id ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_program_id():
|
||||
tl = TLContext(pe_id=5, num_programs=8)
|
||||
assert tl.program_id(0) == 5
|
||||
assert tl.num_programs(0) == 8
|
||||
|
||||
|
||||
# ── 12. tl.arange, tl.zeros, tl.full ─────────────────────────────
|
||||
|
||||
|
||||
def test_tl_arange_zeros_full():
|
||||
tl = _ctx()
|
||||
r = tl.arange(0, 16, dtype="i32")
|
||||
assert r.shape == (16,)
|
||||
assert r.dtype == "i32"
|
||||
|
||||
z = tl.zeros((4, 8), dtype="f16")
|
||||
assert z.shape == (4, 8)
|
||||
assert z.nbytes == 4 * 8 * 2
|
||||
|
||||
f = tl.full((2, 3), value=1.0, dtype="f32")
|
||||
assert f.shape == (2, 3)
|
||||
assert f.nbytes == 2 * 3 * 4
|
||||
|
||||
|
||||
# ── 13. tl.trans → shape change, no command ───────────────────────
|
||||
|
||||
|
||||
def test_tl_trans_shape():
|
||||
tl = _ctx()
|
||||
h = tl.load(0x1000, shape=(32, 64), dtype="f16")
|
||||
t = tl.trans(h)
|
||||
assert t.shape == (64, 32)
|
||||
assert t.id == h.id # same underlying data
|
||||
# Only DmaReadCmd from load, no command from trans
|
||||
assert len(tl.commands) == 1
|
||||
assert isinstance(tl.commands[0], DmaReadCmd)
|
||||
|
||||
|
||||
# ── 14. Kernel registry ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_kernel_registry():
|
||||
clear_registry()
|
||||
|
||||
def my_kernel(tl):
|
||||
pass
|
||||
|
||||
register_kernel("test_kern", my_kernel)
|
||||
assert get_kernel("test_kern") is my_kernel
|
||||
clear_registry()
|
||||
|
||||
|
||||
def test_kernel_registry_missing():
|
||||
clear_registry()
|
||||
import pytest
|
||||
with pytest.raises(KeyError):
|
||||
get_kernel("nonexistent")
|
||||
|
||||
|
||||
def test_kernel_registry_duplicate():
|
||||
clear_registry()
|
||||
register_kernel("dup", lambda tl: None)
|
||||
import pytest
|
||||
with pytest.raises(ValueError):
|
||||
register_kernel("dup", lambda tl: None)
|
||||
clear_registry()
|
||||
|
||||
|
||||
# ── 15. GEMM kernel → correct command sequence ───────────────────
|
||||
|
||||
|
||||
def test_gemm_kernel_command_sequence():
|
||||
"""32×64 × 64×32 GEMM kernel produces [DmaRead, DmaRead, Composite]."""
|
||||
def gemm_kernel(a_ptr, b_ptr, out_ptr, tl):
|
||||
pid = tl.program_id(0)
|
||||
a = tl.load(a_ptr, shape=(32, 64), dtype="f16")
|
||||
b = tl.load(b_ptr + pid * 64 * 32 * 2, shape=(64, 32), dtype="f16")
|
||||
tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr + pid * 32 * 32 * 2)
|
||||
|
||||
tl = _ctx(pe_id=3)
|
||||
run_kernel(gemm_kernel, tl, a_ptr=0x1000, b_ptr=0x2000, out_ptr=0x3000)
|
||||
types = [type(c).__name__ for c in tl.commands]
|
||||
assert types == ["DmaReadCmd", "DmaReadCmd", "CompositeCmd"]
|
||||
|
||||
|
||||
# ── 16. Attention kernel → correct command sequence ───────────────
|
||||
|
||||
|
||||
def test_attention_kernel_command_sequence():
|
||||
"""Attention kernel: load→dot→math ops→dot→store."""
|
||||
def attention_kernel(q_ptr, k_ptr, v_ptr, out_ptr, tl,
|
||||
seq_len=16, head_dim=8):
|
||||
pid = tl.program_id(0)
|
||||
q = tl.load(q_ptr, shape=(seq_len, head_dim), dtype="f16")
|
||||
k = tl.load(k_ptr, shape=(head_dim, seq_len), dtype="f16")
|
||||
scores = tl.dot(q, k)
|
||||
row_max = tl.max(scores, axis=1)
|
||||
scores = scores - row_max
|
||||
scores = tl.exp(scores)
|
||||
row_sum = tl.sum(scores, axis=1)
|
||||
scores = scores / row_sum
|
||||
v = tl.load(v_ptr, shape=(seq_len, head_dim), dtype="f16")
|
||||
out = tl.dot(scores, v)
|
||||
tl.store(out_ptr, out)
|
||||
|
||||
tl = _ctx(pe_id=0)
|
||||
run_kernel(
|
||||
attention_kernel, tl,
|
||||
q_ptr=0x1000, k_ptr=0x2000, v_ptr=0x3000, out_ptr=0x4000,
|
||||
)
|
||||
types = [type(c).__name__ for c in tl.commands]
|
||||
# load, load, dot, max, sub, exp, sum, div, load, dot, store
|
||||
assert types == [
|
||||
"DmaReadCmd", "DmaReadCmd", # load Q, K
|
||||
"GemmCmd", # Q @ K
|
||||
"MathCmd", "MathCmd", "MathCmd", # max, sub, exp
|
||||
"MathCmd", "MathCmd", # sum, div
|
||||
"DmaReadCmd", # load V
|
||||
"GemmCmd", # scores @ V
|
||||
"DmaWriteCmd", # store output
|
||||
]
|
||||
# Verify math ops
|
||||
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
|
||||
math_ops = [c.op for c in math_cmds]
|
||||
assert math_ops == ["max", "sub", "exp", "sum", "div"]
|
||||
|
||||
|
||||
# ── 17. Dispatch overhead auto-inserted ───────────────────────────
|
||||
|
||||
|
||||
def test_dispatch_overhead_inserted():
|
||||
"""Each tl API call auto-inserts PeCpuOverheadCmd when dispatch_cycles > 0."""
|
||||
tl = _ctx_with_overhead()
|
||||
a = tl.load(0x1000, shape=(4, 4), dtype="f16")
|
||||
tl.store(0x2000, a)
|
||||
types = [type(c).__name__ for c in tl.commands]
|
||||
# overhead, load, overhead, store
|
||||
assert types == [
|
||||
"PeCpuOverheadCmd", "DmaReadCmd",
|
||||
"PeCpuOverheadCmd", "DmaWriteCmd",
|
||||
]
|
||||
|
||||
|
||||
# ── 18. where operation ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tl_where():
|
||||
tl = _ctx()
|
||||
cond = tl.load(0x1000, shape=(4, 4), dtype="i32")
|
||||
a = tl.load(0x2000, shape=(4, 4), dtype="f16")
|
||||
b = tl.load(0x3000, shape=(4, 4), dtype="f16")
|
||||
out = tl.where(cond, a, b)
|
||||
assert isinstance(out, TensorHandle)
|
||||
math_cmds = [c for c in tl.commands if isinstance(c, MathCmd)]
|
||||
assert len(math_cmds) == 1
|
||||
assert math_cmds[0].op == "where"
|
||||
assert len(math_cmds[0].inputs) == 3
|
||||
Reference in New Issue
Block a user