This commit is contained in:
Daniel Kirchner 2023-06-24 10:10:00 +02:00
parent 60e1e53f20
commit bd4be335d4
4 changed files with 219 additions and 70 deletions

View File

@ -52,6 +52,7 @@ bool TypeInference::analyze(SourceUnit const& _sourceUnit)
bool TypeInference::visit(FunctionDefinition const& _functionDefinition) bool TypeInference::visit(FunctionDefinition const& _functionDefinition)
{ {
solAssert(m_expressionContext == ExpressionContext::Term);
ScopedSaveAndRestore signatureRestore(m_currentFunctionType, nullopt); ScopedSaveAndRestore signatureRestore(m_currentFunctionType, nullopt);
auto& functionAnnotation = annotation(_functionDefinition); auto& functionAnnotation = annotation(_functionDefinition);
if (functionAnnotation.type) if (functionAnnotation.type)
@ -125,16 +126,21 @@ bool TypeInference::visitNode(ASTNode const& _node)
bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition) bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
{ {
solAssert(m_expressionContext == ExpressionContext::Term);
auto& typeVariableAnnotation = annotation(_typeClassDefinition.typeVariable()); auto& typeVariableAnnotation = annotation(_typeClassDefinition.typeVariable());
if (typeVariableAnnotation.type) if (typeVariableAnnotation.type)
return false; return false;
{
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type};
_typeClassDefinition.typeVariable().accept(*this); _typeClassDefinition.typeVariable().accept(*this);
}
for (auto const& subNode: _typeClassDefinition.subNodes()) for (auto const& subNode: _typeClassDefinition.subNodes())
subNode->accept(*this); subNode->accept(*this);
solAssert(typeVariableAnnotation.type); solAssert(typeVariableAnnotation.type);
unify(*typeVariableAnnotation.type, m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{&_typeClassDefinition}}}), _typeClassDefinition.location()); TypeSystemHelpers helper{m_typeSystem};
unify(*typeVariableAnnotation.type, helper.kindType(m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{&_typeClassDefinition}}})), _typeClassDefinition.location());
return false; return false;
} }
@ -190,9 +196,7 @@ void TypeInference::unify(Type _a, Type _b, langutil::SourceLocation _location)
}, },
[&](TypeEnvironment::SortMismatch _sortMismatch) { [&](TypeEnvironment::SortMismatch _sortMismatch) {
m_errorReporter.typeError(0000_error, _location, fmt::format( m_errorReporter.typeError(0000_error, _location, fmt::format(
"Cannot unify {} and {}: {} does not have sort {}", "{} does not have sort {}",
m_env->typeToString(_a),
m_env->typeToString(_b),
m_env->typeToString(_sortMismatch.type), m_env->typeToString(_sortMismatch.type),
TypeSystemHelpers{m_typeSystem}.sortToString(_sortMismatch.sort) TypeSystemHelpers{m_typeSystem}.sortToString(_sortMismatch.sort)
)); ));
@ -254,29 +258,29 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
expressionAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); expressionAnnotation.type = m_typeSystem.freshTypeVariable(false, {});
return false; return false;
} }
TypeSystemHelpers helper{m_typeSystem};
switch(_expression.type().typeName().token()) switch(_expression.type().typeName().token())
{ {
case Token::Word: case Token::Word:
expressionAnnotation.type = m_wordType; expressionAnnotation.type = helper.kindType(m_wordType);
break; break;
case Token::Void: case Token::Void:
expressionAnnotation.type = m_voidType; expressionAnnotation.type = helper.kindType(m_voidType);
break; break;
case Token::Integer: case Token::Integer:
expressionAnnotation.type = m_integerType; expressionAnnotation.type = helper.kindType(m_integerType);
break; break;
case Token::Unit: case Token::Unit:
expressionAnnotation.type = m_unitType; expressionAnnotation.type = helper.kindType(m_unitType);
break; break;
case Token::Pair: case Token::Pair:
{ {
auto leftType = m_typeSystem.freshTypeVariable(true, {}); auto leftType = m_typeSystem.freshTypeVariable(true, {});
auto rightType = m_typeSystem.freshTypeVariable(true, {}); auto rightType = m_typeSystem.freshTypeVariable(true, {});
TypeSystemHelpers helper{m_typeSystem};
expressionAnnotation.type = expressionAnnotation.type =
helper.functionType( helper.functionType(
helper.tupleType({leftType, rightType}), helper.kindType(helper.tupleType({leftType, rightType})),
m_typeSystem.type(BuiltinType::Pair, {leftType, rightType}) helper.kindType(m_typeSystem.type(BuiltinType::Pair, {leftType, rightType}))
); );
break; break;
} }
@ -284,17 +288,16 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
{ {
auto argType = m_typeSystem.freshTypeVariable(true, {}); auto argType = m_typeSystem.freshTypeVariable(true, {});
auto resultType = m_typeSystem.freshTypeVariable(true, {}); auto resultType = m_typeSystem.freshTypeVariable(true, {});
TypeSystemHelpers helper{m_typeSystem};
expressionAnnotation.type = expressionAnnotation.type =
helper.functionType( helper.functionType(
helper.tupleType({argType, resultType}), helper.kindType(helper.tupleType({argType, resultType})),
m_typeSystem.type(BuiltinType::Function, {argType, resultType}) helper.kindType(m_typeSystem.type(BuiltinType::Function, {argType, resultType}))
); );
break; break;
} }
default: default:
m_errorReporter.typeError(0000_error, _expression.location(), "Only elementary types are supported."); m_errorReporter.typeError(0000_error, _expression.location(), "Only elementary types are supported.");
expressionAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); expressionAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
break; break;
} }
return false; return false;
@ -305,6 +308,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
auto& operationAnnotation = annotation(_binaryOperation); auto& operationAnnotation = annotation(_binaryOperation);
auto& leftAnnotation = annotation(_binaryOperation.leftExpression()); auto& leftAnnotation = annotation(_binaryOperation.leftExpression());
auto& rightAnnotation = annotation(_binaryOperation.rightExpression()); auto& rightAnnotation = annotation(_binaryOperation.rightExpression());
TypeSystemHelpers helper{m_typeSystem};
switch (m_expressionContext) switch (m_expressionContext)
{ {
case ExpressionContext::Term: case ExpressionContext::Term:
@ -327,12 +331,12 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
else else
{ {
m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Binary operations other than colon in type context not yet supported."); m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Binary operations other than colon in type context not yet supported.");
operationAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); operationAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
} }
return false; return false;
case ExpressionContext::Sort: case ExpressionContext::Sort:
m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Invalid binary operation in sort context."); m_errorReporter.typeError(0000_error, _binaryOperation.location(), "Invalid binary operation in sort context.");
operationAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); operationAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
return false; return false;
} }
return false; return false;
@ -344,18 +348,33 @@ bool TypeInference::visit(VariableDeclaration const& _variableDeclaration)
auto& variableAnnotation = annotation(_variableDeclaration); auto& variableAnnotation = annotation(_variableDeclaration);
solAssert(!variableAnnotation.type); solAssert(!variableAnnotation.type);
TypeSystemHelpers helper{m_typeSystem};
switch (m_expressionContext)
{
case ExpressionContext::Term:
if (_variableDeclaration.typeExpression()) if (_variableDeclaration.typeExpression())
{ {
ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type}; ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type};
_variableDeclaration.typeExpression()->accept(*this); _variableDeclaration.typeExpression()->accept(*this);
auto& typeExpressionAnnotation = annotation(*_variableDeclaration.typeExpression()); auto& typeExpressionAnnotation = annotation(*_variableDeclaration.typeExpression());
solAssert(typeExpressionAnnotation.type); solAssert(typeExpressionAnnotation.type);
variableAnnotation.type = m_env->fresh(*typeExpressionAnnotation.type, false); variableAnnotation.type = m_env->fresh(helper.destKindType(*typeExpressionAnnotation.type), false);
return false; return false;
} }
variableAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); variableAnnotation.type = m_typeSystem.freshTypeVariable(false, {});
return false;
case ExpressionContext::Type:
if (_variableDeclaration.typeExpression())
m_errorReporter.typeError(0000_error, _variableDeclaration.location(), "Variable declaration in type context with type expression.");
variableAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
return false;
case ExpressionContext::Sort:
m_errorReporter.typeError(0000_error, _variableDeclaration.location(), "Variable declaration in sort context.");
return false;
}
solAssert(false);
return false; return false;
} }
@ -396,6 +415,7 @@ bool TypeInference::visit(Identifier const& _identifier)
auto const* referencedDeclaration = _identifier.annotation().referencedDeclaration; auto const* referencedDeclaration = _identifier.annotation().referencedDeclaration;
TypeSystemHelpers helper{m_typeSystem};
switch(m_expressionContext) switch(m_expressionContext)
{ {
case ExpressionContext::Term: case ExpressionContext::Term:
@ -448,16 +468,22 @@ bool TypeInference::visit(Identifier const& _identifier)
solAssert(declarationAnnotation.type); solAssert(declarationAnnotation.type);
if (dynamic_cast<VariableDeclaration const*>(referencedDeclaration)) if (dynamic_cast<VariableDeclaration const*>(referencedDeclaration))
{
helper.destKindType(*declarationAnnotation.type);
identifierAnnotation.type = declarationAnnotation.type; identifierAnnotation.type = declarationAnnotation.type;
}
else if (dynamic_cast<TypeDefinition const*>(referencedDeclaration)) else if (dynamic_cast<TypeDefinition const*>(referencedDeclaration))
{
helper.destKindType(*declarationAnnotation.type);
identifierAnnotation.type = m_env->fresh(*declarationAnnotation.type, true); identifierAnnotation.type = m_env->fresh(*declarationAnnotation.type, true);
}
else else
solAssert(false); solAssert(false);
} }
else else
{ {
// TODO: register free type variable name! // TODO: register free type variable name!
identifierAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); identifierAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
return false; return false;
} }
break; break;
@ -472,7 +498,7 @@ bool TypeInference::visit(Identifier const& _identifier)
) )
m_errorReporter.fatalTypeError(0000_error, _identifier.location(), "Expected type class."); m_errorReporter.fatalTypeError(0000_error, _identifier.location(), "Expected type class.");
identifierAnnotation.type = m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{typeClass}}}); identifierAnnotation.type = helper.kindType(m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{typeClass}}}));
break; break;
} }
} }
@ -490,15 +516,18 @@ void TypeInference::endVisit(TupleExpression const& _tupleExpression)
solAssert(componentAnnotation.type); solAssert(componentAnnotation.type);
return *componentAnnotation.type; return *componentAnnotation.type;
}) | ranges::to<vector<Type>>; }) | ranges::to<vector<Type>>;
TypeSystemHelpers helper{m_typeSystem};
switch (m_expressionContext) switch (m_expressionContext)
{ {
case ExpressionContext::Type:
case ExpressionContext::Term: case ExpressionContext::Term:
expressionAnnotation.type = TypeSystemHelpers{m_typeSystem}.tupleType(componentTypes); expressionAnnotation.type = helper.tupleType(componentTypes);
break;
case ExpressionContext::Type:
expressionAnnotation.type = helper.kindType(helper.tupleType(componentTypes));
break; break;
case ExpressionContext::Sort: case ExpressionContext::Sort:
{ {
Type type = m_typeSystem.freshTypeVariable(false, {}); Type type = helper.kindType(m_typeSystem.freshTypeVariable(false, {}));
for (auto componentType: componentTypes) for (auto componentType: componentTypes)
unify(type, componentType, _tupleExpression.location()); unify(type, componentType, _tupleExpression.location());
expressionAnnotation.type = type; expressionAnnotation.type = type;
@ -627,13 +656,14 @@ bool TypeInference::visit(TypeDefinition const& _typeDefinition)
for (size_t i = 0; i < _typeDefinition.arguments()->parameters().size(); ++i) for (size_t i = 0; i < _typeDefinition.arguments()->parameters().size(); ++i)
arguments.emplace_back(m_typeSystem.freshTypeVariable(true, {})); arguments.emplace_back(m_typeSystem.freshTypeVariable(true, {}));
TypeSystemHelpers helper{m_typeSystem};
if (arguments.empty()) if (arguments.empty())
typeDefinitionAnnotation.type = m_typeSystem.type(TypeExpression::Constructor{&_typeDefinition}, arguments); typeDefinitionAnnotation.type = helper.kindType(m_typeSystem.type(TypeExpression::Constructor{&_typeDefinition}, arguments));
else else
typeDefinitionAnnotation.type = typeDefinitionAnnotation.type =
TypeSystemHelpers{m_typeSystem}.functionType( helper.functionType(
TypeSystemHelpers{m_typeSystem}.tupleType(arguments), helper.kindType(helper.tupleType(arguments)),
m_typeSystem.type(TypeExpression::Constructor{&_typeDefinition}, arguments) helper.kindType(m_typeSystem.type(TypeExpression::Constructor{&_typeDefinition}, arguments))
); );
return false; return false;
} }
@ -657,16 +687,25 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
argTypes.emplace_back(*argAnnotation.type); argTypes.emplace_back(*argAnnotation.type);
} }
TypeSystemHelpers helper{m_typeSystem};
switch(m_expressionContext) switch(m_expressionContext)
{ {
case ExpressionContext::Term: case ExpressionContext::Term:
case ExpressionContext::Type:
{ {
Type argTuple = TypeSystemHelpers{m_typeSystem}.tupleType(argTypes); Type argTuple = helper.tupleType(argTypes);
Type genericFunctionType = TypeSystemHelpers{m_typeSystem}.functionType(argTuple, m_typeSystem.freshTypeVariable(false, {})); Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable(false, {}));
unify(genericFunctionType, functionType, _functionCall.location()); unify(genericFunctionType, functionType, _functionCall.location());
functionCallAnnotation.type = m_env->resolve(std::get<1>(TypeSystemHelpers{m_typeSystem}.destFunctionType(m_env->resolve(genericFunctionType)))); functionCallAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType))));
break;
}
case ExpressionContext::Type:
{
Type argTuple = helper.tupleType(argTypes);
Type genericFunctionType = helper.functionType(argTuple, m_typeSystem.freshTypeVariable(false, {}));
unify(genericFunctionType, functionType, _functionCall.location());
functionCallAnnotation.type = m_env->resolve(std::get<1>(helper.destFunctionType(m_env->resolve(genericFunctionType))));
break; break;
} }
case ExpressionContext::Sort: case ExpressionContext::Sort:

