Restructure legacy backups, remove pe_accel, fix DMA self-routing
- Move builtin_legacy/ → legacy/builtin/ (cleaner structure) - Move pe_accel_legacy/ → legacy/pe_accel/ - Remove custom/pe_accel/ (replaced by new builtin) - Remove pe_scheduler_v2 from components.yaml - Switch topology.yaml to pe_scheduler_v1 (new builtin) - Fix PE_DMA self-routing: handle consecutive DMA_READ stages (same component consecutive stages processed in-place, not via port) 382 tests passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import simpy
|
||||
|
||||
from kernbench.components.base import ComponentBase
|
||||
from kernbench.sim_engine.transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kernbench.components.context import ComponentContext
|
||||
from kernbench.topology.types import Node
|
||||
|
||||
|
||||
class IoCpuComponent(ComponentBase):
|
||||
"""IO_CPU component: multi-cube fan-out with response aggregation.
|
||||
|
||||
Forward path:
|
||||
1. Applies overhead_ns processing overhead.
|
||||
2. Resolves target cube(s) from request.target_cubes.
|
||||
3. Fans out sub-Transactions to each target cube's M_CPU.
|
||||
|
||||
Response path:
|
||||
Collects ResponseMsg from each M_CPU. When all cube responses are
|
||||
received, succeeds the parent txn.done.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, ctx: ComponentContext | None = None) -> None:
|
||||
super().__init__(node, ctx)
|
||||
# Pending fan-out tracking: request_id → (expected, received, parent_txn_done)
|
||||
self._pending: dict[str, tuple[int, int, simpy.Event]] = {}
|
||||
|
||||
def run(self, env: simpy.Environment, nbytes: int) -> Generator:
|
||||
overhead_ns = float(self.node.attrs.get("overhead_ns", 0.0))
|
||||
yield env.timeout(overhead_ns)
|
||||
|
||||
def _worker(self, env: simpy.Environment) -> Generator:
|
||||
while True:
|
||||
txn: Any = yield self._inbox.get()
|
||||
if getattr(txn, "is_response", False):
|
||||
self._collect_response(txn)
|
||||
else:
|
||||
yield from self.run(env, txn.nbytes)
|
||||
env.process(self._dispatch_to_m_cpus(env, txn))
|
||||
|
||||
def _collect_response(self, resp_txn: Any) -> None:
|
||||
"""Receive a cube response and increment the aggregation counter."""
|
||||
key = resp_txn.request.request_id
|
||||
if key not in self._pending:
|
||||
return
|
||||
expected, received, parent_done = self._pending[key]
|
||||
received += 1
|
||||
if received >= expected:
|
||||
parent_done.succeed()
|
||||
del self._pending[key]
|
||||
else:
|
||||
self._pending[key] = (expected, received, parent_done)
|
||||
|
||||
def _dispatch_to_m_cpus(self, env: simpy.Environment, txn: Any) -> Generator:
|
||||
"""Fan out sub-Transactions to target cube M_CPUs, wait for responses."""
|
||||
from kernbench.runtime_api.kernel import KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg
|
||||
|
||||
request = txn.request
|
||||
try:
|
||||
cube_targets = self._resolve_cube_targets(request)
|
||||
except Exception:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
if not cube_targets:
|
||||
txn.done.succeed()
|
||||
return
|
||||
|
||||
# Setup aggregation
|
||||
self._pending[request.request_id] = (len(cube_targets), 0, txn.done)
|
||||
|
||||
# Fan out to each target cube's M_CPU
|
||||
for sip, cube in cube_targets:
|
||||
try:
|
||||
m_cpu_id = self.ctx.resolver.find_m_cpu(sip, cube)
|
||||
path = self.ctx.router.find_node_path(self.node.id, m_cpu_id)
|
||||
except Exception:
|
||||
continue
|
||||
if len(path) < 2:
|
||||
continue
|
||||
sub_txn = Transaction(
|
||||
request=request, path=path, step=0,
|
||||
nbytes=txn.nbytes, done=env.event(),
|
||||
result_data=txn.result_data,
|
||||
)
|
||||
yield self.out_ports[path[1]].put(sub_txn.advance())
|
||||
|
||||
def _resolve_cube_targets(self, request: Any) -> list[tuple[int, int]]:
|
||||
"""Return list of (sip, cube) pairs to fan out to."""
|
||||
from kernbench.runtime_api.kernel import (
|
||||
KernelLaunchMsg, MemoryReadMsg, MemoryWriteMsg, MmuMapMsg, MmuUnmapMsg,
|
||||
)
|
||||
|
||||
target_cubes = getattr(request, "target_cubes", "all")
|
||||
|
||||
if isinstance(request, MemoryWriteMsg):
|
||||
sip = request.dst_sip
|
||||
if target_cubes == "all":
|
||||
cube = self._cube_from_pa(request.dst_pa, fallback=request.dst_cube)
|
||||
return [(sip, cube)]
|
||||
return [(sip, c) for c in target_cubes]
|
||||
|
||||
if isinstance(request, MemoryReadMsg):
|
||||
sip = request.src_sip
|
||||
if target_cubes == "all":
|
||||
cube = self._cube_from_pa(request.src_pa, fallback=request.src_cube)
|
||||
return [(sip, cube)]
|
||||
return [(sip, c) for c in target_cubes]
|
||||
|
||||
if isinstance(request, KernelLaunchMsg):
|
||||
my_sip = self._my_sip()
|
||||
if target_cubes != "all":
|
||||
return [(my_sip, c) for c in target_cubes]
|
||||
# "all": derive from tensor shards, filtered to this SIP
|
||||
seen: set[tuple[int, int]] = set()
|
||||
targets: list[tuple[int, int]] = []
|
||||
for arg in request.args:
|
||||
if arg.arg_kind != "tensor":
|
||||
continue
|
||||
for shard in arg.shards:
|
||||
if shard.sip != my_sip:
|
||||
continue
|
||||
key = (shard.sip, shard.cube)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
targets.append(key)
|
||||
return targets
|
||||
|
||||
if isinstance(request, (MmuMapMsg, MmuUnmapMsg)):
|
||||
my_sip = self._my_sip()
|
||||
if target_cubes == "all":
|
||||
n_cubes = 16
|
||||
if self.ctx and self.ctx.spec:
|
||||
sips = self.ctx.spec.get("system", {}).get("sips", {})
|
||||
n_cubes = sips.get("cubes_per_sip", 16)
|
||||
return [(my_sip, c) for c in range(n_cubes)]
|
||||
return [(my_sip, c) for c in target_cubes]
|
||||
|
||||
return []
|
||||
|
||||
def _cube_from_pa(self, pa_val: int, fallback: int) -> int:
|
||||
"""Extract cube_id from a physical address, with fallback."""
|
||||
from kernbench.policy.address.phyaddr import PhysAddr
|
||||
try:
|
||||
return PhysAddr.decode(pa_val).cube_id
|
||||
except Exception:
|
||||
return fallback
|
||||
|
||||
def _my_sip(self) -> int:
|
||||
"""Extract this IO_CPU's SIP ID from its node ID (e.g. 'sip0.io0.io_cpu' → 0)."""
|
||||
return int(self.node.id.split(".")[0].replace("sip", ""))
|
||||
Reference in New Issue
Block a user