Skip to content

Commit e7ec1ab

Browse files
committed
wip: desugar ForEachStmt at AST level
State: Currently trying to figure out how to compute the Continue and Break targets since they need to point to the synthesized While statement. Current test case: ``` func myFunc() { let seq = [0,1,2] for _ in seq { continue } } ``` I have not yet tested any behavior of generated SIL code but it seems reasonable in comparison. So far, I have tested to compile regular (non-async, non-borrowing, non-unsafe) and unsafe ForEachStmts. Here: * include/swift/AST/ASTBridging.h, * include/swift/AST/Expr.h, * include/swift/AST/ExprNodes.def, * include/swift/AST/Stmt.h, * include/swift/AST/StmtNodes.def, * include/swift/AST/TypeCheckRequests.h, * include/swift/AST/TypeCheckerTypeIDZone.def, * include/swift/Sema/ConstraintLocator.h, * include/swift/Sema/SyntacticElementTarget.h, * lib/AST/ASTDumper.cpp, * lib/AST/ASTPrinter.cpp, * lib/AST/ASTScopeCreation.cpp, * lib/AST/ASTVerifier.cpp, * lib/AST/ASTWalker.cpp, * lib/AST/Bridging/StmtBridging.cpp, * lib/AST/Expr.cpp, * lib/AST/Stmt.cpp, * lib/AST/TypeCheckRequests.cpp, * lib/ASTGen/Sources/ASTGen/Stmts.swift, * lib/Parse/ParseStmt.cpp, * lib/SILGen/ASTVisitor.h, * lib/SILGen/SILGenExpr.cpp, * lib/SILGen/SILGenStmt.cpp, * lib/SILOptimizer/Mandatory/MoveOnlyDiagnostics.cpp, * lib/Sema/BuilderTransform.cpp, * lib/Sema/CSApply.cpp, * lib/Sema/CSDiagnostics.cpp, * lib/Sema/CSGen.cpp, * lib/Sema/CSSimplify.cpp, * lib/Sema/CSSyntacticElement.cpp, * lib/Sema/SyntacticElementTarget.cpp, * lib/Sema/TypeCheckEffects.cpp, * lib/Sema/TypeCheckStmt.cpp.
1 parent ef9e3ef commit e7ec1ab

33 files changed

+437
-526
lines changed

include/swift/AST/ASTBridging.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2404,13 +2404,14 @@ BridgedFallthroughStmt_createParsed(swift::SourceLoc loc,
24042404
BridgedDeclContext cDC);
24052405

24062406
SWIFT_NAME("BridgedForEachStmt.createParsed(_:labelInfo:forLoc:tryLoc:awaitLoc:"
2407-
"unsafeLoc:pattern:inLoc:sequence:whereLoc:whereExpr:body:)")
2407+
"unsafeLoc:pattern:inLoc:sequence:whereLoc:whereExpr:body:declContext:)")
24082408
BridgedForEachStmt BridgedForEachStmt_createParsed(
24092409
BridgedASTContext cContext, BridgedLabeledStmtInfo cLabelInfo,
24102410
swift::SourceLoc forLoc, swift::SourceLoc tryLoc, swift::SourceLoc awaitLoc,
24112411
swift::SourceLoc unsafeLoc, BridgedPattern cPat, swift::SourceLoc inLoc,
24122412
BridgedExpr cSequence, swift::SourceLoc whereLoc,
2413-
BridgedNullableExpr cWhereExpr, BridgedBraceStmt cBody);
2413+
BridgedNullableExpr cWhereExpr, BridgedBraceStmt cBody,
2414+
BridgedDeclContext cDeclContext);
24142415

24152416
SWIFT_NAME("BridgedGuardStmt.createParsed(_:guardLoc:conds:body:)")
24162417
BridgedGuardStmt BridgedGuardStmt_createParsed(BridgedASTContext cContext,

include/swift/AST/Expr.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6724,6 +6724,27 @@ class MacroExpansionExpr final : public Expr,
67246724
}
67256725
};
67266726

