Move error flag from CHC to SymbolicState

This commit is contained in:
Leonardo Alt 2020-09-15 19:17:18 +02:00
parent 9115100f2a
commit ac93ee1d08
4 changed files with 57 additions and 40 deletions

View File

@ -115,12 +115,13 @@ void CHC::endVisit(ContractDefinition const& _contract)
auto implicitConstructorPredicate = createSymbolicBlock( auto implicitConstructorPredicate = createSymbolicBlock(
implicitConstructorSort(), implicitConstructorSort(),
"implicit_constructor_" + contractSuffix(_contract), "implicit_constructor_" + contractSuffix(_contract),
PredicateType::ImplicitConstructor,
&_contract &_contract
); );
auto implicitConstructor = (*implicitConstructorPredicate)({}); auto implicitConstructor = (*implicitConstructorPredicate)({});
addRule(implicitConstructor, implicitConstructor.name); addRule(implicitConstructor, implicitConstructor.name);
m_currentBlock = implicitConstructor; m_currentBlock = implicitConstructor;
m_context.addAssertion(m_error.currentValue() == 0); m_context.addAssertion(errorFlag().currentValue() == 0);
if (auto constructor = _contract.constructor()) if (auto constructor = _contract.constructor())
constructor->accept(*this); constructor->accept(*this);
@ -133,8 +134,8 @@ void CHC::endVisit(ContractDefinition const& _contract)
vector<smtutil::Expression> symbArgs = currentFunctionVariables(*m_currentContract); vector<smtutil::Expression> symbArgs = currentFunctionVariables(*m_currentContract);
setCurrentBlock(*m_constructorSummaryPredicate, &symbArgs); setCurrentBlock(*m_constructorSummaryPredicate, &symbArgs);
addAssertVerificationTarget(m_currentContract, m_currentBlock, smtutil::Expression(true), m_error.currentValue()); addAssertVerificationTarget(m_currentContract, m_currentBlock, smtutil::Expression(true), errorFlag().currentValue());
connectBlocks(m_currentBlock, interface(), m_error.currentValue() == 0); connectBlocks(m_currentBlock, interface(), errorFlag().currentValue() == 0);
SMTEncoder::endVisit(_contract); SMTEncoder::endVisit(_contract);
} }
@ -173,7 +174,7 @@ bool CHC::visit(FunctionDefinition const& _function)
else else
addRule(functionPred, functionPred.name); addRule(functionPred, functionPred.name);
m_context.addAssertion(m_error.currentValue() == 0); m_context.addAssertion(errorFlag().currentValue() == 0);
for (auto const* var: m_stateVariables) for (auto const* var: m_stateVariables)
m_context.addAssertion(m_context.variable(*var)->valueAtIndex(0) == currentValue(*var)); m_context.addAssertion(m_context.variable(*var)->valueAtIndex(0) == currentValue(*var));
for (auto const& var: _function.parameters()) for (auto const& var: _function.parameters())
@ -227,7 +228,7 @@ void CHC::endVisit(FunctionDefinition const& _function)
} }
else else
{ {
auto assertionError = m_error.currentValue(); auto assertionError = errorFlag().currentValue();
auto sum = summary(_function); auto sum = summary(_function);
connectBlocks(m_currentBlock, sum); connectBlocks(m_currentBlock, sum);
@ -485,18 +486,18 @@ void CHC::visitAssert(FunctionCall const& _funCall)
else else
m_functionAssertions[m_currentFunction].insert(&_funCall); m_functionAssertions[m_currentFunction].insert(&_funCall);
auto previousError = m_error.currentValue(); auto previousError = errorFlag().currentValue();
m_error.increaseIndex(); errorFlag().increaseIndex();
connectBlocks( connectBlocks(
m_currentBlock, m_currentBlock,
m_currentFunction->isConstructor() ? summary(*m_currentContract) : summary(*m_currentFunction), m_currentFunction->isConstructor() ? summary(*m_currentContract) : summary(*m_currentFunction),
currentPathConditions() && !m_context.expression(*args.front())->currentValue() && ( currentPathConditions() && !m_context.expression(*args.front())->currentValue() && (
m_error.currentValue() == newErrorId(_funCall) errorFlag().currentValue() == newErrorId(_funCall)
) )
); );
m_context.addAssertion(m_error.currentValue() == previousError); m_context.addAssertion(errorFlag().currentValue() == previousError);
} }
void CHC::internalFunctionCall(FunctionCall const& _funCall) void CHC::internalFunctionCall(FunctionCall const& _funCall)
@ -518,18 +519,18 @@ void CHC::internalFunctionCall(FunctionCall const& _funCall)
m_context.addAssertion(interface(*contract)); m_context.addAssertion(interface(*contract));
} }
auto previousError = m_error.currentValue(); auto previousError = errorFlag().currentValue();
m_context.addAssertion(predicate(_funCall)); m_context.addAssertion(predicate(_funCall));
connectBlocks( connectBlocks(
m_currentBlock, m_currentBlock,
(m_currentFunction && !m_currentFunction->isConstructor()) ? summary(*m_currentFunction) : summary(*m_currentContract), (m_currentFunction && !m_currentFunction->isConstructor()) ? summary(*m_currentFunction) : summary(*m_currentContract),
(m_error.currentValue() > 0) (errorFlag().currentValue() > 0)
); );
m_context.addAssertion(m_error.currentValue() == 0); m_context.addAssertion(errorFlag().currentValue() == 0);
m_error.increaseIndex(); errorFlag().increaseIndex();
m_context.addAssertion(m_error.currentValue() == previousError); m_context.addAssertion(errorFlag().currentValue() == previousError);
} }
void CHC::externalFunctionCall(FunctionCall const& _funCall) void CHC::externalFunctionCall(FunctionCall const& _funCall)
@ -558,7 +559,7 @@ void CHC::externalFunctionCall(FunctionCall const& _funCall)
auto nondet = (*m_nondetInterfaces.at(m_currentContract))(preCallState + currentStateVariables()); auto nondet = (*m_nondetInterfaces.at(m_currentContract))(preCallState + currentStateVariables());
m_context.addAssertion(nondet); m_context.addAssertion(nondet);
m_context.addAssertion(m_error.currentValue() == 0); m_context.addAssertion(errorFlag().currentValue() == 0);
} }
void CHC::unknownFunctionCall(FunctionCall const&) void CHC::unknownFunctionCall(FunctionCall const&)
@ -583,13 +584,13 @@ void CHC::makeArrayPopVerificationTarget(FunctionCall const& _arrayPop)
auto symbArray = dynamic_pointer_cast<SymbolicArrayVariable>(m_context.expression(memberAccess->expression())); auto symbArray = dynamic_pointer_cast<SymbolicArrayVariable>(m_context.expression(memberAccess->expression()));
solAssert(symbArray, ""); solAssert(symbArray, "");
auto previousError = m_error.currentValue(); auto previousError = errorFlag().currentValue();
m_error.increaseIndex(); errorFlag().increaseIndex();
addVerificationTarget(&_arrayPop, VerificationTarget::Type::PopEmptyArray, m_error.currentValue()); addVerificationTarget(&_arrayPop, VerificationTarget::Type::PopEmptyArray, errorFlag().currentValue());
smtutil::Expression target = (symbArray->length() <= 0) && (m_error.currentValue() == newErrorId(_arrayPop)); smtutil::Expression target = (symbArray->length() <= 0) && (errorFlag().currentValue() == newErrorId(_arrayPop));
m_context.addAssertion((m_error.currentValue() == previousError) || target); m_context.addAssertion((errorFlag().currentValue() == previousError) || target);
} }
pair<smtutil::Expression, smtutil::Expression> CHC::arithmeticOperation( pair<smtutil::Expression, smtutil::Expression> CHC::arithmeticOperation(
@ -613,8 +614,8 @@ pair<smtutil::Expression, smtutil::Expression> CHC::arithmeticOperation(
if (_op == Token::Mod || (_op == Token::Div && !intType->isSigned())) if (_op == Token::Mod || (_op == Token::Div && !intType->isSigned()))
return values; return values;
auto previousError = m_error.currentValue(); auto previousError = errorFlag().currentValue();
m_error.increaseIndex(); errorFlag().increaseIndex();
VerificationTarget::Type targetType; VerificationTarget::Type targetType;
unsigned errorId = newErrorId(_expression); unsigned errorId = newErrorId(_expression);
@ -623,24 +624,24 @@ pair<smtutil::Expression, smtutil::Expression> CHC::arithmeticOperation(
if (_op == Token::Div) if (_op == Token::Div)
{ {
targetType = VerificationTarget::Type::Overflow; targetType = VerificationTarget::Type::Overflow;
target = values.second > intType->maxValue() && m_error.currentValue() == errorId; target = values.second > intType->maxValue() && errorFlag().currentValue() == errorId;
} }
else if (intType->isSigned()) else if (intType->isSigned())
{ {
unsigned secondErrorId = newErrorId(_expression); unsigned secondErrorId = newErrorId(_expression);
targetType = VerificationTarget::Type::UnderOverflow; targetType = VerificationTarget::Type::UnderOverflow;
target = (values.second < intType->minValue() && m_error.currentValue() == errorId) || target = (values.second < intType->minValue() && errorFlag().currentValue() == errorId) ||
(values.second > intType->maxValue() && m_error.currentValue() == secondErrorId); (values.second > intType->maxValue() && errorFlag().currentValue() == secondErrorId);
} }
else if (_op == Token::Sub) else if (_op == Token::Sub)
{ {
targetType = VerificationTarget::Type::Underflow; targetType = VerificationTarget::Type::Underflow;
target = values.second < intType->minValue() && m_error.currentValue() == errorId; target = values.second < intType->minValue() && errorFlag().currentValue() == errorId;
} }
else if (_op == Token::Add || _op == Token::Mul) else if (_op == Token::Add || _op == Token::Mul)
{ {
targetType = VerificationTarget::Type::Overflow; targetType = VerificationTarget::Type::Overflow;
target = values.second > intType->maxValue() && m_error.currentValue() == errorId; target = values.second > intType->maxValue() && errorFlag().currentValue() == errorId;
} }
else else
solAssert(false, ""); solAssert(false, "");
@ -648,10 +649,10 @@ pair<smtutil::Expression, smtutil::Expression> CHC::arithmeticOperation(
addVerificationTarget( addVerificationTarget(
&_expression, &_expression,
targetType, targetType,
m_error.currentValue() errorFlag().currentValue()
); );
m_context.addAssertion((m_error.currentValue() == previousError) || *target); m_context.addAssertion((errorFlag().currentValue() == previousError) || *target);
return values; return values;
} }
@ -700,7 +701,7 @@ void CHC::resetContractAnalysis()
m_unknownFunctionCallSeen = false; m_unknownFunctionCallSeen = false;
m_breakDest = nullptr; m_breakDest = nullptr;
m_continueDest = nullptr; m_continueDest = nullptr;
m_error.resetIndex(); errorFlag().resetIndex();
} }
void CHC::eraseKnowledge() void CHC::eraseKnowledge()
@ -818,7 +819,7 @@ void CHC::defineInterfacesAndSummaries(SourceUnit const& _source)
auto nondetPre = iface(state0 + state1); auto nondetPre = iface(state0 + state1);
auto nondetPost = iface(state0 + state2); auto nondetPost = iface(state0 + state2);
vector<smtutil::Expression> args{m_error.currentValue()}; vector<smtutil::Expression> args{errorFlag().currentValue()};
args += state1 + args += state1 +
applyMap(function->parameters(), [this](auto _var) { return valueAtIndex(*_var, 0); }) + applyMap(function->parameters(), [this](auto _var) { return valueAtIndex(*_var, 0); }) +
state2 + state2 +
@ -1035,8 +1036,8 @@ smtutil::Expression CHC::predicate(FunctionCall const& _funCall)
if (!function) if (!function)
return smtutil::Expression(true); return smtutil::Expression(true);
m_error.increaseIndex(); errorFlag().increaseIndex();
vector<smtutil::Expression> args{m_error.currentValue()}; vector<smtutil::Expression> args{errorFlag().currentValue()};
auto const* contract = function->annotation().contract; auto const* contract = function->annotation().contract;
FunctionType const& funType = dynamic_cast<FunctionType const&>(*_funCall.expression().annotation().type); FunctionType const& funType = dynamic_cast<FunctionType const&>(*_funCall.expression().annotation().type);
bool otherContract = contract->isLibrary() || bool otherContract = contract->isLibrary() ||
@ -1427,3 +1428,8 @@ unsigned CHC::newErrorId(frontend::Expression const& _expr)
m_errorIds.emplace(_expr.id(), errorId); m_errorIds.emplace(_expr.id(), errorId);
return errorId; return errorId;
} }
SymbolicIntVariable& CHC::errorFlag()
{
return m_context.state().errorFlag();
}

