This commit is contained in:
Daniel Kirchner 2023-06-29 02:42:10 +02:00
parent 5adc255b3c
commit 987278385e
7 changed files with 519 additions and 661 deletions

View File

@ -39,88 +39,130 @@ ASTTransform::ASTTransform(Analysis& _analysis): m_analysis(_analysis), m_errorR
{ {
} }
bool ASTTransform::visit(legacy::ContractDefinition const& _contractDefinition) bool ASTTransform::visit(legacy::TypeDefinition const& _typeDefinition)
{ {
SetNode setNode(*this, _contractDefinition); SetNode setNode(*this, _typeDefinition);
auto [it, newlyInserted] = m_ast->contracts.emplace(&_contractDefinition, AST::ContractInfo{}); auto [it, newlyInserted] = m_ast->typeDefinitions.emplace(&_typeDefinition, term(_typeDefinition));
solAssert(newlyInserted); solAssert(newlyInserted);
AST::ContractInfo& contractInfo = it->second;
for (auto const& node: _contractDefinition.subNodes())
if (auto const* function = dynamic_cast<legacy::FunctionDefinition const*>(node.get()))
solAssert(contractInfo.functions.emplace(string{}, functionDefinition(*function)).second);
else
m_errorReporter.typeError(0000_error, node->location(), "Unsupported contract element.");
return false;
}
bool ASTTransform::visit(legacy::FunctionDefinition const& _functionDefinition)
{
SetNode setNode(*this, _functionDefinition);
solAssert(m_ast->functions.emplace(&_functionDefinition, functionDefinition(_functionDefinition)).second);
return false; return false;
} }
bool ASTTransform::visit(legacy::TypeClassDefinition const& _typeClassDefinition) bool ASTTransform::visit(legacy::TypeClassDefinition const& _typeClassDefinition)
{ {
SetNode setNode(*this, _typeClassDefinition); SetNode setNode(*this, _typeClassDefinition);
auto [it, newlyInserted] = m_ast->typeClasses.emplace(&_typeClassDefinition, AST::TypeClassInformation{}); auto [it, newlyInserted] = m_ast->typeClasses.emplace(&_typeClassDefinition, term(_typeClassDefinition));
solAssert(newlyInserted); solAssert(newlyInserted);
auto& info = it->second;
info.typeVariable = term(_typeClassDefinition.typeVariable());
info.declaration = reference(_typeClassDefinition);
declare(_typeClassDefinition, *info.declaration);
map<std::string, AST::FunctionInfo>& functions = info.functions;
for (auto subNode: _typeClassDefinition.subNodes())
{
auto const *function = dynamic_cast<FunctionDefinition const*>(subNode.get());
solAssert(function);
solAssert(functions.emplace(function->name(), functionDefinition(*function)).second);
}
return false; return false;
} }
bool ASTTransform::visit(legacy::TypeClassInstantiation const& _typeClassInstantiation) bool ASTTransform::visit(legacy::TypeClassInstantiation const& _typeClassInstantiation)
{ {
SetNode setNode(*this, _typeClassInstantiation); SetNode setNode(*this, _typeClassInstantiation);
auto [it, newlyInserted] = m_ast->typeClassInstantiations.emplace(&_typeClassInstantiation, AST::TypeClassInstantiationInformation{}); auto [it, newlyInserted] = m_ast->typeClassInstantiations.emplace(&_typeClassInstantiation, term(_typeClassInstantiation));
solAssert(newlyInserted); solAssert(newlyInserted);
auto& info = it->second;
info.typeClass = std::visit(util::GenericVisitor{
[&](Token _token) -> unique_ptr<Term> { return builtinTypeClass(_token); },
[&](ASTPointer<legacy::IdentifierPath> _identifierPath) -> unique_ptr<Term> {
solAssert(_identifierPath->annotation().referencedDeclaration);
return reference(*_identifierPath->annotation().referencedDeclaration);
}
}, _typeClassInstantiation.typeClass().name());
info.typeConstructor = term(_typeClassInstantiation.typeConstructor());
info.argumentSorts = termOrConstant(_typeClassInstantiation.argumentSorts(), BuiltinConstant::Unit);
map<std::string, AST::FunctionInfo>& functions = info.functions;
for (auto subNode: _typeClassInstantiation.subNodes())
{
auto const *function = dynamic_cast<FunctionDefinition const*>(subNode.get());
solAssert(function);
solAssert(functions.emplace(function->name(), functionDefinition(*function)).second);
}
return false; return false;
} }
bool ASTTransform::visit(legacy::TypeDefinition const& _typeDefinition) bool ASTTransform::visit(legacy::ContractDefinition const& _contractDefinition)
{
SetNode setNode(*this, _contractDefinition);
auto [it, newlyInserted] = m_ast->contracts.emplace(&_contractDefinition, term(_contractDefinition));
solAssert(newlyInserted);
return false;
}
bool ASTTransform::visit(legacy::FunctionDefinition const& _functionDefinition)
{
SetNode setNode(*this, _functionDefinition);
solAssert(m_ast->functions.emplace(&_functionDefinition, term(_functionDefinition)).second);
return false;
}
bool ASTTransform::visitNode(ASTNode const& _node)
{
m_errorReporter.typeError(0000_error, _node.location(), "Unexpected AST node during AST transform.");
return false;
}
unique_ptr<Term> ASTTransform::term(legacy::TypeDefinition const& _typeDefinition)
{ {
SetNode setNode(*this, _typeDefinition); SetNode setNode(*this, _typeDefinition);
auto [it, newlyInserted] = m_ast->typeDefinitions.emplace(&_typeDefinition, AST::TypeInformation{}); unique_ptr<Term> name = reference(_typeDefinition);
solAssert(newlyInserted); unique_ptr<Term> arguments = termOrConstant(_typeDefinition.arguments(), BuiltinConstant::Unit);
auto& info = it->second;
info.declaration = makeTerm<VariableDeclaration>(reference(_typeDefinition), nullptr);
declare(_typeDefinition, *info.declaration);
if (_typeDefinition.arguments())
info.arguments = tuple(_typeDefinition.arguments()->parameters() | ranges::view::transform([&](auto argument){
solAssert(!argument->typeExpression()); // TODO: error handling
return term(*argument);
}) | ranges::view::move | ranges::to<list<unique_ptr<Term>>>);
if (_typeDefinition.typeExpression()) if (_typeDefinition.typeExpression())
info.value = term(*_typeDefinition.typeExpression()); {
return false; unique_ptr<Term> definiens = term(*_typeDefinition.typeExpression());
return application(BuiltinConstant::TypeDefinition, std::move(name), std::move(arguments), std::move(definiens));
}
else
return application(BuiltinConstant::TypeDeclaration, std::move(name), std::move(arguments));
}
unique_ptr<Term> ASTTransform::term(legacy::TypeClassDefinition const& _typeClassDefinition)
{
SetNode setNode(*this, _typeClassDefinition);
unique_ptr<Term> typeVariable = term(_typeClassDefinition.typeVariable());
unique_ptr<Term> name = reference(_typeClassDefinition);
unique_ptr<Term> functions = namedFunctionList(_typeClassDefinition.subNodes());
return application(
BuiltinConstant::TypeClassDefinition,
std::move(typeVariable),
std::move(name),
std::move(functions)
);
}
std::unique_ptr<Term> ASTTransform::term(legacy::TypeClassInstantiation const& _typeClassInstantiation)
{
SetNode setNode(*this, _typeClassInstantiation);
unique_ptr<Term> typeConstructor = term(_typeClassInstantiation.typeConstructor());
unique_ptr<Term> argumentSorts = termOrConstant(_typeClassInstantiation.argumentSorts(), BuiltinConstant::Unit);
unique_ptr<Term> typeClass = term(_typeClassInstantiation.typeClass());
unique_ptr<Term> functions = namedFunctionList(_typeClassInstantiation.subNodes());
return application(
BuiltinConstant::TypeClassInstantiation,
std::move(typeConstructor),
std::move(argumentSorts),
std::move(typeClass),
std::move(functions)
);
}
std::unique_ptr<Term> ASTTransform::term(legacy::FunctionDefinition const& _functionDefinition)
{
SetNode setNode(*this, _functionDefinition);
unique_ptr<Term> name = reference(_functionDefinition);
unique_ptr<Term> arguments = term(_functionDefinition.parameterList());
unique_ptr<Term> returnType = termOrConstant(_functionDefinition.experimentalReturnExpression(), BuiltinConstant::Unit);
if (_functionDefinition.isImplemented())
{
unique_ptr<Term> body = term(_functionDefinition.body());
return application(
BuiltinConstant::FunctionDefinition,
std::move(name),
std::move(arguments),
std::move(returnType),
std::move(body)
);
}
else
return application(
BuiltinConstant::FunctionDeclaration,
std::move(name),
std::move(arguments),
std::move(returnType)
);
}
std::unique_ptr<Term> ASTTransform::term(legacy::ContractDefinition const& _contractDefinition)
{
SetNode setNode(*this, _contractDefinition);
unique_ptr<Term> name = reference(_contractDefinition);
return application(
BuiltinConstant::ContractDefinition,
std::move(name),
namedFunctionList(_contractDefinition.subNodes())
);
} }
unique_ptr<Term> ASTTransform::term(legacy::VariableDeclarationStatement const& _declaration) unique_ptr<Term> ASTTransform::term(legacy::VariableDeclarationStatement const& _declaration)
@ -142,10 +184,44 @@ unique_ptr<Term> ASTTransform::term(legacy::Assignment const& _assignment)
solAssert(false); solAssert(false);
} }
unique_ptr<Term> ASTTransform::term(TypeName const& _name) unique_ptr<Term> ASTTransform::term(legacy::Block const& _block)
{
SetNode setNode(*this, _block);
if (auto statements = ranges::fold_right_last(
_block.statements() | ranges::view::transform([&](auto stmt) { return term(*stmt); }) | ranges::view::move,
[&](auto stmt, auto acc) {
return application(BuiltinConstant::ChainStatements, std::move(stmt), std::move(acc));
}
))
return application(BuiltinConstant::Block, std::move(*statements));
else
return application(BuiltinConstant::Block, constant(BuiltinConstant::Unit));
}
unique_ptr<Term> ASTTransform::term(legacy::Statement const& _statement)
{
SetNode setNode(*this, _statement);
if (auto const* assembly = dynamic_cast<legacy::InlineAssembly const*>(&_statement))
return application(BuiltinConstant::RegularStatement, *assembly);
else if (auto const* declaration = dynamic_cast<legacy::VariableDeclarationStatement const*>(&_statement))
return application(BuiltinConstant::RegularStatement, *declaration);
else if (auto const* assign = dynamic_cast<legacy::Assignment const*>(&_statement))
return application(BuiltinConstant::RegularStatement, *assign);
else if (auto const* expressionStatement = dynamic_cast<legacy::ExpressionStatement const*>(&_statement))
return application(BuiltinConstant::RegularStatement, expressionStatement->expression());
else if (auto const* returnStatement = dynamic_cast<legacy::Return const*>(&_statement))
return application(BuiltinConstant::ReturnStatement, termOrConstant(returnStatement->expression(), BuiltinConstant::Unit));
else
{
m_analysis.errorReporter().fatalTypeError(0000_error, _statement.location(), "Unsupported statement.");
solAssert(false);
}
}
unique_ptr<Term> ASTTransform::term(legacy::TypeName const& _name)
{ {
SetNode setNode(*this, _name); SetNode setNode(*this, _name);
if (auto const* elementaryTypeName = dynamic_cast<ElementaryTypeName const*>(&_name)) if (auto const* elementaryTypeName = dynamic_cast<legacy::ElementaryTypeName const*>(&_name))
{ {
switch (elementaryTypeName->typeName().token()) switch (elementaryTypeName->typeName().token())
{ {
@ -178,69 +254,23 @@ unique_ptr<Term> ASTTransform::term(TypeName const& _name)
solAssert(false); solAssert(false);
} }
unique_ptr<Term> ASTTransform::term(legacy::Statement const& _statement) unique_ptr<Term> ASTTransform::term(legacy::TypeClassName const& _typeClassName)
{ {
SetNode setNode(*this, _statement); SetNode setNode(*this, _typeClassName);
if (auto const* assembly = dynamic_cast<legacy::InlineAssembly const*>(&_statement)) return std::visit(util::GenericVisitor{
return term(*assembly); [&](Token _token) -> unique_ptr<Term> { return builtinTypeClass(_token); },
else if (auto const* declaration = dynamic_cast<legacy::VariableDeclarationStatement const*>(&_statement)) [&](ASTPointer<legacy::IdentifierPath> _identifierPath) -> unique_ptr<Term> {
return term(*declaration); solAssert(_identifierPath->annotation().referencedDeclaration);
else if (auto const* assign = dynamic_cast<legacy::Assignment const*>(&_statement)) return reference(*_identifierPath->annotation().referencedDeclaration);
return term(*assign); }
else if (auto const* expressionStatement = dynamic_cast<legacy::ExpressionStatement const*>(&_statement)) }, _typeClassName.name());
return term(expressionStatement->expression());
else if (auto const* returnStatement = dynamic_cast<legacy::Return const*>(&_statement))
return application(BuiltinConstant::Return, termOrConstant(returnStatement->expression(), BuiltinConstant::Unit));
else
{
m_analysis.errorReporter().fatalTypeError(0000_error, _statement.location(), "Unsupported statement.");
solAssert(false);
}
}
unique_ptr<Term> ASTTransform::term(legacy::Block const& _block)
{
SetNode setNode(*this, _block);
if (_block.statements().empty())
return application(BuiltinConstant::Block, constant(BuiltinConstant::Unit));
auto makeStatement = [&](auto _stmt) {
return application(BuiltinConstant::Statement, *_stmt);
};
return application(
BuiltinConstant::Block,
ranges::fold_right(
_block.statements() | ranges::view::drop(1),
makeStatement(_block.statements().front()),
[&](auto stmt, auto acc) {
return application(BuiltinConstant::ChainStatements, std::move(acc), makeStatement(stmt));
}
)
);
}
AST::FunctionInfo ASTTransform::functionDefinition(legacy::FunctionDefinition const& _functionDefinition)
{
SetNode setNode(*this, _functionDefinition);
std::unique_ptr<Term> body = nullptr;
unique_ptr<Term> argumentExpression = term(_functionDefinition.parameterList());
if (_functionDefinition.isImplemented())
body = term(_functionDefinition.body());
unique_ptr<Term> returnType = termOrConstant(_functionDefinition.experimentalReturnExpression(), BuiltinConstant::Unit);
unique_ptr<Term> name = reference(_functionDefinition);
unique_ptr<Term> function = makeTerm<VariableDeclaration>(std::move(name), std::move(body));
declare(_functionDefinition, *function);
return AST::FunctionInfo{
std::move(function),
std::move(argumentExpression),
std::move(returnType)
};
} }
unique_ptr<Term> ASTTransform::term(legacy::ParameterList const& _parameterList) unique_ptr<Term> ASTTransform::term(legacy::ParameterList const& _parameterList)
{ {
SetNode setNode(*this, _parameterList); SetNode setNode(*this, _parameterList);
return tuple(_parameterList.parameters() | ranges::view::transform([&](auto parameter) { return tuple(_parameterList.parameters() | ranges::view::transform([&](auto parameter) {
solAssert(!parameter->value());
return term(*parameter); return term(*parameter);
}) | ranges::view::move | ranges::to<list<unique_ptr<Term>>>); }) | ranges::view::move | ranges::to<list<unique_ptr<Term>>>);
} }
@ -252,9 +282,10 @@ unique_ptr<Term> ASTTransform::term(legacy::VariableDeclaration const& _variable
unique_ptr<Term> name = reference(_variableDeclaration); unique_ptr<Term> name = reference(_variableDeclaration);
if (_variableDeclaration.typeExpression()) if (_variableDeclaration.typeExpression())
name = constrain(std::move(name), term(*_variableDeclaration.typeExpression())); name = constrain(std::move(name), term(*_variableDeclaration.typeExpression()));
unique_ptr<Term> declaration = makeTerm<VariableDeclaration>(std::move(name), std::move(_initialValue)); if (_initialValue)
declare(_variableDeclaration, *declaration); return application(BuiltinConstant::VariableDefinition, std::move(name), std::move(_initialValue));
return declaration; else
return application(BuiltinConstant::VariableDeclaration, std::move(name));
} }
unique_ptr<Term> ASTTransform::term(legacy::InlineAssembly const& _inlineAssembly) unique_ptr<Term> ASTTransform::term(legacy::InlineAssembly const& _inlineAssembly)
@ -349,47 +380,17 @@ unique_ptr<Term> ASTTransform::term(legacy::Expression const& _expression)
} }
} }
unique_ptr<Term> ASTTransform::binaryOperation(
Token _operator,
unique_ptr<Term> _leftHandSide,
unique_ptr<Term> _rightHandSide
)
{
return application(builtinBinaryOperator(_operator), std::move(_leftHandSide), std::move(_rightHandSide));
}
unique_ptr<Term> ASTTransform::reference(legacy::Declaration const& _declaration) unique_ptr<Term> ASTTransform::reference(legacy::Declaration const& _declaration)
{ {
auto [it, newlyInserted] = m_declarationIndices.emplace(&_declaration, m_ast->declarations.size()); return makeTerm<Reference>(static_cast<size_t>(_declaration.id()), _declaration.name());
if (newlyInserted)
m_ast->declarations.emplace_back(AST::DeclarationInfo{nullptr, {}});
return makeTerm<Reference>(it->second);
} }
size_t ASTTransform::declare(legacy::Declaration const& _declaration, Term& _term) unique_ptr<Term> ASTTransform::tuple(list<unique_ptr<Term>> _components)
{ {
auto [it, newlyInserted] = m_declarationIndices.emplace(&_declaration, m_ast->declarations.size()); if (auto term = ranges::fold_right_last(_components | ranges::view::move, [&](auto a, auto b) { return pair(std::move(a), std::move(b)); }))
if (newlyInserted) return std::move(*term);
m_ast->declarations.emplace_back(AST::DeclarationInfo{&_term, _declaration.name()});
else else
{ return constant(BuiltinConstant::Unit);
auto& info = m_ast->declarations.at(it->second);
solAssert(!info.target);
info.target = &_term;
info.name = _declaration.name();
}
termBase(_term).declaration = it->second;
return it->second;
}
TermBase ASTTransform::makeTermBase()
{
return TermBase{
m_currentLocation,
m_currentNode ? make_optional(m_currentNode->id()) : nullopt,
std::monostate{},
nullopt
};
} }
unique_ptr<Term> ASTTransform::constrain(unique_ptr<Term> _value, unique_ptr<Term> _constraint) unique_ptr<Term> ASTTransform::constrain(unique_ptr<Term> _value, unique_ptr<Term> _constraint)
@ -397,22 +398,18 @@ unique_ptr<Term> ASTTransform::constrain(unique_ptr<Term> _value, unique_ptr<Ter
return application(BuiltinConstant::Constrain, std::move(_value), std::move(_constraint)); return application(BuiltinConstant::Constrain, std::move(_value), std::move(_constraint));
} }
unique_ptr<Term> ASTTransform::builtinTypeClass(langutil::Token _token) std::unique_ptr<Term> ASTTransform::namedFunctionList(std::vector<ASTPointer<ASTNode>> _nodes)
{ {
switch (_token) list<unique_ptr<Term>> functionList;
for (auto subNode: _nodes)
{ {
case Token::Mul: auto const *function = dynamic_cast<FunctionDefinition const*>(subNode.get());
return constant(BuiltinConstant::Mul); solAssert(function);
case Token::Add: unique_ptr<Term> functionName = constant(function->name());
return constant(BuiltinConstant::Add); unique_ptr<Term> functionDefinition = term(*function);
case Token::Integer: functionList.emplace_back(application(BuiltinConstant::NamedTerm, std::move(functionName), std::move(functionDefinition)));
return constant(BuiltinConstant::Integer);
case Token::Equal:
return constant(BuiltinConstant::Equal);
default:
m_analysis.errorReporter().typeError(0000_error, m_currentLocation, "Invalid type class.");
return constant(BuiltinConstant::Undefined);
} }
return tuple(std::move(functionList));
} }
unique_ptr<Term> ASTTransform::builtinBinaryOperator(Token _token) unique_ptr<Term> ASTTransform::builtinBinaryOperator(Token _token)
@ -435,26 +432,29 @@ unique_ptr<Term> ASTTransform::builtinBinaryOperator(Token _token)
} }
} }
unique_ptr<Term> ASTTransform::pair(unique_ptr<Term> _first, unique_ptr<Term> _second) unique_ptr<Term> ASTTransform::builtinTypeClass(langutil::Token _token)
{ {
return application( switch (_token)
application( {
BuiltinConstant::Pair, case Token::Mul:
std::move(_first) return constant(BuiltinConstant::Mul);
), case Token::Add:
std::move(_second) return constant(BuiltinConstant::Add);
); case Token::Integer:
return constant(BuiltinConstant::Integer);
case Token::Equal:
return constant(BuiltinConstant::Equal);
default:
m_analysis.errorReporter().typeError(0000_error, m_currentLocation, "Invalid type class.");
return constant(BuiltinConstant::Undefined);
}
} }
unique_ptr<Term> ASTTransform::tuple(list<unique_ptr<Term>> _components) TermBase ASTTransform::makeTermBase()
{ {
if (auto term = ranges::fold_right_last(_components | ranges::view::move, [&](auto a, auto b) { return pair(std::move(a), std::move(b)); })) return TermBase{
return std::move(*term); m_currentLocation,
else m_currentNode ? make_optional(m_currentNode->id()) : nullopt,
return constant(BuiltinConstant::Unit); std::monostate{}
} };
unique_ptr<Term> ASTTransform::application(unique_ptr<Term> _function, std::list<unique_ptr<Term>> _arguments)
{
return makeTerm<Application>(std::move(_function), tuple(std::move(_arguments)));
} }

