Composite GEMM: K-loop accumulator residency, pinned operands, sweep + deck
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -34,6 +34,7 @@ class TensorHandle:
|
||||
nbytes: int # total byte size
|
||||
data: object = None # reserved for validate mode
|
||||
space: str = "tcm" # MemoryStore space ("tcm" | "hbm" | "sram")
|
||||
pinned: bool = False # operand already DMA-staged in TCM (via tl.load)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -163,6 +163,8 @@ class PeSchedulerComponent(ComponentBase):
|
||||
bytes_per_element=bpe,
|
||||
A_addr=a.addr, B_addr=b.addr, C_addr=cmd.out_addr,
|
||||
pe_prefix=pp,
|
||||
a_pinned=getattr(a, "pinned", False),
|
||||
b_pinned=getattr(b, "pinned", False),
|
||||
)
|
||||
else:
|
||||
# Math composite
|
||||
|
||||
@@ -21,15 +21,22 @@ def generate_gemm_plan(
|
||||
bytes_per_element: int,
|
||||
A_addr: int, B_addr: int, C_addr: int,
|
||||
pe_prefix: str,
|
||||
a_pinned: bool = False,
|
||||
b_pinned: bool = False,
|
||||
) -> PipelinePlan:
|
||||
"""Generate GEMM tile plan: M→N→K order.
|
||||
|
||||
Each tile follows stage sequence:
|
||||
DMA_READ(A) → DMA_READ(B) → FETCH → GEMM → STORE
|
||||
On last K-tile per (m,n): → DMA_WRITE
|
||||
[DMA_READ(A)] → [DMA_READ(B)] → FETCH → GEMM → [STORE → DMA_WRITE]
|
||||
DMA_READ(A) skipped when a_pinned=True (operand pre-staged in TCM).
|
||||
DMA_READ(B) skipped when b_pinned=True.
|
||||
STORE + DMA_WRITE only emitted on last K-tile per (m,n) — accumulator
|
||||
stays in RegFile across K loop.
|
||||
|
||||
Args:
|
||||
pe_prefix: e.g. "sip0.cube0.pe0" — used to build component IDs.
|
||||
a_pinned: A operand already resident in TCM (via prior tl.load).
|
||||
b_pinned: B operand already resident in TCM.
|
||||
"""
|
||||
M_tiles = max(1, ceil(M / tile_m))
|
||||
K_tiles = max(1, ceil(K / tile_k))
|
||||
@@ -58,23 +65,26 @@ def generate_gemm_plan(
|
||||
|
||||
stages: list[Stage] = []
|
||||
|
||||
# DMA READ: load A and B tiles from HBM → TCM
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_READ,
|
||||
component=dma_id,
|
||||
params={
|
||||
"src_addr": a_addr, "nbytes": a_bytes,
|
||||
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
|
||||
},
|
||||
))
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_READ,
|
||||
component=dma_id,
|
||||
params={
|
||||
"src_addr": b_addr, "nbytes": b_bytes,
|
||||
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
|
||||
},
|
||||
))
|
||||
# DMA READ: load A and B tiles from HBM → TCM.
|
||||
# Skip if the operand is already pre-staged via tl.load.
|
||||
if not a_pinned:
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_READ,
|
||||
component=dma_id,
|
||||
params={
|
||||
"src_addr": a_addr, "nbytes": a_bytes,
|
||||
"operand": "A", "tile_m": tile_m, "tile_k": tile_k,
|
||||
},
|
||||
))
|
||||
if not b_pinned:
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_READ,
|
||||
component=dma_id,
|
||||
params={
|
||||
"src_addr": b_addr, "nbytes": b_bytes,
|
||||
"operand": "B", "tile_k": tile_k, "tile_n": tile_n,
|
||||
},
|
||||
))
|
||||
|
||||
# FETCH: TCM → Register File
|
||||
stages.append(Stage(
|
||||
@@ -96,18 +106,17 @@ def generate_gemm_plan(
|
||||
},
|
||||
))
|
||||
|
||||
# STORE: Register File → TCM
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.STORE,
|
||||
component=fetch_id,
|
||||
params={
|
||||
"direction": "write",
|
||||
"nbytes": out_bytes,
|
||||
},
|
||||
))
|
||||
|
||||
# DMA WRITE: TCM → HBM (only on last K-tile)
|
||||
# STORE + DMA_WRITE only on last K-tile per (m,n). The C
|
||||
# accumulator stays in RegFile across the K loop.
|
||||
if last_k:
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.STORE,
|
||||
component=fetch_id,
|
||||
params={
|
||||
"direction": "write",
|
||||
"nbytes": out_bytes,
|
||||
},
|
||||
))
|
||||
stages.append(Stage(
|
||||
stage_type=StageType.DMA_WRITE,
|
||||
component=dma_id,
|
||||
|
||||
@@ -44,11 +44,25 @@ class OpLogger:
|
||||
return self._records
|
||||
|
||||
def record_start(self, t: float, component_id: str, msg: Any) -> None:
|
||||
"""Called by ComponentBase._on_process_start."""
|
||||
"""Called by ComponentBase._on_process_start.
|
||||
|
||||
Snapshots TileToken stage_type at start time so we can attribute the
|
||||
record correctly even if the token advances stage_idx before
|
||||
record_end fires.
|
||||
"""
|
||||
snap: dict[str, Any] = {}
|
||||
# TileToken (ADR-0021 pipeline) — capture which stage this is.
|
||||
try:
|
||||
stage = getattr(msg, "current_stage", None)
|
||||
if stage is not None:
|
||||
snap["stage_type"] = stage.stage_type.name
|
||||
except Exception:
|
||||
pass
|
||||
self._pending[id(msg)] = {
|
||||
"t_start": t,
|
||||
"component_id": component_id,
|
||||
"msg": msg,
|
||||
"snap": snap,
|
||||
}
|
||||
|
||||
def record_end(self, t: float, component_id: str, msg: Any) -> None:
|
||||
@@ -57,6 +71,16 @@ class OpLogger:
|
||||
if pending is None:
|
||||
return
|
||||
op_kind, op_name, params = _extract_op_info(msg)
|
||||
# Merge TileToken stage_type captured at record_start into params,
|
||||
# and reflect it in op_name so reporting can disambiguate
|
||||
# DMA_READ vs DMA_WRITE and FETCH vs STORE on the same component.
|
||||
snap = pending.get("snap", {})
|
||||
stage_type = snap.get("stage_type")
|
||||
if stage_type is not None:
|
||||
params = dict(params)
|
||||
params["stage_type"] = stage_type
|
||||
if op_name == "TileToken":
|
||||
op_name = f"TileToken/{stage_type}"
|
||||
# Snapshot data at record time so Phase 2 replay sidesteps
|
||||
# downstream mutations of source addrs (e.g. a tl.store that
|
||||
# overwrites HBM after a load handle was sent, or a slot that
|
||||
|
||||
@@ -123,13 +123,14 @@ class TLContext:
|
||||
|
||||
def _make_handle(
|
||||
self, addr: int, shape: tuple[int, ...], dtype: str,
|
||||
space: str = "tcm",
|
||||
space: str = "tcm", pinned: bool = False,
|
||||
) -> TensorHandle:
|
||||
return TensorHandle(
|
||||
id=self._next_handle_id(),
|
||||
addr=addr, shape=shape, dtype=dtype,
|
||||
nbytes=self._nbytes(shape, dtype),
|
||||
space=space,
|
||||
pinned=pinned,
|
||||
)
|
||||
|
||||
def _make_compute_out(
|
||||
@@ -184,15 +185,17 @@ class TLContext:
|
||||
actually lives in Phase 2 storage.
|
||||
"""
|
||||
self._emit_dispatch_overhead()
|
||||
handle = self._make_handle(addr=ptr, shape=shape, dtype=dtype, space="hbm")
|
||||
handle = self._make_handle(
|
||||
addr=ptr, shape=shape, dtype=dtype, space="hbm", pinned=True,
|
||||
)
|
||||
cmd = DmaReadCmd(handle=handle, src_addr=ptr, nbytes=handle.nbytes)
|
||||
data = self._emit(cmd)
|
||||
if data is not None:
|
||||
# Greenlet mode: attach real data to handle (preserve space)
|
||||
# Greenlet mode: attach real data to handle (preserve space + pinned)
|
||||
return TensorHandle(
|
||||
id=handle.id, addr=handle.addr, shape=handle.shape,
|
||||
dtype=handle.dtype, nbytes=handle.nbytes, data=data,
|
||||
space=handle.space,
|
||||
space=handle.space, pinned=handle.pinned,
|
||||
)
|
||||
return handle
|
||||
|
||||
|
||||
Reference in New Issue
Block a user