Override checks for modifiers.

This commit is contained in:
chriseth 2019-12-06 10:07:56 +01:00
parent 871a5b83ff
commit e1d6ce2b66
8 changed files with 182 additions and 96 deletions

View File

@ -238,10 +238,8 @@ void ContractLevelChecker::findDuplicateDefinitions(map<string, vector<T>> const
void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _contract) void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _contract)
{ {
FunctionMultiSet const& funcSet = inheritedFunctions(_contract); FunctionMultiSet const& inheritedFuncs = inheritedFunctions(_contract);
ModifierMultiSet const& modSet = inheritedModifiers(_contract); ModifierMultiSet const& inheritedMods = inheritedModifiers(_contract);
checkModifierOverrides(funcSet, modSet, _contract.functionModifiers());
for (auto const* stateVar: _contract.stateVariables()) for (auto const* stateVar: _contract.stateVariables())
{ {
@ -250,9 +248,9 @@ void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _cont
bool found = false; bool found = false;
for ( for (
auto it = find_if(funcSet.begin(), funcSet.end(), MatchByName{stateVar->name()}); auto it = find_if(inheritedFuncs.begin(), inheritedFuncs.end(), MatchByName{stateVar->name()});
it != funcSet.end(); it != inheritedFuncs.end();
it = find_if(++it, funcSet.end(), MatchByName{stateVar->name()}) it = find_if(++it, inheritedFuncs.end(), MatchByName{stateVar->name()})
) )
{ {
if (!hasEqualNameAndParameters(*stateVar, **it)) if (!hasEqualNameAndParameters(*stateVar, **it))
@ -261,7 +259,7 @@ void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _cont
if ((*it)->visibility() != Declaration::Visibility::External) if ((*it)->visibility() != Declaration::Visibility::External)
overrideError(*stateVar, **it, "Public state variables can only override functions with external visibility."); overrideError(*stateVar, **it, "Public state variables can only override functions with external visibility.");
else else
checkFunctionOverride(*stateVar, **it); checkOverride(*stateVar, **it);
found = true; found = true;
} }
@ -273,51 +271,96 @@ void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _cont
); );
} }
for (ModifierDefinition const* modifier: _contract.functionModifiers())
{
if (contains_if(inheritedFuncs, MatchByName{modifier->name()}))
m_errorReporter.typeError(
modifier->location(),
"Override changes function to modifier."
);
auto [begin, end] = inheritedMods.equal_range(modifier);
if (begin == end && modifier->overrides())
m_errorReporter.typeError(
modifier->overrides()->location(),
"Modifier has override specified but does not override anything."
);
for (; begin != end; begin++)
if (ModifierType(**begin) != ModifierType(*modifier))
m_errorReporter.typeError(
modifier->location(),
"Override changes modifier signature."
);
checkOverrideList(inheritedMods, *modifier);
}
for (FunctionDefinition const* function: _contract.definedFunctions()) for (FunctionDefinition const* function: _contract.definedFunctions())
{ {
if (function->isConstructor()) if (function->isConstructor())
continue; continue;
if (contains_if(modSet, MatchByName{function->name()})) if (contains_if(inheritedMods, MatchByName{function->name()}))
m_errorReporter.typeError(function->location(), "Override changes modifier to function."); m_errorReporter.typeError(function->location(), "Override changes modifier to function.");
// No inheriting functions found // No inheriting functions found
if (!funcSet.count(function) && function->overrides()) if (!inheritedFuncs.count(function) && function->overrides())
m_errorReporter.typeError( m_errorReporter.typeError(
function->overrides()->location(), function->overrides()->location(),
"Function has override specified but does not override anything." "Function has override specified but does not override anything."
); );
checkOverrideList(funcSet, *function); checkOverrideList(inheritedFuncs, *function);
} }
} }
template<class T> template<class T, class U>
void ContractLevelChecker::checkFunctionOverride(T const& _overriding, FunctionDefinition const& _super) void ContractLevelChecker::checkOverride(T const& _overriding, U const& _super)
{ {
string overridingName; static_assert(
std::is_same<VariableDeclaration, T>::value ||
std::is_same<FunctionDefinition, T>::value ||
std::is_same<ModifierDefinition, T>::value,
"Invalid call to checkOverride."
);
static_assert(
std::is_same<FunctionDefinition, U>::value ||
std::is_same<ModifierDefinition, U>::value,
"Invalid call to checkOverride."
);
static_assert(
!std::is_same<ModifierDefinition, U>::value ||
std::is_same<ModifierDefinition, T>::value,
"Invalid call to checkOverride."
);
string overridingName;
if constexpr(std::is_same<FunctionDefinition, T>::value) if constexpr(std::is_same<FunctionDefinition, T>::value)
overridingName = "function"; overridingName = "function";
else if constexpr(std::is_same<ModifierDefinition, T>::value)
overridingName = "modifier";
else else
overridingName = "public state variable"; overridingName = "public state variable";
FunctionTypePointer functionType = FunctionType(_overriding).asCallableFunction(false); string superName;
FunctionTypePointer superType = FunctionType(_super).asCallableFunction(false); if constexpr(std::is_same<FunctionDefinition, U>::value)
superName = "function";
solAssert(functionType->hasEqualParameterTypes(*superType), "Override doesn't have equal parameters!"); else
superName = "modifier";
if (!_overriding.overrides()) if (!_overriding.overrides())
overrideError(_overriding, _super, "Overriding " + overridingName + " is missing 'override' specifier."); overrideError(_overriding, _super, "Overriding " + overridingName + " is missing 'override' specifier.");
if (!_super.virtualSemantics()) if (!_super.virtualSemantics())
overrideError( _super, _overriding, "Trying to override non-virtual function. Did you forget to add \"virtual\"?", "Overriding " + overridingName + " is here:"); overrideError(
_super,
if (!functionType->hasEqualReturnTypes(*superType)) _overriding,
overrideError(_overriding, _super, "Overriding " + overridingName + " return types differ."); "Trying to override non-virtual " + superName + ". Did you forget to add \"virtual\"?",
"Overriding " + overridingName + " is here:"
if constexpr(std::is_same<T, FunctionDefinition>::value) );
_overriding.annotation().baseFunctions.emplace(&_super);
if (_overriding.visibility() != _super.visibility()) if (_overriding.visibility() != _super.visibility())
{ {
@ -330,8 +373,23 @@ void ContractLevelChecker::checkFunctionOverride(T const& _overriding, FunctionD
overrideError(_overriding, _super, "Overriding " + overridingName + " visibility differs."); overrideError(_overriding, _super, "Overriding " + overridingName + " visibility differs.");
} }
// This is only relevant for overriding functions by functions or state variables,
// it is skipped for modifiers.
if constexpr(std::is_same<FunctionDefinition, U>::value)
{
FunctionTypePointer functionType = FunctionType(_overriding).asCallableFunction(false);
FunctionTypePointer superType = FunctionType(_super).asCallableFunction(false);
solAssert(functionType->hasEqualParameterTypes(*superType), "Override doesn't have equal parameters!");
if (!functionType->hasEqualReturnTypes(*superType))
overrideError(_overriding, _super, "Overriding " + overridingName + " return types differ.");
// This is only relevant for a function overriding a function.
if constexpr(std::is_same<T, FunctionDefinition>::value) if constexpr(std::is_same<T, FunctionDefinition>::value)
{ {
_overriding.annotation().baseFunctions.emplace(&_super);
if (_overriding.stateMutability() != _super.stateMutability()) if (_overriding.stateMutability() != _super.stateMutability())
overrideError( overrideError(
_overriding, _overriding,
@ -350,9 +408,15 @@ void ContractLevelChecker::checkFunctionOverride(T const& _overriding, FunctionD
"Overriding an implemented function with an unimplemented function is not allowed." "Overriding an implemented function with an unimplemented function is not allowed."
); );
} }
}
} }
void ContractLevelChecker::overrideListError(FunctionDefinition const& function, set<ContractDefinition const*, LessFunction> _secondary, string const& _message1, string const& _message2) void ContractLevelChecker::overrideListError(
CallableDeclaration const& _callable,
set<ContractDefinition const*, LessFunction> _secondary,
string const& _message1,
string const& _message2
)
{ {
// Using a set rather than a vector so the order is always the same // Using a set rather than a vector so the order is always the same
set<string> names; set<string> names;
@ -367,7 +431,7 @@ void ContractLevelChecker::overrideListError(FunctionDefinition const& function,
contractSingularPlural = "contracts "; contractSingularPlural = "contracts ";
m_errorReporter.typeError( m_errorReporter.typeError(
function.overrides() ? function.overrides()->location() : function.location(), _callable.overrides() ? _callable.overrides()->location() : _callable.location(),
ssl, ssl,
_message1 + _message1 +
contractSingularPlural + contractSingularPlural +
@ -666,6 +730,10 @@ void ContractLevelChecker::checkBaseABICompatibility(ContractDefinition const& _
void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _contract) const void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _contract) const
{ {
// TODO same for modifiers.
// Fetch inherited functions and sort them by signature. // Fetch inherited functions and sort them by signature.
// We get at least one function per signature and direct base contract, which is // We get at least one function per signature and direct base contract, which is
// enough because we re-construct the inheritance graph later. // enough because we re-construct the inheritance graph later.
@ -815,49 +883,23 @@ set<ContractDefinition const*, ContractLevelChecker::LessFunction> ContractLevel
return resolved; return resolved;
} }
template <class T>
void ContractLevelChecker::checkModifierOverrides(FunctionMultiSet const& _funcSet, ModifierMultiSet const& _modSet, std::vector<ModifierDefinition const*> _modifiers) void ContractLevelChecker::checkOverrideList(
{ std::multiset<T const*, LessFunction> const& _inheritedCallables,
for (ModifierDefinition const* modifier: _modifiers) T const& _callable
{ )
if (contains_if(_funcSet, MatchByName{modifier->name()}))
m_errorReporter.typeError(
modifier->location(),
"Override changes function to modifier."
);
auto [begin,end] = _modSet.equal_range(modifier);
// Skip if no modifiers found in bases
if (begin == end)
continue;
if (!modifier->overrides())
overrideError(*modifier, **begin, "Overriding modifier is missing 'override' specifier.");
for (; begin != end; begin++)
if (ModifierType(**begin) != ModifierType(*modifier))
m_errorReporter.typeError(
modifier->location(),
"Override changes modifier signature."
);
}
}
void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, FunctionDefinition const& _function)
{ {
set<ContractDefinition const*, LessFunction> specifiedContracts = set<ContractDefinition const*, LessFunction> specifiedContracts =
_function.overrides() ? _callable.overrides() ?
resolveOverrideList(*_function.overrides()) : resolveOverrideList(*_callable.overrides()) :
decltype(specifiedContracts){}; decltype(specifiedContracts){};
// Check for duplicates in override list // Check for duplicates in override list
if (_function.overrides() && specifiedContracts.size() != _function.overrides()->overrides().size()) if (_callable.overrides() && specifiedContracts.size() != _callable.overrides()->overrides().size())
{ {
// Sort by contract id to find duplicate for error reporting // Sort by contract id to find duplicate for error reporting
vector<ASTPointer<UserDefinedTypeName>> list = vector<ASTPointer<UserDefinedTypeName>> list =
sortByContract(_function.overrides()->overrides()); sortByContract(_callable.overrides()->overrides());
// Find duplicates and output error // Find duplicates and output error
for (size_t i = 1; i < list.size(); i++) for (size_t i = 1; i < list.size(); i++)
@ -877,7 +919,7 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
"Duplicate contract \"" + "Duplicate contract \"" +
joinHumanReadable(list[i]->namePath(), ".") + joinHumanReadable(list[i]->namePath(), ".") +
"\" found in override list of \"" + "\" found in override list of \"" +
_function.name() + _callable.name() +
"\"." "\"."
); );
} }
@ -887,12 +929,12 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
decltype(specifiedContracts) expectedContracts; decltype(specifiedContracts) expectedContracts;
// Build list of expected contracts // Build list of expected contracts
for (auto [begin, end] = _funcSet.equal_range(&_function); begin != end; begin++) for (auto [begin, end] = _inheritedCallables.equal_range(&_callable); begin != end; begin++)
{ {
// Validate the override // Validate the override
checkFunctionOverride(_function, **begin); checkOverride(_callable, **begin);
expectedContracts.insert((*begin)->annotation().contract); expectedContracts.insert(&dynamic_cast<ContractDefinition const&>(*(*begin)->scope()));
} }
decltype(specifiedContracts) missingContracts; decltype(specifiedContracts) missingContracts;
@ -906,7 +948,7 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
if (!missingContracts.empty()) if (!missingContracts.empty())
overrideListError( overrideListError(
_function, _callable,
missingContracts, missingContracts,
"Function needs to specify overridden ", "Function needs to specify overridden ",
"" ""
@ -914,7 +956,7 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
if (!surplusContracts.empty()) if (!surplusContracts.empty())
overrideListError( overrideListError(
_function, _callable,
surplusContracts, surplusContracts,
"Invalid ", "Invalid ",
"specified in override list: " "specified in override list: "

View File

@ -53,6 +53,12 @@ public:
bool check(ContractDefinition const& _contract); bool check(ContractDefinition const& _contract);
private: private:
/**
* Comparator that compares
* - functions such that equality means that the functions override each other
* - modifiers by name
* - contracts by AST id.
*/
struct LessFunction struct LessFunction
{ {
bool operator()(ModifierDefinition const* _a, ModifierDefinition const* _b) const; bool operator()(ModifierDefinition const* _a, ModifierDefinition const* _b) const;
@ -70,13 +76,24 @@ private:
template <class T> template <class T>
void findDuplicateDefinitions(std::map<std::string, std::vector<T>> const& _definitions, std::string _message); void findDuplicateDefinitions(std::map<std::string, std::vector<T>> const& _definitions, std::string _message);
void checkIllegalOverrides(ContractDefinition const& _contract); void checkIllegalOverrides(ContractDefinition const& _contract);
/// Performs various checks related to @a _function overriding @a _super like /// Performs various checks related to @a _overriding overriding @a _super like
/// different return type, invalid visibility change, etc. /// different return type, invalid visibility change, etc.
/// Works on functions, modifiers and public state variables.
/// Also stores @a _super as a base function of @a _function in its AST annotations. /// Also stores @a _super as a base function of @a _function in its AST annotations.
template<class T> template<class T, class U>
void checkFunctionOverride(T const& _overriding, FunctionDefinition const& _super); void checkOverride(T const& _overriding, U const& _super);
void overrideListError(FunctionDefinition const& function, std::set<ContractDefinition const*, LessFunction> _secondary, std::string const& _message1, std::string const& _message2); void overrideListError(
void overrideError(Declaration const& _overriding, Declaration const& _super, std::string _message, std::string _secondaryMsg = "Overridden function is here:"); CallableDeclaration const& _callable,
std::set<ContractDefinition const*, LessFunction> _secondary,
std::string const& _message1,
std::string const& _message2
);
void overrideError(
Declaration const& _overriding,
Declaration const& _super,
std::string _message,
std::string _secondaryMsg = "Overridden function is here:"
);
void checkAbstractFunctions(ContractDefinition const& _contract); void checkAbstractFunctions(ContractDefinition const& _contract);
/// Checks that the base constructor arguments are properly provided. /// Checks that the base constructor arguments are properly provided.
/// Fills the list of unimplemented functions in _contract's annotations. /// Fills the list of unimplemented functions in _contract's annotations.
@ -101,8 +118,11 @@ private:
/// Resolves an override list of UserDefinedTypeNames to a list of contracts. /// Resolves an override list of UserDefinedTypeNames to a list of contracts.
std::set<ContractDefinition const*, LessFunction> resolveOverrideList(OverrideSpecifier const& _overrides) const; std::set<ContractDefinition const*, LessFunction> resolveOverrideList(OverrideSpecifier const& _overrides) const;
void checkModifierOverrides(FunctionMultiSet const& _funcSet, ModifierMultiSet const& _modSet, std::vector<ModifierDefinition const*> _modifiers); template <class T>
void checkOverrideList(FunctionMultiSet const& _funcSet, FunctionDefinition const& _function); void checkOverrideList(
std::multiset<T const*, LessFunction> const& _funcSet,
T const& _function
);
/// Returns all functions of bases that have not yet been overwritten. /// Returns all functions of bases that have not yet been overwritten.
/// May contain the same function multiple times when used with shared bases. /// May contain the same function multiple times when used with shared bases.

View File

@ -2313,7 +2313,7 @@ BOOST_AUTO_TEST_CASE(function_modifier_overriding)
char const* sourceCode = R"( char const* sourceCode = R"(
contract A { contract A {
function f() mod public returns (bool r) { return true; } function f() mod public returns (bool r) { return true; }
modifier mod { _; } modifier mod virtual { _; }
} }
contract C is A { contract C is A {
modifier mod override { if (false) _; } modifier mod override { if (false) _; }
@ -2352,7 +2352,7 @@ BOOST_AUTO_TEST_CASE(function_modifier_for_constructor)
contract A { contract A {
uint data; uint data;
constructor() mod1 public { data |= 2; } constructor() mod1 public { data |= 2; }
modifier mod1 { data |= 1; _; } modifier mod1 virtual { data |= 1; _; }
function getData() public returns (uint r) { return data; } function getData() public returns (uint r) { return data; }
} }
contract C is A { contract C is A {

View File

@ -0,0 +1,9 @@
contract A {
modifier f() virtual { _; }
}
contract B {
modifier f() virtual { _; }
}
contract C is A, B {
modifier f() override(A,B) { _; }
}

View File

@ -0,0 +1,10 @@
contract A {
modifier f() virtual { _; }
}
contract B {
modifier f() virtual { _; }
}
contract C is A, B {
}
// ----
// THIS NEEDS TO BE AN ERROR

View File

@ -1,5 +1,6 @@
contract A { modifier mod(uint a) { _; } } contract A { modifier mod(uint a) { _; } }
contract B is A { modifier mod(uint8 a) { _; } } contract B is A { modifier mod(uint8 a) { _; } }
// ---- // ----
// TypeError: (61-89): Overriding modifier is missing 'override' specifier.
// TypeError: (61-89): Override changes modifier signature. // TypeError: (61-89): Override changes modifier signature.
// TypeError: (61-89): Overriding modifier is missing 'override' specifier.
// TypeError: (13-40): Trying to override non-virtual modifier. Did you forget to add "virtual"?

View File

@ -1,2 +1,2 @@
contract A { modifier mod(uint a) { _; } } contract A { modifier mod(uint a) virtual { _; } }
contract B is A { modifier mod(uint a) override { _; } } contract B is A { modifier mod(uint a) override { _; } }

View File

@ -0,0 +1,4 @@
contract A { modifier mod(uint a) { _; } }
contract B is A { modifier mod(uint a) override { _; } }
// ----
// TypeError: (13-40): Trying to override non-virtual modifier. Did you forget to add "virtual"?