Composite GEMM: K-loop accumulator residency, pinned operands, sweep + deck
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,69 @@
|
|||||||
|
"""Single-PE composite GEMM for PE_accelerator perf characterization.
|
||||||
|
|
||||||
|
Three operand-staging variants are selectable via MATMUL_VARIANT:
|
||||||
|
|
||||||
|
- "ref_ref" (default): a = tl.ref, b = tl.ref
|
||||||
|
Both operands HBM-resident; scheduler streams per-tile DMA.
|
||||||
|
- "load_ref": a = tl.load, b = tl.ref
|
||||||
|
A eagerly DMA'd into TCM up-front; B streamed per-tile.
|
||||||
|
- "load_load": a = tl.load, b = tl.load
|
||||||
|
Both eagerly DMA'd into TCM up-front.
|
||||||
|
|
||||||
|
Other env vars: MATMUL_M, MATMUL_K, MATMUL_N, MATMUL_DTYPE.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
MATMUL_M=256 MATMUL_K=256 MATMUL_N=256 MATMUL_VARIANT=load_ref \
|
||||||
|
kernbench run --topology topology.yaml --bench matmul_composite
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from kernbench.policy.placement.dp import DPPolicy
|
||||||
|
|
||||||
|
M = int(os.environ.get("MATMUL_M", "256"))
|
||||||
|
K = int(os.environ.get("MATMUL_K", "256"))
|
||||||
|
N = int(os.environ.get("MATMUL_N", "256"))
|
||||||
|
DTYPE = os.environ.get("MATMUL_DTYPE", "f16")
|
||||||
|
VARIANT = os.environ.get("MATMUL_VARIANT", "ref_ref")
|
||||||
|
|
||||||
|
|
||||||
|
def _kernel_ref_ref(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
||||||
|
M, K, N = int(M), int(K), int(N)
|
||||||
|
a = tl.ref(int(a_ptr), shape=(M, K), dtype=DTYPE)
|
||||||
|
b = tl.ref(int(b_ptr), shape=(K, N), dtype=DTYPE)
|
||||||
|
h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr))
|
||||||
|
tl.wait(h)
|
||||||
|
|
||||||
|
|
||||||
|
def _kernel_load_ref(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
||||||
|
M, K, N = int(M), int(K), int(N)
|
||||||
|
a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE)
|
||||||
|
b = tl.ref(int(b_ptr), shape=(K, N), dtype=DTYPE)
|
||||||
|
h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr))
|
||||||
|
tl.wait(h)
|
||||||
|
|
||||||
|
|
||||||
|
def _kernel_load_load(a_ptr, b_ptr, out_ptr, M, K, N, tl, DTYPE="f16"):
|
||||||
|
M, K, N = int(M), int(K), int(N)
|
||||||
|
a = tl.load(int(a_ptr), shape=(M, K), dtype=DTYPE)
|
||||||
|
b = tl.load(int(b_ptr), shape=(K, N), dtype=DTYPE)
|
||||||
|
h = tl.composite(op="gemm", a=a, b=b, out_ptr=int(out_ptr))
|
||||||
|
tl.wait(h)
|
||||||
|
|
||||||
|
|
||||||
|
_KERNELS = {
|
||||||
|
"ref_ref": _kernel_ref_ref,
|
||||||
|
"load_ref": _kernel_load_ref,
|
||||||
|
"load_load": _kernel_load_load,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run(torch):
|
||||||
|
if VARIANT not in _KERNELS:
|
||||||
|
raise ValueError(f"unknown MATMUL_VARIANT={VARIANT!r}; "
|
||||||
|
f"expected one of {list(_KERNELS)}")
|
||||||
|
kernel_fn = _KERNELS[VARIANT]
|
||||||
|
dp = DPPolicy(cube="replicate", pe="replicate", num_cubes=1, num_pes=1)
|
||||||
|
a = torch.empty((M, K), dtype=DTYPE, dp=dp, name="a")
|
||||||
|
b = torch.empty((K, N), dtype=DTYPE, dp=dp, name="b")
|
||||||
|
out = torch.empty((M, N), dtype=DTYPE, dp=dp, name="out")
|
||||||
|
torch.launch(f"matmul_composite_{VARIANT}", kernel_fn, a, b, out, M, K, N)
|
||||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
+2168
-13
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,232 @@
|
|||||||
|
"""Sweep GEMM shapes through kernbench and dump PE_accelerator engine times.
|
||||||
|
|
||||||
|
For each shape:
|
||||||
|
- run benches.matmul_composite via the same run_bench path the CLI uses
|
||||||
|
- read result.engine.op_log
|
||||||
|
- filter to per-PE engines: pe_dma, pe_fetch_store, pe_gemm, pe_math
|
||||||
|
- record sum-of-durations (engine occupancy) AND wall-clock active interval
|
||||||
|
|
||||||
|
Output: docs/diagrams/gemm_sweep.json
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Default sweep covering under-tile, single-tile, multi-tile, and asymmetric regimes.
|
||||||
|
# Each entry is either a single integer (square M=K=N=S) or "MxKxN".
|
||||||
|
# Override via env: SWEEP_SHAPES="16,32,16x2048x16,..."
|
||||||
|
DEFAULT_SHAPES = [
|
||||||
|
"32x32x32", # 1 tile, K=32 < TILE_K=64 → under-tile in K
|
||||||
|
"32x64x32", # 1 tile, exact single-tile fit
|
||||||
|
"32x128x32", # 2 tiles, aligned
|
||||||
|
"32x128x128", # 8 tiles, aligned
|
||||||
|
"32x3072x32", # 48 tiles, all K-axis (tall-skinny)
|
||||||
|
"8x128x128", # 8 tiles, but M=8 < TILE_M=32 → MAC array under-fed
|
||||||
|
"128x8x128", # 16 tiles, but K=8 < TILE_K=64 → MAC array under-fed
|
||||||
|
"512", # 2048 tiles, fully aligned — "well-pipelined" reference
|
||||||
|
]
|
||||||
|
|
||||||
|
# Operand-staging variants exercised per shape.
|
||||||
|
VARIANTS = ["ref_ref", "load_ref", "load_load"]
|
||||||
|
|
||||||
|
# Engines whose timings we collect (component_id suffix match).
|
||||||
|
ENGINES = ["pe_dma", "pe_fetch_store", "pe_gemm", "pe_math"]
|
||||||
|
|
||||||
|
# Per-stage breakdown labels (StageType enum names from pe_types.py).
|
||||||
|
STAGES = ["DMA_READ", "DMA_WRITE", "FETCH", "STORE", "GEMM", "MATH"]
|
||||||
|
|
||||||
|
# Scheduler tile sizes (mirror of PeSchedulerComponent.TILE_M/K/N).
|
||||||
|
TILE_M, TILE_K, TILE_N = 32, 64, 32
|
||||||
|
|
||||||
|
OUT_PATH = Path(__file__).parent.parent / "docs" / "diagrams" / "gemm_sweep.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _engine_wall_ns(records, suffix: str) -> float:
|
||||||
|
"""Wall-clock interval the engine was active (union of overlapping ops)."""
|
||||||
|
intervals = [(r.t_start, r.t_end) for r in records
|
||||||
|
if r.component_id.endswith("." + suffix)]
|
||||||
|
if not intervals:
|
||||||
|
return 0.0
|
||||||
|
intervals.sort()
|
||||||
|
merged_end = intervals[0][1]
|
||||||
|
merged_start = intervals[0][0]
|
||||||
|
total = 0.0
|
||||||
|
for s, e in intervals[1:]:
|
||||||
|
if s <= merged_end:
|
||||||
|
merged_end = max(merged_end, e)
|
||||||
|
else:
|
||||||
|
total += merged_end - merged_start
|
||||||
|
merged_start, merged_end = s, e
|
||||||
|
total += merged_end - merged_start
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
def _engine_occupancy_ns(records, suffix: str) -> float:
|
||||||
|
return sum(r.t_end - r.t_start for r in records
|
||||||
|
if r.component_id.endswith("." + suffix))
|
||||||
|
|
||||||
|
|
||||||
|
def _engine_count(records, suffix: str) -> int:
|
||||||
|
return sum(1 for r in records if r.component_id.endswith("." + suffix))
|
||||||
|
|
||||||
|
|
||||||
|
def _stage_occupancy_ns(records, stage_type: str) -> float:
|
||||||
|
"""Sum t_end - t_start over op_log records whose params.stage_type matches.
|
||||||
|
|
||||||
|
Requires op_log records produced post the TileToken stage_type capture
|
||||||
|
(sim_engine/op_log.py).
|
||||||
|
"""
|
||||||
|
return sum(
|
||||||
|
r.t_end - r.t_start
|
||||||
|
for r in records
|
||||||
|
if r.params.get("stage_type") == stage_type
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _stage_wall_ns(records, stage_type: str) -> float:
|
||||||
|
"""Interval-union wall-clock for records whose stage_type matches."""
|
||||||
|
intervals = sorted(
|
||||||
|
(r.t_start, r.t_end) for r in records
|
||||||
|
if r.params.get("stage_type") == stage_type
|
||||||
|
)
|
||||||
|
if not intervals:
|
||||||
|
return 0.0
|
||||||
|
total = 0.0
|
||||||
|
cs, ce = intervals[0]
|
||||||
|
for s, e in intervals[1:]:
|
||||||
|
if s <= ce:
|
||||||
|
ce = max(ce, e)
|
||||||
|
else:
|
||||||
|
total += ce - cs
|
||||||
|
cs, ce = s, e
|
||||||
|
total += ce - cs
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
def _stage_count(records, stage_type: str) -> int:
|
||||||
|
return sum(1 for r in records if r.params.get("stage_type") == stage_type)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_one(M: int, K: int, N: int, topology: str, variant: str = "ref_ref") -> dict:
|
||||||
|
os.environ["MATMUL_M"] = str(M)
|
||||||
|
os.environ["MATMUL_K"] = str(K)
|
||||||
|
os.environ["MATMUL_N"] = str(N)
|
||||||
|
os.environ["MATMUL_VARIANT"] = variant
|
||||||
|
|
||||||
|
# Late imports so env vars are read by benches/matmul_composite at module load.
|
||||||
|
# Force re-import to pick up new env values.
|
||||||
|
for mod_name in [m for m in list(sys.modules) if m.startswith("benches.matmul_composite")]:
|
||||||
|
del sys.modules[mod_name]
|
||||||
|
|
||||||
|
from benches.loader import resolve_bench
|
||||||
|
from kernbench.runtime_api.bench_runner import run_bench
|
||||||
|
from kernbench.runtime_api.types import resolve_device
|
||||||
|
from kernbench.sim_engine.engine import GraphEngine
|
||||||
|
from kernbench.topology.builder import resolve_topology
|
||||||
|
|
||||||
|
topo = resolve_topology(topology)
|
||||||
|
bench = resolve_bench("matmul_composite")
|
||||||
|
device = resolve_device(None)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
result = run_bench(
|
||||||
|
topology=topo, bench_fn=bench, device=device,
|
||||||
|
engine_factory=lambda t, d: GraphEngine(
|
||||||
|
getattr(t, "topology_obj", t), enable_data=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
wall = time.time() - t0
|
||||||
|
|
||||||
|
op_log = result.engine.op_log
|
||||||
|
if not result.completion.ok:
|
||||||
|
raise RuntimeError(f"bench failed at M={M},K={K},N={N}: {result.completion}")
|
||||||
|
|
||||||
|
# Bytes touched at f16 (2 B): full A + full B + full out (each operand
|
||||||
|
# streamed once through HBM by the composite plan).
|
||||||
|
bytes_total = (M * K + K * N + M * N) * 2
|
||||||
|
row = {
|
||||||
|
"M": M, "K": K, "N": N,
|
||||||
|
"variant": variant,
|
||||||
|
"flops": 2 * M * K * N,
|
||||||
|
"bytes_hbm": bytes_total,
|
||||||
|
"arith_intensity": (2 * M * K * N) / bytes_total, # flops/byte
|
||||||
|
"tile_count_expected": _ceil(M, TILE_M) * _ceil(N, TILE_N) * _ceil(K, TILE_K),
|
||||||
|
"sim_wall_clock_s": round(wall, 3),
|
||||||
|
"engines": {},
|
||||||
|
}
|
||||||
|
for eng in ENGINES:
|
||||||
|
row["engines"][eng] = {
|
||||||
|
"occupancy_ns": _engine_occupancy_ns(op_log, eng),
|
||||||
|
"wall_ns": _engine_wall_ns(op_log, eng),
|
||||||
|
"record_count": _engine_count(op_log, eng),
|
||||||
|
}
|
||||||
|
row["stages"] = {}
|
||||||
|
for stage in STAGES:
|
||||||
|
row["stages"][stage] = {
|
||||||
|
"occupancy_ns": _stage_occupancy_ns(op_log, stage),
|
||||||
|
"wall_ns": _stage_wall_ns(op_log, stage),
|
||||||
|
"record_count": _stage_count(op_log, stage),
|
||||||
|
}
|
||||||
|
# Kernel-window wall-clock = max t_end - min t_start over PE engine records.
|
||||||
|
pe_records = [r for r in op_log
|
||||||
|
if any(r.component_id.endswith("." + e) for e in ENGINES)]
|
||||||
|
if pe_records:
|
||||||
|
row["pe_window_ns"] = max(r.t_end for r in pe_records) \
|
||||||
|
- min(r.t_start for r in pe_records)
|
||||||
|
else:
|
||||||
|
row["pe_window_ns"] = 0.0
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
|
def _ceil(a: int, b: int) -> int:
|
||||||
|
return (a + b - 1) // b
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
shapes_env = os.environ.get("SWEEP_SHAPES")
|
||||||
|
raw = (shapes_env.split(",") if shapes_env else DEFAULT_SHAPES)
|
||||||
|
shapes: list[tuple[int, int, int]] = []
|
||||||
|
for s in raw:
|
||||||
|
s = s.strip()
|
||||||
|
if not s:
|
||||||
|
continue
|
||||||
|
if "x" in s.lower():
|
||||||
|
parts = s.lower().split("x")
|
||||||
|
shapes.append((int(parts[0]), int(parts[1]), int(parts[2])))
|
||||||
|
else:
|
||||||
|
v = int(s)
|
||||||
|
shapes.append((v, v, v))
|
||||||
|
topology = os.environ.get("SWEEP_TOPOLOGY", "topology.yaml")
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for M, K, N in shapes:
|
||||||
|
for variant in VARIANTS:
|
||||||
|
print(f"[sweep] M={M} K={K} N={N} variant={variant} ...", flush=True)
|
||||||
|
row = _run_one(M, K, N, topology, variant=variant)
|
||||||
|
rows.append(row)
|
||||||
|
eng_dma = row["engines"]["pe_dma"]
|
||||||
|
eng_gem = row["engines"]["pe_gemm"]
|
||||||
|
print(f" tiles={row['tile_count_expected']:>6} "
|
||||||
|
f"pe_window={row['pe_window_ns']:8.1f}ns "
|
||||||
|
f"dma_occ={eng_dma['occupancy_ns']:9.1f} "
|
||||||
|
f"gemm_occ={eng_gem['occupancy_ns']:8.1f} "
|
||||||
|
f"(sim {row['sim_wall_clock_s']:.1f}s)")
|
||||||
|
|
||||||
|
OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
OUT_PATH.write_text(json.dumps({
|
||||||
|
"tile_sizes": {"M": TILE_M, "K": TILE_K, "N": TILE_N},
|
||||||
|
"engines": ENGINES,
|
||||||
|
"stages": STAGES,
|
||||||
|
"variants": VARIANTS,
|
||||||
|
"rows": rows,
|
||||||
|
}, indent=2))
|
||||||
|
print(f"\n[sweep] wrote {OUT_PATH}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
@@ -34,6 +34,7 @@ class TensorHandle:
|
|||||||
nbytes: int # total byte size
|
nbytes: int # total byte size
|
||||||
data: object = None # reserved for validate mode
|
data: object = None # reserved for validate mode
|
||||||
space: str = "tcm" # MemoryStore space ("tcm" | "hbm" | "sram")
|
space: str = "tcm" # MemoryStore space ("tcm" | "hbm" | "sram")
|
||||||
|
pinned: bool = False # operand already DMA-staged in TCM (via tl.load)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|||||||
@@ -163,6 +163,8 @@ class PeSchedulerComponent(ComponentBase):
|
|||||||
bytes_per_element=bpe,
|
bytes_per_element=bpe,
|
||||||
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
|
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
|
||||||
pe_prefix=pp,
|
pe_prefix=pp,
|
||||||
|
a_pinned=getattr(a, "pinned", False),
|
||||||
|
b_pinned=getattr(b, "pinned", False),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Math composite
|
# Math composite
|
||||||
|
|||||||
@@ -21,15 +21,22 @@ def generate_gemm_plan(
|
|||||||
bytes_per_element: int,
|
bytes_per_element: int,
|
||||||
A_addr: int, B_addr: int, C_addr: int,
|
A_addr: int, B_addr: int, C_addr: int,
|
||||||
pe_prefix: str,
|
pe_prefix: str,
|
||||||
|
a_pinned: bool = False,
|
||||||
|
b_pinned: bool = False,
|
||||||
) -> PipelinePlan:
|
) -> PipelinePlan:
|
||||||
"""Generate GEMM tile plan: M→N→K order.
|
"""Generate GEMM tile plan: M→N→K order.
|
||||||
|
|
||||||
Each tile follows stage sequence:
|
Each tile follows stage sequence:
|
||||||
DMA_READ(A) → DMA_READ(B) → FETCH → GEMM → STORE
|
[DMA_READ(A)] → [DMA_READ(B)] → FETCH → GEMM → [STORE → DMA_WRITE]
|
||||||
On last K-tile per (m,n): → DMA_WRITE
|
DMA_READ(A) skipped when a_pinned=True (operand pre-staged in TCM).
|
||||||
|
DMA_READ(B) skipped when b_pinned=True.
|
||||||
|
STORE + DMA_WRITE only emitted on last K-tile per (m,n) — accumulator
|
||||||
|
stays in RegFile across K loop.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs.
|
pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs.
|
||||||
|
a_pinned: A operand already resident in TCM (via prior tl.load).
|
||||||
|
b_pinned: B operand already resident in TCM.
|
||||||
"""
|
"""
|
||||||
M_tiles = max(1, ceil(M / tile_m))
|
M_tiles = max(1, ceil(M / tile_m))
|
||||||
K_tiles = max(1, ceil(K / tile_k))
|
K_tiles = max(1, ceil(K / tile_k))
|
||||||
@@ -58,23 +65,26 @@ def generate_gemm_plan(
|
|||||||
|
|
||||||
stages: list[Stage] = []
|
stages: list[Stage] = []
|
||||||
|
|
||||||
# DMA READ: load A and B tiles from HBM → TCM
|
# DMA READ: load A and B tiles from HBM → TCM.
|
||||||
stages.append(Stage(
|
# Skip if the operand is already pre-staged via tl.load.
|
||||||
stage_type=StageType.DMA_READ,
|
if not a_pinned:
|
||||||
component=dma_id,
|
stages.append(Stage(
|
||||||
params={
|
stage_type=StageType.DMA_READ,
|
||||||
"src_addr": a_addr, "nbytes": a_bytes,
|
component=dma_id,
|
||||||
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
|
params={
|
||||||
},
|
"src_addr": a_addr, "nbytes": a_bytes,
|
||||||
))
|
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
|
||||||
stages.append(Stage(
|
},
|
||||||
stage_type=StageType.DMA_READ,
|
))
|
||||||
component=dma_id,
|
if not b_pinned:
|
||||||
params={
|
stages.append(Stage(
|
||||||
"src_addr": b_addr, "nbytes": b_bytes,
|
stage_type=StageType.DMA_READ,
|
||||||
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
|
component=dma_id,
|
||||||
},
|
params={
|
||||||
))
|
"src_addr": b_addr, "nbytes": b_bytes,
|
||||||
|
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
# FETCH: TCM → Register File
|
# FETCH: TCM → Register File
|
||||||
stages.append(Stage(
|
stages.append(Stage(
|
||||||
@@ -96,18 +106,17 @@ def generate_gemm_plan(
|
|||||||
},
|
},
|
||||||
))
|
))
|
||||||
|
|
||||||
# STORE: Register File → TCM
|
# STORE + DMA_WRITE only on last K-tile per (m,n). The C
|
||||||
stages.append(Stage(
|
# accumulator stays in RegFile across the K loop.
|
||||||
stage_type=StageType.STORE,
|
|
||||||
component=fetch_id,
|
|
||||||
params={
|
|
||||||
"direction": "write",
|
|
||||||
"nbytes": out_bytes,
|
|
||||||
},
|
|
||||||
))
|
|
||||||
|
|
||||||
# DMA WRITE: TCM → HBM (only on last K-tile)
|
|
||||||
if last_k:
|
if last_k:
|
||||||
|
stages.append(Stage(
|
||||||
|
stage_type=StageType.STORE,
|
||||||
|
component=fetch_id,
|
||||||
|
params={
|
||||||
|
"direction": "write",
|
||||||
|
"nbytes": out_bytes,
|
||||||
|
},
|
||||||
|
))
|
||||||
stages.append(Stage(
|
stages.append(Stage(
|
||||||
stage_type=StageType.DMA_WRITE,
|
stage_type=StageType.DMA_WRITE,
|
||||||
component=dma_id,
|
component=dma_id,
|
||||||
|
|||||||
@@ -44,11 +44,25 @@ class OpLogger:
|
|||||||
return self._records
|
return self._records
|
||||||
|
|
||||||
def record_start(self, t: float, component_id: str, msg: Any) -> None:
|
def record_start(self, t: float, component_id: str, msg: Any) -> None:
|
||||||
"""Called by ComponentBase._on_process_start."""
|
"""Called by ComponentBase._on_process_start.
|
||||||
|
|
||||||
|
Snapshots TileToken stage_type at start time so we can attribute the
|
||||||
|
record correctly even if the token advances stage_idx before
|
||||||
|
record_end fires.
|
||||||
|
"""
|
||||||
|
snap: dict[str, Any] = {}
|
||||||
|
# TileToken (ADR-0021 pipeline) — capture which stage this is.
|
||||||
|
try:
|
||||||
|
stage = getattr(msg, "current_stage", None)
|
||||||
|
if stage is not None:
|
||||||
|
snap["stage_type"] = stage.stage_type.name
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self._pending[id(msg)] = {
|
self._pending[id(msg)] = {
|
||||||
"t_start": t,
|
"t_start": t,
|
||||||
"component_id": component_id,
|
"component_id": component_id,
|
||||||
"msg": msg,
|
"msg": msg,
|
||||||
|
"snap": snap,
|
||||||
}
|
}
|
||||||
|
|
||||||
def record_end(self, t: float, component_id: str, msg: Any) -> None:
|
def record_end(self, t: float, component_id: str, msg: Any) -> None:
|
||||||
@@ -57,6 +71,16 @@ class OpLogger:
|
|||||||
if pending is None:
|
if pending is None:
|
||||||
return
|
return
|
||||||
op_kind, op_name, params = _extract_op_info(msg)
|
op_kind, op_name, params = _extract_op_info(msg)
|
||||||
|
# Merge TileToken stage_type captured at record_start into params,
|
||||||
|
# and reflect it in op_name so reporting can disambiguate
|
||||||
|
# DMA_READ vs DMA_WRITE and FETCH vs STORE on the same component.
|
||||||
|
snap = pending.get("snap", {})
|
||||||
|
stage_type = snap.get("stage_type")
|
||||||
|
if stage_type is not None:
|
||||||
|
params = dict(params)
|
||||||
|
params["stage_type"] = stage_type
|
||||||
|
if op_name == "TileToken":
|
||||||
|
op_name = f"TileToken/{stage_type}"
|
||||||
# Snapshot data at record time so Phase 2 replay sidesteps
|
# Snapshot data at record time so Phase 2 replay sidesteps
|
||||||
# downstream mutations of source addrs (e.g. a tl.store that
|
# downstream mutations of source addrs (e.g. a tl.store that
|
||||||
# overwrites HBM after a load handle was sent, or a slot that
|
# overwrites HBM after a load handle was sent, or a slot that
|
||||||
|
|||||||
@@ -123,13 +123,14 @@ class TLContext:
|
|||||||
|
|
||||||
def _make_handle(
|
def _make_handle(
|
||||||
self, addr: int, shape: tuple[int, ...], dtype: str,
|
self, addr: int, shape: tuple[int, ...], dtype: str,
|
||||||
space: str = "tcm",
|
space: str = "tcm", pinned: bool = False,
|
||||||
) -> TensorHandle:
|
) -> TensorHandle:
|
||||||
return TensorHandle(
|
return TensorHandle(
|
||||||
id=self._next_handle_id(),
|
id=self._next_handle_id(),
|
||||||
addr=addr, shape=shape, dtype=dtype,
|
addr=addr, shape=shape, dtype=dtype,
|
||||||
nbytes=self._nbytes(shape, dtype),
|
nbytes=self._nbytes(shape, dtype),
|
||||||
space=space,
|
space=space,
|
||||||
|
pinned=pinned,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_compute_out(
|
def _make_compute_out(
|
||||||
@@ -184,15 +185,17 @@ class TLContext:
|
|||||||
actually lives in Phase 2 storage.
|
actually lives in Phase 2 storage.
|
||||||
"""
|
"""
|
||||||
self._emit_dispatch_overhead()
|
self._emit_dispatch_overhead()
|
||||||
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype, space="hbm")
|
handle = self._make_handle(
|
||||||
|
addr=ptr, shape=shape, dtype=dtype, space="hbm", pinned=True,
|
||||||
|
)
|
||||||
cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes)
|
cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes)
|
||||||
data = self._emit(cmd)
|
data = self._emit(cmd)
|
||||||
if data is not None:
|
if data is not None:
|
||||||
# Greenlet mode: attach real data to handle (preserve space)
|
# Greenlet mode: attach real data to handle (preserve space + pinned)
|
||||||
return TensorHandle(
|
return TensorHandle(
|
||||||
id=handle.id, addr=handle.addr, shape=handle.shape,
|
id=handle.id, addr=handle.addr, shape=handle.shape,
|
||||||
dtype=handle.dtype, nbytes=handle.nbytes, data=data,
|
dtype=handle.dtype, nbytes=handle.nbytes, data=data,
|
||||||
space=handle.space,
|
space=handle.space, pinned=handle.pinned,
|
||||||
)
|
)
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
|||||||
@@ -150,7 +150,11 @@ def test_gemm_plan_stage_sequence():
|
|||||||
|
|
||||||
|
|
||||||
def test_gemm_plan_intermediate_k_no_dma_write():
|
def test_gemm_plan_intermediate_k_no_dma_write():
|
||||||
"""Intermediate K-tiles don't have DMA_WRITE stage."""
|
"""Intermediate K-tiles don't have DMA_WRITE or STORE stage.
|
||||||
|
|
||||||
|
The C accumulator stays in RegFile across the K loop; STORE +
|
||||||
|
DMA_WRITE only fire on the last K-tile per (m,n).
|
||||||
|
"""
|
||||||
from kernbench.components.builtin.tiling import generate_gemm_plan
|
from kernbench.components.builtin.tiling import generate_gemm_plan
|
||||||
|
|
||||||
plan = generate_gemm_plan(
|
plan = generate_gemm_plan(
|
||||||
@@ -162,15 +166,72 @@ def test_gemm_plan_intermediate_k_no_dma_write():
|
|||||||
)
|
)
|
||||||
assert len(plan.tiles) == 2
|
assert len(plan.tiles) == 2
|
||||||
|
|
||||||
# First tile (k=0): no DMA_WRITE
|
# First tile (k=0): no STORE, no DMA_WRITE — accumulator stays in RegFile
|
||||||
t0_types = [s.stage_type for s in plan.tiles[0].stages]
|
t0_types = [s.stage_type for s in plan.tiles[0].stages]
|
||||||
|
assert StageType.STORE not in t0_types
|
||||||
assert StageType.DMA_WRITE not in t0_types
|
assert StageType.DMA_WRITE not in t0_types
|
||||||
|
|
||||||
# Last tile (k=1, last_k=True): has DMA_WRITE
|
# Last tile (k=1, last_k=True): has both STORE and DMA_WRITE
|
||||||
t1_types = [s.stage_type for s in plan.tiles[1].stages]
|
t1_types = [s.stage_type for s in plan.tiles[1].stages]
|
||||||
|
assert StageType.STORE in t1_types
|
||||||
assert StageType.DMA_WRITE in t1_types
|
assert StageType.DMA_WRITE in t1_types
|
||||||
|
|
||||||
|
|
||||||
|
def test_gemm_plan_pinned_operand_skips_dma_read():
|
||||||
|
"""When a_pinned=True, A's per-tile DMA_READ is omitted.
|
||||||
|
|
||||||
|
Same for b_pinned. FETCH is unaffected — it still stages from TCM
|
||||||
|
into RegFile.
|
||||||
|
"""
|
||||||
|
from kernbench.components.builtin.tiling import generate_gemm_plan
|
||||||
|
|
||||||
|
# Baseline: neither pinned — both A and B get DMA_READ per tile.
|
||||||
|
base = generate_gemm_plan(
|
||||||
|
M=32, K=128, N=32, # K_tiles=2
|
||||||
|
tile_m=32, tile_k=64, tile_n=32,
|
||||||
|
bytes_per_element=2,
|
||||||
|
A_addr=0, B_addr=0x1000, C_addr=0x2000,
|
||||||
|
pe_prefix="sip0.cube0.pe0",
|
||||||
|
)
|
||||||
|
for tile in base.tiles:
|
||||||
|
operands = [s.params.get("operand") for s in tile.stages
|
||||||
|
if s.stage_type == StageType.DMA_READ]
|
||||||
|
assert operands == ["A", "B"], \
|
||||||
|
f"baseline tile should DMA_READ A and B, got {operands}"
|
||||||
|
|
||||||
|
# a_pinned: no A DMA_READ.
|
||||||
|
plan_a = generate_gemm_plan(
|
||||||
|
M=32, K=128, N=32,
|
||||||
|
tile_m=32, tile_k=64, tile_n=32,
|
||||||
|
bytes_per_element=2,
|
||||||
|
A_addr=0, B_addr=0x1000, C_addr=0x2000,
|
||||||
|
pe_prefix="sip0.cube0.pe0",
|
||||||
|
a_pinned=True,
|
||||||
|
)
|
||||||
|
for tile in plan_a.tiles:
|
||||||
|
operands = [s.params.get("operand") for s in tile.stages
|
||||||
|
if s.stage_type == StageType.DMA_READ]
|
||||||
|
assert operands == ["B"], \
|
||||||
|
f"a_pinned should leave only B DMA_READ, got {operands}"
|
||||||
|
# FETCH must still exist
|
||||||
|
assert any(s.stage_type == StageType.FETCH for s in tile.stages)
|
||||||
|
|
||||||
|
# Both pinned: no DMA_READ at all.
|
||||||
|
plan_both = generate_gemm_plan(
|
||||||
|
M=32, K=128, N=32,
|
||||||
|
tile_m=32, tile_k=64, tile_n=32,
|
||||||
|
bytes_per_element=2,
|
||||||
|
A_addr=0, B_addr=0x1000, C_addr=0x2000,
|
||||||
|
pe_prefix="sip0.cube0.pe0",
|
||||||
|
a_pinned=True, b_pinned=True,
|
||||||
|
)
|
||||||
|
for tile in plan_both.tiles:
|
||||||
|
dma_reads = [s for s in tile.stages
|
||||||
|
if s.stage_type == StageType.DMA_READ]
|
||||||
|
assert dma_reads == [], \
|
||||||
|
f"both pinned should skip all DMA_READ, got {dma_reads}"
|
||||||
|
|
||||||
|
|
||||||
def test_math_plan_stage_sequence():
|
def test_math_plan_stage_sequence():
|
||||||
"""Math plan has READ→FETCH→MATH→STORE→WRITE sequence."""
|
"""Math plan has READ→FETCH→MATH→STORE→WRITE sequence."""
|
||||||
from kernbench.components.builtin.tiling import generate_math_plan
|
from kernbench.components.builtin.tiling import generate_math_plan
|
||||||
|
|||||||
Reference in New Issue
Block a user