From e1d6ce2b662eb925104a09ab5cf1462994091674 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 6 Dec 2019 10:07:56 +0100 Subject: [PATCH] Override checks for modifiers. --- libsolidity/analysis/ContractLevelChecker.cpp | 212 +++++++++++------- libsolidity/analysis/ContractLevelChecker.h | 34 ++- test/libsolidity/SolidityEndToEndTest.cpp | 4 +- .../override/modifier_ambiguous.sol | 9 + .../override/modifier_ambiguous_fail.sol | 10 + .../modifiers/illegal_modifier_override.sol | 3 +- .../modifiers/legal_modifier_override.sol | 2 +- .../non-virtual_modifier_override.sol | 4 + 8 files changed, 182 insertions(+), 96 deletions(-) create mode 100644 test/libsolidity/syntaxTests/inheritance/override/modifier_ambiguous.sol create mode 100644 test/libsolidity/syntaxTests/inheritance/override/modifier_ambiguous_fail.sol create mode 100644 test/libsolidity/syntaxTests/modifiers/non-virtual_modifier_override.sol diff --git a/libsolidity/analysis/ContractLevelChecker.cpp b/libsolidity/analysis/ContractLevelChecker.cpp index f581f9976..b9ee92f84 100644 --- a/libsolidity/analysis/ContractLevelChecker.cpp +++ b/libsolidity/analysis/ContractLevelChecker.cpp @@ -238,10 +238,8 @@ void ContractLevelChecker::findDuplicateDefinitions(map> const void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _contract) { - FunctionMultiSet const& funcSet = inheritedFunctions(_contract); - ModifierMultiSet const& modSet = inheritedModifiers(_contract); - - checkModifierOverrides(funcSet, modSet, _contract.functionModifiers()); + FunctionMultiSet const& inheritedFuncs = inheritedFunctions(_contract); + ModifierMultiSet const& inheritedMods = inheritedModifiers(_contract); for (auto const* stateVar: _contract.stateVariables()) { @@ -250,9 +248,9 @@ void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _cont bool found = false; for ( - auto it = find_if(funcSet.begin(), funcSet.end(), MatchByName{stateVar->name()}); - it != funcSet.end(); - it = find_if(++it, funcSet.end(), MatchByName{stateVar->name()}) + auto it = find_if(inheritedFuncs.begin(), inheritedFuncs.end(), MatchByName{stateVar->name()}); + it != inheritedFuncs.end(); + it = find_if(++it, inheritedFuncs.end(), MatchByName{stateVar->name()}) ) { if (!hasEqualNameAndParameters(*stateVar, **it)) @@ -261,7 +259,7 @@ void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _cont if ((*it)->visibility() != Declaration::Visibility::External) overrideError(*stateVar, **it, "Public state variables can only override functions with external visibility."); else - checkFunctionOverride(*stateVar, **it); + checkOverride(*stateVar, **it); found = true; } @@ -273,51 +271,96 @@ void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _cont ); } + for (ModifierDefinition const* modifier: _contract.functionModifiers()) + { + if (contains_if(inheritedFuncs, MatchByName{modifier->name()})) + m_errorReporter.typeError( + modifier->location(), + "Override changes function to modifier." + ); + + auto [begin, end] = inheritedMods.equal_range(modifier); + + if (begin == end && modifier->overrides()) + m_errorReporter.typeError( + modifier->overrides()->location(), + "Modifier has override specified but does not override anything." + ); + + for (; begin != end; begin++) + if (ModifierType(**begin) != ModifierType(*modifier)) + m_errorReporter.typeError( + modifier->location(), + "Override changes modifier signature." + ); + + checkOverrideList(inheritedMods, *modifier); + } + for (FunctionDefinition const* function: _contract.definedFunctions()) { if (function->isConstructor()) continue; - if (contains_if(modSet, MatchByName{function->name()})) + if (contains_if(inheritedMods, MatchByName{function->name()})) m_errorReporter.typeError(function->location(), "Override changes modifier to function."); // No inheriting functions found - if (!funcSet.count(function) && function->overrides()) + if (!inheritedFuncs.count(function) && function->overrides()) m_errorReporter.typeError( function->overrides()->location(), "Function has override specified but does not override anything." ); - checkOverrideList(funcSet, *function); + checkOverrideList(inheritedFuncs, *function); } } -template -void ContractLevelChecker::checkFunctionOverride(T const& _overriding, FunctionDefinition const& _super) +template +void ContractLevelChecker::checkOverride(T const& _overriding, U const& _super) { - string overridingName; + static_assert( + std::is_same::value || + std::is_same::value || + std::is_same::value, + "Invalid call to checkOverride." + ); + static_assert( + std::is_same::value || + std::is_same::value, + "Invalid call to checkOverride." + ); + static_assert( + !std::is_same::value || + std::is_same::value, + "Invalid call to checkOverride." + ); + + string overridingName; if constexpr(std::is_same::value) overridingName = "function"; + else if constexpr(std::is_same::value) + overridingName = "modifier"; else overridingName = "public state variable"; - FunctionTypePointer functionType = FunctionType(_overriding).asCallableFunction(false); - FunctionTypePointer superType = FunctionType(_super).asCallableFunction(false); - - solAssert(functionType->hasEqualParameterTypes(*superType), "Override doesn't have equal parameters!"); + string superName; + if constexpr(std::is_same::value) + superName = "function"; + else + superName = "modifier"; if (!_overriding.overrides()) overrideError(_overriding, _super, "Overriding " + overridingName + " is missing 'override' specifier."); if (!_super.virtualSemantics()) - overrideError( _super, _overriding, "Trying to override non-virtual function. Did you forget to add \"virtual\"?", "Overriding " + overridingName + " is here:"); - - if (!functionType->hasEqualReturnTypes(*superType)) - overrideError(_overriding, _super, "Overriding " + overridingName + " return types differ."); - - if constexpr(std::is_same::value) - _overriding.annotation().baseFunctions.emplace(&_super); + overrideError( + _super, + _overriding, + "Trying to override non-virtual " + superName + ". Did you forget to add \"virtual\"?", + "Overriding " + overridingName + " is here:" + ); if (_overriding.visibility() != _super.visibility()) { @@ -330,29 +373,50 @@ void ContractLevelChecker::checkFunctionOverride(T const& _overriding, FunctionD overrideError(_overriding, _super, "Overriding " + overridingName + " visibility differs."); } - if constexpr(std::is_same::value) + // This is only relevant for overriding functions by functions or state variables, + // it is skipped for modifiers. + if constexpr(std::is_same::value) { - if (_overriding.stateMutability() != _super.stateMutability()) - overrideError( - _overriding, - _super, - "Overriding function changes state mutability from \"" + - stateMutabilityToString(_super.stateMutability()) + - "\" to \"" + - stateMutabilityToString(_overriding.stateMutability()) + - "\"." - ); + FunctionTypePointer functionType = FunctionType(_overriding).asCallableFunction(false); + FunctionTypePointer superType = FunctionType(_super).asCallableFunction(false); - if (!_overriding.isImplemented() && _super.isImplemented()) - overrideError( - _overriding, - _super, - "Overriding an implemented function with an unimplemented function is not allowed." - ); + solAssert(functionType->hasEqualParameterTypes(*superType), "Override doesn't have equal parameters!"); + + if (!functionType->hasEqualReturnTypes(*superType)) + overrideError(_overriding, _super, "Overriding " + overridingName + " return types differ."); + + // This is only relevant for a function overriding a function. + if constexpr(std::is_same::value) + { + _overriding.annotation().baseFunctions.emplace(&_super); + + if (_overriding.stateMutability() != _super.stateMutability()) + overrideError( + _overriding, + _super, + "Overriding function changes state mutability from \"" + + stateMutabilityToString(_super.stateMutability()) + + "\" to \"" + + stateMutabilityToString(_overriding.stateMutability()) + + "\"." + ); + + if (!_overriding.isImplemented() && _super.isImplemented()) + overrideError( + _overriding, + _super, + "Overriding an implemented function with an unimplemented function is not allowed." + ); + } } } -void ContractLevelChecker::overrideListError(FunctionDefinition const& function, set _secondary, string const& _message1, string const& _message2) +void ContractLevelChecker::overrideListError( + CallableDeclaration const& _callable, + set _secondary, + string const& _message1, + string const& _message2 +) { // Using a set rather than a vector so the order is always the same set names; @@ -367,7 +431,7 @@ void ContractLevelChecker::overrideListError(FunctionDefinition const& function, contractSingularPlural = "contracts "; m_errorReporter.typeError( - function.overrides() ? function.overrides()->location() : function.location(), + _callable.overrides() ? _callable.overrides()->location() : _callable.location(), ssl, _message1 + contractSingularPlural + @@ -666,6 +730,10 @@ void ContractLevelChecker::checkBaseABICompatibility(ContractDefinition const& _ void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _contract) const { + // TODO same for modifiers. + + + // Fetch inherited functions and sort them by signature. // We get at least one function per signature and direct base contract, which is // enough because we re-construct the inheritance graph later. @@ -815,49 +883,23 @@ set ContractLevel return resolved; } - -void ContractLevelChecker::checkModifierOverrides(FunctionMultiSet const& _funcSet, ModifierMultiSet const& _modSet, std::vector _modifiers) -{ - for (ModifierDefinition const* modifier: _modifiers) - { - if (contains_if(_funcSet, MatchByName{modifier->name()})) - m_errorReporter.typeError( - modifier->location(), - "Override changes function to modifier." - ); - - auto [begin,end] = _modSet.equal_range(modifier); - - // Skip if no modifiers found in bases - if (begin == end) - continue; - - if (!modifier->overrides()) - overrideError(*modifier, **begin, "Overriding modifier is missing 'override' specifier."); - - for (; begin != end; begin++) - if (ModifierType(**begin) != ModifierType(*modifier)) - m_errorReporter.typeError( - modifier->location(), - "Override changes modifier signature." - ); - } - -} - -void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, FunctionDefinition const& _function) +template +void ContractLevelChecker::checkOverrideList( + std::multiset const& _inheritedCallables, + T const& _callable +) { set specifiedContracts = - _function.overrides() ? - resolveOverrideList(*_function.overrides()) : + _callable.overrides() ? + resolveOverrideList(*_callable.overrides()) : decltype(specifiedContracts){}; // Check for duplicates in override list - if (_function.overrides() && specifiedContracts.size() != _function.overrides()->overrides().size()) + if (_callable.overrides() && specifiedContracts.size() != _callable.overrides()->overrides().size()) { // Sort by contract id to find duplicate for error reporting vector> list = - sortByContract(_function.overrides()->overrides()); + sortByContract(_callable.overrides()->overrides()); // Find duplicates and output error for (size_t i = 1; i < list.size(); i++) @@ -877,7 +919,7 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F "Duplicate contract \"" + joinHumanReadable(list[i]->namePath(), ".") + "\" found in override list of \"" + - _function.name() + + _callable.name() + "\"." ); } @@ -887,12 +929,12 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F decltype(specifiedContracts) expectedContracts; // Build list of expected contracts - for (auto [begin, end] = _funcSet.equal_range(&_function); begin != end; begin++) + for (auto [begin, end] = _inheritedCallables.equal_range(&_callable); begin != end; begin++) { // Validate the override - checkFunctionOverride(_function, **begin); + checkOverride(_callable, **begin); - expectedContracts.insert((*begin)->annotation().contract); + expectedContracts.insert(&dynamic_cast(*(*begin)->scope())); } decltype(specifiedContracts) missingContracts; @@ -906,7 +948,7 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F if (!missingContracts.empty()) overrideListError( - _function, + _callable, missingContracts, "Function needs to specify overridden ", "" @@ -914,7 +956,7 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F if (!surplusContracts.empty()) overrideListError( - _function, + _callable, surplusContracts, "Invalid ", "specified in override list: " diff --git a/libsolidity/analysis/ContractLevelChecker.h b/libsolidity/analysis/ContractLevelChecker.h index bcc54a643..111191724 100644 --- a/libsolidity/analysis/ContractLevelChecker.h +++ b/libsolidity/analysis/ContractLevelChecker.h @@ -53,6 +53,12 @@ public: bool check(ContractDefinition const& _contract); private: + /** + * Comparator that compares + * - functions such that equality means that the functions override each other + * - modifiers by name + * - contracts by AST id. + */ struct LessFunction { bool operator()(ModifierDefinition const* _a, ModifierDefinition const* _b) const; @@ -70,13 +76,24 @@ private: template void findDuplicateDefinitions(std::map> const& _definitions, std::string _message); void checkIllegalOverrides(ContractDefinition const& _contract); - /// Performs various checks related to @a _function overriding @a _super like + /// Performs various checks related to @a _overriding overriding @a _super like /// different return type, invalid visibility change, etc. + /// Works on functions, modifiers and public state variables. /// Also stores @a _super as a base function of @a _function in its AST annotations. - template - void checkFunctionOverride(T const& _overriding, FunctionDefinition const& _super); - void overrideListError(FunctionDefinition const& function, std::set _secondary, std::string const& _message1, std::string const& _message2); - void overrideError(Declaration const& _overriding, Declaration const& _super, std::string _message, std::string _secondaryMsg = "Overridden function is here:"); + template + void checkOverride(T const& _overriding, U const& _super); + void overrideListError( + CallableDeclaration const& _callable, + std::set _secondary, + std::string const& _message1, + std::string const& _message2 + ); + void overrideError( + Declaration const& _overriding, + Declaration const& _super, + std::string _message, + std::string _secondaryMsg = "Overridden function is here:" + ); void checkAbstractFunctions(ContractDefinition const& _contract); /// Checks that the base constructor arguments are properly provided. /// Fills the list of unimplemented functions in _contract's annotations. @@ -101,8 +118,11 @@ private: /// Resolves an override list of UserDefinedTypeNames to a list of contracts. std::set resolveOverrideList(OverrideSpecifier const& _overrides) const; - void checkModifierOverrides(FunctionMultiSet const& _funcSet, ModifierMultiSet const& _modSet, std::vector _modifiers); - void checkOverrideList(FunctionMultiSet const& _funcSet, FunctionDefinition const& _function); + template + void checkOverrideList( + std::multiset const& _funcSet, + T const& _function + ); /// Returns all functions of bases that have not yet been overwritten. /// May contain the same function multiple times when used with shared bases. diff --git a/test/libsolidity/SolidityEndToEndTest.cpp b/test/libsolidity/SolidityEndToEndTest.cpp index d6f959957..3a72912fe 100644 --- a/test/libsolidity/SolidityEndToEndTest.cpp +++ b/test/libsolidity/SolidityEndToEndTest.cpp @@ -2313,7 +2313,7 @@ BOOST_AUTO_TEST_CASE(function_modifier_overriding) char const* sourceCode = R"( contract A { function f() mod public returns (bool r) { return true; } - modifier mod { _; } + modifier mod virtual { _; } } contract C is A { modifier mod override { if (false) _; } @@ -2352,7 +2352,7 @@ BOOST_AUTO_TEST_CASE(function_modifier_for_constructor) contract A { uint data; constructor() mod1 public { data |= 2; } - modifier mod1 { data |= 1; _; } + modifier mod1 virtual { data |= 1; _; } function getData() public returns (uint r) { return data; } } contract C is A { diff --git a/test/libsolidity/syntaxTests/inheritance/override/modifier_ambiguous.sol b/test/libsolidity/syntaxTests/inheritance/override/modifier_ambiguous.sol new file mode 100644 index 000000000..fd459252c --- /dev/null +++ b/test/libsolidity/syntaxTests/inheritance/override/modifier_ambiguous.sol @@ -0,0 +1,9 @@ +contract A { + modifier f() virtual { _; } +} +contract B { + modifier f() virtual { _; } +} +contract C is A, B { + modifier f() override(A,B) { _; } +} diff --git a/test/libsolidity/syntaxTests/inheritance/override/modifier_ambiguous_fail.sol b/test/libsolidity/syntaxTests/inheritance/override/modifier_ambiguous_fail.sol new file mode 100644 index 000000000..82a417fa8 --- /dev/null +++ b/test/libsolidity/syntaxTests/inheritance/override/modifier_ambiguous_fail.sol @@ -0,0 +1,10 @@ +contract A { + modifier f() virtual { _; } +} +contract B { + modifier f() virtual { _; } +} +contract C is A, B { +} +// ---- +// THIS NEEDS TO BE AN ERROR diff --git a/test/libsolidity/syntaxTests/modifiers/illegal_modifier_override.sol b/test/libsolidity/syntaxTests/modifiers/illegal_modifier_override.sol index 3abac89c1..0b2087e7e 100644 --- a/test/libsolidity/syntaxTests/modifiers/illegal_modifier_override.sol +++ b/test/libsolidity/syntaxTests/modifiers/illegal_modifier_override.sol @@ -1,5 +1,6 @@ contract A { modifier mod(uint a) { _; } } contract B is A { modifier mod(uint8 a) { _; } } // ---- -// TypeError: (61-89): Overriding modifier is missing 'override' specifier. // TypeError: (61-89): Override changes modifier signature. +// TypeError: (61-89): Overriding modifier is missing 'override' specifier. +// TypeError: (13-40): Trying to override non-virtual modifier. Did you forget to add "virtual"? diff --git a/test/libsolidity/syntaxTests/modifiers/legal_modifier_override.sol b/test/libsolidity/syntaxTests/modifiers/legal_modifier_override.sol index a661193e6..9ee7cb027 100644 --- a/test/libsolidity/syntaxTests/modifiers/legal_modifier_override.sol +++ b/test/libsolidity/syntaxTests/modifiers/legal_modifier_override.sol @@ -1,2 +1,2 @@ -contract A { modifier mod(uint a) { _; } } +contract A { modifier mod(uint a) virtual { _; } } contract B is A { modifier mod(uint a) override { _; } } diff --git a/test/libsolidity/syntaxTests/modifiers/non-virtual_modifier_override.sol b/test/libsolidity/syntaxTests/modifiers/non-virtual_modifier_override.sol new file mode 100644 index 000000000..81d7ba901 --- /dev/null +++ b/test/libsolidity/syntaxTests/modifiers/non-virtual_modifier_override.sol @@ -0,0 +1,4 @@ +contract A { modifier mod(uint a) { _; } } +contract B is A { modifier mod(uint a) override { _; } } +// ---- +// TypeError: (13-40): Trying to override non-virtual modifier. Did you forget to add "virtual"?