Generalize cycle detection.

This commit is contained in:
chriseth 2018-03-15 19:53:29 +01:00 committed by Alex Beregszaszi
parent 5bdadff0d8
commit eb5b18e814
4 changed files with 104 additions and 35 deletions

76
libdevcore/Algorithms.h Normal file
View File

@ -0,0 +1,76 @@
/*
This file is part of solidity.
solidity is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
solidity is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with solidity. If not, see <http://www.gnu.org/licenses/>.
*/
#pragma once
#include <functional>
#include <set>
namespace dev
{
/**
* Detector for cycles in directed graphs. It returns the first
* vertex on the path towards a cycle or a nullptr if there is
* no reachable cycle starting from a given vertex.
*/
template <typename V>
class CycleDetector
{
public:
/// Initializes the cycle detector
/// @param _visit function that is given the current vertex
/// and is supposed to call @a run on all
/// adjacent vertices.
explicit CycleDetector(std::function<void(V const&, CycleDetector&)> _visit):
m_visit(std::move(_visit))
{ }
/// Recursively perform cycle detection starting
/// (or continuing) with @param _vertex
/// @returns the first vertex on the path towards a cycle from @a _vertex
/// or nullptr if no cycle is reachable from @a _vertex.
V const* run(V const& _vertex)
{
if (m_firstCycleVertex)
return m_firstCycleVertex;
if (m_processed.count(&_vertex))
return nullptr;
else if (m_processing.count(&_vertex))
return m_firstCycleVertex = &_vertex;
m_processing.insert(&_vertex);
m_depth++;
m_visit(_vertex, *this);
m_depth--;
if (m_firstCycleVertex && m_depth == 1)
m_firstCycleVertex = &_vertex;
m_processing.erase(&_vertex);
m_processed.insert(&_vertex);
return m_firstCycleVertex;
}
private:
std::function<void(V const&, CycleDetector&)> m_visit;
std::set<V const*> m_processing;
std::set<V const*> m_processed;
size_t m_depth = 0;
V const* m_firstCycleVertex = nullptr;
};
}

View File

@ -21,6 +21,8 @@
#include <libsolidity/interface/ErrorReporter.h>
#include <libsolidity/interface/Version.h>
#include <libdevcore/Algorithms.h>
#include <boost/range/adaptor/map.hpp>
#include <memory>
@ -47,7 +49,7 @@ void PostTypeChecker::endVisit(ContractDefinition const&)
{
solAssert(!m_currentConstVariable, "");
for (auto declaration: m_constVariables)
if (auto identifier = findCycle(declaration))
if (auto identifier = findCycle(*declaration))
m_errorReporter.typeError(
declaration->location(),
"The value of the constant " + declaration->name() +
@ -87,20 +89,24 @@ bool PostTypeChecker::visit(Identifier const& _identifier)
return true;
}
VariableDeclaration const* PostTypeChecker::findCycle(
VariableDeclaration const* _startingFrom,
set<VariableDeclaration const*> const& _seen
)
VariableDeclaration const* PostTypeChecker::findCycle(VariableDeclaration const& _startingFrom)
{
if (_seen.count(_startingFrom))
return _startingFrom;
else if (m_constVariableDependencies.count(_startingFrom))
auto visitor = [&](VariableDeclaration const& _variable, CycleDetector<VariableDeclaration>& _cycleDetector)
{
set<VariableDeclaration const*> seen(_seen);
seen.insert(_startingFrom);
for (auto v: m_constVariableDependencies[_startingFrom])
if (findCycle(v, seen))
return v;
}
return nullptr;
// Iterating through the dependencies needs to be deterministic and thus cannot
// depend on the memory layout.
// Because of that, we sort by AST node id.
vector<VariableDeclaration const*> dependencies(
m_constVariableDependencies[&_variable].begin(),
m_constVariableDependencies[&_variable].end()
);
sort(dependencies.begin(), dependencies.end(), [](VariableDeclaration const* _a, VariableDeclaration const* _b) -> bool
{
return _a->id() < _b->id();
});
for (auto v: dependencies)
if (_cycleDetector.run(*v))
return;
};
return CycleDetector<VariableDeclaration>(visitor).run(_startingFrom);
}

View File

@ -55,10 +55,7 @@ private:
virtual bool visit(Identifier const& _identifier) override;
VariableDeclaration const* findCycle(
VariableDeclaration const* _startingFrom,
std::set<VariableDeclaration const*> const& _seen = std::set<VariableDeclaration const*>{}
);
VariableDeclaration const* findCycle(VariableDeclaration const& _startingFrom);
ErrorReporter& m_errorReporter;

View File

@ -28,6 +28,7 @@
#include <libdevcore/CommonData.h>
#include <libdevcore/SHA3.h>
#include <libdevcore/UTF8.h>
#include <libdevcore/Algorithms.h>
#include <boost/algorithm/string/join.hpp>
#include <boost/algorithm/string/replace.hpp>
@ -1971,30 +1972,19 @@ bool StructType::recursive() const
{
if (!m_recursive.is_initialized())
{
set<StructDefinition const*> structsSeen;
set<StructDefinition const*> structsProcessed;
function<bool(StructType const*)> check = [&](StructType const* t) -> bool
auto visitor = [&](StructDefinition const& _struct, CycleDetector<StructDefinition>& _cycleDetector)
{
StructDefinition const* str = &t->structDefinition();
if (structsProcessed.count(str))
return false;
if (structsSeen.count(str))
return true;
structsSeen.insert(str);
for (ASTPointer<VariableDeclaration> const& variable: str->members())
for (ASTPointer<VariableDeclaration> const& variable: _struct.members())
{
Type const* memberType = variable->annotation().type.get();
while (dynamic_cast<ArrayType const*>(memberType))
memberType = dynamic_cast<ArrayType const*>(memberType)->baseType().get();
if (StructType const* innerStruct = dynamic_cast<StructType const*>(memberType))
if (check(innerStruct))
return true;
if (_cycleDetector.run(innerStruct->structDefinition()))
return;
}
structsSeen.erase(str);
structsProcessed.insert(str);
return false;
};
m_recursive = check(this);
m_recursive = (CycleDetector<StructDefinition>(visitor).run(structDefinition()) != nullptr);
}
return *m_recursive;
}