This commit is contained in:
Daniel Kirchner 2023-06-25 07:04:50 +02:00
parent 6f352cbcbe
commit 5bfe862bc4
5 changed files with 66 additions and 72 deletions

View File

@ -149,7 +149,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
map<string, Type> functionTypes; map<string, Type> functionTypes;
Type typeVar = m_typeSystem.freshTypeVariable(false, {}); Type typeVar = m_typeSystem.freshTypeVariable({});
auto& typeMembers = annotation().members[&_typeClassDefinition]; auto& typeMembers = annotation().members[&_typeClassDefinition];
@ -160,7 +160,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
solAssert(functionDefinition); solAssert(functionDefinition);
auto functionDefinitionType = annotation(*functionDefinition).type; auto functionDefinitionType = annotation(*functionDefinition).type;
solAssert(functionDefinitionType); solAssert(functionDefinitionType);
auto functionType = m_env->fresh(*functionDefinitionType, true); auto functionType = m_env->fresh(*functionDefinitionType);
functionTypes[functionDefinition->name()] = functionType; functionTypes[functionDefinition->name()] = functionType;
auto typeVars = TypeSystemHelpers{m_typeSystem}.typeVars(functionType); auto typeVars = TypeSystemHelpers{m_typeSystem}.typeVars(functionType);
if(typeVars.size() != 1) if(typeVars.size() != 1)
@ -176,7 +176,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
solAssert(typeVariableAnnotation.type); solAssert(typeVariableAnnotation.type);
TypeSystemHelpers helper{m_typeSystem}; TypeSystemHelpers helper{m_typeSystem};
unify(*typeVariableAnnotation.type, helper.kindType(m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{&_typeClassDefinition}}})), _typeClassDefinition.location()); unify(*typeVariableAnnotation.type, helper.kindType(m_typeSystem.freshTypeVariable(Sort{{TypeClass{&_typeClassDefinition}}})), _typeClassDefinition.location());
for (auto instantiation: m_analysis.annotation<TypeRegistration>(_typeClassDefinition).instantiations | ranges::views::values) for (auto instantiation: m_analysis.annotation<TypeRegistration>(_typeClassDefinition).instantiations | ranges::views::values)
// TODO: recursion-safety? // TODO: recursion-safety?
@ -255,7 +255,7 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
if (m_expressionContext != ExpressionContext::Type) if (m_expressionContext != ExpressionContext::Type)
{ {
m_errorReporter.typeError(0000_error, _expression.location(), "Elementary type name expression only supported in type context."); m_errorReporter.typeError(0000_error, _expression.location(), "Elementary type name expression only supported in type context.");
expressionAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); expressionAnnotation.type = m_typeSystem.freshTypeVariable({});
return false; return false;
} }
TypeSystemHelpers helper{m_typeSystem}; TypeSystemHelpers helper{m_typeSystem};
@ -278,8 +278,8 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
break; break;
case Token::Pair: case Token::Pair:
{ {
auto leftType = m_typeSystem.freshTypeVariable(false, {}); auto leftType = m_typeSystem.freshTypeVariable({});
auto rightType = m_typeSystem.freshTypeVariable(false, {}); auto rightType = m_typeSystem.freshTypeVariable({});
expressionAnnotation.type = expressionAnnotation.type =
helper.functionType( helper.functionType(
helper.kindType(helper.tupleType({leftType, rightType})), helper.kindType(helper.tupleType({leftType, rightType})),
@ -289,8 +289,8 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
} }
case Token::Fun: case Token::Fun:
{ {
auto argType = m_typeSystem.freshTypeVariable(false, {}); auto argType = m_typeSystem.freshTypeVariable({});
auto resultType = m_typeSystem.freshTypeVariable(false, {}); auto resultType = m_typeSystem.freshTypeVariable({});
expressionAnnotation.type = expressionAnnotation.type =
helper.functionType( helper.functionType(
helper.kindType(helper.tupleType({argType, resultType})), helper.kindType(helper.tupleType({argType, resultType})),
@ -300,7 +300,7 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
} }
default: default:
m_errorReporter.typeError(0000_error, _expression.location(), "Only elementary types are supported."); m_errorReporter.typeError(0000_error, _expression.location(), "Only elementary types are supported.");
expressionAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {})); expressionAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable({}));
break; break;
} }
return false; return false;
@ -318,7 +318,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
if (auto* operatorInfo = util::valueOrNullptr(m_analysis.annotation<TypeRegistration>().operators, _binaryOperation.getOperator())) if (auto* operatorInfo = util::valueOrNullptr(m_analysis.annotation<TypeRegistration>().operators, _binaryOperation.getOperator()))
{ {
auto [typeClass, functionName] = *operatorInfo; auto [typeClass, functionName] = *operatorInfo;
Type functionType = m_env->fresh(m_typeSystem.typeClassInfo(typeClass)->functions.at(functionName), true); Type functionType = m_env->fresh(m_typeSystem.typeClassInfo(typeClass)->functions.at(functionName));
_binaryOperation.leftExpression().accept(*this); _binaryOperation.leftExpression().accept(*this);
solAssert(leftAnnotation.type); solAssert(leftAnnotation.type);
@ -326,7 +326,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
solAssert(rightAnnotation.type); solAssert(rightAnnotation.type);
Type argTuple = helper.tupleType({*leftAnnotation.type, *rightAnnotation.type}); Type argTuple = helper.tupleType({*leftAnnotation.type, *rightAnnotation.type});
Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable(false, {})); Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable({}));
unify(functionType, genericFunctionType, _binaryOperation.location()); unify(functionType, genericFunctionType, _binaryOperation.location());
operationAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType)))); operationAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType))));
@ -336,7 +336,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
else else
{ {
m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Binary operations in term context not yet supported."); m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Binary operations in term context not yet supported.");
operationAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); operationAnnotation.type = m_typeSystem.freshTypeVariable({});
return false; return false;
} }
case ExpressionContext::Type: case ExpressionContext::Type:
@ -363,7 +363,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
else else
{ {
m_errorReporter.typeError(0000_error, _binaryOperation.leftExpression().location(), "Expected type but got " + m_env->typeToString(*leftAnnotation.type)); m_errorReporter.typeError(0000_error, _binaryOperation.leftExpression().location(), "Expected type but got " + m_env->typeToString(*leftAnnotation.type));
return m_typeSystem.freshTypeVariable(false, {}); return m_typeSystem.freshTypeVariable({});
} }
}; };
Type leftType = getType(leftAnnotation.type); Type leftType = getType(leftAnnotation.type);
@ -373,12 +373,12 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
else else
{ {
m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Invalid binary operations in type context."); m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Invalid binary operations in type context.");
operationAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {})); operationAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable({}));
} }
return false; return false;
case ExpressionContext::Sort: case ExpressionContext::Sort:
m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Invalid binary operation in sort context."); m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Invalid binary operation in sort context.");
operationAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {})); operationAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable({}));
return false; return false;
} }
return false; return false;
@ -419,19 +419,19 @@ bool TypeInference::visit(VariableDeclaration const& _variableDeclaration)
auto& typeExpressionAnnotation = annotation(*_variableDeclaration.typeExpression()); auto& typeExpressionAnnotation = annotation(*_variableDeclaration.typeExpression());
solAssert(typeExpressionAnnotation.type); solAssert(typeExpressionAnnotation.type);
if (helper.isKindType(*typeExpressionAnnotation.type)) if (helper.isKindType(*typeExpressionAnnotation.type))
variableAnnotation.type = m_env->fresh(helper.destKindType(*typeExpressionAnnotation.type), false); variableAnnotation.type = helper.destKindType(*typeExpressionAnnotation.type);
else else
{ {
m_errorReporter.typeError(0000_error, _variableDeclaration.typeExpression()->location(), "Expected type, but got " + m_env->typeToString(*typeExpressionAnnotation.type)); m_errorReporter.typeError(0000_error, _variableDeclaration.typeExpression()->location(), "Expected type, but got " + m_env->typeToString(*typeExpressionAnnotation.type));
variableAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); variableAnnotation.type = m_typeSystem.freshTypeVariable({});
} }
return false; return false;
} }
variableAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); variableAnnotation.type = m_typeSystem.freshTypeVariable({});
return false; return false;
case ExpressionContext::Type: case ExpressionContext::Type:
variableAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {})); variableAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable({}));
if (_variableDeclaration.typeExpression()) if (_variableDeclaration.typeExpression())
{ {
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort}; ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort};
@ -463,7 +463,7 @@ void TypeInference::endVisit(Assignment const& _assignment)
if (m_expressionContext != ExpressionContext::Term) if (m_expressionContext != ExpressionContext::Term)
{ {
m_errorReporter.typeError(0000_error, _assignment.location(), "Assignment outside term context."); m_errorReporter.typeError(0000_error, _assignment.location(), "Assignment outside term context.");
assignmentAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); assignmentAnnotation.type = m_typeSystem.freshTypeVariable({});
return; return;
} }
@ -511,13 +511,13 @@ experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langut
solAssert(declarationAnnotation.type); solAssert(declarationAnnotation.type);
if (dynamic_cast<FunctionDefinition const*>(&_declaration)) if (dynamic_cast<FunctionDefinition const*>(&_declaration))
return m_env->fresh(*declarationAnnotation.type, true); return m_env->fresh(*declarationAnnotation.type);
else if (dynamic_cast<VariableDeclaration const*>(&_declaration)) else if (dynamic_cast<VariableDeclaration const*>(&_declaration))
return *declarationAnnotation.type; return *declarationAnnotation.type;
else if (dynamic_cast<TypeClassDefinition const*>(&_declaration)) else if (dynamic_cast<TypeClassDefinition const*>(&_declaration))
return *declarationAnnotation.type; return m_env->fresh(*declarationAnnotation.type);
else if (dynamic_cast<TypeDefinition const*>(&_declaration)) else if (dynamic_cast<TypeDefinition const*>(&_declaration))
return *declarationAnnotation.type; return m_env->fresh(*declarationAnnotation.type);
else else
solAssert(false); solAssert(false);
break; break;
@ -549,7 +549,7 @@ experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langut
else if (dynamic_cast<TypeDefinition const*>(&_declaration)) else if (dynamic_cast<TypeDefinition const*>(&_declaration))
{ {
// TODO: helper.destKindType(*declarationAnnotation.type); // TODO: helper.destKindType(*declarationAnnotation.type);
return m_env->fresh(*declarationAnnotation.type, true); return m_env->fresh(*declarationAnnotation.type);
} }
else else
solAssert(false); solAssert(false);
@ -561,12 +561,12 @@ experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langut
{ {
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Term}; ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Term};
typeClass->accept(*this); typeClass->accept(*this);
return helper.kindType(m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{typeClass}}})); return helper.kindType(m_typeSystem.freshTypeVariable(Sort{{TypeClass{typeClass}}}));
} }
else else
{ {
m_errorReporter.typeError(0000_error, _location, "Expected type class."); m_errorReporter.typeError(0000_error, _location, "Expected type class.");
return helper.kindType(m_typeSystem.freshTypeVariable(false, {})); return helper.kindType(m_typeSystem.freshTypeVariable({}));
} }
break; break;
} }
@ -595,7 +595,7 @@ bool TypeInference::visit(Identifier const& _identifier)
case ExpressionContext::Type: case ExpressionContext::Type:
{ {
// TODO: register free type variable name! // TODO: register free type variable name!
identifierAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {})); identifierAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable({}));
return false; return false;
} }
case ExpressionContext::Sort: case ExpressionContext::Sort:
@ -626,7 +626,7 @@ void TypeInference::endVisit(TupleExpression const& _tupleExpression)
else else
{ {
m_errorReporter.typeError(0000_error, _expr->location(), "Expected type, but got " + m_env->typeToString(*componentAnnotation.type)); m_errorReporter.typeError(0000_error, _expr->location(), "Expected type, but got " + m_env->typeToString(*componentAnnotation.type));
return m_typeSystem.freshTypeVariable(false, {}); return m_typeSystem.freshTypeVariable({});
} }
case ExpressionContext::Sort: case ExpressionContext::Sort:
return *componentAnnotation.type; return *componentAnnotation.type;
@ -643,7 +643,7 @@ void TypeInference::endVisit(TupleExpression const& _tupleExpression)
break; break;
case ExpressionContext::Sort: case ExpressionContext::Sort:
{ {
Type type = helper.kindType(m_typeSystem.freshTypeVariable(false, {})); Type type = helper.kindType(m_typeSystem.freshTypeVariable({}));
for (auto componentType: componentTypes) for (auto componentType: componentTypes)
unify(type, componentType, _tupleExpression.location()); unify(type, componentType, _tupleExpression.location());
expressionAnnotation.type = type; expressionAnnotation.type = type;
@ -768,19 +768,19 @@ void TypeInference::endVisit(MemberAccess const& _memberAccess)
auto constructor = std::get<0>(helper.destTypeConstant(expressionType)); auto constructor = std::get<0>(helper.destTypeConstant(expressionType));
if (auto* typeMember = util::valueOrNullptr(annotation().members.at(constructor), _memberAccess.memberName())) if (auto* typeMember = util::valueOrNullptr(annotation().members.at(constructor), _memberAccess.memberName()))
{ {
Type type = m_env->fresh(typeMember->type, true); Type type = m_env->fresh(typeMember->type);
annotation(_memberAccess).type = type; annotation(_memberAccess).type = type;
return; return;
} }
else else
{ {
m_errorReporter.typeError(0000_error, _memberAccess.memberLocation(), "Member not found."); m_errorReporter.typeError(0000_error, _memberAccess.memberLocation(), "Member not found.");
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable(false, {}); annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({});
return; return;
} }
} }
m_errorReporter.typeError(0000_error, _memberAccess.expression().location(), "Unsupported member access expression."); m_errorReporter.typeError(0000_error, _memberAccess.expression().location(), "Unsupported member access expression.");
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable(false, {}); annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({});
} }
bool TypeInference::visit(MemberAccess const& _memberAccess) bool TypeInference::visit(MemberAccess const& _memberAccess)
@ -788,7 +788,7 @@ bool TypeInference::visit(MemberAccess const& _memberAccess)
if (m_expressionContext != ExpressionContext::Term) if (m_expressionContext != ExpressionContext::Term)
{ {
m_errorReporter.typeError(0000_error, _memberAccess.location(), "Member access outside term context."); m_errorReporter.typeError(0000_error, _memberAccess.location(), "Member access outside term context.");
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable(false, {}); annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({});
return false; return false;
} }
return true; return true;
@ -815,7 +815,8 @@ bool TypeInference::visit(TypeDefinition const& _typeDefinition)
vector<Type> arguments; vector<Type> arguments;
if (_typeDefinition.arguments()) if (_typeDefinition.arguments())
for (size_t i = 0; i < _typeDefinition.arguments()->parameters().size(); ++i) for (size_t i = 0; i < _typeDefinition.arguments()->parameters().size(); ++i)
arguments.emplace_back(m_typeSystem.freshTypeVariable(true, {})); // TODO: GENERALIZE
arguments.emplace_back(m_typeSystem.freshTypeVariable({}));
Type type = m_typeSystem.type(TypeConstructor{&_typeDefinition}, arguments); Type type = m_typeSystem.type(TypeConstructor{&_typeDefinition}, arguments);
if (arguments.empty()) if (arguments.empty())
@ -866,12 +867,12 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
else else
{ {
m_errorReporter.typeError(0000_error, arg->location(), "Expected type, but got " + m_env->typeToString(*argAnnotation.type)); m_errorReporter.typeError(0000_error, arg->location(), "Expected type, but got " + m_env->typeToString(*argAnnotation.type));
argTypes.emplace_back(m_typeSystem.freshTypeVariable(false, {})); argTypes.emplace_back(m_typeSystem.freshTypeVariable({}));
} }
break; break;
case ExpressionContext::Sort: case ExpressionContext::Sort:
m_errorReporter.typeError(0000_error, _functionCall.location(), "Function call in sort context."); m_errorReporter.typeError(0000_error, _functionCall.location(), "Function call in sort context.");
functionCallAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); functionCallAnnotation.type = m_typeSystem.freshTypeVariable({});
break; break;
} }
} }
@ -881,7 +882,7 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
case ExpressionContext::Term: case ExpressionContext::Term:
{ {
Type argTuple = helper.tupleType(argTypes); Type argTuple = helper.tupleType(argTypes);
Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable(false, {})); Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable({}));
unify(functionType, genericFunctionType, _functionCall.location()); unify(functionType, genericFunctionType, _functionCall.location());
functionCallAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType)))); functionCallAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType))));
@ -890,7 +891,7 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
case ExpressionContext::Type: case ExpressionContext::Type:
{ {
Type argTuple = helper.kindType(helper.tupleType(argTypes)); Type argTuple = helper.kindType(helper.tupleType(argTypes));
Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshKindVariable(false, {})); Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshKindVariable({}));
unify(functionType, genericFunctionType, _functionCall.location()); unify(functionType, genericFunctionType, _functionCall.location());
functionCallAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType)))); functionCallAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType))));
@ -1075,6 +1076,6 @@ bool TypeInference::visit(Literal const& _literal)
m_errorReporter.typeError(0000_error, _literal.location(), "Only integers are supported."); m_errorReporter.typeError(0000_error, _literal.location(), "Only integers are supported.");
return false; return false;
} }
literalAnnotation.type = m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{BuiltinClass::Integer}}}); literalAnnotation.type = m_typeSystem.freshTypeVariable(Sort{{TypeClass{BuiltinClass::Integer}}});
return false; return false;
} }

