diff --git a/libsolidity/analysis/experimental/TypeInference.cpp b/libsolidity/analysis/experimental/TypeInference.cpp index ee76fb82f..5dc53fe7d 100644 --- a/libsolidity/analysis/experimental/TypeInference.cpp +++ b/libsolidity/analysis/experimental/TypeInference.cpp @@ -70,17 +70,11 @@ bool TypeInference::visit(FunctionDefinition const& _functionDefinition) if (_functionDefinition.returnParameterList()) _functionDefinition.returnParameterList()->accept(*this); - auto typeFromParameterList = [&](ParameterList const* _list) { - if (!_list) - return m_unitType; - auto& listAnnotation = annotation(*_list); - solAssert(listAnnotation.type); - return *listAnnotation.type; - }; + auto getListType = [&](ParameterList const* _list) { return _list ? getType(*_list) : m_unitType; }; Type functionType = TypeSystemHelpers{m_typeSystem}.functionType( - typeFromParameterList(&_functionDefinition.parameterList()), - typeFromParameterList(_functionDefinition.returnParameterList().get()) + getListType(&_functionDefinition.parameterList()), + getListType(_functionDefinition.returnParameterList().get()) ); m_currentFunctionType = functionType; @@ -100,36 +94,19 @@ void TypeInference::endVisit(Return const& _return) solAssert(m_currentFunctionType); if (_return.expression()) { - auto& returnExpressionAnnotation = annotation(*_return.expression()); - solAssert(returnExpressionAnnotation.type); + Type returnExpressionType = getType(*_return.expression()); Type functionReturnType = get<1>(TypeSystemHelpers{m_typeSystem}.destFunctionType(*m_currentFunctionType)); - unify(functionReturnType, *returnExpressionAnnotation.type, _return.location()); + unify(functionReturnType, returnExpressionType, _return.location()); } } -bool TypeInference::visit(ParameterList const&) -{ - return true; -} - void TypeInference::endVisit(ParameterList const& _parameterList) { auto& listAnnotation = annotation(_parameterList); solAssert(!listAnnotation.type); - std::vector argTypes; - for(auto arg: _parameterList.parameters()) - { - auto& argAnnotation = annotation(*arg); - solAssert(argAnnotation.type); - argTypes.emplace_back(*argAnnotation.type); - } - listAnnotation.type = TypeSystemHelpers{m_typeSystem}.tupleType(argTypes); -} - -bool TypeInference::visitNode(ASTNode const& _node) -{ - m_errorReporter.fatalTypeError(0000_error, _node.location(), "Unsupported AST node during type inference."); - return false; + listAnnotation.type = TypeSystemHelpers{m_typeSystem}.tupleType( + _parameterList.parameters() | ranges::views::transform([&](auto _arg) { return getType(*_arg); }) | ranges::to> + ); } bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition) @@ -140,12 +117,10 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition) return false; m_typeSystem.declareTypeConstructor(&_typeClassDefinition, _typeClassDefinition.name(), 0); typeClassAnnotation.type = TypeConstant{&_typeClassDefinition, {}}; - auto& typeVariableAnnotation = annotation(_typeClassDefinition.typeVariable()); { ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type}; _typeClassDefinition.typeVariable().accept(*this); } - solAssert(typeVariableAnnotation.type); map functionTypes; @@ -174,8 +149,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition) if (auto error = m_typeSystem.declareTypeClass(TypeClass{&_typeClassDefinition}, typeVar, std::move(functionTypes))) m_errorReporter.fatalTypeError(0000_error, _typeClassDefinition.location(), *error); - solAssert(typeVariableAnnotation.type); - unify(*typeVariableAnnotation.type, m_typeSystem.freshTypeVariable(Sort{{TypeClass{&_typeClassDefinition}}}), _typeClassDefinition.location()); + unify(getType(_typeClassDefinition.typeVariable()), m_typeSystem.freshTypeVariable({{{&_typeClassDefinition}}}), _typeClassDefinition.location()); for (auto instantiation: m_analysis.annotation(_typeClassDefinition).instantiations | ranges::views::values) // TODO: recursion-safety? instantiation->accept(*this); @@ -183,36 +157,6 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition) return false; } -void TypeInference::unify(Type _a, Type _b, langutil::SourceLocation _location, TypeEnvironment* _env) -{ - if (!_env) - _env = m_env; - for (auto failure: _env->unify(_a, _b)) - { - TypeEnvironmentHelpers helper{*_env}; - std::visit(util::GenericVisitor{ - [&](TypeEnvironment::TypeMismatch _typeMismatch) { - m_errorReporter.typeError( - 0000_error, - _location, - fmt::format( - "Cannot unify {} and {}.", - helper.typeToString(_typeMismatch.a), - helper.typeToString(_typeMismatch.b)) - ); - }, - [&](TypeEnvironment::SortMismatch _sortMismatch) { - m_errorReporter.typeError(0000_error, _location, fmt::format( - "{} does not have sort {}", - helper.typeToString(_sortMismatch.type), - TypeSystemHelpers{m_typeSystem}.sortToString(_sortMismatch.sort) - )); - } - }, failure); - - } - -} bool TypeInference::visit(InlineAssembly const& _inlineAssembly) { // External references have already been resolved in a prior stage and stored in the annotation. @@ -239,9 +183,7 @@ bool TypeInference::visit(InlineAssembly const& _inlineAssembly) solAssert(identifierInfo.suffix == "", ""); - auto& declarationAnnotation = annotation(*declaration); - solAssert(declarationAnnotation.type); - unify(*declarationAnnotation.type, m_wordType, originLocationOf(_identifier)); + unify(getType(*declaration), m_wordType, originLocationOf(_identifier)); identifierInfo.valueSize = 1; return true; }; @@ -300,11 +242,10 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression) { auto argType = m_typeSystem.freshTypeVariable({}); auto resultType = m_typeSystem.freshTypeVariable({}); - expressionAnnotation.type = - helper.typeFunctionType( - helper.tupleType({argType, resultType}), - m_typeSystem.type(BuiltinType::Function, {argType, resultType}) - ); + expressionAnnotation.type = helper.typeFunctionType( + helper.tupleType({argType, resultType}), + m_typeSystem.type(BuiltinType::Function, {argType, resultType}) + ); break; } default: @@ -318,8 +259,6 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression) bool TypeInference::visit(BinaryOperation const& _binaryOperation) { auto& operationAnnotation = annotation(_binaryOperation); - auto& leftAnnotation = annotation(_binaryOperation.leftExpression()); - auto& rightAnnotation = annotation(_binaryOperation.rightExpression()); TypeSystemHelpers helper{m_typeSystem}; switch (m_expressionContext) { @@ -330,11 +269,9 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation) Type functionType = m_env->fresh(m_typeSystem.typeClassInfo(typeClass)->functions.at(functionName)); _binaryOperation.leftExpression().accept(*this); - solAssert(leftAnnotation.type); _binaryOperation.rightExpression().accept(*this); - solAssert(rightAnnotation.type); - Type argTuple = helper.tupleType({*leftAnnotation.type, *rightAnnotation.type}); + Type argTuple = helper.tupleType({getType(_binaryOperation.leftExpression()), getType(_binaryOperation.rightExpression())}); Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable({})); unify(functionType, genericFunctionType, _binaryOperation.location()); @@ -356,20 +293,15 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation) ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort}; _binaryOperation.rightExpression().accept(*this); } - solAssert(leftAnnotation.type); - solAssert(rightAnnotation.type); - unify(*leftAnnotation.type, *rightAnnotation.type, _binaryOperation.location()); - operationAnnotation.type = leftAnnotation.type; + Type leftType = getType(_binaryOperation.leftExpression()); + unify(leftType, getType(_binaryOperation.rightExpression()), _binaryOperation.location()); + operationAnnotation.type = leftType; } else if (_binaryOperation.getOperator() == Token::RightArrow) { _binaryOperation.leftExpression().accept(*this); _binaryOperation.rightExpression().accept(*this); - solAssert(leftAnnotation.type); - solAssert(rightAnnotation.type); - Type leftType = *leftAnnotation.type; - Type rightType = *rightAnnotation.type; - operationAnnotation.type = helper.functionType(leftType, rightType); + operationAnnotation.type = helper.functionType(getType(_binaryOperation.leftExpression()), getType(_binaryOperation.rightExpression())); } else { @@ -393,14 +325,9 @@ void TypeInference::endVisit(VariableDeclarationStatement const& _variableDeclar m_errorReporter.typeError(0000_error, _variableDeclarationStatement.location(), "Multi variable declaration not supported."); return; } - auto& variableAnnotation = annotation(*_variableDeclarationStatement.declarations().front()); - solAssert(variableAnnotation.type); + Type variableType = getType(*_variableDeclarationStatement.declarations().front()); if (_variableDeclarationStatement.initialValue()) - { - auto& expressionAnnotation = annotation(*_variableDeclarationStatement.initialValue()); - solAssert(expressionAnnotation.type); - unify(*variableAnnotation.type, *expressionAnnotation.type, _variableDeclarationStatement.location()); - } + unify(variableType, getType(*_variableDeclarationStatement.initialValue()), _variableDeclarationStatement.location()); } bool TypeInference::visit(VariableDeclaration const& _variableDeclaration) @@ -416,12 +343,9 @@ bool TypeInference::visit(VariableDeclaration const& _variableDeclaration) { ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type}; _variableDeclaration.typeExpression()->accept(*this); - auto& typeExpressionAnnotation = annotation(*_variableDeclaration.typeExpression()); - solAssert(typeExpressionAnnotation.type); - variableAnnotation.type = *typeExpressionAnnotation.type; + variableAnnotation.type = getType(*_variableDeclaration.typeExpression()); return false; } - variableAnnotation.type = m_typeSystem.freshTypeVariable({}); return false; case ExpressionContext::Type: @@ -430,23 +354,15 @@ bool TypeInference::visit(VariableDeclaration const& _variableDeclaration) { ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort}; _variableDeclaration.typeExpression()->accept(*this); - auto& typeExpressionAnnotation = annotation(*_variableDeclaration.typeExpression()); - solAssert(typeExpressionAnnotation.type); - - unify(*variableAnnotation.type, *typeExpressionAnnotation.type, _variableDeclaration.typeExpression()->location()); + unify(*variableAnnotation.type, getType(*_variableDeclaration.typeExpression()), _variableDeclaration.typeExpression()->location()); } return false; case ExpressionContext::Sort: m_errorReporter.typeError(0000_error, _variableDeclaration.location(), "Variable declaration in sort context."); + variableAnnotation.type = m_typeSystem.freshTypeVariable({}); return false; } - solAssert(false); - return false; -} - -bool TypeInference::visit(Assignment const&) -{ - return true; + util::unreachable(); } void TypeInference::endVisit(Assignment const& _assignment) @@ -461,22 +377,9 @@ void TypeInference::endVisit(Assignment const& _assignment) return; } - auto& lhsAnnotation = annotation(_assignment.leftHandSide()); - solAssert(lhsAnnotation.type); - auto& rhsAnnotation = annotation(_assignment.rightHandSide()); - solAssert(rhsAnnotation.type); - unify(*lhsAnnotation.type, *rhsAnnotation.type, _assignment.location()); - assignmentAnnotation.type = m_env->resolve(*lhsAnnotation.type); -} - -TypeInference::Annotation& TypeInference::annotation(ASTNode const& _node) -{ - return m_analysis.annotation(_node); -} - -TypeInference::GlobalAnnotation& TypeInference::annotation() -{ - return m_analysis.annotation(); + Type leftType = getType(_assignment.leftHandSide()); + unify(leftType, getType(_assignment.rightHandSide()), _assignment.location()); + assignmentAnnotation.type = m_env->resolve(leftType); } experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langutil::SourceLocation _location, Declaration const& _declaration) @@ -503,10 +406,10 @@ experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langut solAssert(declarationAnnotation.type); - if (dynamic_cast(&_declaration)) - return m_env->fresh(*declarationAnnotation.type); - else if (dynamic_cast(&_declaration)) + if (dynamic_cast(&_declaration)) return *declarationAnnotation.type; + else if (dynamic_cast(&_declaration)) + return m_env->fresh(*declarationAnnotation.type); else if (dynamic_cast(&_declaration)) return m_env->fresh(*declarationAnnotation.type); else if (dynamic_cast(&_declaration)) @@ -535,15 +438,9 @@ experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langut solAssert(declarationAnnotation.type); if (dynamic_cast(&_declaration)) - { - // TODO: helper.destKindType(*declarationAnnotation.type); return *declarationAnnotation.type; - } else if (dynamic_cast(&_declaration)) - { - // TODO: helper.destKindType(*declarationAnnotation.type); return m_env->fresh(*declarationAnnotation.type); - } else solAssert(false); break; @@ -585,11 +482,9 @@ bool TypeInference::visit(Identifier const& _identifier) solAssert(false); break; case ExpressionContext::Type: - { // TODO: register free type variable name! identifierAnnotation.type = m_typeSystem.freshTypeVariable({}); return false; - } case ExpressionContext::Sort: // TODO: error handling solAssert(false); @@ -644,6 +539,7 @@ bool TypeInference::visit(IdentifierPath const& _identifierPath) bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation) { + // TODO: Deal with dependencies between type class instantiations. auto& instantiationAnnotation = annotation(_typeClassInstantiation); if (instantiationAnnotation.type) return false; @@ -715,9 +611,8 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation) auto const* functionDefinition = dynamic_cast(subNode.get()); solAssert(functionDefinition); subNode->accept(*this); - solAssert(annotation(*functionDefinition).type); - if (!functionTypes.emplace(functionDefinition->name(), *annotation(*functionDefinition).type).second) - m_errorReporter.typeError(0000_error, subNode->location(), "Multiple definitions of function " + functionDefinition->name() + " during type class instantiation."); + if (!functionTypes.emplace(functionDefinition->name(), getType(*functionDefinition)).second) + m_errorReporter.typeError(0000_error, subNode->location(), "Duplicate definition of function " + functionDefinition->name() + " during type class instantiation."); } if (auto error = m_typeSystem.instantiateClass(type, arity, std::move(functionTypes))) @@ -726,34 +621,6 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation) return false; } -void TypeInference::endVisit(MemberAccess const& _memberAccess) -{ - auto &memberAccessAnnotation = annotation(_memberAccess); - solAssert(!memberAccessAnnotation.type); - auto& expressionAnnotation = annotation(_memberAccess.expression()); - solAssert(expressionAnnotation.type); - TypeSystemHelpers helper{m_typeSystem}; - if (helper.isTypeConstant(*expressionAnnotation.type)) - { - Type expressionType = *expressionAnnotation.type; - auto constructor = std::get<0>(helper.destTypeConstant(expressionType)); - if (auto* typeMember = util::valueOrNullptr(annotation().members.at(constructor), _memberAccess.memberName())) - { - Type type = m_env->fresh(typeMember->type); - annotation(_memberAccess).type = type; - return; - } - else - { - m_errorReporter.typeError(0000_error, _memberAccess.memberLocation(), "Member not found."); - annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({}); - return; - } - } - m_errorReporter.typeError(0000_error, _memberAccess.expression().location(), "Unsupported member access expression."); - annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({}); -} - bool TypeInference::visit(MemberAccess const& _memberAccess) { if (m_expressionContext != ExpressionContext::Term) @@ -765,6 +632,33 @@ bool TypeInference::visit(MemberAccess const& _memberAccess) return true; } +void TypeInference::endVisit(MemberAccess const& _memberAccess) +{ + auto &memberAccessAnnotation = annotation(_memberAccess); + solAssert(!memberAccessAnnotation.type); + Type expressionType = getType(_memberAccess.expression()); + TypeSystemHelpers helper{m_typeSystem}; + if (helper.isTypeConstant(expressionType)) + { + auto constructor = std::get<0>(helper.destTypeConstant(expressionType)); + if (auto* typeMember = util::valueOrNullptr(annotation().members.at(constructor), _memberAccess.memberName())) + { + Type type = m_env->fresh(typeMember->type); + annotation(_memberAccess).type = type; + } + else + { + m_errorReporter.typeError(0000_error, _memberAccess.memberLocation(), "Member not found."); + annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({}); + } + } + else + { + m_errorReporter.typeError(0000_error, _memberAccess.expression().location(), "Unsupported member access expression."); + annotation(_memberAccess).type = m_typeSystem.freshTypeVariable({}); + } +} + bool TypeInference::visit(TypeDefinition const& _typeDefinition) { TypeSystemHelpers helper{m_typeSystem}; @@ -772,6 +666,9 @@ bool TypeInference::visit(TypeDefinition const& _typeDefinition) if (typeDefinitionAnnotation.type) return false; + if (_typeDefinition.arguments()) + _typeDefinition.arguments()->accept(*this); + std::optional underlyingType; if (_typeDefinition.typeExpression()) { @@ -792,14 +689,13 @@ bool TypeInference::visit(TypeDefinition const& _typeDefinition) else typeDefinitionAnnotation.type = helper.typeFunctionType(helper.tupleType(arguments), type); - auto& typeMembers = annotation().members[&_typeDefinition]; - + auto [members, newlyInserted] = annotation().members.emplace(&_typeDefinition, map{}); + solAssert(newlyInserted); if (underlyingType) { - typeMembers.emplace("abs", TypeMember{helper.functionType(*underlyingType, type)}); - typeMembers.emplace("rep", TypeMember{helper.functionType(type, *underlyingType)}); + members->second.emplace("abs", TypeMember{helper.functionType(*underlyingType, type)}); + members->second.emplace("rep", TypeMember{helper.functionType(type, *underlyingType)}); } - return false; } @@ -809,22 +705,17 @@ void TypeInference::endVisit(FunctionCall const& _functionCall) auto& functionCallAnnotation = annotation(_functionCall); solAssert(!functionCallAnnotation.type); - auto& expressionAnnotation = annotation(_functionCall.expression()); - solAssert(expressionAnnotation.type); - - Type functionType = *expressionAnnotation.type; + Type functionType = getType(_functionCall.expression()); TypeSystemHelpers helper{m_typeSystem}; std::vector argTypes; for(auto arg: _functionCall.arguments()) { - auto& argAnnotation = annotation(*arg); - solAssert(argAnnotation.type); switch(m_expressionContext) { case ExpressionContext::Term: case ExpressionContext::Type: - argTypes.emplace_back(*argAnnotation.type); + argTypes.emplace_back(getType(*arg)); break; case ExpressionContext::Sort: m_errorReporter.typeError(0000_error, _functionCall.location(), "Function call in sort context."); @@ -837,10 +728,8 @@ void TypeInference::endVisit(FunctionCall const& _functionCall) { case ExpressionContext::Term: { - Type argTuple = helper.tupleType(argTypes); - Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable({})); + Type genericFunctionType = helper.functionType(helper.tupleType(argTypes), m_typeSystem.freshTypeVariable({})); unify(functionType, genericFunctionType, _functionCall.location()); - functionCallAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType)))); break; } @@ -849,7 +738,6 @@ void TypeInference::endVisit(FunctionCall const& _functionCall) Type argTuple = helper.tupleType(argTypes); Type genericFunctionType = helper.typeFunctionType(argTuple, m_typeSystem.freshKindVariable({})); unify(functionType, genericFunctionType, _functionCall.location()); - functionCallAnnotation.type = m_env->resolve(std::get<1>(helper.destTypeFunctionType(m_env->resolve(genericFunctionType)))); break; } @@ -1035,3 +923,60 @@ bool TypeInference::visit(Literal const& _literal) literalAnnotation.type = m_typeSystem.freshTypeVariable(Sort{{TypeClass{BuiltinClass::Integer}}}); return false; } + +void TypeInference::unify(Type _a, Type _b, langutil::SourceLocation _location, TypeEnvironment* _env) +{ + if (!_env) + _env = m_env; + for (auto failure: _env->unify(_a, _b)) + { + TypeEnvironmentHelpers helper{*_env}; + std::visit(util::GenericVisitor{ + [&](TypeEnvironment::TypeMismatch _typeMismatch) { + m_errorReporter.typeError( + 0000_error, + _location, + fmt::format( + "Cannot unify {} and {}.", + helper.typeToString(_typeMismatch.a), + helper.typeToString(_typeMismatch.b)) + ); + }, + [&](TypeEnvironment::SortMismatch _sortMismatch) { + m_errorReporter.typeError(0000_error, _location, fmt::format( + "{} does not have sort {}", + helper.typeToString(_sortMismatch.type), + TypeSystemHelpers{m_typeSystem}.sortToString(_sortMismatch.sort) + )); + } + }, failure); + } +} + +experimental::Type TypeInference::getType(ASTNode const& _node) const +{ + auto result = annotation(_node).type; + solAssert(result); + return *result; +} + +bool TypeInference::visitNode(ASTNode const& _node) +{ + m_errorReporter.fatalTypeError(0000_error, _node.location(), "Unsupported AST node during type inference."); + return false; +} + +TypeInference::Annotation& TypeInference::annotation(ASTNode const& _node) +{ + return m_analysis.annotation(_node); +} + +TypeInference::Annotation const& TypeInference::annotation(ASTNode const& _node) const +{ + return m_analysis.annotation(_node); +} + +TypeInference::GlobalAnnotation& TypeInference::annotation() +{ + return m_analysis.annotation(); +} diff --git a/libsolidity/analysis/experimental/TypeInference.h b/libsolidity/analysis/experimental/TypeInference.h index afbea93e9..b4024a005 100644 --- a/libsolidity/analysis/experimental/TypeInference.h +++ b/libsolidity/analysis/experimental/TypeInference.h @@ -53,7 +53,7 @@ public: bool visit(VariableDeclaration const& _variableDeclaration) override; bool visit(FunctionDefinition const& _functionDefinition) override; - bool visit(ParameterList const& _parameterList) override; + bool visit(ParameterList const&) override { return true; } void endVisit(ParameterList const& _parameterList) override; bool visit(SourceUnit const&) override { return true; } bool visit(ContractDefinition const&) override { return true; } @@ -61,8 +61,8 @@ public: bool visit(PragmaDirective const&) override { return false; } bool visit(ExpressionStatement const&) override { return true; } - bool visit(Assignment const&) override; - void endVisit(Assignment const&) override; + bool visit(Assignment const&) override { return true; } + void endVisit(Assignment const& _assignment) override; bool visit(Identifier const&) override; bool visit(IdentifierPath const&) override; bool visit(FunctionCall const& _functionCall) override; @@ -97,7 +97,10 @@ private: Type m_boolType; std::optional m_currentFunctionType; + Type getType(ASTNode const& _node) const; + Annotation& annotation(ASTNode const& _node); + Annotation const& annotation(ASTNode const& _node) const; GlobalAnnotation& annotation(); void unify(Type _a, Type _b, langutil::SourceLocation _location = {}, TypeEnvironment* _env = nullptr);