114510d4b9
- Add cycle-accurate PE accelerator scheduler (SchedulerV2) with tiled GEMM/Math pipelines (DMA_IN → GEMM → MATH → DMA_WB) - Add DPPolicy num_pes/num_cubes/num_sips overrides for single-PE testing - Support tuple target_pe for targeting specific PE subsets - Add gemm_single_pe and gpt3_qkv benchmarks - Switch default topology to pe_scheduler_v2 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
335 lines
13 KiB
Python
335 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Generator
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import simpy
|
|
|
|
from kernbench.components.base import ComponentBase
|
|
from kernbench.sim_engine.transaction import Transaction
|
|
|
|
if TYPE_CHECKING:
|
|
from kernbench.components.context import ComponentContext
|
|
from kernbench.topology.types import Node
|
|
|
|
|
|
class MCpuComponent(ComponentBase):
|
|
"""M_CPU component: multi-PE DMA fan-out with response aggregation.
|
|
|
|
Forward path (ADR-0015 D5):
|
|
When a forward Transaction arrives at m_cpu (terminal hop), M_CPU fans out
|
|
DMA sub-Transactions to target PEs' HBM slices. target_pe on the request
|
|
controls fan-out: int → single PE, "all" → all PEs in the cube.
|
|
|
|
Response path:
|
|
ResponseMsg from each hbm_ctrl arrives back at m_cpu. Once all PE responses
|
|
are collected, m_cpu sends an aggregate ResponseMsg on the reverse command
|
|
path back to io_cpu.
|
|
|
|
Transit:
|
|
When m_cpu is NOT the terminal hop (transit or response relay), the
|
|
Transaction is forwarded normally to the next hop.
|
|
"""
|
|
|
|
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
|
super().__init__(node, ctx)
|
|
# Pending fan-out tracking: request_id → (expected, received, all_done_event)
|
|
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
|
|
# Store parent txn for response sending: request_id → parent_txn
|
|
self._parent_txns: dict[str, Any] = {}
|
|
# DMA engine resources (ADR-0015 D5, ADR-0014 D4): capacity=1 each
|
|
self._dma_write: simpy.Resource | None = None
|
|
self._dma_read: simpy.Resource | None = None
|
|
|
|
def start(self, env: simpy.Environment) -> None:
|
|
self._dma_write = simpy.Resource(env, capacity=1)
|
|
self._dma_read = simpy.Resource(env, capacity=1)
|
|
super().start(env)
|
|
|
|
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
|
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
|
yield env.timeout(overhead_ns)
|
|
|
|
def _worker(self, env: simpy.Environment) -> Generator:
|
|
"""Dispatch forward txns, collect response txns."""
|
|
from kernbench.runtime_api.kernel import KernelLaunchMsg, MmuMapMsg, MmuUnmapMsg
|
|
|
|
while True:
|
|
txn: Any = yield self._inbox.get()
|
|
if getattr(txn, "is_response", False):
|
|
self._collect_response(txn)
|
|
else:
|
|
yield from self.run(env, txn.nbytes)
|
|
next_hop = txn.next_hop
|
|
if next_hop:
|
|
yield self.out_ports[next_hop].put(txn.advance())
|
|
elif self.ctx is not None and txn.request is not None:
|
|
if isinstance(txn.request, KernelLaunchMsg):
|
|
env.process(self._kernel_launch_fanout(env, txn))
|
|
elif isinstance(txn.request, (MmuMapMsg, MmuUnmapMsg)):
|
|
env.process(self._mmu_msg_fanout(env, txn))
|
|
else:
|
|
env.process(self._dma_fanout(env, txn))
|
|
else:
|
|
txn.done.succeed()
|
|
|
|
def _collect_response(self, resp_txn: Any) -> None:
|
|
"""Receive a PE response and increment the aggregation counter."""
|
|
key = resp_txn.request.request_id
|
|
if key not in self._pending:
|
|
return
|
|
expected, received, all_done = self._pending[key]
|
|
received += 1
|
|
if received >= expected:
|
|
all_done.succeed()
|
|
del self._pending[key]
|
|
else:
|
|
self._pending[key] = (expected, received, all_done)
|
|
|
|
def _dma_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
|
|
"""Fan out DMA sub-Transactions to target PE(s), wait for responses,
|
|
then send aggregate response on reverse command path.
|
|
|
|
Each DMA transfer acquires the DMA resource (capacity=1 per ADR-0014 D4),
|
|
so multi-PE fan-out is serialized through the DMA engine.
|
|
"""
|
|
from kernbench.runtime_api.kernel import MemoryWriteMsg
|
|
|
|
request = txn.request
|
|
target_pe = getattr(request, "target_pe", "all")
|
|
|
|
dst_nodes = self._resolve_dma_destinations(request, target_pe)
|
|
if not dst_nodes:
|
|
txn.done.succeed()
|
|
return
|
|
|
|
# Setup aggregation
|
|
all_done = env.event()
|
|
self._pending[request.request_id] = (len(dst_nodes), 0, all_done)
|
|
self._parent_txns[request.request_id] = txn
|
|
|
|
# Select DMA resource based on operation type
|
|
dma_res = self._dma_write if isinstance(request, MemoryWriteMsg) else self._dma_read
|
|
|
|
# Fan out DMA sub-txns (serialized through DMA resource)
|
|
max_drain_ns = 0.0
|
|
for dst_node in dst_nodes:
|
|
try:
|
|
dma_path = self.ctx.router.find_mcpu_dma_path(self.node.id, dst_node)
|
|
except Exception:
|
|
continue
|
|
if len(dma_path) < 2:
|
|
continue
|
|
drain_ns = self.ctx.compute_drain_ns(dma_path, txn.nbytes)
|
|
max_drain_ns = max(max_drain_ns, drain_ns)
|
|
sub_txn = Transaction(
|
|
request=request, path=dma_path, step=0,
|
|
nbytes=txn.nbytes, done=env.event(),
|
|
drain_ns=drain_ns,
|
|
)
|
|
with dma_res.request() as req:
|
|
yield req
|
|
yield self.out_ports[dma_path[1]].put(sub_txn.advance())
|
|
|
|
# Wait for all PE responses
|
|
yield all_done
|
|
txn.result_data["xfer_ns"] = max_drain_ns
|
|
del self._parent_txns[request.request_id]
|
|
|
|
# Send aggregate response on reverse command path
|
|
reverse_path = list(reversed(txn.path))
|
|
if len(reverse_path) >= 2:
|
|
from kernbench.runtime_api.kernel import ResponseMsg
|
|
|
|
parts = self.node.id.split(".")
|
|
cube_id = int(parts[1].replace("cube", ""))
|
|
resp_msg = ResponseMsg(
|
|
correlation_id=request.correlation_id,
|
|
request_id=request.request_id,
|
|
src_cube=cube_id, src_pe=-1, success=True,
|
|
)
|
|
resp_txn = Transaction(
|
|
request=resp_msg, path=reverse_path, step=0,
|
|
nbytes=0, done=env.event(), is_response=True,
|
|
)
|
|
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
|
else:
|
|
txn.done.succeed()
|
|
|
|
def _kernel_launch_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
|
|
"""Fan out KernelLaunchMsg to target PE_CPU(s) via NOC (ADR-0009 D3).
|
|
|
|
Routes through find_node_path (M_CPU → NOC → PE_CPU command edges).
|
|
PE_CPU sends ResponseMsg back via NOC → M_CPU on completion.
|
|
Then sends aggregate ResponseMsg back to IO_CPU on the reverse path.
|
|
"""
|
|
request = txn.request
|
|
target_pe = getattr(request, "target_pe", "all")
|
|
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
|
|
pe_ids = self._resolve_pe_ids(target_pe)
|
|
|
|
if not pe_ids:
|
|
txn.done.succeed()
|
|
return
|
|
|
|
# Fan out to each PE_CPU, using response-based aggregation
|
|
sub_txns: list[Transaction] = []
|
|
n_dispatched = 0
|
|
for pe_id in pe_ids:
|
|
pe_cpu_id = f"{cube_prefix}.pe{pe_id}.pe_cpu"
|
|
try:
|
|
path = self.ctx.router.find_node_path(self.node.id, pe_cpu_id)
|
|
except Exception:
|
|
continue
|
|
if len(path) < 2:
|
|
continue
|
|
sub_txn = Transaction(
|
|
request=request, path=path, step=0,
|
|
nbytes=0, done=env.event(),
|
|
)
|
|
yield self.out_ports[path[1]].put(sub_txn.advance())
|
|
sub_txns.append(sub_txn)
|
|
n_dispatched += 1
|
|
|
|
if n_dispatched == 0:
|
|
txn.done.succeed()
|
|
return
|
|
|
|
# Setup response aggregation (PE_CPU ResponseMsg arrives via _collect_response)
|
|
all_done = env.event()
|
|
self._pending[request.request_id] = (n_dispatched, 0, all_done)
|
|
self._parent_txns[request.request_id] = txn
|
|
|
|
# Wait for all PE_CPU responses via NOC
|
|
yield all_done
|
|
del self._parent_txns[request.request_id]
|
|
|
|
# Aggregate PE-internal metrics (max across PEs)
|
|
pe_exec_values = [st.result_data.get("pe_exec_ns", 0.0) for st in sub_txns]
|
|
if pe_exec_values:
|
|
txn.result_data["pe_exec_ns"] = max(pe_exec_values)
|
|
dma_values = [st.result_data.get("dma_ns", 0.0) for st in sub_txns]
|
|
if dma_values:
|
|
txn.result_data["dma_ns"] = max(dma_values)
|
|
compute_values = [st.result_data.get("compute_ns", 0.0) for st in sub_txns]
|
|
if compute_values:
|
|
txn.result_data["compute_ns"] = max(compute_values)
|
|
|
|
# Send aggregate response on reverse command path back to IO_CPU
|
|
reverse_path = list(reversed(txn.path))
|
|
if len(reverse_path) >= 2:
|
|
from kernbench.runtime_api.kernel import ResponseMsg
|
|
|
|
parts = self.node.id.split(".")
|
|
cube_id = int(parts[1].replace("cube", ""))
|
|
resp_msg = ResponseMsg(
|
|
correlation_id=request.correlation_id,
|
|
request_id=request.request_id,
|
|
src_cube=cube_id, src_pe=-1, success=True,
|
|
)
|
|
resp_txn = Transaction(
|
|
request=resp_msg, path=reverse_path, step=0,
|
|
nbytes=0, done=env.event(), is_response=True,
|
|
)
|
|
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
|
else:
|
|
txn.done.succeed()
|
|
|
|
def _resolve_dma_destinations(self, request: Any, target_pe: int | str) -> list[str]:
|
|
"""Return list of HBM destination node_ids for DMA fan-out.
|
|
|
|
Uses PA-based resolution to determine the actual target cube and slice,
|
|
enabling cross-cube DMA routing when the PA points to a remote cube.
|
|
"""
|
|
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
|
|
|
|
if isinstance(target_pe, int):
|
|
return [f"{cube_prefix}.hbm_ctrl.slice{target_pe}"]
|
|
|
|
# PA-based resolution: extract actual target from physical address
|
|
pa_val = getattr(request, "dst_pa", None) or getattr(request, "src_pa", None)
|
|
if pa_val is not None:
|
|
from kernbench.policy.address.phyaddr import PhysAddr
|
|
try:
|
|
pa = PhysAddr.decode(pa_val)
|
|
return [self.ctx.resolver.resolve(pa)]
|
|
except Exception:
|
|
pass
|
|
|
|
# "all" without PA (KernelLaunch): all slices in local cube
|
|
n_slices = 8
|
|
if self.ctx and self.ctx.spec:
|
|
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
|
|
n_slices = mm.get("hbm_slices_per_cube", 8)
|
|
return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)]
|
|
|
|
def _mmu_msg_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
|
|
"""Fan out MmuMapMsg/MmuUnmapMsg to target PE_MMU(s) via NOC.
|
|
|
|
Routes through find_node_path (M_CPU → NOC → PE_MMU command edges).
|
|
PE_MMU is a terminal node — completes the transaction directly.
|
|
"""
|
|
request = txn.request
|
|
target_pe = getattr(request, "target_pe", "all")
|
|
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
|
|
pe_ids = self._resolve_pe_ids(target_pe)
|
|
|
|
if not pe_ids:
|
|
txn.done.succeed()
|
|
return
|
|
|
|
# Fan out to each PE_MMU
|
|
sub_dones: list[simpy.Event] = []
|
|
for pe_id in pe_ids:
|
|
pe_mmu_id = f"{cube_prefix}.pe{pe_id}.pe_mmu"
|
|
try:
|
|
path = self.ctx.router.find_node_path(self.node.id, pe_mmu_id)
|
|
except Exception:
|
|
continue
|
|
if len(path) < 2:
|
|
continue
|
|
sub_done = env.event()
|
|
sub_txn = Transaction(
|
|
request=request, path=path, step=0,
|
|
nbytes=0, done=sub_done,
|
|
)
|
|
yield self.out_ports[path[1]].put(sub_txn.advance())
|
|
sub_dones.append(sub_done)
|
|
|
|
# Wait for all PE_MMUs to complete
|
|
for sd in sub_dones:
|
|
yield sd
|
|
|
|
# Send aggregate response on reverse path
|
|
reverse_path = list(reversed(txn.path))
|
|
if len(reverse_path) >= 2:
|
|
from kernbench.runtime_api.kernel import ResponseMsg
|
|
|
|
parts = self.node.id.split(".")
|
|
cube_id = int(parts[1].replace("cube", ""))
|
|
resp_msg = ResponseMsg(
|
|
correlation_id=request.correlation_id,
|
|
request_id=request.request_id,
|
|
src_cube=cube_id, src_pe=-1, success=True,
|
|
)
|
|
resp_txn = Transaction(
|
|
request=resp_msg, path=reverse_path, step=0,
|
|
nbytes=0, done=env.event(), is_response=True,
|
|
)
|
|
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
|
else:
|
|
txn.done.succeed()
|
|
|
|
def _resolve_pe_ids(self, target_pe: int | tuple | str) -> list[int]:
|
|
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
|
|
if isinstance(target_pe, int):
|
|
return [target_pe]
|
|
if isinstance(target_pe, tuple):
|
|
return list(target_pe)
|
|
# "all": all PEs in local cube
|
|
n_slices = 8
|
|
if self.ctx and self.ctx.spec:
|
|
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
|
|
n_slices = mm.get("hbm_slices_per_cube", 8)
|
|
return list(range(n_slices))
|