Custom knowledge base.

This commit is contained in:
chriseth 2022-06-20 16:01:14 +02:00
parent f94b85e841
commit 696095c9b0
2 changed files with 206 additions and 34 deletions

View File

@ -27,13 +27,210 @@
#include <libyul/optimiser/DataFlowAnalyzer.h>
#include <libsolutil/CommonData.h>
#include <libsolutil/Visitor.h>
#include <libyul/AsmPrinter.h>
#include <libyul/backends/evm/EVMDialect.h>
#include <variant>
#include <functional>
#include <queue>
using namespace std;
using namespace solidity;
using namespace solidity::yul;
namespace
{
struct SumExpression;
SumExpression clean(SumExpression _in);
/**
* Expression of the form k0 + k1 * x2 + x2 * x2 + ...
* where the ki are u256 constants and the xi are variables.
* The constant term is using the empty yul string.
*/
struct SumExpression
{
static SumExpression variable(YulString _name, u256 _multiplicity = u256(1))
{
SumExpression result;
result.coefficients[_name] = move(_multiplicity);
return result;
}
static SumExpression constant(u256 _value)
{
return variable(YulString{}, move(_value));
}
optional<u256> isConstant() const
{
if (coefficients.empty())
return u256(0);
else if (coefficients.size() == 1 && coefficients.begin()->first == YulString{})
return coefficients.begin()->second;
else
return nullopt;
}
SumExpression operator+(SumExpression const& _other)
{
SumExpression result = *this;
for (auto&& [var, value]: _other.coefficients)
result.coefficients[var] += value;
return clean(move(result));
}
SumExpression operator*(u256 const& _factor) const
{
if (!_factor)
return SumExpression{};
if (_factor == 1)
return *this;
SumExpression result;
for (auto&& [var, value]: coefficients)
result.coefficients[var] = value * _factor;
return result;
}
map<YulString, u256> coefficients;
};
SumExpression clean(SumExpression _in)
{
SumExpression result;
for (auto&& [var, value]: _in.coefficients)
if (value)
result.coefficients[var] = move(value);
return result;
}
optional<SumExpression> operator+(optional<SumExpression> const& _a, optional<SumExpression> const& _b)
{
if (!_a || !_b)
return nullopt;
return *_a + *_b;
}
optional<SumExpression> operator-(optional<SumExpression> const& _a, optional<SumExpression> const& _b)
{
if (!_a || !_b)
return nullopt;
SumExpression result = *_a;
for (auto&& [var, value]: _b->coefficients)
result.coefficients[var] -= value;
return clean(move(result));
}
class SimpleLinearSolver
{
public:
static optional<u256> simplify(
EVMDialect const& _dialect,
std::function<AssignedValue const*(YulString)> _variableValues,
Expression const& _expr
)
{
SimpleLinearSolver solver(_dialect, _variableValues);
return solver.simplify(_expr);
}
private:
optional<u256> simplify(Expression const& _expr)
{
auto value = toSumExpression(_expr);
if (!value)
return nullopt;
while (true)
{
if (auto v = value->isConstant())
return *v;
// TODO this will depend on the sorting order of the variables. This is bad and needs to be fixed.
for (auto&& [var, value]: value->coefficients)
if (var != YulString{} && !m_expandedVariables.count(var) && !m_expandedFailedVariables.count(var))
m_variablesToExpand.push(var);
if (m_variablesToExpand.empty())
return nullopt;
YulString var = m_variablesToExpand.front();
m_variablesToExpand.pop();
expandVariable(var, *value);
}
}
private:
optional<SumExpression> toSumExpression(Expression const& _expr)
{
return std::visit(util::GenericVisitor{
[&](FunctionCall const& _funCall) -> optional<SumExpression> {
if (BuiltinFunctionForEVM const* builtin = m_dialect.builtin(_funCall.functionName.name))
{
if (builtin->instruction == evmasm::Instruction::ADD)
return toSumExpression(_funCall.arguments.at(0)) + toSumExpression(_funCall.arguments.at(1));
else if (builtin->instruction == evmasm::Instruction::SUB)
return toSumExpression(_funCall.arguments.at(0)) - toSumExpression(_funCall.arguments.at(1));
else
return std::nullopt;
// TODO we could also use multiplication by constants.
}
return std::nullopt;
},
[&](Identifier const& _identifier) -> optional<SumExpression> {
if (m_expandedVariables.count(_identifier.name))
return m_expandedVariables.at(_identifier.name);
else
return SumExpression::variable(_identifier.name);
},
[&](Literal const& _literal) -> optional<SumExpression> {
return SumExpression::constant(valueOfLiteral(_literal));
}
}, _expr);
}
void expandVariable(YulString _variable, SumExpression& _currentExpression)
{
if (m_expandedFailedVariables.count(_variable) || m_expandedVariables.count(_variable))
return;
if (auto assignedValue = m_variableValues(_variable))
if (assignedValue->value)
if (auto newValue = toSumExpression(*assignedValue->value))
{
// TODO this will be exponential.
for (auto& [variable, value]: m_expandedVariables)
expandInExpression(value, _variable, *newValue);
expandInExpression(_currentExpression, _variable, *newValue);
m_expandedVariables[_variable] = move(*newValue);
return;
}
m_expandedFailedVariables.insert(_variable);
}
void expandInExpression(SumExpression& _expr, YulString _variable, SumExpression const& _value)
{
if (!_expr.coefficients.count(_variable))
return;
u256 coefficient = _expr.coefficients[_variable];
_expr.coefficients.erase(_variable);
_expr = _expr + _value * coefficient;
}
SimpleLinearSolver(
EVMDialect const& _dialect,
std::function<AssignedValue const*(YulString)> _variableValues
): m_dialect(_dialect), m_variableValues(_variableValues)
{}
EVMDialect const& m_dialect;
std::function<AssignedValue const*(YulString)> m_variableValues;
/// Queue of variables we can still expand in the future.
queue<YulString> m_variablesToExpand;
/// Set of variables we expanded in the past and we should directly expand when we
/// encounter them when expanding other variables.
map<YulString, SumExpression> m_expandedVariables;
/// Set of variables we should not expand because their expansion is not linear.
set<YulString> m_expandedFailedVariables;
};
}
bool KnowledgeBase::knownToBeDifferent(YulString _a, YulString _b)
{
// Try to use the simplification rules together with the
@ -43,30 +240,19 @@ bool KnowledgeBase::knownToBeDifferent(YulString _a, YulString _b)
if (optional<u256> difference = differenceIfKnownConstant(_a, _b))
return difference != 0;
Expression expr2 = simplify(FunctionCall{{}, {{}, "eq"_yulstring}, util::make_vector<Expression>(Identifier{{}, _a}, Identifier{{}, _b})});
if (holds_alternative<Literal>(expr2))
return valueOfLiteral(std::get<Literal>(expr2)) == 0;
// TOOD this is not possible anymore.
// Expression expr2 = simplify(FunctionCall{{}, {{}, "eq"_yulstring}, util::make_vector<Expression>(Identifier{{}, _a}, Identifier{{}, _b})});
return false;
}
optional<u256> KnowledgeBase::differenceIfKnownConstant(YulString _a, YulString _b)
{
// Try to use the simplification rules together with the
// current values to turn `sub(_a, _b)` into a constant.
Expression expr1 = simplify(FunctionCall{{}, {{}, "sub"_yulstring}, util::make_vector<Expression>(Identifier{{}, _a}, Identifier{{}, _b})});
if (Literal const* value = get_if<Literal>(&expr1))
return valueOfLiteral(*value);
return {};
return simplify(FunctionCall{{}, {{}, "sub"_yulstring}, util::make_vector<Expression>(Identifier{{}, _a}, Identifier{{}, _b})});
}
bool KnowledgeBase::knownToBeDifferentByAtLeast32(YulString _a, YulString _b)
{
// Try to use the simplification rules together with the
// current values to turn `sub(_a, _b)` into a constant whose absolute value is at least 32.
if (optional<u256> difference = differenceIfKnownConstant(_a, _b))
return difference >= 32 && difference <= u256(0) - 32;
@ -86,23 +272,10 @@ optional<u256> KnowledgeBase::valueIfKnownConstant(YulString _a)
return {};
}
Expression KnowledgeBase::simplify(Expression _expression)
optional<u256> KnowledgeBase::simplify(Expression _expression)
{
m_counter = 0;
return simplifyRecursively(move(_expression));
}
Expression KnowledgeBase::simplifyRecursively(Expression _expression)
{
if (m_counter++ > 100)
return _expression;
if (holds_alternative<FunctionCall>(_expression))
for (Expression& arg: std::get<FunctionCall>(_expression).arguments)
arg = simplifyRecursively(arg);
if (auto match = SimplificationRules::findFirstMatch(_expression, m_dialect, m_variableValues))
return simplifyRecursively(match->action().toExpression(debugDataOf(_expression)));
return _expression;
if (auto dialect = dynamic_cast<EVMDialect const*>(&m_dialect))
return SimpleLinearSolver::simplify(*dialect, m_variableValues, _expression);
else
return nullopt;
}

View File

@ -58,8 +58,7 @@ public:
std::optional<u256> valueIfKnownConstant(YulString _a);
private:
Expression simplify(Expression _expression);
Expression simplifyRecursively(Expression _expression);
std::optional<u256> simplify(Expression _expression);
Dialect const& m_dialect;
std::function<AssignedValue const*(YulString)> m_variableValues;