commit - release 1
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
from kernbench.components.base import ComponentBase, ComponentRegistry
|
||||
from kernbench.components.context import ComponentContext
|
||||
|
||||
__all__ = ["ComponentBase", "ComponentRegistry", "ComponentContext"]
|
||||
@@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class ComponentBase(ABC):
|
||||
"""Base class for all SimPy component implementations (ADR-0007 D3, ADR-0015).
|
||||
|
||||
Each component corresponds to one node in the compiled topology graph.
|
||||
It models the processing overhead at that node as a SimPy generator,
|
||||
allowing future implementations to add queueing and contention.
|
||||
|
||||
Port model (ADR-0015 D1):
|
||||
in_ports[src_node_id] — SimPy Store for incoming messages from src
|
||||
out_ports[dst_node_id] — SimPy Store for outgoing messages to dst
|
||||
Ports are wired by GraphEngine at initialization; wire processes model
|
||||
propagation delay between connected ports (ADR-0015 D2).
|
||||
|
||||
Context (ADR-0015 D4):
|
||||
ctx — ComponentContext with router and resolver.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
self.node = node
|
||||
self.ctx = ctx
|
||||
self.in_ports: dict[str, simpy.Store] = {}
|
||||
self.out_ports: dict[str, simpy.Store] = {}
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
"""Called once after all ports are wired.
|
||||
|
||||
Default: starts a fan-in collector and a generic forwarding worker.
|
||||
The worker calls self.run() for per-component latency, then routes the
|
||||
Transaction to the next hop or signals done (duck-typed; no direct
|
||||
Transaction import to avoid circular dependencies).
|
||||
|
||||
Override in components that need custom fan-out / aggregation logic
|
||||
(e.g. MCpuComponent, IoCpuComponent for kernel launch).
|
||||
"""
|
||||
if not self.in_ports:
|
||||
return
|
||||
self._inbox: simpy.Store = simpy.Store(env)
|
||||
for port in self.in_ports.values():
|
||||
env.process(self._fan_in(port))
|
||||
env.process(self._worker(env))
|
||||
|
||||
def _fan_in(self, port: simpy.Store) -> Generator:
|
||||
"""Relay messages from one in_port into the shared inbox."""
|
||||
while True:
|
||||
msg = yield port.get()
|
||||
yield self._inbox.put(msg)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Generic forwarding worker: spawns _forward_txn per message (pipeline)."""
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._forward_txn(env, txn))
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Apply run() latency, then forward to next hop or drain at terminal."""
|
||||
yield from self.run(env, txn.nbytes)
|
||||
next_hop = txn.next_hop # duck-typed: Transaction.next_hop
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
else:
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
txn.done.succeed()
|
||||
|
||||
@abstractmethod
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
"""SimPy process: yield one or more events for this node's processing.
|
||||
|
||||
Subclasses yield env.timeout(overhead_ns) or compute latency dynamically.
|
||||
Called by _forward_txn and subclass-specific handlers.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class PeEngineBase(ComponentBase):
|
||||
"""Base class for PE-internal engines (PE_DMA, PE_GEMM, PE_MATH).
|
||||
|
||||
Provides:
|
||||
- ``_pe_prefix``: extracted from node.id (e.g. "sip0.cube0.pe0")
|
||||
- Dual-message ``_worker``: dispatches PeInternalTxn to
|
||||
``handle_command()`` and Transaction to inherited ``_forward_txn()``.
|
||||
- ``init_resources(env)``: hook for subclass resource initialization,
|
||||
called by ``start()`` before the worker is spawned.
|
||||
|
||||
Subclass contract:
|
||||
1. Override ``handle_command(env, pe_txn)`` — process a PeInternalTxn.
|
||||
2. Override ``run(env, nbytes)`` — yield component latency.
|
||||
3. Optionally override ``init_resources(env)`` for DMA channels, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._pe_prefix: str = node.id.rsplit(".", 1)[0]
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
self.init_resources(env)
|
||||
super().start(env)
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
"""Hook for subclass resource initialization. Called before worker spawn."""
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Dual-message dispatch: PeInternalTxn → handle_command, Transaction → _forward_txn."""
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, PeInternalTxn):
|
||||
env.process(self.handle_command(env, msg))
|
||||
else:
|
||||
env.process(self._forward_txn(env, msg))
|
||||
|
||||
@abstractmethod
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: Any) -> Generator:
|
||||
"""Process a PE-internal command (PeInternalTxn).
|
||||
|
||||
Subclass must:
|
||||
- Perform engine-specific work (acquire resources, compute, etc.)
|
||||
- Call ``pe_txn.done.succeed()`` on completion.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
"""DI registry: maps node.impl strings to ComponentBase subclasses.
|
||||
|
||||
Resolution order for ComponentRegistry.create(node, overrides, ctx):
|
||||
1. overrides[node.impl] — caller-injected override
|
||||
2. _registry[node.impl] — globally registered impl
|
||||
3. Error — no fallback; every node must have an impl
|
||||
"""
|
||||
|
||||
_registry: dict[str, type[ComponentBase]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, impl: str, component_cls: type[ComponentBase]) -> None:
|
||||
cls._registry[impl] = component_cls
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
node: Node,
|
||||
overrides: dict[str, type[ComponentBase]] | None = None,
|
||||
ctx: ComponentContext | None = None,
|
||||
) -> ComponentBase:
|
||||
if overrides and node.impl in overrides:
|
||||
return overrides[node.impl](node, ctx)
|
||||
if node.impl in cls._registry:
|
||||
return cls._registry[node.impl](node, ctx)
|
||||
raise ValueError(
|
||||
f"No component registered for impl '{node.impl}' (node: {node.id}). "
|
||||
f"Register it in kernbench.components.impls.__init__."
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.policy.routing.router import AddressResolver, PathRouter
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComponentContext:
|
||||
"""Topology services injected into every component implementation.
|
||||
|
||||
Required by components that need routing or address resolution
|
||||
(IoCpuComponent, MCpuComponent, …). TransitComponent ignores ctx.
|
||||
|
||||
Passed via ComponentRegistry.create(node, overrides, ctx=ctx).
|
||||
"""
|
||||
|
||||
router: PathRouter
|
||||
resolver: AddressResolver
|
||||
positions: dict[str, tuple[float, float] | None] # node_id → pos_mm
|
||||
ns_per_mm: float # wire propagation constant (from topology spec)
|
||||
edge_map: dict[tuple[str, str], Any] = field(default_factory=dict)
|
||||
spec: dict = field(default_factory=dict) # topology spec (cube layout, PE count, etc.)
|
||||
|
||||
def get_shared_resource(
|
||||
self, env: simpy.Environment, key: str, capacity: int = 1,
|
||||
) -> simpy.Resource:
|
||||
"""Return a shared SimPy Resource, creating it on first access.
|
||||
|
||||
Used by PE components that share a resource across engines within
|
||||
the same PE (e.g. accel_slot shared by PE_GEMM and PE_MATH).
|
||||
Key should be scoped per PE: e.g. "sip0.cube0.pe0.accel_slot".
|
||||
"""
|
||||
if not hasattr(self, "_shared_resources"):
|
||||
self._shared_resources: dict[str, simpy.Resource] = {}
|
||||
if key not in self._shared_resources:
|
||||
self._shared_resources[key] = simpy.Resource(env, capacity=capacity)
|
||||
return self._shared_resources[key]
|
||||
|
||||
def compute_drain_ns(self, path: list[str], nbytes: int) -> float:
|
||||
"""Wormhole drain time: nbytes / bottleneck_bw along path."""
|
||||
min_bw = float("inf")
|
||||
for i in range(len(path) - 1):
|
||||
edge = self.edge_map.get((path[i], path[i + 1]))
|
||||
if edge and getattr(edge, "bw_gbs", None):
|
||||
min_bw = min(min_bw, edge.bw_gbs)
|
||||
if min_bw == float("inf"):
|
||||
return 0.0
|
||||
return nbytes / min_bw
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Concrete component implementations.
|
||||
|
||||
Each module registers its component(s) with ComponentRegistry on import.
|
||||
Import this package to activate all built-in implementations.
|
||||
"""
|
||||
|
||||
from kernbench.components.base import ComponentRegistry
|
||||
from kernbench.components.impls.forwarding import TransitComponent
|
||||
from kernbench.components.impls.hbm_ctrl import HbmCtrlComponent
|
||||
from kernbench.components.impls.io_cpu import IoCpuComponent
|
||||
from kernbench.components.impls.m_cpu import MCpuComponent
|
||||
from kernbench.components.impls.noc import TwoDMeshNocComponent
|
||||
from kernbench.components.impls.pcie_ep import PcieEpComponent
|
||||
from kernbench.components.impls.pe_cpu import PeCpuComponent
|
||||
from kernbench.components.impls.pe_dma import PeDmaComponent
|
||||
from kernbench.components.impls.pe_gemm import PeGemmComponent
|
||||
from kernbench.components.impls.pe_math import PeMathComponent
|
||||
from kernbench.components.impls.pe_scheduler import PeSchedulerComponent
|
||||
from kernbench.components.impls.pe_tcm import PeTcmComponent
|
||||
from kernbench.components.impls.sram import SramComponent
|
||||
|
||||
ComponentRegistry.register("forwarding_v1", TransitComponent)
|
||||
ComponentRegistry.register("switch_v1", TransitComponent)
|
||||
ComponentRegistry.register("noc_v1", TransitComponent)
|
||||
ComponentRegistry.register("noc_2d_mesh_v1", TwoDMeshNocComponent)
|
||||
ComponentRegistry.register("ucie_v1", TransitComponent)
|
||||
ComponentRegistry.register("xbar_v1", TransitComponent)
|
||||
ComponentRegistry.register("pcie_ep_v1", PcieEpComponent)
|
||||
ComponentRegistry.register("io_cpu_v1", IoCpuComponent)
|
||||
ComponentRegistry.register("m_cpu_v1", MCpuComponent)
|
||||
ComponentRegistry.register("hbm_ctrl_v1", HbmCtrlComponent)
|
||||
ComponentRegistry.register("sram_v1", SramComponent)
|
||||
ComponentRegistry.register("pe_cpu_v1", PeCpuComponent)
|
||||
ComponentRegistry.register("pe_scheduler_v1", PeSchedulerComponent)
|
||||
ComponentRegistry.register("pe_dma_v1", PeDmaComponent)
|
||||
ComponentRegistry.register("pe_gemm_v1", PeGemmComponent)
|
||||
ComponentRegistry.register("pe_math_v1", PeMathComponent)
|
||||
ComponentRegistry.register("pe_tcm_v1", PeTcmComponent)
|
||||
|
||||
__all__ = [
|
||||
"HbmCtrlComponent",
|
||||
"IoCpuComponent",
|
||||
"MCpuComponent",
|
||||
"PcieEpComponent",
|
||||
"PeCpuComponent",
|
||||
"PeDmaComponent",
|
||||
"PeGemmComponent",
|
||||
"PeMathComponent",
|
||||
"PeSchedulerComponent",
|
||||
"PeTcmComponent",
|
||||
"TransitComponent",
|
||||
"TwoDMeshNocComponent",
|
||||
"SramComponent",
|
||||
]
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class TransitComponent(ComponentBase):
|
||||
"""Transit component for NOC, UCIe, XBAR nodes.
|
||||
|
||||
Applies overhead_ns processing delay (from node.attrs) then forwards the
|
||||
Transaction to the next hop via inherited _forward_txn().
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class HbmCtrlComponent(ComponentBase):
|
||||
"""HBM controller: terminal component that models HBM access latency.
|
||||
|
||||
Dual-channel model: separate read and write resources (each capacity=1)
|
||||
allowing concurrent read/write like PE_DMA. Multiple reads or multiple
|
||||
writes still serialize within their respective channel.
|
||||
|
||||
On completion, creates a ResponseMsg and sends it back on the reverse path
|
||||
so that response latency is modeled through the fabric.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._read: simpy.Resource | None = None
|
||||
self._write: simpy.Resource | None = None
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
capacity = int(self.node.attrs.get("capacity", 1))
|
||||
self._read = simpy.Resource(env, capacity=capacity)
|
||||
self._write = simpy.Resource(env, capacity=capacity)
|
||||
super().start(env)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _select_channel(self, txn: Any) -> simpy.Resource:
|
||||
"""Select channel based on request type: write requests → write, else → read."""
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
|
||||
|
||||
assert self._read is not None and self._write is not None
|
||||
req = txn.request
|
||||
if isinstance(req, MemoryWriteMsg):
|
||||
return self._write
|
||||
if isinstance(req, PeDmaMsg) and req.is_write:
|
||||
return self._write
|
||||
return self._read
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Dispatch each incoming txn to a concurrent process for channel-level parallelism."""
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._handle_txn(env, txn))
|
||||
|
||||
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Acquire channel, run, apply drain, send response."""
|
||||
channel = self._select_channel(txn)
|
||||
with channel.request() as req:
|
||||
yield req
|
||||
yield from self.run(env, txn.nbytes)
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
yield from self._send_response(env, txn)
|
||||
|
||||
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Create ResponseMsg and send on reverse path back to originator.
|
||||
|
||||
PeDmaMsg is a direct probe with no IO_CPU/M_CPU aggregation in the path,
|
||||
so we succeed txn.done directly instead of sending a response Transaction.
|
||||
"""
|
||||
from kernbench.runtime_api.kernel import PeDmaMsg
|
||||
|
||||
if isinstance(txn.request, PeDmaMsg):
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2 and self.ctx:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
pe_id = int(parts[3].replace("slice", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=txn.request.correlation_id,
|
||||
request_id=txn.request.request_id,
|
||||
src_cube=cube_id, src_pe=pe_id, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
@@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class IoCpuComponent(ComponentBase):
|
||||
"""IO_CPU component: multi-cube fan-out with response aggregation.
|
||||
|
||||
Forward path:
|
||||
1. Applies overhead_ns processing overhead.
|
||||
2. Resolves target cube(s) from request.target_cubes.
|
||||
3. Fans out sub-Transactions to each target cube's M_CPU.
|
||||
|
||||
Response path:
|
||||
Collects ResponseMsg from each M_CPU. When all cube responses are
|
||||
received, succeeds the parent txn.done.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
# Pending fan-out tracking: request_id → (expected, received, parent_txn_done)
|
||||
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
if getattr(txn, "is_response", False):
|
||||
self._collect_response(txn)
|
||||
else:
|
||||
yield from self.run(env, txn.nbytes)
|
||||
env.process(self._dispatch_to_m_cpus(env, txn))
|
||||
|
||||
def _collect_response(self, resp_txn: Any) -> None:
|
||||
"""Receive a cube response and increment the aggregation counter."""
|
||||
key = resp_txn.request.request_id
|
||||
if key not in self._pending:
|
||||
return
|
||||
expected, received, parent_done = self._pending[key]
|
||||
received += 1
|
||||
if received >= expected:
|
||||
parent_done.succeed()
|
||||
del self._pending[key]
|
||||
else:
|
||||
self._pending[key] = (expected, received, parent_done)
|
||||
|
||||
def _dispatch_to_m_cpus(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Fan out sub-Transactions to target cube M_CPUs, wait for responses."""
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
|
||||
|
||||
request = txn.request
|
||||
try:
|
||||
cube_targets = self._resolve_cube_targets(request)
|
||||
except Exception:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
if not cube_targets:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Setup aggregation
|
||||
self._pending[request.request_id] = (len(cube_targets), 0, txn.done)
|
||||
|
||||
# Fan out to each target cube's M_CPU
|
||||
for sip, cube in cube_targets:
|
||||
try:
|
||||
m_cpu_id = self.ctx.resolver.find_m_cpu(sip, cube)
|
||||
path = self.ctx.router.find_node_path(self.node.id, m_cpu_id)
|
||||
except Exception:
|
||||
continue
|
||||
if len(path) < 2:
|
||||
continue
|
||||
sub_txn = Transaction(
|
||||
request=request, path=path, step=0,
|
||||
nbytes=txn.nbytes, done=env.event(),
|
||||
result_data=txn.result_data,
|
||||
)
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
|
||||
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
|
||||
"""Return list of (sip, cube) pairs to fan out to."""
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
|
||||
|
||||
target_cubes = getattr(request, "target_cubes", "all")
|
||||
|
||||
if isinstance(request, MemoryWriteMsg):
|
||||
sip = request.dst_sip
|
||||
if target_cubes == "all":
|
||||
cube = self._cube_from_pa(request.dst_pa, fallback=request.dst_cube)
|
||||
return [(sip, cube)]
|
||||
return [(sip, c) for c in target_cubes]
|
||||
|
||||
if isinstance(request, MemoryReadMsg):
|
||||
sip = request.src_sip
|
||||
if target_cubes == "all":
|
||||
cube = self._cube_from_pa(request.src_pa, fallback=request.src_cube)
|
||||
return [(sip, cube)]
|
||||
return [(sip, c) for c in target_cubes]
|
||||
|
||||
if isinstance(request, KernelLaunchMsg):
|
||||
my_sip = self._my_sip()
|
||||
if target_cubes != "all":
|
||||
return [(my_sip, c) for c in target_cubes]
|
||||
# "all": derive from tensor shards, filtered to this SIP
|
||||
seen: set[tuple[int, int]] = set()
|
||||
targets: list[tuple[int, int]] = []
|
||||
for arg in request.args:
|
||||
if arg.arg_kind != "tensor":
|
||||
continue
|
||||
for shard in arg.shards:
|
||||
if shard.sip != my_sip:
|
||||
continue
|
||||
key = (shard.sip, shard.cube)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
targets.append(key)
|
||||
return targets
|
||||
|
||||
return []
|
||||
|
||||
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
|
||||
"""Extract cube_id from a physical address, with fallback."""
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
try:
|
||||
return PhysAddr.decode(pa_val).cube_id
|
||||
except Exception:
|
||||
return fallback
|
||||
|
||||
def _my_sip(self) -> int:
|
||||
"""Extract this IO_CPU's SIP ID from its node ID (e.g. 'sip0.io0.io_cpu' → 0)."""
|
||||
return int(self.node.id.split(".")[0].replace("sip", ""))
|
||||
@@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class MCpuComponent(ComponentBase):
|
||||
"""M_CPU component: multi-PE DMA fan-out with response aggregation.
|
||||
|
||||
Forward path (ADR-0015 D5):
|
||||
When a forward Transaction arrives at m_cpu (terminal hop), M_CPU fans out
|
||||
DMA sub-Transactions to target PEs' HBM slices. target_pe on the request
|
||||
controls fan-out: int → single PE, "all" → all PEs in the cube.
|
||||
|
||||
Response path:
|
||||
ResponseMsg from each hbm_ctrl arrives back at m_cpu. Once all PE responses
|
||||
are collected, m_cpu sends an aggregate ResponseMsg on the reverse command
|
||||
path back to io_cpu.
|
||||
|
||||
Transit:
|
||||
When m_cpu is NOT the terminal hop (transit or response relay), the
|
||||
Transaction is forwarded normally to the next hop.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
# Pending fan-out tracking: request_id → (expected, received, all_done_event)
|
||||
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
|
||||
# Store parent txn for response sending: request_id → parent_txn
|
||||
self._parent_txns: dict[str, Any] = {}
|
||||
# DMA engine resources (ADR-0015 D5, ADR-0014 D4): capacity=1 each
|
||||
self._dma_write: simpy.Resource | None = None
|
||||
self._dma_read: simpy.Resource | None = None
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
self._dma_write = simpy.Resource(env, capacity=1)
|
||||
self._dma_read = simpy.Resource(env, capacity=1)
|
||||
super().start(env)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Dispatch forward txns, collect response txns."""
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg
|
||||
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
if getattr(txn, "is_response", False):
|
||||
self._collect_response(txn)
|
||||
else:
|
||||
yield from self.run(env, txn.nbytes)
|
||||
next_hop = txn.next_hop
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
elif self.ctx is not None and txn.request is not None:
|
||||
if isinstance(txn.request, KernelLaunchMsg):
|
||||
env.process(self._kernel_launch_fanout(env, txn))
|
||||
else:
|
||||
env.process(self._dma_fanout(env, txn))
|
||||
else:
|
||||
txn.done.succeed()
|
||||
|
||||
def _collect_response(self, resp_txn: Any) -> None:
|
||||
"""Receive a PE response and increment the aggregation counter."""
|
||||
key = resp_txn.request.request_id
|
||||
if key not in self._pending:
|
||||
return
|
||||
expected, received, all_done = self._pending[key]
|
||||
received += 1
|
||||
if received >= expected:
|
||||
all_done.succeed()
|
||||
del self._pending[key]
|
||||
else:
|
||||
self._pending[key] = (expected, received, all_done)
|
||||
|
||||
def _dma_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Fan out DMA sub-Transactions to target PE(s), wait for responses,
|
||||
then send aggregate response on reverse command path.
|
||||
|
||||
Each DMA transfer acquires the DMA resource (capacity=1 per ADR-0014 D4),
|
||||
so multi-PE fan-out is serialized through the DMA engine.
|
||||
"""
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg
|
||||
|
||||
request = txn.request
|
||||
target_pe = getattr(request, "target_pe", "all")
|
||||
|
||||
dst_nodes = self._resolve_dma_destinations(request, target_pe)
|
||||
if not dst_nodes:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Setup aggregation
|
||||
all_done = env.event()
|
||||
self._pending[request.request_id] = (len(dst_nodes), 0, all_done)
|
||||
self._parent_txns[request.request_id] = txn
|
||||
|
||||
# Select DMA resource based on operation type
|
||||
dma_res = self._dma_write if isinstance(request, MemoryWriteMsg) else self._dma_read
|
||||
|
||||
# Fan out DMA sub-txns (serialized through DMA resource)
|
||||
max_drain_ns = 0.0
|
||||
for dst_node in dst_nodes:
|
||||
try:
|
||||
dma_path = self.ctx.router.find_mcpu_dma_path(self.node.id, dst_node)
|
||||
except Exception:
|
||||
continue
|
||||
if len(dma_path) < 2:
|
||||
continue
|
||||
drain_ns = self.ctx.compute_drain_ns(dma_path, txn.nbytes)
|
||||
max_drain_ns = max(max_drain_ns, drain_ns)
|
||||
sub_txn = Transaction(
|
||||
request=request, path=dma_path, step=0,
|
||||
nbytes=txn.nbytes, done=env.event(),
|
||||
drain_ns=drain_ns,
|
||||
)
|
||||
with dma_res.request() as req:
|
||||
yield req
|
||||
yield self.out_ports[dma_path[1]].put(sub_txn.advance())
|
||||
|
||||
# Wait for all PE responses
|
||||
yield all_done
|
||||
txn.result_data["xfer_ns"] = max_drain_ns
|
||||
del self._parent_txns[request.request_id]
|
||||
|
||||
# Send aggregate response on reverse command path
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=request.correlation_id,
|
||||
request_id=request.request_id,
|
||||
src_cube=cube_id, src_pe=-1, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
|
||||
def _kernel_launch_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Fan out KernelLaunchMsg to target PE_CPU(s) via NOC (ADR-0009 D3).
|
||||
|
||||
Routes through find_node_path (M_CPU → NOC → PE_CPU command edges).
|
||||
Waits for sub_txn.done directly — no ResponseMsg needed for PE direction.
|
||||
Then sends aggregate ResponseMsg back to IO_CPU on the reverse path.
|
||||
"""
|
||||
request = txn.request
|
||||
target_pe = getattr(request, "target_pe", "all")
|
||||
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
|
||||
pe_ids = self._resolve_pe_ids(target_pe)
|
||||
|
||||
if not pe_ids:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Fan out to each PE_CPU and collect done events
|
||||
sub_dones: list[simpy.Event] = []
|
||||
sub_txns: list[Transaction] = []
|
||||
for pe_id in pe_ids:
|
||||
pe_cpu_id = f"{cube_prefix}.pe{pe_id}.pe_cpu"
|
||||
try:
|
||||
path = self.ctx.router.find_node_path(self.node.id, pe_cpu_id)
|
||||
except Exception:
|
||||
continue
|
||||
if len(path) < 2:
|
||||
continue
|
||||
sub_done = env.event()
|
||||
sub_txn = Transaction(
|
||||
request=request, path=path, step=0,
|
||||
nbytes=0, done=sub_done,
|
||||
)
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
sub_dones.append(sub_done)
|
||||
sub_txns.append(sub_txn)
|
||||
|
||||
if not sub_dones:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Wait for all PE_CPUs to complete
|
||||
for sd in sub_dones:
|
||||
yield sd
|
||||
|
||||
# Aggregate PE-internal metrics (max across PEs)
|
||||
pe_exec_values = [st.result_data.get("pe_exec_ns", 0.0) for st in sub_txns]
|
||||
if pe_exec_values:
|
||||
txn.result_data["pe_exec_ns"] = max(pe_exec_values)
|
||||
dma_values = [st.result_data.get("dma_ns", 0.0) for st in sub_txns]
|
||||
if dma_values:
|
||||
txn.result_data["dma_ns"] = max(dma_values)
|
||||
compute_values = [st.result_data.get("compute_ns", 0.0) for st in sub_txns]
|
||||
if compute_values:
|
||||
txn.result_data["compute_ns"] = max(compute_values)
|
||||
|
||||
# Send aggregate response on reverse command path back to IO_CPU
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=request.correlation_id,
|
||||
request_id=request.request_id,
|
||||
src_cube=cube_id, src_pe=-1, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
|
||||
def _resolve_dma_destinations(self, request: Any, target_pe: int | str) -> list[str]:
|
||||
"""Return list of HBM destination node_ids for DMA fan-out.
|
||||
|
||||
Uses PA-based resolution to determine the actual target cube and slice,
|
||||
enabling cross-cube DMA routing when the PA points to a remote cube.
|
||||
"""
|
||||
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
|
||||
|
||||
if isinstance(target_pe, int):
|
||||
return [f"{cube_prefix}.hbm_ctrl.slice{target_pe}"]
|
||||
|
||||
# PA-based resolution: extract actual target from physical address
|
||||
pa_val = getattr(request, "dst_pa", None) or getattr(request, "src_pa", None)
|
||||
if pa_val is not None:
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
try:
|
||||
pa = PhysAddr.decode(pa_val)
|
||||
return [self.ctx.resolver.resolve(pa)]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# "all" without PA (KernelLaunch): all slices in local cube
|
||||
n_slices = 8
|
||||
if self.ctx and self.ctx.spec:
|
||||
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
|
||||
n_slices = mm.get("hbm_slices_per_cube", 8)
|
||||
return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)]
|
||||
|
||||
def _resolve_pe_ids(self, target_pe: int | str) -> list[int]:
|
||||
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
|
||||
if isinstance(target_pe, int):
|
||||
return [target_pe]
|
||||
# "all": all PEs in local cube
|
||||
n_slices = 8
|
||||
if self.ctx and self.ctx.spec:
|
||||
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
|
||||
n_slices = mm.get("hbm_slices_per_cube", 8)
|
||||
return list(range(n_slices))
|
||||
@@ -0,0 +1,187 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class TwoDMeshNocComponent(ComponentBase):
|
||||
"""2D mesh NOC modeled as a single smart node.
|
||||
|
||||
Latency model:
|
||||
- Traversal latency = Manhattan distance between prev_hop and next_hop
|
||||
node positions, split into XY segments, traversed with pipeline.
|
||||
- overhead_ns (from node.attrs) is added once per traversal.
|
||||
|
||||
Contention model:
|
||||
- Each directed XY segment is a simpy.Resource(capacity=1).
|
||||
- Pipeline: next segment's resource is requested before the current
|
||||
segment's timeout completes, so a free downstream segment is acquired
|
||||
immediately (wormhole-style cut-through).
|
||||
- Two transactions sharing a segment (same row or column band) contend.
|
||||
|
||||
Concurrency:
|
||||
- _worker spawns an independent SimPy process per transaction, so the
|
||||
NOC is never serialized at the node level — only at segment resources.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._env: simpy.Environment | None = None
|
||||
self._links: dict[tuple, simpy.Resource] = {}
|
||||
self._x_grid: list[float] = []
|
||||
self._y_grid: list[float] = []
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
self._env = env
|
||||
self._build_grid()
|
||||
super().start(env)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
yield env.timeout(0)
|
||||
|
||||
# ── Grid construction ────────────────────────────────────────────
|
||||
|
||||
def _build_grid(self) -> None:
|
||||
if not self.ctx:
|
||||
return
|
||||
cube_prefix = self.node.id.rsplit(".", 1)[0]
|
||||
xs: set[float] = set()
|
||||
ys: set[float] = set()
|
||||
for node_id, pos in self.ctx.positions.items():
|
||||
if node_id.startswith(cube_prefix + ".") and pos is not None:
|
||||
xs.add(round(pos[0], 2))
|
||||
ys.add(round(pos[1], 2))
|
||||
self._x_grid = sorted(xs)
|
||||
self._y_grid = sorted(ys)
|
||||
|
||||
def _get_link(self, key: tuple) -> simpy.Resource:
|
||||
if key not in self._links:
|
||||
assert self._env is not None
|
||||
self._links[key] = simpy.Resource(self._env, capacity=1)
|
||||
return self._links[key]
|
||||
|
||||
# ── Worker ───────────────────────────────────────────────────────
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._route(env, txn))
|
||||
|
||||
def _route(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
prev_hop = txn.path[txn.step - 1] if txn.step > 0 else None
|
||||
next_hop = txn.next_hop
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
|
||||
links: list[tuple[tuple, float]] = []
|
||||
if prev_hop and next_hop and self.ctx:
|
||||
src_pos = self.ctx.positions.get(prev_hop)
|
||||
dst_pos = self.ctx.positions.get(next_hop)
|
||||
if src_pos and dst_pos:
|
||||
links = self._xy_links(src_pos, dst_pos)
|
||||
|
||||
if links:
|
||||
yield from self._traverse(env, links, overhead_ns)
|
||||
else:
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
else:
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
txn.done.succeed()
|
||||
|
||||
# ── XY routing and pipelined link traversal ──────────────────────
|
||||
|
||||
def _traverse(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
links: list[tuple[tuple, float]],
|
||||
overhead_ns: float,
|
||||
) -> Generator:
|
||||
"""Pipeline: request next segment before current timeout finishes."""
|
||||
ns_per_mm = self.ctx.ns_per_mm # type: ignore[union-attr]
|
||||
|
||||
# Acquire first link
|
||||
first_key, _ = links[0]
|
||||
current_resource = self._get_link(first_key)
|
||||
current_req = current_resource.request()
|
||||
yield current_req
|
||||
|
||||
for i, (_, dist_mm) in enumerate(links):
|
||||
# Request next link before current timeout (pipeline)
|
||||
if i + 1 < len(links):
|
||||
next_key, _ = links[i + 1]
|
||||
next_resource = self._get_link(next_key)
|
||||
next_req = next_resource.request()
|
||||
|
||||
yield env.timeout(dist_mm * ns_per_mm + (overhead_ns if i == 0 else 0.0))
|
||||
current_resource.release(current_req)
|
||||
|
||||
if i + 1 < len(links):
|
||||
yield next_req # usually already fulfilled (pipeline)
|
||||
current_resource = next_resource
|
||||
current_req = next_req
|
||||
|
||||
def _xy_links(
|
||||
self,
|
||||
src: tuple[float, float],
|
||||
dst: tuple[float, float],
|
||||
) -> list[tuple[tuple, float]]:
|
||||
"""XY routing: horizontal segment first, then vertical.
|
||||
|
||||
Returns list of (link_key, dist_mm) pairs, where link_key uniquely
|
||||
identifies a directed segment shared across concurrent transactions.
|
||||
"""
|
||||
x0, y0 = src
|
||||
x1, y1 = dst
|
||||
links: list[tuple[tuple, float]] = []
|
||||
|
||||
# Horizontal segment at y≈y0
|
||||
if abs(x0 - x1) > 1e-9:
|
||||
y_band = self._snap(y0, self._y_grid)
|
||||
for xa, xb in self._segments(x0, x1, self._x_grid):
|
||||
d = abs(xb - xa)
|
||||
if d > 1e-9:
|
||||
lo, hi = (xa, xb) if xa < xb else (xb, xa)
|
||||
dir_h = "E" if xb > xa else "W"
|
||||
links.append((("H", round(y_band, 2), round(lo, 2), round(hi, 2), dir_h), d))
|
||||
|
||||
# Vertical segment at x≈x1
|
||||
if abs(y0 - y1) > 1e-9:
|
||||
x_band = self._snap(x1, self._x_grid)
|
||||
for ya, yb in self._segments(y0, y1, self._y_grid):
|
||||
d = abs(yb - ya)
|
||||
if d > 1e-9:
|
||||
lo, hi = (ya, yb) if ya < yb else (yb, ya)
|
||||
dir_v = "S" if yb > ya else "N"
|
||||
links.append((("V", round(x_band, 2), round(lo, 2), round(hi, 2), dir_v), d))
|
||||
|
||||
return links
|
||||
|
||||
@staticmethod
|
||||
def _snap(val: float, grid: list[float]) -> float:
|
||||
if not grid:
|
||||
return val
|
||||
return min(grid, key=lambda g: abs(g - val))
|
||||
|
||||
@staticmethod
|
||||
def _segments(a: float, b: float, grid: list[float]) -> list[tuple[float, float]]:
|
||||
"""Consecutive (p_i, p_{i+1}) pairs covering range [a, b] using grid waypoints."""
|
||||
if abs(a - b) < 1e-9:
|
||||
return []
|
||||
lo, hi = (a, b) if a < b else (b, a)
|
||||
pts = [lo] + [g for g in grid if lo + 1e-9 < g < hi - 1e-9] + [hi]
|
||||
pairs = [(pts[i], pts[i + 1]) for i in range(len(pts) - 1)]
|
||||
if a > b:
|
||||
pairs = [(p2, p1) for p1, p2 in reversed(pairs)]
|
||||
return pairs
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PcieEpComponent(ComponentBase):
|
||||
"""PCIe endpoint: protocol processing overhead before forwarding.
|
||||
|
||||
Applies overhead_ns (from node.attrs) for PCIe protocol handling,
|
||||
then forwards via inherited _forward_txn().
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeCpuComponent(ComponentBase):
|
||||
"""PE_CPU: kernel execution controller (Stage 2).
|
||||
|
||||
Two-phase kernel execution (ADR-0014 D1):
|
||||
Phase 1 (compile): look up kernel from registry, run it with TLContext
|
||||
to generate a PeCommand list.
|
||||
Phase 2 (replay): iterate commands, dispatch to PE_SCHEDULER via
|
||||
PeInternalTxn, wait for blocking commands.
|
||||
|
||||
Non-kernel Transactions are forwarded normally.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._pe_prefix = node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
|
||||
try:
|
||||
self._pe_idx = int(self._pe_prefix.rsplit("pe", 1)[1])
|
||||
except (IndexError, ValueError):
|
||||
self._pe_idx = 0
|
||||
# Extract sip/cube index for multi-SIP/cube shard matching
|
||||
parts = node.id.split(".")
|
||||
try:
|
||||
self._sip_idx = int(parts[0].replace("sip", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._sip_idx = 0
|
||||
try:
|
||||
self._cube_idx = int(parts[1].replace("cube", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._cube_idx = 0
|
||||
|
||||
def _find_shard(self, shards: tuple) -> Any:
|
||||
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
|
||||
for s in shards:
|
||||
if s.sip == self._sip_idx and s.cube == self._cube_idx and s.pe == self._pe_idx:
|
||||
return s
|
||||
return shards[min(self._pe_idx, len(shards) - 1)]
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg
|
||||
|
||||
if hasattr(txn, "request") and isinstance(txn.request, KernelLaunchMsg):
|
||||
yield from self._execute_kernel(env, txn)
|
||||
else:
|
||||
yield from self._forward_txn(env, txn)
|
||||
|
||||
def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Compile kernel function and replay command trace."""
|
||||
from kernbench.common.pe_commands import (
|
||||
CompositeCmd,
|
||||
PeCpuOverheadCmd,
|
||||
PeInternalTxn,
|
||||
WaitCmd,
|
||||
)
|
||||
from kernbench.triton_emu.registry import get_kernel
|
||||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||
|
||||
request = txn.request
|
||||
|
||||
# Phase 1: Compile — apply PE_CPU setup overhead, then run kernel
|
||||
yield from self.run(env, 0)
|
||||
|
||||
kernel_fn = get_kernel(request.kernel_ref.name)
|
||||
tl = TLContext(pe_id=self._pe_idx, dispatch_cycles=0)
|
||||
|
||||
# Unpack KernelLaunchMsg.args into positional args for kernel function
|
||||
# TensorArg → PA (pointer), ScalarArg → value
|
||||
kernel_args: list = []
|
||||
for arg in request.args:
|
||||
if arg.arg_kind == "tensor":
|
||||
shard = self._find_shard(arg.shards)
|
||||
kernel_args.append(shard.pa)
|
||||
elif arg.arg_kind == "scalar":
|
||||
kernel_args.append(arg.value)
|
||||
|
||||
run_kernel(kernel_fn, tl, *kernel_args)
|
||||
commands = tl.commands
|
||||
|
||||
# Phase 2: Replay — dispatch commands to PE_SCHEDULER
|
||||
pe_exec_start = env.now
|
||||
scheduler_id = f"{self._pe_prefix}.pe_scheduler"
|
||||
pending: dict[str, simpy.Event] = {} # completion_id → done event
|
||||
composite_results: list[dict] = [] # collect result_data from CompositeCmd txns
|
||||
|
||||
for cmd in commands:
|
||||
if isinstance(cmd, PeCpuOverheadCmd):
|
||||
yield env.timeout(cmd.cycles)
|
||||
elif isinstance(cmd, WaitCmd):
|
||||
if cmd.handle is not None:
|
||||
evt = pending.pop(cmd.handle.id, None)
|
||||
if evt:
|
||||
yield evt
|
||||
else:
|
||||
# Wait all pending completions
|
||||
for evt in pending.values():
|
||||
yield evt
|
||||
pending.clear()
|
||||
elif isinstance(cmd, CompositeCmd):
|
||||
# Non-blocking: dispatch to scheduler, track completion
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt,
|
||||
pe_prefix=self._pe_prefix,
|
||||
)
|
||||
composite_results.append(pe_txn.result_data)
|
||||
yield self.out_ports[scheduler_id].put(pe_txn)
|
||||
pending[cmd.completion.id] = done_evt
|
||||
else:
|
||||
# Blocking: dispatch and wait for completion
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt,
|
||||
pe_prefix=self._pe_prefix,
|
||||
)
|
||||
yield self.out_ports[scheduler_id].put(pe_txn)
|
||||
yield done_evt
|
||||
|
||||
# Wait for any remaining pending completions
|
||||
for evt in pending.values():
|
||||
yield evt
|
||||
|
||||
# Record PE-internal execution time
|
||||
txn.result_data["pe_exec_ns"] = env.now - pe_exec_start
|
||||
|
||||
# Aggregate dma_ns / compute_ns from CompositeCmd results
|
||||
total_dma_ns = 0.0
|
||||
total_compute_ns = 0.0
|
||||
for rd in composite_results:
|
||||
total_dma_ns += rd.get("dma_ns", 0.0)
|
||||
total_compute_ns += rd.get("compute_ns", 0.0)
|
||||
txn.result_data["dma_ns"] = total_dma_ns
|
||||
txn.result_data["compute_ns"] = total_compute_ns
|
||||
|
||||
# Signal original Transaction done
|
||||
txn.done.succeed()
|
||||
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import PeEngineBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeDmaComponent(PeEngineBase):
|
||||
"""PE_DMA: dual-channel DMA engine with READ and WRITE resources.
|
||||
|
||||
Each channel has capacity=1 (ADR-0014 D4):
|
||||
- DMA_READ and DMA_WRITE may execute concurrently.
|
||||
- Multiple READs cannot overlap; multiple WRITEs cannot overlap.
|
||||
|
||||
Handles two message types:
|
||||
- Transaction: external fabric messages (PeDmaMsg probes, M_CPU DMA)
|
||||
- PeInternalTxn: PE-internal commands from PE_SCHEDULER
|
||||
(DmaReadCmd → HBM read, DmaWriteCmd → HBM write)
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._dma_read: simpy.Resource | None = None
|
||||
self._dma_write: simpy.Resource | None = None
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
self._dma_read = simpy.Resource(env, capacity=1)
|
||||
self._dma_write = simpy.Resource(env, capacity=1)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
yield env.timeout(0)
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Handle PE-internal DMA command: resolve PA → HBM path → transfer."""
|
||||
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.runtime_api.kernel import PeDmaMsg
|
||||
|
||||
cmd = pe_txn.command
|
||||
assert self._dma_read is not None and self._dma_write is not None
|
||||
|
||||
# Determine direction and target PA
|
||||
if isinstance(cmd, DmaReadCmd):
|
||||
dma_res = self._dma_read
|
||||
target_pa = cmd.src_pa
|
||||
is_write = False
|
||||
elif isinstance(cmd, DmaWriteCmd):
|
||||
dma_res = self._dma_write
|
||||
target_pa = cmd.dst_pa
|
||||
is_write = True
|
||||
else:
|
||||
pe_txn.done.succeed()
|
||||
return
|
||||
|
||||
# Resolve PA → HBM node and compute path
|
||||
pa = PhysAddr.decode(target_pa)
|
||||
dst_node = self.ctx.resolver.resolve(pa)
|
||||
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
|
||||
drain_ns = self.ctx.compute_drain_ns(path, cmd.nbytes)
|
||||
|
||||
# Acquire DMA channel (command issue serialization)
|
||||
with dma_res.request() as req:
|
||||
yield req
|
||||
# Create sub-Transaction with PeDmaMsg (HbmCtrl handles it directly)
|
||||
sub_done = env.event()
|
||||
sub_request = PeDmaMsg(
|
||||
correlation_id="pe_internal",
|
||||
request_id=f"dma_{id(pe_txn)}",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
dst_pa=target_pa, nbytes=cmd.nbytes,
|
||||
is_write=is_write,
|
||||
)
|
||||
sub_txn = Transaction(
|
||||
request=sub_request, path=path, step=0,
|
||||
nbytes=cmd.nbytes, done=sub_done, drain_ns=drain_ns,
|
||||
)
|
||||
# Send to next hop (path[0] is pe_dma itself, path[1] is xbar)
|
||||
if len(path) > 1:
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
# DMA channel released after issue
|
||||
|
||||
# Wait for HBM transfer completion
|
||||
yield sub_done
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition."""
|
||||
dma_res = self._select_channel(txn)
|
||||
with dma_res.request() as req:
|
||||
yield req
|
||||
next_hop = txn.next_hop
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
else:
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
txn.done.succeed()
|
||||
|
||||
def _select_channel(self, txn: Any) -> simpy.Resource:
|
||||
"""Select DMA channel based on request type."""
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg
|
||||
|
||||
assert self._dma_read is not None and self._dma_write is not None
|
||||
if isinstance(txn.request, MemoryWriteMsg):
|
||||
return self._dma_write
|
||||
return self._dma_read
|
||||
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import PeEngineBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
# dtype → bit width (for TFLOPS scaling)
|
||||
_DTYPE_BITS: dict[str, int] = {
|
||||
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
|
||||
"f32": 32, "fp32": 32, "float32": 32,
|
||||
"i8": 8, "int8": 8,
|
||||
"i16": 16, "int16": 16,
|
||||
"i32": 32, "int32": 32,
|
||||
}
|
||||
|
||||
|
||||
class PeGemmComponent(PeEngineBase):
|
||||
"""PE_GEMM: matrix multiplication engine sharing accel_slot (ADR-0014 D4).
|
||||
|
||||
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
|
||||
exclusive with PE_MATH within the same PE.
|
||||
|
||||
Compute latency model:
|
||||
FLOPs = 2 * M * K * N
|
||||
effective_tflops = peak_tflops_f16 * (16 / dtype_bits)
|
||||
compute_ns = FLOPs / (effective_tflops * 1e3)
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._accel: simpy.Resource | None = None
|
||||
self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0))
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
resource_name = self.node.attrs.get("shared_resource")
|
||||
if resource_name and self.ctx:
|
||||
self._accel = self.ctx.get_shared_resource(
|
||||
env, f"{self._pe_prefix}.{resource_name}"
|
||||
)
|
||||
|
||||
def _compute_ns(self, m: int, k: int, n: int, dtype: str) -> float:
|
||||
"""Compute GEMM latency in nanoseconds."""
|
||||
if self._peak_tflops_f16 <= 0:
|
||||
return float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
dtype_bits = _DTYPE_BITS.get(dtype, 16)
|
||||
effective_tflops = self._peak_tflops_f16 * (16.0 / dtype_bits)
|
||||
flops = 2.0 * m * k * n
|
||||
return flops / (effective_tflops * 1e3)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
from kernbench.common.pe_commands import GemmCmd
|
||||
|
||||
cmd = pe_txn.command
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
if isinstance(cmd, GemmCmd):
|
||||
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
|
||||
yield env.timeout(ns)
|
||||
else:
|
||||
yield from self.run(env, 0)
|
||||
else:
|
||||
if isinstance(cmd, GemmCmd):
|
||||
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
|
||||
yield env.timeout(ns)
|
||||
else:
|
||||
yield from self.run(env, 0)
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Transaction forwarding with accel_slot acquisition."""
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
yield from super()._forward_txn(env, txn)
|
||||
else:
|
||||
yield from super()._forward_txn(env, txn)
|
||||
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import PeEngineBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeMathComponent(PeEngineBase):
|
||||
"""PE_MATH: element-wise computation engine sharing accel_slot (ADR-0014 D4).
|
||||
|
||||
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
|
||||
exclusive with PE_GEMM within the same PE.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._accel: simpy.Resource | None = None
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
resource_name = self.node.attrs.get("shared_resource")
|
||||
if resource_name and self.ctx:
|
||||
self._accel = self.ctx.get_shared_resource(
|
||||
env, f"{self._pe_prefix}.{resource_name}"
|
||||
)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
yield from self.run(env, 0)
|
||||
else:
|
||||
yield from self.run(env, 0)
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Transaction forwarding with accel_slot acquisition."""
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
yield from super()._forward_txn(env, txn)
|
||||
else:
|
||||
yield from super()._forward_txn(env, txn)
|
||||
@@ -0,0 +1,245 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeSchedulerComponent(ComponentBase):
|
||||
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1).
|
||||
|
||||
Receives PeInternalTxn from PE_CPU, routes to the appropriate engine:
|
||||
- DmaReadCmd / DmaWriteCmd → PE_DMA
|
||||
- GemmCmd → PE_GEMM
|
||||
- MathCmd → PE_MATH
|
||||
- CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2)
|
||||
|
||||
Composite GEMM pipeline (32x64x32 tiles):
|
||||
DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t)
|
||||
with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
|
||||
|
||||
Applies scheduler overhead_ns before dispatching each command.
|
||||
Non-PeInternalTxn messages are forwarded via inherited _forward_txn().
|
||||
"""
|
||||
|
||||
# Scheduler tile dimensions (ADR-0014 D3.2)
|
||||
TILE_M = 32
|
||||
TILE_K = 64
|
||||
TILE_N = 32
|
||||
|
||||
# Command → engine suffix dispatch table.
|
||||
# New engines: add a single entry here (e.g. ConvCmd: "pe_conv").
|
||||
_CMD_DISPATCH: dict[type, str] = {}
|
||||
|
||||
@classmethod
|
||||
def _ensure_dispatch_table(cls) -> None:
|
||||
if cls._CMD_DISPATCH:
|
||||
return
|
||||
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd
|
||||
|
||||
cls._CMD_DISPATCH = {
|
||||
DmaReadCmd: "pe_dma",
|
||||
DmaWriteCmd: "pe_dma",
|
||||
GemmCmd: "pe_gemm",
|
||||
MathCmd: "pe_math",
|
||||
}
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._pe_prefix = node.id.rsplit(".", 1)[0]
|
||||
self._ensure_dispatch_table()
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, PeInternalTxn):
|
||||
env.process(self._dispatch(env, msg))
|
||||
else:
|
||||
yield from self._forward_txn(env, msg)
|
||||
|
||||
def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Route a PeInternalTxn to the correct engine via dispatch table."""
|
||||
from kernbench.common.pe_commands import CompositeCmd
|
||||
|
||||
# Scheduler overhead
|
||||
yield from self.run(env, 0)
|
||||
|
||||
cmd = pe_txn.command
|
||||
|
||||
# Check dispatch table first
|
||||
engine_suffix = self._CMD_DISPATCH.get(type(cmd))
|
||||
if engine_suffix is not None:
|
||||
yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn)
|
||||
return
|
||||
|
||||
# CompositeCmd: tiled pipeline (not a simple forward)
|
||||
if isinstance(cmd, CompositeCmd):
|
||||
yield from self._dispatch_composite(env, pe_txn)
|
||||
return
|
||||
|
||||
# Unknown command — signal done immediately
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Composite tiled pipeline (ADR-0014 D3.2).
|
||||
|
||||
GEMM: 3-stage pipeline with b-tile streaming from HBM.
|
||||
MATH: sequential compute + DMA_WRITE (no tiling).
|
||||
"""
|
||||
from kernbench.common.pe_commands import CompositeCmd
|
||||
|
||||
cmd = pe_txn.command
|
||||
assert isinstance(cmd, CompositeCmd)
|
||||
if cmd.op == "gemm" and cmd.b is not None:
|
||||
yield from self._pipeline_gemm(env, pe_txn, cmd)
|
||||
else:
|
||||
yield from self._pipeline_math(env, pe_txn, cmd)
|
||||
|
||||
def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
|
||||
"""Tiled GEMM pipeline: stream b tiles from HBM, compute, write results.
|
||||
|
||||
Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref).
|
||||
Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t)
|
||||
Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
|
||||
"""
|
||||
from kernbench.common.pe_commands import (
|
||||
DmaReadCmd,
|
||||
DmaWriteCmd,
|
||||
GemmCmd,
|
||||
PeInternalTxn as PeTxn,
|
||||
TensorHandle,
|
||||
)
|
||||
|
||||
pp = self._pe_prefix
|
||||
a = cmd.a # already in TCM
|
||||
b = cmd.b # HBM reference (via tl.ref)
|
||||
|
||||
M, K_a = a.shape[-2], a.shape[-1]
|
||||
K_b, N = b.shape[-2], b.shape[-1]
|
||||
dtype = a.dtype
|
||||
dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2
|
||||
|
||||
# Tile counts
|
||||
n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K)
|
||||
n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N)
|
||||
n_tiles = n_tiles_k * n_tiles_n
|
||||
|
||||
prev_compute_done = None
|
||||
prev_write_done = None
|
||||
total_dma_ns = 0.0
|
||||
total_compute_ns = 0.0
|
||||
|
||||
for tile_idx in range(n_tiles):
|
||||
tk = tile_idx // n_tiles_n
|
||||
tn = tile_idx % n_tiles_n
|
||||
|
||||
k_start = tk * self.TILE_K
|
||||
n_start = tn * self.TILE_N
|
||||
tile_k = min(self.TILE_K, K_a - k_start)
|
||||
tile_n = min(self.TILE_N, N - n_start)
|
||||
tile_nbytes = tile_k * tile_n * dtype_bytes
|
||||
|
||||
# --- Stage 1: DMA_READ b_tile from HBM ---
|
||||
read_done = env.event()
|
||||
b_tile_pa = b.pa + (k_start * N + n_start) * dtype_bytes
|
||||
b_tile_handle = TensorHandle(
|
||||
id=f"b_tile_{tile_idx}", pa=b_tile_pa,
|
||||
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
|
||||
)
|
||||
read_cmd = DmaReadCmd(handle=b_tile_handle, src_pa=b_tile_pa, nbytes=tile_nbytes)
|
||||
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
|
||||
|
||||
# Wait for previous compute before starting this tile's compute
|
||||
if prev_compute_done is not None:
|
||||
yield prev_compute_done
|
||||
|
||||
# Wait for this tile's DMA_READ
|
||||
yield read_done
|
||||
total_dma_ns += env.now - t0
|
||||
|
||||
# --- Stage 2: COMPUTE (GEMM) ---
|
||||
compute_done = env.event()
|
||||
out_handle = TensorHandle(
|
||||
id=f"out_tile_{tile_idx}", pa=0,
|
||||
shape=(M, tile_n), dtype=dtype,
|
||||
nbytes=M * tile_n * dtype_bytes,
|
||||
)
|
||||
compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle,
|
||||
m=M, k=tile_k, n=tile_n)
|
||||
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn)
|
||||
|
||||
# Wait for previous write (DMA_WRITE serialization)
|
||||
if prev_write_done is not None:
|
||||
yield prev_write_done
|
||||
|
||||
# Wait for compute of THIS tile
|
||||
yield compute_done
|
||||
total_compute_ns += env.now - t0
|
||||
prev_compute_done = compute_done
|
||||
|
||||
# --- Stage 3: DMA_WRITE out_tile to HBM ---
|
||||
write_done = env.event()
|
||||
out_tile_pa = cmd.out_pa + n_start * dtype_bytes
|
||||
write_nbytes = M * tile_n * dtype_bytes
|
||||
write_cmd = DmaWriteCmd(handle=out_handle, dst_pa=out_tile_pa, nbytes=write_nbytes)
|
||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||
prev_write_done = write_done
|
||||
|
||||
# Wait for final write
|
||||
if prev_write_done is not None:
|
||||
t0 = env.now
|
||||
yield prev_write_done
|
||||
total_dma_ns += env.now - t0
|
||||
|
||||
pe_txn.result_data["dma_ns"] = total_dma_ns
|
||||
pe_txn.result_data["compute_ns"] = total_compute_ns
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
|
||||
"""Non-GEMM composite: sequential compute + DMA_WRITE (no tiling)."""
|
||||
from kernbench.common.pe_commands import (
|
||||
DmaWriteCmd,
|
||||
MathCmd,
|
||||
PeInternalTxn as PeTxn,
|
||||
)
|
||||
|
||||
pp = self._pe_prefix
|
||||
|
||||
# Step 1: Compute (MATH)
|
||||
compute_done = env.event()
|
||||
compute_cmd = MathCmd(
|
||||
op=cmd.math_op or "identity",
|
||||
inputs=(cmd.a,), out=cmd.a,
|
||||
)
|
||||
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
|
||||
yield self.out_ports[f"{pp}.pe_math"].put(compute_txn)
|
||||
yield compute_done
|
||||
|
||||
# Step 2: DMA_WRITE result to HBM
|
||||
write_done = env.event()
|
||||
write_cmd = DmaWriteCmd(handle=cmd.a, dst_pa=cmd.out_pa, nbytes=cmd.out_nbytes)
|
||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||
yield write_done
|
||||
|
||||
pe_txn.done.succeed()
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeTcmComponent(ComponentBase):
|
||||
"""PE_TCM: tightly-coupled memory / local SRAM staging buffer.
|
||||
|
||||
Terminal storage component for PE-internal dataflow (ADR-0014 D5).
|
||||
Phase 0: applies overhead_ns and drain_ns at terminal.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
|
||||
def run(self, env, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class SramComponent(ComponentBase):
|
||||
"""Cube SRAM: terminal component that models SRAM access latency.
|
||||
|
||||
Applies overhead_ns processing overhead (from node.attrs).
|
||||
On completion, sends a ResponseMsg back on the reverse path.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Terminal worker: process, apply drain, send response."""
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
yield from self.run(env, txn.nbytes)
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
yield from self._send_response(env, txn)
|
||||
|
||||
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Create ResponseMsg and send on reverse path."""
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2 and self.ctx:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=txn.request.correlation_id,
|
||||
request_id=txn.request.request_id,
|
||||
src_cube=cube_id, src_pe=-1, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
Reference in New Issue
Block a user