From f179fb4496ff8a9ec5c02f1cecb45197e80edc22 Mon Sep 17 00:00:00 2001 From: "Neil R. Spruit" Date: Fri, 16 Jan 2026 10:30:45 -0800 Subject: [PATCH] Ensure explicit multidriver teardown during context destroy - When multiple drivers are present, delayed init did not always update the handles in the alldrivers such that the drivers were unloaded at teardown. To fix this, the loader now creates a unique list of driver library handles opened during the process that need to be released during context destruction. - Unit tests added to ensure this functionality continues to operate correctly. Signed-off-by: Neil R. Spruit --- source/loader/ze_loader.cpp | 43 ++- test/CMakeLists.txt | 30 +- test/driver_teardown_unit_tests.cpp | 496 ++++++++++++++++++++++++++++ 3 files changed, 554 insertions(+), 15 deletions(-) create mode 100644 test/driver_teardown_unit_tests.cpp diff --git a/source/loader/ze_loader.cpp b/source/loader/ze_loader.cpp index 333f159f..88369a95 100644 --- a/source/loader/ze_loader.cpp +++ b/source/loader/ze_loader.cpp @@ -1,6 +1,6 @@ /* * - * Copyright (C) 2019-2025 Intel Corporation + * Copyright (C) 2019-2026 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -9,6 +9,7 @@ #include "driver_discovery.h" #include +#include #ifdef __linux__ #include @@ -824,18 +825,36 @@ namespace loader } } - for( auto& drv : allDrivers ) - { + // Collect all unique driver handles from allDrivers, zeDrivers, and zesDrivers + // to ensure we free each library exactly once, avoiding double-free issues + std::set uniqueHandles; + for (const auto& drv : allDrivers) { if (drv.handle) { - auto free_result = FREE_DRIVER_LIBRARY( drv.handle ); - auto failure = FREE_DRIVER_LIBRARY_FAILURE_CHECK(free_result); - if (debugTraceEnabled && failure) { - GET_LIBRARY_ERROR(freeLibraryErrorValue); - if (!freeLibraryErrorValue.empty()) { - std::string errorMessage = "Free Library Failed for " + drv.name + " With "; - debug_trace_message(errorMessage, freeLibraryErrorValue); - freeLibraryErrorValue.clear(); - } + uniqueHandles.insert(drv.handle); + } + } + for (const auto& drv : zeDrivers) { + if (drv.handle) { + uniqueHandles.insert(drv.handle); + } + } + for (const auto& drv : zesDrivers) { + if (drv.handle) { + uniqueHandles.insert(drv.handle); + } + } + + // Free each unique driver library exactly once + for (auto handle : uniqueHandles) + { + auto free_result = FREE_DRIVER_LIBRARY( handle ); + auto failure = FREE_DRIVER_LIBRARY_FAILURE_CHECK(free_result); + if (debugTraceEnabled && failure) { + GET_LIBRARY_ERROR(freeLibraryErrorValue); + if (!freeLibraryErrorValue.empty()) { + std::string errorMessage = "Free Library Failed With "; + debug_trace_message(errorMessage, freeLibraryErrorValue); + freeLibraryErrorValue.clear(); } } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index e13c178f..e60bd3c3 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -9,10 +9,13 @@ add_executable( loader_tracing_layer.cpp ) -# Only include driver_ordering_unit_tests for static builds or non-Windows platforms -# as it requires internal loader symbols that are not exported in Windows DLLs +# Only include driver_ordering_unit_tests and driver_teardown_unit_tests for static builds or non-Windows platforms +# as they require internal loader symbols that are not exported in Windows DLLs if(BUILD_STATIC OR NOT WIN32) - target_sources(tests PRIVATE driver_ordering_unit_tests.cpp) + target_sources(tests PRIVATE + driver_ordering_unit_tests.cpp + driver_teardown_unit_tests.cpp + ) endif() # For builds on non-Windows platforms, include init_driver_unit_tests @@ -745,7 +748,28 @@ set_property(TEST driver_ordering_trim_function PROPERTY ENVIRONMENT "ZE_ENABLE_ add_test(NAME driver_ordering_parse_driver_order COMMAND tests --gtest_filter=DriverOrderingHelperFunctionsTest.ParseDriverOrder_*) set_property(TEST driver_ordering_parse_driver_order PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") +# Driver Teardown Unit Tests +add_test(NAME driver_teardown_unit_tests COMMAND tests --gtest_filter=DriverTeardownUnitTest.*) +set_property(TEST driver_teardown_unit_tests PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") + +# Individual Driver Teardown Unit Tests for better granular reporting +add_test(NAME driver_teardown_basic_collection COMMAND tests --gtest_filter=DriverTeardownUnitTest.NoDrivers_*:DriverTeardownUnitTest.AllNullHandles_*:DriverTeardownUnitTest.SingleDriverInAllDrivers_*:DriverTeardownUnitTest.MultipleUniqueHandlesInAllDrivers_*) +set_property(TEST driver_teardown_basic_collection PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") + +add_test(NAME driver_teardown_duplicate_handles COMMAND tests --gtest_filter=DriverTeardownUnitTest.SameHandleInAllThreeVectors_*:DriverTeardownUnitTest.MultipleDuplicateHandlesAcrossVectors_*:DriverTeardownUnitTest.SameHandleInZeAndZesOnly_*) +set_property(TEST driver_teardown_duplicate_handles PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") + +add_test(NAME driver_teardown_handle_distribution COMMAND tests --gtest_filter=DriverTeardownUnitTest.OnlyZeDriversHaveHandles_*:DriverTeardownUnitTest.OnlyZesDriversHaveHandles_*:DriverTeardownUnitTest.DifferentHandlesInEachVector_*) +set_property(TEST driver_teardown_handle_distribution PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") + +add_test(NAME driver_teardown_realistic_scenarios COMMAND tests --gtest_filter=DriverTeardownUnitTest.SingleDriverScenario_*:DriverTeardownUnitTest.MultipleDriversScenario_*:DriverTeardownUnitTest.PartialInitialization_*:DriverTeardownUnitTest.DifferentDriversInitializedInZeVsZes_*) +set_property(TEST driver_teardown_realistic_scenarios PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") + +add_test(NAME driver_teardown_edge_cases COMMAND tests --gtest_filter=DriverTeardownUnitTest.MixedNullAndValidHandles_*:DriverTeardownUnitTest.VectorSizeMismatch_*:DriverTeardownUnitTest.LargeNumberOfDrivers_*:DriverTeardownUnitTest.CustomDriverScenario_*) +set_property(TEST driver_teardown_edge_cases PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") +add_test(NAME driver_teardown_driver_types COMMAND tests --gtest_filter=DriverTeardownUnitTest.MixedDriverTypes_*:DriverTeardownUnitTest.ComplexRealWorldScenario_*) +set_property(TEST driver_teardown_driver_types PROPERTY ENVIRONMENT "ZE_ENABLE_LOADER_DEBUG_TRACE=1;ZE_ENABLE_NULL_DRIVER=1") # Init Driver Unit Tests add_test(NAME init_driver_unit_tests COMMAND tests --gtest_filter=InitDriverUnitTest.*) if (MSVC) diff --git a/test/driver_teardown_unit_tests.cpp b/test/driver_teardown_unit_tests.cpp new file mode 100644 index 00000000..87208d1f --- /dev/null +++ b/test/driver_teardown_unit_tests.cpp @@ -0,0 +1,496 @@ +/* + * + * Copyright (C) 2026 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "gtest/gtest.h" + +#include "source/loader/ze_loader_internal.h" +#include "ze_api.h" + +#include +#include +#include + +#if defined(_WIN32) + #include +#else + #include +#endif + +namespace { + +// Mock handle values for testing (using distinct non-null values) +#if defined(_WIN32) + #define MOCK_HANDLE_1 reinterpret_cast(0x1000) + #define MOCK_HANDLE_2 reinterpret_cast(0x2000) + #define MOCK_HANDLE_3 reinterpret_cast(0x3000) + #define MOCK_HANDLE_4 reinterpret_cast(0x4000) +#else + #define MOCK_HANDLE_1 reinterpret_cast(0x1000) + #define MOCK_HANDLE_2 reinterpret_cast(0x2000) + #define MOCK_HANDLE_3 reinterpret_cast(0x3000) + #define MOCK_HANDLE_4 reinterpret_cast(0x4000) +#endif + +// Helper function to create a mock driver with specific handle +loader::driver_t createMockDriverWithHandle(const std::string& name, HMODULE handle, loader::zel_driver_type_t type = loader::ZEL_DRIVER_TYPE_GPU) { + loader::driver_t driver; + driver.name = name; + driver.handle = handle; + driver.driverType = type; + driver.initStatus = ZE_RESULT_SUCCESS; + driver.driverInuse = false; + driver.ddiInitialized = false; + return driver; +} + +// Helper function to collect unique handles (mimics destructor logic) +std::set collectUniqueHandles( + const loader::driver_vector_t& allDrivers, + const loader::driver_vector_t& zeDrivers, + const loader::driver_vector_t& zesDrivers) +{ + std::set uniqueHandles; + + for (const auto& drv : allDrivers) { + if (drv.handle) { + uniqueHandles.insert(drv.handle); + } + } + for (const auto& drv : zeDrivers) { + if (drv.handle) { + uniqueHandles.insert(drv.handle); + } + } + for (const auto& drv : zesDrivers) { + if (drv.handle) { + uniqueHandles.insert(drv.handle); + } + } + + return uniqueHandles; +} + +// Test fixture for driver teardown functionality +class DriverTeardownUnitTest : public ::testing::Test { +protected: + void SetUp() override { + // Clear all driver vectors before each test + allDrivers.clear(); + zeDrivers.clear(); + zesDrivers.clear(); + } + + void TearDown() override { + // Cleanup after each test + allDrivers.clear(); + zeDrivers.clear(); + zesDrivers.clear(); + } + + loader::driver_vector_t allDrivers; + loader::driver_vector_t zeDrivers; + loader::driver_vector_t zesDrivers; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Tests for basic handle collection +/////////////////////////////////////////////////////////////////////////////// + +TEST_F(DriverTeardownUnitTest, NoDrivers_ShouldReturnEmptySet) { + // Arrange - all vectors are empty from SetUp + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert + EXPECT_EQ(uniqueHandles.size(), 0); +} + +TEST_F(DriverTeardownUnitTest, AllNullHandles_ShouldReturnEmptySet) { + // Arrange + allDrivers.push_back(createMockDriverWithHandle("driver1", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver2", nullptr)); + zeDrivers.push_back(createMockDriverWithHandle("driver1", nullptr)); + zesDrivers.push_back(createMockDriverWithHandle("driver1", nullptr)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert + EXPECT_EQ(uniqueHandles.size(), 0); +} + +TEST_F(DriverTeardownUnitTest, SingleDriverInAllDrivers_ShouldReturnOneHandle) { + // Arrange + allDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert + EXPECT_EQ(uniqueHandles.size(), 1); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); +} + +TEST_F(DriverTeardownUnitTest, MultipleUniqueHandlesInAllDrivers_ShouldReturnAllHandles) { + // Arrange + allDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + allDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + allDrivers.push_back(createMockDriverWithHandle("driver3", MOCK_HANDLE_3)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert + EXPECT_EQ(uniqueHandles.size(), 3); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_3), uniqueHandles.end()); +} + +/////////////////////////////////////////////////////////////////////////////// +// Tests for duplicate handle scenarios +/////////////////////////////////////////////////////////////////////////////// + +TEST_F(DriverTeardownUnitTest, SameHandleInAllThreeVectors_ShouldReturnOneHandle) { + // Arrange - simulate the case where all vectors have copies with the same handle + allDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zeDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zesDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert - should deduplicate to single handle + EXPECT_EQ(uniqueHandles.size(), 1); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); +} + +TEST_F(DriverTeardownUnitTest, MultipleDuplicateHandlesAcrossVectors_ShouldDeduplicate) { + // Arrange + allDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + allDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + + zeDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zeDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + + zesDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zesDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert - should have only 2 unique handles + EXPECT_EQ(uniqueHandles.size(), 2); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); +} + +TEST_F(DriverTeardownUnitTest, SameHandleInZeAndZesOnly_ShouldReturnOneHandle) { + // Arrange - allDrivers has null, but zeDrivers and zesDrivers have handles + allDrivers.push_back(createMockDriverWithHandle("driver1", nullptr)); + zeDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zesDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert - should deduplicate the handle from zeDrivers and zesDrivers + EXPECT_EQ(uniqueHandles.size(), 1); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); +} + +/////////////////////////////////////////////////////////////////////////////// +// Tests for handle distribution across vectors +/////////////////////////////////////////////////////////////////////////////// + +TEST_F(DriverTeardownUnitTest, OnlyZeDriversHaveHandles_ShouldReturnThoseHandles) { + // Arrange + allDrivers.push_back(createMockDriverWithHandle("driver1", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver2", nullptr)); + + zeDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zeDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert + EXPECT_EQ(uniqueHandles.size(), 2); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); +} + +TEST_F(DriverTeardownUnitTest, OnlyZesDriversHaveHandles_ShouldReturnThoseHandles) { + // Arrange + allDrivers.push_back(createMockDriverWithHandle("driver1", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver2", nullptr)); + + zesDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zesDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert + EXPECT_EQ(uniqueHandles.size(), 2); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); +} + +TEST_F(DriverTeardownUnitTest, DifferentHandlesInEachVector_ShouldReturnAllUniqueHandles) { + // Arrange - each vector has different handles + allDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zeDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + zesDrivers.push_back(createMockDriverWithHandle("driver3", MOCK_HANDLE_3)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert + EXPECT_EQ(uniqueHandles.size(), 3); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_3), uniqueHandles.end()); +} + +/////////////////////////////////////////////////////////////////////////////// +// Tests for realistic initialization scenarios +/////////////////////////////////////////////////////////////////////////////// + +TEST_F(DriverTeardownUnitTest, SingleDriverScenario_AllVectorsHaveSameHandle) { + // Arrange - simulate single driver initialization where all vectors get the handle + allDrivers.push_back(createMockDriverWithHandle("gpu_driver", MOCK_HANDLE_1)); + zeDrivers.push_back(createMockDriverWithHandle("gpu_driver", MOCK_HANDLE_1)); + zesDrivers.push_back(createMockDriverWithHandle("gpu_driver", MOCK_HANDLE_1)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert - should deduplicate to single handle + EXPECT_EQ(uniqueHandles.size(), 1); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); +} + +TEST_F(DriverTeardownUnitTest, MultipleDriversScenario_AllDriversHasNulls_OthersHaveHandles) { + // Arrange - simulate multiple drivers where allDrivers initially has nulls, + // but zeDrivers and zesDrivers get handles loaded via init_driver() + allDrivers.push_back(createMockDriverWithHandle("driver1", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver2", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver3", nullptr)); + + // After init, zeDrivers and zesDrivers have handles loaded + zeDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zeDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + zeDrivers.push_back(createMockDriverWithHandle("driver3", MOCK_HANDLE_3)); + + zesDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zesDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + zesDrivers.push_back(createMockDriverWithHandle("driver3", MOCK_HANDLE_3)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert - should deduplicate to 3 unique handles + EXPECT_EQ(uniqueHandles.size(), 3); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_3), uniqueHandles.end()); +} + +TEST_F(DriverTeardownUnitTest, PartialInitialization_SomeDriversInitialized) { + // Arrange - simulate partial initialization where only some drivers loaded + allDrivers.push_back(createMockDriverWithHandle("driver1", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver2", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver3", nullptr)); + + // Only first two drivers initialized in zeDrivers + zeDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zeDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + zeDrivers.push_back(createMockDriverWithHandle("driver3", nullptr)); + + // All three initialized in zesDrivers + zesDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zesDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + zesDrivers.push_back(createMockDriverWithHandle("driver3", MOCK_HANDLE_3)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert - should have all 3 handles + EXPECT_EQ(uniqueHandles.size(), 3); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_3), uniqueHandles.end()); +} + +TEST_F(DriverTeardownUnitTest, DifferentDriversInitializedInZeVsZes) { + // Arrange - simulate where zeDrivers and zesDrivers initialized different sets + allDrivers.push_back(createMockDriverWithHandle("driver1", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver2", nullptr)); + + // Only driver1 initialized in zeDrivers + zeDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zeDrivers.push_back(createMockDriverWithHandle("driver2", nullptr)); + + // Only driver2 initialized in zesDrivers + zesDrivers.push_back(createMockDriverWithHandle("driver1", nullptr)); + zesDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert - should have both handles + EXPECT_EQ(uniqueHandles.size(), 2); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); +} + +/////////////////////////////////////////////////////////////////////////////// +// Tests for edge cases and mixed scenarios +/////////////////////////////////////////////////////////////////////////////// + +TEST_F(DriverTeardownUnitTest, MixedNullAndValidHandles_ShouldOnlyReturnValid) { + // Arrange + allDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + allDrivers.push_back(createMockDriverWithHandle("driver2", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver3", MOCK_HANDLE_2)); + + zeDrivers.push_back(createMockDriverWithHandle("driver1", nullptr)); + zeDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_3)); + + zesDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert - should have 3 unique valid handles + EXPECT_EQ(uniqueHandles.size(), 3); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_3), uniqueHandles.end()); +} + +TEST_F(DriverTeardownUnitTest, VectorSizeMismatch_ShouldHandleCorrectly) { + // Arrange - vectors have different sizes + allDrivers.push_back(createMockDriverWithHandle("driver1", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver2", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver3", nullptr)); + + zeDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zeDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + // zeDrivers has only 2 elements + + zesDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + // zesDrivers has only 1 element + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert - should handle size mismatch and return unique handles + EXPECT_EQ(uniqueHandles.size(), 2); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); +} + +TEST_F(DriverTeardownUnitTest, LargeNumberOfDrivers_ShouldHandleEfficiently) { + // Arrange - simulate a large number of drivers + const size_t numDrivers = 100; + + for (size_t i = 0; i < numDrivers; ++i) { + HMODULE handle = reinterpret_cast(0x1000 + i * 0x100); + allDrivers.push_back(createMockDriverWithHandle("driver" + std::to_string(i), handle)); + zeDrivers.push_back(createMockDriverWithHandle("driver" + std::to_string(i), handle)); + zesDrivers.push_back(createMockDriverWithHandle("driver" + std::to_string(i), handle)); + } + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert - should deduplicate all to numDrivers unique handles + EXPECT_EQ(uniqueHandles.size(), numDrivers); +} + +TEST_F(DriverTeardownUnitTest, CustomDriverScenario_ShouldIncludeCustomHandles) { + // Arrange - simulate custom drivers alongside standard drivers + allDrivers.push_back(createMockDriverWithHandle("standard_driver", MOCK_HANDLE_1)); + allDrivers.push_back(createMockDriverWithHandle("custom_driver", MOCK_HANDLE_2)); + + auto customDriver = createMockDriverWithHandle("custom_driver", MOCK_HANDLE_2); + customDriver.customDriver = true; + + zeDrivers.push_back(createMockDriverWithHandle("standard_driver", MOCK_HANDLE_1)); + zeDrivers.push_back(customDriver); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert + EXPECT_EQ(uniqueHandles.size(), 2); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); +} + +/////////////////////////////////////////////////////////////////////////////// +// Tests for driver type specific scenarios +/////////////////////////////////////////////////////////////////////////////// + +TEST_F(DriverTeardownUnitTest, MixedDriverTypes_ShouldCollectAllHandles) { + // Arrange - different driver types + allDrivers.push_back(createMockDriverWithHandle("discrete_gpu", MOCK_HANDLE_1, + loader::ZEL_DRIVER_TYPE_DISCRETE_GPU)); + allDrivers.push_back(createMockDriverWithHandle("integrated_gpu", MOCK_HANDLE_2, + loader::ZEL_DRIVER_TYPE_INTEGRATED_GPU)); + allDrivers.push_back(createMockDriverWithHandle("npu", MOCK_HANDLE_3, + loader::ZEL_DRIVER_TYPE_NPU)); + + zeDrivers = allDrivers; // Copy to zeDrivers + zesDrivers = allDrivers; // Copy to zesDrivers + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert - should deduplicate regardless of driver type + EXPECT_EQ(uniqueHandles.size(), 3); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_3), uniqueHandles.end()); +} + +TEST_F(DriverTeardownUnitTest, ComplexRealWorldScenario_ShouldHandleCorrectly) { + // Arrange - simulate a complex real-world scenario + // Initial state: allDrivers has some handles, some nulls + allDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + allDrivers.push_back(createMockDriverWithHandle("driver2", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver3", nullptr)); + allDrivers.push_back(createMockDriverWithHandle("driver4", MOCK_HANDLE_4)); + + // After zeInit: some drivers initialized + zeDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zeDrivers.push_back(createMockDriverWithHandle("driver2", MOCK_HANDLE_2)); + zeDrivers.push_back(createMockDriverWithHandle("driver3", nullptr)); // Failed to init + zeDrivers.push_back(createMockDriverWithHandle("driver4", MOCK_HANDLE_4)); + + // After zesInit: different set initialized + zesDrivers.push_back(createMockDriverWithHandle("driver1", MOCK_HANDLE_1)); + zesDrivers.push_back(createMockDriverWithHandle("driver2", nullptr)); // Not needed for sysman + zesDrivers.push_back(createMockDriverWithHandle("driver3", MOCK_HANDLE_3)); // Sysman initialized this + zesDrivers.push_back(createMockDriverWithHandle("driver4", MOCK_HANDLE_4)); + + // Act + auto uniqueHandles = collectUniqueHandles(allDrivers, zeDrivers, zesDrivers); + + // Assert - should collect all unique handles + EXPECT_EQ(uniqueHandles.size(), 4); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_1), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_2), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_3), uniqueHandles.end()); + EXPECT_NE(uniqueHandles.find(MOCK_HANDLE_4), uniqueHandles.end()); +} + +} // namespace