/*
	This file is part of solidity.
	solidity is free software: you can redistribute it and/or modify
	it under the terms of the GNU General Public License as published by
	the Free Software Foundation, either version 3 of the License, or
	(at your option) any later version.
	solidity is distributed in the hope that it will be useful,
	but WITHOUT ANY WARRANTY; without even the implied warranty of
	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
	GNU General Public License for more details.
	You should have received a copy of the GNU General Public License
	along with solidity.  If not, see .
*/
// SPDX-License-Identifier: GPL-3.0
#include 
#include 
namespace solidity::frontend
{
namespace
{
/// Find the right scope for the called function: When calling a base function, we keep the most derived, but we use the called contract in case it is a library function or nullptr for a free function
ContractDefinition const* findScopeContract(FunctionDefinition const& _function, ContractDefinition const* _callingContract)
{
	if (auto const* functionContract = _function.annotation().contract)
	{
		if (_callingContract && _callingContract->derivesFrom(*functionContract))
			return _callingContract;
		else
			return functionContract;
	}
	return nullptr;
}
}
void ControlFlowRevertPruner::run()
{
	// build a lookup table for function calls / callers
	for (auto& [pair, flow]: m_cfg.allFunctionFlows())
		collectCalls(*pair.function, pair.contract);
	findRevertStates();
	modifyFunctionFlows();
}
FunctionDefinition const* ControlFlowRevertPruner::resolveCall(FunctionCall const& _functionCall, ContractDefinition const* _contract)
{
	auto result = m_resolveCache.find({&_functionCall, _contract});
	if (result != m_resolveCache.end())
		return result->second;
	auto const& functionType = dynamic_cast(
		*_functionCall.expression().annotation().type
	);
	if (!functionType.hasDeclaration())
		return nullptr;
	auto const& unresolvedFunctionDefinition =
		dynamic_cast(functionType.declaration());
	FunctionDefinition const* returnFunctionDef = &unresolvedFunctionDefinition;
	if (auto const* memberAccess = dynamic_cast(&_functionCall.expression()))
	{
		if (*memberAccess->annotation().requiredLookup == VirtualLookup::Super)
		{
			if (auto const typeType = dynamic_cast(memberAccess->expression().annotation().type))
				if (auto const contractType = dynamic_cast(typeType->actualType()))
				{
					solAssert(contractType->isSuper(), "");
					ContractDefinition const* superContract = contractType->contractDefinition().superContract(*_contract);
					returnFunctionDef = &unresolvedFunctionDefinition.resolveVirtual(
						*_contract,
						superContract
					);
				}
		}
		else
		{
			solAssert(*memberAccess->annotation().requiredLookup == VirtualLookup::Static, "");
			returnFunctionDef = &unresolvedFunctionDefinition;
		}
	}
	else if (auto const* identifier = dynamic_cast(&_functionCall.expression()))
	{
		solAssert(*identifier->annotation().requiredLookup == VirtualLookup::Virtual, "");
		returnFunctionDef = &unresolvedFunctionDefinition.resolveVirtual(*_contract);
	}
	if (returnFunctionDef && !returnFunctionDef->isImplemented())
		returnFunctionDef = nullptr;
	return m_resolveCache[{&_functionCall, _contract}] = returnFunctionDef;
}
void ControlFlowRevertPruner::findRevertStates()
{
	std::set pendingFunctions = keys(m_functions);
	while (!pendingFunctions.empty())
	{
		CFG::FunctionContractTuple item = *pendingFunctions.begin();
		pendingFunctions.erase(pendingFunctions.begin());
		if (m_functions[item] != RevertState::Unknown)
			continue;
		bool foundExit = false;
		bool foundUnknown = false;
		FunctionFlow const& functionFlow = m_cfg.functionFlow(*item.function, item.contract);
		solidity::util::BreadthFirstSearch{{functionFlow.entry}}.run(
			[&](CFGNode* _node, auto&& _addChild) {
				if (_node == functionFlow.exit)
					foundExit = true;
				for (auto const* functionCall: _node->functionCalls)
				{
					auto const* resolvedFunction = resolveCall(*functionCall, item.contract);
					if (resolvedFunction == nullptr)
						continue;
					switch (m_functions.at({findScopeContract(*resolvedFunction, item.contract), resolvedFunction}))
					{
						case RevertState::Unknown:
							foundUnknown = true;
							return;
						case RevertState::AllPathsRevert:
							return;
						default:
							break;
					}
				}
				for (CFGNode* exit: _node->exits)
					_addChild(exit);
		});
		auto& revertState = m_functions[item];
		if (foundExit)
			revertState = RevertState::HasNonRevertingPath;
		else if (!foundUnknown)
			revertState = RevertState::AllPathsRevert;
		// Mark all functions depending on this one as modified again
		if (revertState != RevertState::Unknown)
			for (auto& nextItem: m_calledBy[item.function])
				// Ignore different most derived contracts in dependent callees
				if (
					item.contract == nullptr ||
					nextItem.contract == nullptr ||
					nextItem.contract == item.contract
				)
					pendingFunctions.insert(nextItem);
	}
}
void ControlFlowRevertPruner::modifyFunctionFlows()
{
	for (auto& item: m_functions)
	{
		FunctionFlow const& functionFlow = m_cfg.functionFlow(*item.first.function, item.first.contract);
		solidity::util::BreadthFirstSearch{{functionFlow.entry}}.run(
			[&](CFGNode* _node, auto&& _addChild) {
				for (auto const* functionCall: _node->functionCalls)
				{
					auto const* resolvedFunction = resolveCall(*functionCall, item.first.contract);
					if (resolvedFunction == nullptr)
						continue;
					switch (m_functions.at({findScopeContract(*resolvedFunction, item.first.contract), resolvedFunction}))
					{
						case RevertState::Unknown:
							[[fallthrough]];
						case RevertState::AllPathsRevert:
							// If the revert states of the functions do not
							// change anymore, we treat all "unknown" states as
							// "reverting", since they can only be caused by
							// recursion.
							_node->exits = {functionFlow.revert};
							return;
						default:
							break;
					}
				}
				for (CFGNode* exit: _node->exits)
					_addChild(exit);
		});
	}
}
void ControlFlowRevertPruner::collectCalls(FunctionDefinition const& _function, ContractDefinition const* _mostDerivedContract)
{
	FunctionFlow const& functionFlow = m_cfg.functionFlow(_function, _mostDerivedContract);
	CFG::FunctionContractTuple pair{_mostDerivedContract, &_function};
	solAssert(m_functions.count(pair) == 0, "");
	m_functions[pair] = RevertState::Unknown;
	solidity::util::BreadthFirstSearch{{functionFlow.entry}}.run(
		[&](CFGNode* _node, auto&& _addChild) {
			for (auto const* functionCall: _node->functionCalls)
				m_calledBy[resolveCall(*functionCall, _mostDerivedContract)].insert(pair);
			for (CFGNode* exit: _node->exits)
				_addChild(exit);
	});
}
}