diff --git a/libsolidity/formal/SMTChecker.cpp b/libsolidity/formal/SMTChecker.cpp index 8639317ba..a4d9500bc 100644 --- a/libsolidity/formal/SMTChecker.cpp +++ b/libsolidity/formal/SMTChecker.cpp @@ -68,7 +68,7 @@ bool SMTChecker::visit(ContractDefinition const& _contract) void SMTChecker::endVisit(ContractDefinition const&) { - m_stateVariables.clear(); + m_variables.clear(); } void SMTChecker::endVisit(VariableDeclaration const& _varDecl) @@ -86,12 +86,10 @@ bool SMTChecker::visit(FunctionDefinition const& _function) ); m_currentFunction = &_function; m_interface->reset(); - m_variables.clear(); - m_variables.insert(m_stateVariables.begin(), m_stateVariables.end()); m_pathConditions.clear(); m_loopExecutionHappened = false; - initializeLocalVariables(_function); resetStateVariables(); + initializeLocalVariables(_function); return true; } @@ -100,6 +98,7 @@ void SMTChecker::endVisit(FunctionDefinition const&) // TOOD we could check for "reachability", i.e. satisfiability here. // We only handle local variables, so we clear at the beginning of the function. // If we add storage variables, those should be cleared differently. + removeLocalVariables(); m_currentFunction = nullptr; } @@ -110,7 +109,7 @@ bool SMTChecker::visit(IfStatement const& _node) checkBooleanNotConstant(_node.condition(), "Condition is always $VALUE."); auto countersEndTrue = visitBranch(_node.trueStatement(), expr(_node.condition())); - vector touchedVariables = m_variableUsage->touchedVariables(_node.trueStatement()); + vector touchedVariables = m_variableUsage->touchedVariables(_node.trueStatement()); decltype(countersEndTrue) countersEndFalse; if (_node.falseStatement()) { @@ -230,10 +229,10 @@ void SMTChecker::endVisit(Assignment const& _assignment) ); else if (Identifier const* identifier = dynamic_cast(&_assignment.leftHandSide())) { - Declaration const* decl = identifier->annotation().referencedDeclaration; - if (knownVariable(*decl)) + VariableDeclaration const& decl = dynamic_cast(*identifier->annotation().referencedDeclaration); + if (knownVariable(decl)) { - assignment(*decl, _assignment.rightHandSide(), _assignment.location()); + assignment(decl, _assignment.rightHandSide(), _assignment.location()); defineExpr(_assignment, expr(_assignment.rightHandSide())); } else @@ -296,12 +295,12 @@ void SMTChecker::endVisit(UnaryOperation const& _op) solAssert(_op.subExpression().annotation().lValueRequested, ""); if (Identifier const* identifier = dynamic_cast(&_op.subExpression())) { - Declaration const* decl = identifier->annotation().referencedDeclaration; - if (knownVariable(*decl)) + VariableDeclaration const& decl = dynamic_cast(*identifier->annotation().referencedDeclaration); + if (knownVariable(decl)) { - auto innerValue = currentValue(*decl); + auto innerValue = currentValue(decl); auto newValue = _op.getOperator() == Token::Inc ? innerValue + 1 : innerValue - 1; - assignment(*decl, newValue, _op.location()); + assignment(decl, newValue, _op.location()); defineExpr(_op, _op.isPrefixOperation() ? newValue : innerValue); } else @@ -383,14 +382,15 @@ void SMTChecker::endVisit(FunctionCall const& _funCall) void SMTChecker::endVisit(Identifier const& _identifier) { - Declaration const* decl = _identifier.annotation().referencedDeclaration; - solAssert(decl, ""); if (_identifier.annotation().lValueRequested) { // Will be translated as part of the node that requested the lvalue. } else if (SSAVariable::isSupportedType(_identifier.annotation().type->category())) - defineExpr(_identifier, currentValue(*decl)); + { + VariableDeclaration const& decl = dynamic_cast(*(_identifier.annotation().referencedDeclaration)); + defineExpr(_identifier, currentValue(decl)); + } else if (FunctionType const* fun = dynamic_cast(_identifier.annotation().type.get())) { if (fun->kind() == FunctionType::Kind::Assert || fun->kind() == FunctionType::Kind::Require) @@ -530,12 +530,12 @@ smt::Expression SMTChecker::division(smt::Expression _left, smt::Expression _rig return _left / _right; } -void SMTChecker::assignment(Declaration const& _variable, Expression const& _value, SourceLocation const& _location) +void SMTChecker::assignment(VariableDeclaration const& _variable, Expression const& _value, SourceLocation const& _location) { assignment(_variable, expr(_value), _location); } -void SMTChecker::assignment(Declaration const& _variable, smt::Expression const& _value, SourceLocation const& _location) +void SMTChecker::assignment(VariableDeclaration const& _variable, smt::Expression const& _value, SourceLocation const& _location) { TypePointer type = _variable.type(); if (auto const* intType = dynamic_cast(type.get())) @@ -583,19 +583,7 @@ void SMTChecker::checkCondition( expressionsToEvaluate.emplace_back(*_additionalValue); expressionNames.push_back(_additionalValueName); } - for (auto const& param: m_currentFunction->parameters()) - if (knownVariable(*param)) - { - expressionsToEvaluate.emplace_back(currentValue(*param)); - expressionNames.push_back(param->name()); - } - for (auto const& var: m_currentFunction->localVariables()) - if (knownVariable(*var)) - { - expressionsToEvaluate.emplace_back(currentValue(*var)); - expressionNames.push_back(var->name()); - } - for (auto const& var: m_stateVariables) + for (auto const& var: m_variables) if (knownVariable(*var.first)) { expressionsToEvaluate.emplace_back(currentValue(*var.first)); @@ -740,14 +728,17 @@ void SMTChecker::initializeLocalVariables(FunctionDefinition const& _function) void SMTChecker::resetStateVariables() { - for (auto const& variable: m_stateVariables) + for (auto const& variable: m_variables) { - newValue(*variable.first); - setUnknownValue(*variable.first); + if (variable.first->isStateVariable()) + { + newValue(*variable.first); + setUnknownValue(*variable.first); + } } } -void SMTChecker::resetVariables(vector _variables) +void SMTChecker::resetVariables(vector _variables) { for (auto const* decl: _variables) { @@ -756,9 +747,9 @@ void SMTChecker::resetVariables(vector _variables) } } -void SMTChecker::mergeVariables(vector const& _variables, smt::Expression const& _condition, VariableSequenceCounters const& _countersEndTrue, VariableSequenceCounters const& _countersEndFalse) +void SMTChecker::mergeVariables(vector const& _variables, smt::Expression const& _condition, VariableSequenceCounters const& _countersEndTrue, VariableSequenceCounters const& _countersEndFalse) { - set uniqueVars(_variables.begin(), _variables.end()); + set uniqueVars(_variables.begin(), _variables.end()); for (auto const* decl: uniqueVars) { int trueCounter = _countersEndTrue.at(decl).index(); @@ -777,14 +768,7 @@ bool SMTChecker::createVariable(VariableDeclaration const& _varDecl) if (SSAVariable::isSupportedType(_varDecl.type()->category())) { solAssert(m_variables.count(&_varDecl) == 0, ""); - solAssert(m_stateVariables.count(&_varDecl) == 0, ""); - if (_varDecl.isLocalVariable()) - m_variables.emplace(&_varDecl, SSAVariable(_varDecl, *m_interface)); - else - { - solAssert(_varDecl.isStateVariable(), ""); - m_stateVariables.emplace(&_varDecl, SSAVariable(_varDecl, *m_interface)); - } + m_variables.emplace(&_varDecl, SSAVariable(_varDecl, *m_interface)); return true; } else @@ -802,37 +786,37 @@ string SMTChecker::uniqueSymbol(Expression const& _expr) return "expr_" + to_string(_expr.id()); } -bool SMTChecker::knownVariable(Declaration const& _decl) +bool SMTChecker::knownVariable(VariableDeclaration const& _decl) { return m_variables.count(&_decl); } -smt::Expression SMTChecker::currentValue(Declaration const& _decl) +smt::Expression SMTChecker::currentValue(VariableDeclaration const& _decl) { solAssert(knownVariable(_decl), ""); return m_variables.at(&_decl)(); } -smt::Expression SMTChecker::valueAtSequence(Declaration const& _decl, int _sequence) +smt::Expression SMTChecker::valueAtSequence(VariableDeclaration const& _decl, int _sequence) { solAssert(knownVariable(_decl), ""); return m_variables.at(&_decl)(_sequence); } -smt::Expression SMTChecker::newValue(Declaration const& _decl) +smt::Expression SMTChecker::newValue(VariableDeclaration const& _decl) { solAssert(knownVariable(_decl), ""); ++m_variables.at(&_decl); return m_variables.at(&_decl)(); } -void SMTChecker::setZeroValue(Declaration const& _decl) +void SMTChecker::setZeroValue(VariableDeclaration const& _decl) { solAssert(knownVariable(_decl), ""); m_variables.at(&_decl).setZeroValue(); } -void SMTChecker::setUnknownValue(Declaration const& _decl) +void SMTChecker::setUnknownValue(VariableDeclaration const& _decl) { solAssert(knownVariable(_decl), ""); m_variables.at(&_decl).setUnknownValue(); @@ -909,3 +893,14 @@ void SMTChecker::addPathImpliedExpression(smt::Expression const& _e) { m_interface->addAssertion(smt::Expression::implies(currentPathConditions(), _e)); } + +void SMTChecker::removeLocalVariables() +{ + for (auto it = m_variables.begin(); it != m_variables.end(); ) + { + if (it->first->isLocalVariable()) + it = m_variables.erase(it); + else + ++it; + } +} diff --git a/libsolidity/formal/SMTChecker.h b/libsolidity/formal/SMTChecker.h index 50d40ab9c..6cf4e48af 100644 --- a/libsolidity/formal/SMTChecker.h +++ b/libsolidity/formal/SMTChecker.h @@ -76,11 +76,11 @@ private: /// of rounding for signed division. smt::Expression division(smt::Expression _left, smt::Expression _right, IntegerType const& _type); - void assignment(Declaration const& _variable, Expression const& _value, SourceLocation const& _location); - void assignment(Declaration const& _variable, smt::Expression const& _value, SourceLocation const& _location); + void assignment(VariableDeclaration const& _variable, Expression const& _value, SourceLocation const& _location); + void assignment(VariableDeclaration const& _variable, smt::Expression const& _value, SourceLocation const& _location); /// Maps a variable to an SSA index. - using VariableSequenceCounters = std::map; + using VariableSequenceCounters = std::map; /// Visits the branch given by the statement, pushes and pops the current path conditions. /// @param _condition if present, asserts that this condition is true within the branch. @@ -114,11 +114,11 @@ private: void initializeLocalVariables(FunctionDefinition const& _function); void resetStateVariables(); - void resetVariables(std::vector _variables); + void resetVariables(std::vector _variables); /// Given two different branches and the touched variables, /// merge the touched variables into after-branch ite variables /// using the branch condition as guard. - void mergeVariables(std::vector const& _variables, smt::Expression const& _condition, VariableSequenceCounters const& _countersEndTrue, VariableSequenceCounters const& _countersEndFalse); + void mergeVariables(std::vector const& _variables, smt::Expression const& _condition, VariableSequenceCounters const& _countersEndTrue, VariableSequenceCounters const& _countersEndFalse); /// Tries to create an uninitialized variable and returns true on success. /// This fails if the type is not supported. bool createVariable(VariableDeclaration const& _varDecl); @@ -127,21 +127,21 @@ private: /// @returns true if _delc is a variable that is known at the current point, i.e. /// has a valid sequence number - bool knownVariable(Declaration const& _decl); + bool knownVariable(VariableDeclaration const& _decl); /// @returns an expression denoting the value of the variable declared in @a _decl /// at the current point. - smt::Expression currentValue(Declaration const& _decl); + smt::Expression currentValue(VariableDeclaration const& _decl); /// @returns an expression denoting the value of the variable declared in @a _decl /// at the given sequence point. Does not ensure that this sequence point exists. - smt::Expression valueAtSequence(Declaration const& _decl, int _sequence); + smt::Expression valueAtSequence(VariableDeclaration const& _decl, int _sequence); /// Allocates a new sequence number for the declaration, updates the current /// sequence number to this value and returns the expression. - smt::Expression newValue(Declaration const& _decl); + smt::Expression newValue(VariableDeclaration const& _decl); /// Sets the value of the declaration to zero. - void setZeroValue(Declaration const& _decl); + void setZeroValue(VariableDeclaration const& _decl); /// Resets the variable to an unknown value (in its range). - void setUnknownValue(Declaration const& decl); + void setUnknownValue(VariableDeclaration const& decl); /// Returns the expression corresponding to the AST node. Throws if the expression does not exist. smt::Expression expr(Expression const& _e); @@ -161,12 +161,14 @@ private: /// Add to the solver: the given expression implied by the current path conditions void addPathImpliedExpression(smt::Expression const& _e); + /// Removes the local variables of a function. + void removeLocalVariables(); + std::shared_ptr m_interface; std::shared_ptr m_variableUsage; bool m_loopExecutionHappened = false; std::map m_expressions; - std::map m_variables; - std::map m_stateVariables; + std::map m_variables; std::vector m_pathConditions; ErrorReporter& m_errorReporter; diff --git a/libsolidity/formal/VariableUsage.cpp b/libsolidity/formal/VariableUsage.cpp index c2dea844a..9282a5606 100644 --- a/libsolidity/formal/VariableUsage.cpp +++ b/libsolidity/formal/VariableUsage.cpp @@ -50,12 +50,12 @@ VariableUsage::VariableUsage(ASTNode const& _node) _node.accept(reducer); } -vector VariableUsage::touchedVariables(ASTNode const& _node) const +vector VariableUsage::touchedVariables(ASTNode const& _node) const { if (!m_children.count(&_node) && !m_touchedVariable.count(&_node)) return {}; - set touched; + set touched; vector toVisit; toVisit.push_back(&_node); diff --git a/libsolidity/formal/VariableUsage.h b/libsolidity/formal/VariableUsage.h index 62561cce4..dda13de25 100644 --- a/libsolidity/formal/VariableUsage.h +++ b/libsolidity/formal/VariableUsage.h @@ -27,7 +27,7 @@ namespace solidity { class ASTNode; -class Declaration; +class VariableDeclaration; /** * This class collects information about which local variables of value type @@ -38,11 +38,11 @@ class VariableUsage public: explicit VariableUsage(ASTNode const& _node); - std::vector touchedVariables(ASTNode const& _node) const; + std::vector touchedVariables(ASTNode const& _node) const; private: // Variable touched by a specific AST node. - std::map m_touchedVariable; + std::map m_touchedVariable; std::map> m_children; };