Unify function call resolve function used in Analysis & Yul CodeGen

This commit is contained in:
Mathias Baumann 2021-06-03 13:09:13 +02:00
parent 1f8f1a3db9
commit 6a0313c456
5 changed files with 60 additions and 101 deletions

View File

@ -54,58 +54,6 @@ void ControlFlowRevertPruner::run()
modifyFunctionFlows(); modifyFunctionFlows();
} }
FunctionDefinition const* ControlFlowRevertPruner::resolveCall(FunctionCall const& _functionCall, ContractDefinition const* _contract)
{
auto result = m_resolveCache.find({&_functionCall, _contract});
if (result != m_resolveCache.end())
return result->second;
auto const& functionType = dynamic_cast<FunctionType const&>(
*_functionCall.expression().annotation().type
);
if (!functionType.hasDeclaration())
return nullptr;
auto const& unresolvedFunctionDefinition =
dynamic_cast<FunctionDefinition const&>(functionType.declaration());
FunctionDefinition const* returnFunctionDef = &unresolvedFunctionDefinition;
if (auto const* memberAccess = dynamic_cast<MemberAccess const*>(&_functionCall.expression()))
{
if (*memberAccess->annotation().requiredLookup == VirtualLookup::Super)
{
if (auto const typeType = dynamic_cast<TypeType const*>(memberAccess->expression().annotation().type))
if (auto const contractType = dynamic_cast<ContractType const*>(typeType->actualType()))
{
solAssert(contractType->isSuper(), "");
ContractDefinition const* superContract = contractType->contractDefinition().superContract(*_contract);
returnFunctionDef = &unresolvedFunctionDefinition.resolveVirtual(
*_contract,
superContract
);
}
}
else
{
solAssert(*memberAccess->annotation().requiredLookup == VirtualLookup::Static, "");
returnFunctionDef = &unresolvedFunctionDefinition;
}
}
else if (auto const* identifier = dynamic_cast<Identifier const*>(&_functionCall.expression()))
{
solAssert(*identifier->annotation().requiredLookup == VirtualLookup::Virtual, "");
returnFunctionDef = &unresolvedFunctionDefinition.resolveVirtual(*_contract);
}
if (returnFunctionDef && !returnFunctionDef->isImplemented())
returnFunctionDef = nullptr;
return m_resolveCache[{&_functionCall, _contract}] = returnFunctionDef;
}
void ControlFlowRevertPruner::findRevertStates() void ControlFlowRevertPruner::findRevertStates()
{ {
std::set<CFG::FunctionContractTuple> pendingFunctions = keys(m_functions); std::set<CFG::FunctionContractTuple> pendingFunctions = keys(m_functions);
@ -130,9 +78,9 @@ void ControlFlowRevertPruner::findRevertStates()
for (auto const* functionCall: _node->functionCalls) for (auto const* functionCall: _node->functionCalls)
{ {
auto const* resolvedFunction = resolveCall(*functionCall, item.contract); auto const* resolvedFunction = ASTNode::resolveFunctionCall(*functionCall, item.contract);
if (resolvedFunction == nullptr) if (resolvedFunction == nullptr || !resolvedFunction->isImplemented())
continue; continue;
switch (m_functions.at({findScopeContract(*resolvedFunction, item.contract), resolvedFunction})) switch (m_functions.at({findScopeContract(*resolvedFunction, item.contract), resolvedFunction}))
@ -180,9 +128,9 @@ void ControlFlowRevertPruner::modifyFunctionFlows()
[&](CFGNode* _node, auto&& _addChild) { [&](CFGNode* _node, auto&& _addChild) {
for (auto const* functionCall: _node->functionCalls) for (auto const* functionCall: _node->functionCalls)
{ {
auto const* resolvedFunction = resolveCall(*functionCall, item.first.contract); auto const* resolvedFunction = ASTNode::resolveFunctionCall(*functionCall, item.first.contract);
if (resolvedFunction == nullptr) if (resolvedFunction == nullptr || !resolvedFunction->isImplemented())
continue; continue;
switch (m_functions.at({findScopeContract(*resolvedFunction, item.first.contract), resolvedFunction})) switch (m_functions.at({findScopeContract(*resolvedFunction, item.first.contract), resolvedFunction}))
@ -223,7 +171,11 @@ void ControlFlowRevertPruner::collectCalls(FunctionDefinition const& _function,
solidity::util::BreadthFirstSearch<CFGNode*>{{functionFlow.entry}}.run( solidity::util::BreadthFirstSearch<CFGNode*>{{functionFlow.entry}}.run(
[&](CFGNode* _node, auto&& _addChild) { [&](CFGNode* _node, auto&& _addChild) {
for (auto const* functionCall: _node->functionCalls) for (auto const* functionCall: _node->functionCalls)
m_calledBy[resolveCall(*functionCall, _mostDerivedContract)].insert(pair); {
auto const* funcDef = ASTNode::resolveFunctionCall(*functionCall, _mostDerivedContract);
if (funcDef && funcDef->isImplemented())
m_calledBy[funcDef].insert(pair);
}
for (CFGNode* exit: _node->exits) for (CFGNode* exit: _node->exits)
_addChild(exit); _addChild(exit);

View File

@ -45,15 +45,6 @@ private:
Unknown, Unknown,
}; };
/// Simple attempt at resolving a function call
/// Does not aim to be able to resolve all calls, only used for variable
/// assignment tracking and revert behavior.
/// @param _functionCall the function call to analyse
/// @param _mostDerivedContract most derived contract
/// @returns function definition to which the call resolved or nullptr if no
/// definition was found.
FunctionDefinition const* resolveCall(FunctionCall const& _functionCall, ContractDefinition const* _mostDerivedContract);
/// Identify revert states of all function flows /// Identify revert states of all function flows
void findRevertStates(); void findRevertStates();

View File

@ -57,6 +57,50 @@ Declaration const* ASTNode::referencedDeclaration(Expression const& _expression)
return nullptr; return nullptr;
} }
FunctionDefinition const* ASTNode::resolveFunctionCall(FunctionCall const& _functionCall, ContractDefinition const* _mostDerivedContract)
{
auto const* functionDef = dynamic_cast<FunctionDefinition const*>(
ASTNode::referencedDeclaration(_functionCall.expression())
);
if (!functionDef)
return nullptr;
if (auto const* memberAccess = dynamic_cast<MemberAccess const*>(&_functionCall.expression()))
{
if (*memberAccess->annotation().requiredLookup == VirtualLookup::Super)
{
if (auto const typeType = dynamic_cast<TypeType const*>(memberAccess->expression().annotation().type))
if (auto const contractType = dynamic_cast<ContractType const*>(typeType->actualType()))
{
solAssert(_mostDerivedContract, "");
solAssert(contractType->isSuper(), "");
ContractDefinition const* superContract = contractType->contractDefinition().superContract(*_mostDerivedContract);
return &functionDef->resolveVirtual(
*_mostDerivedContract,
superContract
);
}
}
else
solAssert(*memberAccess->annotation().requiredLookup == VirtualLookup::Static, "");
}
else if (auto const* identifier = dynamic_cast<Identifier const*>(&_functionCall.expression()))
{
solAssert(*identifier->annotation().requiredLookup == VirtualLookup::Virtual, "");
if (functionDef->virtualSemantics())
{
solAssert(_mostDerivedContract, "");
return &functionDef->resolveVirtual(*_mostDerivedContract);
}
}
else
solAssert(false, "");
return functionDef;
}
ASTAnnotation& ASTNode::annotation() const ASTAnnotation& ASTNode::annotation() const
{ {
if (!m_annotation) if (!m_annotation)

View File

@ -104,6 +104,8 @@ public:
/// Extracts the referenced declaration from all nodes whose annotations support /// Extracts the referenced declaration from all nodes whose annotations support
/// `referencedDeclaration`. /// `referencedDeclaration`.
static Declaration const* referencedDeclaration(Expression const& _expression); static Declaration const* referencedDeclaration(Expression const& _expression);
/// Performs potential super or virtual lookup for a function call based on the most derived contract.
static FunctionDefinition const* resolveFunctionCall(FunctionCall const& _functionCall, ContractDefinition const* _mostDerivedContract);
/// Returns the source code location of this node. /// Returns the source code location of this node.
SourceLocation const& location() const { return m_location; } SourceLocation const& location() const { return m_location; }

View File

@ -884,8 +884,6 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
return; return;
} }
auto const* memberAccess = dynamic_cast<MemberAccess const*>(&_functionCall.expression());
switch (functionType->kind()) switch (functionType->kind())
{ {
case FunctionType::Kind::Declaration: case FunctionType::Kind::Declaration:
@ -893,39 +891,7 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
break; break;
case FunctionType::Kind::Internal: case FunctionType::Kind::Internal:
{ {
auto identifier = dynamic_cast<Identifier const*>(&_functionCall.expression()); FunctionDefinition const* functionDef = ASTNode::resolveFunctionCall(_functionCall, &m_context.mostDerivedContract());
auto const* functionDef = dynamic_cast<FunctionDefinition const*>(
ASTNode::referencedDeclaration(_functionCall.expression())
);
if (functionDef)
{
solAssert(memberAccess || identifier, "");
solAssert(functionType->declaration() == *functionDef, "");
if (identifier)
{
solAssert(*identifier->annotation().requiredLookup == VirtualLookup::Virtual, "");
functionDef = &functionDef->resolveVirtual(m_context.mostDerivedContract());
}
else if (auto typeType = dynamic_cast<TypeType const*>(memberAccess->expression().annotation().type))
if (
auto contractType = dynamic_cast<ContractType const*>(typeType->actualType());
contractType->isSuper()
)
{
ContractDefinition const* super = contractType->contractDefinition().superContract(m_context.mostDerivedContract());
solAssert(super, "Super contract not available.");
solAssert(*memberAccess->annotation().requiredLookup == VirtualLookup::Super, "");
functionDef = &functionDef->resolveVirtual(m_context.mostDerivedContract(), super);
}
solAssert(functionDef && functionDef->isImplemented(), "");
solAssert(
functionDef->parameters().size() == arguments.size() + (functionType->bound() ? 1 : 0),
""
);
}
solAssert(!functionType->takesArbitraryParameters(), ""); solAssert(!functionType->takesArbitraryParameters(), "");
@ -937,11 +903,15 @@ void IRGeneratorForStatements::endVisit(FunctionCall const& _functionCall)
args += convert(*arguments[i], *parameterTypes[i]).stackSlots(); args += convert(*arguments[i], *parameterTypes[i]).stackSlots();
if (functionDef) if (functionDef)
{
solAssert(functionDef->isImplemented(), "");
define(_functionCall) << define(_functionCall) <<
m_context.enqueueFunctionForCodeGeneration(*functionDef) << m_context.enqueueFunctionForCodeGeneration(*functionDef) <<
"(" << "(" <<
joinHumanReadable(args) << joinHumanReadable(args) <<
")\n"; ")\n";
}
else else
{ {
YulArity arity = YulArity::fromType(*functionType); YulArity arity = YulArity::fromType(*functionType);