This commit is contained in:
Daniel Kirchner 2023-06-24 07:47:33 +02:00
parent 0d1679cee0
commit 4a84818669
6 changed files with 150 additions and 29 deletions

View File

@ -275,6 +275,9 @@ namespace solidity::langutil
K(Word, "word", 0) \ K(Word, "word", 0) \
K(Integer, "integer", 0) \ K(Integer, "integer", 0) \
K(Void, "void", 0) \ K(Void, "void", 0) \
K(Pair, "pair", 0) \
K(Fun, "fun", 0) \
K(Unit, "unit", 0) \
K(StaticAssert, "static_assert", 0) \ K(StaticAssert, "static_assert", 0) \
T(ExperimentalEnd, nullptr, 0) /* used as experimental enum end marker */ \ T(ExperimentalEnd, nullptr, 0) /* used as experimental enum end marker */ \
\ \
@ -302,7 +305,9 @@ namespace TokenTraits
// Predicates // Predicates
constexpr bool isElementaryTypeName(Token tok) constexpr bool isElementaryTypeName(Token tok)
{ {
return (Token::Int <= tok && tok < Token::TypesEnd) || tok == Token::Word || tok == Token::Void || tok == Token::Integer; return (Token::Int <= tok && tok < Token::TypesEnd) ||
tok == Token::Word || tok == Token::Void || tok == Token::Integer ||
tok == Token::Pair || tok == Token::Unit || tok == Token::Fun;
} }
constexpr bool isAssignmentOp(Token tok) { return Token::Assign <= tok && tok <= Token::AssignMod; } constexpr bool isAssignmentOp(Token tok) { return Token::Assign <= tok && tok <= Token::AssignMod; }
constexpr bool isBinaryOp(Token op) { return Token::Comma <= op && op <= Token::Exp; } constexpr bool isBinaryOp(Token op) { return Token::Comma <= op && op <= Token::Exp; }

View File

