This commit is contained in:
Daniel Kirchner 2023-06-25 08:33:45 +02:00
parent fd1db21d64
commit 58b7344c5a
2 changed files with 134 additions and 186 deletions

View File

@ -70,17 +70,11 @@ bool TypeInference::visit(FunctionDefinition const& _functionDefinition)
if (_functionDefinition.returnParameterList())
_functionDefinition.returnParameterList()->accept(*this);
auto typeFromParameterList = [&](ParameterList const* _list) {
if (!_list)
return m_unitType;
auto& listAnnotation = annotation(*_list);
solAssert(listAnnotation.type);
return *listAnnotation.type;
};
auto getListType = [&](ParameterList const* _list) { return _list ? getType(*_list) : m_unitType; };
Type functionType = TypeSystemHelpers{m_typeSystem}.functionType(
typeFromParameterList(&_functionDefinition.parameterList()),
typeFromParameterList(_functionDefinition.returnParameterList().get())
getListType(&_functionDefinition.parameterList()),
getListType(_functionDefinition.returnParameterList().get())
);
m_currentFunctionType = functionType;
@ -100,36 +94,19 @@ void TypeInference::endVisit(Return const& _return)
solAssert(m_currentFunctionType);
if (_return.expression())
{
auto& returnExpressionAnnotation = annotation(*_return.expression());
solAssert(returnExpressionAnnotation.type);
Type returnExpressionType = getType(*_return.expression());
Type functionReturnType = get<1>(TypeSystemHelpers{m_typeSystem}.destFunctionType(*m_currentFunctionType));
unify(functionReturnType, *returnExpressionAnnotation.type, _return.location());
unify(functionReturnType, returnExpressionType, _return.location());
}
}
bool TypeInference::visit(ParameterList const&)
{
return true;
}
void TypeInference::endVisit(ParameterList const& _parameterList)
{
auto& listAnnotation = annotation(_parameterList);
solAssert(!listAnnotation.type);
std::vector<Type> argTypes;
for(auto arg: _parameterList.parameters())
{
auto& argAnnotation = annotation(*arg);
solAssert(argAnnotation.type);
argTypes.emplace_back(*argAnnotation.type);
}
listAnnotation.type = TypeSystemHelpers{m_typeSystem}.tupleType(argTypes);
}
bool TypeInference::visitNode(ASTNode const& _node)
{
m_errorReporter.fatalTypeError(0000_error, _node.location(), "Unsupported AST node during type inference.");
return false;
listAnnotation.type = TypeSystemHelpers{m_typeSystem}.tupleType(
_parameterList.parameters() | ranges::views::transform([&](auto _arg) { return getType(*_arg); }) | ranges::to<vector<Type>>
);
}
bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
@ -140,12 +117,10 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
return false;
m_typeSystem.declareTypeConstructor(&_typeClassDefinition, _typeClassDefinition.name(), 0);
typeClassAnnotation.type = TypeConstant{&_typeClassDefinition, {}};
auto& typeVariableAnnotation = annotation(_typeClassDefinition.typeVariable());
{
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type};
_typeClassDefinition.typeVariable().accept(*this);
}
solAssert(typeVariableAnnotation.type);
map<string, Type> functionTypes;
@ -174,8 +149,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
if (auto error = m_typeSystem.declareTypeClass(TypeClass{&_typeClassDefinition}, typeVar, std::move(functionTypes)))
m_errorReporter.fatalTypeError(0000_error, _typeClassDefinition.location(), *error);
solAssert(typeVariableAnnotation.type);
unify(*typeVariableAnnotation.type, m_typeSystem.freshTypeVariable(Sort{{TypeClass{&_typeClassDefinition}}}), _typeClassDefinition.location());
unify(getType(_typeClassDefinition.typeVariable()), m_typeSystem.freshTypeVariable({{{&_typeClassDefinition}}}), _typeClassDefinition.location());
for (auto instantiation: m_analysis.annotation<TypeRegistration>(_typeClassDefinition).instantiations | ranges::views::values)
// TODO: recursion-safety?
instantiation->accept(*this);
@ -183,36 +157,6 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
return false;
}
void TypeInference::unify(Type _a, Type _b, langutil::SourceLocation _location, TypeEnvironment* _env)
{
if (!_env)
_env = m_env;
for (auto failure: _env->unify(_a, _b))
{
TypeEnvironmentHelpers helper{*_env};
std::visit(util::GenericVisitor{
[&](TypeEnvironment::TypeMismatch _typeMismatch) {
m_errorReporter.typeError(
0000_error,
_location,
fmt::format(
"Cannot unify {} and {}.",
helper.typeToString(_typeMismatch.a),
helper.typeToString(_typeMismatch.b))
);
},
[&](TypeEnvironment::SortMismatch _sortMismatch) {
m_errorReporter.typeError(0000_error, _location, fmt::format(
"{} does not have sort {}",
helper.typeToString(_sortMismatch.type),
TypeSystemHelpers{m_typeSystem}.sortToString(_sortMismatch.sort)
));
}
}, failure);
}
}
bool TypeInference::visit(InlineAssembly const& _inlineAssembly)
{
// External references have already been resolved in a prior stage and stored in the annotation.
@ -239,9 +183,7 @@ bool TypeInference::visit(InlineAssembly const& _inlineAssembly)
solAssert(identifierInfo.suffix == "", "");
auto& declarationAnnotation = annotation(*declaration);
solAssert(declarationAnnotation.type);
unify(*declarationAnnotation.type, m_wordType, originLocationOf(_identifier));
unify(getType(*declaration), m_wordType, originLocationOf(_identifier));
identifierInfo.valueSize = 1;
return true;
};
@ -300,11 +242,10 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
{
auto argType = m_typeSystem.freshTypeVariable({});
auto resultType = m_typeSystem.freshTypeVariable({});
expressionAnnotation.type =
helper.typeFunctionType(
helper.tupleType({argType, resultType}),
m_typeSystem.type(BuiltinType::Function, {argType, resultType})
);
expressionAnnotation.type = helper.typeFunctionType(
helper.tupleType({argType, resultType}),
m_typeSystem.type(BuiltinType::Function, {argType, resultType})
);
break;
}
default:
@ -318,8 +259,6 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
bool TypeInference::visit(BinaryOperation const& _binaryOperation)
{
auto& operationAnnotation = annotation(_binaryOperation);
auto& leftAnnotation = annotation(_binaryOperation.leftExpression());
auto& rightAnnotation = annotation(_binaryOperation.rightExpression());
TypeSystemHelpers helper{m_typeSystem};
switch (m_expressionContext)
{
@ -330,11 +269,9 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
Type functionType = m_env->fresh(m_typeSystem.typeClassInfo(typeClass)->functions.at(functionName));
_binaryOperation.leftExpression().accept(*this);
solAssert(leftAnnotation.type);
_binaryOperation.rightExpression().accept(*this);
solAssert(rightAnnotation.type);
Type argTuple = helper.tupleType({*leftAnnotation.type, *rightAnnotation.type});
Type argTuple = helper.tupleType({getType(_binaryOperation.leftExpression()), getType(_binaryOperation.rightExpression())});
Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable({}));
unify(functionType, genericFunctionType, _binaryOperation.location());
@ -356,20 +293,15 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort};
_binaryOperation.rightExpression().accept(*this);
}
solAssert(leftAnnotation.type);
solAssert(rightAnnotation.type);
unify(*leftAnnotation.type, *rightAnnotation.type, _binaryOperation.location());
operationAnnotation.type = leftAnnotation.type;
Type leftType = getType(_binaryOperation.leftExpression());
unify(leftType, getType(_binaryOperation.rightExpression()), _binaryOperation.location());
operationAnnotation.type = leftType;
}
else if (_binaryOperation.getOperator() == Token::RightArrow)
{
_binaryOperation.leftExpression().accept(*this);
_binaryOperation.rightExpression().accept(*this);
solAssert(leftAnnotation.type);
solAssert(rightAnnotation.type);
Type leftType = *leftAnnotation.type;
Type rightType = *rightAnnotation.type;
operationAnnotation.type = helper.functionType(leftType, rightType);
operationAnnotation.type = helper.functionType(getType(_binaryOperation.leftExpression()), getType(_binaryOperation.rightExpression()));
}
else
{
@ -393,14 +325,9 @@ void TypeInference::endVisit(VariableDeclarationStatement const& _variableDeclar
m_errorReporter.typeError(0000_error, _variableDeclarationStatement.location(), "Multi variable declaration not supported.");
return;
}
auto& variableAnnotation = annotation(*_variableDeclarationStatement.declarations().front());
solAssert(variableAnnotation.type);
Type variableType = getType(*_variableDeclarationStatement.declarations().front());
if (_variableDeclarationStatement.initialValue())
{
auto& expressionAnnotation = annotation(*_variableDeclarationStatement.initialValue());
solAssert(expressionAnnotation.type);
unify(*variableAnnotation.type, *expressionAnnotation.type, _variableDeclarationStatement.location());
}
unify(variableType, getType(*_variableDeclarationStatement.initialValue()), _variableDeclarationStatement.location());
}
bool TypeInference::visit(VariableDeclaration const& _variableDeclaration)
@ -416,12 +343,9 @@ bool TypeInference::visit(VariableDeclaration const& _variableDeclaration)
{
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type};
_variableDeclaration.typeExpression()->accept(*this);
auto& typeExpressionAnnotation = annotation(*_variableDeclaration.typeExpression());
solAssert(typeExpressionAnnotation.type);
variableAnnotation.type = *typeExpressionAnnotation.type;
variableAnnotation.type = getType(*_variableDeclaration.typeExpression());
return false;
}
variableAnnotation.type = m_typeSystem.freshTypeVariable({});
return false;
case ExpressionContext::Type:
@ -430,23 +354,15 @@ bool TypeInference::visit(VariableDeclaration const& _variableDeclaration)
{
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort};
_variableDeclaration.typeExpression()->accept(*this);
auto& typeExpressionAnnotation = annotation(*_variableDeclaration.typeExpression());
solAssert(typeExpressionAnnotation.type);
unify(*variableAnnotation.type, *typeExpressionAnnotation.type, _variableDeclaration.typeExpression()->location());
unify(*variableAnnotation.type, getType(*_variableDeclaration.typeExpression()), _variableDeclaration.typeExpression()->location());
}
return false;
case ExpressionContext::Sort:
m_errorReporter.typeError(0000_error, _variableDeclaration.location(), "Variable declaration in sort context.");
variableAnnotation.type = m_typeSystem.freshTypeVariable({});
return false;
}
solAssert(false);
return false;
}
bool TypeInference::visit(Assignment const&)
{
return true;
util::unreachable();
}
void TypeInference::endVisit(Assignment const& _assignment)
@ -461,22 +377,9 @@ void TypeInference::endVisit(Assignment const& _assignment)
return;
}
auto& lhsAnnotation = annotation(_assignment.leftHandSide());
solAssert(lhsAnnotation.type);
auto& rhsAnnotation = annotation(_assignment.rightHandSide());
solAssert(rhsAnnotation.type);
unify(*lhsAnnotation.type, *rhsAnnotation.type, _assignment.location());
assignmentAnnotation.type = m_env->resolve(*lhsAnnotation.type);
}
TypeInference::Annotation& TypeInference::annotation(ASTNode const& _node)
{
return m_analysis.annotation<TypeInference>(_node);
}
TypeInference::GlobalAnnotation& TypeInference::annotation()
{
return m_analysis.annotation<TypeInference>();
Type leftType = getType(_assignment.leftHandSide());
unify(leftType, getType(_assignment.rightHandSide()), _assignment.location());
assignmentAnnotation.type = m_env->resolve(leftType);
}
experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langutil::SourceLocation _location, Declaration const& _declaration)
@ -503,10 +406,10 @@ experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langut
solAssert(declarationAnnotation.type);
if (dynamic_cast<FunctionDefinition const*>(&_declaration))
return m_env->fresh(*declarationAnnotation.type);
else if (dynamic_cast<VariableDeclaration const*>(&_declaration))
if (dynamic_cast<VariableDeclaration const*>(&_declaration))
return *declarationAnnotation.type;
else if (dynamic_cast<FunctionDefinition const*>(&_declaration))
return m_env->fresh(*declarationAnnotation.type);
else if (dynamic_cast<TypeClassDefinition const*>(&_declaration))
return m_env->fresh(*declarationAnnotation.type);
else if (dynamic_cast<TypeDefinition const*>(&_declaration))
@ -535,15 +438,9 @@ experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langut
solAssert(declarationAnnotation.type);
if (dynamic_cast<VariableDeclaration const*>(&_declaration))
{
// TODO: helper.destKindType(*declarationAnnotation.type);
return *declarationAnnotation.type;
}
else if (dynamic_cast<TypeDefinition const*>(&_declaration))
{
// TODO: helper.destKindType(*declarationAnnotation.type);
return m_env->fresh(*declarationAnnotation.type);
}
else
solAssert(false);
break;
@ -585,11 +482,9 @@ bool TypeInference::visit(Identifier const& _identifier)
solAssert(false);
break;
case ExpressionContext::Type:
{
// TODO: register free type variable name!
identifierAnnotation.type = m_typeSystem.freshTypeVariable({});
return false;
}
case ExpressionContext::Sort:
// TODO: error handling
solAssert(false);
@ -644,6 +539,7 @@ bool TypeInference::visit(IdentifierPath const& _identifierPath)
bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
{
// TODO: Deal with dependencies between type class instantiations.
auto& instantiationAnnotation = annotation(_typeClassInstantiation);
if (instantiationAnnotation.type)
return false;
@ -715,9 +611,8 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
auto const* functionDefinition = dynamic_cast<FunctionDefinition const*>(subNode.get());
solAssert(functionDefinition);
subNode->accept(*this);
solAssert(annotation(*functionDefinition).type);
if (!functionTypes.emplace(functionDefinition->name(), *annotation(*functionDefinition).type).second)
m_errorReporter.typeError(0000_error, subNode->location(), "Multiple definitions of function " + functionDefinition->name() + " during type class instantiation.");
if (!functionTypes.emplace(functionDefinition->name(), getType(*functionDefinition)).second)
m_errorReporter.typeError(0000_error, subNode->location(), "Duplicate definition of function " + functionDefinition->name() + " during type class instantiation.");
}
if (auto error = m_typeSystem.instantiateClass(type, arity, std::move(functionTypes)))
@ -726,34 +621,6 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
return false;
}
void TypeInference::endVisit(MemberAccess const& _memberAccess)
{
auto &memberAccessAnnotation = annotation(_memberAccess);
solAssert(!memberAccessAnnotation.type);
auto& expressionAnnotation = annotation(_memberAccess.expression());
solAssert(expressionAnnotation.type);
TypeSystemHelpers helper{m_typeSystem};
if (helper.isTypeConstant(*expressionAnnotation.type))
{
Type expressionType = *expressionAnnotation.type;
auto constructor = std::get<0>(helper.destTypeConstant(expressionType));
if (auto* typeMember = util::valueOrNullptr(annotation().members.at(constructor), _memberAccess.memberName()))
{
Type type = m_env->fresh(typeMember->type);
annotation(_memberAccess).type = type;
return;
}
else
{
m_errorReporter.typeError(0000_error, _memberAccess.memberLocation(), "Member not found.");
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({});
return;
}
}
m_errorReporter.typeError(0000_error, _memberAccess.expression().location(), "Unsupported member access expression.");
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({});
}
bool TypeInference::visit(MemberAccess const& _memberAccess)
{
if (m_expressionContext != ExpressionContext::Term)
@ -765,6 +632,33 @@ bool TypeInference::visit(MemberAccess const& _memberAccess)
return true;
}
void TypeInference::endVisit(MemberAccess const& _memberAccess)
{
auto &memberAccessAnnotation = annotation(_memberAccess);
solAssert(!memberAccessAnnotation.type);
Type expressionType = getType(_memberAccess.expression());
TypeSystemHelpers helper{m_typeSystem};
if (helper.isTypeConstant(expressionType))
{
auto constructor = std::get<0>(helper.destTypeConstant(expressionType));
if (auto* typeMember = util::valueOrNullptr(annotation().members.at(constructor), _memberAccess.memberName()))
{
Type type = m_env->fresh(typeMember->type);
annotation(_memberAccess).type = type;
}
else
{
m_errorReporter.typeError(0000_error, _memberAccess.memberLocation(), "Member not found.");
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({});
}
}
else
{
m_errorReporter.typeError(0000_error, _memberAccess.expression().location(), "Unsupported member access expression.");
annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({});
}
}
bool TypeInference::visit(TypeDefinition const& _typeDefinition)
{
TypeSystemHelpers helper{m_typeSystem};
@ -772,6 +666,9 @@ bool TypeInference::visit(TypeDefinition const& _typeDefinition)
if (typeDefinitionAnnotation.type)
return false;
if (_typeDefinition.arguments())
_typeDefinition.arguments()->accept(*this);
std::optional<Type> underlyingType;
if (_typeDefinition.typeExpression())
{
@ -792,14 +689,13 @@ bool TypeInference::visit(TypeDefinition const& _typeDefinition)
else
typeDefinitionAnnotation.type = helper.typeFunctionType(helper.tupleType(arguments), type);
auto& typeMembers = annotation().members[&_typeDefinition];
auto [members, newlyInserted] = annotation().members.emplace(&_typeDefinition, map<string, TypeMember>{});
solAssert(newlyInserted);
if (underlyingType)
{
typeMembers.emplace("abs", TypeMember{helper.functionType(*underlyingType, type)});
typeMembers.emplace("rep", TypeMember{helper.functionType(type, *underlyingType)});
members->second.emplace("abs", TypeMember{helper.functionType(*underlyingType, type)});
members->second.emplace("rep", TypeMember{helper.functionType(type, *underlyingType)});
}
return false;
}
@ -809,22 +705,17 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
auto& functionCallAnnotation = annotation(_functionCall);
solAssert(!functionCallAnnotation.type);
auto& expressionAnnotation = annotation(_functionCall.expression());
solAssert(expressionAnnotation.type);
Type functionType = *expressionAnnotation.type;
Type functionType = getType(_functionCall.expression());
TypeSystemHelpers helper{m_typeSystem};
std::vector<Type> argTypes;
for(auto arg: _functionCall.arguments())
{
auto& argAnnotation = annotation(*arg);
solAssert(argAnnotation.type);
switch(m_expressionContext)
{
case ExpressionContext::Term:
case ExpressionContext::Type:
argTypes.emplace_back(*argAnnotation.type);
argTypes.emplace_back(getType(*arg));
break;
case ExpressionContext::Sort:
m_errorReporter.typeError(0000_error, _functionCall.location(), "Function call in sort context.");
@ -837,10 +728,8 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
{
case ExpressionContext::Term:
{
Type argTuple = helper.tupleType(argTypes);
Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable({}));
Type genericFunctionType = helper.functionType(helper.tupleType(argTypes), m_typeSystem.freshTypeVariable({}));
unify(functionType, genericFunctionType, _functionCall.location());
functionCallAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType))));
break;
}
@ -849,7 +738,6 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
Type argTuple = helper.tupleType(argTypes);
Type genericFunctionType = helper.typeFunctionType(argTuple, m_typeSystem.freshKindVariable({}));
unify(functionType, genericFunctionType, _functionCall.location());
functionCallAnnotation.type = m_env->resolve(std::get<1>(helper.destTypeFunctionType(m_env->resolve(genericFunctionType))));
break;
}
@ -1035,3 +923,60 @@ bool TypeInference::visit(Literal const& _literal)
literalAnnotation.type = m_typeSystem.freshTypeVariable(Sort{{TypeClass{BuiltinClass::Integer}}});
return false;
}
void TypeInference::unify(Type _a, Type _b, langutil::SourceLocation _location, TypeEnvironment* _env)
{
if (!_env)
_env = m_env;
for (auto failure: _env->unify(_a, _b))
{
TypeEnvironmentHelpers helper{*_env};
std::visit(util::GenericVisitor{
[&](TypeEnvironment::TypeMismatch _typeMismatch) {
m_errorReporter.typeError(
0000_error,
_location,
fmt::format(
"Cannot unify {} and {}.",
helper.typeToString(_typeMismatch.a),
helper.typeToString(_typeMismatch.b))
);
},
[&](TypeEnvironment::SortMismatch _sortMismatch) {
m_errorReporter.typeError(0000_error, _location, fmt::format(
"{} does not have sort {}",
helper.typeToString(_sortMismatch.type),
TypeSystemHelpers{m_typeSystem}.sortToString(_sortMismatch.sort)
));
}
}, failure);
}
}
experimental::Type TypeInference::getType(ASTNode const& _node) const
{
auto result = annotation(_node).type;
solAssert(result);
return *result;
}
bool TypeInference::visitNode(ASTNode const& _node)
{
m_errorReporter.fatalTypeError(0000_error, _node.location(), "Unsupported AST node during type inference.");
return false;
}
TypeInference::Annotation& TypeInference::annotation(ASTNode const& _node)
{
return m_analysis.annotation<TypeInference>(_node);
}
TypeInference::Annotation const& TypeInference::annotation(ASTNode const& _node) const
{
return m_analysis.annotation<TypeInference>(_node);
}
TypeInference::GlobalAnnotation& TypeInference::annotation()
{
return m_analysis.annotation<TypeInference>();
}

