Merge pull request #7912 from ethereum/overrideForModifiers

Override for modifiers
This commit is contained in:
chriseth 2019-12-09 19:37:02 +01:00 committed by GitHub
commit bffd999b20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 442 additions and 242 deletions

View File

@ -305,8 +305,9 @@ Modifier Overriding
=================== ===================
Function modifiers can override each other. This works in the same way as Function modifiers can override each other. This works in the same way as
function overriding (except that there is no overloading for modifiers). The `function overriding <function-overriding>`_ (except that there is no overloading for modifiers). The
``override`` keyword must be used in the overriding contract: ``virtual`` keyword must be used on the overridden modifier
and the ``override`` keyword must be used in the overriding modifier:
:: ::
@ -314,7 +315,7 @@ function overriding (except that there is no overloading for modifiers). The
contract Base contract Base
{ {
modifier foo() {_;} modifier foo() virtual {_;}
} }
contract Inherited is Base contract Inherited is Base
@ -332,12 +333,12 @@ explicitly:
contract Base1 contract Base1
{ {
modifier foo() {_;} modifier foo() virtual {_;}
} }
contract Base2 contract Base2
{ {
modifier foo() {_;} modifier foo() virtual {_;}
} }
contract Inherited is Base1, Base2 contract Inherited is Base1, Base2

View File

