Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions swift/extractor/mangler/SwiftMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,11 @@ unsigned SwiftMangler::getExtensionIndex(const swift::ExtensionDecl* decl,
if (auto found = preloadedExtensionIndexes.extract(decl)) {
return found.mapped();
}
if (auto parentModule = llvm::dyn_cast<swift::ModuleDecl>(parent)) {
llvm::SmallVector<swift::Decl*> siblings;
parentModule->getTopLevelDecls(siblings);
indexExtensions(siblings);

if (llvm::isa<swift::ModuleDecl>(parent)) {
indexNominalTypeExtensions(decl->getExtendedNominal()->getExtensions(), parent);
} else if (auto iterableParent = llvm::dyn_cast<swift::IterableDeclContext>(parent)) {
indexExtensions(iterableParent->getAllMembers());
indexIterableExtensions(iterableParent->getAllMembers());
} else {
// TODO use a generic logging handle for Swift entities here, once it's available
CODEQL_ASSERT(false, "non-local context must be module or iterable decl context");
Expand All @@ -128,13 +127,46 @@ unsigned SwiftMangler::getExtensionIndex(const swift::ExtensionDecl* decl,
return found.mapped();
}

void SwiftMangler::indexExtensions(llvm::ArrayRef<swift::Decl*> siblings) {
void SwiftMangler::indexNominalTypeExtensions(swift::ExtensionRange unsortedExtensions,
const swift::Decl* parent) {
std::vector<const swift::ExtensionDecl*> extensions;
for (const auto& extension : unsortedExtensions) {
extensions.emplace_back(extension);
}

std::ranges::sort(extensions, [](const auto& e1, const auto& e2) {
if (auto f1 = e1->getSourceFileName()) {
if (auto f2 = e2->getSourceFileName()) {
int result = f1.value().compare(f2.value());
if (result != 0) {
return result < 0;
}
}
}
if (auto o1 = e1->getSourceOrder()) {
if (auto o2 = e2->getSourceOrder()) {
return o1.value() < o2.value();
}
}
return false;
});

auto index = 0u;
for (const auto& extension : extensions) {
if (getParent(extension) == parent) {
preloadedExtensionIndexes.emplace(extension, index);
index++;
}
}
}

void SwiftMangler::indexIterableExtensions(llvm::ArrayRef<swift::Decl*> siblings) {
auto index = 0u;
for (auto sibling : siblings) {
if (sibling->getKind() == swift::DeclKind::Extension) {
preloadedExtensionIndexes.emplace(sibling, index);
++index;
}
++index;
}
}

Expand Down
4 changes: 3 additions & 1 deletion swift/extractor/mangler/SwiftMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ class SwiftMangler : private swift::TypeVisitor<SwiftMangler, SwiftMangledName>,
virtual SwiftMangledName fetch(const swift::TypeBase* type) = 0;
SwiftMangledName fetch(swift::Type type) { return fetch(type.getPointer()); }

void indexExtensions(llvm::ArrayRef<swift::Decl*> siblings);
void indexNominalTypeExtensions(swift::ExtensionRange unsortedExtensions,
const swift::Decl* parent);
void indexIterableExtensions(llvm::ArrayRef<swift::Decl*> siblings);
unsigned int getExtensionIndex(const swift::ExtensionDecl* decl, const swift::Decl* parent);
static SwiftMangledName initMangled(const swift::TypeBase* type);
SwiftMangledName initMangled(const swift::Decl* decl);
Expand Down
Loading