394 lines
14 KiB
Python
394 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
"""Generate SVG diagrams illustrating each placement strategy.
|
||
|
||
Example tensor: (M=1024, K=512) fp16 (itemsize=2), 8 PEs.
|
||
Tiled variants use tile_m=256, tile_k=128.
|
||
|
||
Output: docs/diagrams/placement_*.svg
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import math
|
||
from pathlib import Path
|
||
|
||
# ── Diagram parameters ──────────────────────────────────────────────
|
||
M, K = 1024, 512
|
||
ITEMSIZE = 2
|
||
NUM_PE = 8
|
||
TILE_M, TILE_K = 256, 128
|
||
|
||
PE_COLORS = [
|
||
"#3b82f6", # PE0 blue
|
||
"#10b981", # PE1 emerald
|
||
"#f59e0b", # PE2 amber
|
||
"#ef4444", # PE3 red
|
||
"#8b5cf6", # PE4 violet
|
||
"#ec4899", # PE5 pink
|
||
"#06b6d4", # PE6 cyan
|
||
"#f97316", # PE7 orange
|
||
]
|
||
PE_TEXT_COLORS = [
|
||
"#fff", "#fff", "#000", "#fff",
|
||
"#fff", "#fff", "#000", "#fff",
|
||
]
|
||
|
||
OUT_DIR = Path(__file__).parent.parent / "docs" / "diagrams"
|
||
|
||
# ── SVG helpers ─────────────────────────────────────────────────────
|
||
|
||
def _svg_header(w: int, h: int, title: str) -> str:
|
||
return (
|
||
f'<svg xmlns="http://www.w3.org/2000/svg" width="{w}" height="{h}"'
|
||
f' viewBox="0 0 {w} {h}" font-family="monospace">\n'
|
||
f'<rect width="{w}" height="{h}" fill="#f8fafc" rx="6"/>\n'
|
||
f'<text x="{w // 2}" y="32" text-anchor="middle" font-size="16"'
|
||
f' font-weight="bold" fill="#1e293b">{title}</text>\n'
|
||
)
|
||
|
||
def _svg_footer() -> str:
|
||
return "</svg>\n"
|
||
|
||
def _rect(x: float, y: float, w: float, h: float, fill: str,
|
||
stroke: str = "#334155", sw: float = 1.0, opacity: float = 1.0) -> str:
|
||
return (
|
||
f'<rect x="{x:.1f}" y="{y:.1f}" width="{w:.1f}" height="{h:.1f}"'
|
||
f' fill="{fill}" stroke="{stroke}" stroke-width="{sw}"'
|
||
f' fill-opacity="{opacity}" rx="2"/>\n'
|
||
)
|
||
|
||
def _text(x: float, y: float, txt: str, size: int = 11,
|
||
anchor: str = "middle", fill: str = "#1e293b",
|
||
weight: str = "normal") -> str:
|
||
return (
|
||
f'<text x="{x:.1f}" y="{y:.1f}" text-anchor="{anchor}"'
|
||
f' font-size="{size}" fill="{fill}" font-weight="{weight}">{txt}</text>\n'
|
||
)
|
||
|
||
def _line(x1: float, y1: float, x2: float, y2: float,
|
||
stroke: str = "#94a3b8", sw: float = 1) -> str:
|
||
return (
|
||
f'<line x1="{x1:.1f}" y1="{y1:.1f}" x2="{x2:.1f}" y2="{y2:.1f}"'
|
||
f' stroke="{stroke}" stroke-width="{sw}"/>\n'
|
||
)
|
||
|
||
def _format_bytes(n: int) -> str:
|
||
if n >= (1 << 20):
|
||
return f"{n >> 20} MB"
|
||
if n >= (1 << 10):
|
||
return f"{n >> 10} KB"
|
||
return f"{n} B"
|
||
|
||
def _legend(x: float, y0: float, num_pe: int = NUM_PE) -> str:
|
||
s = _text(x + 50, y0, "PE Legend", size=12, weight="bold")
|
||
for i in range(num_pe):
|
||
ly = y0 + 18 + i * 22
|
||
s += _rect(x, ly - 12, 16, 16, PE_COLORS[i])
|
||
s += _text(x + 22, ly, f"PE{i}", size=11, anchor="start")
|
||
return s
|
||
|
||
def _axes(gx: float, gy: float, gw: float, gh: float,
|
||
m_label: str = "M=1024", k_label: str = "K=512") -> str:
|
||
"""Draw axis labels and dimension arrows."""
|
||
s = ""
|
||
# K axis (horizontal) label above grid
|
||
s += _text(gx + gw / 2, gy - 8, f"← {k_label} →", size=11, fill="#475569")
|
||
# M axis (vertical) label left of grid
|
||
mx = gx - 12
|
||
my = gy + gh / 2
|
||
s += (
|
||
f'<text x="{mx:.1f}" y="{my:.1f}" text-anchor="middle"'
|
||
f' font-size="11" fill="#475569"'
|
||
f' transform="rotate(-90 {mx:.1f} {my:.1f})">↑ {m_label} ↓</text>\n'
|
||
)
|
||
return s
|
||
|
||
def _info_box(x: float, y: float, lines: list[str]) -> str:
|
||
"""Rounded info box with key/value lines."""
|
||
bw = max(len(l) for l in lines) * 7 + 20
|
||
bh = len(lines) * 18 + 12
|
||
s = _rect(x, y, bw, bh, "#e2e8f0", stroke="#94a3b8", sw=1)
|
||
for i, line in enumerate(lines):
|
||
s += _text(x + 10, y + 18 + i * 18, line, size=10, anchor="start", fill="#334155")
|
||
return s
|
||
|
||
# ── Grid drawing ────────────────────────────────────────────────────
|
||
|
||
def _draw_grid(
|
||
gx: float, gy: float, gw: float, gh: float,
|
||
cells: list[dict], # [{row, col, rspan, cspan, pe, label?, offset?}]
|
||
rows: int, cols: int,
|
||
cell_labels: bool = True,
|
||
) -> str:
|
||
"""Draw a grid of colored cells representing shard placement."""
|
||
cw = gw / cols
|
||
ch = gh / rows
|
||
s = ""
|
||
for c in cells:
|
||
cx = gx + c["col"] * cw
|
||
cy = gy + c["row"] * ch
|
||
w = c.get("cspan", 1) * cw
|
||
h = c.get("rspan", 1) * ch
|
||
pe = c["pe"]
|
||
s += _rect(cx, cy, w, h, PE_COLORS[pe], stroke="#334155", sw=1.5)
|
||
# PE label
|
||
lx = cx + w / 2
|
||
ly = cy + h / 2
|
||
s += _text(lx, ly - 4, f"PE{pe}", size=12,
|
||
fill=PE_TEXT_COLORS[pe], weight="bold")
|
||
if cell_labels and "label" in c:
|
||
s += _text(lx, ly + 12, c["label"], size=9,
|
||
fill=PE_TEXT_COLORS[pe])
|
||
# Grid border
|
||
s += _rect(gx, gy, gw, gh, "none", stroke="#1e293b", sw=2)
|
||
return s
|
||
|
||
|
||
# ── Strategy-specific generators ────────────────────────────────────
|
||
|
||
def gen_column_wise() -> str:
|
||
"""Column-wise: split K into 8 equal parts."""
|
||
W, H = 820, 500
|
||
s = _svg_header(W, H, "Placement: column_wise")
|
||
s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → K axis split into {NUM_PE} parts",
|
||
size=12, fill="#475569")
|
||
|
||
gx, gy, gw, gh = 80, 90, 480, 320
|
||
chunk_k = K // NUM_PE # 64
|
||
chunk_bytes = M * chunk_k * ITEMSIZE
|
||
|
||
s += _axes(gx, gy, gw, gh)
|
||
cells = []
|
||
for i in range(NUM_PE):
|
||
cells.append({
|
||
"row": 0, "col": i, "rspan": 1, "cspan": 1,
|
||
"pe": i,
|
||
"label": f"({M}×{chunk_k})",
|
||
})
|
||
s += _draw_grid(gx, gy, gw, gh, cells, rows=1, cols=NUM_PE)
|
||
|
||
# Column dimension labels
|
||
cw = gw / NUM_PE
|
||
for i in range(NUM_PE):
|
||
cx = gx + i * cw + cw / 2
|
||
off = i * chunk_bytes
|
||
s += _text(cx, gy + gh + 16, f"off={_format_bytes(off)}", size=9, fill="#475569")
|
||
s += _text(cx, gy + gh + 30, f"{_format_bytes(chunk_bytes)}", size=9, fill="#64748b")
|
||
|
||
s += _legend(620, 100)
|
||
s += _info_box(620, 320, [
|
||
f"Strategy: column_wise",
|
||
f"Split axis: K",
|
||
f"Shards: {NUM_PE}",
|
||
f"Each: ({M}, {chunk_k})",
|
||
f"Each: {_format_bytes(chunk_bytes)}",
|
||
f"Total: {_format_bytes(M * K * ITEMSIZE)}",
|
||
])
|
||
s += _svg_footer()
|
||
return s
|
||
|
||
|
||
def gen_row_wise() -> str:
|
||
"""Row-wise: split M into 8 equal parts."""
|
||
W, H = 820, 560
|
||
s = _svg_header(W, H, "Placement: row_wise")
|
||
s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → M axis split into {NUM_PE} parts",
|
||
size=12, fill="#475569")
|
||
|
||
gx, gy, gw, gh = 80, 90, 320, 400
|
||
chunk_m = M // NUM_PE # 128
|
||
chunk_bytes = chunk_m * K * ITEMSIZE
|
||
|
||
s += _axes(gx, gy, gw, gh)
|
||
cells = []
|
||
for i in range(NUM_PE):
|
||
cells.append({
|
||
"row": i, "col": 0, "rspan": 1, "cspan": 1,
|
||
"pe": i,
|
||
"label": f"({chunk_m}×{K})",
|
||
})
|
||
s += _draw_grid(gx, gy, gw, gh, cells, rows=NUM_PE, cols=1)
|
||
|
||
# Row dimension labels
|
||
ch = gh / NUM_PE
|
||
for i in range(NUM_PE):
|
||
cy = gy + i * ch + ch / 2
|
||
off = i * chunk_bytes
|
||
s += _text(gx + gw + 10, cy - 4, f"off={_format_bytes(off)}",
|
||
size=9, anchor="start", fill="#475569")
|
||
s += _text(gx + gw + 10, cy + 10, f"{_format_bytes(chunk_bytes)}",
|
||
size=9, anchor="start", fill="#64748b")
|
||
|
||
s += _legend(580, 100)
|
||
s += _info_box(580, 320, [
|
||
f"Strategy: row_wise",
|
||
f"Split axis: M",
|
||
f"Shards: {NUM_PE}",
|
||
f"Each: ({chunk_m}, {K})",
|
||
f"Each: {_format_bytes(chunk_bytes)}",
|
||
f"Total: {_format_bytes(M * K * ITEMSIZE)}",
|
||
])
|
||
s += _svg_footer()
|
||
return s
|
||
|
||
|
||
def gen_replicate() -> str:
|
||
"""Replicate: full copy per PE."""
|
||
W, H = 820, 500
|
||
s = _svg_header(W, H, "Placement: replicate")
|
||
s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → full copy to each PE",
|
||
size=12, fill="#475569")
|
||
|
||
full_bytes = M * K * ITEMSIZE
|
||
# Show 8 small copies in 2 rows × 4 cols
|
||
cols, rows = 4, 2
|
||
margin_x, margin_y = 60, 90
|
||
gap = 16
|
||
bw = (700 - (cols - 1) * gap) / cols
|
||
bh = (340 - (rows - 1) * gap) / rows
|
||
|
||
for i in range(NUM_PE):
|
||
r = i // cols
|
||
c = i % cols
|
||
bx = margin_x + c * (bw + gap)
|
||
by = margin_y + r * (bh + gap)
|
||
s += _rect(bx, by, bw, bh, PE_COLORS[i], stroke="#334155", sw=1.5)
|
||
s += _text(bx + bw / 2, by + bh / 2 - 14, f"PE{i}",
|
||
size=14, fill=PE_TEXT_COLORS[i], weight="bold")
|
||
s += _text(bx + bw / 2, by + bh / 2 + 6, f"({M}×{K})",
|
||
size=11, fill=PE_TEXT_COLORS[i])
|
||
s += _text(bx + bw / 2, by + bh / 2 + 22, f"{_format_bytes(full_bytes)}",
|
||
size=10, fill=PE_TEXT_COLORS[i])
|
||
s += _text(bx + bw / 2, by + bh / 2 + 36, "offset=0",
|
||
size=9, fill=PE_TEXT_COLORS[i])
|
||
|
||
s += _info_box(60, 450, [
|
||
f"Strategy: replicate | Shards: {NUM_PE} | Each: {_format_bytes(full_bytes)}"
|
||
f" | Total mem: {_format_bytes(full_bytes * NUM_PE)}",
|
||
])
|
||
s += _svg_footer()
|
||
return s
|
||
|
||
|
||
def gen_tiled(column_major: bool) -> str:
|
||
"""2D tiled placement. column_major=True → tiled_column_major."""
|
||
name = "tiled_column_major" if column_major else "tiled_row_major"
|
||
order = "column-major (K first)" if column_major else "row-major (M first)"
|
||
|
||
tiles_m = M // TILE_M # 4
|
||
tiles_k = K // TILE_K # 4
|
||
total_tiles = tiles_m * tiles_k # 16
|
||
tile_bytes = TILE_M * TILE_K * ITEMSIZE
|
||
|
||
W, H = 820, 620
|
||
s = _svg_header(W, H, f"Placement: {name}")
|
||
s += _text(W // 2, 54,
|
||
f"Tensor ({M}×{K}) fp16, tile=({TILE_M}×{TILE_K}) → "
|
||
f"{tiles_m}×{tiles_k}={total_tiles} tiles, {order}",
|
||
size=11, fill="#475569")
|
||
|
||
gx, gy, gw, gh = 80, 90, 400, 400
|
||
s += _axes(gx, gy, gw, gh)
|
||
|
||
# Build tile → PE mapping
|
||
cells = []
|
||
idx = 0
|
||
if column_major:
|
||
# iterate M first (rows), then K (cols) — but column-major means
|
||
# we traverse in the order that fills columns first
|
||
# Actually: column-major = K axis first within each M row
|
||
# The implementation iterates: for mi in tiles_m: for ki in tiles_k
|
||
for mi in range(tiles_m):
|
||
for ki in range(tiles_k):
|
||
pe = idx % NUM_PE
|
||
row_bytes = K * ITEMSIZE
|
||
offset = (mi * TILE_M * row_bytes) + (ki * TILE_K * ITEMSIZE)
|
||
cells.append({
|
||
"row": mi, "col": ki, "rspan": 1, "cspan": 1,
|
||
"pe": pe,
|
||
"label": f"t{idx}",
|
||
"offset": offset,
|
||
"idx": idx,
|
||
})
|
||
idx += 1
|
||
else:
|
||
# row-major: iterate K first (cols), then M (rows)
|
||
for ki in range(tiles_k):
|
||
for mi in range(tiles_m):
|
||
pe = idx % NUM_PE
|
||
row_bytes = K * ITEMSIZE
|
||
offset = (mi * TILE_M * row_bytes) + (ki * TILE_K * ITEMSIZE)
|
||
cells.append({
|
||
"row": mi, "col": ki, "rspan": 1, "cspan": 1,
|
||
"pe": pe,
|
||
"label": f"t{idx}",
|
||
"offset": offset,
|
||
"idx": idx,
|
||
})
|
||
idx += 1
|
||
|
||
s += _draw_grid(gx, gy, gw, gh, cells, rows=tiles_m, cols=tiles_k)
|
||
|
||
# Tile dimension labels on top
|
||
cw = gw / tiles_k
|
||
for ki in range(tiles_k):
|
||
cx = gx + ki * cw + cw / 2
|
||
s += _text(cx, gy + gh + 16, f"k={ki * TILE_K}..{(ki + 1) * TILE_K - 1}",
|
||
size=9, fill="#475569")
|
||
|
||
# Tile dimension labels on left
|
||
ch = gh / tiles_m
|
||
for mi in range(tiles_m):
|
||
cy = gy + mi * ch + ch / 2
|
||
s += _text(gx - 16, cy, f"m={mi * TILE_M}..{(mi + 1) * TILE_M - 1}",
|
||
size=9, anchor="end", fill="#475569")
|
||
|
||
s += _legend(540, 90)
|
||
|
||
# Assignment table
|
||
table_y = 310
|
||
s += _text(540, table_y, "Tile Assignment Order", size=12, weight="bold")
|
||
# Sort cells by idx for table
|
||
sorted_cells = sorted(cells, key=lambda c: c["idx"])
|
||
for i, c in enumerate(sorted_cells):
|
||
ty = table_y + 18 + i * 16
|
||
if ty > H - 20:
|
||
break
|
||
pe = c["pe"]
|
||
s += _rect(540, ty - 10, 12, 12, PE_COLORS[pe])
|
||
s += _text(558, ty,
|
||
f"t{c['idx']:>2d} → PE{pe} ({c['row']},{c['col']})"
|
||
f" off={_format_bytes(c['offset'])}",
|
||
size=9, anchor="start", fill="#334155")
|
||
|
||
s += _info_box(80, H - 60, [
|
||
f"Strategy: {name} | Tile: ({TILE_M}×{TILE_K})={_format_bytes(tile_bytes)}"
|
||
f" | Tiles: {total_tiles} | Total: {_format_bytes(M * K * ITEMSIZE)}",
|
||
])
|
||
s += _svg_footer()
|
||
return s
|
||
|
||
|
||
# ── Main ────────────────────────────────────────────────────────────
|
||
|
||
def main() -> None:
|
||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
diagrams = {
|
||
"placement_column_wise.svg": gen_column_wise(),
|
||
"placement_row_wise.svg": gen_row_wise(),
|
||
"placement_replicate.svg": gen_replicate(),
|
||
"placement_tiled_column_major.svg": gen_tiled(column_major=True),
|
||
"placement_tiled_row_major.svg": gen_tiled(column_major=False),
|
||
}
|
||
|
||
for name, svg in diagrams.items():
|
||
path = OUT_DIR / name
|
||
path.write_text(svg, encoding="utf-8")
|
||
print(f" wrote {path}")
|
||
|
||
print(f"\nGenerated {len(diagrams)} placement diagrams.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|