commit - release 1

This commit is contained in:
2026-03-18 11:47:48 -07:00
commit 6f43807900
109 changed files with 14909 additions and 0 deletions
+4
View File
@@ -0,0 +1,4 @@
from kernbench.components.base import ComponentBase, ComponentRegistry
from kernbench.components.context import ComponentContext
__all__ = ["ComponentBase", "ComponentRegistry", "ComponentContext"]
+167
View File
@@ -0,0 +1,167 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class ComponentBase(ABC):
"""Base class for all SimPy component implementations (ADR-0007 D3, ADR-0015).
Each component corresponds to one node in the compiled topology graph.
It models the processing overhead at that node as a SimPy generator,
allowing future implementations to add queueing and contention.
Port model (ADR-0015 D1):
in_ports[src_node_id] — SimPy Store for incoming messages from src
out_ports[dst_node_id] — SimPy Store for outgoing messages to dst
Ports are wired by GraphEngine at initialization; wire processes model
propagation delay between connected ports (ADR-0015 D2).
Context (ADR-0015 D4):
ctx — ComponentContext with router and resolver.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
self.node = node
self.ctx = ctx
self.in_ports: dict[str, simpy.Store] = {}
self.out_ports: dict[str, simpy.Store] = {}
def start(self, env: simpy.Environment) -> None:
"""Called once after all ports are wired.
Default: starts a fan-in collector and a generic forwarding worker.
The worker calls self.run() for per-component latency, then routes the
Transaction to the next hop or signals done (duck-typed; no direct
Transaction import to avoid circular dependencies).
Override in components that need custom fan-out / aggregation logic
(e.g. MCpuComponent, IoCpuComponent for kernel launch).
"""
if not self.in_ports:
return
self._inbox: simpy.Store = simpy.Store(env)
for port in self.in_ports.values():
env.process(self._fan_in(port))
env.process(self._worker(env))
def _fan_in(self, port: simpy.Store) -> Generator:
"""Relay messages from one in_port into the shared inbox."""
while True:
msg = yield port.get()
yield self._inbox.put(msg)
def _worker(self, env: simpy.Environment) -> Generator:
"""Generic forwarding worker: spawns _forward_txn per message (pipeline)."""
while True:
txn: Any = yield self._inbox.get()
env.process(self._forward_txn(env, txn))
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Apply run() latency, then forward to next hop or drain at terminal."""
yield from self.run(env, txn.nbytes)
next_hop = txn.next_hop # duck-typed: Transaction.next_hop
if next_hop:
yield self.out_ports[next_hop].put(txn.advance())
else:
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
txn.done.succeed()
@abstractmethod
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
"""SimPy process: yield one or more events for this node's processing.
Subclasses yield env.timeout(overhead_ns) or compute latency dynamically.
Called by _forward_txn and subclass-specific handlers.
"""
...
class PeEngineBase(ComponentBase):
"""Base class for PE-internal engines (PE_DMA, PE_GEMM, PE_MATH).
Provides:
- ``_pe_prefix``: extracted from node.id (e.g. "sip0.cube0.pe0")
- Dual-message ``_worker``: dispatches PeInternalTxn to
``handle_command()`` and Transaction to inherited ``_forward_txn()``.
- ``init_resources(env)``: hook for subclass resource initialization,
called by ``start()`` before the worker is spawned.
Subclass contract:
1. Override ``handle_command(env, pe_txn)`` — process a PeInternalTxn.
2. Override ``run(env, nbytes)`` — yield component latency.
3. Optionally override ``init_resources(env)`` for DMA channels, etc.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._pe_prefix: str = node.id.rsplit(".", 1)[0]
def start(self, env: simpy.Environment) -> None:
self.init_resources(env)
super().start(env)
def init_resources(self, env: simpy.Environment) -> None:
"""Hook for subclass resource initialization. Called before worker spawn."""
def _worker(self, env: simpy.Environment) -> Generator:
"""Dual-message dispatch: PeInternalTxn → handle_command, Transaction → _forward_txn."""
from kernbench.common.pe_commands import PeInternalTxn
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, PeInternalTxn):
env.process(self.handle_command(env, msg))
else:
env.process(self._forward_txn(env, msg))
@abstractmethod
def handle_command(self, env: simpy.Environment, pe_txn: Any) -> Generator:
"""Process a PE-internal command (PeInternalTxn).
Subclass must:
- Perform engine-specific work (acquire resources, compute, etc.)
- Call ``pe_txn.done.succeed()`` on completion.
"""
...
class ComponentRegistry:
"""DI registry: maps node.impl strings to ComponentBase subclasses.
Resolution order for ComponentRegistry.create(node, overrides, ctx):
1. overrides[node.impl] — caller-injected override
2. _registry[node.impl] — globally registered impl
3. Error — no fallback; every node must have an impl
"""
_registry: dict[str, type[ComponentBase]] = {}
@classmethod
def register(cls, impl: str, component_cls: type[ComponentBase]) -> None:
cls._registry[impl] = component_cls
@classmethod
def create(
cls,
node: Node,
overrides: dict[str, type[ComponentBase]] | None = None,
ctx: ComponentContext | None = None,
) -> ComponentBase:
if overrides and node.impl in overrides:
return overrides[node.impl](node, ctx)
if node.impl in cls._registry:
return cls._registry[node.impl](node, ctx)
raise ValueError(
f"No component registered for impl '{node.impl}' (node: {node.id}). "
f"Register it in kernbench.components.impls.__init__."
)
+52
View File
@@ -0,0 +1,52 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import simpy
from kernbench.policy.routing.router import AddressResolver, PathRouter
@dataclass
class ComponentContext:
"""Topology services injected into every component implementation.
Required by components that need routing or address resolution
(IoCpuComponent, MCpuComponent, …). TransitComponent ignores ctx.
Passed via ComponentRegistry.create(node, overrides, ctx=ctx).
"""
router: PathRouter
resolver: AddressResolver
positions: dict[str, tuple[float, float] | None] # node_id → pos_mm
ns_per_mm: float # wire propagation constant (from topology spec)
edge_map: dict[tuple[str, str], Any] = field(default_factory=dict)
spec: dict = field(default_factory=dict) # topology spec (cube layout, PE count, etc.)
def get_shared_resource(
self, env: simpy.Environment, key: str, capacity: int = 1,
) -> simpy.Resource:
"""Return a shared SimPy Resource, creating it on first access.
Used by PE components that share a resource across engines within
the same PE (e.g. accel_slot shared by PE_GEMM and PE_MATH).
Key should be scoped per PE: e.g. "sip0.cube0.pe0.accel_slot".
"""
if not hasattr(self, "_shared_resources"):
self._shared_resources: dict[str, simpy.Resource] = {}
if key not in self._shared_resources:
self._shared_resources[key] = simpy.Resource(env, capacity=capacity)
return self._shared_resources[key]
def compute_drain_ns(self, path: list[str], nbytes: int) -> float:
"""Wormhole drain time: nbytes / bottleneck_bw along path."""
min_bw = float("inf")
for i in range(len(path) - 1):
edge = self.edge_map.get((path[i], path[i + 1]))
if edge and getattr(edge, "bw_gbs", None):
min_bw = min(min_bw, edge.bw_gbs)
if min_bw == float("inf"):
return 0.0
return nbytes / min_bw
@@ -0,0 +1,54 @@
"""Concrete component implementations.
Each module registers its component(s) with ComponentRegistry on import.
Import this package to activate all built-in implementations.
"""
from kernbench.components.base import ComponentRegistry
from kernbench.components.impls.forwarding import TransitComponent
from kernbench.components.impls.hbm_ctrl import HbmCtrlComponent
from kernbench.components.impls.io_cpu import IoCpuComponent
from kernbench.components.impls.m_cpu import MCpuComponent
from kernbench.components.impls.noc import TwoDMeshNocComponent
from kernbench.components.impls.pcie_ep import PcieEpComponent
from kernbench.components.impls.pe_cpu import PeCpuComponent
from kernbench.components.impls.pe_dma import PeDmaComponent
from kernbench.components.impls.pe_gemm import PeGemmComponent
from kernbench.components.impls.pe_math import PeMathComponent
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
from kernbench.components.impls.pe_tcm import PeTcmComponent
from kernbench.components.impls.sram import SramComponent
ComponentRegistry.register("forwarding_v1", TransitComponent)
ComponentRegistry.register("switch_v1", TransitComponent)
ComponentRegistry.register("noc_v1", TransitComponent)
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
ComponentRegistry.register("ucie_v1", TransitComponent)
ComponentRegistry.register("xbar_v1", TransitComponent)
ComponentRegistry.register("pcie_ep_v1", PcieEpComponent)
ComponentRegistry.register("io_cpu_v1", IoCpuComponent)
ComponentRegistry.register("m_cpu_v1", MCpuComponent)
ComponentRegistry.register("hbm_ctrl_v1", HbmCtrlComponent)
ComponentRegistry.register("sram_v1", SramComponent)
ComponentRegistry.register("pe_cpu_v1", PeCpuComponent)
ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
ComponentRegistry.register("pe_math_v1", PeMathComponent)
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
__all__ = [
"HbmCtrlComponent",
"IoCpuComponent",
"MCpuComponent",
"PcieEpComponent",
"PeCpuComponent",
"PeDmaComponent",
"PeGemmComponent",
"PeMathComponent",
"PeSchedulerComponent",
"PeTcmComponent",
"TransitComponent",
"TwoDMeshNocComponent",
"SramComponent",
]
@@ -0,0 +1,27 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class TransitComponent(ComponentBase):
"""Transit component for NOC, UCIe, XBAR nodes.
Applies overhead_ns processing delay (from node.attrs) then forwards the
Transaction to the next hop via inherited _forward_txn().
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
+101
View File
@@ -0,0 +1,101 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class HbmCtrlComponent(ComponentBase):
"""HBM controller: terminal component that models HBM access latency.
Dual-channel model: separate read and write resources (each capacity=1)
allowing concurrent read/write like PE_DMA. Multiple reads or multiple
writes still serialize within their respective channel.
On completion, creates a ResponseMsg and sends it back on the reverse path
so that response latency is modeled through the fabric.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._read: simpy.Resource | None = None
self._write: simpy.Resource | None = None
def start(self, env: simpy.Environment) -> None:
capacity = int(self.node.attrs.get("capacity", 1))
self._read = simpy.Resource(env, capacity=capacity)
self._write = simpy.Resource(env, capacity=capacity)
super().start(env)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _select_channel(self, txn: Any) -> simpy.Resource:
"""Select channel based on request type: write requests → write, else → read."""
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
assert self._read is not None and self._write is not None
req = txn.request
if isinstance(req, MemoryWriteMsg):
return self._write
if isinstance(req, PeDmaMsg) and req.is_write:
return self._write
return self._read
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch each incoming txn to a concurrent process for channel-level parallelism."""
while True:
txn: Any = yield self._inbox.get()
env.process(self._handle_txn(env, txn))
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Acquire channel, run, apply drain, send response."""
channel = self._select_channel(txn)
with channel.request() as req:
yield req
yield from self.run(env, txn.nbytes)
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
yield from self._send_response(env, txn)
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
"""Create ResponseMsg and send on reverse path back to originator.
PeDmaMsg is a direct probe with no IO_CPU/M_CPU aggregation in the path,
so we succeed txn.done directly instead of sending a response Transaction.
"""
from kernbench.runtime_api.kernel import PeDmaMsg
if isinstance(txn.request, PeDmaMsg):
txn.done.succeed()
return
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2 and self.ctx:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
pe_id = int(parts[3].replace("slice", ""))
resp_msg = ResponseMsg(
correlation_id=txn.request.correlation_id,
request_id=txn.request.request_id,
src_cube=cube_id, src_pe=pe_id, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
+145
View File
@@ -0,0 +1,145 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class IoCpuComponent(ComponentBase):
"""IO_CPU component: multi-cube fan-out with response aggregation.
Forward path:
1. Applies overhead_ns processing overhead.
2. Resolves target cube(s) from request.target_cubes.
3. Fans out sub-Transactions to each target cube's M_CPU.
Response path:
Collects ResponseMsg from each M_CPU. When all cube responses are
received, succeeds the parent txn.done.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
# Pending fan-out tracking: request_id → (expected, received, parent_txn_done)
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
while True:
txn: Any = yield self._inbox.get()
if getattr(txn, "is_response", False):
self._collect_response(txn)
else:
yield from self.run(env, txn.nbytes)
env.process(self._dispatch_to_m_cpus(env, txn))
def _collect_response(self, resp_txn: Any) -> None:
"""Receive a cube response and increment the aggregation counter."""
key = resp_txn.request.request_id
if key not in self._pending:
return
expected, received, parent_done = self._pending[key]
received += 1
if received >= expected:
parent_done.succeed()
del self._pending[key]
else:
self._pending[key] = (expected, received, parent_done)
def _dispatch_to_m_cpus(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out sub-Transactions to target cube M_CPUs, wait for responses."""
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
request = txn.request
try:
cube_targets = self._resolve_cube_targets(request)
except Exception:
txn.done.succeed()
return
if not cube_targets:
txn.done.succeed()
return
# Setup aggregation
self._pending[request.request_id] = (len(cube_targets), 0, txn.done)
# Fan out to each target cube's M_CPU
for sip, cube in cube_targets:
try:
m_cpu_id = self.ctx.resolver.find_m_cpu(sip, cube)
path = self.ctx.router.find_node_path(self.node.id, m_cpu_id)
except Exception:
continue
if len(path) < 2:
continue
sub_txn = Transaction(
request=request, path=path, step=0,
nbytes=txn.nbytes, done=env.event(),
result_data=txn.result_data,
)
yield self.out_ports[path[1]].put(sub_txn.advance())
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
"""Return list of (sip, cube) pairs to fan out to."""
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
target_cubes = getattr(request, "target_cubes", "all")
if isinstance(request, MemoryWriteMsg):
sip = request.dst_sip
if target_cubes == "all":
cube = self._cube_from_pa(request.dst_pa, fallback=request.dst_cube)
return [(sip, cube)]
return [(sip, c) for c in target_cubes]
if isinstance(request, MemoryReadMsg):
sip = request.src_sip
if target_cubes == "all":
cube = self._cube_from_pa(request.src_pa, fallback=request.src_cube)
return [(sip, cube)]
return [(sip, c) for c in target_cubes]
if isinstance(request, KernelLaunchMsg):
my_sip = self._my_sip()
if target_cubes != "all":
return [(my_sip, c) for c in target_cubes]
# "all": derive from tensor shards, filtered to this SIP
seen: set[tuple[int, int]] = set()
targets: list[tuple[int, int]] = []
for arg in request.args:
if arg.arg_kind != "tensor":
continue
for shard in arg.shards:
if shard.sip != my_sip:
continue
key = (shard.sip, shard.cube)
if key not in seen:
seen.add(key)
targets.append(key)
return targets
return []
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
"""Extract cube_id from a physical address, with fallback."""
from kernbench.policy.address.phyaddr import PhysAddr
try:
return PhysAddr.decode(pa_val).cube_id
except Exception:
return fallback
def _my_sip(self) -> int:
"""Extract this IO_CPU's SIP ID from its node ID (e.g. 'sip0.io0.io_cpu' → 0)."""
return int(self.node.id.split(".")[0].replace("sip", ""))
+269
View File
@@ -0,0 +1,269 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class MCpuComponent(ComponentBase):
"""M_CPU component: multi-PE DMA fan-out with response aggregation.
Forward path (ADR-0015 D5):
When a forward Transaction arrives at m_cpu (terminal hop), M_CPU fans out
DMA sub-Transactions to target PEs' HBM slices. target_pe on the request
controls fan-out: int → single PE, "all" → all PEs in the cube.
Response path:
ResponseMsg from each hbm_ctrl arrives back at m_cpu. Once all PE responses
are collected, m_cpu sends an aggregate ResponseMsg on the reverse command
path back to io_cpu.
Transit:
When m_cpu is NOT the terminal hop (transit or response relay), the
Transaction is forwarded normally to the next hop.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
# Pending fan-out tracking: request_id → (expected, received, all_done_event)
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
# Store parent txn for response sending: request_id → parent_txn
self._parent_txns: dict[str, Any] = {}
# DMA engine resources (ADR-0015 D5, ADR-0014 D4): capacity=1 each
self._dma_write: simpy.Resource | None = None
self._dma_read: simpy.Resource | None = None
def start(self, env: simpy.Environment) -> None:
self._dma_write = simpy.Resource(env, capacity=1)
self._dma_read = simpy.Resource(env, capacity=1)
super().start(env)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
"""Dispatch forward txns, collect response txns."""
from kernbench.runtime_api.kernel import KernelLaunchMsg
while True:
txn: Any = yield self._inbox.get()
if getattr(txn, "is_response", False):
self._collect_response(txn)
else:
yield from self.run(env, txn.nbytes)
next_hop = txn.next_hop
if next_hop:
yield self.out_ports[next_hop].put(txn.advance())
elif self.ctx is not None and txn.request is not None:
if isinstance(txn.request, KernelLaunchMsg):
env.process(self._kernel_launch_fanout(env, txn))
else:
env.process(self._dma_fanout(env, txn))
else:
txn.done.succeed()
def _collect_response(self, resp_txn: Any) -> None:
"""Receive a PE response and increment the aggregation counter."""
key = resp_txn.request.request_id
if key not in self._pending:
return
expected, received, all_done = self._pending[key]
received += 1
if received >= expected:
all_done.succeed()
del self._pending[key]
else:
self._pending[key] = (expected, received, all_done)
def _dma_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out DMA sub-Transactions to target PE(s), wait for responses,
then send aggregate response on reverse command path.
Each DMA transfer acquires the DMA resource (capacity=1 per ADR-0014 D4),
so multi-PE fan-out is serialized through the DMA engine.
"""
from kernbench.runtime_api.kernel import MemoryWriteMsg
request = txn.request
target_pe = getattr(request, "target_pe", "all")
dst_nodes = self._resolve_dma_destinations(request, target_pe)
if not dst_nodes:
txn.done.succeed()
return
# Setup aggregation
all_done = env.event()
self._pending[request.request_id] = (len(dst_nodes), 0, all_done)
self._parent_txns[request.request_id] = txn
# Select DMA resource based on operation type
dma_res = self._dma_write if isinstance(request, MemoryWriteMsg) else self._dma_read
# Fan out DMA sub-txns (serialized through DMA resource)
max_drain_ns = 0.0
for dst_node in dst_nodes:
try:
dma_path = self.ctx.router.find_mcpu_dma_path(self.node.id, dst_node)
except Exception:
continue
if len(dma_path) < 2:
continue
drain_ns = self.ctx.compute_drain_ns(dma_path, txn.nbytes)
max_drain_ns = max(max_drain_ns, drain_ns)
sub_txn = Transaction(
request=request, path=dma_path, step=0,
nbytes=txn.nbytes, done=env.event(),
drain_ns=drain_ns,
)
with dma_res.request() as req:
yield req
yield self.out_ports[dma_path[1]].put(sub_txn.advance())
# Wait for all PE responses
yield all_done
txn.result_data["xfer_ns"] = max_drain_ns
del self._parent_txns[request.request_id]
# Send aggregate response on reverse command path
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _kernel_launch_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
"""Fan out KernelLaunchMsg to target PE_CPU(s) via NOC (ADR-0009 D3).
Routes through find_node_path (M_CPU → NOC → PE_CPU command edges).
Waits for sub_txn.done directly — no ResponseMsg needed for PE direction.
Then sends aggregate ResponseMsg back to IO_CPU on the reverse path.
"""
request = txn.request
target_pe = getattr(request, "target_pe", "all")
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
pe_ids = self._resolve_pe_ids(target_pe)
if not pe_ids:
txn.done.succeed()
return
# Fan out to each PE_CPU and collect done events
sub_dones: list[simpy.Event] = []
sub_txns: list[Transaction] = []
for pe_id in pe_ids:
pe_cpu_id = f"{cube_prefix}.pe{pe_id}.pe_cpu"
try:
path = self.ctx.router.find_node_path(self.node.id, pe_cpu_id)
except Exception:
continue
if len(path) < 2:
continue
sub_done = env.event()
sub_txn = Transaction(
request=request, path=path, step=0,
nbytes=0, done=sub_done,
)
yield self.out_ports[path[1]].put(sub_txn.advance())
sub_dones.append(sub_done)
sub_txns.append(sub_txn)
if not sub_dones:
txn.done.succeed()
return
# Wait for all PE_CPUs to complete
for sd in sub_dones:
yield sd
# Aggregate PE-internal metrics (max across PEs)
pe_exec_values = [st.result_data.get("pe_exec_ns", 0.0) for st in sub_txns]
if pe_exec_values:
txn.result_data["pe_exec_ns"] = max(pe_exec_values)
dma_values = [st.result_data.get("dma_ns", 0.0) for st in sub_txns]
if dma_values:
txn.result_data["dma_ns"] = max(dma_values)
compute_values = [st.result_data.get("compute_ns", 0.0) for st in sub_txns]
if compute_values:
txn.result_data["compute_ns"] = max(compute_values)
# Send aggregate response on reverse command path back to IO_CPU
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=request.correlation_id,
request_id=request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()
def _resolve_dma_destinations(self, request: Any, target_pe: int | str) -> list[str]:
"""Return list of HBM destination node_ids for DMA fan-out.
Uses PA-based resolution to determine the actual target cube and slice,
enabling cross-cube DMA routing when the PA points to a remote cube.
"""
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
if isinstance(target_pe, int):
return [f"{cube_prefix}.hbm_ctrl.slice{target_pe}"]
# PA-based resolution: extract actual target from physical address
pa_val = getattr(request, "dst_pa", None) or getattr(request, "src_pa", None)
if pa_val is not None:
from kernbench.policy.address.phyaddr import PhysAddr
try:
pa = PhysAddr.decode(pa_val)
return [self.ctx.resolver.resolve(pa)]
except Exception:
pass
# "all" without PA (KernelLaunch): all slices in local cube
n_slices = 8
if self.ctx and self.ctx.spec:
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
n_slices = mm.get("hbm_slices_per_cube", 8)
return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)]
def _resolve_pe_ids(self, target_pe: int | str) -> list[int]:
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
if isinstance(target_pe, int):
return [target_pe]
# "all": all PEs in local cube
n_slices = 8
if self.ctx and self.ctx.spec:
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
n_slices = mm.get("hbm_slices_per_cube", 8)
return list(range(n_slices))
+187
View File
@@ -0,0 +1,187 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class TwoDMeshNocComponent(ComponentBase):
"""2D mesh NOC modeled as a single smart node.
Latency model:
- Traversal latency = Manhattan distance between prev_hop and next_hop
node positions, split into XY segments, traversed with pipeline.
- overhead_ns (from node.attrs) is added once per traversal.
Contention model:
- Each directed XY segment is a simpy.Resource(capacity=1).
- Pipeline: next segment's resource is requested before the current
segment's timeout completes, so a free downstream segment is acquired
immediately (wormhole-style cut-through).
- Two transactions sharing a segment (same row or column band) contend.
Concurrency:
- _worker spawns an independent SimPy process per transaction, so the
NOC is never serialized at the node level — only at segment resources.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._env: simpy.Environment | None = None
self._links: dict[tuple, simpy.Resource] = {}
self._x_grid: list[float] = []
self._y_grid: list[float] = []
def start(self, env: simpy.Environment) -> None:
self._env = env
self._build_grid()
super().start(env)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
yield env.timeout(0)
# ── Grid construction ────────────────────────────────────────────
def _build_grid(self) -> None:
if not self.ctx:
return
cube_prefix = self.node.id.rsplit(".", 1)[0]
xs: set[float] = set()
ys: set[float] = set()
for node_id, pos in self.ctx.positions.items():
if node_id.startswith(cube_prefix + ".") and pos is not None:
xs.add(round(pos[0], 2))
ys.add(round(pos[1], 2))
self._x_grid = sorted(xs)
self._y_grid = sorted(ys)
def _get_link(self, key: tuple) -> simpy.Resource:
if key not in self._links:
assert self._env is not None
self._links[key] = simpy.Resource(self._env, capacity=1)
return self._links[key]
# ── Worker ───────────────────────────────────────────────────────
def _worker(self, env: simpy.Environment) -> Generator:
while True:
txn: Any = yield self._inbox.get()
env.process(self._route(env, txn))
def _route(self, env: simpy.Environment, txn: Any) -> Generator:
prev_hop = txn.path[txn.step - 1] if txn.step > 0 else None
next_hop = txn.next_hop
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
links: list[tuple[tuple, float]] = []
if prev_hop and next_hop and self.ctx:
src_pos = self.ctx.positions.get(prev_hop)
dst_pos = self.ctx.positions.get(next_hop)
if src_pos and dst_pos:
links = self._xy_links(src_pos, dst_pos)
if links:
yield from self._traverse(env, links, overhead_ns)
else:
yield env.timeout(overhead_ns)
if next_hop:
yield self.out_ports[next_hop].put(txn.advance())
else:
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
txn.done.succeed()
# ── XY routing and pipelined link traversal ──────────────────────
def _traverse(
self,
env: simpy.Environment,
links: list[tuple[tuple, float]],
overhead_ns: float,
) -> Generator:
"""Pipeline: request next segment before current timeout finishes."""
ns_per_mm = self.ctx.ns_per_mm # type: ignore[union-attr]
# Acquire first link
first_key, _ = links[0]
current_resource = self._get_link(first_key)
current_req = current_resource.request()
yield current_req
for i, (_, dist_mm) in enumerate(links):
# Request next link before current timeout (pipeline)
if i + 1 < len(links):
next_key, _ = links[i + 1]
next_resource = self._get_link(next_key)
next_req = next_resource.request()
yield env.timeout(dist_mm * ns_per_mm + (overhead_ns if i == 0 else 0.0))
current_resource.release(current_req)
if i + 1 < len(links):
yield next_req # usually already fulfilled (pipeline)
current_resource = next_resource
current_req = next_req
def _xy_links(
self,
src: tuple[float, float],
dst: tuple[float, float],
) -> list[tuple[tuple, float]]:
"""XY routing: horizontal segment first, then vertical.
Returns list of (link_key, dist_mm) pairs, where link_key uniquely
identifies a directed segment shared across concurrent transactions.
"""
x0, y0 = src
x1, y1 = dst
links: list[tuple[tuple, float]] = []
# Horizontal segment at y≈y0
if abs(x0 - x1) > 1e-9:
y_band = self._snap(y0, self._y_grid)
for xa, xb in self._segments(x0, x1, self._x_grid):
d = abs(xb - xa)
if d > 1e-9:
lo, hi = (xa, xb) if xa < xb else (xb, xa)
dir_h = "E" if xb > xa else "W"
links.append((("H", round(y_band, 2), round(lo, 2), round(hi, 2), dir_h), d))
# Vertical segment at x≈x1
if abs(y0 - y1) > 1e-9:
x_band = self._snap(x1, self._x_grid)
for ya, yb in self._segments(y0, y1, self._y_grid):
d = abs(yb - ya)
if d > 1e-9:
lo, hi = (ya, yb) if ya < yb else (yb, ya)
dir_v = "S" if yb > ya else "N"
links.append((("V", round(x_band, 2), round(lo, 2), round(hi, 2), dir_v), d))
return links
@staticmethod
def _snap(val: float, grid: list[float]) -> float:
if not grid:
return val
return min(grid, key=lambda g: abs(g - val))
@staticmethod
def _segments(a: float, b: float, grid: list[float]) -> list[tuple[float, float]]:
"""Consecutive (p_i, p_{i+1}) pairs covering range [a, b] using grid waypoints."""
if abs(a - b) < 1e-9:
return []
lo, hi = (a, b) if a < b else (b, a)
pts = [lo] + [g for g in grid if lo + 1e-9 < g < hi - 1e-9] + [hi]
pairs = [(pts[i], pts[i + 1]) for i in range(len(pts) - 1)]
if a > b:
pairs = [(p2, p1) for p1, p2 in reversed(pairs)]
return pairs
+27
View File
@@ -0,0 +1,27 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PcieEpComponent(ComponentBase):
"""PCIe endpoint: protocol processing overhead before forwarding.
Applies overhead_ns (from node.attrs) for PCIe protocol handling,
then forwards via inherited _forward_txn().
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
+154
View File
@@ -0,0 +1,154 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeCpuComponent(ComponentBase):
"""PE_CPU: kernel execution controller (Stage 2).
Two-phase kernel execution (ADR-0014 D1):
Phase 1 (compile): look up kernel from registry, run it with TLContext
to generate a PeCommand list.
Phase 2 (replay): iterate commands, dispatch to PE_SCHEDULER via
PeInternalTxn, wait for blocking commands.
Non-kernel Transactions are forwarded normally.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._pe_prefix = node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
try:
self._pe_idx = int(self._pe_prefix.rsplit("pe", 1)[1])
except (IndexError, ValueError):
self._pe_idx = 0
# Extract sip/cube index for multi-SIP/cube shard matching
parts = node.id.split(".")
try:
self._sip_idx = int(parts[0].replace("sip", ""))
except (IndexError, ValueError):
self._sip_idx = 0
try:
self._cube_idx = int(parts[1].replace("cube", ""))
except (IndexError, ValueError):
self._cube_idx = 0
def _find_shard(self, shards: tuple) -> Any:
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
for s in shards:
if s.sip == self._sip_idx and s.cube == self._cube_idx and s.pe == self._pe_idx:
return s
return shards[min(self._pe_idx, len(shards) - 1)]
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
while True:
txn: Any = yield self._inbox.get()
from kernbench.runtime_api.kernel import KernelLaunchMsg
if hasattr(txn, "request") and isinstance(txn.request, KernelLaunchMsg):
yield from self._execute_kernel(env, txn)
else:
yield from self._forward_txn(env, txn)
def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator:
"""Compile kernel function and replay command trace."""
from kernbench.common.pe_commands import (
CompositeCmd,
PeCpuOverheadCmd,
PeInternalTxn,
WaitCmd,
)
from kernbench.triton_emu.registry import get_kernel
from kernbench.triton_emu.tl_context import TLContext, run_kernel
request = txn.request
# Phase 1: Compile — apply PE_CPU setup overhead, then run kernel
yield from self.run(env, 0)
kernel_fn = get_kernel(request.kernel_ref.name)
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
# Unpack KernelLaunchMsg.args into positional args for kernel function
# TensorArg → PA (pointer), ScalarArg → value
kernel_args: list = []
for arg in request.args:
if arg.arg_kind == "tensor":
shard = self._find_shard(arg.shards)
kernel_args.append(shard.pa)
elif arg.arg_kind == "scalar":
kernel_args.append(arg.value)
run_kernel(kernel_fn, tl, *kernel_args)
commands = tl.commands
# Phase 2: Replay — dispatch commands to PE_SCHEDULER
pe_exec_start = env.now
scheduler_id = f"{self._pe_prefix}.pe_scheduler"
pending: dict[str, simpy.Event] = {} # completion_id → done event
composite_results: list[dict] = [] # collect result_data from CompositeCmd txns
for cmd in commands:
if isinstance(cmd, PeCpuOverheadCmd):
yield env.timeout(cmd.cycles)
elif isinstance(cmd, WaitCmd):
if cmd.handle is not None:
evt = pending.pop(cmd.handle.id, None)
if evt:
yield evt
else:
# Wait all pending completions
for evt in pending.values():
yield evt
pending.clear()
elif isinstance(cmd, CompositeCmd):
# Non-blocking: dispatch to scheduler, track completion
done_evt = env.event()
pe_txn = PeInternalTxn(
command=cmd, done=done_evt,
pe_prefix=self._pe_prefix,
)
composite_results.append(pe_txn.result_data)
yield self.out_ports[scheduler_id].put(pe_txn)
pending[cmd.completion.id] = done_evt
else:
# Blocking: dispatch and wait for completion
done_evt = env.event()
pe_txn = PeInternalTxn(
command=cmd, done=done_evt,
pe_prefix=self._pe_prefix,
)
yield self.out_ports[scheduler_id].put(pe_txn)
yield done_evt
# Wait for any remaining pending completions
for evt in pending.values():
yield evt
# Record PE-internal execution time
txn.result_data["pe_exec_ns"] = env.now - pe_exec_start
# Aggregate dma_ns / compute_ns from CompositeCmd results
total_dma_ns = 0.0
total_compute_ns = 0.0
for rd in composite_results:
total_dma_ns += rd.get("dma_ns", 0.0)
total_compute_ns += rd.get("compute_ns", 0.0)
txn.result_data["dma_ns"] = total_dma_ns
txn.result_data["compute_ns"] = total_compute_ns
# Signal original Transaction done
txn.done.succeed()
+116
View File
@@ -0,0 +1,116 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import PeEngineBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeDmaComponent(PeEngineBase):
"""PE_DMA: dual-channel DMA engine with READ and WRITE resources.
Each channel has capacity=1 (ADR-0014 D4):
- DMA_READ and DMA_WRITE may execute concurrently.
- Multiple READs cannot overlap; multiple WRITEs cannot overlap.
Handles two message types:
- Transaction: external fabric messages (PeDmaMsg probes, M_CPU DMA)
- PeInternalTxn: PE-internal commands from PE_SCHEDULER
(DmaReadCmd → HBM read, DmaWriteCmd → HBM write)
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._dma_read: simpy.Resource | None = None
self._dma_write: simpy.Resource | None = None
def init_resources(self, env: simpy.Environment) -> None:
self._dma_read = simpy.Resource(env, capacity=1)
self._dma_write = simpy.Resource(env, capacity=1)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
yield env.timeout(0)
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Handle PE-internal DMA command: resolve PA → HBM path → transfer."""
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
from kernbench.policy.address.phyaddr import PhysAddr
from kernbench.runtime_api.kernel import PeDmaMsg
cmd = pe_txn.command
assert self._dma_read is not None and self._dma_write is not None
# Determine direction and target PA
if isinstance(cmd, DmaReadCmd):
dma_res = self._dma_read
target_pa = cmd.src_pa
is_write = False
elif isinstance(cmd, DmaWriteCmd):
dma_res = self._dma_write
target_pa = cmd.dst_pa
is_write = True
else:
pe_txn.done.succeed()
return
# Resolve PA → HBM node and compute path
pa = PhysAddr.decode(target_pa)
dst_node = self.ctx.resolver.resolve(pa)
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
drain_ns = self.ctx.compute_drain_ns(path, cmd.nbytes)
# Acquire DMA channel (command issue serialization)
with dma_res.request() as req:
yield req
# Create sub-Transaction with PeDmaMsg (HbmCtrl handles it directly)
sub_done = env.event()
sub_request = PeDmaMsg(
correlation_id="pe_internal",
request_id=f"dma_{id(pe_txn)}",
src_sip=0, src_cube=0, src_pe=0,
dst_pa=target_pa, nbytes=cmd.nbytes,
is_write=is_write,
)
sub_txn = Transaction(
request=sub_request, path=path, step=0,
nbytes=cmd.nbytes, done=sub_done, drain_ns=drain_ns,
)
# Send to next hop (path[0] is pe_dma itself, path[1] is xbar)
if len(path) > 1:
yield self.out_ports[path[1]].put(sub_txn.advance())
# DMA channel released after issue
# Wait for HBM transfer completion
yield sub_done
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition."""
dma_res = self._select_channel(txn)
with dma_res.request() as req:
yield req
next_hop = txn.next_hop
if next_hop:
yield self.out_ports[next_hop].put(txn.advance())
else:
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
txn.done.succeed()
def _select_channel(self, txn: Any) -> simpy.Resource:
"""Select DMA channel based on request type."""
from kernbench.runtime_api.kernel import MemoryWriteMsg
assert self._dma_read is not None and self._dma_write is not None
if isinstance(txn.request, MemoryWriteMsg):
return self._dma_write
return self._dma_read
+90
View File
@@ -0,0 +1,90 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import PeEngineBase
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
# dtype → bit width (for TFLOPS scaling)
_DTYPE_BITS: dict[str, int] = {
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
"f32": 32, "fp32": 32, "float32": 32,
"i8": 8, "int8": 8,
"i16": 16, "int16": 16,
"i32": 32, "int32": 32,
}
class PeGemmComponent(PeEngineBase):
"""PE_GEMM: matrix multiplication engine sharing accel_slot (ADR-0014 D4).
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
exclusive with PE_MATH within the same PE.
Compute latency model:
FLOPs = 2 * M * K * N
effective_tflops = peak_tflops_f16 * (16 / dtype_bits)
compute_ns = FLOPs / (effective_tflops * 1e3)
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._accel: simpy.Resource | None = None
self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0))
def init_resources(self, env: simpy.Environment) -> None:
resource_name = self.node.attrs.get("shared_resource")
if resource_name and self.ctx:
self._accel = self.ctx.get_shared_resource(
env, f"{self._pe_prefix}.{resource_name}"
)
def _compute_ns(self, m: int, k: int, n: int, dtype: str) -> float:
"""Compute GEMM latency in nanoseconds."""
if self._peak_tflops_f16 <= 0:
return float(self.node.attrs.get("overhead_ns", 0.0))
dtype_bits = _DTYPE_BITS.get(dtype, 16)
effective_tflops = self._peak_tflops_f16 * (16.0 / dtype_bits)
flops = 2.0 * m * k * n
return flops / (effective_tflops * 1e3)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
from kernbench.common.pe_commands import GemmCmd
cmd = pe_txn.command
if self._accel:
with self._accel.request() as req:
yield req
if isinstance(cmd, GemmCmd):
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
yield env.timeout(ns)
else:
yield from self.run(env, 0)
else:
if isinstance(cmd, GemmCmd):
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
yield env.timeout(ns)
else:
yield from self.run(env, 0)
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Transaction forwarding with accel_slot acquisition."""
if self._accel:
with self._accel.request() as req:
yield req
yield from super()._forward_txn(env, txn)
else:
yield from super()._forward_txn(env, txn)
+54
View File
@@ -0,0 +1,54 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import PeEngineBase
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeMathComponent(PeEngineBase):
"""PE_MATH: element-wise computation engine sharing accel_slot (ADR-0014 D4).
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
exclusive with PE_GEMM within the same PE.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._accel: simpy.Resource | None = None
def init_resources(self, env: simpy.Environment) -> None:
resource_name = self.node.attrs.get("shared_resource")
if resource_name and self.ctx:
self._accel = self.ctx.get_shared_resource(
env, f"{self._pe_prefix}.{resource_name}"
)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
if self._accel:
with self._accel.request() as req:
yield req
yield from self.run(env, 0)
else:
yield from self.run(env, 0)
pe_txn.done.succeed()
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
"""Transaction forwarding with accel_slot acquisition."""
if self._accel:
with self._accel.request() as req:
yield req
yield from super()._forward_txn(env, txn)
else:
yield from super()._forward_txn(env, txn)
@@ -0,0 +1,245 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.common.pe_commands import PeInternalTxn
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeSchedulerComponent(ComponentBase):
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1).
Receives PeInternalTxn from PE_CPU, routes to the appropriate engine:
- DmaReadCmd / DmaWriteCmd → PE_DMA
- GemmCmd → PE_GEMM
- MathCmd → PE_MATH
- CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2)
Composite GEMM pipeline (32x64x32 tiles):
DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t)
with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
Applies scheduler overhead_ns before dispatching each command.
Non-PeInternalTxn messages are forwarded via inherited _forward_txn().
"""
# Scheduler tile dimensions (ADR-0014 D3.2)
TILE_M = 32
TILE_K = 64
TILE_N = 32
# Command → engine suffix dispatch table.
# New engines: add a single entry here (e.g. ConvCmd: "pe_conv").
_CMD_DISPATCH: dict[type, str] = {}
@classmethod
def _ensure_dispatch_table(cls) -> None:
if cls._CMD_DISPATCH:
return
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd
cls._CMD_DISPATCH = {
DmaReadCmd: "pe_dma",
DmaWriteCmd: "pe_dma",
GemmCmd: "pe_gemm",
MathCmd: "pe_math",
}
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
self._pe_prefix = node.id.rsplit(".", 1)[0]
self._ensure_dispatch_table()
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
from kernbench.common.pe_commands import PeInternalTxn
while True:
msg: Any = yield self._inbox.get()
if isinstance(msg, PeInternalTxn):
env.process(self._dispatch(env, msg))
else:
yield from self._forward_txn(env, msg)
def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Route a PeInternalTxn to the correct engine via dispatch table."""
from kernbench.common.pe_commands import CompositeCmd
# Scheduler overhead
yield from self.run(env, 0)
cmd = pe_txn.command
# Check dispatch table first
engine_suffix = self._CMD_DISPATCH.get(type(cmd))
if engine_suffix is not None:
yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn)
return
# CompositeCmd: tiled pipeline (not a simple forward)
if isinstance(cmd, CompositeCmd):
yield from self._dispatch_composite(env, pe_txn)
return
# Unknown command — signal done immediately
pe_txn.done.succeed()
def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
"""Composite tiled pipeline (ADR-0014 D3.2).
GEMM: 3-stage pipeline with b-tile streaming from HBM.
MATH: sequential compute + DMA_WRITE (no tiling).
"""
from kernbench.common.pe_commands import CompositeCmd
cmd = pe_txn.command
assert isinstance(cmd, CompositeCmd)
if cmd.op == "gemm" and cmd.b is not None:
yield from self._pipeline_gemm(env, pe_txn, cmd)
else:
yield from self._pipeline_math(env, pe_txn, cmd)
def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
"""Tiled GEMM pipeline: stream b tiles from HBM, compute, write results.
Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref).
Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t)
Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
"""
from kernbench.common.pe_commands import (
DmaReadCmd,
DmaWriteCmd,
GemmCmd,
PeInternalTxn as PeTxn,
TensorHandle,
)
pp = self._pe_prefix
a = cmd.a # already in TCM
b = cmd.b # HBM reference (via tl.ref)
M, K_a = a.shape[-2], a.shape[-1]
K_b, N = b.shape[-2], b.shape[-1]
dtype = a.dtype
dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2
# Tile counts
n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K)
n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N)
n_tiles = n_tiles_k * n_tiles_n
prev_compute_done = None
prev_write_done = None
total_dma_ns = 0.0
total_compute_ns = 0.0
for tile_idx in range(n_tiles):
tk = tile_idx // n_tiles_n
tn = tile_idx % n_tiles_n
k_start = tk * self.TILE_K
n_start = tn * self.TILE_N
tile_k = min(self.TILE_K, K_a - k_start)
tile_n = min(self.TILE_N, N - n_start)
tile_nbytes = tile_k * tile_n * dtype_bytes
# --- Stage 1: DMA_READ b_tile from HBM ---
read_done = env.event()
b_tile_pa = b.pa + (k_start * N + n_start) * dtype_bytes
b_tile_handle = TensorHandle(
id=f"b_tile_{tile_idx}", pa=b_tile_pa,
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
)
read_cmd = DmaReadCmd(handle=b_tile_handle, src_pa=b_tile_pa, nbytes=tile_nbytes)
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
# Wait for previous compute before starting this tile's compute
if prev_compute_done is not None:
yield prev_compute_done
# Wait for this tile's DMA_READ
yield read_done
total_dma_ns += env.now - t0
# --- Stage 2: COMPUTE (GEMM) ---
compute_done = env.event()
out_handle = TensorHandle(
id=f"out_tile_{tile_idx}", pa=0,
shape=(M, tile_n), dtype=dtype,
nbytes=M * tile_n * dtype_bytes,
)
compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle,
m=M, k=tile_k, n=tile_n)
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn)
# Wait for previous write (DMA_WRITE serialization)
if prev_write_done is not None:
yield prev_write_done
# Wait for compute of THIS tile
yield compute_done
total_compute_ns += env.now - t0
prev_compute_done = compute_done
# --- Stage 3: DMA_WRITE out_tile to HBM ---
write_done = env.event()
out_tile_pa = cmd.out_pa + n_start * dtype_bytes
write_nbytes = M * tile_n * dtype_bytes
write_cmd = DmaWriteCmd(handle=out_handle, dst_pa=out_tile_pa, nbytes=write_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
t0 = env.now
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
prev_write_done = write_done
# Wait for final write
if prev_write_done is not None:
t0 = env.now
yield prev_write_done
total_dma_ns += env.now - t0
pe_txn.result_data["dma_ns"] = total_dma_ns
pe_txn.result_data["compute_ns"] = total_compute_ns
pe_txn.done.succeed()
def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
"""Non-GEMM composite: sequential compute + DMA_WRITE (no tiling)."""
from kernbench.common.pe_commands import (
DmaWriteCmd,
MathCmd,
PeInternalTxn as PeTxn,
)
pp = self._pe_prefix
# Step 1: Compute (MATH)
compute_done = env.event()
compute_cmd = MathCmd(
op=cmd.math_op or "identity",
inputs=(cmd.a,), out=cmd.a,
)
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_math"].put(compute_txn)
yield compute_done
# Step 2: DMA_WRITE result to HBM
write_done = env.event()
write_cmd = DmaWriteCmd(handle=cmd.a, dst_pa=cmd.out_pa, nbytes=cmd.out_nbytes)
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
yield write_done
pe_txn.done.succeed()
+25
View File
@@ -0,0 +1,25 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
from kernbench.components.base import ComponentBase
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class PeTcmComponent(ComponentBase):
"""PE_TCM: tightly-coupled memory / local SRAM staging buffer.
Terminal storage component for PE-internal dataflow (ADR-0014 D5).
Phase 0: applies overhead_ns and drain_ns at terminal.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
def run(self, env, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
+59
View File
@@ -0,0 +1,59 @@
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
import simpy
from kernbench.components.base import ComponentBase
from kernbench.sim_engine.transaction import Transaction
if TYPE_CHECKING:
from kernbench.components.context import ComponentContext
from kernbench.topology.types import Node
class SramComponent(ComponentBase):
"""Cube SRAM: terminal component that models SRAM access latency.
Applies overhead_ns processing overhead (from node.attrs).
On completion, sends a ResponseMsg back on the reverse path.
"""
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
super().__init__(node, ctx)
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
yield env.timeout(overhead_ns)
def _worker(self, env: simpy.Environment) -> Generator:
"""Terminal worker: process, apply drain, send response."""
while True:
txn: Any = yield self._inbox.get()
yield from self.run(env, txn.nbytes)
drain = getattr(txn, "drain_ns", 0.0)
if drain > 0:
yield env.timeout(drain)
yield from self._send_response(env, txn)
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
"""Create ResponseMsg and send on reverse path."""
reverse_path = list(reversed(txn.path))
if len(reverse_path) >= 2 and self.ctx:
from kernbench.runtime_api.kernel import ResponseMsg
parts = self.node.id.split(".")
cube_id = int(parts[1].replace("cube", ""))
resp_msg = ResponseMsg(
correlation_id=txn.request.correlation_id,
request_id=txn.request.request_id,
src_cube=cube_id, src_pe=-1, success=True,
)
resp_txn = Transaction(
request=resp_msg, path=reverse_path, step=0,
nbytes=0, done=env.event(), is_response=True,
)
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
else:
txn.done.succeed()