6727+
/// OpaqueExpr - created to serve as an indirection to a ForEachStmt's sequence
6728+
/// expr and where clause to avoid visiting it twice in the ASTWalker after
6729+
/// having desugared the loop. This will only be processed in SILGen to emit
6730+
/// the underlying expression.
6731+
class OpaqueExpr final : public Expr {
6732+
Expr *OriginalExpr;
6733+
6734+
public:
6735+
OpaqueExpr(Expr* originalExpr)
6736+
: Expr(ExprKind::Opaque, /*implicit*/ true, originalExpr->getType()),
6737+
OriginalExpr(originalExpr) {}
6738+
6739+
Expr *getOriginalExpr() const { return OriginalExpr; }
6740+
SourceLoc getStartLoc() const { return OriginalExpr->getStartLoc(); }
6741+
SourceLoc getEndLoc() const { return OriginalExpr->getEndLoc(); }
6742+
6743+
static bool classof(const Expr *E) {
6744+
return E->getKind() == ExprKind::Opaque;
6745+
}
6746+
};
6747+
67276748
inline bool Expr::isInfixOperator() const {
67286749
return isa<BinaryExpr>(this) || isa<TernaryExpr>(this) ||
67296750
isa<AssignExpr>(this) || isa<ExplicitCastExpr>(this);

include/swift/AST/ExprNodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ EXPR(Tap, Expr)
218218
UNCHECKED_EXPR(TypeJoin, Expr)
219219
EXPR(MacroExpansion, Expr)
220220
EXPR(TypeValue, Expr)
221+
EXPR(Opaque, Expr)
221222
// Don't forget to update the LAST_EXPR below when adding a new Expr here.
222223
LAST_EXPR(TypeValue)
223224

include/swift/AST/Stmt.h

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,21 +1003,21 @@ class ForEachStmt : public LabeledStmt {
10031003
SourceLoc WhereLoc;
10041004
Expr *WhereExpr = nullptr;
10051005
BraceStmt *Body;
1006+
DeclContext* DC = nullptr;
10061007

10071008
// Set by Sema:
10081009
ProtocolConformanceRef sequenceConformance = ProtocolConformanceRef();
10091010
Type sequenceType;
1010-
PatternBindingDecl *iteratorVar = nullptr;
10111011
Expr *nextCall = nullptr;
1012-
OpaqueValueExpr *elementExpr = nullptr;
1012+
BraceStmt *desugaredStmt = nullptr;
10131013
Expr *convertElementExpr = nullptr;
10141014

10151015
public:
10161016
ForEachStmt(LabeledStmtInfo LabelInfo, SourceLoc ForLoc, SourceLoc TryLoc,
10171017
SourceLoc AwaitLoc, SourceLoc UnsafeLoc, Pattern *Pat,
10181018
SourceLoc InLoc, Expr *Sequence,
10191019
SourceLoc WhereLoc, Expr *WhereExpr, BraceStmt *Body,
1020-
std::optional<bool> implicit = std::nullopt)
1020+
DeclContext* DC, std::optional<bool> implicit = std::nullopt)
10211021
: LabeledStmt(StmtKind::ForEach, getDefaultImplicitFlag(implicit, ForLoc),
10221022
LabelInfo),
10231023
ForLoc(ForLoc), TryLoc(TryLoc), AwaitLoc(AwaitLoc), UnsafeLoc(UnsafeLoc),
@@ -1026,15 +1026,9 @@ class ForEachStmt : public LabeledStmt {
10261026
setPattern(Pat);
10271027
}
10281028

1029-
void setIteratorVar(PatternBindingDecl *var) { iteratorVar = var; }
1030-
PatternBindingDecl *getIteratorVar() const { return iteratorVar; }
1031-
10321029
void setNextCall(Expr *next) { nextCall = next; }
10331030
Expr *getNextCall() const { return nextCall; }
10341031

1035-
void setElementExpr(OpaqueValueExpr *expr) { elementExpr = expr; }
1036-
OpaqueValueExpr *getElementExpr() const { return elementExpr; }
1037-
10381032
void setConvertElementExpr(Expr *expr) { convertElementExpr = expr; }
10391033
Expr *getConvertElementExpr() const { return convertElementExpr; }
10401034

@@ -1076,20 +1070,23 @@ class ForEachStmt : public LabeledStmt {
10761070
Expr *getParsedSequence() const { return Sequence; }
10771071
void setParsedSequence(Expr *S) { Sequence = S; }
10781072

1079-
/// Type-checked version of the sequence or nullptr if this statement
1080-
/// yet to be type-checked.
1081-
Expr *getTypeCheckedSequence() const;
1082-
10831073
/// getBody - Retrieve the body of the loop.
10841074
BraceStmt *getBody() const { return Body; }
10851075
void setBody(BraceStmt *B) { Body = B; }
10861076

10871077
SourceLoc getStartLoc() const { return getLabelLocOrKeywordLoc(ForLoc); }
10881078
SourceLoc getEndLoc() const { return Body->getEndLoc(); }
1079+
1080+
DeclContext *getDeclContext() const { return DC; }
1081+
void setDeclContext(DeclContext *newDC) { DC = newDC; }
10891082

10901083
static bool classof(const Stmt *S) {
10911084
return S->getKind() == StmtKind::ForEach;
10921085
}
1086+
1087+
BraceStmt* desugar();
1088+
BraceStmt* getDesugaredStmt() const { return desugaredStmt; }
1089+
void setDesugaredStmt(BraceStmt* newStmt) { desugaredStmt = newStmt; }
10931090
};
10941091

10951092
/// A pattern and an optional guard expression used in a 'case' statement.
@@ -1541,6 +1538,31 @@ class DoCatchStmt final
15411538
}
15421539
};
15431540

