diff --git a/benches/matmul_composite.py b/benches/matmul_composite.py new file mode 100644 index 0000000..0365d6e --- /dev/null +++ b/benches/matmul_composite.py @@ -0,0 +1,69 @@ +"""Single-PE composite GEMM for PE_accelerator perf characterization. + +Three operand-staging variants are selectable via MATMUL_VARIANT: + + - "ref_ref" (default): a = tl.ref, b = tl.ref + Both operands HBM-resident; scheduler streams per-tile DMA. + - "load_ref": a = tl.load, b = tl.ref + A eagerly DMA'd into TCM up-front; B streamed per-tile. + - "load_load": a = tl.load, b = tl.load + Both eagerly DMA'd into TCM up-front. + +Other env vars: MATMUL_M, MATMUL_K, MATMUL_N, MATMUL_DTYPE. + +Run: + MATMUL_M=256 MATMUL_K=256 MATMUL_N=256 MATMUL_VARIANT=load_ref \ + kernbench run --topology topology.yaml --bench matmul_composite +""" +import os + +from kernbench.policy.placement.dp import DPPolicy + +M = int(os.environ.get("MATMUL_M", "256")) +K = int(os.environ.get("MATMUL_K", "256")) +N = int(os.environ.get("MATMUL_N", "256")) +DTYPE = os.environ.get("MATMUL_DTYPE", "f16") +VARIANT = os.environ.get("MATMUL_VARIANT", "ref_ref") + + +def _kernel_ref_ref(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"): + M, K, N = int(M), int(K), int(N) + a = tl.ref(int(a_ptr), shape=(M, K), dtype=DTYPE) + b = tl.ref(int(b_ptr), shape=(K, N), dtype=DTYPE) + h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr)) + tl.wait(h) + + +def _kernel_load_ref(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"): + M, K, N = int(M), int(K), int(N) + a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE) + b = tl.ref(int(b_ptr), shape=(K, N), dtype=DTYPE) + h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr)) + tl.wait(h) + + +def _kernel_load_load(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"): + M, K, N = int(M), int(K), int(N) + a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE) + b = tl.load(int(b_ptr), shape=(K, N), dtype=DTYPE) + h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr)) + tl.wait(h) + + +_KERNELS = { + "ref_ref": _kernel_ref_ref, + "load_ref": _kernel_load_ref, + "load_load": _kernel_load_load, +} + + +def run(torch): + if VARIANT not in _KERNELS: + raise ValueError(f"unknown MATMUL_VARIANT={VARIANT!r}; " + f"expected one of {list(_KERNELS)}") + kernel_fn = _KERNELS[VARIANT] + dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1) + a = torch.empty((M, K), dtype=DTYPE, dp=dp, name="a") + b = torch.empty((K, N), dtype=DTYPE, dp=dp, name="b") + out = torch.empty((M, N), dtype=DTYPE, dp=dp, name="out") + torch.launch(f"matmul_composite_{VARIANT}", kernel_fn, a, b, out, M, K, N) diff --git a/docs/diagrams/gemm_sweep.json b/docs/diagrams/gemm_sweep.json new file mode 100644 index 0000000..9ab6fbd --- /dev/null +++ b/docs/diagrams/gemm_sweep.json @@ -0,0 +1,1612 @@ +{ + "tile_sizes": { + "M": 32, + "K": 64, + "N": 32 + }, + "engines": [ + "pe_dma", + "pe_fetch_store", + "pe_gemm", + "pe_math" + ], + "stages": [ + "DMA_READ", + "DMA_WRITE", + "FETCH", + "STORE", + "GEMM", + "MATH" + ], + "variants": [ + "ref_ref", + "load_ref", + "load_load" + ], + "rows": [ + { + "M": 32, + "K": 32, + "N": 32, + "variant": "ref_ref", + "flops": 65536, + "bytes_hbm": 6144, + "arith_intensity": 10.666666666666666, + "tile_count_expected": 1, + "sim_wall_clock_s": 0.569, + "engines": { + "pe_dma": { + "occupancy_ns": 52.0, + "wall_ns": 52.0, + "record_count": 3 + }, + "pe_fetch_store": { + "occupancy_ns": 20.0, + "wall_ns": 20.0, + "record_count": 2 + }, + "pe_gemm": { + "occupancy_ns": 16.384000000000015, + "wall_ns": 16.384000000000015, + "record_count": 1 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 40.0, + "wall_ns": 40.0, + "record_count": 2 + }, + "DMA_WRITE": { + "occupancy_ns": 12.0, + "wall_ns": 12.0, + "record_count": 1 + }, + "FETCH": { + "occupancy_ns": 16.0, + "wall_ns": 16.0, + "record_count": 1 + }, + "STORE": { + "occupancy_ns": 4.0, + "wall_ns": 4.0, + "record_count": 1 + }, + "GEMM": { + "occupancy_ns": 16.384000000000015, + "wall_ns": 16.384000000000015, + "record_count": 1 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 88.38400000000001 + }, + { + "M": 32, + "K": 32, + "N": 32, + "variant": "load_ref", + "flops": 65536, + "bytes_hbm": 6144, + "arith_intensity": 10.666666666666666, + "tile_count_expected": 1, + "sim_wall_clock_s": 0.409, + "engines": { + "pe_dma": { + "occupancy_ns": 44.5, + "wall_ns": 44.5, + "record_count": 3 + }, + "pe_fetch_store": { + "occupancy_ns": 20.0, + "wall_ns": 20.0, + "record_count": 2 + }, + "pe_gemm": { + "occupancy_ns": 16.384000000000015, + "wall_ns": 16.384000000000015, + "record_count": 1 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 20.0, + "wall_ns": 20.0, + "record_count": 1 + }, + "DMA_WRITE": { + "occupancy_ns": 12.0, + "wall_ns": 12.0, + "record_count": 1 + }, + "FETCH": { + "occupancy_ns": 16.0, + "wall_ns": 16.0, + "record_count": 1 + }, + "STORE": { + "occupancy_ns": 4.0, + "wall_ns": 4.0, + "record_count": 1 + }, + "GEMM": { + "occupancy_ns": 16.384000000000015, + "wall_ns": 16.384000000000015, + "record_count": 1 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 81.894 + }, + { + "M": 32, + "K": 32, + "N": 32, + "variant": "load_load", + "flops": 65536, + "bytes_hbm": 6144, + "arith_intensity": 10.666666666666666, + "tile_count_expected": 1, + "sim_wall_clock_s": 0.567, + "engines": { + "pe_dma": { + "occupancy_ns": 37.0, + "wall_ns": 37.0, + "record_count": 3 + }, + "pe_fetch_store": { + "occupancy_ns": 20.0, + "wall_ns": 20.0, + "record_count": 2 + }, + "pe_gemm": { + "occupancy_ns": 16.384000000000015, + "wall_ns": 16.384000000000015, + "record_count": 1 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + }, + "DMA_WRITE": { + "occupancy_ns": 12.0, + "wall_ns": 12.0, + "record_count": 1 + }, + "FETCH": { + "occupancy_ns": 16.0, + "wall_ns": 16.0, + "record_count": 1 + }, + "STORE": { + "occupancy_ns": 4.0, + "wall_ns": 4.0, + "record_count": 1 + }, + "GEMM": { + "occupancy_ns": 16.384000000000015, + "wall_ns": 16.384000000000015, + "record_count": 1 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 75.404 + }, + { + "M": 32, + "K": 64, + "N": 32, + "variant": "ref_ref", + "flops": 131072, + "bytes_hbm": 10240, + "arith_intensity": 12.8, + "tile_count_expected": 1, + "sim_wall_clock_s": 0.838, + "engines": { + "pe_dma": { + "occupancy_ns": 52.0, + "wall_ns": 52.0, + "record_count": 3 + }, + "pe_fetch_store": { + "occupancy_ns": 20.0, + "wall_ns": 20.0, + "record_count": 2 + }, + "pe_gemm": { + "occupancy_ns": 16.384000000000015, + "wall_ns": 16.384000000000015, + "record_count": 1 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 40.0, + "wall_ns": 40.0, + "record_count": 2 + }, + "DMA_WRITE": { + "occupancy_ns": 12.0, + "wall_ns": 12.0, + "record_count": 1 + }, + "FETCH": { + "occupancy_ns": 16.0, + "wall_ns": 16.0, + "record_count": 1 + }, + "STORE": { + "occupancy_ns": 4.0, + "wall_ns": 4.0, + "record_count": 1 + }, + "GEMM": { + "occupancy_ns": 16.384000000000015, + "wall_ns": 16.384000000000015, + "record_count": 1 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 88.38400000000001 + }, + { + "M": 32, + "K": 64, + "N": 32, + "variant": "load_ref", + "flops": 131072, + "bytes_hbm": 10240, + "arith_intensity": 12.8, + "tile_count_expected": 1, + "sim_wall_clock_s": 1.097, + "engines": { + "pe_dma": { + "occupancy_ns": 52.5, + "wall_ns": 52.5, + "record_count": 3 + }, + "pe_fetch_store": { + "occupancy_ns": 20.0, + "wall_ns": 20.0, + "record_count": 2 + }, + "pe_gemm": { + "occupancy_ns": 16.384000000000015, + "wall_ns": 16.384000000000015, + "record_count": 1 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 20.0, + "wall_ns": 20.0, + "record_count": 1 + }, + "DMA_WRITE": { + "occupancy_ns": 12.0, + "wall_ns": 12.0, + "record_count": 1 + }, + "FETCH": { + "occupancy_ns": 16.0, + "wall_ns": 16.0, + "record_count": 1 + }, + "STORE": { + "occupancy_ns": 4.0, + "wall_ns": 4.0, + "record_count": 1 + }, + "GEMM": { + "occupancy_ns": 16.384000000000015, + "wall_ns": 16.384000000000015, + "record_count": 1 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 89.894 + }, + { + "M": 32, + "K": 64, + "N": 32, + "variant": "load_load", + "flops": 131072, + "bytes_hbm": 10240, + "arith_intensity": 12.8, + "tile_count_expected": 1, + "sim_wall_clock_s": 1.264, + "engines": { + "pe_dma": { + "occupancy_ns": 53.0, + "wall_ns": 53.0, + "record_count": 3 + }, + "pe_fetch_store": { + "occupancy_ns": 20.0, + "wall_ns": 20.0, + "record_count": 2 + }, + "pe_gemm": { + "occupancy_ns": 16.384000000000015, + "wall_ns": 16.384000000000015, + "record_count": 1 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + }, + "DMA_WRITE": { + "occupancy_ns": 12.0, + "wall_ns": 12.0, + "record_count": 1 + }, + "FETCH": { + "occupancy_ns": 16.0, + "wall_ns": 16.0, + "record_count": 1 + }, + "STORE": { + "occupancy_ns": 4.0, + "wall_ns": 4.0, + "record_count": 1 + }, + "GEMM": { + "occupancy_ns": 16.384000000000015, + "wall_ns": 16.384000000000015, + "record_count": 1 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 91.404 + }, + { + "M": 32, + "K": 128, + "N": 32, + "variant": "ref_ref", + "flops": 262144, + "bytes_hbm": 18432, + "arith_intensity": 14.222222222222221, + "tile_count_expected": 2, + "sim_wall_clock_s": 1.187, + "engines": { + "pe_dma": { + "occupancy_ns": 131.995, + "wall_ns": 80.0, + "record_count": 5 + }, + "pe_fetch_store": { + "occupancy_ns": 36.0, + "wall_ns": 36.0, + "record_count": 3 + }, + "pe_gemm": { + "occupancy_ns": 33.152000000000044, + "wall_ns": 32.76800000000003, + "record_count": 2 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 119.995, + "wall_ns": 68.0, + "record_count": 4 + }, + "DMA_WRITE": { + "occupancy_ns": 12.0, + "wall_ns": 12.0, + "record_count": 1 + }, + "FETCH": { + "occupancy_ns": 32.0, + "wall_ns": 32.0, + "record_count": 2 + }, + "STORE": { + "occupancy_ns": 4.0, + "wall_ns": 4.0, + "record_count": 1 + }, + "GEMM": { + "occupancy_ns": 33.152000000000044, + "wall_ns": 32.76800000000003, + "record_count": 2 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 100.76800000000003 + }, + { + "M": 32, + "K": 128, + "N": 32, + "variant": "load_ref", + "flops": 262144, + "bytes_hbm": 18432, + "arith_intensity": 14.222222222222221, + "tile_count_expected": 2, + "sim_wall_clock_s": 1.13, + "engines": { + "pe_dma": { + "occupancy_ns": 104.495, + "wall_ns": 84.5, + "record_count": 4 + }, + "pe_fetch_store": { + "occupancy_ns": 36.0, + "wall_ns": 36.0, + "record_count": 3 + }, + "pe_gemm": { + "occupancy_ns": 33.152000000000044, + "wall_ns": 32.76800000000003, + "record_count": 2 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 55.995000000000005, + "wall_ns": 36.0, + "record_count": 2 + }, + "DMA_WRITE": { + "occupancy_ns": 12.0, + "wall_ns": 12.0, + "record_count": 1 + }, + "FETCH": { + "occupancy_ns": 32.0, + "wall_ns": 32.0, + "record_count": 2 + }, + "STORE": { + "occupancy_ns": 4.0, + "wall_ns": 4.0, + "record_count": 1 + }, + "GEMM": { + "occupancy_ns": 33.152000000000044, + "wall_ns": 32.76800000000003, + "record_count": 2 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 106.27300000000002 + }, + { + "M": 32, + "K": 128, + "N": 32, + "variant": "load_load", + "flops": 262144, + "bytes_hbm": 18432, + "arith_intensity": 14.222222222222221, + "tile_count_expected": 2, + "sim_wall_clock_s": 1.113, + "engines": { + "pe_dma": { + "occupancy_ns": 85.0, + "wall_ns": 85.0, + "record_count": 3 + }, + "pe_fetch_store": { + "occupancy_ns": 51.995000000000005, + "wall_ns": 36.0, + "record_count": 3 + }, + "pe_gemm": { + "occupancy_ns": 33.152000000000044, + "wall_ns": 32.76800000000003, + "record_count": 2 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + }, + "DMA_WRITE": { + "occupancy_ns": 12.0, + "wall_ns": 12.0, + "record_count": 1 + }, + "FETCH": { + "occupancy_ns": 47.995000000000005, + "wall_ns": 32.0, + "record_count": 2 + }, + "STORE": { + "occupancy_ns": 4.0, + "wall_ns": 4.0, + "record_count": 1 + }, + "GEMM": { + "occupancy_ns": 33.152000000000044, + "wall_ns": 32.76800000000003, + "record_count": 2 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 123.78300000000002 + }, + { + "M": 32, + "K": 128, + "N": 128, + "variant": "ref_ref", + "flops": 1048576, + "bytes_hbm": 49152, + "arith_intensity": 21.333333333333332, + "tile_count_expected": 8, + "sim_wall_clock_s": 1.451, + "engines": { + "pe_dma": { + "occupancy_ns": 1687.995, + "wall_ns": 272.0, + "record_count": 20 + }, + "pe_fetch_store": { + "occupancy_ns": 201.6959999999999, + "wall_ns": 132.0, + "record_count": 12 + }, + "pe_gemm": { + "occupancy_ns": 136.0640000000003, + "wall_ns": 131.07200000000012, + "record_count": 8 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 1631.995, + "wall_ns": 260.0, + "record_count": 16 + }, + "DMA_WRITE": { + "occupancy_ns": 56.0, + "wall_ns": 40.0, + "record_count": 4 + }, + "FETCH": { + "occupancy_ns": 148.0, + "wall_ns": 132.0, + "record_count": 8 + }, + "STORE": { + "occupancy_ns": 53.69599999999991, + "wall_ns": 47.23199999999997, + "record_count": 4 + }, + "GEMM": { + "occupancy_ns": 136.0640000000003, + "wall_ns": 131.07200000000012, + "record_count": 8 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 297.9200000000001 + }, + { + "M": 32, + "K": 128, + "N": 128, + "variant": "load_ref", + "flops": 1048576, + "bytes_hbm": 49152, + "arith_intensity": 21.333333333333332, + "tile_count_expected": 8, + "sim_wall_clock_s": 1.269, + "engines": { + "pe_dma": { + "occupancy_ns": 700.495, + "wall_ns": 180.5, + "record_count": 13 + }, + "pe_fetch_store": { + "occupancy_ns": 201.6959999999999, + "wall_ns": 132.0, + "record_count": 12 + }, + "pe_gemm": { + "occupancy_ns": 136.0640000000003, + "wall_ns": 131.07200000000012, + "record_count": 8 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 607.995, + "wall_ns": 132.0, + "record_count": 8 + }, + "DMA_WRITE": { + "occupancy_ns": 56.0, + "wall_ns": 40.0, + "record_count": 4 + }, + "FETCH": { + "occupancy_ns": 148.0, + "wall_ns": 132.0, + "record_count": 8 + }, + "STORE": { + "occupancy_ns": 53.69599999999991, + "wall_ns": 47.23199999999997, + "record_count": 4 + }, + "GEMM": { + "occupancy_ns": 136.0640000000003, + "wall_ns": 131.07200000000012, + "record_count": 8 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 207.42500000000007 + }, + { + "M": 32, + "K": 128, + "N": 128, + "variant": "load_load", + "flops": 1048576, + "bytes_hbm": 49152, + "arith_intensity": 21.333333333333332, + "tile_count_expected": 8, + "sim_wall_clock_s": 1.225, + "engines": { + "pe_dma": { + "occupancy_ns": 217.0, + "wall_ns": 217.0, + "record_count": 6 + }, + "pe_fetch_store": { + "occupancy_ns": 591.995, + "wall_ns": 128.0, + "record_count": 12 + }, + "pe_gemm": { + "occupancy_ns": 141.82400000000052, + "wall_ns": 131.07200000000012, + "record_count": 8 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + }, + "DMA_WRITE": { + "occupancy_ns": 48.0, + "wall_ns": 48.0, + "record_count": 4 + }, + "FETCH": { + "occupancy_ns": 575.995, + "wall_ns": 128.0, + "record_count": 8 + }, + "STORE": { + "occupancy_ns": 16.0, + "wall_ns": 16.0, + "record_count": 4 + }, + "GEMM": { + "occupancy_ns": 141.82400000000052, + "wall_ns": 131.07200000000012, + "record_count": 8 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 318.0870000000001 + }, + { + "M": 32, + "K": 3072, + "N": 32, + "variant": "ref_ref", + "flops": 6291456, + "bytes_hbm": 395264, + "arith_intensity": 15.917098445595855, + "tile_count_expected": 48, + "sim_wall_clock_s": 2.724, + "engines": { + "pe_dma": { + "occupancy_ns": 55883.995, + "wall_ns": 1552.0, + "record_count": 97 + }, + "pe_fetch_store": { + "occupancy_ns": 791.1039999999994, + "wall_ns": 772.0, + "record_count": 49 + }, + "pe_gemm": { + "occupancy_ns": 1215.584000000017, + "wall_ns": 786.4320000000007, + "record_count": 48 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 55871.995, + "wall_ns": 1540.0, + "record_count": 96 + }, + "DMA_WRITE": { + "occupancy_ns": 12.0, + "wall_ns": 12.0, + "record_count": 1 + }, + "FETCH": { + "occupancy_ns": 772.0, + "wall_ns": 772.0, + "record_count": 48 + }, + "STORE": { + "occupancy_ns": 19.10399999999936, + "wall_ns": 19.10399999999936, + "record_count": 1 + }, + "GEMM": { + "occupancy_ns": 1215.584000000017, + "wall_ns": 786.4320000000007, + "record_count": 48 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 1590.4320000000007 + }, + { + "M": 32, + "K": 3072, + "N": 32, + "variant": "load_ref", + "flops": 6291456, + "bytes_hbm": 395264, + "arith_intensity": 15.917098445595855, + "tile_count_expected": 48, + "sim_wall_clock_s": 2.137, + "engines": { + "pe_dma": { + "occupancy_ns": 19792.495, + "wall_ns": 1556.5, + "record_count": 50 + }, + "pe_fetch_store": { + "occupancy_ns": 791.1039999999994, + "wall_ns": 772.0, + "record_count": 49 + }, + "pe_gemm": { + "occupancy_ns": 1215.584000000017, + "wall_ns": 786.4320000000007, + "record_count": 48 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 19007.995, + "wall_ns": 772.0, + "record_count": 48 + }, + "DMA_WRITE": { + "occupancy_ns": 12.0, + "wall_ns": 12.0, + "record_count": 1 + }, + "FETCH": { + "occupancy_ns": 772.0, + "wall_ns": 772.0, + "record_count": 48 + }, + "STORE": { + "occupancy_ns": 19.10399999999936, + "wall_ns": 19.10399999999936, + "record_count": 1 + }, + "GEMM": { + "occupancy_ns": 1215.584000000017, + "wall_ns": 786.4320000000007, + "record_count": 48 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 1595.9370000000008 + }, + { + "M": 32, + "K": 3072, + "N": 32, + "variant": "load_load", + "flops": 6291456, + "bytes_hbm": 395264, + "arith_intensity": 15.917098445595855, + "tile_count_expected": 48, + "sim_wall_clock_s": 1.245, + "engines": { + "pe_dma": { + "occupancy_ns": 1557.0, + "wall_ns": 1557.0, + "record_count": 3 + }, + "pe_fetch_store": { + "occupancy_ns": 18819.99500000001, + "wall_ns": 772.0000000000002, + "record_count": 49 + }, + "pe_gemm": { + "occupancy_ns": 1219.5839999999987, + "wall_ns": 786.4320000000005, + "record_count": 48 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + }, + "DMA_WRITE": { + "occupancy_ns": 12.0, + "wall_ns": 12.0, + "record_count": 1 + }, + "FETCH": { + "occupancy_ns": 18815.99500000001, + "wall_ns": 768.0000000000002, + "record_count": 48 + }, + "STORE": { + "occupancy_ns": 4.0, + "wall_ns": 4.0, + "record_count": 1 + }, + "GEMM": { + "occupancy_ns": 1219.5839999999987, + "wall_ns": 786.4320000000005, + "record_count": 48 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 2349.447000000001 + }, + { + "M": 8, + "K": 128, + "N": 128, + "variant": "ref_ref", + "flops": 262144, + "bytes_hbm": 36864, + "arith_intensity": 7.111111111111111, + "tile_count_expected": 8, + "sim_wall_clock_s": 1.477, + "engines": { + "pe_dma": { + "occupancy_ns": 1687.995, + "wall_ns": 272.0, + "record_count": 20 + }, + "pe_fetch_store": { + "occupancy_ns": 201.6959999999999, + "wall_ns": 132.0, + "record_count": 12 + }, + "pe_gemm": { + "occupancy_ns": 136.0640000000003, + "wall_ns": 131.07200000000012, + "record_count": 8 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 1631.995, + "wall_ns": 260.0, + "record_count": 16 + }, + "DMA_WRITE": { + "occupancy_ns": 56.0, + "wall_ns": 40.0, + "record_count": 4 + }, + "FETCH": { + "occupancy_ns": 148.0, + "wall_ns": 132.0, + "record_count": 8 + }, + "STORE": { + "occupancy_ns": 53.69599999999991, + "wall_ns": 47.23199999999997, + "record_count": 4 + }, + "GEMM": { + "occupancy_ns": 136.0640000000003, + "wall_ns": 131.07200000000012, + "record_count": 8 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 297.9200000000001 + }, + { + "M": 8, + "K": 128, + "N": 128, + "variant": "load_ref", + "flops": 262144, + "bytes_hbm": 36864, + "arith_intensity": 7.111111111111111, + "tile_count_expected": 8, + "sim_wall_clock_s": 1.443, + "engines": { + "pe_dma": { + "occupancy_ns": 676.495, + "wall_ns": 156.5, + "record_count": 13 + }, + "pe_fetch_store": { + "occupancy_ns": 201.6959999999999, + "wall_ns": 132.0, + "record_count": 12 + }, + "pe_gemm": { + "occupancy_ns": 136.0640000000003, + "wall_ns": 131.07200000000012, + "record_count": 8 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 607.995, + "wall_ns": 132.0, + "record_count": 8 + }, + "DMA_WRITE": { + "occupancy_ns": 56.0, + "wall_ns": 40.0, + "record_count": 4 + }, + "FETCH": { + "occupancy_ns": 148.0, + "wall_ns": 132.0, + "record_count": 8 + }, + "STORE": { + "occupancy_ns": 53.69599999999991, + "wall_ns": 47.23199999999997, + "record_count": 4 + }, + "GEMM": { + "occupancy_ns": 136.0640000000003, + "wall_ns": 131.07200000000012, + "record_count": 8 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 183.42500000000007 + }, + { + "M": 8, + "K": 128, + "N": 128, + "variant": "load_load", + "flops": 262144, + "bytes_hbm": 36864, + "arith_intensity": 7.111111111111111, + "tile_count_expected": 8, + "sim_wall_clock_s": 1.198, + "engines": { + "pe_dma": { + "occupancy_ns": 193.0, + "wall_ns": 193.0, + "record_count": 6 + }, + "pe_fetch_store": { + "occupancy_ns": 591.995, + "wall_ns": 128.0, + "record_count": 12 + }, + "pe_gemm": { + "occupancy_ns": 141.82400000000052, + "wall_ns": 131.07200000000012, + "record_count": 8 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + }, + "DMA_WRITE": { + "occupancy_ns": 48.0, + "wall_ns": 48.0, + "record_count": 4 + }, + "FETCH": { + "occupancy_ns": 575.995, + "wall_ns": 128.0, + "record_count": 8 + }, + "STORE": { + "occupancy_ns": 16.0, + "wall_ns": 16.0, + "record_count": 4 + }, + "GEMM": { + "occupancy_ns": 141.82400000000052, + "wall_ns": 131.07200000000012, + "record_count": 8 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 294.0870000000001 + }, + { + "M": 128, + "K": 8, + "N": 128, + "variant": "ref_ref", + "flops": 262144, + "bytes_hbm": 36864, + "arith_intensity": 7.111111111111111, + "tile_count_expected": 16, + "sim_wall_clock_s": 1.983, + "engines": { + "pe_dma": { + "occupancy_ns": 6547.771, + "wall_ns": 560.0, + "record_count": 48 + }, + "pe_fetch_store": { + "occupancy_ns": 481.72799999999916, + "wall_ns": 268.0, + "record_count": 32 + }, + "pe_gemm": { + "occupancy_ns": 293.2480000000014, + "wall_ns": 262.14400000000023, + "record_count": 16 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 6335.995, + "wall_ns": 516.0, + "record_count": 32 + }, + "DMA_WRITE": { + "occupancy_ns": 211.77599999999984, + "wall_ns": 169.15200000000004, + "record_count": 16 + }, + "FETCH": { + "occupancy_ns": 308.0, + "wall_ns": 260.0, + "record_count": 16 + }, + "STORE": { + "occupancy_ns": 173.72799999999916, + "wall_ns": 164.2559999999994, + "record_count": 16 + }, + "GEMM": { + "occupancy_ns": 293.2480000000014, + "wall_ns": 262.14400000000023, + "record_count": 16 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 572.9920000000002 + }, + { + "M": 128, + "K": 8, + "N": 128, + "variant": "load_ref", + "flops": 262144, + "bytes_hbm": 36864, + "arith_intensity": 7.111111111111111, + "tile_count_expected": 16, + "sim_wall_clock_s": 1.699, + "engines": { + "pe_dma": { + "occupancy_ns": 2464.2709999999997, + "wall_ns": 316.5, + "record_count": 33 + }, + "pe_fetch_store": { + "occupancy_ns": 481.72799999999916, + "wall_ns": 268.0, + "record_count": 32 + }, + "pe_gemm": { + "occupancy_ns": 293.2480000000014, + "wall_ns": 262.14400000000023, + "record_count": 16 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 2239.995, + "wall_ns": 260.0, + "record_count": 16 + }, + "DMA_WRITE": { + "occupancy_ns": 211.77599999999984, + "wall_ns": 169.15200000000004, + "record_count": 16 + }, + "FETCH": { + "occupancy_ns": 308.0, + "wall_ns": 260.0, + "record_count": 16 + }, + "STORE": { + "occupancy_ns": 173.72799999999916, + "wall_ns": 164.2559999999994, + "record_count": 16 + }, + "GEMM": { + "occupancy_ns": 293.2480000000014, + "wall_ns": 262.14400000000023, + "record_count": 16 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 330.4970000000002 + }, + { + "M": 128, + "K": 8, + "N": 128, + "variant": "load_load", + "flops": 262144, + "bytes_hbm": 36864, + "arith_intensity": 7.111111111111111, + "tile_count_expected": 16, + "sim_wall_clock_s": 1.402, + "engines": { + "pe_dma": { + "occupancy_ns": 217.0, + "wall_ns": 217.0, + "record_count": 18 + }, + "pe_fetch_store": { + "occupancy_ns": 2239.995, + "wall_ns": 264.0, + "record_count": 32 + }, + "pe_gemm": { + "occupancy_ns": 308.224000000002, + "wall_ns": 262.14400000000023, + "record_count": 16 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + }, + "DMA_WRITE": { + "occupancy_ns": 192.0, + "wall_ns": 192.0, + "record_count": 16 + }, + "FETCH": { + "occupancy_ns": 2175.995, + "wall_ns": 256.0, + "record_count": 16 + }, + "STORE": { + "occupancy_ns": 64.0, + "wall_ns": 64.0, + "record_count": 16 + }, + "GEMM": { + "occupancy_ns": 308.224000000002, + "wall_ns": 262.14400000000023, + "record_count": 16 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 321.1590000000002 + }, + { + "M": 512, + "K": 512, + "N": 512, + "variant": "ref_ref", + "flops": 268435456, + "bytes_hbm": 1572864, + "arith_intensity": 170.66666666666666, + "tile_count_expected": 2048, + "sim_wall_clock_s": 89.111, + "engines": { + "pe_dma": { + "occupancy_ns": 100690943.995, + "wall_ns": 65612.00000000001, + "record_count": 4352 + }, + "pe_fetch_store": { + "occupancy_ns": 43566.52800034459, + "wall_ns": 32796.00000000001, + "record_count": 2304 + }, + "pe_gemm": { + "occupancy_ns": 833762.8159962555, + "wall_ns": 33554.431999996836, + "record_count": 2048 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 100687871.995, + "wall_ns": 65540.00000000001, + "record_count": 4096 + }, + "DMA_WRITE": { + "occupancy_ns": 3072.0, + "wall_ns": 3072.0, + "record_count": 256 + }, + "FETCH": { + "occupancy_ns": 40936.00000000001, + "wall_ns": 32772.00000000001, + "record_count": 2048 + }, + "STORE": { + "occupancy_ns": 2630.5280003445805, + "wall_ns": 2630.5280003445805, + "record_count": 256 + }, + "GEMM": { + "occupancy_ns": 833762.8159962555, + "wall_ns": 33554.431999996836, + "record_count": 2048 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 66360.12799999685 + }, + { + "M": 512, + "K": 512, + "N": 512, + "variant": "load_ref", + "flops": 268435456, + "bytes_hbm": 1572864, + "arith_intensity": 170.66666666666666, + "tile_count_expected": 2048, + "sim_wall_clock_s": 48.616, + "engines": { + "pe_dma": { + "occupancy_ns": 33584132.495, + "wall_ns": 34896.5, + "record_count": 2305 + }, + "pe_fetch_store": { + "occupancy_ns": 43562.81600011295, + "wall_ns": 32796.0, + "record_count": 2304 + }, + "pe_gemm": { + "occupancy_ns": 833762.8159987241, + "wall_ns": 33554.43199999785, + "record_count": 2048 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 33579007.995, + "wall_ns": 32772.0, + "record_count": 2048 + }, + "DMA_WRITE": { + "occupancy_ns": 3072.0, + "wall_ns": 3072.0, + "record_count": 256 + }, + "FETCH": { + "occupancy_ns": 40936.0, + "wall_ns": 32772.0, + "record_count": 2048 + }, + "STORE": { + "occupancy_ns": 2626.816000112947, + "wall_ns": 2626.816000112947, + "record_count": 256 + }, + "GEMM": { + "occupancy_ns": 833762.8159987241, + "wall_ns": 33554.43199999785, + "record_count": 2048 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 35645.632999997855 + }, + { + "M": 512, + "K": 512, + "N": 512, + "variant": "load_load", + "flops": 268435456, + "bytes_hbm": 1572864, + "arith_intensity": 170.66666666666666, + "tile_count_expected": 2048, + "sim_wall_clock_s": 7.072, + "engines": { + "pe_dma": { + "occupancy_ns": 7177.0, + "wall_ns": 7177.0, + "record_count": 258 + }, + "pe_fetch_store": { + "occupancy_ns": 33571839.995, + "wall_ns": 32792.0, + "record_count": 2304 + }, + "pe_gemm": { + "occupancy_ns": 838467.5839984363, + "wall_ns": 33554.43199999763, + "record_count": 2048 + }, + "pe_math": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "stages": { + "DMA_READ": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + }, + "DMA_WRITE": { + "occupancy_ns": 3072.0, + "wall_ns": 3072.0, + "record_count": 256 + }, + "FETCH": { + "occupancy_ns": 33570815.995, + "wall_ns": 32767.999999999996, + "record_count": 2048 + }, + "STORE": { + "occupancy_ns": 1024.0, + "wall_ns": 1024.0, + "record_count": 256 + }, + "GEMM": { + "occupancy_ns": 838467.5839984363, + "wall_ns": 33554.43199999763, + "record_count": 2048 + }, + "MATH": { + "occupancy_ns": 0, + "wall_ns": 0.0, + "record_count": 0 + } + }, + "pe_window_ns": 37677.44699999763 + } + ] +} \ No newline at end of file diff --git a/docs/diagrams/kernbench2_overview.pptx b/docs/diagrams/kernbench2_overview.pptx index 0941ab1..dde76aa 100644 Binary files a/docs/diagrams/kernbench2_overview.pptx and b/docs/diagrams/kernbench2_overview.pptx differ diff --git a/scripts/build_overview_slides.py b/scripts/build_overview_slides.py index 1219762..2eb0df5 100644 --- a/scripts/build_overview_slides.py +++ b/scripts/build_overview_slides.py @@ -1,11 +1,14 @@ -"""Generate a 5-slide PPTX summarizing the kernbench2 model. +"""Generate a multi-slide PPTX summarizing the kernbench2 model. -Slides (in order): +Slides: 1. Overall architecture — how PEs are connected (cube_mesh_view) 2. Model correctness — DMA vs P2P latency (pe2pe overview) 3. PE-to-PE IPCQ communication (ipcq_two_pe_dma) 4. 6-device allreduce — model vs theoretical vs ext-sim (overview_broken) 5. IPCQ buffer-kind sweep — TCM vs SRAM vs HBM (buffer_kind_sweep) + 6. PE_accelerator data path (composite GEMM pipeline structure) + 7. matmul(32, 128, 32) — composite GEMM execution sequence + 8. matmul(32, 128, 128) — pipeline scaling and HBM contention This is a derived-artifact generator — no production code touched. """ @@ -17,6 +20,7 @@ from PIL import Image from pptx import Presentation from pptx.dml.color import RGBColor from pptx.enum.shapes import MSO_SHAPE +from pptx.enum.text import PP_ALIGN from pptx.util import Emu, Inches, Pt ROOT = Path(__file__).resolve().parent.parent @@ -77,9 +81,80 @@ SLIDES = [ "At 64 KB / PE: TCM 12.0 µs < HBM 21.4 µs < SRAM 24.3 µs — SRAM is slowest because of its narrow bank link", ], }, + { + "title": "6. PE_accelerator Data Path: Composite GEMM Pipeline", + "render": "pipeline_structure", + }, + { + "title": "7. PE_SCHEDULER: Plan Generation & Tile Dispatch", + "render": "scheduler", + }, + { + "title": "8. matmul(32, 128, 32) — Composite GEMM Execution Sequence", + "render": "sequence_32x128x32", + }, + { + "title": "9. matmul(32, 128, 128) — Pipeline Scaling & HBM Contention", + "render": "sequence_32x128x128", + }, + { + "title": "10. Tiling Walkthrough: 32×128×32 — K-loop Only, No Inter-(m,n) Flush", + "render": "tiling_32x128x32", + }, + { + "title": "11. Tiling Walkthrough: 32×128×128 — K-loop & Inter-(m,n) Flushes", + "render": "tiling_32x128x128", + }, + { + "title": "12. GEMM Sweep — Stage Wall-Clock (load_ref)", + "render": "stage_breakdown_load_ref", + }, + { + "title": "13. Why DMA Isn't Local: Cube-Shared HBM Path", + "render": "hbm_topology", + }, + { + "title": "14. GEMM Utilization + Useful Pipeline Efficiency (load_ref)", + "render": "mac_utilization", + }, + { + "title": "15. GEMM Utilization + Useful Pipeline Efficiency (ref_ref — both A & B via DMA_R)", + "render": "mac_utilization_ref_ref", + }, + { + "title": "16. Pipeline Efficiency Walkthrough — 32×128×128 (with inter flushes)", + "render": "pipeline_eff_walkthrough", + }, + { + "title": "17. Pipeline Efficiency Walkthrough — 32×3072×32 (large K, no flushes)", + "render": "pipeline_eff_walkthrough_largeK", + }, + { + "title": "18. Useful Pipelined Efficiency (ideal pipeline × GEMM util)", + "render": "tflops_table", + }, ] +# ── Palette for the shape-drawn slides ───────────────────────────────────── + +COL_TEXT_DARK = RGBColor(0x1E, 0x29, 0x3B) +COL_TEXT_LIGHT = RGBColor(0xFF, 0xFF, 0xFF) +COL_MUTED = RGBColor(0x47, 0x55, 0x69) +COL_RED = RGBColor(0xDC, 0x26, 0x26) +COL_STORAGE = RGBColor(0xE2, 0xE8, 0xF0) +COL_STORAGE_STROKE = RGBColor(0x47, 0x55, 0x69) +COL_DMA = RGBColor(0x3B, 0x82, 0xF6) # blue +COL_FS = RGBColor(0x10, 0xB9, 0x81) # emerald +COL_GEMM = RGBColor(0xF5, 0x9E, 0x0B) # amber +COL_HBM_BG = RGBColor(0xDB, 0xEA, 0xFE) +COL_TCM_BG = RGBColor(0xD1, 0xFA, 0xE5) +COL_REG_BG = RGBColor(0xFE, 0xF3, 0xC7) +COL_HBM_BORDER = COL_DMA +COL_TCM_BORDER = COL_FS +COL_REG_BORDER = RGBColor(0xD9, 0x77, 0x06) + + def _add_title(slide, text): left = Inches(0.4) top = Inches(0.25) @@ -139,6 +214,2082 @@ def _add_footer(slide, idx, total): run.font.color.rgb = RGBColor(0x88, 0x88, 0x88) +# ── Shape-drawing primitives for the diagram slides ──────────────────────── + +def _txt(box, text, size=11, bold=False, color=COL_TEXT_DARK, + align=PP_ALIGN.CENTER): + tf = box.text_frame + tf.word_wrap = True + tf.margin_left = Emu(18000) + tf.margin_right = Emu(18000) + tf.margin_top = Emu(9000) + tf.margin_bottom = Emu(9000) + lines = text.split("\n") + for i, ln in enumerate(lines): + p = tf.paragraphs[0] if i == 0 else tf.add_paragraph() + p.alignment = align + p.text = "" + run = p.add_run() + run.text = ln + run.font.size = Pt(size) + run.font.bold = bold + run.font.name = "Consolas" + run.font.color.rgb = color + + +def _textbox(slide, x, y, w, h, text, **kw): + tb = slide.shapes.add_textbox(Inches(x), Inches(y), Inches(w), Inches(h)) + _txt(tb, text, **kw) + return tb + + +def _rrect(slide, x, y, w, h, fill, stroke, text="", **kw): + s = slide.shapes.add_shape( + MSO_SHAPE.ROUNDED_RECTANGLE, + Inches(x), Inches(y), Inches(w), Inches(h), + ) + s.fill.solid() + s.fill.fore_color.rgb = fill + s.line.color.rgb = stroke + s.line.width = Pt(1.3) + if text: + _txt(s, text, **kw) + return s + + +def _rect_band(slide, x, y, w, h, fill, stroke): + s = slide.shapes.add_shape( + MSO_SHAPE.RECTANGLE, Inches(x), Inches(y), Inches(w), Inches(h), + ) + s.fill.solid() + s.fill.fore_color.rgb = fill + s.line.color.rgb = stroke + s.line.width = Pt(1.0) + return s + + +def _arrow(slide, x1, y1, x2, y2, color=COL_MUTED, width_pt=1.5): + conn = slide.shapes.add_connector( + 2, Inches(x1), Inches(y1), Inches(x2), Inches(y2), + ) + conn.line.color.rgb = color + conn.line.width = Pt(width_pt) + from pptx.oxml.ns import qn + from lxml import etree + ln = conn.line._get_or_add_ln() + tail = ln.find(qn("a:tailEnd")) + if tail is None: + tail = etree.SubElement(ln, qn("a:tailEnd")) + tail.set("type", "triangle") + tail.set("w", "med") + tail.set("len", "med") + return conn + + +def _vline_dashed(slide, x, y1, y2, color, width_pt=2.0): + conn = slide.shapes.add_connector( + 1, Inches(x), Inches(y1), Inches(x), Inches(y2), + ) + conn.line.color.rgb = color + conn.line.width = Pt(width_pt) + from pptx.oxml.ns import qn + from lxml import etree + ln = conn.line._get_or_add_ln() + pr = ln.find(qn("a:prstDash")) + if pr is None: + pr = etree.SubElement(ln, qn("a:prstDash")) + pr.set("val", "dash") + return conn + + +# ── Slide 6: Pipeline structure ──────────────────────────────────────────── + +def _render_pipeline_structure(slide): + """Vertical hardware datapath. + + HBM at top → DMA_in → GEMM Unit (FETCH + RegFile inside) → DMA_out + (off-page to HBM, not drawn). TCM is on the right, vertical, acting + as the staging buffer that DMA_in writes into, FETCH reads from, and + that the GEMM unit STOREs back into before DMA_out drains it. + """ + READ_COLOR = COL_DMA + WRITE_COLOR = RGBColor(0xEA, 0x58, 0x0C) + + _textbox(slide, 0.4, 1.0, 12.6, 0.5, + "Vertical flow. HBM → DMA_in → TCM → GEMM Unit (FETCH + " + "RegFile) → TCM → DMA_out → HBM. TCM sits on the side as the " + "staging buffer between the DMA engines and the GEMM unit.", + size=12, color=COL_MUTED, align=PP_ALIGN.LEFT) + + # ── Center column: HBM → DMA_in → GEMM Unit → DMA_out ──────────── + col_cx = 3.5 + box_w = 3.0 + box_x = col_cx - box_w / 2 + box_h = 0.50 + + y_hbm = 1.65 + y_dma_in = 2.55 + y_gemm = 3.45 + gemm_h = 2.40 + y_dma_out = y_gemm + gemm_h + 0.30 + out_arr_end_y = y_dma_out + box_h + 0.40 + + # HBM + _rrect(slide, box_x, y_hbm, box_w, box_h, + COL_HBM_BG, COL_HBM_BORDER, + "HBM (off-chip, 256 GB/s)", + size=11, bold=True, color=COL_HBM_BORDER) + + # DMA_in + _rrect(slide, box_x, y_dma_in, box_w, box_h, + COL_DMA, COL_TEXT_DARK, + "DMA_in", + size=13, bold=True, color=COL_TEXT_LIGHT) + + # GEMM Unit container (slightly wider than DMA boxes) + gemm_x = box_x - 0.7 + gemm_w = box_w + 1.4 + _rrect(slide, gemm_x, y_gemm, gemm_w, gemm_h, + RGBColor(0xFF, 0xFB, 0xEB), + COL_GEMM, "", size=10, color=COL_GEMM) + _textbox(slide, gemm_x + 0.20, y_gemm + 0.08, gemm_w - 0.4, 0.32, + "GEMM Unit", + size=14, bold=True, color=COL_GEMM, align=PP_ALIGN.LEFT) + + # Inside GEMM Unit: FETCH (top), RegFile (below), MAC label (bottom) + sub_h = 0.50 + inner_x = gemm_x + 0.5 + inner_w = gemm_w - 1.0 + + fetch_y = y_gemm + 0.55 + reg_y = fetch_y + sub_h + 0.30 + + _rrect(slide, inner_x, fetch_y, inner_w, sub_h, + COL_FS, COL_TEXT_DARK, + "FETCH unit", + size=11, bold=True, color=COL_TEXT_LIGHT) + _rrect(slide, inner_x, reg_y, inner_w, sub_h, + COL_REG_BG, COL_REG_BORDER, + "RegFile (A, B, C accumulator)", + size=11, bold=True, color=COL_REG_BORDER) + + # MAC annotation (text-only) — implicit MAC operation on RegFile + mac_label_y = reg_y + sub_h + 0.05 + _textbox(slide, inner_x, mac_label_y, inner_w, 0.32, + "↻ MAC accumulate (32 × 64 × 32 array)", + size=11, bold=True, color=COL_GEMM, align=PP_ALIGN.CENTER) + + # FETCH → RegFile (internal, vertical down) + inner_cx = inner_x + inner_w / 2 + _arrow(slide, inner_cx, fetch_y + sub_h + 0.02, + inner_cx, reg_y - 0.02, + color=COL_FS, width_pt=1.8) + + # DMA_out (below GEMM Unit) + _rrect(slide, box_x, y_dma_out, box_w, box_h, + WRITE_COLOR, COL_TEXT_DARK, + "DMA_out", + size=13, bold=True, color=COL_TEXT_LIGHT) + + # ── TCM on the right (tall vertical) ───────────────────────────── + tcm_x = 9.4 + tcm_w = 2.6 + tcm_y = y_dma_in - 0.05 + tcm_h = (y_dma_out + box_h + 0.05) - tcm_y + _rrect(slide, tcm_x, tcm_y, tcm_w, tcm_h, + COL_TCM_BG, COL_TCM_BORDER, "", + size=12, color=COL_TCM_BORDER) + _textbox(slide, tcm_x + 0.1, tcm_y + tcm_h / 2 - 0.6, + tcm_w - 0.2, 1.2, + "TCM\n\n(PE-local SRAM,\n512 GB/s)", + size=14, bold=True, color=COL_TCM_BORDER, align=PP_ALIGN.CENTER) + + # ── Wires ─────────────────────────────────────────────────────── + # HBM → DMA_in (vertical down) + _arrow(slide, col_cx, y_hbm + box_h + 0.02, col_cx, y_dma_in - 0.02, + color=READ_COLOR, width_pt=2.5) + _textbox(slide, col_cx + 0.15, (y_hbm + box_h + y_dma_in) / 2 - 0.10, + 1.5, 0.22, "DMA_R", size=10, bold=True, + color=READ_COLOR, align=PP_ALIGN.LEFT) + + # DMA_in → TCM (horizontal right, at DMA_in y) + dma_in_cy = y_dma_in + box_h / 2 + _arrow(slide, box_x + box_w + 0.02, dma_in_cy, + tcm_x - 0.02, dma_in_cy, + color=READ_COLOR, width_pt=2.5) + mid_x = (box_x + box_w + tcm_x) / 2 + _textbox(slide, mid_x - 1.4, dma_in_cy - 0.32, + 2.8, 0.22, "store to TCM", + size=10, bold=True, color=READ_COLOR, align=PP_ALIGN.CENTER) + + # TCM → FETCH (horizontal left, into FETCH inside GEMM Unit) + fetch_right = inner_x + inner_w + fetch_cy = fetch_y + sub_h / 2 + _arrow(slide, tcm_x - 0.02, fetch_cy, + fetch_right + 0.02, fetch_cy, + color=COL_FS, width_pt=2.5) + _textbox(slide, (fetch_right + tcm_x) / 2 - 1.4, fetch_cy - 0.32, + 2.8, 0.22, "FETCH (TCM → Reg)", + size=10, bold=True, color=COL_FS, align=PP_ALIGN.CENTER) + + # GEMM (RegFile) → TCM (STORE, horizontal right exit at RegFile y) + reg_right = inner_x + inner_w + store_cy = reg_y + sub_h / 2 + _arrow(slide, reg_right + 0.02, store_cy, + tcm_x - 0.02, store_cy, + color=WRITE_COLOR, width_pt=2.5) + _textbox(slide, (reg_right + tcm_x) / 2 - 1.6, store_cy + 0.08, + 3.2, 0.22, "STORE (Reg → TCM, after last K)", + size=10, bold=True, color=WRITE_COLOR, align=PP_ALIGN.CENTER) + + # TCM → DMA_out (horizontal left at DMA_out y) + dma_out_cy = y_dma_out + box_h / 2 + _arrow(slide, tcm_x - 0.02, dma_out_cy, + box_x + box_w + 0.02, dma_out_cy, + color=WRITE_COLOR, width_pt=2.5) + _textbox(slide, mid_x - 1.4, dma_out_cy - 0.32, + 2.8, 0.22, "read from TCM", + size=10, bold=True, color=WRITE_COLOR, align=PP_ALIGN.CENTER) + + # DMA_out → HBM (off-page; just an arrow + label, no HBM box at bottom) + _arrow(slide, col_cx, y_dma_out + box_h + 0.02, + col_cx, out_arr_end_y - 0.02, + color=WRITE_COLOR, width_pt=2.5) + _textbox(slide, col_cx + 0.15, out_arr_end_y - 0.30, + 3.0, 0.22, "DMA_W → HBM", + size=10, bold=True, color=WRITE_COLOR, align=PP_ALIGN.LEFT) + + +# ── Slide 7: PE_SCHEDULER ────────────────────────────────────────────────── + +def _render_scheduler(slide): + """PE_SCHEDULER: sole command dispatcher inside a PE. + + Two paths: simple cmd → direct engine dispatch; CompositeCmd → + generate_plan + FIFO feed → TileToken self-routes through stages. + """ + SCHED_FILL = RGBColor(0xF3, 0xE8, 0xFF) + SCHED_BORDER = RGBColor(0x7C, 0x3A, 0xED) + PE_MATH_COL = RGBColor(0x9C, 0xA3, 0xAF) + + _textbox(slide, 0.4, 1.0, 12.6, 0.5, + "Sole dispatcher inside a PE. CompositeCmd is expanded into a " + "TilePlan and fed tile-by-tile in FIFO order; each TileToken " + "self-routes through the pipeline stages.", + size=12, color=COL_MUTED, align=PP_ALIGN.LEFT) + + # SCHEDULER container (compact — left side of slide) + sched_x, sched_y = 0.6, 2.45 + sched_w, sched_h = 4.6, 2.4 + + # PE_CPU box (centered above scheduler) + pe_cpu_w, pe_cpu_h = 2.0, 0.50 + pe_cpu_x = sched_x + sched_w / 2 - pe_cpu_w / 2 + pe_cpu_y = 1.65 + _rrect(slide, pe_cpu_x, pe_cpu_y, pe_cpu_w, pe_cpu_h, + COL_STORAGE, COL_STORAGE_STROKE, + "PE_CPU", + size=12, bold=True, color=COL_TEXT_DARK) + + _rrect(slide, sched_x, sched_y, sched_w, sched_h, + SCHED_FILL, SCHED_BORDER, "", + size=10, color=SCHED_BORDER) + _textbox(slide, sched_x + 0.15, sched_y + 0.08, sched_w - 0.30, 0.32, + "PE_SCHEDULER", + size=13, bold=True, color=SCHED_BORDER, align=PP_ALIGN.LEFT) + + # CompositeCmd description — text only, no inner box + text_x = sched_x + 0.25 + text_y = sched_y + 0.55 + text_w = sched_w - 0.50 + _textbox(slide, text_x, text_y, text_w, 0.30, + "CompositeCmd → generate plan", + size=12, bold=True, color=COL_TEXT_DARK, align=PP_ALIGN.LEFT) + _textbox(slide, text_x + 0.10, text_y + 0.40, text_w - 0.15, + sched_h - 1.05, + "generate_plan(M, K, N)\n" + " → ⌈M/32⌉ × ⌈K/64⌉ × ⌈N/32⌉ tiles\n" + " each tile:\n" + " [DMA_R, FETCH, GEMM,\n" + " STORE, DMA_W] stages", + size=11, color=COL_MUTED, align=PP_ALIGN.LEFT) + + # Engines (compact, right of scheduler — still on the LEFT half of slide) + engines_x = sched_x + sched_w + 0.40 + engines_w = 2.4 + engines_y = sched_y + engine_h = 0.45 + engine_gap = 0.18 + engines = [ + ("pe_dma", COL_DMA, COL_TEXT_LIGHT), + ("pe_fetch_store", COL_FS, COL_TEXT_LIGHT), + ("pe_gemm", COL_GEMM, COL_TEXT_DARK), + ("pe_math", PE_MATH_COL, COL_TEXT_LIGHT), + ] + for i, (name, fill, tcol) in enumerate(engines): + y = engines_y + i * (engine_h + engine_gap) + _rrect(slide, engines_x, y, engines_w, engine_h, + fill, COL_TEXT_DARK, + name, + size=12, bold=True, color=tcol) + + # TileToken / PipelineContext annotation under the engine stack + last_y = engines_y + len(engines) * (engine_h + engine_gap) - engine_gap + note_y = last_y + 0.15 + _textbox(slide, engines_x, note_y, engines_w, 0.28, + "↻ TileToken.advance()", + size=10, bold=True, color=SCHED_BORDER, align=PP_ALIGN.CENTER) + _textbox(slide, engines_x, note_y + 0.28, engines_w, 0.28, + "PipelineContext counts tiles", + size=9, color=COL_MUTED, align=PP_ALIGN.CENTER) + + # PE_CPU → SCHEDULER arrow + _arrow(slide, pe_cpu_x + pe_cpu_w / 2, pe_cpu_y + pe_cpu_h + 0.02, + pe_cpu_x + pe_cpu_w / 2, sched_y - 0.02, + color=COL_TEXT_DARK, width_pt=2.5) + _textbox(slide, pe_cpu_x + pe_cpu_w + 0.10, + (pe_cpu_y + pe_cpu_h + sched_y) / 2 - 0.12, + 3.0, 0.22, "PeInternalTxn(cmd)", + size=11, bold=True, color=COL_MUTED, align=PP_ALIGN.LEFT) + + # SCHEDULER → engines arrows (one per engine, colour-coded) + sched_right = sched_x + sched_w + sched_cy = sched_y + sched_h / 2 + for i, (name, fill, _) in enumerate(engines): + engine_cy = engines_y + i * (engine_h + engine_gap) + engine_h / 2 + _arrow(slide, sched_right + 0.02, sched_cy, + engines_x - 0.02, engine_cy, + color=fill, width_pt=1.5) + + # Bottom note (full width) + _textbox(slide, 0.4, 6.10, 12.6, 0.85, + "Key invariants: (1) FIFO across commands via the single " + "feeder process — no inter-command tile interleaving. " + "(2) TileToken carries its own plan; each engine reads " + "token.current_stage, advances stage_idx, and forwards to " + "the next stage's component. (3) PipelineContext." + "complete_tile() fires done_event on the last tile, " + "unblocking PE_CPU.", + size=11, color=COL_MUTED, align=PP_ALIGN.LEFT) + + +# ── Slides 8 & 9: Sequence diagrams ──────────────────────────────────────── + +# Op kinds and their visual placement. +_KIND_INFO = { + "dma_read": {"fill": COL_DMA, "band": "between_hbm_tcm", "dir": "down"}, + "fetch": {"fill": COL_FS, "band": "between_tcm_reg", "dir": "down"}, + "gemm": {"fill": COL_GEMM, "band": "regfile", "dir": None}, + "store": {"fill": COL_FS, "band": "between_tcm_reg", "dir": "up"}, + "dma_write": {"fill": COL_DMA, "band": "between_hbm_tcm", "dir": "up"}, +} + + +def _draw_sequence_panel(slide, ops, *, x_left=0.6, x_right=12.9, + y_top=1.7, panel_h=4.4, + tile_boundary_after=None, + tile_labels=None, + contention_after=None): + """Generic sequence-diagram drawer. + + ops: list of (step, name, kind, bytes_str, ns_str) + tile_boundary_after: index after which to draw a red dashed divider + tile_labels: list[str] showing on either side of the boundary + contention_after: index after which to draw an HBM-contention callout + """ + n_ops = len(ops) + band_x = x_left + band_w = x_right - x_left + band_h = 0.55 + hbm_y = y_top + tcm_y = y_top + panel_h * 0.5 - band_h / 2 + reg_y = y_top + panel_h - band_h + + # Memory bands + _rect_band(slide, band_x, hbm_y, band_w, band_h, COL_HBM_BG, COL_HBM_BORDER) + _rect_band(slide, band_x, tcm_y, band_w, band_h, COL_TCM_BG, COL_TCM_BORDER) + _rect_band(slide, band_x, reg_y, band_w, band_h, COL_REG_BG, COL_REG_BORDER) + _textbox(slide, band_x + 0.05, hbm_y, 0.8, band_h, "HBM", + size=12, bold=True, color=COL_HBM_BORDER, align=PP_ALIGN.LEFT) + _textbox(slide, band_x + 0.05, tcm_y, 0.8, band_h, "TCM", + size=12, bold=True, color=COL_TCM_BORDER, align=PP_ALIGN.LEFT) + _textbox(slide, band_x + 0.05, reg_y, 0.85, band_h, "RegFile", + size=12, bold=True, color=COL_REG_BORDER, align=PP_ALIGN.LEFT) + + # Op columns span from after the level labels (~0.9 in margin) to right edge + ops_left = band_x + 1.0 + ops_w = band_w - 1.1 + col_w = ops_w / max(n_ops, 1) + op_box_w = col_w * 0.86 + op_box_h = 0.70 + + hbm_bot = hbm_y + band_h + tcm_top = tcm_y + tcm_bot = tcm_y + band_h + reg_top = reg_y + gap_ht = (hbm_bot + tcm_top) / 2 + gap_tr = (tcm_bot + reg_top) / 2 + + for idx, (step, name, kind, byt, ns) in enumerate(ops): + info = _KIND_INFO[kind] + x_center = ops_left + idx * col_w + col_w / 2 + x = x_center - op_box_w / 2 + # Step number + _textbox(slide, x, hbm_y - 0.35, op_box_w, 0.25, + f"#{step}", size=9, bold=True, color=COL_TEXT_DARK) + if info["band"] == "between_hbm_tcm": + y_box = gap_ht - op_box_h / 2 + elif info["band"] == "between_tcm_reg": + y_box = gap_tr - op_box_h / 2 + else: + y_box = reg_y + band_h / 2 - op_box_h / 2 + text_color = COL_TEXT_LIGHT if kind != "gemm" else COL_TEXT_DARK + label = f"{name}\n{byt} {ns}" + _rrect(slide, x, y_box, op_box_w, op_box_h, + info["fill"], COL_TEXT_DARK, label, + size=8, bold=True, color=text_color) + # Arrows + if info["dir"] == "down": + src = hbm_bot if info["band"] == "between_hbm_tcm" else tcm_bot + dst = tcm_top if info["band"] == "between_hbm_tcm" else reg_top + _arrow(slide, x_center, src, x_center, y_box, + color=info["fill"], width_pt=1.6) + _arrow(slide, x_center, y_box + op_box_h, x_center, dst, + color=info["fill"], width_pt=1.6) + elif info["dir"] == "up": + src = reg_top if info["band"] == "between_tcm_reg" else tcm_top + dst = tcm_bot if info["band"] == "between_tcm_reg" else hbm_bot + _arrow(slide, x_center, src, x_center, y_box + op_box_h, + color=info["fill"], width_pt=1.6) + _arrow(slide, x_center, y_box, x_center, dst, + color=info["fill"], width_pt=1.6) + + # Tile boundary + if tile_boundary_after is not None: + bx = ops_left + tile_boundary_after * col_w + _vline_dashed(slide, bx, hbm_y - 0.05, reg_y + band_h + 0.05, + COL_RED, width_pt=2.0) + if tile_labels: + _textbox(slide, bx - 1.7, hbm_y - 0.65, 1.6, 0.3, + tile_labels[0], size=10, bold=True, + color=COL_RED, align=PP_ALIGN.RIGHT) + _textbox(slide, bx + 0.1, hbm_y - 0.65, 1.9, 0.3, + tile_labels[1], size=10, bold=True, + color=COL_RED, align=PP_ALIGN.LEFT) + + # HBM contention callout + if contention_after is not None: + cx = ops_left + contention_after * col_w + cy = hbm_y + band_h + 0.05 + callout_w = 2.6 + callout_h = 0.7 + callout_x = max(band_x + 0.5, cx - callout_w / 2) + callout_x = min(callout_x, band_x + band_w - callout_w - 0.5) + _rrect(slide, callout_x, cy + 0.05, callout_w, callout_h, + RGBColor(0xFE, 0xE2, 0xE2), COL_RED, + "HBM contention:\nDMA_WRITE out + next DMA_READ A,B\ncompete for HBM BW", + size=8, bold=True, color=COL_RED) + _arrow(slide, callout_x + callout_w / 2, cy + 0.05, + cx, hbm_y + band_h * 0.5, + color=COL_RED, width_pt=1.8) + + # Engine legend (below the panel) + legend_y = reg_y + band_h + 0.6 + _rect_band(slide, band_x + 0.1, legend_y, 0.25, 0.22, + COL_DMA, COL_TEXT_DARK) + _textbox(slide, band_x + 0.4, legend_y - 0.05, 2.4, 0.35, + "pe_dma (HBM↔TCM)", size=10, color=COL_TEXT_DARK, + align=PP_ALIGN.LEFT) + _rect_band(slide, band_x + 3.0, legend_y, 0.25, 0.22, + COL_FS, COL_TEXT_DARK) + _textbox(slide, band_x + 3.3, legend_y - 0.05, 3.0, 0.35, + "pe_fetch_store (TCM↔RegFile)", size=10, color=COL_TEXT_DARK, + align=PP_ALIGN.LEFT) + _rect_band(slide, band_x + 6.6, legend_y, 0.25, 0.22, + COL_GEMM, COL_TEXT_DARK) + _textbox(slide, band_x + 6.9, legend_y - 0.05, 2.6, 0.35, + "pe_gemm (MAC compute)", size=10, color=COL_TEXT_DARK, + align=PP_ALIGN.LEFT) + + +def _draw_composite_setup_block(slide): + """Draw the CompositeCmd setup block before the first op column. + + Represents the scheduler-side delay between PE_CPU issuing the + CompositeCmd and the first DMA_R actually firing: plan generation, + FIFO feeder enqueue, and the per-cmd scheduler overhead. + """ + SCHED_COLOR = RGBColor(0x7C, 0x3A, 0xED) + SCHED_FILL = RGBColor(0xF3, 0xE8, 0xFF) + sx, sy = 0.65, 2.45 + sw, sh = 0.90, 1.05 + _rrect(slide, sx, sy, sw, sh, + SCHED_FILL, SCHED_COLOR, + "PLAN-GEN\n+ FEED\n(setup)", + size=8, bold=True, color=SCHED_COLOR) + _textbox(slide, sx, sy + sh + 0.05, sw, 0.22, + "scheduler", + size=8, color=SCHED_COLOR, align=PP_ALIGN.CENTER) + # Arrow into step 1's column (ops_left = band_x + 1.0 = 1.6 by default) + _arrow(slide, sx + sw + 0.02, sy + sh / 2, + 1.62, sy + sh / 2, + color=SCHED_COLOR, width_pt=1.5) + + +def _render_sequence_32x128x32(slide): + _textbox(slide, 0.4, 1.05, 12.6, 0.65, + "load_ref assumption: A (32×128) is pre-loaded into TCM via " + "tl.load before the kernel starts; only B is DMA_R'd per " + "tile. FETCH can start as soon as the first DMA_R(B) finishes " + "— A is already in TCM. Scheduler tile = 32×64×32 → 1·1·2 = " + "2 tiles. The PLAN-GEN/FEED block is the scheduler-side " + "setup delay before the first DMA fires.", + size=12, color=COL_MUTED, align=PP_ALIGN.LEFT) + # load_ref: no DMA_R for A; only B per tile. + ops = [ + (1, "DMA_R B (tile 0)", "dma_read", "4 KB", "~16 ns"), + (2, "FETCH (tile 0)", "fetch", "8 KB", "~16 ns"), + (3, "DMA_R B (tile 1)", "dma_read", "4 KB", "~16 ns"), + (4, "GEMM K=0 (accum)", "gemm", "—", "~17 ns"), + (5, "FETCH (tile 1)", "fetch", "8 KB", "~16 ns"), + (6, "GEMM K=1 (last)", "gemm", "—", "~17 ns"), + (7, "STORE final", "store", "2 KB", "~4 ns"), + (8, "DMA_W out", "dma_write", "2 KB", "~8 ns"), + ] + _draw_sequence_panel(slide, ops) + _draw_composite_setup_block(slide) + + # "A pinned in TCM" annotation overlaid on the TCM band, left side + _textbox(slide, 0.65, 3.92, 1.6, 0.22, + "[ A pinned via tl.load ]", + size=9, bold=True, color=COL_TCM_BORDER, align=PP_ALIGN.LEFT) + + _textbox(slide, 0.4, 6.55, 12.6, 0.45, + "Pipeline is balanced (DMA, FETCH, GEMM all ~16 ns/tile) — " + "DMA engine carries only B, so per-tile DMA cost halves vs " + "ref_ref. Wall = setup + head_latency + N_tiles·T_stage + " + "final STORE+DMA_W.", + size=11, color=COL_MUTED, align=PP_ALIGN.LEFT) + + +def _render_sequence_32x128x128(slide): + _textbox(slide, 0.4, 1.05, 12.6, 0.70, + "load_ref assumption: A (32×128) is pre-loaded into TCM via " + "tl.load before the kernel starts; only B is DMA_R'd per " + "tile. FETCH starts as soon as the corresponding B arrives — " + "A is already in TCM. Scheduler tile = 32×64×32 → 1·4·2 = 8 " + "tiles. PLAN-GEN/FEED block = scheduler-side setup delay " + "before the first DMA.", + size=12, color=COL_MUTED, align=PP_ALIGN.LEFT) + # load_ref: only B per tile, A pre-loaded. DMA_R scattered across timeline. + ops = [ + (1, "DMA_R B (0,0,0)", "dma_read", "4 KB", "~16 ns"), + (2, "FETCH (0,0,0)", "fetch", "8 KB", "~16 ns"), + (3, "DMA_R B (0,0,1)", "dma_read", "4 KB", "~16 ns"), + (4, "GEMM K=0", "gemm", "—", "~17 ns"), + (5, "FETCH (0,0,1)", "fetch", "8 KB", "~16 ns"), + (6, "DMA_R B (0,1,0)", "dma_read", "4 KB", "~16 ns"), + (7, "GEMM K=1 last", "gemm", "—", "~17 ns"), + (8, "STORE out₀₀", "store", "2 KB", "~4 ns"), + (9, "DMA_W out₀₀", "dma_write", "2 KB", "~8 ns"), + (10, "FETCH (0,1,0)", "fetch", "8 KB", "~16 ns"), + (11, "DMA_R B (0,1,1)", "dma_read", "4 KB", "~16 ns"), + (12, "GEMM K=0 (0,1)", "gemm", "—", "~17 ns"), + ] + _draw_sequence_panel( + slide, ops, + tile_boundary_after=9, + tile_labels=["── (m,n)=(0,0) full execution ──", + "── (0,1) starts ──"], + contention_after=9, + ) + _draw_composite_setup_block(slide) + + # "A pinned in TCM" annotation overlaid on the TCM band + _textbox(slide, 0.65, 3.92, 1.6, 0.22, + "[ A pinned via tl.load ]", + size=9, bold=True, color=COL_TCM_BORDER, align=PP_ALIGN.LEFT) + + _textbox(slide, 0.4, 6.55, 12.6, 0.4, + "HBM half-duplex caveat: real HBM channels can't read and write " + "simultaneously. DMA_W out₀₀ competes for HBM bandwidth with " + "DMA_R(B) of (0,1). Simulator currently models PE_DMA read / " + "write as separate resources (full-duplex) — flag for revisit " + "if half-duplex matters.", + size=10, color=COL_RED, align=PP_ALIGN.LEFT) + _textbox(slide, 0.4, 6.95, 12.6, 0.3, + "Pattern repeats for (m,n) = (0,1), (0,2), (0,3). DMA engine " + "only carries B per tile, so the pipeline stays balanced — " + "no DMA bottleneck like ref_ref.", + size=10, color=COL_MUTED, align=PP_ALIGN.LEFT) + + +# ── Slides 10 & 11: Tiling walkthroughs ──────────────────────────────────── + +def _draw_matrix_tiles(slide, *, x, y, n_rows, n_cols, tile_w, tile_h, + fill, border, label_prefix, label_fmt=None, + title=None, title_size=11): + """Draw a matrix as a grid of tiles. Returns (right_x, bottom_y). + + label_fmt(r, c) returns the per-tile label; if None defaults to + f"{label_prefix}{r}{c}" for 2-d or f"{label_prefix}{c}" for 1-row. + """ + if title is not None: + _textbox(slide, x, y - 0.27, n_cols * tile_w, 0.22, + title, size=title_size, bold=True, + color=COL_TEXT_DARK, align=PP_ALIGN.CENTER) + for r in range(n_rows): + for c in range(n_cols): + if label_fmt is not None: + label = label_fmt(r, c) + elif n_rows == 1: + label = f"{label_prefix}{c}" + else: + label = f"{label_prefix}{r}{c}" + _rrect(slide, x + c * tile_w, y + r * tile_h, + tile_w, tile_h, + fill, border, label, + size=10 if n_rows == 1 else 9, + bold=True, color=border) + return x + n_cols * tile_w, y + n_rows * tile_h + + +def _render_tiling_32x128x32(slide): + """32×128×32 — K-loop only, single (m,n) → NO inter-(m,n) flush. + + Visualises why a tall-thin K shape is friendly: the accumulator stays + in RegFile across the entire K loop, and STORE + DMA_W fire ONCE at + the very end. No inter-pair flush serialises HBM bandwidth. + """ + _textbox(slide, 0.4, 1.0, 12.6, 0.55, + "Scheduler tile = 32×64×32 → 1·2·1 = 2 tiles. Only ONE (m,n) " + "output → the K-loop accumulates entirely in RegFile, and " + "STORE + DMA_W fire just once at the very end. No inter-(m,n) " + "flush — DMA never has to drain mid-compute.", + size=12, color=COL_MUTED, align=PP_ALIGN.LEFT) + + # ── Matrix formula C = A · B (centered on slide) ───────────────── + tile_w = 0.60 + tile_h = 0.60 + matrix_top = 2.10 + + # C (1 × 1) — single output tile + c_x = 1.6 + c_y = matrix_top + _draw_matrix_tiles(slide, x=c_x, y=c_y, + n_rows=1, n_cols=1, + tile_w=tile_w, tile_h=tile_h, + fill=RGBColor(0xFE, 0xF3, 0xC7), + border=COL_REG_BORDER, + label_prefix="C", title="C (32 × 32)") + + # "=" + _textbox(slide, c_x + tile_w + 0.10, c_y - 0.05, 0.40, 0.70, + "=", size=26, bold=True, color=COL_TEXT_DARK, + align=PP_ALIGN.CENTER) + + # A (1 × 2) — 2 K-tiles + a_x = c_x + tile_w + 0.60 + a_y = matrix_top + _draw_matrix_tiles(slide, x=a_x, y=a_y, + n_rows=1, n_cols=2, + tile_w=tile_w, tile_h=tile_h, + fill=RGBColor(0xDB, 0xEA, 0xFE), + border=COL_HBM_BORDER, + label_prefix="A", title="A (32 × 128)") + _textbox(slide, a_x, a_y + tile_h + 0.05, 2 * tile_w, 0.22, + "← K (2 tiles) →", + size=9, color=COL_MUTED, align=PP_ALIGN.CENTER) + + # "·" + _textbox(slide, a_x + 2 * tile_w + 0.10, c_y - 0.10, 0.30, 0.70, + "·", size=28, bold=True, color=COL_TEXT_DARK, + align=PP_ALIGN.CENTER) + + # B (2 × 1) — 2 K-tiles, 1 N-tile + b_x = a_x + 2 * tile_w + 0.55 + b_y = matrix_top - tile_h / 2 # center on row + _draw_matrix_tiles(slide, x=b_x, y=b_y, + n_rows=2, n_cols=1, + tile_w=tile_w, tile_h=tile_h, + fill=RGBColor(0xDB, 0xEA, 0xFE), + border=COL_HBM_BORDER, + label_prefix="B", title="B (128 × 32)", + label_fmt=lambda r, c: f"B{r}") + _textbox(slide, b_x - 0.45, b_y + tile_h - 0.10, 0.40, 0.30, + "K\n↓", size=9, color=COL_MUTED, align=PP_ALIGN.CENTER) + + # Formula + _textbox(slide, 0.4, matrix_top + 2 * tile_h + 0.30, 12.6, 0.40, + "C = A0·B0 + A1·B1 (K-loop, 2 iterations — " + "accumulator stays in RegFile)", + size=13, bold=True, color=COL_TEXT_DARK, + align=PP_ALIGN.CENTER) + + # ── Timeline: single (m,n) pair, NO flush ─────────────────────── + tl_y = matrix_top + 2 * tile_h + 1.10 + _textbox(slide, 0.4, tl_y - 0.30, 12.6, 0.22, + "Execution timeline — 1 (m,n) output, 0 inter-(m,n) flushes:", + size=12, bold=True, color=COL_TEXT_DARK, + align=PP_ALIGN.LEFT) + + pair_w = 4.50 + pair_h = 0.85 + pair_x = (SLIDE_W_IN - pair_w) / 2 + k0_w = pair_w * 0.45 + k1_w = pair_w * 0.53 + + _rrect(slide, pair_x, tl_y, k0_w, pair_h, + RGBColor(0xD1, 0xFA, 0xE5), COL_TCM_BORDER, + "K=0 (accumulate into RegFile)", + size=11, bold=True, color=COL_TCM_BORDER) + _rrect(slide, pair_x + k0_w + 0.02, tl_y, k1_w, pair_h, + RGBColor(0xFE, 0xF3, 0xC7), COL_REG_BORDER, + "K=1 last\nSTORE + DMA_W (final drain)", + size=11, bold=True, color=COL_REG_BORDER) + _textbox(slide, pair_x, tl_y + pair_h + 0.05, pair_w, 0.25, + "(m,n)=(0,0) → C (single output tile)", + size=11, bold=True, color=COL_TEXT_DARK, + align=PP_ALIGN.CENTER) + # "no flush" callout to the right + _textbox(slide, pair_x + pair_w + 0.20, tl_y + 0.15, 2.5, 0.55, + "✓ No inter-(m,n) flush\n (only 1 output pair)", + size=11, bold=True, color=COL_TCM_BORDER, + align=PP_ALIGN.LEFT) + + # ── Bottom note ───────────────────────────────────────────────── + _textbox(slide, 0.4, 6.30, 12.6, 0.70, + "Why growing K helps: each (m,n) pair amortises its single " + "STORE+DMA_W over K_tiles iterations of pure compute. With " + "N=32 (one N-tile), there is no NEXT (m,n) pair, so no " + "inter-pair flush at all. Pipeline efficiency is bottlenecked " + "only by head latency and the final drain.", + size=11, color=COL_MUTED, align=PP_ALIGN.LEFT) + + +def _render_tiling_32x128x128(slide): + """32×128×128 — K-loop + 3 inter-(m,n) flushes (N_tiles=4 → 4 pairs).""" + _textbox(slide, 0.4, 1.0, 12.6, 0.55, + "Scheduler tile = 32×64×32 → 1·2·4 = 8 tiles. A is split along " + "K (2 tiles); B along K and N (2×4); C along N (4). For each " + "(m,n) the K-loop accumulates in RegFile; STORE + DMA_W fire " + "only on last K → 3 inter-(m,n) flushes between the 4 pairs.", + size=12, color=COL_MUTED, align=PP_ALIGN.LEFT) + + # ── Matrix formula C = A · B (centered) ────────────────────────── + tile_w = 0.55 + tile_h = 0.55 + matrix_top = 2.10 + + # C (1 × 4) + c_x = 1.50 + c_y = matrix_top + _draw_matrix_tiles(slide, x=c_x, y=c_y, + n_rows=1, n_cols=4, + tile_w=tile_w, tile_h=tile_h, + fill=RGBColor(0xFE, 0xF3, 0xC7), + border=COL_REG_BORDER, + label_prefix="C", title="C (32 × 128)") + _textbox(slide, c_x, c_y + tile_h + 0.05, 4 * tile_w, 0.22, + "← N (4 tiles) →", + size=9, color=COL_MUTED, align=PP_ALIGN.CENTER) + + # "=" + eq_x = c_x + 4 * tile_w + 0.20 + _textbox(slide, eq_x, c_y - 0.05, 0.40, 0.70, + "=", size=26, bold=True, color=COL_TEXT_DARK, + align=PP_ALIGN.CENTER) + + # A (1 × 2) + a_x = eq_x + 0.55 + a_y = matrix_top + _draw_matrix_tiles(slide, x=a_x, y=a_y, + n_rows=1, n_cols=2, + tile_w=tile_w, tile_h=tile_h, + fill=RGBColor(0xDB, 0xEA, 0xFE), + border=COL_HBM_BORDER, + label_prefix="A", title="A (32 × 128)") + _textbox(slide, a_x, a_y + tile_h + 0.05, 2 * tile_w, 0.22, + "← K (2 tiles) →", + size=9, color=COL_MUTED, align=PP_ALIGN.CENTER) + + # "·" + dot_x = a_x + 2 * tile_w + 0.15 + _textbox(slide, dot_x, c_y - 0.10, 0.30, 0.70, + "·", size=28, bold=True, color=COL_TEXT_DARK, + align=PP_ALIGN.CENTER) + + # B (2 × 4) + b_x = dot_x + 0.45 + b_y = matrix_top - tile_h / 2 # center on the row of A and C + _draw_matrix_tiles(slide, x=b_x, y=b_y, + n_rows=2, n_cols=4, + tile_w=tile_w, tile_h=tile_h, + fill=RGBColor(0xDB, 0xEA, 0xFE), + border=COL_HBM_BORDER, + label_prefix="B", title="B (128 × 128)") + _textbox(slide, b_x, b_y + 2 * tile_h + 0.05, 4 * tile_w, 0.22, + "← N (4 tiles) →", + size=9, color=COL_MUTED, align=PP_ALIGN.CENTER) + _textbox(slide, b_x - 0.45, b_y + tile_h - 0.10, 0.40, 0.30, + "K\n↓", size=9, color=COL_MUTED, align=PP_ALIGN.CENTER) + + # Formula + _textbox(slide, 0.4, matrix_top + 2 * tile_h + 0.30, 12.6, 0.40, + "Cn = A0·B0n + A1·B1n (K-loop, 2 iters per (m,n) — " + "accumulator stays in RegFile)", + size=13, bold=True, color=COL_TEXT_DARK, + align=PP_ALIGN.CENTER) + + # ── Timeline: 4 (m,n) pairs + 3 inter-pair flushes ────────────── + tl_y = matrix_top + 2 * tile_h + 1.10 + _textbox(slide, 0.4, tl_y - 0.30, 12.6, 0.22, + "Execution timeline — 4 (m,n) outputs, 3 inter-(m,n) flushes:", + size=12, bold=True, color=COL_TEXT_DARK, + align=PP_ALIGN.LEFT) + + pair_w = 1.85 + pair_h = 0.75 + flush_w = 0.65 + total_w = 4 * pair_w + 3 * flush_w + start_x = (SLIDE_W_IN - total_w) / 2 + + pair_x = start_x + for n in range(4): + k0_w = pair_w * 0.45 + k1_w = pair_w * 0.53 + _rrect(slide, pair_x, tl_y, k0_w, pair_h, + RGBColor(0xD1, 0xFA, 0xE5), COL_TCM_BORDER, + "K=0\n(accum)", + size=9, bold=True, color=COL_TCM_BORDER) + _rrect(slide, pair_x + k0_w + 0.02, tl_y, k1_w, pair_h, + RGBColor(0xFE, 0xF3, 0xC7), COL_REG_BORDER, + "K=1 last\nSTORE +\nDMA_W", + size=9, bold=True, color=COL_REG_BORDER) + _textbox(slide, pair_x, tl_y + pair_h + 0.05, pair_w, 0.22, + f"(0,{n}) → C{n}", + size=10, bold=True, color=COL_TEXT_DARK, + align=PP_ALIGN.CENTER) + if n < 3: + flush_x = pair_x + pair_w + 0.05 + _rrect(slide, flush_x, tl_y, flush_w - 0.10, pair_h, + RGBColor(0xFE, 0xE2, 0xE2), COL_RED, + "FLUSH\n(DMA_W\n vs DMA_R)", + size=8, bold=True, color=COL_RED) + pair_x = flush_x + flush_w - 0.05 + else: + pair_x += pair_w + 0.05 + + # ── Bottom note ──────────────────────────────────────────────── + _textbox(slide, 0.4, 6.30, 12.6, 0.70, + "Why flushes hurt: at every (m,n) boundary, DMA_W of Cn " + "competes with DMA_R of the next pair for the cube-shared " + "HBM channel. Inter-flush count = (M_tiles · N_tiles − 1) — " + "for 32×128×128 that's 1·4−1 = 3 flushes. Bigger N → more " + "flushes; bigger K alone (with small N) → none.", + size=11, color=COL_MUTED, align=PP_ALIGN.LEFT) + + +_RENDERERS = { + "pipeline_structure": _render_pipeline_structure, + "scheduler": _render_scheduler, + "sequence_32x128x32": _render_sequence_32x128x32, + "sequence_32x128x128": _render_sequence_32x128x128, + "tiling_32x128x32": _render_tiling_32x128x32, + "tiling_32x128x128": _render_tiling_32x128x128, +} + + +# ── Bar-chart slides (read from docs/diagrams/gemm_sweep.json) ───────────── + +import json + +GEMM_SWEEP_JSON = DIAG / "gemm_sweep.json" + + +def _under_tile(M, K, N, tile_M, tile_K, tile_N): + return M < tile_M or K < tile_K or N < tile_N + +STAGE_KEYS = ["DMA_READ", "FETCH", "GEMM", "DMA_WRITE"] +STAGE_DISPLAY = { + "DMA_READ": "DMA in", + "FETCH": "Fetch", + "GEMM": "GEMM", + "MATH": "Math", + "DMA_WRITE": "DMA out", +} +STAGE_COLORS_PPTX = { + "DMA_READ": RGBColor(0x3B, 0x82, 0xF6), + "FETCH": RGBColor(0x10, 0xB9, 0x81), + "GEMM": RGBColor(0xF5, 0x9E, 0x0B), + "MATH": RGBColor(0xEF, 0x44, 0x44), + "DMA_WRITE": RGBColor(0xA8, 0x55, 0xF7), +} +VARIANT_COLORS_PPTX = { + "ref_ref": RGBColor(0x10, 0xB9, 0x81), + "load_ref": RGBColor(0xF5, 0x9E, 0x0B), + "load_load": RGBColor(0xEF, 0x44, 0x44), +} + + +def _shape_label(r: dict) -> str: + if r["M"] == r["K"] == r["N"]: + return f"M=K=N={r['M']}" + return f"M={r['M']}\nK={r['K']}\nN={r['N']}" + + +def _draw_native_bar_chart(slide, *, plot_x, plot_y, plot_w, plot_h, + shape_labels, flagged, tile_counts, + series, colors_map, display_map, + wall_clocks=None, + y_label="ns", + legend_x, legend_w, + foot_note=None, + threshold_line=None, + flagged_bar_color=None, + flagged_series_only=None): + """Render a grouped bar chart natively in PPTX. Linear Y scale. + + series: dict[str -> list[float]] — series_name → value per shape. + threshold_line: if set, draws a dashed horizontal reference line at this y-value. + flagged_bar_color: if set, overrides colors_map[sname] for flagged shapes. + flagged_series_only: if set, the flag override applies only to this series. + """ + n_shapes = len(shape_labels) + if n_shapes == 0: + return + series_names = list(series.keys()) + n_series = len(series_names) + + # Y-axis range: linear, top = max value * 1.10 (include wall-clock). + all_vals = [v for vals in series.values() for v in vals if v > 0] + if wall_clocks: + all_vals.extend([w for w in wall_clocks if w > 0]) + if threshold_line is not None: + all_vals.append(threshold_line) + y_max = max(all_vals) * 1.10 if all_vals else 1.0 + + def y_of(v): + v = max(v, 0.0) + return plot_y + plot_h * (1 - v / y_max) + + # Plot box background + _rect_band(slide, plot_x, plot_y, plot_w, plot_h, + RGBColor(0xFF, 0xFF, 0xFF), RGBColor(0xCB, 0xD5, 0xE1)) + + # Y-axis ticks (6 levels) + for i in range(6): + v = y_max * i / 5 + y = y_of(v) + # gridline + s = slide.shapes.add_connector(1, Inches(plot_x), Inches(y), + Inches(plot_x + plot_w), Inches(y)) + s.line.color.rgb = RGBColor(0xE2, 0xE8, 0xF0) + s.line.width = Pt(0.5) + # label + _textbox(slide, plot_x - 0.85, y - 0.12, 0.75, 0.25, + f"{v:>8.0f}", size=9, color=COL_MUTED, align=PP_ALIGN.RIGHT) + + # Y-axis title + _textbox(slide, plot_x - 0.85, plot_y + plot_h / 2 - 0.15, 0.75, 0.3, + y_label, size=10, color=COL_TEXT_DARK) + + # Threshold reference line (dashed) + if threshold_line is not None: + ty = y_of(threshold_line) + line = slide.shapes.add_connector( + 1, Inches(plot_x), Inches(ty), + Inches(plot_x + plot_w), Inches(ty), + ) + line.line.color.rgb = COL_TEXT_DARK + line.line.width = Pt(1.5) + from pptx.oxml.ns import qn + from lxml import etree + ln = line.line._get_or_add_ln() + pr = ln.find(qn("a:prstDash")) + if pr is None: + pr = etree.SubElement(ln, qn("a:prstDash")) + pr.set("val", "dash") + _textbox(slide, plot_x + plot_w - 0.7, ty - 0.30, 0.7, 0.25, + f"{threshold_line:.0f}% peak", + size=9, bold=True, color=COL_TEXT_DARK, align=PP_ALIGN.RIGHT) + + # Geometry per shape group + group_w = plot_w / (n_shapes * 1.4) + bar_w = group_w / max(n_series, 1) + gap = (plot_w - n_shapes * group_w) / (n_shapes + 1) + + y_base = plot_y + plot_h + for i in range(n_shapes): + x_group = plot_x + gap + i * (group_w + gap) + cx = x_group + group_w / 2 + + for j, sname in enumerate(series_names): + v = series[sname][i] + if v <= 0: + continue + bx = x_group + j * bar_w + y_top = y_of(v) + s = slide.shapes.add_shape( + MSO_SHAPE.RECTANGLE, + Inches(bx), Inches(y_top), + Inches(bar_w * 0.85), Inches(y_base - y_top), + ) + s.fill.solid() + apply_flag = (flagged_bar_color is not None and flagged[i] + and (flagged_series_only is None + or sname == flagged_series_only)) + fill_color = flagged_bar_color if apply_flag else colors_map[sname] + s.fill.fore_color.rgb = fill_color + s.line.color.rgb = COL_TEXT_DARK + s.line.width = Pt(0.4) + + # Wall-clock dot + if wall_clocks and wall_clocks[i] > 0: + wy = y_of(wall_clocks[i]) + r = 0.05 + d = slide.shapes.add_shape( + MSO_SHAPE.OVAL, + Inches(cx - r), Inches(wy - r), + Inches(r * 2), Inches(r * 2), + ) + d.fill.solid() + d.fill.fore_color.rgb = COL_TEXT_DARK + d.line.color.rgb = COL_TEXT_DARK + + # Shape label below the group + lab_color = COL_RED if flagged[i] else COL_TEXT_DARK + _textbox(slide, x_group, y_base + 0.05, group_w, 0.7, + shape_labels[i], size=9, bold=flagged[i], color=lab_color) + if tile_counts: + _textbox(slide, x_group, y_base + 0.65, group_w, 0.3, + f"{tile_counts[i]} tiles", + size=8, color=COL_MUTED) + if flagged[i]: + _textbox(slide, x_group, y_base + 0.92, group_w, 0.3, + "↑ under-tile", + size=8, bold=True, color=COL_RED) + + # Legend on the right + _textbox(slide, legend_x, plot_y, legend_w, 0.3, + "Stages (per bar):" if "DMA_READ" in series_names + else "Variants (per bar):", + size=11, bold=True, color=COL_TEXT_DARK, align=PP_ALIGN.LEFT) + ly = plot_y + 0.4 + for sname in series_names: + # color swatch + sw = slide.shapes.add_shape( + MSO_SHAPE.RECTANGLE, + Inches(legend_x), Inches(ly), Inches(0.2), Inches(0.2), + ) + sw.fill.solid() + sw.fill.fore_color.rgb = colors_map[sname] + sw.line.color.rgb = COL_TEXT_DARK + sw.line.width = Pt(0.5) + _textbox(slide, legend_x + 0.28, ly - 0.05, legend_w - 0.3, 0.3, + display_map.get(sname, sname), + size=10, color=COL_TEXT_DARK, align=PP_ALIGN.LEFT) + ly += 0.28 + + if wall_clocks: + ly += 0.08 + d = slide.shapes.add_shape( + MSO_SHAPE.OVAL, + Inches(legend_x + 0.05), Inches(ly + 0.04), + Inches(0.12), Inches(0.12), + ) + d.fill.solid() + d.fill.fore_color.rgb = COL_TEXT_DARK + d.line.color.rgb = COL_TEXT_DARK + _textbox(slide, legend_x + 0.28, ly - 0.05, legend_w - 0.3, 0.3, + "kernel wall-clock", + size=10, color=COL_TEXT_DARK, align=PP_ALIGN.LEFT) + + # Extra legend entry for flagged (under-tile) bars + if flagged_bar_color is not None and any(flagged): + ly += 0.30 + sw = slide.shapes.add_shape( + MSO_SHAPE.RECTANGLE, + Inches(legend_x), Inches(ly), Inches(0.2), Inches(0.2), + ) + sw.fill.solid() + sw.fill.fore_color.rgb = flagged_bar_color + sw.line.color.rgb = COL_TEXT_DARK + sw.line.width = Pt(0.5) + flagged_label_target = flagged_series_only or ( + list(display_map.keys())[0] if display_map else "value" + ) + flagged_label = ( + display_map.get(flagged_label_target, flagged_label_target) + + " (under-tile)" + ) + _textbox(slide, legend_x + 0.28, ly - 0.05, legend_w - 0.3, 0.3, + flagged_label, + size=10, color=COL_TEXT_DARK, align=PP_ALIGN.LEFT) + + if foot_note: + _textbox(slide, plot_x, y_base + 1.3, plot_w, 0.4, + foot_note, size=10, color=COL_MUTED, align=PP_ALIGN.LEFT) + + +EXCLUDED_SHAPES = {(512, 512, 512)} + + +def _load_sweep_data() -> dict: + if not GEMM_SWEEP_JSON.exists(): + return {"rows": []} + data = json.loads(GEMM_SWEEP_JSON.read_text()) + data["rows"] = [ + r for r in data.get("rows", []) + if (r["M"], r["K"], r["N"]) not in EXCLUDED_SHAPES + ] + return data + + +def _render_stage_breakdown(slide, variant: str, *, per_tile: bool = False): + """Stage breakdown for one variant. Linear Y. If per_tile, divide by tile count. + + Uses wall_ns (interval-union of records on each engine) — the honest + engine-active time. Avoids the double-counting that occupancy_ns has + when multiple ops overlap on a contended resource (HBM_CTRL queue). + """ + data = _load_sweep_data() + rows = [r for r in data["rows"] if r.get("variant") == variant] + if not rows: + _textbox(slide, 0.4, 3.0, 12.6, 1.0, + f"No sweep data found for variant '{variant}'. " + f"Run scripts/gemm_sweep.py first.", + size=14, color=COL_RED, align=PP_ALIGN.LEFT) + return + tile = data["tile_sizes"] + subtitle_unit = "Per-tile" if per_tile else "Per-stage" + _textbox(slide, 0.4, 1.0, 12.6, 0.45, + f"Variant: {variant} | {subtitle_unit} engine wall-clock " + f"(linear) — DMA in / Fetch / GEMM / DMA out per shape. " + f"Tile size {tile['M']}×{tile['K']}×{tile['N']}.", + size=12, color=COL_MUTED, align=PP_ALIGN.LEFT) + + shape_labels = [_shape_label(r) for r in rows] + flagged = [_under_tile(r["M"], r["K"], r["N"], + tile["M"], tile["K"], tile["N"]) for r in rows] + tile_counts = [r["tile_count_expected"] for r in rows] + + def _val(r, s): + v = r.get("stages", {}).get(s, {}).get("wall_ns", 0.0) + if per_tile: + tc = r.get("tile_count_expected", 0) or 1 + return v / tc + return v + + series = {s: [_val(r, s) for r in rows] for s in STAGE_KEYS} + foot_note = ( + "Bars = engine wall-clock ÷ tile count (amortized per-tile cost). " + "Falls with tile count as the pipeline fills." + if per_tile else + "Bars = engine wall-clock interval (max t_end − min t_start, " + "merged overlaps). Strips queue-wait double-counting." + ) + _draw_native_bar_chart( + slide, + plot_x=1.0, plot_y=1.65, plot_w=10.0, plot_h=4.45, + shape_labels=shape_labels, flagged=flagged, + tile_counts=tile_counts, + series=series, colors_map=STAGE_COLORS_PPTX, + display_map=STAGE_DISPLAY, + wall_clocks=None, + y_label="ns/tile" if per_tile else "ns", + legend_x=11.4, legend_w=1.85, + foot_note=foot_note, + ) + + +def _render_stage_breakdown_ref_ref(slide): + _render_stage_breakdown(slide, "ref_ref") + + +def _render_stage_breakdown_load_ref(slide): + _render_stage_breakdown(slide, "load_ref") + + +def _render_stage_breakdown_load_load(slide): + _render_stage_breakdown(slide, "load_load") + + +def _render_variant_comparison(slide): + """Wall-clock per shape per variant (3 bars per shape).""" + data = _load_sweep_data() + rows = data["rows"] + if not rows: + _textbox(slide, 0.4, 3.0, 12.6, 1.0, + "No sweep data. Run scripts/gemm_sweep.py first.", + size=14, color=COL_RED, align=PP_ALIGN.LEFT) + return + tile = data["tile_sizes"] + variants = data.get("variants", ["ref_ref", "load_ref", "load_load"]) + + # Group by shape (preserve first-seen order). + by_shape: dict = {} + for r in rows: + key = (r["M"], r["K"], r["N"]) + by_shape.setdefault(key, {})[r["variant"]] = r + shapes = list(by_shape.keys()) + + sample = next(iter(by_shape.values()))[next(iter(by_shape[shapes[0]]))] + sample_label = _shape_label(sample) + _ = sample_label # silence unused warning + + shape_labels = [_shape_label(by_shape[k][next(iter(by_shape[k]))]) + for k in shapes] + flagged = [_under_tile(k[0], k[1], k[2], tile["M"], tile["K"], tile["N"]) + for k in shapes] + tile_counts = [by_shape[k][next(iter(by_shape[k]))]["tile_count_expected"] + for k in shapes] + series = { + v: [(by_shape[k].get(v) or {}).get("pe_window_ns", 0.0) for k in shapes] + for v in variants + } + display_map = {v: v for v in variants} + + _textbox(slide, 0.4, 1.0, 12.6, 0.4, + "Kernel wall-clock per variant per shape (linear). " + "ref_ref baseline, load_ref pins A, load_load pins both A and B.", + size=12, color=COL_MUTED, align=PP_ALIGN.LEFT) + + _draw_native_bar_chart( + slide, + plot_x=1.0, plot_y=1.6, plot_w=10.0, plot_h=4.5, + shape_labels=shape_labels, flagged=flagged, + tile_counts=tile_counts, + series=series, colors_map=VARIANT_COLORS_PPTX, + display_map=display_map, + wall_clocks=None, + y_label="wall ns", + legend_x=11.4, legend_w=1.85, + foot_note=("After Phase 2 fix (gated STORE/DMA_WRITE + pinned operand " + "skip): load_ref / load_load are faster than ref_ref."), + ) + + +def _render_hbm_topology(slide): + """Show cube-shared HBM_CTRL path: PE → router → HBM_CTRL → HBM.""" + _textbox(slide, 0.4, 1.0, 12.6, 0.55, + "DMA reads cross the cube fabric — HBM_CTRL is one per cube " + "(NOT per PE). All 8 PEs serialize at the controller's single " + "channel resource. Even one active PE pays the round-trip on " + "every K-tile miss.", + size=13, color=COL_MUTED, align=PP_ALIGN.LEFT) + + # Cube outline (encloses 8 PEs + their TCMs + the router mesh) + cube_x, cube_y, cube_w, cube_h = 0.4, 1.75, 6.4, 4.45 + s = slide.shapes.add_shape( + MSO_SHAPE.RECTANGLE, + Inches(cube_x), Inches(cube_y), Inches(cube_w), Inches(cube_h), + ) + s.fill.background() + s.line.color.rgb = COL_MUTED + s.line.width = Pt(1.5) + _textbox(slide, cube_x + 0.1, cube_y + 0.04, 2.5, 0.3, + "CUBE (8 PEs share HBM)", size=11, bold=True, color=COL_MUTED, + align=PP_ALIGN.LEFT) + + # 8 PEs in a 4-cols × 2-rows grid, each with TCM beneath + pe_w = 1.05 + pe_h = 0.62 + tcm_h = 0.42 + col_gap = 0.18 + row_gap = 0.30 + grid_x0 = cube_x + 0.25 + grid_y0 = cube_y + 0.50 + row_h = pe_h + tcm_h + row_gap + pe_right_y = [] # y-center of each PE for the link arrows + for i in range(8): + row = i // 4 + col = i % 4 + px = grid_x0 + col * (pe_w + col_gap) + py = grid_y0 + row * row_h + _rrect(slide, px, py, pe_w, pe_h, + RGBColor(0xFE, 0xF3, 0xC7), COL_REG_BORDER, + f"PE{i}", size=11, bold=True, color=COL_TEXT_DARK) + _rrect(slide, px, py + pe_h + 0.05, pe_w, tcm_h, + COL_TCM_BG, COL_TCM_BORDER, + "TCM (local)\n512 GB/s", + size=8, color=COL_TCM_BORDER) + if col == 3: + pe_right_y.append((px + pe_w, py + pe_h / 2)) + + # Router mesh strip on the right edge of the cube + router_x = cube_x + cube_w - 0.85 + router_y = cube_y + 0.50 + router_w = 0.65 + router_h = cube_h - 0.65 + _rrect(slide, router_x, router_y, router_w, router_h, + RGBColor(0xDB, 0xEA, 0xFE), COL_DMA, + "ROUTER\nMESH\n\n256 GB/s\nper link", + size=10, bold=True, color=COL_DMA) + + # Arrows from each row's last PE → router strip + for (rx, ry) in pe_right_y: + _arrow(slide, rx + 0.02, ry, router_x, ry, + color=COL_DMA, width_pt=1.4) + + # HBM_CTRL just outside the cube on the right + ctrl_x = cube_x + cube_w + 0.45 + ctrl_w = 1.85 + ctrl_h = 1.55 + ctrl_y = cube_y + cube_h / 2 - ctrl_h / 2 + _rrect(slide, ctrl_x, ctrl_y, ctrl_w, ctrl_h, + RGBColor(0xFE, 0xE2, 0xE2), COL_RED, + "HBM_CTRL\n(1 per cube)\n\nread channel\ncap = 1\n", + size=11, bold=True, color=COL_RED) + + # Bottleneck label + _textbox(slide, ctrl_x - 0.1, ctrl_y + ctrl_h + 0.05, ctrl_w + 0.2, 0.35, + "BOTTLENECK", size=12, bold=True, color=COL_RED, + align=PP_ALIGN.CENTER) + + # Router strip → HBM_CTRL arrow + link_y = cube_y + cube_h / 2 + _arrow(slide, router_x + router_w, link_y, + ctrl_x, link_y, color=COL_DMA, width_pt=3.0) + _textbox(slide, router_x + router_w + 0.02, + link_y - 0.36, ctrl_x - (router_x + router_w) - 0.05, 0.3, + "256 GB/s", size=10, bold=True, color=COL_DMA, + align=PP_ALIGN.CENTER) + + # HBM banks on the far right + hbm_x = ctrl_x + ctrl_w + 0.45 + hbm_w = 1.85 + hbm_h = ctrl_h + 0.35 + hbm_y = ctrl_y - 0.175 + _rrect(slide, hbm_x, hbm_y, hbm_w, hbm_h, + COL_HBM_BG, COL_HBM_BORDER, + "HBM BANKS\n(per-cube)\n\n256 GB/s\naggregated", + size=11, bold=True, color=COL_HBM_BORDER) + _arrow(slide, ctrl_x + ctrl_w, link_y, + hbm_x, link_y, color=COL_DMA, width_pt=3.0) + + # Side-by-side key takeaways at the bottom + _textbox(slide, 0.4, 6.40, 6.3, 0.55, + "TCM is per-PE local → fetch/store don't contend.\n" + "HBM_CTRL is cube-shared → every DMA serializes on cap=1 channel.", + size=11, color=COL_TEXT_DARK, align=PP_ALIGN.LEFT) + _textbox(slide, 6.9, 6.40, 6.2, 0.55, + "Per-op DMA cost grows with #in-flight ops even on 1 PE.\n" + "load_ref/load_load pin operands → 1 HBM trip instead of per K-tile.", + size=11, color=COL_RED, align=PP_ALIGN.LEFT) + + +def _render_per_op_dma(slide): + """Per-op DMA_READ cost = wall_ns / record_count. + + wall_ns is interval-union of all DMA_READ records ≈ (max t_end - min t_start) + when ops overlap. Dividing by count gives the amortized per-op cost in the + DMA window — converges to the bandwidth-bound floor. + + load_load is excluded — its eager up-front DMAs sit outside the composite + plan so their stage_type isn't DMA_READ and they don't appear here. + """ + data = _load_sweep_data() + rows = data["rows"] + if not rows: + _textbox(slide, 0.4, 3.0, 12.6, 1.0, + "No sweep data. Run scripts/gemm_sweep.py first.", + size=14, color=COL_RED, align=PP_ALIGN.LEFT) + return + tile = data["tile_sizes"] + variants = ["ref_ref", "load_ref"] + + by_shape: dict = {} + for r in rows: + key = (r["M"], r["K"], r["N"]) + by_shape.setdefault(key, {})[r["variant"]] = r + shapes = list(by_shape.keys()) + + shape_labels = [_shape_label(by_shape[k][next(iter(by_shape[k]))]) + for k in shapes] + flagged = [_under_tile(k[0], k[1], k[2], tile["M"], tile["K"], tile["N"]) + for k in shapes] + tile_counts = [by_shape[k][next(iter(by_shape[k]))]["tile_count_expected"] + for k in shapes] + + def _ns_per_op(r): + s = r.get("stages", {}).get("DMA_READ", {}) + cnt = s.get("record_count", 0) + wall = s.get("wall_ns", 0.0) + return (wall / cnt) if cnt else 0.0 + + series = { + v: [_ns_per_op(by_shape[k].get(v) or {"stages": {}}) for k in shapes] + for v in variants + } + display_map = {v: v for v in variants} + + _textbox(slide, 0.4, 1.0, 12.6, 0.55, + "Amortized per-op cost = (DMA window wall-clock) ÷ (#DMA ops). " + "Strips out queue-wait double-counting: when N ops overlap, the " + "window is N·drain_ns, so the average per op = drain_ns ≈ 16 ns " + "(bandwidth-bound floor at 4 KB ÷ 256 GB/s).", + size=12, color=COL_MUTED, align=PP_ALIGN.LEFT) + + _draw_native_bar_chart( + slide, + plot_x=1.0, plot_y=1.7, plot_w=10.0, plot_h=4.4, + shape_labels=shape_labels, flagged=flagged, + tile_counts=tile_counts, + series=series, colors_map=VARIANT_COLORS_PPTX, + display_map=display_map, + wall_clocks=None, + y_label="ns / op", + legend_x=11.4, legend_w=1.85, + foot_note=("Flat ~16-20 ns across shapes confirms the per-op " + "transfer is constant — what looked like growing per-op " + "cost on slide before was queue wait being absorbed."), + ) + + +def _render_mac_utilization(slide): + """GEMM util (shape fill) AND Useful pipeline eff (computed from formula). + + Useful eff = pipe_eff × GEMM_util, where + pipe_eff = (N_tiles × T_stage) / (head + N_tiles × T_stage + inter DMA_W) + """ + data = _load_sweep_data() + rows = data["rows"] + if not rows: + _textbox(slide, 0.4, 3.0, 12.6, 1.0, + "No sweep data. Run scripts/gemm_sweep.py first.", + size=14, color=COL_RED, align=PP_ALIGN.LEFT) + return + tile = data["tile_sizes"] + TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"] + tile_flops = 2 * TILE_M * TILE_K * TILE_N + + HBM_GBS = 256.0 + bpe = 2 + T_STAGE = 16.0 + D_STAGES = 3 + head_ns = (D_STAGES - 1) * T_STAGE + dma_w_per_pair_ns = (TILE_M * TILE_N * bpe) / HBM_GBS + + by_shape: dict = {} + for r in rows: + if r["variant"] != "load_ref": + continue + by_shape[(r["M"], r["K"], r["N"])] = r + shapes = list(by_shape.keys()) + + shape_labels = [_shape_label(by_shape[k]) for k in shapes] + flagged = [_under_tile(k[0], k[1], k[2], TILE_M, TILE_K, TILE_N) + for k in shapes] + tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes] + + gemm_util = [] + useful_eff = [] + for k in shapes: + r = by_shape[k] + M, K, N = r["M"], r["K"], r["N"] + useful = 2 * M * K * N + tiles = r["tile_count_expected"] + gu = useful / (tile_flops * tiles) * 100 + gemm_util.append(gu) + + m_tiles = (M + TILE_M - 1) // TILE_M + n_tiles = (N + TILE_N - 1) // TILE_N + n_mn = m_tiles * n_tiles + gemm_total = tiles * T_STAGE + inter_dma_w = max(0, n_mn - 1) * dma_w_per_pair_ns + wall = head_ns + gemm_total + inter_dma_w + ueff = (gemm_total * (gu / 100.0) / wall) * 100 if wall > 0 else 0.0 + useful_eff.append(ueff) + + series = { + "GEMM util": gemm_util, + "Useful eff": useful_eff, + } + colors_map = { + "GEMM util": COL_FS, # emerald + "Useful eff": RGBColor(0xF5, 0x9E, 0x0B), # amber + } + display_map = { + "GEMM util": "GEMM util %", + "Useful eff": "Useful eff %", + } + + _textbox(slide, 0.4, 1.0, 12.6, 0.70, + f"GEMM util = useful FLOPs ÷ (tile FLOPs × tile count) — pure " + f"shape-vs-tile metric. " + f"Useful eff = (N_tiles × T_stage × GEMM_util) ÷ wall — " + f"useful FLOPs delivered as a fraction of peak over the " + f"ideal-pipelined wall (head + K-loop + inter-(m,n) DMA_W).", + size=11, color=COL_MUTED, align=PP_ALIGN.LEFT) + + _draw_native_bar_chart( + slide, + plot_x=1.0, plot_y=1.75, plot_w=10.0, plot_h=4.35, + shape_labels=shape_labels, flagged=flagged, + tile_counts=tile_counts, + series=series, colors_map=colors_map, + display_map=display_map, + wall_clocks=None, + y_label="%", + legend_x=11.4, legend_w=1.85, + foot_note=("GEMM util < 100% → shape mismatch (padded zeros). " + "Useful eff < GEMM util → pipeline overhead " + "(head + inter-(m,n) DMA_W) eats more of the wall."), + threshold_line=100.0, + flagged_bar_color=COL_RED, + flagged_series_only="GEMM util", + ) + + +def _render_mac_utilization_ref_ref(slide): + """Same metric as slide 14 but for the ref_ref variant. + + In ref_ref both A and B are loaded from HBM by the scheduler — that's + TWO back-to-back DMA_R per tile, so the DMA stage takes 2 × T_stage = + 32 ns/tile while FETCH/GEMM are still 16 ns/tile. The pipeline is + DMA-bound — steady-state cycle = 32 ns/tile — so useful pipeline + efficiency caps near 50 % × GEMM_util. + """ + data = _load_sweep_data() + rows = data["rows"] + if not rows: + _textbox(slide, 0.4, 3.0, 12.6, 1.0, + "No sweep data. Run scripts/gemm_sweep.py first.", + size=14, color=COL_RED, align=PP_ALIGN.LEFT) + return + tile = data["tile_sizes"] + TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"] + tile_flops = 2 * TILE_M * TILE_K * TILE_N + + HBM_GBS = 256.0 + bpe = 2 + T_STAGE_COMPUTE = 16.0 # FETCH = GEMM = 16 ns/tile + T_STAGE_DMA_REF_REF = 2 * 16.0 # 2 DMA_R per tile (A + B) + T_STAGE = T_STAGE_DMA_REF_REF # DMA-bound steady-state cycle + D_STAGES = 3 # DMA, FETCH, GEMM + head_ns = (D_STAGES - 1) * T_STAGE_COMPUTE # pipeline fill = 32 ns + dma_w_per_pair_ns = (TILE_M * TILE_N * bpe) / HBM_GBS + + by_shape: dict = {} + for r in rows: + if r["variant"] != "ref_ref": + continue + by_shape[(r["M"], r["K"], r["N"])] = r + shapes = list(by_shape.keys()) + + shape_labels = [_shape_label(by_shape[k]) for k in shapes] + flagged = [_under_tile(k[0], k[1], k[2], TILE_M, TILE_K, TILE_N) + for k in shapes] + tile_counts = [by_shape[k]["tile_count_expected"] for k in shapes] + + gemm_util = [] + useful_eff = [] + for k in shapes: + r = by_shape[k] + M, K, N = r["M"], r["K"], r["N"] + useful = 2 * M * K * N + tiles = r["tile_count_expected"] + gu = useful / (tile_flops * tiles) * 100 + gemm_util.append(gu) + + m_tiles = (M + TILE_M - 1) // TILE_M + n_tiles = (N + TILE_N - 1) // TILE_N + n_mn = m_tiles * n_tiles + # Useful compute time: each tile delivers T_STAGE_COMPUTE worth + # of MAC. Steady-state pipeline cycle is DMA-bound (32 ns/tile). + compute_total = tiles * T_STAGE_COMPUTE + wall_steady = tiles * T_STAGE + inter_dma_w = max(0, n_mn - 1) * dma_w_per_pair_ns + wall = head_ns + wall_steady + inter_dma_w + ueff = (compute_total * (gu / 100.0) / wall) * 100 \ + if wall > 0 else 0.0 + useful_eff.append(ueff) + + series = { + "GEMM util": gemm_util, + "Useful eff": useful_eff, + } + colors_map = { + "GEMM util": COL_FS, + "Useful eff": RGBColor(0xF5, 0x9E, 0x0B), + } + display_map = { + "GEMM util": "GEMM util %", + "Useful eff": "Useful eff % (ref_ref)", + } + + _textbox(slide, 0.4, 1.0, 12.6, 0.75, + "ref_ref: scheduler issues DMA_R for BOTH A and B every tile. " + "Per-tile DMA cost = 2 × T_stage = 32 ns; FETCH and GEMM stay " + "at 16 ns. Pipeline cycle is DMA-bound → useful eff caps near " + "50 % × GEMM_util, regardless of K-loop length.", + size=11, color=COL_MUTED, align=PP_ALIGN.LEFT) + + _draw_native_bar_chart( + slide, + plot_x=1.0, plot_y=1.85, plot_w=10.0, plot_h=4.25, + shape_labels=shape_labels, flagged=flagged, + tile_counts=tile_counts, + series=series, colors_map=colors_map, + display_map=display_map, + wall_clocks=None, + y_label="%", + legend_x=11.4, legend_w=1.85, + foot_note=("Compare to load_ref (prev slide): tl.load pins A in " + "TCM once, eliminating the per-tile A DMA_R → DMA stage " + "halves to 16 ns/tile → useful eff roughly doubles at " + "the same GEMM util."), + threshold_line=100.0, + flagged_bar_color=COL_RED, + flagged_series_only="GEMM util", + ) + + +def _render_tflops_table(slide): + """Ideal pipelined pipe_eff: assumes non-blocking tl.load + multi-channel HBM. + + Three-stage pipeline (DMA_R → FETCH → GEMM), all stages bandwidth-balanced + at T_stage = 16 ns/tile. Wall = pipeline fill + steady-state K-loop + + inter-(m,n) DMA_W (final flush excluded — tail, not in pipeline). + """ + data = _load_sweep_data() + rows = data["rows"] + if not rows: + _textbox(slide, 0.4, 3.0, 12.6, 1.0, + "No sweep data. Run scripts/gemm_sweep.py first.", + size=14, color=COL_RED, align=PP_ALIGN.LEFT) + return + tile = data["tile_sizes"] + TILE_M, TILE_K, TILE_N = tile["M"], tile["K"], tile["N"] + + HBM_GBS = 256.0 # bytes/ns + bpe = 2 # f16 + T_STAGE = 16.0 # ns per pipeline stage (all stages BW-balanced) + D_STAGES = 3 # DMA_R, FETCH, GEMM + head_ns = (D_STAGES - 1) * T_STAGE # pipeline fill: (D-1) × T_stage = 32 ns + dma_w_per_pair_ns = (TILE_M * TILE_N * bpe) / HBM_GBS # = 8 ns + + by_shape: dict = {} + for r in rows: + by_shape.setdefault((r["M"], r["K"], r["N"]), {})[r["variant"]] = r + shapes = list(by_shape.keys()) + + _textbox(slide, 0.4, 1.0, 12.6, 0.85, + f"Ideal pipelined model — assumes non-blocking tl.load + " + f"multi-channel HBM so DMA, FETCH and GEMM all run at " + f"T_stage = {T_STAGE:.0f} ns/tile.\n" + f" wall = head_latency + N_tiles × T_stage + Σ inter-(m,n) DMA_W\n" + f" head_latency = (D−1) × T_stage = {head_ns:.0f} ns " + f"(pipeline fill, D = {D_STAGES} stages).\n" + f" inter DMA_W = (N_mn − 1) × {dma_w_per_pair_ns:.0f} ns " + f"(final flush is tail — excluded).\n" + f" useful eff = (N_tiles × T_stage × GEMM_util) / wall — " + f"MAC time producing real output, not padded zeros.", + size=11, color=COL_MUTED, align=PP_ALIGN.LEFT) + + headers = ["Input shape", "GEMM util %", "Useful eff %"] + n_rows = len(shapes) + 1 + n_cols = len(headers) + table_x = 2.5 + table_y = 2.10 + table_w = 8.4 + table_h = 4.30 + + tbl_shape = slide.shapes.add_table( + n_rows, n_cols, + Inches(table_x), Inches(table_y), Inches(table_w), Inches(table_h), + ) + tbl = tbl_shape.table + + widths_in = [3.0, 2.5, 2.9] + for ci, w in enumerate(widths_in): + tbl.columns[ci].width = Inches(w) + + def _set_cell(ci, ri, text, *, bold=False, color=COL_TEXT_DARK, + fill=None, size=12, align=PP_ALIGN.CENTER): + cell = tbl.cell(ri, ci) + if fill is not None: + cell.fill.solid() + cell.fill.fore_color.rgb = fill + tf = cell.text_frame + tf.margin_left = Emu(36000) + tf.margin_right = Emu(36000) + tf.margin_top = Emu(18000) + tf.margin_bottom = Emu(18000) + tf.word_wrap = True + p = tf.paragraphs[0] + p.alignment = align + p.text = "" + run = p.add_run() + run.text = text + run.font.size = Pt(size) + run.font.bold = bold + run.font.name = "Consolas" + run.font.color.rgb = color + + for ci, h in enumerate(headers): + _set_cell(ci, 0, h, bold=True, color=COL_TEXT_LIGHT, + fill=RGBColor(0x10, 0x2A, 0x55), size=12) + + tile_flops = 2 * TILE_M * TILE_K * TILE_N + for ri, k in enumerate(shapes, start=1): + M, K, N = k + useful = 2 * M * K * N + any_row = next(iter(by_shape[k].values())) + tiles = any_row["tile_count_expected"] + gemm_util = useful / (tile_flops * tiles) * 100 + is_under = _under_tile(M, K, N, TILE_M, TILE_K, TILE_N) + + m_tiles = (M + TILE_M - 1) // TILE_M + n_tiles = (N + TILE_N - 1) // TILE_N + n_mn = m_tiles * n_tiles + + gemm_total_ns = tiles * T_STAGE + inter_dma_w_ns = max(0, n_mn - 1) * dma_w_per_pair_ns + wall_ns = head_ns + gemm_total_ns + inter_dma_w_ns + # Useful eff = pipe_eff × GEMM_util (shape-waste included). + pipe_eff = ((gemm_total_ns * (gemm_util / 100.0)) / wall_ns) * 100 \ + if wall_ns > 0 else 0.0 + + row_fill = (RGBColor(0xFE, 0xF2, 0xF2) if is_under + else RGBColor(0xF8, 0xFA, 0xFC)) + text_color = COL_RED if is_under else COL_TEXT_DARK + + shape_txt = f"{M}×{K}×{N}" + (" (under-tile)" if is_under else "") + _set_cell(0, ri, shape_txt, bold=is_under, color=text_color, + fill=row_fill, align=PP_ALIGN.LEFT) + _set_cell(1, ri, f"{gemm_util:.1f} %", bold=is_under, + color=text_color, fill=row_fill) + _set_cell(2, ri, f"{pipe_eff:.1f} %", fill=row_fill, + color=COL_TEXT_DARK, bold=True) + + _textbox(slide, 0.4, 6.50, 12.6, 0.6, + "What the model could achieve with non-blocking tl.load + " + "multi-channel HBM (A streams tile-by-tile, no upfront serial). " + "Today's simulator caps at ~50 % for tall-skinny K because tl.load " + "is whole-operand blocking and the cube has a single HBM channel.", + size=11, color=COL_MUTED, align=PP_ALIGN.LEFT) + + +def _render_pipeline_eff_walkthrough(slide): + """Visualize how Useful eff is calculated for one example shape. + + Example: 32×128×128 (8 K-tiles in 4 (m,n) groups, head + 3 inter DMA_W). + Shows the 3-stage pipeline structure, the time-line of all GEMM tiles + + overheads, and evaluates the formula numerically. + """ + # Example parameters (kept hardcoded — this slide is illustrative) + M, K, N = 32, 128, 128 + T_STAGE = 16 + D_STAGES = 3 # DMA_R, FETCH, GEMM + K_TILES = 2 # K=128 → K_tiles = 2 + N_MN = 4 # M_tile × N_tile = 1 × 4 + TOTAL_TILES = K_TILES * N_MN # 8 + HEAD_NS = (D_STAGES - 1) * T_STAGE # 32 + DMA_W_NS = 8 # 32×32×2 / 256 = 8 ns + INTER_FLUSH_COUNT = N_MN - 1 # 3 flushes (final excluded) + GEMM_TOTAL = TOTAL_TILES * T_STAGE # 128 + INTER_DMA_W_TOTAL = INTER_FLUSH_COUNT * DMA_W_NS # 24 + WALL = HEAD_NS + GEMM_TOTAL + INTER_DMA_W_TOTAL # 184 + USEFUL_EFF = GEMM_TOTAL / WALL * 100 # 69.6% + + # Colors + C_HEAD = RGBColor(0x94, 0xA3, 0xB8) # slate gray + C_GEMM = RGBColor(0x10, 0xB9, 0x81) # emerald (useful) + C_FLUSH = RGBColor(0xF5, 0x9E, 0x0B) # amber (inter-(m,n)) + C_TAIL = RGBColor(0xFC, 0xA5, 0xA5) # light red (excluded) + C_DMA_R = RGBColor(0x3B, 0x82, 0xF6) # blue + C_FETCH = RGBColor(0x10, 0xB9, 0x81) + C_GEMM_BOX = RGBColor(0xF5, 0x9E, 0x0B) + + _textbox(slide, 0.4, 1.0, 12.6, 0.5, + f"Example: M=N=32, K=128, N-output=128 → {TOTAL_TILES} tiles in " + f"{N_MN} output groups of {K_TILES} K-tiles each. T_stage = " + f"{T_STAGE} ns/tile, pipeline depth D = {D_STAGES}.", + size=12, color=COL_MUTED, align=PP_ALIGN.LEFT) + + # ── Top: 3-stage pipeline structure diagram ───────────────────── + pd_y = 1.7 + pd_h = 0.55 + box_w = 1.4 + box_gap = 0.35 + pd_total_w = 3 * box_w + 2 * box_gap + pd_x0 = (SLIDE_W_IN - pd_total_w) / 2 + stages = [("DMA_R\nHBM → TCM", C_DMA_R, COL_TEXT_LIGHT), + ("FETCH\nTCM → Reg", C_FETCH, COL_TEXT_LIGHT), + ("GEMM\nMAC array", C_GEMM_BOX, COL_TEXT_DARK)] + for i, (lbl, fill, tc) in enumerate(stages): + x = pd_x0 + i * (box_w + box_gap) + _rrect(slide, x, pd_y, box_w, pd_h, fill, COL_TEXT_DARK, + lbl, size=10, bold=True, color=tc) + if i < 2: + ax1 = x + box_w + 0.03 + ax2 = x + box_w + box_gap - 0.03 + _arrow(slide, ax1, pd_y + pd_h / 2, ax2, pd_y + pd_h / 2, + color=COL_MUTED, width_pt=1.8) + _textbox(slide, pd_x0, pd_y + pd_h + 0.05, pd_total_w, 0.28, + f"each stage = {T_STAGE} ns/tile → " + f"head latency = (D−1) × T_stage = {HEAD_NS} ns", + size=10, color=COL_MUTED, align=PP_ALIGN.CENTER) + + # ── Middle: timeline of the full kernel ───────────────────────── + # Layout: head + 4 K-loops separated by 3 DMA_W + 1 tail (excluded) + tl_y = 3.4 + tl_h = 0.85 + tl_label_y = tl_y - 0.3 + tl_legend_y = tl_y + tl_h + 0.18 + + margin = 0.8 + tl_total_ns = WALL + DMA_W_NS # include tail visually + tl_w_total = SLIDE_W_IN - 2 * margin + px_per_ns = tl_w_total / tl_total_ns + + cur_x = margin + # Head + w = HEAD_NS * px_per_ns + _rect_band(slide, cur_x, tl_y, w, tl_h, C_HEAD, COL_TEXT_DARK) + _textbox(slide, cur_x, tl_y + tl_h / 2 - 0.13, w, 0.3, + f"head\n{HEAD_NS} ns", + size=9, bold=True, color=COL_TEXT_LIGHT) + cur_x += w + + # K-loops + inter DMA_W + for g in range(N_MN): + # K-loop: K_TILES GEMM blocks + for t in range(K_TILES): + w = T_STAGE * px_per_ns + _rect_band(slide, cur_x, tl_y, w, tl_h, C_GEMM, COL_TEXT_DARK) + tile_no = g * K_TILES + t + 1 + _textbox(slide, cur_x, tl_y + tl_h / 2 - 0.13, w, 0.3, + f"GEMM\nT{tile_no}", + size=9, bold=True, color=COL_TEXT_LIGHT) + cur_x += w + # Mark (m,n) group label above + group_start_x = cur_x - K_TILES * T_STAGE * px_per_ns + group_w = K_TILES * T_STAGE * px_per_ns + _textbox(slide, group_start_x, tl_label_y, + group_w, 0.25, + f"(m=0,n={g}) K-loop", + size=9, bold=True, color=COL_TEXT_DARK) + if g < N_MN - 1: + w = DMA_W_NS * px_per_ns + _rect_band(slide, cur_x, tl_y, w, tl_h, C_FLUSH, COL_TEXT_DARK) + _textbox(slide, cur_x, tl_y + tl_h / 2 - 0.13, w, 0.3, + "flush", size=8, bold=True, color=COL_TEXT_DARK) + cur_x += w + + # Tail flush (excluded) + w = DMA_W_NS * px_per_ns + _rect_band(slide, cur_x, tl_y, w, tl_h, C_TAIL, COL_RED) + _textbox(slide, cur_x, tl_y + tl_h / 2 - 0.13, w, 0.3, + "tail", size=8, bold=True, color=COL_RED) + _textbox(slide, cur_x - 0.3, tl_y + tl_h + 0.02, w + 0.6, 0.28, + "(excluded\nfrom wall)", + size=8, bold=True, color=COL_RED) + + # Wall span indicator below the timeline + wall_end_x = margin + WALL * px_per_ns + _arrow(slide, margin, tl_y + tl_h + 0.55, + wall_end_x, tl_y + tl_h + 0.55, color=COL_TEXT_DARK, width_pt=1.5) + _arrow(slide, wall_end_x, tl_y + tl_h + 0.55, + margin, tl_y + tl_h + 0.55, color=COL_TEXT_DARK, width_pt=1.5) + _textbox(slide, margin, tl_y + tl_h + 0.6, + wall_end_x - margin, 0.3, + f"wall = {WALL} ns", + size=11, bold=True, color=COL_TEXT_DARK, align=PP_ALIGN.CENTER) + + # ── Bottom: numerical evaluation ─────────────────────────────── + formula_y = 5.7 + _textbox(slide, 0.4, formula_y, 12.6, 0.35, + f"wall = head + N_tiles × T_stage + (N_mn − 1) × T_dma_w = " + f"{HEAD_NS} + {TOTAL_TILES}×{T_STAGE} + {INTER_FLUSH_COUNT}×{DMA_W_NS} " + f"= {WALL} ns", + size=12, color=COL_TEXT_DARK, align=PP_ALIGN.LEFT) + _textbox(slide, 0.4, formula_y + 0.35, 12.6, 0.35, + f"GEMM useful time = N_tiles × T_stage × GEMM_util = " + f"{TOTAL_TILES}×{T_STAGE}×100 % = {GEMM_TOTAL} ns", + size=12, color=COL_TEXT_DARK, align=PP_ALIGN.LEFT) + _textbox(slide, 0.4, formula_y + 0.7, 12.6, 0.45, + f"Useful efficiency = {GEMM_TOTAL} / {WALL} = {USEFUL_EFF:.1f} %", + size=14, bold=True, color=COL_TEXT_DARK, align=PP_ALIGN.LEFT) + _textbox(slide, 0.4, formula_y + 1.15, 12.6, 0.35, + "Overhead = head (pipeline fill) + inter-(m,n) flushes. " + "Bigger K (more amortization) and smaller N (fewer groups) " + "both raise the efficiency.", + size=10, color=COL_MUTED, align=PP_ALIGN.LEFT) + + +def _render_pipeline_eff_walkthrough_largeK(slide): + """Walkthrough for the tall-skinny case: M=N=32, K=3072. + + N_mn = 1 (one output group) → zero inter-(m,n) DMA_W flushes. + K-loop has 48 tiles — compressed visually (show first 4 + ... + last 2). + """ + M, K, N = 32, 3072, 32 + T_STAGE = 16 + D_STAGES = 3 + K_TILES = 48 + N_MN = 1 + TOTAL_TILES = K_TILES + HEAD_NS = (D_STAGES - 1) * T_STAGE + DMA_W_NS = 8 + GEMM_TOTAL = TOTAL_TILES * T_STAGE + INTER_DMA_W_TOTAL = 0 + WALL = HEAD_NS + GEMM_TOTAL + INTER_DMA_W_TOTAL + USEFUL_EFF = GEMM_TOTAL / WALL * 100 + + C_HEAD = RGBColor(0x94, 0xA3, 0xB8) + C_GEMM = RGBColor(0x10, 0xB9, 0x81) + C_TAIL = RGBColor(0xFC, 0xA5, 0xA5) + C_DMA_R = RGBColor(0x3B, 0x82, 0xF6) + C_FETCH = RGBColor(0x10, 0xB9, 0x81) + C_GEMM_BOX = RGBColor(0xF5, 0x9E, 0x0B) + + _textbox(slide, 0.4, 1.0, 12.6, 0.5, + f"Example: M=N=32, K=3072 → {TOTAL_TILES} tiles, " + f"N_mn = M_tiles × N_tiles = 1 × 1 = 1 → " + f"no inter-(m,n) DMA_W flushes. Long K-loop amortizes the head.", + size=12, color=COL_MUTED, align=PP_ALIGN.LEFT) + + # 3-stage pipeline structure (same as before) + pd_y = 1.7 + pd_h = 0.55 + box_w = 1.4 + box_gap = 0.35 + pd_total_w = 3 * box_w + 2 * box_gap + pd_x0 = (SLIDE_W_IN - pd_total_w) / 2 + stages = [("DMA_R\nHBM → TCM", C_DMA_R, COL_TEXT_LIGHT), + ("FETCH\nTCM → Reg", C_FETCH, COL_TEXT_LIGHT), + ("GEMM\nMAC array", C_GEMM_BOX, COL_TEXT_DARK)] + for i, (lbl, fill, tc) in enumerate(stages): + x = pd_x0 + i * (box_w + box_gap) + _rrect(slide, x, pd_y, box_w, pd_h, fill, COL_TEXT_DARK, + lbl, size=10, bold=True, color=tc) + if i < 2: + ax1 = x + box_w + 0.03 + ax2 = x + box_w + box_gap - 0.03 + _arrow(slide, ax1, pd_y + pd_h / 2, ax2, pd_y + pd_h / 2, + color=COL_MUTED, width_pt=1.8) + _textbox(slide, pd_x0, pd_y + pd_h + 0.05, pd_total_w, 0.28, + f"each stage = {T_STAGE} ns/tile → " + f"head latency = (D−1) × T_stage = {HEAD_NS} ns", + size=10, color=COL_MUTED, align=PP_ALIGN.CENTER) + + # Timeline — compressed (show 4 tiles + gap + 2 tiles) + tl_y = 3.4 + tl_h = 0.85 + tl_label_y = tl_y - 0.3 + + margin = 0.8 + visible_first = 4 + visible_last = 2 + skipped = TOTAL_TILES - visible_first - visible_last + # Width budget: head + visible tiles + ellipsis block + tail + ellipsis_ns_equiv = 6 * T_STAGE # rendered width = 6 tiles worth + tl_total_ns = (HEAD_NS + (visible_first + visible_last) * T_STAGE + + ellipsis_ns_equiv + DMA_W_NS) + tl_w_total = SLIDE_W_IN - 2 * margin + px_per_ns = tl_w_total / tl_total_ns + + cur_x = margin + # Head + w = HEAD_NS * px_per_ns + _rect_band(slide, cur_x, tl_y, w, tl_h, C_HEAD, COL_TEXT_DARK) + _textbox(slide, cur_x, tl_y + tl_h / 2 - 0.13, w, 0.3, + f"head\n{HEAD_NS} ns", + size=9, bold=True, color=COL_TEXT_LIGHT) + cur_x += w + + # Group label for the entire K-loop (one (m,n)) + kloop_start_x = cur_x + # First few tiles + for t in range(visible_first): + w = T_STAGE * px_per_ns + _rect_band(slide, cur_x, tl_y, w, tl_h, C_GEMM, COL_TEXT_DARK) + _textbox(slide, cur_x, tl_y + tl_h / 2 - 0.13, w, 0.3, + f"GEMM\nT{t + 1}", + size=9, bold=True, color=COL_TEXT_LIGHT) + cur_x += w + # Ellipsis block + ew = ellipsis_ns_equiv * px_per_ns + _rect_band(slide, cur_x, tl_y, ew, tl_h, + RGBColor(0x86, 0xEF, 0xAC), COL_TEXT_DARK) + _textbox(slide, cur_x, tl_y + tl_h / 2 - 0.2, ew, 0.5, + f". . . {skipped} more GEMM tiles . . .", + size=10, bold=True, color=COL_TEXT_DARK) + cur_x += ew + # Last tiles + for t in range(visible_last): + w = T_STAGE * px_per_ns + _rect_band(slide, cur_x, tl_y, w, tl_h, C_GEMM, COL_TEXT_DARK) + _textbox(slide, cur_x, tl_y + tl_h / 2 - 0.13, w, 0.3, + f"GEMM\nT{TOTAL_TILES - visible_last + t + 1}", + size=9, bold=True, color=COL_TEXT_LIGHT) + cur_x += w + # K-loop group label + kloop_end_x = cur_x + _textbox(slide, kloop_start_x, tl_label_y, + kloop_end_x - kloop_start_x, 0.25, + f"(m=0,n=0) K-loop — all {TOTAL_TILES} tiles in one group, " + f"NO inter flushes", + size=10, bold=True, color=COL_TEXT_DARK, + align=PP_ALIGN.CENTER) + + # Tail + w = DMA_W_NS * px_per_ns + _rect_band(slide, cur_x, tl_y, w, tl_h, C_TAIL, COL_RED) + _textbox(slide, cur_x, tl_y + tl_h / 2 - 0.13, w, 0.3, + "tail", size=8, bold=True, color=COL_RED) + _textbox(slide, cur_x - 0.3, tl_y + tl_h + 0.02, w + 0.6, 0.28, + "(excluded\nfrom wall)", + size=8, bold=True, color=COL_RED) + + # Wall arrow + wall_end_x = margin + (HEAD_NS + (visible_first + visible_last) * T_STAGE + + ellipsis_ns_equiv) * px_per_ns + _arrow(slide, margin, tl_y + tl_h + 0.55, + wall_end_x, tl_y + tl_h + 0.55, color=COL_TEXT_DARK, width_pt=1.5) + _arrow(slide, wall_end_x, tl_y + tl_h + 0.55, + margin, tl_y + tl_h + 0.55, color=COL_TEXT_DARK, width_pt=1.5) + _textbox(slide, margin, tl_y + tl_h + 0.6, + wall_end_x - margin, 0.3, + f"wall = {WALL} ns", + size=11, bold=True, color=COL_TEXT_DARK, align=PP_ALIGN.CENTER) + + # Formula evaluation + formula_y = 5.7 + _textbox(slide, 0.4, formula_y, 12.6, 0.35, + f"wall = head + N_tiles × T_stage + (N_mn − 1) × T_dma_w = " + f"{HEAD_NS} + {TOTAL_TILES}×{T_STAGE} + 0 = {WALL} ns", + size=12, color=COL_TEXT_DARK, align=PP_ALIGN.LEFT) + _textbox(slide, 0.4, formula_y + 0.35, 12.6, 0.35, + f"GEMM useful time = {TOTAL_TILES}×{T_STAGE}×100 % = {GEMM_TOTAL} ns", + size=12, color=COL_TEXT_DARK, align=PP_ALIGN.LEFT) + _textbox(slide, 0.4, formula_y + 0.7, 12.6, 0.45, + f"Useful efficiency = {GEMM_TOTAL} / {WALL} = {USEFUL_EFF:.1f} %", + size=14, bold=True, color=COL_TEXT_DARK, align=PP_ALIGN.LEFT) + _textbox(slide, 0.4, formula_y + 1.15, 12.6, 0.35, + "Long K-loop, one output group → head amortized over 48 GEMM " + "tiles, no flush penalty. Approaches 100 % as K grows.", + size=10, color=COL_MUTED, align=PP_ALIGN.LEFT) + + +_RENDERERS.update({ + "stage_breakdown_ref_ref": _render_stage_breakdown_ref_ref, + "stage_breakdown_load_ref": _render_stage_breakdown_load_ref, + "stage_breakdown_load_load": _render_stage_breakdown_load_load, + "variant_comparison": _render_variant_comparison, + "hbm_topology": _render_hbm_topology, + "per_op_dma": _render_per_op_dma, + "mac_utilization": _render_mac_utilization, + "mac_utilization_ref_ref": _render_mac_utilization_ref_ref, + "tflops_table": _render_tflops_table, + "pipeline_eff_walkthrough": _render_pipeline_eff_walkthrough, + "pipeline_eff_walkthrough_largeK": _render_pipeline_eff_walkthrough_largeK, +}) + + def build(): prs = Presentation() prs.slide_width = Inches(SLIDE_W_IN) @@ -149,17 +2300,21 @@ def build(): slide = prs.slides.add_slide(blank) _add_title(slide, cfg["title"]) - # Layout: image on the left (8.4 in wide), bullets on the right (4.4 in). - _add_image_centered( - slide, cfg["image"], - left_in=0.3, top_in=1.05, - max_w_in=8.3, max_h_in=5.9, - ) - _add_bullets( - slide, cfg["bullets"], - left_in=8.8, top_in=1.2, - width_in=4.3, height_in=5.7, - ) + if "render" in cfg: + # Shape-drawn slide (sequence diagram / pipeline structure). + _RENDERERS[cfg["render"]](slide) + else: + # Default: image on the left (8.4 in wide), bullets on the right. + _add_image_centered( + slide, cfg["image"], + left_in=0.3, top_in=1.05, + max_w_in=8.3, max_h_in=5.9, + ) + _add_bullets( + slide, cfg["bullets"], + left_in=8.8, top_in=1.2, + width_in=4.3, height_in=5.7, + ) _add_footer(slide, i, len(SLIDES)) OUT.parent.mkdir(parents=True, exist_ok=True) diff --git a/scripts/gemm_sweep.py b/scripts/gemm_sweep.py new file mode 100644 index 0000000..c991ed6 --- /dev/null +++ b/scripts/gemm_sweep.py @@ -0,0 +1,232 @@ +"""Sweep GEMM shapes through kernbench and dump PE_accelerator engine times. + +For each shape: + - run benches.matmul_composite via the same run_bench path the CLI uses + - read result.engine.op_log + - filter to per-PE engines: pe_dma, pe_fetch_store, pe_gemm, pe_math + - record sum-of-durations (engine occupancy) AND wall-clock active interval + +Output: docs/diagrams/gemm_sweep.json +""" +from __future__ import annotations + +import json +import os +import sys +import time +from pathlib import Path + +# Default sweep covering under-tile, single-tile, multi-tile, and asymmetric regimes. +# Each entry is either a single integer (square M=K=N=S) or "MxKxN". +# Override via env: SWEEP_SHAPES="16,32,16x2048x16,..." +DEFAULT_SHAPES = [ + "32x32x32", # 1 tile, K=32 < TILE_K=64 → under-tile in K + "32x64x32", # 1 tile, exact single-tile fit + "32x128x32", # 2 tiles, aligned + "32x128x128", # 8 tiles, aligned + "32x3072x32", # 48 tiles, all K-axis (tall-skinny) + "8x128x128", # 8 tiles, but M=8 < TILE_M=32 → MAC array under-fed + "128x8x128", # 16 tiles, but K=8 < TILE_K=64 → MAC array under-fed + "512", # 2048 tiles, fully aligned — "well-pipelined" reference +] + +# Operand-staging variants exercised per shape. +VARIANTS = ["ref_ref", "load_ref", "load_load"] + +# Engines whose timings we collect (component_id suffix match). +ENGINES = ["pe_dma", "pe_fetch_store", "pe_gemm", "pe_math"] + +# Per-stage breakdown labels (StageType enum names from pe_types.py). +STAGES = ["DMA_READ", "DMA_WRITE", "FETCH", "STORE", "GEMM", "MATH"] + +# Scheduler tile sizes (mirror of PeSchedulerComponent.TILE_M/K/N). +TILE_M, TILE_K, TILE_N = 32, 64, 32 + +OUT_PATH = Path(__file__).parent.parent / "docs" / "diagrams" / "gemm_sweep.json" + + +def _engine_wall_ns(records, suffix: str) -> float: + """Wall-clock interval the engine was active (union of overlapping ops).""" + intervals = [(r.t_start, r.t_end) for r in records + if r.component_id.endswith("." + suffix)] + if not intervals: + return 0.0 + intervals.sort() + merged_end = intervals[0][1] + merged_start = intervals[0][0] + total = 0.0 + for s, e in intervals[1:]: + if s <= merged_end: + merged_end = max(merged_end, e) + else: + total += merged_end - merged_start + merged_start, merged_end = s, e + total += merged_end - merged_start + return total + + +def _engine_occupancy_ns(records, suffix: str) -> float: + return sum(r.t_end - r.t_start for r in records + if r.component_id.endswith("." + suffix)) + + +def _engine_count(records, suffix: str) -> int: + return sum(1 for r in records if r.component_id.endswith("." + suffix)) + + +def _stage_occupancy_ns(records, stage_type: str) -> float: + """Sum t_end - t_start over op_log records whose params.stage_type matches. + + Requires op_log records produced post the TileToken stage_type capture + (sim_engine/op_log.py). + """ + return sum( + r.t_end - r.t_start + for r in records + if r.params.get("stage_type") == stage_type + ) + + +def _stage_wall_ns(records, stage_type: str) -> float: + """Interval-union wall-clock for records whose stage_type matches.""" + intervals = sorted( + (r.t_start, r.t_end) for r in records + if r.params.get("stage_type") == stage_type + ) + if not intervals: + return 0.0 + total = 0.0 + cs, ce = intervals[0] + for s, e in intervals[1:]: + if s <= ce: + ce = max(ce, e) + else: + total += ce - cs + cs, ce = s, e + total += ce - cs + return total + + +def _stage_count(records, stage_type: str) -> int: + return sum(1 for r in records if r.params.get("stage_type") == stage_type) + + +def _run_one(M: int, K: int, N: int, topology: str, variant: str = "ref_ref") -> dict: + os.environ["MATMUL_M"] = str(M) + os.environ["MATMUL_K"] = str(K) + os.environ["MATMUL_N"] = str(N) + os.environ["MATMUL_VARIANT"] = variant + + # Late imports so env vars are read by benches/matmul_composite at module load. + # Force re-import to pick up new env values. + for mod_name in [m for m in list(sys.modules) if m.startswith("benches.matmul_composite")]: + del sys.modules[mod_name] + + from benches.loader import resolve_bench + from kernbench.runtime_api.bench_runner import run_bench + from kernbench.runtime_api.types import resolve_device + from kernbench.sim_engine.engine import GraphEngine + from kernbench.topology.builder import resolve_topology + + topo = resolve_topology(topology) + bench = resolve_bench("matmul_composite") + device = resolve_device(None) + + t0 = time.time() + result = run_bench( + topology=topo, bench_fn=bench, device=device, + engine_factory=lambda t, d: GraphEngine( + getattr(t, "topology_obj", t), enable_data=True, + ), + ) + wall = time.time() - t0 + + op_log = result.engine.op_log + if not result.completion.ok: + raise RuntimeError(f"bench failed at M={M},K={K},N={N}: {result.completion}") + + # Bytes touched at f16 (2 B): full A + full B + full out (each operand + # streamed once through HBM by the composite plan). + bytes_total = (M * K + K * N + M * N) * 2 + row = { + "M": M, "K": K, "N": N, + "variant": variant, + "flops": 2 * M * K * N, + "bytes_hbm": bytes_total, + "arith_intensity": (2 * M * K * N) / bytes_total, # flops/byte + "tile_count_expected": _ceil(M, TILE_M) * _ceil(N, TILE_N) * _ceil(K, TILE_K), + "sim_wall_clock_s": round(wall, 3), + "engines": {}, + } + for eng in ENGINES: + row["engines"][eng] = { + "occupancy_ns": _engine_occupancy_ns(op_log, eng), + "wall_ns": _engine_wall_ns(op_log, eng), + "record_count": _engine_count(op_log, eng), + } + row["stages"] = {} + for stage in STAGES: + row["stages"][stage] = { + "occupancy_ns": _stage_occupancy_ns(op_log, stage), + "wall_ns": _stage_wall_ns(op_log, stage), + "record_count": _stage_count(op_log, stage), + } + # Kernel-window wall-clock = max t_end - min t_start over PE engine records. + pe_records = [r for r in op_log + if any(r.component_id.endswith("." + e) for e in ENGINES)] + if pe_records: + row["pe_window_ns"] = max(r.t_end for r in pe_records) \ + - min(r.t_start for r in pe_records) + else: + row["pe_window_ns"] = 0.0 + return row + + +def _ceil(a: int, b: int) -> int: + return (a + b - 1) // b + + +def main() -> int: + shapes_env = os.environ.get("SWEEP_SHAPES") + raw = (shapes_env.split(",") if shapes_env else DEFAULT_SHAPES) + shapes: list[tuple[int, int, int]] = [] + for s in raw: + s = s.strip() + if not s: + continue + if "x" in s.lower(): + parts = s.lower().split("x") + shapes.append((int(parts[0]), int(parts[1]), int(parts[2]))) + else: + v = int(s) + shapes.append((v, v, v)) + topology = os.environ.get("SWEEP_TOPOLOGY", "topology.yaml") + + rows = [] + for M, K, N in shapes: + for variant in VARIANTS: + print(f"[sweep] M={M} K={K} N={N} variant={variant} ...", flush=True) + row = _run_one(M, K, N, topology, variant=variant) + rows.append(row) + eng_dma = row["engines"]["pe_dma"] + eng_gem = row["engines"]["pe_gemm"] + print(f" tiles={row['tile_count_expected']:>6} " + f"pe_window={row['pe_window_ns']:8.1f}ns " + f"dma_occ={eng_dma['occupancy_ns']:9.1f} " + f"gemm_occ={eng_gem['occupancy_ns']:8.1f} " + f"(sim {row['sim_wall_clock_s']:.1f}s)") + + OUT_PATH.parent.mkdir(parents=True, exist_ok=True) + OUT_PATH.write_text(json.dumps({ + "tile_sizes": {"M": TILE_M, "K": TILE_K, "N": TILE_N}, + "engines": ENGINES, + "stages": STAGES, + "variants": VARIANTS, + "rows": rows, + }, indent=2)) + print(f"\n[sweep] wrote {OUT_PATH}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/kernbench/common/pe_commands.py b/src/kernbench/common/pe_commands.py index e70c367..72c245e 100644 --- a/src/kernbench/common/pe_commands.py +++ b/src/kernbench/common/pe_commands.py @@ -34,6 +34,7 @@ class TensorHandle: nbytes: int # total byte size data: object = None # reserved for validate mode space: str = "tcm" # MemoryStore space ("tcm" | "hbm" | "sram") + pinned: bool = False # operand already DMA-staged in TCM (via tl.load) @dataclass(frozen=True) diff --git a/src/kernbench/components/builtin/pe_scheduler.py b/src/kernbench/components/builtin/pe_scheduler.py index ea07ee2..4f50dd5 100644 --- a/src/kernbench/components/builtin/pe_scheduler.py +++ b/src/kernbench/components/builtin/pe_scheduler.py @@ -163,6 +163,8 @@ class PeSchedulerComponent(ComponentBase): bytes_per_element=bpe, A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr, pe_prefix=pp, + a_pinned=getattr(a, "pinned", False), + b_pinned=getattr(b, "pinned", False), ) else: # Math composite diff --git a/src/kernbench/components/builtin/tiling.py b/src/kernbench/components/builtin/tiling.py index 4ee63ad..321da19 100644 --- a/src/kernbench/components/builtin/tiling.py +++ b/src/kernbench/components/builtin/tiling.py @@ -21,15 +21,22 @@ def generate_gemm_plan( bytes_per_element: int, A_addr: int, B_addr: int, C_addr: int, pe_prefix: str, + a_pinned: bool = False, + b_pinned: bool = False, ) -> PipelinePlan: """Generate GEMM tile plan: M→N→K order. Each tile follows stage sequence: - DMA_READ(A) → DMA_READ(B) → FETCH → GEMM → STORE - On last K-tile per (m,n): → DMA_WRITE + [DMA_READ(A)] → [DMA_READ(B)] → FETCH → GEMM → [STORE → DMA_WRITE] + DMA_READ(A) skipped when a_pinned=True (operand pre-staged in TCM). + DMA_READ(B) skipped when b_pinned=True. + STORE + DMA_WRITE only emitted on last K-tile per (m,n) — accumulator + stays in RegFile across K loop. Args: pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs. + a_pinned: A operand already resident in TCM (via prior tl.load). + b_pinned: B operand already resident in TCM. """ M_tiles = max(1, ceil(M / tile_m)) K_tiles = max(1, ceil(K / tile_k)) @@ -58,23 +65,26 @@ def generate_gemm_plan( stages: list[Stage] = [] - # DMA READ: load A and B tiles from HBM → TCM - stages.append(Stage( - stage_type=StageType.DMA_READ, - component=dma_id, - params={ - "src_addr": a_addr, "nbytes": a_bytes, - "operand": "A", "tile_m": tile_m, "tile_k": tile_k, - }, - )) - stages.append(Stage( - stage_type=StageType.DMA_READ, - component=dma_id, - params={ - "src_addr": b_addr, "nbytes": b_bytes, - "operand": "B", "tile_k": tile_k, "tile_n": tile_n, - }, - )) + # DMA READ: load A and B tiles from HBM → TCM. + # Skip if the operand is already pre-staged via tl.load. + if not a_pinned: + stages.append(Stage( + stage_type=StageType.DMA_READ, + component=dma_id, + params={ + "src_addr": a_addr, "nbytes": a_bytes, + "operand": "A", "tile_m": tile_m, "tile_k": tile_k, + }, + )) + if not b_pinned: + stages.append(Stage( + stage_type=StageType.DMA_READ, + component=dma_id, + params={ + "src_addr": b_addr, "nbytes": b_bytes, + "operand": "B", "tile_k": tile_k, "tile_n": tile_n, + }, + )) # FETCH: TCM → Register File stages.append(Stage( @@ -96,18 +106,17 @@ def generate_gemm_plan( }, )) - # STORE: Register File → TCM - stages.append(Stage( - stage_type=StageType.STORE, - component=fetch_id, - params={ - "direction": "write", - "nbytes": out_bytes, - }, - )) - - # DMA WRITE: TCM → HBM (only on last K-tile) + # STORE + DMA_WRITE only on last K-tile per (m,n). The C + # accumulator stays in RegFile across the K loop. if last_k: + stages.append(Stage( + stage_type=StageType.STORE, + component=fetch_id, + params={ + "direction": "write", + "nbytes": out_bytes, + }, + )) stages.append(Stage( stage_type=StageType.DMA_WRITE, component=dma_id, diff --git a/src/kernbench/sim_engine/op_log.py b/src/kernbench/sim_engine/op_log.py index 20f9da1..9e083c5 100644 --- a/src/kernbench/sim_engine/op_log.py +++ b/src/kernbench/sim_engine/op_log.py @@ -44,11 +44,25 @@ class OpLogger: return self._records def record_start(self, t: float, component_id: str, msg: Any) -> None: - """Called by ComponentBase._on_process_start.""" + """Called by ComponentBase._on_process_start. + + Snapshots TileToken stage_type at start time so we can attribute the + record correctly even if the token advances stage_idx before + record_end fires. + """ + snap: dict[str, Any] = {} + # TileToken (ADR-0021 pipeline) — capture which stage this is. + try: + stage = getattr(msg, "current_stage", None) + if stage is not None: + snap["stage_type"] = stage.stage_type.name + except Exception: + pass self._pending[id(msg)] = { "t_start": t, "component_id": component_id, "msg": msg, + "snap": snap, } def record_end(self, t: float, component_id: str, msg: Any) -> None: @@ -57,6 +71,16 @@ class OpLogger: if pending is None: return op_kind, op_name, params = _extract_op_info(msg) + # Merge TileToken stage_type captured at record_start into params, + # and reflect it in op_name so reporting can disambiguate + # DMA_READ vs DMA_WRITE and FETCH vs STORE on the same component. + snap = pending.get("snap", {}) + stage_type = snap.get("stage_type") + if stage_type is not None: + params = dict(params) + params["stage_type"] = stage_type + if op_name == "TileToken": + op_name = f"TileToken/{stage_type}" # Snapshot data at record time so Phase 2 replay sidesteps # downstream mutations of source addrs (e.g. a tl.store that # overwrites HBM after a load handle was sent, or a slot that diff --git a/src/kernbench/triton_emu/tl_context.py b/src/kernbench/triton_emu/tl_context.py index e35ccbe..bc1984d 100644 --- a/src/kernbench/triton_emu/tl_context.py +++ b/src/kernbench/triton_emu/tl_context.py @@ -123,13 +123,14 @@ class TLContext: def _make_handle( self, addr: int, shape: tuple[int, ...], dtype: str, - space: str = "tcm", + space: str = "tcm", pinned: bool = False, ) -> TensorHandle: return TensorHandle( id=self._next_handle_id(), addr=addr, shape=shape, dtype=dtype, nbytes=self._nbytes(shape, dtype), space=space, + pinned=pinned, ) def _make_compute_out( @@ -184,15 +185,17 @@ class TLContext: actually lives in Phase 2 storage. """ self._emit_dispatch_overhead() - handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype, space="hbm") + handle = self._make_handle( + addr=ptr, shape=shape, dtype=dtype, space="hbm", pinned=True, + ) cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes) data = self._emit(cmd) if data is not None: - # Greenlet mode: attach real data to handle (preserve space) + # Greenlet mode: attach real data to handle (preserve space + pinned) return TensorHandle( id=handle.id, addr=handle.addr, shape=handle.shape, dtype=handle.dtype, nbytes=handle.nbytes, data=data, - space=handle.space, + space=handle.space, pinned=handle.pinned, ) return handle diff --git a/tests/test_pe_pipeline.py b/tests/test_pe_pipeline.py index 4d2dd35..cb7d7a6 100644 --- a/tests/test_pe_pipeline.py +++ b/tests/test_pe_pipeline.py @@ -150,7 +150,11 @@ def test_gemm_plan_stage_sequence(): def test_gemm_plan_intermediate_k_no_dma_write(): - """Intermediate K-tiles don't have DMA_WRITE stage.""" + """Intermediate K-tiles don't have DMA_WRITE or STORE stage. + + The C accumulator stays in RegFile across the K loop; STORE + + DMA_WRITE only fire on the last K-tile per (m,n). + """ from kernbench.components.builtin.tiling import generate_gemm_plan plan = generate_gemm_plan( @@ -162,15 +166,72 @@ def test_gemm_plan_intermediate_k_no_dma_write(): ) assert len(plan.tiles) == 2 - # First tile (k=0): no DMA_WRITE + # First tile (k=0): no STORE, no DMA_WRITE — accumulator stays in RegFile t0_types = [s.stage_type for s in plan.tiles[0].stages] + assert StageType.STORE not in t0_types assert StageType.DMA_WRITE not in t0_types - # Last tile (k=1, last_k=True): has DMA_WRITE + # Last tile (k=1, last_k=True): has both STORE and DMA_WRITE t1_types = [s.stage_type for s in plan.tiles[1].stages] + assert StageType.STORE in t1_types assert StageType.DMA_WRITE in t1_types +def test_gemm_plan_pinned_operand_skips_dma_read(): + """When a_pinned=True, A's per-tile DMA_READ is omitted. + + Same for b_pinned. FETCH is unaffected — it still stages from TCM + into RegFile. + """ + from kernbench.components.builtin.tiling import generate_gemm_plan + + # Baseline: neither pinned — both A and B get DMA_READ per tile. + base = generate_gemm_plan( + M=32, K=128, N=32, # K_tiles=2 + tile_m=32, tile_k=64, tile_n=32, + bytes_per_element=2, + A_addr=0, B_addr=0x1000, C_addr=0x2000, + pe_prefix="sip0.cube0.pe0", + ) + for tile in base.tiles: + operands = [s.params.get("operand") for s in tile.stages + if s.stage_type == StageType.DMA_READ] + assert operands == ["A", "B"], \ + f"baseline tile should DMA_READ A and B, got {operands}" + + # a_pinned: no A DMA_READ. + plan_a = generate_gemm_plan( + M=32, K=128, N=32, + tile_m=32, tile_k=64, tile_n=32, + bytes_per_element=2, + A_addr=0, B_addr=0x1000, C_addr=0x2000, + pe_prefix="sip0.cube0.pe0", + a_pinned=True, + ) + for tile in plan_a.tiles: + operands = [s.params.get("operand") for s in tile.stages + if s.stage_type == StageType.DMA_READ] + assert operands == ["B"], \ + f"a_pinned should leave only B DMA_READ, got {operands}" + # FETCH must still exist + assert any(s.stage_type == StageType.FETCH for s in tile.stages) + + # Both pinned: no DMA_READ at all. + plan_both = generate_gemm_plan( + M=32, K=128, N=32, + tile_m=32, tile_k=64, tile_n=32, + bytes_per_element=2, + A_addr=0, B_addr=0x1000, C_addr=0x2000, + pe_prefix="sip0.cube0.pe0", + a_pinned=True, b_pinned=True, + ) + for tile in plan_both.tiles: + dma_reads = [s for s in tile.stages + if s.stage_type == StageType.DMA_READ] + assert dma_reads == [], \ + f"both pinned should skip all DMA_READ, got {dma_reads}" + + def test_math_plan_stage_sequence(): """Math plan has READ→FETCH→MATH→STORE→WRITE sequence.""" from kernbench.components.builtin.tiling import generate_math_plan