This commit is contained in:
chriseth 2022-03-01 23:49:03 +01:00
parent 797651c74b
commit b4dd0420ca
7 changed files with 141 additions and 43 deletions

View File

@ -97,31 +97,53 @@ void BooleanLPSolver::declareVariable(string const& _name, SortPointer const& _s
void BooleanLPSolver::addAssertion(Expression const& _expr) void BooleanLPSolver::addAssertion(Expression const& _expr)
{ {
solAssert(_expr.sort->kind == Kind::Bool);
if (_expr.arguments.empty()) if (_expr.arguments.empty())
state().clauses.emplace_back(Clause{vector<Literal>{*parseLiteral(_expr)}}); {
solAssert(isBooleanVariable(_expr.name));
state().clauses.emplace_back(Clause{Literal{
true,
state().variables.at(_expr.name)
}});
}
else if (_expr.name == "=") else if (_expr.name == "=")
{ {
// Try to see if both sides are linear. solAssert(_expr.arguments.at(0).sort->kind == _expr.arguments.at(1).sort->kind);
optional<LinearExpression> left = parseLinearSum(_expr.arguments.at(0)); if (_expr.arguments.at(0).sort->kind == Kind::Bool)
optional<LinearExpression> right = parseLinearSum(_expr.arguments.at(1));
if (left && right)
{ {
LinearExpression data = *left - *right; if (_expr.arguments.at(0).arguments.empty() && isBooleanVariable(_expr.arguments.at(0).name))
data[0] *= -1; addBooleanEquality(*parseLiteral(_expr.arguments.at(0)), _expr.arguments.at(1));
Constraint c{move(data), _expr.name == "=", {}}; else if (_expr.arguments.at(1).arguments.empty() && isBooleanVariable(_expr.arguments.at(1).name))
if (!tryAddDirectBounds(c)) addBooleanEquality(*parseLiteral(_expr.arguments.at(1)), _expr.arguments.at(0));
state().fixedConstraints.emplace_back(move(c)); else
{
Literal newBoolean = *parseLiteral(declareInternalVariable(true));
addBooleanEquality(newBoolean, _expr.arguments.at(0));
addBooleanEquality(newBoolean, _expr.arguments.at(1));
}
}
else if (_expr.arguments.at(0).sort->kind == Kind::Int)
{
// Try to see if both sides are linear.
optional<LinearExpression> left = parseLinearSum(_expr.arguments.at(0));
optional<LinearExpression> right = parseLinearSum(_expr.arguments.at(1));
if (left && right)
{
LinearExpression data = *left - *right;
data[0] *= -1;
Constraint c{move(data), _expr.name == "=", {}};
if (!tryAddDirectBounds(c))
state().fixedConstraints.emplace_back(move(c));
}
else
{
Expression left = parseLinearSumOrReturnEqualVariable(_expr.arguments.at(0));
Expression right = parseLinearSumOrReturnEqualVariable(_expr.arguments.at(1));
addAssertion(left == right);
}
} }
else if (_expr.arguments.at(0).arguments.empty() && isBooleanVariable(_expr.arguments.at(0).name))
addBooleanEquality(*parseLiteral(_expr.arguments.at(0)), _expr.arguments.at(1));
else if (_expr.arguments.at(1).arguments.empty() && isBooleanVariable(_expr.arguments.at(1).name))
addBooleanEquality(*parseLiteral(_expr.arguments.at(1)), _expr.arguments.at(0));
else else
{ solAssert(false);
Literal newBoolean = *parseLiteral(declareInternalBoolean());
addBooleanEquality(newBoolean, _expr.arguments.at(0));
addBooleanEquality(newBoolean, _expr.arguments.at(1));
}
} }
else if (_expr.name == "and") else if (_expr.name == "and")
{ {
@ -137,7 +159,7 @@ void BooleanLPSolver::addAssertion(Expression const& _expr)
{ {
// We cannot have more than one constraint per clause. // We cannot have more than one constraint per clause.
// TODO Why? // TODO Why?
right = *parseLiteral(declareInternalBoolean()); right = *parseLiteral(declareInternalVariable(true));
addBooleanEquality(right, _expr.arguments.at(1)); addBooleanEquality(right, _expr.arguments.at(1));
} }
state().clauses.emplace_back(Clause{vector<Literal>{left, right}}); state().clauses.emplace_back(Clause{vector<Literal>{left, right}});
@ -152,18 +174,20 @@ void BooleanLPSolver::addAssertion(Expression const& _expr)
{ {
optional<LinearExpression> left = parseLinearSum(_expr.arguments.at(0)); optional<LinearExpression> left = parseLinearSum(_expr.arguments.at(0));
optional<LinearExpression> right = parseLinearSum(_expr.arguments.at(1)); optional<LinearExpression> right = parseLinearSum(_expr.arguments.at(1));
if (!left || !right) if (left && right)
{ {
cout << "Unable to parse expression" << endl; LinearExpression data = *left - *right;
// TODO fail in some way data[0] *= -1;
return; Constraint c{move(data), _expr.name == "=", {}};
if (!tryAddDirectBounds(c))
state().fixedConstraints.emplace_back(move(c));
}
else
{
Expression left = parseLinearSumOrReturnEqualVariable(_expr.arguments.at(0));
Expression right = parseLinearSumOrReturnEqualVariable(_expr.arguments.at(1));
addAssertion(left <= right);
} }
LinearExpression data = *left - *right;
data[0] *= -1;
Constraint c{move(data), _expr.name == "=", {}};
if (!tryAddDirectBounds(c))
state().fixedConstraints.emplace_back(move(c));
} }
else if (_expr.name == ">=") else if (_expr.name == ">=")
addAssertion(_expr.arguments.at(1) <= _expr.arguments.at(0)); addAssertion(_expr.arguments.at(1) <= _expr.arguments.at(0));
@ -275,11 +299,11 @@ string BooleanLPSolver::toString() const
return result; return result;
} }
Expression BooleanLPSolver::declareInternalBoolean() Expression BooleanLPSolver::declareInternalVariable(bool _boolean)
{ {
string name = "$" + to_string(state().variables.size() + 1); string name = "$" + to_string(state().variables.size() + 1);
declareVariable(name, true); declareVariable(name, _boolean);
return smtutil::Expression(name, {}, SortProvider::boolSort); return smtutil::Expression(name, {}, _boolean ? SortProvider::boolSort : SortProvider::uintSort);
} }
void BooleanLPSolver::declareVariable(string const& _name, bool _boolean) void BooleanLPSolver::declareVariable(string const& _name, bool _boolean)
@ -350,17 +374,42 @@ Literal BooleanLPSolver::negate(Literal const& _lit)
Literal BooleanLPSolver::parseLiteralOrReturnEqualBoolean(Expression const& _expr) Literal BooleanLPSolver::parseLiteralOrReturnEqualBoolean(Expression const& _expr)
{ {
// TODO hen can this fail? solAssert(_expr.sort->kind == Kind::Bool);
// TODO when can this fail?
if (optional<Literal> literal = parseLiteral(_expr)) if (optional<Literal> literal = parseLiteral(_expr))
return *literal; return *literal;
else else
{ {
Literal newBoolean = *parseLiteral(declareInternalBoolean()); Literal newBoolean = *parseLiteral(declareInternalVariable(true));
addBooleanEquality(newBoolean, _expr); addBooleanEquality(newBoolean, _expr);
return newBoolean; return newBoolean;
} }
} }
Expression BooleanLPSolver::parseLinearSumOrReturnEqualVariable(Expression const& _expr)
{
solAssert(_expr.sort->kind == Kind::Int);
if (_expr.name == "ite")
{
Literal condition = parseLiteralOrReturnEqualBoolean(_expr.arguments.at(0));
// TODO this adds too many variables
Expression conditionBoolean = declareInternalVariable(true);
addBooleanEquality(condition, conditionBoolean);
Expression trueValue = parseLinearSumOrReturnEqualVariable(_expr.arguments.at(1));
Expression falseValue = parseLinearSumOrReturnEqualVariable(_expr.arguments.at(2));
Expression result = declareInternalVariable(false);
addAssertion(conditionBoolean || (result == trueValue));
addAssertion(!conditionBoolean || (result == falseValue));
return result;
}
if (_expr.arguments.empty())
return _expr;
solAssert(parseLinearSum(_expr));
Expression result = declareInternalVariable(false);
addAssertion(result == _expr);
return result;
}
optional<LinearExpression> BooleanLPSolver::parseLinearSum(smtutil::Expression const& _expr) const optional<LinearExpression> BooleanLPSolver::parseLinearSum(smtutil::Expression const& _expr) const
{ {
if (_expr.arguments.empty() || _expr.name == "*") if (_expr.arguments.empty() || _expr.name == "*")
@ -374,7 +423,12 @@ optional<LinearExpression> BooleanLPSolver::parseLinearSum(smtutil::Expression c
return _expr.name == "+" ? *left + *right : *left - *right; return _expr.name == "+" ? *left + *right : *left - *right;
} }
else else
{
// TOOD This should just resort to parseLinearSumOrReturn...
// and then use that variable
cout << "Invalid operator " << _expr.name << endl;
return std::nullopt; return std::nullopt;
}
} }
optional<LinearExpression> BooleanLPSolver::parseProduct(smtutil::Expression const& _expr) const optional<LinearExpression> BooleanLPSolver::parseProduct(smtutil::Expression const& _expr) const
@ -477,6 +531,7 @@ size_t BooleanLPSolver::addConditionalConstraint(Constraint _constraint)
void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expression const& _right) void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expression const& _right)
{ {
solAssert(_right.sort->kind == Kind::Bool);
if (optional<Literal> right = parseLiteral(_right)) if (optional<Literal> right = parseLiteral(_right))
{ {
// includes: not, <=, <, >=, >, boolean variables. // includes: not, <=, <, >=, >, boolean variables.
@ -501,7 +556,7 @@ void BooleanLPSolver::addBooleanEquality(Literal const& _left, smtutil::Expressi
{ {
// We cannot have more than one constraint per clause. // We cannot have more than one constraint per clause.
// TODO Why? // TODO Why?
b = *parseLiteral(declareInternalBoolean()); b = *parseLiteral(declareInternalVariable(true));
addBooleanEquality(b, _right.arguments.at(1)); addBooleanEquality(b, _right.arguments.at(1));
} }

View File

@ -47,12 +47,14 @@ struct State
}; };
/** /**
* Component that satisfies the SMT SolverInterface and uses an LP solver plus the DPLL * Component that satisfies the SMT SolverInterface and uses an LP solver plus the CDCL
* algorithm internally. * algorithm internally.
* It uses a rational relaxation of the integer program and thus will not be able to answer * It uses a rational relaxation of the integer program and thus will not be able to answer
* "satisfiable", but its answers are still correct. * "satisfiable", but its answers are still correct.
* *
* TODO are integers always non-negative? * Contrary to the usual SMT type system, it adds an implicit constraint for all variables
* and sub-expressions to be non-negative.
* TODO this does not apply to e.g. `x + y - something`
* *
* Integers are unbounded. * Integers are unbounded.
*/ */
@ -77,13 +79,15 @@ public:
private: private:
using rational = boost::rational<bigint>; using rational = boost::rational<bigint>;
smtutil::Expression declareInternalBoolean(); smtutil::Expression declareInternalVariable(bool _boolean);
void declareVariable(std::string const& _name, bool _boolean); void declareVariable(std::string const& _name, bool _boolean);
/// Parses an expression of sort bool and returns a literal.
std::optional<Literal> parseLiteral(smtutil::Expression const& _expr); std::optional<Literal> parseLiteral(smtutil::Expression const& _expr);
Literal negate(Literal const& _lit); Literal negate(Literal const& _lit);
Literal parseLiteralOrReturnEqualBoolean(smtutil::Expression const& _expr); Literal parseLiteralOrReturnEqualBoolean(smtutil::Expression const& _expr);
smtutil::Expression parseLinearSumOrReturnEqualVariable(smtutil::Expression const& _expr);
/// Parses the expression and expects a linear sum of variables. /// Parses the expression and expects a linear sum of variables.
/// Returns a vector with the first element being the constant and the /// Returns a vector with the first element being the constant and the

View File

@ -28,6 +28,8 @@
#include <libsolutil/CommonData.h> #include <libsolutil/CommonData.h>
#include <libsolutil/BooleanLP.h>
#include <utility> #include <utility>
#include <memory> #include <memory>
@ -40,7 +42,10 @@ using namespace solidity::smtutil;
void ReasoningBasedSimplifier::run(OptimiserStepContext& _context, Block& _ast) void ReasoningBasedSimplifier::run(OptimiserStepContext& _context, Block& _ast)
{ {
set<YulString> ssaVars = SSAValueTracker::ssaVariables(_ast); set<YulString> ssaVars = SSAValueTracker::ssaVariables(_ast);
ReasoningBasedSimplifier{_context.dialect, ssaVars}(_ast); ReasoningBasedSimplifier simpl{_context.dialect, ssaVars};
// Hack to inject the boolean lp solver.
simpl.m_solver = make_unique<BooleanLPSolver>();
simpl(_ast);
} }
std::optional<string> ReasoningBasedSimplifier::invalidInCurrentEnvironment() std::optional<string> ReasoningBasedSimplifier::invalidInCurrentEnvironment()
@ -120,12 +125,21 @@ smtutil::Expression ReasoningBasedSimplifier::encodeEVMBuiltin(
switch (_instruction) switch (_instruction)
{ {
case evmasm::Instruction::ADD: case evmasm::Instruction::ADD:
return wrap(arguments.at(0) + arguments.at(1)); {
auto result = arguments.at(0) + arguments.at(1) - (bigint(1) << 256) * newZeroOneVariable();
restrictToEVMWord(result);
return result;
}
case evmasm::Instruction::MUL: case evmasm::Instruction::MUL:
return wrap(arguments.at(0) * arguments.at(1)); return wrap(arguments.at(0) * arguments.at(1));
case evmasm::Instruction::SUB: case evmasm::Instruction::SUB:
return wrap(arguments.at(0) - arguments.at(1)); {
auto result = arguments.at(0) - arguments.at(1) + (bigint(1) << 256) * newZeroOneVariable();
restrictToEVMWord(result);
return result;
}
case evmasm::Instruction::DIV: case evmasm::Instruction::DIV:
// TODO add assertion that result is <= input
return smtutil::Expression::ite( return smtutil::Expression::ite(
arguments.at(1) == constantValue(0), arguments.at(1) == constantValue(0),
constantValue(0), constantValue(0),
@ -224,3 +238,15 @@ smtutil::Expression ReasoningBasedSimplifier::encodeEVMBuiltin(
} }
return newRestrictedVariable(); return newRestrictedVariable();
} }
smtutil::Expression ReasoningBasedSimplifier::newZeroOneVariable()
{
smtutil::Expression var = newVariable();
m_solver->addAssertion(var == 0 || var == 1);
return var;
}
void ReasoningBasedSimplifier::restrictToEVMWord(smtutil::Expression _value)
{
m_solver->addAssertion(0 <= _value && _value < bigint(1) << 256);
}

View File

@ -69,6 +69,9 @@ private:
std::vector<Expression> const& _arguments std::vector<Expression> const& _arguments
) override; ) override;
smtutil::Expression newZeroOneVariable();
void restrictToEVMWord(smtutil::Expression _value);
Dialect const& m_dialect; Dialect const& m_dialect;
}; };

View File

@ -81,6 +81,11 @@ smtutil::Expression SMTSolver::newVariable()
return m_solver->newVariable(uniqueName(), defaultSort()); return m_solver->newVariable(uniqueName(), defaultSort());
} }
smtutil::Expression SMTSolver::newBooleanVariable()
{
return m_solver->newVariable(uniqueName(), SortProvider::boolSort);
}
smtutil::Expression SMTSolver::newRestrictedVariable(bigint _maxValue) smtutil::Expression SMTSolver::newRestrictedVariable(bigint _maxValue)
{ {
smtutil::Expression var = newVariable(); smtutil::Expression var = newVariable();
@ -100,6 +105,7 @@ shared_ptr<Sort> SMTSolver::defaultSort() const
smtutil::Expression SMTSolver::booleanValue(smtutil::Expression _value) smtutil::Expression SMTSolver::booleanValue(smtutil::Expression _value)
{ {
// TODO should not use ite
return smtutil::Expression::ite(_value, constantValue(1), constantValue(0)); return smtutil::Expression::ite(_value, constantValue(1), constantValue(0));
} }
@ -115,6 +121,7 @@ smtutil::Expression SMTSolver::literalValue(Literal const& _literal)
smtutil::Expression SMTSolver::twosComplementToSigned(smtutil::Expression _value) smtutil::Expression SMTSolver::twosComplementToSigned(smtutil::Expression _value)
{ {
// TODO will that work for LP?
return smtutil::Expression::ite( return smtutil::Expression::ite(
_value < smtutil::Expression(bigint(1) << 255), _value < smtutil::Expression(bigint(1) << 255),
_value, _value,
@ -136,6 +143,7 @@ smtutil::Expression SMTSolver::wrap(smtutil::Expression _value)
smtutil::Expression rest = newRestrictedVariable(); smtutil::Expression rest = newRestrictedVariable();
smtutil::Expression multiplier = newVariable(); smtutil::Expression multiplier = newVariable();
m_solver->addAssertion(_value == multiplier * smtutil::Expression(bigint(1) << 256) + rest); m_solver->addAssertion(_value == multiplier * smtutil::Expression(bigint(1) << 256) + rest);
m_solver->addAssertion(0 <= rest && rest < bigint(1) << 256);
return rest; return rest;
} }

View File

@ -68,6 +68,7 @@ protected:
static smtutil::Expression bv2int(smtutil::Expression _arg); static smtutil::Expression bv2int(smtutil::Expression _arg);
smtutil::Expression newVariable(); smtutil::Expression newVariable();
smtutil::Expression newBooleanVariable();
virtual smtutil::Expression newRestrictedVariable(bigint _maxValue = (bigint(1) << 256) - 1); virtual smtutil::Expression newRestrictedVariable(bigint _maxValue = (bigint(1) << 256) - 1);
std::string uniqueName(); std::string uniqueName();

View File

@ -53,12 +53,13 @@ YulOptimizerTest::YulOptimizerTest(string const& _filename):
BOOST_THROW_EXCEPTION(runtime_error("Filename path has to contain a directory: \"" + _filename + "\".")); BOOST_THROW_EXCEPTION(runtime_error("Filename path has to contain a directory: \"" + _filename + "\"."));
m_optimizerStep = std::prev(std::prev(path.end()))->string(); m_optimizerStep = std::prev(std::prev(path.end()))->string();
/*
if (m_optimizerStep == "reasoningBasedSimplifier" && ( if (m_optimizerStep == "reasoningBasedSimplifier" && (
solidity::test::CommonOptions::get().disableSMT || solidity::test::CommonOptions::get().disableSMT ||
ReasoningBasedSimplifier::invalidInCurrentEnvironment() ReasoningBasedSimplifier::invalidInCurrentEnvironment()
)) ))
m_shouldRun = false; m_shouldRun = false;
*/
m_source = m_reader.source(); m_source = m_reader.source();
auto dialectName = m_reader.stringSetting("dialect", "evm"); auto dialectName = m_reader.stringSetting("dialect", "evm");