View File

@ -35,52 +35,74 @@ class ASTTransform: public ASTConstVisitor
public: public:
ASTTransform(Analysis& _analysis); ASTTransform(Analysis& _analysis);
std::unique_ptr<AST> ast() std::unique_ptr<AST> ast() { return std::move(m_ast); }
{
return std::move(m_ast);
}
private: private:
bool visit(legacy::SourceUnit const&) override { return true; } bool visit(legacy::SourceUnit const&) override { return true; }
bool visit(legacy::PragmaDirective const&) override { return true; } bool visit(legacy::PragmaDirective const&) override { return true; }
bool visit(legacy::ImportDirective const&) override { return true; } bool visit(legacy::ImportDirective const&) override { return true; }
bool visit(legacy::ContractDefinition const& _contractDefinition) override;
bool visit(legacy::FunctionDefinition const& _functionDefinition) override; bool visit(legacy::TypeDefinition const& _typeDefinition) override;
bool visit(legacy::TypeClassDefinition const& _typeClassDefinition) override; bool visit(legacy::TypeClassDefinition const& _typeClassDefinition) override;
bool visit(legacy::TypeClassInstantiation const& _typeClassInstantiation) override; bool visit(legacy::TypeClassInstantiation const& _typeClassInstantiation) override;
bool visit(legacy::TypeDefinition const& _typeDefinition) override; bool visit(legacy::FunctionDefinition const& _functionDefinition) override;
bool visit(legacy::ContractDefinition const& _contractDefinition) override;
bool visitNode(ASTNode const& _node) override;
AST::FunctionInfo functionDefinition(legacy::FunctionDefinition const& _functionDefinition); std::unique_ptr<Term> term(legacy::TypeClassDefinition const& _typeClassDefinition);
std::unique_ptr<Term> term(legacy::ParameterList const& _parameterList); std::unique_ptr<Term> term(legacy::TypeClassInstantiation const& _typeClassInstantiation);
std::unique_ptr<Term> term(legacy::VariableDeclaration const& _variableDeclaration, std::unique_ptr<Term> _initialValue = nullptr); std::unique_ptr<Term> term(legacy::TypeDefinition const& _typeDefinition);
std::unique_ptr<Term> term(legacy::Block const& _block); std::unique_ptr<Term> term(legacy::ContractDefinition const& _contractDefinition);
std::unique_ptr<Term> term(std::vector<ASTPointer<legacy::Statement>> const& _statements); std::unique_ptr<Term> term(legacy::FunctionDefinition const& _functionDefinition);
std::unique_ptr<Term> term(legacy::InlineAssembly const& _assembly);
std::unique_ptr<Term> term(legacy::VariableDeclarationStatement const& _declaration); std::unique_ptr<Term> term(legacy::VariableDeclarationStatement const& _declaration);
std::unique_ptr<Term> term(legacy::Statement const& _statements);
std::unique_ptr<Term> term(legacy::Expression const& _expression);
std::unique_ptr<Term> term(legacy::Assignment const& _assignment); std::unique_ptr<Term> term(legacy::Assignment const& _assignment);
std::unique_ptr<Term> term(legacy::Block const& _block);
std::unique_ptr<Term> term(legacy::Statement const& _statements);
std::unique_ptr<Term> term(legacy::TypeName const& _name); std::unique_ptr<Term> term(legacy::TypeName const& _name);
std::unique_ptr<Term> term(legacy::TypeClassName const& _typeClassName);
std::unique_ptr<Term> binaryOperation( std::unique_ptr<Term> term(legacy::ParameterList const& _parameterList);
langutil::Token _operator, std::unique_ptr<Term> term(legacy::VariableDeclaration const& _variableDeclaration, std::unique_ptr<Term> _initialValue = {});
std::unique_ptr<Term> _leftHandSide, std::unique_ptr<Term> term(legacy::InlineAssembly const& _assembly);
std::unique_ptr<Term> _rightHandSide std::unique_ptr<Term> term(legacy::Expression const& _expression);
);
std::unique_ptr<Term> constant(BuiltinConstant _constant)
{
return makeTerm<Constant>(_constant);
}
std::unique_ptr<Term> constant(std::string _name)
{
return makeTerm<Constant>(_name);
}
// Allows for easy uniform treatment in the variadic templates below. // Allows for easy uniform treatment in the variadic templates below.
std::unique_ptr<Term> term(std::unique_ptr<Term> _term) { return _term; } std::unique_ptr<Term> term(std::unique_ptr<Term> _term) { return _term; }
std::unique_ptr<Term> namedFunctionList(std::vector<ASTPointer<ASTNode>> _nodes);
std::unique_ptr<Term> binaryOperation(
langutil::Token _operator,
std::unique_ptr<Term> _leftHandSide,
std::unique_ptr<Term> _rightHandSide
)
{
return application(builtinBinaryOperator(_operator), std::move(_leftHandSide), std::move(_rightHandSide));
}
std::unique_ptr<Term> reference(legacy::Declaration const& _declaration);
std::unique_ptr<Term> constant(BuiltinConstant _constant) { return makeTerm<Constant>(_constant); }
std::unique_ptr<Term> constant(std::string _name) { return makeTerm<Constant>(_name); }
template<typename T>
std::unique_ptr<Term> termOrConstant(T const* _node, BuiltinConstant _constant)
{
return _node ? term(*_node) : constant(_constant);
}
std::unique_ptr<Term> pair(std::unique_ptr<Term> _first, std::unique_ptr<Term> _second)
{
// Note: BuiltinConstant::Pair has signature a -> (b -> (a, b))
// This reduces n-ary functions to unary functions only as primitives.
return application(
application(
BuiltinConstant::Pair,
std::move(_first)
),
std::move(_second)
);
}
std::unique_ptr<Term> tuple(std::list<std::unique_ptr<Term>> _components); std::unique_ptr<Term> tuple(std::list<std::unique_ptr<Term>> _components);
std::unique_ptr<Term> constrain(std::unique_ptr<Term> _value, std::unique_ptr<Term> _constraint);
std::unique_ptr<Term> builtinBinaryOperator(langutil::Token);
std::unique_ptr<Term> builtinTypeClass(langutil::Token);
template<typename... Args> template<typename... Args>
std::unique_ptr<Term> tuple(Args&&... _args) std::unique_ptr<Term> tuple(Args&&... _args)
{ {
@ -88,8 +110,10 @@ private:
(components.emplace_back(term(std::forward<Args>(_args))), ...); (components.emplace_back(term(std::forward<Args>(_args))), ...);
return tuple(std::move(components)); return tuple(std::move(components));
} }
std::unique_ptr<Term> pair(std::unique_ptr<Term> _first, std::unique_ptr<Term> _second); std::unique_ptr<Term> application(std::unique_ptr<Term> _function, std::list<std::unique_ptr<Term>> _arguments)
std::unique_ptr<Term> application(std::unique_ptr<Term> _function, std::list<std::unique_ptr<Term>> _argument); {
return makeTerm<Application>(std::move(_function), tuple(std::move(_arguments)));
}
template<typename... Args> template<typename... Args>
std::unique_ptr<Term> application(std::unique_ptr<Term> _function, Args&&... _args) std::unique_ptr<Term> application(std::unique_ptr<Term> _function, Args&&... _args)
{ {
@ -97,7 +121,6 @@ private:
(components.emplace_back(term(std::forward<Args>(_args))), ...); (components.emplace_back(term(std::forward<Args>(_args))), ...);
return application(std::move(_function), std::move(components)); return application(std::move(_function), std::move(components));
} }
template<typename... Args> template<typename... Args>
std::unique_ptr<Term> application(BuiltinConstant _function, Args&&... _args) std::unique_ptr<Term> application(BuiltinConstant _function, Args&&... _args)
{ {
@ -106,25 +129,34 @@ private:
return application(constant(_function), std::move(components)); return application(constant(_function), std::move(components));
} }
std::unique_ptr<Term> constrain(std::unique_ptr<Term> _value, std::unique_ptr<Term> _constraint); TermBase makeTermBase();
std::unique_ptr<Term> builtinBinaryOperator(langutil::Token); template<typename TermKind, typename... Args>
std::unique_ptr<Term> builtinTypeClass(langutil::Token); std::unique_ptr<Term> makeTerm(Args&&... _args)
std::unique_ptr<Term> reference(legacy::Declaration const& _declaration); {
size_t declare(legacy::Declaration const& _declaration, Term& _term); return std::make_unique<Term>(TermKind{
makeTermBase(),
std::forward<Args>(_args)...
});
}
Analysis& m_analysis;
langutil::ErrorReporter& m_errorReporter;
std::unique_ptr<AST> m_ast;
struct SetNode { struct SetNode {
SetNode(ASTTransform& _parent, ASTNode const& _node): SetNode(ASTTransform& _parent, ASTNode const& _node):
m_parent(_parent), m_parent(_parent),
m_previousNode(_parent.m_currentNode), m_previousNode(_parent.m_currentNode),
m_previousLocation(_parent.m_currentLocation) m_previousLocation(_parent.m_currentLocation)
{ {
_parent.m_currentNode = &_node; _parent.m_currentNode = &_node;
_parent.m_currentLocation = _node.location(); _parent.m_currentLocation = _node.location();
} }
SetNode(ASTTransform& _parent, langutil::SourceLocation const& _location): SetNode(ASTTransform& _parent, langutil::SourceLocation const& _location):
m_parent(_parent), m_parent(_parent),
m_previousNode(_parent.m_currentNode), m_previousNode(_parent.m_currentNode),
m_previousLocation(_parent.m_currentLocation) m_previousLocation(_parent.m_currentLocation)
{ {
_parent.m_currentNode = nullptr; _parent.m_currentNode = nullptr;
_parent.m_currentLocation = _location; _parent.m_currentLocation = _location;
@ -138,31 +170,8 @@ private:
ASTNode const* m_previousNode = nullptr; ASTNode const* m_previousNode = nullptr;
langutil::SourceLocation m_previousLocation; langutil::SourceLocation m_previousLocation;
}; };
TermBase makeTermBase();
template<typename TermKind, typename... Args>
std::unique_ptr<Term> makeTerm(Args&&... _args)
{
return std::make_unique<Term>(TermKind{
makeTermBase(),
std::forward<Args>(_args)...
});
}
template<typename T>
std::unique_ptr<Term> termOrConstant(T const* _node, BuiltinConstant _constant)
{
if (_node)
return term(*_node);
else
return constant(_constant);
}
Analysis& m_analysis;
langutil::ErrorReporter& m_errorReporter;
ASTNode const* m_currentNode = nullptr; ASTNode const* m_currentNode = nullptr;
langutil::SourceLocation m_currentLocation; langutil::SourceLocation m_currentLocation;
std::unique_ptr<AST> m_ast;
std::map<frontend::Declaration const*, size_t, ASTCompareByID<frontend::Declaration>> m_declarationIndices;
}; };
} }

