Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,9 @@ class alignas(1 << TypeAlignInBits) TypeBase
/// can only be a GenericTypeParamType.
bool isRootParameterPack();

bool isParameterPackExpansion();
bool isRootParameterPackExpansion();

/// Determine whether this type is a value parameter 'let N: Int', which is a
/// GenericTypeParamType.
///
Expand Down Expand Up @@ -7529,6 +7532,10 @@ static CanGenericTypeParamType getType(unsigned depth, unsigned index,
return CanGenericTypeParamType(
GenericTypeParamType::getType(depth, index, C));
}
static CanGenericTypeParamType getPackType(unsigned depth, unsigned index, const ASTContext &ctx) {
return CanGenericTypeParamType(
GenericTypeParamType::getPack(depth, index, ctx));
}
static CanGenericTypeParamType getOpaqueResultType(unsigned depth, unsigned index,
const ASTContext &C) {
return CanGenericTypeParamType(
Expand Down
5 changes: 5 additions & 0 deletions lib/AST/ParameterPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ bool TypeBase::isRootParameterPack() {
t->castTo<GenericTypeParamType>()->isParameterPack();
}

bool TypeBase::isParameterPackExpansion() {
Type t(this);
return t->getKind() == TypeKind::PackExpansion;
}

PackType *TypeBase::getPackSubstitutionAsPackType() {
if (auto pack = getAs<PackType>()) {
return pack;
Expand Down
160 changes: 103 additions & 57 deletions lib/AST/RequirementMachine/InterfaceType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,13 @@ MutableTerm RewriteContext::getMutableTermForType(CanType paramType,
const ProtocolDecl *proto) {
ASSERT(paramType->isTypeParameter());

llvm::dbgs() << "getMutableTermForType "<< paramType << "\n";

// Collect zero or more nested type names in reverse order.
bool innermostAssocTypeWasResolved = false;


bool containsShape = false;

SmallVector<Symbol, 3> symbols;
while (auto memberType = dyn_cast<DependentMemberType>(paramType)) {
paramType = memberType.getBase();
Expand All @@ -171,6 +175,10 @@ MutableTerm RewriteContext::getMutableTermForType(CanType paramType,
thisProto = proto;
innermostAssocTypeWasResolved = true;
}
if (paramType->getKind() == TypeKind::Pack) {
containsShape = true;
symbols.pop_back();
}
symbols.push_back(Symbol::forAssociatedType(thisProto,
assocType->getName(),
*this));
Expand All @@ -195,6 +203,11 @@ MutableTerm RewriteContext::getMutableTermForType(CanType paramType,

std::reverse(symbols.begin(), symbols.end());

if (containsShape)
symbols.push_back(Symbol::forShape(*this));

llvm::dbgs() << MutableTerm(symbols);

return MutableTerm(symbols);
}

Expand Down Expand Up @@ -270,47 +283,47 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end,

for (auto *iter = begin; iter != end; ++iter) {
auto symbol = *iter;

if (!result) {
// A valid term always begins with a generic parameter, protocol or
// associated type symbol.
switch (symbol.getKind()) {
case Symbol::Kind::GenericParam:
handleRoot(symbol.getGenericParam());
continue;

case Symbol::Kind::Protocol:
handleRoot(ctx.getASTContext().TheSelfType);
continue;

case Symbol::Kind::AssociatedType:
handleRoot(ctx.getASTContext().TheSelfType);

// An associated type symbol at the root means we have a dependent
// member type rooted at Self; handle the associated type below.
break;

case Symbol::Kind::PackElement:
continue;

case Symbol::Kind::Name:
case Symbol::Kind::Layout:
case Symbol::Kind::Superclass:
case Symbol::Kind::ConcreteType:
case Symbol::Kind::ConcreteConformance:
case Symbol::Kind::Shape:
ABORT([&](auto &out) {
out << "Invalid root symbol: " << MutableTerm(begin, end);
});
case Symbol::Kind::GenericParam:
handleRoot(symbol.getGenericParam());
continue;
case Symbol::Kind::Protocol:
handleRoot(ctx.getASTContext().TheSelfType);
continue;
case Symbol::Kind::AssociatedType:
handleRoot(ctx.getASTContext().TheSelfType);
// An associated type symbol at the root means we have a dependent
// member type rooted at Self; handle the associated type below.
break;
case Symbol::Kind::PackElement:
continue;
case Symbol::Kind::Name:
case Symbol::Kind::Layout:
case Symbol::Kind::Superclass:
case Symbol::Kind::ConcreteType:
case Symbol::Kind::ConcreteConformance:
case Symbol::Kind::Shape:
ABORT([&](auto &out) {
out << "Invalid root symbol: " << MutableTerm(begin, end);
});
}
}

// An unresolved type can appear if we have invalid requirements.
if (symbol.getKind() == Symbol::Kind::Name) {
result = DependentMemberType::get(result, symbol.getName());
continue;
}

// We can end up with an unsimplified term like this:
//
// X.[P].[P:X]
Expand All @@ -323,12 +336,12 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end,
auto proto = (iter + 1)->getProtocol();
ASSERT(proto == symbol.getProtocol());
}

continue;
}

ASSERT(symbol.getKind() == Symbol::Kind::AssociatedType);

MutableTerm prefix;
if (begin == iter) {
// If the term begins with an associated type symbol, look for the
Expand All @@ -345,17 +358,17 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end,
// for an associated type in those protocols.
prefix.append(begin, iter);
}

auto *props = map.lookUpProperties(prefix.rbegin(), prefix.rend());
if (props == nullptr) {
ABORT([&](auto &out) {
out << "Cannot build interface type for term "
<< MutableTerm(begin, end) << "\n";
<< MutableTerm(begin, end) << "\n";
out << "Prefix does not conform to any protocols: " << prefix << "\n\n";
map.dump(out);
});
}

// Assert that the associated type's protocol appears among the set
// of protocols that the prefix conforms to.
if (CONDITIONAL_ASSERT_enabled()) {
Expand All @@ -364,24 +377,24 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end,
symbol.getProtocol())
!= conformsTo.end());
}

auto *assocType = props->getAssociatedType(symbol.getName());
if (assocType == nullptr) {
ABORT([&](auto &out) {
out << "Cannot build interface type for term "
<< MutableTerm(begin, end) << "\n";
<< MutableTerm(begin, end) << "\n";
out << "Prefix term does not have a nested type named "
<< symbol.getName() << ": " << prefix << "\n";
<< symbol.getName() << ": " << prefix << "\n";
out << "Property map entry: ";
props->dump(out);
out << "\n\n";
map.dump(out);
});
}

result = DependentMemberType::get(result, assocType);
}

llvm::dbgs() << "type for symbol range " << result << " from " << MutableTerm(begin,end);
return result;
}

Expand Down Expand Up @@ -420,12 +433,21 @@ MutableTerm
RewriteContext::getRelativeTermForType(CanType typeWitness,
ArrayRef<Term> substitutions) {
MutableTerm result;

llvm::dbgs() << "getrelativeterm "<< typeWitness << "\n ";
for(auto s : substitutions)
s.dump(llvm::dbgs());
llvm::dbgs() <<"\n\n";
// Get the substitution S corresponding to τ_0_n.
unsigned index = getGenericParamIndex(typeWitness->getRootGenericParam());

result = MutableTerm(substitutions[index]);
ASSERT(result.back().getKind() != Symbol::Kind::Shape);
MutableTerm endInShape;

// If the substitution ends in a Shape, save it for the end of processing
if (result.isPackTerm()) {
endInShape = MutableTerm(result.removeEnd());
}

// If the substitution is a term consisting of a single protocol symbol
// [P], save P for later.
const ProtocolDecl *proto = nullptr;
Expand Down Expand Up @@ -463,6 +485,9 @@ RewriteContext::getRelativeTermForType(CanType typeWitness,
for (auto iter = symbols.rbegin(), end = symbols.rend(); iter != end; ++iter)
result.add(*iter);

if (!endInShape.empty())
result.append(endInShape);
llvm::dbgs() << "result of relative term for type ", result.dump(llvm::dbgs());
return result;
}

Expand All @@ -480,12 +505,16 @@ Type PropertyMap::getTypeFromSubstitutionSchema(
return schema.transformWithPosition(
TypePosition::Invariant,
[&](Type t, TypePosition pos) -> std::optional<Type> {
llvm::dbgs() << "\n transform with position in getTypeFromSubstSchema ", t.dump(llvm::dbgs());

// Consider if as in TypeDifference, there should be a different treatment for PackExpansionTypes here
if (t->is<GenericTypeParamType>()) {
auto index = RewriteContext::getGenericParamIndex(t);
auto substitution = substitutions[index];

bool isShapePosition = (pos == TypePosition::Shape);
bool isShapeTerm = (substitution.back() == Symbol::forShape(Context));
auto shapeTerm = MutableTerm(substitution);
bool isShapeTerm = shapeTerm.isPackTerm();
if (isShapePosition != isShapeTerm) {
ABORT([&](auto &out) {
out << "Shape vs. type mixup\n\n";
Expand All @@ -504,8 +533,8 @@ Type PropertyMap::getTypeFromSubstitutionSchema(
// Undo the thing where the count type of a PackExpansionType
// becomes a shape term.
if (isShapeTerm) {
MutableTerm mutTerm(substitution.begin(), substitution.end() - 1);
substitution = Term::get(mutTerm, Context);
shapeTerm.removeEnd();
substitution = Term::get(shapeTerm, Context);
}

// Prepend the prefix of the lookup key to the substitution.
Expand Down Expand Up @@ -563,21 +592,35 @@ RewriteContext::getRelativeSubstitutionSchemaFromType(
[&](Type t, TypePosition pos) -> std::optional<Type> {
if (!t->isTypeParameter())
return std::nullopt;

auto term = getRelativeTermForType(CanType(t), substitutions);


llvm::dbgs() << "\n transform with position in getRelativeSubstitutionFromTYpe ", term.dump(llvm::dbgs());

unsigned index = result.size();

// PackExpansionType(pattern=T, count=U) becomes
// PackExpansionType(pattern=τ_0_0, count=τ_0_1) with
//
// τ_0_0 := T
// τ_0_1 := U.[shape]
ASSERT(pos != TypePosition::Shape && "Not implemented");

unsigned index = result.size();


// Turn a count type of PackExpansion type away from a shape term
if (pos == TypePosition::Shape) {
if (term.isPackTerm()) {
MutableTerm mutTerm(term);
mutTerm.removeEnd();
term = mutTerm;
}

result.push_back(Term::get(term, *this));
return CanGenericTypeParamType::getPackType(/*depth=*/0, index, Context);
}

result.push_back(Term::get(term, *this));

return CanGenericTypeParamType::getType(/*depth=*/0, index, Context);

}));
}

Expand Down Expand Up @@ -610,13 +653,16 @@ RewriteContext::getSubstitutionSchemaFromType(CanType concreteType,
// τ_0_0 := T
// τ_0_1 := U.[shape]
MutableTerm term = getMutableTermForType(CanType(t), proto);
if (pos == TypePosition::Shape)
bool shapePos = pos == TypePosition::Shape;
if (shapePos)
term.add(Symbol::forShape(*this));

unsigned index = result.size();

result.push_back(Term::get(term, *this));

return CanGenericTypeParamType::getType(/*depth=*/0, index, Context);
if (shapePos)
return CanGenericTypeParamType::getPackType(0, index, Context);
else
return CanGenericTypeParamType::getType(0, index, Context);
}));
}
10 changes: 6 additions & 4 deletions lib/AST/RequirementMachine/PropertyUnification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,10 @@ void PropertyMap::unifyConcreteTypes(Term key,

bool debug = Debug.contains(DebugFlags::ConcreteUnification);

if (debug) {
//if (debug) {
llvm::dbgs() << "% Unifying " << lhsProperty
<< " with " << rhsProperty << "\n";
}
//}

std::optional<unsigned> lhsDifferenceID;
std::optional<unsigned> rhsDifferenceID;
Expand Down Expand Up @@ -489,8 +489,8 @@ void PropertyMap::unifyConcreteTypes(Term key,
ASSERT(!rhsDifferenceID);

const auto &lhsDifference = System.getTypeDifference(*lhsDifferenceID);
ASSERT(lhsProperty == lhsDifference.LHS);
ASSERT(rhsProperty == lhsDifference.RHS);
//ASSERT(lhsProperty == lhsDifference.LHS);
//ASSERT(rhsProperty == lhsDifference.RHS);

// Build a rewrite path (T.[RHS] => T).
RewritePath path;
Expand Down Expand Up @@ -570,7 +570,9 @@ void PropertyMap::unifyConcreteTypes(
// Unify this rule with all other concrete type rules we've seen so far,
// to record rewrite loops relating the rules and their projections.
for (auto pair : existingRules) {
llvm::dbgs() << "one call of unify types\n";
unifyConcreteTypes(key, pair.first, pair.second, property, ruleID);
llvm::dbgs() << "call of unify type ends\n";
}

// Record the new rule.
Expand Down
2 changes: 2 additions & 0 deletions lib/AST/RequirementMachine/RequirementBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,14 @@ void RequirementBuilder::addRequirementRules(ArrayRef<unsigned> rules) {

if (constraintTerm.back().getKind() == Symbol::Kind::Shape) {
ASSERT(rule.getRHS().back().getKind() == Symbol::Kind::Shape);
llvm::dbgs() << "stripping off shape term\n";
// Strip off the shape symbol from the constraint term.
constraintTerm = MutableTerm(constraintTerm.begin(),
constraintTerm.end() - 1);
}

if (constraintTerm.front().getKind() == Symbol::Kind::PackElement) {
llvm::dbgs() << "stripping off front PackELement\n";
// Strip off the element symbol from the constraint term.
constraintTerm = MutableTerm(constraintTerm.begin() + 1,
constraintTerm.end());
Expand Down
Loading