1541+
/// OpaqueStmt - created to serve as an indirection to a ForEachStmt's body
1542+
/// to avoid visiting it twice in the ASTWalker after having desugared the loop.
1543+
/// This ensures we only visit the body once, and this OpaqueStmt will only be
1544+
/// visited to emit the underlying statement in SILGen.
1545+
class OpaqueStmt final : public Stmt {
1546+
SourceLoc StartLoc;
1547+
SourceLoc EndLoc;
1548+
BraceStmt *Body; // FIXME: should I just use Stmt * so that this is more versatile?
1549+
// If not, should the class be renamed to be more specific?
1550+
public:
1551+
OpaqueStmt(BraceStmt* body, SourceLoc startLoc, SourceLoc endLoc)
1552+
: Stmt(StmtKind::Opaque, true /*always implicit*/),
1553+
StartLoc(startLoc), EndLoc(endLoc), Body(body) {}
1554+
1555+
SourceLoc getLoc() const { return StartLoc; }
1556+
SourceLoc getStartLoc() const { return StartLoc; }
1557+
SourceLoc getEndLoc() const { return EndLoc; }
1558+
1559+
BraceStmt* getUnderlyingStmt() { return Body; }
1560+
1561+
static bool classof(const Stmt *S) {
1562+
return S->getKind() == StmtKind::Opaque;
1563+
}
1564+
};
1565+
15441566
/// BreakStmt - The "break" and "break label" statement.
15451567
class BreakStmt : public Stmt {
15461568
SourceLoc Loc;

include/swift/AST/StmtNodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ ABSTRACT_STMT(Labeled, Stmt)
6161
LABELED_STMT(ForEach, LabeledStmt)
6262
LABELED_STMT(Switch, LabeledStmt)
6363
STMT_RANGE(Labeled, If, Switch)
64+
STMT(Opaque, Stmt)
6465
STMT(Case, Stmt)
6566
STMT(Break, Stmt)
6667
STMT(Continue, Stmt)

include/swift/AST/TypeCheckRequests.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5591,6 +5591,25 @@ class IsCustomAvailabilityDomainPermanentlyEnabled
55915591
}
55925592
};
55935593

