slightly broken

This commit is contained in:
Daniel Kirchner 2023-07-03 00:01:34 +02:00
parent 987278385e
commit 32bcbc6cfb
4 changed files with 163 additions and 59 deletions

View File

@ -37,6 +37,7 @@ using namespace solidity::frontend::experimental;
ASTTransform::ASTTransform(Analysis& _analysis): m_analysis(_analysis), m_errorReporter(_analysis.errorReporter()), m_ast(make_unique<AST>())
{
m_ast->nodeById.resize(_analysis.maxAstId() + 1, nullptr);
}
bool ASTTransform::visit(legacy::TypeDefinition const& _typeDefinition)
@ -187,6 +188,7 @@ unique_ptr<Term> ASTTransform::term(legacy::Assignment const& _assignment)
unique_ptr<Term> ASTTransform::term(legacy::Block const& _block)
{
SetNode setNode(*this, _block);
if (auto statements = ranges::fold_right_last(
_block.statements() | ranges::view::transform([&](auto stmt) { return term(*stmt); }) | ranges::view::move,
[&](auto stmt, auto acc) {

View File

@ -152,6 +152,12 @@ private:
{
_parent.m_currentNode = &_node;
_parent.m_currentLocation = _node.location();
// TODO: error robustness
ASTNode const*& node = _parent.m_ast->nodeById.at(static_cast<size_t>(_node.id()));
if (node)
solAssert(node == &_node);
else
node = &_node;
}
SetNode(ASTTransform& _parent, langutil::SourceLocation const& _location):
m_parent(_parent),

View File

@ -28,6 +28,7 @@
#include <range/v3/view/map.hpp>
#include <range/v3/view/reverse.hpp>
#include <range/v3/iterator/insert_iterators.hpp>
using namespace std;
using namespace solidity;
@ -37,6 +38,90 @@ using namespace solidity::frontend::experimental;
namespace
{
void destTuple(Term& _term, list<reference_wrapper<Term>>& _components)
{
if (auto const* app = get_if<Application>(&_term))
if (auto* nestedApp = get_if<Application>(app->expression.get()))
if (auto* constant = get_if<Constant>(nestedApp->expression.get()))
if (constant->name == variant<string, BuiltinConstant>{BuiltinConstant::Pair})
{
_components.emplace_back(*nestedApp->argument);
destTuple(*app->argument, _components);
return;
}
_components.emplace_back(_term);
return;
}
void destTuple(Term const& _term, list<reference_wrapper<Term const>>& _components)
{
if (auto const* app = get_if<Application>(&_term))
if (auto* nestedApp = get_if<Application>(app->expression.get()))
if (auto* constant = get_if<Constant>(nestedApp->expression.get()))
if (constant->name == variant<string, BuiltinConstant>{BuiltinConstant::Pair})
{
_components.emplace_back(*nestedApp->argument);
destTuple(*app->argument, _components);
return;
}
_components.emplace_back(_term);
return;
}
std::optional<Reference> destReference(Term const& _term)
{
if (auto const* ref = get_if<Reference>(&_term))
return *ref;
return nullopt;
}
struct Context
{
Context(Analysis& _analysis, AST& _ast): m_analysis(_analysis), m_ast(_ast)
{
m_env = &m_analysis.typeSystem().env();
}
Analysis& analysis() { return m_analysis; }
langutil::ErrorReporter& errorReporter() { return m_analysis.errorReporter(); }
TypeSystem& typeSystem() { return m_env->typeSystem(); }
TypeEnvironment& env() { return *m_env; }
void unify(Type, Type) {}
Type type(Term const&) const { return Type{};}
frontend::ASTNode const* referencedNode(Term const& _term)
{
if (auto ref = destReference(_term))
return m_ast.nodeById.at(ref->index);
else
errorReporter().typeError(0000_error, locationOf(_term), "Expected reference.");
return nullptr;
}
private:
Analysis& m_analysis;
AST& m_ast;
TypeEnvironment* m_env = nullptr;
};
struct ASTTranslation
{
using Arguments = vector<reference_wrapper<std::unique_ptr<Term>>>;
template<typename R, typename... Args>
ASTTranslation(R _f(Context&, Args...)): m_translation([f = _f](Context& _context, Arguments _arguments) -> std::unique_ptr<Term> {
return invoke(_context, std::move(_arguments), f, std::make_index_sequence<sizeof...(Args)>{});
}), m_numArguments(sizeof...(Args)) {}
private:
template<typename Generator, size_t... Is>
static std::unique_ptr<Term> invoke(Context& _context, Arguments _arguments, Generator const& _generator, std::index_sequence<Is...>)
{
if (_arguments.size() == sizeof...(Is))
return _generator(_context, _arguments[Is].get()...);
else
return {};
}
std::function<std::unique_ptr<Term>(Context&, Arguments)> m_translation;
size_t m_numArguments = 0;
};
struct TPat
{
using Unifier = std::function<void(Type, Type)>;
@ -93,46 +178,82 @@ struct BuiltinConstantInfo
{
std::string name;
std::optional<TPat> builtinType;
std::optional<ASTTranslation> translation;
};
[[maybe_unused]] BuiltinConstantInfo const& builtinConstantInfo(BuiltinConstant _constant)
{
using namespace pattern_ops;
static const TPat unit{PrimitiveType::Unit};
static const auto info = std::map<BuiltinConstant, BuiltinConstantInfo>{
{BuiltinConstant::Unit, {"Unit", unit}},
{BuiltinConstant::Pair, {"Pair", +[](TPat a, TPat b) { return a >> (b >> tuple(a,b)); }}},
{BuiltinConstant::Fun, {"Fun", +[](TPat a, TPat b) { return tuple(a,b) >> (a >> b); }}},
{BuiltinConstant::Constrain, {"Constrain", +[](TPat a) { return tuple(a,a) >> a; }}},
{BuiltinConstant::NamedTerm, {"NamedTerm", +[](TPat a) { return tuple(unit, a) >> a; /* TODO: (name, a) >> a */ }}},
{BuiltinConstant::TypeDeclaration, {"TypeDeclaration", nullopt}},
{BuiltinConstant::TypeDefinition, {"TypeDefinition", +[](TPat type, TPat args, TPat value) {
return tuple(type, args, value) >> (args >> type);
}}},
{BuiltinConstant::TypeClassDefinition, {"TypeClassDefinition", nullopt}},
{BuiltinConstant::TypeClassInstantiation, {"TypeClassInstantiation", nullopt}},
{BuiltinConstant::FunctionDeclaration, {"FunctionDeclaration", nullopt}},
{BuiltinConstant::Unit, {"Unit", unit, nullopt}},
{BuiltinConstant::Pair, {
"Pair",
+[](TPat a, TPat b) {
return a >> (b >> tuple(a,b));
},
nullopt
}},
{BuiltinConstant::Fun, {"Fun", +[](TPat a, TPat b) { return tuple(a,b) >> (a >> b); }, nullopt}},
{BuiltinConstant::Constrain, {
"Constrain",
+[](TPat a) { return tuple(a,a) >> a; },
+[](Context& _context, std::unique_ptr<Term>& _term, std::unique_ptr<Term>& _constraint) {
_context.unify(termBase(*_term).type, _context.type(*_constraint));
return std::move(_term);
}
}},
{BuiltinConstant::NamedTerm, {"NamedTerm", +[](TPat a) { return tuple(unit, a) >> a; /* TODO: (name, a) >> a */ }, nullopt}},
{BuiltinConstant::TypeDeclaration, {"TypeDeclaration", nullopt, nullopt}},
{BuiltinConstant::TypeDefinition, {
"TypeDefinition",
+[](TPat type, TPat args, TPat value) {
return tuple(type, args, value) >> (args >> type);
},
+[](Context& _context, std::unique_ptr<Term>& _name, std::unique_ptr<Term>& _args, std::unique_ptr<Term>& _definiens) -> std::unique_ptr<Term> {
if (auto const* declaration = dynamic_cast<frontend::Declaration const*>(_context.referencedNode(*_name)))
{
std::string canonicalName = declaration->name();
if (auto const* annotation = dynamic_cast<frontend::TypeDeclarationAnnotation const*>(&declaration->annotation()))
canonicalName = *annotation->canonicalName;
else
_context.errorReporter().typeError(0000_error, locationOf(*_name), "Expected type declaration.");
std::list<reference_wrapper<Term>> arguments;
destTuple(*_args, arguments);
TypeConstructor constructor = _context.typeSystem().declareTypeConstructor(declaration->name(), canonicalName, arguments.size(), declaration);
std::vector<Type> argumentTypes = arguments | ranges::view::transform([](Term& _term) { return termBase(_term).type; }) | ranges::to<std::vector<Type>>;
termBase(*_name).type = _context.typeSystem().type(constructor, argumentTypes);
}
else
_context.errorReporter().typeError(0000_error, locationOf(*_name), "Expected declaration.");
(void)_definiens;
return nullptr;
}
}},
{BuiltinConstant::TypeClassDefinition, {"TypeClassDefinition", nullopt, nullopt}},
{BuiltinConstant::TypeClassInstantiation, {"TypeClassInstantiation", nullopt, nullopt}},
{BuiltinConstant::FunctionDeclaration, {"FunctionDeclaration", nullopt, nullopt}},
{BuiltinConstant::FunctionDefinition, {"FunctionDefinition", +[](TPat a, TPat r) {
return tuple(a >> r, a, r, r) >> (a >> r);
}}},
}, nullopt}},
{BuiltinConstant::ContractDefinition, {"ContractDefinition", +[]() {
return tuple(unit, (unit >> unit)) >> unit;
}}},
{BuiltinConstant::VariableDeclaration, {"VariableDeclaration", +[](TPat a) { return a >> a; }}},
{BuiltinConstant::VariableDefinition, {"VariableDefinition", nullopt}},
{BuiltinConstant::Block, {"Block", +[](TPat a) { return a >> a; }}},
{BuiltinConstant::ReturnStatement, {"ReturnStatement", +[](TPat a) { return a >> a; }}},
{BuiltinConstant::RegularStatement, {"RegularStatement", +[](TPat a) { return a >> unit; }}},
{BuiltinConstant::ChainStatements, {"ChainStatements", +[](TPat a, TPat b) { return tuple(a,b) >> b; }}},
{BuiltinConstant::Assign, {"Assign", +[](TPat a) { return tuple(a,a) >> unit; }}},
{BuiltinConstant::MemberAccess, {"MemberAccess", nullopt}},
{BuiltinConstant::Mul, {"Mul", nullopt}},
{BuiltinConstant::Add, {"Add", nullopt}},
{BuiltinConstant::Void, {"Void", nullopt}},
{BuiltinConstant::Word, {"Word", PrimitiveType::Word}},
{BuiltinConstant::Integer, {"Integer", nullopt}},
{BuiltinConstant::Bool, {"Bool", nullopt}},
{BuiltinConstant::Undefined, {"Undefined", nullopt}},
{BuiltinConstant::Equal, {"Equal", nullopt}},
}, nullopt}},
{BuiltinConstant::VariableDeclaration, {"VariableDeclaration", +[](TPat a) { return a >> a; }, nullopt}},
{BuiltinConstant::VariableDefinition, {"VariableDefinition", nullopt, nullopt}},
{BuiltinConstant::Block, {"Block", +[](TPat a) { return a >> a; }, nullopt}},
{BuiltinConstant::ReturnStatement, {"ReturnStatement", +[](TPat a) { return a >> a; }, nullopt}},
{BuiltinConstant::RegularStatement, {"RegularStatement", +[](TPat a) { return a >> unit; }, nullopt}},
{BuiltinConstant::ChainStatements, {"ChainStatements", +[](TPat a, TPat b) { return tuple(a,b) >> b; }, nullopt}},
{BuiltinConstant::Assign, {"Assign", +[](TPat a) { return tuple(a,a) >> unit; }, nullopt}},
{BuiltinConstant::MemberAccess, {"MemberAccess", nullopt, nullopt}},
{BuiltinConstant::Mul, {"Mul", nullopt, nullopt}},
{BuiltinConstant::Add, {"Add", nullopt, nullopt}},
{BuiltinConstant::Void, {"Void", nullopt, nullopt}},
{BuiltinConstant::Word, {"Word", PrimitiveType::Word, nullopt}},
{BuiltinConstant::Integer, {"Integer", nullopt, nullopt}},
{BuiltinConstant::Bool, {"Bool", nullopt, nullopt}},
{BuiltinConstant::Undefined, {"Undefined", nullopt, nullopt}},
{BuiltinConstant::Equal, {"Equal", nullopt, nullopt}},
};
return info.at(_constant);
}
@ -184,22 +305,6 @@ optional<pair<reference_wrapper<Term const>, reference_wrapper<Term const>>> des
return nullopt;
}
void destTuple(Term const& _term, list<reference_wrapper<Term const>>& _components)
{
list<reference_wrapper<Term const>> components;
if (auto const* app = get_if<Application>(&_term))
if (auto* nestedApp = get_if<Application>(app->expression.get()))
if (auto* constant = get_if<Constant>(nestedApp->expression.get()))
if (constant->name == variant<string, BuiltinConstant>{BuiltinConstant::Pair})
{
_components.emplace_back(*nestedApp->argument);
destTuple(*app->argument, _components);
return;
}
_components.emplace_back(_term);
return;
}
Type type(Term const& _term)
{
return std::visit([](auto& term) { return term.type; }, _term);
@ -215,7 +320,7 @@ string colorize(string _color, string _string)
return _color + _string + util::formatting::RESET;
}
string termPrinter(AST& _ast, Term const& _term, TypeEnvironment* _env = nullptr, bool _sugarPairs = true, bool _sugarConsts = true, size_t _indent = 0)
string termPrinter(AST& _ast, Term const& _term, TypeEnvironment* _env = nullptr, bool _sugarPairs = false, bool _sugarConsts = false, size_t _indent = 0)
{
using namespace util::formatting;
auto recurse = [&](Term const& _next) { return termPrinter(_ast, _next, _env, _sugarPairs, _sugarConsts, _indent); };
@ -229,7 +334,7 @@ string termPrinter(AST& _ast, Term const& _term, TypeEnvironment* _env = nullptr
list<reference_wrapper<Term const>> components;
destTuple(_term, components);
std::string result = "(";
result += termPrinter(_ast, components.front(), _env);
result += recurse(components.front());
components.pop_front();
for (auto const& component: components)
{
@ -320,7 +425,7 @@ string termPrinter(AST& _ast, Term const& _term, TypeEnvironment* _env = nullptr
return result;
}
std::string astPrinter(AST& _ast, TypeEnvironment* _env = nullptr, bool _sugarPairs = true, bool _sugarConsts = true, size_t _indent = 0)
std::string astPrinter(AST& _ast, TypeEnvironment* _env = nullptr, bool _sugarPairs = true, bool _sugarConsts = false, size_t _indent = 0)
{
std::string result;
auto printTerm = [&](Term const& _term) { result += termPrinter(_ast, _term, _env, _sugarPairs, _sugarConsts, _indent + 1) + "\n\n"; };
@ -391,16 +496,6 @@ void TypeCheck::operator()(AST& _ast)
auto unify = [&](Type _a, Type _b) { unifyForTerm(_a, _b, &term.get()); };
std::visit(util::GenericVisitor{
[&](Application const& _app) {
/*if (auto* constant = get_if<Constant>(_app.expression.get()))
if (auto* builtin = get_if<BuiltinConstant>(&constant->name))
if (*builtin == BuiltinConstant::Constrain)
if (auto args = destPair(*_app.argument))
{
Type result = type(args->first);
unify(result, type(args->second));
setType(term, result);
return;
}*/
Type resultType = typeSystem.freshTypeVariable({});
unify(helper.functionType(type(*_app.argument), resultType), type(*_app.expression));
setType(term, resultType);

View File

@ -153,6 +153,7 @@ struct AST
std::map<frontend::TypeClassInstantiation const*, std::unique_ptr<Term>, ASTCompareByID<frontend::TypeClassInstantiation>> typeClassInstantiations;
std::map<frontend::FunctionDefinition const*, std::unique_ptr<Term>, ASTCompareByID<frontend::FunctionDefinition>> functions;
std::map<frontend::ContractDefinition const*, std::unique_ptr<Term>, ASTCompareByID<frontend::ContractDefinition>> contracts;
std::vector<frontend::ASTNode const*> nodeById;
};
}