Skip to content

Commit 02251f6

Browse files
committed
feat(BA-4144): Refactor device discovery with GlobalDeviceInfo
Introduce GlobalDeviceInfo dataclass to separate device discovery from allocation map creation in ResourceAllocator. This enables cleaner separation of concerns and more flexible device-based allocation strategies in the future. Key changes include splitting __ainit__() into three distinct phases: device discovery from plugins, allocation map creation, and slot calculation. The _calculate_total_slots() method now uses plugin.available_slots() directly instead of reading from allocation maps, providing cleaner abstraction boundaries. Added comprehensive unit tests covering GlobalDeviceInfo initialization, _create_global_devices() with single and multiple plugins, empty device handling, and slot calculation with aggregation.
1 parent a288c31 commit 02251f6

3 files changed

Lines changed: 278 additions & 11 deletions

File tree

changes/8440.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Introduce GlobalDeviceInfo and device discovery infrastructure to separate device discovery from allocation in ResourceAllocator

src/ai/backend/agent/resources.py

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,23 @@ class ComputerContext:
101101
alloc_map: AbstractAllocMap
102102

103103

104+
@attrs.define(auto_attribs=True, slots=True)
105+
class GlobalDeviceInfo:
106+
"""
107+
Represents discovered devices from a compute plugin.
108+
109+
This struct separates device discovery from allocation. It contains
110+
only the plugin reference and discovered devices, without any
111+
allocation map. Allocation maps are created separately when needed.
112+
"""
113+
114+
plugin: AbstractComputePlugin
115+
devices: Sequence[AbstractComputeDevice]
116+
117+
118+
type GlobalDeviceMap = Mapping[DeviceName, GlobalDeviceInfo]
119+
120+
104121
@dataclass
105122
class DeviceView:
106123
device: DeviceName
@@ -543,15 +560,24 @@ def __init__(self, local_config: AgentUnifiedConfig, etcd: AsyncEtcd) -> None:
543560

544561
async def __ainit__(self) -> None:
545562
alloc_map_mod.log_alloc_map = self.local_config.debug.log_alloc_map
546-
computers = await self._load_resources()
563+
plugins = await self._load_resources()
547564

565+
# Phase 1: Discover devices from all plugins (separation of concerns)
566+
global_device_map = await self._create_global_devices(plugins)
567+
568+
# Phase 2: Create allocation maps and computer contexts
548569
computer_contexts: dict[DeviceName, ComputerContext] = {}
549-
for name, computer in computers.items():
550-
devices = await computer.list_devices()
551-
alloc_map = await computer.create_alloc_map()
552-
computer_contexts[name] = ComputerContext(computer, devices, alloc_map)
570+
for device_name, device_info in global_device_map.items():
571+
alloc_map = await device_info.plugin.create_alloc_map()
572+
computer_contexts[device_name] = ComputerContext(
573+
device_info.plugin,
574+
device_info.devices,
575+
alloc_map,
576+
)
553577
self.computers = computer_contexts
554-
total_slots = self._calculate_total_slots()
578+
579+
# Phase 3: Calculate slots and configure agents
580+
total_slots = await self._calculate_total_slots()
555581
self.available_total_slots = self._calculate_available_total_slots(total_slots)
556582

557583
agent_computers = {}
@@ -621,11 +647,19 @@ def get_resource_scaling_factor(self, agent_id: AgentId) -> SlotsMap:
621647
raise AgentIdNotFoundError(f"Agent ID {agent_id} not in computers")
622648
return self.agent_resource_scaling_factor[agent_id]
623649

624-
def _calculate_total_slots(self) -> SlotsMap:
650+
async def _calculate_total_slots(self) -> SlotsMap:
651+
"""
652+
Calculate total available slots by querying each plugin directly.
653+
654+
This method uses the plugin's available_slots() method rather than
655+
reading from allocation maps, providing a cleaner separation between
656+
device discovery and allocation tracking.
657+
"""
625658
total_slots: dict[SlotName, Decimal] = defaultdict(lambda: Decimal("0"))
626-
for device in self.computers.values():
627-
for slot_info in device.alloc_map.device_slots.values():
628-
total_slots[slot_info.slot_name] += slot_info.amount
659+
for ctx in self.computers.values():
660+
plugin_slots = await ctx.instance.available_slots()
661+
for slot_name, amount in plugin_slots.items():
662+
total_slots[slot_name] += amount
629663
return total_slots
630664

