commit - release 1

This commit is contained in:
2026-03-18 11:47:48 -07:00
commit 6f43807900
109 changed files with 14909 additions and 0 deletions
View File
+64
View File
@@ -0,0 +1,64 @@
import argparse
import sys
from benches.loader import resolve_bench
from kernbench.cli.probe import cmd_probe
from kernbench.cli.report import format_report
from kernbench.common.types import SimEngine
from kernbench.runtime_api.bench_runner import run_bench
from kernbench.runtime_api.types import DeviceSelector, resolve_device
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import resolve_topology
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(prog="kernbench")
sub = p.add_subparsers(dest="cmd", required=True)
runp = sub.add_parser("run", help="Run a benchmark")
runp.add_argument("--topology", required=True)
runp.add_argument("--bench", required=True)
runp.add_argument(
"--device", default=None, help="Target device: 'all' or 'sip:<N>' (default: all)"
)
runp.set_defaults(_handler=cmd_run)
probep = sub.add_parser("probe", help="Probe latency and BW for predefined traffic patterns")
probep.add_argument("--topology", required=True)
probep.add_argument("--case", default="all", help="Case name or 'all' (default: all)")
probep.set_defaults(_handler=cmd_probe)
return p
def engine_factory(topology: object, device: DeviceSelector) -> SimEngine:
topo_obj = getattr(topology, "topology_obj", topology)
return GraphEngine(topo_obj)
def cmd_run(args) -> int:
print("> Running benchmark with:", args)
topo = resolve_topology(args.topology)
bench = resolve_bench(args.bench)
device = resolve_device(args.device)
result = run_bench(topology=topo, bench_fn=bench, device=device, engine_factory=engine_factory)
topo_obj = getattr(topo, "topology_obj", topo)
spec = getattr(topo_obj, "spec", None)
if result.traces:
print(format_report(result.traces, title=args.bench, spec=spec))
print(result.summary_text())
return 0 if result.completion.ok else 1
def main(argv=None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
return int(args._handler(args))
if __name__ == "__main__":
sys.exit(main())
+248
View File
@@ -0,0 +1,248 @@
"""kernbench probe: latency and BW verification utility.
Runs predefined traffic patterns through the simulation engine and reports
latency, effective bandwidth, bottleneck bandwidth, and utilization for each
case. Validates monotonicity invariants across hop counts and access types.
"""
from __future__ import annotations
from pathlib import Path
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.policy.routing.router import AddressResolver, PathRouter
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
from kernbench.sim_engine.engine import GraphEngine
from kernbench.topology.builder import load_topology
from kernbench.topology.types import TopologyGraph
# -- Helpers ----------------------------------------------------------
def _hbm_pa(sip: int, cube: int, pe_id: int, spec: dict) -> int:
mm = spec["cube"]["memory_map"]
slice_bytes = mm["hbm_total_gb_per_cube"] * (1 << 30) // mm["hbm_slices_per_cube"]
pa = PhysAddr.pe_hbm_addr(
rack_id=0, sip_id=sip, cube_id=cube, pe_id=pe_id,
pe_local_hbm_offset=0x1000, slice_size_bytes=slice_bytes,
)
return pa.encode()
def _build_edge_map(graph: TopologyGraph) -> dict[tuple[str, str], object]:
return {(e.src, e.dst): e for e in graph.edges}
def _formula_breakdown(
path: list[str], nbytes: int, edge_map: dict, graph: TopologyGraph,
) -> tuple[float, float, float, float]:
"""Return (wire_ns, overhead_ns, drain_ns, formula_ns) for a path."""
ns_per_mm = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
wire_ns = 0.0
for i in range(len(path) - 1):
e = edge_map.get((path[i], path[i + 1]))
if e:
wire_ns += e.distance_mm * ns_per_mm
overhead_ns = 0.0
for nid in path:
node = graph.nodes.get(nid)
if node:
overhead_ns += float(node.attrs.get("overhead_ns", 0.0))
bws = [e.bw_gbs for i in range(len(path) - 1)
if (e := edge_map.get((path[i], path[i + 1]))) and e.bw_gbs]
drain_ns = nbytes / min(bws) if bws else 0.0
return wire_ns, overhead_ns, drain_ns, wire_ns + overhead_ns + drain_ns
def _bottleneck_bw(path: list[str], edge_map: dict) -> float | None:
"""Per-request bottleneck: single request uses one connection."""
bws: list[float] = []
for i in range(len(path) - 1):
e = edge_map.get((path[i], path[i + 1]))
if e and e.bw_gbs:
bws.append(e.bw_gbs)
return min(bws) if bws else None
def _fmt_bw(bw: float | None) -> str:
return f"{bw:.1f}" if bw is not None else "-"
def _fmt_util(eff: float, bn: float | None) -> str:
if bn is None or bn <= 0:
return "-"
return f"{eff / bn * 100:.1f}%"
def _short_name(node_id: str) -> str:
"""Shorten node id: keep last 2 segments to avoid ambiguity (xbar.pe0 vs pe0)."""
parts = node_id.split(".")
return ".".join(parts[-2:]) if len(parts) >= 2 else node_id
def _short_path(path: list[str]) -> str:
return " -> ".join(_short_name(n) for n in path)
# -- Probe runner -----------------------------------------------------
def run_probe(topology_path: str, case_filter: str | None = None) -> int:
path = Path(topology_path).expanduser().resolve()
graph = load_topology(path)
edge_map = _build_edge_map(graph)
spec = graph.spec
resolver = AddressResolver(graph)
router = PathRouter(graph)
nbytes = 4096
show_all = case_filter is None or case_filter == "all"
# === H2D Write ===
h2d_cases = [
("h2d-1hop", 0, 1),
("h2d-2hop", 4, 2),
("h2d-3hop", 8, 3),
("h2d-4hop", 12, 4),
]
h2d_results: list[tuple[str, int, float, float, float | None]] = []
h2d_paths: list[tuple[str, list[str], list[str], list[str]]] = []
print()
print("=== H2D Write Latency (IO->HBM, varying hop count) ===")
print(f" {'Case':<14} {'Target':<16} {'Hops':>4} {'Actual':>8}"
f" {'Ovhd':>6} {'Drain':>6} {'Wire':>5} {'Ovhd%':>6} {'Drain%':>7}"
f" {'Eff.BW':>8} {'BN.BW':>8} {'Util%':>6}")
print(" " + "-" * 115)
for name, cube, hops in h2d_cases:
if not show_all and case_filter != name:
continue
engine = GraphEngine(graph)
pa = _hbm_pa(sip=0, cube=cube, pe_id=0, spec=spec)
msg = MemoryWriteMsg(
correlation_id="probe", request_id=name,
dst_sip=0, dst_cube=cube, dst_pe=0,
dst_pa=pa, nbytes=nbytes, pattern="zero",
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
total_ns = trace["total_ns"]
eff_bw = nbytes / total_ns if total_ns > 0 else 0.0
pa_obj = PhysAddr.decode(pa)
dst_node = resolver.resolve(pa_obj)
pcie_ep = resolver.find_pcie_ep(0)
io_cpu = resolver.find_io_cpu(0)
m_cpu = resolver.find_m_cpu(0, cube)
leg1 = router.find_node_path(pcie_ep, io_cpu)
leg2 = router.find_node_path(io_cpu, m_cpu)
leg3 = router.find_mcpu_dma_path(m_cpu, dst_node)
full_path = leg1 + leg2[1:] + leg3[1:]
bn_bw = _bottleneck_bw(full_path, edge_map)
# Forward path breakdown only (response path is implicit in actual_ns)
fwd_path = leg1 + leg2[1:] + leg3[1:]
wire, ovhd, drain, formula = _formula_breakdown(fwd_path, nbytes, edge_map, graph)
ovhd_pct = ovhd / total_ns * 100 if total_ns > 0 else 0
drain_pct = drain / total_ns * 100 if total_ns > 0 else 0
h2d_results.append((name, hops, total_ns, eff_bw, bn_bw))
h2d_paths.append((name, leg1, leg2, leg3))
print(f" {name:<14} cube{cube}.pe0{'':<8} {hops:>4} {total_ns:>8.2f}"
f" {ovhd:>6.1f} {drain:>6.1f} {wire:>5.2f} {ovhd_pct:>5.1f}% {drain_pct:>5.1f}%"
f" {eff_bw:>8.2f} {_fmt_bw(bn_bw):>8} {_fmt_util(eff_bw, bn_bw):>6}")
if len(h2d_results) >= 2:
lats = [r[2] for r in h2d_results]
mono = all(lats[i] < lats[i + 1] for i in range(len(lats) - 1))
sym = "[v]" if mono else "[x]"
print(f" {sym} Monotonic increase: {'PASS' if mono else 'FAIL'}")
if h2d_paths:
print()
print(" Route Details:")
print(f" {'Case':<14} {'Leg':>4} Path")
print(" " + "-" * 80)
for name, leg1, leg2, leg3 in h2d_paths:
print(f" {name:<14} {'L1':>4} {_short_path(leg1)}")
print(f" {'':<14} {'L2':>4} {_short_path(leg2)}")
print(f" {'':<14} {'L3':>4} {_short_path(leg3)}")
# === PE DMA → HBM (direct PE-level injection) ===
# (name, sip, src_cube, src_pe, dst_cube, dst_pe)
pe_cases = [
("pe-local-hbm", 0, 0, 0, 0, 0), # pe0 → slice0 (local, 256 GB/s)
("pe-same-half-hbm", 0, 0, 0, 0, 1), # pe0 → slice1 (xbar chain, 128 GB/s)
("pe-cross-half-hbm", 0, 0, 0, 0, 4), # pe0 → slice4 (xbar chain, 128 GB/s)
("pe-cross-cube-hbm", 0, 0, 0, 1, 0), # cube0.pe0 → cube1.slice0 (NOC, 128 GB/s)
]
pe_results: list[tuple[str, float, float, float | None]] = []
pe_paths: list[tuple[str, list[str]]] = []
print()
print("=== PE DMA Latency (pe_dma -> xbar -> HBM, direct injection) ===")
print(f" {'Case':<22} {'Target':<28} {'Actual':>8}"
f" {'Ovhd':>6} {'Drain':>6} {'Wire':>5} {'Ovhd%':>6} {'Drain%':>7}"
f" {'Eff.BW':>8} {'BN.BW':>8} {'Util%':>6}")
print(" " + "-" * 120)
for name, sip, src_cube, src_pe, dst_cube, dst_pe in pe_cases:
if not show_all and case_filter != name:
continue
engine = GraphEngine(graph)
dst_pa = _hbm_pa(sip=sip, cube=dst_cube, pe_id=dst_pe, spec=spec)
msg = PeDmaMsg(
correlation_id="probe", request_id=name,
src_sip=sip, src_cube=src_cube, src_pe=src_pe,
dst_pa=dst_pa, nbytes=nbytes,
)
h = engine.submit(msg)
engine.wait(h)
_, trace = engine.get_completion(h)
total_ns = trace["total_ns"]
eff_bw = nbytes / total_ns if total_ns > 0 else 0.0
pe_ref = f"sip{sip}.cube{src_cube}.pe{src_pe}"
pa_obj = PhysAddr.decode(dst_pa)
dst_node = resolver.resolve(pa_obj)
dma_path = router.find_path(pe_ref, dst_node)
bn_bw = _bottleneck_bw(dma_path, edge_map)
wire, ovhd, drain, formula = _formula_breakdown(dma_path, nbytes, edge_map, graph)
ovhd_pct = ovhd / total_ns * 100 if total_ns > 0 else 0
drain_pct = drain / total_ns * 100 if total_ns > 0 else 0
target_str = f"c{src_cube}.pe{src_pe}->c{dst_cube}.slice{dst_pe}"
pe_results.append((name, total_ns, eff_bw, bn_bw))
pe_paths.append((name, dma_path))
print(f" {name:<22} {target_str:<28} {total_ns:>8.2f}"
f" {ovhd:>6.1f} {drain:>6.1f} {wire:>5.2f} {ovhd_pct:>5.1f}% {drain_pct:>5.1f}%"
f" {eff_bw:>8.2f} {_fmt_bw(bn_bw):>8} {_fmt_util(eff_bw, bn_bw):>6}")
if len(pe_results) >= 2:
local = [r for r in pe_results if "local" in r[0]]
chain = [r for r in pe_results if "local" not in r[0]]
if local and chain:
print(f" * Local BN: {_fmt_bw(local[0][3])} GB/s, "
f"Chain/NOC BN: {_fmt_bw(chain[0][3])} GB/s")
if pe_paths:
print()
print(" Route Details:")
print(f" {'Case':<22} Path")
print(" " + "-" * 80)
for name, dma_path in pe_paths:
print(f" {name:<22} {_short_path(dma_path)}")
print()
return 0
def cmd_probe(args) -> int:
return run_probe(args.topology, getattr(args, "case", "all"))
+175
View File
@@ -0,0 +1,175 @@
"""Performance report formatter for bench results."""
from __future__ import annotations
_DTYPE_BITS: dict[str, int] = {
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
"f32": 32, "fp32": 32, "float32": 32,
"i8": 8, "int8": 8, "i16": 16, "int16": 16, "i32": 32, "int32": 32,
}
def format_report(
traces: list[dict],
title: str = "Benchmark",
spec: dict | None = None,
) -> str:
"""Format collected traces into a human-readable performance report.
spec: topology spec dict for peak TFLOPS / BW extraction.
"""
peak_tflops_f16, peak_hbm_bw_gbs = _extract_peaks(spec)
num_pes = _count_pes(spec)
lines: list[str] = []
title_line = f"-- {title} Performance Report "
deploy_entries = [t for t in traces if t.get("phase") not in ("kernel",)]
kernel_entries = [t for t in traces if t.get("phase") == "kernel"]
# ── Title ──
# Compute max header width for consistent separator lengths
_cmd_hdr = (f"{'Cmd':<10} {'Name':<12} {'SIP':>4} {'Cube':>5} {'PE':>4} {'Bytes':>10} "
f"{'Lat(ns)':>10} {'Xfer(ns)':>10} {'Proc(ns)':>10} "
f"{'BW(GB/s)':>10} {'MinBW':>10} {'Util%':>7}")
report_width = len(_cmd_hdr)
lines.append(title_line + "-" * max(0, report_width - len(title_line)))
# ── Command summary ──
if deploy_entries:
lines.append("")
hdr = (f"{'Cmd':<10} {'Name':<12} {'SIP':>4} {'Cube':>5} {'PE':>4} {'Bytes':>10} "
f"{'Lat(ns)':>10} {'Xfer(ns)':>10} {'Proc(ns)':>10} "
f"{'BW(GB/s)':>10} {'MinBW':>10} {'Util%':>7}")
lines.append(hdr)
lines.append("-" * len(hdr))
for e in deploy_entries:
lat = e.get("total_ns", 0.0)
nb = e.get("nbytes", 0)
sip = e.get("sip", "-")
pe = e.get("pe", "-")
cube = e.get("cube", "-")
cmd = e.get("phase", "deploy")
xfer_ns = e.get("xfer_ns", 0.0)
proc_ns = lat - xfer_ns if xfer_ns > 0 else 0.0
bw = nb / lat if lat > 0 else 0.0
min_bw = nb / xfer_ns if xfer_ns > 0 else 0.0
util = (xfer_ns / lat * 100) if lat > 0 and xfer_ns > 0 else 0.0
lines.append(
f"{cmd:<10} {e.get('name', '?'):<12} {str(sip):>4} {str(cube):>5} {str(pe):>4} {nb:>10} "
f"{lat:>10.1f} {xfer_ns:>10.1f} {proc_ns:>10.1f} "
f"{bw:>10.1f} {min_bw:>10.1f} {util:>6.1f}%"
)
# ── Kernel summary ──
if kernel_entries:
lines.append("")
k_hdr = (f"{'Phase':<10} {'Name':<12} {'PE':>4} {'E2E(ns)':>10} "
f"{'PE(ns)':>10} {'DMA(ns)':>10} {'Comp(ns)':>10} "
f"{'Bound':<8} {'TFLOPS':>8} {'Peak':>8} {'Util%':>7}")
lines.append(k_hdr)
lines.append("-" * len(k_hdr))
for e in kernel_entries:
e2e_ns = e.get("total_ns", 0.0)
pe_ns = e.get("pe_exec_ns", e2e_ns)
dma_ns = e.get("dma_ns", 0.0)
compute_ns = e.get("compute_ns", 0.0)
target_pe = e.get("target_pe", "-")
scalars = e.get("scalars", [])
pe_str = "all" if target_pe == "all" else str(target_pe)
n_active = num_pes if target_pe == "all" else 1
# Bound indicator based on measured DMA vs compute time
if dma_ns > 0 or compute_ns > 0:
bound = "memory" if dma_ns >= compute_ns else "compute"
else:
bound = "-"
achieved = _calc_tflops(scalars, pe_ns)
peak_total = peak_tflops_f16 * n_active
util = (achieved / peak_total * 100) if peak_total > 0 else 0.0
lines.append(
f"{'kernel':<10} {e.get('name', '?'):<12} {pe_str:>4} {e2e_ns:>10.1f} "
f"{pe_ns:>10.1f} {dma_ns:>10.1f} {compute_ns:>10.1f} "
f"{bound:<8} {achieved:>8.3f} {peak_total:>8.1f} {util:>6.1f}%"
)
# ── Per-PE summary ──
pe_deploy = _per_pe_deploy(deploy_entries)
if len(pe_deploy) > 1:
lines.append("")
pe_title = (f"-- Per-PE Summary (peak: {peak_tflops_f16:.1f} TFLOPS/PE, "
f"{peak_hbm_bw_gbs:.0f} GB/s HBM BW) ")
pe_hdr = (f"{'PE':>4} {'Deploy(ns)':>10} {'BW(GB/s)':>10} {'BW Util':>8} "
f"{'Kernel(ns)':>10} {'TFLOPS':>8} {'Util':>7}")
pe_width = max(len(pe_title), len(pe_hdr))
lines.append(pe_title + "-" * max(0, pe_width - len(pe_title)))
lines.append(pe_hdr)
lines.append("-" * pe_width)
k_ns = sum(e.get("pe_exec_ns", e.get("total_ns", 0.0)) for e in kernel_entries)
k_scalars = kernel_entries[0].get("scalars", []) if kernel_entries else []
n_active = len(pe_deploy)
total_achieved = _calc_tflops(k_scalars, k_ns)
per_pe_tflops = total_achieved / n_active if n_active > 0 else 0.0
pe_util = (per_pe_tflops / peak_tflops_f16 * 100) if peak_tflops_f16 > 0 else 0.0
for pe_id in sorted(pe_deploy):
d_ns, d_bytes = pe_deploy[pe_id]
d_bw = d_bytes / d_ns if d_ns > 0 else 0.0
d_util = (d_bw / peak_hbm_bw_gbs * 100) if peak_hbm_bw_gbs > 0 else 0.0
lines.append(
f"{pe_id:>4} {d_ns:>10.1f} {d_bw:>10.1f} {d_util:>7.1f}% "
f"{k_ns:>10.1f} {per_pe_tflops:>8.3f} {pe_util:>6.1f}%"
)
lines.append("")
return "\n".join(lines)
def _extract_peaks(spec: dict | None) -> tuple[float, float]:
"""Extract peak TFLOPS (f16) and HBM BW (GB/s) from spec."""
if spec is None:
return 0.0, 0.0
cube = spec.get("cube", {})
pe_template = cube.get("pe_template", {})
comps = pe_template.get("components", {})
gemm_attrs = comps.get("pe_gemm", {}).get("attrs", {})
peak_tflops = float(gemm_attrs.get("peak_tflops_f16", 0.0))
cube_links = cube.get("links", {})
hbm_bw = float(cube_links.get("xbar_to_hbm_bw_gbs", 0.0))
return peak_tflops, hbm_bw
def _count_pes(spec: dict | None) -> int:
if spec is None:
return 8
cube = spec.get("cube", {})
layout = cube.get("pe_layout", {})
per_corner = layout.get("pe_per_corner", 2)
corners = len(layout.get("corners", ["NW", "NE", "SW", "SE"]))
return per_corner * corners
def _calc_tflops(scalars: list, latency_ns: float) -> float:
"""Calculate achieved TFLOPS from scalar args [M, K, N] and latency."""
if len(scalars) < 3 or latency_ns <= 0:
return 0.0
m, k, n = scalars[0], scalars[1], scalars[2]
flops = 2.0 * m * k * n
return flops / (latency_ns * 1e-9) / 1e12
def _per_pe_deploy(deploy_entries: list[dict]) -> dict[int, tuple[float, int]]:
"""Aggregate deploy latency and bytes per PE."""
result: dict[int, tuple[float, int]] = {}
for e in deploy_entries:
pe = e.get("pe", 0)
lat = e.get("total_ns", 0.0)
nb = e.get("nbytes", 0)
if pe in result:
old_ns, old_bytes = result[pe]
result[pe] = (old_ns + lat, old_bytes + nb)
else:
result[pe] = (lat, nb)
return result
View File
+150
View File
@@ -0,0 +1,150 @@
"""PE-internal command types and handles (ADR-0014).
Generated by triton_emu (TLContext) and consumed by PE component
implementations (PE_CPU, PE_SCHEDULER, PE_DMA, PE_GEMM, PE_MATH).
Command lifecycle:
Triton kernel → TLContext → [PeCommand list] → PE_CPU → PE_SCHEDULER → engines
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
import simpy
# ── Handles ───────────────────────────────────────────────────────
@dataclass(frozen=True)
class TensorHandle:
"""Opaque reference to a tensor residing in PE_TCM.
Returned by tl.load, tl.dot, tl.exp, etc.
Carries metadata for command generation; data field is reserved
for future validate mode (numpy array).
"""
id: str
pa: int # physical address in HBM/TCM
shape: tuple[int, ...]
dtype: str
nbytes: int # total byte size
data: object = None # reserved for validate mode
@dataclass(frozen=True)
class CompletionHandle:
"""Opaque handle for a non-blocking composite command.
Returned by tl.composite, consumed by tl.wait.
"""
id: str
# ── PE Commands ───────────────────────────────────────────────────
@dataclass(frozen=True)
class DmaReadCmd:
"""DMA READ: HBM → PE_TCM."""
handle: TensorHandle
src_pa: int
nbytes: int
@dataclass(frozen=True)
class DmaWriteCmd:
"""DMA WRITE: PE_TCM → HBM."""
handle: TensorHandle
dst_pa: int
nbytes: int
@dataclass(frozen=True)
class GemmCmd:
"""GEMM engine command: matrix multiply on TCM data.
out = a @ b, all operands in TCM.
"""
a: TensorHandle
b: TensorHandle
out: TensorHandle
m: int
k: int
n: int
@dataclass(frozen=True)
class MathCmd:
"""MATH engine command: unary/binary/reduction on TCM data.
op: "exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin",
"add", "sub", "mul", "div", "where",
"sum", "max", "min"
"""
op: str
inputs: tuple[TensorHandle, ...]
out: TensorHandle
axis: int | None = None # for reductions
@dataclass(frozen=True)
class CompositeCmd:
"""Composite command: tiled pipeline of DMA_READ + COMPUTE + DMA_WRITE.
Non-blocking — submitted to PE_SCHEDULER which manages tile splitting
and pipeline overlaps (ADR-0014 D3.2).
"""
completion: CompletionHandle
op: Literal["gemm", "math"]
a: TensorHandle
b: TensorHandle | None
out_pa: int
out_nbytes: int
math_op: str | None = None # for op="math": which math operation
@dataclass(frozen=True)
class WaitCmd:
"""Wait for a specific composite or all pending composites."""
handle: CompletionHandle | None = None # None = wait all
@dataclass(frozen=True)
class PeCpuOverheadCmd:
"""PE_CPU scalar execution overhead (cycles)."""
cycles: int
# Union type for all PE commands
PeCommand = (
DmaReadCmd | DmaWriteCmd | GemmCmd | MathCmd
| CompositeCmd | WaitCmd | PeCpuOverheadCmd
)
@dataclass
class PeInternalTxn:
"""PE-internal message flowing PE_CPU → PE_SCHEDULER → engines.
Carries a single PeCommand and a completion event. PE_CPU creates one
PeInternalTxn per command during the replay phase and sends it to
PE_SCHEDULER, which routes it to the appropriate engine (PE_DMA,
PE_GEMM, PE_MATH). The engine signals ``done`` on completion.
"""
command: PeCommand
done: simpy.Event # succeeded when the engine completes this command
pe_prefix: str = "" # e.g. "sip0.cube0.pe0" — needed by PE_DMA for path resolution
result_data: dict[str, Any] = field(default_factory=dict)
+29
View File
@@ -0,0 +1,29 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, NewType, Protocol, TypeAlias
RequestHandle = NewType("RequestHandle", str)
Trace: TypeAlias = Any
@dataclass(frozen=True)
class Completion:
ok: bool
error_code: str | None = None
error_message: str | None = None
class SimEngine(Protocol):
"""
Backend simulation/runner engine contract.
Engine must be able to:
- accept requests created by RuntimeContext (submit/dispatch)
- report completion and optional trace for a given handle
"""
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]: ...
def submit(self, request: Any) -> RequestHandle: ...
def wait(self, handle: RequestHandle) -> None: ...
+4
View File
@@ -0,0 +1,4 @@
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext
__all__ = ["ComponentBase", "ComponentRegistry", "ComponentContext"]
+167
View File
@@ -0,0 +1,167 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class ComponentBase(ABC):
"""Base class for all SimPy component implementations (ADR-0007 D3, ADR-0015).
Each component corresponds to one node in the compiled topology graph.
It models the processing overhead at that node as a SimPy generator,
allowing future implementations to add queueing and contention.
Port model (ADR-0015 D1):
in_ports[src_node_id] — SimPy Store for incoming messages from src
out_ports[dst_node_id] — SimPy Store for outgoing messages to dst
Ports are wired by GraphEngine at initialization; wire processes model
propagation delay between connected ports (ADR-0015 D2).
Context (ADR-0015 D4):
ctx — ComponentContext with router and resolver.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
self.node = node
self.ctx = ctx
self.in_ports: dict[str, simpy.Store] = {}
self.out_ports: dict[str, simpy.Store] = {}
def start(self, env: simpy.Environment) -> None:
"""Called once after all ports are wired.
Default: starts a fan-in collector and a generic forwarding worker.
The worker calls self.run() for per-component latency, then routes the
Transaction to the next hop or signals done (duck-typed; no direct
Transaction import to avoid circular dependencies).
Override in components that need custom fan-out / aggregation logic
(e.g. MCpuComponent, IoCpuComponent for kernel launch).
"""
if not self.in_ports:
return
self._inbox: simpy.Store = simpy.Store(env)
for port in self.in_ports.values():
env.process(self._fan_in(port))
env.process(self._worker(env))
def _fan_in(self, port: simpy.Store) -> Generator:
"""Relay messages from one in_port into the shared inbox."""
while True:
msg = yield port.get()
yield self._inbox.put(msg)
def _worker(self, env: simpy.Environment) -> Generator:
"""Generic forwarding worker: spawns _forward_txn per message (pipeline)."""
while True:
txn: Any = yield self._inbox.get()
env.process(self._forward_txn(env, txn))
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Apply run() latency, then forward to next hop or drain at terminal."""
yield from self.run(env, txn.nbytes)
next_hop = txn.next_hop # duck-typed: Transaction.next_hop
if next_hop:
yield self.out_ports[next_hop].put(txn.advance())
else:
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
txn.done.succeed()
@abstractmethod
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
"""SimPy process: yield one or more events for this node's processing.
Subclasses yield env.timeout(overhead_ns) or compute latency dynamically.
Called by _forward_txn and subclass-specific handlers.
"""
...
class PeEngineBase(ComponentBase):
"""Base class for PE-internal engines (PE_DMA, PE_GEMM, PE_MATH).
Provides:
- ``_pe_prefix``: extracted from node.id (e.g. "sip0.cube0.pe0")
- Dual-message ``_worker``: dispatches PeInternalTxn to
``handle_command()`` and Transaction to inherited ``_forward_txn()``.
- ``init_resources(env)``: hook for subclass resource initialization,
called by ``start()`` before the worker is spawned.
Subclass contract:
1. Override ``handle_command(env, pe_txn)`` — process a PeInternalTxn.
2. Override ``run(env, nbytes)`` — yield component latency.
3. Optionally override ``init_resources(env)`` for DMA channels, etc.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._pe_prefix: str = node.id.rsplit(".", 1)[0]
def start(self, env: simpy.Environment) -> None:
self.init_resources(env)
super().start(env)
def init_resources(self, env: simpy.Environment) -> None:
"""Hook for subclass resource initialization. Called before worker spawn."""
def _worker(self, env: simpy.Environment) -> Generator:
"""Dual-message dispatch: PeInternalTxn → handle_command, Transaction → _forward_txn."""
from kernbench.common.pe_commands import PeInternalTxn
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, PeInternalTxn):
env.process(self.handle_command(env, msg))
else:
env.process(self._forward_txn(env, msg))
@abstractmethod
def handle_command(self, env: simpy.Environment, pe_txn: Any) -> Generator:
"""Process a PE-internal command (PeInternalTxn).
Subclass must:
- Perform engine-specific work (acquire resources, compute, etc.)
- Call ``pe_txn.done.succeed()`` on completion.
"""
...
class ComponentRegistry:
"""DI registry: maps node.impl strings to ComponentBase subclasses.
Resolution order for ComponentRegistry.create(node, overrides, ctx):
1. overrides[node.impl] — caller-injected override
2. _registry[node.impl] — globally registered impl
3. Error — no fallback; every node must have an impl
"""
_registry: dict[str, type[ComponentBase]] = {}
@classmethod
def register(cls, impl: str, component_cls: type[ComponentBase]) -> None:
cls._registry[impl] = component_cls
@classmethod
def create(
cls,
node: Node,
overrides: dict[str, type[ComponentBase]] | None = None,
ctx: ComponentContext | None = None,
) -> ComponentBase:
if overrides and node.impl in overrides:
return overrides[node.impl](node, ctx)
if node.impl in cls._registry:
return cls._registry[node.impl](node, ctx)
raise ValueError(
f"No component registered for impl '{node.impl}' (node: {node.id}). "
f"Register it in kernbench.components.impls.__init__."
)
+52
View File
@@ -0,0 +1,52 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import simpy
from kernbench.policy.routing.router import AddressResolver, PathRouter
@dataclass
class ComponentContext:
"""Topology services injected into every component implementation.
Required by components that need routing or address resolution
(IoCpuComponent, MCpuComponent, …). TransitComponent ignores ctx.
Passed via ComponentRegistry.create(node, overrides, ctx=ctx).
"""
router: PathRouter
resolver: AddressResolver
positions: dict[str, tuple[float, float] | None] # node_id → pos_mm
ns_per_mm: float # wire propagation constant (from topology spec)
edge_map: dict[tuple[str, str], Any] = field(default_factory=dict)
spec: dict = field(default_factory=dict) # topology spec (cube layout, PE count, etc.)
def get_shared_resource(
self, env: simpy.Environment, key: str, capacity: int = 1,
) -> simpy.Resource:
"""Return a shared SimPy Resource, creating it on first access.
Used by PE components that share a resource across engines within
the same PE (e.g. accel_slot shared by PE_GEMM and PE_MATH).
Key should be scoped per PE: e.g. "sip0.cube0.pe0.accel_slot".
"""
if not hasattr(self, "_shared_resources"):
self._shared_resources: dict[str, simpy.Resource] = {}
if key not in self._shared_resources:
self._shared_resources[key] = simpy.Resource(env, capacity=capacity)
return self._shared_resources[key]
def compute_drain_ns(self, path: list[str], nbytes: int) -> float:
"""Wormhole drain time: nbytes / bottleneck_bw along path."""
min_bw = float("inf")
for i in range(len(path) - 1):
edge = self.edge_map.get((path[i], path[i + 1]))
if edge and getattr(edge, "bw_gbs", None):
min_bw = min(min_bw, edge.bw_gbs)
if min_bw == float("inf"):
return 0.0
return nbytes / min_bw
@@ -0,0 +1,54 @@
"""Concrete component implementations.
Each module registers its component(s) with ComponentRegistry on import.
Import this package to activate all built-in implementations.
"""
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.forwarding import TransitComponent
from kernbench.components.impls.hbm_ctrl import HbmCtrlComponent
from kernbench.components.impls.io_cpu import IoCpuComponent
from kernbench.components.impls.m_cpu import MCpuComponent
from kernbench.components.impls.noc import TwoDMeshNocComponent
from kernbench.components.impls.pcie_ep import PcieEpComponent
from kernbench.components.impls.pe_cpu import PeCpuComponent
from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.impls.pe_gemm import PeGemmComponent
from kernbench.components.impls.pe_math import PeMathComponent
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
from kernbench.components.impls.pe_tcm import PeTcmComponent
from kernbench.components.impls.sram import SramComponent
ComponentRegistry.register("forwarding_v1", TransitComponent)
ComponentRegistry.register("switch_v1", TransitComponent)
ComponentRegistry.register("noc_v1", TransitComponent)
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
ComponentRegistry.register("ucie_v1", TransitComponent)
ComponentRegistry.register("xbar_v1", TransitComponent)
ComponentRegistry.register("pcie_ep_v1", PcieEpComponent)
ComponentRegistry.register("io_cpu_v1", IoCpuComponent)
ComponentRegistry.register("m_cpu_v1", MCpuComponent)
ComponentRegistry.register("hbm_ctrl_v1", HbmCtrlComponent)
ComponentRegistry.register("sram_v1", SramComponent)
ComponentRegistry.register("pe_cpu_v1", PeCpuComponent)
ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
ComponentRegistry.register("pe_math_v1", PeMathComponent)
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
__all__ = [
"HbmCtrlComponent",
"IoCpuComponent",
"MCpuComponent",
"PcieEpComponent",
"PeCpuComponent",
"PeDmaComponent",
"PeGemmComponent",
"PeMathComponent",
"PeSchedulerComponent",
"PeTcmComponent",
"TransitComponent",
"TwoDMeshNocComponent",
"SramComponent",
]
@@ -0,0 +1,27 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class TransitComponent(ComponentBase):
"""Transit component for NOC, UCIe, XBAR nodes.
Applies overhead_ns processing delay (from node.attrs) then forwards the
Transaction to the next hop via inherited _forward_txn().
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
+101
View File
@@ -0,0 +1,101 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class HbmCtrlComponent(ComponentBase):
"""HBM controller: terminal component that models HBM access latency.
Dual-channel model: separate read and write resources (each capacity=1)
allowing concurrent read/write like PE_DMA. Multiple reads or multiple
writes still serialize within their respective channel.
On completion, creates a ResponseMsg and sends it back on the reverse path
so that response latency is modeled through the fabric.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._read: simpy.Resource | None = None
self._write: simpy.Resource | None = None
def start(self, env: simpy.Environment) -> None:
capacity = int(self.node.attrs.get("capacity", 1))
self._read = simpy.Resource(env, capacity=capacity)
self._write = simpy.Resource(env, capacity=capacity)
super().start(env)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _select_channel(self, txn: Any) -> simpy.Resource:
"""Select channel based on request type: write requests → write, else → read."""
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
assert self._read is not None and self._write is not None
req = txn.request
if isinstance(req, MemoryWriteMsg):
return self._write
if isinstance(req, PeDmaMsg) and req.is_write:
return self._write
return self._read
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch each incoming txn to a concurrent process for channel-level parallelism."""
while True:
txn: Any = yield self._inbox.get()
env.process(self._handle_txn(env, txn))
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Acquire channel, run, apply drain, send response."""
channel = self._select_channel(txn)
with channel.request() as req:
yield req
yield from self.run(env, txn.nbytes)
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
yield from self._send_response(env, txn)
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
"""Create ResponseMsg and send on reverse path back to originator.
PeDmaMsg is a direct probe with no IO_CPU/M_CPU aggregation in the path,
so we succeed txn.done directly instead of sending a response Transaction.
"""
from kernbench.runtime_api.kernel import PeDmaMsg
if isinstance(txn.request, PeDmaMsg):
txn.done.succeed()
return
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2 and self.ctx:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
pe_id = int(parts[3].replace("slice", ""))
resp_msg = ResponseMsg(
correlation_id=txn.request.correlation_id,
request_id=txn.request.request_id,
src_cube=cube_id, src_pe=pe_id, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
+145
View File
@@ -0,0 +1,145 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class IoCpuComponent(ComponentBase):
"""IO_CPU component: multi-cube fan-out with response aggregation.
Forward path:
1. Applies overhead_ns processing overhead.
2. Resolves target cube(s) from request.target_cubes.
3. Fans out sub-Transactions to each target cube's M_CPU.
Response path:
Collects ResponseMsg from each M_CPU. When all cube responses are
received, succeeds the parent txn.done.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
# Pending fan-out tracking: request_id → (expected, received, parent_txn_done)
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
while True:
txn: Any = yield self._inbox.get()
if getattr(txn, "is_response", False):
self._collect_response(txn)
else:
yield from self.run(env, txn.nbytes)
env.process(self._dispatch_to_m_cpus(env, txn))
def _collect_response(self, resp_txn: Any) -> None:
"""Receive a cube response and increment the aggregation counter."""
key = resp_txn.request.request_id
if key not in self._pending:
return
expected, received, parent_done = self._pending[key]
received += 1
if received >= expected:
parent_done.succeed()
del self._pending[key]
else:
self._pending[key] = (expected, received, parent_done)
def _dispatch_to_m_cpus(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out sub-Transactions to target cube M_CPUs, wait for responses."""
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
request = txn.request
try:
cube_targets = self._resolve_cube_targets(request)
except Exception:
txn.done.succeed()
return
if not cube_targets:
txn.done.succeed()
return
# Setup aggregation
self._pending[request.request_id] = (len(cube_targets), 0, txn.done)
# Fan out to each target cube's M_CPU
for sip, cube in cube_targets:
try:
m_cpu_id = self.ctx.resolver.find_m_cpu(sip, cube)
path = self.ctx.router.find_node_path(self.node.id, m_cpu_id)
except Exception:
continue
if len(path) < 2:
continue
sub_txn = Transaction(
request=request, path=path, step=0,
nbytes=txn.nbytes, done=env.event(),
result_data=txn.result_data,
)
yield self.out_ports[path[1]].put(sub_txn.advance())
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
"""Return list of (sip, cube) pairs to fan out to."""
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
target_cubes = getattr(request, "target_cubes", "all")
if isinstance(request, MemoryWriteMsg):
sip = request.dst_sip
if target_cubes == "all":
cube = self._cube_from_pa(request.dst_pa, fallback=request.dst_cube)
return [(sip, cube)]
return [(sip, c) for c in target_cubes]
if isinstance(request, MemoryReadMsg):
sip = request.src_sip
if target_cubes == "all":
cube = self._cube_from_pa(request.src_pa, fallback=request.src_cube)
return [(sip, cube)]
return [(sip, c) for c in target_cubes]
if isinstance(request, KernelLaunchMsg):
my_sip = self._my_sip()
if target_cubes != "all":
return [(my_sip, c) for c in target_cubes]
# "all": derive from tensor shards, filtered to this SIP
seen: set[tuple[int, int]] = set()
targets: list[tuple[int, int]] = []
for arg in request.args:
if arg.arg_kind != "tensor":
continue
for shard in arg.shards:
if shard.sip != my_sip:
continue
key = (shard.sip, shard.cube)
if key not in seen:
seen.add(key)
targets.append(key)
return targets
return []
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
"""Extract cube_id from a physical address, with fallback."""
from kernbench.policy.address.phyaddr import PhysAddr
try:
return PhysAddr.decode(pa_val).cube_id
except Exception:
return fallback
def _my_sip(self) -> int:
"""Extract this IO_CPU's SIP ID from its node ID (e.g. 'sip0.io0.io_cpu' → 0)."""
return int(self.node.id.split(".")[0].replace("sip", ""))
+269
View File
@@ -0,0 +1,269 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class MCpuComponent(ComponentBase):
"""M_CPU component: multi-PE DMA fan-out with response aggregation.
Forward path (ADR-0015 D5):
When a forward Transaction arrives at m_cpu (terminal hop), M_CPU fans out
DMA sub-Transactions to target PEs' HBM slices. target_pe on the request
controls fan-out: int → single PE, "all" → all PEs in the cube.
Response path:
ResponseMsg from each hbm_ctrl arrives back at m_cpu. Once all PE responses
are collected, m_cpu sends an aggregate ResponseMsg on the reverse command
path back to io_cpu.
Transit:
When m_cpu is NOT the terminal hop (transit or response relay), the
Transaction is forwarded normally to the next hop.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
# Pending fan-out tracking: request_id → (expected, received, all_done_event)
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
# Store parent txn for response sending: request_id → parent_txn
self._parent_txns: dict[str, Any] = {}
# DMA engine resources (ADR-0015 D5, ADR-0014 D4): capacity=1 each
self._dma_write: simpy.Resource | None = None
self._dma_read: simpy.Resource | None = None
def start(self, env: simpy.Environment) -> None:
self._dma_write = simpy.Resource(env, capacity=1)
self._dma_read = simpy.Resource(env, capacity=1)
super().start(env)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch forward txns, collect response txns."""
from kernbench.runtime_api.kernel import KernelLaunchMsg
while True:
txn: Any = yield self._inbox.get()
if getattr(txn, "is_response", False):
self._collect_response(txn)
else:
yield from self.run(env, txn.nbytes)
next_hop = txn.next_hop
if next_hop:
yield self.out_ports[next_hop].put(txn.advance())
elif self.ctx is not None and txn.request is not None:
if isinstance(txn.request, KernelLaunchMsg):
env.process(self._kernel_launch_fanout(env, txn))
else:
env.process(self._dma_fanout(env, txn))
else:
txn.done.succeed()
def _collect_response(self, resp_txn: Any) -> None:
"""Receive a PE response and increment the aggregation counter."""
key = resp_txn.request.request_id
if key not in self._pending:
return
expected, received, all_done = self._pending[key]
received += 1
if received >= expected:
all_done.succeed()
del self._pending[key]
else:
self._pending[key] = (expected, received, all_done)
def _dma_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out DMA sub-Transactions to target PE(s), wait for responses,
then send aggregate response on reverse command path.
Each DMA transfer acquires the DMA resource (capacity=1 per ADR-0014 D4),
so multi-PE fan-out is serialized through the DMA engine.
"""
from kernbench.runtime_api.kernel import MemoryWriteMsg
request = txn.request
target_pe = getattr(request, "target_pe", "all")
dst_nodes = self._resolve_dma_destinations(request, target_pe)
if not dst_nodes:
txn.done.succeed()
return
# Setup aggregation
all_done = env.event()
self._pending[request.request_id] = (len(dst_nodes), 0, all_done)
self._parent_txns[request.request_id] = txn
# Select DMA resource based on operation type
dma_res = self._dma_write if isinstance(request, MemoryWriteMsg) else self._dma_read
# Fan out DMA sub-txns (serialized through DMA resource)
max_drain_ns = 0.0
for dst_node in dst_nodes:
try:
dma_path = self.ctx.router.find_mcpu_dma_path(self.node.id, dst_node)
except Exception:
continue
if len(dma_path) < 2:
continue
drain_ns = self.ctx.compute_drain_ns(dma_path, txn.nbytes)
max_drain_ns = max(max_drain_ns, drain_ns)
sub_txn = Transaction(
request=request, path=dma_path, step=0,
nbytes=txn.nbytes, done=env.event(),
drain_ns=drain_ns,
)
with dma_res.request() as req:
yield req
yield self.out_ports[dma_path[1]].put(sub_txn.advance())
# Wait for all PE responses
yield all_done
txn.result_data["xfer_ns"] = max_drain_ns
del self._parent_txns[request.request_id]
# Send aggregate response on reverse command path
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _kernel_launch_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out KernelLaunchMsg to target PE_CPU(s) via NOC (ADR-0009 D3).
Routes through find_node_path (M_CPU → NOC → PE_CPU command edges).
Waits for sub_txn.done directly — no ResponseMsg needed for PE direction.
Then sends aggregate ResponseMsg back to IO_CPU on the reverse path.
"""
request = txn.request
target_pe = getattr(request, "target_pe", "all")
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
pe_ids = self._resolve_pe_ids(target_pe)
if not pe_ids:
txn.done.succeed()
return
# Fan out to each PE_CPU and collect done events
sub_dones: list[simpy.Event] = []
sub_txns: list[Transaction] = []
for pe_id in pe_ids:
pe_cpu_id = f"{cube_prefix}.pe{pe_id}.pe_cpu"
try:
path = self.ctx.router.find_node_path(self.node.id, pe_cpu_id)
except Exception:
continue
if len(path) < 2:
continue
sub_done = env.event()
sub_txn = Transaction(
request=request, path=path, step=0,
nbytes=0, done=sub_done,
)
yield self.out_ports[path[1]].put(sub_txn.advance())
sub_dones.append(sub_done)
sub_txns.append(sub_txn)
if not sub_dones:
txn.done.succeed()
return
# Wait for all PE_CPUs to complete
for sd in sub_dones:
yield sd
# Aggregate PE-internal metrics (max across PEs)
pe_exec_values = [st.result_data.get("pe_exec_ns", 0.0) for st in sub_txns]
if pe_exec_values:
txn.result_data["pe_exec_ns"] = max(pe_exec_values)
dma_values = [st.result_data.get("dma_ns", 0.0) for st in sub_txns]
if dma_values:
txn.result_data["dma_ns"] = max(dma_values)
compute_values = [st.result_data.get("compute_ns", 0.0) for st in sub_txns]
if compute_values:
txn.result_data["compute_ns"] = max(compute_values)
# Send aggregate response on reverse command path back to IO_CPU
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _resolve_dma_destinations(self, request: Any, target_pe: int | str) -> list[str]:
"""Return list of HBM destination node_ids for DMA fan-out.
Uses PA-based resolution to determine the actual target cube and slice,
enabling cross-cube DMA routing when the PA points to a remote cube.
"""
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
if isinstance(target_pe, int):
return [f"{cube_prefix}.hbm_ctrl.slice{target_pe}"]
# PA-based resolution: extract actual target from physical address
pa_val = getattr(request, "dst_pa", None) or getattr(request, "src_pa", None)
if pa_val is not None:
from kernbench.policy.address.phyaddr import PhysAddr
try:
pa = PhysAddr.decode(pa_val)
return [self.ctx.resolver.resolve(pa)]
except Exception:
pass
# "all" without PA (KernelLaunch): all slices in local cube
n_slices = 8
if self.ctx and self.ctx.spec:
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
n_slices = mm.get("hbm_slices_per_cube", 8)
return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)]
def _resolve_pe_ids(self, target_pe: int | str) -> list[int]:
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
if isinstance(target_pe, int):
return [target_pe]
# "all": all PEs in local cube
n_slices = 8
if self.ctx and self.ctx.spec:
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
n_slices = mm.get("hbm_slices_per_cube", 8)
return list(range(n_slices))
+187
View File
@@ -0,0 +1,187 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class TwoDMeshNocComponent(ComponentBase):
"""2D mesh NOC modeled as a single smart node.
Latency model:
- Traversal latency = Manhattan distance between prev_hop and next_hop
node positions, split into XY segments, traversed with pipeline.
- overhead_ns (from node.attrs) is added once per traversal.
Contention model:
- Each directed XY segment is a simpy.Resource(capacity=1).
- Pipeline: next segment's resource is requested before the current
segment's timeout completes, so a free downstream segment is acquired
immediately (wormhole-style cut-through).
- Two transactions sharing a segment (same row or column band) contend.
Concurrency:
- _worker spawns an independent SimPy process per transaction, so the
NOC is never serialized at the node level — only at segment resources.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._env: simpy.Environment | None = None
self._links: dict[tuple, simpy.Resource] = {}
self._x_grid: list[float] = []
self._y_grid: list[float] = []
def start(self, env: simpy.Environment) -> None:
self._env = env
self._build_grid()
super().start(env)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
yield env.timeout(0)
# ── Grid construction ────────────────────────────────────────────
def _build_grid(self) -> None:
if not self.ctx:
return
cube_prefix = self.node.id.rsplit(".", 1)[0]
xs: set[float] = set()
ys: set[float] = set()
for node_id, pos in self.ctx.positions.items():
if node_id.startswith(cube_prefix + ".") and pos is not None:
xs.add(round(pos[0], 2))
ys.add(round(pos[1], 2))
self._x_grid = sorted(xs)
self._y_grid = sorted(ys)
def _get_link(self, key: tuple) -> simpy.Resource:
if key not in self._links:
assert self._env is not None
self._links[key] = simpy.Resource(self._env, capacity=1)
return self._links[key]
# ── Worker ───────────────────────────────────────────────────────
def _worker(self, env: simpy.Environment) -> Generator:
while True:
txn: Any = yield self._inbox.get()
env.process(self._route(env, txn))
def _route(self, env: simpy.Environment, txn: Any) -> Generator:
prev_hop = txn.path[txn.step - 1] if txn.step > 0 else None
next_hop = txn.next_hop
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
links: list[tuple[tuple, float]] = []
if prev_hop and next_hop and self.ctx:
src_pos = self.ctx.positions.get(prev_hop)
dst_pos = self.ctx.positions.get(next_hop)
if src_pos and dst_pos:
links = self._xy_links(src_pos, dst_pos)
if links:
yield from self._traverse(env, links, overhead_ns)
else:
yield env.timeout(overhead_ns)
if next_hop:
yield self.out_ports[next_hop].put(txn.advance())
else:
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
txn.done.succeed()
# ── XY routing and pipelined link traversal ──────────────────────
def _traverse(
self,
env: simpy.Environment,
links: list[tuple[tuple, float]],
overhead_ns: float,
) -> Generator:
"""Pipeline: request next segment before current timeout finishes."""
ns_per_mm = self.ctx.ns_per_mm # type: ignore[union-attr]
# Acquire first link
first_key, _ = links[0]
current_resource = self._get_link(first_key)
current_req = current_resource.request()
yield current_req
for i, (_, dist_mm) in enumerate(links):
# Request next link before current timeout (pipeline)
if i + 1 < len(links):
next_key, _ = links[i + 1]
next_resource = self._get_link(next_key)
next_req = next_resource.request()
yield env.timeout(dist_mm * ns_per_mm + (overhead_ns if i == 0 else 0.0))
current_resource.release(current_req)
if i + 1 < len(links):
yield next_req # usually already fulfilled (pipeline)
current_resource = next_resource
current_req = next_req
def _xy_links(
self,
src: tuple[float, float],
dst: tuple[float, float],
) -> list[tuple[tuple, float]]:
"""XY routing: horizontal segment first, then vertical.
Returns list of (link_key, dist_mm) pairs, where link_key uniquely
identifies a directed segment shared across concurrent transactions.
"""
x0, y0 = src
x1, y1 = dst
links: list[tuple[tuple, float]] = []
# Horizontal segment at y≈y0
if abs(x0 - x1) > 1e-9:
y_band = self._snap(y0, self._y_grid)
for xa, xb in self._segments(x0, x1, self._x_grid):
d = abs(xb - xa)
if d > 1e-9:
lo, hi = (xa, xb) if xa < xb else (xb, xa)
dir_h = "E" if xb > xa else "W"
links.append((("H", round(y_band, 2), round(lo, 2), round(hi, 2), dir_h), d))
# Vertical segment at x≈x1
if abs(y0 - y1) > 1e-9:
x_band = self._snap(x1, self._x_grid)
for ya, yb in self._segments(y0, y1, self._y_grid):
d = abs(yb - ya)
if d > 1e-9:
lo, hi = (ya, yb) if ya < yb else (yb, ya)
dir_v = "S" if yb > ya else "N"
links.append((("V", round(x_band, 2), round(lo, 2), round(hi, 2), dir_v), d))
return links
@staticmethod
def _snap(val: float, grid: list[float]) -> float:
if not grid:
return val
return min(grid, key=lambda g: abs(g - val))
@staticmethod
def _segments(a: float, b: float, grid: list[float]) -> list[tuple[float, float]]:
"""Consecutive (p_i, p_{i+1}) pairs covering range [a, b] using grid waypoints."""
if abs(a - b) < 1e-9:
return []
lo, hi = (a, b) if a < b else (b, a)
pts = [lo] + [g for g in grid if lo + 1e-9 < g < hi - 1e-9] + [hi]
pairs = [(pts[i], pts[i + 1]) for i in range(len(pts) - 1)]
if a > b:
pairs = [(p2, p1) for p1, p2 in reversed(pairs)]
return pairs
+27
View File
@@ -0,0 +1,27 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PcieEpComponent(ComponentBase):
"""PCIe endpoint: protocol processing overhead before forwarding.
Applies overhead_ns (from node.attrs) for PCIe protocol handling,
then forwards via inherited _forward_txn().
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
+154
View File
@@ -0,0 +1,154 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeCpuComponent(ComponentBase):
"""PE_CPU: kernel execution controller (Stage 2).
Two-phase kernel execution (ADR-0014 D1):
Phase 1 (compile): look up kernel from registry, run it with TLContext
to generate a PeCommand list.
Phase 2 (replay): iterate commands, dispatch to PE_SCHEDULER via
PeInternalTxn, wait for blocking commands.
Non-kernel Transactions are forwarded normally.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._pe_prefix = node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
try:
self._pe_idx = int(self._pe_prefix.rsplit("pe", 1)[1])
except (IndexError, ValueError):
self._pe_idx = 0
# Extract sip/cube index for multi-SIP/cube shard matching
parts = node.id.split(".")
try:
self._sip_idx = int(parts[0].replace("sip", ""))
except (IndexError, ValueError):
self._sip_idx = 0
try:
self._cube_idx = int(parts[1].replace("cube", ""))
except (IndexError, ValueError):
self._cube_idx = 0
def _find_shard(self, shards: tuple) -> Any:
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
for s in shards:
if s.sip == self._sip_idx and s.cube == self._cube_idx and s.pe == self._pe_idx:
return s
return shards[min(self._pe_idx, len(shards) - 1)]
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
while True:
txn: Any = yield self._inbox.get()
from kernbench.runtime_api.kernel import KernelLaunchMsg
if hasattr(txn, "request") and isinstance(txn.request, KernelLaunchMsg):
yield from self._execute_kernel(env, txn)
else:
yield from self._forward_txn(env, txn)
def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator:
"""Compile kernel function and replay command trace."""
from kernbench.common.pe_commands import (
CompositeCmd,
PeCpuOverheadCmd,
PeInternalTxn,
WaitCmd,
)
from kernbench.triton_emu.registry import get_kernel
from kernbench.triton_emu.tl_context import TLContext, run_kernel
request = txn.request
# Phase 1: Compile — apply PE_CPU setup overhead, then run kernel
yield from self.run(env, 0)
kernel_fn = get_kernel(request.kernel_ref.name)
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
# Unpack KernelLaunchMsg.args into positional args for kernel function
# TensorArg → PA (pointer), ScalarArg → value
kernel_args: list = []
for arg in request.args:
if arg.arg_kind == "tensor":
shard = self._find_shard(arg.shards)
kernel_args.append(shard.pa)
elif arg.arg_kind == "scalar":
kernel_args.append(arg.value)
run_kernel(kernel_fn, tl, *kernel_args)
commands = tl.commands
# Phase 2: Replay — dispatch commands to PE_SCHEDULER
pe_exec_start = env.now
scheduler_id = f"{self._pe_prefix}.pe_scheduler"
pending: dict[str, simpy.Event] = {} # completion_id → done event
composite_results: list[dict] = [] # collect result_data from CompositeCmd txns
for cmd in commands:
if isinstance(cmd, PeCpuOverheadCmd):
yield env.timeout(cmd.cycles)
elif isinstance(cmd, WaitCmd):
if cmd.handle is not None:
evt = pending.pop(cmd.handle.id, None)
if evt:
yield evt
else:
# Wait all pending completions
for evt in pending.values():
yield evt
pending.clear()
elif isinstance(cmd, CompositeCmd):
# Non-blocking: dispatch to scheduler, track completion
done_evt = env.event()
pe_txn = PeInternalTxn(
command=cmd, done=done_evt,
pe_prefix=self._pe_prefix,
)
composite_results.append(pe_txn.result_data)
yield self.out_ports[scheduler_id].put(pe_txn)
pending[cmd.completion.id] = done_evt
else:
# Blocking: dispatch and wait for completion
done_evt = env.event()
pe_txn = PeInternalTxn(
command=cmd, done=done_evt,
pe_prefix=self._pe_prefix,
)
yield self.out_ports[scheduler_id].put(pe_txn)
yield done_evt
# Wait for any remaining pending completions
for evt in pending.values():
yield evt
# Record PE-internal execution time
txn.result_data["pe_exec_ns"] = env.now - pe_exec_start
# Aggregate dma_ns / compute_ns from CompositeCmd results
total_dma_ns = 0.0
total_compute_ns = 0.0
for rd in composite_results:
total_dma_ns += rd.get("dma_ns", 0.0)
total_compute_ns += rd.get("compute_ns", 0.0)
txn.result_data["dma_ns"] = total_dma_ns
txn.result_data["compute_ns"] = total_compute_ns
# Signal original Transaction done
txn.done.succeed()
+116
View File
@@ -0,0 +1,116 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import PeEngineBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeDmaComponent(PeEngineBase):
"""PE_DMA: dual-channel DMA engine with READ and WRITE resources.
Each channel has capacity=1 (ADR-0014 D4):
- DMA_READ and DMA_WRITE may execute concurrently.
- Multiple READs cannot overlap; multiple WRITEs cannot overlap.
Handles two message types:
- Transaction: external fabric messages (PeDmaMsg probes, M_CPU DMA)
- PeInternalTxn: PE-internal commands from PE_SCHEDULER
(DmaReadCmd → HBM read, DmaWriteCmd → HBM write)
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._dma_read: simpy.Resource | None = None
self._dma_write: simpy.Resource | None = None
def init_resources(self, env: simpy.Environment) -> None:
self._dma_read = simpy.Resource(env, capacity=1)
self._dma_write = simpy.Resource(env, capacity=1)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
yield env.timeout(0)
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Handle PE-internal DMA command: resolve PA → HBM path → transfer."""
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import PeDmaMsg
cmd = pe_txn.command
assert self._dma_read is not None and self._dma_write is not None
# Determine direction and target PA
if isinstance(cmd, DmaReadCmd):
dma_res = self._dma_read
target_pa = cmd.src_pa
is_write = False
elif isinstance(cmd, DmaWriteCmd):
dma_res = self._dma_write
target_pa = cmd.dst_pa
is_write = True
else:
pe_txn.done.succeed()
return
# Resolve PA → HBM node and compute path
pa = PhysAddr.decode(target_pa)
dst_node = self.ctx.resolver.resolve(pa)
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
drain_ns = self.ctx.compute_drain_ns(path, cmd.nbytes)
# Acquire DMA channel (command issue serialization)
with dma_res.request() as req:
yield req
# Create sub-Transaction with PeDmaMsg (HbmCtrl handles it directly)
sub_done = env.event()
sub_request = PeDmaMsg(
correlation_id="pe_internal",
request_id=f"dma_{id(pe_txn)}",
src_sip=0, src_cube=0, src_pe=0,
dst_pa=target_pa, nbytes=cmd.nbytes,
is_write=is_write,
)
sub_txn = Transaction(
request=sub_request, path=path, step=0,
nbytes=cmd.nbytes, done=sub_done, drain_ns=drain_ns,
)
# Send to next hop (path[0] is pe_dma itself, path[1] is xbar)
if len(path) > 1:
yield self.out_ports[path[1]].put(sub_txn.advance())
# DMA channel released after issue
# Wait for HBM transfer completion
yield sub_done
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition."""
dma_res = self._select_channel(txn)
with dma_res.request() as req:
yield req
next_hop = txn.next_hop
if next_hop:
yield self.out_ports[next_hop].put(txn.advance())
else:
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
txn.done.succeed()
def _select_channel(self, txn: Any) -> simpy.Resource:
"""Select DMA channel based on request type."""
from kernbench.runtime_api.kernel import MemoryWriteMsg
assert self._dma_read is not None and self._dma_write is not None
if isinstance(txn.request, MemoryWriteMsg):
return self._dma_write
return self._dma_read
+90
View File
@@ -0,0 +1,90 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import PeEngineBase
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
# dtype → bit width (for TFLOPS scaling)
_DTYPE_BITS: dict[str, int] = {
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
"f32": 32, "fp32": 32, "float32": 32,
"i8": 8, "int8": 8,
"i16": 16, "int16": 16,
"i32": 32, "int32": 32,
}
class PeGemmComponent(PeEngineBase):
"""PE_GEMM: matrix multiplication engine sharing accel_slot (ADR-0014 D4).
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
exclusive with PE_MATH within the same PE.
Compute latency model:
FLOPs = 2 * M * K * N
effective_tflops = peak_tflops_f16 * (16 / dtype_bits)
compute_ns = FLOPs / (effective_tflops * 1e3)
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._accel: simpy.Resource | None = None
self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0))
def init_resources(self, env: simpy.Environment) -> None:
resource_name = self.node.attrs.get("shared_resource")
if resource_name and self.ctx:
self._accel = self.ctx.get_shared_resource(
env, f"{self._pe_prefix}.{resource_name}"
)
def _compute_ns(self, m: int, k: int, n: int, dtype: str) -> float:
"""Compute GEMM latency in nanoseconds."""
if self._peak_tflops_f16 <= 0:
return float(self.node.attrs.get("overhead_ns", 0.0))
dtype_bits = _DTYPE_BITS.get(dtype, 16)
effective_tflops = self._peak_tflops_f16 * (16.0 / dtype_bits)
flops = 2.0 * m * k * n
return flops / (effective_tflops * 1e3)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
from kernbench.common.pe_commands import GemmCmd
cmd = pe_txn.command
if self._accel:
with self._accel.request() as req:
yield req
if isinstance(cmd, GemmCmd):
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
yield env.timeout(ns)
else:
yield from self.run(env, 0)
else:
if isinstance(cmd, GemmCmd):
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
yield env.timeout(ns)
else:
yield from self.run(env, 0)
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Transaction forwarding with accel_slot acquisition."""
if self._accel:
with self._accel.request() as req:
yield req
yield from super()._forward_txn(env, txn)
else:
yield from super()._forward_txn(env, txn)
+54
View File
@@ -0,0 +1,54 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import PeEngineBase
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeMathComponent(PeEngineBase):
"""PE_MATH: element-wise computation engine sharing accel_slot (ADR-0014 D4).
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
exclusive with PE_GEMM within the same PE.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._accel: simpy.Resource | None = None
def init_resources(self, env: simpy.Environment) -> None:
resource_name = self.node.attrs.get("shared_resource")
if resource_name and self.ctx:
self._accel = self.ctx.get_shared_resource(
env, f"{self._pe_prefix}.{resource_name}"
)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
if self._accel:
with self._accel.request() as req:
yield req
yield from self.run(env, 0)
else:
yield from self.run(env, 0)
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Transaction forwarding with accel_slot acquisition."""
if self._accel:
with self._accel.request() as req:
yield req
yield from super()._forward_txn(env, txn)
else:
yield from super()._forward_txn(env, txn)
@@ -0,0 +1,245 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeSchedulerComponent(ComponentBase):
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1).
Receives PeInternalTxn from PE_CPU, routes to the appropriate engine:
- DmaReadCmd / DmaWriteCmd → PE_DMA
- GemmCmd → PE_GEMM
- MathCmd → PE_MATH
- CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2)
Composite GEMM pipeline (32x64x32 tiles):
DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t)
with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
Applies scheduler overhead_ns before dispatching each command.
Non-PeInternalTxn messages are forwarded via inherited _forward_txn().
"""
# Scheduler tile dimensions (ADR-0014 D3.2)
TILE_M = 32
TILE_K = 64
TILE_N = 32
# Command → engine suffix dispatch table.
# New engines: add a single entry here (e.g. ConvCmd: "pe_conv").
_CMD_DISPATCH: dict[type, str] = {}
@classmethod
def _ensure_dispatch_table(cls) -> None:
if cls._CMD_DISPATCH:
return
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd
cls._CMD_DISPATCH = {
DmaReadCmd: "pe_dma",
DmaWriteCmd: "pe_dma",
GemmCmd: "pe_gemm",
MathCmd: "pe_math",
}
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._pe_prefix = node.id.rsplit(".", 1)[0]
self._ensure_dispatch_table()
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
from kernbench.common.pe_commands import PeInternalTxn
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, PeInternalTxn):
env.process(self._dispatch(env, msg))
else:
yield from self._forward_txn(env, msg)
def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Route a PeInternalTxn to the correct engine via dispatch table."""
from kernbench.common.pe_commands import CompositeCmd
# Scheduler overhead
yield from self.run(env, 0)
cmd = pe_txn.command
# Check dispatch table first
engine_suffix = self._CMD_DISPATCH.get(type(cmd))
if engine_suffix is not None:
yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn)
return
# CompositeCmd: tiled pipeline (not a simple forward)
if isinstance(cmd, CompositeCmd):
yield from self._dispatch_composite(env, pe_txn)
return
# Unknown command — signal done immediately
pe_txn.done.succeed()
def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Composite tiled pipeline (ADR-0014 D3.2).
GEMM: 3-stage pipeline with b-tile streaming from HBM.
MATH: sequential compute + DMA_WRITE (no tiling).
"""
from kernbench.common.pe_commands import CompositeCmd
cmd = pe_txn.command
assert isinstance(cmd, CompositeCmd)
if cmd.op == "gemm" and cmd.b is not None:
yield from self._pipeline_gemm(env, pe_txn, cmd)
else:
yield from self._pipeline_math(env, pe_txn, cmd)
def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
"""Tiled GEMM pipeline: stream b tiles from HBM, compute, write results.
Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref).
Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t)
Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
"""
from kernbench.common.pe_commands import (
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
PeInternalTxn as PeTxn,
TensorHandle,
)
pp = self._pe_prefix
a = cmd.a # already in TCM
b = cmd.b # HBM reference (via tl.ref)
M, K_a = a.shape[-2], a.shape[-1]
K_b, N = b.shape[-2], b.shape[-1]
dtype = a.dtype
dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2
# Tile counts
n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K)
n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N)
n_tiles = n_tiles_k * n_tiles_n
prev_compute_done = None
prev_write_done = None
total_dma_ns = 0.0
total_compute_ns = 0.0
for tile_idx in range(n_tiles):
tk = tile_idx // n_tiles_n
tn = tile_idx % n_tiles_n
k_start = tk * self.TILE_K
n_start = tn * self.TILE_N
tile_k = min(self.TILE_K, K_a - k_start)
tile_n = min(self.TILE_N, N - n_start)
tile_nbytes = tile_k * tile_n * dtype_bytes
# --- Stage 1: DMA_READ b_tile from HBM ---
read_done = env.event()
b_tile_pa = b.pa + (k_start * N + n_start) * dtype_bytes
b_tile_handle = TensorHandle(
id=f"b_tile_{tile_idx}", pa=b_tile_pa,
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
)
read_cmd = DmaReadCmd(handle=b_tile_handle, src_pa=b_tile_pa, nbytes=tile_nbytes)
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
# Wait for previous compute before starting this tile's compute
if prev_compute_done is not None:
yield prev_compute_done
# Wait for this tile's DMA_READ
yield read_done
total_dma_ns += env.now - t0
# --- Stage 2: COMPUTE (GEMM) ---
compute_done = env.event()
out_handle = TensorHandle(
id=f"out_tile_{tile_idx}", pa=0,
shape=(M, tile_n), dtype=dtype,
nbytes=M * tile_n * dtype_bytes,
)
compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle,
m=M, k=tile_k, n=tile_n)
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn)
# Wait for previous write (DMA_WRITE serialization)
if prev_write_done is not None:
yield prev_write_done
# Wait for compute of THIS tile
yield compute_done
total_compute_ns += env.now - t0
prev_compute_done = compute_done
# --- Stage 3: DMA_WRITE out_tile to HBM ---
write_done = env.event()
out_tile_pa = cmd.out_pa + n_start * dtype_bytes
write_nbytes = M * tile_n * dtype_bytes
write_cmd = DmaWriteCmd(handle=out_handle, dst_pa=out_tile_pa, nbytes=write_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
prev_write_done = write_done
# Wait for final write
if prev_write_done is not None:
t0 = env.now
yield prev_write_done
total_dma_ns += env.now - t0
pe_txn.result_data["dma_ns"] = total_dma_ns
pe_txn.result_data["compute_ns"] = total_compute_ns
pe_txn.done.succeed()
def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
"""Non-GEMM composite: sequential compute + DMA_WRITE (no tiling)."""
from kernbench.common.pe_commands import (
DmaWriteCmd,
MathCmd,
PeInternalTxn as PeTxn,
)
pp = self._pe_prefix
# Step 1: Compute (MATH)
compute_done = env.event()
compute_cmd = MathCmd(
op=cmd.math_op or "identity",
inputs=(cmd.a,), out=cmd.a,
)
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_math"].put(compute_txn)
yield compute_done
# Step 2: DMA_WRITE result to HBM
write_done = env.event()
write_cmd = DmaWriteCmd(handle=cmd.a, dst_pa=cmd.out_pa, nbytes=cmd.out_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
yield write_done
pe_txn.done.succeed()
+25
View File
@@ -0,0 +1,25 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeTcmComponent(ComponentBase):
"""PE_TCM: tightly-coupled memory / local SRAM staging buffer.
Terminal storage component for PE-internal dataflow (ADR-0014 D5).
Phase 0: applies overhead_ns and drain_ns at terminal.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
def run(self, env, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
+59
View File
@@ -0,0 +1,59 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class SramComponent(ComponentBase):
"""Cube SRAM: terminal component that models SRAM access latency.
Applies overhead_ns processing overhead (from node.attrs).
On completion, sends a ResponseMsg back on the reverse path.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
"""Terminal worker: process, apply drain, send response."""
while True:
txn: Any = yield self._inbox.get()
yield from self.run(env, txn.nbytes)
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
yield from self._send_response(env, txn)
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
"""Create ResponseMsg and send on reverse path."""
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2 and self.ctx:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=txn.request.correlation_id,
request_id=txn.request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
View File
+85
View File
@@ -0,0 +1,85 @@
from __future__ import annotations
from dataclasses import dataclass
from kernbench.policy.address.phyaddr import PhysAddr
class AllocationError(Exception):
pass
@dataclass(frozen=True)
class AddressConfig:
sip_count: int
cubes_per_sip: int
pes_per_cube: int
hbm_bytes_per_cube: int
hbm_slices_per_cube: int
tcm_bytes_per_pe: int
tcm_scheduler_reserved_bytes: int
sram_bytes_per_cube: int
@property
def hbm_slice_bytes(self) -> int:
return self.hbm_bytes_per_cube // self.hbm_slices_per_cube
@property
def tcm_allocatable_bytes(self) -> int:
return self.tcm_bytes_per_pe - self.tcm_scheduler_reserved_bytes
class PEMemAllocator:
def __init__(
self, rack_id: int, sip_id: int, cube_id: int, pe_id: int, cfg: AddressConfig,
) -> None:
self._rack_id = rack_id
self._sip_id = sip_id
self._cube_id = cube_id
self._pe_id = pe_id
self._cfg = cfg
self._hbm_cursor = 0
self._tcm_cursor = 0
def alloc_hbm(self, nbytes: int) -> PhysAddr:
if self._hbm_cursor + nbytes > self._cfg.hbm_slice_bytes:
raise AllocationError(
f"HBM overflow: need {nbytes}, "
f"available {self._cfg.hbm_slice_bytes - self._hbm_cursor}"
)
pa = PhysAddr.pe_hbm_addr(
rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id,
pe_id=self._pe_id, pe_local_hbm_offset=self._hbm_cursor,
slice_size_bytes=self._cfg.hbm_slice_bytes,
)
self._hbm_cursor += nbytes
return pa
def alloc_tcm(self, nbytes: int) -> PhysAddr:
if self._tcm_cursor + nbytes > self._cfg.tcm_allocatable_bytes:
raise AllocationError(
f"TCM overflow: need {nbytes}, "
f"available {self._cfg.tcm_allocatable_bytes - self._tcm_cursor}"
)
pa = PhysAddr.pe_tcm_addr(
rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id,
pe_id=self._pe_id, tcm_offset=self._tcm_cursor,
)
self._tcm_cursor += nbytes
return pa
@property
def hbm_used(self) -> int:
return self._hbm_cursor
@property
def hbm_total(self) -> int:
return self._cfg.hbm_slice_bytes
@property
def tcm_used(self) -> int:
return self._tcm_cursor
@property
def tcm_total(self) -> int:
return self._cfg.tcm_allocatable_bytes
+184
View File
@@ -0,0 +1,184 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
from typing import Literal
MAX_51 = (1 << 51) - 1
class PhysAddrError(Exception):
pass
def _chk_range(name: str, v: int, bits: int) -> None:
if not (0 <= v < (1 << bits)):
raise PhysAddrError(f"{name} out of range for {bits} bits: {v}")
def _chk_max(name: str, v: int, maxv: int) -> None:
if not (0 <= v <= maxv):
raise PhysAddrError(f"{name} out of range (0..{maxv}): {v}")
class UnitType(IntEnum):
PE = 0
MCPU = 1
SRAM = 2
@dataclass(frozen=True)
class PhysAddr:
"""
51-bit physical address value object.
Layout:
[50:47] rack_id (4)
[46:43] sip_id (4)
[42:38] sip_seg (5) # cube_id
[37:0] local_offset (38) => each segment is 256GB
local_offset:
[37] selector: 1 = HBM window (128GB reserved), 0 = PE resource window
"""
rack_id: int
sip_id: int
sip_seg: int
local_offset: int
kind: Literal["hbm", "pe_resource", "raw"] = "raw"
cube_id: int = 0
unit_type: UnitType = UnitType.PE
pe_id: int = 0
ext: int = 0
sub_offset: int = 0
hbm_offset: int = 0
HBM_WINDOW_BYTES = 1 << 37 # 128GB
def encode(self) -> int:
_chk_range("rack_id", self.rack_id, 4)
_chk_range("sip_id", self.sip_id, 4)
_chk_range("sip_seg", self.sip_seg, 5)
_chk_range("local_offset", self.local_offset, 38)
addr = (self.rack_id << 47) | (self.sip_id << 43) | (self.sip_seg << 38) | self.local_offset
if not (0 <= addr <= MAX_51):
raise PhysAddrError("address exceeds 51-bit space")
return addr
@staticmethod
def decode(addr: int) -> PhysAddr:
if not (0 <= addr <= MAX_51):
raise PhysAddrError("addr must be a 51-bit value")
rack = (addr >> 47) & 0xF
sip_id = (addr >> 43) & 0xF
sip_seg = (addr >> 38) & 0x1F
off = addr & ((1 << 38) - 1)
cube_id = sip_seg
sel = (off >> 37) & 0x1
if sel == 1:
hbm_offset = int(off & ((1 << 37) - 1))
return PhysAddr(
rack_id=rack,
sip_id=sip_id,
sip_seg=sip_seg,
local_offset=off,
kind="hbm",
cube_id=cube_id,
hbm_offset=hbm_offset,
)
# PE resource decode
raw_ut = int((off >> 34) & 0x7)
try:
unit_type = UnitType(raw_ut)
except ValueError:
raise PhysAddrError(f"unknown unit_type: {raw_ut}") from None
pe_id = int((off >> 30) & 0xF)
ext = int((off >> 29) & 0x1)
sub_offset = int(off & ((1 << 29) - 1))
return PhysAddr(
rack_id=rack,
sip_id=sip_id,
sip_seg=sip_seg,
local_offset=off,
kind="pe_resource",
cube_id=cube_id,
unit_type=unit_type,
pe_id=pe_id,
ext=ext,
sub_offset=sub_offset,
hbm_offset=0,
)
@staticmethod
def hbm_addr(*, rack_id: int, sip_id: int, cube_id: int, hbm_offset: int) -> PhysAddr:
_chk_max("cube_id", cube_id, 31)
_chk_range("hbm_offset", hbm_offset, 37)
sip_seg = cube_id
local_offset = (1 << 37) | int(hbm_offset)
return PhysAddr(
rack_id=rack_id,
sip_id=sip_id,
sip_seg=sip_seg,
local_offset=local_offset,
kind="hbm",
cube_id=cube_id,
hbm_offset=int(hbm_offset),
)
@staticmethod
def pe_hbm_addr(
*,
rack_id: int,
sip_id: int,
cube_id: int,
pe_id: int,
pe_local_hbm_offset: int,
slice_size_bytes: int,
) -> PhysAddr:
_chk_max("cube_id", cube_id, 31)
_chk_range("pe_id", pe_id, 4)
if not (0 <= pe_local_hbm_offset < slice_size_bytes):
raise PhysAddrError("pe_local_hbm_offset out of PE local slice range")
hbm_offset = int(pe_id) * int(slice_size_bytes) + int(pe_local_hbm_offset)
if not (0 <= hbm_offset < PhysAddr.HBM_WINDOW_BYTES):
raise PhysAddrError("HBM offset exceeds reserved 128GB window")
return PhysAddr.hbm_addr(
rack_id=rack_id, sip_id=sip_id, cube_id=cube_id, hbm_offset=hbm_offset
)
@staticmethod
def hbm_pe_id(hbm_offset: int, slice_size_bytes: int) -> int:
return hbm_offset // slice_size_bytes
@staticmethod
def cube_sram_addr(
*, rack_id: int, sip_id: int, cube_id: int, sram_offset: int,
) -> PhysAddr:
_chk_max("cube_id", cube_id, 31)
_chk_range("sram_offset", sram_offset, 29)
sip_seg = cube_id
local_offset = (UnitType.SRAM << 34) | sram_offset
return PhysAddr(
rack_id=rack_id, sip_id=sip_id, sip_seg=sip_seg,
local_offset=local_offset,
kind="pe_resource", cube_id=cube_id,
unit_type=UnitType.SRAM, sub_offset=sram_offset,
)
@staticmethod
def pe_tcm_addr(
*, rack_id: int, sip_id: int, cube_id: int, pe_id: int, tcm_offset: int,
) -> PhysAddr:
_chk_max("cube_id", cube_id, 31)
_chk_range("pe_id", pe_id, 4)
_chk_range("tcm_offset", tcm_offset, 29)
sip_seg = cube_id
local_offset = (UnitType.PE << 34) | (pe_id << 30) | tcm_offset
return PhysAddr(
rack_id=rack_id, sip_id=sip_id, sip_seg=sip_seg,
local_offset=local_offset,
kind="pe_resource", cube_id=cube_id,
unit_type=UnitType.PE, pe_id=pe_id, sub_offset=tcm_offset,
)
+174
View File
@@ -0,0 +1,174 @@
from __future__ import annotations
from dataclasses import dataclass
from math import ceil
from typing import Literal
@dataclass(frozen=True)
class DPPolicy:
"""Two-level data-parallel policy: cube-level + pe-level."""
cube: Literal["replicate", "shard_m", "shard_k"] = "replicate"
pe: Literal["replicate", "column_wise", "row_wise"] = "replicate"
def resolve_dp_policy(
policy: DPPolicy,
*,
shape: tuple[int, int],
itemsize: int,
num_pe: int,
num_cubes: int = 1,
) -> list[ShardSpec]:
"""Resolve a DPPolicy into a list[ShardSpec] with two-level resolution.
Cube-level policy distributes across cubes, pe-level distributes within
each cube. ShardSpec.pe_index uses flat indexing: cube_id * num_pe + pe_id.
"""
_PE_RESOLVERS = {
"replicate": replicate,
"column_wise": column_wise,
"row_wise": row_wise,
}
resolver = _PE_RESOLVERS.get(policy.pe)
if resolver is None:
raise ValueError(f"Unknown pe-level policy: {policy.pe}")
if num_cubes <= 1:
return resolver(shape=shape, itemsize=itemsize, num_pe=num_pe)
# Two-level resolution: cube-level → pe-level
M, K = shape
all_shards: list[ShardSpec] = []
for cube_id in range(num_cubes):
# Determine per-cube shape based on cube-level policy
if policy.cube == "replicate":
cube_shape = (M, K)
cube_offset = 0
elif policy.cube == "shard_m":
chunk_m = M // num_cubes
cube_shape = (chunk_m, K)
cube_offset = cube_id * chunk_m * K * itemsize
elif policy.cube == "shard_k":
chunk_k = K // num_cubes
cube_shape = (M, chunk_k)
cube_offset = cube_id * M * chunk_k * itemsize
else:
raise ValueError(f"Unknown cube-level policy: {policy.cube}")
# Resolve pe-level within this cube's shape
pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe)
# Remap pe_index to flat index and adjust offset
for ps in pe_shards:
flat_idx = cube_id * num_pe + ps.pe_index
all_shards.append(ShardSpec(
pe_index=flat_idx,
offset_bytes=cube_offset + ps.offset_bytes,
nbytes=ps.nbytes,
))
return all_shards
@dataclass(frozen=True)
class ShardSpec:
pe_index: int
offset_bytes: int
nbytes: int
def column_wise(
*, shape: tuple[int, int], itemsize: int, num_pe: int,
) -> list[ShardSpec]:
"""Split K axis into num_pe equal parts. Each PE gets (M, K/P)."""
M, K = shape
chunk_k = K // num_pe
chunk_bytes = M * chunk_k * itemsize
shards = []
for i in range(num_pe):
shards.append(ShardSpec(
pe_index=i,
offset_bytes=i * chunk_bytes,
nbytes=chunk_bytes,
))
return shards
def row_wise(
*, shape: tuple[int, int], itemsize: int, num_pe: int,
) -> list[ShardSpec]:
"""Split M axis into num_pe equal parts. Each PE gets (M/P, K)."""
M, K = shape
chunk_m = M // num_pe
chunk_bytes = chunk_m * K * itemsize
shards = []
for i in range(num_pe):
shards.append(ShardSpec(
pe_index=i,
offset_bytes=i * chunk_bytes,
nbytes=chunk_bytes,
))
return shards
def replicate(
*, shape: tuple[int, int], itemsize: int, num_pe: int,
) -> list[ShardSpec]:
"""Full copy per PE. Each PE gets (M, K)."""
M, K = shape
full_bytes = M * K * itemsize
return [
ShardSpec(pe_index=i, offset_bytes=0, nbytes=full_bytes)
for i in range(num_pe)
]
def tiled_column_major(
*, shape: tuple[int, int], itemsize: int, num_pe: int,
tile_m: int, tile_k: int,
) -> list[ShardSpec]:
"""2D tiling, column-major order (K axis first), round-robin across PEs."""
M, K = shape
tiles_m = ceil(M / tile_m)
tiles_k = ceil(K / tile_k)
tile_bytes = tile_m * tile_k * itemsize
row_bytes = K * itemsize
shards = []
idx = 0
for mi in range(tiles_m):
for ki in range(tiles_k):
offset = (mi * tile_m * row_bytes) + (ki * tile_k * itemsize)
shards.append(ShardSpec(
pe_index=idx % num_pe,
offset_bytes=offset,
nbytes=tile_bytes,
))
idx += 1
return shards
def tiled_row_major(
*, shape: tuple[int, int], itemsize: int, num_pe: int,
tile_m: int, tile_k: int,
) -> list[ShardSpec]:
"""2D tiling, row-major order (M axis first), round-robin across PEs."""
M, K = shape
tiles_m = ceil(M / tile_m)
tiles_k = ceil(K / tile_k)
tile_bytes = tile_m * tile_k * itemsize
row_bytes = K * itemsize
shards = []
idx = 0
for ki in range(tiles_k):
for mi in range(tiles_m):
offset = (mi * tile_m * row_bytes) + (ki * tile_k * itemsize)
shards.append(ShardSpec(
pe_index=idx % num_pe,
offset_bytes=offset,
nbytes=tile_bytes,
))
idx += 1
return shards
+184
View File
@@ -0,0 +1,184 @@
from __future__ import annotations
import heapq
from collections import defaultdict
from kernbench.policy.address.phyaddr import PhysAddr, UnitType
from kernbench.topology.types import TopologyGraph
class RoutingError(Exception):
pass
class AddressResolver:
"""Resolve a PhysAddr to the destination node_id in the compiled graph.
Also provides named node lookups (find_m_cpu, find_pcie_ep, …) so that
component implementations never construct node_id strings directly.
Centralising the naming convention here means a single change propagates
everywhere (ADR-0015 D4).
"""
def __init__(self, graph: TopologyGraph) -> None:
self._node_ids = set(graph.nodes)
mm = graph.spec["cube"]["memory_map"]
self._slice_size_bytes = mm["hbm_total_gb_per_cube"] * (1 << 30) // mm["hbm_slices_per_cube"]
# ── Physical-address resolution ──────────────────────────────────
def resolve(self, addr: PhysAddr) -> str:
s = addr.sip_id
c = addr.cube_id
if addr.kind == "hbm":
pe_slice = PhysAddr.hbm_pe_id(addr.hbm_offset, self._slice_size_bytes)
node_id = f"sip{s}.cube{c}.hbm_ctrl.slice{pe_slice}"
elif addr.kind == "pe_resource":
if addr.unit_type == UnitType.PE:
node_id = f"sip{s}.cube{c}.pe{addr.pe_id}.pe_tcm"
elif addr.unit_type == UnitType.SRAM:
node_id = f"sip{s}.cube{c}.sram"
elif addr.unit_type == UnitType.MCPU:
node_id = f"sip{s}.cube{c}.m_cpu"
else:
raise RoutingError(f"unsupported unit_type: {addr.unit_type}")
else:
raise RoutingError(f"unsupported address kind: {addr.kind}")
if node_id not in self._node_ids:
raise RoutingError(f"node {node_id} not found in topology")
return node_id
# ── Named node lookups ───────────────────────────────────────────
def find_m_cpu(self, sip: int, cube: int) -> str:
node_id = f"sip{sip}.cube{cube}.m_cpu"
if node_id not in self._node_ids:
raise RoutingError(f"M_CPU not found: {node_id}")
return node_id
def find_pcie_ep(self, sip: int, io_id: str = "io0") -> str:
node_id = f"sip{sip}.{io_id}.pcie_ep"
if node_id not in self._node_ids:
raise RoutingError(f"PCIE_EP not found: {node_id}")
return node_id
def find_io_cpu(self, sip: int, io_id: str = "io0") -> str:
node_id = f"sip{sip}.{io_id}.io_cpu"
if node_id not in self._node_ids:
raise RoutingError(f"IO_CPU not found: {node_id}")
return node_id
def find_all_pcie_eps(self) -> list[str]:
"""Return all PCIE_EP node ids across all SIPs, sorted."""
return sorted(nid for nid in self._node_ids if nid.endswith(".pcie_ep"))
class PathRouter:
"""Find data-path from a source PE (or arbitrary node) to a destination node.
Two adjacency graphs are maintained:
_adj — excludes command edges (used by PE DMA routing, find_path)
_adj_all — includes all edges (used by component-to-component routing,
find_node_path; required because M_CPU↔NOC links are "command")
"""
# Edge kinds excluded from M_CPU DMA adjacency: prevents routing through
# PE-internal pipeline nodes when computing DMA paths.
_MCPU_DMA_EXCLUDE = {"pe_internal", "pe_to_xbar"}
def __init__(self, graph: TopologyGraph) -> None:
self._adj: dict[str, list[tuple[str, float]]] = defaultdict(list)
self._adj_all: dict[str, list[tuple[str, float]]] = defaultdict(list)
self._adj_mcpu_dma: dict[str, list[tuple[str, float]]] = defaultdict(list)
for e in graph.edges:
w = e.routing_weight_mm if e.routing_weight_mm is not None else e.distance_mm
self._adj_all[e.src].append((e.dst, w))
if e.kind != "command":
self._adj[e.src].append((e.dst, w))
if e.kind not in self._MCPU_DMA_EXCLUDE:
self._adj_mcpu_dma[e.src].append((e.dst, w))
def find_path(self, src_pe: str, dst_node: str) -> list[str]:
"""PE DMA routing: prepends .pe_dma, excludes command edges."""
start = f"{src_pe}.pe_dma"
return self._run_dijkstra(self._adj, start, dst_node)
def find_path_with_distance(self, src_pe: str, dst_node: str) -> tuple[list[str], float]:
start = f"{src_pe}.pe_dma"
return self._run_dijkstra_with_dist(self._adj, start, dst_node)
def find_mcpu_dma_path(self, m_cpu_id: str, dst_hbm_slice_id: str) -> list[str]:
"""M_CPU DMA path: never routes through PE-internal nodes (ADR-0015 D5).
Same-cube: deterministic [m_cpu, noc, xbar.pe_i, hbm_ctrl.slice_i].
Cross-cube: Dijkstra via _adj_mcpu_dma (pe_internal/pe_to_xbar excluded)
→ routes through NOC → UCIe → target cube NOC → xbar → HBM.
"""
m_cube = ".".join(m_cpu_id.split(".")[:2])
d_cube = ".".join(dst_hbm_slice_id.split(".")[:2])
if m_cube == d_cube:
slice_idx = int(dst_hbm_slice_id.rsplit("slice", 1)[1])
return [
m_cpu_id,
f"{m_cube}.noc",
f"{m_cube}.xbar.pe{slice_idx}",
dst_hbm_slice_id,
]
return self._run_dijkstra(self._adj_mcpu_dma, m_cpu_id, dst_hbm_slice_id)
def find_node_path(self, src: str, dst: str) -> list[str]:
"""General routing between arbitrary nodes, including command edges.
Used by components (IoCpuComponent, MCpuComponent) that route through
M_CPU↔NOC command-kind links.
"""
return self._run_dijkstra(self._adj_all, src, dst)
def _run_dijkstra(
self,
adj: dict[str, list[tuple[str, float]]],
start: str,
goal: str,
) -> list[str]:
path, _ = self._run_dijkstra_with_dist(adj, start, goal)
return path
def _run_dijkstra_with_dist(
self,
adj: dict[str, list[tuple[str, float]]],
start: str,
goal: str,
) -> tuple[list[str], float]:
if start == goal:
return [start], 0.0
best: dict[str, float] = {start: 0.0}
prev: dict[str, str] = {}
heap: list[tuple[float, str]] = [(0.0, start)]
while heap:
d, node = heapq.heappop(heap)
if node == goal:
path: list[str] = []
cur = goal
while cur != start:
path.append(cur)
cur = prev[cur]
path.append(start)
path.reverse()
return path, d
if d > best.get(node, float("inf")):
continue
for neighbor, edge_dist in adj[node]:
new_d = d + edge_dist
if new_d < best.get(neighbor, float("inf")):
best[neighbor] = new_d
prev[neighbor] = node
heapq.heappush(heap, (new_d, neighbor))
raise RoutingError(f"no path from {start} to {goal}")
# ── backward-compat shims (used by existing tests) ───────────────
def _dijkstra(self, start: str, goal: str) -> list[str]:
return self._run_dijkstra(self._adj, start, goal)
def _dijkstra_with_dist(self, start: str, goal: str) -> tuple[list[str], float]:
return self._run_dijkstra_with_dist(self._adj, start, goal)
+96
View File
@@ -0,0 +1,96 @@
from __future__ import annotations
from collections.abc import Callable
from enum import Enum
from typing import Any
from kernbench.common.types import Completion, SimEngine, Trace
from .context import RuntimeContext
from .types import BenchResult, DeviceSelector
class CompletionPolicy(str, Enum):
LAST_SUBMITTED = "last_submitted"
LAST_COMPLETED = "last_completed" # requires trace/timestamps or engine support; stub for now
ALL_OK_FAIL_FAST = "all_ok_fail_fast"
BenchFn = Callable[[RuntimeContext], Any]
EngineFactory = Callable[[object, DeviceSelector], SimEngine]
def run_bench(
*,
topology: object,
bench_fn: BenchFn,
device: DeviceSelector,
engine_factory: EngineFactory,
correlation_id: str = "bench0",
completion_policy: CompletionPolicy = CompletionPolicy.LAST_SUBMITTED,
) -> BenchResult:
"""
Minimal bench runner.
- topology: compiled topology object (opaque to runtime here)
- bench_fn: callable that receives RuntimeContext and submits requests
- device: DeviceSelector ("all" or "sip:<N>")
- engine_factory: builds sim_engine for given topology & device
- completion_policy: how to determine overall completion/result
"""
engine = engine_factory(topology, device)
# Extract spec from TopologyHandle or TopologyGraph
topo_obj = getattr(topology, "topology_obj", topology)
spec = getattr(topo_obj, "spec", None)
ctx = RuntimeContext(
engine=engine, target_device=device,
correlation_id=correlation_id, spec=spec,
)
bench_fn(ctx)
ctx.wait_all()
collected_traces = ctx._traces or None
handles = ctx.handles()
if not handles:
return BenchResult(
completion=Completion(
ok=False, error_code="NO_REQUESTS", error_message="Bench submitted no requests"
),
correlation_id=correlation_id,
trace=None,
traces=collected_traces,
)
if completion_policy == CompletionPolicy.LAST_SUBMITTED:
last = handles[-1]
completion, trace = engine.get_completion(last)
return BenchResult(
completion=completion, correlation_id=correlation_id,
trace=trace, traces=collected_traces,
)
if completion_policy == CompletionPolicy.ALL_OK_FAIL_FAST:
last_trace: Trace | None = None
for h in handles:
c, t = engine.get_completion(h)
last_trace = t if t is not None else last_trace
if not c.ok:
return BenchResult(
completion=c, correlation_id=correlation_id,
trace=last_trace, traces=collected_traces,
)
return BenchResult(
completion=Completion(ok=True), correlation_id=correlation_id,
trace=last_trace, traces=collected_traces,
)
# LAST_COMPLETED placeholder (needs engine support for timing). Fall back.
last = handles[-1]
completion, trace = engine.get_completion(last)
return BenchResult(
completion=completion, correlation_id=correlation_id,
trace=trace, traces=collected_traces,
)
+282
View File
@@ -0,0 +1,282 @@
# kernbench/runtime_api/context.py
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from kernbench.common.types import Completion, RequestHandle, SimEngine
from .types import DeviceSelector
@dataclass
class RuntimeContext:
engine: SimEngine
target_device: DeviceSelector
correlation_id: str
spec: dict | None = None
_handles: list[RequestHandle] = field(default_factory=list, init=False)
_completed: set[RequestHandle] = field(default_factory=set, init=False)
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
_tensor_counter: int = field(default=0, init=False)
_traces: list[dict] = field(default_factory=list, init=False)
def submit(self, request: Any) -> RequestHandle:
submit_fn = getattr(self.engine, "submit", None)
if submit_fn is None:
raise AttributeError("Engine does not implement submit(request) -> RequestHandle.")
handle: RequestHandle = submit_fn(request) # type: ignore[call-arg]
self._handles.append(handle)
return handle
def is_completed(self, handle: RequestHandle) -> bool:
return handle in self._completed
def wait(self, handle: RequestHandle, *, _meta: dict | None = None) -> Completion:
if handle in self._completed:
completion, trace = self.engine.get_completion(handle)
return completion
wait_fn = getattr(self.engine, "wait", None)
if wait_fn is not None:
wait_fn(handle) # type: ignore[misc]
completion, trace = self.engine.get_completion(handle)
self._completed.add(handle)
if _meta is not None and trace is not None:
entry = dict(trace) if isinstance(trace, dict) else {"raw": trace}
entry.update(_meta)
self._traces.append(entry)
return completion
def wait_all(self) -> None:
for h in self._handles:
if h not in self._completed:
self.wait(h)
def handles(self) -> list[RequestHandle]:
return list(self._handles)
# ── PyTorch-like tensor API ──────────────────────────────────────
def _ensure_allocators(self) -> dict:
"""Lazily create PEMemAllocator instances from spec."""
if self._allocators:
return self._allocators
if self.spec is None:
raise RuntimeError(
"RuntimeContext.spec is required for tensor operations. "
"Pass spec=graph.spec when creating RuntimeContext."
)
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
system = self.spec.get("system", {})
cube = self.spec.get("cube", {})
mm = cube.get("memory_map", {})
pe_template = cube.get("pe_template", {})
pe_comps = pe_template.get("components", {})
tcm_cfg = pe_comps.get("pe_tcm", {}).get("attrs", {})
sip_count = system.get("sips", {}).get("count", 1)
cubes_per_sip = system.get("sips", {}).get("cubes_per_sip", 16)
pes_per_cube = (
cube.get("pe_layout", {}).get("pe_per_corner", 2)
* len(cube.get("pe_layout", {}).get("corners", ["NW", "NE", "SW", "SE"]))
)
hbm_gb = mm.get("hbm_total_gb_per_cube", 48)
hbm_slices = mm.get("hbm_slices_per_cube", 8)
tcm_mb = tcm_cfg.get("size_mb", 16)
cfg = AddressConfig(
sip_count=sip_count,
cubes_per_sip=cubes_per_sip,
pes_per_cube=pes_per_cube,
hbm_bytes_per_cube=hbm_gb * (1 << 30),
hbm_slices_per_cube=hbm_slices,
tcm_bytes_per_pe=tcm_mb * (1 << 20),
tcm_scheduler_reserved_bytes=4 * (1 << 20),
sram_bytes_per_cube=32 * (1 << 20),
)
# Create allocators for all SIPs × cubes × PEs
# Flat index: sip_id * cubes_per_sip * pes_per_cube + cube_id * pes_per_cube + pe_id
self._pes_per_cube = pes_per_cube
self._num_cubes = cubes_per_sip
self._num_sips = sip_count
cubes_x_pes = cubes_per_sip * pes_per_cube
for sip_id in range(sip_count):
for cube_id in range(cubes_per_sip):
for pe_id in range(pes_per_cube):
flat_idx = sip_id * cubes_x_pes + cube_id * pes_per_cube + pe_id
self._allocators[flat_idx] = PEMemAllocator(
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
)
return self._allocators
def _next_tensor_name(self) -> str:
self._tensor_counter += 1
return f"t{self._tensor_counter}"
def zeros(
self,
shape: tuple[int, ...],
dtype: str = "f16",
*,
placement: list | None = None,
dp: Any = None,
name: str | None = None,
):
"""Create a tensor and deploy to HBM with zero-fill (like torch.zeros)."""
return self._create_tensor(shape, dtype, placement, name, pattern="zero", dp=dp)
def empty(
self,
shape: tuple[int, ...],
dtype: str = "f16",
*,
placement: list | None = None,
dp: Any = None,
name: str | None = None,
):
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
return self._create_tensor(shape, dtype, placement, name, pattern=None, dp=dp)
def _create_tensor(
self,
shape: tuple[int, ...],
dtype: str,
placement: list | None,
name: str | None,
pattern: str | None,
dp: Any = None,
):
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
from kernbench.runtime_api.kernel import MemoryWriteMsg
from kernbench.runtime_api.tensor import Tensor, deploy_tensor, dtype_itemsize
tensor_name = name or self._next_tensor_name()
t = Tensor(shape=shape, dtype=dtype, name=tensor_name)
dp_policy: DPPolicy | None = None
# Resolve placement: dp= takes priority over placement=
if dp is not None and isinstance(dp, DPPolicy):
dp_policy = dp
allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
total_cubes = self._num_sips * self._num_cubes
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=self._pes_per_cube, num_cubes=total_cubes,
)
elif placement is None:
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=t.nbytes)]
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
pe_indices = {s.pe_index for s in placement}
target_pe: int | str = "all" if len(pe_indices) > 1 else next(iter(pe_indices))
t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy)
# Allocate PAs via PEMemAllocator
allocators = self._ensure_allocators()
handle = deploy_tensor(
name=tensor_name,
shape=shape,
dtype=dtype,
placement=placement,
allocators=allocators,
)
t._handle = handle
# Submit MemoryWriteMsg per shard (deploy data to device)
if pattern is not None:
for shard in handle.shards:
h = self.submit(MemoryWriteMsg(
correlation_id=self.correlation_id,
request_id=f"deploy_{tensor_name}_pe{shard.pe}",
dst_sip=shard.sip, dst_cube=shard.cube, dst_pe=shard.pe,
dst_pa=shard.pa, nbytes=shard.nbytes, pattern=pattern,
target_cubes=(shard.cube,), target_pe=shard.pe,
))
self.wait(h, _meta={
"phase": "memory_write", "name": tensor_name,
"sip": shard.sip, "cube": shard.cube, "pe": shard.pe,
"nbytes": shard.nbytes,
})
return t
def launch(
self,
kernel_name: str,
kernel_fn: Any,
*args: Any,
**kwargs: Any,
) -> RequestHandle:
"""Register and launch a kernel (like a fused torch op).
Positional args: Tensor objects become TensorArg, int/float become ScalarArg.
Keyword args: become ScalarArg (name is discarded, order preserved).
"""
from kernbench.runtime_api.kernel import (
KernelLaunchMsg,
KernelRef,
ScalarArg,
)
from kernbench.runtime_api.tensor import Tensor
from kernbench.triton_emu.registry import register_kernel
# Register kernel (idempotent)
try:
register_kernel(kernel_name, kernel_fn)
except ValueError:
pass
# Build kernel args from positional + keyword args
kernel_args: list = []
target_pe: int | str = 0
for a in args:
if isinstance(a, Tensor):
kernel_args.append(a.to_tensor_arg())
# Infer target_pe from tensor DP metadata
if a._dp_metadata is not None:
dp_target = a._dp_metadata.target_pe
if dp_target == "all":
target_pe = "all"
elif isinstance(dp_target, int) and target_pe != "all":
target_pe = dp_target
elif isinstance(a, (int, float)):
dtype_str = "f32" if isinstance(a, float) else "i32"
kernel_args.append(ScalarArg(dtype=dtype_str, value=a))
for v in kwargs.values():
if isinstance(v, (int, float)):
dtype_str = "f32" if isinstance(v, float) else "i32"
kernel_args.append(ScalarArg(dtype=dtype_str, value=v))
# Determine target cubes from all tensor shards
cube_set: set[int] = set()
for a in args:
if isinstance(a, Tensor) and a._handle is not None:
for s in a._handle.shards:
cube_set.add(s.cube)
target_cubes = tuple(sorted(cube_set)) if cube_set else (0,)
# Collect scalar values for GEMM FLOP calculation
scalar_vals = [a.value for a in kernel_args if hasattr(a, "value")]
h = self.submit(KernelLaunchMsg(
correlation_id=self.correlation_id,
request_id=kernel_name,
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
args=tuple(kernel_args),
target_cubes=target_cubes,
target_pe=target_pe,
))
self.wait(h, _meta={
"phase": "kernel", "name": kernel_name,
"target_pe": target_pe, "scalars": scalar_vals,
})
return h
+123
View File
@@ -0,0 +1,123 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal, TypeAlias
@dataclass(frozen=True)
class MemoryWriteMsg:
correlation_id: str
request_id: str
dst_sip: int
dst_cube: int
dst_pe: int
dst_pa: int
nbytes: int
src_kind: Literal["pattern", "host_buffer_ref"] = "pattern"
pattern: str | None = None
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["memory_write"] = "memory_write"
@dataclass(frozen=True)
class MemoryReadMsg:
correlation_id: str
request_id: str
src_sip: int
src_cube: int
src_pe: int
src_pa: int
nbytes: int
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["memory_read"] = "memory_read"
@dataclass(frozen=True)
class KernelRef:
"""Reference to a kernel binary or builtin timing model.
Kernel binaries must be pre-deployed to device memory via MemoryWriteMsg.
KernelLaunchMsg references the deployed location by PA — source code or IR
MUST NOT be embedded in launch messages.
- "deployed": kernel binary pre-deployed to HBM/SRAM at deploy_pa.
- "builtin": simulator built-in timing model, identified by name.
"""
name: str
kind: Literal["deployed", "builtin"]
deploy_pa: int | None = None
deploy_sip: int = 0
deploy_cube: int = 0
deploy_pe: int = 0
nbytes_code: int = 0
@dataclass(frozen=True)
class TensorArgShard:
sip: int
cube: int
pe: int
pa: int
nbytes: int
offset_bytes: int
@dataclass(frozen=True)
class TensorArg:
shards: tuple[TensorArgShard, ...]
arg_kind: Literal["tensor"] = "tensor"
@dataclass(frozen=True)
class ScalarArg:
dtype: str
value: float | int
arg_kind: Literal["scalar"] = "scalar"
KernelArg: TypeAlias = TensorArg | ScalarArg
@dataclass(frozen=True)
class KernelLaunchMsg:
correlation_id: str
request_id: str
kernel_ref: KernelRef
args: tuple[KernelArg, ...]
target_cubes: tuple[int, ...] | Literal["all"] = "all"
target_pe: int | Literal["all"] = "all"
msg_type: Literal["kernel_launch"] = "kernel_launch"
@dataclass(frozen=True)
class ResponseMsg:
"""Device→Host response carrying PE execution result."""
correlation_id: str
request_id: str
src_cube: int
src_pe: int
success: bool
msg_type: Literal["response"] = "response"
@dataclass(frozen=True)
class PeDmaMsg:
"""Direct PE DMA request: host injects a transfer at PE_DMA level.
Used by the probe utility to measure PE→HBM latency without requiring
the full PE_CPU → scheduler → DMA pipeline.
"""
correlation_id: str
request_id: str
src_sip: int
src_cube: int
src_pe: int
dst_pa: int
nbytes: int
is_write: bool = False
msg_type: Literal["pe_dma"] = "pe_dma"
+166
View File
@@ -0,0 +1,166 @@
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Literal
from kernbench.policy.address.allocator import PEMemAllocator
from kernbench.policy.placement.dp import DPPolicy, ShardSpec
from kernbench.runtime_api.kernel import TensorArg, TensorArgShard
@dataclass(frozen=True)
class TensorShard:
sip: int
cube: int
pe: int
pa: int
nbytes: int
offset_bytes: int
@dataclass(frozen=True)
class TensorHandle:
name: str
shape: tuple[int, ...]
dtype: str
itemsize: int
shards: tuple[TensorShard, ...]
@property
def nbytes(self) -> int:
return math.prod(self.shape) * self.itemsize
_DTYPE_ITEMSIZE = {
"fp16": 2, "float16": 2, "f16": 2,
"fp32": 4, "float32": 4, "f32": 4,
"bf16": 2,
"int8": 1, "i8": 1,
"int16": 2, "i16": 2,
"int32": 4, "i32": 4,
}
def dtype_itemsize(dtype: str) -> int:
if dtype not in _DTYPE_ITEMSIZE:
raise ValueError(f"unsupported dtype: {dtype}")
return _DTYPE_ITEMSIZE[dtype]
def deploy_tensor(
*,
name: str,
shape: tuple[int, ...],
dtype: str,
placement: list[ShardSpec],
allocators: dict[int, PEMemAllocator],
mem_kind: Literal["hbm", "tcm"] = "hbm",
) -> TensorHandle:
isize = dtype_itemsize(dtype)
shards: list[TensorShard] = []
for spec in placement:
alloc = allocators[spec.pe_index]
if mem_kind == "hbm":
pa = alloc.alloc_hbm(spec.nbytes)
else:
pa = alloc.alloc_tcm(spec.nbytes)
shards.append(TensorShard(
sip=alloc._sip_id,
cube=alloc._cube_id,
pe=alloc._pe_id,
pa=pa.encode(),
nbytes=spec.nbytes,
offset_bytes=spec.offset_bytes,
))
return TensorHandle(
name=name,
shape=shape,
dtype=dtype,
itemsize=isize,
shards=tuple(shards),
)
# ── PyTorch-like Tensor API ──────────────────────────────────────────
@dataclass(frozen=True)
class DPMetadata:
"""Data-parallel placement metadata (stored as Tensor._dp_metadata)."""
placement: list[ShardSpec]
dp_policy: DPPolicy | None = None
sip: int = 0
cube: int = 0
target_pe: int | str = 0 # int → single PE, "all" → all PEs
class Tensor:
"""PyTorch-like tensor for benchmark code.
Usage::
a = ctx.zeros((M, K), dtype="f16")
a = ctx.zeros((M, K), dtype="f16", placement=dp.replicate(num_pe=8))
ctx.launch("kernel_name", kernel_fn, a, b, out, M=M, K=K)
"""
def __init__(
self,
shape: tuple[int, ...],
dtype: str = "f16",
name: str = "",
) -> None:
self.shape = shape
self.dtype = dtype
self.name = name
self._dp_metadata: DPMetadata | None = None
self._handle: TensorHandle | None = None
@property
def itemsize(self) -> int:
return dtype_itemsize(self.dtype)
@property
def nbytes(self) -> int:
return math.prod(self.shape) * self.itemsize
@property
def pa(self) -> int:
"""Primary PA (first shard). Used as kernel pointer argument."""
if self._handle is None or not self._handle.shards:
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
return self._handle.shards[0].pa
def to(
self,
placement: list[ShardSpec] | None = None,
*,
dp_policy: DPPolicy | None = None,
sip: int = 0,
cube: int = 0,
target_pe: int | str = 0,
) -> Tensor:
"""Set DP placement metadata (like torch.Tensor.to())."""
if placement is None:
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=self.nbytes)]
self._dp_metadata = DPMetadata(
placement=placement, dp_policy=dp_policy,
sip=sip, cube=cube, target_pe=target_pe,
)
return self
def to_tensor_arg(self) -> TensorArg:
"""Convert deployed shards to KernelLaunchMsg TensorArg."""
if self._handle is None:
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
return TensorArg(
shards=tuple(
TensorArgShard(
sip=s.sip, cube=s.cube, pe=s.pe,
pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes,
)
for s in self._handle.shards
),
)
+71
View File
@@ -0,0 +1,71 @@
from __future__ import annotations
import re
from dataclasses import dataclass
from kernbench.common.types import Completion, Trace
@dataclass(frozen=True)
class BenchResult:
completion: Completion
correlation_id: str
trace: Trace | None = None
traces: list[dict] | None = None
def summary_text(self) -> str:
if self.completion.ok:
return f"[OK] correlation_id={self.correlation_id}"
code = self.completion.error_code or "ERROR"
msg = self.completion.error_message or ""
return f"[FAIL:{code}] correlation_id={self.correlation_id} {msg}".rstrip()
@dataclass(frozen=True)
class DeviceSelector:
"""
Device selector.
Supported:
- "all" : all SIPs in the tray topology
- "sip:<N>" : a single SIP index
"""
raw: str # "all" or "sip:<N>"
@property
def is_all(self) -> bool:
return self.raw == "all"
@property
def sip_index(self) -> int:
if self.is_all:
raise ValueError("DeviceSelector is 'all'; no single sip_index.")
m = re.fullmatch(r"sip:(\d+)", self.raw)
if not m:
raise ValueError(
f"Invalid device '{self.raw}'. Expected 'all' or 'sip:<N>' (e.g., sip:0)."
)
return int(m.group(1))
def resolve_device(raw: str | None) -> DeviceSelector:
"""
Resolve the CLI --device string into a DeviceSelector.
Semantics:
- if omitted/empty -> "all"
- else accept "all" or "sip:<N>"
"""
if raw is None or raw.strip() == "":
return DeviceSelector(raw="all")
raw = raw.strip().lower()
if raw == "all":
return DeviceSelector(raw="all")
m = re.fullmatch(r"sip:(\d+)", raw)
if not m:
raise ValueError(f"Invalid device '{raw}'. Expected 'all' or 'sip:<N>' (e.g., sip:0).")
return DeviceSelector(raw=raw)
+31
View File
@@ -0,0 +1,31 @@
# kernbench/engine/dummy.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from kernbench.common.types import Completion, RequestHandle, SimEngine, Trace
@dataclass
class DummyEngine(SimEngine):
topology: object
device_raw: str
_n: int = 0
_store: dict[str, tuple[Completion, Trace | None]] = None # type: ignore
def __post_init__(self) -> None:
self._store = {}
def submit(self, request: Any) -> RequestHandle:
self._n += 1
h = RequestHandle(f"h{self._n}")
# 여기서 request 처리/시뮬레이션/스케줄링 등을 수행
self._store[str(h)] = (Completion(ok=True), {"request": request, "device": self.device_raw})
return h
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
return self._store[str(handle)]
def wait(self, handle: RequestHandle) -> None:
pass
+298
View File
@@ -0,0 +1,298 @@
from __future__ import annotations
from typing import Any
import simpy
from kernbench.common.types import Completion, RequestHandle, Trace
import kernbench.components.impls # noqa: F401 — registers built-in implementations
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.policy.routing.router import AddressResolver, PathRouter
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, PeDmaMsg
from kernbench.sim_engine.transaction import Transaction
from kernbench.topology.types import Edge, TopologyGraph
class GraphEngine:
"""simpy-based discrete-event simulation engine.
Phase B: engine injects a Transaction into the PCIE_EP host queue for
each request. Components handle their own routing:
Path 1: PCIE_EP → IO_CPU (engine-computed path, pre-loaded in Transaction)
Path 2: IO_CPU → M_CPU (IO_CPU dispatches, fire-and-forget callback)
Path 3: M_CPU.DMA → HBM (M_CPU dispatches, fire-and-forget callback)
Component implementations are DI-injectable via component_overrides (ADR-0007 D3).
"""
def __init__(
self,
graph: TopologyGraph,
*,
component_overrides: dict[str, type[ComponentBase]] | None = None,
) -> None:
self._env = simpy.Environment()
self._resolver = AddressResolver(graph)
self._router = PathRouter(graph)
self._nodes = graph.nodes
self._edge_map: dict[tuple[str, str], Edge] = {}
for e in graph.edges:
self._edge_map[(e.src, e.dst)] = e
self._ns_per_mm: float = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
self._results: dict[str, tuple[Completion, Trace]] = {}
self._events: dict[str, simpy.Event] = {}
self._counter = 0
overrides = component_overrides or {}
ctx = ComponentContext(
router=self._router,
resolver=self._resolver,
positions={nid: n.pos_mm for nid, n in graph.nodes.items()},
ns_per_mm=self._ns_per_mm,
edge_map=self._edge_map,
spec=graph.spec,
)
self._components: dict[str, ComponentBase] = {
node_id: ComponentRegistry.create(node, overrides, ctx)
for node_id, node in graph.nodes.items()
}
# Wire ports: one Store per directed edge (ADR-0015 D1)
for e in graph.edges:
src_comp = self._components.get(e.src)
dst_comp = self._components.get(e.dst)
if src_comp is None or dst_comp is None:
continue
store: simpy.Store = simpy.Store(self._env)
src_comp.out_ports[e.dst] = store
dst_comp.in_ports[e.src] = store
# Wire processes: propagation delay per edge (ADR-0015 D2)
# Cut-through (wormhole) model: wires apply propagation only.
# Serialization (drain) is computed per-path and applied once at the terminal.
for e in graph.edges:
src_comp = self._components.get(e.src)
dst_comp = self._components.get(e.dst)
if src_comp is None or dst_comp is None:
continue
prop_ns = e.distance_mm * self._ns_per_mm
self._env.process(
self._wire(src_comp.out_ports[e.dst], dst_comp.in_ports[e.src],
prop_ns)
)
# Attach host queues to PCIE_EP in_ports before start() (ADR-0015 D3)
self._host_queues: dict[str, simpy.Store] = {}
for pcie_ep_id in self._resolver.find_all_pcie_eps():
host_q: simpy.Store = simpy.Store(self._env)
self._components[pcie_ep_id].in_ports["host"] = host_q
self._host_queues[pcie_ep_id] = host_q
# Attach host queues to PE_DMA nodes for direct PE DMA injection
self._pe_dma_queues: dict[str, simpy.Store] = {}
for node_id, node in graph.nodes.items():
if node.kind == "pe_dma":
host_q = simpy.Store(self._env)
self._components[node_id].in_ports["host"] = host_q
self._pe_dma_queues[node_id] = host_q
# Start components after all ports are wired (ADR-0015 D3)
for comp in self._components.values():
comp.start(self._env)
def submit(self, request: Any) -> RequestHandle:
self._counter += 1
handle = RequestHandle(f"h{self._counter}")
event = self._env.event()
self._events[str(handle)] = event
self._env.process(self._process(str(handle), request, event))
return handle
def wait(self, handle: RequestHandle) -> None:
key = str(handle)
event = self._events[key]
if not event.triggered:
self._env.run(until=event)
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
return self._results[str(handle)]
# ── internal ────────────────────────────────────────────────────
def _wire(
self,
out_port: simpy.Store,
in_port: simpy.Store,
prop_ns: float,
):
"""SimPy process: relay messages with propagation delay only.
Cut-through (wormhole) model: serialization (drain) is computed per-path
and applied once at the terminal component, not at every wire hop.
"""
while True:
msg = yield out_port.get()
if prop_ns > 0:
yield self._env.timeout(prop_ns)
yield in_port.put(msg)
def _process(self, key: str, request: Any, done: simpy.Event):
if isinstance(request, PeDmaMsg):
yield from self._process_pe_dma(key, request, done)
return
entries = self._entry_points(request)
if not entries:
self._results[key] = (
Completion(ok=True),
{"total_ns": 0.0, "nbytes": 0},
)
done.succeed()
return
start_ns = self._env.now
total_nbytes = 0
root_txn: Transaction | None = None
if len(entries) == 1:
# Single-SIP: direct inject (common path, no extra events)
pcie_ep_id, io_cpu_id, nbytes = entries[0]
total_nbytes = nbytes
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
txn_done = self._env.event()
txn = Transaction(request=request, path=path, step=0, nbytes=nbytes, done=txn_done)
root_txn = txn
yield self._host_queues[pcie_ep_id].put(txn)
yield txn_done
else:
# Multi-SIP: inject per SIP, aggregate completions (ADR-0007)
sub_dones: list[simpy.Event] = []
sub_txns: list[Transaction] = []
for pcie_ep_id, io_cpu_id, nbytes in entries:
total_nbytes = max(total_nbytes, nbytes)
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
txn_done = self._env.event()
txn = Transaction(
request=request, path=path, step=0,
nbytes=nbytes, done=txn_done,
)
yield self._host_queues[pcie_ep_id].put(txn)
sub_dones.append(txn_done)
sub_txns.append(txn)
for sd in sub_dones:
yield sd
# Aggregate pe_exec_ns from multi-SIP (max)
pe_vals = [st.result_data.get("pe_exec_ns") for st in sub_txns]
pe_vals = [v for v in pe_vals if v is not None]
if pe_vals:
if root_txn is None:
root_txn = sub_txns[0]
root_txn.result_data["pe_exec_ns"] = max(pe_vals)
total_ns = self._env.now - start_ns
result_trace: dict[str, Any] = {"total_ns": total_ns, "nbytes": total_nbytes}
if root_txn is not None and root_txn.result_data:
result_trace.update(root_txn.result_data)
self._results[key] = (
Completion(ok=True),
result_trace,
)
done.succeed()
def _process_pe_dma(self, key: str, request: PeDmaMsg, done: simpy.Event):
"""Inject a Transaction directly at PE_DMA for PE→HBM latency measurement."""
pe_prefix = f"sip{request.src_sip}.cube{request.src_cube}.pe{request.src_pe}"
pe_dma_id = f"{pe_prefix}.pe_dma"
pa = PhysAddr.decode(request.dst_pa)
dst_node = self._resolver.resolve(pa)
path = self._router.find_path(pe_prefix, dst_node)
drain_ns = self._path_drain_ns(path, request.nbytes)
start_ns = self._env.now
txn_done = self._env.event()
txn = Transaction(request=request, path=path, step=0, nbytes=request.nbytes,
done=txn_done, drain_ns=drain_ns)
yield self._pe_dma_queues[pe_dma_id].put(txn)
yield txn_done
total_ns = self._env.now - start_ns
formula_ns = self._formula_latency(path, request.nbytes)
self._results[key] = (
Completion(ok=True),
{"total_ns": total_ns, "formula_ns": formula_ns, "nbytes": request.nbytes},
)
done.succeed()
def _path_drain_ns(self, path: list[str], nbytes: int) -> float:
"""Wormhole drain time: nbytes / bottleneck_bw along path."""
min_bw = float("inf")
for i in range(len(path) - 1):
edge = self._edge_map.get((path[i], path[i + 1]))
if edge and edge.bw_gbs:
min_bw = min(min_bw, edge.bw_gbs)
if min_bw == float("inf"):
return 0.0
return nbytes / min_bw
def _formula_latency(self, path: list[str], nbytes: int) -> float:
"""Lower-bound formula latency (ADR-0015 D7).
formula = Σ(wire propagation) + Σ(component overhead_ns) + drain_ns
Phase 0: formula == actual (no contention).
Phase 1+: formula <= actual (contention adds queueing).
"""
total = 0.0
# Wire propagation delays
for i in range(len(path) - 1):
edge = self._edge_map.get((path[i], path[i + 1]))
if edge:
total += edge.distance_mm * self._ns_per_mm
# Component overhead_ns
for node_id in path:
node = self._nodes.get(node_id)
if node:
total += float(node.attrs.get("overhead_ns", 0.0))
# Drain
total += self._path_drain_ns(path, nbytes)
return total
def _entry_points(self, request: Any) -> list[tuple[str, str, int]]:
"""Return list of (pcie_ep_id, io_cpu_id, nbytes) per target SIP.
For Memory{Write,Read}: single SIP entry.
For KernelLaunchMsg: one entry per distinct SIP in tensor shards.
"""
if isinstance(request, MemoryWriteMsg):
sip = request.dst_sip
return [(
self._resolver.find_pcie_ep(sip),
self._resolver.find_io_cpu(sip),
request.nbytes,
)]
if isinstance(request, MemoryReadMsg):
sip = request.src_sip
return [(
self._resolver.find_pcie_ep(sip),
self._resolver.find_io_cpu(sip),
request.nbytes,
)]
if isinstance(request, KernelLaunchMsg):
seen: set[int] = set()
entries: list[tuple[str, str, int]] = []
for arg in request.args:
if arg.arg_kind != "tensor":
continue
for shard in arg.shards:
if shard.sip not in seen:
seen.add(shard.sip)
entries.append((
self._resolver.find_pcie_ep(shard.sip),
self._resolver.find_io_cpu(shard.sip),
shard.nbytes,
))
return entries
raise ValueError(f"unsupported request type: {type(request)}")
+49
View File
@@ -0,0 +1,49 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import simpy
@dataclass
class Transaction:
"""In-flight request traversing the device fabric hop-by-hop (ADR-0015 D4).
A Transaction carries a host request through one leg of the device fabric.
Each component on the path reads from its in_port, processes (overhead_ns or
other latency), and advances the Transaction to the next hop via out_port.
Wire processes (ADR-0015 D2) model propagation delay between hops.
Multi-leg flows (e.g. IO_CPU → M_CPU as leg 1, M_CPU.DMA → HBM as leg 2)
use separate Transactions: the terminal component of leg 1 creates leg 2
and waits for leg 2's done before succeeding leg 1's done.
"""
request: Any # original host request (MemoryReadMsg, KernelLaunchMsg, …)
path: list[str] # node_id sequence for this leg
step: int # index of the component currently holding this Transaction
nbytes: int # payload size (bytes)
done: simpy.Event # succeeded when this leg completes
drain_ns: float = 0.0 # wormhole drain time: nbytes / bottleneck_bw (applied once at terminal)
is_response: bool = False # True when carrying ResponseMsg on reverse path
result_data: dict[str, Any] = field(default_factory=dict) # PE-level metrics (pe_exec_ns, etc.)
@property
def next_hop(self) -> str | None:
"""Node id of the next component, or None if this is the terminal hop."""
nxt = self.step + 1
return self.path[nxt] if nxt < len(self.path) else None
def advance(self) -> Transaction:
"""Return a copy of this Transaction advanced one step along the path."""
return Transaction(
request=self.request,
path=self.path,
step=self.step + 1,
nbytes=self.nbytes,
done=self.done,
drain_ns=self.drain_ns,
is_response=self.is_response,
result_data=self.result_data,
)
View File
+965
View File
@@ -0,0 +1,965 @@
# kernbench/topology/builder.py
"""
Topology compiler: parses topology.yaml and produces a fully-instantiated
TopologyGraph with nodes, edges, and representative view projections.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import yaml
from .types import Edge, Node, TopologyGraph, TopologyHandle, ViewGraph
# PE component offsets from PE center (small, intra-PE distances ~0.5mm)
_PE_COMP_OFFSETS = {
"pe_cpu": (-0.3, 0.0),
"pe_scheduler": (-0.15, 0.0),
"pe_dma": (0.0, -0.15),
"pe_gemm": (0.0, 0.0),
"pe_math": (0.0, 0.15),
"pe_tcm": (0.3, 0.0),
}
# ── Public API ───────────────────────────────────────────────────────
def resolve_topology(path_str: str) -> TopologyHandle:
"""Validate path and build compiled topology graph."""
p = Path(path_str).expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Topology file not found: {p}")
if not p.is_file():
raise ValueError(f"Topology path is not a file: {p}")
graph = load_topology(p)
return TopologyHandle(path=p, topology_obj=graph)
def load_topology(path: Path) -> TopologyGraph:
"""Load topology spec from file and compile into a topology graph."""
spec = _read_spec(path)
_validate_spec(spec)
return _compile_graph(spec)
def _read_spec(path: Path) -> dict[str, Any]:
"""Read YAML topology spec file and return a dict."""
try:
with path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
except yaml.YAMLError as e:
msg = f"Failed to parse YAML topology: {path}"
mark = getattr(e, "problem_mark", None)
if mark is not None:
msg += f" (line {mark.line + 1}, column {mark.column + 1})"
raise ValueError(msg) from e
if data is None:
raise ValueError(f"Topology YAML is empty: {path}")
if not isinstance(data, dict):
raise ValueError(
f"Topology YAML root must be a mapping/dict: {path} (got {type(data).__name__})"
)
return data
def _validate_spec(spec: dict) -> None:
# TODO: schema validation
return
# ── Graph Compiler ───────────────────────────────────────────────────
def _compile_graph(spec: dict) -> TopologyGraph:
"""Build fully-instantiated flat graph + representative view projections."""
nodes: dict[str, Node] = {}
edges: list[Edge] = []
system = spec["system"]
sip_spec = spec["sip"]
cube_spec = spec["cube"]
mesh_w = sip_spec["cube_mesh"]["w"]
mesh_h = sip_spec["cube_mesh"]["h"]
cube_w = cube_spec["geometry"]["cube_mm"]["w"]
cube_h = cube_spec["geometry"]["cube_mm"]["h"]
seam = sip_spec["links"]["inter_cube_mesh"]["distance_mm_across_seam"]
stride_x = cube_w + seam
stride_y = cube_h + seam
# System-level
_instantiate_system(nodes, system)
# Per-SIP
for sip_id in range(system["sips"]["count"]):
sp = f"sip{sip_id}"
# IO chiplets
_instantiate_io_chiplets(
nodes, edges, sp, sip_spec,
cube_w, cube_h, mesh_w, mesh_h, seam,
)
# Cubes + PEs
for row in range(mesh_h):
for col in range(mesh_w):
cid = row * mesh_w + col
cp = f"{sp}.cube{cid}"
origin = (col * stride_x, row * stride_y)
_instantiate_cube(nodes, edges, cp, cube_spec, origin)
# Inter-cube UCIe mesh
_add_inter_cube_edges(edges, sp, mesh_w, mesh_h, sip_spec)
# IO → cube UCIe
_add_io_to_cube_edges(edges, sp, sip_spec, mesh_w)
# Switch → IO pcie_ep
_add_system_to_io_edges(edges, sp, sip_spec, system)
# Build views
return TopologyGraph(
spec=spec,
nodes=nodes,
edges=edges,
system_view=_build_system_view(spec),
sip_view=_build_sip_view(spec),
cube_view=_build_cube_view(spec),
pe_view=_build_pe_view(spec),
)
# ── Layout helpers ───────────────────────────────────────────────────
def _cube_local_positions(cube_w: float, cube_h: float) -> dict[str, tuple[float, float]]:
"""Cube-internal component positions relative to cube origin (0,0) at top-left."""
cx, cy = cube_w / 2, cube_h / 2
# UCIe node half-sizes (default 2.0×1.2mm) — inset so edges touch boundary
uh = 0.6 # half height
uw = 1.0 # half width
return {
"ucie-N": (cx, uh),
"ucie-S": (cx, cube_h - uh),
"ucie-W": (uw, cy),
"ucie-E": (cube_w - uw, cy),
"m_cpu": (cube_w - 2.5, cy - 1.5),
"xbar.top": (cx, 3.5), # Y reference for top-half xbar.pe nodes
"hbm_ctrl": (cx - 2.0, cy),
"xbar.bottom": (cx, cube_h - 3.5), # Y reference for bottom-half xbar.pe nodes
"bridge.left": (2.5, cy + 2.0),
"bridge.right": (cube_w - 2.5, cy + 2.0),
"noc": (cx + 2.0, cy),
"sram": (2.5, cy - 1.5),
}
def _corner_pe_positions(cube_w: float, cube_h: float) -> dict[str, list[tuple[float, float]]]:
"""PE center positions per corner, relative to cube origin."""
return {
"NW": [(1.5, 1.5), (4.5, 1.5)],
"NE": [(cube_w - 4.5, 1.5), (cube_w - 1.5, 1.5)],
"SW": [(1.5, cube_h - 1.5), (4.5, cube_h - 1.5)],
"SE": [(cube_w - 4.5, cube_h - 1.5), (cube_w - 1.5, cube_h - 1.5)],
}
# ── Instantiation: system ───────────────────────────────────────────
def _instantiate_system(nodes: dict[str, Node], system: dict) -> None:
"""Add system-level nodes (fabric switch)."""
sw = system["components"]["switch"]
sw_id = "fabric.switch0"
nodes[sw_id] = Node(
id=sw_id, kind=sw["kind"], impl=sw["impl"],
attrs=sw.get("attrs", {}), pos_mm=None, label="Switch",
)
# ── Instantiation: IO chiplets ──────────────────────────────────────
def _instantiate_io_chiplets(
nodes: dict[str, Node],
edges: list[Edge],
sp: str,
sip_spec: dict,
cube_w: float,
cube_h: float,
mesh_w: int,
mesh_h: int,
seam: float,
) -> None:
"""Add IO chiplet nodes and internal pcie_ep → io_cpu edges."""
io_spec = sip_spec["iochiplet"]
comp = io_spec["components"]
links = io_spec["links"]
mesh_total_w = mesh_w * cube_w + (mesh_w - 1) * seam
mesh_total_h = mesh_h * cube_h + (mesh_h - 1) * seam
for inst in io_spec["instances"]:
iid = inst["id"]
prefix = f"{sp}.{iid}"
side = inst["place"]["side"]
cx = mesh_total_w / 2
if side == "N":
pcie_y, cpu_y = -5.0, -3.0
else:
pcie_y, cpu_y = mesh_total_h + 5.0, mesh_total_h + 3.0
# pcie_ep
ep = comp["pcie_ep"]
ep_id = f"{prefix}.pcie_ep"
nodes[ep_id] = Node(
id=ep_id, kind=ep["kind"], impl=ep["impl"],
attrs=ep["attrs"], pos_mm=(cx, pcie_y), label="PCIe EP",
)
# io_cpu
cpu = comp["io_cpu"]
cpu_id = f"{prefix}.io_cpu"
nodes[cpu_id] = Node(
id=cpu_id, kind=cpu["kind"], impl=cpu["impl"],
attrs=cpu["attrs"], pos_mm=(cx, cpu_y), label="IO CPU",
)
# Internal edge
edges.append(Edge(
src=ep_id, dst=cpu_id,
distance_mm=links["pcie_ep_to_io_cpu_mm"],
bw_gbs=links["pcie_ep_to_io_cpu_bw_gbs"],
kind="io_internal",
))
# ── Instantiation: cube + PEs ───────────────────────────────────────
def _instantiate_cube(
nodes: dict[str, Node],
edges: list[Edge],
cp: str,
cube: dict,
origin: tuple[float, float],
) -> None:
"""Add all cube-internal nodes and edges, including PE instances."""
cube_w = cube["geometry"]["cube_mm"]["w"]
cube_h = cube["geometry"]["cube_mm"]["h"]
ox, oy = origin
local_pos = _cube_local_positions(cube_w, cube_h)
clinks = cube["links"]
n_slices = cube["memory_map"]["hbm_slices_per_cube"]
# ── UCIe ports ──
ucie_ns = cube["ucie"]["overhead_ns"]
for port in cube["ucie"]["ports"]:
pid = f"{cp}.ucie-{port}"
lx, ly = local_pos[f"ucie-{port}"]
nodes[pid] = Node(
id=pid, kind="ucie_port", impl="ucie_v1",
attrs={"overhead_ns": ucie_ns}, pos_mm=(ox + lx, oy + ly),
label=f"UCIe-{port}",
)
# ── Named components: noc, m_cpu, sram ──
for name in ("noc", "m_cpu", "sram"):
c = cube["components"][name]
nid = f"{cp}.{name}"
lx, ly = local_pos[name]
nodes[nid] = Node(
id=nid, kind=c["kind"], impl=c["impl"],
attrs=c["attrs"], pos_mm=(ox + lx, oy + ly),
label=name.upper().replace("_", " "),
)
# ── HBM controller slices (one per PE) ──
hbm_spec = cube["components"]["hbm_ctrl"]
hbm_lx, hbm_ly = local_pos["hbm_ctrl"]
for sl in range(n_slices):
sid = f"{cp}.hbm_ctrl.slice{sl}"
nodes[sid] = Node(
id=sid, kind=hbm_spec["kind"], impl=hbm_spec["impl"],
attrs=hbm_spec["attrs"], pos_mm=(ox + hbm_lx, oy + hbm_ly),
label=f"HBM SLICE{sl}",
)
# ── Bridges ──
for br in cube["components"]["xbar"]["bridges"]:
bname = br["id"]
nid = f"{cp}.bridge.{bname}"
lx, ly = local_pos[f"bridge.{bname}"]
nodes[nid] = Node(
id=nid, kind=br["kind"], impl=br["impl"],
attrs=br["attrs"], pos_mm=(ox + lx, oy + ly),
label=f"Bridge {bname.upper()}",
)
# ── PE instances + per-PE xbar entry nodes ──
corners = cube["pe_layout"]["corners"]
pe_per_corner = cube["pe_layout"]["pe_per_corner"]
corner_pos = _corner_pe_positions(cube_w, cube_h)
pe_tmpl = cube["pe_template"]
pe_links = pe_tmpl["links"]
xbar_pe_spec = cube["components"]["xbar"]["pe"]
xbar_top_y = local_pos["xbar.top"][1]
xbar_bot_y = local_pos["xbar.bottom"][1]
pe_idx = 0
for corner in corners:
is_top = corner in ("NW", "NE")
xbar_y = xbar_top_y if is_top else xbar_bot_y
mm_key = "pe_to_xbar_row_n_mm" if is_top else "pe_to_xbar_row_s_mm"
for ci in range(pe_per_corner):
pp = f"{cp}.pe{pe_idx}"
pe_cx, pe_cy = corner_pos[corner][ci]
# Per-PE xbar entry node
xbar_nid = f"{cp}.xbar.pe{pe_idx}"
nodes[xbar_nid] = Node(
id=xbar_nid, kind=xbar_pe_spec["kind"], impl=xbar_pe_spec["impl"],
attrs=xbar_pe_spec["attrs"], pos_mm=(ox + pe_cx, oy + xbar_y),
label=f"XBAR PE{pe_idx}",
)
# PE template components
for comp_name, comp_spec in pe_tmpl["components"].items():
cid = f"{pp}.{comp_name}"
dx, dy = _PE_COMP_OFFSETS.get(comp_name, (0.0, 0.0))
nodes[cid] = Node(
id=cid, kind=comp_spec["kind"], impl=comp_spec["impl"],
attrs=comp_spec["attrs"],
pos_mm=(ox + pe_cx + dx, oy + pe_cy + dy),
label=comp_name.upper().replace("_", " "),
)
# PE-internal edges
_add_pe_internal_edges(edges, pp, pe_links)
# PE_DMA → xbar.pe_i (HBM data path)
edges.append(Edge(
src=f"{pp}.pe_dma", dst=xbar_nid,
distance_mm=clinks[mm_key],
bw_gbs=clinks["pe_to_xbar_bw_gbs"],
kind="pe_to_xbar",
))
# PE_DMA → noc (non-HBM data path: SRAM, inter-cube, etc.)
edges.append(Edge(
src=f"{pp}.pe_dma", dst=f"{cp}.noc",
distance_mm=clinks["pe_dma_to_noc_mm"],
bw_gbs=clinks["pe_dma_to_noc_bw_gbs"],
kind="pe_to_noc",
))
# noc → PE_CPU (command delivery)
edges.append(Edge(
src=f"{cp}.noc", dst=f"{pp}.pe_cpu",
distance_mm=clinks["noc_to_pe_cpu_mm"],
kind="command",
))
pe_idx += 1
# ── Cube fabric edges ──
# xbar.pe_i ↔ hbm_ctrl.slice_i (local Y-path, bidirectional for response)
for i in range(n_slices):
edges.append(Edge(
src=f"{cp}.xbar.pe{i}", dst=f"{cp}.hbm_ctrl.slice{i}",
distance_mm=clinks["xbar_to_hbm_mm"],
bw_gbs=clinks["xbar_to_hbm_bw_gbs"],
kind="xbar_to_hbm",
))
edges.append(Edge(
src=f"{cp}.hbm_ctrl.slice{i}", dst=f"{cp}.xbar.pe{i}",
distance_mm=clinks["xbar_to_hbm_mm"],
bw_gbs=clinks["xbar_to_hbm_bw_gbs"],
kind="hbm_to_xbar",
))
# xbar chain: pe0↔pe1↔pe2↔pe3 (top), pe4↔pe5↔pe6↔pe7 (bottom)
half = n_slices // 2
for half_start in (0, half):
for i in range(half_start, half_start + half - 1):
intra = ((i - half_start) % pe_per_corner) != (pe_per_corner - 1)
x_dist = clinks["xbar_chain_intra_corner_mm"] if intra else clinks["xbar_chain_inter_corner_mm"]
for a, b in [(i, i + 1), (i + 1, i)]:
edges.append(Edge(
src=f"{cp}.xbar.pe{a}", dst=f"{cp}.xbar.pe{b}",
distance_mm=x_dist,
bw_gbs=clinks["xbar_x_bw_gbs"],
kind="xbar_chain",
))
# bridge connections: pe0↔bridge.left↔pe4, pe3↔bridge.right↔pe7
for bname, pe_top, pe_bot in [("left", 0, half), ("right", half - 1, n_slices - 1)]:
br_node = f"{cp}.bridge.{bname}"
for pe_i, br_mm_key in [(pe_top, "xbar_row_n_to_bridge_mm"),
(pe_bot, "xbar_row_s_to_bridge_mm")]:
xbar_node = f"{cp}.xbar.pe{pe_i}"
edges.append(Edge(
src=xbar_node, dst=br_node,
distance_mm=clinks[br_mm_key],
bw_gbs=clinks["xbar_to_bridge_bw_gbs"],
kind="xbar_to_bridge",
))
edges.append(Edge(
src=br_node, dst=xbar_node,
distance_mm=clinks[br_mm_key],
bw_gbs=clinks["xbar_to_bridge_bw_gbs"],
kind="bridge_to_xbar",
))
# ucie ↔ noc (UCIe-NOC boundary; per_connection_bw_gbs = 128 GB/s, n_connections = 4)
_noc_ucie = clinks["noc_to_ucie"]
for port in cube["ucie"]["ports"]:
edges.append(Edge(
src=f"{cp}.ucie-{port}", dst=f"{cp}.noc",
distance_mm=0.0,
bw_gbs=_noc_ucie["per_connection_bw_gbs"],
n_connections=_noc_ucie["n_connections"],
kind="ucie_to_noc",
))
for port in cube["ucie"]["ports"]:
edges.append(Edge(
src=f"{cp}.noc", dst=f"{cp}.ucie-{port}",
distance_mm=0.0,
bw_gbs=_noc_ucie["per_connection_bw_gbs"],
n_connections=_noc_ucie["n_connections"],
kind="noc_to_ucie",
))
# noc ↔ xbar.pe{i}: wire delay is 0 (NOC traversal latency computed by TwoDMeshNocComponent);
# routing_weight_mm=50.0 steers PE DMA Dijkstra away from this path (prefer direct pe_dma→xbar)
_noc_xbar = clinks.get("noc_to_xbar", {})
_noc_xbar_bw = _noc_xbar.get("per_connection_bw_gbs")
for i in range(n_slices):
edges.append(Edge(
src=f"{cp}.noc", dst=f"{cp}.xbar.pe{i}",
distance_mm=0.0,
bw_gbs=_noc_xbar_bw,
routing_weight_mm=50.0,
kind="noc_to_xbar",
))
edges.append(Edge(
src=f"{cp}.xbar.pe{i}", dst=f"{cp}.noc",
distance_mm=0.0,
bw_gbs=_noc_xbar_bw,
routing_weight_mm=50.0,
kind="xbar_to_noc",
))
# m_cpu ↔ noc (command dispatch, both directions)
edges.append(Edge(
src=f"{cp}.m_cpu", dst=f"{cp}.noc",
distance_mm=clinks["m_cpu_to_noc_mm"],
kind="command",
))
edges.append(Edge(
src=f"{cp}.noc", dst=f"{cp}.m_cpu",
distance_mm=clinks["m_cpu_to_noc_mm"],
kind="command",
))
# noc ↔ sram (shared SRAM access; per_connection_bw_gbs = 128 GB/s, n_connections = 4)
_noc_sram = clinks["noc_to_sram"]
edges.append(Edge(
src=f"{cp}.noc", dst=f"{cp}.sram",
distance_mm=clinks["noc_to_sram_mm"],
bw_gbs=_noc_sram["per_connection_bw_gbs"],
n_connections=_noc_sram["n_connections"],
kind="noc_to_sram",
))
edges.append(Edge(
src=f"{cp}.sram", dst=f"{cp}.noc",
distance_mm=clinks["noc_to_sram_mm"],
bw_gbs=_noc_sram["per_connection_bw_gbs"],
n_connections=_noc_sram["n_connections"],
kind="noc_to_sram",
))
def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None:
"""Add PE-internal edges for a single PE instance."""
edges.append(Edge(
src=f"{pp}.pe_cpu", dst=f"{pp}.pe_scheduler",
distance_mm=pe_links["pe_cpu_to_scheduler_mm"],
kind="pe_internal",
))
for eng, key in [("pe_dma", "scheduler_to_dma_mm"),
("pe_gemm", "scheduler_to_gemm_mm"),
("pe_math", "scheduler_to_math_mm")]:
edges.append(Edge(
src=f"{pp}.pe_scheduler", dst=f"{pp}.{eng}",
distance_mm=pe_links[key],
kind="pe_internal",
))
for eng, mm_key, bw_key in [("pe_dma", "dma_to_tcm_mm", "dma_to_tcm_bw_gbs"),
("pe_gemm", "gemm_to_tcm_mm", "gemm_to_tcm_bw_gbs"),
("pe_math", "math_to_tcm_mm", "math_to_tcm_bw_gbs")]:
edges.append(Edge(
src=f"{pp}.{eng}", dst=f"{pp}.pe_tcm",
distance_mm=pe_links[mm_key],
bw_gbs=pe_links[bw_key],
kind="pe_internal",
))
# ── Inter-cube / IO / system edges ──────────────────────────────────
def _add_inter_cube_edges(
edges: list[Edge], sp: str, mesh_w: int, mesh_h: int, sip_spec: dict,
) -> None:
"""Add UCIe mesh edges between adjacent cubes within a SIP."""
mesh = sip_spec["links"]["inter_cube_mesh"]
bw = mesh["bw_gbs_per_ucie_phy"]
dist = mesh["distance_mm_across_seam"]
for row in range(mesh_h):
for col in range(mesh_w):
cid = row * mesh_w + col
if col + 1 < mesh_w:
nid = row * mesh_w + (col + 1)
edges.append(Edge(
src=f"{sp}.cube{cid}.ucie-E", dst=f"{sp}.cube{nid}.ucie-W",
distance_mm=dist, bw_gbs=bw, kind="ucie_mesh",
))
edges.append(Edge(
src=f"{sp}.cube{nid}.ucie-W", dst=f"{sp}.cube{cid}.ucie-E",
distance_mm=dist, bw_gbs=bw, kind="ucie_mesh",
))
if row + 1 < mesh_h:
nid = (row + 1) * mesh_w + col
edges.append(Edge(
src=f"{sp}.cube{cid}.ucie-S", dst=f"{sp}.cube{nid}.ucie-N",
distance_mm=dist, bw_gbs=bw, kind="ucie_mesh",
))
edges.append(Edge(
src=f"{sp}.cube{nid}.ucie-N", dst=f"{sp}.cube{cid}.ucie-S",
distance_mm=dist, bw_gbs=bw, kind="ucie_mesh",
))
def _add_io_to_cube_edges(
edges: list[Edge], sp: str, sip_spec: dict, mesh_w: int,
) -> None:
"""Add IO chiplet io_cpu ↔ cube UCIe edges (bidirectional for response)."""
io_links = sip_spec["iochiplet"]["links"]
io_to_ucie_mm = io_links["io_cpu_to_ucie_mm"]
io_to_ucie_bw = io_links["io_cpu_to_ucie_bw_gbs"]
for inst in sip_spec["iochiplet"]["instances"]:
iid = inst["id"]
io_cpu_id = f"{sp}.{iid}.io_cpu"
for port in inst["cube_ports"]:
cube_col, cube_row = port["cube"]["xy"]
cube_id = cube_row * mesh_w + cube_col
cube_side = port["cube_side"]
ucie_id = f"{sp}.cube{cube_id}.ucie-{cube_side}"
edges.append(Edge(
src=io_cpu_id, dst=ucie_id,
distance_mm=io_to_ucie_mm + port["distance_mm"],
bw_gbs=io_to_ucie_bw,
kind="io_to_cube",
))
edges.append(Edge(
src=ucie_id, dst=io_cpu_id,
distance_mm=io_to_ucie_mm + port["distance_mm"],
bw_gbs=io_to_ucie_bw,
kind="cube_to_io",
))
def _add_system_to_io_edges(
edges: list[Edge], sp: str, sip_spec: dict, system: dict,
) -> None:
"""Add fabric switch → IO chiplet PCIe edges."""
sw_id = "fabric.switch0"
sys_link = system["links"]["io_ep_to_switch"]
for inst in sip_spec["iochiplet"]["instances"]:
pcie_ep_id = f"{sp}.{inst['id']}.pcie_ep"
edges.append(Edge(
src=sw_id, dst=pcie_ep_id,
distance_mm=sys_link["distance_mm"],
bw_gbs=sys_link["bw_gbs_per_ep"],
kind="pcie",
))
# ── View builders ────────────────────────────────────────────────────
def _build_system_view(spec: dict) -> ViewGraph:
"""System-level view: SIP blocks, IO chiplets, fabric switch."""
system = spec["system"]
sip_count = system["sips"]["count"]
sip_w, sip_h = 71.0, 59.0
gap = 30.0
canvas_w = sip_count * sip_w + (sip_count - 1) * gap
canvas_h = sip_h + 20.0
nodes: dict[str, Node] = {}
view_edges: list[Edge] = []
sw = system["components"]["switch"]
sw_id = "fabric.switch0"
nodes[sw_id] = Node(
id=sw_id, kind=sw["kind"], impl=sw["impl"],
attrs=sw.get("attrs", {}), pos_mm=(canvas_w / 2, 5.0), label="Fabric Switch",
)
for s in range(sip_count):
sx = s * (sip_w + gap)
sy = 20.0
sip_id = f"sip{s}"
nodes[sip_id] = Node(
id=sip_id, kind="sip", impl="",
attrs={"w_mm": sip_w, "h_mm": sip_h},
pos_mm=(sx + sip_w / 2, sy + sip_h / 2),
label=f"SIP {s}",
)
for inst in spec["sip"]["iochiplet"]["instances"]:
iid = inst["id"]
io_nid = f"{sip_id}.{iid}"
side = inst["place"]["side"]
iy = sy if side == "N" else sy + sip_h
nodes[io_nid] = Node(
id=io_nid, kind="iochiplet", impl="",
attrs={}, pos_mm=(sx + sip_w / 2, iy), label=f"IO {iid}",
)
view_edges.append(Edge(
src=sw_id, dst=io_nid,
distance_mm=system["links"]["io_ep_to_switch"]["distance_mm"],
bw_gbs=system["links"]["io_ep_to_switch"]["bw_gbs_per_ep"],
kind="pcie",
))
return ViewGraph(
name="system", nodes=nodes, edges=view_edges,
width_mm=canvas_w, height_mm=canvas_h,
)
def _build_sip_view(spec: dict) -> ViewGraph:
"""SIP-level view: cube mesh + IO chiplets (representative, sip0)."""
sip_spec = spec["sip"]
cube_spec = spec["cube"]
mesh_w = sip_spec["cube_mesh"]["w"]
mesh_h = sip_spec["cube_mesh"]["h"]
cube_w = cube_spec["geometry"]["cube_mm"]["w"]
cube_h = cube_spec["geometry"]["cube_mm"]["h"]
seam = sip_spec["links"]["inter_cube_mesh"]["distance_mm_across_seam"]
stride_x = cube_w + seam
stride_y = cube_h + seam
mesh_total_w = mesh_w * cube_w + (mesh_w - 1) * seam
mesh_total_h = mesh_h * cube_h + (mesh_h - 1) * seam
io_margin = 6.0
canvas_w = mesh_total_w
canvas_h = mesh_total_h + 2 * io_margin
nodes: dict[str, Node] = {}
view_edges: list[Edge] = []
# Cubes as opaque blocks
for row in range(mesh_h):
for col in range(mesh_w):
cid = row * mesh_w + col
cx = col * stride_x + cube_w / 2
cy = io_margin + row * stride_y + cube_h / 2
nid = f"cube{cid}"
nodes[nid] = Node(
id=nid, kind="cube", impl="",
attrs={"w_mm": cube_w, "h_mm": cube_h, "col": col, "row": row},
pos_mm=(cx, cy), label=f"CUBE ({col},{row})",
)
# Inter-cube mesh edges
mesh_link = sip_spec["links"]["inter_cube_mesh"]
for row in range(mesh_h):
for col in range(mesh_w):
cid = row * mesh_w + col
if col + 1 < mesh_w:
nid = row * mesh_w + (col + 1)
view_edges.append(Edge(
src=f"cube{cid}", dst=f"cube{nid}",
distance_mm=mesh_link["distance_mm_across_seam"],
bw_gbs=mesh_link["bw_gbs_per_ucie_phy"],
kind="ucie_mesh",
))
if row + 1 < mesh_h:
nid = (row + 1) * mesh_w + col
view_edges.append(Edge(
src=f"cube{cid}", dst=f"cube{nid}",
distance_mm=mesh_link["distance_mm_across_seam"],
bw_gbs=mesh_link["bw_gbs_per_ucie_phy"],
kind="ucie_mesh",
))
# IO chiplets
io_links = sip_spec["iochiplet"]["links"]
for inst in sip_spec["iochiplet"]["instances"]:
iid = inst["id"]
side = inst["place"]["side"]
iy = 2.0 if side == "N" else canvas_h - 2.0
nodes[iid] = Node(
id=iid, kind="iochiplet", impl="",
attrs={}, pos_mm=(mesh_total_w / 2, iy), label=f"IO {iid}",
)
for port in inst["cube_ports"]:
cube_col, cube_row = port["cube"]["xy"]
cube_id = cube_row * mesh_w + cube_col
view_edges.append(Edge(
src=iid, dst=f"cube{cube_id}",
distance_mm=io_links["io_cpu_to_ucie_mm"] + port["distance_mm"],
bw_gbs=io_links["io_cpu_to_ucie_bw_gbs"],
kind="io_to_cube",
))
return ViewGraph(
name="sip", nodes=nodes, edges=view_edges,
width_mm=canvas_w, height_mm=canvas_h,
)
def _build_cube_view(spec: dict) -> ViewGraph:
"""Cube-level view: representative single cube, PEs as opaque blocks."""
cube = spec["cube"]
cube_w = cube["geometry"]["cube_mm"]["w"]
cube_h = cube["geometry"]["cube_mm"]["h"]
local_pos = _cube_local_positions(cube_w, cube_h)
clinks = cube["links"]
n_slices = cube["memory_map"]["hbm_slices_per_cube"]
nodes: dict[str, Node] = {}
view_edges: list[Edge] = []
# UCIe ports
for port in cube["ucie"]["ports"]:
pid = f"ucie-{port}"
lx, ly = local_pos[pid]
nodes[pid] = Node(
id=pid, kind="ucie_port", impl="ucie_v1",
attrs={}, pos_mm=(lx, ly), label=f"UCIe-{port}",
)
# Named components (hbm_ctrl as single representative node in view)
for name in ("noc", "m_cpu", "hbm_ctrl", "sram"):
c = cube["components"][name]
lx, ly = local_pos[name]
nodes[name] = Node(
id=name, kind=c["kind"], impl=c["impl"],
attrs=c["attrs"], pos_mm=(lx, ly),
label=name.upper().replace("_", " "),
)
# Bridges
for br in cube["components"]["xbar"]["bridges"]:
bname = br["id"]
bid = f"bridge.{bname}"
lx, ly = local_pos[bid]
nodes[bid] = Node(
id=bid, kind=br["kind"], impl=br["impl"],
attrs=br["attrs"], pos_mm=(lx, ly),
label=f"Bridge {bname.upper()}",
)
# PEs as opaque blocks + per-PE xbar entry nodes
corners = cube["pe_layout"]["corners"]
pe_per_corner = cube["pe_layout"]["pe_per_corner"]
corner_pos = _corner_pe_positions(cube_w, cube_h)
xbar_pe_spec = cube["components"]["xbar"]["pe"]
xbar_top_y = local_pos["xbar.top"][1]
xbar_bot_y = local_pos["xbar.bottom"][1]
pe_idx = 0
for corner in corners:
is_top = corner in ("NW", "NE")
xbar_y = xbar_top_y if is_top else xbar_bot_y
mm_key = "pe_to_xbar_row_n_mm" if is_top else "pe_to_xbar_row_s_mm"
for ci in range(pe_per_corner):
pid = f"pe{pe_idx}"
xbar_id = f"xbar.pe{pe_idx}"
px, py = corner_pos[corner][ci]
nodes[pid] = Node(
id=pid, kind="pe", impl="",
attrs={"corner": corner}, pos_mm=(px, py),
label=f"PE{pe_idx}",
)
nodes[xbar_id] = Node(
id=xbar_id, kind=xbar_pe_spec["kind"], impl=xbar_pe_spec["impl"],
attrs=xbar_pe_spec["attrs"], pos_mm=(px, xbar_y),
label=f"XBAR PE{pe_idx}",
)
# PE → xbar.pe_i (HBM data path)
view_edges.append(Edge(
src=pid, dst=xbar_id,
distance_mm=clinks[mm_key],
bw_gbs=clinks["pe_to_xbar_bw_gbs"],
kind="pe_to_xbar",
))
# PE → noc (non-HBM data path)
view_edges.append(Edge(
src=pid, dst="noc",
distance_mm=clinks["pe_dma_to_noc_mm"],
bw_gbs=clinks["pe_dma_to_noc_bw_gbs"],
kind="pe_to_noc",
))
# noc → PE (command delivery)
view_edges.append(Edge(
src="noc", dst=pid,
distance_mm=clinks["noc_to_pe_cpu_mm"],
kind="command",
))
pe_idx += 1
# Cube fabric edges
# xbar.pe_i → hbm_ctrl (single representative node in view)
for i in range(n_slices):
view_edges.append(Edge(
src=f"xbar.pe{i}", dst="hbm_ctrl",
distance_mm=clinks["xbar_to_hbm_mm"],
bw_gbs=clinks["xbar_to_hbm_bw_gbs"],
kind="xbar_to_hbm",
))
# xbar chain
half = n_slices // 2
for half_start in (0, half):
for i in range(half_start, half_start + half - 1):
intra = ((i - half_start) % pe_per_corner) != (pe_per_corner - 1)
x_dist = clinks["xbar_chain_intra_corner_mm"] if intra else clinks["xbar_chain_inter_corner_mm"]
for a, b in [(i, i + 1), (i + 1, i)]:
view_edges.append(Edge(
src=f"xbar.pe{a}", dst=f"xbar.pe{b}",
distance_mm=x_dist,
bw_gbs=clinks["xbar_x_bw_gbs"],
kind="xbar_chain",
))
# bridge connections
for bname, pe_top, pe_bot in [("left", 0, half), ("right", half - 1, n_slices - 1)]:
br_id = f"bridge.{bname}"
for pe_i, br_mm_key in [(pe_top, "xbar_row_n_to_bridge_mm"),
(pe_bot, "xbar_row_s_to_bridge_mm")]:
xbar_id = f"xbar.pe{pe_i}"
view_edges.append(Edge(
src=xbar_id, dst=br_id,
distance_mm=clinks[br_mm_key],
bw_gbs=clinks["xbar_to_bridge_bw_gbs"],
kind="xbar_to_bridge",
))
view_edges.append(Edge(
src=br_id, dst=xbar_id,
distance_mm=clinks[br_mm_key],
bw_gbs=clinks["xbar_to_bridge_bw_gbs"],
kind="bridge_to_xbar",
))
_noc_ucie_v = clinks["noc_to_ucie"]
for port in cube["ucie"]["ports"]:
view_edges.append(Edge(
src="noc", dst=f"ucie-{port}",
distance_mm=0.0,
bw_gbs=_noc_ucie_v["per_connection_bw_gbs"],
n_connections=_noc_ucie_v["n_connections"],
kind="noc_to_ucie",
))
# m_cpu ↔ noc (command dispatch, both directions)
view_edges.append(Edge(
src="m_cpu", dst="noc",
distance_mm=clinks["m_cpu_to_noc_mm"],
kind="command",
))
view_edges.append(Edge(
src="noc", dst="m_cpu",
distance_mm=clinks["m_cpu_to_noc_mm"],
kind="command",
))
# noc ↔ sram (shared SRAM access, bidirectional)
_noc_sram_v = clinks["noc_to_sram"]
view_edges.append(Edge(
src="noc", dst="sram",
distance_mm=clinks["noc_to_sram_mm"],
bw_gbs=_noc_sram_v["per_connection_bw_gbs"],
n_connections=_noc_sram_v["n_connections"],
kind="noc_to_sram",
))
view_edges.append(Edge(
src="sram", dst="noc",
distance_mm=clinks["noc_to_sram_mm"],
bw_gbs=_noc_sram_v["per_connection_bw_gbs"],
n_connections=_noc_sram_v["n_connections"],
kind="noc_to_sram",
))
return ViewGraph(
name="cube", nodes=nodes, edges=view_edges,
width_mm=cube_w, height_mm=cube_h,
)
def _build_pe_view(spec: dict) -> ViewGraph:
"""PE-level view: representative single PE with all template components."""
pe_tmpl = spec["cube"]["pe_template"]
pe_links = pe_tmpl["links"]
canvas_w, canvas_h = 12.0, 8.0
positions = {
"pe_cpu": (1.5, 4.0),
"pe_scheduler": (4.0, 4.0),
"pe_dma": (7.0, 1.5),
"pe_gemm": (7.0, 4.0),
"pe_math": (7.0, 6.5),
"pe_tcm": (10.0, 4.0),
}
nodes: dict[str, Node] = {}
view_edges: list[Edge] = []
for comp_name, comp_spec in pe_tmpl["components"].items():
px, py = positions[comp_name]
nodes[comp_name] = Node(
id=comp_name, kind=comp_spec["kind"], impl=comp_spec["impl"],
attrs=comp_spec["attrs"], pos_mm=(px, py),
label=comp_name.upper().replace("_", " "),
)
view_edges.append(Edge(
src="pe_cpu", dst="pe_scheduler",
distance_mm=pe_links["pe_cpu_to_scheduler_mm"],
kind="pe_internal",
))
for eng, key in [("pe_dma", "scheduler_to_dma_mm"),
("pe_gemm", "scheduler_to_gemm_mm"),
("pe_math", "scheduler_to_math_mm")]:
view_edges.append(Edge(
src="pe_scheduler", dst=eng,
distance_mm=pe_links[key],
kind="pe_internal",
))
for eng, mm_key, bw_key in [("pe_dma", "dma_to_tcm_mm", "dma_to_tcm_bw_gbs"),
("pe_gemm", "gemm_to_tcm_mm", "gemm_to_tcm_bw_gbs"),
("pe_math", "math_to_tcm_mm", "math_to_tcm_bw_gbs")]:
view_edges.append(Edge(
src=eng, dst="pe_tcm",
distance_mm=pe_links[mm_key],
bw_gbs=pe_links[bw_key],
kind="pe_internal",
))
return ViewGraph(
name="pe", nodes=nodes, edges=view_edges,
width_mm=canvas_w, height_mm=canvas_h,
)
View File
+56
View File
@@ -0,0 +1,56 @@
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class Node:
id: str
kind: str
impl: str
attrs: dict[str, Any]
pos_mm: tuple[float, float] | None # (x_mm, y_mm); None for abstract nodes
label: str = ""
@dataclass
class Edge:
src: str # node id
dst: str # node id
distance_mm: float # physical wire delay distance (ns = distance_mm * ns_per_mm)
routing_weight_mm: float | None = None # Dijkstra cost; None → use distance_mm
bw_gbs: float | None = None
n_connections: int | None = None # multi-connection links; single request uses 1 connection
kind: str = "link"
@dataclass
class ViewGraph:
name: str # "system" | "sip" | "cube" | "pe"
nodes: dict[str, Node]
edges: list[Edge]
width_mm: float
height_mm: float
@dataclass
class TopologyGraph:
spec: dict[str, Any]
# Full instantiated flat graph (used by sim_engine)
nodes: dict[str, Node] = field(default_factory=dict)
edges: list[Edge] = field(default_factory=list)
# Representative view projections (used by visualizer)
system_view: ViewGraph | None = None
sip_view: ViewGraph | None = None
cube_view: ViewGraph | None = None
pe_view: ViewGraph | None = None
@dataclass(frozen=True)
class TopologyHandle:
path: Path
topology_obj: TopologyGraph | None # None until _compile_graph is implemented
+367
View File
@@ -0,0 +1,367 @@
# kernbench/topology/visualizer.py
"""
SVG diagram generator for TopologyGraph views.
Produces mm-accurate, deterministic SVG files for each view level
(system, SIP, cube, PE) per ADR-0005 and ADR-0006.
"""
from __future__ import annotations
from pathlib import Path
from .types import Edge, Node, TopologyGraph, ViewGraph
# ── Color palette by component kind ─────────────────────────────────
_KIND_COLORS: dict[str, str] = {
"switch": "#6366f1", # indigo
"sip": "#e0e7ff", # light indigo
"iochiplet": "#0ea5e9", # sky blue
"pcie_ep": "#0ea5e9",
"io_cpu": "#0ea5e9",
"ucie_port": "#3b82f6", # blue
"noc": "#a78bfa", # purple
"m_cpu": "#f59e0b", # amber
"xbar": "#f97316", # orange
"hbm_ctrl": "#10b981", # emerald
"pe": "#94a3b8", # slate
"pe_cpu": "#ef4444", # red
"pe_scheduler": "#f59e0b", # amber
"pe_dma": "#3b82f6", # blue
"pe_gemm": "#8b5cf6", # violet
"pe_math": "#ec4899", # pink
"pe_tcm": "#10b981", # emerald
"sram": "#f59e0b", # amber
"cube": "#cbd5e1", # slate-300
}
_EDGE_COLORS: dict[str, str] = {
"pcie": "#6366f1",
"io_internal": "#0ea5e9",
"io_to_cube": "#0ea5e9",
"ucie_mesh": "#3b82f6",
"pe_to_xbar": "#f97316",
"xbar_to_hbm": "#10b981",
"xbar_to_bridge": "#a78bfa",
"bridge_to_xbar": "#a78bfa",
"noc_to_ucie": "#a78bfa",
"pe_to_noc": "#a78bfa",
"noc_to_sram": "#f59e0b",
"command": "#f59e0b",
"pe_internal": "#94a3b8",
}
# ── Node sizing ──────────────────────────────────────────────────────
_DEFAULT_NODE_W = 2.0 # mm
_DEFAULT_NODE_H = 1.2 # mm
_KIND_SIZE: dict[str, tuple[float, float]] = {
"sip": (60.0, 50.0),
"cube": (6.0, 4.0),
"iochiplet": (4.0, 1.5),
"switch": (5.0, 1.5),
}
# ── Public API ───────────────────────────────────────────────────────
def emit_diagrams(graph: TopologyGraph, out_dir: Path) -> list[Path]:
"""Generate SVG diagrams for all views. Returns list of created file paths."""
out_dir.mkdir(parents=True, exist_ok=True)
created: list[Path] = []
views = [
("system_view", graph.system_view),
("sip_view", graph.sip_view),
("cube_view", graph.cube_view),
("pe_view", graph.pe_view),
]
for name, view in views:
if view is None:
continue
svg = _render_view_svg(view)
path = out_dir / f"{name}.svg"
path.write_text(svg, encoding="utf-8")
created.append(path)
return created
# ── SVG rendering ────────────────────────────────────────────────────
def _render_view_svg(view: ViewGraph) -> str:
"""Render a ViewGraph to an SVG string."""
scale = _pick_scale(view)
pad = 40 # px padding
node_sizes = _compute_node_sizes(view, scale)
# Canvas size in px
w_px = int(view.width_mm * scale + 2 * pad)
h_px = int(view.height_mm * scale + 2 * pad)
parts: list[str] = []
parts.append(_svg_header(w_px, h_px, view.name))
# Background
parts.append(f' <rect width="{w_px}" height="{h_px}" fill="#f8fafc"/>')
# Title
parts.append(
f' <text x="{w_px // 2}" y="18" text-anchor="middle" '
f'font-family="monospace" font-size="14" font-weight="bold" fill="#1e293b">'
f'{view.name.upper()} VIEW</text>'
)
# Special: draw cube boundary + HBM block background in cube view
if view.name == "cube":
_draw_cube_boundary(parts, view, scale, pad)
_draw_hbm_block(parts, view, scale, pad)
# Edges (draw before nodes so nodes are on top)
# Track fan-out edges to assign per-edge offsets
fanout_counter: dict[str, int] = {}
for edge in view.edges:
if edge.src in view.nodes and edge.dst in view.nodes:
_draw_edge(parts, edge, view, node_sizes, scale, pad, fanout_counter)
# Nodes
for node in view.nodes.values():
_draw_node(parts, node, node_sizes, scale, pad)
parts.append("</svg>")
return "\n".join(parts)
def _pick_scale(view: ViewGraph) -> float:
"""Pixels per mm, chosen per view type."""
return {
"system": 4.0,
"sip": 8.0,
"cube": 28.0,
"pe": 35.0,
}.get(view.name, 10.0)
def _compute_node_sizes(
view: ViewGraph, scale: float,
) -> dict[str, tuple[float, float]]:
"""Returns (w_px, h_px) for each node."""
sizes: dict[str, tuple[float, float]] = {}
for nid, node in view.nodes.items():
w_mm, h_mm = _KIND_SIZE.get(node.kind, (_DEFAULT_NODE_W, _DEFAULT_NODE_H))
# For cube view, use smaller PE nodes
if view.name == "cube" and node.kind == "pe":
w_mm, h_mm = 1.8, 1.0
if view.name == "pe":
w_mm, h_mm = 2.5, 1.4
sizes[nid] = (w_mm * scale, h_mm * scale)
return sizes
def _svg_header(w: int, h: int, title: str) -> str:
return (
f'<svg xmlns="http://www.w3.org/2000/svg" '
f'width="{w}" height="{h}" viewBox="0 0 {w} {h}">\n'
f' <title>{title}</title>'
)
def _draw_cube_boundary(
parts: list[str], view: ViewGraph, scale: float, pad: int,
) -> None:
"""Draw the cube die outline as a dashed rectangle."""
bx = pad
by = pad
bw = view.width_mm * scale
bh = view.height_mm * scale
parts.append(
f' <rect x="{bx:.1f}" y="{by:.1f}" '
f'width="{bw:.1f}" height="{bh:.1f}" '
f'rx="6" fill="none" stroke="#475569" stroke-width="2" '
f'stroke-dasharray="8,4"/>'
)
def _draw_hbm_block(
parts: list[str], view: ViewGraph, scale: float, pad: int,
) -> None:
"""Draw HBM area as a filled rectangle in cube view."""
# HBM area: centered at (8.5, 7.0), size 9x5 -> x=[4.0,13.0], y=[4.5,9.5]
hbm_x = 4.0 * scale + pad
hbm_y = 4.5 * scale + pad
hbm_w = 9.0 * scale
hbm_h = 5.0 * scale
parts.append(
f' <rect x="{hbm_x:.1f}" y="{hbm_y:.1f}" '
f'width="{hbm_w:.1f}" height="{hbm_h:.1f}" '
f'rx="4" fill="#d1fae5" stroke="#10b981" stroke-width="1.5" '
f'stroke-dasharray="6,3" opacity="0.5"/>'
)
cx = 8.5 * scale + pad
cy = 8.5 * scale + pad
parts.append(
f' <text x="{cx:.1f}" y="{cy:.1f}" text-anchor="middle" '
f'font-family="monospace" font-size="11" fill="#047857" opacity="0.7">'
f'HBM</text>'
)
def _draw_node(
parts: list[str],
node: Node,
sizes: dict[str, tuple[float, float]],
scale: float,
pad: int,
) -> None:
"""Draw a single node as a rounded rectangle with label."""
if node.pos_mm is None:
return
px = node.pos_mm[0] * scale + pad
py = node.pos_mm[1] * scale + pad
w, h = sizes.get(node.id, (40, 24))
x = px - w / 2
y = py - h / 2
fill = _KIND_COLORS.get(node.kind, "#e2e8f0")
text_color = "#ffffff" if _is_dark(fill) else "#1e293b"
parts.append(
f' <rect x="{x:.1f}" y="{y:.1f}" width="{w:.1f}" height="{h:.1f}" '
f'rx="4" fill="{fill}" stroke="#475569" stroke-width="1"/>'
)
label = node.label or node.id
font_size = _label_font_size(w, label)
parts.append(
f' <text x="{px:.1f}" y="{py + 4:.1f}" text-anchor="middle" '
f'font-family="monospace" font-size="{font_size}" fill="{text_color}">'
f'{_escape(label)}</text>'
)
# ── Fan-out edge kinds that need offset routing ─────────────────────
_FANOUT_KINDS = {"pe_to_xbar", "pe_to_noc", "command", "noc_to_ucie"}
def _draw_edge(
parts: list[str],
edge: Edge,
view: ViewGraph,
sizes: dict[str, tuple[float, float]],
scale: float,
pad: int,
fanout_counter: dict[str, int],
) -> None:
"""Draw an edge with orthogonal (90-degree) routing for fan-out kinds."""
nodes = view.nodes
src_node = nodes[edge.src]
dst_node = nodes[edge.dst]
if src_node.pos_mm is None or dst_node.pos_mm is None:
return
x1 = src_node.pos_mm[0] * scale + pad
y1 = src_node.pos_mm[1] * scale + pad
x2 = dst_node.pos_mm[0] * scale + pad
y2 = dst_node.pos_mm[1] * scale + pad
color = _EDGE_COLORS.get(edge.kind, "#94a3b8")
width = "1.5" if edge.kind == "pe_internal" else "1"
opacity = "0.6" if edge.kind in ("command", "noc_to_ucie") else "0.8"
if edge.kind in _FANOUT_KINDS and view.name == "cube":
# Orthogonal routing: src→horizontal→vertical→dst with per-edge offset.
group_key = f"{edge.kind}:{edge.dst}"
idx = fanout_counter.get(group_key, 0)
fanout_counter[group_key] = idx + 1
# Route: go vertically from src to a staggered horizontal channel,
# then horizontally to dst x, then vertically to dst.
mid_y = (y1 + y2) / 2 + (idx - 1.5) * 10 # spread channels vertically
parts.append(
f' <polyline points="{x1:.1f},{y1:.1f} {x1:.1f},{mid_y:.1f} '
f'{x2:.1f},{mid_y:.1f} {x2:.1f},{y2:.1f}" '
f'fill="none" stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
)
# Label on the horizontal segment
if edge.distance_mm > 0:
lx = (x1 + x2) / 2
label = f"{edge.distance_mm:.1f}mm"
if edge.bw_gbs:
label += f" {edge.bw_gbs:.0f}GB/s"
parts.append(
f' <text x="{lx:.1f}" y="{mid_y - 3:.1f}" text-anchor="middle" '
f'font-family="monospace" font-size="7" fill="#64748b">'
f'{label}</text>'
)
return
# Non-fanout: orthogonal L-bend
if abs(x2 - x1) > 1 and abs(y2 - y1) > 1:
# PE view: vertical-first for left→right edges (scheduler→engines),
# horizontal-first for right→right edges (engines→tcm)
if view.name == "pe":
if src_node.pos_mm[0] < view.width_mm / 2:
# Source in left half: vertical-first (scheduler fan-out)
parts.append(
f' <polyline points="{x1:.1f},{y1:.1f} {x1:.1f},{y2:.1f} {x2:.1f},{y2:.1f}" '
f'fill="none" stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
)
else:
# Source in right half: horizontal-first (dma/math→tcm)
parts.append(
f' <polyline points="{x1:.1f},{y1:.1f} {x2:.1f},{y1:.1f} {x2:.1f},{y2:.1f}" '
f'fill="none" stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
)
else:
parts.append(
f' <polyline points="{x1:.1f},{y1:.1f} {x2:.1f},{y1:.1f} {x2:.1f},{y2:.1f}" '
f'fill="none" stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
)
else:
parts.append(
f' <line x1="{x1:.1f}" y1="{y1:.1f}" x2="{x2:.1f}" y2="{y2:.1f}" '
f'stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
)
# Distance label at midpoint
if edge.distance_mm > 0:
mx = (x1 + x2) / 2
my = (y1 + y2) / 2
label = f"{edge.distance_mm:.1f}mm"
if edge.bw_gbs:
label += f" {edge.bw_gbs:.0f}GB/s"
parts.append(
f' <text x="{mx:.1f}" y="{my - 4:.1f}" text-anchor="middle" '
f'font-family="monospace" font-size="7" fill="#64748b">'
f'{label}</text>'
)
# ── Helpers ──────────────────────────────────────────────────────────
def _is_dark(hex_color: str) -> bool:
"""Check if a hex color is dark (for white text)."""
h = hex_color.lstrip("#")
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
return (r * 0.299 + g * 0.587 + b * 0.114) < 140
def _label_font_size(box_width: float, label: str) -> int:
"""Choose font size to fit label in box."""
char_w = len(label) * 7
if char_w > box_width * 0.9:
return max(7, int(box_width * 0.9 / len(label) * 1.4))
return 10
def _escape(text: str) -> str:
"""Escape XML special characters."""
return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
+11
View File
@@ -0,0 +1,11 @@
"""Triton emulator: fake tl module for kernel performance simulation.
Provides TLContext (the fake `tl` parameter) that kernels use to express
memory access patterns and compute operations. Kernel functions are plain
Python — no yield, no async — and generate a PeCommand trace that PE_CPU
replays through SimPy.
Usage:
from kernbench.triton_emu.registry import register_kernel, get_kernel
from kernbench.triton_emu.tl_context import TLContext
"""
+30
View File
@@ -0,0 +1,30 @@
"""Kernel registry: maps kernel names to Python callable generators.
Benchmarks register kernel functions here; PE_CPU looks them up by
KernelRef.name at execution time.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
_kernels: dict[str, Callable[..., None]] = {}
def register_kernel(name: str, fn: Callable[..., None]) -> None:
"""Register a kernel function by name."""
if name in _kernels:
raise ValueError(f"kernel '{name}' already registered")
_kernels[name] = fn
def get_kernel(name: str) -> Callable[..., None]:
"""Look up a registered kernel function by name."""
if name not in _kernels:
raise KeyError(f"kernel '{name}' not registered")
return _kernels[name]
def clear_registry() -> None:
"""Clear all registered kernels (for testing)."""
_kernels.clear()
+356
View File
@@ -0,0 +1,356 @@
"""TLContext: fake Triton Language module for kernel performance simulation.
Passed as the `tl` parameter to kernel functions. Each API call records a
PeCommand in the internal trace. After the kernel returns, PE_CPU replays
the command list through SimPy.
Kernel code looks like standard Python — no yield, no async:
def my_kernel(a_ptr, b_ptr, out_ptr, tl):
pid = tl.program_id(0)
a = tl.load(a_ptr, shape=(32, 64), dtype="f16")
b = tl.load(b_ptr + pid * stride, shape=(64, 32), dtype="f16")
tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
"""
from __future__ import annotations
import math
from typing import Literal
from kernbench.common.pe_commands import (
CompletionHandle,
CompositeCmd,
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
MathCmd,
PeCommand,
PeCpuOverheadCmd,
TensorHandle,
WaitCmd,
)
_DTYPE_BYTES: dict[str, int] = {
"f16": 2, "f32": 4, "f64": 8,
"bf16": 2,
"i8": 1, "i16": 2, "i32": 4, "i64": 8,
"u8": 1, "u16": 2, "u32": 4, "u64": 8,
}
class TLContext:
"""Fake Triton Language context.
Args:
pe_id: program instance index (returned by program_id).
num_programs: total number of program instances.
dispatch_cycles: PE_CPU overhead per tl API call (auto-inserted).
"""
def __init__(
self,
pe_id: int = 0,
num_programs: int = 1,
dispatch_cycles: int = 1,
) -> None:
self._pe_id = pe_id
self._num_programs = num_programs
self._dispatch_cycles = dispatch_cycles
self._commands: list[PeCommand] = []
self._handle_counter = 0
self._completion_counter = 0
@property
def commands(self) -> list[PeCommand]:
"""Return the recorded command trace."""
return self._commands
# ── helpers ────────────────────────────────────────────────────
def _next_handle_id(self) -> str:
self._handle_counter += 1
return f"t{self._handle_counter}"
def _next_completion_id(self) -> str:
self._completion_counter += 1
return f"c{self._completion_counter}"
def _dtype_bytes(self, dtype: str) -> int:
return _DTYPE_BYTES.get(dtype, 2)
def _nbytes(self, shape: tuple[int, ...], dtype: str) -> int:
return math.prod(shape) * self._dtype_bytes(dtype)
def _emit_dispatch_overhead(self) -> None:
if self._dispatch_cycles > 0:
self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
def _make_handle(
self, pa: int, shape: tuple[int, ...], dtype: str,
) -> TensorHandle:
return TensorHandle(
id=self._next_handle_id(),
pa=pa, shape=shape, dtype=dtype,
nbytes=self._nbytes(shape, dtype),
)
# ── Reference (no DMA, metadata only) ────────────────────────
def ref(
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
) -> TensorHandle:
"""Create a TensorHandle referencing HBM data without issuing DMA.
Used when the scheduler will stream data per-tile (e.g., tensor b
in a composite GEMM). No command is generated.
"""
return self._make_handle(pa=ptr, shape=shape, dtype=dtype)
# ── Data Movement (blocking, DMA engine) ──────────────────────
def load(
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
) -> TensorHandle:
"""Load tensor from HBM to TCM. Returns TensorHandle."""
self._emit_dispatch_overhead()
handle = self._make_handle(pa=ptr, shape=shape, dtype=dtype)
self._commands.append(DmaReadCmd(
handle=handle, src_pa=ptr, nbytes=handle.nbytes,
))
return handle
def store(self, ptr: int, handle: TensorHandle) -> None:
"""Store tensor from TCM to HBM."""
self._emit_dispatch_overhead()
self._commands.append(DmaWriteCmd(
handle=handle, dst_pa=ptr, nbytes=handle.nbytes,
))
# ── GEMM Engine (blocking) ────────────────────────────────────
def dot(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
"""Matrix multiply: out = a @ b. Both operands must be in TCM.
a: (M, K), b: (K, N) → out: (M, N)
"""
if len(a.shape) < 2 or len(b.shape) < 2:
raise ValueError("dot requires 2D tensors")
m, k = a.shape[-2], a.shape[-1]
k2, n = b.shape[-2], b.shape[-1]
if k != k2:
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
out_shape = (*a.shape[:-2], m, n)
out_dtype = a.dtype
out = self._make_handle(pa=0, shape=out_shape, dtype=out_dtype)
self._emit_dispatch_overhead()
self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
return out
# ── MATH Engine: unary (blocking) ─────────────────────────────
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
out = self._make_handle(pa=0, shape=x.shape, dtype=x.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(x,), out=out))
return out
def exp(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("exp", x)
def log(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("log", x)
def sqrt(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("sqrt", x)
def abs(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("abs", x)
def sigmoid(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("sigmoid", x)
def cos(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("cos", x)
def sin(self, x: TensorHandle) -> TensorHandle:
return self._unary_math("sin", x)
# ── MATH Engine: reduction (blocking) ─────────────────────────
def _reduction(
self, op: str, x: TensorHandle, axis: int,
) -> TensorHandle:
out_shape = list(x.shape)
out_shape[axis] = 1
out = self._make_handle(pa=0, shape=tuple(out_shape), dtype=x.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
return out
def sum(self, x: TensorHandle, axis: int) -> TensorHandle:
return self._reduction("sum", x, axis)
def max(self, x: TensorHandle, axis: int) -> TensorHandle:
return self._reduction("max", x, axis)
def min(self, x: TensorHandle, axis: int) -> TensorHandle:
return self._reduction("min", x, axis)
# ── MATH Engine: binary (blocking) ────────────────────────────
def _binary_math(
self, op: str, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op=op, inputs=(a, b), out=out))
return out
def where(
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
) -> TensorHandle:
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
self._emit_dispatch_overhead()
self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out))
return out
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
def program_id(self, axis: int = 0) -> int:
"""Return program instance index."""
return self._pe_id
def num_programs(self, axis: int = 0) -> int:
"""Return total number of program instances."""
return self._num_programs
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
"""Create index range tensor in TCM."""
n = end - start
return self._make_handle(pa=0, shape=(n,), dtype=dtype)
def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
"""Create zero-filled tensor in TCM."""
return self._make_handle(pa=0, shape=shape, dtype=dtype)
def full(
self, shape: tuple[int, ...], value: float | int, dtype: str = "f16",
) -> TensorHandle:
"""Create constant-filled tensor in TCM."""
return self._make_handle(pa=0, shape=shape, dtype=dtype)
# ── Metadata (no compute, no DMA) ─────────────────────────────
def trans(self, x: TensorHandle) -> TensorHandle:
"""Transpose — shape change only, no command generated."""
if len(x.shape) < 2:
raise ValueError("trans requires at least 2D tensor")
new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2])
return TensorHandle(
id=x.id, pa=x.pa, shape=new_shape,
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
)
# ── Composite + Control ───────────────────────────────────────
def composite(
self,
op: Literal["gemm", "math"],
a: TensorHandle,
b: TensorHandle | None = None,
out_ptr: int = 0,
math_op: str | None = None,
) -> CompletionHandle:
"""Submit a composite command (non-blocking, tiled pipeline).
Returns CompletionHandle for use with wait().
"""
# Compute output size based on op
if op == "gemm" and b is not None:
m, k = a.shape[-2], a.shape[-1]
n = b.shape[-1]
out_dtype = a.dtype
out_nbytes = m * n * self._dtype_bytes(out_dtype)
else:
out_nbytes = a.nbytes
completion = CompletionHandle(id=self._next_completion_id())
self._emit_dispatch_overhead()
self._commands.append(CompositeCmd(
completion=completion, op=op,
a=a, b=b, out_pa=out_ptr, out_nbytes=out_nbytes,
math_op=math_op,
))
return completion
def wait(self, handle: CompletionHandle | None = None) -> None:
"""Wait for a specific composite or all pending composites."""
self._commands.append(WaitCmd(handle=handle))
def cycles(self, n: int) -> None:
"""Declare PE_CPU scalar execution overhead (cycles)."""
self._commands.append(PeCpuOverheadCmd(cycles=n))
# ── TensorHandle arithmetic operators ─────────────────────────────
# Enables: a + b, a * b, a - b, a / b in kernel code.
# Each creates a MathCmd via a module-level helper that requires a
# TLContext. We attach the context to handles via a closure approach.
def _enable_tensor_ops() -> None:
"""Patch TensorHandle with arithmetic operators.
Called once at module load. Operators create MathCmd entries via
a thread-local TLContext reference set during kernel execution.
"""
import threading
_local = threading.local()
def set_active_context(ctx: TLContext | None) -> None:
_local.ctx = ctx
def get_active_context() -> TLContext:
ctx = getattr(_local, "ctx", None)
if ctx is None:
raise RuntimeError("TensorHandle ops require an active TLContext")
return ctx
def _binop(op: str):
def method(self: TensorHandle, other: TensorHandle) -> TensorHandle:
ctx = get_active_context()
return ctx._binary_math(op, self, other)
return method
# Patch TensorHandle class with operators
TensorHandle.__add__ = _binop("add") # type: ignore[attr-defined]
TensorHandle.__sub__ = _binop("sub") # type: ignore[attr-defined]
TensorHandle.__mul__ = _binop("mul") # type: ignore[attr-defined]
TensorHandle.__truediv__ = _binop("div") # type: ignore[attr-defined]
# Expose context management
TLContext._set_active = staticmethod(set_active_context) # type: ignore[attr-defined]
TLContext._get_active = staticmethod(get_active_context) # type: ignore[attr-defined]
_enable_tensor_ops()
def run_kernel(
kernel_fn,
tl_ctx: TLContext,
*args,
**kwargs,
) -> list[PeCommand]:
"""Execute a kernel function with the given TLContext and return commands.
Sets tl_ctx as the active context for TensorHandle operators,
calls the kernel, then clears the context.
"""
TLContext._set_active(tl_ctx) # type: ignore[attr-defined]
try:
kernel_fn(*args, tl=tl_ctx, **kwargs)
finally:
TLContext._set_active(None) # type: ignore[attr-defined]
return tl_ctx.commands