View File

@ -18,13 +18,17 @@
#include <libsolidity/experimental/analysis/TypeCheck.h> #include <libsolidity/experimental/analysis/TypeCheck.h>
#include <libsolidity/experimental/ast/TypeSystemHelper.h> #include <libsolidity/experimental/ast/TypeSystemHelper.h>
#include <libsolidity/ast/AST.h> #include <libsolidity/ast/AST.h>
#include <libsolutil/CommonIO.h>
#include <libsolutil/Visitor.h>
#include <range/v3/view/map.hpp>
#include <liblangutil/ErrorReporter.h> #include <liblangutil/ErrorReporter.h>
#include <liblangutil/Exceptions.h> #include <liblangutil/Exceptions.h>
#include <libsolutil/AnsiColorized.h>
#include <libsolutil/CommonIO.h>
#include <libsolutil/Visitor.h>
#include <range/v3/view/map.hpp>
#include <range/v3/view/reverse.hpp>
using namespace std; using namespace std;
using namespace solidity; using namespace solidity;
using namespace langutil; using namespace langutil;
@ -32,8 +36,143 @@ using namespace solidity::frontend::experimental;
namespace namespace
{ {
using Term = std::variant<Application, Lambda, InlineAssembly, VariableDeclaration, Reference, Constant>;
struct TPat
{
using Unifier = std::function<void(Type, Type)>;
template<typename R, typename... Args>
TPat(R _f(Args...)): generator([f = _f](TypeEnvironment& _env, Unifier _unifier) -> Type {
return invoke(_env, _unifier, f, std::make_index_sequence<sizeof...(Args)>{});
}) {}
TPat(PrimitiveType _type): generator([type = _type](TypeEnvironment& _env, Unifier) { return _env.typeSystem().type(type, {}); }) {}
Type realize(TypeEnvironment& _env, Unifier _unifier) const { return generator(_env, _unifier); }
TPat(std::function<Type(TypeEnvironment&, Unifier)> _generator):generator(std::move(_generator)) {}
TPat(Type _t): generator([t = _t](TypeEnvironment&, Unifier) -> Type { return t; }) {}
private:
template<size_t I>
static TPat makeFreshVariable(TypeEnvironment& _env) { return TPat{_env.typeSystem().freshTypeVariable({})}; }
template<typename Generator, size_t... Is>
static Type invoke(TypeEnvironment& _env, Unifier _unifier, Generator const& _generator, std::index_sequence<Is...>)
{
// Use an auxiliary array to ensure deterministic evaluation order.
[[maybe_unused]] std::array<TPat, sizeof...(Is)> patterns{makeFreshVariable<Is>(_env)...};
return (_generator(std::move(patterns[Is])...)).realize(_env, _unifier);
}
std::function<Type(TypeEnvironment&, Unifier)> generator;
};
namespace pattern_ops
{
using Unifier = std::function<void(Type, Type)>;
inline TPat operator>>(TPat _a, TPat _b)
{
return TPat([a = std::move(_a), b = std::move(_b)](TypeEnvironment& _env, Unifier _unifier) -> Type {
return TypeSystemHelpers{_env.typeSystem()}.functionType(a.realize(_env, _unifier), b.realize(_env, _unifier));
});
}
inline TPat operator==(TPat _a, TPat _b)
{
return TPat([a = std::move(_a), b = std::move(_b)](TypeEnvironment& _env, Unifier _unifier) -> Type {
Type left = a.realize(_env, _unifier);
Type right = b.realize(_env, _unifier);
_unifier(left, right);
return left;
});
}
template<typename... Args>
TPat tuple(Args... args)
{
return TPat([args = std::array<TPat, sizeof...(Args)>{{std::move(args)...}}](TypeEnvironment& _env, Unifier _unifier) -> Type {
return TypeSystemHelpers{_env.typeSystem()}.tupleType(
args | ranges::view::transform([&](TPat _pat) { return _pat.realize(_env, _unifier); }) | ranges::to<vector<Type>>
);
});
}
}
struct BuiltinConstantInfo
{
std::string name;
std::optional<TPat> builtinType;
};
[[maybe_unused]] BuiltinConstantInfo const& builtinConstantInfo(BuiltinConstant _constant)
{
using namespace pattern_ops;
static const TPat unit{PrimitiveType::Unit};
static const auto info = std::map<BuiltinConstant, BuiltinConstantInfo>{
{BuiltinConstant::Unit, {"Unit", unit}},
{BuiltinConstant::Pair, {"Pair", +[](TPat a, TPat b) { return a >> (b >> tuple(a,b)); }}},
{BuiltinConstant::Fun, {"Fun", +[](TPat a, TPat b) { return tuple(a,b) >> (a >> b); }}},
{BuiltinConstant::Constrain, {"Constrain", +[](TPat a) { return tuple(a,a) >> a; }}},
{BuiltinConstant::NamedTerm, {"NamedTerm", +[](TPat a) { return tuple(unit, a) >> a; /* TODO: (name, a) >> a */ }}},
{BuiltinConstant::TypeDeclaration, {"TypeDeclaration", nullopt}},
{BuiltinConstant::TypeDefinition, {"TypeDefinition", +[](TPat type, TPat args, TPat value) {
return tuple(type, args, value) >> (args >> type);
}}},
{BuiltinConstant::TypeClassDefinition, {"TypeClassDefinition", nullopt}},
{BuiltinConstant::TypeClassInstantiation, {"TypeClassInstantiation", nullopt}},
{BuiltinConstant::FunctionDeclaration, {"FunctionDeclaration", nullopt}},
{BuiltinConstant::FunctionDefinition, {"FunctionDefinition", +[](TPat a, TPat r) {
return tuple(a >> r, a, r, r) >> (a >> r);
}}},
{BuiltinConstant::ContractDefinition, {"ContractDefinition", +[]() {
return tuple(unit, (unit >> unit)) >> unit;
}}},
{BuiltinConstant::VariableDeclaration, {"VariableDeclaration", +[](TPat a) { return a >> a; }}},
{BuiltinConstant::VariableDefinition, {"VariableDefinition", nullopt}},
{BuiltinConstant::Block, {"Block", +[](TPat a) { return a >> a; }}},
{BuiltinConstant::ReturnStatement, {"ReturnStatement", +[](TPat a) { return a >> a; }}},
{BuiltinConstant::RegularStatement, {"RegularStatement", +[](TPat a) { return a >> unit; }}},
{BuiltinConstant::ChainStatements, {"ChainStatements", +[](TPat a, TPat b) { return tuple(a,b) >> b; }}},
{BuiltinConstant::Assign, {"Assign", +[](TPat a) { return tuple(a,a) >> unit; }}},
{BuiltinConstant::MemberAccess, {"MemberAccess", nullopt}},
{BuiltinConstant::Mul, {"Mul", nullopt}},
{BuiltinConstant::Add, {"Add", nullopt}},
{BuiltinConstant::Void, {"Void", nullopt}},
{BuiltinConstant::Word, {"Word", PrimitiveType::Word}},
{BuiltinConstant::Integer, {"Integer", nullopt}},
{BuiltinConstant::Bool, {"Bool", nullopt}},
{BuiltinConstant::Undefined, {"Undefined", nullopt}},
{BuiltinConstant::Equal, {"Equal", nullopt}},
};
return info.at(_constant);
}
template<typename Visitor>
void forEachTopLevelTerm(AST& _ast, Visitor _visitor)
{
for (auto& term: _ast.typeDefinitions | ranges::view::values)
_visitor(*term);
for (auto& term: _ast.typeClasses | ranges::view::values)
_visitor(*term);
for (auto& term: _ast.typeClassInstantiations | ranges::view::values)
_visitor(*term);
for (auto& term: _ast.functions | ranges::views::values)
_visitor(*term);
for (auto& term: _ast.contracts | ranges::views::values)
_visitor(*term);
}
template<typename Visitor>
void forEachImmediateSubTerm(Term& _term, Visitor _visitor)
{
std::visit(util::GenericVisitor{
[&](Application const& _app) {
_visitor(*_app.expression);
_visitor(*_app.argument);
},
[&](Lambda const& _lambda)
{
_visitor(*_lambda.argument);
_visitor(*_lambda.value);
},
[&](InlineAssembly const&)
{
// TODO
},
[&](Reference const&) {},
[&](Constant const&) {}
}, _term);
}
optional<pair<reference_wrapper<Term const>, reference_wrapper<Term const>>> destPair(Term const& _term) optional<pair<reference_wrapper<Term const>, reference_wrapper<Term const>>> destPair(Term const& _term)
{ {
@ -71,29 +210,15 @@ void setType(Term& _term, Type _type)
std::visit([&](auto& term) { term.type = _type; }, _term); std::visit([&](auto& term) { term.type = _type; }, _term);
} }
string colorize(string _color, string _string)
{
return _color + _string + util::formatting::RESET;
}
string termPrinter(AST& _ast, Term const& _term, TypeEnvironment* _env = nullptr, bool _sugarPairs = true, bool _sugarConsts = true, size_t _indent = 0) string termPrinter(AST& _ast, Term const& _term, TypeEnvironment* _env = nullptr, bool _sugarPairs = true, bool _sugarConsts = true, size_t _indent = 0)
{ {
using namespace util::formatting;
auto recurse = [&](Term const& _next) { return termPrinter(_ast, _next, _env, _sugarPairs, _sugarConsts, _indent); }; auto recurse = [&](Term const& _next) { return termPrinter(_ast, _next, _env, _sugarPairs, _sugarConsts, _indent); };
static const std::map<BuiltinConstant, const char*> builtinConstants = {
{BuiltinConstant::Unit, "()"},
{BuiltinConstant::Pair, "Pair"},
{BuiltinConstant::Fun, "Fun"},
{BuiltinConstant::Constrain, "Constrain"},
{BuiltinConstant::Return, "Return"},
{BuiltinConstant::Block, "Block"},
{BuiltinConstant::Statement, "Statement"},
{BuiltinConstant::ChainStatements, "ChainStatements"},
{BuiltinConstant::Assign, "Assign"},
{BuiltinConstant::MemberAccess, "MemberAccess"},
{BuiltinConstant::Mul, "Mul"},
{BuiltinConstant::Add, "Add"},
{BuiltinConstant::Void, "void"},
{BuiltinConstant::Word, "word"},
{BuiltinConstant::Integer, "Integer"},
{BuiltinConstant::Bool, "Bool"},
{BuiltinConstant::Undefined, "Undefined"},
{BuiltinConstant::Equal, "Equal"}
};
string result = std::visit(util::GenericVisitor{ string result = std::visit(util::GenericVisitor{
[&](Application const& _app) { [&](Application const& _app) {
if (_sugarPairs) if (_sugarPairs)
@ -154,8 +279,10 @@ string termPrinter(AST& _ast, Term const& _term, TypeEnvironment* _env = nullptr
if (auto pair = destPair(*_app.argument)) if (auto pair = destPair(*_app.argument))
return recurse(pair->first) + "\n" + std::string(_indent, '\t') + recurse(pair->second); return recurse(pair->first) + "\n" + std::string(_indent, '\t') + recurse(pair->second);
break; break;
case BuiltinConstant::Statement: case BuiltinConstant::RegularStatement:
return recurse(*_app.argument) + ";"; return recurse(*_app.argument) + ";";
case BuiltinConstant::ReturnStatement:
return colorize(CYAN, "return ") + recurse(*_app.argument) + ";";
default: default:
break; break;
} }
@ -166,23 +293,20 @@ string termPrinter(AST& _ast, Term const& _term, TypeEnvironment* _env = nullptr
return "(" + recurse(*_lambda.argument) + " -> " + recurse(*_lambda.value) + ")"; return "(" + recurse(*_lambda.argument) + " -> " + recurse(*_lambda.value) + ")";
}, },
[&](InlineAssembly const&) -> string { [&](InlineAssembly const&) -> string {
return "assembly"; return colorize(CYAN, "assembly");
},
[&](VariableDeclaration const& _varDecl) {
return "let " + recurse(*_varDecl.namePattern) + (_varDecl.initialValue ? " = " + recurse(*_varDecl.initialValue) : "");
}, },
[&](Reference const& _reference) { [&](Reference const& _reference) {
return "" + _ast.declarations.at(_reference.index).name + ""; return _reference.name.empty() ? util::toString(_reference.index) : _reference.name;
}, },
[&](Constant const& _constant) { [&](Constant const& _constant) {
return "" + std::visit(util::GenericVisitor{ return colorize(BLUE, std::visit(util::GenericVisitor{
[](BuiltinConstant _constant) -> string { [](BuiltinConstant _constant) -> string {
return builtinConstants.at(_constant); return builtinConstantInfo(_constant).name;
}, },
[](std::string const& _name) { [](std::string const& _name) {
return _name; return _name;
} }
}, _constant.name) + ""; }, _constant.name));
} }
}, _term); }, _term);
if (_env) if (_env)
@ -190,140 +314,39 @@ string termPrinter(AST& _ast, Term const& _term, TypeEnvironment* _env = nullptr
Type termType = type(_term); Type termType = type(_term);
if (!holds_alternative<std::monostate>(termType)) if (!holds_alternative<std::monostate>(termType))
{ {
result += "[:" + TypeEnvironmentHelpers{*_env}.typeToString(termType) + "]"; result += colorize(GREEN, "[:" + TypeEnvironmentHelpers{*_env}.typeToString(termType) + "]");
} }
} }
return result; return result;
} }
std::string functionPrinter(AST& _ast, AST::FunctionInfo const& _info, TypeEnvironment* _env = nullptr, bool _sugarPairs = true, bool _sugarConsts = true, size_t _indent = 0)
{
auto printTerm = [&](Term const& _term) { return termPrinter(_ast, _term, _env, _sugarPairs, _sugarConsts, _indent + 1); };
return "function (" + printTerm(*_info.arguments) + ") -> " + printTerm(*_info.returnType) + " = " + printTerm(*_info.function) + "\n";
}
std::string astPrinter(AST& _ast, TypeEnvironment* _env = nullptr, bool _sugarPairs = true, bool _sugarConsts = true, size_t _indent = 0) std::string astPrinter(AST& _ast, TypeEnvironment* _env = nullptr, bool _sugarPairs = true, bool _sugarConsts = true, size_t _indent = 0)
{ {
auto printTerm = [&](Term const& _term) { return termPrinter(_ast, _term, _env, _sugarPairs, _sugarConsts, _indent + 1); };
auto printFunction = [&](AST::FunctionInfo const& _info) { return functionPrinter(_ast, _info, _env, _sugarPairs, _sugarConsts, _indent + 1); };
std::string result; std::string result;
for (auto& info: _ast.typeDefinitions | ranges::view::values) auto printTerm = [&](Term const& _term) { result += termPrinter(_ast, _term, _env, _sugarPairs, _sugarConsts, _indent + 1) + "\n\n"; };
{ forEachTopLevelTerm(_ast, printTerm);
result += "type " + printTerm(*info.declaration);
if (info.arguments)
result += " " + printTerm(*info.arguments);
if (info.value)
result += " = " + printTerm(*info.declaration);
result += "\n\n";
}
for (auto& info: _ast.typeClasses | ranges::view::values)
{
result += "class " + printTerm(*info.typeVariable) + ":" + printTerm(*info.declaration) + " {";
_indent++;
for (auto&& functionInfo: info.functions | ranges::view::values)
result += printFunction(functionInfo);
_indent--;
result += "}\n\n";
}
for (auto& info: _ast.typeClassInstantiations | ranges::view::values)
{
result += "instantiation " + printTerm(*info.typeConstructor) + "(" + printTerm(*info.argumentSorts) + "):" + printTerm(*info.typeClass) + "{\n";
_indent++;
for (auto&& functionInfo: info.functions | ranges::view::values)
result += printFunction(functionInfo);
_indent--;
}
for (auto& functionInfo: _ast.functions | ranges::views::values)
{
result += printFunction(functionInfo);
result += "\n";
}
for (auto& [contract, info]: _ast.contracts)
{
result += "contract " + contract->name() + " {\n";
_indent++;
for(auto& function: info.functions | ranges::view::values)
result += printFunction(function);
_indent--;
result += "}\n\n";
}
return result; return result;
} }
} }
namespace
{
struct TVar
{
TypeEnvironment& env;
Type type;
};
inline TVar operator>>(TVar a, TVar b)
{
TypeSystemHelpers helper{a.env.typeSystem()};
return TVar{a.env, helper.functionType(a.type, b.type)};
}
inline TVar operator,(TVar a, TVar b)
{
TypeSystemHelpers helper{a.env.typeSystem()};
return TVar{a.env, helper.tupleType({a.type, b.type})};
}
template <typename T>
struct ArgumentCount;
template <typename R, typename... Args>
struct ArgumentCount<std::function<R(Args...)>> {
static constexpr size_t value = sizeof...(Args);
};
struct TypeGenerator
{
template<typename Generator>
TypeGenerator(Generator&& _generator):generator([generator = std::move(_generator)](TypeEnvironment& _env) -> Type {
return invoke(_env, generator, std::make_index_sequence<ArgumentCount<decltype(std::function{_generator})>::value>{});
}) {}
TypeGenerator(TVar _type):generator([type = _type.type](TypeEnvironment& _env) -> Type { return _env.fresh(type); }) {}
TypeGenerator(PrimitiveType _type): generator([type = _type](TypeEnvironment& _env) -> Type { return _env.typeSystem().type(type, {}); }) {}
Type operator()(TypeEnvironment& _env) const { return generator(_env); }
private:
template<size_t I>
static TVar makeFreshVariable(TypeEnvironment& _env) { return TVar{_env, _env.typeSystem().freshTypeVariable({}) }; }
template<typename Generator, size_t... Is>
static Type invoke(TypeEnvironment& _env, Generator&& _generator, std::index_sequence<Is...>)
{
// Use an auxiliary array to ensure deterministic evaluation order.
std::array<TVar, sizeof...(Is)> tvars{makeFreshVariable<Is>(_env)...};
return std::invoke(_generator, tvars[Is]...).type;
}
std::function<Type(TypeEnvironment&)> generator;
};
}
void TypeCheck::operator()(AST& _ast) void TypeCheck::operator()(AST& _ast)
{ {
TypeSystem& typeSystem = m_analysis.typeSystem(); TypeSystem& typeSystem = m_analysis.typeSystem();
TypeSystemHelpers helper{typeSystem}; TypeSystemHelpers helper{typeSystem};
TypeEnvironment& env = typeSystem.env(); TypeEnvironment& env = typeSystem.env();
TVar unit = TVar{env, typeSystem.type(PrimitiveType::Unit, {})}; list<reference_wrapper<Term>> toCheck;
TVar word = TVar{env, typeSystem.type(PrimitiveType::Word, {})}; forEachTopLevelTerm(_ast, [&](Term& _root) {
std::unique_ptr<TVar> currentReturn; list<reference_wrapper<Term>> staged{{_root}};
std::map<BuiltinConstant, TypeGenerator> builtinConstantTypeGenerators{ while (!staged.empty())
{BuiltinConstant::Unit, unit}, {
{BuiltinConstant::Pair, [](TVar a, TVar b) { return a >> (b >> (a,b)); }}, Term& term = staged.front().get();
{BuiltinConstant::Word, word}, staged.pop_front();
{BuiltinConstant::Assign, [=](TVar a) { return (a,a) >> unit; }}, // TODO: (a,a) >> a toCheck.push_back(term);
{BuiltinConstant::Block, [](TVar a) { return a >> a; }}, forEachImmediateSubTerm(term, [&](Term& _subTerm) { staged.push_back(_subTerm); });
{BuiltinConstant::ChainStatements, [](TVar a, TVar b) { return (a,b) >> b; }}, }
{BuiltinConstant::Statement, [=](TVar a) { return a >> unit; }}, });
{BuiltinConstant::Return, [&]() {
solAssert(currentReturn);
return *currentReturn >> unit;
}},
{BuiltinConstant::Fun, [&](TVar a, TVar b) {
return (a,b) >> (a >> b);
}},
};
auto unifyForTerm = [&](Type _a, Type _b, Term* _term) { auto unifyForTerm = [&](Type _a, Type _b, Term* _term) {
for (auto failure: env.unify(_a, _b)) for (auto failure: env.unify(_a, _b))
@ -361,161 +384,65 @@ void TypeCheck::operator()(AST& _ast)
}, failure); }, failure);
} }
}; };
auto checkTerm = [&](Term& _root) {
std::list<reference_wrapper<Term>> heap;
heap.emplace_back(_root);
auto checked = [](Term const& _term) {
return !holds_alternative<std::monostate>(type(_term));
};
auto canCheck = [&](Term& _term) -> bool {
bool hasUnchecked = false;
auto stage = [&](Term& _term) {
if (!checked(_term))
{
heap.push_back(_term);
hasUnchecked = true;
}
};
std::visit(util::GenericVisitor{
[&](Application const& _app) {
stage(*_app.expression);
stage(*_app.argument);
},
[&](Lambda const& _lambda)
{
stage(*_lambda.argument);
stage(*_lambda.value);
},
[&](InlineAssembly const&)
{
// TODO
},
[&](VariableDeclaration const& _varDecl)
{
stage(*_varDecl.namePattern);
if (_varDecl.initialValue)
stage(*_varDecl.initialValue);
},
[&](Reference const&)
{
},
[&](Constant const&) {}
}, _term);
if (hasUnchecked)
{
stage(_term);
return false;
}
return true;
};
std::map<size_t, Type> declarationTypes;
while (!heap.empty())
{
Term& current = heap.front();
heap.pop_front();
if (checked(current))
continue;
if (!canCheck(current))
continue;
auto unify = [&](Type _a, Type _b) { unifyForTerm(_a, _b, &current); }; std::map<size_t, Type> declarationTypes;
for(auto term: toCheck | ranges::view::reverse)
std::visit(util::GenericVisitor{ {
[&](Application const& _app) { auto unify = [&](Type _a, Type _b) { unifyForTerm(_a, _b, &term.get()); };
if (auto* constant = get_if<Constant>(_app.expression.get())) std::visit(util::GenericVisitor{
if (auto* builtin = get_if<BuiltinConstant>(&constant->name)) [&](Application const& _app) {
if (*builtin == BuiltinConstant::Constrain) /*if (auto* constant = get_if<Constant>(_app.expression.get()))
if (auto args = destPair(*_app.argument)) if (auto* builtin = get_if<BuiltinConstant>(&constant->name))
{ if (*builtin == BuiltinConstant::Constrain)
Type result = type(args->first); if (auto args = destPair(*_app.argument))
unify(result, type(args->second));
setType(current, result);
return;
}
Type resultType = typeSystem.freshTypeVariable({});
unify(helper.functionType(type(*_app.argument), resultType), type(*_app.expression));
setType(current, resultType);
},
[&](Lambda const& _lambda)
{
setType(current, helper.functionType(type(*_lambda.argument), type(*_lambda.value)));
},
[&](InlineAssembly const& _inlineAssembly)
{
// TODO
(void)_inlineAssembly;
setType(current, typeSystem.type(PrimitiveType::Unit, {}));
},
[&](VariableDeclaration const& _varDecl)
{
Type name = type(*_varDecl.namePattern);
if (_varDecl.initialValue)
unify(name, type(*_varDecl.initialValue));
setType(current, name);
},
[&](Reference const& _reference)
{
Type result = typeSystem.freshTypeVariable({});
if (
auto [it, newlyInserted] = declarationTypes.emplace(_reference.index, result);
!newlyInserted
)
unify(result, it->second);
setType(current, result);
},
[&](Constant const& _constant)
{
bool assigned = std::visit(util::GenericVisitor{
[&](std::string const&) { return false; },
[&](BuiltinConstant const& _constant) {
if (auto* generator = util::valueOrNullptr(builtinConstantTypeGenerators, _constant))
{ {
setType(current, (*generator)(env)); Type result = type(args->first);
return true; unify(result, type(args->second));
} setType(term, result);
return false; return;
} }*/
}, _constant.name); Type resultType = typeSystem.freshTypeVariable({});
if (!assigned) unify(helper.functionType(type(*_app.argument), resultType), type(*_app.expression));
setType(current, typeSystem.freshTypeVariable({})); setType(term, resultType);
} },
}, current); [&](Lambda const& _lambda)
solAssert(checked(current));
if (auto declaration = termBase(current).declaration)
{ {
setType(term, helper.functionType(type(*_lambda.argument), type(*_lambda.value)));
},
[&](InlineAssembly const& _inlineAssembly)
{
// TODO
(void)_inlineAssembly;
setType(term, typeSystem.type(PrimitiveType::Unit, {}));
},
[&](Reference const& _reference)
{
Type result = typeSystem.freshTypeVariable({});
if ( if (
auto [it, newlyInserted] = declarationTypes.emplace(*declaration, type(current)); auto [it, newlyInserted] = declarationTypes.emplace(_reference.index, result);
!newlyInserted !newlyInserted
) )
unify(type(current), it->second); unify(result, it->second);
setType(term, result);
},
[&](Constant const& _constant)
{
bool assigned = std::visit(util::GenericVisitor{
[&](std::string const&) { return false; },
[&](BuiltinConstant const& _constant) {
if (auto generator = builtinConstantInfo(_constant).builtinType)
{
setType(term, (*generator).realize(env, unify));
return true;
}
return false;
}
}, _constant.name);
if (!assigned)
setType(term, typeSystem.freshTypeVariable({}));
} }
} }, term.get());
}; solAssert(!holds_alternative<std::monostate>(type(term)));
for(auto& info: _ast.typeDefinitions | ranges::view::values)
{
if (info.arguments)
checkTerm(*info.arguments);
if (info.value)
checkTerm(*info.value);
checkTerm(*info.declaration);
}
for(auto& info: _ast.contracts | ranges::view::values)
for(auto& function: info.functions | ranges::view::values)
{
checkTerm(*function.returnType);
ScopedSaveAndRestore returnType{currentReturn, std::make_unique<TVar>(TVar{env,type(*function.returnType)})};
checkTerm(*function.function);
checkTerm(*function.arguments);
// TODO: unify stuff?
}
for(auto&& info: _ast.functions | ranges::view::values)
{
checkTerm(*info.returnType);
ScopedSaveAndRestore returnType{currentReturn, std::make_unique<TVar>(TVar{env,type(*info.returnType)})};
checkTerm(*info.function);
checkTerm(*info.arguments);
// TODO: unify stuff
} }
std::cout << astPrinter(_ast, &env) << std::endl; std::cout << astPrinter(_ast, &env) << std::endl;

