Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions cuda_core/cuda/core/_memory/_device_memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import multiprocessing
import platform # no-cython-lint
import uuid

from ._peer_access_utils import plan_peer_access_update
from cuda.core._utils.cuda_utils import check_multiprocessing_start_method

__all__ = ['DeviceMemoryResource', 'DeviceMemoryResourceOptions']
Expand Down Expand Up @@ -281,17 +282,24 @@ cdef inline _DMR_query_peer_access(DeviceMemoryResource self):
cdef inline _DMR_set_peer_accessible_by(DeviceMemoryResource self, devices):
from .._device import Device

cdef set[int] target_ids = {Device(dev).device_id for dev in devices}
target_ids.discard(self._dev_id)
this_dev = Device(self._dev_id)
cdef list bad = [dev for dev in target_ids if not this_dev.can_access_peer(dev)]
if bad:
raise ValueError(f"Device {self._dev_id} cannot access peer(s): {', '.join(map(str, bad))}")
cdef object resolve_device_id = lambda dev: Device(dev).device_id
cdef object plan
cdef tuple target_ids
cdef tuple to_add
cdef tuple to_rm
if not self._mempool_owned:
_DMR_query_peer_access(self)
cdef set[int] cur_ids = set(self._peer_accessible_by)
cdef set[int] to_add = target_ids - cur_ids
cdef set[int] to_rm = cur_ids - target_ids
plan = plan_peer_access_update(
owner_device_id=self._dev_id,
current_peer_ids=self._peer_accessible_by,
requested_devices=devices,
resolve_device_id=resolve_device_id,
can_access_peer=this_dev.can_access_peer,
)
target_ids = plan.target_ids
to_add = plan.to_add
to_rm = plan.to_remove
cdef size_t count = len(to_add) + len(to_rm)
cdef cydriver.CUmemAccessDesc* access_desc = NULL
cdef size_t i = 0
Expand Down
59 changes: 59 additions & 0 deletions cuda_core/cuda/core/_memory/_peer_access_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from collections.abc import Callable, Iterable
from dataclasses import dataclass


@dataclass(frozen=True)
class PeerAccessPlan:
"""Normalized peer-access target state and the driver updates it requires."""

target_ids: tuple[int, ...]
to_add: tuple[int, ...]
to_remove: tuple[int, ...]


def normalize_peer_access_targets(
owner_device_id: int,
requested_devices: Iterable[object],
*,
resolve_device_id: Callable[[object], int],
) -> tuple[int, ...]:
"""Return sorted, unique peer device IDs, excluding the owner device."""

target_ids = {resolve_device_id(device) for device in requested_devices}
target_ids.discard(owner_device_id)
return tuple(sorted(target_ids))


def plan_peer_access_update(
owner_device_id: int,
current_peer_ids: Iterable[int],
requested_devices: Iterable[object],
*,
resolve_device_id: Callable[[object], int],
can_access_peer: Callable[[int], bool],
) -> PeerAccessPlan:
"""Compute the peer-access target state and add/remove deltas."""

target_ids = normalize_peer_access_targets(
owner_device_id,
requested_devices,
resolve_device_id=resolve_device_id,
)
bad = tuple(dev_id for dev_id in target_ids if not can_access_peer(dev_id))
if bad:
bad_ids = ", ".join(str(dev_id) for dev_id in bad)
raise ValueError(f"Device {owner_device_id} cannot access peer(s): {bad_ids}")

current_ids = set(current_peer_ids)
target_id_set = set(target_ids)
return PeerAccessPlan(
target_ids=target_ids,
to_add=tuple(sorted(target_id_set - current_ids)),
to_remove=tuple(sorted(current_ids - target_id_set)),
)
34 changes: 0 additions & 34 deletions cuda_core/tests/test_memory_peer_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
from helpers.buffers import PatternGen, compare_buffer_to_constant, make_scratch_buffer

import cuda.core
from cuda.core import DeviceMemoryResource, DeviceMemoryResourceOptions
from cuda.core._utils.cuda_utils import CUDAError

Expand Down Expand Up @@ -48,39 +47,6 @@ def test_peer_access_basic(mempool_device_x2):
zero_on_dev0.copy_from(buf_on_dev1, stream=stream_on_dev0)


def test_peer_access_property_x2(mempool_device_x2):
"""The the dmr.peer_accessible_by property (but not its functionality)."""
# The peer access list is a sorted tuple and always excludes the self
# device.
dev0, dev1 = mempool_device_x2
# Use owned pool to ensure clean initial state (no stale peer access).
dmr = DeviceMemoryResource(dev0, DeviceMemoryResourceOptions())

def check(expected):
assert isinstance(dmr.peer_accessible_by, tuple)
assert dmr.peer_accessible_by == expected

# No access to begin with.
check(expected=())
# fmt: off
dmr.peer_accessible_by = (0,) ; check(expected=())
dmr.peer_accessible_by = (1,) ; check(expected=(1,))
dmr.peer_accessible_by = (0, 1) ; check(expected=(1,))
dmr.peer_accessible_by = () ; check(expected=())
dmr.peer_accessible_by = [0, 1] ; check(expected=(1,))
dmr.peer_accessible_by = set() ; check(expected=())
dmr.peer_accessible_by = [1, 1, 1, 1, 1] ; check(expected=(1,))
# fmt: on

with pytest.raises(ValueError, match=r"device_id must be \>\= 0"):
dmr.peer_accessible_by = [-1] # device ID out of bounds

num_devices = len(cuda.core.Device.get_all_devices())

with pytest.raises(ValueError, match=r"device_id must be within \[0, \d+\)"):
dmr.peer_accessible_by = [num_devices] # device ID out of bounds


def test_peer_access_transitions(mempool_device_x3):
"""Advanced tests for dmr.peer_accessible_by."""

Expand Down
67 changes: 67 additions & 0 deletions cuda_core/tests/test_memory_peer_access_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass

import pytest

from cuda.core._memory._peer_access_utils import PeerAccessPlan, plan_peer_access_update


@dataclass(frozen=True)
class DummyDevice:
device_id: int


def _resolve_device_id(device) -> int:
if isinstance(device, DummyDevice):
return device.device_id
return int(device)


def test_plan_peer_access_update_normalizes_requests():
plan = plan_peer_access_update(
owner_device_id=1,
current_peer_ids=(),
requested_devices=[1, DummyDevice(3), 2, DummyDevice(2), 3],
resolve_device_id=_resolve_device_id,
can_access_peer=lambda _device_id: True,
)

assert plan == PeerAccessPlan(
target_ids=(2, 3),
to_add=(2, 3),
to_remove=(),
)


def test_plan_peer_access_update_rejects_inaccessible_peers():
with pytest.raises(ValueError, match=r"Device 0 cannot access peer\(s\): 2, 4"):
plan_peer_access_update(
owner_device_id=0,
current_peer_ids=(1,),
requested_devices=[4, 0, DummyDevice(2), 1],
resolve_device_id=_resolve_device_id,
can_access_peer=lambda device_id: device_id == 1,
)


def test_plan_peer_access_update_covers_all_state_transitions():
states = [(), (1,), (2,), (1, 2)]
for current_state in states:
for requested_state in states:
plan = plan_peer_access_update(
owner_device_id=0,
current_peer_ids=current_state,
requested_devices=requested_state,
resolve_device_id=_resolve_device_id,
can_access_peer=lambda device_id: device_id in {1, 2},
)

assert plan == PeerAccessPlan(
target_ids=requested_state,
to_add=tuple(sorted(set(requested_state) - set(current_state))),
to_remove=tuple(sorted(set(current_state) - set(requested_state))),
)
Loading