This commit is contained in:
Daniel Kirchner 2023-06-24 07:47:33 +02:00
parent 0d1679cee0
commit 4a84818669
6 changed files with 150 additions and 29 deletions

View File

@ -275,6 +275,9 @@ namespace solidity::langutil
K(Word, "word", 0) \
K(Integer, "integer", 0) \
K(Void, "void", 0) \
K(Pair, "pair", 0) \
K(Fun, "fun", 0) \
K(Unit, "unit", 0) \
K(StaticAssert, "static_assert", 0) \
T(ExperimentalEnd, nullptr, 0) /* used as experimental enum end marker */ \
\
@ -302,7 +305,9 @@ namespace TokenTraits
// Predicates
constexpr bool isElementaryTypeName(Token tok)
{
return (Token::Int <= tok && tok < Token::TypesEnd) || tok == Token::Word || tok == Token::Void || tok == Token::Integer;
return (Token::Int <= tok && tok < Token::TypesEnd) ||
tok == Token::Word || tok == Token::Void || tok == Token::Integer ||
tok == Token::Pair || tok == Token::Unit || tok == Token::Fun;
}
constexpr bool isAssignmentOp(Token tok) { return Token::Assign <= tok && tok <= Token::AssignMod; }
constexpr bool isBinaryOp(Token op) { return Token::Comma <= op && op <= Token::Exp; }

View File

