commit - release 1

This commit is contained in:
2026-03-18 11:47:48 -07:00
commit 6f43807900
109 changed files with 14909 additions and 0 deletions
+393
View File
@@ -0,0 +1,393 @@
#!/usr/bin/env python3
"""Generate SVG diagrams illustrating each placement strategy.
Example tensor: (M=1024, K=512) fp16 (itemsize=2), 8 PEs.
Tiled variants use tile_m=256, tile_k=128.
Output: docs/diagrams/placement_*.svg
"""
from __future__ import annotations
import math
from pathlib import Path
# ── Diagram parameters ──────────────────────────────────────────────
M, K = 1024, 512
ITEMSIZE = 2
NUM_PE = 8
TILE_M, TILE_K = 256, 128
PE_COLORS = [
"#3b82f6", # PE0 blue
"#10b981", # PE1 emerald
"#f59e0b", # PE2 amber
"#ef4444", # PE3 red
"#8b5cf6", # PE4 violet
"#ec4899", # PE5 pink
"#06b6d4", # PE6 cyan
"#f97316", # PE7 orange
]
PE_TEXT_COLORS = [
"#fff", "#fff", "#000", "#fff",
"#fff", "#fff", "#000", "#fff",
]
OUT_DIR = Path(__file__).parent.parent / "docs" / "diagrams"
# ── SVG helpers ─────────────────────────────────────────────────────
def _svg_header(w: int, h: int, title: str) -> str:
return (
f'<svg xmlns="http://www.w3.org/2000/svg" width="{w}" height="{h}"'
f' viewBox="0 0 {w} {h}" font-family="monospace">\n'
f'<rect width="{w}" height="{h}" fill="#f8fafc" rx="6"/>\n'
f'<text x="{w // 2}" y="32" text-anchor="middle" font-size="16"'
f' font-weight="bold" fill="#1e293b">{title}</text>\n'
)
def _svg_footer() -> str:
return "</svg>\n"
def _rect(x: float, y: float, w: float, h: float, fill: str,
stroke: str = "#334155", sw: float = 1.0, opacity: float = 1.0) -> str:
return (
f'<rect x="{x:.1f}" y="{y:.1f}" width="{w:.1f}" height="{h:.1f}"'
f' fill="{fill}" stroke="{stroke}" stroke-width="{sw}"'
f' fill-opacity="{opacity}" rx="2"/>\n'
)
def _text(x: float, y: float, txt: str, size: int = 11,
anchor: str = "middle", fill: str = "#1e293b",
weight: str = "normal") -> str:
return (
f'<text x="{x:.1f}" y="{y:.1f}" text-anchor="{anchor}"'
f' font-size="{size}" fill="{fill}" font-weight="{weight}">{txt}</text>\n'
)
def _line(x1: float, y1: float, x2: float, y2: float,
stroke: str = "#94a3b8", sw: float = 1) -> str:
return (
f'<line x1="{x1:.1f}" y1="{y1:.1f}" x2="{x2:.1f}" y2="{y2:.1f}"'
f' stroke="{stroke}" stroke-width="{sw}"/>\n'
)
def _format_bytes(n: int) -> str:
if n >= (1 << 20):
return f"{n >> 20} MB"
if n >= (1 << 10):
return f"{n >> 10} KB"
return f"{n} B"
def _legend(x: float, y0: float, num_pe: int = NUM_PE) -> str:
s = _text(x + 50, y0, "PE Legend", size=12, weight="bold")
for i in range(num_pe):
ly = y0 + 18 + i * 22
s += _rect(x, ly - 12, 16, 16, PE_COLORS[i])
s += _text(x + 22, ly, f"PE{i}", size=11, anchor="start")
return s
def _axes(gx: float, gy: float, gw: float, gh: float,
m_label: str = "M=1024", k_label: str = "K=512") -> str:
"""Draw axis labels and dimension arrows."""
s = ""
# K axis (horizontal) label above grid
s += _text(gx + gw / 2, gy - 8, f"{k_label}", size=11, fill="#475569")
# M axis (vertical) label left of grid
mx = gx - 12
my = gy + gh / 2
s += (
f'<text x="{mx:.1f}" y="{my:.1f}" text-anchor="middle"'
f' font-size="11" fill="#475569"'
f' transform="rotate(-90 {mx:.1f} {my:.1f})">↑ {m_label} ↓</text>\n'
)
return s
def _info_box(x: float, y: float, lines: list[str]) -> str:
"""Rounded info box with key/value lines."""
bw = max(len(l) for l in lines) * 7 + 20
bh = len(lines) * 18 + 12
s = _rect(x, y, bw, bh, "#e2e8f0", stroke="#94a3b8", sw=1)
for i, line in enumerate(lines):
s += _text(x + 10, y + 18 + i * 18, line, size=10, anchor="start", fill="#334155")
return s
# ── Grid drawing ────────────────────────────────────────────────────
def _draw_grid(
gx: float, gy: float, gw: float, gh: float,
cells: list[dict], # [{row, col, rspan, cspan, pe, label?, offset?}]
rows: int, cols: int,
cell_labels: bool = True,
) -> str:
"""Draw a grid of colored cells representing shard placement."""
cw = gw / cols
ch = gh / rows
s = ""
for c in cells:
cx = gx + c["col"] * cw
cy = gy + c["row"] * ch
w = c.get("cspan", 1) * cw
h = c.get("rspan", 1) * ch
pe = c["pe"]
s += _rect(cx, cy, w, h, PE_COLORS[pe], stroke="#334155", sw=1.5)
# PE label
lx = cx + w / 2
ly = cy + h / 2
s += _text(lx, ly - 4, f"PE{pe}", size=12,
fill=PE_TEXT_COLORS[pe], weight="bold")
if cell_labels and "label" in c:
s += _text(lx, ly + 12, c["label"], size=9,
fill=PE_TEXT_COLORS[pe])
# Grid border
s += _rect(gx, gy, gw, gh, "none", stroke="#1e293b", sw=2)
return s
# ── Strategy-specific generators ────────────────────────────────────
def gen_column_wise() -> str:
"""Column-wise: split K into 8 equal parts."""
W, H = 820, 500
s = _svg_header(W, H, "Placement: column_wise")
s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → K axis split into {NUM_PE} parts",
size=12, fill="#475569")
gx, gy, gw, gh = 80, 90, 480, 320
chunk_k = K // NUM_PE # 64
chunk_bytes = M * chunk_k * ITEMSIZE
s += _axes(gx, gy, gw, gh)
cells = []
for i in range(NUM_PE):
cells.append({
"row": 0, "col": i, "rspan": 1, "cspan": 1,
"pe": i,
"label": f"({M}×{chunk_k})",
})
s += _draw_grid(gx, gy, gw, gh, cells, rows=1, cols=NUM_PE)
# Column dimension labels
cw = gw / NUM_PE
for i in range(NUM_PE):
cx = gx + i * cw + cw / 2
off = i * chunk_bytes
s += _text(cx, gy + gh + 16, f"off={_format_bytes(off)}", size=9, fill="#475569")
s += _text(cx, gy + gh + 30, f"{_format_bytes(chunk_bytes)}", size=9, fill="#64748b")
s += _legend(620, 100)
s += _info_box(620, 320, [
f"Strategy: column_wise",
f"Split axis: K",
f"Shards: {NUM_PE}",
f"Each: ({M}, {chunk_k})",
f"Each: {_format_bytes(chunk_bytes)}",
f"Total: {_format_bytes(M * K * ITEMSIZE)}",
])
s += _svg_footer()
return s
def gen_row_wise() -> str:
"""Row-wise: split M into 8 equal parts."""
W, H = 820, 560
s = _svg_header(W, H, "Placement: row_wise")
s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → M axis split into {NUM_PE} parts",
size=12, fill="#475569")
gx, gy, gw, gh = 80, 90, 320, 400
chunk_m = M // NUM_PE # 128
chunk_bytes = chunk_m * K * ITEMSIZE
s += _axes(gx, gy, gw, gh)
cells = []
for i in range(NUM_PE):
cells.append({
"row": i, "col": 0, "rspan": 1, "cspan": 1,
"pe": i,
"label": f"({chunk_m}×{K})",
})
s += _draw_grid(gx, gy, gw, gh, cells, rows=NUM_PE, cols=1)
# Row dimension labels
ch = gh / NUM_PE
for i in range(NUM_PE):
cy = gy + i * ch + ch / 2
off = i * chunk_bytes
s += _text(gx + gw + 10, cy - 4, f"off={_format_bytes(off)}",
size=9, anchor="start", fill="#475569")
s += _text(gx + gw + 10, cy + 10, f"{_format_bytes(chunk_bytes)}",
size=9, anchor="start", fill="#64748b")
s += _legend(580, 100)
s += _info_box(580, 320, [
f"Strategy: row_wise",
f"Split axis: M",
f"Shards: {NUM_PE}",
f"Each: ({chunk_m}, {K})",
f"Each: {_format_bytes(chunk_bytes)}",
f"Total: {_format_bytes(M * K * ITEMSIZE)}",
])
s += _svg_footer()
return s
def gen_replicate() -> str:
"""Replicate: full copy per PE."""
W, H = 820, 500
s = _svg_header(W, H, "Placement: replicate")
s += _text(W // 2, 54, f"Tensor ({M}×{K}) fp16 → full copy to each PE",
size=12, fill="#475569")
full_bytes = M * K * ITEMSIZE
# Show 8 small copies in 2 rows × 4 cols
cols, rows = 4, 2
margin_x, margin_y = 60, 90
gap = 16
bw = (700 - (cols - 1) * gap) / cols
bh = (340 - (rows - 1) * gap) / rows
for i in range(NUM_PE):
r = i // cols
c = i % cols
bx = margin_x + c * (bw + gap)
by = margin_y + r * (bh + gap)
s += _rect(bx, by, bw, bh, PE_COLORS[i], stroke="#334155", sw=1.5)
s += _text(bx + bw / 2, by + bh / 2 - 14, f"PE{i}",
size=14, fill=PE_TEXT_COLORS[i], weight="bold")
s += _text(bx + bw / 2, by + bh / 2 + 6, f"({M}×{K})",
size=11, fill=PE_TEXT_COLORS[i])
s += _text(bx + bw / 2, by + bh / 2 + 22, f"{_format_bytes(full_bytes)}",
size=10, fill=PE_TEXT_COLORS[i])
s += _text(bx + bw / 2, by + bh / 2 + 36, "offset=0",
size=9, fill=PE_TEXT_COLORS[i])
s += _info_box(60, 450, [
f"Strategy: replicate | Shards: {NUM_PE} | Each: {_format_bytes(full_bytes)}"
f" | Total mem: {_format_bytes(full_bytes * NUM_PE)}",
])
s += _svg_footer()
return s
def gen_tiled(column_major: bool) -> str:
"""2D tiled placement. column_major=True → tiled_column_major."""
name = "tiled_column_major" if column_major else "tiled_row_major"
order = "column-major (K first)" if column_major else "row-major (M first)"
tiles_m = M // TILE_M # 4
tiles_k = K // TILE_K # 4
total_tiles = tiles_m * tiles_k # 16
tile_bytes = TILE_M * TILE_K * ITEMSIZE
W, H = 820, 620
s = _svg_header(W, H, f"Placement: {name}")
s += _text(W // 2, 54,
f"Tensor ({M}×{K}) fp16, tile=({TILE_M}×{TILE_K}) → "
f"{tiles_m}×{tiles_k}={total_tiles} tiles, {order}",
size=11, fill="#475569")
gx, gy, gw, gh = 80, 90, 400, 400
s += _axes(gx, gy, gw, gh)
# Build tile → PE mapping
cells = []
idx = 0
if column_major:
# iterate M first (rows), then K (cols) — but column-major means
# we traverse in the order that fills columns first
# Actually: column-major = K axis first within each M row
# The implementation iterates: for mi in tiles_m: for ki in tiles_k
for mi in range(tiles_m):
for ki in range(tiles_k):
pe = idx % NUM_PE
row_bytes = K * ITEMSIZE
offset = (mi * TILE_M * row_bytes) + (ki * TILE_K * ITEMSIZE)
cells.append({
"row": mi, "col": ki, "rspan": 1, "cspan": 1,
"pe": pe,
"label": f"t{idx}",
"offset": offset,
"idx": idx,
})
idx += 1
else:
# row-major: iterate K first (cols), then M (rows)
for ki in range(tiles_k):
for mi in range(tiles_m):
pe = idx % NUM_PE
row_bytes = K * ITEMSIZE
offset = (mi * TILE_M * row_bytes) + (ki * TILE_K * ITEMSIZE)
cells.append({
"row": mi, "col": ki, "rspan": 1, "cspan": 1,
"pe": pe,
"label": f"t{idx}",
"offset": offset,
"idx": idx,
})
idx += 1
s += _draw_grid(gx, gy, gw, gh, cells, rows=tiles_m, cols=tiles_k)
# Tile dimension labels on top
cw = gw / tiles_k
for ki in range(tiles_k):
cx = gx + ki * cw + cw / 2
s += _text(cx, gy + gh + 16, f"k={ki * TILE_K}..{(ki + 1) * TILE_K - 1}",
size=9, fill="#475569")
# Tile dimension labels on left
ch = gh / tiles_m
for mi in range(tiles_m):
cy = gy + mi * ch + ch / 2
s += _text(gx - 16, cy, f"m={mi * TILE_M}..{(mi + 1) * TILE_M - 1}",
size=9, anchor="end", fill="#475569")
s += _legend(540, 90)
# Assignment table
table_y = 310
s += _text(540, table_y, "Tile Assignment Order", size=12, weight="bold")
# Sort cells by idx for table
sorted_cells = sorted(cells, key=lambda c: c["idx"])
for i, c in enumerate(sorted_cells):
ty = table_y + 18 + i * 16
if ty > H - 20:
break
pe = c["pe"]
s += _rect(540, ty - 10, 12, 12, PE_COLORS[pe])
s += _text(558, ty,
f"t{c['idx']:>2d} → PE{pe} ({c['row']},{c['col']})"
f" off={_format_bytes(c['offset'])}",
size=9, anchor="start", fill="#334155")
s += _info_box(80, H - 60, [
f"Strategy: {name} | Tile: ({TILE_M}×{TILE_K})={_format_bytes(tile_bytes)}"
f" | Tiles: {total_tiles} | Total: {_format_bytes(M * K * ITEMSIZE)}",
])
s += _svg_footer()
return s
# ── Main ────────────────────────────────────────────────────────────
def main() -> None:
OUT_DIR.mkdir(parents=True, exist_ok=True)
diagrams = {
"placement_column_wise.svg": gen_column_wise(),
"placement_row_wise.svg": gen_row_wise(),
"placement_replicate.svg": gen_replicate(),
"placement_tiled_column_major.svg": gen_tiled(column_major=True),
"placement_tiled_row_major.svg": gen_tiled(column_major=False),
}
for name, svg in diagrams.items():
path = OUT_DIR / name
path.write_text(svg, encoding="utf-8")
print(f" wrote {path}")
print(f"\nGenerated {len(diagrams)} placement diagrams.")
if __name__ == "__main__":
main()