This commit is contained in:
Daniel Kirchner 2023-06-25 07:04:50 +02:00
parent 6f352cbcbe
commit 5bfe862bc4
5 changed files with 66 additions and 72 deletions

View File

@ -149,7 +149,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
map<string, Type> functionTypes;
Type typeVar = m_typeSystem.freshTypeVariable(false, {});
Type typeVar = m_typeSystem.freshTypeVariable({});
auto& typeMembers = annotation().members[&_typeClassDefinition];
@ -160,7 +160,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
solAssert(functionDefinition);
auto functionDefinitionType = annotation(*functionDefinition).type;
solAssert(functionDefinitionType);
auto functionType = m_env->fresh(*functionDefinitionType, true);
auto functionType = m_env->fresh(*functionDefinitionType);
functionTypes[functionDefinition->name()] = functionType;
auto typeVars = TypeSystemHelpers{m_typeSystem}.typeVars(functionType);
if(typeVars.size() != 1)
@ -176,7 +176,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
solAssert(typeVariableAnnotation.type);
TypeSystemHelpers helper{m_typeSystem};
unify(*typeVariableAnnotation.type, helper.kindType(m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{&_typeClassDefinition}}})), _typeClassDefinition.location());
unify(*typeVariableAnnotation.type, helper.kindType(m_typeSystem.freshTypeVariable(Sort{{TypeClass{&_typeClassDefinition}}})), _typeClassDefinition.location());
for (auto instantiation: m_analysis.annotation<TypeRegistration>(_typeClassDefinition).instantiations | ranges::views::values)
// TODO: recursion-safety?
@ -255,7 +255,7 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
if (m_expressionContext != ExpressionContext::Type)
{
m_errorReporter.typeError(0000_error, _expression.location(), "Elementary type name expression only supported in type context.");
expressionAnnotation.type = m_typeSystem.freshTypeVariable(false, {});
expressionAnnotation.type = m_typeSystem.freshTypeVariable({});
return false;
}
TypeSystemHelpers helper{m_typeSystem};
@ -278,8 +278,8 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
break;
case Token::Pair:
{
auto leftType = m_typeSystem.freshTypeVariable(false, {});
auto rightType = m_typeSystem.freshTypeVariable(false, {});
auto leftType = m_typeSystem.freshTypeVariable({});
auto rightType = m_typeSystem.freshTypeVariable({});
expressionAnnotation.type =
helper.functionType(
helper.kindType(helper.tupleType({leftType, rightType})),
@ -289,8 +289,8 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
}
case Token::Fun:
{
auto argType = m_typeSystem.freshTypeVariable(false, {});
auto resultType = m_typeSystem.freshTypeVariable(false, {});
auto argType = m_typeSystem.freshTypeVariable({});
auto resultType = m_typeSystem.freshTypeVariable({});
expressionAnnotation.type =
helper.functionType(
helper.kindType(helper.tupleType({argType, resultType})),
@ -300,7 +300,7 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
}
default:
m_errorReporter.typeError(0000_error, _expression.location(), "Only elementary types are supported.");
expressionAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
expressionAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable({}));
break;
}
return false;
@ -318,7 +318,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
if (auto* operatorInfo = util::valueOrNullptr(m_analysis.annotation<TypeRegistration>().operators, _binaryOperation.getOperator()))
{
auto [typeClass, functionName] = *operatorInfo;
Type functionType = m_env->fresh(m_typeSystem.typeClassInfo(typeClass)->functions.at(functionName), true);
Type functionType = m_env->fresh(m_typeSystem.typeClassInfo(typeClass)->functions.at(functionName));
_binaryOperation.leftExpression().accept(*this);
solAssert(leftAnnotation.type);
@ -326,7 +326,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
solAssert(rightAnnotation.type);
Type argTuple = helper.tupleType({*leftAnnotation.type, *rightAnnotation.type});
Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable(false, {}));
Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable({}));
unify(functionType, genericFunctionType, _binaryOperation.location());
operationAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType))));
@ -336,7 +336,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
else
{
m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Binary operations in term context not yet supported.");
operationAnnotation.type = m_typeSystem.freshTypeVariable(false, {});
operationAnnotation.type = m_typeSystem.freshTypeVariable({});
return false;
}
case ExpressionContext::Type:
@ -363,7 +363,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
else
{
m_errorReporter.typeError(0000_error, _binaryOperation.leftExpression().location(), "Expected type but got " + m_env->typeToString(*leftAnnotation.type));
return m_typeSystem.freshTypeVariable(false, {});
return m_typeSystem.freshTypeVariable({});
}
};
Type leftType = getType(leftAnnotation.type);
@ -373,12 +373,12 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
else
{
m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Invalid binary operations in type context.");
operationAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
operationAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable({}));
}
return false;
case ExpressionContext::Sort:
m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Invalid binary operation in sort context.");
operationAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
operationAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable({}));
return false;
}
return false;
@ -419,19 +419,19 @@ bool TypeInference::visit(VariableDeclaration const& _variableDeclaration)
auto& typeExpressionAnnotation = annotation(*_variableDeclaration.typeExpression());
solAssert(typeExpressionAnnotation.type);
if (helper.isKindType(*typeExpressionAnnotation.type))
variableAnnotation.type = m_env->fresh(helper.destKindType(*typeExpressionAnnotation.type), false);
variableAnnotation.type = helper.destKindType(*typeExpressionAnnotation.type);
else
{
m_errorReporter.typeError(0000_error, _variableDeclaration.typeExpression()->location(), "Expected type, but got " + m_env->typeToString(*typeExpressionAnnotation.type));
variableAnnotation.type = m_typeSystem.freshTypeVariable(false, {});
variableAnnotation.type = m_typeSystem.freshTypeVariable({});
}
return false;
}
variableAnnotation.type = m_typeSystem.freshTypeVariable(false, {});
variableAnnotation.type = m_typeSystem.freshTypeVariable({});
return false;
case ExpressionContext::Type:
variableAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
variableAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable({}));
if (_variableDeclaration.typeExpression())
{
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort};
@ -463,7 +463,7 @@ void TypeInference::endVisit(Assignment const& _assignment)
if (m_expressionContext != ExpressionContext::Term)
{
m_errorReporter.typeError(0000_error, _assignment.location(), "Assignment outside term context.");
assignmentAnnotation.type = m_typeSystem.freshTypeVariable(false, {});
assignmentAnnotation.type = m_typeSystem.freshTypeVariable({});
return;
}
@ -511,13 +511,13 @@ experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langut
solAssert(declarationAnnotation.type);
if (dynamic_cast<FunctionDefinition const*>(&_declaration))
return m_env->fresh(*declarationAnnotation.type, true);
return m_env->fresh(*declarationAnnotation.type);
else if (dynamic_cast<VariableDeclaration const*>(&_declaration))
return *declarationAnnotation.type;
else if (dynamic_cast<TypeClassDefinition const*>(&_declaration))
return *declarationAnnotation.type;
return m_env->fresh(*declarationAnnotation.type);
else if (dynamic_cast<TypeDefinition const*>(&_declaration))
return *declarationAnnotation.type;
return m_env->fresh(*declarationAnnotation.type);
else
solAssert(false);
break;
@ -549,7 +549,7 @@ experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langut
else if (dynamic_cast<TypeDefinition const*>(&_declaration))
{
// TODO: helper.destKindType(*declarationAnnotation.type);
return m_env->fresh(*declarationAnnotation.type, true);
return m_env->fresh(*declarationAnnotation.type);
}
else
solAssert(false);
@ -561,12 +561,12 @@ experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langut
{
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Term};
typeClass->accept(*this);
return helper.kindType(m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{typeClass}}}));
return helper.kindType(m_typeSystem.freshTypeVariable(Sort{{TypeClass{typeClass}}}));
}
else
{
m_errorReporter.typeError(0000_error, _location, "Expected type class.");
return helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
return helper.kindType(m_typeSystem.freshTypeVariable({}));
}
break;
}
@ -595,7 +595,7 @@ bool TypeInference::visit(Identifier const& _identifier)
case ExpressionContext::Type:
{
// TODO: register free type variable name!
identifierAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
identifierAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable({}));
return false;
}
case ExpressionContext::Sort:
@ -626,7 +626,7 @@ void TypeInference::endVisit(TupleExpression const& _tupleExpression)
else
{
m_errorReporter.typeError(0000_error, _expr->location(), "Expected type, but got " + m_env->typeToString(*componentAnnotation.type));
return m_typeSystem.freshTypeVariable(false, {});
return m_typeSystem.freshTypeVariable({});
}
case ExpressionContext::Sort:
return *componentAnnotation.type;
@ -643,7 +643,7 @@ void TypeInference::endVisit(TupleExpression const& _tupleExpression)
break;
case ExpressionContext::Sort:
{
Type type = helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
Type type = helper.kindType(m_typeSystem.freshTypeVariable({}));
for (auto componentType: componentTypes)
unify(type, componentType, _tupleExpression.location());
expressionAnnotation.type = type;
@ -768,19 +768,19 @@ void TypeInference::endVisit(MemberAccess const& _memberAccess)
auto constructor = std::get<0>(helper.destTypeConstant(expressionType));
if (auto* typeMember = util::valueOrNullptr(annotation().members.at(constructor), _memberAccess.memberName()))
{
Type type = m_env->fresh(typeMember->type, true);
Type type = m_env->fresh(typeMember->type);
annotation(_memberAccess).type = type;
return;
}
else
{
m_errorReporter.typeError(0000_error, _memberAccess.memberLocation(), "Member not found.");
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable(false, {});
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({});
return;
}
}
m_errorReporter.typeError(0000_error, _memberAccess.expression().location(), "Unsupported member access expression.");
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable(false, {});
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({});
}
bool TypeInference::visit(MemberAccess const& _memberAccess)
@ -788,7 +788,7 @@ bool TypeInference::visit(MemberAccess const& _memberAccess)
if (m_expressionContext != ExpressionContext::Term)
{
m_errorReporter.typeError(0000_error, _memberAccess.location(), "Member access outside term context.");
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable(false, {});
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({});
return false;
}
return true;
@ -815,7 +815,8 @@ bool TypeInference::visit(TypeDefinition const& _typeDefinition)
vector<Type> arguments;
if (_typeDefinition.arguments())
for (size_t i = 0; i < _typeDefinition.arguments()->parameters().size(); ++i)
arguments.emplace_back(m_typeSystem.freshTypeVariable(true, {}));
// TODO: GENERALIZE
arguments.emplace_back(m_typeSystem.freshTypeVariable({}));
Type type = m_typeSystem.type(TypeConstructor{&_typeDefinition}, arguments);
if (arguments.empty())
@ -866,12 +867,12 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
else
{
m_errorReporter.typeError(0000_error, arg->location(), "Expected type, but got " + m_env->typeToString(*argAnnotation.type));
argTypes.emplace_back(m_typeSystem.freshTypeVariable(false, {}));
argTypes.emplace_back(m_typeSystem.freshTypeVariable({}));
}
break;
case ExpressionContext::Sort:
m_errorReporter.typeError(0000_error, _functionCall.location(), "Function call in sort context.");
functionCallAnnotation.type = m_typeSystem.freshTypeVariable(false, {});
functionCallAnnotation.type = m_typeSystem.freshTypeVariable({});
break;
}
}
@ -881,7 +882,7 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
case ExpressionContext::Term:
{
Type argTuple = helper.tupleType(argTypes);
Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable(false, {}));
Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable({}));
unify(functionType, genericFunctionType, _functionCall.location());
functionCallAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType))));
@ -890,7 +891,7 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
case ExpressionContext::Type:
{
Type argTuple = helper.kindType(helper.tupleType(argTypes));
Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshKindVariable(false, {}));
Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshKindVariable({}));
unify(functionType, genericFunctionType, _functionCall.location());
functionCallAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType))));
@ -1075,6 +1076,6 @@ bool TypeInference::visit(Literal const& _literal)
m_errorReporter.typeError(0000_error, _literal.location(), "Only integers are supported.");
return false;
}
literalAnnotation.type = m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{BuiltinClass::Integer}}});
literalAnnotation.type = m_typeSystem.freshTypeVariable(Sort{{TypeClass{BuiltinClass::Integer}}});
return false;
}