View File

@ -34,17 +34,16 @@ struct TermBase
langutil::SourceLocation location; langutil::SourceLocation location;
std::optional<int64_t> legacyId; std::optional<int64_t> legacyId;
Type type; Type type;
std::optional<size_t> declaration;
}; };
struct Application; struct Application;
struct Lambda; struct Lambda;
struct InlineAssembly; struct InlineAssembly;
struct VariableDeclaration;
struct Reference: TermBase struct Reference: TermBase
{ {
size_t index = std::numeric_limits<size_t>::max(); size_t index = std::numeric_limits<size_t>::max();
std::string name;
}; };
enum class BuiltinConstant enum class BuiltinConstant
@ -54,9 +53,19 @@ enum class BuiltinConstant
Pair, Pair,
Fun, Fun,
Constrain, Constrain,
Return, NamedTerm,
TypeDeclaration,
TypeDefinition,
TypeClassDefinition,
TypeClassInstantiation,
FunctionDeclaration,
FunctionDefinition,
ContractDefinition,
VariableDeclaration,
VariableDefinition,
Block, Block,
Statement, ReturnStatement,
RegularStatement,
ChainStatements, ChainStatements,
Assign, Assign,
MemberAccess, MemberAccess,
@ -75,7 +84,7 @@ struct Constant: TermBase
std::variant<std::string, BuiltinConstant> name; std::variant<std::string, BuiltinConstant> name;
}; };
using Term = std::variant<Application, Lambda, InlineAssembly, VariableDeclaration, Reference, Constant>; using Term = std::variant<Application, Lambda, InlineAssembly, Reference, Constant>;
struct InlineAssembly: TermBase struct InlineAssembly: TermBase
{ {
@ -83,12 +92,6 @@ struct InlineAssembly: TermBase
std::map<yul::Identifier const*, std::unique_ptr<Term>> references; std::map<yul::Identifier const*, std::unique_ptr<Term>> references;
}; };
struct VariableDeclaration: TermBase
{
std::unique_ptr<Term> namePattern;
std::unique_ptr<Term> initialValue;
};
struct Application: TermBase struct Application: TermBase
{ {
std::unique_ptr<Term> expression; std::unique_ptr<Term> expression;
@ -145,41 +148,11 @@ langutil::SourceLocation locationOf(T const& _t)
struct AST struct AST
{ {
struct FunctionInfo { std::map<frontend::TypeDefinition const*, std::unique_ptr<Term>, ASTCompareByID<frontend::TypeDefinition>> typeDefinitions;
std::unique_ptr<Term> function; std::map<frontend::TypeClassDefinition const*, std::unique_ptr<Term>, ASTCompareByID<frontend::TypeClassDefinition>> typeClasses;
std::unique_ptr<Term> arguments; std::map<frontend::TypeClassInstantiation const*, std::unique_ptr<Term>, ASTCompareByID<frontend::TypeClassInstantiation>> typeClassInstantiations;
std::unique_ptr<Term> returnType; std::map<frontend::FunctionDefinition const*, std::unique_ptr<Term>, ASTCompareByID<frontend::FunctionDefinition>> functions;
}; std::map<frontend::ContractDefinition const*, std::unique_ptr<Term>, ASTCompareByID<frontend::ContractDefinition>> contracts;
std::map<frontend::FunctionDefinition const*, FunctionInfo, ASTCompareByID<frontend::FunctionDefinition>> functions;
struct ContractInfo {
std::map<std::string, FunctionInfo> functions;
};
std::map<frontend::ContractDefinition const*, ContractInfo, ASTCompareByID<frontend::ContractDefinition>> contracts;
struct TypeInformation {
std::unique_ptr<Term> declaration;
std::unique_ptr<Term> arguments;
std::unique_ptr<Term> value;
};
std::map<frontend::TypeDefinition const*, TypeInformation, ASTCompareByID<frontend::TypeDefinition>> typeDefinitions;
struct TypeClassInformation {
std::unique_ptr<Term> declaration;
std::unique_ptr<Term> typeVariable;
std::map<std::string, FunctionInfo> functions;
};
struct TypeClassInstantiationInformation {
std::unique_ptr<Term> typeConstructor;
std::unique_ptr<Term> argumentSorts;
std::unique_ptr<Term> typeClass;
std::map<std::string, FunctionInfo> functions;
};
std::map<frontend::TypeClassDefinition const*, TypeClassInformation, ASTCompareByID<frontend::TypeClassDefinition>> typeClasses;
std::map<frontend::TypeClassInstantiation const*, TypeClassInstantiationInformation, ASTCompareByID<frontend::TypeClassInstantiation>> typeClassInstantiations;
struct DeclarationInfo
{
Term const* target = nullptr;
std::string name;
};
std::vector<DeclarationInfo> declarations;
}; };
} }

