Merge pull request #11481 from ethereum/unify-resolve

Unify function call resolve function used in Analysis & Yul CodeGen
This commit is contained in:
chriseth 2021-06-07 16:54:13 +02:00 committed by GitHub
commit e3e6729f22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 101 deletions

View File

@ -54,58 +54,6 @@ void ControlFlowRevertPruner::run()
modifyFunctionFlows();
}
FunctionDefinition const* ControlFlowRevertPruner::resolveCall(FunctionCall const& _functionCall, ContractDefinition const* _contract)
{
auto result = m_resolveCache.find({&_functionCall, _contract});
if (result != m_resolveCache.end())
return result->second;
auto const& functionType = dynamic_cast<FunctionType const&>(
*_functionCall.expression().annotation().type
);
if (!functionType.hasDeclaration())
return nullptr;
auto const& unresolvedFunctionDefinition =
dynamic_cast<FunctionDefinition const&>(functionType.declaration());
FunctionDefinition const* returnFunctionDef = &unresolvedFunctionDefinition;
if (auto const* memberAccess = dynamic_cast<MemberAccess const*>(&_functionCall.expression()))
{
if (*memberAccess->annotation().requiredLookup == VirtualLookup::Super)
{
if (auto const typeType = dynamic_cast<TypeType const*>(memberAccess->expression().annotation().type))
if (auto const contractType = dynamic_cast<ContractType const*>(typeType->actualType()))
{
solAssert(contractType->isSuper(), "");
ContractDefinition const* superContract = contractType->contractDefinition().superContract(*_contract);
returnFunctionDef = &unresolvedFunctionDefinition.resolveVirtual(
*_contract,
superContract
);
}
}
else
{
solAssert(*memberAccess->annotation().requiredLookup == VirtualLookup::Static, "");
returnFunctionDef = &unresolvedFunctionDefinition;
}
}
else if (auto const* identifier = dynamic_cast<Identifier const*>(&_functionCall.expression()))
{
solAssert(*identifier->annotation().requiredLookup == VirtualLookup::Virtual, "");
returnFunctionDef = &unresolvedFunctionDefinition.resolveVirtual(*_contract);
}
if (returnFunctionDef && !returnFunctionDef->isImplemented())
returnFunctionDef = nullptr;
return m_resolveCache[{&_functionCall, _contract}] = returnFunctionDef;
}
void ControlFlowRevertPruner::findRevertStates()
{
std::set<CFG::FunctionContractTuple> pendingFunctions = keys(m_functions);
@ -130,9 +78,9 @@ void ControlFlowRevertPruner::findRevertStates()
for (auto const* functionCall: _node->functionCalls)
{
auto const* resolvedFunction = resolveCall(*functionCall, item.contract);
auto const* resolvedFunction = ASTNode::resolveFunctionCall(*functionCall, item.contract);
if (resolvedFunction == nullptr)
if (resolvedFunction == nullptr || !resolvedFunction->isImplemented())
continue;
switch (m_functions.at({findScopeContract(*resolvedFunction, item.contract), resolvedFunction}))
@ -180,9 +128,9 @@ void ControlFlowRevertPruner::modifyFunctionFlows()
[&](CFGNode* _node, auto&& _addChild) {
for (auto const* functionCall: _node->functionCalls)
{
auto const* resolvedFunction = resolveCall(*functionCall, item.first.contract);
auto const* resolvedFunction = ASTNode::resolveFunctionCall(*functionCall, item.first.contract);
if (resolvedFunction == nullptr)
if (resolvedFunction == nullptr || !resolvedFunction->isImplemented())
continue;
switch (m_functions.at({findScopeContract(*resolvedFunction, item.first.contract), resolvedFunction}))
@ -223,7 +171,11 @@ void ControlFlowRevertPruner::collectCalls(FunctionDefinition const& _function,
solidity::util::BreadthFirstSearch<CFGNode*>{{functionFlow.entry}}.run(
[&](CFGNode* _node, auto&& _addChild) {
for (auto const* functionCall: _node->functionCalls)
m_calledBy[resolveCall(*functionCall, _mostDerivedContract)].insert(pair);
{
auto const* funcDef = ASTNode::resolveFunctionCall(*functionCall, _mostDerivedContract);
if (funcDef && funcDef->isImplemented())
m_calledBy[funcDef].insert(pair);
}
for (CFGNode* exit: _node->exits)
_addChild(exit);

View File

@ -45,15 +45,6 @@ private:
Unknown,
};
/// Simple attempt at resolving a function call
/// Does not aim to be able to resolve all calls, only used for variable
/// assignment tracking and revert behavior.
/// @param _functionCall the function call to analyse
/// @param _mostDerivedContract most derived contract
/// @returns function definition to which the call resolved or nullptr if no
/// definition was found.
FunctionDefinition const* resolveCall(FunctionCall const& _functionCall, ContractDefinition const* _mostDerivedContract);
/// Identify revert states of all function flows
void findRevertStates();

View File

@ -57,6 +57,50 @@ Declaration const* ASTNode::referencedDeclaration(Expression const& _expression)
return nullptr;
}
FunctionDefinition const* ASTNode::resolveFunctionCall(FunctionCall const& _functionCall, ContractDefinition const* _mostDerivedContract)
{
auto const* functionDef = dynamic_cast<FunctionDefinition const*>(
ASTNode::referencedDeclaration(_functionCall.expression())
);
if (!functionDef)
return nullptr;
if (auto const* memberAccess = dynamic_cast<MemberAccess const*>(&_functionCall.expression()))
{
if (*memberAccess->annotation().requiredLookup == VirtualLookup::Super)
{
if (auto const typeType = dynamic_cast<TypeType const*>(memberAccess->expression().annotation().type))
if (auto const contractType = dynamic_cast<ContractType const*>(typeType->actualType()))
{
solAssert(_mostDerivedContract, "");
solAssert(contractType->isSuper(), "");
ContractDefinition const* superContract = contractType->contractDefinition().superContract(*_mostDerivedContract);
return &functionDef->resolveVirtual(
*_mostDerivedContract,
superContract
);
}
}
else
solAssert(*memberAccess->annotation().requiredLookup == VirtualLookup::Static, "");
}
else if (auto const* identifier = dynamic_cast<Identifier const*>(&_functionCall.expression()))
{
solAssert(*identifier->annotation().requiredLookup == VirtualLookup::Virtual, "");
if (functionDef->virtualSemantics())
{
solAssert(_mostDerivedContract, "");
return &functionDef->resolveVirtual(*_mostDerivedContract);
}
}
else
solAssert(false, "");
return functionDef;
}
ASTAnnotation& ASTNode::annotation() const
{
if (!m_annotation)

View File

@ -104,6 +104,8 @@ public:
/// Extracts the referenced declaration from all nodes whose annotations support
/// `referencedDeclaration`.
static Declaration const* referencedDeclaration(Expression const& _expression);
/// Performs potential super or virtual lookup for a function call based on the most derived contract.
static FunctionDefinition const* resolveFunctionCall(FunctionCall const& _functionCall, ContractDefinition const* _mostDerivedContract);
/// Returns the source code location of this node.
SourceLocation const& location() const { return m_location; }

View File

@ -884,8 +884,6 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
return;
}
auto const* memberAccess = dynamic_cast<MemberAccess const*>(&_functionCall.expression());
switch (functionType->kind())
{
case FunctionType::Kind::Declaration:
@ -893,39 +891,7 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
break;
case FunctionType::Kind::Internal:
{
auto identifier = dynamic_cast<Identifier const*>(&_functionCall.expression());
auto const* functionDef = dynamic_cast<FunctionDefinition const*>(
ASTNode::referencedDeclaration(_functionCall.expression())
);
if (functionDef)
{
solAssert(memberAccess || identifier, "");
solAssert(functionType->declaration() == *functionDef, "");
if (identifier)
{
solAssert(*identifier->annotation().requiredLookup == VirtualLookup::Virtual, "");
functionDef = &functionDef->resolveVirtual(m_context.mostDerivedContract());
}
else if (auto typeType = dynamic_cast<TypeType const*>(memberAccess->expression().annotation().type))
if (
auto contractType = dynamic_cast<ContractType const*>(typeType->actualType());
contractType->isSuper()
)
{
ContractDefinition const* super = contractType->contractDefinition().superContract(m_context.mostDerivedContract());
solAssert(super, "Super contract not available.");
solAssert(*memberAccess->annotation().requiredLookup == VirtualLookup::Super, "");
functionDef = &functionDef->resolveVirtual(m_context.mostDerivedContract(), super);
}
solAssert(functionDef && functionDef->isImplemented(), "");
solAssert(
functionDef->parameters().size() == arguments.size() + (functionType->bound() ? 1 : 0),
""
);
}
FunctionDefinition const* functionDef = ASTNode::resolveFunctionCall(_functionCall, &m_context.mostDerivedContract());
solAssert(!functionType->takesArbitraryParameters(), "");
@ -937,11 +903,15 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
args += convert(*arguments[i], *parameterTypes[i]).stackSlots();
if (functionDef)
{
solAssert(functionDef->isImplemented(), "");
define(_functionCall) <<
m_context.enqueueFunctionForCodeGeneration(*functionDef) <<
"(" <<
joinHumanReadable(args) <<
")\n";
}
else
{
YulArity arity = YulArity::fromType(*functionType);