Use proxies.

This commit is contained in:
chriseth 2019-12-10 09:48:01 +01:00
parent 3e1b00b459
commit 57824566e6
3 changed files with 495 additions and 335 deletions

View File

@ -42,23 +42,12 @@ namespace
struct MatchByName
{
string const& m_name;
bool operator()(CallableDeclaration const* _callable)
bool operator()(OverrideProxy const& _item)
{
return _callable->name() == m_name;
return _item.name() == m_name;
}
};
template <class T, class B>
bool hasEqualNameAndParameters(T const& _a, B const& _b)
{
return
_a.name() == _b.name() &&
FunctionType(_a).asCallableFunction(false)->hasEqualParameterTypes(
*FunctionType(_b).asCallableFunction(false)
);
}
vector<ContractDefinition const*> resolveDirectBaseContracts(ContractDefinition const& _contract)
{
vector<ContractDefinition const*> resolvedContracts;
@ -97,71 +86,260 @@ vector<ASTPointer<UserDefinedTypeName>> sortByContract(vector<ASTPointer<UserDef
return sorted;
}
OverrideProxy makeOverrideProxy(CallableDeclaration const& _callable)
{
if (auto const* fun = dynamic_cast<FunctionDefinition const*>(&_callable))
return OverrideProxy{fun};
else if (auto const* mod = dynamic_cast<ModifierDefinition const*>(&_callable))
return OverrideProxy{mod};
else
solAssert(false, "Invalid call to makeOverrideProxy.");
return {};
}
}
bool OverrideProxy::operator<(OverrideProxy const& _other) const
{
return std::visit(GenericVisitor{
[&](FunctionDefinition const* _function)
{
return std::visit(GenericVisitor{
[&](FunctionDefinition const* _other)
{
if (_function->name() != _other->name())
return _function->name() < _other->name();
if (_function->kind() != _other->kind())
return _function->kind() < _other->kind();
return boost::lexicographical_compare(
FunctionType(*_function).asCallableFunction(false)->parameterTypes(),
FunctionType(*_other).asCallableFunction(false)->parameterTypes(),
[](auto const& _paramTypeA, auto const& _paramTypeB)
{
return _paramTypeA->richIdentifier() < _paramTypeB->richIdentifier();
}
);
},
[&](VariableDeclaration const* _other)
{
if (_function->name() != _other->name())
return _function->name() < _other->name();
if (_function->kind() != Token::Function)
return _function->kind() < Token::Function;
return boost::lexicographical_compare(
FunctionType(*_function).asCallableFunction(false)->parameterTypes(),
FunctionType(*_other).asCallableFunction(false)->parameterTypes(),
[](auto const& _paramTypeA, auto const& _paramTypeB)
{
return _paramTypeA->richIdentifier() < _paramTypeB->richIdentifier();
}
);
},
[&](ModifierDefinition const*)
{
solAssert(false, "Compared function to something else than function or state variable.");
return false;
}
}, _other.item);
},
[&](ModifierDefinition const*)
{
solAssert(false, "todo");
return false;
},
[&](VariableDeclaration const*)
{
solAssert(false, "todo");
return false;
}
}, item);
return false;
return id() < _other.id();
}
bool OverrideProxy::isVariable() const
{
return holds_alternative<VariableDeclaration const*>(m_item);
}
bool OverrideProxy::isFunction() const
{
return holds_alternative<FunctionDefinition const*>(m_item);
}
bool OverrideProxy::isModifier() const
{
return holds_alternative<ModifierDefinition const*>(m_item);
}
bool OverrideProxy::CompareBySignature::operator()(OverrideProxy const& _a, OverrideProxy const& _b) const
{
return _a.overrideComparator() < _b.overrideComparator();
}
size_t OverrideProxy::id() const
{
return std::visit(GenericVisitor{
[&](auto const* _item) -> size_t { return _item->id(); }
}, m_item);
}
shared_ptr<OverrideSpecifier> OverrideProxy::overrides() const
{
return std::visit(GenericVisitor{
[&](auto const* _item) { return _item->overrides(); }
}, m_item);
}
set<OverrideProxy> OverrideProxy::baseFunctions() const
{
return std::visit(GenericVisitor{
[&](auto const* _item) -> set<OverrideProxy> {
set<OverrideProxy> ret;
for (auto const* f: _item->annotation().baseFunctions)
ret.insert(makeOverrideProxy(*f));
return ret;
}
}, m_item);
}
void OverrideProxy::storeBaseFunction(OverrideProxy const& _base) const
{
std::visit(GenericVisitor{
[&](FunctionDefinition const* _item) {
_item->annotation().baseFunctions.emplace(std::get<FunctionDefinition const*>(_base.m_item));
},
[&](ModifierDefinition const* _item) {
_item->annotation().baseFunctions.emplace(std::get<ModifierDefinition const*>(_base.m_item));
},
[&](VariableDeclaration const* _item) {
_item->annotation().baseFunctions.emplace(std::get<FunctionDefinition const*>(_base.m_item));
}
}, m_item);
}
string const& OverrideProxy::name() const
{
return std::visit(GenericVisitor{
[&](auto const* _item) -> string const& { return _item->name(); }
}, m_item);
}
ContractDefinition const& OverrideProxy::contract() const
{
return std::visit(GenericVisitor{
[&](auto const* _item) -> ContractDefinition const& {
return dynamic_cast<ContractDefinition const&>(*_item->scope());
}
}, m_item);
}
string const& OverrideProxy::contractName() const
{
return contract().name();
}
Visibility OverrideProxy::visibility() const
{
return std::visit(GenericVisitor{
[&](FunctionDefinition const* _item) { return _item->visibility(); },
[&](ModifierDefinition const* _item) { return _item->visibility(); },
[&](VariableDeclaration const*) { return Visibility::External; }
}, m_item);
}
StateMutability OverrideProxy::stateMutability() const
{
return std::visit(GenericVisitor{
[&](FunctionDefinition const* _item) { return _item->stateMutability(); },
[&](ModifierDefinition const*) { solAssert(false, "Requested state mutability from modifier."); return StateMutability{}; },
[&](VariableDeclaration const*) { return StateMutability::View; }
}, m_item);
}
bool OverrideProxy::virtualSemantics() const
{
return std::visit(GenericVisitor{
[&](FunctionDefinition const* _item) { return _item->virtualSemantics(); },
[&](ModifierDefinition const* _item) { return _item->virtualSemantics(); },
[&](VariableDeclaration const*) { return false; }
}, m_item);
}
Token OverrideProxy::functionKind() const
{
return std::visit(GenericVisitor{
[&](FunctionDefinition const* _item) { return _item->kind(); },
[&](ModifierDefinition const*) { return Token::Function; },
[&](VariableDeclaration const*) { return Token::Function; }
}, m_item);
}
FunctionType const* OverrideProxy::functionType() const
{
return std::visit(GenericVisitor{
[&](FunctionDefinition const* _item) { return FunctionType(*_item).asCallableFunction(false); },
[&](VariableDeclaration const* _item) { return FunctionType(*_item).asCallableFunction(false); },
[&](ModifierDefinition const*) -> FunctionType const* { solAssert(false, "Requested function type of modifier."); return nullptr; }
}, m_item);
}
ModifierType const* OverrideProxy::modifierType() const
{
return std::visit(GenericVisitor{
[&](FunctionDefinition const*) -> ModifierType const* { solAssert(false, "Requested modifier type of function."); return nullptr; },
[&](VariableDeclaration const*) -> ModifierType const* { solAssert(false, "Requested modifier type of variable."); return nullptr; },
[&](ModifierDefinition const* _modifier) -> ModifierType const* { return TypeProvider::modifier(*_modifier); }
}, m_item);
}
SourceLocation const& OverrideProxy::location() const
{
return std::visit(GenericVisitor{
[&](auto const* _item) -> SourceLocation const& { return _item->location(); }
}, m_item);
}
string OverrideProxy::astNodeName() const
{
return std::visit(GenericVisitor{
[&](FunctionDefinition const*) { return "function"; },
[&](ModifierDefinition const*) { return "modifier"; },
[&](VariableDeclaration const*) { return "public state variable"; },
}, m_item);
}
string OverrideProxy::astNodeNameCapitalized() const
{
return std::visit(GenericVisitor{
[&](FunctionDefinition const*) { return "Function"; },
[&](ModifierDefinition const*) { return "Modifier"; },
[&](VariableDeclaration const*) { return "Public state variable"; },
}, m_item);
}
string OverrideProxy::distinguishingProperty() const
{
return std::visit(GenericVisitor{
[&](FunctionDefinition const*) { return "name and parameter types"; },
[&](ModifierDefinition const*) { return "name"; },
[&](VariableDeclaration const*) { return "name and parameter types"; },
}, m_item);
}
bool OverrideProxy::unimplemented() const
{
return std::visit(GenericVisitor{
[&](FunctionDefinition const* _item) { return !_item->isImplemented(); },
[&](ModifierDefinition const*) { return false; },
[&](VariableDeclaration const*) { return false; }
}, m_item);
}
bool OverrideProxy::OverrideComparator::operator<(OverrideComparator const& _other) const
{
if (name != _other.name)
return name < _other.name;
if (!functionKind || !_other.functionKind)
return false;
if (functionKind != _other.functionKind)
return *functionKind < *_other.functionKind;
if (!parameterTypes || !_other.parameterTypes)
return false;
return boost::lexicographical_compare(*parameterTypes, *_other.parameterTypes);
}
OverrideProxy::OverrideComparator const& OverrideProxy::overrideComparator() const
{
if (!m_comparator)
{
m_comparator = make_shared<OverrideComparator>(std::visit(GenericVisitor{
[&](FunctionDefinition const* _function)
{
vector<string> paramTypes;
for (Type const* t: functionType()->parameterTypes())
paramTypes.emplace_back(t->richIdentifier());
return OverrideComparator{
_function->name(),
_function->kind(),
std::move(paramTypes)
};
},
[&](VariableDeclaration const* _var)
{
vector<string> paramTypes;
for (Type const* t: functionType()->parameterTypes())
paramTypes.emplace_back(t->richIdentifier());
return OverrideComparator{
_var->name(),
Token::Function,
std::move(paramTypes)
};
},
[&](ModifierDefinition const* _mod)
{
return OverrideComparator{
_mod->name(),
{},
{}
};
}
}, m_item));
}
return *m_comparator;
}
bool OverrideChecker::LessFunction::operator()(ModifierDefinition const* _a, ModifierDefinition const* _b) const
{
@ -177,8 +355,8 @@ bool OverrideChecker::LessFunction::operator()(FunctionDefinition const* _a, Fun
return _a->kind() < _b->kind();
return boost::lexicographical_compare(
FunctionType(*_a).asCallableFunction(false)->parameterTypes(),
FunctionType(*_b).asCallableFunction(false)->parameterTypes(),
TypeProvider::function(*_a)->asCallableFunction(false)->parameterTypes(),
TypeProvider::function(*_b)->asCallableFunction(false)->parameterTypes(),
[](auto const& _paramTypeA, auto const& _paramTypeB)
{
return _paramTypeA->richIdentifier() < _paramTypeB->richIdentifier();
@ -202,38 +380,15 @@ void OverrideChecker::check(ContractDefinition const& _contract)
void OverrideChecker::checkIllegalOverrides(ContractDefinition const& _contract)
{
FunctionMultiSet const& inheritedFuncs = inheritedFunctions(_contract);
ModifierMultiSet const& inheritedMods = inheritedModifiers(_contract);
OverrideProxyBySignatureMultiSet const& inheritedFuncs = inheritedFunctions(_contract);
OverrideProxyBySignatureMultiSet const& inheritedMods = inheritedModifiers(_contract);
for (auto const* stateVar: _contract.stateVariables())
{
if (!stateVar->isPublic())
continue;
bool found = false;
for (
auto it = find_if(inheritedFuncs.begin(), inheritedFuncs.end(), MatchByName{stateVar->name()});
it != inheritedFuncs.end();
it = find_if(++it, inheritedFuncs.end(), MatchByName{stateVar->name()})
)
{
// -> compare equal
if (!hasEqualNameAndParameters(*stateVar, **it))
continue;
if ((*it)->visibility() != Visibility::External)
overrideError(*stateVar, **it, "Public state variables can only override functions with external visibility.");
else
checkOverride(*stateVar, **it);
found = true;
}
if (!found && stateVar->overrides())
m_errorReporter.typeError(
stateVar->overrides()->location(),
"Public state variable has override specified but does not override anything."
);
checkOverrideList(OverrideProxy{stateVar}, inheritedFuncs);
}
for (ModifierDefinition const* modifier: _contract.functionModifiers())
@ -244,22 +399,7 @@ void OverrideChecker::checkIllegalOverrides(ContractDefinition const& _contract)
"Override changes function to modifier."
);
auto [begin, end] = inheritedMods.equal_range(modifier);
if (begin == end && modifier->overrides())
m_errorReporter.typeError(
modifier->overrides()->location(),
"Modifier has override specified but does not override anything."
);
for (; begin != end; begin++)
if (ModifierType(**begin) != ModifierType(*modifier))
m_errorReporter.typeError(
modifier->location(),
"Override changes modifier signature."
);
checkOverrideList(inheritedMods, *modifier);
checkOverrideList(OverrideProxy{modifier}, inheritedMods);
}
for (FunctionDefinition const* function: _contract.definedFunctions())
@ -270,64 +410,41 @@ void OverrideChecker::checkIllegalOverrides(ContractDefinition const& _contract)
if (contains_if(inheritedMods, MatchByName{function->name()}))
m_errorReporter.typeError(function->location(), "Override changes modifier to function.");
// No inheriting functions found
if (!inheritedFuncs.count(function) && function->overrides())
m_errorReporter.typeError(
function->overrides()->location(),
"Function has override specified but does not override anything."
);
checkOverrideList(inheritedFuncs, *function);
checkOverrideList(OverrideProxy{function}, inheritedFuncs);
}
}
template<class T, class U>
void OverrideChecker::checkOverride(T const& _overriding, U const& _super)
void OverrideChecker::checkOverride(OverrideProxy const& _overriding, OverrideProxy const& _super)
{
static_assert(
std::is_same<VariableDeclaration, T>::value ||
std::is_same<FunctionDefinition, T>::value ||
std::is_same<ModifierDefinition, T>::value,
"Invalid call to checkOverride."
);
solAssert(_super.isFunction() || _super.isModifier(), "");
solAssert(_super.isModifier() == _overriding.isModifier(), "");
static_assert(
std::is_same<FunctionDefinition, U>::value ||
std::is_same<ModifierDefinition, U>::value,
"Invalid call to checkOverride."
);
static_assert(
!std::is_same<ModifierDefinition, U>::value ||
std::is_same<ModifierDefinition, T>::value,
"Invalid call to checkOverride."
);
_overriding.storeBaseFunction(_super);
string overridingName;
if constexpr(std::is_same<FunctionDefinition, T>::value)
overridingName = "function";
else if constexpr(std::is_same<ModifierDefinition, T>::value)
overridingName = "modifier";
else
overridingName = "public state variable";
string superName;
if constexpr(std::is_same<FunctionDefinition, U>::value)
superName = "function";
else
superName = "modifier";
if (_overriding.isModifier() && *_overriding.modifierType() != *_super.modifierType())
m_errorReporter.typeError(
_overriding.location(),
"Override changes modifier signature."
);
if (!_overriding.overrides())
overrideError(_overriding, _super, "Overriding " + overridingName + " is missing 'override' specifier.");
overrideError(_overriding, _super, "Overriding " + _overriding.astNodeName() + " is missing 'override' specifier.");
if (!_super.virtualSemantics())
overrideError(
_super,
_overriding,
"Trying to override non-virtual " + superName + ". Did you forget to add \"virtual\"?",
"Overriding " + overridingName + " is here:"
"Trying to override non-virtual " + _super.astNodeName() + ". Did you forget to add \"virtual\"?",
"Overriding " + _overriding.astNodeName() + " is here:"
);
if (_overriding.visibility() != _super.visibility())
if (_overriding.isVariable())
{
if (_super.visibility() != Visibility::External)
overrideError(_overriding, _super, "Public state variables can only override functions with external visibility.");
solAssert(_overriding.visibility() == Visibility::External, "");
}
else if (_overriding.visibility() != _super.visibility())
{
// Visibility change from external to public is fine.
// Any other change is disallowed.
@ -335,26 +452,22 @@ void OverrideChecker::checkOverride(T const& _overriding, U const& _super)
_super.visibility() == Visibility::External &&
_overriding.visibility() == Visibility::Public
))
overrideError(_overriding, _super, "Overriding " + overridingName + " visibility differs.");
overrideError(_overriding, _super, "Overriding " + _overriding.astNodeName() + " visibility differs.");
}
// This is only relevant for overriding functions by functions or state variables,
// it is skipped for modifiers.
if constexpr(std::is_same<FunctionDefinition, U>::value)
if (_super.isFunction())
{
FunctionTypePointer functionType = FunctionType(_overriding).asCallableFunction(false);
FunctionTypePointer superType = FunctionType(_super).asCallableFunction(false);
FunctionType const* functionType = _overriding.functionType();
FunctionType const* superType = _super.functionType();
solAssert(functionType->hasEqualParameterTypes(*superType), "Override doesn't have equal parameters!");
if (!functionType->hasEqualReturnTypes(*superType))
overrideError(_overriding, _super, "Overriding " + overridingName + " return types differ.");
overrideError(_overriding, _super, "Overriding " + _overriding.astNodeName() + " return types differ.");
// This is only relevant for a function overriding a function.
if constexpr(std::is_same<T, FunctionDefinition>::value)
if (_overriding.isFunction())
{
_overriding.annotation().baseFunctions.emplace(&_super);
if (_overriding.stateMutability() != _super.stateMutability())
overrideError(
_overriding,
@ -366,7 +479,7 @@ void OverrideChecker::checkOverride(T const& _overriding, U const& _super)
"\"."
);
if (!_overriding.isImplemented() && _super.isImplemented())
if (_overriding.unimplemented() && !_super.unimplemented())
overrideError(
_overriding,
_super,
@ -377,7 +490,7 @@ void OverrideChecker::checkOverride(T const& _overriding, U const& _super)
}
void OverrideChecker::overrideListError(
CallableDeclaration const& _callable,
OverrideProxy const& _item,
set<ContractDefinition const*, LessFunction> _secondary,
string const& _message1,
string const& _message2
@ -396,7 +509,7 @@ void OverrideChecker::overrideListError(
contractSingularPlural = "contracts ";
m_errorReporter.typeError(
_callable.overrides() ? _callable.overrides()->location() : _callable.location(),
_item.overrides() ? _item.overrides()->location() : _item.location(),
ssl,
_message1 +
contractSingularPlural +
@ -406,7 +519,17 @@ void OverrideChecker::overrideListError(
);
}
void OverrideChecker::overrideError(Declaration const& _overriding, Declaration const& _super, string _message, string _secondaryMsg)
void OverrideChecker::overrideError(Declaration const& _overriding, Declaration const& _super, string const& _message, string const& _secondaryMsg)
{
m_errorReporter.typeError(
_overriding.location(),
SecondarySourceLocation().append(_secondaryMsg, _super.location()),
_message
);
}
void OverrideChecker::overrideError(OverrideProxy const& _overriding, OverrideProxy const& _super, string const& _message, string const& _secondaryMsg)
{
m_errorReporter.typeError(
_overriding.location(),
@ -417,21 +540,23 @@ void OverrideChecker::overrideError(Declaration const& _overriding, Declaration
void OverrideChecker::checkAmbiguousOverrides(ContractDefinition const& _contract) const
{
std::function<bool(CallableDeclaration const*, CallableDeclaration const*)> compareById =
[](auto const* _a, auto const* _b) { return _a->id() < _b->id(); };
{
// Fetch inherited functions and sort them by signature.
// We get at least one function per signature and direct base contract, which is
// enough because we re-construct the inheritance graph later.
FunctionMultiSet nonOverriddenFunctions = inheritedFunctions(_contract);
OverrideProxyBySignatureMultiSet nonOverriddenFunctions = inheritedFunctions(_contract);
for (OverrideProxy stateVar: inheritedPublicStateVariables(_contract))
nonOverriddenFunctions.insert(stateVar);
// Remove all functions that match the signature of a function in the current contract.
nonOverriddenFunctions -= _contract.definedFunctions();
for (FunctionDefinition const* f: _contract.definedFunctions())
nonOverriddenFunctions.erase(OverrideProxy{f});
// Walk through the set of functions signature by signature.
for (auto it = nonOverriddenFunctions.cbegin(); it != nonOverriddenFunctions.cend();)
{
std::set<CallableDeclaration const*, decltype(compareById)> baseFunctions(compareById);
std::set<OverrideProxy> baseFunctions;
for (auto nextSignature = nonOverriddenFunctions.upper_bound(*it); it != nextSignature; ++it)
baseFunctions.insert(*it);
@ -440,11 +565,13 @@ void OverrideChecker::checkAmbiguousOverrides(ContractDefinition const& _contrac
}
{
ModifierMultiSet modifiers = inheritedModifiers(_contract);
modifiers -= _contract.functionModifiers();
OverrideProxyBySignatureMultiSet modifiers = inheritedModifiers(_contract);
for (ModifierDefinition const* mod: _contract.functionModifiers())
modifiers.erase(OverrideProxy{mod});
for (auto it = modifiers.cbegin(); it != modifiers.cend();)
{
std::set<CallableDeclaration const*, decltype(compareById)> baseModifiers(compareById);
std::set<OverrideProxy> baseModifiers;
for (auto next = modifiers.upper_bound(*it); it != next; ++it)
baseModifiers.insert(*it);
@ -454,10 +581,7 @@ void OverrideChecker::checkAmbiguousOverrides(ContractDefinition const& _contrac
}
}
void OverrideChecker::checkAmbiguousOverridesInternal(set<
CallableDeclaration const*,
std::function<bool(CallableDeclaration const*, CallableDeclaration const*)>
> _baseCallables, SourceLocation const& _location) const
void OverrideChecker::checkAmbiguousOverridesInternal(set<OverrideProxy> _baseCallables, SourceLocation const& _location) const
{
if (_baseCallables.size() <= 1)
return;
@ -468,13 +592,13 @@ void OverrideChecker::checkAmbiguousOverridesInternal(set<
// connect at the end.
struct OverrideGraph
{
OverrideGraph(decltype(_baseCallables) const& __baseCallables)
OverrideGraph(decltype(_baseCallables) const& _baseCallables)
{
for (auto const* baseFunction: __baseCallables)
for (auto const& baseFunction: _baseCallables)
addEdge(0, visit(baseFunction));
}
std::map<CallableDeclaration const*, int> nodes;
std::map<int, CallableDeclaration const*> nodeInv;
std::map<OverrideProxy, int> nodes;
std::map<int, OverrideProxy> nodeInv;
std::map<int, std::set<int>> edges;
int numNodes = 2;
void addEdge(int _a, int _b)
@ -485,7 +609,7 @@ void OverrideChecker::checkAmbiguousOverridesInternal(set<
private:
/// Completes the graph starting from @a _function and
/// @returns the node ID.
int visit(CallableDeclaration const* _function)
int visit(OverrideProxy const& _function)
{
auto it = nodes.find(_function);
if (it != nodes.end())
@ -493,8 +617,8 @@ void OverrideChecker::checkAmbiguousOverridesInternal(set<
int currentNode = numNodes++;
nodes[_function] = currentNode;
nodeInv[currentNode] = _function;
if (_function->overrides())
for (auto const* baseFunction: _function->annotation().baseFunctions)
if (_function.overrides())
for (auto const& baseFunction: _function.baseFunctions())
addEdge(currentNode, visit(baseFunction));
else
addEdge(currentNode, 1);
@ -511,7 +635,7 @@ void OverrideChecker::checkAmbiguousOverridesInternal(set<
{
run();
}
std::set<CallableDeclaration const*> const& cutVertices() const { return m_cutVertices; }
std::set<OverrideProxy> const& cutVertices() const { return m_cutVertices; }
private:
OverrideGraph const& m_graph;
@ -520,7 +644,7 @@ void OverrideChecker::checkAmbiguousOverridesInternal(set<
std::vector<int> m_depths = std::vector<int>(m_graph.numNodes, -1);
std::vector<int> m_low = std::vector<int>(m_graph.numNodes, -1);
std::vector<int> m_parent = std::vector<int>(m_graph.numNodes, -1);
std::set<CallableDeclaration const*> m_cutVertices{};
std::set<OverrideProxy> m_cutVertices{};
void run(int _u = 0, int _depth = 0)
{
@ -541,21 +665,20 @@ void OverrideChecker::checkAmbiguousOverridesInternal(set<
} cutVertexFinder{overrideGraph};
// Remove all base functions overridden by cut vertices (they don't need to be overridden).
for (auto const* function: cutVertexFinder.cutVertices())
for (OverrideProxy const& function: cutVertexFinder.cutVertices())
{
std::set<CallableDeclaration const*> toTraverse = function->annotation().baseFunctions;
std::set<OverrideProxy> toTraverse = function.baseFunctions();
while (!toTraverse.empty())
{
auto const *base = *toTraverse.begin();
OverrideProxy base = *toTraverse.begin();
toTraverse.erase(toTraverse.begin());
_baseCallables.erase(base);
for (CallableDeclaration const* f: base->annotation().baseFunctions)
for (OverrideProxy const& f: base.baseFunctions())
toTraverse.insert(f);
}
// Remove unimplemented base functions at the cut vertices itself as well.
if (auto opt = dynamic_cast<ImplementationOptional const*>(function))
if (!opt->isImplemented())
_baseCallables.erase(function);
if (function.unimplemented())
_baseCallables.erase(function);
}
// If more than one function is left, they have to be overridden.
@ -563,34 +686,28 @@ void OverrideChecker::checkAmbiguousOverridesInternal(set<
return;
SecondarySourceLocation ssl;
for (auto const* baseFunction: _baseCallables)
{
string contractName = dynamic_cast<ContractDefinition const&>(*baseFunction->scope()).name();
ssl.append("Definition in \"" + contractName + "\": ", baseFunction->location());
}
for (OverrideProxy const& baseFunction: _baseCallables)
ssl.append("Definition in \"" + baseFunction.contractName() + "\": ", baseFunction.location());
string callableName;
string distinguishigProperty;
if (dynamic_cast<FunctionDefinition const*>(*_baseCallables.begin()))
{
callableName = "function";
distinguishigProperty = "name and parameter types";
}
else if (dynamic_cast<ModifierDefinition const*>(*_baseCallables.begin()))
{
callableName = "modifier";
distinguishigProperty = "name";
}
else
solAssert(false, "Invalid type for ambiguous override.");
string callableName = _baseCallables.begin()->isModifier() ? _baseCallables.begin()->astNodeName() : "function";
string distinguishigProperty = _baseCallables.begin()->distinguishingProperty();
m_errorReporter.typeError(
_location,
ssl,
bool foundVariable = false;
for (auto const& base: _baseCallables)
if (base.isVariable())
foundVariable = true;
string message =
"Derived contract must override " + callableName + " \"" +
(*_baseCallables.begin())->name() +
"\". Two or more base classes define " + callableName + " with same " + distinguishigProperty + "."
);
_baseCallables.begin()->name() +
"\". Two or more base classes define " + callableName + " with same " + distinguishigProperty + ".";
if (foundVariable)
message +=
" Since one of the bases defines a public state variable which cannot be overridden, "
"you have to change the inheritance layout or the names of the functions.";
m_errorReporter.typeError(_location, ssl, message);
}
set<ContractDefinition const*, OverrideChecker::LessFunction> OverrideChecker::resolveOverrideList(OverrideSpecifier const& _overrides) const
@ -611,23 +728,19 @@ set<ContractDefinition const*, OverrideChecker::LessFunction> OverrideChecker::r
return resolved;
}
template <class T>
void OverrideChecker::checkOverrideList(
std::multiset<T const*, LessFunction> const& _inheritedCallables,
T const& _callable
)
void OverrideChecker::checkOverrideList(OverrideProxy _item, OverrideProxyBySignatureMultiSet const& _inherited)
{
set<ContractDefinition const*, LessFunction> specifiedContracts =
_callable.overrides() ?
resolveOverrideList(*_callable.overrides()) :
_item.overrides() ?
resolveOverrideList(*_item.overrides()) :
decltype(specifiedContracts){};
// Check for duplicates in override list
if (_callable.overrides() && specifiedContracts.size() != _callable.overrides()->overrides().size())
if (_item.overrides() && specifiedContracts.size() != _item.overrides()->overrides().size())
{
// Sort by contract id to find duplicate for error reporting
vector<ASTPointer<UserDefinedTypeName>> list =
sortByContract(_callable.overrides()->overrides());
sortByContract(_item.overrides()->overrides());
// Find duplicates and output error
for (size_t i = 1; i < list.size(); i++)
@ -647,144 +760,123 @@ void OverrideChecker::checkOverrideList(
"Duplicate contract \"" +
joinHumanReadable(list[i]->namePath(), ".") +
"\" found in override list of \"" +
_callable.name() +
_item.name() +
"\"."
);
}
}
}
decltype(specifiedContracts) expectedContracts;
set<ContractDefinition const*, LessFunction> expectedContracts;
// Build list of expected contracts
for (auto [begin, end] = _inheritedCallables.equal_range(&_callable); begin != end; begin++)
for (auto [begin, end] = _inherited.equal_range(_item); begin != end; begin++)
{
// Validate the override
checkOverride(_callable, **begin);
checkOverride(_item, *begin);
expectedContracts.insert(&dynamic_cast<ContractDefinition const&>(*(*begin)->scope()));
expectedContracts.insert(&begin->contract());
}
decltype(specifiedContracts) missingContracts;
decltype(specifiedContracts) surplusContracts;
if (_item.overrides() && expectedContracts.empty())
m_errorReporter.typeError(
_item.overrides()->location(),
_item.astNodeNameCapitalized() + " has override specified but does not override anything."
);
decltype(specifiedContracts) missingContracts;
// If we expect only one contract, no contract needs to be specified
if (expectedContracts.size() > 1)
missingContracts = expectedContracts - specifiedContracts;
surplusContracts = specifiedContracts - expectedContracts;
if (!missingContracts.empty())
overrideListError(
_callable,
_item,
missingContracts,
"Function needs to specify overridden ",
_item.astNodeNameCapitalized() + " needs to specify overridden ",
""
);
auto surplusContracts = specifiedContracts - expectedContracts;
if (!surplusContracts.empty())
overrideListError(
_callable,
_item,
surplusContracts,
"Invalid ",
"specified in override list: "
);
}
OverrideChecker::FunctionMultiSet const& OverrideChecker::inheritedFunctions(ContractDefinition const& _contract) const
OverrideChecker::OverrideProxyBySignatureMultiSet const& OverrideChecker::inheritedFunctions(ContractDefinition const& _contract) const
{
if (!m_inheritedFunctions.count(&_contract))
{
FunctionMultiSet set;
OverrideProxyBySignatureMultiSet result;
for (auto const* base: resolveDirectBaseContracts(_contract))
{
std::set<FunctionDefinition const*, LessFunction> functionsInBase;
for (FunctionDefinition const* fun: base->definedFunctions())
if (!fun->isConstructor())
functionsInBase.emplace(fun);
for (auto const& func: inheritedFunctions(*base))
functionsInBase.insert(func);
set += functionsInBase;
}
m_inheritedFunctions[&_contract] = set;
}
return m_inheritedFunctions[&_contract];
}
OverrideChecker::ModifierMultiSet const& OverrideChecker::inheritedModifiers(ContractDefinition const& _contract) const
{
auto const& result = m_contractBaseModifiers.find(&_contract);
if (result != m_contractBaseModifiers.cend())
return result->second;
ModifierMultiSet set;
for (auto const* base: resolveDirectBaseContracts(_contract))
{
std::set<ModifierDefinition const*, LessFunction> tmpSet =
convertContainer<decltype(tmpSet)>(base->functionModifiers());
for (auto const& mod: inheritedModifiers(*base))
tmpSet.insert(mod);
set += tmpSet;
}
return m_contractBaseModifiers[&_contract] = set;
}
multiset<OverrideProxy> const& OverrideChecker::inheritedFunctionsByProxy(ContractDefinition const& _contract) const
{
if (!m_inheritedFunctionsByProxy.count(&_contract))
{
multiset<OverrideProxy> result;
for (auto const* base: resolveDirectBaseContracts(_contract))
{
std::set<OverrideProxy> functionsInBase;
set<OverrideProxy, OverrideProxy::CompareBySignature> functionsInBase;
for (FunctionDefinition const* fun: base->definedFunctions())
if (!fun->isConstructor())
functionsInBase.emplace(OverrideProxy{fun});
for (OverrideProxy const& func: inheritedFunctionsByProxy(*base))
for (OverrideProxy const& func: inheritedFunctions(*base))
functionsInBase.insert(func);
result += functionsInBase;
}
m_inheritedFunctionsByProxy[&_contract] = result;
m_inheritedFunctions[&_contract] = result;
}
return m_inheritedFunctionsByProxy[&_contract];
return m_inheritedFunctions[&_contract];
}
multiset<OverrideProxy> const& OverrideChecker::inheritedModifiersByProxy(ContractDefinition const& _contract) const
OverrideChecker::OverrideProxyBySignatureMultiSet const& OverrideChecker::inheritedModifiers(ContractDefinition const& _contract) const
{
if (!m_inheritedModifiersByProxy.count(&_contract))
if (!m_inheritedModifiers.count(&_contract))
{
multiset<OverrideProxy> result;
OverrideProxyBySignatureMultiSet result;
for (auto const* base: resolveDirectBaseContracts(_contract))
{
std::set<OverrideProxy> modifiersInBase;
set<OverrideProxy, OverrideProxy::CompareBySignature> modifiersInBase;
for (ModifierDefinition const* mod: base->functionModifiers())
modifiersInBase.emplace(OverrideProxy{mod});
for (OverrideProxy const& mod: inheritedModifiersByProxy(*base))
for (OverrideProxy const& mod: inheritedModifiers(*base))
modifiersInBase.insert(mod);
result += modifiersInBase;
}
m_inheritedModifiersByProxy[&_contract] = result;
m_inheritedModifiers[&_contract] = result;
}
return m_inheritedModifiersByProxy[&_contract];
return m_inheritedModifiers[&_contract];
}
OverrideChecker::OverrideProxyBySignatureMultiSet const& OverrideChecker::inheritedPublicStateVariables(ContractDefinition const& _contract) const
{
if (!m_inheritedPublicStateVariables.count(&_contract))
{
OverrideProxyBySignatureMultiSet result;
for (auto const* base: resolveDirectBaseContracts(_contract))
{
set<OverrideProxy, OverrideProxy::CompareBySignature> stateVarsInBase;
for (VariableDeclaration const* var: base->stateVariables())
if (var->isPublic())
stateVarsInBase.emplace(OverrideProxy{var});
for (OverrideProxy const& mod: inheritedPublicStateVariables(*base))
stateVarsInBase.insert(mod);
result += stateVarsInBase;
}
m_inheritedPublicStateVariables[&_contract] = result;
}
return m_inheritedPublicStateVariables[&_contract];
}

View File

@ -21,37 +21,107 @@
#pragma once
#include <libsolidity/ast/ASTForward.h>
#include <libsolidity/ast/ASTEnums.h>
#include <liblangutil/SourceLocation.h>
#include <map>
#include <functional>
#include <set>
#include <variant>
#include <optional>
namespace langutil
{
class ErrorReporter;
struct SourceLocation;
}
namespace dev
{
namespace solidity
{
class FunctionType;
class ModifierType;
/**
* Class that represents a function, public state variable or modifier
* and helps with overload checking.
* Comparison results in two elements being equal when they can override each
* Regular comparison is performed based on AST node, while CompareBySignature
* results in two elements being equal when they can override each
* other.
*/
class OverrideProxy
{
public:
OverrideProxy() {}
explicit OverrideProxy(FunctionDefinition const* _fun): m_item{_fun} {}
explicit OverrideProxy(ModifierDefinition const* _mod): m_item{_mod} {}
explicit OverrideProxy(VariableDeclaration const* _var): m_item{_var} {}
bool operator<(OverrideProxy const& _other) const;
struct CompareBySignature
{
bool operator()(OverrideProxy const& _a, OverrideProxy const& _b) const;
};
bool isVariable() const;
bool isFunction() const;
bool isModifier() const;
size_t id() const;
std::shared_ptr<OverrideSpecifier> overrides() const;
std::set<OverrideProxy> baseFunctions() const;
/// This stores the item in the list of base items.
void storeBaseFunction(OverrideProxy const& _base) const;
std::string const& name() const;
ContractDefinition const& contract() const;
std::string const& contractName() const;
Visibility visibility() const;
StateMutability stateMutability() const;
bool virtualSemantics() const;
/// @returns receive / fallback / function (only the latter for modifiers and variables);
langutil::Token functionKind() const;
FunctionType const* functionType() const;
ModifierType const* modifierType() const;
langutil::SourceLocation const& location() const;
std::string astNodeName() const;
std::string astNodeNameCapitalized() const;
std::string distinguishingProperty() const;
/// @returns true if this AST elements supports the feature of being unimplemented
/// and is actually not implemented.
bool unimplemented() const;
/**
* Struct to help comparing override items about whether they override each other.
* Does not produce a total order.
*/
struct OverrideComparator
{
std::string name;
std::optional<langutil::Token> functionKind;
std::optional<std::vector<std::string>> parameterTypes;
bool operator<(OverrideComparator const& _other) const;
};
/// @returns a structure used to compare override items with regards to whether
/// they override each other.
OverrideComparator const& overrideComparator() const;
private:
std::variant<
FunctionDefinition const*,
ModifierDefinition const*,
VariableDeclaration const*
> item;
> m_item;
std::shared_ptr<OverrideComparator> mutable m_comparator;
};
@ -91,10 +161,9 @@ private:
/// different return type, invalid visibility change, etc.
/// Works on functions, modifiers and public state variables.
/// Also stores @a _super as a base function of @a _function in its AST annotations.
template<class T, class U>
void checkOverride(T const& _overriding, U const& _super);
void checkOverride(OverrideProxy const& _overriding, OverrideProxy const& _super);
void overrideListError(
CallableDeclaration const& _callable,
OverrideProxy const& _item,
std::set<ContractDefinition const*, LessFunction> _secondary,
std::string const& _message1,
std::string const& _message2
@ -102,41 +171,38 @@ private:
void overrideError(
Declaration const& _overriding,
Declaration const& _super,
std::string _message,
std::string _secondaryMsg = "Overridden function is here:"
std::string const& _message,
std::string const& _secondaryMsg = "Overridden function is here:"
);
void overrideError(
OverrideProxy const& _overriding,
OverrideProxy const& _super,
std::string const& _message,
std::string const& _secondaryMsg = "Overridden function is here:"
);
/// Checks for functions in different base contracts which conflict with each
/// other and thus need to be overridden explicitly.
void checkAmbiguousOverrides(ContractDefinition const& _contract) const;
void checkAmbiguousOverridesInternal(std::set<
CallableDeclaration const*,
std::function<bool(CallableDeclaration const*, CallableDeclaration const*)>
> _baseCallables, langutil::SourceLocation const& _location) const;
void checkAmbiguousOverridesInternal(std::set<OverrideProxy> _baseCallables, langutil::SourceLocation const& _location) const;
/// Resolves an override list of UserDefinedTypeNames to a list of contracts.
std::set<ContractDefinition const*, LessFunction> resolveOverrideList(OverrideSpecifier const& _overrides) const;
template <class T>
void checkOverrideList(
std::multiset<T const*, LessFunction> const& _funcSet,
T const& _function
);
using OverrideProxyBySignatureMultiSet = std::multiset<OverrideProxy, OverrideProxy::CompareBySignature>;
void checkOverrideList(OverrideProxy _item, OverrideProxyBySignatureMultiSet const& _inherited);
/// Returns all functions of bases that have not yet been overwritten.
/// May contain the same function multiple times when used with shared bases.
FunctionMultiSet const& inheritedFunctions(ContractDefinition const& _contract) const;
ModifierMultiSet const& inheritedModifiers(ContractDefinition const& _contract) const;
std::multiset<OverrideProxy> const& inheritedFunctionsByProxy(ContractDefinition const& _contract) const;
std::multiset<OverrideProxy> const& inheritedModifiersByProxy(ContractDefinition const& _contract) const;
OverrideProxyBySignatureMultiSet const& inheritedFunctions(ContractDefinition const& _contract) const;
OverrideProxyBySignatureMultiSet const& inheritedModifiers(ContractDefinition const& _contract) const;
OverrideProxyBySignatureMultiSet const& inheritedPublicStateVariables(ContractDefinition const& _contract) const;
langutil::ErrorReporter& m_errorReporter;
/// Cache for inheritedFunctions().
std::map<ContractDefinition const*, FunctionMultiSet> mutable m_inheritedFunctions;
std::map<ContractDefinition const*, ModifierMultiSet> mutable m_contractBaseModifiers;
std::map<ContractDefinition const*, std::multiset<OverrideProxy>> mutable m_inheritedFunctionsByProxy;
std::map<ContractDefinition const*, std::multiset<OverrideProxy>> mutable m_inheritedModifiersByProxy;
std::map<ContractDefinition const*, OverrideProxyBySignatureMultiSet> mutable m_inheritedFunctions;
std::map<ContractDefinition const*, OverrideProxyBySignatureMultiSet> mutable m_inheritedModifiers;
std::map<ContractDefinition const*, OverrideProxyBySignatureMultiSet> mutable m_inheritedPublicStateVariables;
};
}

View File

@ -128,6 +128,8 @@ struct VariableDeclarationAnnotation: ASTAnnotation
{
/// Type of variable (type of identifier referencing this variable).
TypePointer type = nullptr;
/// The set of functions this (public state) variable overrides.
std::set<CallableDeclaration const*> baseFunctions;
};
struct StatementAnnotation: ASTAnnotation, DocumentedAnnotation