diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index 50238fef5fc39..ac40b5d206fa0 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -725,6 +725,10 @@ class alignas(1 << PatternAlignInBits) StmtConditionElement { bool rebindsSelf(ASTContext &Ctx, bool requiresCaptureListRef = false, bool requireLoadExpr = false) const; + /// Returns the synthesized RHS for a shorthand if let (eg. `if let x`), or + /// null if this element does not represent a shorthand if let. + Expr *getSynthesizedShorthandInitOrNull() const; + SourceLoc getStartLoc() const; SourceLoc getEndLoc() const; SourceRange getSourceRange() const; diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index 45abe32bce521..835f8c736eb44 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -597,6 +597,28 @@ bool StmtConditionElement::rebindsSelf(ASTContext &Ctx, return false; } +Expr *StmtConditionElement::getSynthesizedShorthandInitOrNull() const { + auto *init = getInitializerOrNull(); + if (!init) + return nullptr; + + auto *pattern = dyn_cast_or_null(getPattern()); + if (!pattern) + return nullptr; + + auto *var = pattern->getSubPattern()->getSingleVar(); + if (!var) + return nullptr; + + // If the right-hand side has the same location as the variable, it was + // synthesized. + if (var->getLoc().isValid() && var->getLoc() == init->getStartLoc() && + init->getStartLoc() == init->getEndLoc()) { + return init; + } + return nullptr; +} + SourceRange ConditionalPatternBindingInfo::getSourceRange() const { SourceLoc Start; if (IntroducerLoc.isValid()) diff --git a/lib/Refactoring/Async/AsyncConverter.cpp b/lib/Refactoring/Async/AsyncConverter.cpp index 1b4bc8708b79f..75b53a6552cb4 100644 --- a/lib/Refactoring/Async/AsyncConverter.cpp +++ b/lib/Refactoring/Async/AsyncConverter.cpp @@ -424,6 +424,11 @@ bool AsyncConverter::walkToDeclPost(Decl *D) { #define PLACEHOLDER_START "<#" #define PLACEHOLDER_END "#>" bool AsyncConverter::walkToExprPre(Expr *E) { + // We've already added any shorthand if declaration, don't add its + // synthesized initializer as well. + if (shorthandIfInits.contains(E)) + return true; + // TODO: Handle Result.get as well if (auto *DRE = dyn_cast(E)) { if (auto *D = DRE->getDecl()) { @@ -530,6 +535,15 @@ bool AsyncConverter::walkToExprPost(Expr *E) { #undef PLACEHOLDER_END bool AsyncConverter::walkToStmtPre(Stmt *S) { + // Keep track of any shorthand initializer expressions + if (auto *labeledConditional = dyn_cast(S)) { + for (const auto &condition : labeledConditional->getCond()) { + if (auto *init = condition.getSynthesizedShorthandInitOrNull()) { + shorthandIfInits.insert(init); + } + } + } + // CaseStmt has an implicit BraceStmt inside it, which *should* start a new // scope, so don't check isImplicit here. if (startsNewScope(S)) { diff --git a/lib/Refactoring/Async/AsyncRefactoring.h b/lib/Refactoring/Async/AsyncRefactoring.h index 89a7ee2f9375f..e1773deb5597f 100644 --- a/lib/Refactoring/Async/AsyncRefactoring.h +++ b/lib/Refactoring/Async/AsyncRefactoring.h @@ -969,6 +969,10 @@ class AsyncConverter : private SourceEntityWalker { SmallString<0> Buffer; llvm::raw_svector_ostream OS; + // Any initializer expressions in a shorthand if that we need to skip (as it + // points to the same identifier as the declaration itself). + llvm::DenseSet shorthandIfInits; + // Decls where any force unwrap or optional chain of that decl should be // elided, e.g for a previously optional closure parameter that has become a // non-optional local. diff --git a/lib/Sema/TypeCheckEffects.cpp b/lib/Sema/TypeCheckEffects.cpp index 8d452a84f39f8..92f043575a7a6 100644 --- a/lib/Sema/TypeCheckEffects.cpp +++ b/lib/Sema/TypeCheckEffects.cpp @@ -4713,37 +4713,7 @@ class CheckEffectsCoverage : public EffectsHandlingWalker // Make a note of any initializers that are the synthesized right-hand side // for an "if let x". for (const auto &condition: stmt->getCond()) { - switch (condition.getKind()) { - case StmtConditionElement::CK_Availability: - case StmtConditionElement::CK_Boolean: - case StmtConditionElement::CK_HasSymbol: - continue; - - case StmtConditionElement::CK_PatternBinding: - break; - } - - auto init = condition.getInitializer(); - if (!init) - continue; - - auto pattern = condition.getPattern(); - if (!pattern) - continue; - - auto optPattern = dyn_cast(pattern); - if (!optPattern) - continue; - - auto var = optPattern->getSubPattern()->getSingleVar(); - if (!var) - continue; - - // If the right-hand side has the same location as the variable, it was - // synthesized. - if (var->getLoc().isValid() && - var->getLoc() == init->getStartLoc() && - init->getStartLoc() == init->getEndLoc()) + if (auto *init = condition.getSynthesizedShorthandInitOrNull()) synthesizedIfLetInitializers.insert(init); } } diff --git a/test/refactoring/ConvertAsync/convert_shorthand_if.swift b/test/refactoring/ConvertAsync/convert_shorthand_if.swift new file mode 100644 index 0000000000000..6709ef2185d5f --- /dev/null +++ b/test/refactoring/ConvertAsync/convert_shorthand_if.swift @@ -0,0 +1,25 @@ +// REQUIRES: concurrency + +// RUN: %empty-directory(%t) + +func foo(_ fn: @escaping (String, Error?) -> Void) {} +func foo() async throws -> String { return "" } + +// RUN: %refactor-check-compiles -convert-to-async -dump-text -source-filename %s -pos=%(line+1):1 | %FileCheck %s +func shorthandIf(completion: @escaping (String?, Error?) -> Void) { + foo { str, error in + if let error { + completion(nil, error) + } else { + completion(str, nil) + } + } +} +// CHECK: func shorthandIf() async throws -> String { +// CHECK-NEXT: return try await withCheckedThrowingContinuation { continuation in +// CHECK-NEXT: foo { str, error in +// CHECK-NEXT: if let error { +// CHECK-NEXT: continuation.resume(throwing: error) +// CHECK-NEXT: } else { +// CHECK-NEXT: continuation.resume(returning: str) +// CHECK-NEXT: }