View File

@ -37,7 +37,6 @@ m_typeSystem(_analysis.typeSystem())
{ {
for (auto [type, name, arity]: std::initializer_list<std::tuple<BuiltinType, const char*, uint64_t>> { for (auto [type, name, arity]: std::initializer_list<std::tuple<BuiltinType, const char*, uint64_t>> {
{BuiltinType::Void, "void", 0}, {BuiltinType::Void, "void", 0},
{BuiltinType::Function, "fun", 2},
{BuiltinType::Unit, "unit", 0}, {BuiltinType::Unit, "unit", 0},
{BuiltinType::Pair, "pair", 2}, {BuiltinType::Pair, "pair", 2},
{BuiltinType::Word, "word", 0}, {BuiltinType::Word, "word", 0},
@ -120,10 +119,11 @@ bool TypeRegistration::visit(TypeClassInstantiation const& _typeClassInstantiati
{ {
if (auto const* referencedDeclaration = argumentSort->annotation().referencedDeclaration) if (auto const* referencedDeclaration = argumentSort->annotation().referencedDeclaration)
{ {
if (!dynamic_cast<TypeClassDefinition const*>(referencedDeclaration)) if (auto const* typeClassDefinition = dynamic_cast<TypeClassDefinition const*>(referencedDeclaration))
m_errorReporter.fatalTypeError(0000_error, argumentSort->location(), "Argument sort has to be a type class.");
// TODO: multi arities // TODO: multi arities
arity.argumentSorts.emplace_back(Sort{{TypeClass{referencedDeclaration}}}); arity.argumentSorts.emplace_back(Sort{{TypeClass{typeClassDefinition}}});
else
m_errorReporter.fatalTypeError(0000_error, argumentSort->location(), "Argument sort has to be a type class.");
} }
else else
{ {

View File

@ -37,7 +37,39 @@ using namespace solidity::frontend::experimental;
bool TypeClass::operator<(TypeClass const& _rhs) const bool TypeClass::operator<(TypeClass const& _rhs) const
{ {
return declaration->id() < _rhs.declaration->id(); return std::visit(util::GenericVisitor{
[](BuiltinClass _left, BuiltinClass _right) { return _left < _right; },
[](TypeClassDefinition const* _left, TypeClassDefinition const* _right) { return _left->id() < _right->id(); },
[](BuiltinClass, TypeClassDefinition const*) { return true; },
[](TypeClassDefinition const*, BuiltinClass) { return false; },
}, declaration, _rhs.declaration);
}
bool TypeClass::operator==(TypeClass const& _rhs) const
{
return std::visit(util::GenericVisitor{
[](BuiltinClass _left, BuiltinClass _right) { return _left == _right; },
[](TypeClassDefinition const* _left, TypeClassDefinition const* _right) { return _left->id() == _right->id(); },
[](BuiltinClass, TypeClassDefinition const*) { return false; },
[](TypeClassDefinition const*, BuiltinClass) { return false; },
}, declaration, _rhs.declaration);
}
string TypeClass::toString() const
{
return std::visit(util::GenericVisitor{
[](BuiltinClass _class) -> string {
switch(_class)
{
case BuiltinClass::Type:
return "type";
case BuiltinClass::Kind:
return "kind";
}
solAssert(false);
},
[](TypeClassDefinition const* _declaration) { return _declaration->name(); },
}, declaration);
} }
bool Sort::operator==(Sort const& _rhs) const bool Sort::operator==(Sort const& _rhs) const
@ -65,6 +97,14 @@ Sort Sort::operator+(Sort const& _rhs) const
return result; return result;
} }
Sort Sort::operator-(Sort const& _rhs) const
{
Sort result { classes };
result.classes -= _rhs.classes;
return result;
}
bool TypeExpression::operator<(TypeExpression const& _rhs) const bool TypeExpression::operator<(TypeExpression const& _rhs) const
{ {
if (constructor < _rhs.constructor) if (constructor < _rhs.constructor)
@ -110,6 +150,9 @@ std::string experimental::canonicalTypeName(Type _type)
printTypeArguments(); printTypeArguments();
switch(_builtinType) switch(_builtinType)
{ {
case BuiltinType::Type:
stream << "type";
break;
case BuiltinType::Void: case BuiltinType::Void:
stream << "void"; stream << "void";
break; break;
@ -156,8 +199,8 @@ std::string TypeEnvironment::typeToString(Type const& _type) const
}; };
std::visit(util::GenericVisitor{ std::visit(util::GenericVisitor{
[&](Declaration const* _declaration) { [&](Declaration const* _declaration) {
printTypeArguments();
stream << m_typeSystem.typeName(_declaration); stream << m_typeSystem.typeName(_declaration);
printTypeArguments();
}, },
[&](BuiltinType _builtinType) { [&](BuiltinType _builtinType) {
switch (_builtinType) switch (_builtinType)
@ -179,9 +222,15 @@ std::string TypeEnvironment::typeToString(Type const& _type) const
stream << typeToString(tupleTypes.back()) << ")"; stream << typeToString(tupleTypes.back()) << ")";
break; break;
} }
case BuiltinType::Type:
{
solAssert(_type.arguments.size() == 1);
stream << "TYPE(" << typeToString(_type.arguments.front()) << ")";
break;
}
default: default:
printTypeArguments();
stream << m_typeSystem.typeName(_builtinType); stream << m_typeSystem.typeName(_builtinType);
printTypeArguments();
break; break;
} }
} }
@ -196,14 +245,14 @@ std::string TypeEnvironment::typeToString(Type const& _type) const
case 0: case 0:
break; break;
case 1: case 1:
stream << ":" << _type.sort().classes.begin()->declaration->name(); stream << ":" << _type.sort().classes.begin()->toString();
break; break;
default: default:
stream << ":{"; stream << ":(";
for (auto typeClass: _type.sort().classes | ranges::views::drop_last(1)) for (auto typeClass: _type.sort().classes | ranges::views::drop_last(1))
stream << typeClass.declaration->name() << ", "; stream << typeClass.toString() << ", ";
stream << _type.sort().classes.rbegin()->declaration->name(); stream << _type.sort().classes.rbegin()->toString();
stream << "}"; stream << ")";
break; break;
} }
return stream.str(); return stream.str();
@ -234,7 +283,7 @@ vector<TypeEnvironment::UnificationFailure> TypeEnvironment::unify(Type _a, Type
failures += instantiate(_right, _left); failures += instantiate(_right, _left);
else else
{ {
Type newVar = m_typeSystem.freshTypeVariable(_left.generic() && _right.generic(), _left.sort() + _right.sort()); Type newVar = m_typeSystem.freshVariable(_left.generic() && _right.generic(), _left.sort() + _right.sort());
failures += instantiate(_left, newVar); failures += instantiate(_left, newVar);
failures += instantiate(_right, newVar); failures += instantiate(_right, newVar);
} }
@ -268,10 +317,39 @@ TypeEnvironment TypeEnvironment::clone() const
return result; return result;
} }
experimental::Type TypeSystem::freshTypeVariable(bool _generic, Sort const& _sort) TypeSystem::TypeSystem()
{
Sort kindSort{{TypeClass{BuiltinClass::Kind}}};
Sort typeSort{{TypeClass{BuiltinClass::Type}}};
m_typeConstructors[BuiltinType::Type] = TypeConstructorInfo{
"type",
{Arity{vector<Sort>{{kindSort}}, TypeClass{BuiltinClass::Kind}}}
};
m_typeConstructors[BuiltinType::Function] = TypeConstructorInfo{
"fun",
{
Arity{vector<Sort>{{kindSort, kindSort}}, TypeClass{BuiltinClass::Kind}},
Arity{vector<Sort>{{typeSort, typeSort}}, TypeClass{BuiltinClass::Type}}
}
};
}
experimental::Type TypeSystem::freshVariable(bool _generic, Sort _sort)
{ {
uint64_t index = m_numTypeVariables++; uint64_t index = m_numTypeVariables++;
return TypeVariable(index, _sort, _generic); return TypeVariable(index, std::move(_sort), _generic);
}
experimental::Type TypeSystem::freshTypeVariable(bool _generic, Sort _sort)
{
_sort.classes.emplace(TypeClass{BuiltinClass::Type});
return freshVariable(_generic, _sort);
}
experimental::Type TypeSystem::freshKindVariable(bool _generic, Sort _sort)
{
_sort.classes.emplace(TypeClass{BuiltinClass::Kind});
return freshVariable(_generic, _sort);
} }
vector<TypeEnvironment::UnificationFailure> TypeEnvironment::instantiate(TypeVariable _variable, Type _type) vector<TypeEnvironment::UnificationFailure> TypeEnvironment::instantiate(TypeVariable _variable, Type _type)
@ -279,7 +357,7 @@ vector<TypeEnvironment::UnificationFailure> TypeEnvironment::instantiate(TypeVar
Sort sort = m_typeSystem.sort(_type); Sort sort = m_typeSystem.sort(_type);
if (!(_variable.sort() < sort)) if (!(_variable.sort() < sort))
{ {
return {UnificationFailure{SortMismatch{_type, _variable.sort()}}}; return {UnificationFailure{SortMismatch{_type, _variable.sort() - sort}}};
} }
solAssert(m_typeVariables.emplace(_variable.index(), _type).second); solAssert(m_typeVariables.emplace(_variable.index(), _type).second);
return {}; return {};
@ -346,10 +424,10 @@ Sort TypeSystem::sort(Type _type) const
void TypeSystem::declareTypeConstructor(TypeExpression::Constructor _typeConstructor, std::string _name, size_t _arguments) void TypeSystem::declareTypeConstructor(TypeExpression::Constructor _typeConstructor, std::string _name, size_t _arguments)
{ {
Sort baseSort{{TypeClass{BuiltinClass::Type}}};
bool newlyInserted = m_typeConstructors.emplace(std::make_pair(_typeConstructor, TypeConstructorInfo{ bool newlyInserted = m_typeConstructors.emplace(std::make_pair(_typeConstructor, TypeConstructorInfo{
_name, _name,
_arguments, {Arity{vector<Sort>{_arguments, baseSort}, TypeClass{BuiltinClass::Type}}}
{}
})).second; })).second;
// TODO: proper error handling. // TODO: proper error handling.
solAssert(newlyInserted, "Type constructor already declared."); solAssert(newlyInserted, "Type constructor already declared.");
@ -359,7 +437,7 @@ experimental::Type TypeSystem::type(TypeExpression::Constructor _constructor, st
{ {
// TODO: proper error handling // TODO: proper error handling
auto const& info = m_typeConstructors.at(_constructor); auto const& info = m_typeConstructors.at(_constructor);
solAssert(info.arguments == _arguments.size(), "Invalid arity."); solAssert(info.arguments() == _arguments.size(), "Invalid arity.");
return TypeExpression{_constructor, _arguments}; return TypeExpression{_constructor, _arguments};
} }
@ -401,7 +479,7 @@ void TypeSystem::instantiateClass(TypeExpression::Constructor _typeConstructor,
{ {
// TODO: proper error handling // TODO: proper error handling
auto& typeConstructorInfo = m_typeConstructors.at(_typeConstructor); auto& typeConstructorInfo = m_typeConstructors.at(_typeConstructor);
solAssert(_arity.argumentSorts.size() == typeConstructorInfo.arguments, "Invalid arity."); solAssert(_arity.argumentSorts.size() == typeConstructorInfo.arguments(), "Invalid arity.");
typeConstructorInfo.arities.emplace_back(_arity); typeConstructorInfo.arities.emplace_back(_arity);
} }
@ -499,6 +577,19 @@ vector<experimental::Type> TypeSystemHelpers::typeVars(Type _type) const
} }
experimental::Type TypeSystemHelpers::kindType(Type _type) const
{
return typeSystem.type(BuiltinType::Type, {_type});
}
experimental::Type TypeSystemHelpers::destKindType(Type _type) const
{
auto [constructor, arguments] = destTypeExpression(_type);
solAssert(constructor == TypeExpression::Constructor{BuiltinType::Type});
solAssert(arguments.size() == 1);
return arguments.front();
}
std::string TypeSystemHelpers::sortToString(Sort _sort) const std::string TypeSystemHelpers::sortToString(Sort _sort) const
{ {
switch (_sort.classes.size()) switch (_sort.classes.size())
@ -506,14 +597,14 @@ std::string TypeSystemHelpers::sortToString(Sort _sort) const
case 0: case 0:
return "()"; return "()";
case 1: case 1:
return _sort.classes.begin()->declaration->name(); return _sort.classes.begin()->toString();
default: default:
{ {
std::stringstream stream; std::stringstream stream;
stream << "("; stream << "(";
for (auto typeClass: _sort.classes | ranges::views::drop_last(1)) for (auto typeClass: _sort.classes | ranges::views::drop_last(1))
stream << typeClass.declaration->name() << ", "; stream << typeClass.toString() << ", ";
stream << _sort.classes.rbegin()->declaration->name() << ")"; stream << _sort.classes.rbegin()->toString() << ")";
return stream.str(); return stream.str();
} }
} }

View File

@ -27,6 +27,7 @@
namespace solidity::frontend namespace solidity::frontend
{ {
class Declaration; class Declaration;
class TypeClassDefinition;
} }
namespace solidity::frontend::experimental namespace solidity::frontend::experimental
@ -44,6 +45,7 @@ std::string canonicalTypeName(Type _type);
enum class BuiltinType enum class BuiltinType
{ {
Type,
Void, Void,
Function, Function,
Unit, Unit,
@ -69,12 +71,19 @@ struct TypeExpression
} }
}; };
enum class BuiltinClass
{
Type,
Kind
};
struct TypeClass struct TypeClass
{ {
Declaration const* declaration = nullptr; std::variant<BuiltinClass, TypeClassDefinition const*> declaration;
std::string toString() const;
bool operator<(TypeClass const& _rhs) const; bool operator<(TypeClass const& _rhs) const;
bool operator==(TypeClass const& _rhs) const { return declaration == _rhs.declaration; } bool operator==(TypeClass const& _rhs) const;
bool operator!=(TypeClass const& _rhs) const { return declaration != _rhs.declaration; } bool operator!=(TypeClass const& _rhs) const { return !operator==(_rhs); }
}; };
struct Sort struct Sort
@ -84,6 +93,7 @@ struct Sort
bool operator!=(Sort const& _rhs) const { return !operator==(_rhs); } bool operator!=(Sort const& _rhs) const { return !operator==(_rhs); }
bool operator<(Sort const& _rhs) const; bool operator<(Sort const& _rhs) const;
Sort operator+(Sort const& _rhs) const; Sort operator+(Sort const& _rhs) const;
Sort operator-(Sort const& _rhs) const;
}; };
struct Arity struct Arity
@ -145,7 +155,7 @@ private:
class TypeSystem class TypeSystem
{ {
public: public:
TypeSystem() {} TypeSystem();
TypeSystem(TypeSystem const&) = delete; TypeSystem(TypeSystem const&) = delete;
TypeSystem const& operator=(TypeSystem const&) = delete; TypeSystem const& operator=(TypeSystem const&) = delete;
Type type(TypeExpression::Constructor _typeConstructor, std::vector<Type> _arguments) const; Type type(TypeExpression::Constructor _typeConstructor, std::vector<Type> _arguments) const;
@ -158,22 +168,29 @@ public:
size_t constructorArguments(TypeExpression::Constructor _typeConstructor) const size_t constructorArguments(TypeExpression::Constructor _typeConstructor) const
{ {
// TODO: error handling // TODO: error handling
return m_typeConstructors.at(_typeConstructor).arguments; return m_typeConstructors.at(_typeConstructor).arguments();
} }
void instantiateClass(TypeExpression::Constructor _typeConstructor, Arity _arity); void instantiateClass(TypeExpression::Constructor _typeConstructor, Arity _arity);
Type freshTypeVariable(bool _generic, Sort const& _sort); Type freshTypeVariable(bool _generic, Sort _sort);
Type freshKindVariable(bool _generic, Sort _sort);
TypeEnvironment const& env() const { return m_globalTypeEnvironment; } TypeEnvironment const& env() const { return m_globalTypeEnvironment; }
TypeEnvironment& env() { return m_globalTypeEnvironment; } TypeEnvironment& env() { return m_globalTypeEnvironment; }
Sort sort(Type _type) const; Sort sort(Type _type) const;
Type freshVariable(bool _generic, Sort _sort);
private: private:
size_t m_numTypeVariables = 0; size_t m_numTypeVariables = 0;
struct TypeConstructorInfo struct TypeConstructorInfo
{ {
std::string name; std::string name;
size_t arguments;
std::vector<Arity> arities; std::vector<Arity> arities;
size_t arguments() const
{
solAssert(!arities.empty());
return arities.front().argumentSorts.size();
}
}; };
std::map<TypeExpression::Constructor, TypeConstructorInfo> m_typeConstructors; std::map<TypeExpression::Constructor, TypeConstructorInfo> m_typeConstructors;
TypeEnvironment m_globalTypeEnvironment{*this}; TypeEnvironment m_globalTypeEnvironment{*this};
@ -189,6 +206,8 @@ struct TypeSystemHelpers
std::tuple<Type, Type> destFunctionType(Type _functionType) const; std::tuple<Type, Type> destFunctionType(Type _functionType) const;
std::vector<Type> typeVars(Type _type) const; std::vector<Type> typeVars(Type _type) const;
std::string sortToString(Sort _sort) const; std::string sortToString(Sort _sort) const;
Type kindType(Type _type) const;
Type destKindType(Type _type) const;
}; };
} }