diff --git a/CMakePresets.json b/CMakePresets.json index 61f23932..836f5905 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -14,6 +14,7 @@ "installDir": "${sourceDir}/build/install/${presetName}", "toolchainFile": "$env{VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake", "cacheVariables": { + "WIL_ENABLE_ASAN": true, "CMAKE_CONFIGURATION_TYPES": "Debug;RelWithDebInfo;Release;MinSizeRel", "CMAKE_CXX_COMPILER": "cl", "CMAKE_C_COMPILER": "cl" @@ -29,9 +30,19 @@ } }, "cacheVariables": { + "WIL_ENABLE_ASAN": false, "CMAKE_CXX_COMPILER": "clang-cl", "CMAKE_C_COMPILER": "clang-cl" } + }, + { + "name": "clang-release", + "inherits": "clang", + "hidden": false, + "cacheVariables": { + "WIL_ENABLE_ASAN": true, + "WIL_ENABLE_UBSAN": true + } } ], "buildPresets": [ @@ -39,19 +50,13 @@ "name": "msvc-debug", "displayName": "MSVC Debug", "configurePreset": "msvc", - "configuration": "Debug", - "cacheVariables": { - "WIL_ENABLE_ASAN": true - } + "configuration": "Debug" }, { "name": "msvc-release", "displayName": "MSVC Release (debuggable)", "configurePreset": "msvc", - "configuration": "RelWithDebInfo", - "cacheVariables": { - "WIL_ENABLE_ASAN": true - } + "configuration": "RelWithDebInfo" }, { "name": "clang-debug", @@ -62,12 +67,8 @@ { "name": "clang-release", "displayName": "clang Release (debuggable)", - "configurePreset": "clang", - "configuration": "RelWithDebInfo", - "cacheVariables": { - "WIL_ENABLE_ASAN": true, - "WIL_ENABLE_UBSAN": true - } + "configurePreset": "clang-release", + "configuration": "RelWithDebInfo" } ], "testPresets": [ diff --git a/_codeql_detected_source_root b/_codeql_detected_source_root new file mode 120000 index 00000000..945c9b46 --- /dev/null +++ b/_codeql_detected_source_root @@ -0,0 +1 @@ +. \ No newline at end of file diff --git a/include/wil/result.h b/include/wil/result.h index f4de6c6b..15ee33c8 100644 --- a/include/wil/result.h +++ b/include/wil/result.h @@ -375,73 +375,91 @@ namespace details_abi template class ThreadLocalStorage { - public: - ThreadLocalStorage(const ThreadLocalStorage&) = delete; - ThreadLocalStorage& operator=(const ThreadLocalStorage&) = delete; - - ThreadLocalStorage() = default; + struct Node + { + Node* next{nullptr}; + DWORD threadId = 0xffffffffU; + T value{}; + }; - ~ThreadLocalStorage() WI_NOEXCEPT + struct Bucket { - for (auto& entry : m_hashArray) + wil::srwlock lock; + Node* head{nullptr}; + + ~Bucket() WI_NOEXCEPT { - Node* pNode = entry; - while (pNode != nullptr) + // Cleanup in a loop rather than recursively + while (head) { - auto pCurrent = pNode; -#pragma warning(push) -#pragma warning(disable : 6001) // https://github.com/microsoft/wil/issues/164 - pNode = pNode->pNext; -#pragma warning(pop) - pCurrent->~Node(); - ::HeapFree(::GetProcessHeap(), 0, pCurrent); + auto tmp = head; + head = tmp->next; + tmp->~Node(); + details::FreeProcessHeap(tmp); } - entry = nullptr; } - } + }; + + Bucket m_hashArray[13]{}; + + public: + ThreadLocalStorage(const ThreadLocalStorage&) = delete; + ThreadLocalStorage& operator=(const ThreadLocalStorage&) = delete; + + ThreadLocalStorage() = default; + ~ThreadLocalStorage() WI_NOEXCEPT = default; // Note: Can return nullptr even when (shouldAllocate == true) upon allocation failure T* GetLocal(bool shouldAllocate = false) WI_NOEXCEPT { + // Get the current thread ID DWORD const threadId = ::GetCurrentThreadId(); + + // Determine the appropriate bucket for this thread size_t const index = ((threadId >> 2) % ARRAYSIZE(m_hashArray)); // Reduce hash collisions; thread IDs are even. - for (auto pNode = m_hashArray[index]; pNode != nullptr; pNode = pNode->pNext) + Bucket& bucket = m_hashArray[index]; + + // Lock the bucket and search for an existing entry { - if (pNode->threadId == threadId) + auto lock = bucket.lock.lock_shared(); + for (auto pNode = bucket.head; pNode != nullptr; pNode = pNode->next) { - return &pNode->value; + if (pNode->threadId == threadId) + { + return &pNode->value; + } } } - if (shouldAllocate) + if (!shouldAllocate) { - if (auto pNewRaw = details::ProcessHeapAlloc(0, sizeof(Node))) - { - auto pNew = new (pNewRaw) Node{threadId}; + return nullptr; + } - Node* pFirst; - do - { - pFirst = m_hashArray[index]; - pNew->pNext = pFirst; - } while (::InterlockedCompareExchangePointer(reinterpret_cast(m_hashArray + index), pNew, pFirst) != - pFirst); + // No entry for us, make a new one and insert it at the head + void* newNodeStore = details::ProcessHeapAlloc(0, sizeof(Node)); + if (!newNodeStore) + { + return nullptr; + } + auto node = new (newNodeStore) Node{nullptr, threadId}; - return &pNew->value; + // Look again and insert the new node + auto lock = bucket.lock.lock_exclusive(); + for (auto pNode = bucket.head; pNode != nullptr; pNode = pNode->next) + { + if (pNode->threadId == threadId) + { + node->~Node(); + details::FreeProcessHeap(node); + return &pNode->value; } } - return nullptr; - } - private: - struct Node - { - DWORD threadId = 0xffffffffU; - Node* pNext = nullptr; - T value{}; - }; - - Node* volatile m_hashArray[10]{}; + node->next = bucket.head; + bucket.head = node; + return &bucket.head->value; + } }; struct ThreadLocalFailureInfo diff --git a/tests/ResultTests.cpp b/tests/ResultTests.cpp index 8939c4a8..ed10c83d 100644 --- a/tests/ResultTests.cpp +++ b/tests/ResultTests.cpp @@ -12,6 +12,10 @@ #include "common.h" +#include +#include +#include + static volatile long objectCount = 0; struct SharedObject { @@ -773,3 +777,61 @@ TEST_CASE("ResultTests::ReportDoesNotChangeLastError", "[result]") LOG_IF_WIN32_BOOL_FALSE(FALSE); REQUIRE(::GetLastError() == ERROR_ABIOS_ERROR); } + +TEST_CASE("ResultTests::ThreadLocalStorage", "[result]") +{ + constexpr int NUM_THREADS = 10; + constexpr int ITERATIONS = 1000; + + wil::details_abi::ThreadLocalStorage storage; + std::atomic errors{0}; + std::vector threads; + + // Create multiple threads that will access the thread local storage concurrently + for (int i = 0; i < NUM_THREADS; ++i) + { + threads.emplace_back([&storage, &errors, i]() { + for (int j = 0; j < ITERATIONS; ++j) + { + // Get or create thread local value + int* pValue = storage.GetLocal(true); + if (!pValue) + { + errors.fetch_add(1); + continue; + } + + // First time should be zero-initialized + if (j == 0 && *pValue != 0) + { + errors.fetch_add(1); + } + + // Set a thread-specific value + *pValue = i * 1000 + j; + + // Verify we get the same pointer on subsequent calls + int* pValue2 = storage.GetLocal(false); + if (pValue != pValue2) + { + errors.fetch_add(1); + } + + // Verify the value is correct + if (*pValue2 != i * 1000 + j) + { + errors.fetch_add(1); + } + } + }); + } + + // Wait for all threads to complete + for (auto& thread : threads) + { + thread.join(); + } + + // Verify no errors occurred + REQUIRE(errors.load() == 0); +}