diff --git a/libsolidity/analysis/FunctionCallGraph.cpp b/libsolidity/analysis/FunctionCallGraph.cpp index 8e605b793..117d18efa 100644 --- a/libsolidity/analysis/FunctionCallGraph.cpp +++ b/libsolidity/analysis/FunctionCallGraph.cpp @@ -98,7 +98,7 @@ bool FunctionCallGraphBuilder::visit(Identifier const& _identifier) // For events kind() == Event, so we have an extra check here if (funType && funType->kind() == FunctionType::Kind::Internal) { - processFunction(callable->resolveVirtual(*m_contract), _identifier.annotation()); + processFunction(callable->resolveVirtual(*m_contract), _identifier.annotation().calledDirectly); solAssert(m_currentNode.has_value(), ""); } @@ -125,19 +125,37 @@ void FunctionCallGraphBuilder::endVisit(MemberAccess const& _memberAccess) // Super functions if (*_memberAccess.annotation().requiredLookup == VirtualLookup::Super) { - if (ContractType const* type = dynamic_cast(_memberAccess.expression().annotation().type)) - { - solAssert(type->isSuper(), ""); - functionDef = &functionDef->resolveVirtual(*m_contract, type->contractDefinition().superContract(*m_contract)); - } + if (auto const* typeType = dynamic_cast(_memberAccess.expression().annotation().type)) + if (auto const contractType = dynamic_cast(typeType->actualType())) + { + solAssert(contractType->isSuper(), ""); + functionDef = + &functionDef->resolveVirtual( + *m_contract, + contractType->contractDefinition().superContract(*m_contract) + ); + } } else solAssert(*_memberAccess.annotation().requiredLookup == VirtualLookup::Static, ""); - processFunction(*functionDef, _memberAccess.annotation()); + processFunction(*functionDef, _memberAccess.annotation().calledDirectly); return; } +void FunctionCallGraphBuilder::endVisit(ModifierInvocation const& _modifierInvocation) +{ + if (auto const* modifier = dynamic_cast(_modifierInvocation.name().annotation().referencedDeclaration)) + { + if (*_modifierInvocation.name().annotation().requiredLookup == VirtualLookup::Virtual) + modifier = &modifier->resolveVirtual(*m_contract); + else + solAssert(*_modifierInvocation.name().annotation().requiredLookup == VirtualLookup::Static, ""); + + processFunction(*modifier); + } +} + void FunctionCallGraphBuilder::visitCallable(CallableDeclaration const* _callable, bool _directCall) { solAssert(!m_graph->edges.count(_callable), ""); @@ -175,13 +193,13 @@ bool FunctionCallGraphBuilder::add(Node _caller, Node _callee) return m_graph->edges[_caller].insert(_callee).second; } -void FunctionCallGraphBuilder::processFunction(CallableDeclaration const& _callable, ExpressionAnnotation const& _annotation) +void FunctionCallGraphBuilder::processFunction(CallableDeclaration const& _callable, bool _calledDirectly) { if (m_graph->edges.count(&_callable)) return; // Create edge to creation dispatch - if (!_annotation.calledDirectly) + if (!_calledDirectly) add(m_currentDispatch, &_callable); - visitCallable(&_callable, _annotation.calledDirectly); + visitCallable(&_callable, _calledDirectly); } diff --git a/libsolidity/analysis/FunctionCallGraph.h b/libsolidity/analysis/FunctionCallGraph.h index f90dd056c..6133fc19e 100644 --- a/libsolidity/analysis/FunctionCallGraph.h +++ b/libsolidity/analysis/FunctionCallGraph.h @@ -81,12 +81,13 @@ private: bool visit(Identifier const& _identifier) override; bool visit(NewExpression const& _newExpression) override; void endVisit(MemberAccess const& _memberAccess) override; + void endVisit(ModifierInvocation const& _modifierInvocation) override; void visitCallable(CallableDeclaration const* _callable, bool _directCall = true); void visitConstructor(ContractDefinition const& _contract); bool add(Node _caller, Node _callee); - void processFunction(CallableDeclaration const& _callable, ExpressionAnnotation const& _annotation); + void processFunction(CallableDeclaration const& _callable, bool _calledDirectly = true); ContractDefinition const* m_contract = nullptr; std::optional m_currentNode;