From 93d84f35544b293bae23d0a2750239db3ac09301 Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 9 Dec 2019 23:43:58 +0100 Subject: [PATCH] Split out override checker into its own file. --- libsolidity/CMakeLists.txt | 2 + libsolidity/analysis/ContractLevelChecker.cpp | 620 +--------------- libsolidity/analysis/ContractLevelChecker.h | 62 +- libsolidity/analysis/OverrideChecker.cpp | 674 ++++++++++++++++++ libsolidity/analysis/OverrideChecker.h | 118 +++ 5 files changed, 798 insertions(+), 678 deletions(-) create mode 100644 libsolidity/analysis/OverrideChecker.cpp create mode 100644 libsolidity/analysis/OverrideChecker.h diff --git a/libsolidity/CMakeLists.txt b/libsolidity/CMakeLists.txt index 00f10ea5c..87eb0e9b7 100644 --- a/libsolidity/CMakeLists.txt +++ b/libsolidity/CMakeLists.txt @@ -18,6 +18,8 @@ set(sources analysis/GlobalContext.h analysis/NameAndTypeResolver.cpp analysis/NameAndTypeResolver.h + analysis/OverrideChecker.cpp + analysis/OverrideChecker.h analysis/PostTypeChecker.cpp analysis/PostTypeChecker.h analysis/ReferencesResolver.cpp diff --git a/libsolidity/analysis/ContractLevelChecker.cpp b/libsolidity/analysis/ContractLevelChecker.cpp index 4e49efc68..687511c53 100644 --- a/libsolidity/analysis/ContractLevelChecker.cpp +++ b/libsolidity/analysis/ContractLevelChecker.cpp @@ -37,38 +37,6 @@ using namespace dev::solidity; namespace { -// Helper struct to do a search by name -struct MatchByName -{ - string const& m_name; - bool operator()(CallableDeclaration const* _callable) - { - return _callable->name() == m_name; - } -}; - -vector> sortByContract(vector> const& _list) -{ - auto sorted = _list; - - stable_sort(sorted.begin(), sorted.end(), - [] (ASTPointer _a, ASTPointer _b) { - if (!_a || !_b) - return _a < _b; - - Declaration const* aDecl = _a->annotation().referencedDeclaration; - Declaration const* bDecl = _b->annotation().referencedDeclaration; - - if (!aDecl || !bDecl) - return aDecl < bDecl; - - return aDecl->id() < bDecl->id(); - } - ); - - return sorted; -} - template bool hasEqualNameAndParameters(T const& _a, B const& _b) { @@ -79,61 +47,13 @@ bool hasEqualNameAndParameters(T const& _a, B const& _b) ); } -vector resolveDirectBaseContracts(ContractDefinition const& _contract) -{ - vector resolvedContracts; - - for (ASTPointer const& specifier: _contract.baseContracts()) - { - Declaration const* baseDecl = - specifier->name().annotation().referencedDeclaration; - auto contract = dynamic_cast(baseDecl); - solAssert(contract, "contract is null"); - resolvedContracts.emplace_back(contract); - } - - return resolvedContracts; -} - -} - -bool ContractLevelChecker::LessFunction::operator()(ModifierDefinition const* _a, ModifierDefinition const* _b) const -{ - return _a->name() < _b->name(); -} - -bool ContractLevelChecker::LessFunction::operator()(FunctionDefinition const* _a, FunctionDefinition const* _b) const -{ - if (_a->name() != _b->name()) - return _a->name() < _b->name(); - - if (_a->kind() != _b->kind()) - return _a->kind() < _b->kind(); - - return boost::lexicographical_compare( - FunctionType(*_a).asCallableFunction(false)->parameterTypes(), - FunctionType(*_b).asCallableFunction(false)->parameterTypes(), - [](auto const& _paramTypeA, auto const& _paramTypeB) - { - return _paramTypeA->richIdentifier() < _paramTypeB->richIdentifier(); - } - ); -} - -bool ContractLevelChecker::LessFunction::operator()(ContractDefinition const* _a, ContractDefinition const* _b) const -{ - if (!_a || !_b) - return _a < _b; - - return _a->id() < _b->id(); } bool ContractLevelChecker::check(ContractDefinition const& _contract) { checkDuplicateFunctions(_contract); checkDuplicateEvents(_contract); - checkIllegalOverrides(_contract); - checkAmbiguousOverrides(_contract); + m_overrideChecker.check(_contract); checkBaseConstructorArguments(_contract); checkAbstractFunctions(_contract); checkExternalTypeClashes(_contract); @@ -236,220 +156,6 @@ void ContractLevelChecker::findDuplicateDefinitions(map> const } } -void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _contract) -{ - FunctionMultiSet const& inheritedFuncs = inheritedFunctions(_contract); - ModifierMultiSet const& inheritedMods = inheritedModifiers(_contract); - - for (auto const* stateVar: _contract.stateVariables()) - { - if (!stateVar->isPublic()) - continue; - - bool found = false; - for ( - auto it = find_if(inheritedFuncs.begin(), inheritedFuncs.end(), MatchByName{stateVar->name()}); - it != inheritedFuncs.end(); - it = find_if(++it, inheritedFuncs.end(), MatchByName{stateVar->name()}) - ) - { - if (!hasEqualNameAndParameters(*stateVar, **it)) - continue; - - if ((*it)->visibility() != Declaration::Visibility::External) - overrideError(*stateVar, **it, "Public state variables can only override functions with external visibility."); - else - checkOverride(*stateVar, **it); - - found = true; - } - - if (!found && stateVar->overrides()) - m_errorReporter.typeError( - stateVar->overrides()->location(), - "Public state variable has override specified but does not override anything." - ); - } - - for (ModifierDefinition const* modifier: _contract.functionModifiers()) - { - if (contains_if(inheritedFuncs, MatchByName{modifier->name()})) - m_errorReporter.typeError( - modifier->location(), - "Override changes function to modifier." - ); - - auto [begin, end] = inheritedMods.equal_range(modifier); - - if (begin == end && modifier->overrides()) - m_errorReporter.typeError( - modifier->overrides()->location(), - "Modifier has override specified but does not override anything." - ); - - for (; begin != end; begin++) - if (ModifierType(**begin) != ModifierType(*modifier)) - m_errorReporter.typeError( - modifier->location(), - "Override changes modifier signature." - ); - - checkOverrideList(inheritedMods, *modifier); - } - - for (FunctionDefinition const* function: _contract.definedFunctions()) - { - if (function->isConstructor()) - continue; - - if (contains_if(inheritedMods, MatchByName{function->name()})) - m_errorReporter.typeError(function->location(), "Override changes modifier to function."); - - // No inheriting functions found - if (!inheritedFuncs.count(function) && function->overrides()) - m_errorReporter.typeError( - function->overrides()->location(), - "Function has override specified but does not override anything." - ); - - checkOverrideList(inheritedFuncs, *function); - } -} - -template -void ContractLevelChecker::checkOverride(T const& _overriding, U const& _super) -{ - static_assert( - std::is_same::value || - std::is_same::value || - std::is_same::value, - "Invalid call to checkOverride." - ); - - static_assert( - std::is_same::value || - std::is_same::value, - "Invalid call to checkOverride." - ); - static_assert( - !std::is_same::value || - std::is_same::value, - "Invalid call to checkOverride." - ); - - string overridingName; - if constexpr(std::is_same::value) - overridingName = "function"; - else if constexpr(std::is_same::value) - overridingName = "modifier"; - else - overridingName = "public state variable"; - - string superName; - if constexpr(std::is_same::value) - superName = "function"; - else - superName = "modifier"; - - if (!_overriding.overrides()) - overrideError(_overriding, _super, "Overriding " + overridingName + " is missing 'override' specifier."); - - if (!_super.virtualSemantics()) - overrideError( - _super, - _overriding, - "Trying to override non-virtual " + superName + ". Did you forget to add \"virtual\"?", - "Overriding " + overridingName + " is here:" - ); - - if (_overriding.visibility() != _super.visibility()) - { - // Visibility change from external to public is fine. - // Any other change is disallowed. - if (!( - _super.visibility() == FunctionDefinition::Visibility::External && - _overriding.visibility() == FunctionDefinition::Visibility::Public - )) - overrideError(_overriding, _super, "Overriding " + overridingName + " visibility differs."); - } - - // This is only relevant for overriding functions by functions or state variables, - // it is skipped for modifiers. - if constexpr(std::is_same::value) - { - FunctionTypePointer functionType = FunctionType(_overriding).asCallableFunction(false); - FunctionTypePointer superType = FunctionType(_super).asCallableFunction(false); - - solAssert(functionType->hasEqualParameterTypes(*superType), "Override doesn't have equal parameters!"); - - if (!functionType->hasEqualReturnTypes(*superType)) - overrideError(_overriding, _super, "Overriding " + overridingName + " return types differ."); - - // This is only relevant for a function overriding a function. - if constexpr(std::is_same::value) - { - _overriding.annotation().baseFunctions.emplace(&_super); - - if (_overriding.stateMutability() != _super.stateMutability()) - overrideError( - _overriding, - _super, - "Overriding function changes state mutability from \"" + - stateMutabilityToString(_super.stateMutability()) + - "\" to \"" + - stateMutabilityToString(_overriding.stateMutability()) + - "\"." - ); - - if (!_overriding.isImplemented() && _super.isImplemented()) - overrideError( - _overriding, - _super, - "Overriding an implemented function with an unimplemented function is not allowed." - ); - } - } -} - -void ContractLevelChecker::overrideListError( - CallableDeclaration const& _callable, - set _secondary, - string const& _message1, - string const& _message2 -) -{ - // Using a set rather than a vector so the order is always the same - set names; - SecondarySourceLocation ssl; - for (Declaration const* c: _secondary) - { - ssl.append("This contract: ", c->location()); - names.insert(c->name()); - } - string contractSingularPlural = "contract "; - if (_secondary.size() > 1) - contractSingularPlural = "contracts "; - - m_errorReporter.typeError( - _callable.overrides() ? _callable.overrides()->location() : _callable.location(), - ssl, - _message1 + - contractSingularPlural + - _message2 + - joinHumanReadable(names, ", ", " and ") + - "." - ); -} - -void ContractLevelChecker::overrideError(Declaration const& _overriding, Declaration const& _super, string _message, string _secondaryMsg) -{ - m_errorReporter.typeError( - _overriding.location(), - SecondarySourceLocation().append(_secondaryMsg, _super.location()), - _message - ); -} - void ContractLevelChecker::checkAbstractFunctions(ContractDefinition const& _contract) { // Mapping from name to function definition (exactly one per argument type equality class) and @@ -728,330 +434,6 @@ void ContractLevelChecker::checkBaseABICompatibility(ContractDefinition const& _ } -void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _contract) const -{ - std::function compareById = - [](auto const* _a, auto const* _b) { return _a->id() < _b->id(); }; - - { - // Fetch inherited functions and sort them by signature. - // We get at least one function per signature and direct base contract, which is - // enough because we re-construct the inheritance graph later. - FunctionMultiSet nonOverriddenFunctions = inheritedFunctions(_contract); - // Remove all functions that match the signature of a function in the current contract. - nonOverriddenFunctions -= _contract.definedFunctions(); - - // Walk through the set of functions signature by signature. - for (auto it = nonOverriddenFunctions.cbegin(); it != nonOverriddenFunctions.cend();) - { - std::set baseFunctions(compareById); - for (auto nextSignature = nonOverriddenFunctions.upper_bound(*it); it != nextSignature; ++it) - baseFunctions.insert(*it); - - checkAmbiguousOverridesInternal(std::move(baseFunctions), _contract.location()); - } - } - - { - ModifierMultiSet modifiers = inheritedModifiers(_contract); - modifiers -= _contract.functionModifiers(); - for (auto it = modifiers.cbegin(); it != modifiers.cend();) - { - std::set baseModifiers(compareById); - for (auto next = modifiers.upper_bound(*it); it != next; ++it) - baseModifiers.insert(*it); - - checkAmbiguousOverridesInternal(std::move(baseModifiers), _contract.location()); - } - - } -} - -void ContractLevelChecker::checkAmbiguousOverridesInternal(set< - CallableDeclaration const*, - std::function -> _baseCallables, SourceLocation const& _location) const -{ - if (_baseCallables.size() <= 1) - return; - - // Construct the override graph for this signature. - // Reserve node 0 for the current contract and node - // 1 for an artificial top node to which all override paths - // connect at the end. - struct OverrideGraph - { - OverrideGraph(decltype(_baseCallables) const& __baseCallables) - { - for (auto const* baseFunction: __baseCallables) - addEdge(0, visit(baseFunction)); - } - std::map nodes; - std::map nodeInv; - std::map> edges; - int numNodes = 2; - void addEdge(int _a, int _b) - { - edges[_a].insert(_b); - edges[_b].insert(_a); - } - private: - /// Completes the graph starting from @a _function and - /// @returns the node ID. - int visit(CallableDeclaration const* _function) - { - auto it = nodes.find(_function); - if (it != nodes.end()) - return it->second; - int currentNode = numNodes++; - nodes[_function] = currentNode; - nodeInv[currentNode] = _function; - if (_function->overrides()) - for (auto const* baseFunction: _function->annotation().baseFunctions) - addEdge(currentNode, visit(baseFunction)); - else - addEdge(currentNode, 1); - - return currentNode; - } - } overrideGraph(_baseCallables); - - // Detect cut vertices following https://en.wikipedia.org/wiki/Biconnected_component#Pseudocode - // Can ignore the root node, since it is never a cut vertex in our case. - struct CutVertexFinder - { - CutVertexFinder(OverrideGraph const& _graph): m_graph(_graph) - { - run(); - } - std::set const& cutVertices() const { return m_cutVertices; } - - private: - OverrideGraph const& m_graph; - - std::vector m_visited = std::vector(m_graph.numNodes, false); - std::vector m_depths = std::vector(m_graph.numNodes, -1); - std::vector m_low = std::vector(m_graph.numNodes, -1); - std::vector m_parent = std::vector(m_graph.numNodes, -1); - std::set m_cutVertices{}; - - void run(int _u = 0, int _depth = 0) - { - m_visited.at(_u) = true; - m_depths.at(_u) = m_low.at(_u) = _depth; - for (int v: m_graph.edges.at(_u)) - if (!m_visited.at(v)) - { - m_parent[v] = _u; - run(v, _depth + 1); - if (m_low[v] >= m_depths[_u] && m_parent[_u] != -1) - m_cutVertices.insert(m_graph.nodeInv.at(_u)); - m_low[_u] = min(m_low[_u], m_low[v]); - } - else if (v != m_parent[_u]) - m_low[_u] = min(m_low[_u], m_depths[v]); - } - } cutVertexFinder{overrideGraph}; - - // Remove all base functions overridden by cut vertices (they don't need to be overridden). - for (auto const* function: cutVertexFinder.cutVertices()) - { - std::set toTraverse = function->annotation().baseFunctions; - while (!toTraverse.empty()) - { - auto const *base = *toTraverse.begin(); - toTraverse.erase(toTraverse.begin()); - _baseCallables.erase(base); - for (CallableDeclaration const* f: base->annotation().baseFunctions) - toTraverse.insert(f); - } - // Remove unimplemented base functions at the cut vertices itself as well. - if (auto opt = dynamic_cast(function)) - if (!opt->isImplemented()) - _baseCallables.erase(function); - } - - // If more than one function is left, they have to be overridden. - if (_baseCallables.size() <= 1) - return; - - SecondarySourceLocation ssl; - for (auto const* baseFunction: _baseCallables) - { - string contractName = dynamic_cast(*baseFunction->scope()).name(); - ssl.append("Definition in \"" + contractName + "\": ", baseFunction->location()); - } - - string callableName; - string distinguishigProperty; - if (dynamic_cast(*_baseCallables.begin())) - { - callableName = "function"; - distinguishigProperty = "name and parameter types"; - } - else if (dynamic_cast(*_baseCallables.begin())) - { - callableName = "modifier"; - distinguishigProperty = "name"; - } - else - solAssert(false, "Invalid type for ambiguous override."); - - m_errorReporter.typeError( - _location, - ssl, - "Derived contract must override " + callableName + " \"" + - (*_baseCallables.begin())->name() + - "\". Two or more base classes define " + callableName + " with same " + distinguishigProperty + "." - ); -} - -set ContractLevelChecker::resolveOverrideList(OverrideSpecifier const& _overrides) const -{ - set resolved; - - for (ASTPointer const& override: _overrides.overrides()) - { - Declaration const* decl = override->annotation().referencedDeclaration; - solAssert(decl, "Expected declaration to be resolved."); - - // If it's not a contract it will be caught - // in the reference resolver - if (ContractDefinition const* contract = dynamic_cast(decl)) - resolved.insert(contract); - } - - return resolved; -} - -template -void ContractLevelChecker::checkOverrideList( - std::multiset const& _inheritedCallables, - T const& _callable -) -{ - set specifiedContracts = - _callable.overrides() ? - resolveOverrideList(*_callable.overrides()) : - decltype(specifiedContracts){}; - - // Check for duplicates in override list - if (_callable.overrides() && specifiedContracts.size() != _callable.overrides()->overrides().size()) - { - // Sort by contract id to find duplicate for error reporting - vector> list = - sortByContract(_callable.overrides()->overrides()); - - // Find duplicates and output error - for (size_t i = 1; i < list.size(); i++) - { - Declaration const* aDecl = list[i]->annotation().referencedDeclaration; - Declaration const* bDecl = list[i-1]->annotation().referencedDeclaration; - if (!aDecl || !bDecl) - continue; - - if (aDecl->id() == bDecl->id()) - { - SecondarySourceLocation ssl; - ssl.append("First occurrence here: ", list[i-1]->location()); - m_errorReporter.typeError( - list[i]->location(), - ssl, - "Duplicate contract \"" + - joinHumanReadable(list[i]->namePath(), ".") + - "\" found in override list of \"" + - _callable.name() + - "\"." - ); - } - } - } - - decltype(specifiedContracts) expectedContracts; - - // Build list of expected contracts - for (auto [begin, end] = _inheritedCallables.equal_range(&_callable); begin != end; begin++) - { - // Validate the override - checkOverride(_callable, **begin); - - expectedContracts.insert(&dynamic_cast(*(*begin)->scope())); - } - - decltype(specifiedContracts) missingContracts; - decltype(specifiedContracts) surplusContracts; - - // If we expect only one contract, no contract needs to be specified - if (expectedContracts.size() > 1) - missingContracts = expectedContracts - specifiedContracts; - - surplusContracts = specifiedContracts - expectedContracts; - - if (!missingContracts.empty()) - overrideListError( - _callable, - missingContracts, - "Function needs to specify overridden ", - "" - ); - - if (!surplusContracts.empty()) - overrideListError( - _callable, - surplusContracts, - "Invalid ", - "specified in override list: " - ); -} - -ContractLevelChecker::FunctionMultiSet const& ContractLevelChecker::inheritedFunctions(ContractDefinition const& _contract) const -{ - if (!m_inheritedFunctions.count(&_contract)) - { - FunctionMultiSet set; - - for (auto const* base: resolveDirectBaseContracts(_contract)) - { - std::set functionsInBase; - for (FunctionDefinition const* fun: base->definedFunctions()) - if (!fun->isConstructor()) - functionsInBase.emplace(fun); - - for (auto const& func: inheritedFunctions(*base)) - functionsInBase.insert(func); - - set += functionsInBase; - } - - m_inheritedFunctions[&_contract] = set; - } - - return m_inheritedFunctions[&_contract]; -} - -ContractLevelChecker::ModifierMultiSet const& ContractLevelChecker::inheritedModifiers(ContractDefinition const& _contract) const -{ - auto const& result = m_contractBaseModifiers.find(&_contract); - - if (result != m_contractBaseModifiers.cend()) - return result->second; - - ModifierMultiSet set; - - for (auto const* base: resolveDirectBaseContracts(_contract)) - { - std::set tmpSet = - convertContainer(base->functionModifiers()); - - for (auto const& mod: inheritedModifiers(*base)) - tmpSet.insert(mod); - - set += tmpSet; - } - - return m_contractBaseModifiers[&_contract] = set; -} - void ContractLevelChecker::checkPayableFallbackWithoutReceive(ContractDefinition const& _contract) { if (auto const* fallback = _contract.fallbackFunction()) diff --git a/libsolidity/analysis/ContractLevelChecker.h b/libsolidity/analysis/ContractLevelChecker.h index 0696c2d02..b9ac5b7a9 100644 --- a/libsolidity/analysis/ContractLevelChecker.h +++ b/libsolidity/analysis/ContractLevelChecker.h @@ -22,6 +22,7 @@ #pragma once #include +#include #include #include #include @@ -47,6 +48,7 @@ public: /// @param _errorReporter provides the error logging functionality. explicit ContractLevelChecker(langutil::ErrorReporter& _errorReporter): + m_overrideChecker{_errorReporter}, m_errorReporter(_errorReporter) {} @@ -55,47 +57,12 @@ public: bool check(ContractDefinition const& _contract); private: - /** - * Comparator that compares - * - functions such that equality means that the functions override each other - * - modifiers by name - * - contracts by AST id. - */ - struct LessFunction - { - bool operator()(ModifierDefinition const* _a, ModifierDefinition const* _b) const; - bool operator()(FunctionDefinition const* _a, FunctionDefinition const* _b) const; - bool operator()(ContractDefinition const* _a, ContractDefinition const* _b) const; - }; - - using FunctionMultiSet = std::multiset; - using ModifierMultiSet = std::multiset; - /// Checks that two functions defined in this contract with the same name have different /// arguments and that there is at most one constructor. void checkDuplicateFunctions(ContractDefinition const& _contract); void checkDuplicateEvents(ContractDefinition const& _contract); template void findDuplicateDefinitions(std::map> const& _definitions, std::string _message); - void checkIllegalOverrides(ContractDefinition const& _contract); - /// Performs various checks related to @a _overriding overriding @a _super like - /// different return type, invalid visibility change, etc. - /// Works on functions, modifiers and public state variables. - /// Also stores @a _super as a base function of @a _function in its AST annotations. - template - void checkOverride(T const& _overriding, U const& _super); - void overrideListError( - CallableDeclaration const& _callable, - std::set _secondary, - std::string const& _message1, - std::string const& _message2 - ); - void overrideError( - Declaration const& _overriding, - Declaration const& _super, - std::string _message, - std::string _secondaryMsg = "Overridden function is here:" - ); void checkAbstractFunctions(ContractDefinition const& _contract); /// Checks that the base constructor arguments are properly provided. /// Fills the list of unimplemented functions in _contract's annotations. @@ -114,35 +81,12 @@ private: void checkLibraryRequirements(ContractDefinition const& _contract); /// Checks base contracts for ABI compatibility void checkBaseABICompatibility(ContractDefinition const& _contract); - /// Checks for functions in different base contracts which conflict with each - /// other and thus need to be overridden explicitly. - void checkAmbiguousOverrides(ContractDefinition const& _contract) const; - void checkAmbiguousOverridesInternal(std::set< - CallableDeclaration const*, - std::function - > _baseCallables, langutil::SourceLocation const& _location) const; - /// Resolves an override list of UserDefinedTypeNames to a list of contracts. - std::set resolveOverrideList(OverrideSpecifier const& _overrides) const; - - template - void checkOverrideList( - std::multiset const& _funcSet, - T const& _function - ); - - /// Returns all functions of bases that have not yet been overwritten. - /// May contain the same function multiple times when used with shared bases. - FunctionMultiSet const& inheritedFunctions(ContractDefinition const& _contract) const; - ModifierMultiSet const& inheritedModifiers(ContractDefinition const& _contract) const; /// Warns if the contract has a payable fallback, but no receive ether function. void checkPayableFallbackWithoutReceive(ContractDefinition const& _contract); + OverrideChecker m_overrideChecker; langutil::ErrorReporter& m_errorReporter; - - /// Cache for inheritedFunctions(). - std::map mutable m_inheritedFunctions; - std::map mutable m_contractBaseModifiers; }; } diff --git a/libsolidity/analysis/OverrideChecker.cpp b/libsolidity/analysis/OverrideChecker.cpp new file mode 100644 index 000000000..922070171 --- /dev/null +++ b/libsolidity/analysis/OverrideChecker.cpp @@ -0,0 +1,674 @@ +/* + This file is part of solidity. + + solidity is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + solidity is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with solidity. If not, see . +*/ +/** + * Component that verifies overloads, abstract contracts, function clashes and others + * checks at contract or function level. + */ + +#include + +#include +#include +#include +#include +#include +#include + + +using namespace std; +using namespace dev; +using namespace langutil; +using namespace dev::solidity; + +namespace +{ + +// Helper struct to do a search by name +struct MatchByName +{ + string const& m_name; + bool operator()(CallableDeclaration const* _callable) + { + return _callable->name() == m_name; + } +}; + +vector> sortByContract(vector> const& _list) +{ + auto sorted = _list; + + stable_sort(sorted.begin(), sorted.end(), + [] (ASTPointer _a, ASTPointer _b) { + if (!_a || !_b) + return _a < _b; + + Declaration const* aDecl = _a->annotation().referencedDeclaration; + Declaration const* bDecl = _b->annotation().referencedDeclaration; + + if (!aDecl || !bDecl) + return aDecl < bDecl; + + return aDecl->id() < bDecl->id(); + } + ); + + return sorted; +} + +template +bool hasEqualNameAndParameters(T const& _a, B const& _b) +{ + return + _a.name() == _b.name() && + FunctionType(_a).asCallableFunction(false)->hasEqualParameterTypes( + *FunctionType(_b).asCallableFunction(false) + ); +} + +vector resolveDirectBaseContracts(ContractDefinition const& _contract) +{ + vector resolvedContracts; + + for (ASTPointer const& specifier: _contract.baseContracts()) + { + Declaration const* baseDecl = + specifier->name().annotation().referencedDeclaration; + auto contract = dynamic_cast(baseDecl); + solAssert(contract, "contract is null"); + resolvedContracts.emplace_back(contract); + } + + return resolvedContracts; +} + +} + +bool OverrideChecker::LessFunction::operator()(ModifierDefinition const* _a, ModifierDefinition const* _b) const +{ + return _a->name() < _b->name(); +} + +bool OverrideChecker::LessFunction::operator()(FunctionDefinition const* _a, FunctionDefinition const* _b) const +{ + if (_a->name() != _b->name()) + return _a->name() < _b->name(); + + if (_a->kind() != _b->kind()) + return _a->kind() < _b->kind(); + + return boost::lexicographical_compare( + FunctionType(*_a).asCallableFunction(false)->parameterTypes(), + FunctionType(*_b).asCallableFunction(false)->parameterTypes(), + [](auto const& _paramTypeA, auto const& _paramTypeB) + { + return _paramTypeA->richIdentifier() < _paramTypeB->richIdentifier(); + } + ); +} + +bool OverrideChecker::LessFunction::operator()(ContractDefinition const* _a, ContractDefinition const* _b) const +{ + if (!_a || !_b) + return _a < _b; + + return _a->id() < _b->id(); +} + +void OverrideChecker::check(ContractDefinition const& _contract) +{ + checkIllegalOverrides(_contract); + checkAmbiguousOverrides(_contract); +} + +void OverrideChecker::checkIllegalOverrides(ContractDefinition const& _contract) +{ + FunctionMultiSet const& inheritedFuncs = inheritedFunctions(_contract); + ModifierMultiSet const& inheritedMods = inheritedModifiers(_contract); + + for (auto const* stateVar: _contract.stateVariables()) + { + if (!stateVar->isPublic()) + continue; + + bool found = false; + for ( + auto it = find_if(inheritedFuncs.begin(), inheritedFuncs.end(), MatchByName{stateVar->name()}); + it != inheritedFuncs.end(); + it = find_if(++it, inheritedFuncs.end(), MatchByName{stateVar->name()}) + ) + { + if (!hasEqualNameAndParameters(*stateVar, **it)) + continue; + + if ((*it)->visibility() != Declaration::Visibility::External) + overrideError(*stateVar, **it, "Public state variables can only override functions with external visibility."); + else + checkOverride(*stateVar, **it); + + found = true; + } + + if (!found && stateVar->overrides()) + m_errorReporter.typeError( + stateVar->overrides()->location(), + "Public state variable has override specified but does not override anything." + ); + } + + for (ModifierDefinition const* modifier: _contract.functionModifiers()) + { + if (contains_if(inheritedFuncs, MatchByName{modifier->name()})) + m_errorReporter.typeError( + modifier->location(), + "Override changes function to modifier." + ); + + auto [begin, end] = inheritedMods.equal_range(modifier); + + if (begin == end && modifier->overrides()) + m_errorReporter.typeError( + modifier->overrides()->location(), + "Modifier has override specified but does not override anything." + ); + + for (; begin != end; begin++) + if (ModifierType(**begin) != ModifierType(*modifier)) + m_errorReporter.typeError( + modifier->location(), + "Override changes modifier signature." + ); + + checkOverrideList(inheritedMods, *modifier); + } + + for (FunctionDefinition const* function: _contract.definedFunctions()) + { + if (function->isConstructor()) + continue; + + if (contains_if(inheritedMods, MatchByName{function->name()})) + m_errorReporter.typeError(function->location(), "Override changes modifier to function."); + + // No inheriting functions found + if (!inheritedFuncs.count(function) && function->overrides()) + m_errorReporter.typeError( + function->overrides()->location(), + "Function has override specified but does not override anything." + ); + + checkOverrideList(inheritedFuncs, *function); + } +} + +template +void OverrideChecker::checkOverride(T const& _overriding, U const& _super) +{ + static_assert( + std::is_same::value || + std::is_same::value || + std::is_same::value, + "Invalid call to checkOverride." + ); + + static_assert( + std::is_same::value || + std::is_same::value, + "Invalid call to checkOverride." + ); + static_assert( + !std::is_same::value || + std::is_same::value, + "Invalid call to checkOverride." + ); + + string overridingName; + if constexpr(std::is_same::value) + overridingName = "function"; + else if constexpr(std::is_same::value) + overridingName = "modifier"; + else + overridingName = "public state variable"; + + string superName; + if constexpr(std::is_same::value) + superName = "function"; + else + superName = "modifier"; + + if (!_overriding.overrides()) + overrideError(_overriding, _super, "Overriding " + overridingName + " is missing 'override' specifier."); + + if (!_super.virtualSemantics()) + overrideError( + _super, + _overriding, + "Trying to override non-virtual " + superName + ". Did you forget to add \"virtual\"?", + "Overriding " + overridingName + " is here:" + ); + + if (_overriding.visibility() != _super.visibility()) + { + // Visibility change from external to public is fine. + // Any other change is disallowed. + if (!( + _super.visibility() == FunctionDefinition::Visibility::External && + _overriding.visibility() == FunctionDefinition::Visibility::Public + )) + overrideError(_overriding, _super, "Overriding " + overridingName + " visibility differs."); + } + + // This is only relevant for overriding functions by functions or state variables, + // it is skipped for modifiers. + if constexpr(std::is_same::value) + { + FunctionTypePointer functionType = FunctionType(_overriding).asCallableFunction(false); + FunctionTypePointer superType = FunctionType(_super).asCallableFunction(false); + + solAssert(functionType->hasEqualParameterTypes(*superType), "Override doesn't have equal parameters!"); + + if (!functionType->hasEqualReturnTypes(*superType)) + overrideError(_overriding, _super, "Overriding " + overridingName + " return types differ."); + + // This is only relevant for a function overriding a function. + if constexpr(std::is_same::value) + { + _overriding.annotation().baseFunctions.emplace(&_super); + + if (_overriding.stateMutability() != _super.stateMutability()) + overrideError( + _overriding, + _super, + "Overriding function changes state mutability from \"" + + stateMutabilityToString(_super.stateMutability()) + + "\" to \"" + + stateMutabilityToString(_overriding.stateMutability()) + + "\"." + ); + + if (!_overriding.isImplemented() && _super.isImplemented()) + overrideError( + _overriding, + _super, + "Overriding an implemented function with an unimplemented function is not allowed." + ); + } + } +} + +void OverrideChecker::overrideListError( + CallableDeclaration const& _callable, + set _secondary, + string const& _message1, + string const& _message2 +) +{ + // Using a set rather than a vector so the order is always the same + set names; + SecondarySourceLocation ssl; + for (Declaration const* c: _secondary) + { + ssl.append("This contract: ", c->location()); + names.insert(c->name()); + } + string contractSingularPlural = "contract "; + if (_secondary.size() > 1) + contractSingularPlural = "contracts "; + + m_errorReporter.typeError( + _callable.overrides() ? _callable.overrides()->location() : _callable.location(), + ssl, + _message1 + + contractSingularPlural + + _message2 + + joinHumanReadable(names, ", ", " and ") + + "." + ); +} + +void OverrideChecker::overrideError(Declaration const& _overriding, Declaration const& _super, string _message, string _secondaryMsg) +{ + m_errorReporter.typeError( + _overriding.location(), + SecondarySourceLocation().append(_secondaryMsg, _super.location()), + _message + ); +} + +void OverrideChecker::checkAmbiguousOverrides(ContractDefinition const& _contract) const +{ + std::function compareById = + [](auto const* _a, auto const* _b) { return _a->id() < _b->id(); }; + + { + // Fetch inherited functions and sort them by signature. + // We get at least one function per signature and direct base contract, which is + // enough because we re-construct the inheritance graph later. + FunctionMultiSet nonOverriddenFunctions = inheritedFunctions(_contract); + // Remove all functions that match the signature of a function in the current contract. + nonOverriddenFunctions -= _contract.definedFunctions(); + + // Walk through the set of functions signature by signature. + for (auto it = nonOverriddenFunctions.cbegin(); it != nonOverriddenFunctions.cend();) + { + std::set baseFunctions(compareById); + for (auto nextSignature = nonOverriddenFunctions.upper_bound(*it); it != nextSignature; ++it) + baseFunctions.insert(*it); + + checkAmbiguousOverridesInternal(std::move(baseFunctions), _contract.location()); + } + } + + { + ModifierMultiSet modifiers = inheritedModifiers(_contract); + modifiers -= _contract.functionModifiers(); + for (auto it = modifiers.cbegin(); it != modifiers.cend();) + { + std::set baseModifiers(compareById); + for (auto next = modifiers.upper_bound(*it); it != next; ++it) + baseModifiers.insert(*it); + + checkAmbiguousOverridesInternal(std::move(baseModifiers), _contract.location()); + } + + } +} + +void OverrideChecker::checkAmbiguousOverridesInternal(set< + CallableDeclaration const*, + std::function +> _baseCallables, SourceLocation const& _location) const +{ + if (_baseCallables.size() <= 1) + return; + + // Construct the override graph for this signature. + // Reserve node 0 for the current contract and node + // 1 for an artificial top node to which all override paths + // connect at the end. + struct OverrideGraph + { + OverrideGraph(decltype(_baseCallables) const& __baseCallables) + { + for (auto const* baseFunction: __baseCallables) + addEdge(0, visit(baseFunction)); + } + std::map nodes; + std::map nodeInv; + std::map> edges; + int numNodes = 2; + void addEdge(int _a, int _b) + { + edges[_a].insert(_b); + edges[_b].insert(_a); + } + private: + /// Completes the graph starting from @a _function and + /// @returns the node ID. + int visit(CallableDeclaration const* _function) + { + auto it = nodes.find(_function); + if (it != nodes.end()) + return it->second; + int currentNode = numNodes++; + nodes[_function] = currentNode; + nodeInv[currentNode] = _function; + if (_function->overrides()) + for (auto const* baseFunction: _function->annotation().baseFunctions) + addEdge(currentNode, visit(baseFunction)); + else + addEdge(currentNode, 1); + + return currentNode; + } + } overrideGraph(_baseCallables); + + // Detect cut vertices following https://en.wikipedia.org/wiki/Biconnected_component#Pseudocode + // Can ignore the root node, since it is never a cut vertex in our case. + struct CutVertexFinder + { + CutVertexFinder(OverrideGraph const& _graph): m_graph(_graph) + { + run(); + } + std::set const& cutVertices() const { return m_cutVertices; } + + private: + OverrideGraph const& m_graph; + + std::vector m_visited = std::vector(m_graph.numNodes, false); + std::vector m_depths = std::vector(m_graph.numNodes, -1); + std::vector m_low = std::vector(m_graph.numNodes, -1); + std::vector m_parent = std::vector(m_graph.numNodes, -1); + std::set m_cutVertices{}; + + void run(int _u = 0, int _depth = 0) + { + m_visited.at(_u) = true; + m_depths.at(_u) = m_low.at(_u) = _depth; + for (int v: m_graph.edges.at(_u)) + if (!m_visited.at(v)) + { + m_parent[v] = _u; + run(v, _depth + 1); + if (m_low[v] >= m_depths[_u] && m_parent[_u] != -1) + m_cutVertices.insert(m_graph.nodeInv.at(_u)); + m_low[_u] = min(m_low[_u], m_low[v]); + } + else if (v != m_parent[_u]) + m_low[_u] = min(m_low[_u], m_depths[v]); + } + } cutVertexFinder{overrideGraph}; + + // Remove all base functions overridden by cut vertices (they don't need to be overridden). + for (auto const* function: cutVertexFinder.cutVertices()) + { + std::set toTraverse = function->annotation().baseFunctions; + while (!toTraverse.empty()) + { + auto const *base = *toTraverse.begin(); + toTraverse.erase(toTraverse.begin()); + _baseCallables.erase(base); + for (CallableDeclaration const* f: base->annotation().baseFunctions) + toTraverse.insert(f); + } + // Remove unimplemented base functions at the cut vertices itself as well. + if (auto opt = dynamic_cast(function)) + if (!opt->isImplemented()) + _baseCallables.erase(function); + } + + // If more than one function is left, they have to be overridden. + if (_baseCallables.size() <= 1) + return; + + SecondarySourceLocation ssl; + for (auto const* baseFunction: _baseCallables) + { + string contractName = dynamic_cast(*baseFunction->scope()).name(); + ssl.append("Definition in \"" + contractName + "\": ", baseFunction->location()); + } + + string callableName; + string distinguishigProperty; + if (dynamic_cast(*_baseCallables.begin())) + { + callableName = "function"; + distinguishigProperty = "name and parameter types"; + } + else if (dynamic_cast(*_baseCallables.begin())) + { + callableName = "modifier"; + distinguishigProperty = "name"; + } + else + solAssert(false, "Invalid type for ambiguous override."); + + m_errorReporter.typeError( + _location, + ssl, + "Derived contract must override " + callableName + " \"" + + (*_baseCallables.begin())->name() + + "\". Two or more base classes define " + callableName + " with same " + distinguishigProperty + "." + ); +} + +set OverrideChecker::resolveOverrideList(OverrideSpecifier const& _overrides) const +{ + set resolved; + + for (ASTPointer const& override: _overrides.overrides()) + { + Declaration const* decl = override->annotation().referencedDeclaration; + solAssert(decl, "Expected declaration to be resolved."); + + // If it's not a contract it will be caught + // in the reference resolver + if (ContractDefinition const* contract = dynamic_cast(decl)) + resolved.insert(contract); + } + + return resolved; +} + +template +void OverrideChecker::checkOverrideList( + std::multiset const& _inheritedCallables, + T const& _callable +) +{ + set specifiedContracts = + _callable.overrides() ? + resolveOverrideList(*_callable.overrides()) : + decltype(specifiedContracts){}; + + // Check for duplicates in override list + if (_callable.overrides() && specifiedContracts.size() != _callable.overrides()->overrides().size()) + { + // Sort by contract id to find duplicate for error reporting + vector> list = + sortByContract(_callable.overrides()->overrides()); + + // Find duplicates and output error + for (size_t i = 1; i < list.size(); i++) + { + Declaration const* aDecl = list[i]->annotation().referencedDeclaration; + Declaration const* bDecl = list[i-1]->annotation().referencedDeclaration; + if (!aDecl || !bDecl) + continue; + + if (aDecl->id() == bDecl->id()) + { + SecondarySourceLocation ssl; + ssl.append("First occurrence here: ", list[i-1]->location()); + m_errorReporter.typeError( + list[i]->location(), + ssl, + "Duplicate contract \"" + + joinHumanReadable(list[i]->namePath(), ".") + + "\" found in override list of \"" + + _callable.name() + + "\"." + ); + } + } + } + + decltype(specifiedContracts) expectedContracts; + + // Build list of expected contracts + for (auto [begin, end] = _inheritedCallables.equal_range(&_callable); begin != end; begin++) + { + // Validate the override + checkOverride(_callable, **begin); + + expectedContracts.insert(&dynamic_cast(*(*begin)->scope())); + } + + decltype(specifiedContracts) missingContracts; + decltype(specifiedContracts) surplusContracts; + + // If we expect only one contract, no contract needs to be specified + if (expectedContracts.size() > 1) + missingContracts = expectedContracts - specifiedContracts; + + surplusContracts = specifiedContracts - expectedContracts; + + if (!missingContracts.empty()) + overrideListError( + _callable, + missingContracts, + "Function needs to specify overridden ", + "" + ); + + if (!surplusContracts.empty()) + overrideListError( + _callable, + surplusContracts, + "Invalid ", + "specified in override list: " + ); +} + +OverrideChecker::FunctionMultiSet const& OverrideChecker::inheritedFunctions(ContractDefinition const& _contract) const +{ + if (!m_inheritedFunctions.count(&_contract)) + { + FunctionMultiSet set; + + for (auto const* base: resolveDirectBaseContracts(_contract)) + { + std::set functionsInBase; + for (FunctionDefinition const* fun: base->definedFunctions()) + if (!fun->isConstructor()) + functionsInBase.emplace(fun); + + for (auto const& func: inheritedFunctions(*base)) + functionsInBase.insert(func); + + set += functionsInBase; + } + + m_inheritedFunctions[&_contract] = set; + } + + return m_inheritedFunctions[&_contract]; +} + +OverrideChecker::ModifierMultiSet const& OverrideChecker::inheritedModifiers(ContractDefinition const& _contract) const +{ + auto const& result = m_contractBaseModifiers.find(&_contract); + + if (result != m_contractBaseModifiers.cend()) + return result->second; + + ModifierMultiSet set; + + for (auto const* base: resolveDirectBaseContracts(_contract)) + { + std::set tmpSet = + convertContainer(base->functionModifiers()); + + for (auto const& mod: inheritedModifiers(*base)) + tmpSet.insert(mod); + + set += tmpSet; + } + + return m_contractBaseModifiers[&_contract] = set; +} + diff --git a/libsolidity/analysis/OverrideChecker.h b/libsolidity/analysis/OverrideChecker.h new file mode 100644 index 000000000..59801e2ef --- /dev/null +++ b/libsolidity/analysis/OverrideChecker.h @@ -0,0 +1,118 @@ +/* + This file is part of solidity. + + solidity is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + solidity is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with solidity. If not, see . +*/ +/** + * Component that verifies override properties. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace langutil +{ +class ErrorReporter; +} + +namespace dev +{ +namespace solidity +{ + +/** + * Component that verifies override properties. + */ +class OverrideChecker +{ +public: + + /// @param _errorReporter provides the error logging functionality. + explicit OverrideChecker(langutil::ErrorReporter& _errorReporter): + m_errorReporter(_errorReporter) + {} + + void check(ContractDefinition const& _contract); + +private: + /** + * Comparator that compares + * - functions such that equality means that the functions override each other + * - modifiers by name + * - contracts by AST id. + */ + struct LessFunction + { + bool operator()(ModifierDefinition const* _a, ModifierDefinition const* _b) const; + bool operator()(FunctionDefinition const* _a, FunctionDefinition const* _b) const; + bool operator()(ContractDefinition const* _a, ContractDefinition const* _b) const; + }; + + using FunctionMultiSet = std::multiset; + using ModifierMultiSet = std::multiset; + + void checkIllegalOverrides(ContractDefinition const& _contract); + /// Performs various checks related to @a _overriding overriding @a _super like + /// different return type, invalid visibility change, etc. + /// Works on functions, modifiers and public state variables. + /// Also stores @a _super as a base function of @a _function in its AST annotations. + template + void checkOverride(T const& _overriding, U const& _super); + void overrideListError( + CallableDeclaration const& _callable, + std::set _secondary, + std::string const& _message1, + std::string const& _message2 + ); + void overrideError( + Declaration const& _overriding, + Declaration const& _super, + std::string _message, + std::string _secondaryMsg = "Overridden function is here:" + ); + /// Checks for functions in different base contracts which conflict with each + /// other and thus need to be overridden explicitly. + void checkAmbiguousOverrides(ContractDefinition const& _contract) const; + void checkAmbiguousOverridesInternal(std::set< + CallableDeclaration const*, + std::function + > _baseCallables, langutil::SourceLocation const& _location) const; + /// Resolves an override list of UserDefinedTypeNames to a list of contracts. + std::set resolveOverrideList(OverrideSpecifier const& _overrides) const; + + template + void checkOverrideList( + std::multiset const& _funcSet, + T const& _function + ); + + /// Returns all functions of bases that have not yet been overwritten. + /// May contain the same function multiple times when used with shared bases. + FunctionMultiSet const& inheritedFunctions(ContractDefinition const& _contract) const; + ModifierMultiSet const& inheritedModifiers(ContractDefinition const& _contract) const; + + langutil::ErrorReporter& m_errorReporter; + + /// Cache for inheritedFunctions(). + std::map mutable m_inheritedFunctions; + std::map mutable m_contractBaseModifiers; +}; + +} +}