View File

@ -37,45 +37,6 @@ using namespace solidity::langutil;
using namespace solidity::frontend; using namespace solidity::frontend;
using namespace solidity::frontend::experimental; using namespace solidity::frontend::experimental;
/*std::optional<TypeConstructor> experimental::typeConstructorFromTypeName(Analysis const& _analysis, TypeName const& _typeName)
{
if (auto const* elementaryTypeName = dynamic_cast<ElementaryTypeName const*>(&_typeName))
{
if (auto constructor = typeConstructorFromToken(_analysis, elementaryTypeName->typeName().token()))
return *constructor;
}
else if (auto const* userDefinedType = dynamic_cast<UserDefinedTypeName const*>(&_typeName))
{
if (auto const* referencedDeclaration = userDefinedType->pathNode().annotation().referencedDeclaration)
return _analysis.annotation<TypeRegistration>(*referencedDeclaration).typeConstructor;
}
return nullopt;
}*/
/*
std::optional<TypeConstructor> experimental::typeConstructorFromToken(Analysis const& _analysis, langutil::Token _token)
{
TypeSystem const& typeSystem = _analysis.typeSystem();
switch(_token)
{
case Token::Void:
return typeSystem.builtinConstructor(BuiltinType::Void);
case Token::Fun:
return typeSystem.builtinConstructor(BuiltinType::Function);
case Token::Unit:
return typeSystem.builtinConstructor(BuiltinType::Unit);
case Token::Pair:
return typeSystem.builtinConstructor(BuiltinType::Pair);
case Token::Word:
return typeSystem.builtinConstructor(BuiltinType::Word);
case Token::Integer:
return typeSystem.builtinConstructor(BuiltinType::Integer);
case Token::Bool:
return typeSystem.builtinConstructor(BuiltinType::Bool);
default:
return nullopt;
}
}*/
std::optional<BuiltinClass> experimental::builtinClassFromToken(langutil::Token _token) std::optional<BuiltinClass> experimental::builtinClassFromToken(langutil::Token _token)
{ {
switch (_token) switch (_token)
@ -100,22 +61,7 @@ std::optional<BuiltinClass> experimental::builtinClassFromToken(langutil::Token
return nullopt; return nullopt;
} }
} }
/*
std::optional<TypeClass> experimental::typeClassFromTypeClassName(TypeClassName const& _typeClass)
{
return std::visit(util::GenericVisitor{
[&](ASTPointer<IdentifierPath> _path) -> optional<TypeClass> {
auto classDefinition = dynamic_cast<TypeClassDefinition const*>(_path->annotation().referencedDeclaration);
if (!classDefinition)
return nullopt;
return TypeClass{classDefinition};
},
[&](Token _token) -> optional<TypeClass> {
return typeClassFromToken(_token);
}
}, _typeClass.name());
}
*/
experimental::Type TypeSystemHelpers::tupleType(vector<Type> _elements) const experimental::Type TypeSystemHelpers::tupleType(vector<Type> _elements) const
{ {
if (_elements.empty()) if (_elements.empty())

View File

@ -18,16 +18,13 @@
#pragma once #pragma once
#include <libsolidity/experimental/ast/TypeSystem.h> #include <libsolidity/experimental/ast/TypeSystem.h>
#include <libsolidity/ast/ASTForward.h>
#include <liblangutil/Token.h> #include <liblangutil/Token.h>
namespace solidity::frontend::experimental namespace solidity::frontend::experimental
{ {
class Analysis; class Analysis;
enum class BuiltinClass; enum class BuiltinClass;
//std::optional<TypeConstructor> typeConstructorFromTypeName(Analysis const& _analysis, TypeName const& _typeName);
//std::optional<TypeConstructor> typeConstructorFromToken(Analysis const& _analysis, langutil::Token _token);
//std::optional<TypeClass> typeClassFromTypeClassName(TypeClassName const& _typeClass);
std::optional<BuiltinClass> builtinClassFromToken(langutil::Token _token); std::optional<BuiltinClass> builtinClassFromToken(langutil::Token _token);
struct TypeSystemHelpers struct TypeSystemHelpers

View File

@ -1,6 +1,8 @@
pragma experimental solidity; pragma experimental solidity;
type uint256 = word; type uint256 = word;
/*
type double(a) = (a,a);
instantiation uint256: + { instantiation uint256: + {
function add(x, y) -> uint256 { function add(x, y) -> uint256 {
@ -51,18 +53,22 @@ instantiation word: == {
function f(x:uint256->uint256,y:uint256) -> uint256 function f(x:uint256->uint256,y:uint256) -> uint256
{ {
return x(y); x(y);
} }
*/
function g(x:uint256) -> uint256 function g(x) -> word
{ {
return x; return x;
} }
contract C { contract C {
fallback() external { fallback() external {
let z : uint256;
let x : word; let x : word;
assembly { let y;
y = (x:word);
x = g(x);
/*assembly {
x := 0x10 x := 0x10
} }
let w = uint256.abs(x); let w = uint256.abs(x);
@ -75,7 +81,7 @@ contract C {
assembly { assembly {
mstore(0, y) mstore(0, y)
return(0, 32) return(0, 32)
} }*/
} }
} }
// ==== // ====