fixup! Type inference draft.

This commit is contained in:
Kamil Śliwak 2023-09-11 13:53:23 +02:00 committed by Nikola Matic
parent 761f5b342f
commit 0e95ca163c

View File

@ -192,11 +192,12 @@ void TypeInference::endVisit(ParameterList const& _parameterList)
bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition) bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
{ {
solAssert(m_expressionContext == ExpressionContext::Term); solAssert(m_expressionContext == ExpressionContext::Term);
auto& typeClassAnnotation = annotation(_typeClassDefinition); auto& typeClassDefinitionAnnotation = annotation(_typeClassDefinition);
if (typeClassAnnotation.type) if (typeClassDefinitionAnnotation.type)
return false; return false;
typeClassAnnotation.type = type(&_typeClassDefinition, {}); typeClassDefinitionAnnotation.type = type(&_typeClassDefinition, {});
{ {
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type}; ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type};
_typeClassDefinition.typeVariable().accept(*this); _typeClassDefinition.typeVariable().accept(*this);
@ -205,7 +206,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
std::map<std::string, Type> functionTypes; std::map<std::string, Type> functionTypes;
Type typeVar = m_typeSystem.freshTypeVariable({}); Type typeVar = m_typeSystem.freshTypeVariable({});
auto& typeMembers = annotation().members[typeConstructor(&_typeClassDefinition)]; auto& typeMembersAnnotation = annotation().members[typeConstructor(&_typeClassDefinition)];
for (auto subNode: _typeClassDefinition.subNodes()) for (auto subNode: _typeClassDefinition.subNodes())
{ {
@ -220,7 +221,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
if (typeVars.size() != 1) if (typeVars.size() != 1)
m_errorReporter.fatalTypeError(8379_error, functionDefinition->location(), "Function in type class may only depend on the type class variable."); m_errorReporter.fatalTypeError(8379_error, functionDefinition->location(), "Function in type class may only depend on the type class variable.");
unify(typeVars.front(), typeVar, functionDefinition->location()); unify(typeVars.front(), typeVar, functionDefinition->location());
typeMembers[functionDefinition->name()] = TypeMember{functionType}; typeMembersAnnotation[functionDefinition->name()] = TypeMember{functionType};
} }
TypeClass typeClass = std::visit(util::GenericVisitor{ TypeClass typeClass = std::visit(util::GenericVisitor{
@ -230,7 +231,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
util::unreachable(); util::unreachable();
} }
}, m_typeSystem.declareTypeClass(typeVar, _typeClassDefinition.name(), &_typeClassDefinition)); }, m_typeSystem.declareTypeClass(typeVar, _typeClassDefinition.name(), &_typeClassDefinition));
annotation(_typeClassDefinition).typeClass = typeClass; typeClassDefinitionAnnotation.typeClass = typeClass;
annotation().typeClassFunctions[typeClass] = std::move(functionTypes); annotation().typeClassFunctions[typeClass] = std::move(functionTypes);
for (auto [functionName, functionType]: functionTypes) for (auto [functionName, functionType]: functionTypes)
@ -569,16 +570,16 @@ experimental::Type TypeInference::handleIdentifierByReferencedDeclaration(langut
} }
case ExpressionContext::Sort: case ExpressionContext::Sort:
{ {
if (auto const* typeClass = dynamic_cast<TypeClassDefinition const*>(&_declaration)) if (auto const* typeClassDefinition = dynamic_cast<TypeClassDefinition const*>(&_declaration))
{ {
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Term}; ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Term};
typeClass->accept(*this); typeClassDefinition->accept(*this);
if (!annotation(*typeClass).typeClass) if (!annotation(*typeClassDefinition).typeClass)
{ {
m_errorReporter.typeError(2736_error, _location, "Unregistered type class."); m_errorReporter.typeError(2736_error, _location, "Unregistered type class.");
return m_typeSystem.freshTypeVariable({}); return m_typeSystem.freshTypeVariable({});
} }
return m_typeSystem.freshTypeVariable(Sort{{*annotation(*typeClass).typeClass}}); return m_typeSystem.freshTypeVariable(Sort{{*annotation(*typeClassDefinition).typeClass}});
} }
else else
{ {
@ -674,14 +675,14 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
instantiationAnnotation.type = m_voidType; instantiationAnnotation.type = m_voidType;
std::optional<TypeClass> typeClass = std::visit(util::GenericVisitor{ std::optional<TypeClass> typeClass = std::visit(util::GenericVisitor{
[&](ASTPointer<IdentifierPath> _typeClassName) -> std::optional<TypeClass> { [&](ASTPointer<IdentifierPath> _typeClassName) -> std::optional<TypeClass> {
if (auto const* typeClass = dynamic_cast<TypeClassDefinition const*>(_typeClassName->annotation().referencedDeclaration)) if (auto const* typeClassDefinition = dynamic_cast<TypeClassDefinition const*>(_typeClassName->annotation().referencedDeclaration))
{ {
// visiting the type class will re-visit this instantiation // visiting the type class will re-visit this instantiation
typeClass->accept(*this); typeClassDefinition->accept(*this);
// TODO: more error handling? Should be covered by the visit above. // TODO: more error handling? Should be covered by the visit above.
if (!annotation(*typeClass).typeClass) if (!annotation(*typeClassDefinition).typeClass)
m_errorReporter.typeError(8503_error, _typeClassInstantiation.typeClass().location(), "Expected type class."); m_errorReporter.typeError(8503_error, _typeClassInstantiation.typeClass().location(), "Expected type class.");
return annotation(*typeClass).typeClass; return annotation(*typeClassDefinition).typeClass;
} }
else else
{ {