From 696095c9b00fd13332a9d3a577202bf5b21fe90a Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 20 Jun 2022 16:01:14 +0200 Subject: [PATCH] Custom knowledge base. --- libyul/optimiser/KnowledgeBase.cpp | 237 +++++++++++++++++++++++++---- libyul/optimiser/KnowledgeBase.h | 3 +- 2 files changed, 206 insertions(+), 34 deletions(-) diff --git a/libyul/optimiser/KnowledgeBase.cpp b/libyul/optimiser/KnowledgeBase.cpp index 460f6707f..6a1239e34 100644 --- a/libyul/optimiser/KnowledgeBase.cpp +++ b/libyul/optimiser/KnowledgeBase.cpp @@ -27,13 +27,210 @@ #include #include +#include + +#include +#include #include +#include +#include using namespace std; using namespace solidity; using namespace solidity::yul; +namespace +{ +struct SumExpression; +SumExpression clean(SumExpression _in); + +/** + * Expression of the form k0 + k1 * x2 + x2 * x2 + ... + * where the ki are u256 constants and the xi are variables. + * The constant term is using the empty yul string. + */ +struct SumExpression +{ + static SumExpression variable(YulString _name, u256 _multiplicity = u256(1)) + { + SumExpression result; + result.coefficients[_name] = move(_multiplicity); + return result; + } + static SumExpression constant(u256 _value) + { + return variable(YulString{}, move(_value)); + } + optional isConstant() const + { + if (coefficients.empty()) + return u256(0); + else if (coefficients.size() == 1 && coefficients.begin()->first == YulString{}) + return coefficients.begin()->second; + else + return nullopt; + } + SumExpression operator+(SumExpression const& _other) + { + SumExpression result = *this; + for (auto&& [var, value]: _other.coefficients) + result.coefficients[var] += value; + return clean(move(result)); + } + SumExpression operator*(u256 const& _factor) const + { + if (!_factor) + return SumExpression{}; + if (_factor == 1) + return *this; + + SumExpression result; + for (auto&& [var, value]: coefficients) + result.coefficients[var] = value * _factor; + return result; + } + + map coefficients; +}; + +SumExpression clean(SumExpression _in) +{ + SumExpression result; + for (auto&& [var, value]: _in.coefficients) + if (value) + result.coefficients[var] = move(value); + return result; +} + +optional operator+(optional const& _a, optional const& _b) +{ + if (!_a || !_b) + return nullopt; + return *_a + *_b; +} +optional operator-(optional const& _a, optional const& _b) +{ + if (!_a || !_b) + return nullopt; + SumExpression result = *_a; + for (auto&& [var, value]: _b->coefficients) + result.coefficients[var] -= value; + return clean(move(result)); +} + +class SimpleLinearSolver +{ +public: + static optional simplify( + EVMDialect const& _dialect, + std::function _variableValues, + Expression const& _expr + ) + { + SimpleLinearSolver solver(_dialect, _variableValues); + return solver.simplify(_expr); + } + +private: + optional simplify(Expression const& _expr) + { + auto value = toSumExpression(_expr); + if (!value) + return nullopt; + + while (true) + { + if (auto v = value->isConstant()) + return *v; + // TODO this will depend on the sorting order of the variables. This is bad and needs to be fixed. + for (auto&& [var, value]: value->coefficients) + if (var != YulString{} && !m_expandedVariables.count(var) && !m_expandedFailedVariables.count(var)) + m_variablesToExpand.push(var); + if (m_variablesToExpand.empty()) + return nullopt; + YulString var = m_variablesToExpand.front(); + m_variablesToExpand.pop(); + expandVariable(var, *value); + } + } + +private: + optional toSumExpression(Expression const& _expr) + { + return std::visit(util::GenericVisitor{ + [&](FunctionCall const& _funCall) -> optional { + if (BuiltinFunctionForEVM const* builtin = m_dialect.builtin(_funCall.functionName.name)) + { + if (builtin->instruction == evmasm::Instruction::ADD) + return toSumExpression(_funCall.arguments.at(0)) + toSumExpression(_funCall.arguments.at(1)); + else if (builtin->instruction == evmasm::Instruction::SUB) + return toSumExpression(_funCall.arguments.at(0)) - toSumExpression(_funCall.arguments.at(1)); + else + return std::nullopt; + // TODO we could also use multiplication by constants. + } + return std::nullopt; + }, + [&](Identifier const& _identifier) -> optional { + if (m_expandedVariables.count(_identifier.name)) + return m_expandedVariables.at(_identifier.name); + else + return SumExpression::variable(_identifier.name); + }, + [&](Literal const& _literal) -> optional { + return SumExpression::constant(valueOfLiteral(_literal)); + } + }, _expr); + } + + void expandVariable(YulString _variable, SumExpression& _currentExpression) + { + if (m_expandedFailedVariables.count(_variable) || m_expandedVariables.count(_variable)) + return; + if (auto assignedValue = m_variableValues(_variable)) + if (assignedValue->value) + if (auto newValue = toSumExpression(*assignedValue->value)) + { + // TODO this will be exponential. + for (auto& [variable, value]: m_expandedVariables) + expandInExpression(value, _variable, *newValue); + expandInExpression(_currentExpression, _variable, *newValue); + m_expandedVariables[_variable] = move(*newValue); + return; + } + m_expandedFailedVariables.insert(_variable); + } + + void expandInExpression(SumExpression& _expr, YulString _variable, SumExpression const& _value) + { + if (!_expr.coefficients.count(_variable)) + return; + u256 coefficient = _expr.coefficients[_variable]; + _expr.coefficients.erase(_variable); + _expr = _expr + _value * coefficient; + } + + SimpleLinearSolver( + EVMDialect const& _dialect, + std::function _variableValues + ): m_dialect(_dialect), m_variableValues(_variableValues) + {} + + EVMDialect const& m_dialect; + std::function m_variableValues; + + /// Queue of variables we can still expand in the future. + queue m_variablesToExpand; + /// Set of variables we expanded in the past and we should directly expand when we + /// encounter them when expanding other variables. + map m_expandedVariables; + /// Set of variables we should not expand because their expansion is not linear. + set m_expandedFailedVariables; +}; + +} + bool KnowledgeBase::knownToBeDifferent(YulString _a, YulString _b) { // Try to use the simplification rules together with the @@ -43,30 +240,19 @@ bool KnowledgeBase::knownToBeDifferent(YulString _a, YulString _b) if (optional difference = differenceIfKnownConstant(_a, _b)) return difference != 0; - Expression expr2 = simplify(FunctionCall{{}, {{}, "eq"_yulstring}, util::make_vector(Identifier{{}, _a}, Identifier{{}, _b})}); - if (holds_alternative(expr2)) - return valueOfLiteral(std::get(expr2)) == 0; + // TOOD this is not possible anymore. + // Expression expr2 = simplify(FunctionCall{{}, {{}, "eq"_yulstring}, util::make_vector(Identifier{{}, _a}, Identifier{{}, _b})}); return false; } optional KnowledgeBase::differenceIfKnownConstant(YulString _a, YulString _b) { - // Try to use the simplification rules together with the - // current values to turn `sub(_a, _b)` into a constant. - - Expression expr1 = simplify(FunctionCall{{}, {{}, "sub"_yulstring}, util::make_vector(Identifier{{}, _a}, Identifier{{}, _b})}); - if (Literal const* value = get_if(&expr1)) - return valueOfLiteral(*value); - - return {}; + return simplify(FunctionCall{{}, {{}, "sub"_yulstring}, util::make_vector(Identifier{{}, _a}, Identifier{{}, _b})}); } bool KnowledgeBase::knownToBeDifferentByAtLeast32(YulString _a, YulString _b) { - // Try to use the simplification rules together with the - // current values to turn `sub(_a, _b)` into a constant whose absolute value is at least 32. - if (optional difference = differenceIfKnownConstant(_a, _b)) return difference >= 32 && difference <= u256(0) - 32; @@ -86,23 +272,10 @@ optional KnowledgeBase::valueIfKnownConstant(YulString _a) return {}; } -Expression KnowledgeBase::simplify(Expression _expression) +optional KnowledgeBase::simplify(Expression _expression) { - m_counter = 0; - return simplifyRecursively(move(_expression)); -} - -Expression KnowledgeBase::simplifyRecursively(Expression _expression) -{ - if (m_counter++ > 100) - return _expression; - - if (holds_alternative(_expression)) - for (Expression& arg: std::get(_expression).arguments) - arg = simplifyRecursively(arg); - - if (auto match = SimplificationRules::findFirstMatch(_expression, m_dialect, m_variableValues)) - return simplifyRecursively(match->action().toExpression(debugDataOf(_expression))); - - return _expression; + if (auto dialect = dynamic_cast(&m_dialect)) + return SimpleLinearSolver::simplify(*dialect, m_variableValues, _expression); + else + return nullopt; } diff --git a/libyul/optimiser/KnowledgeBase.h b/libyul/optimiser/KnowledgeBase.h index 999d0e312..e0c0cd4cd 100644 --- a/libyul/optimiser/KnowledgeBase.h +++ b/libyul/optimiser/KnowledgeBase.h @@ -58,8 +58,7 @@ public: std::optional valueIfKnownConstant(YulString _a); private: - Expression simplify(Expression _expression); - Expression simplifyRecursively(Expression _expression); + std::optional simplify(Expression _expression); Dialect const& m_dialect; std::function m_variableValues;