View File

@ -245,6 +245,8 @@ private:
/// @returns a new unique error id associated with _expr and stores /// @returns a new unique error id associated with _expr and stores
/// it into m_errorIds. /// it into m_errorIds.
unsigned newErrorId(Expression const& _expr); unsigned newErrorId(Expression const& _expr);
smt::SymbolicIntVariable& errorFlag();
//@} //@}
/// Predicates. /// Predicates.
@ -269,13 +271,6 @@ private:
/// Function predicates. /// Function predicates.
std::map<ContractDefinition const*, std::map<FunctionDefinition const*, Predicate const*>> m_summaries; std::map<ContractDefinition const*, std::map<FunctionDefinition const*, Predicate const*>> m_summaries;
smt::SymbolicIntVariable m_error{
TypeProvider::uint256(),
TypeProvider::uint256(),
"error",
m_context
};
//@} //@}
/// Variables. /// Variables.

View File

@ -33,6 +33,7 @@ void SymbolicState::reset()
{ {
m_thisAddress.resetIndex(); m_thisAddress.resetIndex();
m_balances.resetIndex(); m_balances.resetIndex();
m_error.resetIndex();
} }
// Blockchain // Blockchain
@ -52,6 +53,11 @@ smtutil::Expression SymbolicState::balance(smtutil::Expression _address)
return smtutil::Expression::select(m_balances.elements(), move(_address)); return smtutil::Expression::select(m_balances.elements(), move(_address));
} }
SymbolicIntVariable& SymbolicState::errorFlag()
{
return m_error;
}
void SymbolicState::transfer(smtutil::Expression _from, smtutil::Expression _to, smtutil::Expression _value) void SymbolicState::transfer(smtutil::Expression _from, smtutil::Expression _to, smtutil::Expression _value)
{ {
unsigned indexBefore = m_balances.index(); unsigned indexBefore = m_balances.index();

View File

@ -46,6 +46,9 @@ public:
smtutil::Expression balance(); smtutil::Expression balance();
/// @returns the symbolic balance of an address. /// @returns the symbolic balance of an address.
smtutil::Expression balance(smtutil::Expression _address); smtutil::Expression balance(smtutil::Expression _address);
SymbolicIntVariable& errorFlag();
/// Transfer _value from _from to _to. /// Transfer _value from _from to _to.
void transfer(smtutil::Expression _from, smtutil::Expression _to, smtutil::Expression _value); void transfer(smtutil::Expression _from, smtutil::Expression _to, smtutil::Expression _value);
//@} //@}
@ -68,6 +71,13 @@ private:
"balances", "balances",
m_context m_context
}; };
smt::SymbolicIntVariable m_error{
TypeProvider::uint256(),
TypeProvider::uint256(),
"error",
m_context
};
}; };
} }