View File

@ -47,7 +47,7 @@ m_typeSystem(_analysis.typeSystem())
m_typeSystem.declareTypeConstructor(type, name, arity);
auto declareBuiltinClass = [&](BuiltinClass _class, auto _memberCreator, Sort _sort = {}) {
Type type = m_typeSystem.freshTypeVariable(false, std::move(_sort));
Type type = m_typeSystem.freshTypeVariable(std::move(_sort));
auto error = m_typeSystem.declareTypeClass(
TypeClass{_class},
type,

View File

@ -112,15 +112,13 @@ struct Arity
struct TypeVariable
{
size_t index() const { return m_index; }
bool generic() const { return m_generic; }
Sort const& sort() const { return m_sort; }
private:
friend class TypeSystem;
size_t m_index = 0;
Sort m_sort;
bool m_generic = false;
TypeVariable(size_t _index, Sort _sort, bool _generic):
m_index(_index), m_sort(std::move(_sort)), m_generic(_generic) {}
TypeVariable(size_t _index, Sort _sort):
m_index(_index), m_sort(std::move(_sort)) {}
};
}

View File

@ -160,7 +160,7 @@ std::string TypeEnvironment::typeToString(Type const& _type) const
},
[](TypeVariable const& _type) {
std::stringstream stream;
stream << (_type.generic() ? '?' : '\'') << "var" << _type.index();
stream << "'var" << _type.index();
switch (_type.sort().classes.size())
{
case 0:
@ -204,7 +204,7 @@ vector<TypeEnvironment::UnificationFailure> TypeEnvironment::unify(Type _a, Type
failures += instantiate(_right, _left);
else
{
Type newVar = m_typeSystem.freshVariable(_left.generic() && _right.generic(), _left.sort() + _right.sort());
Type newVar = m_typeSystem.freshVariable(_left.sort() + _right.sort());
failures += instantiate(_left, newVar);
failures += instantiate(_right, newVar);
}
@ -285,22 +285,22 @@ TypeSystem::TypeSystem()
};
}
experimental::Type TypeSystem::freshVariable(bool _generic, Sort _sort)
experimental::Type TypeSystem::freshVariable(Sort _sort)
{
uint64_t index = m_numTypeVariables++;
return TypeVariable(index, std::move(_sort), _generic);
return TypeVariable(index, std::move(_sort));
}
experimental::Type TypeSystem::freshTypeVariable(bool _generic, Sort _sort)
experimental::Type TypeSystem::freshTypeVariable(Sort _sort)
{
_sort.classes.emplace(TypeClass{BuiltinClass::Type});
return freshVariable(_generic, _sort);
return freshVariable(_sort);
}
experimental::Type TypeSystem::freshKindVariable(bool _generic, Sort _sort)
experimental::Type TypeSystem::freshKindVariable(Sort _sort)
{
_sort.classes.emplace(TypeClass{BuiltinClass::Kind});
return freshVariable(_generic, _sort);
return freshVariable(_sort);
}
vector<TypeEnvironment::UnificationFailure> TypeEnvironment::instantiate(TypeVariable _variable, Type _type)
@ -418,38 +418,33 @@ experimental::Type TypeSystem::type(TypeConstructor _constructor, std::vector<Ty
return TypeConstant{_constructor, _arguments};
}
experimental::Type TypeEnvironment::fresh(Type _type, bool _generalize)
experimental::Type TypeEnvironment::fresh(Type _type)
{
std::unordered_map<uint64_t, Type> mapping;
auto freshImpl = [&](Type _type, bool _generalize, auto _recurse) -> Type {
auto freshImpl = [&](Type _type, auto _recurse) -> Type {
return std::visit(util::GenericVisitor{
[&](TypeConstant const& _type) -> Type {
return TypeConstant{
_type.constructor,
_type.arguments | ranges::views::transform([&](Type _argType) {
return _recurse(_argType, _generalize, _recurse);
return _recurse(_argType, _recurse);
}) | ranges::to<vector<Type>>
};
},
[&](TypeVariable const& _var) -> Type {
if (_generalize || _var.generic())
if (auto* mapped = util::valueOrNullptr(mapping, _var.index()))
{
if (auto* mapped = util::valueOrNullptr(mapping, _var.index()))
{
auto* typeVariable = get_if<TypeVariable>(mapped);
solAssert(typeVariable);
// TODO: can there be a mismatch?
solAssert(typeVariable->sort() == _var.sort());
return *mapped;
}
return mapping[_var.index()] = m_typeSystem.freshTypeVariable(true, _var.sort());
auto* typeVariable = get_if<TypeVariable>(mapped);
solAssert(typeVariable);
// TODO: can there be a mismatch?
solAssert(typeVariable->sort() == _var.sort());
return *mapped;
}
else
return _type;
return mapping[_var.index()] = m_typeSystem.freshTypeVariable(_var.sort());
},
}, resolve(_type));
};
return freshImpl(_type, _generalize, freshImpl);
return freshImpl(_type, freshImpl);
}
std::optional<std::string> TypeSystem::instantiateClass(Type _instanceVariable, Arity _arity, map<string, Type> _functionTypes)

View File

@ -37,7 +37,7 @@ public:
TypeEnvironment clone() const;
Type resolve(Type _type) const;
Type resolveRecursive(Type _type) const;
Type fresh(Type _type, bool _generalize);
Type fresh(Type _type);
struct TypeMismatch { Type a; Type b; };
struct SortMismatch { Type type; Sort sort; };
using UnificationFailure = std::variant<TypeMismatch, SortMismatch>;
@ -99,13 +99,13 @@ public:
[[nodiscard]] std::optional<std::string> declareTypeClass(TypeClass _class, Type _typeVariable, std::map<std::string, Type> _functions);
[[nodiscard]] std::optional<std::string> instantiateClass(Type _instanceVariable, Arity _arity, std::map<std::string, Type> _functions);
Type freshTypeVariable(bool _generic, Sort _sort);
Type freshKindVariable(bool _generic, Sort _sort);
Type freshTypeVariable(Sort _sort);
Type freshKindVariable(Sort _sort);
TypeEnvironment const& env() const { return m_globalTypeEnvironment; }
TypeEnvironment& env() { return m_globalTypeEnvironment; }
Type freshVariable(bool _generic, Sort _sort);
Type freshVariable(Sort _sort);
private:
size_t m_numTypeVariables = 0;
std::map<TypeConstructor, TypeConstructorInfo> m_typeConstructors;