View File

@ -47,7 +47,7 @@ m_typeSystem(_analysis.typeSystem())
m_typeSystem.declareTypeConstructor(type, name, arity); m_typeSystem.declareTypeConstructor(type, name, arity);
auto declareBuiltinClass = [&](BuiltinClass _class, auto _memberCreator, Sort _sort = {}) { auto declareBuiltinClass = [&](BuiltinClass _class, auto _memberCreator, Sort _sort = {}) {
Type type = m_typeSystem.freshTypeVariable(false, std::move(_sort)); Type type = m_typeSystem.freshTypeVariable(std::move(_sort));
auto error = m_typeSystem.declareTypeClass( auto error = m_typeSystem.declareTypeClass(
TypeClass{_class}, TypeClass{_class},
type, type,

View File

@ -112,15 +112,13 @@ struct Arity
struct TypeVariable struct TypeVariable
{ {
size_t index() const { return m_index; } size_t index() const { return m_index; }
bool generic() const { return m_generic; }
Sort const& sort() const { return m_sort; } Sort const& sort() const { return m_sort; }
private: private:
friend class TypeSystem; friend class TypeSystem;
size_t m_index = 0; size_t m_index = 0;
Sort m_sort; Sort m_sort;
bool m_generic = false; TypeVariable(size_t _index, Sort _sort):
TypeVariable(size_t _index, Sort _sort, bool _generic): m_index(_index), m_sort(std::move(_sort)) {}
m_index(_index), m_sort(std::move(_sort)), m_generic(_generic) {}
}; };
} }

View File

@ -160,7 +160,7 @@ std::string TypeEnvironment::typeToString(Type const& _type) const
}, },
[](TypeVariable const& _type) { [](TypeVariable const& _type) {
std::stringstream stream; std::stringstream stream;
stream << (_type.generic() ? '?' : '\'') << "var" << _type.index(); stream << "'var" << _type.index();
switch (_type.sort().classes.size()) switch (_type.sort().classes.size())
{ {
case 0: case 0:
@ -204,7 +204,7 @@ vector<TypeEnvironment::UnificationFailure> TypeEnvironment::unify(Type _a, Type
failures += instantiate(_right, _left); failures += instantiate(_right, _left);
else else
{ {
Type newVar = m_typeSystem.freshVariable(_left.generic() && _right.generic(), _left.sort() + _right.sort()); Type newVar = m_typeSystem.freshVariable(_left.sort() + _right.sort());
failures += instantiate(_left, newVar); failures += instantiate(_left, newVar);
failures += instantiate(_right, newVar); failures += instantiate(_right, newVar);
} }
@ -285,22 +285,22 @@ TypeSystem::TypeSystem()
}; };
} }
experimental::Type TypeSystem::freshVariable(bool _generic, Sort _sort) experimental::Type TypeSystem::freshVariable(Sort _sort)
{ {
uint64_t index = m_numTypeVariables++; uint64_t index = m_numTypeVariables++;
return TypeVariable(index, std::move(_sort), _generic); return TypeVariable(index, std::move(_sort));
} }
experimental::Type TypeSystem::freshTypeVariable(bool _generic, Sort _sort) experimental::Type TypeSystem::freshTypeVariable(Sort _sort)
{ {
_sort.classes.emplace(TypeClass{BuiltinClass::Type}); _sort.classes.emplace(TypeClass{BuiltinClass::Type});
return freshVariable(_generic, _sort); return freshVariable(_sort);
} }
experimental::Type TypeSystem::freshKindVariable(bool _generic, Sort _sort) experimental::Type TypeSystem::freshKindVariable(Sort _sort)
{ {
_sort.classes.emplace(TypeClass{BuiltinClass::Kind}); _sort.classes.emplace(TypeClass{BuiltinClass::Kind});
return freshVariable(_generic, _sort); return freshVariable(_sort);
} }
vector<TypeEnvironment::UnificationFailure> TypeEnvironment::instantiate(TypeVariable _variable, Type _type) vector<TypeEnvironment::UnificationFailure> TypeEnvironment::instantiate(TypeVariable _variable, Type _type)
@ -418,38 +418,33 @@ experimental::Type TypeSystem::type(TypeConstructor _constructor, std::vector<Ty
return TypeConstant{_constructor, _arguments}; return TypeConstant{_constructor, _arguments};
} }
experimental::Type TypeEnvironment::fresh(Type _type, bool _generalize) experimental::Type TypeEnvironment::fresh(Type _type)
{ {
std::unordered_map<uint64_t, Type> mapping; std::unordered_map<uint64_t, Type> mapping;
auto freshImpl = [&](Type _type, bool _generalize, auto _recurse) -> Type { auto freshImpl = [&](Type _type, auto _recurse) -> Type {
return std::visit(util::GenericVisitor{ return std::visit(util::GenericVisitor{
[&](TypeConstant const& _type) -> Type { [&](TypeConstant const& _type) -> Type {
return TypeConstant{ return TypeConstant{
_type.constructor, _type.constructor,
_type.arguments | ranges::views::transform([&](Type _argType) { _type.arguments | ranges::views::transform([&](Type _argType) {
return _recurse(_argType, _generalize, _recurse); return _recurse(_argType, _recurse);
}) | ranges::to<vector<Type>> }) | ranges::to<vector<Type>>
}; };
}, },
[&](TypeVariable const& _var) -> Type { [&](TypeVariable const& _var) -> Type {
if (_generalize || _var.generic()) if (auto* mapped = util::valueOrNullptr(mapping, _var.index()))
{ {
if (auto* mapped = util::valueOrNullptr(mapping, _var.index())) auto* typeVariable = get_if<TypeVariable>(mapped);
{ solAssert(typeVariable);
auto* typeVariable = get_if<TypeVariable>(mapped); // TODO: can there be a mismatch?
solAssert(typeVariable); solAssert(typeVariable->sort() == _var.sort());
// TODO: can there be a mismatch? return *mapped;
solAssert(typeVariable->sort() == _var.sort());
return *mapped;
}
return mapping[_var.index()] = m_typeSystem.freshTypeVariable(true, _var.sort());
} }
else return mapping[_var.index()] = m_typeSystem.freshTypeVariable(_var.sort());
return _type;
}, },
}, resolve(_type)); }, resolve(_type));
}; };
return freshImpl(_type, _generalize, freshImpl); return freshImpl(_type, freshImpl);
} }
std::optional<std::string> TypeSystem::instantiateClass(Type _instanceVariable, Arity _arity, map<string, Type> _functionTypes) std::optional<std::string> TypeSystem::instantiateClass(Type _instanceVariable, Arity _arity, map<string, Type> _functionTypes)

