commit - release 1
This commit is contained in:
@@ -0,0 +1,367 @@
|
||||
# kernbench/topology/visualizer.py
|
||||
"""
|
||||
SVG diagram generator for TopologyGraph views.
|
||||
|
||||
Produces mm-accurate, deterministic SVG files for each view level
|
||||
(system, SIP, cube, PE) per ADR-0005 and ADR-0006.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .types import Edge, Node, TopologyGraph, ViewGraph
|
||||
|
||||
# ── Color palette by component kind ─────────────────────────────────
|
||||
|
||||
_KIND_COLORS: dict[str, str] = {
|
||||
"switch": "#6366f1", # indigo
|
||||
"sip": "#e0e7ff", # light indigo
|
||||
"iochiplet": "#0ea5e9", # sky blue
|
||||
"pcie_ep": "#0ea5e9",
|
||||
"io_cpu": "#0ea5e9",
|
||||
"ucie_port": "#3b82f6", # blue
|
||||
"noc": "#a78bfa", # purple
|
||||
"m_cpu": "#f59e0b", # amber
|
||||
"xbar": "#f97316", # orange
|
||||
"hbm_ctrl": "#10b981", # emerald
|
||||
"pe": "#94a3b8", # slate
|
||||
"pe_cpu": "#ef4444", # red
|
||||
"pe_scheduler": "#f59e0b", # amber
|
||||
"pe_dma": "#3b82f6", # blue
|
||||
"pe_gemm": "#8b5cf6", # violet
|
||||
"pe_math": "#ec4899", # pink
|
||||
"pe_tcm": "#10b981", # emerald
|
||||
"sram": "#f59e0b", # amber
|
||||
"cube": "#cbd5e1", # slate-300
|
||||
}
|
||||
|
||||
_EDGE_COLORS: dict[str, str] = {
|
||||
"pcie": "#6366f1",
|
||||
"io_internal": "#0ea5e9",
|
||||
"io_to_cube": "#0ea5e9",
|
||||
"ucie_mesh": "#3b82f6",
|
||||
"pe_to_xbar": "#f97316",
|
||||
"xbar_to_hbm": "#10b981",
|
||||
"xbar_to_bridge": "#a78bfa",
|
||||
"bridge_to_xbar": "#a78bfa",
|
||||
"noc_to_ucie": "#a78bfa",
|
||||
"pe_to_noc": "#a78bfa",
|
||||
"noc_to_sram": "#f59e0b",
|
||||
"command": "#f59e0b",
|
||||
"pe_internal": "#94a3b8",
|
||||
}
|
||||
|
||||
# ── Node sizing ──────────────────────────────────────────────────────
|
||||
|
||||
_DEFAULT_NODE_W = 2.0 # mm
|
||||
_DEFAULT_NODE_H = 1.2 # mm
|
||||
|
||||
_KIND_SIZE: dict[str, tuple[float, float]] = {
|
||||
"sip": (60.0, 50.0),
|
||||
"cube": (6.0, 4.0),
|
||||
"iochiplet": (4.0, 1.5),
|
||||
"switch": (5.0, 1.5),
|
||||
}
|
||||
|
||||
|
||||
# ── Public API ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def emit_diagrams(graph: TopologyGraph, out_dir: Path) -> list[Path]:
|
||||
"""Generate SVG diagrams for all views. Returns list of created file paths."""
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
created: list[Path] = []
|
||||
|
||||
views = [
|
||||
("system_view", graph.system_view),
|
||||
("sip_view", graph.sip_view),
|
||||
("cube_view", graph.cube_view),
|
||||
("pe_view", graph.pe_view),
|
||||
]
|
||||
|
||||
for name, view in views:
|
||||
if view is None:
|
||||
continue
|
||||
svg = _render_view_svg(view)
|
||||
path = out_dir / f"{name}.svg"
|
||||
path.write_text(svg, encoding="utf-8")
|
||||
created.append(path)
|
||||
|
||||
return created
|
||||
|
||||
|
||||
# ── SVG rendering ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _render_view_svg(view: ViewGraph) -> str:
|
||||
"""Render a ViewGraph to an SVG string."""
|
||||
scale = _pick_scale(view)
|
||||
pad = 40 # px padding
|
||||
node_sizes = _compute_node_sizes(view, scale)
|
||||
|
||||
# Canvas size in px
|
||||
w_px = int(view.width_mm * scale + 2 * pad)
|
||||
h_px = int(view.height_mm * scale + 2 * pad)
|
||||
|
||||
parts: list[str] = []
|
||||
parts.append(_svg_header(w_px, h_px, view.name))
|
||||
|
||||
# Background
|
||||
parts.append(f' <rect width="{w_px}" height="{h_px}" fill="#f8fafc"/>')
|
||||
|
||||
# Title
|
||||
parts.append(
|
||||
f' <text x="{w_px // 2}" y="18" text-anchor="middle" '
|
||||
f'font-family="monospace" font-size="14" font-weight="bold" fill="#1e293b">'
|
||||
f'{view.name.upper()} VIEW</text>'
|
||||
)
|
||||
|
||||
# Special: draw cube boundary + HBM block background in cube view
|
||||
if view.name == "cube":
|
||||
_draw_cube_boundary(parts, view, scale, pad)
|
||||
_draw_hbm_block(parts, view, scale, pad)
|
||||
|
||||
# Edges (draw before nodes so nodes are on top)
|
||||
# Track fan-out edges to assign per-edge offsets
|
||||
fanout_counter: dict[str, int] = {}
|
||||
for edge in view.edges:
|
||||
if edge.src in view.nodes and edge.dst in view.nodes:
|
||||
_draw_edge(parts, edge, view, node_sizes, scale, pad, fanout_counter)
|
||||
|
||||
# Nodes
|
||||
for node in view.nodes.values():
|
||||
_draw_node(parts, node, node_sizes, scale, pad)
|
||||
|
||||
parts.append("</svg>")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _pick_scale(view: ViewGraph) -> float:
|
||||
"""Pixels per mm, chosen per view type."""
|
||||
return {
|
||||
"system": 4.0,
|
||||
"sip": 8.0,
|
||||
"cube": 28.0,
|
||||
"pe": 35.0,
|
||||
}.get(view.name, 10.0)
|
||||
|
||||
|
||||
def _compute_node_sizes(
|
||||
view: ViewGraph, scale: float,
|
||||
) -> dict[str, tuple[float, float]]:
|
||||
"""Returns (w_px, h_px) for each node."""
|
||||
sizes: dict[str, tuple[float, float]] = {}
|
||||
for nid, node in view.nodes.items():
|
||||
w_mm, h_mm = _KIND_SIZE.get(node.kind, (_DEFAULT_NODE_W, _DEFAULT_NODE_H))
|
||||
# For cube view, use smaller PE nodes
|
||||
if view.name == "cube" and node.kind == "pe":
|
||||
w_mm, h_mm = 1.8, 1.0
|
||||
if view.name == "pe":
|
||||
w_mm, h_mm = 2.5, 1.4
|
||||
sizes[nid] = (w_mm * scale, h_mm * scale)
|
||||
return sizes
|
||||
|
||||
|
||||
def _svg_header(w: int, h: int, title: str) -> str:
|
||||
return (
|
||||
f'<svg xmlns="http://www.w3.org/2000/svg" '
|
||||
f'width="{w}" height="{h}" viewBox="0 0 {w} {h}">\n'
|
||||
f' <title>{title}</title>'
|
||||
)
|
||||
|
||||
|
||||
def _draw_cube_boundary(
|
||||
parts: list[str], view: ViewGraph, scale: float, pad: int,
|
||||
) -> None:
|
||||
"""Draw the cube die outline as a dashed rectangle."""
|
||||
bx = pad
|
||||
by = pad
|
||||
bw = view.width_mm * scale
|
||||
bh = view.height_mm * scale
|
||||
parts.append(
|
||||
f' <rect x="{bx:.1f}" y="{by:.1f}" '
|
||||
f'width="{bw:.1f}" height="{bh:.1f}" '
|
||||
f'rx="6" fill="none" stroke="#475569" stroke-width="2" '
|
||||
f'stroke-dasharray="8,4"/>'
|
||||
)
|
||||
|
||||
|
||||
def _draw_hbm_block(
|
||||
parts: list[str], view: ViewGraph, scale: float, pad: int,
|
||||
) -> None:
|
||||
"""Draw HBM area as a filled rectangle in cube view."""
|
||||
# HBM area: centered at (8.5, 7.0), size 9x5 -> x=[4.0,13.0], y=[4.5,9.5]
|
||||
hbm_x = 4.0 * scale + pad
|
||||
hbm_y = 4.5 * scale + pad
|
||||
hbm_w = 9.0 * scale
|
||||
hbm_h = 5.0 * scale
|
||||
parts.append(
|
||||
f' <rect x="{hbm_x:.1f}" y="{hbm_y:.1f}" '
|
||||
f'width="{hbm_w:.1f}" height="{hbm_h:.1f}" '
|
||||
f'rx="4" fill="#d1fae5" stroke="#10b981" stroke-width="1.5" '
|
||||
f'stroke-dasharray="6,3" opacity="0.5"/>'
|
||||
)
|
||||
cx = 8.5 * scale + pad
|
||||
cy = 8.5 * scale + pad
|
||||
parts.append(
|
||||
f' <text x="{cx:.1f}" y="{cy:.1f}" text-anchor="middle" '
|
||||
f'font-family="monospace" font-size="11" fill="#047857" opacity="0.7">'
|
||||
f'HBM</text>'
|
||||
)
|
||||
|
||||
|
||||
def _draw_node(
|
||||
parts: list[str],
|
||||
node: Node,
|
||||
sizes: dict[str, tuple[float, float]],
|
||||
scale: float,
|
||||
pad: int,
|
||||
) -> None:
|
||||
"""Draw a single node as a rounded rectangle with label."""
|
||||
if node.pos_mm is None:
|
||||
return
|
||||
px = node.pos_mm[0] * scale + pad
|
||||
py = node.pos_mm[1] * scale + pad
|
||||
w, h = sizes.get(node.id, (40, 24))
|
||||
|
||||
x = px - w / 2
|
||||
y = py - h / 2
|
||||
fill = _KIND_COLORS.get(node.kind, "#e2e8f0")
|
||||
text_color = "#ffffff" if _is_dark(fill) else "#1e293b"
|
||||
|
||||
parts.append(
|
||||
f' <rect x="{x:.1f}" y="{y:.1f}" width="{w:.1f}" height="{h:.1f}" '
|
||||
f'rx="4" fill="{fill}" stroke="#475569" stroke-width="1"/>'
|
||||
)
|
||||
|
||||
label = node.label or node.id
|
||||
font_size = _label_font_size(w, label)
|
||||
parts.append(
|
||||
f' <text x="{px:.1f}" y="{py + 4:.1f}" text-anchor="middle" '
|
||||
f'font-family="monospace" font-size="{font_size}" fill="{text_color}">'
|
||||
f'{_escape(label)}</text>'
|
||||
)
|
||||
|
||||
|
||||
# ── Fan-out edge kinds that need offset routing ─────────────────────
|
||||
|
||||
_FANOUT_KINDS = {"pe_to_xbar", "pe_to_noc", "command", "noc_to_ucie"}
|
||||
|
||||
|
||||
def _draw_edge(
|
||||
parts: list[str],
|
||||
edge: Edge,
|
||||
view: ViewGraph,
|
||||
sizes: dict[str, tuple[float, float]],
|
||||
scale: float,
|
||||
pad: int,
|
||||
fanout_counter: dict[str, int],
|
||||
) -> None:
|
||||
"""Draw an edge with orthogonal (90-degree) routing for fan-out kinds."""
|
||||
nodes = view.nodes
|
||||
src_node = nodes[edge.src]
|
||||
dst_node = nodes[edge.dst]
|
||||
if src_node.pos_mm is None or dst_node.pos_mm is None:
|
||||
return
|
||||
|
||||
x1 = src_node.pos_mm[0] * scale + pad
|
||||
y1 = src_node.pos_mm[1] * scale + pad
|
||||
x2 = dst_node.pos_mm[0] * scale + pad
|
||||
y2 = dst_node.pos_mm[1] * scale + pad
|
||||
|
||||
color = _EDGE_COLORS.get(edge.kind, "#94a3b8")
|
||||
width = "1.5" if edge.kind == "pe_internal" else "1"
|
||||
opacity = "0.6" if edge.kind in ("command", "noc_to_ucie") else "0.8"
|
||||
|
||||
if edge.kind in _FANOUT_KINDS and view.name == "cube":
|
||||
# Orthogonal routing: src→horizontal→vertical→dst with per-edge offset.
|
||||
group_key = f"{edge.kind}:{edge.dst}"
|
||||
idx = fanout_counter.get(group_key, 0)
|
||||
fanout_counter[group_key] = idx + 1
|
||||
|
||||
# Route: go vertically from src to a staggered horizontal channel,
|
||||
# then horizontally to dst x, then vertically to dst.
|
||||
mid_y = (y1 + y2) / 2 + (idx - 1.5) * 10 # spread channels vertically
|
||||
|
||||
parts.append(
|
||||
f' <polyline points="{x1:.1f},{y1:.1f} {x1:.1f},{mid_y:.1f} '
|
||||
f'{x2:.1f},{mid_y:.1f} {x2:.1f},{y2:.1f}" '
|
||||
f'fill="none" stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
|
||||
)
|
||||
|
||||
# Label on the horizontal segment
|
||||
if edge.distance_mm > 0:
|
||||
lx = (x1 + x2) / 2
|
||||
label = f"{edge.distance_mm:.1f}mm"
|
||||
if edge.bw_gbs:
|
||||
label += f" {edge.bw_gbs:.0f}GB/s"
|
||||
parts.append(
|
||||
f' <text x="{lx:.1f}" y="{mid_y - 3:.1f}" text-anchor="middle" '
|
||||
f'font-family="monospace" font-size="7" fill="#64748b">'
|
||||
f'{label}</text>'
|
||||
)
|
||||
return
|
||||
|
||||
# Non-fanout: orthogonal L-bend
|
||||
if abs(x2 - x1) > 1 and abs(y2 - y1) > 1:
|
||||
# PE view: vertical-first for left→right edges (scheduler→engines),
|
||||
# horizontal-first for right→right edges (engines→tcm)
|
||||
if view.name == "pe":
|
||||
if src_node.pos_mm[0] < view.width_mm / 2:
|
||||
# Source in left half: vertical-first (scheduler fan-out)
|
||||
parts.append(
|
||||
f' <polyline points="{x1:.1f},{y1:.1f} {x1:.1f},{y2:.1f} {x2:.1f},{y2:.1f}" '
|
||||
f'fill="none" stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
|
||||
)
|
||||
else:
|
||||
# Source in right half: horizontal-first (dma/math→tcm)
|
||||
parts.append(
|
||||
f' <polyline points="{x1:.1f},{y1:.1f} {x2:.1f},{y1:.1f} {x2:.1f},{y2:.1f}" '
|
||||
f'fill="none" stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
|
||||
)
|
||||
else:
|
||||
parts.append(
|
||||
f' <polyline points="{x1:.1f},{y1:.1f} {x2:.1f},{y1:.1f} {x2:.1f},{y2:.1f}" '
|
||||
f'fill="none" stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
|
||||
)
|
||||
else:
|
||||
parts.append(
|
||||
f' <line x1="{x1:.1f}" y1="{y1:.1f}" x2="{x2:.1f}" y2="{y2:.1f}" '
|
||||
f'stroke="{color}" stroke-width="{width}" opacity="{opacity}"/>'
|
||||
)
|
||||
|
||||
# Distance label at midpoint
|
||||
if edge.distance_mm > 0:
|
||||
mx = (x1 + x2) / 2
|
||||
my = (y1 + y2) / 2
|
||||
label = f"{edge.distance_mm:.1f}mm"
|
||||
if edge.bw_gbs:
|
||||
label += f" {edge.bw_gbs:.0f}GB/s"
|
||||
parts.append(
|
||||
f' <text x="{mx:.1f}" y="{my - 4:.1f}" text-anchor="middle" '
|
||||
f'font-family="monospace" font-size="7" fill="#64748b">'
|
||||
f'{label}</text>'
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _is_dark(hex_color: str) -> bool:
|
||||
"""Check if a hex color is dark (for white text)."""
|
||||
h = hex_color.lstrip("#")
|
||||
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
|
||||
return (r * 0.299 + g * 0.587 + b * 0.114) < 140
|
||||
|
||||
|
||||
def _label_font_size(box_width: float, label: str) -> int:
|
||||
"""Choose font size to fit label in box."""
|
||||
char_w = len(label) * 7
|
||||
if char_w > box_width * 0.9:
|
||||
return max(7, int(box_width * 0.9 / len(label) * 1.4))
|
||||
return 10
|
||||
|
||||
|
||||
def _escape(text: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
Reference in New Issue
Block a user