5594+
class DesugarForEachStmtRequest
5595+
: public SimpleRequest<DesugarForEachStmtRequest,
5596+
BraceStmt *(ForEachStmt*),
5597+
RequestFlags::SeparatelyCached> {
5598+
public:
5599+
using SimpleRequest::SimpleRequest;
5600+
5601+
private:
5602+
friend SimpleRequest;
5603+
5604+
// Evaluation.
5605+
BraceStmt *evaluate(Evaluator &evaluator, ForEachStmt *FES) const;
5606+
5607+
public:
5608+
bool isCached() const { return true; }
5609+
std::optional<BraceStmt*> getCachedResult() const;
5610+
void cacheResult(BraceStmt *stmt) const;
5611+
};
5612+
55945613
#define SWIFT_TYPEID_ZONE TypeChecker
55955614
#define SWIFT_TYPEID_HEADER "swift/AST/TypeCheckerTypeIDZone.def"
55965615
#include "swift/Basic/DefineTypeIDZone.h"

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,3 +674,7 @@ SWIFT_REQUEST(TypeChecker, IsCustomAvailabilityDomainPermanentlyEnabled,
674674
SWIFT_REQUEST(TypeChecker, EmitPerformanceHints,
675675
evaluator::SideEffect(SourceFile *),
676676
Cached, NoLocationInfo)
677+
678+
SWIFT_REQUEST(TypeChecker, DesugarForEachStmtRequest,
679+
Stmt*(const ForEachStmt*),
680+
Cached, NoLocationInfo)

include/swift/Sema/ConstraintLocator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ enum ContextualTypePurpose : uint8_t {
8383

8484
CTP_ExprPattern, ///< `~=` operator application associated with expression
8585
/// pattern.
86+
87+
CTP_ForEachElement, ///< Element expression associated with `for-in` loop.
8688
};
8789

8890
namespace constraints {

include/swift/Sema/SyntacticElementTarget.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,6 @@ struct SequenceIterationInfo {
4141

4242
/// The type of the pattern that matches the elements.
4343
Type initType;
44-
45-
/// Implicit `$iterator = <sequence>.makeIterator()`
46-
PatternBindingDecl *makeIteratorVar;
47-
48-
/// Implicit `$iterator.next()` call.
49-
Expr *nextCall;
5044
};
5145

5246
/// Describes information about a for-in loop over a pack that needs to be
@@ -605,6 +599,7 @@ class SyntacticElementTarget {
605599
case CTP_Initialization:
606600
case CTP_ForEachSequence:
607601
case CTP_ExprPattern:
602+
case CTP_ForEachElement:
608603
break;
609604
default:
610605
assert(false && "Unexpected contextual type purpose");

lib/AST/ASTDumper.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3252,6 +3252,10 @@ class PrintStmt : public StmtVisitor<PrintStmt, void, Label>,
32523252
printFlag(S->TrailingSemiLoc.isValid(), "trailing_semi");
32533253
}
32543254

3255+
void visitOpaqueStmt(OpaqueStmt *S, Label label){
3256+
visitBraceStmt(S->getUnderlyingStmt(), label);
3257+
}
3258+
32553259
void visitBraceStmt(BraceStmt *S, Label label) {
32563260
printCommon(S, "brace_stmt", label);
32573261
printList(S->getElements(), [&](auto &Elt, Label label) {
@@ -3332,20 +3336,15 @@ class PrintStmt : public StmtVisitor<PrintStmt, void, Label>,
33323336
printRec(S->getWhere(), Label::always("where"));
33333337
}
33343338
printRec(S->getParsedSequence(), Label::optional("parsed_sequence"));
3335-
if (S->getIteratorVar()) {
3336-
printRec(S->getIteratorVar(), Label::optional("iterator_var"));
3337-
}
3338-
if (S->getNextCall()) {
3339-
printRec(S->getNextCall(), Label::optional("next_call"));
3340-
}
33413339
if (S->getConvertElementExpr()) {
33423340
printRec(S->getConvertElementExpr(),
33433341
Label::optional("convert_element_expr"));
33443342
}
3345-
if (S->getElementExpr()) {
3346-
printRec(S->getElementExpr(), Label::optional("element_expr"));
3347-
}
3343+
33483344
printRec(S->getBody(), Label::optional("body"));
3345+
3346+
printRec(S->getDesugaredStmt(), Label::optional("desugared_loop"));
3347+
33493348
printFoot();
33503349
}
33513350
void visitBreakStmt(BreakStmt *S, Label label) {
@@ -4237,6 +4236,10 @@ class PrintExpr : public ExprVisitor<PrintExpr, void, Label>,
42374236
printFoot();
42384237
}
42394238

4239+
void visitOpaqueExpr(OpaqueExpr *E, Label label){
4240+
visit(E->getOriginalExpr(), label);
4241+
}
4242+
42404243
void visitPropertyWrapperValuePlaceholderExpr(
42414244
PropertyWrapperValuePlaceholderExpr *E, Label label) {
42424245
printCommon(E, "property_wrapper_value_placeholder_expr", label);

0 commit comments

Comments
 (0)