From c865677442448be31f02f4f74320a79800a51105 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 16 Mar 2026 16:45:52 -0400 Subject: [PATCH] fix(core): unit-test peer access planning without multi-GPU hardware Extract the peer-access transition planning from DeviceMemoryResource so stale-state regressions can be covered on single-GPU systems. Keep the existing multi-GPU integration tests for end-to-end peer access behavior. Made-with: Cursor --- .../core/_memory/_device_memory_resource.pyx | 24 ++++--- .../cuda/core/_memory/_peer_access_utils.py | 59 ++++++++++++++++ cuda_core/tests/test_memory_peer_access.py | 34 ---------- .../tests/test_memory_peer_access_utils.py | 67 +++++++++++++++++++ 4 files changed, 142 insertions(+), 42 deletions(-) create mode 100644 cuda_core/cuda/core/_memory/_peer_access_utils.py create mode 100644 cuda_core/tests/test_memory_peer_access_utils.py diff --git a/cuda_core/cuda/core/_memory/_device_memory_resource.pyx b/cuda_core/cuda/core/_memory/_device_memory_resource.pyx index 744c58e021..9f8e4bcd53 100644 --- a/cuda_core/cuda/core/_memory/_device_memory_resource.pyx +++ b/cuda_core/cuda/core/_memory/_device_memory_resource.pyx @@ -25,6 +25,7 @@ import multiprocessing import platform # no-cython-lint import uuid +from ._peer_access_utils import plan_peer_access_update from cuda.core._utils.cuda_utils import check_multiprocessing_start_method __all__ = ['DeviceMemoryResource', 'DeviceMemoryResourceOptions'] @@ -281,17 +282,24 @@ cdef inline _DMR_query_peer_access(DeviceMemoryResource self): cdef inline _DMR_set_peer_accessible_by(DeviceMemoryResource self, devices): from .._device import Device - cdef set[int] target_ids = {Device(dev).device_id for dev in devices} - target_ids.discard(self._dev_id) this_dev = Device(self._dev_id) - cdef list bad = [dev for dev in target_ids if not this_dev.can_access_peer(dev)] - if bad: - raise ValueError(f"Device {self._dev_id} cannot access peer(s): {', '.join(map(str, bad))}") + cdef object resolve_device_id = lambda dev: Device(dev).device_id + cdef object plan + cdef tuple target_ids + cdef tuple to_add + cdef tuple to_rm if not self._mempool_owned: _DMR_query_peer_access(self) - cdef set[int] cur_ids = set(self._peer_accessible_by) - cdef set[int] to_add = target_ids - cur_ids - cdef set[int] to_rm = cur_ids - target_ids + plan = plan_peer_access_update( + owner_device_id=self._dev_id, + current_peer_ids=self._peer_accessible_by, + requested_devices=devices, + resolve_device_id=resolve_device_id, + can_access_peer=this_dev.can_access_peer, + ) + target_ids = plan.target_ids + to_add = plan.to_add + to_rm = plan.to_remove cdef size_t count = len(to_add) + len(to_rm) cdef cydriver.CUmemAccessDesc* access_desc = NULL cdef size_t i = 0 diff --git a/cuda_core/cuda/core/_memory/_peer_access_utils.py b/cuda_core/cuda/core/_memory/_peer_access_utils.py new file mode 100644 index 0000000000..e08de69f2c --- /dev/null +++ b/cuda_core/cuda/core/_memory/_peer_access_utils.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Callable, Iterable +from dataclasses import dataclass + + +@dataclass(frozen=True) +class PeerAccessPlan: + """Normalized peer-access target state and the driver updates it requires.""" + + target_ids: tuple[int, ...] + to_add: tuple[int, ...] + to_remove: tuple[int, ...] + + +def normalize_peer_access_targets( + owner_device_id: int, + requested_devices: Iterable[object], + *, + resolve_device_id: Callable[[object], int], +) -> tuple[int, ...]: + """Return sorted, unique peer device IDs, excluding the owner device.""" + + target_ids = {resolve_device_id(device) for device in requested_devices} + target_ids.discard(owner_device_id) + return tuple(sorted(target_ids)) + + +def plan_peer_access_update( + owner_device_id: int, + current_peer_ids: Iterable[int], + requested_devices: Iterable[object], + *, + resolve_device_id: Callable[[object], int], + can_access_peer: Callable[[int], bool], +) -> PeerAccessPlan: + """Compute the peer-access target state and add/remove deltas.""" + + target_ids = normalize_peer_access_targets( + owner_device_id, + requested_devices, + resolve_device_id=resolve_device_id, + ) + bad = tuple(dev_id for dev_id in target_ids if not can_access_peer(dev_id)) + if bad: + bad_ids = ", ".join(str(dev_id) for dev_id in bad) + raise ValueError(f"Device {owner_device_id} cannot access peer(s): {bad_ids}") + + current_ids = set(current_peer_ids) + target_id_set = set(target_ids) + return PeerAccessPlan( + target_ids=target_ids, + to_add=tuple(sorted(target_id_set - current_ids)), + to_remove=tuple(sorted(current_ids - target_id_set)), + ) diff --git a/cuda_core/tests/test_memory_peer_access.py b/cuda_core/tests/test_memory_peer_access.py index b7d5747b75..1a79064586 100644 --- a/cuda_core/tests/test_memory_peer_access.py +++ b/cuda_core/tests/test_memory_peer_access.py @@ -4,7 +4,6 @@ import pytest from helpers.buffers import PatternGen, compare_buffer_to_constant, make_scratch_buffer -import cuda.core from cuda.core import DeviceMemoryResource, DeviceMemoryResourceOptions from cuda.core._utils.cuda_utils import CUDAError @@ -48,39 +47,6 @@ def test_peer_access_basic(mempool_device_x2): zero_on_dev0.copy_from(buf_on_dev1, stream=stream_on_dev0) -def test_peer_access_property_x2(mempool_device_x2): - """The the dmr.peer_accessible_by property (but not its functionality).""" - # The peer access list is a sorted tuple and always excludes the self - # device. - dev0, dev1 = mempool_device_x2 - # Use owned pool to ensure clean initial state (no stale peer access). - dmr = DeviceMemoryResource(dev0, DeviceMemoryResourceOptions()) - - def check(expected): - assert isinstance(dmr.peer_accessible_by, tuple) - assert dmr.peer_accessible_by == expected - - # No access to begin with. - check(expected=()) - # fmt: off - dmr.peer_accessible_by = (0,) ; check(expected=()) - dmr.peer_accessible_by = (1,) ; check(expected=(1,)) - dmr.peer_accessible_by = (0, 1) ; check(expected=(1,)) - dmr.peer_accessible_by = () ; check(expected=()) - dmr.peer_accessible_by = [0, 1] ; check(expected=(1,)) - dmr.peer_accessible_by = set() ; check(expected=()) - dmr.peer_accessible_by = [1, 1, 1, 1, 1] ; check(expected=(1,)) - # fmt: on - - with pytest.raises(ValueError, match=r"device_id must be \>\= 0"): - dmr.peer_accessible_by = [-1] # device ID out of bounds - - num_devices = len(cuda.core.Device.get_all_devices()) - - with pytest.raises(ValueError, match=r"device_id must be within \[0, \d+\)"): - dmr.peer_accessible_by = [num_devices] # device ID out of bounds - - def test_peer_access_transitions(mempool_device_x3): """Advanced tests for dmr.peer_accessible_by.""" diff --git a/cuda_core/tests/test_memory_peer_access_utils.py b/cuda_core/tests/test_memory_peer_access_utils.py new file mode 100644 index 0000000000..97fab3c619 --- /dev/null +++ b/cuda_core/tests/test_memory_peer_access_utils.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from cuda.core._memory._peer_access_utils import PeerAccessPlan, plan_peer_access_update + + +@dataclass(frozen=True) +class DummyDevice: + device_id: int + + +def _resolve_device_id(device) -> int: + if isinstance(device, DummyDevice): + return device.device_id + return int(device) + + +def test_plan_peer_access_update_normalizes_requests(): + plan = plan_peer_access_update( + owner_device_id=1, + current_peer_ids=(), + requested_devices=[1, DummyDevice(3), 2, DummyDevice(2), 3], + resolve_device_id=_resolve_device_id, + can_access_peer=lambda _device_id: True, + ) + + assert plan == PeerAccessPlan( + target_ids=(2, 3), + to_add=(2, 3), + to_remove=(), + ) + + +def test_plan_peer_access_update_rejects_inaccessible_peers(): + with pytest.raises(ValueError, match=r"Device 0 cannot access peer\(s\): 2, 4"): + plan_peer_access_update( + owner_device_id=0, + current_peer_ids=(1,), + requested_devices=[4, 0, DummyDevice(2), 1], + resolve_device_id=_resolve_device_id, + can_access_peer=lambda device_id: device_id == 1, + ) + + +def test_plan_peer_access_update_covers_all_state_transitions(): + states = [(), (1,), (2,), (1, 2)] + for current_state in states: + for requested_state in states: + plan = plan_peer_access_update( + owner_device_id=0, + current_peer_ids=current_state, + requested_devices=requested_state, + resolve_device_id=_resolve_device_id, + can_access_peer=lambda device_id: device_id in {1, 2}, + ) + + assert plan == PeerAccessPlan( + target_ids=requested_state, + to_add=tuple(sorted(set(requested_state) - set(current_state))), + to_remove=tuple(sorted(set(current_state) - set(requested_state))), + )