@ -40,6 +40,7 @@ m_typeSystem(_analysis.typeSystem())
m_voidType = m_typeSystem.builtinType(BuiltinType::Void, {});
m_wordType = m_typeSystem.builtinType(BuiltinType::Word, {});
m_integerType = m_typeSystem.builtinType(BuiltinType::Integer, {});
m_unitType = m_typeSystem.builtinType(BuiltinType::Unit, {});
m_env = &m_typeSystem.env();
}
@ -93,7 +94,7 @@ void TypeInference::endVisit(Return const& _return)
auto& returnExpressionAnnotation = annotation(*_return.expression());
solAssert(returnExpressionAnnotation.type);
Type functionReturnType = get<1>(TypeSystemHelpers{m_typeSystem}.destFunctionType(*m_currentFunctionType));
unify(functionReturnType, *returnExpressionAnnotation.type);
unify(functionReturnType, *returnExpressionAnnotation.type, _return.location());
}
}
@ -133,7 +134,7 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
subNode->accept(*this);
solAssert(typeVariableAnnotation.type);
unify(*typeVariableAnnotation.type, m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{&_typeClassDefinition}}}));
unify(*typeVariableAnnotation.type, m_typeSystem.freshTypeVariable(false, Sort{{TypeClass{&_typeClassDefinition}}}), _typeClassDefinition.location());
return false;
}
@ -180,10 +181,23 @@ experimental::Type TypeInference::fromTypeName(TypeName const& _typeName)
}
*/
void TypeInference::unify(Type _a, Type _b)
void TypeInference::unify(Type _a, Type _b, langutil::SourceLocation _location)
{
for (auto failure: m_env->unify(_a, _b))
m_errorReporter.typeError(0000_error, {}, fmt::format("Cannot unify {} and {}.", m_env->typeToString(_a), m_env->typeToString(_b)));
std::visit(util::GenericVisitor{
[&](TypeEnvironment::TypeMismatch _typeMismatch) {
m_errorReporter.typeError(0000_error, _location, fmt::format("Cannot unify {} and {}.", m_env->typeToString(_typeMismatch.a), m_env->typeToString(_typeMismatch.b)));
},
[&](TypeEnvironment::SortMismatch _sortMismatch) {
m_errorReporter.typeError(0000_error, _location, fmt::format(
"Cannot unify {} and {}: {} does not have sort {}",
m_env->typeToString(_a),
m_env->typeToString(_b),
m_env->typeToString(_sortMismatch.type),
TypeSystemHelpers{m_typeSystem}.sortToString(_sortMismatch.sort)
));
}
}, failure);
}
bool TypeInference::visit(InlineAssembly const& _inlineAssembly)
@ -214,7 +228,7 @@ bool TypeInference::visit(InlineAssembly const& _inlineAssembly)
auto& declarationAnnotation = annotation(*declaration);
solAssert(declarationAnnotation.type);
unify(*declarationAnnotation.type, m_wordType);
unify(*declarationAnnotation.type, m_wordType, originLocationOf(_identifier));
identifierInfo.valueSize = 1;
return true;
};
@ -251,6 +265,33 @@ bool TypeInference::visit(ElementaryTypeNameExpression const& _expression)
case Token::Integer:
expressionAnnotation.type = m_integerType;
break;
case Token::Unit:
expressionAnnotation.type = m_unitType;
break;
case Token::Pair:
{
auto leftType = m_typeSystem.freshTypeVariable(true, {});
auto rightType = m_typeSystem.freshTypeVariable(true, {});
TypeSystemHelpers helper{m_typeSystem};
expressionAnnotation.type =
helper.functionType(
helper.tupleType({leftType, rightType}),
m_typeSystem.type(BuiltinType::Pair, {leftType, rightType})
);
break;
}
case Token::Fun:
{
auto argType = m_typeSystem.freshTypeVariable(true, {});
auto resultType = m_typeSystem.freshTypeVariable(true, {});
TypeSystemHelpers helper{m_typeSystem};
expressionAnnotation.type =
helper.functionType(
helper.tupleType({argType, resultType}),
m_typeSystem.type(BuiltinType::Function, {argType, resultType})
);
break;
}
default:
m_errorReporter.typeError(0000_error, _expression.location(), "Only elementary types are supported.");
expressionAnnotation.type = m_typeSystem.freshTypeVariable(false, {});
@ -280,7 +321,7 @@ bool TypeInference::visit(BinaryOperation const& _binaryOperation)
}
solAssert(leftAnnotation.type);
solAssert(rightAnnotation.type);
unify(*leftAnnotation.type, *rightAnnotation.type);
unify(*leftAnnotation.type, *rightAnnotation.type, _binaryOperation.location());
operationAnnotation.type = leftAnnotation.type;
}
else
@ -339,7 +380,7 @@ void TypeInference::endVisit(Assignment const& _assignment)
solAssert(lhsAnnotation.type);
auto& rhsAnnotation = annotation(_assignment.rightHandSide());
solAssert(rhsAnnotation.type);
unify(*lhsAnnotation.type, *rhsAnnotation.type);
unify(*lhsAnnotation.type, *rhsAnnotation.type, _assignment.location());
assignmentAnnotation.type = m_env->resolve(*lhsAnnotation.type);
}
@ -459,7 +500,7 @@ void TypeInference::endVisit(TupleExpression const& _tupleExpression)
{
Type type = m_typeSystem.freshTypeVariable(false, {});
for (auto componentType: componentTypes)
unify(type, componentType);
unify(type, componentType, _tupleExpression.location());
expressionAnnotation.type = type;
break;
}
@ -480,13 +521,18 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
map<string, Type> functionTypes;
Type typeVar = m_typeSystem.freshTypeVariable(false, {});
for (auto subNode: typeClass->subNodes())
{
auto const* functionDefinition = dynamic_cast<FunctionDefinition const*>(subNode.get());
solAssert(functionDefinition);
auto functionType = annotation(*functionDefinition).type;
solAssert(functionType);
functionTypes[functionDefinition->name()] = *functionType;
functionTypes[functionDefinition->name()] = m_env->fresh(*functionType, true);
auto typeVars = TypeSystemHelpers{m_typeSystem}.typeVars(functionTypes[functionDefinition->name()]);
solAssert(typeVars.size() == 1);
unify(typeVars.front(), typeVar);
}
for (auto subNode: _typeClassInstantiation.subNodes())
{
@ -508,7 +554,7 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
auto functionType = annotation(*functionDefinition).type;
solAssert(functionType);
// TODO: require exact match?
unify(*functionType, m_env->fresh(*expectedFunctionType, true));
unify(*functionType, *expectedFunctionType, functionDefinition->location());
functionTypes.erase(functionDefinition->name());
}
}
@ -618,7 +664,7 @@ void TypeInference::endVisit(FunctionCall const& _functionCall)
{
Type argTuple = TypeSystemHelpers{m_typeSystem}.tupleType(argTypes);
Type genericFunctionType = TypeSystemHelpers{m_typeSystem}.functionType(argTuple, m_typeSystem.freshTypeVariable(false, {}));
unify(genericFunctionType, functionType);
unify(genericFunctionType, functionType, _functionCall.location());
functionCallAnnotation.type = m_env->resolve(std::get<1>(TypeSystemHelpers{m_typeSystem}.destFunctionType(m_env->resolve(genericFunctionType))));
break;

View File

@ -81,11 +81,12 @@ private:
Type m_voidType;
Type m_wordType;
Type m_integerType;
Type m_unitType;
std::optional<Type> m_currentFunctionType;
Annotation& annotation(ASTNode const& _node);
void unify(Type _a, Type _b);
void unify(Type _a, Type _b, langutil::SourceLocation _location = {});
enum class ExpressionContext
{
Term,

View File

@ -81,6 +81,12 @@ bool TypeRegistration::visit(TypeClassInstantiation const& _typeClassInstantiati
return BuiltinType::Void;
case Token::Integer:
return BuiltinType::Integer;
case Token::Pair:
return BuiltinType::Pair;
case Token::Function:
return BuiltinType::Function;
case Token::Unit:
return BuiltinType::Function;
default:
m_errorReporter.typeError(0000_error, typeName.location(), "Only elementary types are supported.");
return BuiltinType::Void;
@ -117,7 +123,7 @@ bool TypeRegistration::visit(TypeClassInstantiation const& _typeClassInstantiati
if (!dynamic_cast<TypeClassDefinition const*>(referencedDeclaration))
m_errorReporter.fatalTypeError(0000_error, argumentSort->location(), "Argument sort has to be a type class.");
// TODO: multi arities
arity._argumentSorts.emplace_back(Sort{{TypeClass{referencedDeclaration}}});
arity.argumentSorts.emplace_back(Sort{{TypeClass{referencedDeclaration}}});
}
else
{

View File

@ -100,8 +100,11 @@ std::string experimental::canonicalTypeName(Type _type)
std::visit(util::GenericVisitor{
[&](Declaration const* _declaration) {
printTypeArguments();
// TODO: canonical name
stream << _declaration->name();
if (auto const* typeDeclarationAnnotation = dynamic_cast<TypeDeclarationAnnotation const*>(&_declaration->annotation()))
stream << *typeDeclarationAnnotation->canonicalName;
else
// TODO: canonical name
stream << _declaration->name();
},
[&](BuiltinType _builtinType) {
printTypeArguments();
@ -212,7 +215,7 @@ vector<TypeEnvironment::UnificationFailure> TypeEnvironment::unify(Type _a, Type
{
vector<UnificationFailure> failures;
auto unificationFailure = [&]() {
failures.emplace_back(UnificationFailure{_a, _b});
failures.emplace_back(UnificationFailure{TypeMismatch{_a, _b}});
};
_a = resolve(_a);
_b = resolve(_b);
@ -226,22 +229,22 @@ vector<TypeEnvironment::UnificationFailure> TypeEnvironment::unify(Type _a, Type
else
{
if (_left.sort() < _right.sort())
instantiate(_left, _right);
failures += instantiate(_left, _right);
else if (_right.sort() < _left.sort())
instantiate(_right, _left);
failures += instantiate(_right, _left);
else
{
Type newVar = m_typeSystem.freshTypeVariable(_left.generic() && _right.generic(), _left.sort() + _right.sort());
instantiate(_left, newVar);
instantiate(_right, newVar);
failures += instantiate(_left, newVar);
failures += instantiate(_right, newVar);
}
}
},
[&](TypeVariable _var, auto) {
instantiate(_var, _b);
failures += instantiate(_var, _b);
},
[&](auto, TypeVariable _var) {
instantiate(_var, _a);
failures += instantiate(_var, _a);
},
[&](TypeExpression _left, TypeExpression _right) {
if(_left.constructor != _right.constructor)
@ -271,9 +274,15 @@ experimental::Type TypeSystem::freshTypeVariable(bool _generic, Sort const& _sor
return TypeVariable(index, _sort, _generic);
}
void TypeEnvironment::instantiate(TypeVariable _variable, Type _type)
vector<TypeEnvironment::UnificationFailure> TypeEnvironment::instantiate(TypeVariable _variable, Type _type)
{
Sort sort = m_typeSystem.sort(_type);
if (!(_variable.sort() < sort))
{
return {UnificationFailure{SortMismatch{_type, _variable.sort()}}};
}
solAssert(m_typeVariables.emplace(_variable.index(), _type).second);
return {};
}
experimental::Type TypeEnvironment::resolve(Type _type) const
@ -304,6 +313,36 @@ experimental::Type TypeEnvironment::resolveRecursive(Type _type) const
}, resolve(_type));
}
Sort TypeSystem::sort(Type _type) const
{
return std::visit(util::GenericVisitor{
[&](TypeExpression const& _expression) -> Sort
{
auto const& constructorInfo = m_typeConstructors.at(_expression.constructor);
auto argumentSorts = _expression.arguments | ranges::views::transform([&](Type _type) {
return sort(_type);
}) | ranges::to<vector<Sort>>;
Sort sort;
for (auto const& arity: constructorInfo.arities)
{
solAssert(arity.argumentSorts.size() == argumentSorts.size());
bool hasArity = true;
for (auto&& [argumentSort, arityArgumentSort]: ranges::zip_view(argumentSorts, arity.argumentSorts))
{
if (!(argumentSort < arityArgumentSort))
{
hasArity = false;
break;
}
}
if (hasArity)
sort.classes.insert(arity.typeClass);
}
return sort;
},
[](TypeVariable const& _variable) -> Sort { return _variable.sort(); },
}, _type);
}
void TypeSystem::declareTypeConstructor(TypeExpression::Constructor _typeConstructor, std::string _name, size_t _arguments)
{
@ -367,7 +406,7 @@ void TypeSystem::instantiateClass(TypeExpression::Constructor _typeConstructor,
{
// TODO: proper error handling
auto& typeConstructorInfo = m_typeConstructors.at(_typeConstructor);
solAssert(_arity._argumentSorts.size() == typeConstructorInfo.arguments, "Invalid arity.");
solAssert(_arity.argumentSorts.size() == typeConstructorInfo.arguments, "Invalid arity.");
typeConstructorInfo.arities.emplace_back(_arity);
}
@ -464,3 +503,23 @@ vector<experimental::Type> TypeSystemHelpers::typeVars(Type _type) const
return typeVars;
}
std::string TypeSystemHelpers::sortToString(Sort _sort) const
{
switch (_sort.classes.size())
{
case 0:
return "()";
case 1:
return _sort.classes.begin()->declaration->name();
default:
{
std::stringstream stream;
stream << "(";
for (auto typeClass: _sort.classes | ranges::views::drop_last(1))
stream << typeClass.declaration->name() << ", ";
stream << _sort.classes.rbegin()->declaration->name() << ")";
return stream.str();
}
}
}

View File

@ -88,8 +88,8 @@ struct Sort
struct Arity
{
std::vector<Sort> _argumentSorts;
TypeClass _class;
std::vector<Sort> argumentSorts;
TypeClass typeClass;
};
struct TypeVariable
@ -130,12 +130,14 @@ public:
Type resolve(Type _type) const;
Type resolveRecursive(Type _type) const;
Type fresh(Type _type, bool _generalize);
struct UnificationFailure { Type a; Type b; };
struct TypeMismatch { Type a; Type b; };
struct SortMismatch { Type type; Sort sort; };
using UnificationFailure = std::variant<TypeMismatch, SortMismatch>;
[[nodiscard]] std::vector<UnificationFailure> unify(Type _a, Type _b);
std::string typeToString(Type const& _type) const;
private:
TypeEnvironment(TypeEnvironment&& _env): m_typeSystem(_env.m_typeSystem), m_typeVariables(std::move(_env.m_typeVariables)) {}
void instantiate(TypeVariable _variable, Type _type);
[[nodiscard]] std::vector<TypeEnvironment::UnificationFailure> instantiate(TypeVariable _variable, Type _type);
TypeSystem& m_typeSystem;
std::map<size_t, Type> m_typeVariables;
};
@ -165,6 +167,7 @@ public:
TypeEnvironment const& env() const { return m_globalTypeEnvironment; }
TypeEnvironment& env() { return m_globalTypeEnvironment; }
Sort sort(Type _type) const;
private:
size_t m_numTypeVariables = 0;
struct TypeConstructorInfo
@ -186,6 +189,7 @@ struct TypeSystemHelpers
Type functionType(Type _argType, Type _resultType) const;
std::tuple<Type, Type> destFunctionType(Type _functionType) const;
std::vector<Type> typeVars(Type _type) const;
std::string sortToString(Sort _sort) const;
};
}