diff --git a/include/swift/AST/ASTBridging.h b/include/swift/AST/ASTBridging.h index d4f17ace66efe..34891d5921669 100644 --- a/include/swift/AST/ASTBridging.h +++ b/include/swift/AST/ASTBridging.h @@ -2404,13 +2404,14 @@ BridgedFallthroughStmt_createParsed(swift::SourceLoc loc, BridgedDeclContext cDC); SWIFT_NAME("BridgedForEachStmt.createParsed(_:labelInfo:forLoc:tryLoc:awaitLoc:" - "unsafeLoc:pattern:inLoc:sequence:whereLoc:whereExpr:body:)") + "unsafeLoc:pattern:inLoc:sequence:whereLoc:whereExpr:body:declContext:)") BridgedForEachStmt BridgedForEachStmt_createParsed( BridgedASTContext cContext, BridgedLabeledStmtInfo cLabelInfo, swift::SourceLoc forLoc, swift::SourceLoc tryLoc, swift::SourceLoc awaitLoc, swift::SourceLoc unsafeLoc, BridgedPattern cPat, swift::SourceLoc inLoc, BridgedExpr cSequence, swift::SourceLoc whereLoc, - BridgedNullableExpr cWhereExpr, BridgedBraceStmt cBody); + BridgedNullableExpr cWhereExpr, BridgedBraceStmt cBody, + BridgedDeclContext cDeclContext); SWIFT_NAME("BridgedGuardStmt.createParsed(_:guardLoc:conds:body:)") BridgedGuardStmt BridgedGuardStmt_createParsed(BridgedASTContext cContext, diff --git a/include/swift/AST/Expr.h b/include/swift/AST/Expr.h index 12a1946d35bc8..0b18d0eeae003 100644 --- a/include/swift/AST/Expr.h +++ b/include/swift/AST/Expr.h @@ -6724,6 +6724,27 @@ class MacroExpansionExpr final : public Expr, } }; +/// OpaqueExpr - created to serve as an indirection to a ForEachStmt's sequence +/// expr and where clause to avoid visiting it twice in the ASTWalker after +/// having desugared the loop. This will only be processed in SILGen to emit +/// the underlying expression. +class OpaqueExpr final : public Expr { + Expr *OriginalExpr; + +public: + OpaqueExpr(Expr* originalExpr) + : Expr(ExprKind::Opaque, /*implicit*/ true, originalExpr->getType()), + OriginalExpr(originalExpr) {} + + Expr *getOriginalExpr() const { return OriginalExpr; } + SourceLoc getStartLoc() const { return OriginalExpr->getStartLoc(); } + SourceLoc getEndLoc() const { return OriginalExpr->getEndLoc(); } + + static bool classof(const Expr *E) { + return E->getKind() == ExprKind::Opaque; + } +}; + inline bool Expr::isInfixOperator() const { return isa(this) || isa(this) || isa(this) || isa(this); diff --git a/include/swift/AST/ExprNodes.def b/include/swift/AST/ExprNodes.def index a193b8d699ec0..6bfa3d17c29de 100644 --- a/include/swift/AST/ExprNodes.def +++ b/include/swift/AST/ExprNodes.def @@ -218,6 +218,7 @@ EXPR(Tap, Expr) UNCHECKED_EXPR(TypeJoin, Expr) EXPR(MacroExpansion, Expr) EXPR(TypeValue, Expr) +EXPR(Opaque, Expr) // Don't forget to update the LAST_EXPR below when adding a new Expr here. LAST_EXPR(TypeValue) diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index 50238fef5fc39..74ea5e3a7f5ad 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -1003,13 +1003,13 @@ class ForEachStmt : public LabeledStmt { SourceLoc WhereLoc; Expr *WhereExpr = nullptr; BraceStmt *Body; + DeclContext* DC = nullptr; // Set by Sema: ProtocolConformanceRef sequenceConformance = ProtocolConformanceRef(); Type sequenceType; - PatternBindingDecl *iteratorVar = nullptr; Expr *nextCall = nullptr; - OpaqueValueExpr *elementExpr = nullptr; + BraceStmt *desugaredStmt = nullptr; Expr *convertElementExpr = nullptr; public: @@ -1017,7 +1017,7 @@ class ForEachStmt : public LabeledStmt { SourceLoc AwaitLoc, SourceLoc UnsafeLoc, Pattern *Pat, SourceLoc InLoc, Expr *Sequence, SourceLoc WhereLoc, Expr *WhereExpr, BraceStmt *Body, - std::optional implicit = std::nullopt) + DeclContext* DC, std::optional implicit = std::nullopt) : LabeledStmt(StmtKind::ForEach, getDefaultImplicitFlag(implicit, ForLoc), LabelInfo), ForLoc(ForLoc), TryLoc(TryLoc), AwaitLoc(AwaitLoc), UnsafeLoc(UnsafeLoc), @@ -1026,15 +1026,9 @@ class ForEachStmt : public LabeledStmt { setPattern(Pat); } - void setIteratorVar(PatternBindingDecl *var) { iteratorVar = var; } - PatternBindingDecl *getIteratorVar() const { return iteratorVar; } - void setNextCall(Expr *next) { nextCall = next; } Expr *getNextCall() const { return nextCall; } - void setElementExpr(OpaqueValueExpr *expr) { elementExpr = expr; } - OpaqueValueExpr *getElementExpr() const { return elementExpr; } - void setConvertElementExpr(Expr *expr) { convertElementExpr = expr; } Expr *getConvertElementExpr() const { return convertElementExpr; } @@ -1076,20 +1070,23 @@ class ForEachStmt : public LabeledStmt { Expr *getParsedSequence() const { return Sequence; } void setParsedSequence(Expr *S) { Sequence = S; } - /// Type-checked version of the sequence or nullptr if this statement - /// yet to be type-checked. - Expr *getTypeCheckedSequence() const; - /// getBody - Retrieve the body of the loop. BraceStmt *getBody() const { return Body; } void setBody(BraceStmt *B) { Body = B; } SourceLoc getStartLoc() const { return getLabelLocOrKeywordLoc(ForLoc); } SourceLoc getEndLoc() const { return Body->getEndLoc(); } + + DeclContext *getDeclContext() const { return DC; } + void setDeclContext(DeclContext *newDC) { DC = newDC; } static bool classof(const Stmt *S) { return S->getKind() == StmtKind::ForEach; } + + BraceStmt* desugar(); + BraceStmt* getDesugaredStmt() const { return desugaredStmt; } + void setDesugaredStmt(BraceStmt* newStmt) { desugaredStmt = newStmt; } }; /// A pattern and an optional guard expression used in a 'case' statement. @@ -1541,6 +1538,31 @@ class DoCatchStmt final } }; +/// OpaqueStmt - created to serve as an indirection to a ForEachStmt's body +/// to avoid visiting it twice in the ASTWalker after having desugared the loop. +/// This ensures we only visit the body once, and this OpaqueStmt will only be +/// visited to emit the underlying statement in SILGen. +class OpaqueStmt final : public Stmt { + SourceLoc StartLoc; + SourceLoc EndLoc; + BraceStmt *Body; // FIXME: should I just use Stmt * so that this is more versatile? + // If not, should the class be renamed to be more specific? + public: + OpaqueStmt(BraceStmt* body, SourceLoc startLoc, SourceLoc endLoc) + : Stmt(StmtKind::Opaque, true /*always implicit*/), + StartLoc(startLoc), EndLoc(endLoc), Body(body) {} + + SourceLoc getLoc() const { return StartLoc; } + SourceLoc getStartLoc() const { return StartLoc; } + SourceLoc getEndLoc() const { return EndLoc; } + + BraceStmt* getUnderlyingStmt() { return Body; } + + static bool classof(const Stmt *S) { + return S->getKind() == StmtKind::Opaque; + } +}; + /// BreakStmt - The "break" and "break label" statement. class BreakStmt : public Stmt { SourceLoc Loc; diff --git a/include/swift/AST/StmtNodes.def b/include/swift/AST/StmtNodes.def index b35149ad7f437..a3da49f053814 100644 --- a/include/swift/AST/StmtNodes.def +++ b/include/swift/AST/StmtNodes.def @@ -61,6 +61,7 @@ ABSTRACT_STMT(Labeled, Stmt) LABELED_STMT(ForEach, LabeledStmt) LABELED_STMT(Switch, LabeledStmt) STMT_RANGE(Labeled, If, Switch) +STMT(Opaque, Stmt) STMT(Case, Stmt) STMT(Break, Stmt) STMT(Continue, Stmt) diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index 6125c326c1295..146fd21031cfa 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -5591,6 +5591,25 @@ class IsCustomAvailabilityDomainPermanentlyEnabled } }; +class DesugarForEachStmtRequest + : public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + BraceStmt *evaluate(Evaluator &evaluator, ForEachStmt *FES) const; + +public: + bool isCached() const { return true; } + std::optional getCachedResult() const; + void cacheResult(BraceStmt *stmt) const; +}; + #define SWIFT_TYPEID_ZONE TypeChecker #define SWIFT_TYPEID_HEADER "swift/AST/TypeCheckerTypeIDZone.def" #include "swift/Basic/DefineTypeIDZone.h" diff --git a/include/swift/AST/TypeCheckerTypeIDZone.def b/include/swift/AST/TypeCheckerTypeIDZone.def index 4ebd381402475..f68bdafff3ebc 100644 --- a/include/swift/AST/TypeCheckerTypeIDZone.def +++ b/include/swift/AST/TypeCheckerTypeIDZone.def @@ -674,3 +674,7 @@ SWIFT_REQUEST(TypeChecker, IsCustomAvailabilityDomainPermanentlyEnabled, SWIFT_REQUEST(TypeChecker, EmitPerformanceHints, evaluator::SideEffect(SourceFile *), Cached, NoLocationInfo) + +SWIFT_REQUEST(TypeChecker, DesugarForEachStmtRequest, + Stmt*(const ForEachStmt*), + Cached, NoLocationInfo) diff --git a/include/swift/Sema/ConstraintLocator.h b/include/swift/Sema/ConstraintLocator.h index fdea2f5e50966..6d1b0079c8e32 100644 --- a/include/swift/Sema/ConstraintLocator.h +++ b/include/swift/Sema/ConstraintLocator.h @@ -83,6 +83,8 @@ enum ContextualTypePurpose : uint8_t { CTP_ExprPattern, ///< `~=` operator application associated with expression /// pattern. + + CTP_ForEachElement, ///< Element expression associated with `for-in` loop. }; namespace constraints { diff --git a/include/swift/Sema/SyntacticElementTarget.h b/include/swift/Sema/SyntacticElementTarget.h index 47d084bf2bb98..d0efb8b893a9b 100644 --- a/include/swift/Sema/SyntacticElementTarget.h +++ b/include/swift/Sema/SyntacticElementTarget.h @@ -41,12 +41,6 @@ struct SequenceIterationInfo { /// The type of the pattern that matches the elements. Type initType; - - /// Implicit `$iterator = .makeIterator()` - PatternBindingDecl *makeIteratorVar; - - /// Implicit `$iterator.next()` call. - Expr *nextCall; }; /// Describes information about a for-in loop over a pack that needs to be @@ -605,6 +599,7 @@ class SyntacticElementTarget { case CTP_Initialization: case CTP_ForEachSequence: case CTP_ExprPattern: + case CTP_ForEachElement: break; default: assert(false && "Unexpected contextual type purpose"); diff --git a/lib/AST/ASTDumper.cpp b/lib/AST/ASTDumper.cpp index 218d4c2fe4bbc..67301f1d7a32e 100644 --- a/lib/AST/ASTDumper.cpp +++ b/lib/AST/ASTDumper.cpp @@ -3252,6 +3252,10 @@ class PrintStmt : public StmtVisitor, printFlag(S->TrailingSemiLoc.isValid(), "trailing_semi"); } + void visitOpaqueStmt(OpaqueStmt *S, Label label){ + visitBraceStmt(S->getUnderlyingStmt(), label); + } + void visitBraceStmt(BraceStmt *S, Label label) { printCommon(S, "brace_stmt", label); printList(S->getElements(), [&](auto &Elt, Label label) { @@ -3332,20 +3336,15 @@ class PrintStmt : public StmtVisitor, printRec(S->getWhere(), Label::always("where")); } printRec(S->getParsedSequence(), Label::optional("parsed_sequence")); - if (S->getIteratorVar()) { - printRec(S->getIteratorVar(), Label::optional("iterator_var")); - } - if (S->getNextCall()) { - printRec(S->getNextCall(), Label::optional("next_call")); - } if (S->getConvertElementExpr()) { printRec(S->getConvertElementExpr(), Label::optional("convert_element_expr")); } - if (S->getElementExpr()) { - printRec(S->getElementExpr(), Label::optional("element_expr")); - } + printRec(S->getBody(), Label::optional("body")); + + printRec(S->getDesugaredStmt(), Label::optional("desugared_loop")); + printFoot(); } void visitBreakStmt(BreakStmt *S, Label label) { @@ -4237,6 +4236,10 @@ class PrintExpr : public ExprVisitor, printFoot(); } + void visitOpaqueExpr(OpaqueExpr *E, Label label){ + visit(E->getOriginalExpr(), label); + } + void visitPropertyWrapperValuePlaceholderExpr( PropertyWrapperValuePlaceholderExpr *E, Label label) { printCommon(E, "property_wrapper_value_placeholder_expr", label); diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp index a135b5c1f8468..aec38dd4258f2 100644 --- a/lib/AST/ASTPrinter.cpp +++ b/lib/AST/ASTPrinter.cpp @@ -5633,6 +5633,16 @@ void PrintAST::visitTypeValueExpr(TypeValueExpr *expr) { expr->getType()->print(Printer, Options); } +void PrintAST::visitOpaqueExpr(OpaqueExpr *expr) { + // FIXME: unsure about this, maybe do nothing? + visit(expr->getOriginalExpr()); +} + +void PrintAST::visitOpaqueStmt(OpaqueStmt *stmt) { + // FIXME: unsure about this, maybe do nothing? + printBraceStmt(stmt->getUnderlyingStmt()); +} + void PrintAST::visitBraceStmt(BraceStmt *stmt) { printBraceStmt(stmt); } @@ -5810,7 +5820,7 @@ void PrintAST::visitForEachStmt(ForEachStmt *stmt) { printPattern(stmt->getPattern()); Printer << " " << tok::kw_in << " "; // FIXME: print container - if (auto *seq = stmt->getTypeCheckedSequence()) { + if (auto *seq = stmt->getParsedSequence()) { // Look through the call to '.makeIterator()' if (auto *CE = dyn_cast(seq)) { diff --git a/lib/AST/ASTScopeCreation.cpp b/lib/AST/ASTScopeCreation.cpp index b947c44dcc8d8..d7b63c679cab5 100644 --- a/lib/AST/ASTScopeCreation.cpp +++ b/lib/AST/ASTScopeCreation.cpp @@ -414,6 +414,7 @@ class NodeAdder VISIT_AND_IGNORE(ContinueStmt) VISIT_AND_IGNORE(FallthroughStmt) VISIT_AND_IGNORE(FailStmt) + VISIT_AND_IGNORE(OpaqueStmt) #undef VISIT_AND_IGNORE diff --git a/lib/AST/ASTVerifier.cpp b/lib/AST/ASTVerifier.cpp index 1f4d3eced7d6f..3223c399c761e 100644 --- a/lib/AST/ASTVerifier.cpp +++ b/lib/AST/ASTVerifier.cpp @@ -802,11 +802,6 @@ class Verifier : public ASTWalker { ForEachPatternSequences.insert(expansion); } - if (!S->getElementExpr()) - return true; - - assert(!OpaqueValues.count(S->getElementExpr())); - OpaqueValues[S->getElementExpr()] = 0; return true; } @@ -819,12 +814,6 @@ class Verifier : public ASTWalker { // Clean up for real. cleanup(expansion); } - - if (!S->getElementExpr()) - return; - - assert(OpaqueValues.count(S->getElementExpr())); - OpaqueValues.erase(S->getElementExpr()); } bool shouldVerify(InterpolatedStringLiteralExpr *expr) { diff --git a/lib/AST/ASTWalker.cpp b/lib/AST/ASTWalker.cpp index e39f3c4c3c7ed..1f7ee271036d9 100644 --- a/lib/AST/ASTWalker.cpp +++ b/lib/AST/ASTWalker.cpp @@ -644,6 +644,8 @@ class Traversal : public ASTVisitorgetOpaqueValuePlaceholder()) { @@ -1896,6 +1898,11 @@ Stmt *Traversal::visitPoundAssertStmt(PoundAssertStmt *S) { return S; } +Stmt* Traversal::visitOpaqueStmt(OpaqueStmt* OS){ + // We do not want to visit it. + return OS; +} + Stmt *Traversal::visitBraceStmt(BraceStmt *BS) { for (auto &Elem : BS->getElements()) { if (auto *SubExpr = Elem.dyn_cast()) { @@ -2066,28 +2073,11 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) { return nullptr; } - // The iterator decl is built directly on top of the sequence - // expression, so don't visit both. - // - // If for-in is already type-checked, the type-checked version - // of the sequence is going to be visited as part of `iteratorVar`. - if (auto IteratorVar = S->getIteratorVar()) { - if (doIt(IteratorVar)) - return nullptr; - - if (auto NextCall = S->getNextCall()) { - if ((NextCall = doIt(NextCall))) - S->setNextCall(NextCall); - else - return nullptr; - } - } else { - if (Expr *Sequence = S->getParsedSequence()) { + if (Expr *Sequence = S->getParsedSequence()) { if ((Sequence = doIt(Sequence))) S->setParsedSequence(Sequence); else return nullptr; - } } if (Expr *Where = S->getWhere()) { @@ -2111,6 +2101,13 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) { return nullptr; } + if (Stmt *Desugared = S->getDesugaredStmt()) { + if ((Desugared = doIt(Desugared))) + S->setDesugaredStmt(cast(Desugared)); + else + return nullptr; + } + return S; } diff --git a/lib/AST/Bridging/StmtBridging.cpp b/lib/AST/Bridging/StmtBridging.cpp index e8c11bb98d0cf..b285b3d120c12 100644 --- a/lib/AST/Bridging/StmtBridging.cpp +++ b/lib/AST/Bridging/StmtBridging.cpp @@ -191,11 +191,11 @@ BridgedForEachStmt BridgedForEachStmt_createParsed( SourceLoc forLoc, SourceLoc tryLoc, SourceLoc awaitLoc, SourceLoc unsafeLoc, BridgedPattern cPat, SourceLoc inLoc, BridgedExpr cSequence, SourceLoc whereLoc, BridgedNullableExpr cWhereExpr, - BridgedBraceStmt cBody) { + BridgedBraceStmt cBody, BridgedDeclContext cDeclContext) { return new (cContext.unbridged()) ForEachStmt(cLabelInfo.unbridged(), forLoc, tryLoc, awaitLoc, unsafeLoc, cPat.unbridged(), inLoc, cSequence.unbridged(), whereLoc, - cWhereExpr.unbridged(), cBody.unbridged()); + cWhereExpr.unbridged(), cBody.unbridged(), cDeclContext.unbridged()); } BridgedGuardStmt BridgedGuardStmt_createParsed(BridgedASTContext cContext, diff --git a/lib/AST/Expr.cpp b/lib/AST/Expr.cpp index b1e96674cbf59..da233cd0e4e97 100644 --- a/lib/AST/Expr.cpp +++ b/lib/AST/Expr.cpp @@ -466,6 +466,7 @@ ConcreteDeclRef Expr::getReferencedDecl(bool stopAtParenExpr) const { NO_REFERENCE(TypeJoin); SIMPLE_REFERENCE(MacroExpansion, getMacroRef); NO_REFERENCE(TypeValue); + NO_REFERENCE(Opaque); #undef SIMPLE_REFERENCE #undef NO_REFERENCE @@ -840,6 +841,7 @@ bool Expr::canAppendPostfixExpression(bool appendingPostfixOperator) const { case ExprKind::MacroExpansion: case ExprKind::CurrentContextIsolation: + case ExprKind::Opaque: /* FIXME: unsure about this */ return true; } @@ -1044,6 +1046,7 @@ bool Expr::isValidParentOfTypeExpr(Expr *typeExpr) const { case ExprKind::ActorIsolationErasure: case ExprKind::ExtractFunctionIsolation: case ExprKind::UnsafeCast: + case ExprKind::Opaque: return false; } diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index 45abe32bce521..ce46f1d91af10 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -89,6 +89,8 @@ StringRef Stmt::getDescriptiveKindName(StmtKind K) { return "discard"; case StmtKind::PoundAssert: return "#assert"; + case StmtKind::Opaque: + return "opaque"; } llvm_unreachable("Unhandled case in switch!"); } @@ -453,13 +455,6 @@ void ForEachStmt::setPattern(Pattern *p) { Pat->markOwnedByStatement(this); } -Expr *ForEachStmt::getTypeCheckedSequence() const { - if (auto *expansion = dyn_cast(getParsedSequence())) - return expansion; - - return iteratorVar ? iteratorVar->getInit(/*index=*/0) : nullptr; -} - DoCatchStmt *DoCatchStmt::create(DeclContext *dc, LabeledStmtInfo labelInfo, SourceLoc doLoc, SourceLoc throwsLoc, TypeLoc thrownType, Stmt *body, @@ -486,6 +481,13 @@ bool DoCatchStmt::isSyntacticallyExhaustive() const { return false; } +BraceStmt *ForEachStmt::desugar() { + auto &ctx = this->getDeclContext()->getASTContext(); + return evaluateOrDefault(ctx.evaluator, + DesugarForEachStmtRequest{this}, + nullptr); +} + Type DoCatchStmt::getExplicitCaughtType() const { ASTContext &ctx = DC->getASTContext(); return CatchNode(const_cast(this)).getExplicitCaughtType(ctx); diff --git a/lib/AST/TypeCheckRequests.cpp b/lib/AST/TypeCheckRequests.cpp index fbb490e0932ed..90367c754f694 100644 --- a/lib/AST/TypeCheckRequests.cpp +++ b/lib/AST/TypeCheckRequests.cpp @@ -2890,3 +2890,20 @@ void IsCustomAvailabilityDomainPermanentlyEnabled::cacheResult( domain->flags.isPermanentlyEnabledComputed = true; domain->flags.isPermanentlyEnabled = isPermanentlyEnabled; } + +//----------------------------------------------------------------------------// +// DesugarForEachStmtRequest computation. +//----------------------------------------------------------------------------// +std::optional DesugarForEachStmtRequest::getCachedResult() const { + auto *fes = std::get<0>(getStorage()); + auto* desugaredStmt = fes->getDesugaredStmt(); + if (!desugaredStmt){ + return std::nullopt; + } + return desugaredStmt; +} + +void DesugarForEachStmtRequest::cacheResult(BraceStmt *stmt) const { + auto *fes = std::get<0>(getStorage()); + fes->setDesugaredStmt(stmt); +} diff --git a/lib/ASTGen/Sources/ASTGen/Stmts.swift b/lib/ASTGen/Sources/ASTGen/Stmts.swift index db3bbc465f17c..9d4f4eaab9041 100644 --- a/lib/ASTGen/Sources/ASTGen/Stmts.swift +++ b/lib/ASTGen/Sources/ASTGen/Stmts.swift @@ -389,7 +389,8 @@ extension ASTGenVisitor { sequence: self.generate(expr: node.sequence), whereLoc: self.generateSourceLoc(node.whereClause?.whereKeyword), whereExpr: self.generate(expr: node.whereClause?.condition), - body: self.generate(codeBlock: node.body) + body: self.generate(codeBlock: node.body), + declContext: self.declContext ) } diff --git a/lib/Parse/ParseStmt.cpp b/lib/Parse/ParseStmt.cpp index 921a865c297e7..df9a665e37e4a 100644 --- a/lib/Parse/ParseStmt.cpp +++ b/lib/Parse/ParseStmt.cpp @@ -2506,7 +2506,7 @@ ParserResult Parser::parseStmtForEach(LabeledStmtInfo LabelInfo) { new (Context) ForEachStmt(LabelInfo, ForLoc, TryLoc, AwaitLoc, UnsafeLoc, pattern.get(), InLoc, Container.get(), WhereLoc, Where.getPtrOrNull(), - Body.get())); + Body.get(), CurDeclContext)); } /// diff --git a/lib/SILGen/ASTVisitor.h b/lib/SILGen/ASTVisitor.h index 919bf08904ce7..28f8b127dc5dc 100644 --- a/lib/SILGen/ASTVisitor.h +++ b/lib/SILGen/ASTVisitor.h @@ -56,6 +56,10 @@ class ASTVisitor : public swift::ASTVisitorgetOriginalExpr()); +} + RValue RValueEmitter::visitOpaqueValueExpr(OpaqueValueExpr *E, SGFContext C) { auto found = SGF.OpaqueValues.find(E); assert(found != SGF.OpaqueValues.end()); diff --git a/lib/SILGen/SILGenStmt.cpp b/lib/SILGen/SILGenStmt.cpp index 78ab99c3aa731..f6d3bc41afe15 100644 --- a/lib/SILGen/SILGenStmt.cpp +++ b/lib/SILGen/SILGenStmt.cpp @@ -1404,10 +1404,14 @@ void StmtEmitter::visitRepeatWhileStmt(RepeatWhileStmt *S) { SGF.BreakContinueDestStack.pop_back(); } +void StmtEmitter::visitOpaqueStmt(OpaqueStmt *S) { + visitBraceStmt(S->getUnderlyingStmt()); +} + void StmtEmitter::visitForEachStmt(ForEachStmt *S) { - if (auto *expansion = - dyn_cast(S->getTypeCheckedSequence())) { + if (auto *expansion = + dyn_cast(S->getParsedSequence())) { auto formalPackType = dyn_cast( PackType::get(SGF.getASTContext(), expansion->getType()) ->getCanonicalType()); @@ -1442,170 +1446,9 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) { return; } - // Emit the 'iterator' variable that we'll be using for iteration. - LexicalScope OuterForScope(SGF, CleanupLocation(S)); - SGF.emitPatternBinding(S->getIteratorVar(), - /*index=*/0, /*debuginfo*/ true); - - // If we ever reach an unreachable point, stop emitting statements. - // This will need revision if we ever add goto. - if (!SGF.B.hasValidInsertionPoint()) return; - - // If generator's optional result is address-only, create a stack allocation - // to hold the results. This will be initialized on every entry into the loop - // header and consumed by the loop body. On loop exit, the terminating value - // will be in the buffer. - CanType optTy = S->getNextCall()->getType()->getCanonicalType(); - auto &optTL = SGF.getTypeLowering(optTy); - - SILValue addrOnlyBuf; - bool nextResultTyIsAddressOnly = - optTL.isAddressOnly() && SGF.silConv.useLoweredAddresses(); - - if (nextResultTyIsAddressOnly) - addrOnlyBuf = SGF.emitTemporaryAllocation(S, optTL.getLoweredType()); - - // Create a new basic block and jump into it. - JumpDest loopDest = createJumpDest(S->getBody()); - SGF.B.emitBlock(loopDest.getBlock(), S); - - // Set the destinations for 'break' and 'continue'. - JumpDest endDest = createJumpDest(S->getBody()); - SGF.BreakContinueDestStack.push_back({ S, endDest, loopDest }); - - bool hasElementConversion = S->getElementExpr(); - auto buildElementRValue = [&](SGFContext ctx) { - RValue result; - result = SGF.emitRValue(S->getNextCall(), - hasElementConversion ? SGFContext() : ctx); - return result; - }; - - ManagedValue nextBufOrElement; - // Then emit the loop destination block. - // - // Advance the generator. Use a scope to ensure that any temporary stack - // allocations in the subexpression are immediately released. - if (nextResultTyIsAddressOnly) { - // Create the initialization outside of the innerForScope so that the - // innerForScope doesn't clean it up. - auto nextInit = SGF.useBufferAsTemporary(addrOnlyBuf, optTL); - { - ArgumentScope innerForScope(SGF, SILLocation(S)); - SILLocation loc = SILLocation(S); - RValue result = buildElementRValue(SGFContext(nextInit.get())); - if (!result.isInContext()) { - ArgumentSource(SILLocation(S->getTypeCheckedSequence()), - std::move(result).ensurePlusOne(SGF, loc)) - .forwardInto(SGF, nextInit.get()); - } - innerForScope.pop(); - } - nextBufOrElement = nextInit->getManagedAddress(); - } else { - ArgumentScope innerForScope(SGF, SILLocation(S)); - nextBufOrElement = innerForScope.popPreservingValue( - buildElementRValue(SGFContext()) - .getAsSingleValue(SGF, SILLocation(S))); - } - - SILBasicBlock *failExitingBlock = createBasicBlock(); - SwitchEnumBuilder switchEnumBuilder(SGF.B, S, nextBufOrElement); - - auto convertElementRValue = [&](ManagedValue inputValue, SGFContext ctx) -> ManagedValue { - SILGenFunction::OpaqueValueRAII pushOpaqueValue(SGF, S->getElementExpr(), - inputValue); - return SGF.emitRValue(S->getConvertElementExpr(), ctx) - .getAsSingleValue(SGF, SILLocation(S)); - }; - - switchEnumBuilder.addOptionalSomeCase( - createBasicBlock(), loopDest.getBlock(), - [&](ManagedValue inputValue, SwitchCaseFullExpr &&scope) { - SGF.emitProfilerIncrement(S->getBody()); - - // Emit the loop body. - // The declared variable(s) for the current element are destroyed - // at the end of each loop iteration. - { - Scope innerForScope(SGF.Cleanups, CleanupLocation(S->getBody())); - // Emit the initialization for the pattern. If any of the bound - // patterns - // fail (because this is a 'for case' pattern with a refutable - // pattern, - // the code should jump to the continue block. - InitializationPtr initLoopVars = - SGF.emitPatternBindingInitialization(S->getPattern(), loopDest); - - // If we had a loadable "next" generator value, we know it is present. - // Get the value out of the optional, and wrap it up with a cleanup so - // that any exits out of this scope properly clean it up. - // - // *NOTE* If we do not have an address only value, then inputValue is - // *already properly unwrapped. - SGFContext loopVarCtx{initLoopVars.get()}; - if (nextResultTyIsAddressOnly) { - inputValue = SGF.emitUncheckedGetOptionalValueFrom( - S, inputValue, optTL, - hasElementConversion ? SGFContext() : loopVarCtx); - } - - CanType optConvertedTy = optTy; - if (hasElementConversion) { - inputValue = convertElementRValue(inputValue, loopVarCtx); - optConvertedTy = - OptionalType::get(S->getConvertElementExpr()->getType()) - ->getCanonicalType(); - } - if (!inputValue.isInContext()) - RValue(SGF, S, optConvertedTy.getOptionalObjectType(), inputValue) - .forwardInto(SGF, S->getBody(), initLoopVars.get()); - - // Now that the pattern has been initialized, check any where - // condition. - // If it fails, loop around as if 'continue' happened. - if (auto *Where = S->getWhere()) { - auto cond = SGF.emitCondition(Where, /*invert*/ true); - // If self is null, branch to the epilog. - cond.enterTrue(SGF); - SGF.Cleanups.emitBranchAndCleanups(loopDest, Where, {}); - cond.exitTrue(SGF); - cond.complete(SGF); - } - - visit(S->getBody()); - } - - // If we emitted an unreachable in the body, we will not have a valid - // insertion point. Just return early. - if (!SGF.B.hasValidInsertionPoint()) { - scope.unreachableExit(); - return; - } - - // Otherwise, associate the loop body's closing brace with this branch. - RegularLocation L(S->getBody()); - L.pointToEnd(); - scope.exitAndBranch(L); - }, - SGF.loadProfilerCount(S->getBody())); - - // We add loop fail block, just to be defensive about intermediate - // transformations performing cleanups at scope.exit(). We still jump to the - // contBlock. - switchEnumBuilder.addOptionalNoneCase( - createBasicBlock(), failExitingBlock, - [&](ManagedValue inputValue, SwitchCaseFullExpr &&scope) { - assert(!inputValue && "None should not be passed an argument!"); - scope.exitAndBranch(S); - }, - SGF.loadProfilerCount(S)); - - std::move(switchEnumBuilder).emit(); - - SGF.B.emitBlock(failExitingBlock); - emitOrDeleteBlock(SGF, endDest, S); - SGF.BreakContinueDestStack.pop_back(); + auto* braceStmt = S->getDesugaredStmt(); + if (braceStmt) + visitBraceStmt(braceStmt); } void StmtEmitter::visitBreakStmt(BreakStmt *S) { diff --git a/lib/SILOptimizer/Mandatory/MoveOnlyDiagnostics.cpp b/lib/SILOptimizer/Mandatory/MoveOnlyDiagnostics.cpp index 514391ada4e0c..82b2a99b71e15 100644 --- a/lib/SILOptimizer/Mandatory/MoveOnlyDiagnostics.cpp +++ b/lib/SILOptimizer/Mandatory/MoveOnlyDiagnostics.cpp @@ -187,6 +187,7 @@ void DiagnosticEmitter::emitMissingConsumeInDiscardingContext( case StmtKind::Case: case StmtKind::Fallthrough: case StmtKind::Discard: + case StmtKind::Opaque: return false; }; } diff --git a/lib/Sema/BuilderTransform.cpp b/lib/Sema/BuilderTransform.cpp index 74a74fb2b591e..23d398d11d4e2 100644 --- a/lib/Sema/BuilderTransform.cpp +++ b/lib/Sema/BuilderTransform.cpp @@ -735,7 +735,7 @@ class ResultBuilderTransform forEachStmt->getParsedSequence(), forEachStmt->getWhereLoc(), forEachStmt->getWhere(), cloneBraceWith(forEachStmt->getBody(), newBody), - forEachStmt->isImplicit()); + forEachStmt->getDeclContext(), forEachStmt->isImplicit()); // For a body of new `do` statement that holds updated `for-in` loop // and epilog that consists of a call to `buildArray` that forms the @@ -771,6 +771,7 @@ class ResultBuilderTransform UNSUPPORTED_STMT(Fail) UNSUPPORTED_STMT(PoundAssert) UNSUPPORTED_STMT(Case) + UNSUPPORTED_STMT(Opaque) #undef UNSUPPORTED_STMT diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index c28f85d68ba38..e53e45a5ef2e3 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -2758,6 +2758,11 @@ namespace { return expr; } + Expr *visitOpaqueExpr(OpaqueExpr *expr) { + // Do nothing with error expressions. + return expr; + } + Expr *visitCodeCompletionExpr(CodeCompletionExpr *expr) { // Do nothing with code completion expressions. auto toType = simplifyType(cs.getType(expr)); @@ -9336,107 +9341,26 @@ applySolutionToForEachStmtPreamble(ForEachStmt *stmt, auto &ctx = cs.getASTContext(); auto *parsedSequence = stmt->getParsedSequence(); - bool isAsync = stmt->getAwaitLoc().isValid(); // Simplify the various types. info.sequenceType = solution.simplifyType(info.sequenceType); info.elementType = solution.simplifyType(info.elementType); info.initType = solution.simplifyType(info.initType); - // First, let's apply the solution to the expression. - auto *makeIteratorVar = info.makeIteratorVar; + auto sequenceTarget = *cs.getTargetFor(parsedSequence); - auto makeIteratorTarget = *cs.getTargetFor({makeIteratorVar, /*index=*/0}); - - auto rewrittenTarget = rewriter.rewriteTarget(makeIteratorTarget); + auto rewrittenTarget = rewriter.rewriteTarget(sequenceTarget); if (!rewrittenTarget) return std::nullopt; - // Set type-checked initializer and mark it as such. - { - makeIteratorVar->setInit(/*index=*/0, rewrittenTarget->getAsExpr()); - makeIteratorVar->setInitializerChecked(/*index=*/0); - } - - stmt->setIteratorVar(makeIteratorVar); - - // Now, `$iterator.next()` call. - { - auto nextTarget = *cs.getTargetFor(info.nextCall); - - auto rewrittenTarget = rewriter.rewriteTarget(nextTarget); - if (!rewrittenTarget) - return std::nullopt; - - Expr *nextCall = rewrittenTarget->getAsExpr(); - // Wrap a call to `next()` into `try await` since `AsyncIteratorProtocol` - // witness could be `async throws`. - if (isAsync) { - // Cannot use `forEachChildExpr` here because we need to - // to wrap a call in `try` and then stop immediately after. - struct TryInjector : ASTWalker { - ASTContext &C; - const Solution &S; - - bool ShouldStop = false; - - TryInjector(ASTContext &ctx, const Solution &solution) - : C(ctx), S(solution) {} - - MacroWalking getMacroWalkingBehavior() const override { - return MacroWalking::Expansion; - } - - PreWalkResult walkToExprPre(Expr *E) override { - if (ShouldStop) - return Action::Stop(); - - if (auto *call = dyn_cast(E)) { - // There is a single call expression in `nextCall`. - ShouldStop = true; - - auto nextRefType = - S.getResolvedType(call->getFn())->castTo(); - - // If the inferred witness is throwing, we need to wrap the call - // into `try` expression. - if (nextRefType->isThrowing()) { - auto *tryExpr = TryExpr::createImplicit( - C, /*tryLoc=*/call->getStartLoc(), call, call->getType()); - // Cannot stop here because we need to make sure that - // the new expression gets injected into AST. - return Action::SkipNode(tryExpr); - } - } - - return Action::Continue(E); - } - }; - - nextCall->walk(TryInjector(ctx, solution)); - } - - stmt->setNextCall(nextCall); - } + stmt->setParsedSequence(rewrittenTarget->getAsExpr()); + // FIXME: Next used to be optional but we have to deal w this differently now + // This models the gap between the type of the next call and the pattern's + // We will be typechecking next separately w the contextual type of the + // pattern we already have so this shouldnt be necessary anymore // Convert that std::optional value to the type of the pattern. - auto optPatternType = OptionalType::get(info.initType); - Type nextResultType = OptionalType::get(info.elementType); - if (!optPatternType->isEqual(nextResultType)) { - OpaqueValueExpr *elementExpr = new (ctx) OpaqueValueExpr( - stmt->getInLoc(), nextResultType->getOptionalObjectType(), - /*isPlaceholder=*/false); - cs.cacheExprTypes(elementExpr); - - auto *loc = cs.getConstraintLocator(parsedSequence, - ConstraintLocator::SequenceElementType); - auto *convertExpr = solution.coerceToType(elementExpr, info.initType, loc); - if (!convertExpr) - return std::nullopt; - - stmt->setElementExpr(elementExpr); - stmt->setConvertElementExpr(convertExpr); - } + // (deleted code here) // Get the conformance of the sequence type to the Sequence protocol. auto sequenceProto = TypeChecker::getProtocol( @@ -9590,6 +9514,7 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) { case CTP_Condition: case CTP_WrappedProperty: case CTP_SingleValueStmtBranch: + case CTP_ForEachElement: result.setExpr(rewrittenExpr); break; } diff --git a/lib/Sema/CSDiagnostics.cpp b/lib/Sema/CSDiagnostics.cpp index 80abf67a9db35..aaf5b9891a021 100644 --- a/lib/Sema/CSDiagnostics.cpp +++ b/lib/Sema/CSDiagnostics.cpp @@ -868,6 +868,7 @@ GenericArgumentsMismatchFailure::getDiagnosticFor( case CTP_EnumCaseRawValue: case CTP_ExprPattern: case CTP_SingleValueStmtBranch: + case CTP_ForEachElement: break; } return std::nullopt; @@ -2963,6 +2964,7 @@ getContextualNilDiagnostic(ContextualTypePurpose CTP) { case CTP_WrappedProperty: case CTP_ExprPattern: case CTP_SingleValueStmtBranch: + case CTP_ForEachElement: return std::nullopt; case CTP_EnumCaseRawValue: @@ -3748,6 +3750,7 @@ ContextualFailure::getDiagnosticFor(ContextualTypePurpose context, case CTP_Unused: case CTP_YieldByReference: case CTP_ExprPattern: + case CTP_ForEachElement: break; } return std::nullopt; diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index ed18eb6b96617..4697b818e5209 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -4243,6 +4243,10 @@ namespace { return resultType; } + virtual Type visitOpaqueExpr(OpaqueExpr *E) { + return E->getOriginalExpr()->getType(); + } + static bool isTriggerFallbackDiagnosticBuiltin(UnresolvedDotExpr *UDE, ASTContext &Context) { auto *DRE = dyn_cast(UDE->getBase()); @@ -4660,14 +4664,6 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc, bool isAsync = stmt->getAwaitLoc().isValid(); auto *sequenceExpr = stmt->getParsedSequence(); - // If we have an unsafe expression for the sequence, lift it out of the - // sequence expression. We'll put it back after we've introduced the - // various calls. - UnsafeExpr *unsafeExpr = dyn_cast(sequenceExpr); - if (unsafeExpr) { - sequenceExpr = unsafeExpr->getSubExpr(); - } - auto contextualLocator = cs.getConstraintLocator( sequenceExpr, LocatorPathElt::ContextualType(CTP_ForEachSequence)); auto elementLocator = cs.getConstraintLocator( @@ -4682,164 +4678,36 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc, if (!sequenceProto) return std::nullopt; - std::string name; - { - if (auto np = dyn_cast_or_null(stmt->getPattern())) - name = "$"+np->getBoundName().str().str(); - name += "$generator"; - } - - auto *makeIteratorVar = new (ctx) - VarDecl(/*isStatic=*/false, VarDecl::Introducer::Var, - sequenceExpr->getStartLoc(), ctx.getIdentifier(name), dc); - makeIteratorVar->setImplicit(); - - // FIXME: Apply `nonisolated(unsafe)` to async iterators. - // - // Async iterators are not `Sendable`; they're only meant to be used from - // the isolation domain that creates them. But the `next()` method runs on - // the generic executor, so calling it from an actor-isolated context passes - // non-`Sendable` state across the isolation boundary. `next()` should - // inherit the isolation of the caller, but for now, use the opt out. - if (isAsync) { - auto *nonisolated = - NonisolatedAttr::createImplicit(ctx, NonIsolatedModifier::Unsafe); - makeIteratorVar->addAttribute(nonisolated); - } - - // First, let's form a call from sequence to `.makeIterator()` and save - // that in a special variable which is going to be used by SILGen. - { - FuncDecl *makeIterator = isAsync ? ctx.getAsyncSequenceMakeAsyncIterator() - : ctx.getSequenceMakeIterator(); - - auto *makeIteratorRef = new (ctx) UnresolvedDotExpr( - sequenceExpr, SourceLoc(), DeclNameRef(makeIterator->getName()), - DeclNameLoc(stmt->getForLoc()), /*implicit=*/true); - makeIteratorRef->setFunctionRefInfo(FunctionRefInfo::singleBaseNameApply()); - - Expr *makeIteratorCall = - CallExpr::createImplicitEmpty(ctx, makeIteratorRef); - - // Swap in the 'unsafe' expression. - if (unsafeExpr) { - unsafeExpr->setSubExpr(makeIteratorCall); - makeIteratorCall = unsafeExpr; - } - - Pattern *pattern = NamedPattern::createImplicit(ctx, makeIteratorVar); - auto *PB = PatternBindingDecl::createImplicit( - ctx, StaticSpellingKind::None, pattern, makeIteratorCall, dc); - auto makeIteratorTarget = SyntacticElementTarget::forInitialization( - makeIteratorCall, /*patternType=*/Type(), PB, /*index=*/0, - /*shouldBindPatternsOneWay=*/false); + ContextualTypeInfo contextInfo(sequenceProto->getDeclaredInterfaceType(), + CTP_ForEachSequence); + cs.setContextualInfo(sequenceExpr, contextInfo); - ContextualTypeInfo contextInfo(sequenceProto->getDeclaredInterfaceType(), - CTP_ForEachSequence); - cs.setContextualInfo(sequenceExpr, contextInfo); + auto seqExprTarget = SyntacticElementTarget(sequenceExpr, + dc, contextInfo, false); - if (cs.generateConstraints(makeIteratorTarget)) - return std::nullopt; - - sequenceIterationInfo.makeIteratorVar = PB; - - // Type of sequence expression has to conform to Sequence protocol. - // - // Note that the following emulates having `$generator` separately - // type-checked by introducing a `TVO_PrefersSubtypeBinding` type - // variable that would make sure that result of `.makeIterator` would - // get ranked standalone. - { - auto *externalIteratorType = cs.createTypeVariable( - cs.getConstraintLocator(sequenceExpr), TVO_PrefersSubtypeBinding); - - cs.addConstraint(ConstraintKind::Equal, externalIteratorType, - cs.getType(sequenceExpr), - externalIteratorType->getImpl().getLocator()); - - cs.addConstraint(ConstraintKind::ConformsTo, externalIteratorType, - sequenceProto->getDeclaredInterfaceType(), - contextualLocator); - - sequenceIterationInfo.sequenceType = cs.getType(sequenceExpr); - } + if (cs.generateConstraints(seqExprTarget)) + return std::nullopt; + cs.setTargetFor(sequenceExpr, seqExprTarget); + auto seqType = cs.getType(sequenceExpr); + // Type of sequence expression has to conform to Sequence protocol. + // + // Note that the following emulates having `$generator` separately + // type-checked by introducing a `TVO_PrefersSubtypeBinding` type + // variable that would make sure that result of `.makeIterator` would + // get ranked standalone. + auto *externalSequenceType = cs.createTypeVariable( + cs.getConstraintLocator(sequenceExpr), TVO_PrefersSubtypeBinding); - cs.setTargetFor({PB, /*index=*/0}, makeIteratorTarget); - } + cs.addConstraint(ConstraintKind::Equal, externalSequenceType, + seqType, + externalSequenceType->getImpl().getLocator()); - // Now, result type of `.makeIterator()` is used to form a call to - // `.next()`. `next()` is called on each iteration of the loop. - { - FuncDecl *nextFn = - TypeChecker::getForEachIteratorNextFunction(dc, stmt->getForLoc(), isAsync); - Identifier nextId = nextFn ? nextFn->getName().getBaseIdentifier() - : ctx.Id_next; - TinyPtrVector labels; - if (nextFn && nextFn->getParameters()->size() == 1) - labels.push_back(ctx.Id_isolation); - auto *makeIteratorVarRef = - new (ctx) DeclRefExpr(makeIteratorVar, DeclNameLoc(stmt->getForLoc()), - /*Implicit=*/true); - auto *nextRef = new (ctx) - UnresolvedDotExpr(makeIteratorVarRef, SourceLoc(), - DeclNameRef(DeclName(ctx, nextId, labels)), - DeclNameLoc(stmt->getForLoc()), /*implicit=*/true); - nextRef->setFunctionRefInfo(FunctionRefInfo::singleBaseNameApply()); - - ArgumentList *nextArgs; - if (nextFn && nextFn->getParameters()->size() == 1) { - auto isolationArg = - new (ctx) CurrentContextIsolationExpr(stmt->getForLoc(), Type()); - nextArgs = ArgumentList::createImplicit( - ctx, {Argument(SourceLoc(), ctx.Id_isolation, isolationArg)}); - } else { - nextArgs = ArgumentList::createImplicit(ctx, {}); - } - Expr *nextCall = CallExpr::createImplicit(ctx, nextRef, nextArgs); - - // `next` is always async but witness might not be throwing - if (isAsync) { - nextCall = - AwaitExpr::createImplicit(ctx, nextCall->getLoc(), nextCall); - } - - // Wrap the 'next' call in 'unsafe', if the for..in loop has that - // effect or if the loop is async (in which case the iterator variable - // is nonisolated(unsafe). - if (stmt->getUnsafeLoc().isValid() || - (isAsync && - ctx.LangOpts.StrictConcurrencyLevel == StrictConcurrency::Complete)) { - SourceLoc loc = stmt->getUnsafeLoc(); - bool implicit = stmt->getUnsafeLoc().isInvalid(); - if (loc.isInvalid()) - loc = stmt->getForLoc(); - nextCall = new (ctx) UnsafeExpr(loc, nextCall, Type(), implicit); - } - - // The iterator type must conform to IteratorProtocol. - { - ProtocolDecl *iteratorProto = TypeChecker::getProtocol( - cs.getASTContext(), stmt->getForLoc(), - isAsync ? KnownProtocolKind::AsyncIteratorProtocol - : KnownProtocolKind::IteratorProtocol); - if (!iteratorProto) - return std::nullopt; - - ContextualTypeInfo contextInfo(iteratorProto->getDeclaredInterfaceType(), - CTP_ForEachSequence); - cs.setContextualInfo(nextRef->getBase(), contextInfo); - } - - SyntacticElementTarget nextTarget(nextCall, dc, CTP_Unused, - /*contextualType=*/Type(), - /*isDiscarded=*/false); - if (cs.generateConstraints(nextTarget, FreeTypeVariableBinding::Disallow)) - return std::nullopt; + cs.addConstraint(ConstraintKind::ConformsTo, seqType, + sequenceProto->getDeclaredInterfaceType(), + contextualLocator); - sequenceIterationInfo.nextCall = nextTarget.getAsExpr(); - cs.setTargetFor(sequenceIterationInfo.nextCall, nextTarget); - } + sequenceIterationInfo.sequenceType = seqType; // Generate constraints for the pattern. Type initType = @@ -4850,17 +4718,9 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc, // Add a conversion constraint between the element type of the sequence // and the type of the element pattern. - auto *elementTypeLoc = cs.getConstraintLocator( - elementLocator, ConstraintLocator::OptionalInjection); - auto elementType = cs.createTypeVariable(elementTypeLoc, - /*flags=*/0); - { - auto nextType = cs.getType(sequenceIterationInfo.nextCall); - cs.addConstraint(ConstraintKind::OptionalObject, nextType, elementType, - elementTypeLoc); - cs.addConstraint(ConstraintKind::Conversion, elementType, initType, + auto* elementType = DependentMemberType::get(externalSequenceType, sequenceProto->getAssociatedType(ctx.Id_Element)); + cs.addConstraint(ConstraintKind::Conversion, elementType, initType, elementLocator); - } // Populate all of the information for a for-each loop. sequenceIterationInfo.elementType = elementType; diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index a1f64680131ba..ac507c7ddf86d 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -16701,6 +16701,7 @@ void ConstraintSystem::addContextualConversionConstraint( case CTP_WrappedProperty: case CTP_ExprPattern: case CTP_SingleValueStmtBranch: + case CTP_ForEachElement: break; } diff --git a/lib/Sema/CSSyntacticElement.cpp b/lib/Sema/CSSyntacticElement.cpp index 503cd35d783a9..78fe5f51b1def 100644 --- a/lib/Sema/CSSyntacticElement.cpp +++ b/lib/Sema/CSSyntacticElement.cpp @@ -866,6 +866,7 @@ class SyntacticElementConstraintGenerator } // These statements don't require any type-checking. + void visitOpaqueStmt(OpaqueStmt *opaqueStmt) {} void visitBreakStmt(BreakStmt *breakStmt) {} void visitContinueStmt(ContinueStmt *continueStmt) {} void visitDeferStmt(DeferStmt *deferStmt) {} @@ -1816,6 +1817,10 @@ class SyntacticElementSolutionApplication rewriter.addLocalDeclToTypeCheck(decl); } + ASTNode visitOpaqueStmt(OpaqueStmt *opaqueStmt) { + return opaqueStmt; + } + ASTNode visitBreakStmt(BreakStmt *breakStmt) { // Force the target to be computed in case it produces diagnostics. (void)breakStmt->getTarget(); diff --git a/lib/Sema/SyntacticElementTarget.cpp b/lib/Sema/SyntacticElementTarget.cpp index e978dddbf7743..08b37ea476f3f 100644 --- a/lib/Sema/SyntacticElementTarget.cpp +++ b/lib/Sema/SyntacticElementTarget.cpp @@ -275,6 +275,7 @@ bool SyntacticElementTarget::contextualTypeIsOnlyAHint() const { case CTP_WrappedProperty: case CTP_ExprPattern: case CTP_SingleValueStmtBranch: + case CTP_ForEachElement: return false; } llvm_unreachable("invalid contextual type"); diff --git a/lib/Sema/TypeCheckEffects.cpp b/lib/Sema/TypeCheckEffects.cpp index b9766dbc9e724..6b504c8afe53b 100644 --- a/lib/Sema/TypeCheckEffects.cpp +++ b/lib/Sema/TypeCheckEffects.cpp @@ -1988,15 +1988,19 @@ class ApplyClassifier { classifier.AsyncKind, /*FIXME:*/PotentialEffectReason::forApply()); } - case EffectKind::Unsafe: - llvm_unreachable("Unimplemented"); + case EffectKind::Unsafe: { + FunctionUnsafeClassifier classifier(*this); + stmt->walk(classifier); + return classifier.classification; + } } + llvm_unreachable("Bad effect"); } /// Check to see if the given for-each statement to determine if it /// throws or is async. Classification classifyForEach(ForEachStmt *stmt) { - if (!stmt->getNextCall()) + if (!stmt->getDesugaredStmt()) return Classification::forInvalidCode(); // If there is an 'await', the for-each loop is always async. @@ -2011,10 +2015,10 @@ class ApplyClassifier { } // Merge the thrown result from the next/nextElement call. - result.merge(classifyExpr(stmt->getNextCall(), EffectKind::Throws)); + result.merge(classifyStmt(stmt->getDesugaredStmt(), EffectKind::Throws)); // Merge unsafe effect from the next/nextElement call. - result.merge(classifyExpr(stmt->getNextCall(), EffectKind::Unsafe)); + result.merge(classifyStmt(stmt->getDesugaredStmt(), EffectKind::Unsafe)); return result; } @@ -3629,10 +3633,6 @@ class CheckEffectsCoverage : public EffectsHandlingWalker llvm::DenseMap> uncoveredAsync; llvm::DenseMap parentMap; - /// The next/nextElement call expressions within for-in statements we've - /// seen. - llvm::SmallDenseSet forEachNextCallExprs; - /// Expressions that are assumed to be safe because they are being /// passed directly into an explicitly `@safe` function. llvm::DenseSet assumedSafeArguments; @@ -4384,11 +4384,6 @@ class CheckEffectsCoverage : public EffectsHandlingWalker /*stopAtAutoClosure=*/false, EffectKind::Unsafe); - // We don't diagnose uncovered unsafe uses within the next/nextElement - // call, because they're handled already by the for-in loop checking. - if (forEachNextCallExprs.contains(anchor)) - break; - // Figure out a location to use if the unsafe use didn't have one. SourceLoc replacementLoc; if (anchor) @@ -4585,18 +4580,17 @@ class CheckEffectsCoverage : public EffectsHandlingWalker ShouldRecurse_t checkForEach(ForEachStmt *S) { // Reparent the type-checked sequence on the parsed sequence, so we can // find an anchor. - if (auto typeCheckedExpr = S->getTypeCheckedSequence()) { + if (auto typeCheckedExpr = S->getParsedSequence()) { parentMap = typeCheckedExpr->getParentMap(); - - if (auto parsedSequence = S->getParsedSequence()) { - parentMap[typeCheckedExpr] = parsedSequence; - } } - // Note the nextCall expression. - if (auto nextCall = S->getNextCall()) { - forEachNextCallExprs.insert(nextCall); - } + // Walk everything + S->getParsedSequence()->walk(*this); + S->getBody()->walk(*this); + if (S->getWhere()) + S->getWhere()->walk(*this); + + S->getDesugaredStmt()->walk(*this); auto classification = getApplyClassifier().classifyForEach(S); @@ -4638,7 +4632,13 @@ class CheckEffectsCoverage : public EffectsHandlingWalker } } - return ShouldRecurse; + if (S->getUnsafeLoc().isValid() && !classification.hasUnsafe()){ + Ctx.Diags.diagnose(S->getUnsafeLoc(), + diag::no_unsafe_in_unsafe_for) + .fixItRemove(S->getUnsafeLoc()); + } + + return ShouldNotRecurse; } ShouldRecurse_t checkDefer(DeferStmt *S) { @@ -4704,11 +4704,6 @@ class CheckEffectsCoverage : public EffectsHandlingWalker return; } - Ctx.Diags.diagnose(E->getUnsafeLoc(), - forEachNextCallExprs.contains(E) - ? diag::no_unsafe_in_unsafe_for - : diag::no_unsafe_in_unsafe) - .fixItRemove(E->getUnsafeLoc()); } void noteLabeledConditionalStmt(LabeledConditionalStmt *stmt) { diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp index 9584ec27111e7..611ec04f57a15 100644 --- a/lib/Sema/TypeCheckStmt.cpp +++ b/lib/Sema/TypeCheckStmt.cpp @@ -137,6 +137,11 @@ namespace { CS->setDeclContext(ParentDC); if (auto *FS = dyn_cast(S)) FS->setDeclContext(ParentDC); + if (auto *FES = dyn_cast(S)) + { + FES->setDeclContext(ParentDC); + FES->desugar(); + } return Action::Continue(S); } @@ -1520,6 +1525,10 @@ class StmtChecker : public StmtVisitor { return S; } + Stmt *visitOpaqueStmt(OpaqueStmt *S) { + return S; + } + Stmt *visitBreakStmt(BreakStmt *S) { // Force the target to be computed in case it produces diagnostics. (void)S->getTarget(); @@ -3430,3 +3439,172 @@ FuncDecl *TypeChecker::getForEachIteratorNextFunction( // Fall back to AsyncIteratorProtocol.next(). return ctx.getAsyncIteratorNext(); } + +static BraceStmt *desugarForEachStmt(ForEachStmt* stmt){ + auto *parsedSequence = stmt->getParsedSequence(); + auto *dc = stmt->getDeclContext(); + auto &ctx = dc->getASTContext(); + bool isAsync = stmt->getAwaitLoc().isValid(); + + // If we have an unsafe expression for the sequence, lift it out of the + // sequence expression. We'll put it back after we've introduced the + // various calls. + UnsafeExpr *unsafeExpr = dyn_cast(parsedSequence); + if (unsafeExpr) { + parsedSequence = unsafeExpr->getSubExpr(); + } + + auto opaqueSeqExpr = new (ctx) OpaqueExpr(parsedSequence); + + std::string name; + { + if (auto np = dyn_cast_or_null(stmt->getPattern())) + name = "$"+np->getBoundName().str().str(); + name += "$generator"; + } + + auto *makeIteratorVar = new (ctx) + VarDecl(/*isStatic=*/false, VarDecl::Introducer::Var, + opaqueSeqExpr->getStartLoc(), + ctx.getIdentifier(name), dc); + makeIteratorVar->setImplicit(); + + // Async iterators are not `Sendable`; they're only meant to be used from + // the isolation domain that creates them. But the `next()` method runs on + // the generic executor, so calling it from an actor-isolated context passes + // non-`Sendable` state across the isolation boundary. `next()` should + // inherit the isolation of the caller, but for now, use the opt out. + if (isAsync) { + auto *nonisolated = + NonisolatedAttr::createImplicit(ctx, NonIsolatedModifier::Unsafe); + makeIteratorVar->addAttribute(nonisolated); + } + + // First, let's form a call from sequence to `.makeIterator()` and save + // that in a special variable which is going to be used by SILGen. + FuncDecl *makeIterator = isAsync ? ctx.getAsyncSequenceMakeAsyncIterator() + : ctx.getSequenceMakeIterator(); + + auto *makeIteratorRef = new (ctx) UnresolvedDotExpr( + opaqueSeqExpr, SourceLoc(), DeclNameRef(makeIterator->getName()), + DeclNameLoc(stmt->getForLoc()), /*implicit=*/true); + makeIteratorRef->setFunctionRefInfo(FunctionRefInfo::singleBaseNameApply()); + + Expr *makeIteratorCall = + CallExpr::createImplicitEmpty(ctx, makeIteratorRef); + + // Swap in the 'unsafe' expression. + if (unsafeExpr) { + unsafeExpr = UnsafeExpr::createImplicit(ctx, unsafeExpr->getUnsafeLoc(), makeIteratorCall); + makeIteratorCall = unsafeExpr; + } + + Pattern *pattern = NamedPattern::createImplicit(ctx, makeIteratorVar); + auto *PB = PatternBindingDecl::createImplicit( + ctx, StaticSpellingKind::None, pattern, makeIteratorCall, dc); + + if (TypeChecker::typeCheckPatternBinding(PB, 0)) + return nullptr; + + // The result type of `.makeIterator()` is used to form a call to + // `.next()`. `next()` is called on each iteration of the loop. + FuncDecl *nextFn = + TypeChecker::getForEachIteratorNextFunction(dc, stmt->getForLoc(), isAsync); + Identifier nextId = nextFn ? nextFn->getName().getBaseIdentifier() + : ctx.Id_next; + TinyPtrVector labels; + if (nextFn && nextFn->getParameters()->size() == 1) + labels.push_back(ctx.Id_isolation); + auto *makeIteratorVarRef = + new (ctx) DeclRefExpr(makeIteratorVar, DeclNameLoc(stmt->getForLoc()), + /*Implicit=*/true); + auto *nextRef = new (ctx) + UnresolvedDotExpr(makeIteratorVarRef, SourceLoc(), + DeclNameRef(DeclName(ctx, nextId, labels)), + DeclNameLoc(stmt->getForLoc()), /*implicit=*/true); + nextRef->setFunctionRefInfo(FunctionRefInfo::singleBaseNameApply()); + + ArgumentList *nextArgs; + if (nextFn && nextFn->getParameters()->size() == 1) { + auto isolationArg = + new (ctx) CurrentContextIsolationExpr(stmt->getForLoc(), Type()); + nextArgs = ArgumentList::createImplicit( + ctx, {Argument(SourceLoc(), ctx.Id_isolation, isolationArg)}); + } else { + nextArgs = ArgumentList::createImplicit(ctx, {}); + } + Expr *nextCall = CallExpr::createImplicit(ctx, nextRef, nextArgs); + + // `next` is always async but witness might not be throwing + if (isAsync) { + nextCall = + AwaitExpr::createImplicit(ctx, nextCall->getLoc(), nextCall); + } + + // Wrap the 'next' call in 'unsafe', if the for..in loop has that + // effect or if the loop is async (in which case the iterator variable + // is nonisolated(unsafe). + if (stmt->getUnsafeLoc().isValid() || + (isAsync && + ctx.LangOpts.StrictConcurrencyLevel == StrictConcurrency::Complete)) { + SourceLoc loc = stmt->getUnsafeLoc(); + bool implicit = stmt->getUnsafeLoc().isInvalid(); + if (loc.isInvalid()) + loc = stmt->getForLoc(); + nextCall = new (ctx) UnsafeExpr(loc, nextCall, Type(), implicit); + } + + auto elementPattern = stmt->getPattern(); + auto optPatternType = OptionalType::get(elementPattern->getType()); + swift::constraints::SyntacticElementTarget nextTarget(nextCall, dc, CTP_ForEachElement, + /*contextualType=*/optPatternType, + /*isDiscarded=*/false); + + auto nextCallTarget = TypeChecker::typeCheckExpression(nextTarget); + if (nextCallTarget == std::nullopt) + return nullptr; + nextCall = nextCallTarget->getAsExpr(); + + SmallVector cond; + + auto *somePattern = OptionalSomePattern::createImplicit(ctx, stmt->getPattern()); + somePattern->setType(optPatternType); + + auto PBI = ConditionalPatternBindingInfo::create(ctx, SourceLoc(), somePattern, nextCall); + auto conditionElement = StmtConditionElement(PBI); + cond.push_back(conditionElement); + + /* for ... in ... where cond { body } + * becomes: + * while ... { if cond then body else continue } + */ + auto* whereClause = stmt->getWhere(); + auto* forBody = stmt->getBody(); + + Stmt* whileBody = new (ctx) OpaqueStmt(forBody, SourceLoc(), SourceLoc()); + + if (whereClause) + { + SmallVector thenClause{whileBody}; + + whereClause = new (ctx) OpaqueExpr(whereClause); + + whileBody = new (ctx) IfStmt(SourceLoc(), whereClause, + BraceStmt::create(ctx, SourceLoc(), thenClause, SourceLoc()), SourceLoc(), + nullptr, /*implicit*/ true, ctx); + } + + // FIXME: do we need to do anything extra here or elseswhere if the for each + // stmt is async? + auto* whileStmt = new (ctx) WhileStmt(stmt->getLabelInfo(), SourceLoc(), ctx.AllocateCopy(cond), whileBody, true); + + SmallVector stmts; + stmts.push_back(PB); + stmts.push_back(whileStmt); + + return BraceStmt::create(ctx, stmt->getStartLoc(), stmts, stmt->getEndLoc()); +} + +BraceStmt* DesugarForEachStmtRequest::evaluate(Evaluator &evaluator, ForEachStmt *stmt) const { + return desugarForEachStmt(stmt); +}