Add SIP-level tensor parallelism, component registry YAML, VA offset verification
- DPPolicy: 3-level (sip/cube/pe), unified naming (column_wise/row_wise) - PE_CPU: auto num_programs from cube shard count - context.launch(): per-SIP KernelLaunchMsg with local va_base + auto local shape - deploy_tensor: removed mmus param, MMU mapping is context-only responsibility - ComponentRegistry: YAML-based lazy loading (components.yaml), impls→builtin rename - VA offset bench + tests: 2D/1D, standard Triton kernel pattern Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
"""Concrete component implementations.
|
||||
|
||||
Loaded from components.yaml via ComponentRegistry.load_components_yaml().
|
||||
Manual imports are no longer needed — add new impls to components.yaml.
|
||||
|
||||
Classes are still importable from this package via lazy __getattr__.
|
||||
"""
|
||||
|
||||
from kernbench.components.base import ComponentRegistry
|
||||
|
||||
ComponentRegistry.load_components_yaml()
|
||||
|
||||
# Lazy re-export: allow `from kernbench.components.builtin import FooComponent`
|
||||
# without eagerly importing every module.
|
||||
_CLASS_MAP: dict[str, str] = {} # ClassName → "module.path:ClassName"
|
||||
|
||||
|
||||
def _build_class_map() -> None:
|
||||
if _CLASS_MAP:
|
||||
return
|
||||
for class_path in ComponentRegistry._lazy.values():
|
||||
module_path, class_name = class_path.rsplit(":", 1)
|
||||
_CLASS_MAP[class_name] = class_path
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
_build_class_map()
|
||||
class_path = _CLASS_MAP.get(name)
|
||||
if class_path is None:
|
||||
raise ImportError(f"cannot import name '{name}' from 'kernbench.components.builtin'")
|
||||
import importlib
|
||||
module_path, class_name = class_path.rsplit(":", 1)
|
||||
mod = importlib.import_module(module_path)
|
||||
return getattr(mod, class_name)
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class TransitComponent(ComponentBase):
|
||||
"""Transit component for NOC, UCIe, XBAR nodes.
|
||||
|
||||
Applies overhead_ns processing delay (from node.attrs) then forwards the
|
||||
Transaction to the next hop via inherited _forward_txn().
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class HbmCtrlComponent(ComponentBase):
|
||||
"""HBM controller: terminal component that models HBM access latency.
|
||||
|
||||
Dual-channel model: separate read and write resources (each capacity=1)
|
||||
allowing concurrent read/write like PE_DMA. Multiple reads or multiple
|
||||
writes still serialize within their respective channel.
|
||||
|
||||
On completion, creates a ResponseMsg and sends it back on the reverse path
|
||||
so that response latency is modeled through the fabric.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._read: simpy.Resource | None = None
|
||||
self._write: simpy.Resource | None = None
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
capacity = int(self.node.attrs.get("capacity", 1))
|
||||
self._read = simpy.Resource(env, capacity=capacity)
|
||||
self._write = simpy.Resource(env, capacity=capacity)
|
||||
super().start(env)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _select_channel(self, txn: Any) -> simpy.Resource:
|
||||
"""Select channel based on request type: write requests → write, else → read."""
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg, PeDmaMsg
|
||||
|
||||
assert self._read is not None and self._write is not None
|
||||
req = txn.request
|
||||
if isinstance(req, MemoryWriteMsg):
|
||||
return self._write
|
||||
if isinstance(req, PeDmaMsg) and req.is_write:
|
||||
return self._write
|
||||
return self._read
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Dispatch each incoming txn to a concurrent process for channel-level parallelism."""
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._handle_txn(env, txn))
|
||||
|
||||
def _handle_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Acquire channel, run, apply drain, send response."""
|
||||
channel = self._select_channel(txn)
|
||||
with channel.request() as req:
|
||||
yield req
|
||||
yield from self.run(env, txn.nbytes)
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
yield from self._send_response(env, txn)
|
||||
|
||||
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Route completion based on path type.
|
||||
|
||||
- PeDmaMsg: succeed done directly (probe).
|
||||
- Bypass path (no m_cpu): MemoryWrite succeeds done; MemoryRead sends
|
||||
data back on reverse path with original done event.
|
||||
- M_CPU DMA path: send ResponseMsg for m_cpu/io_cpu aggregation.
|
||||
"""
|
||||
from kernbench.runtime_api.kernel import MemoryReadMsg, PeDmaMsg
|
||||
|
||||
if isinstance(txn.request, PeDmaMsg):
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
resp_txn = Transaction(
|
||||
request=txn.request, path=reverse_path, step=0,
|
||||
nbytes=0, done=txn.done, is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
return
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Bypass path: no m_cpu in the transaction path
|
||||
is_bypass = not any("m_cpu" in n for n in txn.path)
|
||||
if is_bypass:
|
||||
if isinstance(txn.request, MemoryReadMsg):
|
||||
# D2H: send data back on reverse path to pcie_ep
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
resp_txn = Transaction(
|
||||
request=txn.request, path=reverse_path, step=0,
|
||||
nbytes=txn.request.nbytes, done=txn.done,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
return
|
||||
# MemoryWrite bypass or short path: done
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# M_CPU DMA path: send ResponseMsg for aggregation
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2 and self.ctx:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
pe_id = int(parts[3].replace("slice", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=txn.request.correlation_id,
|
||||
request_id=txn.request.request_id,
|
||||
src_cube=cube_id, src_pe=pe_id, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
@@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class IoCpuComponent(ComponentBase):
|
||||
"""IO_CPU component: multi-cube fan-out with response aggregation.
|
||||
|
||||
Forward path:
|
||||
1. Applies overhead_ns processing overhead.
|
||||
2. Resolves target cube(s) from request.target_cubes.
|
||||
3. Fans out sub-Transactions to each target cube's M_CPU.
|
||||
|
||||
Response path:
|
||||
Collects ResponseMsg from each M_CPU. When all cube responses are
|
||||
received, succeeds the parent txn.done.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
# Pending fan-out tracking: request_id → (expected, received, parent_txn_done)
|
||||
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
if getattr(txn, "is_response", False):
|
||||
self._collect_response(txn)
|
||||
else:
|
||||
yield from self.run(env, txn.nbytes)
|
||||
env.process(self._dispatch_to_m_cpus(env, txn))
|
||||
|
||||
def _collect_response(self, resp_txn: Any) -> None:
|
||||
"""Receive a cube response and increment the aggregation counter."""
|
||||
key = resp_txn.request.request_id
|
||||
if key not in self._pending:
|
||||
return
|
||||
expected, received, parent_done = self._pending[key]
|
||||
received += 1
|
||||
if received >= expected:
|
||||
parent_done.succeed()
|
||||
del self._pending[key]
|
||||
else:
|
||||
self._pending[key] = (expected, received, parent_done)
|
||||
|
||||
def _dispatch_to_m_cpus(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Fan out sub-Transactions to target cube M_CPUs, wait for responses."""
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
|
||||
|
||||
request = txn.request
|
||||
try:
|
||||
cube_targets = self._resolve_cube_targets(request)
|
||||
except Exception:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
if not cube_targets:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Setup aggregation
|
||||
self._pending[request.request_id] = (len(cube_targets), 0, txn.done)
|
||||
|
||||
# Fan out to each target cube's M_CPU
|
||||
for sip, cube in cube_targets:
|
||||
try:
|
||||
m_cpu_id = self.ctx.resolver.find_m_cpu(sip, cube)
|
||||
path = self.ctx.router.find_node_path(self.node.id, m_cpu_id)
|
||||
except Exception:
|
||||
continue
|
||||
if len(path) < 2:
|
||||
continue
|
||||
sub_txn = Transaction(
|
||||
request=request, path=path, step=0,
|
||||
nbytes=txn.nbytes, done=env.event(),
|
||||
result_data=txn.result_data,
|
||||
)
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
|
||||
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
|
||||
"""Return list of (sip, cube) pairs to fan out to."""
|
||||
from kernbench.runtime_api.kernel import (
|
||||
KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, MmuMapMsg, MmuUnmapMsg,
|
||||
)
|
||||
|
||||
target_cubes = getattr(request, "target_cubes", "all")
|
||||
|
||||
if isinstance(request, MemoryWriteMsg):
|
||||
sip = request.dst_sip
|
||||
if target_cubes == "all":
|
||||
cube = self._cube_from_pa(request.dst_pa, fallback=request.dst_cube)
|
||||
return [(sip, cube)]
|
||||
return [(sip, c) for c in target_cubes]
|
||||
|
||||
if isinstance(request, MemoryReadMsg):
|
||||
sip = request.src_sip
|
||||
if target_cubes == "all":
|
||||
cube = self._cube_from_pa(request.src_pa, fallback=request.src_cube)
|
||||
return [(sip, cube)]
|
||||
return [(sip, c) for c in target_cubes]
|
||||
|
||||
if isinstance(request, KernelLaunchMsg):
|
||||
my_sip = self._my_sip()
|
||||
if target_cubes != "all":
|
||||
return [(my_sip, c) for c in target_cubes]
|
||||
# "all": derive from tensor shards, filtered to this SIP
|
||||
seen: set[tuple[int, int]] = set()
|
||||
targets: list[tuple[int, int]] = []
|
||||
for arg in request.args:
|
||||
if arg.arg_kind != "tensor":
|
||||
continue
|
||||
for shard in arg.shards:
|
||||
if shard.sip != my_sip:
|
||||
continue
|
||||
key = (shard.sip, shard.cube)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
targets.append(key)
|
||||
return targets
|
||||
|
||||
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
|
||||
my_sip = self._my_sip()
|
||||
if target_cubes == "all":
|
||||
n_cubes = 16
|
||||
if self.ctx and self.ctx.spec:
|
||||
sips = self.ctx.spec.get("system", {}).get("sips", {})
|
||||
n_cubes = sips.get("cubes_per_sip", 16)
|
||||
return [(my_sip, c) for c in range(n_cubes)]
|
||||
return [(my_sip, c) for c in target_cubes]
|
||||
|
||||
return []
|
||||
|
||||
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
|
||||
"""Extract cube_id from a physical address, with fallback."""
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
try:
|
||||
return PhysAddr.decode(pa_val).cube_id
|
||||
except Exception:
|
||||
return fallback
|
||||
|
||||
def _my_sip(self) -> int:
|
||||
"""Extract this IO_CPU's SIP ID from its node ID (e.g. 'sip0.io0.io_cpu' → 0)."""
|
||||
return int(self.node.id.split(".")[0].replace("sip", ""))
|
||||
@@ -0,0 +1,332 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class MCpuComponent(ComponentBase):
|
||||
"""M_CPU component: multi-PE DMA fan-out with response aggregation.
|
||||
|
||||
Forward path (ADR-0015 D5):
|
||||
When a forward Transaction arrives at m_cpu (terminal hop), M_CPU fans out
|
||||
DMA sub-Transactions to target PEs' HBM slices. target_pe on the request
|
||||
controls fan-out: int → single PE, "all" → all PEs in the cube.
|
||||
|
||||
Response path:
|
||||
ResponseMsg from each hbm_ctrl arrives back at m_cpu. Once all PE responses
|
||||
are collected, m_cpu sends an aggregate ResponseMsg on the reverse command
|
||||
path back to io_cpu.
|
||||
|
||||
Transit:
|
||||
When m_cpu is NOT the terminal hop (transit or response relay), the
|
||||
Transaction is forwarded normally to the next hop.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
# Pending fan-out tracking: request_id → (expected, received, all_done_event)
|
||||
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
|
||||
# Store parent txn for response sending: request_id → parent_txn
|
||||
self._parent_txns: dict[str, Any] = {}
|
||||
# DMA engine resources (ADR-0015 D5, ADR-0014 D4): capacity=1 each
|
||||
self._dma_write: simpy.Resource | None = None
|
||||
self._dma_read: simpy.Resource | None = None
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
self._dma_write = simpy.Resource(env, capacity=1)
|
||||
self._dma_read = simpy.Resource(env, capacity=1)
|
||||
super().start(env)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Dispatch forward txns, collect response txns."""
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MmuMapMsg, MmuUnmapMsg
|
||||
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
if getattr(txn, "is_response", False):
|
||||
self._collect_response(txn)
|
||||
else:
|
||||
yield from self.run(env, txn.nbytes)
|
||||
next_hop = txn.next_hop
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
elif self.ctx is not None and txn.request is not None:
|
||||
if isinstance(txn.request, KernelLaunchMsg):
|
||||
env.process(self._kernel_launch_fanout(env, txn))
|
||||
elif isinstance(txn.request, (MmuMapMsg, MmuUnmapMsg)):
|
||||
env.process(self._mmu_msg_fanout(env, txn))
|
||||
else:
|
||||
env.process(self._dma_fanout(env, txn))
|
||||
else:
|
||||
txn.done.succeed()
|
||||
|
||||
def _collect_response(self, resp_txn: Any) -> None:
|
||||
"""Receive a PE response and increment the aggregation counter."""
|
||||
key = resp_txn.request.request_id
|
||||
if key not in self._pending:
|
||||
return
|
||||
expected, received, all_done = self._pending[key]
|
||||
received += 1
|
||||
if received >= expected:
|
||||
all_done.succeed()
|
||||
del self._pending[key]
|
||||
else:
|
||||
self._pending[key] = (expected, received, all_done)
|
||||
|
||||
def _dma_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Fan out DMA sub-Transactions to target PE(s), wait for responses,
|
||||
then send aggregate response on reverse command path.
|
||||
|
||||
Each DMA transfer acquires the DMA resource (capacity=1 per ADR-0014 D4),
|
||||
so multi-PE fan-out is serialized through the DMA engine.
|
||||
"""
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg
|
||||
|
||||
request = txn.request
|
||||
target_pe = getattr(request, "target_pe", "all")
|
||||
|
||||
dst_nodes = self._resolve_dma_destinations(request, target_pe)
|
||||
if not dst_nodes:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Setup aggregation
|
||||
all_done = env.event()
|
||||
self._pending[request.request_id] = (len(dst_nodes), 0, all_done)
|
||||
self._parent_txns[request.request_id] = txn
|
||||
|
||||
# Select DMA resource based on operation type
|
||||
dma_res = self._dma_write if isinstance(request, MemoryWriteMsg) else self._dma_read
|
||||
|
||||
# Fan out DMA sub-txns (serialized through DMA resource)
|
||||
max_drain_ns = 0.0
|
||||
for dst_node in dst_nodes:
|
||||
try:
|
||||
dma_path = self.ctx.router.find_mcpu_dma_path(self.node.id, dst_node)
|
||||
except Exception:
|
||||
continue
|
||||
if len(dma_path) < 2:
|
||||
continue
|
||||
drain_ns = self.ctx.compute_drain_ns(dma_path, txn.nbytes)
|
||||
max_drain_ns = max(max_drain_ns, drain_ns)
|
||||
sub_txn = Transaction(
|
||||
request=request, path=dma_path, step=0,
|
||||
nbytes=txn.nbytes, done=env.event(),
|
||||
drain_ns=drain_ns,
|
||||
)
|
||||
with dma_res.request() as req:
|
||||
yield req
|
||||
yield self.out_ports[dma_path[1]].put(sub_txn.advance())
|
||||
|
||||
# Wait for all PE responses
|
||||
yield all_done
|
||||
txn.result_data["xfer_ns"] = max_drain_ns
|
||||
del self._parent_txns[request.request_id]
|
||||
|
||||
# Send aggregate response on reverse command path
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=request.correlation_id,
|
||||
request_id=request.request_id,
|
||||
src_cube=cube_id, src_pe=-1, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
|
||||
def _kernel_launch_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Fan out KernelLaunchMsg to target PE_CPU(s) via NOC (ADR-0009 D3).
|
||||
|
||||
Routes through find_node_path (M_CPU → NOC → PE_CPU command edges).
|
||||
PE_CPU sends ResponseMsg back via NOC → M_CPU on completion.
|
||||
Then sends aggregate ResponseMsg back to IO_CPU on the reverse path.
|
||||
"""
|
||||
request = txn.request
|
||||
target_pe = getattr(request, "target_pe", "all")
|
||||
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
|
||||
pe_ids = self._resolve_pe_ids(target_pe)
|
||||
|
||||
if not pe_ids:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Fan out to each PE_CPU, using response-based aggregation
|
||||
sub_txns: list[Transaction] = []
|
||||
n_dispatched = 0
|
||||
for pe_id in pe_ids:
|
||||
pe_cpu_id = f"{cube_prefix}.pe{pe_id}.pe_cpu"
|
||||
try:
|
||||
path = self.ctx.router.find_node_path(self.node.id, pe_cpu_id)
|
||||
except Exception:
|
||||
continue
|
||||
if len(path) < 2:
|
||||
continue
|
||||
sub_txn = Transaction(
|
||||
request=request, path=path, step=0,
|
||||
nbytes=0, done=env.event(),
|
||||
)
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
sub_txns.append(sub_txn)
|
||||
n_dispatched += 1
|
||||
|
||||
if n_dispatched == 0:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Setup response aggregation (PE_CPU ResponseMsg arrives via _collect_response)
|
||||
all_done = env.event()
|
||||
self._pending[request.request_id] = (n_dispatched, 0, all_done)
|
||||
self._parent_txns[request.request_id] = txn
|
||||
|
||||
# Wait for all PE_CPU responses via NOC
|
||||
yield all_done
|
||||
del self._parent_txns[request.request_id]
|
||||
|
||||
# Aggregate PE-internal metrics (max across PEs)
|
||||
pe_exec_values = [st.result_data.get("pe_exec_ns", 0.0) for st in sub_txns]
|
||||
if pe_exec_values:
|
||||
txn.result_data["pe_exec_ns"] = max(pe_exec_values)
|
||||
dma_values = [st.result_data.get("dma_ns", 0.0) for st in sub_txns]
|
||||
if dma_values:
|
||||
txn.result_data["dma_ns"] = max(dma_values)
|
||||
compute_values = [st.result_data.get("compute_ns", 0.0) for st in sub_txns]
|
||||
if compute_values:
|
||||
txn.result_data["compute_ns"] = max(compute_values)
|
||||
|
||||
# Send aggregate response on reverse command path back to IO_CPU
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=request.correlation_id,
|
||||
request_id=request.request_id,
|
||||
src_cube=cube_id, src_pe=-1, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
|
||||
def _resolve_dma_destinations(self, request: Any, target_pe: int | str) -> list[str]:
|
||||
"""Return list of HBM destination node_ids for DMA fan-out.
|
||||
|
||||
Uses PA-based resolution to determine the actual target cube and slice,
|
||||
enabling cross-cube DMA routing when the PA points to a remote cube.
|
||||
"""
|
||||
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
|
||||
|
||||
if isinstance(target_pe, int):
|
||||
return [f"{cube_prefix}.hbm_ctrl.slice{target_pe}"]
|
||||
|
||||
# PA-based resolution: extract actual target from physical address
|
||||
pa_val = getattr(request, "dst_pa", None) or getattr(request, "src_pa", None)
|
||||
if pa_val is not None:
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
try:
|
||||
pa = PhysAddr.decode(pa_val)
|
||||
return [self.ctx.resolver.resolve(pa)]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# "all" without PA (KernelLaunch): all slices in local cube
|
||||
n_slices = 8
|
||||
if self.ctx and self.ctx.spec:
|
||||
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
|
||||
n_slices = mm.get("hbm_slices_per_cube", 8)
|
||||
return [f"{cube_prefix}.hbm_ctrl.slice{i}" for i in range(n_slices)]
|
||||
|
||||
def _mmu_msg_fanout(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Fan out MmuMapMsg/MmuUnmapMsg to target PE_MMU(s) via NOC.
|
||||
|
||||
Routes through find_node_path (M_CPU → NOC → PE_MMU command edges).
|
||||
PE_MMU is a terminal node — completes the transaction directly.
|
||||
"""
|
||||
request = txn.request
|
||||
target_pe = getattr(request, "target_pe", "all")
|
||||
cube_prefix = self.node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0"
|
||||
pe_ids = self._resolve_pe_ids(target_pe)
|
||||
|
||||
if not pe_ids:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Fan out to each PE_MMU
|
||||
sub_dones: list[simpy.Event] = []
|
||||
for pe_id in pe_ids:
|
||||
pe_mmu_id = f"{cube_prefix}.pe{pe_id}.pe_mmu"
|
||||
try:
|
||||
path = self.ctx.router.find_node_path(self.node.id, pe_mmu_id)
|
||||
except Exception:
|
||||
continue
|
||||
if len(path) < 2:
|
||||
continue
|
||||
sub_done = env.event()
|
||||
sub_txn = Transaction(
|
||||
request=request, path=path, step=0,
|
||||
nbytes=0, done=sub_done,
|
||||
)
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
sub_dones.append(sub_done)
|
||||
|
||||
# Wait for all PE_MMUs to complete
|
||||
for sd in sub_dones:
|
||||
yield sd
|
||||
|
||||
# Send aggregate response on reverse path
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=request.correlation_id,
|
||||
request_id=request.request_id,
|
||||
src_cube=cube_id, src_pe=-1, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
|
||||
def _resolve_pe_ids(self, target_pe: int | str) -> list[int]:
|
||||
"""Return list of PE IDs to fan out to (used by kernel launch fan-out)."""
|
||||
if isinstance(target_pe, int):
|
||||
return [target_pe]
|
||||
# "all": all PEs in local cube
|
||||
n_slices = 8
|
||||
if self.ctx and self.ctx.spec:
|
||||
mm = self.ctx.spec.get("cube", {}).get("memory_map", {})
|
||||
n_slices = mm.get("hbm_slices_per_cube", 8)
|
||||
return list(range(n_slices))
|
||||
@@ -0,0 +1,224 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class TwoDMeshNocComponent(ComponentBase):
|
||||
"""2D mesh NOC modeled as a single smart node.
|
||||
|
||||
Latency model:
|
||||
- Traversal latency = Manhattan distance between prev_hop and next_hop
|
||||
node positions, split into XY segments, traversed with pipeline.
|
||||
- overhead_ns (from node.attrs) is added once per traversal.
|
||||
|
||||
Contention model:
|
||||
- Each directed XY segment is a simpy.Resource(capacity=1).
|
||||
- Pipeline: next segment's resource is requested before the current
|
||||
segment's timeout completes, so a free downstream segment is acquired
|
||||
immediately (wormhole-style cut-through).
|
||||
- Two transactions sharing a segment (same row or column band) contend.
|
||||
|
||||
Concurrency:
|
||||
- _worker spawns an independent SimPy process per transaction, so the
|
||||
NOC is never serialized at the node level — only at segment resources.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._env: simpy.Environment | None = None
|
||||
self._links: dict[tuple, simpy.Resource] = {}
|
||||
self._x_grid: list[float] = []
|
||||
self._y_grid: list[float] = []
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
self._env = env
|
||||
self._build_grid()
|
||||
super().start(env)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
yield env.timeout(0)
|
||||
|
||||
# ── Grid construction ────────────────────────────────────────────
|
||||
|
||||
def _build_grid(self) -> None:
|
||||
if not self.ctx:
|
||||
return
|
||||
mesh = self.ctx.spec.get("_mesh") if self.ctx.spec else None
|
||||
if mesh:
|
||||
self._build_grid_from_mesh(mesh)
|
||||
else:
|
||||
self._build_grid_from_positions()
|
||||
|
||||
def _build_grid_from_mesh(self, mesh: dict) -> None:
|
||||
"""Build XY grid from cube_mesh.yaml router positions (authoritative)."""
|
||||
origin_x, origin_y = self._cube_origin()
|
||||
xs: set[float] = set()
|
||||
ys: set[float] = set()
|
||||
for key, router in mesh.get("routers", {}).items():
|
||||
if router is not None:
|
||||
xs.add(round(origin_x + router["pos_mm"][0], 2))
|
||||
ys.add(round(origin_y + router["pos_mm"][1], 2))
|
||||
self._x_grid = sorted(xs)
|
||||
self._y_grid = sorted(ys)
|
||||
|
||||
def _build_grid_from_positions(self) -> None:
|
||||
"""Fallback: infer grid from all node positions in the cube."""
|
||||
cube_prefix = self.node.id.rsplit(".", 1)[0]
|
||||
xs: set[float] = set()
|
||||
ys: set[float] = set()
|
||||
for node_id, pos in self.ctx.positions.items():
|
||||
if node_id.startswith(cube_prefix + ".") and pos is not None:
|
||||
xs.add(round(pos[0], 2))
|
||||
ys.add(round(pos[1], 2))
|
||||
self._x_grid = sorted(xs)
|
||||
self._y_grid = sorted(ys)
|
||||
|
||||
def _cube_origin(self) -> tuple[float, float]:
|
||||
"""Compute absolute origin (top-left) of this cube from cube_id."""
|
||||
parts = self.node.id.split(".")
|
||||
cube_str = [p for p in parts if p.startswith("cube")][0]
|
||||
cube_id = int(cube_str[4:])
|
||||
spec = self.ctx.spec
|
||||
sip_spec = spec.get("sip", {})
|
||||
cube_spec = spec.get("cube", {})
|
||||
mesh_w = sip_spec.get("cube_mesh", {}).get("w", 4)
|
||||
cube_w = cube_spec.get("geometry", {}).get("cube_mm", {}).get("w", 17.0)
|
||||
cube_h = cube_spec.get("geometry", {}).get("cube_mm", {}).get("h", 14.0)
|
||||
seam = sip_spec.get("links", {}).get("inter_cube_mesh", {}).get(
|
||||
"distance_mm_across_seam", 1.0)
|
||||
col = cube_id % mesh_w
|
||||
row = cube_id // mesh_w
|
||||
return (col * (cube_w + seam), row * (cube_h + seam))
|
||||
|
||||
def _get_link(self, key: tuple) -> simpy.Resource:
|
||||
if key not in self._links:
|
||||
assert self._env is not None
|
||||
self._links[key] = simpy.Resource(self._env, capacity=1)
|
||||
return self._links[key]
|
||||
|
||||
# ── Worker ───────────────────────────────────────────────────────
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._route(env, txn))
|
||||
|
||||
def _route(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
prev_hop = txn.path[txn.step - 1] if txn.step > 0 else None
|
||||
next_hop = txn.next_hop
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
|
||||
links: list[tuple[tuple, float]] = []
|
||||
if prev_hop and next_hop and self.ctx:
|
||||
src_pos = self.ctx.positions.get(prev_hop)
|
||||
dst_pos = self.ctx.positions.get(next_hop)
|
||||
if src_pos and dst_pos:
|
||||
links = self._xy_links(src_pos, dst_pos)
|
||||
|
||||
if links:
|
||||
yield from self._traverse(env, links, overhead_ns)
|
||||
else:
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
else:
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
txn.done.succeed()
|
||||
|
||||
# ── XY routing and pipelined link traversal ──────────────────────
|
||||
|
||||
def _traverse(
|
||||
self,
|
||||
env: simpy.Environment,
|
||||
links: list[tuple[tuple, float]],
|
||||
overhead_ns: float,
|
||||
) -> Generator:
|
||||
"""Pipeline: request next segment before current timeout finishes."""
|
||||
ns_per_mm = self.ctx.ns_per_mm # type: ignore[union-attr]
|
||||
|
||||
# Acquire first link
|
||||
first_key, _ = links[0]
|
||||
current_resource = self._get_link(first_key)
|
||||
current_req = current_resource.request()
|
||||
yield current_req
|
||||
|
||||
for i, (_, dist_mm) in enumerate(links):
|
||||
# Request next link before current timeout (pipeline)
|
||||
if i + 1 < len(links):
|
||||
next_key, _ = links[i + 1]
|
||||
next_resource = self._get_link(next_key)
|
||||
next_req = next_resource.request()
|
||||
|
||||
yield env.timeout(dist_mm * ns_per_mm + (overhead_ns if i == 0 else 0.0))
|
||||
current_resource.release(current_req)
|
||||
|
||||
if i + 1 < len(links):
|
||||
yield next_req # usually already fulfilled (pipeline)
|
||||
current_resource = next_resource
|
||||
current_req = next_req
|
||||
|
||||
def _xy_links(
|
||||
self,
|
||||
src: tuple[float, float],
|
||||
dst: tuple[float, float],
|
||||
) -> list[tuple[tuple, float]]:
|
||||
"""XY routing: horizontal segment first, then vertical.
|
||||
|
||||
Returns list of (link_key, dist_mm) pairs, where link_key uniquely
|
||||
identifies a directed segment shared across concurrent transactions.
|
||||
"""
|
||||
x0, y0 = src
|
||||
x1, y1 = dst
|
||||
links: list[tuple[tuple, float]] = []
|
||||
|
||||
# Horizontal segment at y≈y0
|
||||
if abs(x0 - x1) > 1e-9:
|
||||
y_band = self._snap(y0, self._y_grid)
|
||||
for xa, xb in self._segments(x0, x1, self._x_grid):
|
||||
d = abs(xb - xa)
|
||||
if d > 1e-9:
|
||||
lo, hi = (xa, xb) if xa < xb else (xb, xa)
|
||||
dir_h = "E" if xb > xa else "W"
|
||||
links.append((("H", round(y_band, 2), round(lo, 2), round(hi, 2), dir_h), d))
|
||||
|
||||
# Vertical segment at x≈x1
|
||||
if abs(y0 - y1) > 1e-9:
|
||||
x_band = self._snap(x1, self._x_grid)
|
||||
for ya, yb in self._segments(y0, y1, self._y_grid):
|
||||
d = abs(yb - ya)
|
||||
if d > 1e-9:
|
||||
lo, hi = (ya, yb) if ya < yb else (yb, ya)
|
||||
dir_v = "S" if yb > ya else "N"
|
||||
links.append((("V", round(x_band, 2), round(lo, 2), round(hi, 2), dir_v), d))
|
||||
|
||||
return links
|
||||
|
||||
@staticmethod
|
||||
def _snap(val: float, grid: list[float]) -> float:
|
||||
if not grid:
|
||||
return val
|
||||
return min(grid, key=lambda g: abs(g - val))
|
||||
|
||||
@staticmethod
|
||||
def _segments(a: float, b: float, grid: list[float]) -> list[tuple[float, float]]:
|
||||
"""Consecutive (p_i, p_{i+1}) pairs covering range [a, b] using grid waypoints."""
|
||||
if abs(a - b) < 1e-9:
|
||||
return []
|
||||
lo, hi = (a, b) if a < b else (b, a)
|
||||
pts = [lo] + [g for g in grid if lo + 1e-9 < g < hi - 1e-9] + [hi]
|
||||
pairs = [(pts[i], pts[i + 1]) for i in range(len(pts) - 1)]
|
||||
if a > b:
|
||||
pairs = [(p2, p1) for p1, p2 in reversed(pairs)]
|
||||
return pairs
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PcieEpComponent(ComponentBase):
|
||||
"""PCIe endpoint: protocol processing overhead before forwarding.
|
||||
|
||||
Applies overhead_ns (from node.attrs) for PCIe protocol handling,
|
||||
then forwards via inherited _forward_txn().
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
@@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeCpuComponent(ComponentBase):
|
||||
"""PE_CPU: kernel execution controller (Stage 2).
|
||||
|
||||
Two-phase kernel execution (ADR-0014 D1):
|
||||
Phase 1 (compile): look up kernel from registry, run it with TLContext
|
||||
to generate a PeCommand list.
|
||||
Phase 2 (replay): iterate commands, dispatch to PE_SCHEDULER via
|
||||
PeInternalTxn, wait for blocking commands.
|
||||
|
||||
Non-kernel Transactions are forwarded normally.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._pe_prefix = node.id.rsplit(".", 1)[0] # e.g. "sip0.cube0.pe0"
|
||||
try:
|
||||
self._pe_idx = int(self._pe_prefix.rsplit("pe", 1)[1])
|
||||
except (IndexError, ValueError):
|
||||
self._pe_idx = 0
|
||||
# Extract sip/cube index for multi-SIP/cube shard matching
|
||||
parts = node.id.split(".")
|
||||
try:
|
||||
self._sip_idx = int(parts[0].replace("sip", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._sip_idx = 0
|
||||
try:
|
||||
self._cube_idx = int(parts[1].replace("cube", ""))
|
||||
except (IndexError, ValueError):
|
||||
self._cube_idx = 0
|
||||
|
||||
def _find_shard(self, shards: tuple) -> Any:
|
||||
"""Find shard matching this PE's (sip, cube, pe). Fallback to positional index."""
|
||||
for s in shards:
|
||||
if s.sip == self._sip_idx and s.cube == self._cube_idx and s.pe == self._pe_idx:
|
||||
return s
|
||||
return shards[min(self._pe_idx, len(shards) - 1)]
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg
|
||||
|
||||
if hasattr(txn, "request") and isinstance(txn.request, KernelLaunchMsg):
|
||||
yield from self._execute_kernel(env, txn)
|
||||
else:
|
||||
yield from self._forward_txn(env, txn)
|
||||
|
||||
def _execute_kernel(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Compile kernel function and replay command trace."""
|
||||
from kernbench.common.pe_commands import (
|
||||
CompositeCmd,
|
||||
PeCpuOverheadCmd,
|
||||
PeInternalTxn,
|
||||
WaitCmd,
|
||||
)
|
||||
from kernbench.triton_emu.registry import get_kernel
|
||||
from kernbench.triton_emu.tl_context import TLContext, run_kernel
|
||||
|
||||
request = txn.request
|
||||
|
||||
# Phase 1: Compile — apply PE_CPU setup overhead, then run kernel
|
||||
yield from self.run(env, 0)
|
||||
|
||||
kernel_fn = get_kernel(request.kernel_ref.name)
|
||||
|
||||
# Derive num_programs from the number of PE shards in this cube
|
||||
num_programs = 1
|
||||
for arg in request.args:
|
||||
if arg.arg_kind == "tensor":
|
||||
cube_pe_count = sum(
|
||||
1 for s in arg.shards
|
||||
if s.sip == self._sip_idx and s.cube == self._cube_idx
|
||||
)
|
||||
if cube_pe_count > num_programs:
|
||||
num_programs = cube_pe_count
|
||||
|
||||
tl = TLContext(pe_id=self._pe_idx, num_programs=num_programs, dispatch_cycles=0)
|
||||
|
||||
# Unpack KernelLaunchMsg.args into positional args for kernel function
|
||||
# TensorArg → va_base (already local, set by runtime) or PA fallback
|
||||
kernel_args: list = []
|
||||
for arg in request.args:
|
||||
if arg.arg_kind == "tensor":
|
||||
if arg.va_base:
|
||||
kernel_args.append(arg.va_base)
|
||||
else:
|
||||
shard = self._find_shard(arg.shards)
|
||||
kernel_args.append(shard.pa)
|
||||
elif arg.arg_kind == "scalar":
|
||||
kernel_args.append(arg.value)
|
||||
|
||||
run_kernel(kernel_fn, tl, *kernel_args)
|
||||
commands = tl.commands
|
||||
|
||||
# Phase 2: Replay — dispatch commands to PE_SCHEDULER
|
||||
pe_exec_start = env.now
|
||||
scheduler_id = f"{self._pe_prefix}.pe_scheduler"
|
||||
pending: dict[str, simpy.Event] = {} # completion_id → done event
|
||||
composite_results: list[dict] = [] # collect result_data from CompositeCmd txns
|
||||
|
||||
for cmd in commands:
|
||||
if isinstance(cmd, PeCpuOverheadCmd):
|
||||
yield env.timeout(cmd.cycles)
|
||||
elif isinstance(cmd, WaitCmd):
|
||||
if cmd.handle is not None:
|
||||
evt = pending.pop(cmd.handle.id, None)
|
||||
if evt:
|
||||
yield evt
|
||||
else:
|
||||
# Wait all pending completions
|
||||
for evt in pending.values():
|
||||
yield evt
|
||||
pending.clear()
|
||||
elif isinstance(cmd, CompositeCmd):
|
||||
# Non-blocking: dispatch to scheduler, track completion
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt,
|
||||
pe_prefix=self._pe_prefix,
|
||||
)
|
||||
composite_results.append(pe_txn.result_data)
|
||||
yield self.out_ports[scheduler_id].put(pe_txn)
|
||||
pending[cmd.completion.id] = done_evt
|
||||
else:
|
||||
# Blocking: dispatch and wait for completion
|
||||
done_evt = env.event()
|
||||
pe_txn = PeInternalTxn(
|
||||
command=cmd, done=done_evt,
|
||||
pe_prefix=self._pe_prefix,
|
||||
)
|
||||
yield self.out_ports[scheduler_id].put(pe_txn)
|
||||
yield done_evt
|
||||
|
||||
# Wait for any remaining pending completions
|
||||
for evt in pending.values():
|
||||
yield evt
|
||||
|
||||
# Record PE-internal execution time
|
||||
txn.result_data["pe_exec_ns"] = env.now - pe_exec_start
|
||||
|
||||
# Aggregate dma_ns / compute_ns from CompositeCmd results
|
||||
total_dma_ns = 0.0
|
||||
total_compute_ns = 0.0
|
||||
for rd in composite_results:
|
||||
total_dma_ns += rd.get("dma_ns", 0.0)
|
||||
total_compute_ns += rd.get("compute_ns", 0.0)
|
||||
txn.result_data["dma_ns"] = total_dma_ns
|
||||
txn.result_data["compute_ns"] = total_compute_ns
|
||||
|
||||
# Send ResponseMsg on reverse path (PE_CPU → NOC → M_CPU)
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=request.correlation_id,
|
||||
request_id=request.request_id,
|
||||
src_cube=self._cube_idx, src_pe=self._pe_idx,
|
||||
success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import PeEngineBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeDmaComponent(PeEngineBase):
|
||||
"""PE_DMA: dual-channel DMA engine with READ and WRITE resources.
|
||||
|
||||
Each channel has capacity=1 (ADR-0014 D4):
|
||||
- DMA_READ and DMA_WRITE may execute concurrently.
|
||||
- Multiple READs cannot overlap; multiple WRITEs cannot overlap.
|
||||
|
||||
Handles two message types:
|
||||
- Transaction: external fabric messages (PeDmaMsg probes, M_CPU DMA)
|
||||
- PeInternalTxn: PE-internal commands from PE_SCHEDULER
|
||||
(DmaReadCmd → HBM read, DmaWriteCmd → HBM write)
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._dma_read: simpy.Resource | None = None
|
||||
self._dma_write: simpy.Resource | None = None
|
||||
self._mmu = None # PeMMU instance, set by engine wiring
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
self._dma_read = simpy.Resource(env, capacity=1)
|
||||
self._dma_write = simpy.Resource(env, capacity=1)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
yield env.timeout(0)
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Handle PE-internal DMA command: resolve PA → HBM path → transfer."""
|
||||
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
from kernbench.runtime_api.kernel import PeDmaMsg
|
||||
|
||||
cmd = pe_txn.command
|
||||
assert self._dma_read is not None and self._dma_write is not None
|
||||
|
||||
# Determine direction and target address (VA → PA via MMU)
|
||||
if isinstance(cmd, DmaReadCmd):
|
||||
dma_res = self._dma_read
|
||||
raw_addr = cmd.src_addr
|
||||
is_write = False
|
||||
elif isinstance(cmd, DmaWriteCmd):
|
||||
dma_res = self._dma_write
|
||||
raw_addr = cmd.dst_addr
|
||||
is_write = True
|
||||
else:
|
||||
pe_txn.done.succeed()
|
||||
return
|
||||
|
||||
# Translate VA → PA via MMU (if available), then resolve HBM node
|
||||
# If MMU has no mapping for this address (PageFault), treat as PA directly
|
||||
# (backward-compatible with PA-only mode)
|
||||
if self._mmu is not None:
|
||||
from kernbench.policy.address.pe_mmu import PageFault
|
||||
try:
|
||||
target_pa = self._mmu.translate(raw_addr)
|
||||
if self._mmu.overhead_ns > 0:
|
||||
yield env.timeout(self._mmu.overhead_ns)
|
||||
except PageFault:
|
||||
target_pa = raw_addr
|
||||
else:
|
||||
target_pa = raw_addr # fallback: treat as PA directly
|
||||
pa = PhysAddr.decode(target_pa)
|
||||
dst_node = self.ctx.resolver.resolve(pa)
|
||||
path = self.ctx.router.find_path(self._pe_prefix, dst_node)
|
||||
drain_ns = self.ctx.compute_drain_ns(path, cmd.nbytes)
|
||||
|
||||
# Acquire DMA channel (command issue serialization)
|
||||
with dma_res.request() as req:
|
||||
yield req
|
||||
# Create sub-Transaction with PeDmaMsg (HbmCtrl handles it directly)
|
||||
sub_done = env.event()
|
||||
sub_request = PeDmaMsg(
|
||||
correlation_id="pe_internal",
|
||||
request_id=f"dma_{id(pe_txn)}",
|
||||
src_sip=0, src_cube=0, src_pe=0,
|
||||
dst_pa=target_pa, nbytes=cmd.nbytes,
|
||||
is_write=is_write,
|
||||
)
|
||||
sub_txn = Transaction(
|
||||
request=sub_request, path=path, step=0,
|
||||
nbytes=cmd.nbytes, done=sub_done, drain_ns=drain_ns,
|
||||
)
|
||||
# Send to next hop (path[0] is pe_dma itself, path[1] is xbar)
|
||||
if len(path) > 1:
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
# DMA channel released after issue
|
||||
|
||||
# Wait for HBM transfer completion
|
||||
yield sub_done
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Handle external Transaction (PeDmaMsg probe, M_CPU DMA) with channel acquisition."""
|
||||
# Response transactions bypass DMA channel (no outbound resource needed)
|
||||
if getattr(txn, "is_response", False):
|
||||
next_hop = txn.next_hop
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
dma_res = self._select_channel(txn)
|
||||
with dma_res.request() as req:
|
||||
yield req
|
||||
next_hop = txn.next_hop
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
else:
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
txn.done.succeed()
|
||||
|
||||
def _select_channel(self, txn: Any) -> simpy.Resource:
|
||||
"""Select DMA channel based on request type."""
|
||||
from kernbench.runtime_api.kernel import MemoryWriteMsg
|
||||
|
||||
assert self._dma_read is not None and self._dma_write is not None
|
||||
if isinstance(txn.request, MemoryWriteMsg):
|
||||
return self._dma_write
|
||||
return self._dma_read
|
||||
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import PeEngineBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
# dtype → bit width (for TFLOPS scaling)
|
||||
_DTYPE_BITS: dict[str, int] = {
|
||||
"f16": 16, "fp16": 16, "float16": 16, "bf16": 16,
|
||||
"f32": 32, "fp32": 32, "float32": 32,
|
||||
"i8": 8, "int8": 8,
|
||||
"i16": 16, "int16": 16,
|
||||
"i32": 32, "int32": 32,
|
||||
}
|
||||
|
||||
|
||||
class PeGemmComponent(PeEngineBase):
|
||||
"""PE_GEMM: matrix multiplication engine sharing accel_slot (ADR-0014 D4).
|
||||
|
||||
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
|
||||
exclusive with PE_MATH within the same PE.
|
||||
|
||||
Compute latency model:
|
||||
FLOPs = 2 * M * K * N
|
||||
effective_tflops = peak_tflops_f16 * (16 / dtype_bits)
|
||||
compute_ns = FLOPs / (effective_tflops * 1e3)
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._accel: simpy.Resource | None = None
|
||||
self._peak_tflops_f16: float = float(node.attrs.get("peak_tflops_f16", 0.0))
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
resource_name = self.node.attrs.get("shared_resource")
|
||||
if resource_name and self.ctx:
|
||||
self._accel = self.ctx.get_shared_resource(
|
||||
env, f"{self._pe_prefix}.{resource_name}"
|
||||
)
|
||||
|
||||
def _compute_ns(self, m: int, k: int, n: int, dtype: str) -> float:
|
||||
"""Compute GEMM latency in nanoseconds."""
|
||||
if self._peak_tflops_f16 <= 0:
|
||||
return float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
dtype_bits = _DTYPE_BITS.get(dtype, 16)
|
||||
effective_tflops = self._peak_tflops_f16 * (16.0 / dtype_bits)
|
||||
flops = 2.0 * m * k * n
|
||||
return flops / (effective_tflops * 1e3)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
from kernbench.common.pe_commands import GemmCmd
|
||||
|
||||
cmd = pe_txn.command
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
if isinstance(cmd, GemmCmd):
|
||||
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
|
||||
yield env.timeout(ns)
|
||||
else:
|
||||
yield from self.run(env, 0)
|
||||
else:
|
||||
if isinstance(cmd, GemmCmd):
|
||||
ns = self._compute_ns(cmd.m, cmd.k, cmd.n, cmd.a.dtype)
|
||||
yield env.timeout(ns)
|
||||
else:
|
||||
yield from self.run(env, 0)
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Transaction forwarding with accel_slot acquisition."""
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
yield from super()._forward_txn(env, txn)
|
||||
else:
|
||||
yield from super()._forward_txn(env, txn)
|
||||
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import PeEngineBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeMathComponent(PeEngineBase):
|
||||
"""PE_MATH: element-wise computation engine sharing accel_slot (ADR-0014 D4).
|
||||
|
||||
Uses a shared compute resource (PE_ACCEL capacity=1) that is mutually
|
||||
exclusive with PE_GEMM within the same PE.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._accel: simpy.Resource | None = None
|
||||
|
||||
def init_resources(self, env: simpy.Environment) -> None:
|
||||
resource_name = self.node.attrs.get("shared_resource")
|
||||
if resource_name and self.ctx:
|
||||
self._accel = self.ctx.get_shared_resource(
|
||||
env, f"{self._pe_prefix}.{resource_name}"
|
||||
)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def handle_command(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
yield from self.run(env, 0)
|
||||
else:
|
||||
yield from self.run(env, 0)
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _forward_txn(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Transaction forwarding with accel_slot acquisition."""
|
||||
if self._accel:
|
||||
with self._accel.request() as req:
|
||||
yield req
|
||||
yield from super()._forward_txn(env, txn)
|
||||
else:
|
||||
yield from super()._forward_txn(env, txn)
|
||||
@@ -0,0 +1,66 @@
|
||||
"""PE_MMU component: address translation unit.
|
||||
|
||||
Component role: receives MmuMapMsg/MmuUnmapMsg via inbox (independent of PE_CPU).
|
||||
Utility role: PE_DMA/PE_GEMM call mmu.translate() directly (no SimPy overhead).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase, ComponentRegistry
|
||||
from kernbench.policy.address.pe_mmu import PeMMU
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeMmuComponent(ComponentBase):
|
||||
"""PE_MMU: per-PE virtual-to-physical address translation.
|
||||
|
||||
Receives MmuMapMsg/MmuUnmapMsg via inbox and updates the internal
|
||||
page table. PE_DMA and PE_GEMM access the underlying PeMMU object
|
||||
via the ``mmu`` property for synchronous VA→PA translation.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
page_size = int(node.attrs.get("page_size", 2 * 1024 * 1024))
|
||||
overhead_ns = float(node.attrs.get("tlb_overhead_ns", 0.0))
|
||||
self._mmu = PeMMU(page_size=page_size, overhead_ns=overhead_ns)
|
||||
|
||||
@property
|
||||
def mmu(self) -> PeMMU:
|
||||
"""The underlying PeMMU utility object for direct translate() calls."""
|
||||
return self._mmu
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
yield env.timeout(0)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Process MmuMapMsg/MmuUnmapMsg from inbox."""
|
||||
from kernbench.runtime_api.kernel import MmuMapMsg, MmuUnmapMsg
|
||||
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
|
||||
if hasattr(txn, "request"):
|
||||
request = txn.request
|
||||
if isinstance(request, MmuMapMsg):
|
||||
for entry in request.entries:
|
||||
self._mmu.map(
|
||||
va=entry["va"], pa=entry["pa"], size=entry["size"],
|
||||
)
|
||||
txn.done.succeed()
|
||||
elif isinstance(request, MmuUnmapMsg):
|
||||
for entry in request.entries:
|
||||
self._mmu.unmap(va=entry["va"], size=entry["size"])
|
||||
txn.done.succeed()
|
||||
else:
|
||||
# Forward non-MMU transactions normally
|
||||
yield from self._forward_txn(env, txn)
|
||||
else:
|
||||
yield from self._forward_txn(env, txn)
|
||||
@@ -0,0 +1,245 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeSchedulerComponent(ComponentBase):
|
||||
"""PE_SCHEDULER: sole dispatcher inside a PE (ADR-0014 D1).
|
||||
|
||||
Receives PeInternalTxn from PE_CPU, routes to the appropriate engine:
|
||||
- DmaReadCmd / DmaWriteCmd → PE_DMA
|
||||
- GemmCmd → PE_GEMM
|
||||
- MathCmd → PE_MATH
|
||||
- CompositeCmd → tiled pipeline (Stage 3: ADR-0014 D3.2)
|
||||
|
||||
Composite GEMM pipeline (32x64x32 tiles):
|
||||
DMA_READ(b_tile_t) → COMPUTE(t) → DMA_WRITE(out_tile_t)
|
||||
with overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
|
||||
|
||||
Applies scheduler overhead_ns before dispatching each command.
|
||||
Non-PeInternalTxn messages are forwarded via inherited _forward_txn().
|
||||
"""
|
||||
|
||||
# Scheduler tile dimensions (ADR-0014 D3.2)
|
||||
TILE_M = 32
|
||||
TILE_K = 64
|
||||
TILE_N = 32
|
||||
|
||||
# Command → engine suffix dispatch table.
|
||||
# New engines: add a single entry here (e.g. ConvCmd: "pe_conv").
|
||||
_CMD_DISPATCH: dict[type, str] = {}
|
||||
|
||||
@classmethod
|
||||
def _ensure_dispatch_table(cls) -> None:
|
||||
if cls._CMD_DISPATCH:
|
||||
return
|
||||
from kernbench.common.pe_commands import DmaReadCmd, DmaWriteCmd, GemmCmd, MathCmd
|
||||
|
||||
cls._CMD_DISPATCH = {
|
||||
DmaReadCmd: "pe_dma",
|
||||
DmaWriteCmd: "pe_dma",
|
||||
GemmCmd: "pe_gemm",
|
||||
MathCmd: "pe_math",
|
||||
}
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._pe_prefix = node.id.rsplit(".", 1)[0]
|
||||
self._ensure_dispatch_table()
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
from kernbench.common.pe_commands import PeInternalTxn
|
||||
|
||||
while True:
|
||||
msg: Any = yield self._inbox.get()
|
||||
if isinstance(msg, PeInternalTxn):
|
||||
env.process(self._dispatch(env, msg))
|
||||
else:
|
||||
yield from self._forward_txn(env, msg)
|
||||
|
||||
def _dispatch(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Route a PeInternalTxn to the correct engine via dispatch table."""
|
||||
from kernbench.common.pe_commands import CompositeCmd
|
||||
|
||||
# Scheduler overhead
|
||||
yield from self.run(env, 0)
|
||||
|
||||
cmd = pe_txn.command
|
||||
|
||||
# Check dispatch table first
|
||||
engine_suffix = self._CMD_DISPATCH.get(type(cmd))
|
||||
if engine_suffix is not None:
|
||||
yield self.out_ports[f"{self._pe_prefix}.{engine_suffix}"].put(pe_txn)
|
||||
return
|
||||
|
||||
# CompositeCmd: tiled pipeline (not a simple forward)
|
||||
if isinstance(cmd, CompositeCmd):
|
||||
yield from self._dispatch_composite(env, pe_txn)
|
||||
return
|
||||
|
||||
# Unknown command — signal done immediately
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _dispatch_composite(self, env: simpy.Environment, pe_txn: PeInternalTxn) -> Generator:
|
||||
"""Composite tiled pipeline (ADR-0014 D3.2).
|
||||
|
||||
GEMM: 3-stage pipeline with b-tile streaming from HBM.
|
||||
MATH: sequential compute + DMA_WRITE (no tiling).
|
||||
"""
|
||||
from kernbench.common.pe_commands import CompositeCmd
|
||||
|
||||
cmd = pe_txn.command
|
||||
assert isinstance(cmd, CompositeCmd)
|
||||
if cmd.op == "gemm" and cmd.b is not None:
|
||||
yield from self._pipeline_gemm(env, pe_txn, cmd)
|
||||
else:
|
||||
yield from self._pipeline_math(env, pe_txn, cmd)
|
||||
|
||||
def _pipeline_gemm(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
|
||||
"""Tiled GEMM pipeline: stream b tiles from HBM, compute, write results.
|
||||
|
||||
Tensor a is in TCM (loaded via tl.load). Tensor b is in HBM (via tl.ref).
|
||||
Pipeline: DMA_READ(b_tile_t) -> COMPUTE(t) -> DMA_WRITE(out_tile_t)
|
||||
Overlap: READ(t+1) || COMPUTE(t) || WRITE(t-1)
|
||||
"""
|
||||
from kernbench.common.pe_commands import (
|
||||
DmaReadCmd,
|
||||
DmaWriteCmd,
|
||||
GemmCmd,
|
||||
PeInternalTxn as PeTxn,
|
||||
TensorHandle,
|
||||
)
|
||||
|
||||
pp = self._pe_prefix
|
||||
a = cmd.a # already in TCM
|
||||
b = cmd.b # HBM reference (via tl.ref)
|
||||
|
||||
M, K_a = a.shape[-2], a.shape[-1]
|
||||
K_b, N = b.shape[-2], b.shape[-1]
|
||||
dtype = a.dtype
|
||||
dtype_bytes = b.nbytes // (K_b * N) if (K_b * N) > 0 else 2
|
||||
|
||||
# Tile counts
|
||||
n_tiles_k = max(1, (K_a + self.TILE_K - 1) // self.TILE_K)
|
||||
n_tiles_n = max(1, (N + self.TILE_N - 1) // self.TILE_N)
|
||||
n_tiles = n_tiles_k * n_tiles_n
|
||||
|
||||
prev_compute_done = None
|
||||
prev_write_done = None
|
||||
total_dma_ns = 0.0
|
||||
total_compute_ns = 0.0
|
||||
|
||||
for tile_idx in range(n_tiles):
|
||||
tk = tile_idx // n_tiles_n
|
||||
tn = tile_idx % n_tiles_n
|
||||
|
||||
k_start = tk * self.TILE_K
|
||||
n_start = tn * self.TILE_N
|
||||
tile_k = min(self.TILE_K, K_a - k_start)
|
||||
tile_n = min(self.TILE_N, N - n_start)
|
||||
tile_nbytes = tile_k * tile_n * dtype_bytes
|
||||
|
||||
# --- Stage 1: DMA_READ b_tile from HBM ---
|
||||
read_done = env.event()
|
||||
b_tile_addr = b.addr + (k_start * N + n_start) * dtype_bytes
|
||||
b_tile_handle = TensorHandle(
|
||||
id=f"b_tile_{tile_idx}", addr=b_tile_addr,
|
||||
shape=(tile_k, tile_n), dtype=dtype, nbytes=tile_nbytes,
|
||||
)
|
||||
read_cmd = DmaReadCmd(handle=b_tile_handle, src_addr=b_tile_addr, nbytes=tile_nbytes)
|
||||
read_txn = PeTxn(command=read_cmd, done=read_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(read_txn)
|
||||
|
||||
# Wait for previous compute before starting this tile's compute
|
||||
if prev_compute_done is not None:
|
||||
yield prev_compute_done
|
||||
|
||||
# Wait for this tile's DMA_READ
|
||||
yield read_done
|
||||
total_dma_ns += env.now - t0
|
||||
|
||||
# --- Stage 2: COMPUTE (GEMM) ---
|
||||
compute_done = env.event()
|
||||
out_handle = TensorHandle(
|
||||
id=f"out_tile_{tile_idx}", addr=0,
|
||||
shape=(M, tile_n), dtype=dtype,
|
||||
nbytes=M * tile_n * dtype_bytes,
|
||||
)
|
||||
compute_cmd = GemmCmd(a=a, b=b_tile_handle, out=out_handle,
|
||||
m=M, k=tile_k, n=tile_n)
|
||||
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_gemm"].put(compute_txn)
|
||||
|
||||
# Wait for previous write (DMA_WRITE serialization)
|
||||
if prev_write_done is not None:
|
||||
yield prev_write_done
|
||||
|
||||
# Wait for compute of THIS tile
|
||||
yield compute_done
|
||||
total_compute_ns += env.now - t0
|
||||
prev_compute_done = compute_done
|
||||
|
||||
# --- Stage 3: DMA_WRITE out_tile to HBM ---
|
||||
write_done = env.event()
|
||||
out_tile_pa = cmd.out_addr + n_start * dtype_bytes
|
||||
write_nbytes = M * tile_n * dtype_bytes
|
||||
write_cmd = DmaWriteCmd(handle=out_handle, dst_addr=out_tile_pa, nbytes=write_nbytes)
|
||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||
t0 = env.now
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||
prev_write_done = write_done
|
||||
|
||||
# Wait for final write
|
||||
if prev_write_done is not None:
|
||||
t0 = env.now
|
||||
yield prev_write_done
|
||||
total_dma_ns += env.now - t0
|
||||
|
||||
pe_txn.result_data["dma_ns"] = total_dma_ns
|
||||
pe_txn.result_data["compute_ns"] = total_compute_ns
|
||||
pe_txn.done.succeed()
|
||||
|
||||
def _pipeline_math(self, env: simpy.Environment, pe_txn: PeInternalTxn, cmd: Any) -> Generator:
|
||||
"""Non-GEMM composite: sequential compute + DMA_WRITE (no tiling)."""
|
||||
from kernbench.common.pe_commands import (
|
||||
DmaWriteCmd,
|
||||
MathCmd,
|
||||
PeInternalTxn as PeTxn,
|
||||
)
|
||||
|
||||
pp = self._pe_prefix
|
||||
|
||||
# Step 1: Compute (MATH)
|
||||
compute_done = env.event()
|
||||
compute_cmd = MathCmd(
|
||||
op=cmd.math_op or "identity",
|
||||
inputs=(cmd.a,), out=cmd.a,
|
||||
)
|
||||
compute_txn = PeTxn(command=compute_cmd, done=compute_done, pe_prefix=pp)
|
||||
yield self.out_ports[f"{pp}.pe_math"].put(compute_txn)
|
||||
yield compute_done
|
||||
|
||||
# Step 2: DMA_WRITE result to HBM
|
||||
write_done = env.event()
|
||||
write_cmd = DmaWriteCmd(handle=cmd.a, dst_addr=cmd.out_addr, nbytes=cmd.out_nbytes)
|
||||
write_txn = PeTxn(command=write_cmd, done=write_done, pe_prefix=pp)
|
||||
yield self.out_ports[f"{pp}.pe_dma"].put(write_txn)
|
||||
yield write_done
|
||||
|
||||
pe_txn.done.succeed()
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PeTcmComponent(ComponentBase):
|
||||
"""PE_TCM: tightly-coupled memory / local SRAM staging buffer.
|
||||
|
||||
Terminal storage component for PE-internal dataflow (ADR-0014 D5).
|
||||
Phase 0: applies overhead_ns and drain_ns at terminal.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
|
||||
def run(self, env, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class SramComponent(ComponentBase):
|
||||
"""Cube SRAM: terminal component that models SRAM access latency.
|
||||
|
||||
Applies overhead_ns processing overhead (from node.attrs).
|
||||
On completion, sends a ResponseMsg back on the reverse path.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
"""Terminal worker: process, apply drain, send response."""
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
yield from self.run(env, txn.nbytes)
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
yield from self._send_response(env, txn)
|
||||
|
||||
def _send_response(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Create ResponseMsg and send on reverse path."""
|
||||
reverse_path = list(reversed(txn.path))
|
||||
if len(reverse_path) >= 2 and self.ctx:
|
||||
from kernbench.runtime_api.kernel import ResponseMsg
|
||||
|
||||
parts = self.node.id.split(".")
|
||||
cube_id = int(parts[1].replace("cube", ""))
|
||||
resp_msg = ResponseMsg(
|
||||
correlation_id=txn.request.correlation_id,
|
||||
request_id=txn.request.request_id,
|
||||
src_cube=cube_id, src_pe=-1, success=True,
|
||||
)
|
||||
resp_txn = Transaction(
|
||||
request=resp_msg, path=reverse_path, step=0,
|
||||
nbytes=0, done=env.event(), is_response=True,
|
||||
)
|
||||
yield self.out_ports[reverse_path[1]].put(resp_txn.advance())
|
||||
else:
|
||||
txn.done.succeed()
|
||||
@@ -0,0 +1,168 @@
|
||||
"""Position-aware XBAR component.
|
||||
|
||||
Models crossbar latency as base_overhead_ns + internal_distance * ns_per_mm,
|
||||
where internal_distance is the Manhattan distance between the entry port
|
||||
(PE router attachment) and exit port (HBM slice logical position) within
|
||||
the crossbar matrix.
|
||||
|
||||
PE router positions come from cube_mesh.yaml (via ctx.spec["_mesh"]).
|
||||
HBM slice positions are uniformly distributed across the HBM physical width.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class PositionAwareXbarComponent(ComponentBase):
|
||||
"""XBAR with position-dependent latency based on PE-to-slice distance.
|
||||
|
||||
Latency = base_overhead_ns + |entry_port_x - exit_port_x| * ns_per_mm
|
||||
|
||||
Entry/exit port X positions are determined from the transaction path:
|
||||
- PE_DMA nodes: router X from cube_mesh.yaml
|
||||
- HBM slices: uniformly distributed across HBM physical width
|
||||
- Bridge nodes: physical X from topology positions
|
||||
- NOC: resolved by scanning path for PE_DMA node
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
self._base_overhead_ns = float(node.attrs.get("overhead_ns", 0.0))
|
||||
self._pe_router_xs: dict[str, float] = {}
|
||||
self._slice_xs: dict[str, float] = {}
|
||||
self._bridge_xs: dict[str, float] = {}
|
||||
self._ns_per_mm: float = 0.0
|
||||
|
||||
def start(self, env: simpy.Environment) -> None:
|
||||
self._build_position_map()
|
||||
super().start(env)
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
yield env.timeout(self._base_overhead_ns)
|
||||
|
||||
# ── Position map construction ─────────────────────────────────
|
||||
|
||||
def _build_position_map(self) -> None:
|
||||
if not self.ctx or not self.ctx.spec:
|
||||
return
|
||||
mesh = self.ctx.spec.get("_mesh")
|
||||
if not mesh:
|
||||
return
|
||||
|
||||
self._ns_per_mm = self.ctx.ns_per_mm
|
||||
cube_prefix = self.node.id.rsplit(".", 1)[0]
|
||||
xbar_name = self.node.id.rsplit(".", 1)[1]
|
||||
is_top = xbar_name == "xbar_top"
|
||||
xbar_key = "top" if is_top else "bottom"
|
||||
|
||||
# PE router X positions from mesh attachments
|
||||
routers_list = mesh.get("xbar", {}).get(xbar_key, {}).get("routers", [])
|
||||
for router_id in routers_list:
|
||||
router_data = mesh["routers"].get(router_id)
|
||||
if router_data is None:
|
||||
continue
|
||||
router_x = router_data["pos_mm"][0]
|
||||
for attach in router_data.get("attach", []):
|
||||
if attach.endswith(".dma"):
|
||||
pe_name = attach.split(".")[0]
|
||||
pe_dma_id = f"{cube_prefix}.{pe_name}.pe_dma"
|
||||
self._pe_router_xs[pe_dma_id] = router_x
|
||||
|
||||
# HBM slice X positions: uniformly distributed across HBM width
|
||||
cube_spec = self.ctx.spec.get("cube", {})
|
||||
cube_w = cube_spec.get("geometry", {}).get("cube_mm", {}).get("w", 17.0)
|
||||
hbm_w = cube_spec.get("geometry", {}).get("hbm_mm", {}).get("w", 9.0)
|
||||
n_slices = cube_spec.get("memory_map", {}).get("hbm_slices_per_cube", 8)
|
||||
half = n_slices // 2
|
||||
hbm_left = (cube_w - hbm_w) / 2
|
||||
|
||||
if is_top:
|
||||
slice_range = range(half)
|
||||
else:
|
||||
slice_range = range(half, n_slices)
|
||||
|
||||
n = len(list(slice_range))
|
||||
for i, sl in enumerate(slice_range):
|
||||
if n > 1:
|
||||
x = hbm_left + i * hbm_w / (n - 1)
|
||||
else:
|
||||
x = cube_w / 2
|
||||
self._slice_xs[f"{cube_prefix}.hbm_ctrl.slice{sl}"] = x
|
||||
|
||||
# Bridge X positions from topology positions
|
||||
for node_id, pos in self.ctx.positions.items():
|
||||
if node_id.startswith(cube_prefix + ".bridge.") and pos is not None:
|
||||
origin_x = self._cube_origin_x()
|
||||
self._bridge_xs[node_id] = pos[0] - origin_x
|
||||
|
||||
def _cube_origin_x(self) -> float:
|
||||
"""Compute absolute X origin of this cube."""
|
||||
parts = self.node.id.split(".")
|
||||
cube_str = [p for p in parts if p.startswith("cube")][0]
|
||||
cube_id = int(cube_str[4:])
|
||||
spec = self.ctx.spec
|
||||
sip_spec = spec.get("sip", {})
|
||||
cube_spec = spec.get("cube", {})
|
||||
mesh_w = sip_spec.get("cube_mesh", {}).get("w", 4)
|
||||
cube_w = cube_spec.get("geometry", {}).get("cube_mm", {}).get("w", 17.0)
|
||||
seam = sip_spec.get("links", {}).get("inter_cube_mesh", {}).get(
|
||||
"distance_mm_across_seam", 1.0)
|
||||
col = cube_id % mesh_w
|
||||
return col * (cube_w + seam)
|
||||
|
||||
# ── Worker override ───────────────────────────────────────────
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
env.process(self._position_aware_forward(env, txn))
|
||||
|
||||
def _position_aware_forward(
|
||||
self, env: simpy.Environment, txn: Any,
|
||||
) -> Generator:
|
||||
prev_hop = txn.path[txn.step - 1] if txn.step > 0 else None
|
||||
next_hop = txn.next_hop
|
||||
|
||||
overhead = self._base_overhead_ns
|
||||
if prev_hop and next_hop and self._ns_per_mm > 0:
|
||||
entry_x = self._get_port_x(prev_hop, txn.path)
|
||||
exit_x = self._get_port_x(next_hop, txn.path)
|
||||
if entry_x is not None and exit_x is not None:
|
||||
overhead = self._base_overhead_ns + abs(entry_x - exit_x) * self._ns_per_mm
|
||||
|
||||
yield env.timeout(overhead)
|
||||
|
||||
if next_hop:
|
||||
yield self.out_ports[next_hop].put(txn.advance())
|
||||
else:
|
||||
drain = getattr(txn, "drain_ns", 0.0)
|
||||
if drain > 0:
|
||||
yield env.timeout(drain)
|
||||
txn.done.succeed()
|
||||
|
||||
def _get_port_x(self, node_id: str, path: list[str]) -> float | None:
|
||||
"""Resolve the X position of an XBAR port from node context."""
|
||||
# Direct lookup: PE DMA
|
||||
if node_id in self._pe_router_xs:
|
||||
return self._pe_router_xs[node_id]
|
||||
# Direct lookup: HBM slice
|
||||
if node_id in self._slice_xs:
|
||||
return self._slice_xs[node_id]
|
||||
# Direct lookup: bridge
|
||||
if node_id in self._bridge_xs:
|
||||
return self._bridge_xs[node_id]
|
||||
# NOC: scan path for PE DMA node
|
||||
if "noc" in node_id:
|
||||
for p in path:
|
||||
if p in self._pe_router_xs:
|
||||
return self._pe_router_xs[p]
|
||||
return None
|
||||
Reference in New Issue
Block a user