Files
kernbench2/scripts/gen_placement_diagrams.py
T
2026-03-18 11:47:48 -07:00

394 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Generate SVG diagrams illustrating each placement strategy.
Example tensor: (M=1024, K=512) fp16 (itemsize=2), 8 PEs.
Tiled variants use tile_m=256, tile_k=128.
Output: docs/diagrams/placement_*.svg
"""
from __future__ import annotations
import math
from pathlib import Path
# ── Diagram parameters ──────────────────────────────────────────────
M, K = 1024, 512
ITEMSIZE = 2
NUM_PE = 8
TILE_M, TILE_K = 256, 128
PE_COLORS = [
"#3b82f6", # PE0 blue
"#10b981", # PE1 emerald
"#f59e0b", # PE2 amber
"#ef4444", # PE3 red
"#8b5cf6", # PE4 violet
"#ec4899", # PE5 pink
"#06b6d4", # PE6 cyan
"#f97316", # PE7 orange
]
PE_TEXT_COLORS = [
"#fff", "#fff", "#000", "#fff",
"#fff", "#fff", "#000", "#fff",
]
OUT_DIR = Path(__file__).parent.parent / "docs" / "diagrams"
# ── SVG helpers ─────────────────────────────────────────────────────
def _svg_header(w: int, h: int, title: str) -> str:
return (
f'<svg xmlns="http://www.w3.org/2000/svg" width="{w}" height="{h}"'
f' viewBox="0 0 {w} {h}" font-family="monospace">\n'
f'<rect width="{w}" height="{h}" fill="#f8fafc" rx="6"/>\n'
f'<text x="{w // 2}" y="32" text-anchor="middle" font-size="16"'
f' font-weight="bold" fill="#1e293b">{title}</text>\n'
)
def _svg_footer() -> str:
return "</svg>\n"
def _rect(x: float, y: float, w: float, h: float, fill: str,
stroke: str = "#334155", sw: float = 1.0, opacity: float = 1.0) -> str:
return (
f'<rect x="{x:.1f}" y="{y:.1f}" width="{w:.1f}" height="{h:.1f}"'
f' fill="{fill}" stroke="{stroke}" stroke-width="{sw}"'
f' fill-opacity="{opacity}" rx="2"/>\n'
)
def _text(x: float, y: float, txt: str, size: int = 11,
anchor: str = "middle", fill: str = "#1e293b",
weight: str = "normal") -> str:
return (
f'<text x="{x:.1f}" y="{y:.1f}" text-anchor="{anchor}"'
f' font-size="{size}" fill="{fill}" font-weight="{weight}">{txt}</text>\n'
)
def _line(x1: float, y1: float, x2: float, y2: float,
stroke: str = "#94a3b8", sw: float = 1) -> str:
return (
f'<line x1="{x1:.1f}" y1="{y1:.1f}" x2="{x2:.1f}" y2="{y2:.1f}"'
f' stroke="{stroke}" stroke-width="{sw}"/>\n'
)
def _format_bytes(n: int) -> str:
if n >= (1 << 20):
return f"{n >> 20} MB"
if n >= (1 << 10):
return f"{n >> 10} KB"
return f"{n} B"
def _legend(x: float, y0: float, num_pe: int = NUM_PE) -> str:
s = _text(x + 50, y0, "PE Legend", size=12, weight="bold")
for i in range(num_pe):
ly = y0 + 18 + i * 22
s += _rect(x, ly - 12, 16, 16, PE_COLORS[i])
s += _text(x + 22, ly, f"PE{i}", size=11, anchor="start")
return s
def _axes(gx: float, gy: float, gw: float, gh: float,
m_label: str = "M=1024", k_label: str = "K=512") -> str:
"""Draw axis labels and dimension arrows."""
s = ""
# K axis (horizontal) label above grid
s += _text(gx + gw / 2, gy - 8, f"{k_label}", size=11, fill="#475569")
# M axis (vertical) label left of grid
mx = gx - 12
my = gy + gh / 2
s += (
f'<text x="{mx:.1f}" y="{my:.1f}" text-anchor="middle"'
f' font-size="11" fill="#475569"'
f' transform="rotate(-90 {mx:.1f} {my:.1f})">↑ {m_label} ↓</text>\n'
)
return s
def _info_box(x: float, y: float, lines: list[str]) -> str:
"""Rounded info box with key/value lines."""
bw = max(len(l) for l in lines) * 7 + 20
bh = len(lines) * 18 + 12
s = _rect(x, y, bw, bh, "#e2e8f0", stroke="#94a3b8", sw=1)
for i, line in enumerate(lines):
s += _text(x + 10, y + 18 + i * 18, line, size=10, anchor="start", fill="#334155")
return s
# ── Grid drawing ────────────────────────────────────────────────────
def _draw_grid(
gx: float, gy: float, gw: float, gh: float,
cells: list[dict], # [{row, col, rspan, cspan, pe, label?, offset?}]
rows: int, cols: int,
cell_labels: bool = True,
) -> str:
"""Draw a grid of colored cells representing shard placement."""
cw = gw / cols
ch = gh / rows
s = ""
for c in cells:
cx = gx + c["col"] * cw
cy = gy + c["row"] * ch
w = c.get("cspan", 1) * cw
h = c.get("rspan", 1) * ch
pe = c["pe"]
s += _rect(cx, cy, w, h, PE_COLORS[pe], stroke="#334155", sw=1.5)
# PE label
lx = cx + w / 2
ly = cy + h / 2
s += _text(lx, ly - 4, f"PE{pe}", size=12,
fill=PE_TEXT_COLORS[pe], weight="bold")
if cell_labels and "label" in c:
s += _text(lx, ly + 12, c["label"], size=9,
fill=PE_TEXT_COLORS[pe])
# Grid border
s += _rect(gx, gy, gw, gh, "none", stroke="#1e293b", sw=2)
return s
# ── Strategy-specific generators ────────────────────────────────────
def gen_column_wise() -> str:
"""Column-wise: split K into 8 equal parts."""
W, H = 820, 500
s = _svg_header(W, H, "Placement: column_wise")
s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → K axis split into {NUM_PE} parts",
size=12, fill="#475569")
gx, gy, gw, gh = 80, 90, 480, 320
chunk_k = K // NUM_PE # 64
chunk_bytes = M * chunk_k * ITEMSIZE
s += _axes(gx, gy, gw, gh)
cells = []
for i in range(NUM_PE):
cells.append({
"row": 0, "col": i, "rspan": 1, "cspan": 1,
"pe": i,
"label": f"({M}×{chunk_k})",
})
s += _draw_grid(gx, gy, gw, gh, cells, rows=1, cols=NUM_PE)
# Column dimension labels
cw = gw / NUM_PE
for i in range(NUM_PE):
cx = gx + i * cw + cw / 2
off = i * chunk_bytes
s += _text(cx, gy + gh + 16, f"off={_format_bytes(off)}", size=9, fill="#475569")
s += _text(cx, gy + gh + 30, f"{_format_bytes(chunk_bytes)}", size=9, fill="#64748b")
s += _legend(620, 100)
s += _info_box(620, 320, [
f"Strategy: column_wise",
f"Split axis: K",
f"Shards: {NUM_PE}",
f"Each: ({M}, {chunk_k})",
f"Each: {_format_bytes(chunk_bytes)}",
f"Total: {_format_bytes(M * K * ITEMSIZE)}",
])
s += _svg_footer()
return s
def gen_row_wise() -> str:
"""Row-wise: split M into 8 equal parts."""
W, H = 820, 560
s = _svg_header(W, H, "Placement: row_wise")
s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → M axis split into {NUM_PE} parts",
size=12, fill="#475569")
gx, gy, gw, gh = 80, 90, 320, 400
chunk_m = M // NUM_PE # 128
chunk_bytes = chunk_m * K * ITEMSIZE
s += _axes(gx, gy, gw, gh)
cells = []
for i in range(NUM_PE):
cells.append({
"row": i, "col": 0, "rspan": 1, "cspan": 1,
"pe": i,
"label": f"({chunk_m}×{K})",
})
s += _draw_grid(gx, gy, gw, gh, cells, rows=NUM_PE, cols=1)
# Row dimension labels
ch = gh / NUM_PE
for i in range(NUM_PE):
cy = gy + i * ch + ch / 2
off = i * chunk_bytes
s += _text(gx + gw + 10, cy - 4, f"off={_format_bytes(off)}",
size=9, anchor="start", fill="#475569")
s += _text(gx + gw + 10, cy + 10, f"{_format_bytes(chunk_bytes)}",
size=9, anchor="start", fill="#64748b")
s += _legend(580, 100)
s += _info_box(580, 320, [
f"Strategy: row_wise",
f"Split axis: M",
f"Shards: {NUM_PE}",
f"Each: ({chunk_m}, {K})",
f"Each: {_format_bytes(chunk_bytes)}",
f"Total: {_format_bytes(M * K * ITEMSIZE)}",
])
s += _svg_footer()
return s
def gen_replicate() -> str:
"""Replicate: full copy per PE."""
W, H = 820, 500
s = _svg_header(W, H, "Placement: replicate")
s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → full copy to each PE",
size=12, fill="#475569")
full_bytes = M * K * ITEMSIZE
# Show 8 small copies in 2 rows × 4 cols
cols, rows = 4, 2
margin_x, margin_y = 60, 90
gap = 16
bw = (700 - (cols - 1) * gap) / cols
bh = (340 - (rows - 1) * gap) / rows
for i in range(NUM_PE):
r = i // cols
c = i % cols
bx = margin_x + c * (bw + gap)
by = margin_y + r * (bh + gap)
s += _rect(bx, by, bw, bh, PE_COLORS[i], stroke="#334155", sw=1.5)
s += _text(bx + bw / 2, by + bh / 2 - 14, f"PE{i}",
size=14, fill=PE_TEXT_COLORS[i], weight="bold")
s += _text(bx + bw / 2, by + bh / 2 + 6, f"({M}×{K})",
size=11, fill=PE_TEXT_COLORS[i])
s += _text(bx + bw / 2, by + bh / 2 + 22, f"{_format_bytes(full_bytes)}",
size=10, fill=PE_TEXT_COLORS[i])
s += _text(bx + bw / 2, by + bh / 2 + 36, "offset=0",
size=9, fill=PE_TEXT_COLORS[i])
s += _info_box(60, 450, [
f"Strategy: replicate | Shards: {NUM_PE} | Each: {_format_bytes(full_bytes)}"
f" | Total mem: {_format_bytes(full_bytes * NUM_PE)}",
])
s += _svg_footer()
return s
def gen_tiled(column_major: bool) -> str:
"""2D tiled placement. column_major=True → tiled_column_major."""
name = "tiled_column_major" if column_major else "tiled_row_major"
order = "column-major (K first)" if column_major else "row-major (M first)"
tiles_m = M // TILE_M # 4
tiles_k = K // TILE_K # 4
total_tiles = tiles_m * tiles_k # 16
tile_bytes = TILE_M * TILE_K * ITEMSIZE
W, H = 820, 620
s = _svg_header(W, H, f"Placement: {name}")
s += _text(W // 2, 54,
f"Tensor ({M}×{K}) fp16, tile=({TILE_M}×{TILE_K}) → "
f"{tiles_m}×{tiles_k}={total_tiles} tiles, {order}",
size=11, fill="#475569")
gx, gy, gw, gh = 80, 90, 400, 400
s += _axes(gx, gy, gw, gh)
# Build tile → PE mapping
cells = []
idx = 0
if column_major:
# iterate M first (rows), then K (cols) — but column-major means
# we traverse in the order that fills columns first
# Actually: column-major = K axis first within each M row
# The implementation iterates: for mi in tiles_m: for ki in tiles_k
for mi in range(tiles_m):
for ki in range(tiles_k):
pe = idx % NUM_PE
row_bytes = K * ITEMSIZE
offset = (mi * TILE_M * row_bytes) + (ki * TILE_K * ITEMSIZE)
cells.append({
"row": mi, "col": ki, "rspan": 1, "cspan": 1,
"pe": pe,
"label": f"t{idx}",
"offset": offset,
"idx": idx,
})
idx += 1
else:
# row-major: iterate K first (cols), then M (rows)
for ki in range(tiles_k):
for mi in range(tiles_m):
pe = idx % NUM_PE
row_bytes = K * ITEMSIZE
offset = (mi * TILE_M * row_bytes) + (ki * TILE_K * ITEMSIZE)
cells.append({
"row": mi, "col": ki, "rspan": 1, "cspan": 1,
"pe": pe,
"label": f"t{idx}",
"offset": offset,
"idx": idx,
})
idx += 1
s += _draw_grid(gx, gy, gw, gh, cells, rows=tiles_m, cols=tiles_k)
# Tile dimension labels on top
cw = gw / tiles_k
for ki in range(tiles_k):
cx = gx + ki * cw + cw / 2
s += _text(cx, gy + gh + 16, f"k={ki * TILE_K}..{(ki + 1) * TILE_K - 1}",
size=9, fill="#475569")
# Tile dimension labels on left
ch = gh / tiles_m
for mi in range(tiles_m):
cy = gy + mi * ch + ch / 2
s += _text(gx - 16, cy, f"m={mi * TILE_M}..{(mi + 1) * TILE_M - 1}",
size=9, anchor="end", fill="#475569")
s += _legend(540, 90)
# Assignment table
table_y = 310
s += _text(540, table_y, "Tile Assignment Order", size=12, weight="bold")
# Sort cells by idx for table
sorted_cells = sorted(cells, key=lambda c: c["idx"])
for i, c in enumerate(sorted_cells):
ty = table_y + 18 + i * 16
if ty > H - 20:
break
pe = c["pe"]
s += _rect(540, ty - 10, 12, 12, PE_COLORS[pe])
s += _text(558, ty,
f"t{c['idx']:>2d} → PE{pe} ({c['row']},{c['col']})"
f" off={_format_bytes(c['offset'])}",
size=9, anchor="start", fill="#334155")
s += _info_box(80, H - 60, [
f"Strategy: {name} | Tile: ({TILE_M}×{TILE_K})={_format_bytes(tile_bytes)}"
f" | Tiles: {total_tiles} | Total: {_format_bytes(M * K * ITEMSIZE)}",
])
s += _svg_footer()
return s
# ── Main ────────────────────────────────────────────────────────────
def main() -> None:
OUT_DIR.mkdir(parents=True, exist_ok=True)
diagrams = {
"placement_column_wise.svg": gen_column_wise(),
"placement_row_wise.svg": gen_row_wise(),
"placement_replicate.svg": gen_replicate(),
"placement_tiled_column_major.svg": gen_tiled(column_major=True),
"placement_tiled_row_major.svg": gen_tiled(column_major=False),
}
for name, svg in diagrams.items():
path = OUT_DIR / name
path.write_text(svg, encoding="utf-8")
print(f" wrote {path}")
print(f"\nGenerated {len(diagrams)} placement diagrams.")
if __name__ == "__main__":
main()