#!/usr/bin/env python3 """Generate SVG diagrams illustrating each placement strategy. Example tensor: (M=1024, K=512) fp16 (itemsize=2), 8 PEs. Tiled variants use tile_m=256, tile_k=128. Output: docs/diagrams/placement_*.svg """ from __future__ import annotations import math from pathlib import Path # ── Diagram parameters ────────────────────────────────────────────── M, K = 1024, 512 ITEMSIZE = 2 NUM_PE = 8 TILE_M, TILE_K = 256, 128 PE_COLORS = [ "#3b82f6", # PE0 blue "#10b981", # PE1 emerald "#f59e0b", # PE2 amber "#ef4444", # PE3 red "#8b5cf6", # PE4 violet "#ec4899", # PE5 pink "#06b6d4", # PE6 cyan "#f97316", # PE7 orange ] PE_TEXT_COLORS = [ "#fff", "#fff", "#000", "#fff", "#fff", "#fff", "#000", "#fff", ] OUT_DIR = Path(__file__).parent.parent / "docs" / "diagrams" # ── SVG helpers ───────────────────────────────────────────────────── def _svg_header(w: int, h: int, title: str) -> str: return ( f'\n' f'\n' f'{title}\n' ) def _svg_footer() -> str: return "\n" def _rect(x: float, y: float, w: float, h: float, fill: str, stroke: str = "#334155", sw: float = 1.0, opacity: float = 1.0) -> str: return ( f'\n' ) def _text(x: float, y: float, txt: str, size: int = 11, anchor: str = "middle", fill: str = "#1e293b", weight: str = "normal") -> str: return ( f'{txt}\n' ) def _line(x1: float, y1: float, x2: float, y2: float, stroke: str = "#94a3b8", sw: float = 1) -> str: return ( f'\n' ) def _format_bytes(n: int) -> str: if n >= (1 << 20): return f"{n >> 20} MB" if n >= (1 << 10): return f"{n >> 10} KB" return f"{n} B" def _legend(x: float, y0: float, num_pe: int = NUM_PE) -> str: s = _text(x + 50, y0, "PE Legend", size=12, weight="bold") for i in range(num_pe): ly = y0 + 18 + i * 22 s += _rect(x, ly - 12, 16, 16, PE_COLORS[i]) s += _text(x + 22, ly, f"PE{i}", size=11, anchor="start") return s def _axes(gx: float, gy: float, gw: float, gh: float, m_label: str = "M=1024", k_label: str = "K=512") -> str: """Draw axis labels and dimension arrows.""" s = "" # K axis (horizontal) label above grid s += _text(gx + gw / 2, gy - 8, f"← {k_label} →", size=11, fill="#475569") # M axis (vertical) label left of grid mx = gx - 12 my = gy + gh / 2 s += ( f'↑ {m_label} ↓\n' ) return s def _info_box(x: float, y: float, lines: list[str]) -> str: """Rounded info box with key/value lines.""" bw = max(len(l) for l in lines) * 7 + 20 bh = len(lines) * 18 + 12 s = _rect(x, y, bw, bh, "#e2e8f0", stroke="#94a3b8", sw=1) for i, line in enumerate(lines): s += _text(x + 10, y + 18 + i * 18, line, size=10, anchor="start", fill="#334155") return s # ── Grid drawing ──────────────────────────────────────────────────── def _draw_grid( gx: float, gy: float, gw: float, gh: float, cells: list[dict], # [{row, col, rspan, cspan, pe, label?, offset?}] rows: int, cols: int, cell_labels: bool = True, ) -> str: """Draw a grid of colored cells representing shard placement.""" cw = gw / cols ch = gh / rows s = "" for c in cells: cx = gx + c["col"] * cw cy = gy + c["row"] * ch w = c.get("cspan", 1) * cw h = c.get("rspan", 1) * ch pe = c["pe"] s += _rect(cx, cy, w, h, PE_COLORS[pe], stroke="#334155", sw=1.5) # PE label lx = cx + w / 2 ly = cy + h / 2 s += _text(lx, ly - 4, f"PE{pe}", size=12, fill=PE_TEXT_COLORS[pe], weight="bold") if cell_labels and "label" in c: s += _text(lx, ly + 12, c["label"], size=9, fill=PE_TEXT_COLORS[pe]) # Grid border s += _rect(gx, gy, gw, gh, "none", stroke="#1e293b", sw=2) return s # ── Strategy-specific generators ──────────────────────────────────── def gen_column_wise() -> str: """Column-wise: split K into 8 equal parts.""" W, H = 820, 500 s = _svg_header(W, H, "Placement: column_wise") s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → K axis split into {NUM_PE} parts", size=12, fill="#475569") gx, gy, gw, gh = 80, 90, 480, 320 chunk_k = K // NUM_PE # 64 chunk_bytes = M * chunk_k * ITEMSIZE s += _axes(gx, gy, gw, gh) cells = [] for i in range(NUM_PE): cells.append({ "row": 0, "col": i, "rspan": 1, "cspan": 1, "pe": i, "label": f"({M}×{chunk_k})", }) s += _draw_grid(gx, gy, gw, gh, cells, rows=1, cols=NUM_PE) # Column dimension labels cw = gw / NUM_PE for i in range(NUM_PE): cx = gx + i * cw + cw / 2 off = i * chunk_bytes s += _text(cx, gy + gh + 16, f"off={_format_bytes(off)}", size=9, fill="#475569") s += _text(cx, gy + gh + 30, f"{_format_bytes(chunk_bytes)}", size=9, fill="#64748b") s += _legend(620, 100) s += _info_box(620, 320, [ f"Strategy: column_wise", f"Split axis: K", f"Shards: {NUM_PE}", f"Each: ({M}, {chunk_k})", f"Each: {_format_bytes(chunk_bytes)}", f"Total: {_format_bytes(M * K * ITEMSIZE)}", ]) s += _svg_footer() return s def gen_row_wise() -> str: """Row-wise: split M into 8 equal parts.""" W, H = 820, 560 s = _svg_header(W, H, "Placement: row_wise") s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → M axis split into {NUM_PE} parts", size=12, fill="#475569") gx, gy, gw, gh = 80, 90, 320, 400 chunk_m = M // NUM_PE # 128 chunk_bytes = chunk_m * K * ITEMSIZE s += _axes(gx, gy, gw, gh) cells = [] for i in range(NUM_PE): cells.append({ "row": i, "col": 0, "rspan": 1, "cspan": 1, "pe": i, "label": f"({chunk_m}×{K})", }) s += _draw_grid(gx, gy, gw, gh, cells, rows=NUM_PE, cols=1) # Row dimension labels ch = gh / NUM_PE for i in range(NUM_PE): cy = gy + i * ch + ch / 2 off = i * chunk_bytes s += _text(gx + gw + 10, cy - 4, f"off={_format_bytes(off)}", size=9, anchor="start", fill="#475569") s += _text(gx + gw + 10, cy + 10, f"{_format_bytes(chunk_bytes)}", size=9, anchor="start", fill="#64748b") s += _legend(580, 100) s += _info_box(580, 320, [ f"Strategy: row_wise", f"Split axis: M", f"Shards: {NUM_PE}", f"Each: ({chunk_m}, {K})", f"Each: {_format_bytes(chunk_bytes)}", f"Total: {_format_bytes(M * K * ITEMSIZE)}", ]) s += _svg_footer() return s def gen_replicate() -> str: """Replicate: full copy per PE.""" W, H = 820, 500 s = _svg_header(W, H, "Placement: replicate") s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → full copy to each PE", size=12, fill="#475569") full_bytes = M * K * ITEMSIZE # Show 8 small copies in 2 rows × 4 cols cols, rows = 4, 2 margin_x, margin_y = 60, 90 gap = 16 bw = (700 - (cols - 1) * gap) / cols bh = (340 - (rows - 1) * gap) / rows for i in range(NUM_PE): r = i // cols c = i % cols bx = margin_x + c * (bw + gap) by = margin_y + r * (bh + gap) s += _rect(bx, by, bw, bh, PE_COLORS[i], stroke="#334155", sw=1.5) s += _text(bx + bw / 2, by + bh / 2 - 14, f"PE{i}", size=14, fill=PE_TEXT_COLORS[i], weight="bold") s += _text(bx + bw / 2, by + bh / 2 + 6, f"({M}×{K})", size=11, fill=PE_TEXT_COLORS[i]) s += _text(bx + bw / 2, by + bh / 2 + 22, f"{_format_bytes(full_bytes)}", size=10, fill=PE_TEXT_COLORS[i]) s += _text(bx + bw / 2, by + bh / 2 + 36, "offset=0", size=9, fill=PE_TEXT_COLORS[i]) s += _info_box(60, 450, [ f"Strategy: replicate | Shards: {NUM_PE} | Each: {_format_bytes(full_bytes)}" f" | Total mem: {_format_bytes(full_bytes * NUM_PE)}", ]) s += _svg_footer() return s def gen_tiled(column_major: bool) -> str: """2D tiled placement. column_major=True → tiled_column_major.""" name = "tiled_column_major" if column_major else "tiled_row_major" order = "column-major (K first)" if column_major else "row-major (M first)" tiles_m = M // TILE_M # 4 tiles_k = K // TILE_K # 4 total_tiles = tiles_m * tiles_k # 16 tile_bytes = TILE_M * TILE_K * ITEMSIZE W, H = 820, 620 s = _svg_header(W, H, f"Placement: {name}") s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16, tile=({TILE_M}×{TILE_K}) → " f"{tiles_m}×{tiles_k}={total_tiles} tiles, {order}", size=11, fill="#475569") gx, gy, gw, gh = 80, 90, 400, 400 s += _axes(gx, gy, gw, gh) # Build tile → PE mapping cells = [] idx = 0 if column_major: # iterate M first (rows), then K (cols) — but column-major means # we traverse in the order that fills columns first # Actually: column-major = K axis first within each M row # The implementation iterates: for mi in tiles_m: for ki in tiles_k for mi in range(tiles_m): for ki in range(tiles_k): pe = idx % NUM_PE row_bytes = K * ITEMSIZE offset = (mi * TILE_M * row_bytes) + (ki * TILE_K * ITEMSIZE) cells.append({ "row": mi, "col": ki, "rspan": 1, "cspan": 1, "pe": pe, "label": f"t{idx}", "offset": offset, "idx": idx, }) idx += 1 else: # row-major: iterate K first (cols), then M (rows) for ki in range(tiles_k): for mi in range(tiles_m): pe = idx % NUM_PE row_bytes = K * ITEMSIZE offset = (mi * TILE_M * row_bytes) + (ki * TILE_K * ITEMSIZE) cells.append({ "row": mi, "col": ki, "rspan": 1, "cspan": 1, "pe": pe, "label": f"t{idx}", "offset": offset, "idx": idx, }) idx += 1 s += _draw_grid(gx, gy, gw, gh, cells, rows=tiles_m, cols=tiles_k) # Tile dimension labels on top cw = gw / tiles_k for ki in range(tiles_k): cx = gx + ki * cw + cw / 2 s += _text(cx, gy + gh + 16, f"k={ki * TILE_K}..{(ki + 1) * TILE_K - 1}", size=9, fill="#475569") # Tile dimension labels on left ch = gh / tiles_m for mi in range(tiles_m): cy = gy + mi * ch + ch / 2 s += _text(gx - 16, cy, f"m={mi * TILE_M}..{(mi + 1) * TILE_M - 1}", size=9, anchor="end", fill="#475569") s += _legend(540, 90) # Assignment table table_y = 310 s += _text(540, table_y, "Tile Assignment Order", size=12, weight="bold") # Sort cells by idx for table sorted_cells = sorted(cells, key=lambda c: c["idx"]) for i, c in enumerate(sorted_cells): ty = table_y + 18 + i * 16 if ty > H - 20: break pe = c["pe"] s += _rect(540, ty - 10, 12, 12, PE_COLORS[pe]) s += _text(558, ty, f"t{c['idx']:>2d} → PE{pe} ({c['row']},{c['col']})" f" off={_format_bytes(c['offset'])}", size=9, anchor="start", fill="#334155") s += _info_box(80, H - 60, [ f"Strategy: {name} | Tile: ({TILE_M}×{TILE_K})={_format_bytes(tile_bytes)}" f" | Tiles: {total_tiles} | Total: {_format_bytes(M * K * ITEMSIZE)}", ]) s += _svg_footer() return s # ── Main ──────────────────────────────────────────────────────────── def main() -> None: OUT_DIR.mkdir(parents=True, exist_ok=True) diagrams = { "placement_column_wise.svg": gen_column_wise(), "placement_row_wise.svg": gen_row_wise(), "placement_replicate.svg": gen_replicate(), "placement_tiled_column_major.svg": gen_tiled(column_major=True), "placement_tiled_row_major.svg": gen_tiled(column_major=False), } for name, svg in diagrams.items(): path = OUT_DIR / name path.write_text(svg, encoding="utf-8") print(f" wrote {path}") print(f"\nGenerated {len(diagrams)} placement diagrams.") if __name__ == "__main__": main()