/*
This file is part of solidity.
solidity is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
solidity is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with solidity. If not, see .
*/
// SPDX-License-Identifier: GPL-3.0
#include
#include
#include
namespace solidity::frontend
{
namespace
{
/// Find the right scope for the called function: When calling a base function, we keep the most derived, but we use the called contract in case it is a library function or nullptr for a free function
ContractDefinition const* findScopeContract(FunctionDefinition const& _function, ContractDefinition const* _callingContract)
{
if (auto const* functionContract = _function.annotation().contract)
{
if (_callingContract && _callingContract->derivesFrom(*functionContract))
return _callingContract;
else
return functionContract;
}
return nullptr;
}
}
void ControlFlowRevertPruner::run()
{
// build a lookup table for function calls / callers
for (auto& [pair, flow]: m_cfg.allFunctionFlows())
collectCalls(*pair.function, pair.contract);
findRevertStates();
modifyFunctionFlows();
}
FunctionDefinition const* ControlFlowRevertPruner::resolveCall(FunctionCall const& _functionCall, ContractDefinition const* _contract)
{
auto result = m_resolveCache.find({&_functionCall, _contract});
if (result != m_resolveCache.end())
return result->second;
auto const& functionType = dynamic_cast(
*_functionCall.expression().annotation().type
);
if (!functionType.hasDeclaration())
return nullptr;
auto const& unresolvedFunctionDefinition =
dynamic_cast(functionType.declaration());
FunctionDefinition const* returnFunctionDef = &unresolvedFunctionDefinition;
if (auto const* memberAccess = dynamic_cast(&_functionCall.expression()))
{
if (*memberAccess->annotation().requiredLookup == VirtualLookup::Super)
{
if (auto const typeType = dynamic_cast(memberAccess->expression().annotation().type))
if (auto const contractType = dynamic_cast(typeType->actualType()))
{
solAssert(contractType->isSuper(), "");
ContractDefinition const* superContract = contractType->contractDefinition().superContract(*_contract);
returnFunctionDef = &unresolvedFunctionDefinition.resolveVirtual(
*_contract,
superContract
);
}
}
else
{
solAssert(*memberAccess->annotation().requiredLookup == VirtualLookup::Static, "");
returnFunctionDef = &unresolvedFunctionDefinition;
}
}
else if (auto const* identifier = dynamic_cast(&_functionCall.expression()))
{
solAssert(*identifier->annotation().requiredLookup == VirtualLookup::Virtual, "");
returnFunctionDef = &unresolvedFunctionDefinition.resolveVirtual(*_contract);
}
if (returnFunctionDef && !returnFunctionDef->isImplemented())
returnFunctionDef = nullptr;
return m_resolveCache[{&_functionCall, _contract}] = returnFunctionDef;
}
void ControlFlowRevertPruner::findRevertStates()
{
std::set pendingFunctions = keys(m_functions);
while (!pendingFunctions.empty())
{
CFG::FunctionContractTuple item = *pendingFunctions.begin();
pendingFunctions.erase(pendingFunctions.begin());
if (m_functions[item] != RevertState::Unknown)
continue;
bool foundExit = false;
bool foundUnknown = false;
FunctionFlow const& functionFlow = m_cfg.functionFlow(*item.function, item.contract);
solidity::util::BreadthFirstSearch{{functionFlow.entry}}.run(
[&](CFGNode* _node, auto&& _addChild) {
if (_node == functionFlow.exit)
foundExit = true;
for (auto const* functionCall: _node->functionCalls)
{
auto const* resolvedFunction = resolveCall(*functionCall, item.contract);
if (resolvedFunction == nullptr)
continue;
switch (m_functions.at({findScopeContract(*resolvedFunction, item.contract), resolvedFunction}))
{
case RevertState::Unknown:
foundUnknown = true;
return;
case RevertState::AllPathsRevert:
return;
default:
break;
}
}
for (CFGNode* exit: _node->exits)
_addChild(exit);
});
auto& revertState = m_functions[item];
if (foundExit)
revertState = RevertState::HasNonRevertingPath;
else if (!foundUnknown)
revertState = RevertState::AllPathsRevert;
// Mark all functions depending on this one as modified again
if (revertState != RevertState::Unknown)
for (auto& nextItem: m_calledBy[item.function])
// Ignore different most derived contracts in dependent callees
if (
item.contract == nullptr ||
nextItem.contract == nullptr ||
nextItem.contract == item.contract
)
pendingFunctions.insert(nextItem);
}
}
void ControlFlowRevertPruner::modifyFunctionFlows()
{
for (auto& item: m_functions)
{
FunctionFlow const& functionFlow = m_cfg.functionFlow(*item.first.function, item.first.contract);
solidity::util::BreadthFirstSearch{{functionFlow.entry}}.run(
[&](CFGNode* _node, auto&& _addChild) {
for (auto const* functionCall: _node->functionCalls)
{
auto const* resolvedFunction = resolveCall(*functionCall, item.first.contract);
if (resolvedFunction == nullptr)
continue;
switch (m_functions.at({findScopeContract(*resolvedFunction, item.first.contract), resolvedFunction}))
{
case RevertState::Unknown:
[[fallthrough]];
case RevertState::AllPathsRevert:
// If the revert states of the functions do not
// change anymore, we treat all "unknown" states as
// "reverting", since they can only be caused by
// recursion.
for (CFGNode * node: _node->exits)
ranges::remove(node->entries, _node);
_node->exits = {functionFlow.revert};
functionFlow.revert->entries.push_back(_node);
return;
default:
break;
}
}
for (CFGNode* exit: _node->exits)
_addChild(exit);
});
}
}
void ControlFlowRevertPruner::collectCalls(FunctionDefinition const& _function, ContractDefinition const* _mostDerivedContract)
{
FunctionFlow const& functionFlow = m_cfg.functionFlow(_function, _mostDerivedContract);
CFG::FunctionContractTuple pair{_mostDerivedContract, &_function};
solAssert(m_functions.count(pair) == 0, "");
m_functions[pair] = RevertState::Unknown;
solidity::util::BreadthFirstSearch{{functionFlow.entry}}.run(
[&](CFGNode* _node, auto&& _addChild) {
for (auto const* functionCall: _node->functionCalls)
m_calledBy[resolveCall(*functionCall, _mostDerivedContract)].insert(pair);
for (CFGNode* exit: _node->exits)
_addChild(exit);
});
}
}