View File

@ -37,7 +37,7 @@ public:
TypeEnvironment clone() const; TypeEnvironment clone() const;
Type resolve(Type _type) const; Type resolve(Type _type) const;
Type resolveRecursive(Type _type) const; Type resolveRecursive(Type _type) const;
Type fresh(Type _type, bool _generalize); Type fresh(Type _type);
struct TypeMismatch { Type a; Type b; }; struct TypeMismatch { Type a; Type b; };
struct SortMismatch { Type type; Sort sort; }; struct SortMismatch { Type type; Sort sort; };
using UnificationFailure = std::variant<TypeMismatch, SortMismatch>; using UnificationFailure = std::variant<TypeMismatch, SortMismatch>;
@ -99,13 +99,13 @@ public:
[[nodiscard]] std::optional<std::string> declareTypeClass(TypeClass _class, Type _typeVariable, std::map<std::string, Type> _functions); [[nodiscard]] std::optional<std::string> declareTypeClass(TypeClass _class, Type _typeVariable, std::map<std::string, Type> _functions);
[[nodiscard]] std::optional<std::string> instantiateClass(Type _instanceVariable, Arity _arity, std::map<std::string, Type> _functions); [[nodiscard]] std::optional<std::string> instantiateClass(Type _instanceVariable, Arity _arity, std::map<std::string, Type> _functions);
Type freshTypeVariable(bool _generic, Sort _sort); Type freshTypeVariable(Sort _sort);
Type freshKindVariable(bool _generic, Sort _sort); Type freshKindVariable(Sort _sort);
TypeEnvironment const& env() const { return m_globalTypeEnvironment; } TypeEnvironment const& env() const { return m_globalTypeEnvironment; }
TypeEnvironment& env() { return m_globalTypeEnvironment; } TypeEnvironment& env() { return m_globalTypeEnvironment; }
Type freshVariable(bool _generic, Sort _sort); Type freshVariable(Sort _sort);
private: private:
size_t m_numTypeVariables = 0; size_t m_numTypeVariables = 0;
std::map<TypeConstructor, TypeConstructorInfo> m_typeConstructors; std::map<TypeConstructor, TypeConstructorInfo> m_typeConstructors;