Do not require overriding for functions in common base with unique implementation.

This commit is contained in:
Daniel Kirchner 2019-12-03 13:58:30 +01:00 committed by chriseth
parent 071a52f0ff
commit 4c7f9f9751
12 changed files with 264 additions and 57 deletions

View File

@ -112,24 +112,32 @@ inline std::set<T> operator+(std::set<T>&& _a, U&& _b)
return ret; return ret;
} }
/// Remove one set from another one. /// Remove the elements of a container from a set.
template <class... T> template <class C, class... T>
inline std::set<T...>& operator-=(std::set<T...>& _a, std::set<T...> const& _b) inline std::set<T...>& operator-=(std::set<T...>& _a, C const& _b)
{ {
for (auto const& x: _b) for (auto const& x: _b)
_a.erase(x); _a.erase(x);
return _a; return _a;
} }
template <class... T> template <class C, class... T>
inline std::set<T...> operator-(std::set<T...> const& _a, std::set<T...> const& _b) inline std::set<T...> operator-(std::set<T...> const& _a, C const& _b)
{ {
auto result = _a; auto result = _a;
result -= _b; result -= _b;
return result; return result;
} }
/// Remove the elements of a container from a multiset.
template <class C, class... T>
inline std::multiset<T...>& operator-=(std::multiset<T...>& _a, C const& _b)
{
for (auto const& x: _b)
_a.erase(x);
return _a;
}
namespace dev namespace dev
{ {

View File

@ -238,8 +238,8 @@ void ContractLevelChecker::findDuplicateDefinitions(map<string, vector<T>> const
void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _contract) void ContractLevelChecker::checkIllegalOverrides(ContractDefinition const& _contract)
{ {
FunctionMultiSet const& funcSet = inheritedFunctions(&_contract); FunctionMultiSet const& funcSet = inheritedFunctions(_contract);
ModifierMultiSet const& modSet = inheritedModifiers(&_contract); ModifierMultiSet const& modSet = inheritedModifiers(_contract);
checkModifierOverrides(funcSet, modSet, _contract.functionModifiers()); checkModifierOverrides(funcSet, modSet, _contract.functionModifiers());
@ -666,56 +666,132 @@ void ContractLevelChecker::checkBaseABICompatibility(ContractDefinition const& _
void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _contract) const void ContractLevelChecker::checkAmbiguousOverrides(ContractDefinition const& _contract) const
{ {
vector<FunctionDefinition const*> contractFuncs = _contract.definedFunctions(); // Fetch inherited functions and sort them by signature.
// We get at least one function per signature and direct base contract, which is
// enough because we re-construct the inheritance graph later.
FunctionMultiSet nonOverriddenFunctions = inheritedFunctions(_contract);
// Remove all functions that match the signature of a function in the current contract.
nonOverriddenFunctions -= _contract.definedFunctions();
auto const resolvedBases = resolveDirectBaseContracts(_contract); // Walk through the set of functions signature by signature.
for (auto it = nonOverriddenFunctions.cbegin(); it != nonOverriddenFunctions.cend();)
FunctionMultiSet inheritedFuncs = inheritedFunctions(&_contract);;
// Check the sets of the most-inherited functions
for (auto it = inheritedFuncs.cbegin(); it != inheritedFuncs.cend(); it = inheritedFuncs.upper_bound(*it))
{ {
auto [begin,end] = inheritedFuncs.equal_range(*it); static constexpr auto compareById = [](auto const* a, auto const* b) { return a->id() < b->id(); };
std::set<FunctionDefinition const*, decltype(compareById)> baseFunctions(compareById);
for (auto nextSignature = nonOverriddenFunctions.upper_bound(*it); it != nextSignature; ++it)
baseFunctions.insert(*it);
// Only one function if (baseFunctions.size() <= 1)
if (next(begin) == end)
continue; continue;
// Not an overridable function // Construct the override graph for this signature.
if ((*it)->isConstructor()) // Reserve node 0 for the current contract and node
// 1 for an artificial top node to which all override paths
// connect at the end.
struct OverrideGraph
{ {
for (begin++; begin != end; begin++) OverrideGraph(decltype(baseFunctions) const& _baseFunctions)
solAssert((*begin)->isConstructor(), "All functions in range expected to be constructors!"); {
continue; for (auto const* baseFunction: _baseFunctions)
} addEdge(0, visit(baseFunction));
// Function has been explicitly overridden
if (contains_if(
contractFuncs,
[&] (FunctionDefinition const* _f) {
return hasEqualNameAndParameters(*_f, **it);
} }
)) std::map<FunctionDefinition const*, int> nodes;
continue; std::map<int, FunctionDefinition const*> nodeInv;
std::map<int, std::set<int>> edges;
int numNodes = 2;
void addEdge(int _a, int _b)
{
edges[_a].insert(_b);
edges[_b].insert(_a);
}
private:
/// Completes the graph starting from @a _function and
/// @returns the node ID.
int visit(FunctionDefinition const* _function)
{
auto it = nodes.find(_function);
if (it != nodes.end())
return it->second;
int currentNode = numNodes++;
nodes[_function] = currentNode;
nodeInv[currentNode] = _function;
if (_function->overrides())
for (auto const* baseFunction: _function->annotation().baseFunctions)
addEdge(currentNode, visit(baseFunction));
else
addEdge(currentNode, 1);
set<FunctionDefinition const*> ambiguousFunctions; return currentNode;
SecondarySourceLocation ssl; }
} overrideGraph(baseFunctions);
for (;begin != end; begin++) // Detect cut vertices following https://en.wikipedia.org/wiki/Biconnected_component#Pseudocode
// Can ignore the root node, since it is never a cut vertex in our case.
struct CutVertexFinder
{ {
ambiguousFunctions.insert(*begin); CutVertexFinder(OverrideGraph const& _graph): m_graph(_graph)
ssl.append("Definition here: ", (*begin)->location()); {
run();
}
std::set<FunctionDefinition const*> const& cutVertices() const { return m_cutVertices; }
private:
OverrideGraph const& m_graph;
std::vector<bool> m_visited = std::vector<bool>(m_graph.numNodes, false);
std::vector<int> m_depths = std::vector<int>(m_graph.numNodes, -1);
std::vector<int> m_low = std::vector<int>(m_graph.numNodes, -1);
std::vector<int> m_parent = std::vector<int>(m_graph.numNodes, -1);
std::set<FunctionDefinition const*> m_cutVertices{};
void run(int _u = 0, int _depth = 0)
{
m_visited.at(_u) = true;
m_depths.at(_u) = m_low.at(_u) = _depth;
for (int v: m_graph.edges.at(_u))
if (!m_visited.at(v))
{
m_parent[v] = _u;
run(v, _depth + 1);
if (m_low[v] >= m_depths[_u] && m_parent[_u] != -1)
m_cutVertices.insert(m_graph.nodeInv.at(_u));
m_low[_u] = min(m_low[_u], m_low[v]);
}
else if (v != m_parent[_u])
m_low[_u] = min(m_low[_u], m_depths[v]);
}
} cutVertexFinder{overrideGraph};
// Remove all base functions overridden by cut vertices (they don't need to be overridden).
for (auto const* function: cutVertexFinder.cutVertices())
{
std::set<FunctionDefinition const*> toTraverse = function->annotation().baseFunctions;
while (!toTraverse.empty())
{
auto const* base = *toTraverse.begin();
toTraverse.erase(toTraverse.begin());
baseFunctions.erase(base);
for (auto const* f: base->annotation().baseFunctions)
toTraverse.insert(f);
}
// Remove unimplemented base functions at the cut vertices themselves as well.
if (!function->isImplemented())
baseFunctions.erase(function);
} }
// Make sure the functions are not from the same base contract // If more than one function is left, they have to be overridden.
if (ambiguousFunctions.size() == 1) if (baseFunctions.size() <= 1)
continue; continue;
SecondarySourceLocation ssl;
for (auto const* baseFunction: baseFunctions)
ssl.append("Definition here: ", baseFunction->location());
m_errorReporter.typeError( m_errorReporter.typeError(
_contract.location(), _contract.location(),
ssl, ssl,
"Derived contract must override function \"" + "Derived contract must override function \"" +
(*it)->name() + (*baseFunctions.begin())->name() +
"\". Function with the same name and parameter types defined in two or more base classes." "\". Function with the same name and parameter types defined in two or more base classes."
); );
} }
@ -845,50 +921,52 @@ void ContractLevelChecker::checkOverrideList(FunctionMultiSet const& _funcSet, F
); );
} }
ContractLevelChecker::FunctionMultiSet const& ContractLevelChecker::inheritedFunctions(ContractDefinition const* _contract) const ContractLevelChecker::FunctionMultiSet const& ContractLevelChecker::inheritedFunctions(ContractDefinition const& _contract) const
{ {
if (!m_inheritedFunctions.count(_contract)) if (!m_inheritedFunctions.count(&_contract))
{ {
FunctionMultiSet set; FunctionMultiSet set;
for (auto const* base: resolveDirectBaseContracts(*_contract)) for (auto const* base: resolveDirectBaseContracts(_contract))
{ {
std::set<FunctionDefinition const*, LessFunction> tmpSet = std::set<FunctionDefinition const*, LessFunction> functionsInBase;
convertContainer<decltype(tmpSet)>(base->definedFunctions()); for (FunctionDefinition const* fun: base->definedFunctions())
if (!fun->isConstructor())
functionsInBase.emplace(fun);
for (auto const& func: inheritedFunctions(base)) for (auto const& func: inheritedFunctions(*base))
tmpSet.insert(func); functionsInBase.insert(func);
set += tmpSet; set += functionsInBase;
} }
m_inheritedFunctions[_contract] = set; m_inheritedFunctions[&_contract] = set;
} }
return m_inheritedFunctions[_contract]; return m_inheritedFunctions[&_contract];
} }
ContractLevelChecker::ModifierMultiSet const& ContractLevelChecker::inheritedModifiers(ContractDefinition const* _contract) const ContractLevelChecker::ModifierMultiSet const& ContractLevelChecker::inheritedModifiers(ContractDefinition const& _contract) const
{ {
auto const& result = m_contractBaseModifiers.find(_contract); auto const& result = m_contractBaseModifiers.find(&_contract);
if (result != m_contractBaseModifiers.cend()) if (result != m_contractBaseModifiers.cend())
return result->second; return result->second;
ModifierMultiSet set; ModifierMultiSet set;
for (auto const* base: resolveDirectBaseContracts(*_contract)) for (auto const* base: resolveDirectBaseContracts(_contract))
{ {
std::set<ModifierDefinition const*, LessFunction> tmpSet = std::set<ModifierDefinition const*, LessFunction> tmpSet =
convertContainer<decltype(tmpSet)>(base->functionModifiers()); convertContainer<decltype(tmpSet)>(base->functionModifiers());
for (auto const& mod: inheritedModifiers(base)) for (auto const& mod: inheritedModifiers(*base))
tmpSet.insert(mod); tmpSet.insert(mod);
set += tmpSet; set += tmpSet;
} }
return m_contractBaseModifiers[_contract] = set; return m_contractBaseModifiers[&_contract] = set;
} }
void ContractLevelChecker::checkPayableFallbackWithoutReceive(ContractDefinition const& _contract) void ContractLevelChecker::checkPayableFallbackWithoutReceive(ContractDefinition const& _contract)

View File

@ -106,8 +106,8 @@ private:
/// Returns all functions of bases that have not yet been overwritten. /// Returns all functions of bases that have not yet been overwritten.
/// May contain the same function multiple times when used with shared bases. /// May contain the same function multiple times when used with shared bases.
FunctionMultiSet const& inheritedFunctions(ContractDefinition const* _contract) const; FunctionMultiSet const& inheritedFunctions(ContractDefinition const& _contract) const;
ModifierMultiSet const& inheritedModifiers(ContractDefinition const* _contract) const; ModifierMultiSet const& inheritedModifiers(ContractDefinition const& _contract) const;
/// Warns if the contract has a payable fallback, but no receive ether function. /// Warns if the contract has a payable fallback, but no receive ether function.
void checkPayableFallbackWithoutReceive(ContractDefinition const& _contract); void checkPayableFallbackWithoutReceive(ContractDefinition const& _contract);

View File

@ -0,0 +1,20 @@
interface I {
function f() external;
function g() external;
}
interface J {
function f() external;
}
abstract contract A is I, J {
function f() external override (I, J) {}
function g() external override virtual;
}
abstract contract B is I {
function f() external override virtual;
function g() external override {}
}
contract C is A, B {
}
// ----
// TypeError: (342-364): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes.
// TypeError: (342-364): Derived contract must override function "g". Function with the same name and parameter types defined in two or more base classes.

View File

@ -0,0 +1,17 @@
interface I {
function f() external;
function g() external;
}
interface J {
function f() external;
}
abstract contract A is I, J {
function f() external override (I, J) {}
}
abstract contract B is I {
function g() external override {}
}
contract C is A, B {
}
// ----
// TypeError: (254-276): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes.

View File

@ -0,0 +1,12 @@
contract A {
function f() external virtual {}
}
contract B {
function f() external virtual {}
}
contract C is A, B {
function f() external override (A, B) {}
}
contract X is C {
}
// ----

View File

@ -0,0 +1,17 @@
interface I {
function f() external;
function g() external;
}
abstract contract A is I {
function f() external override {}
function g() external override virtual;
}
abstract contract B is I {
function g() external override {}
function f() external override virtual;
}
contract C is A, B {
}
// ----
// TypeError: (292-314): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes.
// TypeError: (292-314): Derived contract must override function "g". Function with the same name and parameter types defined in two or more base classes.

View File

@ -0,0 +1,12 @@
interface I {
function f() external;
function g() external;
}
abstract contract A is I {
function f() external override {}
}
abstract contract B is I {
function g() external override {}
}
contract C is A, B {
}

View File

@ -0,0 +1,13 @@
abstract contract A {
function f() external {}
function g() external virtual;
}
abstract contract B {
function g() external {}
function f() external virtual;
}
contract C is A, B {
}
// ----
// TypeError: (176-198): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes.
// TypeError: (176-198): Derived contract must override function "g". Function with the same name and parameter types defined in two or more base classes.

View File

@ -0,0 +1,19 @@
interface I {
function f() external;
function g() external;
}
interface J {
function f() external;
}
abstract contract IJ is I, J {
function f() external virtual override (I, J);
}
abstract contract A is IJ
{
function f() external override {}
}
abstract contract B is IJ
{
function g() external override {}
}
contract C is A, B {}

View File

@ -0,0 +1,6 @@
contract A { function f() public pure virtual {} }
contract B is A { function f() public pure virtual override {} }
contract C is A, B { }
contract D is A, B { function f() public pure override(A, B) {} }
// ----
// TypeError: (116-138): Derived contract must override function "f". Function with the same name and parameter types defined in two or more base classes.

View File

@ -0,0 +1,5 @@
abstract contract A { function f() public pure virtual; }
contract B is A { function f() public pure virtual override {} }
contract C is A, B { }
contract D is A, B { function f() public pure override(A, B) {} }
// ----