commit - release 1
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from benches.loader import resolve_bench
|
||||
from kernbench.cli.probe import cmd_probe
|
||||
from kernbench.cli.report import format_report
|
||||
from kernbench.common.types import SimEngine
|
||||
from kernbench.runtime_api.bench_runner import run_bench
|
||||
from kernbench.runtime_api.types import DeviceSelector, resolve_device
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import resolve_topology
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(prog="kernbench")
|
||||
sub = p.add_subparsers(dest="cmd", required=True)
|
||||
|
||||
runp = sub.add_parser("run", help="Run a benchmark")
|
||||
runp.add_argument("--topology", required=True)
|
||||
runp.add_argument("--bench", required=True)
|
||||
runp.add_argument(
|
||||
"--device", default=None, help="Target device: 'all' or 'sip:<N>' (default: all)"
|
||||
)
|
||||
runp.set_defaults(_handler=cmd_run)
|
||||
|
||||
probep = sub.add_parser("probe", help="Probe latency and BW for predefined traffic patterns")
|
||||
probep.add_argument("--topology", required=True)
|
||||
probep.add_argument("--case", default="all", help="Case name or 'all' (default: all)")
|
||||
probep.set_defaults(_handler=cmd_probe)
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def engine_factory(topology: object, device: DeviceSelector) -> SimEngine:
|
||||
topo_obj = getattr(topology, "topology_obj", topology)
|
||||
return GraphEngine(topo_obj)
|
||||
|
||||
|
||||
def cmd_run(args) -> int:
|
||||
print("> Running benchmark with:", args)
|
||||
|
||||
topo = resolve_topology(args.topology)
|
||||
bench = resolve_bench(args.bench)
|
||||
device = resolve_device(args.device)
|
||||
|
||||
result = run_bench(topology=topo, bench_fn=bench, device=device, engine_factory=engine_factory)
|
||||
|
||||
topo_obj = getattr(topo, "topology_obj", topo)
|
||||
spec = getattr(topo_obj, "spec", None)
|
||||
if result.traces:
|
||||
print(format_report(result.traces, title=args.bench, spec=spec))
|
||||
print(result.summary_text())
|
||||
|
||||
return 0 if result.completion.ok else 1
|
||||
|
||||
|
||||
def main(argv=None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
return int(args._handler(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,248 @@
|
||||
"""kernbench probe: latency and BW verification utility.
|
||||
|
||||
Runs predefined traffic patterns through the simulation engine and reports
|
||||
latency, effective bandwidth, bottleneck bandwidth, and utilization for each
|
||||
case. Validates monotonicity invariants across hop counts and access types.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.policy.routing.router import AddressResolver, PathRouter
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
|
||||
from kernbench.sim_engine.engine import GraphEngine
|
||||
from kernbench.topology.builder import load_topology
|
||||
from kernbench.topology.types import TopologyGraph
|
||||
|
||||
|
||||
# -- Helpers ----------------------------------------------------------
|
||||
|
||||
|
||||
def _hbm_pa(sip: int, cube: int, pe_id: int, spec: dict) -> int:
|
||||
mm = spec["cube"]["memory_map"]
|
||||
slice_bytes = mm["hbm_total_gb_per_cube"] * (1 << 30) // mm["hbm_slices_per_cube"]
|
||||
pa = PhysAddr.pe_hbm_addr(
|
||||
rack_id=0, sip_id=sip, cube_id=cube, pe_id=pe_id,
|
||||
pe_local_hbm_offset=0x1000, slice_size_bytes=slice_bytes,
|
||||
)
|
||||
return pa.encode()
|
||||
|
||||
|
||||
def _build_edge_map(graph: TopologyGraph) -> dict[tuple[str, str], object]:
|
||||
return {(e.src, e.dst): e for e in graph.edges}
|
||||
|
||||
|
||||
def _formula_breakdown(
|
||||
path: list[str], nbytes: int, edge_map: dict, graph: TopologyGraph,
|
||||
) -> tuple[float, float, float, float]:
|
||||
"""Return (wire_ns, overhead_ns, drain_ns, formula_ns) for a path."""
|
||||
ns_per_mm = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
|
||||
wire_ns = 0.0
|
||||
for i in range(len(path) - 1):
|
||||
e = edge_map.get((path[i], path[i + 1]))
|
||||
if e:
|
||||
wire_ns += e.distance_mm * ns_per_mm
|
||||
overhead_ns = 0.0
|
||||
for nid in path:
|
||||
node = graph.nodes.get(nid)
|
||||
if node:
|
||||
overhead_ns += float(node.attrs.get("overhead_ns", 0.0))
|
||||
bws = [e.bw_gbs for i in range(len(path) - 1)
|
||||
if (e := edge_map.get((path[i], path[i + 1]))) and e.bw_gbs]
|
||||
drain_ns = nbytes / min(bws) if bws else 0.0
|
||||
return wire_ns, overhead_ns, drain_ns, wire_ns + overhead_ns + drain_ns
|
||||
|
||||
|
||||
def _bottleneck_bw(path: list[str], edge_map: dict) -> float | None:
|
||||
"""Per-request bottleneck: single request uses one connection."""
|
||||
bws: list[float] = []
|
||||
for i in range(len(path) - 1):
|
||||
e = edge_map.get((path[i], path[i + 1]))
|
||||
if e and e.bw_gbs:
|
||||
bws.append(e.bw_gbs)
|
||||
return min(bws) if bws else None
|
||||
|
||||
|
||||
|
||||
def _fmt_bw(bw: float | None) -> str:
|
||||
return f"{bw:.1f}" if bw is not None else "-"
|
||||
|
||||
|
||||
def _fmt_util(eff: float, bn: float | None) -> str:
|
||||
if bn is None or bn <= 0:
|
||||
return "-"
|
||||
return f"{eff / bn * 100:.1f}%"
|
||||
|
||||
|
||||
def _short_name(node_id: str) -> str:
|
||||
"""Shorten node id: keep last 2 segments to avoid ambiguity (xbar.pe0 vs pe0)."""
|
||||
parts = node_id.split(".")
|
||||
return ".".join(parts[-2:]) if len(parts) >= 2 else node_id
|
||||
|
||||
|
||||
def _short_path(path: list[str]) -> str:
|
||||
return " -> ".join(_short_name(n) for n in path)
|
||||
|
||||
|
||||
# -- Probe runner -----------------------------------------------------
|
||||
|
||||
|
||||
def run_probe(topology_path: str, case_filter: str | None = None) -> int:
|
||||
path = Path(topology_path).expanduser().resolve()
|
||||
graph = load_topology(path)
|
||||
edge_map = _build_edge_map(graph)
|
||||
spec = graph.spec
|
||||
resolver = AddressResolver(graph)
|
||||
router = PathRouter(graph)
|
||||
|
||||
nbytes = 4096
|
||||
show_all = case_filter is None or case_filter == "all"
|
||||
|
||||
# === H2D Write ===
|
||||
h2d_cases = [
|
||||
("h2d-1hop", 0, 1),
|
||||
("h2d-2hop", 4, 2),
|
||||
("h2d-3hop", 8, 3),
|
||||
("h2d-4hop", 12, 4),
|
||||
]
|
||||
h2d_results: list[tuple[str, int, float, float, float | None]] = []
|
||||
h2d_paths: list[tuple[str, list[str], list[str], list[str]]] = []
|
||||
|
||||
print()
|
||||
print("=== H2D Write Latency (IO->HBM, varying hop count) ===")
|
||||
print(f" {'Case':<14} {'Target':<16} {'Hops':>4} {'Actual':>8}"
|
||||
f" {'Ovhd':>6} {'Drain':>6} {'Wire':>5} {'Ovhd%':>6} {'Drain%':>7}"
|
||||
f" {'Eff.BW':>8} {'BN.BW':>8} {'Util%':>6}")
|
||||
print(" " + "-" * 115)
|
||||
|
||||
for name, cube, hops in h2d_cases:
|
||||
if not show_all and case_filter != name:
|
||||
continue
|
||||
engine = GraphEngine(graph)
|
||||
pa = _hbm_pa(sip=0, cube=cube, pe_id=0, spec=spec)
|
||||
msg = MemoryWriteMsg(
|
||||
correlation_id="probe", request_id=name,
|
||||
dst_sip=0, dst_cube=cube, dst_pe=0,
|
||||
dst_pa=pa, nbytes=nbytes, pattern="zero",
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
_, trace = engine.get_completion(h)
|
||||
total_ns = trace["total_ns"]
|
||||
eff_bw = nbytes / total_ns if total_ns > 0 else 0.0
|
||||
|
||||
pa_obj = PhysAddr.decode(pa)
|
||||
dst_node = resolver.resolve(pa_obj)
|
||||
|
||||
pcie_ep = resolver.find_pcie_ep(0)
|
||||
io_cpu = resolver.find_io_cpu(0)
|
||||
m_cpu = resolver.find_m_cpu(0, cube)
|
||||
leg1 = router.find_node_path(pcie_ep, io_cpu)
|
||||
leg2 = router.find_node_path(io_cpu, m_cpu)
|
||||
leg3 = router.find_mcpu_dma_path(m_cpu, dst_node)
|
||||
full_path = leg1 + leg2[1:] + leg3[1:]
|
||||
bn_bw = _bottleneck_bw(full_path, edge_map)
|
||||
|
||||
# Forward path breakdown only (response path is implicit in actual_ns)
|
||||
fwd_path = leg1 + leg2[1:] + leg3[1:]
|
||||
wire, ovhd, drain, formula = _formula_breakdown(fwd_path, nbytes, edge_map, graph)
|
||||
|
||||
ovhd_pct = ovhd / total_ns * 100 if total_ns > 0 else 0
|
||||
drain_pct = drain / total_ns * 100 if total_ns > 0 else 0
|
||||
|
||||
h2d_results.append((name, hops, total_ns, eff_bw, bn_bw))
|
||||
h2d_paths.append((name, leg1, leg2, leg3))
|
||||
print(f" {name:<14} cube{cube}.pe0{'':<8} {hops:>4} {total_ns:>8.2f}"
|
||||
f" {ovhd:>6.1f} {drain:>6.1f} {wire:>5.2f} {ovhd_pct:>5.1f}% {drain_pct:>5.1f}%"
|
||||
f" {eff_bw:>8.2f} {_fmt_bw(bn_bw):>8} {_fmt_util(eff_bw, bn_bw):>6}")
|
||||
|
||||
if len(h2d_results) >= 2:
|
||||
lats = [r[2] for r in h2d_results]
|
||||
mono = all(lats[i] < lats[i + 1] for i in range(len(lats) - 1))
|
||||
sym = "[v]" if mono else "[x]"
|
||||
print(f" {sym} Monotonic increase: {'PASS' if mono else 'FAIL'}")
|
||||
|
||||
if h2d_paths:
|
||||
print()
|
||||
print(" Route Details:")
|
||||
print(f" {'Case':<14} {'Leg':>4} Path")
|
||||
print(" " + "-" * 80)
|
||||
for name, leg1, leg2, leg3 in h2d_paths:
|
||||
print(f" {name:<14} {'L1':>4} {_short_path(leg1)}")
|
||||
print(f" {'':<14} {'L2':>4} {_short_path(leg2)}")
|
||||
print(f" {'':<14} {'L3':>4} {_short_path(leg3)}")
|
||||
|
||||
# === PE DMA → HBM (direct PE-level injection) ===
|
||||
# (name, sip, src_cube, src_pe, dst_cube, dst_pe)
|
||||
pe_cases = [
|
||||
("pe-local-hbm", 0, 0, 0, 0, 0), # pe0 → slice0 (local, 256 GB/s)
|
||||
("pe-same-half-hbm", 0, 0, 0, 0, 1), # pe0 → slice1 (xbar chain, 128 GB/s)
|
||||
("pe-cross-half-hbm", 0, 0, 0, 0, 4), # pe0 → slice4 (xbar chain, 128 GB/s)
|
||||
("pe-cross-cube-hbm", 0, 0, 0, 1, 0), # cube0.pe0 → cube1.slice0 (NOC, 128 GB/s)
|
||||
]
|
||||
pe_results: list[tuple[str, float, float, float | None]] = []
|
||||
pe_paths: list[tuple[str, list[str]]] = []
|
||||
|
||||
print()
|
||||
print("=== PE DMA Latency (pe_dma -> xbar -> HBM, direct injection) ===")
|
||||
print(f" {'Case':<22} {'Target':<28} {'Actual':>8}"
|
||||
f" {'Ovhd':>6} {'Drain':>6} {'Wire':>5} {'Ovhd%':>6} {'Drain%':>7}"
|
||||
f" {'Eff.BW':>8} {'BN.BW':>8} {'Util%':>6}")
|
||||
print(" " + "-" * 120)
|
||||
|
||||
for name, sip, src_cube, src_pe, dst_cube, dst_pe in pe_cases:
|
||||
if not show_all and case_filter != name:
|
||||
continue
|
||||
engine = GraphEngine(graph)
|
||||
dst_pa = _hbm_pa(sip=sip, cube=dst_cube, pe_id=dst_pe, spec=spec)
|
||||
msg = PeDmaMsg(
|
||||
correlation_id="probe", request_id=name,
|
||||
src_sip=sip, src_cube=src_cube, src_pe=src_pe,
|
||||
dst_pa=dst_pa, nbytes=nbytes,
|
||||
)
|
||||
h = engine.submit(msg)
|
||||
engine.wait(h)
|
||||
_, trace = engine.get_completion(h)
|
||||
total_ns = trace["total_ns"]
|
||||
eff_bw = nbytes / total_ns if total_ns > 0 else 0.0
|
||||
|
||||
pe_ref = f"sip{sip}.cube{src_cube}.pe{src_pe}"
|
||||
pa_obj = PhysAddr.decode(dst_pa)
|
||||
dst_node = resolver.resolve(pa_obj)
|
||||
dma_path = router.find_path(pe_ref, dst_node)
|
||||
bn_bw = _bottleneck_bw(dma_path, edge_map)
|
||||
|
||||
wire, ovhd, drain, formula = _formula_breakdown(dma_path, nbytes, edge_map, graph)
|
||||
|
||||
ovhd_pct = ovhd / total_ns * 100 if total_ns > 0 else 0
|
||||
drain_pct = drain / total_ns * 100 if total_ns > 0 else 0
|
||||
|
||||
target_str = f"c{src_cube}.pe{src_pe}->c{dst_cube}.slice{dst_pe}"
|
||||
pe_results.append((name, total_ns, eff_bw, bn_bw))
|
||||
pe_paths.append((name, dma_path))
|
||||
print(f" {name:<22} {target_str:<28} {total_ns:>8.2f}"
|
||||
f" {ovhd:>6.1f} {drain:>6.1f} {wire:>5.2f} {ovhd_pct:>5.1f}% {drain_pct:>5.1f}%"
|
||||
f" {eff_bw:>8.2f} {_fmt_bw(bn_bw):>8} {_fmt_util(eff_bw, bn_bw):>6}")
|
||||
|
||||
if len(pe_results) >= 2:
|
||||
local = [r for r in pe_results if "local" in r[0]]
|
||||
chain = [r for r in pe_results if "local" not in r[0]]
|
||||
if local and chain:
|
||||
print(f" * Local BN: {_fmt_bw(local[0][3])} GB/s, "
|
||||
f"Chain/NOC BN: {_fmt_bw(chain[0][3])} GB/s")
|
||||
|
||||
if pe_paths:
|
||||
print()
|
||||
print(" Route Details:")
|
||||
print(f" {'Case':<22} Path")
|
||||
print(" " + "-" * 80)
|
||||
for name, dma_path in pe_paths:
|
||||
print(f" {name:<22} {_short_path(dma_path)}")
|
||||
|
||||
print()
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_probe(args) -> int:
|
||||
return run_probe(args.topology, getattr(args, "case", "all"))
|
||||
@@ -0,0 +1,175 @@
|
||||
"""Performance report formatter for bench results."""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
_DTYPE_BITS: dict[str, int] = {
|
||||
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
|
||||
"f32": 32, "fp32": 32, "float32": 32,
|
||||
"i8": 8, "int8": 8, "i16": 16, "int16": 16, "i32": 32, "int32": 32,
|
||||
}
|
||||
|
||||
|
||||
def format_report(
|
||||
traces: list[dict],
|
||||
title: str = "Benchmark",
|
||||
spec: dict | None = None,
|
||||
) -> str:
|
||||
"""Format collected traces into a human-readable performance report.
|
||||
|
||||
spec: topology spec dict for peak TFLOPS / BW extraction.
|
||||
"""
|
||||
peak_tflops_f16, peak_hbm_bw_gbs = _extract_peaks(spec)
|
||||
num_pes = _count_pes(spec)
|
||||
|
||||
lines: list[str] = []
|
||||
title_line = f"-- {title} Performance Report "
|
||||
|
||||
deploy_entries = [t for t in traces if t.get("phase") not in ("kernel",)]
|
||||
kernel_entries = [t for t in traces if t.get("phase") == "kernel"]
|
||||
|
||||
# ── Title ──
|
||||
# Compute max header width for consistent separator lengths
|
||||
_cmd_hdr = (f"{'Cmd':<10} {'Name':<12} {'SIP':>4} {'Cube':>5} {'PE':>4} {'Bytes':>10} "
|
||||
f"{'Lat(ns)':>10} {'Xfer(ns)':>10} {'Proc(ns)':>10} "
|
||||
f"{'BW(GB/s)':>10} {'MinBW':>10} {'Util%':>7}")
|
||||
report_width = len(_cmd_hdr)
|
||||
lines.append(title_line + "-" * max(0, report_width - len(title_line)))
|
||||
|
||||
# ── Command summary ──
|
||||
if deploy_entries:
|
||||
lines.append("")
|
||||
hdr = (f"{'Cmd':<10} {'Name':<12} {'SIP':>4} {'Cube':>5} {'PE':>4} {'Bytes':>10} "
|
||||
f"{'Lat(ns)':>10} {'Xfer(ns)':>10} {'Proc(ns)':>10} "
|
||||
f"{'BW(GB/s)':>10} {'MinBW':>10} {'Util%':>7}")
|
||||
lines.append(hdr)
|
||||
lines.append("-" * len(hdr))
|
||||
for e in deploy_entries:
|
||||
lat = e.get("total_ns", 0.0)
|
||||
nb = e.get("nbytes", 0)
|
||||
sip = e.get("sip", "-")
|
||||
pe = e.get("pe", "-")
|
||||
cube = e.get("cube", "-")
|
||||
cmd = e.get("phase", "deploy")
|
||||
xfer_ns = e.get("xfer_ns", 0.0)
|
||||
proc_ns = lat - xfer_ns if xfer_ns > 0 else 0.0
|
||||
bw = nb / lat if lat > 0 else 0.0
|
||||
min_bw = nb / xfer_ns if xfer_ns > 0 else 0.0
|
||||
util = (xfer_ns / lat * 100) if lat > 0 and xfer_ns > 0 else 0.0
|
||||
lines.append(
|
||||
f"{cmd:<10} {e.get('name', '?'):<12} {str(sip):>4} {str(cube):>5} {str(pe):>4} {nb:>10} "
|
||||
f"{lat:>10.1f} {xfer_ns:>10.1f} {proc_ns:>10.1f} "
|
||||
f"{bw:>10.1f} {min_bw:>10.1f} {util:>6.1f}%"
|
||||
)
|
||||
|
||||
# ── Kernel summary ──
|
||||
if kernel_entries:
|
||||
lines.append("")
|
||||
k_hdr = (f"{'Phase':<10} {'Name':<12} {'PE':>4} {'E2E(ns)':>10} "
|
||||
f"{'PE(ns)':>10} {'DMA(ns)':>10} {'Comp(ns)':>10} "
|
||||
f"{'Bound':<8} {'TFLOPS':>8} {'Peak':>8} {'Util%':>7}")
|
||||
lines.append(k_hdr)
|
||||
lines.append("-" * len(k_hdr))
|
||||
for e in kernel_entries:
|
||||
e2e_ns = e.get("total_ns", 0.0)
|
||||
pe_ns = e.get("pe_exec_ns", e2e_ns)
|
||||
dma_ns = e.get("dma_ns", 0.0)
|
||||
compute_ns = e.get("compute_ns", 0.0)
|
||||
target_pe = e.get("target_pe", "-")
|
||||
scalars = e.get("scalars", [])
|
||||
pe_str = "all" if target_pe == "all" else str(target_pe)
|
||||
n_active = num_pes if target_pe == "all" else 1
|
||||
|
||||
# Bound indicator based on measured DMA vs compute time
|
||||
if dma_ns > 0 or compute_ns > 0:
|
||||
bound = "memory" if dma_ns >= compute_ns else "compute"
|
||||
else:
|
||||
bound = "-"
|
||||
|
||||
achieved = _calc_tflops(scalars, pe_ns)
|
||||
peak_total = peak_tflops_f16 * n_active
|
||||
util = (achieved / peak_total * 100) if peak_total > 0 else 0.0
|
||||
lines.append(
|
||||
f"{'kernel':<10} {e.get('name', '?'):<12} {pe_str:>4} {e2e_ns:>10.1f} "
|
||||
f"{pe_ns:>10.1f} {dma_ns:>10.1f} {compute_ns:>10.1f} "
|
||||
f"{bound:<8} {achieved:>8.3f} {peak_total:>8.1f} {util:>6.1f}%"
|
||||
)
|
||||
|
||||
# ── Per-PE summary ──
|
||||
pe_deploy = _per_pe_deploy(deploy_entries)
|
||||
if len(pe_deploy) > 1:
|
||||
lines.append("")
|
||||
pe_title = (f"-- Per-PE Summary (peak: {peak_tflops_f16:.1f} TFLOPS/PE, "
|
||||
f"{peak_hbm_bw_gbs:.0f} GB/s HBM BW) ")
|
||||
pe_hdr = (f"{'PE':>4} {'Deploy(ns)':>10} {'BW(GB/s)':>10} {'BW Util':>8} "
|
||||
f"{'Kernel(ns)':>10} {'TFLOPS':>8} {'Util':>7}")
|
||||
pe_width = max(len(pe_title), len(pe_hdr))
|
||||
lines.append(pe_title + "-" * max(0, pe_width - len(pe_title)))
|
||||
lines.append(pe_hdr)
|
||||
lines.append("-" * pe_width)
|
||||
|
||||
k_ns = sum(e.get("pe_exec_ns", e.get("total_ns", 0.0)) for e in kernel_entries)
|
||||
k_scalars = kernel_entries[0].get("scalars", []) if kernel_entries else []
|
||||
n_active = len(pe_deploy)
|
||||
total_achieved = _calc_tflops(k_scalars, k_ns)
|
||||
per_pe_tflops = total_achieved / n_active if n_active > 0 else 0.0
|
||||
pe_util = (per_pe_tflops / peak_tflops_f16 * 100) if peak_tflops_f16 > 0 else 0.0
|
||||
|
||||
for pe_id in sorted(pe_deploy):
|
||||
d_ns, d_bytes = pe_deploy[pe_id]
|
||||
d_bw = d_bytes / d_ns if d_ns > 0 else 0.0
|
||||
d_util = (d_bw / peak_hbm_bw_gbs * 100) if peak_hbm_bw_gbs > 0 else 0.0
|
||||
lines.append(
|
||||
f"{pe_id:>4} {d_ns:>10.1f} {d_bw:>10.1f} {d_util:>7.1f}% "
|
||||
f"{k_ns:>10.1f} {per_pe_tflops:>8.3f} {pe_util:>6.1f}%"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_peaks(spec: dict | None) -> tuple[float, float]:
|
||||
"""Extract peak TFLOPS (f16) and HBM BW (GB/s) from spec."""
|
||||
if spec is None:
|
||||
return 0.0, 0.0
|
||||
cube = spec.get("cube", {})
|
||||
pe_template = cube.get("pe_template", {})
|
||||
comps = pe_template.get("components", {})
|
||||
gemm_attrs = comps.get("pe_gemm", {}).get("attrs", {})
|
||||
peak_tflops = float(gemm_attrs.get("peak_tflops_f16", 0.0))
|
||||
cube_links = cube.get("links", {})
|
||||
hbm_bw = float(cube_links.get("xbar_to_hbm_bw_gbs", 0.0))
|
||||
return peak_tflops, hbm_bw
|
||||
|
||||
|
||||
def _count_pes(spec: dict | None) -> int:
|
||||
if spec is None:
|
||||
return 8
|
||||
cube = spec.get("cube", {})
|
||||
layout = cube.get("pe_layout", {})
|
||||
per_corner = layout.get("pe_per_corner", 2)
|
||||
corners = len(layout.get("corners", ["NW", "NE", "SW", "SE"]))
|
||||
return per_corner * corners
|
||||
|
||||
|
||||
def _calc_tflops(scalars: list, latency_ns: float) -> float:
|
||||
"""Calculate achieved TFLOPS from scalar args [M, K, N] and latency."""
|
||||
if len(scalars) < 3 or latency_ns <= 0:
|
||||
return 0.0
|
||||
m, k, n = scalars[0], scalars[1], scalars[2]
|
||||
flops = 2.0 * m * k * n
|
||||
return flops / (latency_ns * 1e-9) / 1e12
|
||||
|
||||
|
||||
def _per_pe_deploy(deploy_entries: list[dict]) -> dict[int, tuple[float, int]]:
|
||||
"""Aggregate deploy latency and bytes per PE."""
|
||||
result: dict[int, tuple[float, int]] = {}
|
||||
for e in deploy_entries:
|
||||
pe = e.get("pe", 0)
|
||||
lat = e.get("total_ns", 0.0)
|
||||
nb = e.get("nbytes", 0)
|
||||
if pe in result:
|
||||
old_ns, old_bytes = result[pe]
|
||||
result[pe] = (old_ns + lat, old_bytes + nb)
|
||||
else:
|
||||
result[pe] = (lat, nb)
|
||||
return result
|
||||
@@ -0,0 +1,150 @@
|
||||
"""PE-internal command types and handles (ADR-0014).
|
||||
|
||||
Generated by triton_emu (TLContext) and consumed by PE component
|
||||
implementations (PE_CPU, PE_SCHEDULER, PE_DMA, PE_GEMM, PE_MATH).
|
||||
|
||||
Command lifecycle:
|
||||
Triton kernel → TLContext → [PeCommand list] → PE_CPU → PE_SCHEDULER → engines
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import simpy
|
||||
|
||||
|
||||
# ── Handles ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorHandle:
|
||||
"""Opaque reference to a tensor residing in PE_TCM.
|
||||
|
||||
Returned by tl.load, tl.dot, tl.exp, etc.
|
||||
Carries metadata for command generation; data field is reserved
|
||||
for future validate mode (numpy array).
|
||||
"""
|
||||
|
||||
id: str
|
||||
pa: int # physical address in HBM/TCM
|
||||
shape: tuple[int, ...]
|
||||
dtype: str
|
||||
nbytes: int # total byte size
|
||||
data: object = None # reserved for validate mode
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompletionHandle:
|
||||
"""Opaque handle for a non-blocking composite command.
|
||||
|
||||
Returned by tl.composite, consumed by tl.wait.
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
|
||||
# ── PE Commands ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DmaReadCmd:
|
||||
"""DMA READ: HBM → PE_TCM."""
|
||||
|
||||
handle: TensorHandle
|
||||
src_pa: int
|
||||
nbytes: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DmaWriteCmd:
|
||||
"""DMA WRITE: PE_TCM → HBM."""
|
||||
|
||||
handle: TensorHandle
|
||||
dst_pa: int
|
||||
nbytes: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GemmCmd:
|
||||
"""GEMM engine command: matrix multiply on TCM data.
|
||||
|
||||
out = a @ b, all operands in TCM.
|
||||
"""
|
||||
|
||||
a: TensorHandle
|
||||
b: TensorHandle
|
||||
out: TensorHandle
|
||||
m: int
|
||||
k: int
|
||||
n: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MathCmd:
|
||||
"""MATH engine command: unary/binary/reduction on TCM data.
|
||||
|
||||
op: "exp", "log", "sqrt", "abs", "sigmoid", "cos", "sin",
|
||||
"add", "sub", "mul", "div", "where",
|
||||
"sum", "max", "min"
|
||||
"""
|
||||
|
||||
op: str
|
||||
inputs: tuple[TensorHandle, ...]
|
||||
out: TensorHandle
|
||||
axis: int | None = None # for reductions
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompositeCmd:
|
||||
"""Composite command: tiled pipeline of DMA_READ + COMPUTE + DMA_WRITE.
|
||||
|
||||
Non-blocking — submitted to PE_SCHEDULER which manages tile splitting
|
||||
and pipeline overlaps (ADR-0014 D3.2).
|
||||
"""
|
||||
|
||||
completion: CompletionHandle
|
||||
op: Literal["gemm", "math"]
|
||||
a: TensorHandle
|
||||
b: TensorHandle | None
|
||||
out_pa: int
|
||||
out_nbytes: int
|
||||
math_op: str | None = None # for op="math": which math operation
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WaitCmd:
|
||||
"""Wait for a specific composite or all pending composites."""
|
||||
|
||||
handle: CompletionHandle | None = None # None = wait all
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PeCpuOverheadCmd:
|
||||
"""PE_CPU scalar execution overhead (cycles)."""
|
||||
|
||||
cycles: int
|
||||
|
||||
|
||||
# Union type for all PE commands
|
||||
PeCommand = (
|
||||
DmaReadCmd | DmaWriteCmd | GemmCmd | MathCmd
|
||||
| CompositeCmd | WaitCmd | PeCpuOverheadCmd
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeInternalTxn:
|
||||
"""PE-internal message flowing PE_CPU → PE_SCHEDULER → engines.
|
||||
|
||||
Carries a single PeCommand and a completion event. PE_CPU creates one
|
||||
PeInternalTxn per command during the replay phase and sends it to
|
||||
PE_SCHEDULER, which routes it to the appropriate engine (PE_DMA,
|
||||
PE_GEMM, PE_MATH). The engine signals ``done`` on completion.
|
||||
"""
|
||||
|
||||
command: PeCommand
|
||||
done: simpy.Event # succeeded when the engine completes this command
|
||||
pe_prefix: str = "" # e.g. "sip0.cube0.pe0" — needed by PE_DMA for path resolution
|
||||
result_data: dict[str, Any] = field(default_factory=dict)
|
||||
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NewType, Protocol, TypeAlias
|
||||
|
||||
RequestHandle = NewType("RequestHandle", str)
|
||||
|
||||
Trace: TypeAlias = Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Completion:
|
||||
ok: bool
|
||||
error_code: str | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class SimEngine(Protocol):
|
||||
"""
|
||||
Backend simulation/runner engine contract.
|
||||
|
||||
Engine must be able to:
|
||||
- accept requests created by RuntimeContext (submit/dispatch)
|
||||
- report completion and optional trace for a given handle
|
||||
"""
|
||||
|
||||
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]: ...
|
||||
def submit(self, request: Any) -> RequestHandle: ...
|
||||
def wait(self, handle: RequestHandle) -> None: ...
|
||||
@@ -0,0 +1,4 @@
|
||||
from kernbench.components.base import ComponentBase, ComponentRegistry
|
||||
from kernbench.components.context import ComponentContext
|
||||
|
||||
__all__ = ["ComponentBase", "ComponentRegistry", "ComponentContext"]
|
||||
@@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class ComponentBase(ABC):
|
||||
"""Base class for all SimPy component implementations (ADR-0007 D3, ADR-0015).
|
||||
|
||||
Each component corresponds to one node in the compiled topology graph.
|
||||
It models the processing overhead at that node as a SimPy generator,
|
||||
allowing future implementations to add queueing and contention.
|
||||
|
||||
Port model (ADR-0015 D1):
|
||||
in_ports[src_node_id] — SimPy Store for incoming messages from src
|
||||
out_ports[dst_node_id] — SimPy Store for outgoing messages to dst
|
||||
Ports are wired by GraphEngine at initialization; wire processes model
|
||||
propagation delay between connected ports (ADR-0015 D2).
|
||||
|
||||
Context (ADR-0015 D4):
|
||||
ctx — ComponentContext with router and resolver.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
self.node = node
|
||||
self.ctx = ctx
|
||||
self.in_ports: dict[str, simpy.Store] = {}
|
||||
self.out_ports: dict[str, simpy.Store] = {}
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
"""Called once after all ports are wired.
|
||||
|
||||
Default: starts a fan-in collector and a generic forwarding worker.
|
||||
The worker calls self.run() for per-component latency, then routes the
|
||||
Transaction to the next hop or signals done (duck-typed; no direct
|
||||
Transaction import to avoid circular dependencies).
|
||||
|
||||
Override in components that need custom fan-out / aggregation logic
|
||||
(e.g. MCpuComponent, IoCpuComponent for kernel launch).
|
||||
"""
|
||||
if not self.in_ports:
|
||||
return
|
||||
self._inbox: simpy.Store = simpy.Store(env)
|
||||
for port in self.in_ports.values():
|
||||
env.process(self._fan_in(port))
|
||||
env.process(self._worker(env))
|
||||
|
||||
def _fan_in(self, port: simpy.Store) -> Generator:
|
||||
"""Relay messages from one in_port into the shared inbox."""
|
||||
while True:
|
||||
msg = yield port.get()
|
||||
yield self._inbox.put(msg)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Generic forwarding worker: spawns _forward_txn per message (pipeline)."""
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._forward_txn(env, txn))
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Apply run() latency, then forward to next hop or drain at terminal."""
|
||||
yield from self.run(env, txn.nbytes)
|
||||
next_hop = txn.next_hop # duck-typed: Transaction.next_hop
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
else:
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
txn.done.succeed()
|
||||
|
||||
@abstractmethod
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
"""SimPy process: yield one or more events for this node's processing.
|
||||
|
||||
Subclasses yield env.timeout(overhead_ns) or compute latency dynamically.
|
||||
Called by _forward_txn and subclass-specific handlers.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class PeEngineBase(ComponentBase):
|
||||
"""Base class for PE-internal engines (PE_DMA, PE_GEMM, PE_MATH).
|
||||
|
||||
Provides:
|
||||
- ``_pe_prefix``: extracted from node.id (e.g. "sip0.cube0.pe0")
|
||||
- Dual-message ``_worker``: dispatches PeInternalTxn to
|
||||
``handle_command()`` and Transaction to inherited ``_forward_txn()``.
|
||||
- ``init_resources(env)``: hook for subclass resource initialization,
|
||||
called by ``start()`` before the worker is spawned.
|
||||
|
||||
Subclass contract:
|
||||
1. Override ``handle_command(env, pe_txn)`` — process a PeInternalTxn.
|
||||
2. Override ``run(env, nbytes)`` — yield component latency.
|
||||
3. Optionally override ``init_resources(env)`` for DMA channels, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._pe_prefix: str = node.id.rsplit(".", 1)[0]
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
self.init_resources(env)
|
||||
super().start(env)
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
"""Hook for subclass resource initialization. Called before worker spawn."""
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Dual-message dispatch: PeInternalTxn → handle_command, Transaction → _forward_txn."""
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, PeInternalTxn):
|
||||
env.process(self.handle_command(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
@abstractmethod
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: Any) -> Generator:
|
||||
"""Process a PE-internal command (PeInternalTxn).
|
||||
|
||||
Subclass must:
|
||||
- Perform engine-specific work (acquire resources, compute, etc.)
|
||||
- Call ``pe_txn.done.succeed()`` on completion.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
"""DI registry: maps node.impl strings to ComponentBase subclasses.
|
||||
|
||||
Resolution order for ComponentRegistry.create(node, overrides, ctx):
|
||||
1. overrides[node.impl] — caller-injected override
|
||||
2. _registry[node.impl] — globally registered impl
|
||||
3. Error — no fallback; every node must have an impl
|
||||
"""
|
||||
|
||||
_registry: dict[str, type[ComponentBase]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, impl: str, component_cls: type[ComponentBase]) -> None:
|
||||
cls._registry[impl] = component_cls
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
node: Node,
|
||||
overrides: dict[str, type[ComponentBase]] | None = None,
|
||||
ctx: ComponentContext | None = None,
|
||||
) -> ComponentBase:
|
||||
if overrides and node.impl in overrides:
|
||||
return overrides[node.impl](node, ctx)
|
||||
if node.impl in cls._registry:
|
||||
return cls._registry[node.impl](node, ctx)
|
||||
raise ValueError(
|
||||
f"No component registered for impl '{node.impl}' (node: {node.id}). "
|
||||
f"Register it in kernbench.components.impls.__init__."
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.policy.routing.router import AddressResolver, PathRouter
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComponentContext:
|
||||
"""Topology services injected into every component implementation.
|
||||
|
||||
Required by components that need routing or address resolution
|
||||
(IoCpuComponent, MCpuComponent, …). TransitComponent ignores ctx.
|
||||
|
||||
Passed via ComponentRegistry.create(node, overrides, ctx=ctx).
|
||||
"""
|
||||
|
||||
router: PathRouter
|
||||
resolver: AddressResolver
|
||||
positions: dict[str, tuple[float, float] | None] # node_id → pos_mm
|
||||
ns_per_mm: float # wire propagation constant (from topology spec)
|
||||
edge_map: dict[tuple[str, str], Any] = field(default_factory=dict)
|
||||
spec: dict = field(default_factory=dict) # topology spec (cube layout, PE count, etc.)
|
||||
|
||||
def get_shared_resource(
|
||||
self, env: simpy.Environment, key: str, capacity: int = 1,
|
||||
) -> simpy.Resource:
|
||||
"""Return a shared SimPy Resource, creating it on first access.
|
||||
|
||||
Used by PE components that share a resource across engines within
|
||||
the same PE (e.g. accel_slot shared by PE_GEMM and PE_MATH).
|
||||
Key should be scoped per PE: e.g. "sip0.cube0.pe0.accel_slot".
|
||||
"""
|
||||
if not hasattr(self, "_shared_resources"):
|
||||
self._shared_resources: dict[str, simpy.Resource] = {}
|
||||
if key not in self._shared_resources:
|
||||
self._shared_resources[key] = simpy.Resource(env, capacity=capacity)
|
||||
return self._shared_resources[key]
|
||||
|
||||
def compute_drain_ns(self, path: list[str], nbytes: int) -> float:
|
||||
"""Wormhole drain time: nbytes / bottleneck_bw along path."""
|
||||
min_bw = float("inf")
|
||||
for i in range(len(path) - 1):
|
||||
edge = self.edge_map.get((path[i], path[i + 1]))
|
||||
if edge and getattr(edge, "bw_gbs", None):
|
||||
min_bw = min(min_bw, edge.bw_gbs)
|
||||
if min_bw == float("inf"):
|
||||
return 0.0
|
||||
return nbytes / min_bw
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Concrete component implementations.
|
||||
|
||||
Each module registers its component(s) with ComponentRegistry on import.
|
||||
Import this package to activate all built-in implementations.
|
||||
"""
|
||||
|
||||
from kernbench.components.base import ComponentRegistry
|
||||
from kernbench.components.impls.forwarding import TransitComponent
|
||||
from kernbench.components.impls.hbm_ctrl import HbmCtrlComponent
|
||||
from kernbench.components.impls.io_cpu import IoCpuComponent
|
||||
from kernbench.components.impls.m_cpu import MCpuComponent
|
||||
from kernbench.components.impls.noc import TwoDMeshNocComponent
|
||||
from kernbench.components.impls.pcie_ep import PcieEpComponent
|
||||
from kernbench.components.impls.pe_cpu import PeCpuComponent
|
||||
from kernbench.components.impls.pe_dma import PeDmaComponent
|
||||
from kernbench.components.impls.pe_gemm import PeGemmComponent
|
||||
from kernbench.components.impls.pe_math import PeMathComponent
|
||||
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
|
||||
from kernbench.components.impls.pe_tcm import PeTcmComponent
|
||||
from kernbench.components.impls.sram import SramComponent
|
||||
|
||||
ComponentRegistry.register("forwarding_v1", TransitComponent)
|
||||
ComponentRegistry.register("switch_v1", TransitComponent)
|
||||
ComponentRegistry.register("noc_v1", TransitComponent)
|
||||
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
|
||||
ComponentRegistry.register("ucie_v1", TransitComponent)
|
||||
ComponentRegistry.register("xbar_v1", TransitComponent)
|
||||
ComponentRegistry.register("pcie_ep_v1", PcieEpComponent)
|
||||
ComponentRegistry.register("io_cpu_v1", IoCpuComponent)
|
||||
ComponentRegistry.register("m_cpu_v1", MCpuComponent)
|
||||
ComponentRegistry.register("hbm_ctrl_v1", HbmCtrlComponent)
|
||||
ComponentRegistry.register("sram_v1", SramComponent)
|
||||
ComponentRegistry.register("pe_cpu_v1", PeCpuComponent)
|
||||
ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
|
||||
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
|
||||
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
|
||||
ComponentRegistry.register("pe_math_v1", PeMathComponent)
|
||||
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
|
||||
|
||||
__all__ = [
|
||||
"HbmCtrlComponent",
|
||||
"IoCpuComponent",
|
||||
"MCpuComponent",
|
||||
"PcieEpComponent",
|
||||
"PeCpuComponent",
|
||||
"PeDmaComponent",
|
||||
"PeGemmComponent",
|
||||
"PeMathComponent",
|
||||
"PeSchedulerComponent",
|
||||
"PeTcmComponent",
|
||||
"TransitComponent",
|
||||
"TwoDMeshNocComponent",
|
||||
"SramComponent",
|
||||
]
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class TransitComponent(ComponentBase):
|
||||
"""Transit component for NOC, UCIe, XBAR nodes.
|
||||
|
||||
Applies overhead_ns processing delay (from node.attrs) then forwards the
|
||||
Transaction to the next hop via inherited _forward_txn().
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class HbmCtrlComponent(ComponentBase):
|
||||
"""HBM controller: terminal component that models HBM access latency.
|
||||
|
||||
Dual-channel model: separate read and write resources (each capacity=1)
|
||||
allowing concurrent read/write like PE_DMA. Multiple reads or multiple
|
||||
writes still serialize within their respective channel.
|
||||
|
||||
On completion, creates a ResponseMsg and sends it back on the reverse path
|
||||
so that response latency is modeled through the fabric.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._read: simpy.Resource | None = None
|
||||
self._write: simpy.Resource | None = None
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
capacity = int(self.node.attrs.get("capacity", 1))
|
||||
self._read = simpy.Resource(env, capacity=capacity)
|
||||
self._write = simpy.Resource(env, capacity=capacity)
|
||||
super().start(env)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _select_channel(self, txn: Any) -> simpy.Resource:
|
||||
"""Select channel based on request type: write requests → write, else → read."""
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
|
||||
|
||||
assert self._read is not None and self._write is not None
|
||||
req = txn.request
|
||||
if isinstance(req, MemoryWriteMsg):
|
||||
return self._write
|
||||
if isinstance(req, PeDmaMsg) and req.is_write:
|
||||
return self._write
|
||||
return self._read
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Dispatch each incoming txn to a concurrent process for channel-level parallelism."""
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._handle_txn(env, txn))
|
||||
|
||||
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Acquire channel, run, apply drain, send response."""
|
||||
channel = self._select_channel(txn)
|
||||
with channel.request() as req:
|
||||
yield req
|
||||
yield from self.run(env, txn.nbytes)
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
yield from self._send_response(env, txn)
|
||||
|
||||
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Create ResponseMsg and send on reverse path back to originator.
|
||||
|
||||
PeDmaMsg is a direct probe with no IO_CPU/M_CPU aggregation in the path,
|
||||
so we succeed txn.done directly instead of sending a response Transaction.
|
||||
"""
|
||||
from kernbench.runtime_api.kernel import PeDmaMsg
|
||||
|
||||
if isinstance(txn.request, PeDmaMsg):
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2 and self.ctx:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
pe_id = int(parts[3].replace("slice", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=txn.request.correlation_id,
|
||||
request_id=txn.request.request_id,
|
||||
src_cube=cube_id, src_pe=pe_id, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
@@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class IoCpuComponent(ComponentBase):
|
||||
"""IO_CPU component: multi-cube fan-out with response aggregation.
|
||||
|
||||
Forward path:
|
||||
1. Applies overhead_ns processing overhead.
|
||||
2. Resolves target cube(s) from request.target_cubes.
|
||||
3. Fans out sub-Transactions to each target cube's M_CPU.
|
||||
|
||||
Response path:
|
||||
Collects ResponseMsg from each M_CPU. When all cube responses are
|
||||
received, succeeds the parent txn.done.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
# Pending fan-out tracking: request_id → (expected, received, parent_txn_done)
|
||||
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
if getattr(txn, "is_response", False):
|
||||
self._collect_response(txn)
|
||||
else:
|
||||
yield from self.run(env, txn.nbytes)
|
||||
env.process(self._dispatch_to_m_cpus(env, txn))
|
||||
|
||||
def _collect_response(self, resp_txn: Any) -> None:
|
||||
"""Receive a cube response and increment the aggregation counter."""
|
||||
key = resp_txn.request.request_id
|
||||
if key not in self._pending:
|
||||
return
|
||||
expected, received, parent_done = self._pending[key]
|
||||
received += 1
|
||||
if received >= expected:
|
||||
parent_done.succeed()
|
||||
del self._pending[key]
|
||||
else:
|
||||
self._pending[key] = (expected, received, parent_done)
|
||||
|
||||
def _dispatch_to_m_cpus(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Fan out sub-Transactions to target cube M_CPUs, wait for responses."""
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
|
||||
|
||||
request = txn.request
|
||||
try:
|
||||
cube_targets = self._resolve_cube_targets(request)
|
||||
except Exception:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
if not cube_targets:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Setup aggregation
|
||||
self._pending[request.request_id] = (len(cube_targets), 0, txn.done)
|
||||
|
||||
# Fan out to each target cube's M_CPU
|
||||
for sip, cube in cube_targets:
|
||||
try:
|
||||
m_cpu_id = self.ctx.resolver.find_m_cpu(sip, cube)
|
||||
path = self.ctx.router.find_node_path(self.node.id, m_cpu_id)
|
||||
except Exception:
|
||||
continue
|
||||
if len(path) < 2:
|
||||
continue
|
||||
sub_txn = Transaction(
|
||||
request=request, path=path, step=0,
|
||||
nbytes=txn.nbytes, done=env.event(),
|
||||
result_data=txn.result_data,
|
||||
)
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
|
||||
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
|
||||
"""Return list of (sip, cube) pairs to fan out to."""
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
|
||||
|
||||
target_cubes = getattr(request, "target_cubes", "all")
|
||||
|
||||
if isinstance(request, MemoryWriteMsg):
|
||||
sip = request.dst_sip
|
||||
if target_cubes == "all":
|
||||
cube = self._cube_from_pa(request.dst_pa, fallback=request.dst_cube)
|
||||
return [(sip, cube)]
|
||||
return [(sip, c) for c in target_cubes]
|
||||
|
||||
if isinstance(request, MemoryReadMsg):
|
||||
sip = request.src_sip
|
||||
if target_cubes == "all":
|
||||
cube = self._cube_from_pa(request.src_pa, fallback=request.src_cube)
|
||||
return [(sip, cube)]
|
||||
return [(sip, c) for c in target_cubes]
|
||||
|
||||
if isinstance(request, KernelLaunchMsg):
|
||||
my_sip = self._my_sip()
|
||||
if target_cubes != "all":
|
||||
return [(my_sip, c) for c in target_cubes]
|
||||
# "all": derive from tensor shards, filtered to this SIP
|
||||
seen: set[tuple[int, int]] = set()
|
||||
targets: list[tuple[int, int]] = []
|
||||
for arg in request.args:
|
||||
if arg.arg_kind != "tensor":
|
||||
continue
|
||||
for shard in arg.shards:
|
||||
if shard.sip != my_sip:
|
||||
continue
|
||||
key = (shard.sip, shard.cube)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
targets.append(key)
|
||||
return targets
|
||||
|
||||
return []
|
||||
|
||||
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
|
||||
"""Extract cube_id from a physical address, with fallback."""
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
try:
|
||||
return PhysAddr.decode(pa_val).cube_id
|
||||
except Exception:
|
||||
return fallback
|
||||
|
||||
def _my_sip(self) -> int:
|
||||
"""Extract this IO_CPU's SIP ID from its node ID (e.g. 'sip0.io0.io_cpu' → 0)."""
|
||||
return int(self.node.id.split(".")[0].replace("sip", ""))
|
||||
@@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class MCpuComponent(ComponentBase):
|
||||
"""M_CPU component: multi-PE DMA fan-out with response aggregation.
|
||||
|
||||
Forward path (ADR-0015 D5):
|
||||
When a forward Transaction arrives at m_cpu (terminal hop), M_CPU fans out
|
||||
DMA sub-Transactions to target PEs' HBM slices. target_pe on the request
|
||||
controls fan-out: int → single PE, "all" → all PEs in the cube.
|
||||
|
||||
Response path:
|
||||
ResponseMsg from each hbm_ctrl arrives back at m_cpu. Once all PE responses
|
||||
are collected, m_cpu sends an aggregate ResponseMsg on the reverse command
|
||||
path back to io_cpu.
|
||||
|
||||
Transit:
|
||||
When m_cpu is NOT the terminal hop (transit or response relay), the
|
||||
Transaction is forwarded normally to the next hop.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
# Pending fan-out tracking: request_id → (expected, received, all_done_event)
|
||||
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
|
||||
# Store parent txn for response sending: request_id → parent_txn
|
||||
self._parent_txns: dict[str, Any] = {}
|
||||
# DMA engine resources (ADR-0015 D5, ADR-0014 D4): capacity=1 each
|
||||
self._dma_write: simpy.Resource | None = None
|
||||
self._dma_read: simpy.Resource | None = None
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
self._dma_write = simpy.Resource(env, capacity=1)
|
||||
self._dma_read = simpy.Resource(env, capacity=1)
|
||||
super().start(env)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Dispatch forward txns, collect response txns."""
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg
|
||||
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
if getattr(txn, "is_response", False):
|
||||
self._collect_response(txn)
|
||||
else:
|
||||
yield from self.run(env, txn.nbytes)
|
||||
next_hop = txn.next_hop
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
elif self.ctx is not None and txn.request is not None:
|
||||
if isinstance(txn.request, KernelLaunchMsg):
|
||||
env.process(self._kernel_launch_fanout(env, txn))
|
||||
else:
|
||||
env.process(self._dma_fanout(env, txn))
|
||||
else:
|
||||
txn.done.succeed()
|
||||
|
||||
def _collect_response(self, resp_txn: Any) -> None:
|
||||
"""Receive a PE response and increment the aggregation counter."""
|
||||
key = resp_txn.request.request_id
|
||||
if key not in self._pending:
|
||||
return
|
||||
expected, received, all_done = self._pending[key]
|
||||
received += 1
|
||||
if received >= expected:
|
||||
all_done.succeed()
|
||||
del self._pending[key]
|
||||
else:
|
||||
self._pending[key] = (expected, received, all_done)
|
||||
|
||||
def _dma_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Fan out DMA sub-Transactions to target PE(s), wait for responses,
|
||||
then send aggregate response on reverse command path.
|
||||
|
||||
Each DMA transfer acquires the DMA resource (capacity=1 per ADR-0014 D4),
|
||||
so multi-PE fan-out is serialized through the DMA engine.
|
||||
"""
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg
|
||||
|
||||
request = txn.request
|
||||
target_pe = getattr(request, "target_pe", "all")
|
||||
|
||||
dst_nodes = self._resolve_dma_destinations(request, target_pe)
|
||||
if not dst_nodes:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Setup aggregation
|
||||
all_done = env.event()
|
||||
self._pending[request.request_id] = (len(dst_nodes), 0, all_done)
|
||||
self._parent_txns[request.request_id] = txn
|
||||
|
||||
# Select DMA resource based on operation type
|
||||
dma_res = self._dma_write if isinstance(request, MemoryWriteMsg) else self._dma_read
|
||||
|
||||
# Fan out DMA sub-txns (serialized through DMA resource)
|
||||
max_drain_ns = 0.0
|
||||
for dst_node in dst_nodes:
|
||||
try:
|
||||
dma_path = self.ctx.router.find_mcpu_dma_path(self.node.id, dst_node)
|
||||
except Exception:
|
||||
continue
|
||||
if len(dma_path) < 2:
|
||||
continue
|
||||
drain_ns = self.ctx.compute_drain_ns(dma_path, txn.nbytes)
|
||||
max_drain_ns = max(max_drain_ns, drain_ns)
|
||||
sub_txn = Transaction(
|
||||
request=request, path=dma_path, step=0,
|
||||
nbytes=txn.nbytes, done=env.event(),
|
||||
drain_ns=drain_ns,
|
||||
)
|
||||
with dma_res.request() as req:
|
||||
yield req
|
||||
yield self.out_ports[dma_path[1]].put(sub_txn.advance())
|
||||
|
||||
# Wait for all PE responses
|
||||
yield all_done
|
||||
txn.result_data["xfer_ns"] = max_drain_ns
|
||||
del self._parent_txns[request.request_id]
|
||||
|
||||
# Send aggregate response on reverse command path
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=request.correlation_id,
|
||||
request_id=request.request_id,
|
||||
src_cube=cube_id, src_pe=-1, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
|
||||
def _kernel_launch_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Fan out KernelLaunchMsg to target PE_CPU(s) via NOC (ADR-0009 D3).
|
||||
|
||||
Routes through find_node_path (M_CPU → NOC → PE_CPU command edges).
|
||||
Waits for sub_txn.done directly — no ResponseMsg needed for PE direction.
|
||||
Then sends aggregate ResponseMsg back to IO_CPU on the reverse path.
|
||||
"""
|
||||
request = txn.request
|
||||
target_pe = getattr(request, "target_pe", "all")
|
||||
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
|
||||
pe_ids = self._resolve_pe_ids(target_pe)
|
||||
|
||||
if not pe_ids:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Fan out to each PE_CPU and collect done events
|
||||
sub_dones: list[simpy.Event] = []
|
||||
sub_txns: list[Transaction] = []
|
||||
for pe_id in pe_ids:
|
||||
pe_cpu_id = f"{cube_prefix}.pe{pe_id}.pe_cpu"
|
||||
try:
|
||||
path = self.ctx.router.find_node_path(self.node.id, pe_cpu_id)
|
||||
except Exception:
|
||||
continue
|
||||
if len(path) < 2:
|
||||
continue
|
||||
sub_done = env.event()
|
||||
sub_txn = Transaction(
|
||||
request=request, path=path, step=0,
|
||||
nbytes=0, done=sub_done,
|
||||
)
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
sub_dones.append(sub_done)
|
||||
sub_txns.append(sub_txn)
|
||||
|
||||
if not sub_dones:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Wait for all PE_CPUs to complete
|
||||
for sd in sub_dones:
|
||||
yield sd
|
||||
|
||||
# Aggregate PE-internal metrics (max across PEs)
|
||||
pe_exec_values = [st.result_data.get("pe_exec_ns", 0.0) for st in sub_txns]
|
||||
if pe_exec_values:
|
||||
txn.result_data["pe_exec_ns"] = max(pe_exec_values)
|
||||
dma_values = [st.result_data.get("dma_ns", 0.0) for st in sub_txns]
|
||||
if dma_values:
|
||||
txn.result_data["dma_ns"] = max(dma_values)
|
||||
compute_values = [st.result_data.get("compute_ns", 0.0) for st in sub_txns]
|
||||
if compute_values:
|
||||
txn.result_data["compute_ns"] = max(compute_values)
|
||||
|
||||
# Send aggregate response on reverse command path back to IO_CPU
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=request.correlation_id,
|
||||
request_id=request.request_id,
|
||||
src_cube=cube_id, src_pe=-1, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
|
||||
def _resolve_dma_destinations(self, request: Any, target_pe: int | str) -> list[str]:
|
||||
"""Return list of HBM destination node_ids for DMA fan-out.
|
||||
|
||||
Uses PA-based resolution to determine the actual target cube and slice,
|
||||
enabling cross-cube DMA routing when the PA points to a remote cube.
|
||||
"""
|
||||
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
|
||||
|
||||
if isinstance(target_pe, int):
|
||||
return [f"{cube_prefix}.hbm_ctrl.slice{target_pe}"]
|
||||
|
||||
# PA-based resolution: extract actual target from physical address
|
||||
pa_val = getattr(request, "dst_pa", None) or getattr(request, "src_pa", None)
|
||||
if pa_val is not None:
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
try:
|
||||
pa = PhysAddr.decode(pa_val)
|
||||
return [self.ctx.resolver.resolve(pa)]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# "all" without PA (KernelLaunch): all slices in local cube
|
||||
n_slices = 8
|
||||
if self.ctx and self.ctx.spec:
|
||||
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
|
||||
n_slices = mm.get("hbm_slices_per_cube", 8)
|
||||
return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)]
|
||||
|
||||
def _resolve_pe_ids(self, target_pe: int | str) -> list[int]:
|
||||
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
|
||||
if isinstance(target_pe, int):
|
||||
return [target_pe]
|
||||
# "all": all PEs in local cube
|
||||
n_slices = 8
|
||||
if self.ctx and self.ctx.spec:
|
||||
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
|
||||
n_slices = mm.get("hbm_slices_per_cube", 8)
|
||||
return list(range(n_slices))
|
||||
@@ -0,0 +1,187 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class TwoDMeshNocComponent(ComponentBase):
|
||||
"""2D mesh NOC modeled as a single smart node.
|
||||
|
||||
Latency model:
|
||||
- Traversal latency = Manhattan distance between prev_hop and next_hop
|
||||
node positions, split into XY segments, traversed with pipeline.
|
||||
- overhead_ns (from node.attrs) is added once per traversal.
|
||||
|
||||
Contention model:
|
||||
- Each directed XY segment is a simpy.Resource(capacity=1).
|
||||
- Pipeline: next segment's resource is requested before the current
|
||||
segment's timeout completes, so a free downstream segment is acquired
|
||||
immediately (wormhole-style cut-through).
|
||||
- Two transactions sharing a segment (same row or column band) contend.
|
||||
|
||||
Concurrency:
|
||||
- _worker spawns an independent SimPy process per transaction, so the
|
||||
NOC is never serialized at the node level — only at segment resources.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._env: simpy.Environment | None = None
|
||||
self._links: dict[tuple, simpy.Resource] = {}
|
||||
self._x_grid: list[float] = []
|
||||
self._y_grid: list[float] = []
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
self._env = env
|
||||
self._build_grid()
|
||||
super().start(env)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
yield env.timeout(0)
|
||||
|
||||
# ── Grid construction ────────────────────────────────────────────
|
||||
|
||||
def _build_grid(self) -> None:
|
||||
if not self.ctx:
|
||||
return
|
||||
cube_prefix = self.node.id.rsplit(".", 1)[0]
|
||||
xs: set[float] = set()
|
||||
ys: set[float] = set()
|
||||
for node_id, pos in self.ctx.positions.items():
|
||||
if node_id.startswith(cube_prefix + ".") and pos is not None:
|
||||
xs.add(round(pos[0], 2))
|
||||
ys.add(round(pos[1], 2))
|
||||
self._x_grid = sorted(xs)
|
||||
self._y_grid = sorted(ys)
|
||||
|
||||
def _get_link(self, key: tuple) -> simpy.Resource:
|
||||
if key not in self._links:
|
||||
assert self._env is not None
|
||||
self._links[key] = simpy.Resource(self._env, capacity=1)
|
||||
return self._links[key]
|
||||
|
||||
# ── Worker ───────────────────────────────────────────────────────
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._route(env, txn))
|
||||
|
||||
def _route(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
prev_hop = txn.path[txn.step - 1] if txn.step > 0 else None
|
||||
next_hop = txn.next_hop
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
|
||||
links: list[tuple[tuple, float]] = []
|
||||
if prev_hop and next_hop and self.ctx:
|
||||
src_pos = self.ctx.positions.get(prev_hop)
|
||||
dst_pos = self.ctx.positions.get(next_hop)
|
||||
if src_pos and dst_pos:
|
||||
links = self._xy_links(src_pos, dst_pos)
|
||||
|
||||
if links:
|
||||
yield from self._traverse(env, links, overhead_ns)
|
||||
else:
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
else:
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
txn.done.succeed()
|
||||
|
||||
# ── XY routing and pipelined link traversal ──────────────────────
|
||||
|
||||
def _traverse(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
links: list[tuple[tuple, float]],
|
||||
overhead_ns: float,
|
||||
) -> Generator:
|
||||
"""Pipeline: request next segment before current timeout finishes."""
|
||||
ns_per_mm = self.ctx.ns_per_mm # type: ignore[union-attr]
|
||||
|
||||
# Acquire first link
|
||||
first_key, _ = links[0]
|
||||
current_resource = self._get_link(first_key)
|
||||
current_req = current_resource.request()
|
||||
yield current_req
|
||||
|
||||
for i, (_, dist_mm) in enumerate(links):
|
||||
# Request next link before current timeout (pipeline)
|
||||
if i + 1 < len(links):
|
||||
next_key, _ = links[i + 1]
|
||||
next_resource = self._get_link(next_key)
|
||||
next_req = next_resource.request()
|
||||
|
||||
yield env.timeout(dist_mm * ns_per_mm + (overhead_ns if i == 0 else 0.0))
|
||||
current_resource.release(current_req)
|
||||
|
||||
if i + 1 < len(links):
|
||||
yield next_req # usually already fulfilled (pipeline)
|
||||
current_resource = next_resource
|
||||
current_req = next_req
|
||||
|
||||
def _xy_links(
|
||||
self,
|
||||
src: tuple[float, float],
|
||||
dst: tuple[float, float],
|
||||
) -> list[tuple[tuple, float]]:
|
||||
"""XY routing: horizontal segment first, then vertical.
|
||||
|
||||
Returns list of (link_key, dist_mm) pairs, where link_key uniquely
|
||||
identifies a directed segment shared across concurrent transactions.
|
||||
"""
|
||||
x0, y0 = src
|
||||
x1, y1 = dst
|
||||
links: list[tuple[tuple, float]] = []
|
||||
|
||||
# Horizontal segment at y≈y0
|
||||
if abs(x0 - x1) > 1e-9:
|
||||
y_band = self._snap(y0, self._y_grid)
|
||||
for xa, xb in self._segments(x0, x1, self._x_grid):
|
||||
d = abs(xb - xa)
|
||||
if d > 1e-9:
|
||||
lo, hi = (xa, xb) if xa < xb else (xb, xa)
|
||||
dir_h = "E" if xb > xa else "W"
|
||||
links.append((("H", round(y_band, 2), round(lo, 2), round(hi, 2), dir_h), d))
|
||||
|
||||
# Vertical segment at x≈x1
|
||||
if abs(y0 - y1) > 1e-9:
|
||||
x_band = self._snap(x1, self._x_grid)
|
||||
for ya, yb in self._segments(y0, y1, self._y_grid):
|
||||
d = abs(yb - ya)
|
||||
if d > 1e-9:
|
||||
lo, hi = (ya, yb) if ya < yb else (yb, ya)
|
||||
dir_v = "S" if yb > ya else "N"
|
||||
links.append((("V", round(x_band, 2), round(lo, 2), round(hi, 2), dir_v), d))
|
||||
|
||||
return links
|
||||
|
||||
@staticmethod
|
||||
def _snap(val: float, grid: list[float]) -> float:
|
||||
if not grid:
|
||||
return val
|
||||
return min(grid, key=lambda g: abs(g - val))
|
||||
|
||||
@staticmethod
|
||||
def _segments(a: float, b: float, grid: list[float]) -> list[tuple[float, float]]:
|
||||
"""Consecutive (p_i, p_{i+1}) pairs covering range [a, b] using grid waypoints."""
|
||||
if abs(a - b) < 1e-9:
|
||||
return []
|
||||
lo, hi = (a, b) if a < b else (b, a)
|
||||
pts = [lo] + [g for g in grid if lo + 1e-9 < g < hi - 1e-9] + [hi]
|
||||
pairs = [(pts[i], pts[i + 1]) for i in range(len(pts) - 1)]
|
||||
if a > b:
|
||||
pairs = [(p2, p1) for p1, p2 in reversed(pairs)]
|
||||
return pairs
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PcieEpComponent(ComponentBase):
|
||||
"""PCIe endpoint: protocol processing overhead before forwarding.
|
||||
|
||||
Applies overhead_ns (from node.attrs) for PCIe protocol handling,
|
||||
then forwards via inherited _forward_txn().
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeCpuComponent(ComponentBase):
|
||||
"""PE_CPU: kernel execution controller (Stage 2).
|
||||
|
||||
Two-phase kernel execution (ADR-0014 D1):
|
||||
Phase 1 (compile): look up kernel from registry, run it with TLContext
|
||||
to generate a PeCommand list.
|
||||
Phase 2 (replay): iterate commands, dispatch to PE_SCHEDULER via
|
||||
PeInternalTxn, wait for blocking commands.
|
||||
|
||||
Non-kernel Transactions are forwarded normally.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._pe_prefix = node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
|
||||
try:
|
||||
self._pe_idx = int(self._pe_prefix.rsplit("pe", 1)[1])
|
||||
except (IndexError, ValueError):
|
||||
self._pe_idx = 0
|
||||
# Extract sip/cube index for multi-SIP/cube shard matching
|
||||
parts = node.id.split(".")
|
||||
try:
|
||||
self._sip_idx = int(parts[0].replace("sip", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._sip_idx = 0
|
||||
try:
|
||||
self._cube_idx = int(parts[1].replace("cube", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._cube_idx = 0
|
||||
|
||||
def _find_shard(self, shards: tuple) -> Any:
|
||||
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
|
||||
for s in shards:
|
||||
if s.sip == self._sip_idx and s.cube == self._cube_idx and s.pe == self._pe_idx:
|
||||
return s
|
||||
return shards[min(self._pe_idx, len(shards) - 1)]
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg
|
||||
|
||||
if hasattr(txn, "request") and isinstance(txn.request, KernelLaunchMsg):
|
||||
yield from self._execute_kernel(env, txn)
|
||||
else:
|
||||
yield from self._forward_txn(env, txn)
|
||||
|
||||
def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Compile kernel function and replay command trace."""
|
||||
from kernbench.common.pe_commands import (
|
||||
CompositeCmd,
|
||||
PeCpuOverheadCmd,
|
||||
PeInternalTxn,
|
||||
WaitCmd,
|
||||
)
|
||||
from kernbench.triton_emu.registry import get_kernel
|
||||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||
|
||||
request = txn.request
|
||||
|
||||
# Phase 1: Compile — apply PE_CPU setup overhead, then run kernel
|
||||
yield from self.run(env, 0)
|
||||
|
||||
kernel_fn = get_kernel(request.kernel_ref.name)
|
||||
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
|
||||
|
||||
# Unpack KernelLaunchMsg.args into positional args for kernel function
|
||||
# TensorArg → PA (pointer), ScalarArg → value
|
||||
kernel_args: list = []
|
||||
for arg in request.args:
|
||||
if arg.arg_kind == "tensor":
|
||||
shard = self._find_shard(arg.shards)
|
||||
kernel_args.append(shard.pa)
|
||||
elif arg.arg_kind == "scalar":
|
||||
kernel_args.append(arg.value)
|
||||
|
||||
run_kernel(kernel_fn, tl, *kernel_args)
|
||||
commands = tl.commands
|
||||
|
||||
# Phase 2: Replay — dispatch commands to PE_SCHEDULER
|
||||
pe_exec_start = env.now
|
||||
scheduler_id = f"{self._pe_prefix}.pe_scheduler"
|
||||
pending: dict[str, simpy.Event] = {} # completion_id → done event
|
||||
composite_results: list[dict] = [] # collect result_data from CompositeCmd txns
|
||||
|
||||
for cmd in commands:
|
||||
if isinstance(cmd, PeCpuOverheadCmd):
|
||||
yield env.timeout(cmd.cycles)
|
||||
elif isinstance(cmd, WaitCmd):
|
||||
if cmd.handle is not None:
|
||||
evt = pending.pop(cmd.handle.id, None)
|
||||
if evt:
|
||||
yield evt
|
||||
else:
|
||||
# Wait all pending completions
|
||||
for evt in pending.values():
|
||||
yield evt
|
||||
pending.clear()
|
||||
elif isinstance(cmd, CompositeCmd):
|
||||
# Non-blocking: dispatch to scheduler, track completion
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt,
|
||||
pe_prefix=self._pe_prefix,
|
||||
)
|
||||
composite_results.append(pe_txn.result_data)
|
||||
yield self.out_ports[scheduler_id].put(pe_txn)
|
||||
pending[cmd.completion.id] = done_evt
|
||||
else:
|
||||
# Blocking: dispatch and wait for completion
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt,
|
||||
pe_prefix=self._pe_prefix,
|
||||
)
|
||||
yield self.out_ports[scheduler_id].put(pe_txn)
|
||||
yield done_evt
|
||||
|
||||
# Wait for any remaining pending completions
|
||||
for evt in pending.values():
|
||||
yield evt
|
||||
|
||||
# Record PE-internal execution time
|
||||
txn.result_data["pe_exec_ns"] = env.now - pe_exec_start
|
||||
|
||||
# Aggregate dma_ns / compute_ns from CompositeCmd results
|
||||
total_dma_ns = 0.0
|
||||
total_compute_ns = 0.0
|
||||
for rd in composite_results:
|
||||
total_dma_ns += rd.get("dma_ns", 0.0)
|
||||
total_compute_ns += rd.get("compute_ns", 0.0)
|
||||
txn.result_data["dma_ns"] = total_dma_ns
|
||||
txn.result_data["compute_ns"] = total_compute_ns
|
||||
|
||||
# Signal original Transaction done
|
||||
txn.done.succeed()
|
||||
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import PeEngineBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeDmaComponent(PeEngineBase):
|
||||
"""PE_DMA: dual-channel DMA engine with READ and WRITE resources.
|
||||
|
||||
Each channel has capacity=1 (ADR-0014 D4):
|
||||
- DMA_READ and DMA_WRITE may execute concurrently.
|
||||
- Multiple READs cannot overlap; multiple WRITEs cannot overlap.
|
||||
|
||||
Handles two message types:
|
||||
- Transaction: external fabric messages (PeDmaMsg probes, M_CPU DMA)
|
||||
- PeInternalTxn: PE-internal commands from PE_SCHEDULER
|
||||
(DmaReadCmd → HBM read, DmaWriteCmd → HBM write)
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._dma_read: simpy.Resource | None = None
|
||||
self._dma_write: simpy.Resource | None = None
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
self._dma_read = simpy.Resource(env, capacity=1)
|
||||
self._dma_write = simpy.Resource(env, capacity=1)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
yield env.timeout(0)
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Handle PE-internal DMA command: resolve PA → HBM path → transfer."""
|
||||
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.runtime_api.kernel import PeDmaMsg
|
||||
|
||||
cmd = pe_txn.command
|
||||
assert self._dma_read is not None and self._dma_write is not None
|
||||
|
||||
# Determine direction and target PA
|
||||
if isinstance(cmd, DmaReadCmd):
|
||||
dma_res = self._dma_read
|
||||
target_pa = cmd.src_pa
|
||||
is_write = False
|
||||
elif isinstance(cmd, DmaWriteCmd):
|
||||
dma_res = self._dma_write
|
||||
target_pa = cmd.dst_pa
|
||||
is_write = True
|
||||
else:
|
||||
pe_txn.done.succeed()
|
||||
return
|
||||
|
||||
# Resolve PA → HBM node and compute path
|
||||
pa = PhysAddr.decode(target_pa)
|
||||
dst_node = self.ctx.resolver.resolve(pa)
|
||||
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
|
||||
drain_ns = self.ctx.compute_drain_ns(path, cmd.nbytes)
|
||||
|
||||
# Acquire DMA channel (command issue serialization)
|
||||
with dma_res.request() as req:
|
||||
yield req
|
||||
# Create sub-Transaction with PeDmaMsg (HbmCtrl handles it directly)
|
||||
sub_done = env.event()
|
||||
sub_request = PeDmaMsg(
|
||||
correlation_id="pe_internal",
|
||||
request_id=f"dma_{id(pe_txn)}",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
dst_pa=target_pa, nbytes=cmd.nbytes,
|
||||
is_write=is_write,
|
||||
)
|
||||
sub_txn = Transaction(
|
||||
request=sub_request, path=path, step=0,
|
||||
nbytes=cmd.nbytes, done=sub_done, drain_ns=drain_ns,
|
||||
)
|
||||
# Send to next hop (path[0] is pe_dma itself, path[1] is xbar)
|
||||
if len(path) > 1:
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
# DMA channel released after issue
|
||||
|
||||
# Wait for HBM transfer completion
|
||||
yield sub_done
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition."""
|
||||
dma_res = self._select_channel(txn)
|
||||
with dma_res.request() as req:
|
||||
yield req
|
||||
next_hop = txn.next_hop
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
else:
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
txn.done.succeed()
|
||||
|
||||
def _select_channel(self, txn: Any) -> simpy.Resource:
|
||||
"""Select DMA channel based on request type."""
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg
|
||||
|
||||
assert self._dma_read is not None and self._dma_write is not None
|
||||
if isinstance(txn.request, MemoryWriteMsg):
|
||||
return self._dma_write
|
||||
return self._dma_read
|
||||
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import PeEngineBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
# dtype → bit width (for TFLOPS scaling)
|
||||
_DTYPE_BITS: dict[str, int] = {
|
||||
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
|
||||
"f32": 32, "fp32": 32, "float32": 32,
|
||||
"i8": 8, "int8": 8,
|
||||
"i16": 16, "int16": 16,
|
||||
"i32": 32, "int32": 32,
|
||||
}
|
||||
|
||||
|
||||
class PeGemmComponent(PeEngineBase):
|
||||
"""PE_GEMM: matrix multiplication engine sharing accel_slot (ADR-0014 D4).
|
||||
|
||||
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
|
||||
exclusive with PE_MATH within the same PE.
|
||||
|
||||
Compute latency model:
|
||||
FLOPs = 2 * M * K * N
|
||||
effective_tflops = peak_tflops_f16 * (16 / dtype_bits)
|
||||
compute_ns = FLOPs / (effective_tflops * 1e3)
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._accel: simpy.Resource | None = None
|
||||
self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0))
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
resource_name = self.node.attrs.get("shared_resource")
|
||||
if resource_name and self.ctx:
|
||||
self._accel = self.ctx.get_shared_resource(
|
||||
env, f"{self._pe_prefix}.{resource_name}"
|
||||
)
|
||||
|
||||
def _compute_ns(self, m: int, k: int, n: int, dtype: str) -> float:
|
||||
"""Compute GEMM latency in nanoseconds."""
|
||||
if self._peak_tflops_f16 <= 0:
|
||||
return float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
dtype_bits = _DTYPE_BITS.get(dtype, 16)
|
||||
effective_tflops = self._peak_tflops_f16 * (16.0 / dtype_bits)
|
||||
flops = 2.0 * m * k * n
|
||||
return flops / (effective_tflops * 1e3)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
from kernbench.common.pe_commands import GemmCmd
|
||||
|
||||
cmd = pe_txn.command
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
if isinstance(cmd, GemmCmd):
|
||||
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
|
||||
yield env.timeout(ns)
|
||||
else:
|
||||
yield from self.run(env, 0)
|
||||
else:
|
||||
if isinstance(cmd, GemmCmd):
|
||||
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
|
||||
yield env.timeout(ns)
|
||||
else:
|
||||
yield from self.run(env, 0)
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Transaction forwarding with accel_slot acquisition."""
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
yield from super()._forward_txn(env, txn)
|
||||
else:
|
||||
yield from super()._forward_txn(env, txn)
|
||||
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import PeEngineBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeMathComponent(PeEngineBase):
|
||||
"""PE_MATH: element-wise computation engine sharing accel_slot (ADR-0014 D4).
|
||||
|
||||
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
|
||||
exclusive with PE_GEMM within the same PE.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._accel: simpy.Resource | None = None
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
resource_name = self.node.attrs.get("shared_resource")
|
||||
if resource_name and self.ctx:
|
||||
self._accel = self.ctx.get_shared_resource(
|
||||
env, f"{self._pe_prefix}.{resource_name}"
|
||||
)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
yield from self.run(env, 0)
|
||||
else:
|
||||
yield from self.run(env, 0)
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Transaction forwarding with accel_slot acquisition."""
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
yield from super()._forward_txn(env, txn)
|
||||
else:
|
||||
yield from super()._forward_txn(env, txn)
|
||||
@@ -0,0 +1,245 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeSchedulerComponent(ComponentBase):
|
||||
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1).
|
||||
|
||||
Receives PeInternalTxn from PE_CPU, routes to the appropriate engine:
|
||||
- DmaReadCmd / DmaWriteCmd → PE_DMA
|
||||
- GemmCmd → PE_GEMM
|
||||
- MathCmd → PE_MATH
|
||||
- CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2)
|
||||
|
||||
Composite GEMM pipeline (32x64x32 tiles):
|
||||
DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t)
|
||||
with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
|
||||
|
||||
Applies scheduler overhead_ns before dispatching each command.
|
||||
Non-PeInternalTxn messages are forwarded via inherited _forward_txn().
|
||||
"""
|
||||
|
||||
# Scheduler tile dimensions (ADR-0014 D3.2)
|
||||
TILE_M = 32
|
||||
TILE_K = 64
|
||||
TILE_N = 32
|
||||
|
||||
# Command → engine suffix dispatch table.
|
||||
# New engines: add a single entry here (e.g. ConvCmd: "pe_conv").
|
||||
_CMD_DISPATCH: dict[type, str] = {}
|
||||
|
||||
@classmethod
|
||||
def _ensure_dispatch_table(cls) -> None:
|
||||
if cls._CMD_DISPATCH:
|
||||
return
|
||||
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd
|
||||
|
||||
cls._CMD_DISPATCH = {
|
||||
DmaReadCmd: "pe_dma",
|
||||
DmaWriteCmd: "pe_dma",
|
||||
GemmCmd: "pe_gemm",
|
||||
MathCmd: "pe_math",
|
||||
}
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._pe_prefix = node.id.rsplit(".", 1)[0]
|
||||
self._ensure_dispatch_table()
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, PeInternalTxn):
|
||||
env.process(self._dispatch(env, msg))
|
||||
else:
|
||||
yield from self._forward_txn(env, msg)
|
||||
|
||||
def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Route a PeInternalTxn to the correct engine via dispatch table."""
|
||||
from kernbench.common.pe_commands import CompositeCmd
|
||||
|
||||
# Scheduler overhead
|
||||
yield from self.run(env, 0)
|
||||
|
||||
cmd = pe_txn.command
|
||||
|
||||
# Check dispatch table first
|
||||
engine_suffix = self._CMD_DISPATCH.get(type(cmd))
|
||||
if engine_suffix is not None:
|
||||
yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn)
|
||||
return
|
||||
|
||||
# CompositeCmd: tiled pipeline (not a simple forward)
|
||||
if isinstance(cmd, CompositeCmd):
|
||||
yield from self._dispatch_composite(env, pe_txn)
|
||||
return
|
||||
|
||||
# Unknown command — signal done immediately
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Composite tiled pipeline (ADR-0014 D3.2).
|
||||
|
||||
GEMM: 3-stage pipeline with b-tile streaming from HBM.
|
||||
MATH: sequential compute + DMA_WRITE (no tiling).
|
||||
"""
|
||||
from kernbench.common.pe_commands import CompositeCmd
|
||||
|
||||
cmd = pe_txn.command
|
||||
assert isinstance(cmd, CompositeCmd)
|
||||
if cmd.op == "gemm" and cmd.b is not None:
|
||||
yield from self._pipeline_gemm(env, pe_txn, cmd)
|
||||
else:
|
||||
yield from self._pipeline_math(env, pe_txn, cmd)
|
||||
|
||||
def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
|
||||
"""Tiled GEMM pipeline: stream b tiles from HBM, compute, write results.
|
||||
|
||||
Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref).
|
||||
Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t)
|
||||
Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
|
||||
"""
|
||||
from kernbench.common.pe_commands import (
|
||||
DmaReadCmd,
|
||||
DmaWriteCmd,
|
||||
GemmCmd,
|
||||
PeInternalTxn as PeTxn,
|
||||
TensorHandle,
|
||||
)
|
||||
|
||||
pp = self._pe_prefix
|
||||
a = cmd.a # already in TCM
|
||||
b = cmd.b # HBM reference (via tl.ref)
|
||||
|
||||
M, K_a = a.shape[-2], a.shape[-1]
|
||||
K_b, N = b.shape[-2], b.shape[-1]
|
||||
dtype = a.dtype
|
||||
dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2
|
||||
|
||||
# Tile counts
|
||||
n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K)
|
||||
n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N)
|
||||
n_tiles = n_tiles_k * n_tiles_n
|
||||
|
||||
prev_compute_done = None
|
||||
prev_write_done = None
|
||||
total_dma_ns = 0.0
|
||||
total_compute_ns = 0.0
|
||||
|
||||
for tile_idx in range(n_tiles):
|
||||
tk = tile_idx // n_tiles_n
|
||||
tn = tile_idx % n_tiles_n
|
||||
|
||||
k_start = tk * self.TILE_K
|
||||
n_start = tn * self.TILE_N
|
||||
tile_k = min(self.TILE_K, K_a - k_start)
|
||||
tile_n = min(self.TILE_N, N - n_start)
|
||||
tile_nbytes = tile_k * tile_n * dtype_bytes
|
||||
|
||||
# --- Stage 1: DMA_READ b_tile from HBM ---
|
||||
read_done = env.event()
|
||||
b_tile_pa = b.pa + (k_start * N + n_start) * dtype_bytes
|
||||
b_tile_handle = TensorHandle(
|
||||
id=f"b_tile_{tile_idx}", pa=b_tile_pa,
|
||||
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
|
||||
)
|
||||
read_cmd = DmaReadCmd(handle=b_tile_handle, src_pa=b_tile_pa, nbytes=tile_nbytes)
|
||||
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
|
||||
|
||||
# Wait for previous compute before starting this tile's compute
|
||||
if prev_compute_done is not None:
|
||||
yield prev_compute_done
|
||||
|
||||
# Wait for this tile's DMA_READ
|
||||
yield read_done
|
||||
total_dma_ns += env.now - t0
|
||||
|
||||
# --- Stage 2: COMPUTE (GEMM) ---
|
||||
compute_done = env.event()
|
||||
out_handle = TensorHandle(
|
||||
id=f"out_tile_{tile_idx}", pa=0,
|
||||
shape=(M, tile_n), dtype=dtype,
|
||||
nbytes=M * tile_n * dtype_bytes,
|
||||
)
|
||||
compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle,
|
||||
m=M, k=tile_k, n=tile_n)
|
||||
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn)
|
||||
|
||||
# Wait for previous write (DMA_WRITE serialization)
|
||||
if prev_write_done is not None:
|
||||
yield prev_write_done
|
||||
|
||||
# Wait for compute of THIS tile
|
||||
yield compute_done
|
||||
total_compute_ns += env.now - t0
|
||||
prev_compute_done = compute_done
|
||||
|
||||
# --- Stage 3: DMA_WRITE out_tile to HBM ---
|
||||
write_done = env.event()
|
||||
out_tile_pa = cmd.out_pa + n_start * dtype_bytes
|
||||
write_nbytes = M * tile_n * dtype_bytes
|
||||
write_cmd = DmaWriteCmd(handle=out_handle, dst_pa=out_tile_pa, nbytes=write_nbytes)
|
||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||
prev_write_done = write_done
|
||||
|
||||
# Wait for final write
|
||||
if prev_write_done is not None:
|
||||
t0 = env.now
|
||||
yield prev_write_done
|
||||
total_dma_ns += env.now - t0
|
||||
|
||||
pe_txn.result_data["dma_ns"] = total_dma_ns
|
||||
pe_txn.result_data["compute_ns"] = total_compute_ns
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
|
||||
"""Non-GEMM composite: sequential compute + DMA_WRITE (no tiling)."""
|
||||
from kernbench.common.pe_commands import (
|
||||
DmaWriteCmd,
|
||||
MathCmd,
|
||||
PeInternalTxn as PeTxn,
|
||||
)
|
||||
|
||||
pp = self._pe_prefix
|
||||
|
||||
# Step 1: Compute (MATH)
|
||||
compute_done = env.event()
|
||||
compute_cmd = MathCmd(
|
||||
op=cmd.math_op or "identity",
|
||||
inputs=(cmd.a,), out=cmd.a,
|
||||
)
|
||||
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
|
||||
yield self.out_ports[f"{pp}.pe_math"].put(compute_txn)
|
||||
yield compute_done
|
||||
|
||||
# Step 2: DMA_WRITE result to HBM
|
||||
write_done = env.event()
|
||||
write_cmd = DmaWriteCmd(handle=cmd.a, dst_pa=cmd.out_pa, nbytes=cmd.out_nbytes)
|
||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||
yield write_done
|
||||
|
||||
pe_txn.done.succeed()
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeTcmComponent(ComponentBase):
|
||||
"""PE_TCM: tightly-coupled memory / local SRAM staging buffer.
|
||||
|
||||
Terminal storage component for PE-internal dataflow (ADR-0014 D5).
|
||||
Phase 0: applies overhead_ns and drain_ns at terminal.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
|
||||
def run(self, env, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class SramComponent(ComponentBase):
|
||||
"""Cube SRAM: terminal component that models SRAM access latency.
|
||||
|
||||
Applies overhead_ns processing overhead (from node.attrs).
|
||||
On completion, sends a ResponseMsg back on the reverse path.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Terminal worker: process, apply drain, send response."""
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
yield from self.run(env, txn.nbytes)
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
yield from self._send_response(env, txn)
|
||||
|
||||
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Create ResponseMsg and send on reverse path."""
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2 and self.ctx:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=txn.request.correlation_id,
|
||||
request_id=txn.request.request_id,
|
||||
src_cube=cube_id, src_pe=-1, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
|
||||
|
||||
class AllocationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AddressConfig:
|
||||
sip_count: int
|
||||
cubes_per_sip: int
|
||||
pes_per_cube: int
|
||||
hbm_bytes_per_cube: int
|
||||
hbm_slices_per_cube: int
|
||||
tcm_bytes_per_pe: int
|
||||
tcm_scheduler_reserved_bytes: int
|
||||
sram_bytes_per_cube: int
|
||||
|
||||
@property
|
||||
def hbm_slice_bytes(self) -> int:
|
||||
return self.hbm_bytes_per_cube // self.hbm_slices_per_cube
|
||||
|
||||
@property
|
||||
def tcm_allocatable_bytes(self) -> int:
|
||||
return self.tcm_bytes_per_pe - self.tcm_scheduler_reserved_bytes
|
||||
|
||||
|
||||
class PEMemAllocator:
|
||||
def __init__(
|
||||
self, rack_id: int, sip_id: int, cube_id: int, pe_id: int, cfg: AddressConfig,
|
||||
) -> None:
|
||||
self._rack_id = rack_id
|
||||
self._sip_id = sip_id
|
||||
self._cube_id = cube_id
|
||||
self._pe_id = pe_id
|
||||
self._cfg = cfg
|
||||
self._hbm_cursor = 0
|
||||
self._tcm_cursor = 0
|
||||
|
||||
def alloc_hbm(self, nbytes: int) -> PhysAddr:
|
||||
if self._hbm_cursor + nbytes > self._cfg.hbm_slice_bytes:
|
||||
raise AllocationError(
|
||||
f"HBM overflow: need {nbytes}, "
|
||||
f"available {self._cfg.hbm_slice_bytes - self._hbm_cursor}"
|
||||
)
|
||||
pa = PhysAddr.pe_hbm_addr(
|
||||
rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id,
|
||||
pe_id=self._pe_id, pe_local_hbm_offset=self._hbm_cursor,
|
||||
slice_size_bytes=self._cfg.hbm_slice_bytes,
|
||||
)
|
||||
self._hbm_cursor += nbytes
|
||||
return pa
|
||||
|
||||
def alloc_tcm(self, nbytes: int) -> PhysAddr:
|
||||
if self._tcm_cursor + nbytes > self._cfg.tcm_allocatable_bytes:
|
||||
raise AllocationError(
|
||||
f"TCM overflow: need {nbytes}, "
|
||||
f"available {self._cfg.tcm_allocatable_bytes - self._tcm_cursor}"
|
||||
)
|
||||
pa = PhysAddr.pe_tcm_addr(
|
||||
rack_id=self._rack_id, sip_id=self._sip_id, cube_id=self._cube_id,
|
||||
pe_id=self._pe_id, tcm_offset=self._tcm_cursor,
|
||||
)
|
||||
self._tcm_cursor += nbytes
|
||||
return pa
|
||||
|
||||
@property
|
||||
def hbm_used(self) -> int:
|
||||
return self._hbm_cursor
|
||||
|
||||
@property
|
||||
def hbm_total(self) -> int:
|
||||
return self._cfg.hbm_slice_bytes
|
||||
|
||||
@property
|
||||
def tcm_used(self) -> int:
|
||||
return self._tcm_cursor
|
||||
|
||||
@property
|
||||
def tcm_total(self) -> int:
|
||||
return self._cfg.tcm_allocatable_bytes
|
||||
@@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Literal
|
||||
|
||||
MAX_51 = (1 << 51) - 1
|
||||
|
||||
|
||||
class PhysAddrError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _chk_range(name: str, v: int, bits: int) -> None:
|
||||
if not (0 <= v < (1 << bits)):
|
||||
raise PhysAddrError(f"{name} out of range for {bits} bits: {v}")
|
||||
|
||||
|
||||
def _chk_max(name: str, v: int, maxv: int) -> None:
|
||||
if not (0 <= v <= maxv):
|
||||
raise PhysAddrError(f"{name} out of range (0..{maxv}): {v}")
|
||||
|
||||
|
||||
class UnitType(IntEnum):
|
||||
PE = 0
|
||||
MCPU = 1
|
||||
SRAM = 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PhysAddr:
|
||||
"""
|
||||
51-bit physical address value object.
|
||||
|
||||
Layout:
|
||||
[50:47] rack_id (4)
|
||||
[46:43] sip_id (4)
|
||||
[42:38] sip_seg (5) # cube_id
|
||||
[37:0] local_offset (38) => each segment is 256GB
|
||||
|
||||
local_offset:
|
||||
[37] selector: 1 = HBM window (128GB reserved), 0 = PE resource window
|
||||
"""
|
||||
|
||||
rack_id: int
|
||||
sip_id: int
|
||||
sip_seg: int
|
||||
local_offset: int
|
||||
|
||||
kind: Literal["hbm", "pe_resource", "raw"] = "raw"
|
||||
cube_id: int = 0
|
||||
unit_type: UnitType = UnitType.PE
|
||||
pe_id: int = 0
|
||||
ext: int = 0
|
||||
sub_offset: int = 0
|
||||
hbm_offset: int = 0
|
||||
|
||||
HBM_WINDOW_BYTES = 1 << 37 # 128GB
|
||||
|
||||
def encode(self) -> int:
|
||||
_chk_range("rack_id", self.rack_id, 4)
|
||||
_chk_range("sip_id", self.sip_id, 4)
|
||||
_chk_range("sip_seg", self.sip_seg, 5)
|
||||
_chk_range("local_offset", self.local_offset, 38)
|
||||
addr = (self.rack_id << 47) | (self.sip_id << 43) | (self.sip_seg << 38) | self.local_offset
|
||||
if not (0 <= addr <= MAX_51):
|
||||
raise PhysAddrError("address exceeds 51-bit space")
|
||||
return addr
|
||||
|
||||
@staticmethod
|
||||
def decode(addr: int) -> PhysAddr:
|
||||
if not (0 <= addr <= MAX_51):
|
||||
raise PhysAddrError("addr must be a 51-bit value")
|
||||
rack = (addr >> 47) & 0xF
|
||||
sip_id = (addr >> 43) & 0xF
|
||||
sip_seg = (addr >> 38) & 0x1F
|
||||
off = addr & ((1 << 38) - 1)
|
||||
cube_id = sip_seg
|
||||
sel = (off >> 37) & 0x1
|
||||
if sel == 1:
|
||||
hbm_offset = int(off & ((1 << 37) - 1))
|
||||
return PhysAddr(
|
||||
rack_id=rack,
|
||||
sip_id=sip_id,
|
||||
sip_seg=sip_seg,
|
||||
local_offset=off,
|
||||
kind="hbm",
|
||||
cube_id=cube_id,
|
||||
hbm_offset=hbm_offset,
|
||||
)
|
||||
# PE resource decode
|
||||
raw_ut = int((off >> 34) & 0x7)
|
||||
try:
|
||||
unit_type = UnitType(raw_ut)
|
||||
except ValueError:
|
||||
raise PhysAddrError(f"unknown unit_type: {raw_ut}") from None
|
||||
pe_id = int((off >> 30) & 0xF)
|
||||
ext = int((off >> 29) & 0x1)
|
||||
sub_offset = int(off & ((1 << 29) - 1))
|
||||
return PhysAddr(
|
||||
rack_id=rack,
|
||||
sip_id=sip_id,
|
||||
sip_seg=sip_seg,
|
||||
local_offset=off,
|
||||
kind="pe_resource",
|
||||
cube_id=cube_id,
|
||||
unit_type=unit_type,
|
||||
pe_id=pe_id,
|
||||
ext=ext,
|
||||
sub_offset=sub_offset,
|
||||
hbm_offset=0,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def hbm_addr(*, rack_id: int, sip_id: int, cube_id: int, hbm_offset: int) -> PhysAddr:
|
||||
_chk_max("cube_id", cube_id, 31)
|
||||
_chk_range("hbm_offset", hbm_offset, 37)
|
||||
sip_seg = cube_id
|
||||
local_offset = (1 << 37) | int(hbm_offset)
|
||||
return PhysAddr(
|
||||
rack_id=rack_id,
|
||||
sip_id=sip_id,
|
||||
sip_seg=sip_seg,
|
||||
local_offset=local_offset,
|
||||
kind="hbm",
|
||||
cube_id=cube_id,
|
||||
hbm_offset=int(hbm_offset),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def pe_hbm_addr(
|
||||
*,
|
||||
rack_id: int,
|
||||
sip_id: int,
|
||||
cube_id: int,
|
||||
pe_id: int,
|
||||
pe_local_hbm_offset: int,
|
||||
slice_size_bytes: int,
|
||||
) -> PhysAddr:
|
||||
_chk_max("cube_id", cube_id, 31)
|
||||
_chk_range("pe_id", pe_id, 4)
|
||||
if not (0 <= pe_local_hbm_offset < slice_size_bytes):
|
||||
raise PhysAddrError("pe_local_hbm_offset out of PE local slice range")
|
||||
hbm_offset = int(pe_id) * int(slice_size_bytes) + int(pe_local_hbm_offset)
|
||||
if not (0 <= hbm_offset < PhysAddr.HBM_WINDOW_BYTES):
|
||||
raise PhysAddrError("HBM offset exceeds reserved 128GB window")
|
||||
return PhysAddr.hbm_addr(
|
||||
rack_id=rack_id, sip_id=sip_id, cube_id=cube_id, hbm_offset=hbm_offset
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def hbm_pe_id(hbm_offset: int, slice_size_bytes: int) -> int:
|
||||
return hbm_offset // slice_size_bytes
|
||||
|
||||
@staticmethod
|
||||
def cube_sram_addr(
|
||||
*, rack_id: int, sip_id: int, cube_id: int, sram_offset: int,
|
||||
) -> PhysAddr:
|
||||
_chk_max("cube_id", cube_id, 31)
|
||||
_chk_range("sram_offset", sram_offset, 29)
|
||||
sip_seg = cube_id
|
||||
local_offset = (UnitType.SRAM << 34) | sram_offset
|
||||
return PhysAddr(
|
||||
rack_id=rack_id, sip_id=sip_id, sip_seg=sip_seg,
|
||||
local_offset=local_offset,
|
||||
kind="pe_resource", cube_id=cube_id,
|
||||
unit_type=UnitType.SRAM, sub_offset=sram_offset,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def pe_tcm_addr(
|
||||
*, rack_id: int, sip_id: int, cube_id: int, pe_id: int, tcm_offset: int,
|
||||
) -> PhysAddr:
|
||||
_chk_max("cube_id", cube_id, 31)
|
||||
_chk_range("pe_id", pe_id, 4)
|
||||
_chk_range("tcm_offset", tcm_offset, 29)
|
||||
sip_seg = cube_id
|
||||
local_offset = (UnitType.PE << 34) | (pe_id << 30) | tcm_offset
|
||||
return PhysAddr(
|
||||
rack_id=rack_id, sip_id=sip_id, sip_seg=sip_seg,
|
||||
local_offset=local_offset,
|
||||
kind="pe_resource", cube_id=cube_id,
|
||||
unit_type=UnitType.PE, pe_id=pe_id, sub_offset=tcm_offset,
|
||||
)
|
||||
@@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DPPolicy:
|
||||
"""Two-level data-parallel policy: cube-level + pe-level."""
|
||||
|
||||
cube: Literal["replicate", "shard_m", "shard_k"] = "replicate"
|
||||
pe: Literal["replicate", "column_wise", "row_wise"] = "replicate"
|
||||
|
||||
|
||||
def resolve_dp_policy(
|
||||
policy: DPPolicy,
|
||||
*,
|
||||
shape: tuple[int, int],
|
||||
itemsize: int,
|
||||
num_pe: int,
|
||||
num_cubes: int = 1,
|
||||
) -> list[ShardSpec]:
|
||||
"""Resolve a DPPolicy into a list[ShardSpec] with two-level resolution.
|
||||
|
||||
Cube-level policy distributes across cubes, pe-level distributes within
|
||||
each cube. ShardSpec.pe_index uses flat indexing: cube_id * num_pe + pe_id.
|
||||
"""
|
||||
_PE_RESOLVERS = {
|
||||
"replicate": replicate,
|
||||
"column_wise": column_wise,
|
||||
"row_wise": row_wise,
|
||||
}
|
||||
resolver = _PE_RESOLVERS.get(policy.pe)
|
||||
if resolver is None:
|
||||
raise ValueError(f"Unknown pe-level policy: {policy.pe}")
|
||||
|
||||
if num_cubes <= 1:
|
||||
return resolver(shape=shape, itemsize=itemsize, num_pe=num_pe)
|
||||
|
||||
# Two-level resolution: cube-level → pe-level
|
||||
M, K = shape
|
||||
all_shards: list[ShardSpec] = []
|
||||
|
||||
for cube_id in range(num_cubes):
|
||||
# Determine per-cube shape based on cube-level policy
|
||||
if policy.cube == "replicate":
|
||||
cube_shape = (M, K)
|
||||
cube_offset = 0
|
||||
elif policy.cube == "shard_m":
|
||||
chunk_m = M // num_cubes
|
||||
cube_shape = (chunk_m, K)
|
||||
cube_offset = cube_id * chunk_m * K * itemsize
|
||||
elif policy.cube == "shard_k":
|
||||
chunk_k = K // num_cubes
|
||||
cube_shape = (M, chunk_k)
|
||||
cube_offset = cube_id * M * chunk_k * itemsize
|
||||
else:
|
||||
raise ValueError(f"Unknown cube-level policy: {policy.cube}")
|
||||
|
||||
# Resolve pe-level within this cube's shape
|
||||
pe_shards = resolver(shape=cube_shape, itemsize=itemsize, num_pe=num_pe)
|
||||
|
||||
# Remap pe_index to flat index and adjust offset
|
||||
for ps in pe_shards:
|
||||
flat_idx = cube_id * num_pe + ps.pe_index
|
||||
all_shards.append(ShardSpec(
|
||||
pe_index=flat_idx,
|
||||
offset_bytes=cube_offset + ps.offset_bytes,
|
||||
nbytes=ps.nbytes,
|
||||
))
|
||||
|
||||
return all_shards
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShardSpec:
|
||||
pe_index: int
|
||||
offset_bytes: int
|
||||
nbytes: int
|
||||
|
||||
|
||||
def column_wise(
|
||||
*, shape: tuple[int, int], itemsize: int, num_pe: int,
|
||||
) -> list[ShardSpec]:
|
||||
"""Split K axis into num_pe equal parts. Each PE gets (M, K/P)."""
|
||||
M, K = shape
|
||||
chunk_k = K // num_pe
|
||||
chunk_bytes = M * chunk_k * itemsize
|
||||
shards = []
|
||||
for i in range(num_pe):
|
||||
shards.append(ShardSpec(
|
||||
pe_index=i,
|
||||
offset_bytes=i * chunk_bytes,
|
||||
nbytes=chunk_bytes,
|
||||
))
|
||||
return shards
|
||||
|
||||
|
||||
def row_wise(
|
||||
*, shape: tuple[int, int], itemsize: int, num_pe: int,
|
||||
) -> list[ShardSpec]:
|
||||
"""Split M axis into num_pe equal parts. Each PE gets (M/P, K)."""
|
||||
M, K = shape
|
||||
chunk_m = M // num_pe
|
||||
chunk_bytes = chunk_m * K * itemsize
|
||||
shards = []
|
||||
for i in range(num_pe):
|
||||
shards.append(ShardSpec(
|
||||
pe_index=i,
|
||||
offset_bytes=i * chunk_bytes,
|
||||
nbytes=chunk_bytes,
|
||||
))
|
||||
return shards
|
||||
|
||||
|
||||
def replicate(
|
||||
*, shape: tuple[int, int], itemsize: int, num_pe: int,
|
||||
) -> list[ShardSpec]:
|
||||
"""Full copy per PE. Each PE gets (M, K)."""
|
||||
M, K = shape
|
||||
full_bytes = M * K * itemsize
|
||||
return [
|
||||
ShardSpec(pe_index=i, offset_bytes=0, nbytes=full_bytes)
|
||||
for i in range(num_pe)
|
||||
]
|
||||
|
||||
|
||||
def tiled_column_major(
|
||||
*, shape: tuple[int, int], itemsize: int, num_pe: int,
|
||||
tile_m: int, tile_k: int,
|
||||
) -> list[ShardSpec]:
|
||||
"""2D tiling, column-major order (K axis first), round-robin across PEs."""
|
||||
M, K = shape
|
||||
tiles_m = ceil(M / tile_m)
|
||||
tiles_k = ceil(K / tile_k)
|
||||
tile_bytes = tile_m * tile_k * itemsize
|
||||
row_bytes = K * itemsize
|
||||
shards = []
|
||||
idx = 0
|
||||
for mi in range(tiles_m):
|
||||
for ki in range(tiles_k):
|
||||
offset = (mi * tile_m * row_bytes) + (ki * tile_k * itemsize)
|
||||
shards.append(ShardSpec(
|
||||
pe_index=idx % num_pe,
|
||||
offset_bytes=offset,
|
||||
nbytes=tile_bytes,
|
||||
))
|
||||
idx += 1
|
||||
return shards
|
||||
|
||||
|
||||
def tiled_row_major(
|
||||
*, shape: tuple[int, int], itemsize: int, num_pe: int,
|
||||
tile_m: int, tile_k: int,
|
||||
) -> list[ShardSpec]:
|
||||
"""2D tiling, row-major order (M axis first), round-robin across PEs."""
|
||||
M, K = shape
|
||||
tiles_m = ceil(M / tile_m)
|
||||
tiles_k = ceil(K / tile_k)
|
||||
tile_bytes = tile_m * tile_k * itemsize
|
||||
row_bytes = K * itemsize
|
||||
shards = []
|
||||
idx = 0
|
||||
for ki in range(tiles_k):
|
||||
for mi in range(tiles_m):
|
||||
offset = (mi * tile_m * row_bytes) + (ki * tile_k * itemsize)
|
||||
shards.append(ShardSpec(
|
||||
pe_index=idx % num_pe,
|
||||
offset_bytes=offset,
|
||||
nbytes=tile_bytes,
|
||||
))
|
||||
idx += 1
|
||||
return shards
|
||||
@@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import heapq
|
||||
from collections import defaultdict
|
||||
|
||||
from kernbench.policy.address.phyaddr import PhysAddr, UnitType
|
||||
from kernbench.topology.types import TopologyGraph
|
||||
|
||||
|
||||
class RoutingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AddressResolver:
|
||||
"""Resolve a PhysAddr to the destination node_id in the compiled graph.
|
||||
|
||||
Also provides named node lookups (find_m_cpu, find_pcie_ep, …) so that
|
||||
component implementations never construct node_id strings directly.
|
||||
Centralising the naming convention here means a single change propagates
|
||||
everywhere (ADR-0015 D4).
|
||||
"""
|
||||
|
||||
def __init__(self, graph: TopologyGraph) -> None:
|
||||
self._node_ids = set(graph.nodes)
|
||||
mm = graph.spec["cube"]["memory_map"]
|
||||
self._slice_size_bytes = mm["hbm_total_gb_per_cube"] * (1 << 30) // mm["hbm_slices_per_cube"]
|
||||
|
||||
# ── Physical-address resolution ──────────────────────────────────
|
||||
|
||||
def resolve(self, addr: PhysAddr) -> str:
|
||||
s = addr.sip_id
|
||||
c = addr.cube_id
|
||||
if addr.kind == "hbm":
|
||||
pe_slice = PhysAddr.hbm_pe_id(addr.hbm_offset, self._slice_size_bytes)
|
||||
node_id = f"sip{s}.cube{c}.hbm_ctrl.slice{pe_slice}"
|
||||
elif addr.kind == "pe_resource":
|
||||
if addr.unit_type == UnitType.PE:
|
||||
node_id = f"sip{s}.cube{c}.pe{addr.pe_id}.pe_tcm"
|
||||
elif addr.unit_type == UnitType.SRAM:
|
||||
node_id = f"sip{s}.cube{c}.sram"
|
||||
elif addr.unit_type == UnitType.MCPU:
|
||||
node_id = f"sip{s}.cube{c}.m_cpu"
|
||||
else:
|
||||
raise RoutingError(f"unsupported unit_type: {addr.unit_type}")
|
||||
else:
|
||||
raise RoutingError(f"unsupported address kind: {addr.kind}")
|
||||
if node_id not in self._node_ids:
|
||||
raise RoutingError(f"node {node_id} not found in topology")
|
||||
return node_id
|
||||
|
||||
# ── Named node lookups ───────────────────────────────────────────
|
||||
|
||||
def find_m_cpu(self, sip: int, cube: int) -> str:
|
||||
node_id = f"sip{sip}.cube{cube}.m_cpu"
|
||||
if node_id not in self._node_ids:
|
||||
raise RoutingError(f"M_CPU not found: {node_id}")
|
||||
return node_id
|
||||
|
||||
def find_pcie_ep(self, sip: int, io_id: str = "io0") -> str:
|
||||
node_id = f"sip{sip}.{io_id}.pcie_ep"
|
||||
if node_id not in self._node_ids:
|
||||
raise RoutingError(f"PCIE_EP not found: {node_id}")
|
||||
return node_id
|
||||
|
||||
def find_io_cpu(self, sip: int, io_id: str = "io0") -> str:
|
||||
node_id = f"sip{sip}.{io_id}.io_cpu"
|
||||
if node_id not in self._node_ids:
|
||||
raise RoutingError(f"IO_CPU not found: {node_id}")
|
||||
return node_id
|
||||
|
||||
def find_all_pcie_eps(self) -> list[str]:
|
||||
"""Return all PCIE_EP node ids across all SIPs, sorted."""
|
||||
return sorted(nid for nid in self._node_ids if nid.endswith(".pcie_ep"))
|
||||
|
||||
|
||||
class PathRouter:
|
||||
"""Find data-path from a source PE (or arbitrary node) to a destination node.
|
||||
|
||||
Two adjacency graphs are maintained:
|
||||
_adj — excludes command edges (used by PE DMA routing, find_path)
|
||||
_adj_all — includes all edges (used by component-to-component routing,
|
||||
find_node_path; required because M_CPU↔NOC links are "command")
|
||||
"""
|
||||
|
||||
# Edge kinds excluded from M_CPU DMA adjacency: prevents routing through
|
||||
# PE-internal pipeline nodes when computing DMA paths.
|
||||
_MCPU_DMA_EXCLUDE = {"pe_internal", "pe_to_xbar"}
|
||||
|
||||
def __init__(self, graph: TopologyGraph) -> None:
|
||||
self._adj: dict[str, list[tuple[str, float]]] = defaultdict(list)
|
||||
self._adj_all: dict[str, list[tuple[str, float]]] = defaultdict(list)
|
||||
self._adj_mcpu_dma: dict[str, list[tuple[str, float]]] = defaultdict(list)
|
||||
for e in graph.edges:
|
||||
w = e.routing_weight_mm if e.routing_weight_mm is not None else e.distance_mm
|
||||
self._adj_all[e.src].append((e.dst, w))
|
||||
if e.kind != "command":
|
||||
self._adj[e.src].append((e.dst, w))
|
||||
if e.kind not in self._MCPU_DMA_EXCLUDE:
|
||||
self._adj_mcpu_dma[e.src].append((e.dst, w))
|
||||
|
||||
def find_path(self, src_pe: str, dst_node: str) -> list[str]:
|
||||
"""PE DMA routing: prepends .pe_dma, excludes command edges."""
|
||||
start = f"{src_pe}.pe_dma"
|
||||
return self._run_dijkstra(self._adj, start, dst_node)
|
||||
|
||||
def find_path_with_distance(self, src_pe: str, dst_node: str) -> tuple[list[str], float]:
|
||||
start = f"{src_pe}.pe_dma"
|
||||
return self._run_dijkstra_with_dist(self._adj, start, dst_node)
|
||||
|
||||
def find_mcpu_dma_path(self, m_cpu_id: str, dst_hbm_slice_id: str) -> list[str]:
|
||||
"""M_CPU DMA path: never routes through PE-internal nodes (ADR-0015 D5).
|
||||
|
||||
Same-cube: deterministic [m_cpu, noc, xbar.pe_i, hbm_ctrl.slice_i].
|
||||
Cross-cube: Dijkstra via _adj_mcpu_dma (pe_internal/pe_to_xbar excluded)
|
||||
→ routes through NOC → UCIe → target cube NOC → xbar → HBM.
|
||||
"""
|
||||
m_cube = ".".join(m_cpu_id.split(".")[:2])
|
||||
d_cube = ".".join(dst_hbm_slice_id.split(".")[:2])
|
||||
if m_cube == d_cube:
|
||||
slice_idx = int(dst_hbm_slice_id.rsplit("slice", 1)[1])
|
||||
return [
|
||||
m_cpu_id,
|
||||
f"{m_cube}.noc",
|
||||
f"{m_cube}.xbar.pe{slice_idx}",
|
||||
dst_hbm_slice_id,
|
||||
]
|
||||
return self._run_dijkstra(self._adj_mcpu_dma, m_cpu_id, dst_hbm_slice_id)
|
||||
|
||||
def find_node_path(self, src: str, dst: str) -> list[str]:
|
||||
"""General routing between arbitrary nodes, including command edges.
|
||||
|
||||
Used by components (IoCpuComponent, MCpuComponent) that route through
|
||||
M_CPU↔NOC command-kind links.
|
||||
"""
|
||||
return self._run_dijkstra(self._adj_all, src, dst)
|
||||
|
||||
def _run_dijkstra(
|
||||
self,
|
||||
adj: dict[str, list[tuple[str, float]]],
|
||||
start: str,
|
||||
goal: str,
|
||||
) -> list[str]:
|
||||
path, _ = self._run_dijkstra_with_dist(adj, start, goal)
|
||||
return path
|
||||
|
||||
def _run_dijkstra_with_dist(
|
||||
self,
|
||||
adj: dict[str, list[tuple[str, float]]],
|
||||
start: str,
|
||||
goal: str,
|
||||
) -> tuple[list[str], float]:
|
||||
if start == goal:
|
||||
return [start], 0.0
|
||||
best: dict[str, float] = {start: 0.0}
|
||||
prev: dict[str, str] = {}
|
||||
heap: list[tuple[float, str]] = [(0.0, start)]
|
||||
while heap:
|
||||
d, node = heapq.heappop(heap)
|
||||
if node == goal:
|
||||
path: list[str] = []
|
||||
cur = goal
|
||||
while cur != start:
|
||||
path.append(cur)
|
||||
cur = prev[cur]
|
||||
path.append(start)
|
||||
path.reverse()
|
||||
return path, d
|
||||
if d > best.get(node, float("inf")):
|
||||
continue
|
||||
for neighbor, edge_dist in adj[node]:
|
||||
new_d = d + edge_dist
|
||||
if new_d < best.get(neighbor, float("inf")):
|
||||
best[neighbor] = new_d
|
||||
prev[neighbor] = node
|
||||
heapq.heappush(heap, (new_d, neighbor))
|
||||
raise RoutingError(f"no path from {start} to {goal}")
|
||||
|
||||
# ── backward-compat shims (used by existing tests) ───────────────
|
||||
|
||||
def _dijkstra(self, start: str, goal: str) -> list[str]:
|
||||
return self._run_dijkstra(self._adj, start, goal)
|
||||
|
||||
def _dijkstra_with_dist(self, start: str, goal: str) -> tuple[list[str], float]:
|
||||
return self._run_dijkstra_with_dist(self._adj, start, goal)
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from kernbench.common.types import Completion, SimEngine, Trace
|
||||
|
||||
from .context import RuntimeContext
|
||||
from .types import BenchResult, DeviceSelector
|
||||
|
||||
|
||||
class CompletionPolicy(str, Enum):
|
||||
LAST_SUBMITTED = "last_submitted"
|
||||
LAST_COMPLETED = "last_completed" # requires trace/timestamps or engine support; stub for now
|
||||
ALL_OK_FAIL_FAST = "all_ok_fail_fast"
|
||||
|
||||
|
||||
BenchFn = Callable[[RuntimeContext], Any]
|
||||
EngineFactory = Callable[[object, DeviceSelector], SimEngine]
|
||||
|
||||
|
||||
def run_bench(
|
||||
*,
|
||||
topology: object,
|
||||
bench_fn: BenchFn,
|
||||
device: DeviceSelector,
|
||||
engine_factory: EngineFactory,
|
||||
correlation_id: str = "bench0",
|
||||
completion_policy: CompletionPolicy = CompletionPolicy.LAST_SUBMITTED,
|
||||
) -> BenchResult:
|
||||
"""
|
||||
Minimal bench runner.
|
||||
|
||||
- topology: compiled topology object (opaque to runtime here)
|
||||
- bench_fn: callable that receives RuntimeContext and submits requests
|
||||
- device: DeviceSelector ("all" or "sip:<N>")
|
||||
- engine_factory: builds sim_engine for given topology & device
|
||||
- completion_policy: how to determine overall completion/result
|
||||
"""
|
||||
engine = engine_factory(topology, device)
|
||||
# Extract spec from TopologyHandle or TopologyGraph
|
||||
topo_obj = getattr(topology, "topology_obj", topology)
|
||||
spec = getattr(topo_obj, "spec", None)
|
||||
ctx = RuntimeContext(
|
||||
engine=engine, target_device=device,
|
||||
correlation_id=correlation_id, spec=spec,
|
||||
)
|
||||
|
||||
bench_fn(ctx)
|
||||
|
||||
ctx.wait_all()
|
||||
|
||||
collected_traces = ctx._traces or None
|
||||
|
||||
handles = ctx.handles()
|
||||
if not handles:
|
||||
return BenchResult(
|
||||
completion=Completion(
|
||||
ok=False, error_code="NO_REQUESTS", error_message="Bench submitted no requests"
|
||||
),
|
||||
correlation_id=correlation_id,
|
||||
trace=None,
|
||||
traces=collected_traces,
|
||||
)
|
||||
|
||||
if completion_policy == CompletionPolicy.LAST_SUBMITTED:
|
||||
last = handles[-1]
|
||||
completion, trace = engine.get_completion(last)
|
||||
return BenchResult(
|
||||
completion=completion, correlation_id=correlation_id,
|
||||
trace=trace, traces=collected_traces,
|
||||
)
|
||||
|
||||
if completion_policy == CompletionPolicy.ALL_OK_FAIL_FAST:
|
||||
last_trace: Trace | None = None
|
||||
for h in handles:
|
||||
c, t = engine.get_completion(h)
|
||||
last_trace = t if t is not None else last_trace
|
||||
if not c.ok:
|
||||
return BenchResult(
|
||||
completion=c, correlation_id=correlation_id,
|
||||
trace=last_trace, traces=collected_traces,
|
||||
)
|
||||
return BenchResult(
|
||||
completion=Completion(ok=True), correlation_id=correlation_id,
|
||||
trace=last_trace, traces=collected_traces,
|
||||
)
|
||||
|
||||
# LAST_COMPLETED placeholder (needs engine support for timing). Fall back.
|
||||
last = handles[-1]
|
||||
completion, trace = engine.get_completion(last)
|
||||
return BenchResult(
|
||||
completion=completion, correlation_id=correlation_id,
|
||||
trace=trace, traces=collected_traces,
|
||||
)
|
||||
@@ -0,0 +1,282 @@
|
||||
# kernbench/runtime_api/context.py
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from kernbench.common.types import Completion, RequestHandle, SimEngine
|
||||
|
||||
from .types import DeviceSelector
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeContext:
|
||||
engine: SimEngine
|
||||
target_device: DeviceSelector
|
||||
correlation_id: str
|
||||
spec: dict | None = None
|
||||
|
||||
_handles: list[RequestHandle] = field(default_factory=list, init=False)
|
||||
_completed: set[RequestHandle] = field(default_factory=set, init=False)
|
||||
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
|
||||
_tensor_counter: int = field(default=0, init=False)
|
||||
_traces: list[dict] = field(default_factory=list, init=False)
|
||||
|
||||
def submit(self, request: Any) -> RequestHandle:
|
||||
submit_fn = getattr(self.engine, "submit", None)
|
||||
if submit_fn is None:
|
||||
raise AttributeError("Engine does not implement submit(request) -> RequestHandle.")
|
||||
handle: RequestHandle = submit_fn(request) # type: ignore[call-arg]
|
||||
self._handles.append(handle)
|
||||
return handle
|
||||
|
||||
def is_completed(self, handle: RequestHandle) -> bool:
|
||||
return handle in self._completed
|
||||
|
||||
def wait(self, handle: RequestHandle, *, _meta: dict | None = None) -> Completion:
|
||||
if handle in self._completed:
|
||||
completion, trace = self.engine.get_completion(handle)
|
||||
return completion
|
||||
|
||||
wait_fn = getattr(self.engine, "wait", None)
|
||||
if wait_fn is not None:
|
||||
wait_fn(handle) # type: ignore[misc]
|
||||
|
||||
completion, trace = self.engine.get_completion(handle)
|
||||
self._completed.add(handle)
|
||||
if _meta is not None and trace is not None:
|
||||
entry = dict(trace) if isinstance(trace, dict) else {"raw": trace}
|
||||
entry.update(_meta)
|
||||
self._traces.append(entry)
|
||||
return completion
|
||||
|
||||
def wait_all(self) -> None:
|
||||
for h in self._handles:
|
||||
if h not in self._completed:
|
||||
self.wait(h)
|
||||
|
||||
def handles(self) -> list[RequestHandle]:
|
||||
return list(self._handles)
|
||||
|
||||
# ── PyTorch-like tensor API ──────────────────────────────────────
|
||||
|
||||
def _ensure_allocators(self) -> dict:
|
||||
"""Lazily create PEMemAllocator instances from spec."""
|
||||
if self._allocators:
|
||||
return self._allocators
|
||||
if self.spec is None:
|
||||
raise RuntimeError(
|
||||
"RuntimeContext.spec is required for tensor operations. "
|
||||
"Pass spec=graph.spec when creating RuntimeContext."
|
||||
)
|
||||
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
|
||||
|
||||
system = self.spec.get("system", {})
|
||||
cube = self.spec.get("cube", {})
|
||||
mm = cube.get("memory_map", {})
|
||||
pe_template = cube.get("pe_template", {})
|
||||
pe_comps = pe_template.get("components", {})
|
||||
tcm_cfg = pe_comps.get("pe_tcm", {}).get("attrs", {})
|
||||
|
||||
sip_count = system.get("sips", {}).get("count", 1)
|
||||
cubes_per_sip = system.get("sips", {}).get("cubes_per_sip", 16)
|
||||
pes_per_cube = (
|
||||
cube.get("pe_layout", {}).get("pe_per_corner", 2)
|
||||
* len(cube.get("pe_layout", {}).get("corners", ["NW", "NE", "SW", "SE"]))
|
||||
)
|
||||
hbm_gb = mm.get("hbm_total_gb_per_cube", 48)
|
||||
hbm_slices = mm.get("hbm_slices_per_cube", 8)
|
||||
tcm_mb = tcm_cfg.get("size_mb", 16)
|
||||
|
||||
cfg = AddressConfig(
|
||||
sip_count=sip_count,
|
||||
cubes_per_sip=cubes_per_sip,
|
||||
pes_per_cube=pes_per_cube,
|
||||
hbm_bytes_per_cube=hbm_gb * (1 << 30),
|
||||
hbm_slices_per_cube=hbm_slices,
|
||||
tcm_bytes_per_pe=tcm_mb * (1 << 20),
|
||||
tcm_scheduler_reserved_bytes=4 * (1 << 20),
|
||||
sram_bytes_per_cube=32 * (1 << 20),
|
||||
)
|
||||
# Create allocators for all SIPs × cubes × PEs
|
||||
# Flat index: sip_id * cubes_per_sip * pes_per_cube + cube_id * pes_per_cube + pe_id
|
||||
self._pes_per_cube = pes_per_cube
|
||||
self._num_cubes = cubes_per_sip
|
||||
self._num_sips = sip_count
|
||||
cubes_x_pes = cubes_per_sip * pes_per_cube
|
||||
for sip_id in range(sip_count):
|
||||
for cube_id in range(cubes_per_sip):
|
||||
for pe_id in range(pes_per_cube):
|
||||
flat_idx = sip_id * cubes_x_pes + cube_id * pes_per_cube + pe_id
|
||||
self._allocators[flat_idx] = PEMemAllocator(
|
||||
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
|
||||
)
|
||||
return self._allocators
|
||||
|
||||
def _next_tensor_name(self) -> str:
|
||||
self._tensor_counter += 1
|
||||
return f"t{self._tensor_counter}"
|
||||
|
||||
def zeros(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
dtype: str = "f16",
|
||||
*,
|
||||
placement: list | None = None,
|
||||
dp: Any = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""Create a tensor and deploy to HBM with zero-fill (like torch.zeros)."""
|
||||
return self._create_tensor(shape, dtype, placement, name, pattern="zero", dp=dp)
|
||||
|
||||
def empty(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
dtype: str = "f16",
|
||||
*,
|
||||
placement: list | None = None,
|
||||
dp: Any = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
|
||||
return self._create_tensor(shape, dtype, placement, name, pattern=None, dp=dp)
|
||||
|
||||
def _create_tensor(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
dtype: str,
|
||||
placement: list | None,
|
||||
name: str | None,
|
||||
pattern: str | None,
|
||||
dp: Any = None,
|
||||
):
|
||||
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg
|
||||
from kernbench.runtime_api.tensor import Tensor, deploy_tensor, dtype_itemsize
|
||||
|
||||
tensor_name = name or self._next_tensor_name()
|
||||
t = Tensor(shape=shape, dtype=dtype, name=tensor_name)
|
||||
|
||||
dp_policy: DPPolicy | None = None
|
||||
|
||||
# Resolve placement: dp= takes priority over placement=
|
||||
if dp is not None and isinstance(dp, DPPolicy):
|
||||
dp_policy = dp
|
||||
allocators = self._ensure_allocators()
|
||||
itemsize = dtype_itemsize(dtype)
|
||||
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
|
||||
total_cubes = self._num_sips * self._num_cubes
|
||||
placement = resolve_dp_policy(
|
||||
dp, shape=shape_2d, itemsize=itemsize,
|
||||
num_pe=self._pes_per_cube, num_cubes=total_cubes,
|
||||
)
|
||||
elif placement is None:
|
||||
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=t.nbytes)]
|
||||
|
||||
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
|
||||
pe_indices = {s.pe_index for s in placement}
|
||||
target_pe: int | str = "all" if len(pe_indices) > 1 else next(iter(pe_indices))
|
||||
t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy)
|
||||
|
||||
# Allocate PAs via PEMemAllocator
|
||||
allocators = self._ensure_allocators()
|
||||
handle = deploy_tensor(
|
||||
name=tensor_name,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
placement=placement,
|
||||
allocators=allocators,
|
||||
)
|
||||
t._handle = handle
|
||||
|
||||
# Submit MemoryWriteMsg per shard (deploy data to device)
|
||||
if pattern is not None:
|
||||
for shard in handle.shards:
|
||||
h = self.submit(MemoryWriteMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=f"deploy_{tensor_name}_pe{shard.pe}",
|
||||
dst_sip=shard.sip, dst_cube=shard.cube, dst_pe=shard.pe,
|
||||
dst_pa=shard.pa, nbytes=shard.nbytes, pattern=pattern,
|
||||
target_cubes=(shard.cube,), target_pe=shard.pe,
|
||||
))
|
||||
self.wait(h, _meta={
|
||||
"phase": "memory_write", "name": tensor_name,
|
||||
"sip": shard.sip, "cube": shard.cube, "pe": shard.pe,
|
||||
"nbytes": shard.nbytes,
|
||||
})
|
||||
|
||||
return t
|
||||
|
||||
def launch(
|
||||
self,
|
||||
kernel_name: str,
|
||||
kernel_fn: Any,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> RequestHandle:
|
||||
"""Register and launch a kernel (like a fused torch op).
|
||||
|
||||
Positional args: Tensor objects become TensorArg, int/float become ScalarArg.
|
||||
Keyword args: become ScalarArg (name is discarded, order preserved).
|
||||
"""
|
||||
from kernbench.runtime_api.kernel import (
|
||||
KernelLaunchMsg,
|
||||
KernelRef,
|
||||
ScalarArg,
|
||||
)
|
||||
from kernbench.runtime_api.tensor import Tensor
|
||||
from kernbench.triton_emu.registry import register_kernel
|
||||
|
||||
# Register kernel (idempotent)
|
||||
try:
|
||||
register_kernel(kernel_name, kernel_fn)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Build kernel args from positional + keyword args
|
||||
kernel_args: list = []
|
||||
target_pe: int | str = 0
|
||||
|
||||
for a in args:
|
||||
if isinstance(a, Tensor):
|
||||
kernel_args.append(a.to_tensor_arg())
|
||||
# Infer target_pe from tensor DP metadata
|
||||
if a._dp_metadata is not None:
|
||||
dp_target = a._dp_metadata.target_pe
|
||||
if dp_target == "all":
|
||||
target_pe = "all"
|
||||
elif isinstance(dp_target, int) and target_pe != "all":
|
||||
target_pe = dp_target
|
||||
elif isinstance(a, (int, float)):
|
||||
dtype_str = "f32" if isinstance(a, float) else "i32"
|
||||
kernel_args.append(ScalarArg(dtype=dtype_str, value=a))
|
||||
|
||||
for v in kwargs.values():
|
||||
if isinstance(v, (int, float)):
|
||||
dtype_str = "f32" if isinstance(v, float) else "i32"
|
||||
kernel_args.append(ScalarArg(dtype=dtype_str, value=v))
|
||||
|
||||
# Determine target cubes from all tensor shards
|
||||
cube_set: set[int] = set()
|
||||
for a in args:
|
||||
if isinstance(a, Tensor) and a._handle is not None:
|
||||
for s in a._handle.shards:
|
||||
cube_set.add(s.cube)
|
||||
target_cubes = tuple(sorted(cube_set)) if cube_set else (0,)
|
||||
|
||||
# Collect scalar values for GEMM FLOP calculation
|
||||
scalar_vals = [a.value for a in kernel_args if hasattr(a, "value")]
|
||||
|
||||
h = self.submit(KernelLaunchMsg(
|
||||
correlation_id=self.correlation_id,
|
||||
request_id=kernel_name,
|
||||
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
|
||||
args=tuple(kernel_args),
|
||||
target_cubes=target_cubes,
|
||||
target_pe=target_pe,
|
||||
))
|
||||
self.wait(h, _meta={
|
||||
"phase": "kernel", "name": kernel_name,
|
||||
"target_pe": target_pe, "scalars": scalar_vals,
|
||||
})
|
||||
return h
|
||||
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemoryWriteMsg:
|
||||
correlation_id: str
|
||||
request_id: str
|
||||
dst_sip: int
|
||||
dst_cube: int
|
||||
dst_pe: int
|
||||
dst_pa: int
|
||||
nbytes: int
|
||||
src_kind: Literal["pattern", "host_buffer_ref"] = "pattern"
|
||||
pattern: str | None = None
|
||||
target_cubes: tuple[int, ...] | Literal["all"] = "all"
|
||||
target_pe: int | Literal["all"] = "all"
|
||||
msg_type: Literal["memory_write"] = "memory_write"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemoryReadMsg:
|
||||
correlation_id: str
|
||||
request_id: str
|
||||
src_sip: int
|
||||
src_cube: int
|
||||
src_pe: int
|
||||
src_pa: int
|
||||
nbytes: int
|
||||
target_cubes: tuple[int, ...] | Literal["all"] = "all"
|
||||
target_pe: int | Literal["all"] = "all"
|
||||
msg_type: Literal["memory_read"] = "memory_read"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelRef:
|
||||
"""Reference to a kernel binary or builtin timing model.
|
||||
|
||||
Kernel binaries must be pre-deployed to device memory via MemoryWriteMsg.
|
||||
KernelLaunchMsg references the deployed location by PA — source code or IR
|
||||
MUST NOT be embedded in launch messages.
|
||||
|
||||
- "deployed": kernel binary pre-deployed to HBM/SRAM at deploy_pa.
|
||||
- "builtin": simulator built-in timing model, identified by name.
|
||||
"""
|
||||
|
||||
name: str
|
||||
kind: Literal["deployed", "builtin"]
|
||||
deploy_pa: int | None = None
|
||||
deploy_sip: int = 0
|
||||
deploy_cube: int = 0
|
||||
deploy_pe: int = 0
|
||||
nbytes_code: int = 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorArgShard:
|
||||
sip: int
|
||||
cube: int
|
||||
pe: int
|
||||
pa: int
|
||||
nbytes: int
|
||||
offset_bytes: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorArg:
|
||||
shards: tuple[TensorArgShard, ...]
|
||||
arg_kind: Literal["tensor"] = "tensor"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScalarArg:
|
||||
dtype: str
|
||||
value: float | int
|
||||
arg_kind: Literal["scalar"] = "scalar"
|
||||
|
||||
|
||||
KernelArg: TypeAlias = TensorArg | ScalarArg
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelLaunchMsg:
|
||||
correlation_id: str
|
||||
request_id: str
|
||||
kernel_ref: KernelRef
|
||||
args: tuple[KernelArg, ...]
|
||||
target_cubes: tuple[int, ...] | Literal["all"] = "all"
|
||||
target_pe: int | Literal["all"] = "all"
|
||||
msg_type: Literal["kernel_launch"] = "kernel_launch"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResponseMsg:
|
||||
"""Device→Host response carrying PE execution result."""
|
||||
|
||||
correlation_id: str
|
||||
request_id: str
|
||||
src_cube: int
|
||||
src_pe: int
|
||||
success: bool
|
||||
msg_type: Literal["response"] = "response"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PeDmaMsg:
|
||||
"""Direct PE DMA request: host injects a transfer at PE_DMA level.
|
||||
|
||||
Used by the probe utility to measure PE→HBM latency without requiring
|
||||
the full PE_CPU → scheduler → DMA pipeline.
|
||||
"""
|
||||
|
||||
correlation_id: str
|
||||
request_id: str
|
||||
src_sip: int
|
||||
src_cube: int
|
||||
src_pe: int
|
||||
dst_pa: int
|
||||
nbytes: int
|
||||
is_write: bool = False
|
||||
msg_type: Literal["pe_dma"] = "pe_dma"
|
||||
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from kernbench.policy.address.allocator import PEMemAllocator
|
||||
from kernbench.policy.placement.dp import DPPolicy, ShardSpec
|
||||
from kernbench.runtime_api.kernel import TensorArg, TensorArgShard
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorShard:
|
||||
sip: int
|
||||
cube: int
|
||||
pe: int
|
||||
pa: int
|
||||
nbytes: int
|
||||
offset_bytes: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorHandle:
|
||||
name: str
|
||||
shape: tuple[int, ...]
|
||||
dtype: str
|
||||
itemsize: int
|
||||
shards: tuple[TensorShard, ...]
|
||||
|
||||
@property
|
||||
def nbytes(self) -> int:
|
||||
return math.prod(self.shape) * self.itemsize
|
||||
|
||||
|
||||
_DTYPE_ITEMSIZE = {
|
||||
"fp16": 2, "float16": 2, "f16": 2,
|
||||
"fp32": 4, "float32": 4, "f32": 4,
|
||||
"bf16": 2,
|
||||
"int8": 1, "i8": 1,
|
||||
"int16": 2, "i16": 2,
|
||||
"int32": 4, "i32": 4,
|
||||
}
|
||||
|
||||
|
||||
def dtype_itemsize(dtype: str) -> int:
|
||||
if dtype not in _DTYPE_ITEMSIZE:
|
||||
raise ValueError(f"unsupported dtype: {dtype}")
|
||||
return _DTYPE_ITEMSIZE[dtype]
|
||||
|
||||
|
||||
def deploy_tensor(
|
||||
*,
|
||||
name: str,
|
||||
shape: tuple[int, ...],
|
||||
dtype: str,
|
||||
placement: list[ShardSpec],
|
||||
allocators: dict[int, PEMemAllocator],
|
||||
mem_kind: Literal["hbm", "tcm"] = "hbm",
|
||||
) -> TensorHandle:
|
||||
isize = dtype_itemsize(dtype)
|
||||
shards: list[TensorShard] = []
|
||||
for spec in placement:
|
||||
alloc = allocators[spec.pe_index]
|
||||
if mem_kind == "hbm":
|
||||
pa = alloc.alloc_hbm(spec.nbytes)
|
||||
else:
|
||||
pa = alloc.alloc_tcm(spec.nbytes)
|
||||
shards.append(TensorShard(
|
||||
sip=alloc._sip_id,
|
||||
cube=alloc._cube_id,
|
||||
pe=alloc._pe_id,
|
||||
pa=pa.encode(),
|
||||
nbytes=spec.nbytes,
|
||||
offset_bytes=spec.offset_bytes,
|
||||
))
|
||||
return TensorHandle(
|
||||
name=name,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
itemsize=isize,
|
||||
shards=tuple(shards),
|
||||
)
|
||||
|
||||
|
||||
# ── PyTorch-like Tensor API ──────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DPMetadata:
|
||||
"""Data-parallel placement metadata (stored as Tensor._dp_metadata)."""
|
||||
|
||||
placement: list[ShardSpec]
|
||||
dp_policy: DPPolicy | None = None
|
||||
sip: int = 0
|
||||
cube: int = 0
|
||||
target_pe: int | str = 0 # int → single PE, "all" → all PEs
|
||||
|
||||
|
||||
class Tensor:
|
||||
"""PyTorch-like tensor for benchmark code.
|
||||
|
||||
Usage::
|
||||
|
||||
a = ctx.zeros((M, K), dtype="f16")
|
||||
a = ctx.zeros((M, K), dtype="f16", placement=dp.replicate(num_pe=8))
|
||||
ctx.launch("kernel_name", kernel_fn, a, b, out, M=M, K=K)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
dtype: str = "f16",
|
||||
name: str = "",
|
||||
) -> None:
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.name = name
|
||||
self._dp_metadata: DPMetadata | None = None
|
||||
self._handle: TensorHandle | None = None
|
||||
|
||||
@property
|
||||
def itemsize(self) -> int:
|
||||
return dtype_itemsize(self.dtype)
|
||||
|
||||
@property
|
||||
def nbytes(self) -> int:
|
||||
return math.prod(self.shape) * self.itemsize
|
||||
|
||||
@property
|
||||
def pa(self) -> int:
|
||||
"""Primary PA (first shard). Used as kernel pointer argument."""
|
||||
if self._handle is None or not self._handle.shards:
|
||||
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
|
||||
return self._handle.shards[0].pa
|
||||
|
||||
def to(
|
||||
self,
|
||||
placement: list[ShardSpec] | None = None,
|
||||
*,
|
||||
dp_policy: DPPolicy | None = None,
|
||||
sip: int = 0,
|
||||
cube: int = 0,
|
||||
target_pe: int | str = 0,
|
||||
) -> Tensor:
|
||||
"""Set DP placement metadata (like torch.Tensor.to())."""
|
||||
if placement is None:
|
||||
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=self.nbytes)]
|
||||
self._dp_metadata = DPMetadata(
|
||||
placement=placement, dp_policy=dp_policy,
|
||||
sip=sip, cube=cube, target_pe=target_pe,
|
||||
)
|
||||
return self
|
||||
|
||||
def to_tensor_arg(self) -> TensorArg:
|
||||
"""Convert deployed shards to KernelLaunchMsg TensorArg."""
|
||||
if self._handle is None:
|
||||
raise RuntimeError(f"Tensor '{self.name}' is not deployed yet")
|
||||
return TensorArg(
|
||||
shards=tuple(
|
||||
TensorArgShard(
|
||||
sip=s.sip, cube=s.cube, pe=s.pe,
|
||||
pa=s.pa, nbytes=s.nbytes, offset_bytes=s.offset_bytes,
|
||||
)
|
||||
for s in self._handle.shards
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from kernbench.common.types import Completion, Trace
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BenchResult:
|
||||
completion: Completion
|
||||
correlation_id: str
|
||||
trace: Trace | None = None
|
||||
traces: list[dict] | None = None
|
||||
|
||||
def summary_text(self) -> str:
|
||||
if self.completion.ok:
|
||||
return f"[OK] correlation_id={self.correlation_id}"
|
||||
code = self.completion.error_code or "ERROR"
|
||||
msg = self.completion.error_message or ""
|
||||
return f"[FAIL:{code}] correlation_id={self.correlation_id} {msg}".rstrip()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeviceSelector:
|
||||
"""
|
||||
Device selector.
|
||||
|
||||
Supported:
|
||||
- "all" : all SIPs in the tray topology
|
||||
- "sip:<N>" : a single SIP index
|
||||
"""
|
||||
|
||||
raw: str # "all" or "sip:<N>"
|
||||
|
||||
@property
|
||||
def is_all(self) -> bool:
|
||||
return self.raw == "all"
|
||||
|
||||
@property
|
||||
def sip_index(self) -> int:
|
||||
if self.is_all:
|
||||
raise ValueError("DeviceSelector is 'all'; no single sip_index.")
|
||||
m = re.fullmatch(r"sip:(\d+)", self.raw)
|
||||
if not m:
|
||||
raise ValueError(
|
||||
f"Invalid device '{self.raw}'. Expected 'all' or 'sip:<N>' (e.g., sip:0)."
|
||||
)
|
||||
return int(m.group(1))
|
||||
|
||||
|
||||
def resolve_device(raw: str | None) -> DeviceSelector:
|
||||
"""
|
||||
Resolve the CLI --device string into a DeviceSelector.
|
||||
|
||||
Semantics:
|
||||
- if omitted/empty -> "all"
|
||||
- else accept "all" or "sip:<N>"
|
||||
"""
|
||||
if raw is None or raw.strip() == "":
|
||||
return DeviceSelector(raw="all")
|
||||
|
||||
raw = raw.strip().lower()
|
||||
if raw == "all":
|
||||
return DeviceSelector(raw="all")
|
||||
|
||||
m = re.fullmatch(r"sip:(\d+)", raw)
|
||||
if not m:
|
||||
raise ValueError(f"Invalid device '{raw}'. Expected 'all' or 'sip:<N>' (e.g., sip:0).")
|
||||
|
||||
return DeviceSelector(raw=raw)
|
||||
@@ -0,0 +1,31 @@
|
||||
# kernbench/engine/dummy.py
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from kernbench.common.types import Completion, RequestHandle, SimEngine, Trace
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyEngine(SimEngine):
|
||||
topology: object
|
||||
device_raw: str
|
||||
_n: int = 0
|
||||
_store: dict[str, tuple[Completion, Trace | None]] = None # type: ignore
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._store = {}
|
||||
|
||||
def submit(self, request: Any) -> RequestHandle:
|
||||
self._n += 1
|
||||
h = RequestHandle(f"h{self._n}")
|
||||
# 여기서 request 처리/시뮬레이션/스케줄링 등을 수행
|
||||
self._store[str(h)] = (Completion(ok=True), {"request": request, "device": self.device_raw})
|
||||
return h
|
||||
|
||||
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
|
||||
return self._store[str(handle)]
|
||||
|
||||
def wait(self, handle: RequestHandle) -> None:
|
||||
pass
|
||||
@@ -0,0 +1,298 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.common.types import Completion, RequestHandle, Trace
|
||||
import kernbench.components.impls # noqa: F401 — registers built-in implementations
|
||||
from kernbench.components.base import ComponentBase, ComponentRegistry
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.policy.routing.router import AddressResolver, PathRouter
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, PeDmaMsg
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
from kernbench.topology.types import Edge, TopologyGraph
|
||||
|
||||
|
||||
class GraphEngine:
|
||||
"""simpy-based discrete-event simulation engine.
|
||||
|
||||
Phase B: engine injects a Transaction into the PCIE_EP host queue for
|
||||
each request. Components handle their own routing:
|
||||
Path 1: PCIE_EP → IO_CPU (engine-computed path, pre-loaded in Transaction)
|
||||
Path 2: IO_CPU → M_CPU (IO_CPU dispatches, fire-and-forget callback)
|
||||
Path 3: M_CPU.DMA → HBM (M_CPU dispatches, fire-and-forget callback)
|
||||
|
||||
Component implementations are DI-injectable via component_overrides (ADR-0007 D3).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: TopologyGraph,
|
||||
*,
|
||||
component_overrides: dict[str, type[ComponentBase]] | None = None,
|
||||
) -> None:
|
||||
self._env = simpy.Environment()
|
||||
self._resolver = AddressResolver(graph)
|
||||
self._router = PathRouter(graph)
|
||||
self._nodes = graph.nodes
|
||||
self._edge_map: dict[tuple[str, str], Edge] = {}
|
||||
for e in graph.edges:
|
||||
self._edge_map[(e.src, e.dst)] = e
|
||||
self._ns_per_mm: float = graph.spec.get("system", {}).get("ns_per_mm", 0.01)
|
||||
self._results: dict[str, tuple[Completion, Trace]] = {}
|
||||
self._events: dict[str, simpy.Event] = {}
|
||||
self._counter = 0
|
||||
overrides = component_overrides or {}
|
||||
ctx = ComponentContext(
|
||||
router=self._router,
|
||||
resolver=self._resolver,
|
||||
positions={nid: n.pos_mm for nid, n in graph.nodes.items()},
|
||||
ns_per_mm=self._ns_per_mm,
|
||||
edge_map=self._edge_map,
|
||||
spec=graph.spec,
|
||||
)
|
||||
self._components: dict[str, ComponentBase] = {
|
||||
node_id: ComponentRegistry.create(node, overrides, ctx)
|
||||
for node_id, node in graph.nodes.items()
|
||||
}
|
||||
|
||||
# Wire ports: one Store per directed edge (ADR-0015 D1)
|
||||
for e in graph.edges:
|
||||
src_comp = self._components.get(e.src)
|
||||
dst_comp = self._components.get(e.dst)
|
||||
if src_comp is None or dst_comp is None:
|
||||
continue
|
||||
store: simpy.Store = simpy.Store(self._env)
|
||||
src_comp.out_ports[e.dst] = store
|
||||
dst_comp.in_ports[e.src] = store
|
||||
|
||||
# Wire processes: propagation delay per edge (ADR-0015 D2)
|
||||
# Cut-through (wormhole) model: wires apply propagation only.
|
||||
# Serialization (drain) is computed per-path and applied once at the terminal.
|
||||
for e in graph.edges:
|
||||
src_comp = self._components.get(e.src)
|
||||
dst_comp = self._components.get(e.dst)
|
||||
if src_comp is None or dst_comp is None:
|
||||
continue
|
||||
prop_ns = e.distance_mm * self._ns_per_mm
|
||||
self._env.process(
|
||||
self._wire(src_comp.out_ports[e.dst], dst_comp.in_ports[e.src],
|
||||
prop_ns)
|
||||
)
|
||||
|
||||
# Attach host queues to PCIE_EP in_ports before start() (ADR-0015 D3)
|
||||
self._host_queues: dict[str, simpy.Store] = {}
|
||||
for pcie_ep_id in self._resolver.find_all_pcie_eps():
|
||||
host_q: simpy.Store = simpy.Store(self._env)
|
||||
self._components[pcie_ep_id].in_ports["host"] = host_q
|
||||
self._host_queues[pcie_ep_id] = host_q
|
||||
|
||||
# Attach host queues to PE_DMA nodes for direct PE DMA injection
|
||||
self._pe_dma_queues: dict[str, simpy.Store] = {}
|
||||
for node_id, node in graph.nodes.items():
|
||||
if node.kind == "pe_dma":
|
||||
host_q = simpy.Store(self._env)
|
||||
self._components[node_id].in_ports["host"] = host_q
|
||||
self._pe_dma_queues[node_id] = host_q
|
||||
|
||||
# Start components after all ports are wired (ADR-0015 D3)
|
||||
for comp in self._components.values():
|
||||
comp.start(self._env)
|
||||
|
||||
def submit(self, request: Any) -> RequestHandle:
|
||||
self._counter += 1
|
||||
handle = RequestHandle(f"h{self._counter}")
|
||||
event = self._env.event()
|
||||
self._events[str(handle)] = event
|
||||
self._env.process(self._process(str(handle), request, event))
|
||||
return handle
|
||||
|
||||
def wait(self, handle: RequestHandle) -> None:
|
||||
key = str(handle)
|
||||
event = self._events[key]
|
||||
if not event.triggered:
|
||||
self._env.run(until=event)
|
||||
|
||||
def get_completion(self, handle: RequestHandle) -> tuple[Completion, Trace | None]:
|
||||
return self._results[str(handle)]
|
||||
|
||||
# ── internal ────────────────────────────────────────────────────
|
||||
|
||||
def _wire(
|
||||
self,
|
||||
out_port: simpy.Store,
|
||||
in_port: simpy.Store,
|
||||
prop_ns: float,
|
||||
):
|
||||
"""SimPy process: relay messages with propagation delay only.
|
||||
|
||||
Cut-through (wormhole) model: serialization (drain) is computed per-path
|
||||
and applied once at the terminal component, not at every wire hop.
|
||||
"""
|
||||
while True:
|
||||
msg = yield out_port.get()
|
||||
if prop_ns > 0:
|
||||
yield self._env.timeout(prop_ns)
|
||||
yield in_port.put(msg)
|
||||
|
||||
def _process(self, key: str, request: Any, done: simpy.Event):
|
||||
if isinstance(request, PeDmaMsg):
|
||||
yield from self._process_pe_dma(key, request, done)
|
||||
return
|
||||
|
||||
entries = self._entry_points(request)
|
||||
if not entries:
|
||||
self._results[key] = (
|
||||
Completion(ok=True),
|
||||
{"total_ns": 0.0, "nbytes": 0},
|
||||
)
|
||||
done.succeed()
|
||||
return
|
||||
|
||||
start_ns = self._env.now
|
||||
total_nbytes = 0
|
||||
|
||||
root_txn: Transaction | None = None
|
||||
if len(entries) == 1:
|
||||
# Single-SIP: direct inject (common path, no extra events)
|
||||
pcie_ep_id, io_cpu_id, nbytes = entries[0]
|
||||
total_nbytes = nbytes
|
||||
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
|
||||
txn_done = self._env.event()
|
||||
txn = Transaction(request=request, path=path, step=0, nbytes=nbytes, done=txn_done)
|
||||
root_txn = txn
|
||||
yield self._host_queues[pcie_ep_id].put(txn)
|
||||
yield txn_done
|
||||
else:
|
||||
# Multi-SIP: inject per SIP, aggregate completions (ADR-0007)
|
||||
sub_dones: list[simpy.Event] = []
|
||||
sub_txns: list[Transaction] = []
|
||||
for pcie_ep_id, io_cpu_id, nbytes in entries:
|
||||
total_nbytes = max(total_nbytes, nbytes)
|
||||
path = self._router.find_node_path(pcie_ep_id, io_cpu_id)
|
||||
txn_done = self._env.event()
|
||||
txn = Transaction(
|
||||
request=request, path=path, step=0,
|
||||
nbytes=nbytes, done=txn_done,
|
||||
)
|
||||
yield self._host_queues[pcie_ep_id].put(txn)
|
||||
sub_dones.append(txn_done)
|
||||
sub_txns.append(txn)
|
||||
for sd in sub_dones:
|
||||
yield sd
|
||||
# Aggregate pe_exec_ns from multi-SIP (max)
|
||||
pe_vals = [st.result_data.get("pe_exec_ns") for st in sub_txns]
|
||||
pe_vals = [v for v in pe_vals if v is not None]
|
||||
if pe_vals:
|
||||
if root_txn is None:
|
||||
root_txn = sub_txns[0]
|
||||
root_txn.result_data["pe_exec_ns"] = max(pe_vals)
|
||||
|
||||
total_ns = self._env.now - start_ns
|
||||
result_trace: dict[str, Any] = {"total_ns": total_ns, "nbytes": total_nbytes}
|
||||
if root_txn is not None and root_txn.result_data:
|
||||
result_trace.update(root_txn.result_data)
|
||||
self._results[key] = (
|
||||
Completion(ok=True),
|
||||
result_trace,
|
||||
)
|
||||
done.succeed()
|
||||
|
||||
def _process_pe_dma(self, key: str, request: PeDmaMsg, done: simpy.Event):
|
||||
"""Inject a Transaction directly at PE_DMA for PE→HBM latency measurement."""
|
||||
pe_prefix = f"sip{request.src_sip}.cube{request.src_cube}.pe{request.src_pe}"
|
||||
pe_dma_id = f"{pe_prefix}.pe_dma"
|
||||
pa = PhysAddr.decode(request.dst_pa)
|
||||
dst_node = self._resolver.resolve(pa)
|
||||
path = self._router.find_path(pe_prefix, dst_node)
|
||||
drain_ns = self._path_drain_ns(path, request.nbytes)
|
||||
|
||||
start_ns = self._env.now
|
||||
txn_done = self._env.event()
|
||||
txn = Transaction(request=request, path=path, step=0, nbytes=request.nbytes,
|
||||
done=txn_done, drain_ns=drain_ns)
|
||||
yield self._pe_dma_queues[pe_dma_id].put(txn)
|
||||
yield txn_done
|
||||
total_ns = self._env.now - start_ns
|
||||
formula_ns = self._formula_latency(path, request.nbytes)
|
||||
self._results[key] = (
|
||||
Completion(ok=True),
|
||||
{"total_ns": total_ns, "formula_ns": formula_ns, "nbytes": request.nbytes},
|
||||
)
|
||||
done.succeed()
|
||||
|
||||
def _path_drain_ns(self, path: list[str], nbytes: int) -> float:
|
||||
"""Wormhole drain time: nbytes / bottleneck_bw along path."""
|
||||
min_bw = float("inf")
|
||||
for i in range(len(path) - 1):
|
||||
edge = self._edge_map.get((path[i], path[i + 1]))
|
||||
if edge and edge.bw_gbs:
|
||||
min_bw = min(min_bw, edge.bw_gbs)
|
||||
if min_bw == float("inf"):
|
||||
return 0.0
|
||||
return nbytes / min_bw
|
||||
|
||||
def _formula_latency(self, path: list[str], nbytes: int) -> float:
|
||||
"""Lower-bound formula latency (ADR-0015 D7).
|
||||
|
||||
formula = Σ(wire propagation) + Σ(component overhead_ns) + drain_ns
|
||||
|
||||
Phase 0: formula == actual (no contention).
|
||||
Phase 1+: formula <= actual (contention adds queueing).
|
||||
"""
|
||||
total = 0.0
|
||||
# Wire propagation delays
|
||||
for i in range(len(path) - 1):
|
||||
edge = self._edge_map.get((path[i], path[i + 1]))
|
||||
if edge:
|
||||
total += edge.distance_mm * self._ns_per_mm
|
||||
# Component overhead_ns
|
||||
for node_id in path:
|
||||
node = self._nodes.get(node_id)
|
||||
if node:
|
||||
total += float(node.attrs.get("overhead_ns", 0.0))
|
||||
# Drain
|
||||
total += self._path_drain_ns(path, nbytes)
|
||||
return total
|
||||
|
||||
def _entry_points(self, request: Any) -> list[tuple[str, str, int]]:
|
||||
"""Return list of (pcie_ep_id, io_cpu_id, nbytes) per target SIP.
|
||||
|
||||
For Memory{Write,Read}: single SIP entry.
|
||||
For KernelLaunchMsg: one entry per distinct SIP in tensor shards.
|
||||
"""
|
||||
if isinstance(request, MemoryWriteMsg):
|
||||
sip = request.dst_sip
|
||||
return [(
|
||||
self._resolver.find_pcie_ep(sip),
|
||||
self._resolver.find_io_cpu(sip),
|
||||
request.nbytes,
|
||||
)]
|
||||
|
||||
if isinstance(request, MemoryReadMsg):
|
||||
sip = request.src_sip
|
||||
return [(
|
||||
self._resolver.find_pcie_ep(sip),
|
||||
self._resolver.find_io_cpu(sip),
|
||||
request.nbytes,
|
||||
)]
|
||||
|
||||
if isinstance(request, KernelLaunchMsg):
|
||||
seen: set[int] = set()
|
||||
entries: list[tuple[str, str, int]] = []
|
||||
for arg in request.args:
|
||||
if arg.arg_kind != "tensor":
|
||||
continue
|
||||
for shard in arg.shards:
|
||||
if shard.sip not in seen:
|
||||
seen.add(shard.sip)
|
||||
entries.append((
|
||||
self._resolver.find_pcie_ep(shard.sip),
|
||||
self._resolver.find_io_cpu(shard.sip),
|
||||
shard.nbytes,
|
||||
))
|
||||
return entries
|
||||
|
||||
raise ValueError(f"unsupported request type: {type(request)}")
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import simpy
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transaction:
|
||||
"""In-flight request traversing the device fabric hop-by-hop (ADR-0015 D4).
|
||||
|
||||
A Transaction carries a host request through one leg of the device fabric.
|
||||
Each component on the path reads from its in_port, processes (overhead_ns or
|
||||
other latency), and advances the Transaction to the next hop via out_port.
|
||||
Wire processes (ADR-0015 D2) model propagation delay between hops.
|
||||
|
||||
Multi-leg flows (e.g. IO_CPU → M_CPU as leg 1, M_CPU.DMA → HBM as leg 2)
|
||||
use separate Transactions: the terminal component of leg 1 creates leg 2
|
||||
and waits for leg 2's done before succeeding leg 1's done.
|
||||
"""
|
||||
|
||||
request: Any # original host request (MemoryReadMsg, KernelLaunchMsg, …)
|
||||
path: list[str] # node_id sequence for this leg
|
||||
step: int # index of the component currently holding this Transaction
|
||||
nbytes: int # payload size (bytes)
|
||||
done: simpy.Event # succeeded when this leg completes
|
||||
drain_ns: float = 0.0 # wormhole drain time: nbytes / bottleneck_bw (applied once at terminal)
|
||||
is_response: bool = False # True when carrying ResponseMsg on reverse path
|
||||
result_data: dict[str, Any] = field(default_factory=dict) # PE-level metrics (pe_exec_ns, etc.)
|
||||
|
||||
@property
|
||||
def next_hop(self) -> str | None:
|
||||
"""Node id of the next component, or None if this is the terminal hop."""
|
||||
nxt = self.step + 1
|
||||
return self.path[nxt] if nxt < len(self.path) else None
|
||||
|
||||
def advance(self) -> Transaction:
|
||||
"""Return a copy of this Transaction advanced one step along the path."""
|
||||
return Transaction(
|
||||
request=self.request,
|
||||
path=self.path,
|
||||
step=self.step + 1,
|
||||
nbytes=self.nbytes,
|
||||
done=self.done,
|
||||
drain_ns=self.drain_ns,
|
||||
is_response=self.is_response,
|
||||
result_data=self.result_data,
|
||||
)
|
||||
@@ -0,0 +1,965 @@
|
||||
# kernbench/topology/builder.py
|
||||
"""
|
||||
Topology compiler: parses topology.yaml and produces a fully-instantiated
|
||||
TopologyGraph with nodes, edges, and representative view projections.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from .types import Edge, Node, TopologyGraph, TopologyHandle, ViewGraph
|
||||
|
||||
|
||||
# PE component offsets from PE center (small, intra-PE distances ~0.5mm)
|
||||
_PE_COMP_OFFSETS = {
|
||||
"pe_cpu": (-0.3, 0.0),
|
||||
"pe_scheduler": (-0.15, 0.0),
|
||||
"pe_dma": (0.0, -0.15),
|
||||
"pe_gemm": (0.0, 0.0),
|
||||
"pe_math": (0.0, 0.15),
|
||||
"pe_tcm": (0.3, 0.0),
|
||||
}
|
||||
|
||||
|
||||
# ── Public API ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def resolve_topology(path_str: str) -> TopologyHandle:
|
||||
"""Validate path and build compiled topology graph."""
|
||||
p = Path(path_str).expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Topology file not found: {p}")
|
||||
if not p.is_file():
|
||||
raise ValueError(f"Topology path is not a file: {p}")
|
||||
graph = load_topology(p)
|
||||
return TopologyHandle(path=p, topology_obj=graph)
|
||||
|
||||
|
||||
def load_topology(path: Path) -> TopologyGraph:
|
||||
"""Load topology spec from file and compile into a topology graph."""
|
||||
spec = _read_spec(path)
|
||||
_validate_spec(spec)
|
||||
return _compile_graph(spec)
|
||||
|
||||
|
||||
def _read_spec(path: Path) -> dict[str, Any]:
|
||||
"""Read YAML topology spec file and return a dict."""
|
||||
try:
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
except yaml.YAMLError as e:
|
||||
msg = f"Failed to parse YAML topology: {path}"
|
||||
mark = getattr(e, "problem_mark", None)
|
||||
if mark is not None:
|
||||
msg += f" (line {mark.line + 1}, column {mark.column + 1})"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
if data is None:
|
||||
raise ValueError(f"Topology YAML is empty: {path}")
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(
|
||||
f"Topology YAML root must be a mapping/dict: {path} (got {type(data).__name__})"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def _validate_spec(spec: dict) -> None:
|
||||
# TODO: schema validation
|
||||
return
|
||||
|
||||
|
||||
# ── Graph Compiler ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _compile_graph(spec: dict) -> TopologyGraph:
|
||||
"""Build fully-instantiated flat graph + representative view projections."""
|
||||
nodes: dict[str, Node] = {}
|
||||
edges: list[Edge] = []
|
||||
|
||||
system = spec["system"]
|
||||
sip_spec = spec["sip"]
|
||||
cube_spec = spec["cube"]
|
||||
|
||||
mesh_w = sip_spec["cube_mesh"]["w"]
|
||||
mesh_h = sip_spec["cube_mesh"]["h"]
|
||||
cube_w = cube_spec["geometry"]["cube_mm"]["w"]
|
||||
cube_h = cube_spec["geometry"]["cube_mm"]["h"]
|
||||
seam = sip_spec["links"]["inter_cube_mesh"]["distance_mm_across_seam"]
|
||||
stride_x = cube_w + seam
|
||||
stride_y = cube_h + seam
|
||||
|
||||
# System-level
|
||||
_instantiate_system(nodes, system)
|
||||
|
||||
# Per-SIP
|
||||
for sip_id in range(system["sips"]["count"]):
|
||||
sp = f"sip{sip_id}"
|
||||
|
||||
# IO chiplets
|
||||
_instantiate_io_chiplets(
|
||||
nodes, edges, sp, sip_spec,
|
||||
cube_w, cube_h, mesh_w, mesh_h, seam,
|
||||
)
|
||||
|
||||
# Cubes + PEs
|
||||
for row in range(mesh_h):
|
||||
for col in range(mesh_w):
|
||||
cid = row * mesh_w + col
|
||||
cp = f"{sp}.cube{cid}"
|
||||
origin = (col * stride_x, row * stride_y)
|
||||
_instantiate_cube(nodes, edges, cp, cube_spec, origin)
|
||||
|
||||
# Inter-cube UCIe mesh
|
||||
_add_inter_cube_edges(edges, sp, mesh_w, mesh_h, sip_spec)
|
||||
|
||||
# IO → cube UCIe
|
||||
_add_io_to_cube_edges(edges, sp, sip_spec, mesh_w)
|
||||
|
||||
# Switch → IO pcie_ep
|
||||
_add_system_to_io_edges(edges, sp, sip_spec, system)
|
||||
|
||||
# Build views
|
||||
return TopologyGraph(
|
||||
spec=spec,
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
system_view=_build_system_view(spec),
|
||||
sip_view=_build_sip_view(spec),
|
||||
cube_view=_build_cube_view(spec),
|
||||
pe_view=_build_pe_view(spec),
|
||||
)
|
||||
|
||||
|
||||
# ── Layout helpers ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _cube_local_positions(cube_w: float, cube_h: float) -> dict[str, tuple[float, float]]:
|
||||
"""Cube-internal component positions relative to cube origin (0,0) at top-left."""
|
||||
cx, cy = cube_w / 2, cube_h / 2
|
||||
# UCIe node half-sizes (default 2.0×1.2mm) — inset so edges touch boundary
|
||||
uh = 0.6 # half height
|
||||
uw = 1.0 # half width
|
||||
return {
|
||||
"ucie-N": (cx, uh),
|
||||
"ucie-S": (cx, cube_h - uh),
|
||||
"ucie-W": (uw, cy),
|
||||
"ucie-E": (cube_w - uw, cy),
|
||||
"m_cpu": (cube_w - 2.5, cy - 1.5),
|
||||
"xbar.top": (cx, 3.5), # Y reference for top-half xbar.pe nodes
|
||||
"hbm_ctrl": (cx - 2.0, cy),
|
||||
"xbar.bottom": (cx, cube_h - 3.5), # Y reference for bottom-half xbar.pe nodes
|
||||
"bridge.left": (2.5, cy + 2.0),
|
||||
"bridge.right": (cube_w - 2.5, cy + 2.0),
|
||||
"noc": (cx + 2.0, cy),
|
||||
"sram": (2.5, cy - 1.5),
|
||||
}
|
||||
|
||||
|
||||
def _corner_pe_positions(cube_w: float, cube_h: float) -> dict[str, list[tuple[float, float]]]:
|
||||
"""PE center positions per corner, relative to cube origin."""
|
||||
return {
|
||||
"NW": [(1.5, 1.5), (4.5, 1.5)],
|
||||
"NE": [(cube_w - 4.5, 1.5), (cube_w - 1.5, 1.5)],
|
||||
"SW": [(1.5, cube_h - 1.5), (4.5, cube_h - 1.5)],
|
||||
"SE": [(cube_w - 4.5, cube_h - 1.5), (cube_w - 1.5, cube_h - 1.5)],
|
||||
}
|
||||
|
||||
|
||||
# ── Instantiation: system ───────────────────────────────────────────
|
||||
|
||||
|
||||
def _instantiate_system(nodes: dict[str, Node], system: dict) -> None:
|
||||
"""Add system-level nodes (fabric switch)."""
|
||||
sw = system["components"]["switch"]
|
||||
sw_id = "fabric.switch0"
|
||||
nodes[sw_id] = Node(
|
||||
id=sw_id, kind=sw["kind"], impl=sw["impl"],
|
||||
attrs=sw.get("attrs", {}), pos_mm=None, label="Switch",
|
||||
)
|
||||
|
||||
|
||||
# ── Instantiation: IO chiplets ──────────────────────────────────────
|
||||
|
||||
|
||||
def _instantiate_io_chiplets(
|
||||
nodes: dict[str, Node],
|
||||
edges: list[Edge],
|
||||
sp: str,
|
||||
sip_spec: dict,
|
||||
cube_w: float,
|
||||
cube_h: float,
|
||||
mesh_w: int,
|
||||
mesh_h: int,
|
||||
seam: float,
|
||||
) -> None:
|
||||
"""Add IO chiplet nodes and internal pcie_ep → io_cpu edges."""
|
||||
io_spec = sip_spec["iochiplet"]
|
||||
comp = io_spec["components"]
|
||||
links = io_spec["links"]
|
||||
mesh_total_w = mesh_w * cube_w + (mesh_w - 1) * seam
|
||||
mesh_total_h = mesh_h * cube_h + (mesh_h - 1) * seam
|
||||
|
||||
for inst in io_spec["instances"]:
|
||||
iid = inst["id"]
|
||||
prefix = f"{sp}.{iid}"
|
||||
side = inst["place"]["side"]
|
||||
cx = mesh_total_w / 2
|
||||
if side == "N":
|
||||
pcie_y, cpu_y = -5.0, -3.0
|
||||
else:
|
||||
pcie_y, cpu_y = mesh_total_h + 5.0, mesh_total_h + 3.0
|
||||
|
||||
# pcie_ep
|
||||
ep = comp["pcie_ep"]
|
||||
ep_id = f"{prefix}.pcie_ep"
|
||||
nodes[ep_id] = Node(
|
||||
id=ep_id, kind=ep["kind"], impl=ep["impl"],
|
||||
attrs=ep["attrs"], pos_mm=(cx, pcie_y), label="PCIe EP",
|
||||
)
|
||||
|
||||
# io_cpu
|
||||
cpu = comp["io_cpu"]
|
||||
cpu_id = f"{prefix}.io_cpu"
|
||||
nodes[cpu_id] = Node(
|
||||
id=cpu_id, kind=cpu["kind"], impl=cpu["impl"],
|
||||
attrs=cpu["attrs"], pos_mm=(cx, cpu_y), label="IO CPU",
|
||||
)
|
||||
|
||||
# Internal edge
|
||||
edges.append(Edge(
|
||||
src=ep_id, dst=cpu_id,
|
||||
distance_mm=links["pcie_ep_to_io_cpu_mm"],
|
||||
bw_gbs=links["pcie_ep_to_io_cpu_bw_gbs"],
|
||||
kind="io_internal",
|
||||
))
|
||||
|
||||
|
||||
# ── Instantiation: cube + PEs ───────────────────────────────────────
|
||||
|
||||
|
||||
def _instantiate_cube(
|
||||
nodes: dict[str, Node],
|
||||
edges: list[Edge],
|
||||
cp: str,
|
||||
cube: dict,
|
||||
origin: tuple[float, float],
|
||||
) -> None:
|
||||
"""Add all cube-internal nodes and edges, including PE instances."""
|
||||
cube_w = cube["geometry"]["cube_mm"]["w"]
|
||||
cube_h = cube["geometry"]["cube_mm"]["h"]
|
||||
ox, oy = origin
|
||||
local_pos = _cube_local_positions(cube_w, cube_h)
|
||||
clinks = cube["links"]
|
||||
n_slices = cube["memory_map"]["hbm_slices_per_cube"]
|
||||
|
||||
# ── UCIe ports ──
|
||||
ucie_ns = cube["ucie"]["overhead_ns"]
|
||||
for port in cube["ucie"]["ports"]:
|
||||
pid = f"{cp}.ucie-{port}"
|
||||
lx, ly = local_pos[f"ucie-{port}"]
|
||||
nodes[pid] = Node(
|
||||
id=pid, kind="ucie_port", impl="ucie_v1",
|
||||
attrs={"overhead_ns": ucie_ns}, pos_mm=(ox + lx, oy + ly),
|
||||
label=f"UCIe-{port}",
|
||||
)
|
||||
|
||||
# ── Named components: noc, m_cpu, sram ──
|
||||
for name in ("noc", "m_cpu", "sram"):
|
||||
c = cube["components"][name]
|
||||
nid = f"{cp}.{name}"
|
||||
lx, ly = local_pos[name]
|
||||
nodes[nid] = Node(
|
||||
id=nid, kind=c["kind"], impl=c["impl"],
|
||||
attrs=c["attrs"], pos_mm=(ox + lx, oy + ly),
|
||||
label=name.upper().replace("_", " "),
|
||||
)
|
||||
|
||||
# ── HBM controller slices (one per PE) ──
|
||||
hbm_spec = cube["components"]["hbm_ctrl"]
|
||||
hbm_lx, hbm_ly = local_pos["hbm_ctrl"]
|
||||
for sl in range(n_slices):
|
||||
sid = f"{cp}.hbm_ctrl.slice{sl}"
|
||||
nodes[sid] = Node(
|
||||
id=sid, kind=hbm_spec["kind"], impl=hbm_spec["impl"],
|
||||
attrs=hbm_spec["attrs"], pos_mm=(ox + hbm_lx, oy + hbm_ly),
|
||||
label=f"HBM SLICE{sl}",
|
||||
)
|
||||
|
||||
# ── Bridges ──
|
||||
for br in cube["components"]["xbar"]["bridges"]:
|
||||
bname = br["id"]
|
||||
nid = f"{cp}.bridge.{bname}"
|
||||
lx, ly = local_pos[f"bridge.{bname}"]
|
||||
nodes[nid] = Node(
|
||||
id=nid, kind=br["kind"], impl=br["impl"],
|
||||
attrs=br["attrs"], pos_mm=(ox + lx, oy + ly),
|
||||
label=f"Bridge {bname.upper()}",
|
||||
)
|
||||
|
||||
# ── PE instances + per-PE xbar entry nodes ──
|
||||
corners = cube["pe_layout"]["corners"]
|
||||
pe_per_corner = cube["pe_layout"]["pe_per_corner"]
|
||||
corner_pos = _corner_pe_positions(cube_w, cube_h)
|
||||
pe_tmpl = cube["pe_template"]
|
||||
pe_links = pe_tmpl["links"]
|
||||
|
||||
xbar_pe_spec = cube["components"]["xbar"]["pe"]
|
||||
xbar_top_y = local_pos["xbar.top"][1]
|
||||
xbar_bot_y = local_pos["xbar.bottom"][1]
|
||||
|
||||
pe_idx = 0
|
||||
for corner in corners:
|
||||
is_top = corner in ("NW", "NE")
|
||||
xbar_y = xbar_top_y if is_top else xbar_bot_y
|
||||
mm_key = "pe_to_xbar_row_n_mm" if is_top else "pe_to_xbar_row_s_mm"
|
||||
for ci in range(pe_per_corner):
|
||||
pp = f"{cp}.pe{pe_idx}"
|
||||
pe_cx, pe_cy = corner_pos[corner][ci]
|
||||
|
||||
# Per-PE xbar entry node
|
||||
xbar_nid = f"{cp}.xbar.pe{pe_idx}"
|
||||
nodes[xbar_nid] = Node(
|
||||
id=xbar_nid, kind=xbar_pe_spec["kind"], impl=xbar_pe_spec["impl"],
|
||||
attrs=xbar_pe_spec["attrs"], pos_mm=(ox + pe_cx, oy + xbar_y),
|
||||
label=f"XBAR PE{pe_idx}",
|
||||
)
|
||||
|
||||
# PE template components
|
||||
for comp_name, comp_spec in pe_tmpl["components"].items():
|
||||
cid = f"{pp}.{comp_name}"
|
||||
dx, dy = _PE_COMP_OFFSETS.get(comp_name, (0.0, 0.0))
|
||||
nodes[cid] = Node(
|
||||
id=cid, kind=comp_spec["kind"], impl=comp_spec["impl"],
|
||||
attrs=comp_spec["attrs"],
|
||||
pos_mm=(ox + pe_cx + dx, oy + pe_cy + dy),
|
||||
label=comp_name.upper().replace("_", " "),
|
||||
)
|
||||
|
||||
# PE-internal edges
|
||||
_add_pe_internal_edges(edges, pp, pe_links)
|
||||
|
||||
# PE_DMA → xbar.pe_i (HBM data path)
|
||||
edges.append(Edge(
|
||||
src=f"{pp}.pe_dma", dst=xbar_nid,
|
||||
distance_mm=clinks[mm_key],
|
||||
bw_gbs=clinks["pe_to_xbar_bw_gbs"],
|
||||
kind="pe_to_xbar",
|
||||
))
|
||||
|
||||
# PE_DMA → noc (non-HBM data path: SRAM, inter-cube, etc.)
|
||||
edges.append(Edge(
|
||||
src=f"{pp}.pe_dma", dst=f"{cp}.noc",
|
||||
distance_mm=clinks["pe_dma_to_noc_mm"],
|
||||
bw_gbs=clinks["pe_dma_to_noc_bw_gbs"],
|
||||
kind="pe_to_noc",
|
||||
))
|
||||
|
||||
# noc → PE_CPU (command delivery)
|
||||
edges.append(Edge(
|
||||
src=f"{cp}.noc", dst=f"{pp}.pe_cpu",
|
||||
distance_mm=clinks["noc_to_pe_cpu_mm"],
|
||||
kind="command",
|
||||
))
|
||||
|
||||
pe_idx += 1
|
||||
|
||||
# ── Cube fabric edges ──
|
||||
|
||||
# xbar.pe_i ↔ hbm_ctrl.slice_i (local Y-path, bidirectional for response)
|
||||
for i in range(n_slices):
|
||||
edges.append(Edge(
|
||||
src=f"{cp}.xbar.pe{i}", dst=f"{cp}.hbm_ctrl.slice{i}",
|
||||
distance_mm=clinks["xbar_to_hbm_mm"],
|
||||
bw_gbs=clinks["xbar_to_hbm_bw_gbs"],
|
||||
kind="xbar_to_hbm",
|
||||
))
|
||||
edges.append(Edge(
|
||||
src=f"{cp}.hbm_ctrl.slice{i}", dst=f"{cp}.xbar.pe{i}",
|
||||
distance_mm=clinks["xbar_to_hbm_mm"],
|
||||
bw_gbs=clinks["xbar_to_hbm_bw_gbs"],
|
||||
kind="hbm_to_xbar",
|
||||
))
|
||||
|
||||
# xbar chain: pe0↔pe1↔pe2↔pe3 (top), pe4↔pe5↔pe6↔pe7 (bottom)
|
||||
half = n_slices // 2
|
||||
for half_start in (0, half):
|
||||
for i in range(half_start, half_start + half - 1):
|
||||
intra = ((i - half_start) % pe_per_corner) != (pe_per_corner - 1)
|
||||
x_dist = clinks["xbar_chain_intra_corner_mm"] if intra else clinks["xbar_chain_inter_corner_mm"]
|
||||
for a, b in [(i, i + 1), (i + 1, i)]:
|
||||
edges.append(Edge(
|
||||
src=f"{cp}.xbar.pe{a}", dst=f"{cp}.xbar.pe{b}",
|
||||
distance_mm=x_dist,
|
||||
bw_gbs=clinks["xbar_x_bw_gbs"],
|
||||
kind="xbar_chain",
|
||||
))
|
||||
|
||||
# bridge connections: pe0↔bridge.left↔pe4, pe3↔bridge.right↔pe7
|
||||
for bname, pe_top, pe_bot in [("left", 0, half), ("right", half - 1, n_slices - 1)]:
|
||||
br_node = f"{cp}.bridge.{bname}"
|
||||
for pe_i, br_mm_key in [(pe_top, "xbar_row_n_to_bridge_mm"),
|
||||
(pe_bot, "xbar_row_s_to_bridge_mm")]:
|
||||
xbar_node = f"{cp}.xbar.pe{pe_i}"
|
||||
edges.append(Edge(
|
||||
src=xbar_node, dst=br_node,
|
||||
distance_mm=clinks[br_mm_key],
|
||||
bw_gbs=clinks["xbar_to_bridge_bw_gbs"],
|
||||
kind="xbar_to_bridge",
|
||||
))
|
||||
edges.append(Edge(
|
||||
src=br_node, dst=xbar_node,
|
||||
distance_mm=clinks[br_mm_key],
|
||||
bw_gbs=clinks["xbar_to_bridge_bw_gbs"],
|
||||
kind="bridge_to_xbar",
|
||||
))
|
||||
|
||||
# ucie ↔ noc (UCIe-NOC boundary; per_connection_bw_gbs = 128 GB/s, n_connections = 4)
|
||||
_noc_ucie = clinks["noc_to_ucie"]
|
||||
for port in cube["ucie"]["ports"]:
|
||||
edges.append(Edge(
|
||||
src=f"{cp}.ucie-{port}", dst=f"{cp}.noc",
|
||||
distance_mm=0.0,
|
||||
bw_gbs=_noc_ucie["per_connection_bw_gbs"],
|
||||
n_connections=_noc_ucie["n_connections"],
|
||||
kind="ucie_to_noc",
|
||||
))
|
||||
|
||||
for port in cube["ucie"]["ports"]:
|
||||
edges.append(Edge(
|
||||
src=f"{cp}.noc", dst=f"{cp}.ucie-{port}",
|
||||
distance_mm=0.0,
|
||||
bw_gbs=_noc_ucie["per_connection_bw_gbs"],
|
||||
n_connections=_noc_ucie["n_connections"],
|
||||
kind="noc_to_ucie",
|
||||
))
|
||||
|
||||
# noc ↔ xbar.pe{i}: wire delay is 0 (NOC traversal latency computed by TwoDMeshNocComponent);
|
||||
# routing_weight_mm=50.0 steers PE DMA Dijkstra away from this path (prefer direct pe_dma→xbar)
|
||||
_noc_xbar = clinks.get("noc_to_xbar", {})
|
||||
_noc_xbar_bw = _noc_xbar.get("per_connection_bw_gbs")
|
||||
for i in range(n_slices):
|
||||
edges.append(Edge(
|
||||
src=f"{cp}.noc", dst=f"{cp}.xbar.pe{i}",
|
||||
distance_mm=0.0,
|
||||
bw_gbs=_noc_xbar_bw,
|
||||
routing_weight_mm=50.0,
|
||||
kind="noc_to_xbar",
|
||||
))
|
||||
edges.append(Edge(
|
||||
src=f"{cp}.xbar.pe{i}", dst=f"{cp}.noc",
|
||||
distance_mm=0.0,
|
||||
bw_gbs=_noc_xbar_bw,
|
||||
routing_weight_mm=50.0,
|
||||
kind="xbar_to_noc",
|
||||
))
|
||||
|
||||
# m_cpu ↔ noc (command dispatch, both directions)
|
||||
edges.append(Edge(
|
||||
src=f"{cp}.m_cpu", dst=f"{cp}.noc",
|
||||
distance_mm=clinks["m_cpu_to_noc_mm"],
|
||||
kind="command",
|
||||
))
|
||||
edges.append(Edge(
|
||||
src=f"{cp}.noc", dst=f"{cp}.m_cpu",
|
||||
distance_mm=clinks["m_cpu_to_noc_mm"],
|
||||
kind="command",
|
||||
))
|
||||
|
||||
# noc ↔ sram (shared SRAM access; per_connection_bw_gbs = 128 GB/s, n_connections = 4)
|
||||
_noc_sram = clinks["noc_to_sram"]
|
||||
edges.append(Edge(
|
||||
src=f"{cp}.noc", dst=f"{cp}.sram",
|
||||
distance_mm=clinks["noc_to_sram_mm"],
|
||||
bw_gbs=_noc_sram["per_connection_bw_gbs"],
|
||||
n_connections=_noc_sram["n_connections"],
|
||||
kind="noc_to_sram",
|
||||
))
|
||||
edges.append(Edge(
|
||||
src=f"{cp}.sram", dst=f"{cp}.noc",
|
||||
distance_mm=clinks["noc_to_sram_mm"],
|
||||
bw_gbs=_noc_sram["per_connection_bw_gbs"],
|
||||
n_connections=_noc_sram["n_connections"],
|
||||
kind="noc_to_sram",
|
||||
))
|
||||
|
||||
|
||||
def _add_pe_internal_edges(edges: list[Edge], pp: str, pe_links: dict) -> None:
|
||||
"""Add PE-internal edges for a single PE instance."""
|
||||
edges.append(Edge(
|
||||
src=f"{pp}.pe_cpu", dst=f"{pp}.pe_scheduler",
|
||||
distance_mm=pe_links["pe_cpu_to_scheduler_mm"],
|
||||
kind="pe_internal",
|
||||
))
|
||||
for eng, key in [("pe_dma", "scheduler_to_dma_mm"),
|
||||
("pe_gemm", "scheduler_to_gemm_mm"),
|
||||
("pe_math", "scheduler_to_math_mm")]:
|
||||
edges.append(Edge(
|
||||
src=f"{pp}.pe_scheduler", dst=f"{pp}.{eng}",
|
||||
distance_mm=pe_links[key],
|
||||
kind="pe_internal",
|
||||
))
|
||||
for eng, mm_key, bw_key in [("pe_dma", "dma_to_tcm_mm", "dma_to_tcm_bw_gbs"),
|
||||
("pe_gemm", "gemm_to_tcm_mm", "gemm_to_tcm_bw_gbs"),
|
||||
("pe_math", "math_to_tcm_mm", "math_to_tcm_bw_gbs")]:
|
||||
edges.append(Edge(
|
||||
src=f"{pp}.{eng}", dst=f"{pp}.pe_tcm",
|
||||
distance_mm=pe_links[mm_key],
|
||||
bw_gbs=pe_links[bw_key],
|
||||
kind="pe_internal",
|
||||
))
|
||||
|
||||
|
||||
# ── Inter-cube / IO / system edges ──────────────────────────────────
|
||||
|
||||
|
||||
def _add_inter_cube_edges(
|
||||
edges: list[Edge], sp: str, mesh_w: int, mesh_h: int, sip_spec: dict,
|
||||
) -> None:
|
||||
"""Add UCIe mesh edges between adjacent cubes within a SIP."""
|
||||
mesh = sip_spec["links"]["inter_cube_mesh"]
|
||||
bw = mesh["bw_gbs_per_ucie_phy"]
|
||||
dist = mesh["distance_mm_across_seam"]
|
||||
for row in range(mesh_h):
|
||||
for col in range(mesh_w):
|
||||
cid = row * mesh_w + col
|
||||
if col + 1 < mesh_w:
|
||||
nid = row * mesh_w + (col + 1)
|
||||
edges.append(Edge(
|
||||
src=f"{sp}.cube{cid}.ucie-E", dst=f"{sp}.cube{nid}.ucie-W",
|
||||
distance_mm=dist, bw_gbs=bw, kind="ucie_mesh",
|
||||
))
|
||||
edges.append(Edge(
|
||||
src=f"{sp}.cube{nid}.ucie-W", dst=f"{sp}.cube{cid}.ucie-E",
|
||||
distance_mm=dist, bw_gbs=bw, kind="ucie_mesh",
|
||||
))
|
||||
if row + 1 < mesh_h:
|
||||
nid = (row + 1) * mesh_w + col
|
||||
edges.append(Edge(
|
||||
src=f"{sp}.cube{cid}.ucie-S", dst=f"{sp}.cube{nid}.ucie-N",
|
||||
distance_mm=dist, bw_gbs=bw, kind="ucie_mesh",
|
||||
))
|
||||
edges.append(Edge(
|
||||
src=f"{sp}.cube{nid}.ucie-N", dst=f"{sp}.cube{cid}.ucie-S",
|
||||
distance_mm=dist, bw_gbs=bw, kind="ucie_mesh",
|
||||
))
|
||||
|
||||
|
||||
def _add_io_to_cube_edges(
|
||||
edges: list[Edge], sp: str, sip_spec: dict, mesh_w: int,
|
||||
) -> None:
|
||||
"""Add IO chiplet io_cpu ↔ cube UCIe edges (bidirectional for response)."""
|
||||
io_links = sip_spec["iochiplet"]["links"]
|
||||
io_to_ucie_mm = io_links["io_cpu_to_ucie_mm"]
|
||||
io_to_ucie_bw = io_links["io_cpu_to_ucie_bw_gbs"]
|
||||
for inst in sip_spec["iochiplet"]["instances"]:
|
||||
iid = inst["id"]
|
||||
io_cpu_id = f"{sp}.{iid}.io_cpu"
|
||||
for port in inst["cube_ports"]:
|
||||
cube_col, cube_row = port["cube"]["xy"]
|
||||
cube_id = cube_row * mesh_w + cube_col
|
||||
cube_side = port["cube_side"]
|
||||
ucie_id = f"{sp}.cube{cube_id}.ucie-{cube_side}"
|
||||
edges.append(Edge(
|
||||
src=io_cpu_id, dst=ucie_id,
|
||||
distance_mm=io_to_ucie_mm + port["distance_mm"],
|
||||
bw_gbs=io_to_ucie_bw,
|
||||
kind="io_to_cube",
|
||||
))
|
||||
edges.append(Edge(
|
||||
src=ucie_id, dst=io_cpu_id,
|
||||
distance_mm=io_to_ucie_mm + port["distance_mm"],
|
||||
bw_gbs=io_to_ucie_bw,
|
||||
kind="cube_to_io",
|
||||
))
|
||||
|
||||
|
||||
def _add_system_to_io_edges(
|
||||
edges: list[Edge], sp: str, sip_spec: dict, system: dict,
|
||||
) -> None:
|
||||
"""Add fabric switch → IO chiplet PCIe edges."""
|
||||
sw_id = "fabric.switch0"
|
||||
sys_link = system["links"]["io_ep_to_switch"]
|
||||
for inst in sip_spec["iochiplet"]["instances"]:
|
||||
pcie_ep_id = f"{sp}.{inst['id']}.pcie_ep"
|
||||
edges.append(Edge(
|
||||
src=sw_id, dst=pcie_ep_id,
|
||||
distance_mm=sys_link["distance_mm"],
|
||||
bw_gbs=sys_link["bw_gbs_per_ep"],
|
||||
kind="pcie",
|
||||
))
|
||||
|
||||
|
||||
# ── View builders ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_system_view(spec: dict) -> ViewGraph:
|
||||
"""System-level view: SIP blocks, IO chiplets, fabric switch."""
|
||||
system = spec["system"]
|
||||
sip_count = system["sips"]["count"]
|
||||
sip_w, sip_h = 71.0, 59.0
|
||||
gap = 30.0
|
||||
canvas_w = sip_count * sip_w + (sip_count - 1) * gap
|
||||
canvas_h = sip_h + 20.0
|
||||
|
||||
nodes: dict[str, Node] = {}
|
||||
view_edges: list[Edge] = []
|
||||
|
||||
sw = system["components"]["switch"]
|
||||
sw_id = "fabric.switch0"
|
||||
nodes[sw_id] = Node(
|
||||
id=sw_id, kind=sw["kind"], impl=sw["impl"],
|
||||
attrs=sw.get("attrs", {}), pos_mm=(canvas_w / 2, 5.0), label="Fabric Switch",
|
||||
)
|
||||
|
||||
for s in range(sip_count):
|
||||
sx = s * (sip_w + gap)
|
||||
sy = 20.0
|
||||
sip_id = f"sip{s}"
|
||||
|
||||
nodes[sip_id] = Node(
|
||||
id=sip_id, kind="sip", impl="",
|
||||
attrs={"w_mm": sip_w, "h_mm": sip_h},
|
||||
pos_mm=(sx + sip_w / 2, sy + sip_h / 2),
|
||||
label=f"SIP {s}",
|
||||
)
|
||||
|
||||
for inst in spec["sip"]["iochiplet"]["instances"]:
|
||||
iid = inst["id"]
|
||||
io_nid = f"{sip_id}.{iid}"
|
||||
side = inst["place"]["side"]
|
||||
iy = sy if side == "N" else sy + sip_h
|
||||
nodes[io_nid] = Node(
|
||||
id=io_nid, kind="iochiplet", impl="",
|
||||
attrs={}, pos_mm=(sx + sip_w / 2, iy), label=f"IO {iid}",
|
||||
)
|
||||
view_edges.append(Edge(
|
||||
src=sw_id, dst=io_nid,
|
||||
distance_mm=system["links"]["io_ep_to_switch"]["distance_mm"],
|
||||
bw_gbs=system["links"]["io_ep_to_switch"]["bw_gbs_per_ep"],
|
||||
kind="pcie",
|
||||
))
|
||||
|
||||
return ViewGraph(
|
||||
name="system", nodes=nodes, edges=view_edges,
|
||||
width_mm=canvas_w, height_mm=canvas_h,
|
||||
)
|
||||
|
||||
|
||||
def _build_sip_view(spec: dict) -> ViewGraph:
|
||||
"""SIP-level view: cube mesh + IO chiplets (representative, sip0)."""
|
||||
sip_spec = spec["sip"]
|
||||
cube_spec = spec["cube"]
|
||||
mesh_w = sip_spec["cube_mesh"]["w"]
|
||||
mesh_h = sip_spec["cube_mesh"]["h"]
|
||||
cube_w = cube_spec["geometry"]["cube_mm"]["w"]
|
||||
cube_h = cube_spec["geometry"]["cube_mm"]["h"]
|
||||
seam = sip_spec["links"]["inter_cube_mesh"]["distance_mm_across_seam"]
|
||||
stride_x = cube_w + seam
|
||||
stride_y = cube_h + seam
|
||||
mesh_total_w = mesh_w * cube_w + (mesh_w - 1) * seam
|
||||
mesh_total_h = mesh_h * cube_h + (mesh_h - 1) * seam
|
||||
io_margin = 6.0
|
||||
canvas_w = mesh_total_w
|
||||
canvas_h = mesh_total_h + 2 * io_margin
|
||||
|
||||
nodes: dict[str, Node] = {}
|
||||
view_edges: list[Edge] = []
|
||||
|
||||
# Cubes as opaque blocks
|
||||
for row in range(mesh_h):
|
||||
for col in range(mesh_w):
|
||||
cid = row * mesh_w + col
|
||||
cx = col * stride_x + cube_w / 2
|
||||
cy = io_margin + row * stride_y + cube_h / 2
|
||||
nid = f"cube{cid}"
|
||||
nodes[nid] = Node(
|
||||
id=nid, kind="cube", impl="",
|
||||
attrs={"w_mm": cube_w, "h_mm": cube_h, "col": col, "row": row},
|
||||
pos_mm=(cx, cy), label=f"CUBE ({col},{row})",
|
||||
)
|
||||
|
||||
# Inter-cube mesh edges
|
||||
mesh_link = sip_spec["links"]["inter_cube_mesh"]
|
||||
for row in range(mesh_h):
|
||||
for col in range(mesh_w):
|
||||
cid = row * mesh_w + col
|
||||
if col + 1 < mesh_w:
|
||||
nid = row * mesh_w + (col + 1)
|
||||
view_edges.append(Edge(
|
||||
src=f"cube{cid}", dst=f"cube{nid}",
|
||||
distance_mm=mesh_link["distance_mm_across_seam"],
|
||||
bw_gbs=mesh_link["bw_gbs_per_ucie_phy"],
|
||||
kind="ucie_mesh",
|
||||
))
|
||||
if row + 1 < mesh_h:
|
||||
nid = (row + 1) * mesh_w + col
|
||||
view_edges.append(Edge(
|
||||
src=f"cube{cid}", dst=f"cube{nid}",
|
||||
distance_mm=mesh_link["distance_mm_across_seam"],
|
||||
bw_gbs=mesh_link["bw_gbs_per_ucie_phy"],
|
||||
kind="ucie_mesh",
|
||||
))
|
||||
|
||||
# IO chiplets
|
||||
io_links = sip_spec["iochiplet"]["links"]
|
||||
for inst in sip_spec["iochiplet"]["instances"]:
|
||||
iid = inst["id"]
|
||||
side = inst["place"]["side"]
|
||||
iy = 2.0 if side == "N" else canvas_h - 2.0
|
||||
nodes[iid] = Node(
|
||||
id=iid, kind="iochiplet", impl="",
|
||||
attrs={}, pos_mm=(mesh_total_w / 2, iy), label=f"IO {iid}",
|
||||
)
|
||||
for port in inst["cube_ports"]:
|
||||
cube_col, cube_row = port["cube"]["xy"]
|
||||
cube_id = cube_row * mesh_w + cube_col
|
||||
view_edges.append(Edge(
|
||||
src=iid, dst=f"cube{cube_id}",
|
||||
distance_mm=io_links["io_cpu_to_ucie_mm"] + port["distance_mm"],
|
||||
bw_gbs=io_links["io_cpu_to_ucie_bw_gbs"],
|
||||
kind="io_to_cube",
|
||||
))
|
||||
|
||||
return ViewGraph(
|
||||
name="sip", nodes=nodes, edges=view_edges,
|
||||
width_mm=canvas_w, height_mm=canvas_h,
|
||||
)
|
||||
|
||||
|
||||
def _build_cube_view(spec: dict) -> ViewGraph:
|
||||
"""Cube-level view: representative single cube, PEs as opaque blocks."""
|
||||
cube = spec["cube"]
|
||||
cube_w = cube["geometry"]["cube_mm"]["w"]
|
||||
cube_h = cube["geometry"]["cube_mm"]["h"]
|
||||
local_pos = _cube_local_positions(cube_w, cube_h)
|
||||
clinks = cube["links"]
|
||||
n_slices = cube["memory_map"]["hbm_slices_per_cube"]
|
||||
|
||||
nodes: dict[str, Node] = {}
|
||||
view_edges: list[Edge] = []
|
||||
|
||||
# UCIe ports
|
||||
for port in cube["ucie"]["ports"]:
|
||||
pid = f"ucie-{port}"
|
||||
lx, ly = local_pos[pid]
|
||||
nodes[pid] = Node(
|
||||
id=pid, kind="ucie_port", impl="ucie_v1",
|
||||
attrs={}, pos_mm=(lx, ly), label=f"UCIe-{port}",
|
||||
)
|
||||
|
||||
# Named components (hbm_ctrl as single representative node in view)
|
||||
for name in ("noc", "m_cpu", "hbm_ctrl", "sram"):
|
||||
c = cube["components"][name]
|
||||
lx, ly = local_pos[name]
|
||||
nodes[name] = Node(
|
||||
id=name, kind=c["kind"], impl=c["impl"],
|
||||
attrs=c["attrs"], pos_mm=(lx, ly),
|
||||
label=name.upper().replace("_", " "),
|
||||
)
|
||||
|
||||
# Bridges
|
||||
for br in cube["components"]["xbar"]["bridges"]:
|
||||
bname = br["id"]
|
||||
bid = f"bridge.{bname}"
|
||||
lx, ly = local_pos[bid]
|
||||
nodes[bid] = Node(
|
||||
id=bid, kind=br["kind"], impl=br["impl"],
|
||||
attrs=br["attrs"], pos_mm=(lx, ly),
|
||||
label=f"Bridge {bname.upper()}",
|
||||
)
|
||||
|
||||
# PEs as opaque blocks + per-PE xbar entry nodes
|
||||
corners = cube["pe_layout"]["corners"]
|
||||
pe_per_corner = cube["pe_layout"]["pe_per_corner"]
|
||||
corner_pos = _corner_pe_positions(cube_w, cube_h)
|
||||
xbar_pe_spec = cube["components"]["xbar"]["pe"]
|
||||
xbar_top_y = local_pos["xbar.top"][1]
|
||||
xbar_bot_y = local_pos["xbar.bottom"][1]
|
||||
|
||||
pe_idx = 0
|
||||
for corner in corners:
|
||||
is_top = corner in ("NW", "NE")
|
||||
xbar_y = xbar_top_y if is_top else xbar_bot_y
|
||||
mm_key = "pe_to_xbar_row_n_mm" if is_top else "pe_to_xbar_row_s_mm"
|
||||
for ci in range(pe_per_corner):
|
||||
pid = f"pe{pe_idx}"
|
||||
xbar_id = f"xbar.pe{pe_idx}"
|
||||
px, py = corner_pos[corner][ci]
|
||||
|
||||
nodes[pid] = Node(
|
||||
id=pid, kind="pe", impl="",
|
||||
attrs={"corner": corner}, pos_mm=(px, py),
|
||||
label=f"PE{pe_idx}",
|
||||
)
|
||||
nodes[xbar_id] = Node(
|
||||
id=xbar_id, kind=xbar_pe_spec["kind"], impl=xbar_pe_spec["impl"],
|
||||
attrs=xbar_pe_spec["attrs"], pos_mm=(px, xbar_y),
|
||||
label=f"XBAR PE{pe_idx}",
|
||||
)
|
||||
|
||||
# PE → xbar.pe_i (HBM data path)
|
||||
view_edges.append(Edge(
|
||||
src=pid, dst=xbar_id,
|
||||
distance_mm=clinks[mm_key],
|
||||
bw_gbs=clinks["pe_to_xbar_bw_gbs"],
|
||||
kind="pe_to_xbar",
|
||||
))
|
||||
# PE → noc (non-HBM data path)
|
||||
view_edges.append(Edge(
|
||||
src=pid, dst="noc",
|
||||
distance_mm=clinks["pe_dma_to_noc_mm"],
|
||||
bw_gbs=clinks["pe_dma_to_noc_bw_gbs"],
|
||||
kind="pe_to_noc",
|
||||
))
|
||||
# noc → PE (command delivery)
|
||||
view_edges.append(Edge(
|
||||
src="noc", dst=pid,
|
||||
distance_mm=clinks["noc_to_pe_cpu_mm"],
|
||||
kind="command",
|
||||
))
|
||||
pe_idx += 1
|
||||
|
||||
# Cube fabric edges
|
||||
# xbar.pe_i → hbm_ctrl (single representative node in view)
|
||||
for i in range(n_slices):
|
||||
view_edges.append(Edge(
|
||||
src=f"xbar.pe{i}", dst="hbm_ctrl",
|
||||
distance_mm=clinks["xbar_to_hbm_mm"],
|
||||
bw_gbs=clinks["xbar_to_hbm_bw_gbs"],
|
||||
kind="xbar_to_hbm",
|
||||
))
|
||||
|
||||
# xbar chain
|
||||
half = n_slices // 2
|
||||
for half_start in (0, half):
|
||||
for i in range(half_start, half_start + half - 1):
|
||||
intra = ((i - half_start) % pe_per_corner) != (pe_per_corner - 1)
|
||||
x_dist = clinks["xbar_chain_intra_corner_mm"] if intra else clinks["xbar_chain_inter_corner_mm"]
|
||||
for a, b in [(i, i + 1), (i + 1, i)]:
|
||||
view_edges.append(Edge(
|
||||
src=f"xbar.pe{a}", dst=f"xbar.pe{b}",
|
||||
distance_mm=x_dist,
|
||||
bw_gbs=clinks["xbar_x_bw_gbs"],
|
||||
kind="xbar_chain",
|
||||
))
|
||||
|
||||
# bridge connections
|
||||
for bname, pe_top, pe_bot in [("left", 0, half), ("right", half - 1, n_slices - 1)]:
|
||||
br_id = f"bridge.{bname}"
|
||||
for pe_i, br_mm_key in [(pe_top, "xbar_row_n_to_bridge_mm"),
|
||||
(pe_bot, "xbar_row_s_to_bridge_mm")]:
|
||||
xbar_id = f"xbar.pe{pe_i}"
|
||||
view_edges.append(Edge(
|
||||
src=xbar_id, dst=br_id,
|
||||
distance_mm=clinks[br_mm_key],
|
||||
bw_gbs=clinks["xbar_to_bridge_bw_gbs"],
|
||||
kind="xbar_to_bridge",
|
||||
))
|
||||
view_edges.append(Edge(
|
||||
src=br_id, dst=xbar_id,
|
||||
distance_mm=clinks[br_mm_key],
|
||||
bw_gbs=clinks["xbar_to_bridge_bw_gbs"],
|
||||
kind="bridge_to_xbar",
|
||||
))
|
||||
|
||||
_noc_ucie_v = clinks["noc_to_ucie"]
|
||||
for port in cube["ucie"]["ports"]:
|
||||
view_edges.append(Edge(
|
||||
src="noc", dst=f"ucie-{port}",
|
||||
distance_mm=0.0,
|
||||
bw_gbs=_noc_ucie_v["per_connection_bw_gbs"],
|
||||
n_connections=_noc_ucie_v["n_connections"],
|
||||
kind="noc_to_ucie",
|
||||
))
|
||||
|
||||
# m_cpu ↔ noc (command dispatch, both directions)
|
||||
view_edges.append(Edge(
|
||||
src="m_cpu", dst="noc",
|
||||
distance_mm=clinks["m_cpu_to_noc_mm"],
|
||||
kind="command",
|
||||
))
|
||||
view_edges.append(Edge(
|
||||
src="noc", dst="m_cpu",
|
||||
distance_mm=clinks["m_cpu_to_noc_mm"],
|
||||
kind="command",
|
||||
))
|
||||
|
||||
# noc ↔ sram (shared SRAM access, bidirectional)
|
||||
_noc_sram_v = clinks["noc_to_sram"]
|
||||
view_edges.append(Edge(
|
||||
src="noc", dst="sram",
|
||||
distance_mm=clinks["noc_to_sram_mm"],
|
||||
bw_gbs=_noc_sram_v["per_connection_bw_gbs"],
|
||||
n_connections=_noc_sram_v["n_connections"],
|
||||
kind="noc_to_sram",
|
||||
))
|
||||
view_edges.append(Edge(
|
||||
src="sram", dst="noc",
|
||||
distance_mm=clinks["noc_to_sram_mm"],
|
||||
bw_gbs=_noc_sram_v["per_connection_bw_gbs"],
|
||||
n_connections=_noc_sram_v["n_connections"],
|
||||
kind="noc_to_sram",
|
||||
))
|
||||
|
||||
return ViewGraph(
|
||||
name="cube", nodes=nodes, edges=view_edges,
|
||||
width_mm=cube_w, height_mm=cube_h,
|
||||
)
|
||||
|
||||
|
||||
def _build_pe_view(spec: dict) -> ViewGraph:
|
||||
"""PE-level view: representative single PE with all template components."""
|
||||
pe_tmpl = spec["cube"]["pe_template"]
|
||||
pe_links = pe_tmpl["links"]
|
||||
canvas_w, canvas_h = 12.0, 8.0
|
||||
|
||||
positions = {
|
||||
"pe_cpu": (1.5, 4.0),
|
||||
"pe_scheduler": (4.0, 4.0),
|
||||
"pe_dma": (7.0, 1.5),
|
||||
"pe_gemm": (7.0, 4.0),
|
||||
"pe_math": (7.0, 6.5),
|
||||
"pe_tcm": (10.0, 4.0),
|
||||
}
|
||||
|
||||
nodes: dict[str, Node] = {}
|
||||
view_edges: list[Edge] = []
|
||||
|
||||
for comp_name, comp_spec in pe_tmpl["components"].items():
|
||||
px, py = positions[comp_name]
|
||||
nodes[comp_name] = Node(
|
||||
id=comp_name, kind=comp_spec["kind"], impl=comp_spec["impl"],
|
||||
attrs=comp_spec["attrs"], pos_mm=(px, py),
|
||||
label=comp_name.upper().replace("_", " "),
|
||||
)
|
||||
|
||||
view_edges.append(Edge(
|
||||
src="pe_cpu", dst="pe_scheduler",
|
||||
distance_mm=pe_links["pe_cpu_to_scheduler_mm"],
|
||||
kind="pe_internal",
|
||||
))
|
||||
for eng, key in [("pe_dma", "scheduler_to_dma_mm"),
|
||||
("pe_gemm", "scheduler_to_gemm_mm"),
|
||||
("pe_math", "scheduler_to_math_mm")]:
|
||||
view_edges.append(Edge(
|
||||
src="pe_scheduler", dst=eng,
|
||||
distance_mm=pe_links[key],
|
||||
kind="pe_internal",
|
||||
))
|
||||
for eng, mm_key, bw_key in [("pe_dma", "dma_to_tcm_mm", "dma_to_tcm_bw_gbs"),
|
||||
("pe_gemm", "gemm_to_tcm_mm", "gemm_to_tcm_bw_gbs"),
|
||||
("pe_math", "math_to_tcm_mm", "math_to_tcm_bw_gbs")]:
|
||||
view_edges.append(Edge(
|
||||
src=eng, dst="pe_tcm",
|
||||
distance_mm=pe_links[mm_key],
|
||||
bw_gbs=pe_links[bw_key],
|
||||
kind="pe_internal",
|
||||
))
|
||||
|
||||
return ViewGraph(
|
||||
name="pe", nodes=nodes, edges=view_edges,
|
||||
width_mm=canvas_w, height_mm=canvas_h,
|
||||
)
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class Node:
|
||||
id: str
|
||||
kind: str
|
||||
impl: str
|
||||
attrs: dict[str, Any]
|
||||
pos_mm: tuple[float, float] | None # (x_mm, y_mm); None for abstract nodes
|
||||
label: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Edge:
|
||||
src: str # node id
|
||||
dst: str # node id
|
||||
distance_mm: float # physical wire delay distance (ns = distance_mm * ns_per_mm)
|
||||
routing_weight_mm: float | None = None # Dijkstra cost; None → use distance_mm
|
||||
bw_gbs: float | None = None
|
||||
n_connections: int | None = None # multi-connection links; single request uses 1 connection
|
||||
kind: str = "link"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ViewGraph:
|
||||
name: str # "system" | "sip" | "cube" | "pe"
|
||||
nodes: dict[str, Node]
|
||||
edges: list[Edge]
|
||||
width_mm: float
|
||||
height_mm: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class TopologyGraph:
|
||||
spec: dict[str, Any]
|
||||
|
||||
# Full instantiated flat graph (used by sim_engine)
|
||||
nodes: dict[str, Node] = field(default_factory=dict)
|
||||
edges: list[Edge] = field(default_factory=list)
|
||||
|
||||
# Representative view projections (used by visualizer)
|
||||
system_view: ViewGraph | None = None
|
||||
sip_view: ViewGraph | None = None
|
||||
cube_view: ViewGraph | None = None
|
||||
pe_view: ViewGraph | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TopologyHandle:
|
||||
path: Path
|
||||
topology_obj: TopologyGraph | None # None until _compile_graph is implemented
|
||||
@@ -0,0 +1,367 @@
|
||||
# kernbench/topology/visualizer.py
|
||||
"""
|
||||
SVG diagram generator for TopologyGraph views.
|
||||
|
||||
Produces mm-accurate, deterministic SVG files for each view level
|
||||
(system, SIP, cube, PE) per ADR-0005 and ADR-0006.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .types import Edge, Node, TopologyGraph, ViewGraph
|
||||
|
||||
# ── Color palette by component kind ─────────────────────────────────
|
||||
|
||||
_KIND_COLORS: dict[str, str] = {
|
||||
"switch": "#6366f1", # indigo
|
||||
"sip": "#e0e7ff", # light indigo
|
||||
"iochiplet": "#0ea5e9", # sky blue
|
||||
"pcie_ep": "#0ea5e9",
|
||||
"io_cpu": "#0ea5e9",
|
||||
"ucie_port": "#3b82f6", # blue
|
||||
"noc": "#a78bfa", # purple
|
||||
"m_cpu": "#f59e0b", # amber
|
||||
"xbar": "#f97316", # orange
|
||||
"hbm_ctrl": "#10b981", # emerald
|
||||
"pe": "#94a3b8", # slate
|
||||
"pe_cpu": "#ef4444", # red
|
||||
"pe_scheduler": "#f59e0b", # amber
|
||||
"pe_dma": "#3b82f6", # blue
|
||||
"pe_gemm": "#8b5cf6", # violet
|
||||
"pe_math": "#ec4899", # pink
|
||||
"pe_tcm": "#10b981", # emerald
|
||||
"sram": "#f59e0b", # amber
|
||||
"cube": "#cbd5e1", # slate-300
|
||||
}
|
||||
|
||||
_EDGE_COLORS: dict[str, str] = {
|
||||
"pcie": "#6366f1",
|
||||
"io_internal": "#0ea5e9",
|
||||
"io_to_cube": "#0ea5e9",
|
||||
"ucie_mesh": "#3b82f6",
|
||||
"pe_to_xbar": "#f97316",
|
||||
"xbar_to_hbm": "#10b981",
|
||||
"xbar_to_bridge": "#a78bfa",
|
||||
"bridge_to_xbar": "#a78bfa",
|
||||
"noc_to_ucie": "#a78bfa",
|
||||
"pe_to_noc": "#a78bfa",
|
||||
"noc_to_sram": "#f59e0b",
|
||||
"command": "#f59e0b",
|
||||
"pe_internal": "#94a3b8",
|
||||
}
|
||||
|
||||
# ── Node sizing ──────────────────────────────────────────────────────
|
||||
|
||||
_DEFAULT_NODE_W = 2.0 # mm
|
||||
_DEFAULT_NODE_H = 1.2 # mm
|
||||
|
||||
_KIND_SIZE: dict[str, tuple[float, float]] = {
|
||||
"sip": (60.0, 50.0),
|
||||
"cube": (6.0, 4.0),
|
||||
"iochiplet": (4.0, 1.5),
|
||||
"switch": (5.0, 1.5),
|
||||
}
|
||||
|
||||
|
||||
# ── Public API ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def emit_diagrams(graph: TopologyGraph, out_dir: Path) -> list[Path]:
|
||||
"""Generate SVG diagrams for all views. Returns list of created file paths."""
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
created: list[Path] = []
|
||||
|
||||
views = [
|
||||
("system_view", graph.system_view),
|
||||
("sip_view", graph.sip_view),
|
||||
("cube_view", graph.cube_view),
|
||||
("pe_view", graph.pe_view),
|
||||
]
|
||||
|
||||
for name, view in views:
|
||||
if view is None:
|
||||
continue
|
||||
svg = _render_view_svg(view)
|
||||
path = out_dir / f"{name}.svg"
|
||||
path.write_text(svg, encoding="utf-8")
|
||||
created.append(path)
|
||||
|
||||
return created
|
||||
|
||||
|
||||
# ── SVG rendering ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _render_view_svg(view: ViewGraph) -> str:
|
||||
"""Render a ViewGraph to an SVG string."""
|
||||
scale = _pick_scale(view)
|
||||
pad = 40 # px padding
|
||||
node_sizes = _compute_node_sizes(view, scale)
|
||||
|
||||
# Canvas size in px
|
||||
w_px = int(view.width_mm * scale + 2 * pad)
|
||||
h_px = int(view.height_mm * scale + 2 * pad)
|
||||
|
||||
parts: list[str] = []
|
||||
parts.append(_svg_header(w_px, h_px, view.name))
|
||||
|
||||
# Background
|
||||
parts.append(f' <rect width="{w_px}" height="{h_px}" fill="#f8fafc"/>')
|
||||
|
||||
# Title
|
||||
parts.append(
|
||||
f' <text x="{w_px // 2}" y="18" text-anchor="middle" '
|
||||
f'font-family="monospace" font-size="14" font-weight="bold" fill="#1e293b">'
|
||||
f'{view.name.upper()} VIEW</text>'
|
||||
)
|
||||
|
||||
# Special: draw cube boundary + HBM block background in cube view
|
||||
if view.name == "cube":
|
||||
_draw_cube_boundary(parts, view, scale, pad)
|
||||
_draw_hbm_block(parts, view, scale, pad)
|
||||
|
||||
# Edges (draw before nodes so nodes are on top)
|
||||
# Track fan-out edges to assign per-edge offsets
|
||||
fanout_counter: dict[str, int] = {}
|
||||
for edge in view.edges:
|
||||
if edge.src in view.nodes and edge.dst in view.nodes:
|
||||
_draw_edge(parts, edge, view, node_sizes, scale, pad, fanout_counter)
|
||||
|
||||
# Nodes
|
||||
for node in view.nodes.values():
|
||||
_draw_node(parts, node, node_sizes, scale, pad)
|
||||
|
||||
parts.append("</svg>")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _pick_scale(view: ViewGraph) -> float:
|
||||
"""Pixels per mm, chosen per view type."""
|
||||
return {
|
||||
"system": 4.0,
|
||||
"sip": 8.0,
|
||||
"cube": 28.0,
|
||||
"pe": 35.0,
|
||||
}.get(view.name, 10.0)
|
||||
|
||||
|
||||
def _compute_node_sizes(
|
||||
view: ViewGraph, scale: float,
|
||||
) -> dict[str, tuple[float, float]]:
|
||||
"""Returns (w_px, h_px) for each node."""
|
||||
sizes: dict[str, tuple[float, float]] = {}
|
||||
for nid, node in view.nodes.items():
|
||||
w_mm, h_mm = _KIND_SIZE.get(node.kind, (_DEFAULT_NODE_W, _DEFAULT_NODE_H))
|
||||
# For cube view, use smaller PE nodes
|
||||
if view.name == "cube" and node.kind == "pe":
|
||||
w_mm, h_mm = 1.8, 1.0
|
||||
if view.name == "pe":
|
||||
w_mm, h_mm = 2.5, 1.4
|
||||
sizes[nid] = (w_mm * scale, h_mm * scale)
|
||||
return sizes
|
||||
|
||||
|
||||
def _svg_header(w: int, h: int, title: str) -> str:
|
||||
return (
|
||||
f'<svg xmlns="http://www.w3.org/2000/svg" '
|
||||
f'width="{w}" height="{h}" viewBox="0 0 {w} {h}">\n'
|
||||
f' <title>{title}</title>'
|
||||
)
|
||||
|
||||
|
||||
def _draw_cube_boundary(
|
||||
parts: list[str], view: ViewGraph, scale: float, pad: int,
|
||||
) -> None:
|
||||
"""Draw the cube die outline as a dashed rectangle."""
|
||||
bx = pad
|
||||
by = pad
|
||||
bw = view.width_mm * scale
|
||||
bh = view.height_mm * scale
|
||||
parts.append(
|
||||
f' <rect x="{bx:.1f}" y="{by:.1f}" '
|
||||
f'width="{bw:.1f}" height="{bh:.1f}" '
|
||||
f'rx="6" fill="none" stroke="#475569" stroke-width="2" '
|
||||
f'stroke-dasharray="8,4"/>'
|
||||
)
|
||||
|
||||
|
||||
def _draw_hbm_block(
|
||||
parts: list[str], view: ViewGraph, scale: float, pad: int,
|
||||
) -> None:
|
||||
"""Draw HBM area as a filled rectangle in cube view."""
|
||||
# HBM area: centered at (8.5, 7.0), size 9x5 -> x=[4.0,13.0], y=[4.5,9.5]
|
||||
hbm_x = 4.0 * scale + pad
|
||||
hbm_y = 4.5 * scale + pad
|
||||
hbm_w = 9.0 * scale
|
||||
hbm_h = 5.0 * scale
|
||||
parts.append(
|
||||
f' <rect x="{hbm_x:.1f}" y="{hbm_y:.1f}" '
|
||||
f'width="{hbm_w:.1f}" height="{hbm_h:.1f}" '
|
||||
f'rx="4" fill="#d1fae5" stroke="#10b981" stroke-width="1.5" '
|
||||
f'stroke-dasharray="6,3" opacity="0.5"/>'
|
||||
)
|
||||
cx = 8.5 * scale + pad
|
||||
cy = 8.5 * scale + pad
|
||||
parts.append(
|
||||
f' <text x="{cx:.1f}" y="{cy:.1f}" text-anchor="middle" '
|
||||
f'font-family="monospace" font-size="11" fill="#047857" opacity="0.7">'
|
||||
f'HBM</text>'
|
||||
)
|
||||
|
||||
|
||||
def _draw_node(
|
||||
parts: list[str],
|
||||
node: Node,
|
||||
sizes: dict[str, tuple[float, float]],
|
||||
scale: float,
|
||||
pad: int,
|
||||
) -> None:
|
||||
"""Draw a single node as a rounded rectangle with label."""
|
||||
if node.pos_mm is None:
|
||||
return
|
||||
px = node.pos_mm[0] * scale + pad
|
||||
py = node.pos_mm[1] * scale + pad
|
||||
w, h = sizes.get(node.id, (40, 24))
|
||||
|
||||
x = px - w / 2
|
||||
y = py - h / 2
|
||||
fill = _KIND_COLORS.get(node.kind, "#e2e8f0")
|
||||
text_color = "#ffffff" if _is_dark(fill) else "#1e293b"
|
||||
|
||||
parts.append(
|
||||
f' <rect x="{x:.1f}" y="{y:.1f}" width="{w:.1f}" height="{h:.1f}" '
|
||||
f'rx="4" fill="{fill}" stroke="#475569" stroke-width="1"/>'
|
||||
)
|
||||
|
||||
label = node.label or node.id
|
||||
font_size = _label_font_size(w, label)
|
||||
parts.append(
|
||||
f' <text x="{px:.1f}" y="{py + 4:.1f}" text-anchor="middle" '
|
||||
f'font-family="monospace" font-size="{font_size}" fill="{text_color}">'
|
||||
f'{_escape(label)}</text>'
|
||||
)
|
||||
|
||||
|
||||
# ── Fan-out edge kinds that need offset routing ─────────────────────
|
||||
|
||||
_FANOUT_KINDS = {"pe_to_xbar", "pe_to_noc", "command", "noc_to_ucie"}
|
||||
|
||||
|
||||
def _draw_edge(
|
||||
parts: list[str],
|
||||
edge: Edge,
|
||||
view: ViewGraph,
|
||||
sizes: dict[str, tuple[float, float]],
|
||||
scale: float,
|
||||
pad: int,
|
||||
fanout_counter: dict[str, int],
|
||||
) -> None:
|
||||
"""Draw an edge with orthogonal (90-degree) routing for fan-out kinds."""
|
||||
nodes = view.nodes
|
||||
src_node = nodes[edge.src]
|
||||
dst_node = nodes[edge.dst]
|
||||
if src_node.pos_mm is None or dst_node.pos_mm is None:
|
||||
return
|
||||
|
||||
x1 = src_node.pos_mm[0] * scale + pad
|
||||
y1 = src_node.pos_mm[1] * scale + pad
|
||||
x2 = dst_node.pos_mm[0] * scale + pad
|
||||
y2 = dst_node.pos_mm[1] * scale + pad
|
||||
|
||||
color = _EDGE_COLORS.get(edge.kind, "#94a3b8")
|
||||
width = "1.5" if edge.kind == "pe_internal" else "1"
|
||||
opacity = "0.6" if edge.kind in ("command", "noc_to_ucie") else "0.8"
|
||||
|
||||
if edge.kind in _FANOUT_KINDS and view.name == "cube":
|
||||
# Orthogonal routing: src→horizontal→vertical→dst with per-edge offset.
|
||||
group_key = f"{edge.kind}:{edge.dst}"
|
||||
idx = fanout_counter.get(group_key, 0)
|
||||
fanout_counter[group_key] = idx + 1
|
||||
|
||||
# Route: go vertically from src to a staggered horizontal channel,
|
||||
# then horizontally to dst x, then vertically to dst.
|
||||
mid_y = (y1 + y2) / 2 + (idx - 1.5) * 10 # spread channels vertically
|
||||
|
||||
parts.append(
|
||||
f' <polyline points="{x1:.1f},{y1:.1f} {x1:.1f},{mid_y:.1f} '
|
||||
f'{x2:.1f},{mid_y:.1f} {x2:.1f},{y2:.1f}" '
|
||||
f'fill="none" stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
|
||||
)
|
||||
|
||||
# Label on the horizontal segment
|
||||
if edge.distance_mm > 0:
|
||||
lx = (x1 + x2) / 2
|
||||
label = f"{edge.distance_mm:.1f}mm"
|
||||
if edge.bw_gbs:
|
||||
label += f" {edge.bw_gbs:.0f}GB/s"
|
||||
parts.append(
|
||||
f' <text x="{lx:.1f}" y="{mid_y - 3:.1f}" text-anchor="middle" '
|
||||
f'font-family="monospace" font-size="7" fill="#64748b">'
|
||||
f'{label}</text>'
|
||||
)
|
||||
return
|
||||
|
||||
# Non-fanout: orthogonal L-bend
|
||||
if abs(x2 - x1) > 1 and abs(y2 - y1) > 1:
|
||||
# PE view: vertical-first for left→right edges (scheduler→engines),
|
||||
# horizontal-first for right→right edges (engines→tcm)
|
||||
if view.name == "pe":
|
||||
if src_node.pos_mm[0] < view.width_mm / 2:
|
||||
# Source in left half: vertical-first (scheduler fan-out)
|
||||
parts.append(
|
||||
f' <polyline points="{x1:.1f},{y1:.1f} {x1:.1f},{y2:.1f} {x2:.1f},{y2:.1f}" '
|
||||
f'fill="none" stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
|
||||
)
|
||||
else:
|
||||
# Source in right half: horizontal-first (dma/math→tcm)
|
||||
parts.append(
|
||||
f' <polyline points="{x1:.1f},{y1:.1f} {x2:.1f},{y1:.1f} {x2:.1f},{y2:.1f}" '
|
||||
f'fill="none" stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
|
||||
)
|
||||
else:
|
||||
parts.append(
|
||||
f' <polyline points="{x1:.1f},{y1:.1f} {x2:.1f},{y1:.1f} {x2:.1f},{y2:.1f}" '
|
||||
f'fill="none" stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
|
||||
)
|
||||
else:
|
||||
parts.append(
|
||||
f' <line x1="{x1:.1f}" y1="{y1:.1f}" x2="{x2:.1f}" y2="{y2:.1f}" '
|
||||
f'stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
|
||||
)
|
||||
|
||||
# Distance label at midpoint
|
||||
if edge.distance_mm > 0:
|
||||
mx = (x1 + x2) / 2
|
||||
my = (y1 + y2) / 2
|
||||
label = f"{edge.distance_mm:.1f}mm"
|
||||
if edge.bw_gbs:
|
||||
label += f" {edge.bw_gbs:.0f}GB/s"
|
||||
parts.append(
|
||||
f' <text x="{mx:.1f}" y="{my - 4:.1f}" text-anchor="middle" '
|
||||
f'font-family="monospace" font-size="7" fill="#64748b">'
|
||||
f'{label}</text>'
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _is_dark(hex_color: str) -> bool:
|
||||
"""Check if a hex color is dark (for white text)."""
|
||||
h = hex_color.lstrip("#")
|
||||
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
|
||||
return (r * 0.299 + g * 0.587 + b * 0.114) < 140
|
||||
|
||||
|
||||
def _label_font_size(box_width: float, label: str) -> int:
|
||||
"""Choose font size to fit label in box."""
|
||||
char_w = len(label) * 7
|
||||
if char_w > box_width * 0.9:
|
||||
return max(7, int(box_width * 0.9 / len(label) * 1.4))
|
||||
return 10
|
||||
|
||||
|
||||
def _escape(text: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Triton emulator: fake tl module for kernel performance simulation.
|
||||
|
||||
Provides TLContext (the fake `tl` parameter) that kernels use to express
|
||||
memory access patterns and compute operations. Kernel functions are plain
|
||||
Python — no yield, no async — and generate a PeCommand trace that PE_CPU
|
||||
replays through SimPy.
|
||||
|
||||
Usage:
|
||||
from kernbench.triton_emu.registry import register_kernel, get_kernel
|
||||
from kernbench.triton_emu.tl_context import TLContext
|
||||
"""
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Kernel registry: maps kernel names to Python callable generators.
|
||||
|
||||
Benchmarks register kernel functions here; PE_CPU looks them up by
|
||||
KernelRef.name at execution time.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
_kernels: dict[str, Callable[..., None]] = {}
|
||||
|
||||
|
||||
def register_kernel(name: str, fn: Callable[..., None]) -> None:
|
||||
"""Register a kernel function by name."""
|
||||
if name in _kernels:
|
||||
raise ValueError(f"kernel '{name}' already registered")
|
||||
_kernels[name] = fn
|
||||
|
||||
|
||||
def get_kernel(name: str) -> Callable[..., None]:
|
||||
"""Look up a registered kernel function by name."""
|
||||
if name not in _kernels:
|
||||
raise KeyError(f"kernel '{name}' not registered")
|
||||
return _kernels[name]
|
||||
|
||||
|
||||
def clear_registry() -> None:
|
||||
"""Clear all registered kernels (for testing)."""
|
||||
_kernels.clear()
|
||||
@@ -0,0 +1,356 @@
|
||||
"""TLContext: fake Triton Language module for kernel performance simulation.
|
||||
|
||||
Passed as the `tl` parameter to kernel functions. Each API call records a
|
||||
PeCommand in the internal trace. After the kernel returns, PE_CPU replays
|
||||
the command list through SimPy.
|
||||
|
||||
Kernel code looks like standard Python — no yield, no async:
|
||||
|
||||
def my_kernel(a_ptr, b_ptr, out_ptr, tl):
|
||||
pid = tl.program_id(0)
|
||||
a = tl.load(a_ptr, shape=(32, 64), dtype="f16")
|
||||
b = tl.load(b_ptr + pid * stride, shape=(64, 32), dtype="f16")
|
||||
tl.composite(op="gemm", a=a, b=b, out_ptr=out_ptr)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
from kernbench.common.pe_commands import (
|
||||
CompletionHandle,
|
||||
CompositeCmd,
|
||||
DmaReadCmd,
|
||||
DmaWriteCmd,
|
||||
GemmCmd,
|
||||
MathCmd,
|
||||
PeCommand,
|
||||
PeCpuOverheadCmd,
|
||||
TensorHandle,
|
||||
WaitCmd,
|
||||
)
|
||||
|
||||
_DTYPE_BYTES: dict[str, int] = {
|
||||
"f16": 2, "f32": 4, "f64": 8,
|
||||
"bf16": 2,
|
||||
"i8": 1, "i16": 2, "i32": 4, "i64": 8,
|
||||
"u8": 1, "u16": 2, "u32": 4, "u64": 8,
|
||||
}
|
||||
|
||||
|
||||
class TLContext:
|
||||
"""Fake Triton Language context.
|
||||
|
||||
Args:
|
||||
pe_id: program instance index (returned by program_id).
|
||||
num_programs: total number of program instances.
|
||||
dispatch_cycles: PE_CPU overhead per tl API call (auto-inserted).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pe_id: int = 0,
|
||||
num_programs: int = 1,
|
||||
dispatch_cycles: int = 1,
|
||||
) -> None:
|
||||
self._pe_id = pe_id
|
||||
self._num_programs = num_programs
|
||||
self._dispatch_cycles = dispatch_cycles
|
||||
self._commands: list[PeCommand] = []
|
||||
self._handle_counter = 0
|
||||
self._completion_counter = 0
|
||||
|
||||
@property
|
||||
def commands(self) -> list[PeCommand]:
|
||||
"""Return the recorded command trace."""
|
||||
return self._commands
|
||||
|
||||
# ── helpers ────────────────────────────────────────────────────
|
||||
|
||||
def _next_handle_id(self) -> str:
|
||||
self._handle_counter += 1
|
||||
return f"t{self._handle_counter}"
|
||||
|
||||
def _next_completion_id(self) -> str:
|
||||
self._completion_counter += 1
|
||||
return f"c{self._completion_counter}"
|
||||
|
||||
def _dtype_bytes(self, dtype: str) -> int:
|
||||
return _DTYPE_BYTES.get(dtype, 2)
|
||||
|
||||
def _nbytes(self, shape: tuple[int, ...], dtype: str) -> int:
|
||||
return math.prod(shape) * self._dtype_bytes(dtype)
|
||||
|
||||
def _emit_dispatch_overhead(self) -> None:
|
||||
if self._dispatch_cycles > 0:
|
||||
self._commands.append(PeCpuOverheadCmd(cycles=self._dispatch_cycles))
|
||||
|
||||
def _make_handle(
|
||||
self, pa: int, shape: tuple[int, ...], dtype: str,
|
||||
) -> TensorHandle:
|
||||
return TensorHandle(
|
||||
id=self._next_handle_id(),
|
||||
pa=pa, shape=shape, dtype=dtype,
|
||||
nbytes=self._nbytes(shape, dtype),
|
||||
)
|
||||
|
||||
# ── Reference (no DMA, metadata only) ────────────────────────
|
||||
|
||||
def ref(
|
||||
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
|
||||
) -> TensorHandle:
|
||||
"""Create a TensorHandle referencing HBM data without issuing DMA.
|
||||
|
||||
Used when the scheduler will stream data per-tile (e.g., tensor b
|
||||
in a composite GEMM). No command is generated.
|
||||
"""
|
||||
return self._make_handle(pa=ptr, shape=shape, dtype=dtype)
|
||||
|
||||
# ── Data Movement (blocking, DMA engine) ──────────────────────
|
||||
|
||||
def load(
|
||||
self, ptr: int, shape: tuple[int, ...], dtype: str = "f16",
|
||||
) -> TensorHandle:
|
||||
"""Load tensor from HBM to TCM. Returns TensorHandle."""
|
||||
self._emit_dispatch_overhead()
|
||||
handle = self._make_handle(pa=ptr, shape=shape, dtype=dtype)
|
||||
self._commands.append(DmaReadCmd(
|
||||
handle=handle, src_pa=ptr, nbytes=handle.nbytes,
|
||||
))
|
||||
return handle
|
||||
|
||||
def store(self, ptr: int, handle: TensorHandle) -> None:
|
||||
"""Store tensor from TCM to HBM."""
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(DmaWriteCmd(
|
||||
handle=handle, dst_pa=ptr, nbytes=handle.nbytes,
|
||||
))
|
||||
|
||||
# ── GEMM Engine (blocking) ────────────────────────────────────
|
||||
|
||||
def dot(self, a: TensorHandle, b: TensorHandle) -> TensorHandle:
|
||||
"""Matrix multiply: out = a @ b. Both operands must be in TCM.
|
||||
|
||||
a: (M, K), b: (K, N) → out: (M, N)
|
||||
"""
|
||||
if len(a.shape) < 2 or len(b.shape) < 2:
|
||||
raise ValueError("dot requires 2D tensors")
|
||||
m, k = a.shape[-2], a.shape[-1]
|
||||
k2, n = b.shape[-2], b.shape[-1]
|
||||
if k != k2:
|
||||
raise ValueError(f"dot shape mismatch: a.K={k} != b.K={k2}")
|
||||
out_shape = (*a.shape[:-2], m, n)
|
||||
out_dtype = a.dtype
|
||||
out = self._make_handle(pa=0, shape=out_shape, dtype=out_dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(GemmCmd(a=a, b=b, out=out, m=m, k=k, n=n))
|
||||
return out
|
||||
|
||||
# ── MATH Engine: unary (blocking) ─────────────────────────────
|
||||
|
||||
def _unary_math(self, op: str, x: TensorHandle) -> TensorHandle:
|
||||
out = self._make_handle(pa=0, shape=x.shape, dtype=x.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(MathCmd(op=op, inputs=(x,), out=out))
|
||||
return out
|
||||
|
||||
def exp(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("exp", x)
|
||||
|
||||
def log(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("log", x)
|
||||
|
||||
def sqrt(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("sqrt", x)
|
||||
|
||||
def abs(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("abs", x)
|
||||
|
||||
def sigmoid(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("sigmoid", x)
|
||||
|
||||
def cos(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("cos", x)
|
||||
|
||||
def sin(self, x: TensorHandle) -> TensorHandle:
|
||||
return self._unary_math("sin", x)
|
||||
|
||||
# ── MATH Engine: reduction (blocking) ─────────────────────────
|
||||
|
||||
def _reduction(
|
||||
self, op: str, x: TensorHandle, axis: int,
|
||||
) -> TensorHandle:
|
||||
out_shape = list(x.shape)
|
||||
out_shape[axis] = 1
|
||||
out = self._make_handle(pa=0, shape=tuple(out_shape), dtype=x.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(MathCmd(op=op, inputs=(x,), out=out, axis=axis))
|
||||
return out
|
||||
|
||||
def sum(self, x: TensorHandle, axis: int) -> TensorHandle:
|
||||
return self._reduction("sum", x, axis)
|
||||
|
||||
def max(self, x: TensorHandle, axis: int) -> TensorHandle:
|
||||
return self._reduction("max", x, axis)
|
||||
|
||||
def min(self, x: TensorHandle, axis: int) -> TensorHandle:
|
||||
return self._reduction("min", x, axis)
|
||||
|
||||
# ── MATH Engine: binary (blocking) ────────────────────────────
|
||||
|
||||
def _binary_math(
|
||||
self, op: str, a: TensorHandle, b: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(MathCmd(op=op, inputs=(a, b), out=out))
|
||||
return out
|
||||
|
||||
def where(
|
||||
self, cond: TensorHandle, a: TensorHandle, b: TensorHandle,
|
||||
) -> TensorHandle:
|
||||
out = self._make_handle(pa=0, shape=a.shape, dtype=a.dtype)
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(MathCmd(op="where", inputs=(cond, a, b), out=out))
|
||||
return out
|
||||
|
||||
# ── Index / Scalar (PE_CPU, no engine) ────────────────────────
|
||||
|
||||
def program_id(self, axis: int = 0) -> int:
|
||||
"""Return program instance index."""
|
||||
return self._pe_id
|
||||
|
||||
def num_programs(self, axis: int = 0) -> int:
|
||||
"""Return total number of program instances."""
|
||||
return self._num_programs
|
||||
|
||||
def arange(self, start: int, end: int, dtype: str = "i32") -> TensorHandle:
|
||||
"""Create index range tensor in TCM."""
|
||||
n = end - start
|
||||
return self._make_handle(pa=0, shape=(n,), dtype=dtype)
|
||||
|
||||
def zeros(self, shape: tuple[int, ...], dtype: str = "f16") -> TensorHandle:
|
||||
"""Create zero-filled tensor in TCM."""
|
||||
return self._make_handle(pa=0, shape=shape, dtype=dtype)
|
||||
|
||||
def full(
|
||||
self, shape: tuple[int, ...], value: float | int, dtype: str = "f16",
|
||||
) -> TensorHandle:
|
||||
"""Create constant-filled tensor in TCM."""
|
||||
return self._make_handle(pa=0, shape=shape, dtype=dtype)
|
||||
|
||||
# ── Metadata (no compute, no DMA) ─────────────────────────────
|
||||
|
||||
def trans(self, x: TensorHandle) -> TensorHandle:
|
||||
"""Transpose — shape change only, no command generated."""
|
||||
if len(x.shape) < 2:
|
||||
raise ValueError("trans requires at least 2D tensor")
|
||||
new_shape = (*x.shape[:-2], x.shape[-1], x.shape[-2])
|
||||
return TensorHandle(
|
||||
id=x.id, pa=x.pa, shape=new_shape,
|
||||
dtype=x.dtype, nbytes=x.nbytes, data=x.data,
|
||||
)
|
||||
|
||||
# ── Composite + Control ───────────────────────────────────────
|
||||
|
||||
def composite(
|
||||
self,
|
||||
op: Literal["gemm", "math"],
|
||||
a: TensorHandle,
|
||||
b: TensorHandle | None = None,
|
||||
out_ptr: int = 0,
|
||||
math_op: str | None = None,
|
||||
) -> CompletionHandle:
|
||||
"""Submit a composite command (non-blocking, tiled pipeline).
|
||||
|
||||
Returns CompletionHandle for use with wait().
|
||||
"""
|
||||
# Compute output size based on op
|
||||
if op == "gemm" and b is not None:
|
||||
m, k = a.shape[-2], a.shape[-1]
|
||||
n = b.shape[-1]
|
||||
out_dtype = a.dtype
|
||||
out_nbytes = m * n * self._dtype_bytes(out_dtype)
|
||||
else:
|
||||
out_nbytes = a.nbytes
|
||||
|
||||
completion = CompletionHandle(id=self._next_completion_id())
|
||||
self._emit_dispatch_overhead()
|
||||
self._commands.append(CompositeCmd(
|
||||
completion=completion, op=op,
|
||||
a=a, b=b, out_pa=out_ptr, out_nbytes=out_nbytes,
|
||||
math_op=math_op,
|
||||
))
|
||||
return completion
|
||||
|
||||
def wait(self, handle: CompletionHandle | None = None) -> None:
|
||||
"""Wait for a specific composite or all pending composites."""
|
||||
self._commands.append(WaitCmd(handle=handle))
|
||||
|
||||
def cycles(self, n: int) -> None:
|
||||
"""Declare PE_CPU scalar execution overhead (cycles)."""
|
||||
self._commands.append(PeCpuOverheadCmd(cycles=n))
|
||||
|
||||
|
||||
# ── TensorHandle arithmetic operators ─────────────────────────────
|
||||
# Enables: a + b, a * b, a - b, a / b in kernel code.
|
||||
# Each creates a MathCmd via a module-level helper that requires a
|
||||
# TLContext. We attach the context to handles via a closure approach.
|
||||
|
||||
|
||||
def _enable_tensor_ops() -> None:
|
||||
"""Patch TensorHandle with arithmetic operators.
|
||||
|
||||
Called once at module load. Operators create MathCmd entries via
|
||||
a thread-local TLContext reference set during kernel execution.
|
||||
"""
|
||||
import threading
|
||||
|
||||
_local = threading.local()
|
||||
|
||||
def set_active_context(ctx: TLContext | None) -> None:
|
||||
_local.ctx = ctx
|
||||
|
||||
def get_active_context() -> TLContext:
|
||||
ctx = getattr(_local, "ctx", None)
|
||||
if ctx is None:
|
||||
raise RuntimeError("TensorHandle ops require an active TLContext")
|
||||
return ctx
|
||||
|
||||
def _binop(op: str):
|
||||
def method(self: TensorHandle, other: TensorHandle) -> TensorHandle:
|
||||
ctx = get_active_context()
|
||||
return ctx._binary_math(op, self, other)
|
||||
return method
|
||||
|
||||
# Patch TensorHandle class with operators
|
||||
TensorHandle.__add__ = _binop("add") # type: ignore[attr-defined]
|
||||
TensorHandle.__sub__ = _binop("sub") # type: ignore[attr-defined]
|
||||
TensorHandle.__mul__ = _binop("mul") # type: ignore[attr-defined]
|
||||
TensorHandle.__truediv__ = _binop("div") # type: ignore[attr-defined]
|
||||
|
||||
# Expose context management
|
||||
TLContext._set_active = staticmethod(set_active_context) # type: ignore[attr-defined]
|
||||
TLContext._get_active = staticmethod(get_active_context) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_enable_tensor_ops()
|
||||
|
||||
|
||||
def run_kernel(
|
||||
kernel_fn,
|
||||
tl_ctx: TLContext,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> list[PeCommand]:
|
||||
"""Execute a kernel function with the given TLContext and return commands.
|
||||
|
||||
Sets tl_ctx as the active context for TensorHandle operators,
|
||||
calls the kernel, then clears the context.
|
||||
"""
|
||||
TLContext._set_active(tl_ctx) # type: ignore[attr-defined]
|
||||
try:
|
||||
kernel_fn(*args, tl=tl_ctx, **kwargs)
|
||||
finally:
|
||||
TLContext._set_active(None) # type: ignore[attr-defined]
|
||||
return tl_ctx.commands
|
||||
Reference in New Issue
Block a user