@ -238,10 +238,8 @@ void ContractLevelChecker::findDuplicateDefinitions(map<string, vector<T>> const
void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _contract) void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _contract)
{ {
FunctionMultiSet const& funcSet = inheritedFunctions(_contract); FunctionMultiSet const& inheritedFuncs = inheritedFunctions(_contract);
ModifierMultiSet const& modSet = inheritedModifiers(_contract); ModifierMultiSet const& inheritedMods = inheritedModifiers(_contract);
checkModifierOverrides(funcSet, modSet, _contract.functionModifiers());
for (auto const* stateVar: _contract.stateVariables()) for (auto const* stateVar: _contract.stateVariables())
{ {
@ -250,9 +248,9 @@ void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _cont
bool found = false; bool found = false;
for ( for (
auto it = find_if(funcSet.begin(), funcSet.end(), MatchByName{stateVar->name()}); auto it = find_if(inheritedFuncs.begin(), inheritedFuncs.end(), MatchByName{stateVar->name()});
it != funcSet.end(); it != inheritedFuncs.end();
it = find_if(++it, funcSet.end(), MatchByName{stateVar->name()}) it = find_if(++it, inheritedFuncs.end(), MatchByName{stateVar->name()})
) )
{ {
if (!hasEqualNameAndParameters(*stateVar, **it)) if (!hasEqualNameAndParameters(*stateVar, **it))
@ -261,7 +259,7 @@ void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _cont
if ((*it)->visibility() != Declaration::Visibility::External) if ((*it)->visibility() != Declaration::Visibility::External)
overrideError(*stateVar, **it, "Public state variables can only override functions with external visibility."); overrideError(*stateVar, **it, "Public state variables can only override functions with external visibility.");
else else
checkFunctionOverride(*stateVar, **it); checkOverride(*stateVar, **it);
found = true; found = true;
} }
@ -273,51 +271,96 @@ void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _cont
); );
} }
for (ModifierDefinition const* modifier: _contract.functionModifiers())
{
if (contains_if(inheritedFuncs, MatchByName{modifier->name()}))
m_errorReporter.typeError(
modifier->location(),
"Override changes function to modifier."
);
auto [begin, end] = inheritedMods.equal_range(modifier);
if (begin == end && modifier->overrides())
m_errorReporter.typeError(
modifier->overrides()->location(),
"Modifier has override specified but does not override anything."
);
for (; begin != end; begin++)
if (ModifierType(**begin) != ModifierType(*modifier))
m_errorReporter.typeError(
modifier->location(),
"Override changes modifier signature."
);
checkOverrideList(inheritedMods, *modifier);
}
for (FunctionDefinition const* function: _contract.definedFunctions()) for (FunctionDefinition const* function: _contract.definedFunctions())
{ {
if (function->isConstructor()) if (function->isConstructor())
continue; continue;
if (contains_if(modSet, MatchByName{function->name()})) if (contains_if(inheritedMods, MatchByName{function->name()}))
m_errorReporter.typeError(function->location(), "Override changes modifier to function."); m_errorReporter.typeError(function->location(), "Override changes modifier to function.");
// No inheriting functions found // No inheriting functions found
if (!funcSet.count(function) && function->overrides()) if (!inheritedFuncs.count(function) && function->overrides())
m_errorReporter.typeError( m_errorReporter.typeError(
function->overrides()->location(), function->overrides()->location(),
"Function has override specified but does not override anything." "Function has override specified but does not override anything."
); );
checkOverrideList(funcSet, *function); checkOverrideList(inheritedFuncs, *function);
} }
} }
template<class T> template<class T, class U>
void ContractLevelChecker::checkFunctionOverride(T const& _overriding, FunctionDefinition const& _super) void ContractLevelChecker::checkOverride(T const& _overriding, U const& _super)
{ {
string overridingName; static_assert(
std::is_same<VariableDeclaration, T>::value ||
std::is_same<FunctionDefinition, T>::value ||
std::is_same<ModifierDefinition, T>::value,
"Invalid call to checkOverride."
);
static_assert(
std::is_same<FunctionDefinition, U>::value ||
std::is_same<ModifierDefinition, U>::value,
"Invalid call to checkOverride."
);
static_assert(
!std::is_same<ModifierDefinition, U>::value ||
std::is_same<ModifierDefinition, T>::value,
"Invalid call to checkOverride."
);
string overridingName;
if constexpr(std::is_same<FunctionDefinition, T>::value) if constexpr(std::is_same<FunctionDefinition, T>::value)
overridingName = "function"; overridingName = "function";
else if constexpr(std::is_same<ModifierDefinition, T>::value)
overridingName = "modifier";
else else
overridingName = "public state variable"; overridingName = "public state variable";
FunctionTypePointer functionType = FunctionType(_overriding).asCallableFunction(false); string superName;
FunctionTypePointer superType = FunctionType(_super).asCallableFunction(false); if constexpr(std::is_same<FunctionDefinition, U>::value)
superName = "function";
solAssert(functionType->hasEqualParameterTypes(*superType), "Override doesn't have equal parameters!"); else
superName = "modifier";
if (!_overriding.overrides()) if (!_overriding.overrides())
overrideError(_overriding, _super, "Overriding " + overridingName + " is missing 'override' specifier."); overrideError(_overriding, _super, "Overriding " + overridingName + " is missing 'override' specifier.");
if (!_super.virtualSemantics()) if (!_super.virtualSemantics())
overrideError( _super, _overriding, "Trying to override non-virtual function. Did you forget to add \"virtual\"?", "Overriding " + overridingName + " is here:"); overrideError(
_super,
if (!functionType->hasEqualReturnTypes(*superType)) _overriding,
overrideError(_overriding, _super, "Overriding " + overridingName + " return types differ."); "Trying to override non-virtual " + superName + ". Did you forget to add \"virtual\"?",
"Overriding " + overridingName + " is here:"
if constexpr(std::is_same<T, FunctionDefinition>::value) );
_overriding.annotation().baseFunctions.emplace(&_super);
if (_overriding.visibility() != _super.visibility()) if (_overriding.visibility() != _super.visibility())
{ {
@ -330,8 +373,23 @@ void ContractLevelChecker::checkFunctionOverride(T const& _overriding, FunctionD
overrideError(_overriding, _super, "Overriding " + overridingName + " visibility differs."); overrideError(_overriding, _super, "Overriding " + overridingName + " visibility differs.");
} }
// This is only relevant for overriding functions by functions or state variables,
// it is skipped for modifiers.
if constexpr(std::is_same<FunctionDefinition, U>::value)
{
FunctionTypePointer functionType = FunctionType(_overriding).asCallableFunction(false);
FunctionTypePointer superType = FunctionType(_super).asCallableFunction(false);
solAssert(functionType->hasEqualParameterTypes(*superType), "Override doesn't have equal parameters!");
if (!functionType->hasEqualReturnTypes(*superType))
overrideError(_overriding, _super, "Overriding " + overridingName + " return types differ.");
// This is only relevant for a function overriding a function.
if constexpr(std::is_same<T, FunctionDefinition>::value) if constexpr(std::is_same<T, FunctionDefinition>::value)
{ {
_overriding.annotation().baseFunctions.emplace(&_super);
if (_overriding.stateMutability() != _super.stateMutability()) if (_overriding.stateMutability() != _super.stateMutability())
overrideError( overrideError(
_overriding, _overriding,
@ -351,8 +409,14 @@ void ContractLevelChecker::checkFunctionOverride(T const& _overriding, FunctionD
); );
} }
} }
}
void ContractLevelChecker::overrideListError(FunctionDefinition const& function, set<ContractDefinition const*, LessFunction> _secondary, string const& _message1, string const& _message2) void ContractLevelChecker::overrideListError(
CallableDeclaration const& _callable,
set<ContractDefinition const*, LessFunction> _secondary,
string const& _message1,
string const& _message2
)
{ {
// Using a set rather than a vector so the order is always the same // Using a set rather than a vector so the order is always the same
set<string> names; set<string> names;
@ -367,7 +431,7 @@ void ContractLevelChecker::overrideListError(FunctionDefinition const& function,
contractSingularPlural = "contracts "; contractSingularPlural = "contracts ";
m_errorReporter.typeError( m_errorReporter.typeError(
function.overrides() ? function.overrides()->location() : function.location(), _callable.overrides() ? _callable.overrides()->location() : _callable.location(),
ssl, ssl,
_message1 + _message1 +
contractSingularPlural + contractSingularPlural +
@ -665,6 +729,10 @@ void ContractLevelChecker::checkBaseABICompatibility(ContractDefinition const& _
} }
void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _contract) const void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _contract) const
{
std::function<bool(CallableDeclaration const*, CallableDeclaration const*)> compareById =
[](auto const* _a, auto const* _b) { return _a->id() < _b->id(); };
{ {
// Fetch inherited functions and sort them by signature. // Fetch inherited functions and sort them by signature.
// We get at least one function per signature and direct base contract, which is // We get at least one function per signature and direct base contract, which is
@ -676,13 +744,36 @@ void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _co
// Walk through the set of functions signature by signature. // Walk through the set of functions signature by signature.
for (auto it = nonOverriddenFunctions.cbegin(); it != nonOverriddenFunctions.cend();) for (auto it = nonOverriddenFunctions.cbegin(); it != nonOverriddenFunctions.cend();)
{ {
static constexpr auto compareById = [](auto const* a, auto const* b) { return a->id() < b->id(); }; std::set<CallableDeclaration const*, decltype(compareById)> baseFunctions(compareById);
std::set<FunctionDefinition const*, decltype(compareById)> baseFunctions(compareById);
for (auto nextSignature = nonOverriddenFunctions.upper_bound(*it); it != nextSignature; ++it) for (auto nextSignature = nonOverriddenFunctions.upper_bound(*it); it != nextSignature; ++it)
baseFunctions.insert(*it); baseFunctions.insert(*it);
if (baseFunctions.size() <= 1) checkAmbiguousOverridesInternal(std::move(baseFunctions), _contract.location());
continue; }
}
{
ModifierMultiSet modifiers = inheritedModifiers(_contract);
modifiers -= _contract.functionModifiers();
for (auto it = modifiers.cbegin(); it != modifiers.cend();)
{
std::set<CallableDeclaration const*, decltype(compareById)> baseModifiers(compareById);
for (auto next = modifiers.upper_bound(*it); it != next; ++it)
baseModifiers.insert(*it);
checkAmbiguousOverridesInternal(std::move(baseModifiers), _contract.location());
}
}
}
void ContractLevelChecker::checkAmbiguousOverridesInternal(set<
CallableDeclaration const*,
std::function<bool(CallableDeclaration const*, CallableDeclaration const*)>
> _baseCallables, SourceLocation const& _location) const
{
if (_baseCallables.size() <= 1)
return;
// Construct the override graph for this signature. // Construct the override graph for this signature.
// Reserve node 0 for the current contract and node // Reserve node 0 for the current contract and node
@ -690,13 +781,13 @@ void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _co
// connect at the end. // connect at the end.
struct OverrideGraph struct OverrideGraph
{ {
OverrideGraph(decltype(baseFunctions) const& _baseFunctions) OverrideGraph(decltype(_baseCallables) const& __baseCallables)
{ {
for (auto const* baseFunction: _baseFunctions) for (auto const* baseFunction: __baseCallables)
addEdge(0, visit(baseFunction)); addEdge(0, visit(baseFunction));
} }
std::map<FunctionDefinition const*, int> nodes; std::map<CallableDeclaration const*, int> nodes;
std::map<int, FunctionDefinition const*> nodeInv; std::map<int, CallableDeclaration const*> nodeInv;
std::map<int, std::set<int>> edges; std::map<int, std::set<int>> edges;
int numNodes = 2; int numNodes = 2;
void addEdge(int _a, int _b) void addEdge(int _a, int _b)
@ -707,7 +798,7 @@ void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _co
private: private:
/// Completes the graph starting from @a _function and /// Completes the graph starting from @a _function and
/// @returns the node ID. /// @returns the node ID.
int visit(FunctionDefinition const* _function) int visit(CallableDeclaration const* _function)
{ {
auto it = nodes.find(_function); auto it = nodes.find(_function);
if (it != nodes.end()) if (it != nodes.end())
@ -723,7 +814,7 @@ void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _co
return currentNode; return currentNode;
} }
} overrideGraph(baseFunctions); } overrideGraph(_baseCallables);
// Detect cut vertices following https://en.wikipedia.org/wiki/Biconnected_component#Pseudocode // Detect cut vertices following https://en.wikipedia.org/wiki/Biconnected_component#Pseudocode
// Can ignore the root node, since it is never a cut vertex in our case. // Can ignore the root node, since it is never a cut vertex in our case.
@ -733,7 +824,7 @@ void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _co
{ {
run(); run();
} }
std::set<FunctionDefinition const*> const& cutVertices() const { return m_cutVertices; } std::set<CallableDeclaration const*> const& cutVertices() const { return m_cutVertices; }
private: private:
OverrideGraph const& m_graph; OverrideGraph const& m_graph;
@ -742,7 +833,7 @@ void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _co
std::vector<int> m_depths = std::vector<int>(m_graph.numNodes, -1); std::vector<int> m_depths = std::vector<int>(m_graph.numNodes, -1);
std::vector<int> m_low = std::vector<int>(m_graph.numNodes, -1); std::vector<int> m_low = std::vector<int>(m_graph.numNodes, -1);
std::vector<int> m_parent = std::vector<int>(m_graph.numNodes, -1); std::vector<int> m_parent = std::vector<int>(m_graph.numNodes, -1);
std::set<FunctionDefinition const*> m_cutVertices{}; std::set<CallableDeclaration const*> m_cutVertices{};
void run(int _u = 0, int _depth = 0) void run(int _u = 0, int _depth = 0)
{ {
@ -765,37 +856,55 @@ void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _co
// Remove all base functions overridden by cut vertices (they don't need to be overridden). // Remove all base functions overridden by cut vertices (they don't need to be overridden).
for (auto const* function: cutVertexFinder.cutVertices()) for (auto const* function: cutVertexFinder.cutVertices())
{ {
std::set<FunctionDefinition const*> toTraverse = function->annotation().baseFunctions; std::set<CallableDeclaration const*> toTraverse = function->annotation().baseFunctions;
while (!toTraverse.empty()) while (!toTraverse.empty())
{ {
auto const *base = *toTraverse.begin(); auto const *base = *toTraverse.begin();
toTraverse.erase(toTraverse.begin()); toTraverse.erase(toTraverse.begin());
baseFunctions.erase(base); _baseCallables.erase(base);
for (auto const* f: base->annotation().baseFunctions) for (CallableDeclaration const* f: base->annotation().baseFunctions)
toTraverse.insert(f); toTraverse.insert(f);
} }
// Remove unimplemented base functions at the cut vertices themselves as well. // Remove unimplemented base functions at the cut vertices itself as well.
if (!function->isImplemented()) if (auto opt = dynamic_cast<ImplementationOptional const*>(function))
baseFunctions.erase(function); if (!opt->isImplemented())
_baseCallables.erase(function);
} }
// If more than one function is left, they have to be overridden. // If more than one function is left, they have to be overridden.
if (baseFunctions.size() <= 1) if (_baseCallables.size() <= 1)
continue; return;
SecondarySourceLocation ssl; SecondarySourceLocation ssl;
for (auto const* baseFunction: baseFunctions) for (auto const* baseFunction: _baseCallables)
ssl.append("Definition here: ", baseFunction->location()); {
string contractName = dynamic_cast<ContractDefinition const&>(*baseFunction->scope()).name();
ssl.append("Definition in \"" + contractName + "\": ", baseFunction->location());
}
string callableName;
string distinguishigProperty;
if (dynamic_cast<FunctionDefinition const*>(*_baseCallables.begin()))
{
callableName = "function";
distinguishigProperty = "name and parameter types";
}
else if (dynamic_cast<ModifierDefinition const*>(*_baseCallables.begin()))
{
callableName = "modifier";
distinguishigProperty = "name";
}
else
solAssert(false, "Invalid type for ambiguous override.");
m_errorReporter.typeError( m_errorReporter.typeError(
_contract.location(), _location,
ssl, ssl,
"Derived contract must override function \"" + "Derived contract must override " + callableName + " \"" +
(*baseFunctions.begin())->name() + (*_baseCallables.begin())->name() +
"\". Function with the same name and parameter types defined in two or more base classes." "\". Two or more base classes define " + callableName + " with same " + distinguishigProperty + "."
); );
} }
}
set<ContractDefinition const*, ContractLevelChecker::LessFunction> ContractLevelChecker::resolveOverrideList(OverrideSpecifier const& _overrides) const set<ContractDefinition const*, ContractLevelChecker::LessFunction> ContractLevelChecker::resolveOverrideList(OverrideSpecifier const& _overrides) const
{ {
@ -815,49 +924,23 @@ set<ContractDefinition const*, ContractLevelChecker::LessFunction> ContractLevel
return resolved; return resolved;
} }
template <class T>
void ContractLevelChecker::checkModifierOverrides(FunctionMultiSet const& _funcSet, ModifierMultiSet const& _modSet, std::vector<ModifierDefinition const*> _modifiers) void ContractLevelChecker::checkOverrideList(
{ std::multiset<T const*, LessFunction> const& _inheritedCallables,
for (ModifierDefinition const* modifier: _modifiers) T const& _callable
{ )
if (contains_if(_funcSet, MatchByName{modifier->name()}))
m_errorReporter.typeError(
modifier->location(),
"Override changes function to modifier."
);
auto [begin,end] = _modSet.equal_range(modifier);
// Skip if no modifiers found in bases
if (begin == end)
continue;
if (!modifier->overrides())
overrideError(*modifier, **begin, "Overriding modifier is missing 'override' specifier.");
for (; begin != end; begin++)
if (ModifierType(**begin) != ModifierType(*modifier))
m_errorReporter.typeError(
modifier->location(),
"Override changes modifier signature."
);
}
}
void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, FunctionDefinition const& _function)
{ {
set<ContractDefinition const*, LessFunction> specifiedContracts = set<ContractDefinition const*, LessFunction> specifiedContracts =
_function.overrides() ? _callable.overrides() ?
resolveOverrideList(*_function.overrides()) : resolveOverrideList(*_callable.overrides()) :
decltype(specifiedContracts){}; decltype(specifiedContracts){};
// Check for duplicates in override list // Check for duplicates in override list
if (_function.overrides() && specifiedContracts.size() != _function.overrides()->overrides().size()) if (_callable.overrides() && specifiedContracts.size() != _callable.overrides()->overrides().size())
{ {
// Sort by contract id to find duplicate for error reporting // Sort by contract id to find duplicate for error reporting
vector<ASTPointer<UserDefinedTypeName>> list = vector<ASTPointer<UserDefinedTypeName>> list =
sortByContract(_function.overrides()->overrides()); sortByContract(_callable.overrides()->overrides());
// Find duplicates and output error // Find duplicates and output error
for (size_t i = 1; i < list.size(); i++) for (size_t i = 1; i < list.size(); i++)
@ -877,7 +960,7 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
"Duplicate contract \"" + "Duplicate contract \"" +
joinHumanReadable(list[i]->namePath(), ".") + joinHumanReadable(list[i]->namePath(), ".") +
"\" found in override list of \"" + "\" found in override list of \"" +
_function.name() + _callable.name() +
"\"." "\"."
); );
} }
@ -887,12 +970,12 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
decltype(specifiedContracts) expectedContracts; decltype(specifiedContracts) expectedContracts;
// Build list of expected contracts // Build list of expected contracts
for (auto [begin, end] = _funcSet.equal_range(&_function); begin != end; begin++) for (auto [begin, end] = _inheritedCallables.equal_range(&_callable); begin != end; begin++)
{ {
// Validate the override // Validate the override
checkFunctionOverride(_function, **begin); checkOverride(_callable, **begin);
expectedContracts.insert((*begin)->annotation().contract); expectedContracts.insert(&dynamic_cast<ContractDefinition const&>(*(*begin)->scope()));
} }
decltype(specifiedContracts) missingContracts; decltype(specifiedContracts) missingContracts;
@ -906,7 +989,7 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
if (!missingContracts.empty()) if (!missingContracts.empty())
overrideListError( overrideListError(
_function, _callable,
missingContracts, missingContracts,
"Function needs to specify overridden ", "Function needs to specify overridden ",
"" ""
@ -914,7 +997,7 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
if (!surplusContracts.empty()) if (!surplusContracts.empty())
overrideListError( overrideListError(
_function, _callable,
surplusContracts, surplusContracts,
"Invalid ", "Invalid ",
"specified in override list: " "specified in override list: "

View File

@ -22,7 +22,9 @@
#pragma once #pragma once
#include <libsolidity/ast/ASTForward.h> #include <libsolidity/ast/ASTForward.h>
#include <liblangutil/SourceLocation.h>
#include <map> #include <map>
#include <functional>
#include <set> #include <set>
namespace langutil namespace langutil
@ -53,6 +55,12 @@ public:
bool check(ContractDefinition const& _contract); bool check(ContractDefinition const& _contract);
private: private:
/**
* Comparator that compares
* - functions such that equality means that the functions override each other
* - modifiers by name
* - contracts by AST id.
*/
struct LessFunction struct LessFunction
{ {
bool operator()(ModifierDefinition const* _a, ModifierDefinition const* _b) const; bool operator()(ModifierDefinition const* _a, ModifierDefinition const* _b) const;
@ -70,13 +78,24 @@ private:
template <class T> template <class T>
void findDuplicateDefinitions(std::map<std::string, std::vector<T>> const& _definitions, std::string _message); void findDuplicateDefinitions(std::map<std::string, std::vector<T>> const& _definitions, std::string _message);
void checkIllegalOverrides(ContractDefinition const& _contract); void checkIllegalOverrides(ContractDefinition const& _contract);
/// Performs various checks related to @a _function overriding @a _super like /// Performs various checks related to @a _overriding overriding @a _super like
/// different return type, invalid visibility change, etc. /// different return type, invalid visibility change, etc.
/// Works on functions, modifiers and public state variables.
/// Also stores @a _super as a base function of @a _function in its AST annotations. /// Also stores @a _super as a base function of @a _function in its AST annotations.
template<class T> template<class T, class U>
void checkFunctionOverride(T const& _overriding, FunctionDefinition const& _super); void checkOverride(T const& _overriding, U const& _super);
void overrideListError(FunctionDefinition const& function, std::set<ContractDefinition const*, LessFunction> _secondary, std::string const& _message1, std::string const& _message2); void overrideListError(
void overrideError(Declaration const& _overriding, Declaration const& _super, std::string _message, std::string _secondaryMsg = "Overridden function is here:"); CallableDeclaration const& _callable,
std::set<ContractDefinition const*, LessFunction> _secondary,
std::string const& _message1,
std::string const& _message2
);
void overrideError(
Declaration const& _overriding,
Declaration const& _super,
std::string _message,
std::string _secondaryMsg = "Overridden function is here:"
);
void checkAbstractFunctions(ContractDefinition const& _contract); void checkAbstractFunctions(ContractDefinition const& _contract);
/// Checks that the base constructor arguments are properly provided. /// Checks that the base constructor arguments are properly provided.
/// Fills the list of unimplemented functions in _contract's annotations. /// Fills the list of unimplemented functions in _contract's annotations.
@ -98,11 +117,18 @@ private:
/// Checks for functions in different base contracts which conflict with each /// Checks for functions in different base contracts which conflict with each
/// other and thus need to be overridden explicitly. /// other and thus need to be overridden explicitly.
void checkAmbiguousOverrides(ContractDefinition const& _contract) const; void checkAmbiguousOverrides(ContractDefinition const& _contract) const;
void checkAmbiguousOverridesInternal(std::set<
CallableDeclaration const*,
std::function<bool(CallableDeclaration const*, CallableDeclaration const*)>
> _baseCallables, langutil::SourceLocation const& _location) const;
/// Resolves an override list of UserDefinedTypeNames to a list of contracts. /// Resolves an override list of UserDefinedTypeNames to a list of contracts.
std::set<ContractDefinition const*, LessFunction> resolveOverrideList(OverrideSpecifier const& _overrides) const; std::set<ContractDefinition const*, LessFunction> resolveOverrideList(OverrideSpecifier const& _overrides) const;
void checkModifierOverrides(FunctionMultiSet const& _funcSet, ModifierMultiSet const& _modSet, std::vector<ModifierDefinition const*> _modifiers); template <class T>
void checkOverrideList(FunctionMultiSet const& _funcSet, FunctionDefinition const& _function); void checkOverrideList(
std::multiset<T const*, LessFunction> const& _funcSet,
T const& _function
);
/// Returns all functions of bases that have not yet been overwritten. /// Returns all functions of bases that have not yet been overwritten.
/// May contain the same function multiple times when used with shared bases. /// May contain the same function multiple times when used with shared bases.

View File

@ -312,6 +312,16 @@ ContractDefinition::ContractKind FunctionDefinition::inContractKind() const
return contractDef->contractKind(); return contractDef->contractKind();
} }
CallableDeclarationAnnotation& CallableDeclaration::annotation() const
{
solAssert(
m_annotation,
"CallableDeclarationAnnotation is an abstract base, need to call annotation on the concrete class first."
);
return dynamic_cast<CallableDeclarationAnnotation&>(*m_annotation);
}
FunctionTypePointer FunctionDefinition::functionType(bool _internal) const FunctionTypePointer FunctionDefinition::functionType(bool _internal) const
{ {
if (_internal) if (_internal)

View File

@ -624,6 +624,8 @@ public:
bool markedVirtual() const { return m_isVirtual; } bool markedVirtual() const { return m_isVirtual; }
virtual bool virtualSemantics() const { return markedVirtual(); } virtual bool virtualSemantics() const { return markedVirtual(); }
CallableDeclarationAnnotation& annotation() const override;
protected: protected:
ASTPointer<ParameterList> m_parameters; ASTPointer<ParameterList> m_parameters;
ASTPointer<OverrideSpecifier> m_overrides; ASTPointer<OverrideSpecifier> m_overrides;

View File

@ -104,19 +104,23 @@ struct ContractDefinitionAnnotation: TypeDeclarationAnnotation, DocumentedAnnota
std::map<FunctionDefinition const*, ASTNode const*> baseConstructorArguments; std::map<FunctionDefinition const*, ASTNode const*> baseConstructorArguments;
}; };
struct FunctionDefinitionAnnotation: ASTAnnotation, DocumentedAnnotation struct CallableDeclarationAnnotation: ASTAnnotation
{
/// The set of functions/modifiers/events this callable overrides.
std::set<CallableDeclaration const*> baseFunctions;
};
struct FunctionDefinitionAnnotation: CallableDeclarationAnnotation, DocumentedAnnotation
{ {
/// The set of functions this function overrides.
std::set<FunctionDefinition const*> baseFunctions;
/// Pointer to the contract this function is defined in /// Pointer to the contract this function is defined in
ContractDefinition const* contract = nullptr; ContractDefinition const* contract = nullptr;
}; };
struct EventDefinitionAnnotation: ASTAnnotation, DocumentedAnnotation struct EventDefinitionAnnotation: CallableDeclarationAnnotation, DocumentedAnnotation
{ {
}; };
struct ModifierDefinitionAnnotation: ASTAnnotation, DocumentedAnnotation struct ModifierDefinitionAnnotation: CallableDeclarationAnnotation, DocumentedAnnotation
{ {
}; };

View File

@ -392,7 +392,7 @@ bool ASTJsonConverter::visit(VariableDeclaration const& _node)
bool ASTJsonConverter::visit(ModifierDefinition const& _node) bool ASTJsonConverter::visit(ModifierDefinition const& _node)
{ {
setJsonNode(_node, "ModifierDefinition", { std::vector<pair<string, Json::Value>> attributes = {
make_pair("name", _node.name()), make_pair("name", _node.name()),
make_pair("documentation", _node.documentation() ? Json::Value(*_node.documentation()) : Json::nullValue), make_pair("documentation", _node.documentation() ? Json::Value(*_node.documentation()) : Json::nullValue),
make_pair("visibility", Declaration::visibilityToString(_node.visibility())), make_pair("visibility", Declaration::visibilityToString(_node.visibility())),
@ -400,7 +400,10 @@ bool ASTJsonConverter::visit(ModifierDefinition const& _node)
make_pair("virtual", _node.markedVirtual()), make_pair("virtual", _node.markedVirtual()),
make_pair("overrides", _node.overrides() ? toJson(*_node.overrides()) : Json::nullValue), make_pair("overrides", _node.overrides() ? toJson(*_node.overrides()) : Json::nullValue),
make_pair("body", toJson(_node.body())) make_pair("body", toJson(_node.body()))
}); };
if (!_node.annotation().baseFunctions.empty())
attributes.emplace_back(make_pair("baseModifiers", getContainerIds(_node.annotation().baseFunctions, true)));
setJsonNode(_node, "ModifierDefinition", std::move(attributes));
return false; return false;
} }

View File

@ -2313,7 +2313,7 @@ BOOST_AUTO_TEST_CASE(function_modifier_overriding)
char const* sourceCode = R"( char const* sourceCode = R"(
contract A { contract A {
function f() mod public returns (bool r) { return true; } function f() mod public returns (bool r) { return true; }
modifier mod { _; } modifier mod virtual { _; }
} }
contract C is A { contract C is A {
modifier mod override { if (false) _; } modifier mod override { if (false) _; }
@ -2352,7 +2352,7 @@ BOOST_AUTO_TEST_CASE(function_modifier_for_constructor)
contract A { contract A {
uint data; uint data;
constructor() mod1 public { data |= 2; } constructor() mod1 public { data |= 2; }
modifier mod1 { data |= 1; _; } modifier mod1 virtual { data |= 1; _; }
function getData() public returns (uint r) { return data; } function getData() public returns (uint r) { return data; }
} }
contract C is A { contract C is A {

View File

@ -16,5 +16,5 @@ abstract contract B is I {
contract C is A, B { contract C is A, B {
} }
// ---- // ----
// TypeError: (342-364): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes. // TypeError: (342-364): Derived contract must override function "f". Two or more base classes define function with same name and parameter types.
// TypeError: (342-364): Derived contract must override function "g". Function with the same name and parameter types defined in two or more base classes. // TypeError: (342-364): Derived contract must override function "g". Two or more base classes define function with same name and parameter types.

View File

@ -14,4 +14,4 @@ abstract contract B is I {
contract C is A, B { contract C is A, B {
} }
// ---- // ----
// TypeError: (254-276): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes. // TypeError: (254-276): Derived contract must override function "f". Two or more base classes define function with same name and parameter types.

View File

@ -13,5 +13,5 @@ abstract contract B is I {
contract C is A, B { contract C is A, B {
} }
// ---- // ----
// TypeError: (292-314): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes. // TypeError: (292-314): Derived contract must override function "f". Two or more base classes define function with same name and parameter types.
// TypeError: (292-314): Derived contract must override function "g". Function with the same name and parameter types defined in two or more base classes. // TypeError: (292-314): Derived contract must override function "g". Two or more base classes define function with same name and parameter types.

View File

@ -6,4 +6,4 @@ contract B {
} }
contract C is A, B {} contract C is A, B {}
// ---- // ----
// TypeError: (126-147): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes. // TypeError: (126-147): Derived contract must override function "f". Two or more base classes define function with same name and parameter types.

View File

@ -0,0 +1,9 @@
contract A {
modifier f() virtual { _; }
}
contract B {
modifier f() virtual { _; }
}
contract C is A, B {
modifier f() override(A,B) { _; }
}

View File

@ -0,0 +1,10 @@
contract A {
modifier f() virtual { _; }
}
contract B {
modifier f() virtual { _; }
}
contract C is A, B {
}
// ----
// TypeError: (94-116): Derived contract must override modifier "f". Two or more base classes define modifier with same name.

View File

@ -0,0 +1,10 @@
contract A {
modifier f(uint a) virtual { _; }
}
contract B {
modifier f() virtual { _; }
}
contract C is A, B {
}
// ----
// TypeError: (100-122): Derived contract must override modifier "f". Two or more base classes define modifier with same name.

View File

@ -0,0 +1,11 @@
contract A {
modifier f(uint a) virtual { _; }
}
contract B {
modifier f() virtual { _; }
}
contract C is A, B {
modifier f() virtual override(A, B) { _; }
}
// ----
// TypeError: (125-167): Override changes modifier signature.

View File

@ -9,5 +9,5 @@ abstract contract B {
contract C is A, B { contract C is A, B {
} }
// ---- // ----
// TypeError: (176-198): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes. // TypeError: (176-198): Derived contract must override function "f". Two or more base classes define function with same name and parameter types.
// TypeError: (176-198): Derived contract must override function "g". Function with the same name and parameter types defined in two or more base classes. // TypeError: (176-198): Derived contract must override function "g". Two or more base classes define function with same name and parameter types.

View File

@ -0,0 +1,19 @@
contract I {
modifier f() virtual { _; }
}
contract J {
modifier f() virtual { _; }
}
contract IJ is I, J {
modifier f() virtual override (I, J) { _; }
}
contract A is IJ
{
modifier f() override { _; }
}
contract B is IJ
{
}
contract C is A, B {}
// ----
// TypeError: (229-250): Derived contract must override modifier "f". Two or more base classes define modifier with same name.

View File

@ -9,4 +9,4 @@ abstract contract X is A, B {
function test() internal override returns (uint256) {} function test() internal override returns (uint256) {}
} }
// ---- // ----
// TypeError: (205-292): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (205-292): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -0,0 +1,7 @@
abstract contract A {
}
abstract contract X is A {
modifier f() override { _; }
}
// ----
// TypeError: (65-73): Modifier has override specified but does not override anything.

View File

@ -10,4 +10,4 @@ contract C is A, B
{ {
} }
// ---- // ----
// TypeError: (94-116): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (94-116): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -9,4 +9,4 @@ abstract contract X is A, B {
function test() internal override virtual returns (uint256); function test() internal override virtual returns (uint256);
} }
// ---- // ----
// TypeError: (203-296): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (203-296): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -8,4 +8,4 @@ contract X is A, B {
uint public override foo; uint public override foo;
} }
// ---- // ----
// TypeError: (162-211): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (162-211): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -11,4 +11,4 @@ contract X is B, C {
uint public override foo; uint public override foo;
} }
// ---- // ----
// TypeError: (271-320): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (271-320): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -12,4 +12,4 @@ contract X is B, C {
} }
// ---- // ----
// DeclarationError: (245-269): Identifier already declared. // DeclarationError: (245-269): Identifier already declared.
// TypeError: (223-272): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (223-272): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -8,4 +8,4 @@ contract X is A, B {
uint public override(A, B) foo; uint public override(A, B) foo;
} }
// ---- // ----
// TypeError: (162-217): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (162-217): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -3,4 +3,4 @@ contract B is A { function f() public pure virtual override {} }
contract C is A, B { } contract C is A, B { }
contract D is A, B { function f() public pure override(A, B) {} } contract D is A, B { function f() public pure override(A, B) {} }
// ---- // ----
// TypeError: (116-138): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes. // TypeError: (116-138): Derived contract must override function "f". Two or more base classes define function with same name and parameter types.

View File

@ -1,5 +1,6 @@
contract A { modifier mod(uint a) { _; } } contract A { modifier mod(uint a) { _; } }
contract B is A { modifier mod(uint8 a) { _; } } contract B is A { modifier mod(uint8 a) { _; } }
// ---- // ----
// TypeError: (61-89): Overriding modifier is missing 'override' specifier.
// TypeError: (61-89): Override changes modifier signature. // TypeError: (61-89): Override changes modifier signature.
// TypeError: (61-89): Overriding modifier is missing 'override' specifier.
// TypeError: (13-40): Trying to override non-virtual modifier. Did you forget to add "virtual"?

View File

@ -1,2 +1,2 @@
contract A { modifier mod(uint a) { _; } } contract A { modifier mod(uint a) virtual { _; } }
contract B is A { modifier mod(uint a) override { _; } } contract B is A { modifier mod(uint a) override { _; } }

View File

@ -0,0 +1,4 @@
contract A { modifier mod(uint a) { _; } }
contract B is A { modifier mod(uint a) override { _; } }
// ----
// TypeError: (13-40): Trying to override non-virtual modifier. Did you forget to add "virtual"?