View File

@ -53,7 +53,7 @@ public:
bool visit(VariableDeclaration const& _variableDeclaration) override;
bool visit(FunctionDefinition const& _functionDefinition) override;
bool visit(ParameterList const& _parameterList) override;
bool visit(ParameterList const&) override { return true; }
void endVisit(ParameterList const& _parameterList) override;
bool visit(SourceUnit const&) override { return true; }
bool visit(ContractDefinition const&) override { return true; }
@ -61,8 +61,8 @@ public:
bool visit(PragmaDirective const&) override { return false; }
bool visit(ExpressionStatement const&) override { return true; }
bool visit(Assignment const&) override;
void endVisit(Assignment const&) override;
bool visit(Assignment const&) override { return true; }
void endVisit(Assignment const& _assignment) override;
bool visit(Identifier const&) override;
bool visit(IdentifierPath const&) override;
bool visit(FunctionCall const& _functionCall) override;
@ -97,7 +97,10 @@ private:
Type m_boolType;
std::optional<Type> m_currentFunctionType;
Type getType(ASTNode const& _node) const;
Annotation& annotation(ASTNode const& _node);
Annotation const& annotation(ASTNode const& _node) const;
GlobalAnnotation& annotation();
void unify(Type _a, Type _b, langutil::SourceLocation _location = {}, TypeEnvironment* _env = nullptr);