commit - release 1
This commit is contained in:
@@ -0,0 +1,393 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate SVG diagrams illustrating each placement strategy.
|
||||
|
||||
Example tensor: (M=1024, K=512) fp16 (itemsize=2), 8 PEs.
|
||||
Tiled variants use tile_m=256, tile_k=128.
|
||||
|
||||
Output: docs/diagrams/placement_*.svg
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
# ── Diagram parameters ──────────────────────────────────────────────
|
||||
M, K = 1024, 512
|
||||
ITEMSIZE = 2
|
||||
NUM_PE = 8
|
||||
TILE_M, TILE_K = 256, 128
|
||||
|
||||
PE_COLORS = [
|
||||
"#3b82f6", # PE0 blue
|
||||
"#10b981", # PE1 emerald
|
||||
"#f59e0b", # PE2 amber
|
||||
"#ef4444", # PE3 red
|
||||
"#8b5cf6", # PE4 violet
|
||||
"#ec4899", # PE5 pink
|
||||
"#06b6d4", # PE6 cyan
|
||||
"#f97316", # PE7 orange
|
||||
]
|
||||
PE_TEXT_COLORS = [
|
||||
"#fff", "#fff", "#000", "#fff",
|
||||
"#fff", "#fff", "#000", "#fff",
|
||||
]
|
||||
|
||||
OUT_DIR = Path(__file__).parent.parent / "docs" / "diagrams"
|
||||
|
||||
# ── SVG helpers ─────────────────────────────────────────────────────
|
||||
|
||||
def _svg_header(w: int, h: int, title: str) -> str:
|
||||
return (
|
||||
f'<svg xmlns="http://www.w3.org/2000/svg" width="{w}" height="{h}"'
|
||||
f' viewBox="0 0 {w} {h}" font-family="monospace">\n'
|
||||
f'<rect width="{w}" height="{h}" fill="#f8fafc" rx="6"/>\n'
|
||||
f'<text x="{w // 2}" y="32" text-anchor="middle" font-size="16"'
|
||||
f' font-weight="bold" fill="#1e293b">{title}</text>\n'
|
||||
)
|
||||
|
||||
def _svg_footer() -> str:
|
||||
return "</svg>\n"
|
||||
|
||||
def _rect(x: float, y: float, w: float, h: float, fill: str,
|
||||
stroke: str = "#334155", sw: float = 1.0, opacity: float = 1.0) -> str:
|
||||
return (
|
||||
f'<rect x="{x:.1f}" y="{y:.1f}" width="{w:.1f}" height="{h:.1f}"'
|
||||
f' fill="{fill}" stroke="{stroke}" stroke-width="{sw}"'
|
||||
f' fill-opacity="{opacity}" rx="2"/>\n'
|
||||
)
|
||||
|
||||
def _text(x: float, y: float, txt: str, size: int = 11,
|
||||
anchor: str = "middle", fill: str = "#1e293b",
|
||||
weight: str = "normal") -> str:
|
||||
return (
|
||||
f'<text x="{x:.1f}" y="{y:.1f}" text-anchor="{anchor}"'
|
||||
f' font-size="{size}" fill="{fill}" font-weight="{weight}">{txt}</text>\n'
|
||||
)
|
||||
|
||||
def _line(x1: float, y1: float, x2: float, y2: float,
|
||||
stroke: str = "#94a3b8", sw: float = 1) -> str:
|
||||
return (
|
||||
f'<line x1="{x1:.1f}" y1="{y1:.1f}" x2="{x2:.1f}" y2="{y2:.1f}"'
|
||||
f' stroke="{stroke}" stroke-width="{sw}"/>\n'
|
||||
)
|
||||
|
||||
def _format_bytes(n: int) -> str:
|
||||
if n >= (1 << 20):
|
||||
return f"{n >> 20} MB"
|
||||
if n >= (1 << 10):
|
||||
return f"{n >> 10} KB"
|
||||
return f"{n} B"
|
||||
|
||||
def _legend(x: float, y0: float, num_pe: int = NUM_PE) -> str:
|
||||
s = _text(x + 50, y0, "PE Legend", size=12, weight="bold")
|
||||
for i in range(num_pe):
|
||||
ly = y0 + 18 + i * 22
|
||||
s += _rect(x, ly - 12, 16, 16, PE_COLORS[i])
|
||||
s += _text(x + 22, ly, f"PE{i}", size=11, anchor="start")
|
||||
return s
|
||||
|
||||
def _axes(gx: float, gy: float, gw: float, gh: float,
|
||||
m_label: str = "M=1024", k_label: str = "K=512") -> str:
|
||||
"""Draw axis labels and dimension arrows."""
|
||||
s = ""
|
||||
# K axis (horizontal) label above grid
|
||||
s += _text(gx + gw / 2, gy - 8, f"← {k_label} →", size=11, fill="#475569")
|
||||
# M axis (vertical) label left of grid
|
||||
mx = gx - 12
|
||||
my = gy + gh / 2
|
||||
s += (
|
||||
f'<text x="{mx:.1f}" y="{my:.1f}" text-anchor="middle"'
|
||||
f' font-size="11" fill="#475569"'
|
||||
f' transform="rotate(-90 {mx:.1f} {my:.1f})">↑ {m_label} ↓</text>\n'
|
||||
)
|
||||
return s
|
||||
|
||||
def _info_box(x: float, y: float, lines: list[str]) -> str:
|
||||
"""Rounded info box with key/value lines."""
|
||||
bw = max(len(l) for l in lines) * 7 + 20
|
||||
bh = len(lines) * 18 + 12
|
||||
s = _rect(x, y, bw, bh, "#e2e8f0", stroke="#94a3b8", sw=1)
|
||||
for i, line in enumerate(lines):
|
||||
s += _text(x + 10, y + 18 + i * 18, line, size=10, anchor="start", fill="#334155")
|
||||
return s
|
||||
|
||||
# ── Grid drawing ────────────────────────────────────────────────────
|
||||
|
||||
def _draw_grid(
|
||||
gx: float, gy: float, gw: float, gh: float,
|
||||
cells: list[dict], # [{row, col, rspan, cspan, pe, label?, offset?}]
|
||||
rows: int, cols: int,
|
||||
cell_labels: bool = True,
|
||||
) -> str:
|
||||
"""Draw a grid of colored cells representing shard placement."""
|
||||
cw = gw / cols
|
||||
ch = gh / rows
|
||||
s = ""
|
||||
for c in cells:
|
||||
cx = gx + c["col"] * cw
|
||||
cy = gy + c["row"] * ch
|
||||
w = c.get("cspan", 1) * cw
|
||||
h = c.get("rspan", 1) * ch
|
||||
pe = c["pe"]
|
||||
s += _rect(cx, cy, w, h, PE_COLORS[pe], stroke="#334155", sw=1.5)
|
||||
# PE label
|
||||
lx = cx + w / 2
|
||||
ly = cy + h / 2
|
||||
s += _text(lx, ly - 4, f"PE{pe}", size=12,
|
||||
fill=PE_TEXT_COLORS[pe], weight="bold")
|
||||
if cell_labels and "label" in c:
|
||||
s += _text(lx, ly + 12, c["label"], size=9,
|
||||
fill=PE_TEXT_COLORS[pe])
|
||||
# Grid border
|
||||
s += _rect(gx, gy, gw, gh, "none", stroke="#1e293b", sw=2)
|
||||
return s
|
||||
|
||||
|
||||
# ── Strategy-specific generators ────────────────────────────────────
|
||||
|
||||
def gen_column_wise() -> str:
|
||||
"""Column-wise: split K into 8 equal parts."""
|
||||
W, H = 820, 500
|
||||
s = _svg_header(W, H, "Placement: column_wise")
|
||||
s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → K axis split into {NUM_PE} parts",
|
||||
size=12, fill="#475569")
|
||||
|
||||
gx, gy, gw, gh = 80, 90, 480, 320
|
||||
chunk_k = K // NUM_PE # 64
|
||||
chunk_bytes = M * chunk_k * ITEMSIZE
|
||||
|
||||
s += _axes(gx, gy, gw, gh)
|
||||
cells = []
|
||||
for i in range(NUM_PE):
|
||||
cells.append({
|
||||
"row": 0, "col": i, "rspan": 1, "cspan": 1,
|
||||
"pe": i,
|
||||
"label": f"({M}×{chunk_k})",
|
||||
})
|
||||
s += _draw_grid(gx, gy, gw, gh, cells, rows=1, cols=NUM_PE)
|
||||
|
||||
# Column dimension labels
|
||||
cw = gw / NUM_PE
|
||||
for i in range(NUM_PE):
|
||||
cx = gx + i * cw + cw / 2
|
||||
off = i * chunk_bytes
|
||||
s += _text(cx, gy + gh + 16, f"off={_format_bytes(off)}", size=9, fill="#475569")
|
||||
s += _text(cx, gy + gh + 30, f"{_format_bytes(chunk_bytes)}", size=9, fill="#64748b")
|
||||
|
||||
s += _legend(620, 100)
|
||||
s += _info_box(620, 320, [
|
||||
f"Strategy: column_wise",
|
||||
f"Split axis: K",
|
||||
f"Shards: {NUM_PE}",
|
||||
f"Each: ({M}, {chunk_k})",
|
||||
f"Each: {_format_bytes(chunk_bytes)}",
|
||||
f"Total: {_format_bytes(M * K * ITEMSIZE)}",
|
||||
])
|
||||
s += _svg_footer()
|
||||
return s
|
||||
|
||||
|
||||
def gen_row_wise() -> str:
|
||||
"""Row-wise: split M into 8 equal parts."""
|
||||
W, H = 820, 560
|
||||
s = _svg_header(W, H, "Placement: row_wise")
|
||||
s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → M axis split into {NUM_PE} parts",
|
||||
size=12, fill="#475569")
|
||||
|
||||
gx, gy, gw, gh = 80, 90, 320, 400
|
||||
chunk_m = M // NUM_PE # 128
|
||||
chunk_bytes = chunk_m * K * ITEMSIZE
|
||||
|
||||
s += _axes(gx, gy, gw, gh)
|
||||
cells = []
|
||||
for i in range(NUM_PE):
|
||||
cells.append({
|
||||
"row": i, "col": 0, "rspan": 1, "cspan": 1,
|
||||
"pe": i,
|
||||
"label": f"({chunk_m}×{K})",
|
||||
})
|
||||
s += _draw_grid(gx, gy, gw, gh, cells, rows=NUM_PE, cols=1)
|
||||
|
||||
# Row dimension labels
|
||||
ch = gh / NUM_PE
|
||||
for i in range(NUM_PE):
|
||||
cy = gy + i * ch + ch / 2
|
||||
off = i * chunk_bytes
|
||||
s += _text(gx + gw + 10, cy - 4, f"off={_format_bytes(off)}",
|
||||
size=9, anchor="start", fill="#475569")
|
||||
s += _text(gx + gw + 10, cy + 10, f"{_format_bytes(chunk_bytes)}",
|
||||
size=9, anchor="start", fill="#64748b")
|
||||
|
||||
s += _legend(580, 100)
|
||||
s += _info_box(580, 320, [
|
||||
f"Strategy: row_wise",
|
||||
f"Split axis: M",
|
||||
f"Shards: {NUM_PE}",
|
||||
f"Each: ({chunk_m}, {K})",
|
||||
f"Each: {_format_bytes(chunk_bytes)}",
|
||||
f"Total: {_format_bytes(M * K * ITEMSIZE)}",
|
||||
])
|
||||
s += _svg_footer()
|
||||
return s
|
||||
|
||||
|
||||
def gen_replicate() -> str:
|
||||
"""Replicate: full copy per PE."""
|
||||
W, H = 820, 500
|
||||
s = _svg_header(W, H, "Placement: replicate")
|
||||
s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → full copy to each PE",
|
||||
size=12, fill="#475569")
|
||||
|
||||
full_bytes = M * K * ITEMSIZE
|
||||
# Show 8 small copies in 2 rows × 4 cols
|
||||
cols, rows = 4, 2
|
||||
margin_x, margin_y = 60, 90
|
||||
gap = 16
|
||||
bw = (700 - (cols - 1) * gap) / cols
|
||||
bh = (340 - (rows - 1) * gap) / rows
|
||||
|
||||
for i in range(NUM_PE):
|
||||
r = i // cols
|
||||
c = i % cols
|
||||
bx = margin_x + c * (bw + gap)
|
||||
by = margin_y + r * (bh + gap)
|
||||
s += _rect(bx, by, bw, bh, PE_COLORS[i], stroke="#334155", sw=1.5)
|
||||
s += _text(bx + bw / 2, by + bh / 2 - 14, f"PE{i}",
|
||||
size=14, fill=PE_TEXT_COLORS[i], weight="bold")
|
||||
s += _text(bx + bw / 2, by + bh / 2 + 6, f"({M}×{K})",
|
||||
size=11, fill=PE_TEXT_COLORS[i])
|
||||
s += _text(bx + bw / 2, by + bh / 2 + 22, f"{_format_bytes(full_bytes)}",
|
||||
size=10, fill=PE_TEXT_COLORS[i])
|
||||
s += _text(bx + bw / 2, by + bh / 2 + 36, "offset=0",
|
||||
size=9, fill=PE_TEXT_COLORS[i])
|
||||
|
||||
s += _info_box(60, 450, [
|
||||
f"Strategy: replicate | Shards: {NUM_PE} | Each: {_format_bytes(full_bytes)}"
|
||||
f" | Total mem: {_format_bytes(full_bytes * NUM_PE)}",
|
||||
])
|
||||
s += _svg_footer()
|
||||
return s
|
||||
|
||||
|
||||
def gen_tiled(column_major: bool) -> str:
|
||||
"""2D tiled placement. column_major=True → tiled_column_major."""
|
||||
name = "tiled_column_major" if column_major else "tiled_row_major"
|
||||
order = "column-major (K first)" if column_major else "row-major (M first)"
|
||||
|
||||
tiles_m = M // TILE_M # 4
|
||||
tiles_k = K // TILE_K # 4
|
||||
total_tiles = tiles_m * tiles_k # 16
|
||||
tile_bytes = TILE_M * TILE_K * ITEMSIZE
|
||||
|
||||
W, H = 820, 620
|
||||
s = _svg_header(W, H, f"Placement: {name}")
|
||||
s += _text(W // 2, 54,
|
||||
f"Tensor ({M}×{K}) fp16, tile=({TILE_M}×{TILE_K}) → "
|
||||
f"{tiles_m}×{tiles_k}={total_tiles} tiles, {order}",
|
||||
size=11, fill="#475569")
|
||||
|
||||
gx, gy, gw, gh = 80, 90, 400, 400
|
||||
s += _axes(gx, gy, gw, gh)
|
||||
|
||||
# Build tile → PE mapping
|
||||
cells = []
|
||||
idx = 0
|
||||
if column_major:
|
||||
# iterate M first (rows), then K (cols) — but column-major means
|
||||
# we traverse in the order that fills columns first
|
||||
# Actually: column-major = K axis first within each M row
|
||||
# The implementation iterates: for mi in tiles_m: for ki in tiles_k
|
||||
for mi in range(tiles_m):
|
||||
for ki in range(tiles_k):
|
||||
pe = idx % NUM_PE
|
||||
row_bytes = K * ITEMSIZE
|
||||
offset = (mi * TILE_M * row_bytes) + (ki * TILE_K * ITEMSIZE)
|
||||
cells.append({
|
||||
"row": mi, "col": ki, "rspan": 1, "cspan": 1,
|
||||
"pe": pe,
|
||||
"label": f"t{idx}",
|
||||
"offset": offset,
|
||||
"idx": idx,
|
||||
})
|
||||
idx += 1
|
||||
else:
|
||||
# row-major: iterate K first (cols), then M (rows)
|
||||
for ki in range(tiles_k):
|
||||
for mi in range(tiles_m):
|
||||
pe = idx % NUM_PE
|
||||
row_bytes = K * ITEMSIZE
|
||||
offset = (mi * TILE_M * row_bytes) + (ki * TILE_K * ITEMSIZE)
|
||||
cells.append({
|
||||
"row": mi, "col": ki, "rspan": 1, "cspan": 1,
|
||||
"pe": pe,
|
||||
"label": f"t{idx}",
|
||||
"offset": offset,
|
||||
"idx": idx,
|
||||
})
|
||||
idx += 1
|
||||
|
||||
s += _draw_grid(gx, gy, gw, gh, cells, rows=tiles_m, cols=tiles_k)
|
||||
|
||||
# Tile dimension labels on top
|
||||
cw = gw / tiles_k
|
||||
for ki in range(tiles_k):
|
||||
cx = gx + ki * cw + cw / 2
|
||||
s += _text(cx, gy + gh + 16, f"k={ki * TILE_K}..{(ki + 1) * TILE_K - 1}",
|
||||
size=9, fill="#475569")
|
||||
|
||||
# Tile dimension labels on left
|
||||
ch = gh / tiles_m
|
||||
for mi in range(tiles_m):
|
||||
cy = gy + mi * ch + ch / 2
|
||||
s += _text(gx - 16, cy, f"m={mi * TILE_M}..{(mi + 1) * TILE_M - 1}",
|
||||
size=9, anchor="end", fill="#475569")
|
||||
|
||||
s += _legend(540, 90)
|
||||
|
||||
# Assignment table
|
||||
table_y = 310
|
||||
s += _text(540, table_y, "Tile Assignment Order", size=12, weight="bold")
|
||||
# Sort cells by idx for table
|
||||
sorted_cells = sorted(cells, key=lambda c: c["idx"])
|
||||
for i, c in enumerate(sorted_cells):
|
||||
ty = table_y + 18 + i * 16
|
||||
if ty > H - 20:
|
||||
break
|
||||
pe = c["pe"]
|
||||
s += _rect(540, ty - 10, 12, 12, PE_COLORS[pe])
|
||||
s += _text(558, ty,
|
||||
f"t{c['idx']:>2d} → PE{pe} ({c['row']},{c['col']})"
|
||||
f" off={_format_bytes(c['offset'])}",
|
||||
size=9, anchor="start", fill="#334155")
|
||||
|
||||
s += _info_box(80, H - 60, [
|
||||
f"Strategy: {name} | Tile: ({TILE_M}×{TILE_K})={_format_bytes(tile_bytes)}"
|
||||
f" | Tiles: {total_tiles} | Total: {_format_bytes(M * K * ITEMSIZE)}",
|
||||
])
|
||||
s += _svg_footer()
|
||||
return s
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
diagrams = {
|
||||
"placement_column_wise.svg": gen_column_wise(),
|
||||
"placement_row_wise.svg": gen_row_wise(),
|
||||
"placement_replicate.svg": gen_replicate(),
|
||||
"placement_tiled_column_major.svg": gen_tiled(column_major=True),
|
||||
"placement_tiled_row_major.svg": gen_tiled(column_major=False),
|
||||
}
|
||||
|
||||
for name, svg in diagrams.items():
|
||||
path = OUT_DIR / name
|
||||
path.write_text(svg, encoding="utf-8")
|
||||
print(f" wrote {path}")
|
||||
|
||||
print(f"\nGenerated {len(diagrams)} placement diagrams.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user