@ -40,6 +40,7 @@ m_typeSystem(_analysis.typeSystem())
m_voidType = m_typeSystem.builtinType(BuiltinType::Void, {}); m_voidType = m_typeSystem.builtinType(BuiltinType::Void, {});
m_wordType = m_typeSystem.builtinType(BuiltinType::Word, {}); m_wordType = m_typeSystem.builtinType(BuiltinType::Word, {});
m_integerType = m_typeSystem.builtinType(BuiltinType::Integer, {}); m_integerType = m_typeSystem.builtinType(BuiltinType::Integer, {});
m_unitType = m_typeSystem.builtinType(BuiltinType::Unit, {});
m_env = &m_typeSystem.env(); m_env = &m_typeSystem.env();
} }
@ -93,7 +94,7 @@ void TypeInference::endVisit(Return const& _return)
auto& returnExpressionAnnotation = annotation(*_return.expression()); auto& returnExpressionAnnotation = annotation(*_return.expression());
solAssert(returnExpressionAnnotation.type); solAssert(returnExpressionAnnotation.type);
Type functionReturnType = get<1>(TypeSystemHelpers{m_typeSystem}.destFunctionType(*m_currentFunctionType)); Type functionReturnType = get<1>(TypeSystemHelpers{m_typeSystem}.destFunctionType(*m_currentFunctionType));
unify(functionReturnType, *returnExpressionAnnotation.type); unify(functionReturnType, *returnExpressionAnnotation.type, _return.location());
} }
} }
@ -133,7 +134,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
subNode->accept(*this); subNode->accept(*this);
solAssert(typeVariableAnnotation.type); solAssert(typeVariableAnnotation.type);
unify(*typeVariableAnnotation.type, m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{&_typeClassDefinition}}})); unify(*typeVariableAnnotation.type, m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{&_typeClassDefinition}}}), _typeClassDefinition.location());
return false; return false;
} }
@ -180,10 +181,23 @@ experimental::Type TypeInference::fromTypeName(TypeName const& _typeName)
} }
*/ */
void TypeInference::unify(Type _a, Type _b) void TypeInference::unify(Type _a, Type _b, langutil::SourceLocation _location)
{ {
for (auto failure: m_env->unify(_a, _b)) for (auto failure: m_env->unify(_a, _b))
m_errorReporter.typeError(0000_error, {}, fmt::format("Cannot unify {} and {}.", m_env->typeToString(_a), m_env->typeToString(_b))); std::visit(util::GenericVisitor{
[&](TypeEnvironment::TypeMismatch _typeMismatch) {
m_errorReporter.typeError(0000_error, _location, fmt::format("Cannot unify {} and {}.", m_env->typeToString(_typeMismatch.a), m_env->typeToString(_typeMismatch.b)));
},
[&](TypeEnvironment::SortMismatch _sortMismatch) {
m_errorReporter.typeError(0000_error, _location, fmt::format(
"Cannot unify {} and {}: {} does not have sort {}",
m_env->typeToString(_a),
m_env->typeToString(_b),
m_env->typeToString(_sortMismatch.type),
TypeSystemHelpers{m_typeSystem}.sortToString(_sortMismatch.sort)
));
}
}, failure);
} }
bool TypeInference::visit(InlineAssembly const& _inlineAssembly) bool TypeInference::visit(InlineAssembly const& _inlineAssembly)
@ -214,7 +228,7 @@ bool TypeInference::visit(InlineAssembly const& _inlineAssembly)
auto& declarationAnnotation = annotation(*declaration); auto& declarationAnnotation = annotation(*declaration);
solAssert(declarationAnnotation.type); solAssert(declarationAnnotation.type);
unify(*declarationAnnotation.type, m_wordType); unify(*declarationAnnotation.type, m_wordType, originLocationOf(_identifier));
identifierInfo.valueSize = 1; identifierInfo.valueSize = 1;
return true; return true;
}; };
@ -251,6 +265,33 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
case Token::Integer: case Token::Integer:
expressionAnnotation.type = m_integerType; expressionAnnotation.type = m_integerType;
break; break;
case Token::Unit:
expressionAnnotation.type = m_unitType;
break;
case Token::Pair:
{
auto leftType = m_typeSystem.freshTypeVariable(true, {});
auto rightType = m_typeSystem.freshTypeVariable(true, {});
TypeSystemHelpers helper{m_typeSystem};
expressionAnnotation.type =
helper.functionType(
helper.tupleType({leftType, rightType}),
m_typeSystem.type(BuiltinType::Pair, {leftType, rightType})
);
break;
}
case Token::Fun:
{
auto argType = m_typeSystem.freshTypeVariable(true, {});
auto resultType = m_typeSystem.freshTypeVariable(true, {});
TypeSystemHelpers helper{m_typeSystem};
expressionAnnotation.type =
helper.functionType(
helper.tupleType({argType, resultType}),
m_typeSystem.type(BuiltinType::Function, {argType, resultType})
);
break;
}
default: default:
m_errorReporter.typeError(0000_error, _expression.location(), "Only elementary types are supported."); m_errorReporter.typeError(0000_error, _expression.location(), "Only elementary types are supported.");
expressionAnnotation.type = m_typeSystem.freshTypeVariable(false, {}); expressionAnnotation.type = m_typeSystem.freshTypeVariable(false, {});
@ -280,7 +321,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
} }
solAssert(leftAnnotation.type); solAssert(leftAnnotation.type);
solAssert(rightAnnotation.type); solAssert(rightAnnotation.type);
unify(*leftAnnotation.type, *rightAnnotation.type); unify(*leftAnnotation.type, *rightAnnotation.type, _binaryOperation.location());
operationAnnotation.type = leftAnnotation.type; operationAnnotation.type = leftAnnotation.type;
} }
else else
@ -339,7 +380,7 @@ void TypeInference::endVisit(Assignment const& _assignment)
solAssert(lhsAnnotation.type); solAssert(lhsAnnotation.type);
auto& rhsAnnotation = annotation(_assignment.rightHandSide()); auto& rhsAnnotation = annotation(_assignment.rightHandSide());
solAssert(rhsAnnotation.type); solAssert(rhsAnnotation.type);
unify(*lhsAnnotation.type, *rhsAnnotation.type); unify(*lhsAnnotation.type, *rhsAnnotation.type, _assignment.location());
assignmentAnnotation.type = m_env->resolve(*lhsAnnotation.type); assignmentAnnotation.type = m_env->resolve(*lhsAnnotation.type);
} }
@ -459,7 +500,7 @@ void TypeInference::endVisit(TupleExpression const& _tupleExpression)
{ {
Type type = m_typeSystem.freshTypeVariable(false, {}); Type type = m_typeSystem.freshTypeVariable(false, {});
for (auto componentType: componentTypes) for (auto componentType: componentTypes)
unify(type, componentType); unify(type, componentType, _tupleExpression.location());
expressionAnnotation.type = type; expressionAnnotation.type = type;
break; break;
} }
@ -480,13 +521,18 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
map<string, Type> functionTypes; map<string, Type> functionTypes;
Type typeVar = m_typeSystem.freshTypeVariable(false, {});
for (auto subNode: typeClass->subNodes()) for (auto subNode: typeClass->subNodes())
{ {
auto const* functionDefinition = dynamic_cast<FunctionDefinition const*>(subNode.get()); auto const* functionDefinition = dynamic_cast<FunctionDefinition const*>(subNode.get());
solAssert(functionDefinition); solAssert(functionDefinition);
auto functionType = annotation(*functionDefinition).type; auto functionType = annotation(*functionDefinition).type;
solAssert(functionType); solAssert(functionType);
functionTypes[functionDefinition->name()] = *functionType; functionTypes[functionDefinition->name()] = m_env->fresh(*functionType, true);
auto typeVars = TypeSystemHelpers{m_typeSystem}.typeVars(functionTypes[functionDefinition->name()]);
solAssert(typeVars.size() == 1);
unify(typeVars.front(), typeVar);
} }
for (auto subNode: _typeClassInstantiation.subNodes()) for (auto subNode: _typeClassInstantiation.subNodes())
{ {
@ -508,7 +554,7 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
auto functionType = annotation(*functionDefinition).type; auto functionType = annotation(*functionDefinition).type;
solAssert(functionType); solAssert(functionType);
// TODO: require exact match? // TODO: require exact match?
unify(*functionType, m_env->fresh(*expectedFunctionType, true)); unify(*functionType, *expectedFunctionType, functionDefinition->location());
functionTypes.erase(functionDefinition->name()); functionTypes.erase(functionDefinition->name());
} }
} }
@ -618,7 +664,7 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
{ {
Type argTuple = TypeSystemHelpers{m_typeSystem}.tupleType(argTypes); Type argTuple = TypeSystemHelpers{m_typeSystem}.tupleType(argTypes);
Type genericFunctionType = TypeSystemHelpers{m_typeSystem}.functionType(argTuple, m_typeSystem.freshTypeVariable(false, {})); Type genericFunctionType = TypeSystemHelpers{m_typeSystem}.functionType(argTuple, m_typeSystem.freshTypeVariable(false, {}));
unify(genericFunctionType, functionType); unify(genericFunctionType, functionType, _functionCall.location());
functionCallAnnotation.type = m_env->resolve(std::get<1>(TypeSystemHelpers{m_typeSystem}.destFunctionType(m_env->resolve(genericFunctionType)))); functionCallAnnotation.type = m_env->resolve(std::get<1>(TypeSystemHelpers{m_typeSystem}.destFunctionType(m_env->resolve(genericFunctionType))));
break; break;

View File

@ -81,11 +81,12 @@ private:
Type m_voidType; Type m_voidType;
Type m_wordType; Type m_wordType;
Type m_integerType; Type m_integerType;
Type m_unitType;
std::optional<Type> m_currentFunctionType; std::optional<Type> m_currentFunctionType;
Annotation& annotation(ASTNode const& _node); Annotation& annotation(ASTNode const& _node);
void unify(Type _a, Type _b); void unify(Type _a, Type _b, langutil::SourceLocation _location = {});
enum class ExpressionContext enum class ExpressionContext
{ {
Term, Term,

View File

@ -81,6 +81,12 @@ bool TypeRegistration::visit(TypeClassInstantiation const& _typeClassInstantiati
return BuiltinType::Void; return BuiltinType::Void;
case Token::Integer: case Token::Integer:
return BuiltinType::Integer; return BuiltinType::Integer;
case Token::Pair:
return BuiltinType::Pair;
case Token::Function:
return BuiltinType::Function;
case Token::Unit:
return BuiltinType::Function;
default: default:
m_errorReporter.typeError(0000_error, typeName.location(), "Only elementary types are supported."); m_errorReporter.typeError(0000_error, typeName.location(), "Only elementary types are supported.");
return BuiltinType::Void; return BuiltinType::Void;
@ -117,7 +123,7 @@ bool TypeRegistration::visit(TypeClassInstantiation const& _typeClassInstantiati
if (!dynamic_cast<TypeClassDefinition const*>(referencedDeclaration)) if (!dynamic_cast<TypeClassDefinition const*>(referencedDeclaration))
m_errorReporter.fatalTypeError(0000_error, argumentSort->location(), "Argument sort has to be a type class."); m_errorReporter.fatalTypeError(0000_error, argumentSort->location(), "Argument sort has to be a type class.");
// TODO: multi arities // TODO: multi arities
arity._argumentSorts.emplace_back(Sort{{TypeClass{referencedDeclaration}}}); arity.argumentSorts.emplace_back(Sort{{TypeClass{referencedDeclaration}}});
} }
else else
{ {

View File

@ -100,8 +100,11 @@ std::string experimental::canonicalTypeName(Type _type)
std::visit(util::GenericVisitor{ std::visit(util::GenericVisitor{
[&](Declaration const* _declaration) { [&](Declaration const* _declaration) {
printTypeArguments(); printTypeArguments();
// TODO: canonical name if (auto const* typeDeclarationAnnotation = dynamic_cast<TypeDeclarationAnnotation const*>(&_declaration->annotation()))
stream << _declaration->name(); stream << *typeDeclarationAnnotation->canonicalName;
else
// TODO: canonical name
stream << _declaration->name();
}, },
[&](BuiltinType _builtinType) { [&](BuiltinType _builtinType) {
printTypeArguments(); printTypeArguments();
@ -212,7 +215,7 @@ vector<TypeEnvironment::UnificationFailure> TypeEnvironment::unify(Type _a, Type
{ {
vector<UnificationFailure> failures; vector<UnificationFailure> failures;
auto unificationFailure = [&]() { auto unificationFailure = [&]() {
failures.emplace_back(UnificationFailure{_a, _b}); failures.emplace_back(UnificationFailure{TypeMismatch{_a, _b}});
}; };
_a = resolve(_a); _a = resolve(_a);
_b = resolve(_b); _b = resolve(_b);
@ -226,22 +229,22 @@ vector<TypeEnvironment::UnificationFailure> TypeEnvironment::unify(Type _a, Type
else else
{ {
if (_left.sort() < _right.sort()) if (_left.sort() < _right.sort())
instantiate(_left, _right); failures += instantiate(_left, _right);
else if (_right.sort() < _left.sort()) else if (_right.sort() < _left.sort())
instantiate(_right, _left); failures += instantiate(_right, _left);
else else
{ {
Type newVar = m_typeSystem.freshTypeVariable(_left.generic() && _right.generic(), _left.sort() + _right.sort()); Type newVar = m_typeSystem.freshTypeVariable(_left.generic() && _right.generic(), _left.sort() + _right.sort());
instantiate(_left, newVar); failures += instantiate(_left, newVar);
instantiate(_right, newVar); failures += instantiate(_right, newVar);
} }
} }
}, },
[&](TypeVariable _var, auto) { [&](TypeVariable _var, auto) {
instantiate(_var, _b); failures += instantiate(_var, _b);
}, },
[&](auto, TypeVariable _var) { [&](auto, TypeVariable _var) {
instantiate(_var, _a); failures += instantiate(_var, _a);
}, },
[&](TypeExpression _left, TypeExpression _right) { [&](TypeExpression _left, TypeExpression _right) {
if(_left.constructor != _right.constructor) if(_left.constructor != _right.constructor)
@ -271,9 +274,15 @@ experimental::Type TypeSystem::freshTypeVariable(bool _generic, Sort const& _sor
return TypeVariable(index, _sort, _generic); return TypeVariable(index, _sort, _generic);
} }
void TypeEnvironment::instantiate(TypeVariable _variable, Type _type) vector<TypeEnvironment::UnificationFailure> TypeEnvironment::instantiate(TypeVariable _variable, Type _type)
{ {
Sort sort = m_typeSystem.sort(_type);
if (!(_variable.sort() < sort))
{
return {UnificationFailure{SortMismatch{_type, _variable.sort()}}};
}
solAssert(m_typeVariables.emplace(_variable.index(), _type).second); solAssert(m_typeVariables.emplace(_variable.index(), _type).second);
return {};
} }
experimental::Type TypeEnvironment::resolve(Type _type) const experimental::Type TypeEnvironment::resolve(Type _type) const
@ -304,6 +313,36 @@ experimental::Type TypeEnvironment::resolveRecursive(Type _type) const
}, resolve(_type)); }, resolve(_type));
} }
Sort TypeSystem::sort(Type _type) const
{
return std::visit(util::GenericVisitor{
[&](TypeExpression const& _expression) -> Sort
{
auto const& constructorInfo = m_typeConstructors.at(_expression.constructor);
auto argumentSorts = _expression.arguments | ranges::views::transform([&](Type _type) {
return sort(_type);
}) | ranges::to<vector<Sort>>;
Sort sort;
for (auto const& arity: constructorInfo.arities)
{
solAssert(arity.argumentSorts.size() == argumentSorts.size());
bool hasArity = true;
for (auto&& [argumentSort, arityArgumentSort]: ranges::zip_view(argumentSorts, arity.argumentSorts))
{
if (!(argumentSort < arityArgumentSort))
{
hasArity = false;
break;
}
}
if (hasArity)
sort.classes.insert(arity.typeClass);
}
return sort;
},
[](TypeVariable const& _variable) -> Sort { return _variable.sort(); },
}, _type);
}
void TypeSystem::declareTypeConstructor(TypeExpression::Constructor _typeConstructor, std::string _name, size_t _arguments) void TypeSystem::declareTypeConstructor(TypeExpression::Constructor _typeConstructor, std::string _name, size_t _arguments)
{ {
@ -367,7 +406,7 @@ void TypeSystem::instantiateClass(TypeExpression::Constructor _typeConstructor,
{ {
// TODO: proper error handling // TODO: proper error handling
auto& typeConstructorInfo = m_typeConstructors.at(_typeConstructor); auto& typeConstructorInfo = m_typeConstructors.at(_typeConstructor);
solAssert(_arity._argumentSorts.size() == typeConstructorInfo.arguments, "Invalid arity."); solAssert(_arity.argumentSorts.size() == typeConstructorInfo.arguments, "Invalid arity.");
typeConstructorInfo.arities.emplace_back(_arity); typeConstructorInfo.arities.emplace_back(_arity);
} }
@ -463,4 +502,24 @@ vector<experimental::Type> TypeSystemHelpers::typeVars(Type _type) const
typeVarsImpl(_type, typeVarsImpl); typeVarsImpl(_type, typeVarsImpl);
return typeVars; return typeVars;
}
std::string TypeSystemHelpers::sortToString(Sort _sort) const
{
switch (_sort.classes.size())
{
case 0:
return "()";
case 1:
return _sort.classes.begin()->declaration->name();
default:
{
std::stringstream stream;
stream << "(";
for (auto typeClass: _sort.classes | ranges::views::drop_last(1))
stream << typeClass.declaration->name() << ", ";
stream << _sort.classes.rbegin()->declaration->name() << ")";
return stream.str();
}
}
} }

View File

@ -88,8 +88,8 @@ struct Sort
struct Arity struct Arity
{ {
std::vector<Sort> _argumentSorts; std::vector<Sort> argumentSorts;
TypeClass _class; TypeClass typeClass;
}; };
struct TypeVariable struct TypeVariable
@ -130,12 +130,14 @@ public:
Type resolve(Type _type) const; Type resolve(Type _type) const;
Type resolveRecursive(Type _type) const; Type resolveRecursive(Type _type) const;
Type fresh(Type _type, bool _generalize); Type fresh(Type _type, bool _generalize);
struct UnificationFailure { Type a; Type b; }; struct TypeMismatch { Type a; Type b; };
struct SortMismatch { Type type; Sort sort; };
using UnificationFailure = std::variant<TypeMismatch, SortMismatch>;
[[nodiscard]] std::vector<UnificationFailure> unify(Type _a, Type _b); [[nodiscard]] std::vector<UnificationFailure> unify(Type _a, Type _b);
std::string typeToString(Type const& _type) const; std::string typeToString(Type const& _type) const;
private: private:
TypeEnvironment(TypeEnvironment&& _env): m_typeSystem(_env.m_typeSystem), m_typeVariables(std::move(_env.m_typeVariables)) {} TypeEnvironment(TypeEnvironment&& _env): m_typeSystem(_env.m_typeSystem), m_typeVariables(std::move(_env.m_typeVariables)) {}
void instantiate(TypeVariable _variable, Type _type); [[nodiscard]] std::vector<TypeEnvironment::UnificationFailure> instantiate(TypeVariable _variable, Type _type);
TypeSystem& m_typeSystem; TypeSystem& m_typeSystem;
std::map<size_t, Type> m_typeVariables; std::map<size_t, Type> m_typeVariables;
}; };
@ -165,6 +167,7 @@ public:
TypeEnvironment const& env() const { return m_globalTypeEnvironment; } TypeEnvironment const& env() const { return m_globalTypeEnvironment; }
TypeEnvironment& env() { return m_globalTypeEnvironment; } TypeEnvironment& env() { return m_globalTypeEnvironment; }
Sort sort(Type _type) const;
private: private:
size_t m_numTypeVariables = 0; size_t m_numTypeVariables = 0;
struct TypeConstructorInfo struct TypeConstructorInfo
@ -186,6 +189,7 @@ struct TypeSystemHelpers
Type functionType(Type _argType, Type _resultType) const; Type functionType(Type _argType, Type _resultType) const;
std::tuple<Type, Type> destFunctionType(Type _functionType) const; std::tuple<Type, Type> destFunctionType(Type _functionType) const;
std::vector<Type> typeVars(Type _type) const; std::vector<Type> typeVars(Type _type) const;
std::string sortToString(Sort _sort) const;
}; };
} }