fixup! Type inference draft.

This commit is contained in:
Kamil Śliwak 2023-09-20 13:25:31 +02:00
parent e3caed0ea4
commit 6d43dfbd43
2 changed files with 27 additions and 27 deletions

View File

@ -152,14 +152,14 @@ bool TypeInference::visit(FunctionDefinition const& _functionDefinition)
_functionDefinition.parameterList().accept(*this);
unify(argumentsType, getType(_functionDefinition.parameterList()), _functionDefinition.parameterList().location());
unify(argumentsType, typeAnnotation(_functionDefinition.parameterList()), _functionDefinition.parameterList().location());
if (_functionDefinition.experimentalReturnExpression())
{
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type};
_functionDefinition.experimentalReturnExpression()->accept(*this);
unify(
returnType,
getType(*_functionDefinition.experimentalReturnExpression()),
typeAnnotation(*_functionDefinition.experimentalReturnExpression()),
_functionDefinition.experimentalReturnExpression()->location()
);
}
@ -177,7 +177,7 @@ void TypeInference::endVisit(Return const& _return)
solAssert(m_currentFunctionType);
Type functionReturnType = std::get<1>(TypeSystemHelpers{m_typeSystem}.destFunctionType(*m_currentFunctionType));
if (_return.expression())
unify(functionReturnType, getType(*_return.expression()), _return.location());
unify(functionReturnType, typeAnnotation(*_return.expression()), _return.location());
else
unify(functionReturnType, m_unitType, _return.location());
}
@ -187,7 +187,7 @@ void TypeInference::endVisit(ParameterList const& _parameterList)
auto& listAnnotation = annotation(_parameterList);
solAssert(!listAnnotation.type);
listAnnotation.type = TypeSystemHelpers{m_typeSystem}.tupleType(
_parameterList.parameters() | ranges::views::transform([&](auto _arg) { return getType(*_arg); }) | ranges::to<std::vector<Type>>
_parameterList.parameters() | ranges::views::transform([&](auto _arg) { return typeAnnotation(*_arg); }) | ranges::to<std::vector<Type>>
);
}
@ -218,7 +218,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
auto const* functionDefinition = dynamic_cast<FunctionDefinition const*>(subNode.get());
solAssert(functionDefinition);
// TODO: need polymorphicInstance?
auto functionType = polymorphicInstance(getType(*functionDefinition));
auto functionType = polymorphicInstance(typeAnnotation(*functionDefinition));
if (!functionTypes.emplace(functionDefinition->name(), functionType).second)
m_errorReporter.fatalTypeError(3195_error, functionDefinition->location(), "Function in type class declared multiple times.");
auto typeVars = TypeEnvironmentHelpers{*m_env}.typeVars(functionType);
@ -242,7 +242,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
m_errorReporter.typeError(1807_error, _typeClassDefinition.location(), "Function " + functionName + " depends on invalid type variable.");
}
unify(getType(_typeClassDefinition.typeVariable()), m_typeSystem.freshTypeVariable({{typeClass}}), _typeClassDefinition.location());
unify(typeAnnotation(_typeClassDefinition.typeVariable()), m_typeSystem.freshTypeVariable({{typeClass}}), _typeClassDefinition.location());
for (auto instantiation: m_analysis.annotation<TypeRegistration>(_typeClassDefinition).instantiations | ranges::views::values)
// TODO: recursion-safety? Order of instantiation?
instantiation->accept(*this);
@ -275,7 +275,7 @@ bool TypeInference::visit(InlineAssembly const& _inlineAssembly)
solAssert(!!declaration, "");
solAssert(identifierInfo->suffix == "", "");
unify(getType(*declaration), m_wordType, originLocationOf(_identifier));
unify(typeAnnotation(*declaration), m_wordType, originLocationOf(_identifier));
identifierInfo->valueSize = 1;
return true;
};
@ -351,7 +351,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
_binaryOperation.leftExpression().accept(*this);
_binaryOperation.rightExpression().accept(*this);
Type argTuple = helper.tupleType({getType(_binaryOperation.leftExpression()), getType(_binaryOperation.rightExpression())});
Type argTuple = helper.tupleType({typeAnnotation(_binaryOperation.leftExpression()), typeAnnotation(_binaryOperation.rightExpression())});
Type resultType = m_typeSystem.freshTypeVariable({});
Type genericFunctionType = helper.functionType(argTuple, resultType);
unify(functionType, genericFunctionType, _binaryOperation.location());
@ -365,8 +365,8 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type};
_binaryOperation.rightExpression().accept(*this);
}
Type leftType = getType(_binaryOperation.leftExpression());
unify(leftType, getType(_binaryOperation.rightExpression()), _binaryOperation.location());
Type leftType = typeAnnotation(_binaryOperation.leftExpression());
unify(leftType, typeAnnotation(_binaryOperation.rightExpression()), _binaryOperation.location());
operationAnnotation.type = leftType;
}
else
@ -383,21 +383,21 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort};
_binaryOperation.rightExpression().accept(*this);
}
Type leftType = getType(_binaryOperation.leftExpression());
unify(leftType, getType(_binaryOperation.rightExpression()), _binaryOperation.location());
Type leftType = typeAnnotation(_binaryOperation.leftExpression());
unify(leftType, typeAnnotation(_binaryOperation.rightExpression()), _binaryOperation.location());
operationAnnotation.type = leftType;
}
else if (_binaryOperation.getOperator() == Token::RightArrow)
{
_binaryOperation.leftExpression().accept(*this);
_binaryOperation.rightExpression().accept(*this);
operationAnnotation.type = helper.functionType(getType(_binaryOperation.leftExpression()), getType(_binaryOperation.rightExpression()));
operationAnnotation.type = helper.functionType(typeAnnotation(_binaryOperation.leftExpression()), typeAnnotation(_binaryOperation.rightExpression()));
}
else if (_binaryOperation.getOperator() == Token::BitOr)
{
_binaryOperation.leftExpression().accept(*this);
_binaryOperation.rightExpression().accept(*this);
operationAnnotation.type = helper.sumType({getType(_binaryOperation.leftExpression()), getType(_binaryOperation.rightExpression())});
operationAnnotation.type = helper.sumType({typeAnnotation(_binaryOperation.leftExpression()), typeAnnotation(_binaryOperation.rightExpression())});
}
else
{
@ -421,9 +421,9 @@ void TypeInference::endVisit(VariableDeclarationStatement const& _variableDeclar
m_errorReporter.typeError(2655_error, _variableDeclarationStatement.location(), "Multi variable declaration not supported.");
return;
}
Type variableType = getType(*_variableDeclarationStatement.declarations().front());
Type variableType = typeAnnotation(*_variableDeclarationStatement.declarations().front());
if (_variableDeclarationStatement.initialValue())
unify(variableType, getType(*_variableDeclarationStatement.initialValue()), _variableDeclarationStatement.location());
unify(variableType, typeAnnotation(*_variableDeclarationStatement.initialValue()), _variableDeclarationStatement.location());
}
bool TypeInference::visit(VariableDeclaration const& _variableDeclaration)
@ -439,7 +439,7 @@ bool TypeInference::visit(VariableDeclaration const& _variableDeclaration)
{
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type};
_variableDeclaration.typeExpression()->accept(*this);
variableAnnotation.type = getType(*_variableDeclaration.typeExpression());
variableAnnotation.type = typeAnnotation(*_variableDeclaration.typeExpression());
return false;
}
variableAnnotation.type = m_typeSystem.freshTypeVariable({});
@ -450,7 +450,7 @@ bool TypeInference::visit(VariableDeclaration const& _variableDeclaration)
{
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Sort};
_variableDeclaration.typeExpression()->accept(*this);
unify(*variableAnnotation.type, getType(*_variableDeclaration.typeExpression()), _variableDeclaration.typeExpression()->location());
unify(*variableAnnotation.type, typeAnnotation(*_variableDeclaration.typeExpression()), _variableDeclaration.typeExpression()->location());
}
return false;
case ExpressionContext::Sort:
@ -473,7 +473,7 @@ void TypeInference::endVisit(IfStatement const& _ifStatement)
return;
}
unify(getType(_ifStatement.condition()), m_boolType, _ifStatement.condition().location());
unify(typeAnnotation(_ifStatement.condition()), m_boolType, _ifStatement.condition().location());
ifAnnotation.type = m_unitType;
}
@ -490,8 +490,8 @@ void TypeInference::endVisit(Assignment const& _assignment)
return;
}
Type leftType = getType(_assignment.leftHandSide());
unify(leftType, getType(_assignment.rightHandSide()), _assignment.location());
Type leftType = typeAnnotation(_assignment.leftHandSide());
unify(leftType, typeAnnotation(_assignment.rightHandSide()), _assignment.location());
assignmentAnnotation.type = leftType;
}
@ -731,7 +731,7 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
auto const* functionDefinition = dynamic_cast<FunctionDefinition const*>(subNode.get());
solAssert(functionDefinition);
subNode->accept(*this);
if (!functionTypes.emplace(functionDefinition->name(), getType(*functionDefinition)).second)
if (!functionTypes.emplace(functionDefinition->name(), typeAnnotation(*functionDefinition)).second)
m_errorReporter.typeError(3654_error, subNode->location(), "Duplicate definition of function " + functionDefinition->name() + " during type class instantiation.");
}
@ -804,7 +804,7 @@ void TypeInference::endVisit(MemberAccess const& _memberAccess)
{
auto& memberAccessAnnotation = annotation(_memberAccess);
solAssert(!memberAccessAnnotation.type);
Type expressionType = getType(_memberAccess.expression());
Type expressionType = typeAnnotation(_memberAccess.expression());
memberAccessAnnotation.type = memberType(expressionType, _memberAccess.memberName(), _memberAccess.location());
}
@ -854,7 +854,7 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
auto& functionCallAnnotation = annotation(_functionCall);
solAssert(!functionCallAnnotation.type);
Type functionType = getType(_functionCall.expression());
Type functionType = typeAnnotation(_functionCall.expression());
TypeSystemHelpers helper{m_typeSystem};
std::vector<Type> argTypes;
@ -864,7 +864,7 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
{
case ExpressionContext::Term:
case ExpressionContext::Type:
argTypes.emplace_back(getType(*arg));
argTypes.emplace_back(typeAnnotation(*arg));
break;
case ExpressionContext::Sort:
m_errorReporter.typeError(9173_error, _functionCall.location(), "Function call in sort context.");
@ -1210,7 +1210,7 @@ void TypeInference::unify(Type _a, Type _b, langutil::SourceLocation _location)
}
}
experimental::Type TypeInference::getType(ASTNode const& _node) const
experimental::Type TypeInference::typeAnnotation(ASTNode const& _node) const
{
auto result = annotation(_node).type;
solAssert(result);

View File

@ -104,7 +104,7 @@ private:
Type m_boolType;
std::optional<Type> m_currentFunctionType;
Type getType(ASTNode const& _node) const;
Type typeAnnotation(ASTNode const& _node) const;
Annotation& annotation(ASTNode const& _node);
Annotation const& annotation(ASTNode const& _node) const;