commit - release 1

This commit is contained in:
2026-03-18 11:47:48 -07:00
commit 6f43807900
109 changed files with 14909 additions and 0 deletions
+282
View File
@@ -0,0 +1,282 @@
# kernbench/runtime_api/context.py
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from kernbench.common.types import Completion, RequestHandle, SimEngine
from .types import DeviceSelector
@dataclass
class RuntimeContext:
engine: SimEngine
target_device: DeviceSelector
correlation_id: str
spec: dict | None = None
_handles: list[RequestHandle] = field(default_factory=list, init=False)
_completed: set[RequestHandle] = field(default_factory=set, init=False)
_allocators: dict[int, Any] = field(default_factory=dict, init=False)
_tensor_counter: int = field(default=0, init=False)
_traces: list[dict] = field(default_factory=list, init=False)
def submit(self, request: Any) -> RequestHandle:
submit_fn = getattr(self.engine, "submit", None)
if submit_fn is None:
raise AttributeError("Engine does not implement submit(request) -> RequestHandle.")
handle: RequestHandle = submit_fn(request) # type: ignore[call-arg]
self._handles.append(handle)
return handle
def is_completed(self, handle: RequestHandle) -> bool:
return handle in self._completed
def wait(self, handle: RequestHandle, *, _meta: dict | None = None) -> Completion:
if handle in self._completed:
completion, trace = self.engine.get_completion(handle)
return completion
wait_fn = getattr(self.engine, "wait", None)
if wait_fn is not None:
wait_fn(handle) # type: ignore[misc]
completion, trace = self.engine.get_completion(handle)
self._completed.add(handle)
if _meta is not None and trace is not None:
entry = dict(trace) if isinstance(trace, dict) else {"raw": trace}
entry.update(_meta)
self._traces.append(entry)
return completion
def wait_all(self) -> None:
for h in self._handles:
if h not in self._completed:
self.wait(h)
def handles(self) -> list[RequestHandle]:
return list(self._handles)
# ── PyTorch-like tensor API ──────────────────────────────────────
def _ensure_allocators(self) -> dict:
"""Lazily create PEMemAllocator instances from spec."""
if self._allocators:
return self._allocators
if self.spec is None:
raise RuntimeError(
"RuntimeContext.spec is required for tensor operations. "
"Pass spec=graph.spec when creating RuntimeContext."
)
from kernbench.policy.address.allocator import AddressConfig, PEMemAllocator
system = self.spec.get("system", {})
cube = self.spec.get("cube", {})
mm = cube.get("memory_map", {})
pe_template = cube.get("pe_template", {})
pe_comps = pe_template.get("components", {})
tcm_cfg = pe_comps.get("pe_tcm", {}).get("attrs", {})
sip_count = system.get("sips", {}).get("count", 1)
cubes_per_sip = system.get("sips", {}).get("cubes_per_sip", 16)
pes_per_cube = (
cube.get("pe_layout", {}).get("pe_per_corner", 2)
* len(cube.get("pe_layout", {}).get("corners", ["NW", "NE", "SW", "SE"]))
)
hbm_gb = mm.get("hbm_total_gb_per_cube", 48)
hbm_slices = mm.get("hbm_slices_per_cube", 8)
tcm_mb = tcm_cfg.get("size_mb", 16)
cfg = AddressConfig(
sip_count=sip_count,
cubes_per_sip=cubes_per_sip,
pes_per_cube=pes_per_cube,
hbm_bytes_per_cube=hbm_gb * (1 << 30),
hbm_slices_per_cube=hbm_slices,
tcm_bytes_per_pe=tcm_mb * (1 << 20),
tcm_scheduler_reserved_bytes=4 * (1 << 20),
sram_bytes_per_cube=32 * (1 << 20),
)
# Create allocators for all SIPs × cubes × PEs
# Flat index: sip_id * cubes_per_sip * pes_per_cube + cube_id * pes_per_cube + pe_id
self._pes_per_cube = pes_per_cube
self._num_cubes = cubes_per_sip
self._num_sips = sip_count
cubes_x_pes = cubes_per_sip * pes_per_cube
for sip_id in range(sip_count):
for cube_id in range(cubes_per_sip):
for pe_id in range(pes_per_cube):
flat_idx = sip_id * cubes_x_pes + cube_id * pes_per_cube + pe_id
self._allocators[flat_idx] = PEMemAllocator(
rack_id=0, sip_id=sip_id, cube_id=cube_id, pe_id=pe_id, cfg=cfg,
)
return self._allocators
def _next_tensor_name(self) -> str:
self._tensor_counter += 1
return f"t{self._tensor_counter}"
def zeros(
self,
shape: tuple[int, ...],
dtype: str = "f16",
*,
placement: list | None = None,
dp: Any = None,
name: str | None = None,
):
"""Create a tensor and deploy to HBM with zero-fill (like torch.zeros)."""
return self._create_tensor(shape, dtype, placement, name, pattern="zero", dp=dp)
def empty(
self,
shape: tuple[int, ...],
dtype: str = "f16",
*,
placement: list | None = None,
dp: Any = None,
name: str | None = None,
):
"""Allocate a tensor in HBM without initialization (like torch.empty)."""
return self._create_tensor(shape, dtype, placement, name, pattern=None, dp=dp)
def _create_tensor(
self,
shape: tuple[int, ...],
dtype: str,
placement: list | None,
name: str | None,
pattern: str | None,
dp: Any = None,
):
from kernbench.policy.placement.dp import DPPolicy, ShardSpec, resolve_dp_policy
from kernbench.runtime_api.kernel import MemoryWriteMsg
from kernbench.runtime_api.tensor import Tensor, deploy_tensor, dtype_itemsize
tensor_name = name or self._next_tensor_name()
t = Tensor(shape=shape, dtype=dtype, name=tensor_name)
dp_policy: DPPolicy | None = None
# Resolve placement: dp= takes priority over placement=
if dp is not None and isinstance(dp, DPPolicy):
dp_policy = dp
allocators = self._ensure_allocators()
itemsize = dtype_itemsize(dtype)
shape_2d = (shape[0], shape[1]) # type: tuple[int, int]
total_cubes = self._num_sips * self._num_cubes
placement = resolve_dp_policy(
dp, shape=shape_2d, itemsize=itemsize,
num_pe=self._pes_per_cube, num_cubes=total_cubes,
)
elif placement is None:
placement = [ShardSpec(pe_index=0, offset_bytes=0, nbytes=t.nbytes)]
# Infer target_pe from placement: multi-PE → "all", single PE → pe_index
pe_indices = {s.pe_index for s in placement}
target_pe: int | str = "all" if len(pe_indices) > 1 else next(iter(pe_indices))
t.to(placement=placement, target_pe=target_pe, dp_policy=dp_policy)
# Allocate PAs via PEMemAllocator
allocators = self._ensure_allocators()
handle = deploy_tensor(
name=tensor_name,
shape=shape,
dtype=dtype,
placement=placement,
allocators=allocators,
)
t._handle = handle
# Submit MemoryWriteMsg per shard (deploy data to device)
if pattern is not None:
for shard in handle.shards:
h = self.submit(MemoryWriteMsg(
correlation_id=self.correlation_id,
request_id=f"deploy_{tensor_name}_pe{shard.pe}",
dst_sip=shard.sip, dst_cube=shard.cube, dst_pe=shard.pe,
dst_pa=shard.pa, nbytes=shard.nbytes, pattern=pattern,
target_cubes=(shard.cube,), target_pe=shard.pe,
))
self.wait(h, _meta={
"phase": "memory_write", "name": tensor_name,
"sip": shard.sip, "cube": shard.cube, "pe": shard.pe,
"nbytes": shard.nbytes,
})
return t
def launch(
self,
kernel_name: str,
kernel_fn: Any,
*args: Any,
**kwargs: Any,
) -> RequestHandle:
"""Register and launch a kernel (like a fused torch op).
Positional args: Tensor objects become TensorArg, int/float become ScalarArg.
Keyword args: become ScalarArg (name is discarded, order preserved).
"""
from kernbench.runtime_api.kernel import (
KernelLaunchMsg,
KernelRef,
ScalarArg,
)
from kernbench.runtime_api.tensor import Tensor
from kernbench.triton_emu.registry import register_kernel
# Register kernel (idempotent)
try:
register_kernel(kernel_name, kernel_fn)
except ValueError:
pass
# Build kernel args from positional + keyword args
kernel_args: list = []
target_pe: int | str = 0
for a in args:
if isinstance(a, Tensor):
kernel_args.append(a.to_tensor_arg())
# Infer target_pe from tensor DP metadata
if a._dp_metadata is not None:
dp_target = a._dp_metadata.target_pe
if dp_target == "all":
target_pe = "all"
elif isinstance(dp_target, int) and target_pe != "all":
target_pe = dp_target
elif isinstance(a, (int, float)):
dtype_str = "f32" if isinstance(a, float) else "i32"
kernel_args.append(ScalarArg(dtype=dtype_str, value=a))
for v in kwargs.values():
if isinstance(v, (int, float)):
dtype_str = "f32" if isinstance(v, float) else "i32"
kernel_args.append(ScalarArg(dtype=dtype_str, value=v))
# Determine target cubes from all tensor shards
cube_set: set[int] = set()
for a in args:
if isinstance(a, Tensor) and a._handle is not None:
for s in a._handle.shards:
cube_set.add(s.cube)
target_cubes = tuple(sorted(cube_set)) if cube_set else (0,)
# Collect scalar values for GEMM FLOP calculation
scalar_vals = [a.value for a in kernel_args if hasattr(a, "value")]
h = self.submit(KernelLaunchMsg(
correlation_id=self.correlation_id,
request_id=kernel_name,
kernel_ref=KernelRef(name=kernel_name, kind="builtin"),
args=tuple(kernel_args),
target_cubes=target_cubes,
target_pe=target_pe,
))
self.wait(h, _meta={
"phase": "kernel", "name": kernel_name,
"target_pe": target_pe, "scalars": scalar_vals,
})
return h