631665
def _calculate_available_total_slots(self, total_slots: SlotsMap) -> SlotsMap:
@@ -691,6 +725,37 @@ async def _load_resources(self) -> Mapping[DeviceName, AbstractComputePlugin]:
691725
self.local_config.model_dump(by_alias=True),
692726
)
693727

728+
async def _create_global_devices(
729+
self,
730+
plugins: Mapping[DeviceName, AbstractComputePlugin],
731+
) -> GlobalDeviceMap:
732+
"""
733+
Discover available devices from all compute plugins.
734+
735+
This method iterates through all registered compute plugins and
736+
discovers the physical devices available from each. The result is
737+
a mapping of device names to GlobalDeviceInfo, which contains the
738+
plugin reference and the discovered devices.
739+
740+
This separation allows device discovery to be performed independently
741+
of allocation map creation, enabling more flexible device-based
742+
allocation strategies in the future.
743+
744+
Args:
745+
plugins: Mapping of device names to compute plugins
746+
747+
Returns:
748+
GlobalDeviceMap containing discovered devices from all plugins
749+
"""
750+
global_devices: dict[DeviceName, GlobalDeviceInfo] = {}
751+
for device_name, plugin in plugins.items():
752+
devices = await plugin.list_devices()
753+
global_devices[device_name] = GlobalDeviceInfo(
754+
plugin=plugin,
755+
devices=list(devices),
756+
)
757+
return global_devices
758+
694759
async def _scan_available_resources(self) -> Mapping[SlotName, Decimal]:
695760
return await self._agent_discovery.scan_available_resources({
696761
name: cctx.instance for name, cctx in self.computers.items()

tests/unit/agent/test_resources.py

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import textwrap
44
import unittest.mock
55
import uuid
6+
from collections.abc import Sequence
67
from decimal import Decimal
78
from pathlib import Path
89
from unittest import mock
10+
from unittest.mock import AsyncMock, Mock
911

1012
import pytest
1113
from aioresponses import aioresponses
@@ -15,7 +17,15 @@
1517
from ai.backend.agent.affinity_map import AffinityMap, AffinityPolicy
1618
from ai.backend.agent.dummy.intrinsic import CPUPlugin, MemoryPlugin
1719
from ai.backend.agent.exception import FractionalResourceFragmented, InsufficientResource
18-
from ai.backend.agent.resources import ComputerContext, align_memory, scan_resource_usage_per_slot
20+
from ai.backend.agent.resources import (
21+
AbstractComputeDevice,
22+
AbstractComputePlugin,
23+
ComputerContext,
24+
GlobalDeviceInfo,
25+
ResourceAllocator,
26+
align_memory,
27+
scan_resource_usage_per_slot,
28+
)
1929
from ai.backend.agent.vendor import linux
2030
from ai.backend.common.types import DeviceId, DeviceName, KernelId, ResourceSlot, SlotName
2131

@@ -510,3 +520,194 @@ def test_align_memory():
510520
assert usable % align == 0
511521
assert usable + actual_reserved == orig
512522
assert 990 <= actual_reserved <= 1010
523+
524+
525+
class TestGlobalDeviceInfo:
526+
"""Tests for GlobalDeviceInfo dataclass."""
527+
528+
def test_initialization_with_devices(self) -> None:
529+
"""Verify GlobalDeviceInfo correctly stores plugin and devices."""
530+
mock_plugin = Mock(spec=AbstractComputePlugin)
531+
mock_device = Mock(spec=AbstractComputeDevice)
532+
mock_device.device_id = DeviceId("0")
533+
534+
info = GlobalDeviceInfo(plugin=mock_plugin, devices=[mock_device])
535+
536+
assert info.plugin is mock_plugin
537+
assert len(info.devices) == 1
538+
assert info.devices[0] is mock_device
539+
540+
def test_initialization_with_empty_devices(self) -> None:
541+
"""Verify GlobalDeviceInfo handles empty device list."""
542+
mock_plugin = Mock(spec=AbstractComputePlugin)
543+
544+
info = GlobalDeviceInfo(plugin=mock_plugin, devices=[])
545+
546+
assert info.plugin is mock_plugin
547+
assert len(info.devices) == 0
548+
assert isinstance(info.devices, Sequence)
549+
550+
def test_no_alloc_map_attribute(self) -> None:
551+
"""Verify GlobalDeviceInfo does not have alloc_map (separation of concerns)."""
552+
mock_plugin = Mock(spec=AbstractComputePlugin)
553+
554+
info = GlobalDeviceInfo(plugin=mock_plugin, devices=[])
555+
556+
assert not hasattr(info, "alloc_map")
557+
558+
559+
@pytest.mark.asyncio
560+
class TestCreateGlobalDevices:
561+
"""Tests for _create_global_devices method."""
562+
563+
async def test_discovers_devices_from_single_plugin(self) -> None:
564+
"""Verify device discovery works with a single plugin."""
565+
mock_device = Mock(spec=AbstractComputeDevice)
566+
mock_device.device_id = DeviceId("gpu-0")
567+
568+
mock_plugin = AsyncMock(spec=AbstractComputePlugin)
569+
mock_plugin.list_devices.return_value = [mock_device]
570+
571+
plugins = {DeviceName("cuda"): mock_plugin}
572+
573+
# Create a minimal ResourceAllocator mock to test the method
574+
allocator = Mock(spec=ResourceAllocator)
575+
allocator._create_global_devices = ResourceAllocator._create_global_devices.__get__(
576+
allocator, ResourceAllocator
577+
)
578+
579+
result = await allocator._create_global_devices(plugins)
580+
581+
assert DeviceName("cuda") in result
582+
assert result[DeviceName("cuda")].plugin is mock_plugin
583+
assert len(result[DeviceName("cuda")].devices) == 1
584+
mock_plugin.list_devices.assert_called_once()
585+
586+
async def test_discovers_devices_from_multiple_plugins(self) -> None:
587+
"""Verify correct aggregation of devices from CPU, memory, and accelerator plugins."""
588+
cpu_device = Mock(spec=AbstractComputeDevice)
589+
cpu_device.device_id = DeviceId("0")
590+
mem_device = Mock(spec=AbstractComputeDevice)
591+
mem_device.device_id = DeviceId("root")
592+
gpu_device = Mock(spec=AbstractComputeDevice)
593+
gpu_device.device_id = DeviceId("gpu-0")
594+
595+
cpu_plugin = AsyncMock(spec=AbstractComputePlugin)
596+
cpu_plugin.list_devices.return_value = [cpu_device]
597+
mem_plugin = AsyncMock(spec=AbstractComputePlugin)
598+
mem_plugin.list_devices.return_value = [mem_device]
599+
gpu_plugin = AsyncMock(spec=AbstractComputePlugin)
600+
gpu_plugin.list_devices.return_value = [gpu_device]
601+
602+
plugins = {
603+
DeviceName("cpu"): cpu_plugin,
604+
DeviceName("mem"): mem_plugin,
605+
DeviceName("cuda"): gpu_plugin,
606+
}
607+
608+
allocator = Mock(spec=ResourceAllocator)
609+
allocator._create_global_devices = ResourceAllocator._create_global_devices.__get__(
610+
allocator, ResourceAllocator
611+
)
612+
613+
result = await allocator._create_global_devices(plugins)
614+
615+
assert len(result) == 3
616+
assert DeviceName("cpu") in result
617+
assert DeviceName("mem") in result
618+
assert DeviceName("cuda") in result
619+
620+
# Verify each plugin's devices are correctly mapped
621+
assert result[DeviceName("cpu")].devices[0].device_id == DeviceId("0")
622+
assert result[DeviceName("mem")].devices[0].device_id == DeviceId("root")
623+
assert result[DeviceName("cuda")].devices[0].device_id == DeviceId("gpu-0")
624+
625+
626+
@pytest.mark.asyncio
627+
class TestEmptyPluginHandling:
628+
"""Tests for behavior when a plugin reports no devices."""
629+
630+
async def test_handles_plugin_with_no_devices(self) -> None:
631+
"""Verify behavior when a plugin reports no devices."""
632+
mock_plugin = AsyncMock(spec=AbstractComputePlugin)
633+
mock_plugin.list_devices.return_value = []
634+
635+
plugins = {DeviceName("mock"): mock_plugin}
636+
637+
allocator = Mock(spec=ResourceAllocator)
638+
allocator._create_global_devices = ResourceAllocator._create_global_devices.__get__(
639+
allocator, ResourceAllocator
640+
)
641+
642+
result = await allocator._create_global_devices(plugins)
643+
644+
assert DeviceName("mock") in result
645+
assert len(result[DeviceName("mock")].devices) == 0
646+
assert result[DeviceName("mock")].plugin is mock_plugin
647+
648+
649+
@pytest.mark.asyncio
650+
class TestCalculateTotalSlots:
651+
"""Tests for _calculate_total_slots method."""
652+
653+
async def test_calculate_total_slots_with_plugins(self) -> None:
654+
"""Verify _calculate_total_slots returns correct values using plugin.available_slots()."""
655+
# Create mock plugins that return specific slot amounts
656+
cpu_plugin = AsyncMock()
657+
cpu_plugin.available_slots.return_value = {SlotName("cpu"): Decimal(4)}
658+
659+
mem_plugin = AsyncMock()
660+
mem_plugin.available_slots.return_value = {SlotName("mem"): Decimal(8192)}
661+
662+
# Create mock computer contexts
663+
cpu_ctx = Mock(spec=ComputerContext)
664+
cpu_ctx.instance = cpu_plugin
665+
666+
mem_ctx = Mock(spec=ComputerContext)
667+
mem_ctx.instance = mem_plugin
668+
669+
# Create allocator mock with computers attribute
670+
allocator = Mock(spec=ResourceAllocator)
671+
allocator.computers = {
672+
DeviceName("cpu"): cpu_ctx,
673+
DeviceName("mem"): mem_ctx,
674+
}
675+
allocator._calculate_total_slots = ResourceAllocator._calculate_total_slots.__get__(
676+
allocator, ResourceAllocator
677+
)
678+
679+
total_slots = await allocator._calculate_total_slots()
680+
681+
assert total_slots[SlotName("cpu")] == Decimal(4)
682+
assert total_slots[SlotName("mem")] == Decimal(8192)
683+
cpu_plugin.available_slots.assert_called_once()
684+
mem_plugin.available_slots.assert_called_once()
685+
686+
async def test_calculate_total_slots_aggregates_same_slot_names(self) -> None:
687+
"""Verify _calculate_total_slots aggregates slots with the same name from multiple plugins."""
688+
# Create two plugins that both report a "gpu" slot
689+
gpu1_plugin = AsyncMock()
690+
gpu1_plugin.available_slots.return_value = {SlotName("cuda.shares"): Decimal("2.0")}
691+
692+
gpu2_plugin = AsyncMock()
693+
gpu2_plugin.available_slots.return_value = {SlotName("cuda.shares"): Decimal("3.0")}
694+
695+
gpu1_ctx = Mock(spec=ComputerContext)
696+
gpu1_ctx.instance = gpu1_plugin
697+
698+
gpu2_ctx = Mock(spec=ComputerContext)
699+
gpu2_ctx.instance = gpu2_plugin
700+
701+
allocator = Mock(spec=ResourceAllocator)
702+
allocator.computers = {
703+
DeviceName("cuda1"): gpu1_ctx,
704+
DeviceName("cuda2"): gpu2_ctx,
705+
}
706+
allocator._calculate_total_slots = ResourceAllocator._calculate_total_slots.__get__(
707+
allocator, ResourceAllocator
708+
)
709+
710+
total_slots = await allocator._calculate_total_slots()
711+
712+
# Slots with the same name should be aggregated
713+
assert total_slots[SlotName("cuda.shares")] == Decimal("5.0")

0 commit comments

Comments
 (0)