Composite GEMM: K-loop accumulator residency, pinned operands, sweep + deck

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-13 15:00:41 -07:00
parent 5accd98171
commit 83ea97b05f
11 changed files with 4219 additions and 51 deletions
+1
View File
@@ -34,6 +34,7 @@ class TensorHandle:
nbytes: int # total byte size
data: object = None # reserved for validate mode
space: str = "tcm" # MemoryStore space ("tcm" | "hbm" | "sram")
pinned: bool = False # operand already DMA-staged in TCM (via tl.load)
@dataclass(frozen=True)
@@ -163,6 +163,8 @@ class PeSchedulerComponent(ComponentBase):
bytes_per_element=bpe,
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
pe_prefix=pp,
a_pinned=getattr(a, "pinned", False),
b_pinned=getattr(b, "pinned", False),
)
else:
# Math composite
+39 -30
View File
@@ -21,15 +21,22 @@ def generate_gemm_plan(
bytes_per_element: int,
A_addr: int, B_addr: int, C_addr: int,
pe_prefix: str,
a_pinned: bool = False,
b_pinned: bool = False,
) -> PipelinePlan:
"""Generate GEMM tile plan: M→N→K order.
Each tile follows stage sequence:
DMA_READ(A) → DMA_READ(B) → FETCH → GEMM → STORE
On last K-tile per (m,n): → DMA_WRITE
[DMA_READ(A)][DMA_READ(B)] → FETCH → GEMM → [STORE → DMA_WRITE]
DMA_READ(A) skipped when a_pinned=True (operand pre-staged in TCM).
DMA_READ(B) skipped when b_pinned=True.
STORE + DMA_WRITE only emitted on last K-tile per (m,n) — accumulator
stays in RegFile across K loop.
Args:
pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs.
a_pinned: A operand already resident in TCM (via prior tl.load).
b_pinned: B operand already resident in TCM.
"""
M_tiles = max(1, ceil(M / tile_m))
K_tiles = max(1, ceil(K / tile_k))
@@ -58,23 +65,26 @@ def generate_gemm_plan(
stages: list[Stage] = []
# DMA READ: load A and B tiles from HBM → TCM
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": a_addr, "nbytes": a_bytes,
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
},
))
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": b_addr, "nbytes": b_bytes,
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
},
))
# DMA READ: load A and B tiles from HBM → TCM.
# Skip if the operand is already pre-staged via tl.load.
if not a_pinned:
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": a_addr, "nbytes": a_bytes,
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
},
))
if not b_pinned:
stages.append(Stage(
stage_type=StageType.DMA_READ,
component=dma_id,
params={
"src_addr": b_addr, "nbytes": b_bytes,
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
},
))
# FETCH: TCM → Register File
stages.append(Stage(
@@ -96,18 +106,17 @@ def generate_gemm_plan(
},
))
# STORE: Register File → TCM
stages.append(Stage(
stage_type=StageType.STORE,
component=fetch_id,
params={
"direction": "write",
"nbytes": out_bytes,
},
))
# DMA WRITE: TCM → HBM (only on last K-tile)
# STORE + DMA_WRITE only on last K-tile per (m,n). The C
# accumulator stays in RegFile across the K loop.
if last_k:
stages.append(Stage(
stage_type=StageType.STORE,
component=fetch_id,
params={
"direction": "write",
"nbytes": out_bytes,
},
))
stages.append(Stage(
stage_type=StageType.DMA_WRITE,
component=dma_id,
+25 -1
View File
@@ -44,11 +44,25 @@ class OpLogger:
return self._records
def record_start(self, t: float, component_id: str, msg: Any) -> None:
"""Called by ComponentBase._on_process_start."""
"""Called by ComponentBase._on_process_start.
Snapshots TileToken stage_type at start time so we can attribute the
record correctly even if the token advances stage_idx before
record_end fires.
"""
snap: dict[str, Any] = {}
# TileToken (ADR-0021 pipeline) — capture which stage this is.
try:
stage = getattr(msg, "current_stage", None)
if stage is not None:
snap["stage_type"] = stage.stage_type.name
except Exception:
pass
self._pending[id(msg)] = {
"t_start": t,
"component_id": component_id,
"msg": msg,
"snap": snap,
}
def record_end(self, t: float, component_id: str, msg: Any) -> None:
@@ -57,6 +71,16 @@ class OpLogger:
if pending is None:
return
op_kind, op_name, params = _extract_op_info(msg)
# Merge TileToken stage_type captured at record_start into params,
# and reflect it in op_name so reporting can disambiguate
# DMA_READ vs DMA_WRITE and FETCH vs STORE on the same component.
snap = pending.get("snap", {})
stage_type = snap.get("stage_type")
if stage_type is not None:
params = dict(params)
params["stage_type"] = stage_type
if op_name == "TileToken":
op_name = f"TileToken/{stage_type}"
# Snapshot data at record time so Phase 2 replay sidesteps
# downstream mutations of source addrs (e.g. a tl.store that
# overwrites HBM after a load handle was sent, or a slot that
+7 -4
View File
@@ -123,13 +123,14 @@ class TLContext:
def _make_handle(
self, addr: int, shape: tuple[int, ...], dtype: str,
space: str = "tcm",
space: str = "tcm", pinned: bool = False,
) -> TensorHandle:
return TensorHandle(
id=self._next_handle_id(),
addr=addr, shape=shape, dtype=dtype,
nbytes=self._nbytes(shape, dtype),
space=space,
pinned=pinned,
)
def _make_compute_out(
@@ -184,15 +185,17 @@ class TLContext:
actually lives in Phase 2 storage.
"""
self._emit_dispatch_overhead()
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype, space="hbm")
handle = self._make_handle(
addr=ptr, shape=shape, dtype=dtype, space="hbm", pinned=True,
)
cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes)
data = self._emit(cmd)
if data is not None:
# Greenlet mode: attach real data to handle (preserve space)
# Greenlet mode: attach real data to handle (preserve space + pinned)
return TensorHandle(
id=handle.id, addr=handle.addr, shape=handle.shape,
dtype=handle.dtype, nbytes=handle.nbytes, data=data,
space=handle.space,
space=handle.space, pinned=handle.pinned,
)
return handle