Files
kernbench2/tests/test_topology_compile.py
T
mukesh a7fe785e5f tl.composite: fused epilogue ops with per-op scope
Extend tl.composite() with an ordered epilogue list. Each op carries
a scope flag - output_tile (default, runs once per (m,n) before
STORE), k_tile (every K-tile right after GEMM), or kernel. Plan
generator slots MATH stages by scope; pe_math reuses pe_dma's
local-loop pattern so chained epilogues (bias->relu) skip the port
hop. op_log captures per-stage params for telemetry. Topology
gains a gemm->math edge (snapshot test updated).

API stays backward-compatible - `epilogue=` is opt-in.

Example:
    h = tl.composite(
        op="gemm", a=a, b=b, out_ptr=int(out),
        epilogue=[
            {"op": "dequant", "scale": s_per_k, "scope": "k_tile"},
            {"op": "bias",    "bias":  bias_vec},
            {"op": "relu"},
            {"op": "scale",   "factor": 0.5},
        ],
    )
    tl.wait(h)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-15 10:16:47 -07:00

444 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
from kernbench.policy.routing.router import PathRouter
from kernbench.topology.builder import load_topology
TOPOLOGY_PATH = Path(__file__).parent.parent / "topology.yaml"
def _graph():
return load_topology(TOPOLOGY_PATH)
# -- Full graph: node counts --------------------------------------------------
def test_full_graph_node_count():
g = _graph()
# 1 switch
# + 2 SIPs x (1 IO x 23 io_nodes
# + 16 cubes x (32 routers + 8 hbm_ctrl.peX + 1 m_cpu + 1 sram
# + 20 ucie (4 ports x (1 port + 4 conn))
# + 8 PEs x 9 pe_comps)) (ADR-0023: +pe_ipcq)
# IO: pcie_ep + io_cpu + noc + 4 io_ucie_ports + 4*4 io_ucie_conn = 23
# cube: 32 + 10 + 20 + 72 = 134 (was 127; ADR-0019 D1 per-PE HBM CTRL)
# = 1 + 2*(23 + 16*134) = 1 + 2*(23+2144) = 1 + 4334 = 4335
assert len(g.nodes) == 4335
def test_full_graph_edge_count():
g = _graph()
# ADR-0023: +3 IPCQ edges per PE
# ADR-0019 D1 (restored): HBM↔router edges drop from 32 routers × 2
# to 8 PE-routers × 2 per cube. 32 cubes × (16-64) = -1536 edges.
# Multi-op composite (ADR-0021): +1 gemm→math edge per PE for
# epilogue chaining = 2 SIPs × 16 cubes × 8 PEs = +256 edges.
assert len(g.edges) == 12412
# -- Full graph: specific nodes exist -----------------------------------------
def test_system_switch_exists():
g = _graph()
assert "fabric.switch0" in g.nodes
assert g.nodes["fabric.switch0"].kind == "switch"
assert g.nodes["fabric.switch0"].pos_mm is None # abstract
def test_io_chiplet_nodes_exist():
g = _graph()
for s in range(2):
assert f"sip{s}.io0.pcie_ep" in g.nodes
assert f"sip{s}.io0.io_cpu" in g.nodes
def test_cube_component_nodes_exist():
g = _graph()
cp = "sip0.cube0"
# Core cube components (no more noc, xbar, bridge)
for name in ("m_cpu", "sram",
"ucie-N", "ucie-S", "ucie-E", "ucie-W"):
assert f"{cp}.{name}" in g.nodes
# Old nodes must not exist
for old in ("noc", "xbar_top", "xbar_bot", "bridge.left", "bridge.right"):
assert f"{cp}.{old}" not in g.nodes
# Router mesh nodes (32 routers in 6x6 grid minus 4 null holes)
router_nodes = [n for n in g.nodes if n.startswith(f"{cp}.r")]
assert len(router_nodes) == 32
# Spot-check specific routers
assert f"{cp}.r0c0" in g.nodes
assert g.nodes[f"{cp}.r0c0"].kind == "noc_router"
assert f"{cp}.r5c5" in g.nodes
# Null holes must not exist
for null_rc in ("r2c2", "r2c3", "r3c2", "r3c3"):
assert f"{cp}.{null_rc}" not in g.nodes
# Per-PE HBM CTRL (ADR-0019 D1) — 8 instances, no legacy single node
for pe in range(8):
nid = f"{cp}.hbm_ctrl.pe{pe}"
assert g.nodes[nid].kind == "hbm_ctrl"
assert f"{cp}.hbm_ctrl" not in g.nodes
for s in range(8):
assert f"{cp}.hbm_ctrl.slice{s}" not in g.nodes
def test_pe_component_nodes_exist():
g = _graph()
for comp in ("pe_cpu", "pe_scheduler", "pe_dma", "pe_gemm", "pe_math", "pe_tcm"):
assert f"sip0.cube0.pe0.{comp}" in g.nodes
assert f"sip1.cube15.pe7.{comp}" in g.nodes
# -- Full graph: positions ----------------------------------------------------
def test_hbm_ctrl_at_cube_center():
g = _graph()
# Per-PE hbm_ctrl nodes share the cube's HBM placement (ADR-0019 D1)
# cube0 origin = (0, 0), hbm at (6.5, 7.0)
for pe in range(8):
node = g.nodes[f"sip0.cube0.hbm_ctrl.pe{pe}"]
assert node.pos_mm == (6.5, 7.0)
def test_hbm_ctrl_cube5_position():
g = _graph()
# cube5 = col=1, row=1 -> origin = (1*18, 1*15) = (18, 15)
# hbm_ctrl = (18 + 6.5, 15 + 7.0) = (24.5, 22.0)
node = g.nodes["sip0.cube5.hbm_ctrl.pe0"]
assert node.pos_mm == (24.5, 22.0)
def test_ucie_ports_at_cube_edges():
g = _graph()
# cube0 origin = (0, 0), cube_w=17, cube_h=14
# UCIe nodes inset by half-size so edges touch boundary
assert g.nodes["sip0.cube0.ucie-N"].pos_mm == (8.5, 0.6)
assert g.nodes["sip0.cube0.ucie-S"].pos_mm == (8.5, 13.4)
assert g.nodes["sip0.cube0.ucie-W"].pos_mm == (1.0, 7.0)
assert g.nodes["sip0.cube0.ucie-E"].pos_mm == (16.0, 7.0)
# -- Full graph: edges --------------------------------------------------------
def _edge_set(g):
return {(e.src, e.dst) for e in g.edges}
def test_inter_cube_ucie_edges():
es = _edge_set(_graph())
# cube0 (0,0) E -> cube1 (1,0) W
assert ("sip0.cube0.ucie-E", "sip0.cube1.ucie-W") in es
# cube0 (0,0) S -> cube4 (0,1) N
assert ("sip0.cube0.ucie-S", "sip0.cube4.ucie-N") in es
def test_io_to_cube_edges():
es = _edge_set(_graph())
# io0 connects io_ucie PHYs to cube UCIe ports on N side
assert ("sip0.io0.ucie-P0", "sip0.cube0.ucie-N") in es
assert ("sip0.io0.ucie-P3", "sip0.cube3.ucie-N") in es
def test_switch_to_io_edges():
es = _edge_set(_graph())
assert ("fabric.switch0", "sip0.io0.pcie_ep") in es
assert ("fabric.switch0", "sip1.io0.pcie_ep") in es
def test_pe_dma_to_router():
"""PE_DMA connects to its local router (pe_to_router kind)."""
es = _edge_set(_graph())
cp = "sip0.cube0"
# PE0 at r0c0, PE1 at r0c1
assert (f"{cp}.pe0.pe_dma", f"{cp}.r0c0") in es
assert (f"{cp}.pe1.pe_dma", f"{cp}.r0c1") in es
# PE2 at r1c4, PE3 at r1c5
assert (f"{cp}.pe2.pe_dma", f"{cp}.r1c4") in es
assert (f"{cp}.pe3.pe_dma", f"{cp}.r1c5") in es
# PE4 at r4c0, PE5 at r4c1
assert (f"{cp}.pe4.pe_dma", f"{cp}.r4c0") in es
assert (f"{cp}.pe5.pe_dma", f"{cp}.r4c1") in es
# PE6 at r5c4, PE7 at r5c5
assert (f"{cp}.pe6.pe_dma", f"{cp}.r5c4") in es
assert (f"{cp}.pe7.pe_dma", f"{cp}.r5c5") in es
def test_command_path_m_cpu_router_pe_cpu():
es = _edge_set(_graph())
cp = "sip0.cube0"
# m_cpu <-> r1c2 (bidirectional command)
assert (f"{cp}.m_cpu", f"{cp}.r1c2") in es
assert (f"{cp}.r1c2", f"{cp}.m_cpu") in es
# router -> pe_cpu for each PE (command kind)
assert (f"{cp}.r0c0", f"{cp}.pe0.pe_cpu") in es
assert (f"{cp}.r5c5", f"{cp}.pe7.pe_cpu") in es
def test_pe_internal_edges():
es = _edge_set(_graph())
pp = "sip0.cube0.pe0"
assert (f"{pp}.pe_cpu", f"{pp}.pe_scheduler") in es
assert (f"{pp}.pe_scheduler", f"{pp}.pe_dma") in es
assert (f"{pp}.pe_scheduler", f"{pp}.pe_gemm") in es
assert (f"{pp}.pe_scheduler", f"{pp}.pe_math") in es
assert (f"{pp}.pe_dma", f"{pp}.pe_tcm") in es
assert (f"{pp}.pe_gemm", f"{pp}.pe_tcm") in es
assert (f"{pp}.pe_math", f"{pp}.pe_tcm") in es
def test_per_pe_hbm_ctrl_connects_only_to_owning_router():
"""Each hbm_ctrl.pe{X} connects ONLY to PE_X's attaching router
(ADR-0019 D4). Replaces a prior test that asserted the
spec-violating all-routers consolidation (commit 5917b34)."""
g = _graph()
es = _edge_set(g)
cp = "sip0.cube0"
pe_router = {0: "r0c0", 1: "r0c1", 2: "r1c4", 3: "r1c5",
4: "r4c0", 5: "r4c1", 6: "r5c4", 7: "r5c5"}
for pe, rkey in pe_router.items():
nid = f"{cp}.hbm_ctrl.pe{pe}"
owner = f"{cp}.{rkey}"
assert (owner, nid) in es, f"missing {owner}{nid}"
assert (nid, owner) in es, f"missing {nid}{owner}"
for other in g.nodes:
if other.startswith(f"{cp}.r") and other != owner:
assert (other, nid) not in es, (
f"unexpected edge {other}{nid}"
)
def test_router_mesh_edges():
"""Adjacent routers are connected by router_mesh edges."""
g = _graph()
edge_kinds = {(e.src, e.dst): e.kind for e in g.edges}
cp = "sip0.cube0"
# r0c0 <-> r0c1 (horizontal neighbors)
assert edge_kinds.get((f"{cp}.r0c0", f"{cp}.r0c1")) == "router_mesh"
assert edge_kinds.get((f"{cp}.r0c1", f"{cp}.r0c0")) == "router_mesh"
# r0c0 <-> r1c0 (vertical neighbors)
assert edge_kinds.get((f"{cp}.r0c0", f"{cp}.r1c0")) == "router_mesh"
assert edge_kinds.get((f"{cp}.r1c0", f"{cp}.r0c0")) == "router_mesh"
# -- Views: system ------------------------------------------------------------
def test_system_view_nodes():
v = _graph().system_view
assert "fabric.switch0" in v.nodes
assert "sip0" in v.nodes
assert "sip1" in v.nodes
assert "sip0.io0" in v.nodes
assert "sip1.io0" in v.nodes
# -- Views: SIP ---------------------------------------------------------------
def test_sip_view_cube_count():
v = _graph().sip_view
cube_nodes = [n for n in v.nodes if n.startswith("cube")]
assert len(cube_nodes) == 16
def test_sip_view_io_chiplets():
v = _graph().sip_view
assert "io0" in v.nodes
def test_sip_view_cube_positions():
v = _graph().sip_view
# cube0 (0,0): center = (8.5, 6+7.0) = (8.5, 13.0) [io_margin=6]
x, y = v.nodes["cube0"].pos_mm
assert x == 8.5
assert y == 13.0
# cube1 (1,0): center = (18+8.5, 13.0) = (26.5, 13.0)
x1, y1 = v.nodes["cube1"].pos_mm
assert x1 == 26.5
assert y1 == 13.0
# -- Views: cube ---------------------------------------------------------------
def test_cube_view_has_all_components():
v = _graph().cube_view
expected = {"ucie-N", "ucie-S", "ucie-W", "ucie-E",
"m_cpu", "hbm_ctrl", "sram",
"pe0", "pe1", "pe2", "pe3", "pe4", "pe5", "pe6", "pe7",
"r0c0", "r0c1", "r0c2", "r0c3", "r0c4", "r0c5",
"r1c0", "r1c1", "r1c2", "r1c3", "r1c4", "r1c5",
"r2c0", "r2c1", "r2c4", "r2c5",
"r3c0", "r3c1", "r3c4", "r3c5",
"r4c0", "r4c1", "r4c2", "r4c3", "r4c4", "r4c5",
"r5c0", "r5c1", "r5c2", "r5c3", "r5c4", "r5c5"}
# Add UCIe connection nodes (4 ports x 4 connections)
for port in ("N", "S", "E", "W"):
for ci in range(4):
expected.add(f"ucie-{port}.conn{ci}")
assert set(v.nodes.keys()) == expected
def test_cube_view_hbm_at_center():
v = _graph().cube_view
assert v.nodes["hbm_ctrl"].pos_mm == (6.5, 7.0)
assert "r0c0" in v.nodes # routers exist in cube view
assert v.width_mm == 17.0
assert v.height_mm == 14.0
def test_cube_view_pe_to_router():
"""PEs connect to their assigned routers in cube view."""
v = _graph().cube_view
ves = {(e.src, e.dst) for e in v.edges}
pe_router_map = {"pe0": "r0c0", "pe1": "r0c1", "pe2": "r1c4", "pe3": "r1c5",
"pe4": "r4c0", "pe5": "r4c1", "pe6": "r5c4", "pe7": "r5c5"}
for pe, router in pe_router_map.items():
assert (pe, router) in ves, f"{pe} should connect to {router}"
# -- Views: PE ----------------------------------------------------------------
def test_pe_view_has_all_components():
v = _graph().pe_view
assert set(v.nodes.keys()) == {
"pe_cpu", "pe_scheduler", "pe_dma", "pe_fetch_store",
"pe_gemm", "pe_math", "pe_mmu", "pe_tcm", "pe_ipcq",
}
def test_pe_view_edges():
v = _graph().pe_view
ves = {(e.src, e.dst) for e in v.edges}
assert ("pe_cpu", "pe_scheduler") in ves
assert ("pe_scheduler", "pe_dma") in ves
assert ("pe_scheduler", "pe_gemm") in ves
assert ("pe_scheduler", "pe_math") in ves
assert ("pe_dma", "pe_tcm") in ves
assert ("pe_gemm", "pe_tcm") in ves
assert ("pe_math", "pe_tcm") in ves
# -- SRAM ----------------------------------------------------------------------
def test_sram_node_exists():
g = _graph()
assert "sip0.cube0.sram" in g.nodes
assert g.nodes["sip0.cube0.sram"].kind == "sram"
def test_sram_to_router_edges():
es = _edge_set(_graph())
cp = "sip0.cube0"
# SRAM connects to router r3c0
assert (f"{cp}.sram", f"{cp}.r3c0") in es
assert (f"{cp}.r3c0", f"{cp}.sram") in es
# -- PE_DMA -> Router (data path) ---------------------------------------------
def test_pe_dma_to_router_edges():
es = _edge_set(_graph())
cp = "sip0.cube0"
# Each PE DMA connects to its local router
pe_router_map = {
0: "r0c0", 1: "r0c1", 2: "r1c4", 3: "r1c5",
4: "r4c0", 5: "r4c1", 6: "r5c4", 7: "r5c5",
}
for i, router in pe_router_map.items():
assert (f"{cp}.pe{i}.pe_dma", f"{cp}.{router}") in es
# -- UCIe conn nodes connect to routers (not NOC) -----------------------------
def test_ucie_noc_reverse_edges():
"""UCIe ports connect to routers via conn nodes (bidirectional)."""
es = _edge_set(_graph())
cp = "sip0.cube1" # non-edge cube to avoid io-cube edges
for port in ("N", "S", "E", "W"):
# Each conn has edges: ucie<->conn, conn<->router
for ci in range(4):
conn = f"{cp}.ucie-{port}.conn{ci}"
assert (f"{cp}.ucie-{port}", conn) in es, \
f"missing ucie-{port}->conn{ci}"
assert (conn, f"{cp}.ucie-{port}") in es, \
f"missing conn{ci}->ucie-{port}"
def test_ucie_conn_nodes_exist():
"""Each UCIe port must have n_connections independent conn nodes."""
g = _graph()
cp = "sip0.cube0"
for port in ("N", "S", "E", "W"):
for ci in range(4):
conn_id = f"{cp}.ucie-{port}.conn{ci}"
assert conn_id in g.nodes, f"missing {conn_id}"
assert g.nodes[conn_id].kind == "ucie_conn"
assert g.nodes[conn_id].attrs["overhead_ns"] == 0.0
def test_ucie_conn_edge_bw():
"""conn<->router edges must have per_connection_bw_gbs (128 GB/s)."""
g = _graph()
edge_map = {(e.src, e.dst): e for e in g.edges}
cp = "sip0.cube0"
# Check conn0 for each port connects to a router with correct bw
for port in ("N", "S", "E", "W"):
for ci in range(4):
conn_id = f"{cp}.ucie-{port}.conn{ci}"
# Find the ucie_conn_to_router edge
conn_edges = [e for e in g.edges
if e.src == conn_id and e.kind == "ucie_conn_to_router"]
assert len(conn_edges) == 1, f"expected 1 ucie_conn_to_router from {conn_id}"
assert conn_edges[0].bw_gbs == 128.0
def test_cross_cube_path_includes_conn():
"""PE cross-cube path must traverse conn nodes."""
g = _graph()
router = PathRouter(g)
path = router.find_path("sip0.cube0.pe0", "sip0.cube1.hbm_ctrl.pe0")
conn_nodes = [n for n in path if ".conn" in n]
assert len(conn_nodes) >= 2, f"Expected >=2 conn nodes in path, got {conn_nodes}"
# -- Cube view: edges ---------------------------------------------------------
def test_cube_view_pe_to_router_edges():
"""All PEs connect to their routers in cube view."""
v = _graph().cube_view
ves = {(e.src, e.dst) for e in v.edges}
pe_router_map = {"pe0": "r0c0", "pe1": "r0c1", "pe2": "r1c4", "pe3": "r1c5",
"pe4": "r4c0", "pe5": "r4c1", "pe6": "r5c4", "pe7": "r5c5"}
for pe, router in pe_router_map.items():
assert (pe, router) in ves, f"{pe} should connect to {router}"
def test_cube_view_sram():
v = _graph().cube_view
assert "sram" in v.nodes
ves = {(e.src, e.dst) for e in v.edges}
assert ("sram", "r3c0") in ves
def test_cube_view_hbm_router():
"""Cube view: PE routers connect to hbm_ctrl."""
v = _graph().cube_view
ves = {(e.src, e.dst) for e in v.edges}
assert ("r0c0", "hbm_ctrl") in ves # PE0's router → HBM
def test_cube_view_m_cpu_router():
"""Cube view: m_cpu connects to its router r1c2."""
v = _graph().cube_view
ves = {(e.src, e.dst) for e in v.edges}
assert ("m_cpu", "r1c2") in ves
assert ("r1c2", "m_cpu") in ves