Merge pull request #7912 from ethereum/overrideForModifiers

Override for modifiers
This commit is contained in:
chriseth 2019-12-09 19:37:02 +01:00 committed by GitHub
commit bffd999b20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 442 additions and 242 deletions

View File

@ -305,8 +305,9 @@ Modifier Overriding
=================== ===================
Function modifiers can override each other. This works in the same way as Function modifiers can override each other. This works in the same way as
function overriding (except that there is no overloading for modifiers). The `function overriding <function-overriding>`_ (except that there is no overloading for modifiers). The
``override`` keyword must be used in the overriding contract: ``virtual`` keyword must be used on the overridden modifier
and the ``override`` keyword must be used in the overriding modifier:
:: ::
@ -314,7 +315,7 @@ function overriding (except that there is no overloading for modifiers). The
contract Base contract Base
{ {
modifier foo() {_;} modifier foo() virtual {_;}
} }
contract Inherited is Base contract Inherited is Base
@ -332,12 +333,12 @@ explicitly:
contract Base1 contract Base1
{ {
modifier foo() {_;} modifier foo() virtual {_;}
} }
contract Base2 contract Base2
{ {
modifier foo() {_;} modifier foo() virtual {_;}
} }
contract Inherited is Base1, Base2 contract Inherited is Base1, Base2

View File

