From e6a0138004861bbb6360586892b988d00210e4ba Mon Sep 17 00:00:00 2001 From: Jeroen Ketema Date: Mon, 22 Sep 2025 11:02:21 +0200 Subject: [PATCH] Swift: Assign indexes to extensions looking at all the extensions of a type --- swift/extractor/mangler/SwiftMangler.cpp | 46 ++++++++++++++++++++---- swift/extractor/mangler/SwiftMangler.h | 4 ++- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/swift/extractor/mangler/SwiftMangler.cpp b/swift/extractor/mangler/SwiftMangler.cpp index 7e1f1f0bfe87..91e3359fc4de 100644 --- a/swift/extractor/mangler/SwiftMangler.cpp +++ b/swift/extractor/mangler/SwiftMangler.cpp @@ -112,12 +112,11 @@ unsigned SwiftMangler::getExtensionIndex(const swift::ExtensionDecl* decl, if (auto found = preloadedExtensionIndexes.extract(decl)) { return found.mapped(); } - if (auto parentModule = llvm::dyn_cast(parent)) { - llvm::SmallVector siblings; - parentModule->getTopLevelDecls(siblings); - indexExtensions(siblings); + + if (llvm::isa(parent)) { + indexNominalTypeExtensions(decl->getExtendedNominal()->getExtensions(), parent); } else if (auto iterableParent = llvm::dyn_cast(parent)) { - indexExtensions(iterableParent->getAllMembers()); + indexIterableExtensions(iterableParent->getAllMembers()); } else { // TODO use a generic logging handle for Swift entities here, once it's available CODEQL_ASSERT(false, "non-local context must be module or iterable decl context"); @@ -128,13 +127,46 @@ unsigned SwiftMangler::getExtensionIndex(const swift::ExtensionDecl* decl, return found.mapped(); } -void SwiftMangler::indexExtensions(llvm::ArrayRef siblings) { +void SwiftMangler::indexNominalTypeExtensions(swift::ExtensionRange unsortedExtensions, + const swift::Decl* parent) { + std::vector extensions; + for (const auto& extension : unsortedExtensions) { + extensions.emplace_back(extension); + } + + std::ranges::sort(extensions, [](const auto& e1, const auto& e2) { + if (auto f1 = e1->getSourceFileName()) { + if (auto f2 = e2->getSourceFileName()) { + int result = f1.value().compare(f2.value()); + if (result != 0) { + return result < 0; + } + } + } + if (auto o1 = e1->getSourceOrder()) { + if (auto o2 = e2->getSourceOrder()) { + return o1.value() < o2.value(); + } + } + return false; + }); + + auto index = 0u; + for (const auto& extension : extensions) { + if (getParent(extension) == parent) { + preloadedExtensionIndexes.emplace(extension, index); + index++; + } + } +} + +void SwiftMangler::indexIterableExtensions(llvm::ArrayRef siblings) { auto index = 0u; for (auto sibling : siblings) { if (sibling->getKind() == swift::DeclKind::Extension) { preloadedExtensionIndexes.emplace(sibling, index); + ++index; } - ++index; } } diff --git a/swift/extractor/mangler/SwiftMangler.h b/swift/extractor/mangler/SwiftMangler.h index 2e3acbb9103c..d0b1b12f7e8b 100644 --- a/swift/extractor/mangler/SwiftMangler.h +++ b/swift/extractor/mangler/SwiftMangler.h @@ -112,7 +112,9 @@ class SwiftMangler : private swift::TypeVisitor, virtual SwiftMangledName fetch(const swift::TypeBase* type) = 0; SwiftMangledName fetch(swift::Type type) { return fetch(type.getPointer()); } - void indexExtensions(llvm::ArrayRef siblings); + void indexNominalTypeExtensions(swift::ExtensionRange unsortedExtensions, + const swift::Decl* parent); + void indexIterableExtensions(llvm::ArrayRef siblings); unsigned int getExtensionIndex(const swift::ExtensionDecl* decl, const swift::Decl* parent); static SwiftMangledName initMangled(const swift::TypeBase* type); SwiftMangledName initMangled(const swift::Decl* decl);