Skip to content

Commit 01ca519

Browse files
committed
Sema: Sink Protocols down from BindingSet into PotentialBindings
1 parent 33e2cfa commit 01ca519

File tree

5 files changed

+59
-43
lines changed

5 files changed

+59
-43
lines changed

include/swift/Sema/CSBindings.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ struct PotentialBindings {
241241
llvm::SmallVector<std::pair<TypeVariableType *, Constraint *>, 4> SupertypeOf;
242242
llvm::SmallVector<std::pair<TypeVariableType *, Constraint *>, 4> EquivalentTo;
243243

244+
/// The set of protocol conformance requirements imposed on this type variable.
245+
llvm::SmallVector<Constraint *, 4> Protocols;
246+
244247
ASTNode AssociatedCodeCompletionToken = ASTNode();
245248

246249
/// Add a potential binding to the list of bindings,
@@ -256,6 +259,10 @@ struct PotentialBindings {
256259
});
257260
}
258261

262+
ArrayRef<Constraint *> getConformanceRequirements() const {
263+
return Protocols;
264+
}
265+
259266
private:
260267
/// Attempt to infer a new binding and other useful information
261268
/// (i.e. whether bindings should be delayed) from the given
@@ -365,9 +372,6 @@ class BindingSet {
365372
public:
366373
swift::SmallSetVector<PotentialBinding, 4> Bindings;
367374

368-
/// The set of protocol conformance requirements placed on this type variable.
369-
llvm::SmallVector<Constraint *, 4> Protocols;
370-
371375
/// The set of unique literal protocol requirements placed on this
372376
/// type variable or inferred transitively through subtype chains.
373377
///
@@ -494,10 +498,6 @@ class BindingSet {
494498
return hasViableBindings() || isDirectHole();
495499
}
496500

497-
ArrayRef<Constraint *> getConformanceRequirements() const {
498-
return Protocols;
499-
}
500-
501501
unsigned getNumViableLiteralBindings() const;
502502

503503
unsigned getNumViableDefaultableBindings() const {

include/swift/Sema/CSTrail.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ GRAPH_NODE_CHANGE(RemovedConstraint)
7777
GRAPH_NODE_CHANGE(InferredBindings)
7878
GRAPH_NODE_CHANGE(RetractedBindings)
7979
GRAPH_NODE_CHANGE(RetractedDelayedBy)
80+
GRAPH_NODE_CHANGE(RetractedProtocol)
8081

8182
BINDING_RELATION_CHANGE(RetractedAdjacentVar)
8283
BINDING_RELATION_CHANGE(RetractedSubtypeOf)

lib/Sema/CSBindings.cpp

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@ BindingSet::BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar,
4747

4848
for (auto *constraint : info.Constraints) {
4949
switch (constraint->getKind()) {
50-
case ConstraintKind::NonisolatedConformsTo:
51-
case ConstraintKind::ConformsTo:
52-
if (constraint->getSecondType()->is<ProtocolType>())
53-
Protocols.push_back(constraint);
54-
break;
55-
5650
case ConstraintKind::LiteralConformsTo:
5751
addLiteralRequirement(constraint);
5852
break;
@@ -435,6 +429,8 @@ void BindingSet::inferTransitiveProtocolRequirements() {
435429
}
436430

437431
auto &bindings = node.getBindingSet();
432+
auto conformanceReqs =
433+
node.getPotentialBindings().getConformanceRequirements();
438434

439435
// If current variable already has transitive protocol
440436
// conformances inferred, there is no need to look deeper
@@ -443,8 +439,8 @@ void BindingSet::inferTransitiveProtocolRequirements() {
443439
TypeVariableType *parent = nullptr;
444440
std::tie(parent, currentVar) = workList.pop_back_val();
445441
assert(parent);
446-
propagateProtocolsTo(parent, bindings.getConformanceRequirements(),
447-
*bindings.TransitiveProtocols);
442+
propagateProtocolsTo(parent, conformanceReqs,
443+
*bindings.TransitiveProtocols);
448444
continue;
449445
}
450446

@@ -485,14 +481,16 @@ void BindingSet::inferTransitiveProtocolRequirements() {
485481
if (!node.hasBindingSet())
486482
continue;
487483

488-
const auto &bindings = node.getBindingSet();
484+
auto conformanceReqs =
485+
node.getPotentialBindings().getConformanceRequirements();
489486

490487
llvm::SmallPtrSet<Constraint *, 2> placeholder;
491488
// Add any direct protocols from members of the
492489
// equivalence class, so they could be propagated
493490
// to all of the members.
494-
propagateProtocolsTo(currentVar, bindings.getConformanceRequirements(),
495-
placeholder);
491+
propagateProtocolsTo(currentVar, conformanceReqs, placeholder);
492+
493+
const auto &bindings = node.getBindingSet();
496494

497495
// Since type variables are equal, current type variable
498496
// becomes a subtype to any supertype found in the current
@@ -512,8 +510,7 @@ void BindingSet::inferTransitiveProtocolRequirements() {
512510
// are transitive to its parent, propagate them down the subtype/equivalence
513511
// chain.
514512
if (parent) {
515-
propagateProtocolsTo(parent, bindings.getConformanceRequirements(),
516-
protocols[currentVar]);
513+
propagateProtocolsTo(parent, conformanceReqs, protocols[currentVar]);
517514
}
518515

519516
auto &inferredProtocols = protocols[currentVar];
@@ -526,9 +523,8 @@ void BindingSet::inferTransitiveProtocolRequirements() {
526523
// - all of the transitive protocols inferred through
527524
// the members of the equivalence class.
528525
{
529-
auto directRequirements = bindings.getConformanceRequirements();
530-
protocolsForEquivalence.insert(directRequirements.begin(),
531-
directRequirements.end());
526+
protocolsForEquivalence.insert(conformanceReqs.begin(),
527+
conformanceReqs.end());
532528

533529
protocolsForEquivalence.insert(inferredProtocols.begin(),
534530
inferredProtocols.end());
@@ -2063,6 +2059,12 @@ void PotentialBindings::infer(ConstraintSystem &CS,
20632059
break;
20642060
}
20652061

2062+
case ConstraintKind::NonisolatedConformsTo:
2063+
case ConstraintKind::ConformsTo:
2064+
if (constraint->getSecondType()->is<ProtocolType>())
2065+
Protocols.push_back(constraint);
2066+
break;
2067+
20662068
case ConstraintKind::BridgingConversion:
20672069
case ConstraintKind::CheckedCast:
20682070
case ConstraintKind::EscapableFunctionOf:
@@ -2076,8 +2078,6 @@ void PotentialBindings::infer(ConstraintSystem &CS,
20762078
case ConstraintKind::PackElementOf:
20772079
case ConstraintKind::SameShape:
20782080
case ConstraintKind::MaterializePackExpansion:
2079-
case ConstraintKind::NonisolatedConformsTo:
2080-
case ConstraintKind::ConformsTo:
20812081
case ConstraintKind::LiteralConformsTo:
20822082
case ConstraintKind::Defaultable:
20832083
case ConstraintKind::FallbackType:
@@ -2206,21 +2206,27 @@ void PotentialBindings::retract(ConstraintSystem &CS,
22062206
}),
22072207
Bindings.end());
22082208

2209+
#define CALLBACK(ChangeKind) \
2210+
[&](Constraint *other) { \
2211+
if (other == constraint) { \
2212+
if (recordingChanges) { \
2213+
CS.recordChange(SolverTrail::Change::ChangeKind( \
2214+
TypeVar, constraint)); \
2215+
} \
2216+
return true; \
2217+
} \
2218+
return false; \
2219+
}
2220+
22092221
DelayedBy.erase(
2210-
llvm::remove_if(DelayedBy,
2211-
[&](Constraint *existing) {
2212-
if (existing == constraint) {
2213-
if (recordingChanges) {
2214-
CS.recordChange(SolverTrail::Change::RetractedDelayedBy(
2215-
TypeVar, constraint));
2216-
}
2217-
return true;
2218-
}
2219-
return false;
2220-
}),
2222+
llvm::remove_if(DelayedBy, CALLBACK(RetractedDelayedBy)),
22212223
DelayedBy.end());
22222224

2223-
#define CALLBACK(ChangeKind) \
2225+
Protocols.erase(
2226+
llvm::remove_if(Protocols, CALLBACK(RetractedProtocol)),
2227+
Protocols.end());
2228+
2229+
#define PAIR_CALLBACK(ChangeKind) \
22242230
[&](std::pair<TypeVariableType *, Constraint *> pair) { \
22252231
if (pair.second == constraint) { \
22262232
if (recordingChanges) { \
@@ -2233,19 +2239,19 @@ void PotentialBindings::retract(ConstraintSystem &CS,
22332239
}
22342240

22352241
AdjacentVars.erase(
2236-
llvm::remove_if(AdjacentVars, CALLBACK(RetractedAdjacentVar)),
2242+
llvm::remove_if(AdjacentVars, PAIR_CALLBACK(RetractedAdjacentVar)),
22372243
AdjacentVars.end());
22382244

22392245
SubtypeOf.erase(
2240-
llvm::remove_if(SubtypeOf, CALLBACK(RetractedSubtypeOf)),
2246+
llvm::remove_if(SubtypeOf, PAIR_CALLBACK(RetractedSubtypeOf)),
22412247
SubtypeOf.end());
22422248

22432249
SupertypeOf.erase(
2244-
llvm::remove_if(SupertypeOf, CALLBACK(RetractedSupertypeOf)),
2250+
llvm::remove_if(SupertypeOf, PAIR_CALLBACK(RetractedSupertypeOf)),
22452251
SupertypeOf.end());
22462252

22472253
EquivalentTo.erase(
2248-
llvm::remove_if(EquivalentTo, CALLBACK(RetractedEquivalentTo)),
2254+
llvm::remove_if(EquivalentTo, PAIR_CALLBACK(RetractedEquivalentTo)),
22492255
EquivalentTo.end());
22502256

22512257
#undef CALLBACK

lib/Sema/CSTrail.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,11 @@ void SolverTrail::Change::undo(ConstraintSystem &cs) const {
538538
.DelayedBy.push_back(TheConstraint.Constraint);
539539
break;
540540

541+
case ChangeKind::RetractedProtocol:
542+
cg[TheConstraint.TypeVar].getPotentialBindings()
543+
.Protocols.push_back(TheConstraint.Constraint);
544+
break;
545+
541546
case ChangeKind::RetractedAdjacentVar:
542547
cg[BindingRelation.TypeVar].getPotentialBindings()
543548
.AdjacentVars.emplace_back(BindingRelation.OtherTypeVar,

unittests/Sema/BindingInferenceTests.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,9 @@ TEST_F(SemaTest, TestTransitiveProtocolInference) {
197197
CTP_Initialization)));
198198

199199
auto &bindings = inferBindings(cs, typeVar);
200-
ASSERT_TRUE(bindings.getConformanceRequirements().empty());
200+
ASSERT_TRUE(cs.getConstraintGraph()[typeVar]
201+
.getPotentialBindings().getConformanceRequirements().empty());
202+
201203
ASSERT_TRUE(bool(bindings.TransitiveProtocols));
202204
verifyProtocolInferenceResults(*bindings.TransitiveProtocols,
203205
{protocolTy1});
@@ -218,8 +220,10 @@ TEST_F(SemaTest, TestTransitiveProtocolInference) {
218220
cs.addConstraint(ConstraintKind::Conversion, typeVar, GPT1,
219221
cs.getConstraintLocator({}));
220222

223+
ASSERT_TRUE(cs.getConstraintGraph()[typeVar]
224+
.getPotentialBindings().getConformanceRequirements().empty());
225+
221226
auto &bindings = inferBindings(cs, typeVar);
222-
ASSERT_TRUE(bindings.getConformanceRequirements().empty());
223227
ASSERT_TRUE(bool(bindings.TransitiveProtocols));
224228
verifyProtocolInferenceResults(*bindings.TransitiveProtocols,
225229
{protocolTy1, protocolTy2});

0 commit comments

Comments
 (0)