@ -238,10 +238,8 @@ void ContractLevelChecker::findDuplicateDefinitions(map<string, vector<T>> const
void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _contract) void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _contract)
{ {
FunctionMultiSet const& funcSet = inheritedFunctions(_contract); FunctionMultiSet const& inheritedFuncs = inheritedFunctions(_contract);
ModifierMultiSet const& modSet = inheritedModifiers(_contract); ModifierMultiSet const& inheritedMods = inheritedModifiers(_contract);
checkModifierOverrides(funcSet, modSet, _contract.functionModifiers());
for (auto const* stateVar: _contract.stateVariables()) for (auto const* stateVar: _contract.stateVariables())
{ {
@ -250,9 +248,9 @@ void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _cont
bool found = false; bool found = false;
for ( for (
auto it = find_if(funcSet.begin(), funcSet.end(), MatchByName{stateVar->name()}); auto it = find_if(inheritedFuncs.begin(), inheritedFuncs.end(), MatchByName{stateVar->name()});
it != funcSet.end(); it != inheritedFuncs.end();
it = find_if(++it, funcSet.end(), MatchByName{stateVar->name()}) it = find_if(++it, inheritedFuncs.end(), MatchByName{stateVar->name()})
) )
{ {
if (!hasEqualNameAndParameters(*stateVar, **it)) if (!hasEqualNameAndParameters(*stateVar, **it))
@ -261,7 +259,7 @@ void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _cont
if ((*it)->visibility() != Declaration::Visibility::External) if ((*it)->visibility() != Declaration::Visibility::External)
overrideError(*stateVar, **it, "Public state variables can only override functions with external visibility."); overrideError(*stateVar, **it, "Public state variables can only override functions with external visibility.");
else else
checkFunctionOverride(*stateVar, **it); checkOverride(*stateVar, **it);
found = true; found = true;
} }
@ -273,51 +271,96 @@ void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _cont
); );
} }
for (ModifierDefinition const* modifier: _contract.functionModifiers())
{
if (contains_if(inheritedFuncs, MatchByName{modifier->name()}))
m_errorReporter.typeError(
modifier->location(),
"Override changes function to modifier."
);
auto [begin, end] = inheritedMods.equal_range(modifier);
if (begin == end && modifier->overrides())
m_errorReporter.typeError(
modifier->overrides()->location(),
"Modifier has override specified but does not override anything."
);
for (; begin != end; begin++)
if (ModifierType(**begin) != ModifierType(*modifier))
m_errorReporter.typeError(
modifier->location(),
"Override changes modifier signature."
);
checkOverrideList(inheritedMods, *modifier);
}
for (FunctionDefinition const* function: _contract.definedFunctions()) for (FunctionDefinition const* function: _contract.definedFunctions())
{ {
if (function->isConstructor()) if (function->isConstructor())
continue; continue;
if (contains_if(modSet, MatchByName{function->name()})) if (contains_if(inheritedMods, MatchByName{function->name()}))
m_errorReporter.typeError(function->location(), "Override changes modifier to function."); m_errorReporter.typeError(function->location(), "Override changes modifier to function.");
// No inheriting functions found // No inheriting functions found
if (!funcSet.count(function) && function->overrides()) if (!inheritedFuncs.count(function) && function->overrides())
m_errorReporter.typeError( m_errorReporter.typeError(
function->overrides()->location(), function->overrides()->location(),
"Function has override specified but does not override anything." "Function has override specified but does not override anything."
); );
checkOverrideList(funcSet, *function); checkOverrideList(inheritedFuncs, *function);
} }
} }
template<class T> template<class T, class U>
void ContractLevelChecker::checkFunctionOverride(T const& _overriding, FunctionDefinition const& _super) void ContractLevelChecker::checkOverride(T const& _overriding, U const& _super)
{ {
string overridingName; static_assert(
std::is_same<VariableDeclaration, T>::value ||
std::is_same<FunctionDefinition, T>::value ||
std::is_same<ModifierDefinition, T>::value,
"Invalid call to checkOverride."
);
static_assert(
std::is_same<FunctionDefinition, U>::value ||
std::is_same<ModifierDefinition, U>::value,
"Invalid call to checkOverride."
);
static_assert(
!std::is_same<ModifierDefinition, U>::value ||
std::is_same<ModifierDefinition, T>::value,
"Invalid call to checkOverride."
);
string overridingName;
if constexpr(std::is_same<FunctionDefinition, T>::value) if constexpr(std::is_same<FunctionDefinition, T>::value)
overridingName = "function"; overridingName = "function";
else if constexpr(std::is_same<ModifierDefinition, T>::value)
overridingName = "modifier";
else else
overridingName = "public state variable"; overridingName = "public state variable";
FunctionTypePointer functionType = FunctionType(_overriding).asCallableFunction(false); string superName;
FunctionTypePointer superType = FunctionType(_super).asCallableFunction(false); if constexpr(std::is_same<FunctionDefinition, U>::value)
superName = "function";
solAssert(functionType->hasEqualParameterTypes(*superType), "Override doesn't have equal parameters!"); else
superName = "modifier";
if (!_overriding.overrides()) if (!_overriding.overrides())
overrideError(_overriding, _super, "Overriding " + overridingName + " is missing 'override' specifier."); overrideError(_overriding, _super, "Overriding " + overridingName + " is missing 'override' specifier.");
if (!_super.virtualSemantics()) if (!_super.virtualSemantics())
overrideError( _super, _overriding, "Trying to override non-virtual function. Did you forget to add \"virtual\"?", "Overriding " + overridingName + " is here:"); overrideError(
_super,
if (!functionType->hasEqualReturnTypes(*superType)) _overriding,
overrideError(_overriding, _super, "Overriding " + overridingName + " return types differ."); "Trying to override non-virtual " + superName + ". Did you forget to add \"virtual\"?",
"Overriding " + overridingName + " is here:"
if constexpr(std::is_same<T, FunctionDefinition>::value) );
_overriding.annotation().baseFunctions.emplace(&_super);
if (_overriding.visibility() != _super.visibility()) if (_overriding.visibility() != _super.visibility())
{ {
@ -330,29 +373,50 @@ void ContractLevelChecker::checkFunctionOverride(T const& _overriding, FunctionD
overrideError(_overriding, _super, "Overriding " + overridingName + " visibility differs."); overrideError(_overriding, _super, "Overriding " + overridingName + " visibility differs.");
} }
if constexpr(std::is_same<T, FunctionDefinition>::value) // This is only relevant for overriding functions by functions or state variables,
// it is skipped for modifiers.
if constexpr(std::is_same<FunctionDefinition, U>::value)
{ {
if (_overriding.stateMutability() != _super.stateMutability()) FunctionTypePointer functionType = FunctionType(_overriding).asCallableFunction(false);
overrideError( FunctionTypePointer superType = FunctionType(_super).asCallableFunction(false);
_overriding,
_super,
"Overriding function changes state mutability from \"" +
stateMutabilityToString(_super.stateMutability()) +
"\" to \"" +
stateMutabilityToString(_overriding.stateMutability()) +
"\"."
);
if (!_overriding.isImplemented() && _super.isImplemented()) solAssert(functionType->hasEqualParameterTypes(*superType), "Override doesn't have equal parameters!");
overrideError(
_overriding, if (!functionType->hasEqualReturnTypes(*superType))
_super, overrideError(_overriding, _super, "Overriding " + overridingName + " return types differ.");
"Overriding an implemented function with an unimplemented function is not allowed."
); // This is only relevant for a function overriding a function.
if constexpr(std::is_same<T, FunctionDefinition>::value)
{
_overriding.annotation().baseFunctions.emplace(&_super);
if (_overriding.stateMutability() != _super.stateMutability())
overrideError(
_overriding,
_super,
"Overriding function changes state mutability from \"" +
stateMutabilityToString(_super.stateMutability()) +
"\" to \"" +
stateMutabilityToString(_overriding.stateMutability()) +
"\"."
);
if (!_overriding.isImplemented() && _super.isImplemented())
overrideError(
_overriding,
_super,
"Overriding an implemented function with an unimplemented function is not allowed."
);
}
} }
} }
void ContractLevelChecker::overrideListError(FunctionDefinition const& function, set<ContractDefinition const*, LessFunction> _secondary, string const& _message1, string const& _message2) void ContractLevelChecker::overrideListError(
CallableDeclaration const& _callable,
set<ContractDefinition const*, LessFunction> _secondary,
string const& _message1,
string const& _message2
)
{ {
// Using a set rather than a vector so the order is always the same // Using a set rather than a vector so the order is always the same
set<string> names; set<string> names;
@ -367,7 +431,7 @@ void ContractLevelChecker::overrideListError(FunctionDefinition const& function,
contractSingularPlural = "contracts "; contractSingularPlural = "contracts ";
m_errorReporter.typeError( m_errorReporter.typeError(
function.overrides() ? function.overrides()->location() : function.location(), _callable.overrides() ? _callable.overrides()->location() : _callable.location(),
ssl, ssl,
_message1 + _message1 +
contractSingularPlural + contractSingularPlural +
@ -666,137 +730,182 @@ void ContractLevelChecker::checkBaseABICompatibility(ContractDefinition const& _
void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _contract) const void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _contract) const
{ {
// Fetch inherited functions and sort them by signature. std::function<bool(CallableDeclaration const*, CallableDeclaration const*)> compareById =
// We get at least one function per signature and direct base contract, which is [](auto const* _a, auto const* _b) { return _a->id() < _b->id(); };
// enough because we re-construct the inheritance graph later.
FunctionMultiSet nonOverriddenFunctions = inheritedFunctions(_contract);
// Remove all functions that match the signature of a function in the current contract.
nonOverriddenFunctions -= _contract.definedFunctions();
// Walk through the set of functions signature by signature.
for (auto it = nonOverriddenFunctions.cbegin(); it != nonOverriddenFunctions.cend();)
{ {
static constexpr auto compareById = [](auto const* a, auto const* b) { return a->id() < b->id(); }; // Fetch inherited functions and sort them by signature.
std::set<FunctionDefinition const*, decltype(compareById)> baseFunctions(compareById); // We get at least one function per signature and direct base contract, which is
for (auto nextSignature = nonOverriddenFunctions.upper_bound(*it); it != nextSignature; ++it) // enough because we re-construct the inheritance graph later.
baseFunctions.insert(*it); FunctionMultiSet nonOverriddenFunctions = inheritedFunctions(_contract);
// Remove all functions that match the signature of a function in the current contract.
nonOverriddenFunctions -= _contract.definedFunctions();
if (baseFunctions.size() <= 1) // Walk through the set of functions signature by signature.
continue; for (auto it = nonOverriddenFunctions.cbegin(); it != nonOverriddenFunctions.cend();)
// Construct the override graph for this signature.
// Reserve node 0 for the current contract and node
// 1 for an artificial top node to which all override paths
// connect at the end.
struct OverrideGraph
{ {
OverrideGraph(decltype(baseFunctions) const& _baseFunctions) std::set<CallableDeclaration const*, decltype(compareById)> baseFunctions(compareById);
{ for (auto nextSignature = nonOverriddenFunctions.upper_bound(*it); it != nextSignature; ++it)
for (auto const* baseFunction: _baseFunctions) baseFunctions.insert(*it);
addEdge(0, visit(baseFunction));
}
std::map<FunctionDefinition const*, int> nodes;
std::map<int, FunctionDefinition const*> nodeInv;
std::map<int, std::set<int>> edges;
int numNodes = 2;
void addEdge(int _a, int _b)
{
edges[_a].insert(_b);
edges[_b].insert(_a);
}
private:
/// Completes the graph starting from @a _function and
/// @returns the node ID.
int visit(FunctionDefinition const* _function)
{
auto it = nodes.find(_function);
if (it != nodes.end())
return it->second;
int currentNode = numNodes++;
nodes[_function] = currentNode;
nodeInv[currentNode] = _function;
if (_function->overrides())
for (auto const* baseFunction: _function->annotation().baseFunctions)
addEdge(currentNode, visit(baseFunction));
else
addEdge(currentNode, 1);
return currentNode; checkAmbiguousOverridesInternal(std::move(baseFunctions), _contract.location());
} }
} overrideGraph(baseFunctions); }
// Detect cut vertices following https://en.wikipedia.org/wiki/Biconnected_component#Pseudocode {
// Can ignore the root node, since it is never a cut vertex in our case. ModifierMultiSet modifiers = inheritedModifiers(_contract);
struct CutVertexFinder modifiers -= _contract.functionModifiers();
for (auto it = modifiers.cbegin(); it != modifiers.cend();)
{ {
CutVertexFinder(OverrideGraph const& _graph): m_graph(_graph) std::set<CallableDeclaration const*, decltype(compareById)> baseModifiers(compareById);
{ for (auto next = modifiers.upper_bound(*it); it != next; ++it)
run(); baseModifiers.insert(*it);
}
std::set<FunctionDefinition const*> const& cutVertices() const { return m_cutVertices; }
private: checkAmbiguousOverridesInternal(std::move(baseModifiers), _contract.location());
OverrideGraph const& m_graph;
std::vector<bool> m_visited = std::vector<bool>(m_graph.numNodes, false);
std::vector<int> m_depths = std::vector<int>(m_graph.numNodes, -1);
std::vector<int> m_low = std::vector<int>(m_graph.numNodes, -1);
std::vector<int> m_parent = std::vector<int>(m_graph.numNodes, -1);
std::set<FunctionDefinition const*> m_cutVertices{};
void run(int _u = 0, int _depth = 0)
{
m_visited.at(_u) = true;
m_depths.at(_u) = m_low.at(_u) = _depth;
for (int v: m_graph.edges.at(_u))
if (!m_visited.at(v))
{
m_parent[v] = _u;
run(v, _depth + 1);
if (m_low[v] >= m_depths[_u] && m_parent[_u] != -1)
m_cutVertices.insert(m_graph.nodeInv.at(_u));
m_low[_u] = min(m_low[_u], m_low[v]);
}
else if (v != m_parent[_u])
m_low[_u] = min(m_low[_u], m_depths[v]);
}
} cutVertexFinder{overrideGraph};
// Remove all base functions overridden by cut vertices (they don't need to be overridden).
for (auto const* function: cutVertexFinder.cutVertices())
{
std::set<FunctionDefinition const*> toTraverse = function->annotation().baseFunctions;
while (!toTraverse.empty())
{
auto const* base = *toTraverse.begin();
toTraverse.erase(toTraverse.begin());
baseFunctions.erase(base);
for (auto const* f: base->annotation().baseFunctions)
toTraverse.insert(f);
}
// Remove unimplemented base functions at the cut vertices themselves as well.
if (!function->isImplemented())
baseFunctions.erase(function);
} }
// If more than one function is left, they have to be overridden.
if (baseFunctions.size() <= 1)
continue;
SecondarySourceLocation ssl;
for (auto const* baseFunction: baseFunctions)
ssl.append("Definition here: ", baseFunction->location());
m_errorReporter.typeError(
_contract.location(),
ssl,
"Derived contract must override function \"" +
(*baseFunctions.begin())->name() +
"\". Function with the same name and parameter types defined in two or more base classes."
);
} }
} }
void ContractLevelChecker::checkAmbiguousOverridesInternal(set<
CallableDeclaration const*,
std::function<bool(CallableDeclaration const*, CallableDeclaration const*)>
> _baseCallables, SourceLocation const& _location) const
{
if (_baseCallables.size() <= 1)
return;
// Construct the override graph for this signature.
// Reserve node 0 for the current contract and node
// 1 for an artificial top node to which all override paths
// connect at the end.
struct OverrideGraph
{
OverrideGraph(decltype(_baseCallables) const& __baseCallables)
{
for (auto const* baseFunction: __baseCallables)
addEdge(0, visit(baseFunction));
}
std::map<CallableDeclaration const*, int> nodes;
std::map<int, CallableDeclaration const*> nodeInv;
std::map<int, std::set<int>> edges;
int numNodes = 2;
void addEdge(int _a, int _b)
{
edges[_a].insert(_b);
edges[_b].insert(_a);
}
private:
/// Completes the graph starting from @a _function and
/// @returns the node ID.
int visit(CallableDeclaration const* _function)
{
auto it = nodes.find(_function);
if (it != nodes.end())
return it->second;
int currentNode = numNodes++;
nodes[_function] = currentNode;
nodeInv[currentNode] = _function;
if (_function->overrides())
for (auto const* baseFunction: _function->annotation().baseFunctions)
addEdge(currentNode, visit(baseFunction));
else
addEdge(currentNode, 1);
return currentNode;
}
} overrideGraph(_baseCallables);
// Detect cut vertices following https://en.wikipedia.org/wiki/Biconnected_component#Pseudocode
// Can ignore the root node, since it is never a cut vertex in our case.
struct CutVertexFinder
{
CutVertexFinder(OverrideGraph const& _graph): m_graph(_graph)
{
run();
}
std::set<CallableDeclaration const*> const& cutVertices() const { return m_cutVertices; }
private:
OverrideGraph const& m_graph;
std::vector<bool> m_visited = std::vector<bool>(m_graph.numNodes, false);
std::vector<int> m_depths = std::vector<int>(m_graph.numNodes, -1);
std::vector<int> m_low = std::vector<int>(m_graph.numNodes, -1);
std::vector<int> m_parent = std::vector<int>(m_graph.numNodes, -1);
std::set<CallableDeclaration const*> m_cutVertices{};
void run(int _u = 0, int _depth = 0)
{
m_visited.at(_u) = true;
m_depths.at(_u) = m_low.at(_u) = _depth;
for (int v: m_graph.edges.at(_u))
if (!m_visited.at(v))
{
m_parent[v] = _u;
run(v, _depth + 1);
if (m_low[v] >= m_depths[_u] && m_parent[_u] != -1)
m_cutVertices.insert(m_graph.nodeInv.at(_u));
m_low[_u] = min(m_low[_u], m_low[v]);
}
else if (v != m_parent[_u])
m_low[_u] = min(m_low[_u], m_depths[v]);
}
} cutVertexFinder{overrideGraph};
// Remove all base functions overridden by cut vertices (they don't need to be overridden).
for (auto const* function: cutVertexFinder.cutVertices())
{
std::set<CallableDeclaration const*> toTraverse = function->annotation().baseFunctions;
while (!toTraverse.empty())
{
auto const *base = *toTraverse.begin();
toTraverse.erase(toTraverse.begin());
_baseCallables.erase(base);
for (CallableDeclaration const* f: base->annotation().baseFunctions)
toTraverse.insert(f);
}
// Remove unimplemented base functions at the cut vertices itself as well.
if (auto opt = dynamic_cast<ImplementationOptional const*>(function))
if (!opt->isImplemented())
_baseCallables.erase(function);
}
// If more than one function is left, they have to be overridden.
if (_baseCallables.size() <= 1)
return;
SecondarySourceLocation ssl;
for (auto const* baseFunction: _baseCallables)
{
string contractName = dynamic_cast<ContractDefinition const&>(*baseFunction->scope()).name();
ssl.append("Definition in \"" + contractName + "\": ", baseFunction->location());
}
string callableName;
string distinguishigProperty;
if (dynamic_cast<FunctionDefinition const*>(*_baseCallables.begin()))
{
callableName = "function";
distinguishigProperty = "name and parameter types";
}
else if (dynamic_cast<ModifierDefinition const*>(*_baseCallables.begin()))
{
callableName = "modifier";
distinguishigProperty = "name";
}
else
solAssert(false, "Invalid type for ambiguous override.");
m_errorReporter.typeError(
_location,
ssl,
"Derived contract must override " + callableName + " \"" +
(*_baseCallables.begin())->name() +
"\". Two or more base classes define " + callableName + " with same " + distinguishigProperty + "."
);
}
set<ContractDefinition const*, ContractLevelChecker::LessFunction> ContractLevelChecker::resolveOverrideList(OverrideSpecifier const& _overrides) const set<ContractDefinition const*, ContractLevelChecker::LessFunction> ContractLevelChecker::resolveOverrideList(OverrideSpecifier const& _overrides) const
{ {
set<ContractDefinition const*, LessFunction> resolved; set<ContractDefinition const*, LessFunction> resolved;
@ -815,49 +924,23 @@ set<ContractDefinition const*, ContractLevelChecker::LessFunction> ContractLevel
return resolved; return resolved;
} }
template <class T>
void ContractLevelChecker::checkModifierOverrides(FunctionMultiSet const& _funcSet, ModifierMultiSet const& _modSet, std::vector<ModifierDefinition const*> _modifiers) void ContractLevelChecker::checkOverrideList(
{ std::multiset<T const*, LessFunction> const& _inheritedCallables,
for (ModifierDefinition const* modifier: _modifiers) T const& _callable
{ )
if (contains_if(_funcSet, MatchByName{modifier->name()}))
m_errorReporter.typeError(
modifier->location(),
"Override changes function to modifier."
);
auto [begin,end] = _modSet.equal_range(modifier);
// Skip if no modifiers found in bases
if (begin == end)
continue;
if (!modifier->overrides())
overrideError(*modifier, **begin, "Overriding modifier is missing 'override' specifier.");
for (; begin != end; begin++)
if (ModifierType(**begin) != ModifierType(*modifier))
m_errorReporter.typeError(
modifier->location(),
"Override changes modifier signature."
);
}
}
void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, FunctionDefinition const& _function)
{ {
set<ContractDefinition const*, LessFunction> specifiedContracts = set<ContractDefinition const*, LessFunction> specifiedContracts =
_function.overrides() ? _callable.overrides() ?
resolveOverrideList(*_function.overrides()) : resolveOverrideList(*_callable.overrides()) :
decltype(specifiedContracts){}; decltype(specifiedContracts){};
// Check for duplicates in override list // Check for duplicates in override list
if (_function.overrides() && specifiedContracts.size() != _function.overrides()->overrides().size()) if (_callable.overrides() && specifiedContracts.size() != _callable.overrides()->overrides().size())
{ {
// Sort by contract id to find duplicate for error reporting // Sort by contract id to find duplicate for error reporting
vector<ASTPointer<UserDefinedTypeName>> list = vector<ASTPointer<UserDefinedTypeName>> list =
sortByContract(_function.overrides()->overrides()); sortByContract(_callable.overrides()->overrides());
// Find duplicates and output error // Find duplicates and output error
for (size_t i = 1; i < list.size(); i++) for (size_t i = 1; i < list.size(); i++)
@ -877,7 +960,7 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
"Duplicate contract \"" + "Duplicate contract \"" +
joinHumanReadable(list[i]->namePath(), ".") + joinHumanReadable(list[i]->namePath(), ".") +
"\" found in override list of \"" + "\" found in override list of \"" +
_function.name() + _callable.name() +
"\"." "\"."
); );
} }
@ -887,12 +970,12 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
decltype(specifiedContracts) expectedContracts; decltype(specifiedContracts) expectedContracts;
// Build list of expected contracts // Build list of expected contracts
for (auto [begin, end] = _funcSet.equal_range(&_function); begin != end; begin++) for (auto [begin, end] = _inheritedCallables.equal_range(&_callable); begin != end; begin++)
{ {
// Validate the override // Validate the override
checkFunctionOverride(_function, **begin); checkOverride(_callable, **begin);
expectedContracts.insert((*begin)->annotation().contract); expectedContracts.insert(&dynamic_cast<ContractDefinition const&>(*(*begin)->scope()));
} }
decltype(specifiedContracts) missingContracts; decltype(specifiedContracts) missingContracts;
@ -906,7 +989,7 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
if (!missingContracts.empty()) if (!missingContracts.empty())
overrideListError( overrideListError(
_function, _callable,
missingContracts, missingContracts,
"Function needs to specify overridden ", "Function needs to specify overridden ",
"" ""
@ -914,7 +997,7 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
if (!surplusContracts.empty()) if (!surplusContracts.empty())
overrideListError( overrideListError(
_function, _callable,
surplusContracts, surplusContracts,
"Invalid ", "Invalid ",
"specified in override list: " "specified in override list: "

View File

@ -22,7 +22,9 @@
#pragma once #pragma once
#include <libsolidity/ast/ASTForward.h> #include <libsolidity/ast/ASTForward.h>
#include <liblangutil/SourceLocation.h>
#include <map> #include <map>
#include <functional>
#include <set> #include <set>
namespace langutil namespace langutil
@ -53,6 +55,12 @@ public:
bool check(ContractDefinition const& _contract); bool check(ContractDefinition const& _contract);
private: private:
/**
* Comparator that compares
* - functions such that equality means that the functions override each other
* - modifiers by name
* - contracts by AST id.
*/
struct LessFunction struct LessFunction
{ {
bool operator()(ModifierDefinition const* _a, ModifierDefinition const* _b) const; bool operator()(ModifierDefinition const* _a, ModifierDefinition const* _b) const;
@ -70,13 +78,24 @@ private:
template <class T> template <class T>
void findDuplicateDefinitions(std::map<std::string, std::vector<T>> const& _definitions, std::string _message); void findDuplicateDefinitions(std::map<std::string, std::vector<T>> const& _definitions, std::string _message);
void checkIllegalOverrides(ContractDefinition const& _contract); void checkIllegalOverrides(ContractDefinition const& _contract);
/// Performs various checks related to @a _function overriding @a _super like /// Performs various checks related to @a _overriding overriding @a _super like
/// different return type, invalid visibility change, etc. /// different return type, invalid visibility change, etc.
/// Works on functions, modifiers and public state variables.
/// Also stores @a _super as a base function of @a _function in its AST annotations. /// Also stores @a _super as a base function of @a _function in its AST annotations.
template<class T> template<class T, class U>
void checkFunctionOverride(T const& _overriding, FunctionDefinition const& _super); void checkOverride(T const& _overriding, U const& _super);
void overrideListError(FunctionDefinition const& function, std::set<ContractDefinition const*, LessFunction> _secondary, std::string const& _message1, std::string const& _message2); void overrideListError(
void overrideError(Declaration const& _overriding, Declaration const& _super, std::string _message, std::string _secondaryMsg = "Overridden function is here:"); CallableDeclaration const& _callable,
std::set<ContractDefinition const*, LessFunction> _secondary,
std::string const& _message1,
std::string const& _message2
);
void overrideError(
Declaration const& _overriding,
Declaration const& _super,
std::string _message,
std::string _secondaryMsg = "Overridden function is here:"
);
void checkAbstractFunctions(ContractDefinition const& _contract); void checkAbstractFunctions(ContractDefinition const& _contract);
/// Checks that the base constructor arguments are properly provided. /// Checks that the base constructor arguments are properly provided.
/// Fills the list of unimplemented functions in _contract's annotations. /// Fills the list of unimplemented functions in _contract's annotations.
@ -98,11 +117,18 @@ private:
/// Checks for functions in different base contracts which conflict with each /// Checks for functions in different base contracts which conflict with each
/// other and thus need to be overridden explicitly. /// other and thus need to be overridden explicitly.
void checkAmbiguousOverrides(ContractDefinition const& _contract) const; void checkAmbiguousOverrides(ContractDefinition const& _contract) const;
void checkAmbiguousOverridesInternal(std::set<
CallableDeclaration const*,
std::function<bool(CallableDeclaration const*, CallableDeclaration const*)>
> _baseCallables, langutil::SourceLocation const& _location) const;
/// Resolves an override list of UserDefinedTypeNames to a list of contracts. /// Resolves an override list of UserDefinedTypeNames to a list of contracts.
std::set<ContractDefinition const*, LessFunction> resolveOverrideList(OverrideSpecifier const& _overrides) const; std::set<ContractDefinition const*, LessFunction> resolveOverrideList(OverrideSpecifier const& _overrides) const;
void checkModifierOverrides(FunctionMultiSet const& _funcSet, ModifierMultiSet const& _modSet, std::vector<ModifierDefinition const*> _modifiers); template <class T>
void checkOverrideList(FunctionMultiSet const& _funcSet, FunctionDefinition const& _function); void checkOverrideList(
std::multiset<T const*, LessFunction> const& _funcSet,
T const& _function
);
/// Returns all functions of bases that have not yet been overwritten. /// Returns all functions of bases that have not yet been overwritten.
/// May contain the same function multiple times when used with shared bases. /// May contain the same function multiple times when used with shared bases.

View File

@ -312,6 +312,16 @@ ContractDefinition::ContractKind FunctionDefinition::inContractKind() const
return contractDef->contractKind(); return contractDef->contractKind();
} }
CallableDeclarationAnnotation& CallableDeclaration::annotation() const
{
solAssert(
m_annotation,
"CallableDeclarationAnnotation is an abstract base, need to call annotation on the concrete class first."
);
return dynamic_cast<CallableDeclarationAnnotation&>(*m_annotation);
}
FunctionTypePointer FunctionDefinition::functionType(bool _internal) const FunctionTypePointer FunctionDefinition::functionType(bool _internal) const
{ {
if (_internal) if (_internal)

View File

@ -624,6 +624,8 @@ public:
bool markedVirtual() const { return m_isVirtual; } bool markedVirtual() const { return m_isVirtual; }
virtual bool virtualSemantics() const { return markedVirtual(); } virtual bool virtualSemantics() const { return markedVirtual(); }
CallableDeclarationAnnotation& annotation() const override;
protected: protected:
ASTPointer<ParameterList> m_parameters; ASTPointer<ParameterList> m_parameters;
ASTPointer<OverrideSpecifier> m_overrides; ASTPointer<OverrideSpecifier> m_overrides;

View File

@ -104,19 +104,23 @@ struct ContractDefinitionAnnotation: TypeDeclarationAnnotation, DocumentedAnnota
std::map<FunctionDefinition const*, ASTNode const*> baseConstructorArguments; std::map<FunctionDefinition const*, ASTNode const*> baseConstructorArguments;
}; };
struct FunctionDefinitionAnnotation: ASTAnnotation, DocumentedAnnotation struct CallableDeclarationAnnotation: ASTAnnotation
{
/// The set of functions/modifiers/events this callable overrides.
std::set<CallableDeclaration const*> baseFunctions;
};
struct FunctionDefinitionAnnotation: CallableDeclarationAnnotation, DocumentedAnnotation
{ {
/// The set of functions this function overrides.
std::set<FunctionDefinition const*> baseFunctions;
/// Pointer to the contract this function is defined in /// Pointer to the contract this function is defined in
ContractDefinition const* contract = nullptr; ContractDefinition const* contract = nullptr;
}; };
struct EventDefinitionAnnotation: ASTAnnotation, DocumentedAnnotation struct EventDefinitionAnnotation: CallableDeclarationAnnotation, DocumentedAnnotation
{ {
}; };
struct ModifierDefinitionAnnotation: ASTAnnotation, DocumentedAnnotation struct ModifierDefinitionAnnotation: CallableDeclarationAnnotation, DocumentedAnnotation
{ {
}; };

View File

@ -392,7 +392,7 @@ bool ASTJsonConverter::visit(VariableDeclaration const& _node)
bool ASTJsonConverter::visit(ModifierDefinition const& _node) bool ASTJsonConverter::visit(ModifierDefinition const& _node)
{ {
setJsonNode(_node, "ModifierDefinition", { std::vector<pair<string, Json::Value>> attributes = {
make_pair("name", _node.name()), make_pair("name", _node.name()),
make_pair("documentation", _node.documentation() ? Json::Value(*_node.documentation()) : Json::nullValue), make_pair("documentation", _node.documentation() ? Json::Value(*_node.documentation()) : Json::nullValue),
make_pair("visibility", Declaration::visibilityToString(_node.visibility())), make_pair("visibility", Declaration::visibilityToString(_node.visibility())),
@ -400,7 +400,10 @@ bool ASTJsonConverter::visit(ModifierDefinition const& _node)
make_pair("virtual", _node.markedVirtual()), make_pair("virtual", _node.markedVirtual()),
make_pair("overrides", _node.overrides() ? toJson(*_node.overrides()) : Json::nullValue), make_pair("overrides", _node.overrides() ? toJson(*_node.overrides()) : Json::nullValue),
make_pair("body", toJson(_node.body())) make_pair("body", toJson(_node.body()))
}); };
if (!_node.annotation().baseFunctions.empty())
attributes.emplace_back(make_pair("baseModifiers", getContainerIds(_node.annotation().baseFunctions, true)));
setJsonNode(_node, "ModifierDefinition", std::move(attributes));
return false; return false;
} }

View File

@ -2313,7 +2313,7 @@ BOOST_AUTO_TEST_CASE(function_modifier_overriding)
char const* sourceCode = R"( char const* sourceCode = R"(
contract A { contract A {
function f() mod public returns (bool r) { return true; } function f() mod public returns (bool r) { return true; }
modifier mod { _; } modifier mod virtual { _; }
} }
contract C is A { contract C is A {
modifier mod override { if (false) _; } modifier mod override { if (false) _; }
@ -2352,7 +2352,7 @@ BOOST_AUTO_TEST_CASE(function_modifier_for_constructor)
contract A { contract A {
uint data; uint data;
constructor() mod1 public { data |= 2; } constructor() mod1 public { data |= 2; }
modifier mod1 { data |= 1; _; } modifier mod1 virtual { data |= 1; _; }
function getData() public returns (uint r) { return data; } function getData() public returns (uint r) { return data; }
} }
contract C is A { contract C is A {

View File

@ -16,5 +16,5 @@ abstract contract B is I {
contract C is A, B { contract C is A, B {
} }
// ---- // ----
// TypeError: (342-364): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes. // TypeError: (342-364): Derived contract must override function "f". Two or more base classes define function with same name and parameter types.
// TypeError: (342-364): Derived contract must override function "g". Function with the same name and parameter types defined in two or more base classes. // TypeError: (342-364): Derived contract must override function "g". Two or more base classes define function with same name and parameter types.

View File

@ -14,4 +14,4 @@ abstract contract B is I {
contract C is A, B { contract C is A, B {
} }
// ---- // ----
// TypeError: (254-276): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes. // TypeError: (254-276): Derived contract must override function "f". Two or more base classes define function with same name and parameter types.

View File

@ -13,5 +13,5 @@ abstract contract B is I {
contract C is A, B { contract C is A, B {
} }
// ---- // ----
// TypeError: (292-314): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes. // TypeError: (292-314): Derived contract must override function "f". Two or more base classes define function with same name and parameter types.
// TypeError: (292-314): Derived contract must override function "g". Function with the same name and parameter types defined in two or more base classes. // TypeError: (292-314): Derived contract must override function "g". Two or more base classes define function with same name and parameter types.

View File

@ -6,4 +6,4 @@ contract B {
} }
contract C is A, B {} contract C is A, B {}
// ---- // ----
// TypeError: (126-147): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes. // TypeError: (126-147): Derived contract must override function "f". Two or more base classes define function with same name and parameter types.

View File

@ -0,0 +1,9 @@
contract A {
modifier f() virtual { _; }
}
contract B {
modifier f() virtual { _; }
}
contract C is A, B {
modifier f() override(A,B) { _; }
}

View File

@ -0,0 +1,10 @@
contract A {
modifier f() virtual { _; }
}
contract B {
modifier f() virtual { _; }
}
contract C is A, B {
}
// ----
// TypeError: (94-116): Derived contract must override modifier "f". Two or more base classes define modifier with same name.

View File

@ -0,0 +1,10 @@
contract A {
modifier f(uint a) virtual { _; }
}
contract B {
modifier f() virtual { _; }
}
contract C is A, B {
}
// ----
// TypeError: (100-122): Derived contract must override modifier "f". Two or more base classes define modifier with same name.

View File

@ -0,0 +1,11 @@
contract A {
modifier f(uint a) virtual { _; }
}
contract B {
modifier f() virtual { _; }
}
contract C is A, B {
modifier f() virtual override(A, B) { _; }
}
// ----
// TypeError: (125-167): Override changes modifier signature.

View File

@ -9,5 +9,5 @@ abstract contract B {
contract C is A, B { contract C is A, B {
} }
// ---- // ----
// TypeError: (176-198): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes. // TypeError: (176-198): Derived contract must override function "f". Two or more base classes define function with same name and parameter types.
// TypeError: (176-198): Derived contract must override function "g". Function with the same name and parameter types defined in two or more base classes. // TypeError: (176-198): Derived contract must override function "g". Two or more base classes define function with same name and parameter types.

View File

@ -0,0 +1,19 @@
contract I {
modifier f() virtual { _; }
}
contract J {
modifier f() virtual { _; }
}
contract IJ is I, J {
modifier f() virtual override (I, J) { _; }
}
contract A is IJ
{
modifier f() override { _; }
}
contract B is IJ
{
}
contract C is A, B {}
// ----
// TypeError: (229-250): Derived contract must override modifier "f". Two or more base classes define modifier with same name.

View File

@ -9,4 +9,4 @@ abstract contract X is A, B {
function test() internal override returns (uint256) {} function test() internal override returns (uint256) {}
} }
// ---- // ----
// TypeError: (205-292): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (205-292): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -0,0 +1,7 @@
abstract contract A {
}
abstract contract X is A {
modifier f() override { _; }
}
// ----
// TypeError: (65-73): Modifier has override specified but does not override anything.

View File

@ -10,4 +10,4 @@ contract C is A, B
{ {
} }
// ---- // ----
// TypeError: (94-116): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (94-116): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -9,4 +9,4 @@ abstract contract X is A, B {
function test() internal override virtual returns (uint256); function test() internal override virtual returns (uint256);
} }
// ---- // ----
// TypeError: (203-296): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (203-296): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -8,4 +8,4 @@ contract X is A, B {
uint public override foo; uint public override foo;
} }
// ---- // ----
// TypeError: (162-211): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (162-211): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -11,4 +11,4 @@ contract X is B, C {
uint public override foo; uint public override foo;
} }
// ---- // ----
// TypeError: (271-320): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (271-320): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -12,4 +12,4 @@ contract X is B, C {
} }
// ---- // ----
// DeclarationError: (245-269): Identifier already declared. // DeclarationError: (245-269): Identifier already declared.
// TypeError: (223-272): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (223-272): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -8,4 +8,4 @@ contract X is A, B {
uint public override(A, B) foo; uint public override(A, B) foo;
} }
// ---- // ----
// TypeError: (162-217): Derived contract must override function "foo". Function with the same name and parameter types defined in two or more base classes. // TypeError: (162-217): Derived contract must override function "foo". Two or more base classes define function with same name and parameter types.

View File

@ -3,4 +3,4 @@ contract B is A { function f() public pure virtual override {} }
contract C is A, B { } contract C is A, B { }
contract D is A, B { function f() public pure override(A, B) {} } contract D is A, B { function f() public pure override(A, B) {} }
// ---- // ----
// TypeError: (116-138): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes. // TypeError: (116-138): Derived contract must override function "f". Two or more base classes define function with same name and parameter types.

View File

@ -1,5 +1,6 @@
contract A { modifier mod(uint a) { _; } } contract A { modifier mod(uint a) { _; } }
contract B is A { modifier mod(uint8 a) { _; } } contract B is A { modifier mod(uint8 a) { _; } }
// ---- // ----
// TypeError: (61-89): Overriding modifier is missing 'override' specifier.
// TypeError: (61-89): Override changes modifier signature. // TypeError: (61-89): Override changes modifier signature.
// TypeError: (61-89): Overriding modifier is missing 'override' specifier.
// TypeError: (13-40): Trying to override non-virtual modifier. Did you forget to add "virtual"?

View File

@ -1,2 +1,2 @@
contract A { modifier mod(uint a) { _; } } contract A { modifier mod(uint a) virtual { _; } }
contract B is A { modifier mod(uint a) override { _; } } contract B is A { modifier mod(uint a) override { _; } }

View File

@ -0,0 +1,4 @@
contract A { modifier mod(uint a) { _; } }
contract B is A { modifier mod(uint a) override { _; } }
// ----
// TypeError: (13-40): Trying to override non-virtual modifier. Did you forget to add "virtual"?