import pytest from kernbench.policy.address.allocator import AddressConfig, AllocationError, PEMemAllocator from kernbench.policy.address.phyaddr import ( PhysAddr, PhysAddrError, UnitType, PESubUnit, MCPUSubUnit, IOCPUSubUnit, ) _MB = 1 << 20 _GB = 1 << 30 # Topology-matching config: 48GB HBM / 8 slices / 16MB TCM / 4MB reserved / 32MB SRAM _CFG = AddressConfig( sip_count=2, cubes_per_sip=16, pes_per_cube=8, hbm_bytes_per_cube=48 * _GB, hbm_slices_per_cube=8, tcm_bytes_per_pe=16 * _MB, tcm_scheduler_reserved_bytes=4 * _MB, sram_bytes_per_cube=32 * _MB, ) # ── Immutability & value semantics ────────────────────────────────── def test_physaddr_immutable(): pa = PhysAddr.hbm_addr(sip_id=0, die_id=0, hbm_offset=0) with pytest.raises(AttributeError): pa.sip_id = 1 # type: ignore[misc] {pa} # hashable pa2 = PhysAddr.hbm_addr(sip_id=0, die_id=0, hbm_offset=0) assert pa == pa2 # ── HBM encode/decode roundtrip ──────────────────────────────────── def test_hbm_encode_decode_roundtrip(): pa = PhysAddr.hbm_addr(sip_id=3, die_id=5, hbm_offset=0x1000) raw = pa.encode() dec = PhysAddr.decode(raw) assert dec.sip_id == 3 assert dec.die_id == 5 assert dec.kind == "hbm" assert dec.hbm_offset == 0x1000 # ── PE resource encode/decode roundtrip (new layout) ─────────────── def test_pe_resource_encode_decode_roundtrip(): pa = PhysAddr.pe_resource_addr( sip_id=2, die_id=7, pe_id=3, pe_sub_unit=PESubUnit.PE_TCM, sub_offset=0xFF, ) raw = pa.encode() dec = PhysAddr.decode(raw) assert dec.kind == "pe_resource" assert dec.unit_type == UnitType.PE assert dec.pe_id == 3 assert dec.pe_sub_unit == PESubUnit.PE_TCM assert dec.sub_offset == 0xFF assert dec.die_id == 7 assert dec.sip_id == 2 def test_pe_resource_all_sub_units(): """Each PE sub-unit roundtrips correctly.""" for su in PESubUnit: pa = PhysAddr.pe_resource_addr( sip_id=0, die_id=0, pe_id=0, pe_sub_unit=su, sub_offset=42, ) dec = PhysAddr.decode(pa.encode()) assert dec.pe_sub_unit == su assert dec.sub_offset == 42 # ── pe_hbm_addr factory ──────────────────────────────────────────── def test_pe_hbm_addr_factory(): SLICE = 6 * _GB pa = PhysAddr.pe_hbm_addr( sip_id=0, die_id=0, pe_id=2, pe_local_hbm_offset=1024, slice_size_bytes=SLICE, ) assert pa.kind == "hbm" assert pa.die_id == 0 assert pa.hbm_offset == 2 * SLICE + 1024 def test_pe_hbm_addr_overflow(): SLICE = 6 * _GB with pytest.raises(PhysAddrError, match="pe_local_hbm_offset"): PhysAddr.pe_hbm_addr( sip_id=0, die_id=0, pe_id=0, pe_local_hbm_offset=SLICE, slice_size_bytes=SLICE, ) # ── Invalid resource_kind decode ────────────────────────────────── def test_invalid_resource_kind_raises(): # resource_kind=7 (invalid), addr_space=0 local_offset = (7 << 34) | 0 pa_raw = PhysAddr(sip_id=0, die_id=0, local_offset=local_offset) raw = pa_raw.encode() with pytest.raises(PhysAddrError, match="resource_kind"): PhysAddr.decode(raw) # ── hbm_pe_id utility ───────────────────────────────────────────── def test_hbm_pe_id_utility(): SLICE = 6 * _GB pa = PhysAddr.pe_hbm_addr( sip_id=0, die_id=0, pe_id=5, pe_local_hbm_offset=256, slice_size_bytes=SLICE, ) assert PhysAddr.hbm_pe_id(pa.hbm_offset, SLICE) == 5 # ── UnitType / sub-unit enums ────────────────────────────────────── def test_sram_unit_type_exists(): assert UnitType.SRAM == 2 def test_pe_sub_unit_enum(): assert PESubUnit.PE_TCM == 6 assert PESubUnit.IPCQ == 2 def test_mcpu_sub_unit_enum(): assert MCPUSubUnit.MCPU_SRAM == 5 def test_iocpu_sub_unit_enum(): assert IOCPUSubUnit.IO_SRAM == 5 # ── cube_sram_addr factory + roundtrip ────────────────────────────── def test_cube_sram_addr_roundtrip(): pa = PhysAddr.cube_sram_addr(sip_id=1, die_id=3, sram_offset=0x800) assert pa.kind == "pe_resource" assert pa.unit_type == UnitType.SRAM assert pa.die_id == 3 assert pa.sub_offset == 0x800 dec = PhysAddr.decode(pa.encode()) assert dec.unit_type == UnitType.SRAM assert dec.die_id == 3 assert dec.sub_offset == 0x800 def test_cube_sram_addr_range_check(): with pytest.raises(PhysAddrError): PhysAddr.cube_sram_addr( sip_id=0, die_id=0, sram_offset=(1 << 25), # exceeds 25-bit sub_offset ) # ── pe_tcm_addr factory + roundtrip ──────────────────────────────── def test_pe_tcm_addr_roundtrip(): pa = PhysAddr.pe_tcm_addr(sip_id=0, die_id=2, pe_id=7, tcm_offset=0x400) assert pa.kind == "pe_resource" assert pa.unit_type == UnitType.PE assert pa.pe_id == 7 assert pa.die_id == 2 assert pa.pe_sub_unit == PESubUnit.PE_TCM assert pa.sub_offset == 0x400 dec = PhysAddr.decode(pa.encode()) assert dec.unit_type == UnitType.PE assert dec.pe_id == 7 assert dec.pe_sub_unit == PESubUnit.PE_TCM assert dec.sub_offset == 0x400 def test_pe_tcm_addr_range_check(): with pytest.raises(PhysAddrError): PhysAddr.pe_tcm_addr( sip_id=0, die_id=0, pe_id=0, tcm_offset=(1 << 25), # exceeds 25-bit sub_offset ) # ── MCPU resource factory + roundtrip ────────────────────────────── def test_mcpu_resource_roundtrip(): pa = PhysAddr.mcpu_resource_addr( sip_id=0, die_id=1, mcpu_sub_unit=MCPUSubUnit.MCPU_SRAM, sub_offset=0x100, ) assert pa.kind == "pe_resource" assert pa.unit_type == UnitType.MCPU assert pa.mcpu_sub_unit == MCPUSubUnit.MCPU_SRAM assert pa.sub_offset == 0x100 dec = PhysAddr.decode(pa.encode()) assert dec.unit_type == UnitType.MCPU assert dec.mcpu_sub_unit == MCPUSubUnit.MCPU_SRAM assert dec.sub_offset == 0x100 # ── IOCHIPLET: IOCPU factory + roundtrip ──────────────────────────── def test_iocpu_resource_roundtrip(): pa = PhysAddr.iocpu_resource_addr( sip_id=1, die_id=17, iocpu_sub_unit=IOCPUSubUnit.IPCQ, sub_offset=0x20000, ) assert pa.kind == "iocpu" assert pa.iocpu_sub_unit == IOCPUSubUnit.IPCQ assert pa.sub_offset == 0x20000 dec = PhysAddr.decode(pa.encode()) assert dec.kind == "iocpu" assert dec.iocpu_sub_unit == IOCPUSubUnit.IPCQ assert dec.sub_offset == 0x20000 assert dec.die_id == 17 def test_iocpu_die_range_check(): with pytest.raises(PhysAddrError, match="IOCHIPLET"): PhysAddr.iocpu_resource_addr( sip_id=0, die_id=5, # not a chiplet die iocpu_sub_unit=0, sub_offset=0, ) # ── IOCHIPLET: UAL factory + roundtrip ────────────────────────────── def test_ual_addr_roundtrip(): pa = PhysAddr.ual_addr(sip_id=0, die_id=16, ual_offset=0x1000) assert pa.kind == "ual" dec = PhysAddr.decode(pa.encode()) assert dec.kind == "ual" assert dec.die_id == 16 assert dec.chiplet_offset >= (1 << 31) # >= 2 GB boundary # ── die_id dispatch ──────────────────────────────────────────────── def test_die_id_ahbm_range(): for die in [0, 15]: pa = PhysAddr.hbm_addr(sip_id=0, die_id=die, hbm_offset=0) dec = PhysAddr.decode(pa.encode()) assert dec.kind == "hbm" assert dec.die_id == die def test_die_id_chiplet_range(): for die in [16, 20]: pa = PhysAddr.iocpu_resource_addr( sip_id=0, die_id=die, iocpu_sub_unit=0, sub_offset=0, ) dec = PhysAddr.decode(pa.encode()) assert dec.kind == "iocpu" assert dec.die_id == die def test_die_id_reserved_raises(): raw = (0 << 47) | (21 << 42) | 0 # die_id=21 (reserved) with pytest.raises(PhysAddrError, match="reserved"): PhysAddr.decode(raw) # ── Boundary values ──────────────────────────────────────────────── def test_sip_boundary(): pa = PhysAddr.hbm_addr(sip_id=15, die_id=0, hbm_offset=0) dec = PhysAddr.decode(pa.encode()) assert dec.sip_id == 15 def test_mbz_enforcement_ahbm(): """AHBM local_offset bits [41:38] must be zero.""" local_offset = (1 << 38) | (1 << 37) # MBZ bit set + HBM pa = PhysAddr(sip_id=0, die_id=0, local_offset=local_offset) with pytest.raises(PhysAddrError, match="bits \\[41:38\\]"): pa.encode() def test_mbz_enforcement_chiplet(): """IOCHIPLET local_offset bits [41:40] must be zero.""" local_offset = (1 << 40) | 0 # MBZ bit set pa = PhysAddr(sip_id=0, die_id=16, local_offset=local_offset) with pytest.raises(PhysAddrError, match="bits \\[41:40\\]"): pa.encode() # ── AddressConfig ─────────────────────────────────────────────────── def test_address_config_derived_sizes(): assert _CFG.hbm_slice_bytes == 6 * _GB assert _CFG.tcm_allocatable_bytes == 12 * _MB # ── PEMemAllocator: HBM ──────────────────────────────────────────── def _make_alloc(pe_id: int = 0) -> PEMemAllocator: return PEMemAllocator(sip_id=0, die_id=0, pe_id=pe_id, cfg=_CFG) def test_allocator_hbm_basic(): a = _make_alloc(pe_id=3) pa = a.alloc_hbm(4096) assert pa.kind == "hbm" assert pa.sip_id == 0 assert pa.die_id == 0 assert pa.hbm_offset == 3 * 6 * _GB def test_allocator_hbm_sequential(): a = _make_alloc() pa1 = a.alloc_hbm(1024) pa2 = a.alloc_hbm(2048) assert pa1.hbm_offset == 0 assert pa2.hbm_offset == 1024 def test_allocator_hbm_overflow(): a = _make_alloc() a.alloc_hbm(6 * _GB - 256) with pytest.raises(AllocationError, match="HBM"): a.alloc_hbm(512) # ── PEMemAllocator: TCM ──────────────────────────────────────────── def test_allocator_tcm_basic(): a = _make_alloc(pe_id=5) pa = a.alloc_tcm(256) assert pa.kind == "pe_resource" assert pa.unit_type == UnitType.PE assert pa.pe_id == 5 assert pa.sub_offset == 0 def test_allocator_tcm_respects_reserved(): a = _make_alloc() a.alloc_tcm(12 * _MB) assert a.tcm_used == 12 * _MB assert a.tcm_total == 12 * _MB def test_allocator_tcm_overflow(): a = _make_alloc() a.alloc_tcm(12 * _MB) with pytest.raises(AllocationError, match="TCM"): a.alloc_tcm(1) # ── PEMemAllocator: stats & determinism ───────────────────────────── def test_allocator_stats(): a = _make_alloc() a.alloc_hbm(1000) a.alloc_tcm(500) assert a.hbm_used == 1000 assert a.hbm_total == 6 * _GB assert a.tcm_used == 500 assert a.tcm_total == 12 * _MB def test_allocator_deterministic(): a1 = _make_alloc(pe_id=2) a2 = _make_alloc(pe_id=2) assert a1.alloc_hbm(4096) == a2.alloc_hbm(4096) assert a1.alloc_tcm(256) == a2.alloc_tcm(256)