fixup! Type inference draft.

This commit is contained in:
Kamil Śliwak 2023-09-20 13:25:31 +02:00
parent e3caed0ea4
commit 6d43dfbd43
2 changed files with 27 additions and 27 deletions

View File

@ -152,14 +152,14 @@ bool TypeInference::visit(FunctionDefinition const& _functionDefinition)
_functionDefinition.parameterList().accept(*this); _functionDefinition.parameterList().accept(*this);
unify(argumentsType, getType(_functionDefinition.parameterList()), _functionDefinition.parameterList().location()); unify(argumentsType, typeAnnotation(_functionDefinition.parameterList()), _functionDefinition.parameterList().location());
if (_functionDefinition.experimentalReturnExpression()) if (_functionDefinition.experimentalReturnExpression())
{ {
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type}; ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type};
_functionDefinition.experimentalReturnExpression()->accept(*this); _functionDefinition.experimentalReturnExpression()->accept(*this);
unify( unify(
returnType, returnType,
getType(*_functionDefinition.experimentalReturnExpression()), typeAnnotation(*_functionDefinition.experimentalReturnExpression()),
_functionDefinition.experimentalReturnExpression()->location() _functionDefinition.experimentalReturnExpression()->location()
); );
} }
@ -177,7 +177,7 @@ void TypeInference::endVisit(Return const& _return)
solAssert(m_currentFunctionType); solAssert(m_currentFunctionType);
Type functionReturnType = std::get<1>(TypeSystemHelpers{m_typeSystem}.destFunctionType(*m_currentFunctionType)); Type functionReturnType = std::get<1>(TypeSystemHelpers{m_typeSystem}.destFunctionType(*m_currentFunctionType));
if (_return.expression()) if (_return.expression())
unify(functionReturnType, getType(*_return.expression()), _return.location()); unify(functionReturnType, typeAnnotation(*_return.expression()), _return.location());
else else
unify(functionReturnType, m_unitType, _return.location()); unify(functionReturnType, m_unitType, _return.location());
} }
@ -187,7 +187,7 @@ void TypeInference::endVisit(ParameterList const& _parameterList)
auto& listAnnotation = annotation(_parameterList); auto& listAnnotation = annotation(_parameterList);
solAssert(!listAnnotation.type); solAssert(!listAnnotation.type);
listAnnotation.type = TypeSystemHelpers{m_typeSystem}.tupleType( listAnnotation.type = TypeSystemHelpers{m_typeSystem}.tupleType(
_parameterList.parameters() | ranges::views::transform([&](auto _arg) { return getType(*_arg); }) | ranges::to<std::vector<Type>> _parameterList.parameters() | ranges::views::transform([&](auto _arg) { return typeAnnotation(*_arg); }) | ranges::to<std::vector<Type>>
); );
} }
@ -218,7 +218,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
auto const* functionDefinition = dynamic_cast<FunctionDefinition const*>(subNode.get()); auto const* functionDefinition = dynamic_cast<FunctionDefinition const*>(subNode.get());
solAssert(functionDefinition); solAssert(functionDefinition);
// TODO: need polymorphicInstance? // TODO: need polymorphicInstance?
auto functionType = polymorphicInstance(getType(*functionDefinition)); auto functionType = polymorphicInstance(typeAnnotation(*functionDefinition));
if (!functionTypes.emplace(functionDefinition->name(), functionType).second) if (!functionTypes.emplace(functionDefinition->name(), functionType).second)
m_errorReporter.fatalTypeError(3195_error, functionDefinition->location(), "Function in type class declared multiple times."); m_errorReporter.fatalTypeError(3195_error, functionDefinition->location(), "Function in type class declared multiple times.");
auto typeVars = TypeEnvironmentHelpers{*m_env}.typeVars(functionType); auto typeVars = TypeEnvironmentHelpers{*m_env}.typeVars(functionType);
@ -242,7 +242,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
m_errorReporter.typeError(1807_error, _typeClassDefinition.location(), "Function " + functionName + " depends on invalid type variable."); m_errorReporter.typeError(1807_error, _typeClassDefinition.location(), "Function " + functionName + " depends on invalid type variable.");
} }
unify(getType(_typeClassDefinition.typeVariable()), m_typeSystem.freshTypeVariable({{typeClass}}), _typeClassDefinition.location()); unify(typeAnnotation(_typeClassDefinition.typeVariable()), m_typeSystem.freshTypeVariable({{typeClass}}), _typeClassDefinition.location());
for (auto instantiation: m_analysis.annotation<TypeRegistration>(_typeClassDefinition).instantiations | ranges::views::values) for (auto instantiation: m_analysis.annotation<TypeRegistration>(_typeClassDefinition).instantiations | ranges::views::values)
// TODO: recursion-safety? Order of instantiation? // TODO: recursion-safety? Order of instantiation?
instantiation->accept(*this); instantiation->accept(*this);
@ -275,7 +275,7 @@ bool TypeInference::visit(InlineAssembly const& _inlineAssembly)
solAssert(!!declaration, ""); solAssert(!!declaration, "");
solAssert(identifierInfo->suffix == "", ""); solAssert(identifierInfo->suffix == "", "");
unify(getType(*declaration), m_wordType, originLocationOf(_identifier)); unify(typeAnnotation(*declaration), m_wordType, originLocationOf(_identifier));
identifierInfo->valueSize = 1; identifierInfo->valueSize = 1;
return true; return true;
}; };
@ -351,7 +351,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
_binaryOperation.leftExpression().accept(*this); _binaryOperation.leftExpression().accept(*this);
_binaryOperation.rightExpression().accept(*this); _binaryOperation.rightExpression().accept(*this);
Type argTuple = helper.tupleType({getType(_binaryOperation.leftExpression()), getType(_binaryOperation.rightExpression())}); Type argTuple = helper.tupleType({typeAnnotation(_binaryOperation.leftExpression()), typeAnnotation(_binaryOperation.rightExpression())});
Type resultType = m_typeSystem.freshTypeVariable({}); Type resultType = m_typeSystem.freshTypeVariable({});
Type genericFunctionType = helper.functionType(argTuple, resultType); Type genericFunctionType = helper.functionType(argTuple, resultType);
unify(functionType, genericFunctionType, _binaryOperation.location()); unify(functionType, genericFunctionType, _binaryOperation.location());
@ -365,8 +365,8 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type}; ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type};
_binaryOperation.rightExpression().accept(*this); _binaryOperation.rightExpression().accept(*this);
} }
Type leftType = getType(_binaryOperation.leftExpression()); Type leftType = typeAnnotation(_binaryOperation.leftExpression());
unify(leftType, getType(_binaryOperation.rightExpression()), _binaryOperation.location()); unify(leftType, typeAnnotation(_binaryOperation.rightExpression()), _binaryOperation.location());
operationAnnotation.type = leftType; operationAnnotation.type = leftType;
} }
else else
@ -383,21 +383,21 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort}; ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort};
_binaryOperation.rightExpression().accept(*this); _binaryOperation.rightExpression().accept(*this);
} }
Type leftType = getType(_binaryOperation.leftExpression()); Type leftType = typeAnnotation(_binaryOperation.leftExpression());
unify(leftType, getType(_binaryOperation.rightExpression()), _binaryOperation.location()); unify(leftType, typeAnnotation(_binaryOperation.rightExpression()), _binaryOperation.location());
operationAnnotation.type = leftType; operationAnnotation.type = leftType;
} }
else if (_binaryOperation.getOperator() == Token::RightArrow) else if (_binaryOperation.getOperator() == Token::RightArrow)
{ {
_binaryOperation.leftExpression().accept(*this); _binaryOperation.leftExpression().accept(*this);
_binaryOperation.rightExpression().accept(*this); _binaryOperation.rightExpression().accept(*this);
operationAnnotation.type = helper.functionType(getType(_binaryOperation.leftExpression()), getType(_binaryOperation.rightExpression())); operationAnnotation.type = helper.functionType(typeAnnotation(_binaryOperation.leftExpression()), typeAnnotation(_binaryOperation.rightExpression()));
} }
else if (_binaryOperation.getOperator() == Token::BitOr) else if (_binaryOperation.getOperator() == Token::BitOr)
{ {
_binaryOperation.leftExpression().accept(*this); _binaryOperation.leftExpression().accept(*this);
_binaryOperation.rightExpression().accept(*this); _binaryOperation.rightExpression().accept(*this);
operationAnnotation.type = helper.sumType({getType(_binaryOperation.leftExpression()), getType(_binaryOperation.rightExpression())}); operationAnnotation.type = helper.sumType({typeAnnotation(_binaryOperation.leftExpression()), typeAnnotation(_binaryOperation.rightExpression())});
} }
else else
{ {
@ -421,9 +421,9 @@ void TypeInference::endVisit(VariableDeclarationStatement const& _variableDeclar
m_errorReporter.typeError(2655_error, _variableDeclarationStatement.location(), "Multi variable declaration not supported."); m_errorReporter.typeError(2655_error, _variableDeclarationStatement.location(), "Multi variable declaration not supported.");
return; return;
} }
Type variableType = getType(*_variableDeclarationStatement.declarations().front()); Type variableType = typeAnnotation(*_variableDeclarationStatement.declarations().front());
if (_variableDeclarationStatement.initialValue()) if (_variableDeclarationStatement.initialValue())
unify(variableType, getType(*_variableDeclarationStatement.initialValue()), _variableDeclarationStatement.location()); unify(variableType, typeAnnotation(*_variableDeclarationStatement.initialValue()), _variableDeclarationStatement.location());
} }
bool TypeInference::visit(VariableDeclaration const& _variableDeclaration) bool TypeInference::visit(VariableDeclaration const& _variableDeclaration)
@ -439,7 +439,7 @@ bool TypeInference::visit(VariableDeclaration const& _variableDeclaration)
{ {
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type}; ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type};
_variableDeclaration.typeExpression()->accept(*this); _variableDeclaration.typeExpression()->accept(*this);
variableAnnotation.type = getType(*_variableDeclaration.typeExpression()); variableAnnotation.type = typeAnnotation(*_variableDeclaration.typeExpression());
return false; return false;
} }
variableAnnotation.type = m_typeSystem.freshTypeVariable({}); variableAnnotation.type = m_typeSystem.freshTypeVariable({});
@ -450,7 +450,7 @@ bool TypeInference::visit(VariableDeclaration const& _variableDeclaration)
{ {
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort}; ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort};
_variableDeclaration.typeExpression()->accept(*this); _variableDeclaration.typeExpression()->accept(*this);
unify(*variableAnnotation.type, getType(*_variableDeclaration.typeExpression()), _variableDeclaration.typeExpression()->location()); unify(*variableAnnotation.type, typeAnnotation(*_variableDeclaration.typeExpression()), _variableDeclaration.typeExpression()->location());
} }
return false; return false;
case ExpressionContext::Sort: case ExpressionContext::Sort:
@ -473,7 +473,7 @@ void TypeInference::endVisit(IfStatement const& _ifStatement)
return; return;
} }
unify(getType(_ifStatement.condition()), m_boolType, _ifStatement.condition().location()); unify(typeAnnotation(_ifStatement.condition()), m_boolType, _ifStatement.condition().location());
ifAnnotation.type = m_unitType; ifAnnotation.type = m_unitType;
} }
@ -490,8 +490,8 @@ void TypeInference::endVisit(Assignment const& _assignment)
return; return;
} }
Type leftType = getType(_assignment.leftHandSide()); Type leftType = typeAnnotation(_assignment.leftHandSide());
unify(leftType, getType(_assignment.rightHandSide()), _assignment.location()); unify(leftType, typeAnnotation(_assignment.rightHandSide()), _assignment.location());
assignmentAnnotation.type = leftType; assignmentAnnotation.type = leftType;
} }
@ -731,7 +731,7 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
auto const* functionDefinition = dynamic_cast<FunctionDefinition const*>(subNode.get()); auto const* functionDefinition = dynamic_cast<FunctionDefinition const*>(subNode.get());
solAssert(functionDefinition); solAssert(functionDefinition);
subNode->accept(*this); subNode->accept(*this);
if (!functionTypes.emplace(functionDefinition->name(), getType(*functionDefinition)).second) if (!functionTypes.emplace(functionDefinition->name(), typeAnnotation(*functionDefinition)).second)
m_errorReporter.typeError(3654_error, subNode->location(), "Duplicate definition of function " + functionDefinition->name() + " during type class instantiation."); m_errorReporter.typeError(3654_error, subNode->location(), "Duplicate definition of function " + functionDefinition->name() + " during type class instantiation.");
} }
@ -804,7 +804,7 @@ void TypeInference::endVisit(MemberAccess const& _memberAccess)
{ {
auto& memberAccessAnnotation = annotation(_memberAccess); auto& memberAccessAnnotation = annotation(_memberAccess);
solAssert(!memberAccessAnnotation.type); solAssert(!memberAccessAnnotation.type);
Type expressionType = getType(_memberAccess.expression()); Type expressionType = typeAnnotation(_memberAccess.expression());
memberAccessAnnotation.type = memberType(expressionType, _memberAccess.memberName(), _memberAccess.location()); memberAccessAnnotation.type = memberType(expressionType, _memberAccess.memberName(), _memberAccess.location());
} }
@ -854,7 +854,7 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
auto& functionCallAnnotation = annotation(_functionCall); auto& functionCallAnnotation = annotation(_functionCall);
solAssert(!functionCallAnnotation.type); solAssert(!functionCallAnnotation.type);
Type functionType = getType(_functionCall.expression()); Type functionType = typeAnnotation(_functionCall.expression());
TypeSystemHelpers helper{m_typeSystem}; TypeSystemHelpers helper{m_typeSystem};
std::vector<Type> argTypes; std::vector<Type> argTypes;
@ -864,7 +864,7 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
{ {
case ExpressionContext::Term: case ExpressionContext::Term:
case ExpressionContext::Type: case ExpressionContext::Type:
argTypes.emplace_back(getType(*arg)); argTypes.emplace_back(typeAnnotation(*arg));
break; break;
case ExpressionContext::Sort: case ExpressionContext::Sort:
m_errorReporter.typeError(9173_error, _functionCall.location(), "Function call in sort context."); m_errorReporter.typeError(9173_error, _functionCall.location(), "Function call in sort context.");
@ -1210,7 +1210,7 @@ void TypeInference::unify(Type _a, Type _b, langutil::SourceLocation _location)
} }
} }
experimental::Type TypeInference::getType(ASTNode const& _node) const experimental::Type TypeInference::typeAnnotation(ASTNode const& _node) const
{ {
auto result = annotation(_node).type; auto result = annotation(_node).type;
solAssert(result); solAssert(result);

View File

@ -104,7 +104,7 @@ private:
Type m_boolType; Type m_boolType;
std::optional<Type> m_currentFunctionType; std::optional<Type> m_currentFunctionType;
Type getType(ASTNode const& _node) const; Type typeAnnotation(ASTNode const& _node) const;
Annotation& annotation(ASTNode const& _node); Annotation& annotation(ASTNode const& _node);
Annotation const& annotation(ASTNode const& _node) const; Annotation const& annotation(ASTNode const& _node) const;