Merge pull request #830 from chriseth/sol_overridesInConstructorContext

Include virtual function overrides in constructor context.
This commit is contained in:
Gav Wood 2015-01-20 10:02:18 -08:00
commit 30b455e4d6
9 changed files with 61 additions and 25 deletions

View File

@ -82,7 +82,7 @@ map<FixedHash<4>, FunctionDefinition const*> ContractDefinition::getInterfaceFun
FunctionDefinition const* ContractDefinition::getConstructor() const
{
for (ASTPointer<FunctionDefinition> const& f: m_definedFunctions)
if (f->getName() == getName())
if (f->isConstructor())
return f.get();
return nullptr;
}
@ -95,7 +95,7 @@ void ContractDefinition::checkIllegalOverrides() const
for (ContractDefinition const* contract: getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
{
if (function->getName() == contract->getName())
if (function->isConstructor())
continue; // constructors can neither be overriden nor override anything
FunctionDefinition const*& override = functions[function->getName()];
if (!override)
@ -115,8 +115,7 @@ vector<pair<FixedHash<4>, FunctionDefinition const*>> const& ContractDefinition:
m_interfaceFunctionList.reset(new vector<pair<FixedHash<4>, FunctionDefinition const*>>());
for (ContractDefinition const* contract: getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& f: contract->getDefinedFunctions())
if (f->isPublic() && f->getName() != contract->getName() &&
functionsSeen.count(f->getName()) == 0)
if (f->isPublic() && !f->isConstructor() && functionsSeen.count(f->getName()) == 0)
{
functionsSeen.insert(f->getName());
FixedHash<4> hash(dev::sha3(f->getCanonicalSignature()));

5
AST.h
View File

@ -281,12 +281,13 @@ class FunctionDefinition: public Declaration
public:
FunctionDefinition(Location const& _location, ASTPointer<ASTString> const& _name,
bool _isPublic,
bool _isConstructor,
ASTPointer<ASTString> const& _documentation,
ASTPointer<ParameterList> const& _parameters,
bool _isDeclaredConst,
ASTPointer<ParameterList> const& _returnParameters,
ASTPointer<Block> const& _body):
Declaration(_location, _name), m_isPublic(_isPublic),
Declaration(_location, _name), m_isPublic(_isPublic), m_isConstructor(_isConstructor),
m_parameters(_parameters),
m_isDeclaredConst(_isDeclaredConst),
m_returnParameters(_returnParameters),
@ -298,6 +299,7 @@ public:
virtual void accept(ASTConstVisitor& _visitor) const override;
bool isPublic() const { return m_isPublic; }
bool isConstructor() const { return m_isConstructor; }
bool isDeclaredConst() const { return m_isDeclaredConst; }
std::vector<ASTPointer<VariableDeclaration>> const& getParameters() const { return m_parameters->getParameters(); }
ParameterList const& getParameterList() const { return *m_parameters; }
@ -321,6 +323,7 @@ public:
private:
bool m_isPublic;
bool m_isConstructor;
ASTPointer<ParameterList> m_parameters;
bool m_isDeclaredConst;
ASTPointer<ParameterList> m_returnParameters;

View File

@ -38,6 +38,7 @@ void CallGraph::addNode(ASTNode const& _node)
set<FunctionDefinition const*> const& CallGraph::getCalls()
{
computeCallGraph();
return m_functionsSeen;
}
@ -45,8 +46,7 @@ void CallGraph::computeCallGraph()
{
while (!m_workQueue.empty())
{
FunctionDefinition const* fun = m_workQueue.front();
fun->accept(*this);
m_workQueue.front()->accept(*this);
m_workQueue.pop();
}
}
@ -55,7 +55,12 @@ bool CallGraph::visit(Identifier const& _identifier)
{
FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration());
if (fun)
{
if (m_overrideResolver)
fun = (*m_overrideResolver)(fun->getName());
solAssert(fun, "Error finding override for function " + fun->getName());
addFunction(*fun);
}
return true;
}

View File

@ -22,6 +22,7 @@
#include <set>
#include <queue>
#include <functional>
#include <boost/range/iterator_range.hpp>
#include <libsolidity/ASTVisitor.h>
@ -38,8 +39,11 @@ namespace solidity
class CallGraph: private ASTConstVisitor
{
public:
using OverrideResolver = std::function<FunctionDefinition const*(std::string const&)>;
CallGraph(OverrideResolver const& _overrideResolver): m_overrideResolver(&_overrideResolver) {}
void addNode(ASTNode const& _node);
void computeCallGraph();
std::set<FunctionDefinition const*> const& getCalls();
@ -48,8 +52,10 @@ private:
virtual bool visit(Identifier const& _identifier) override;
virtual bool visit(MemberAccess const& _memberAccess) override;
void computeCallGraph();
void addFunction(FunctionDefinition const& _function);
OverrideResolver const* m_overrideResolver;
std::set<FunctionDefinition const*> m_functionsSeen;
std::queue<FunctionDefinition const*> m_workQueue;
};

View File

@ -43,13 +43,13 @@ void Compiler::compileContract(ContractDefinition const& _contract,
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
if (function->getName() != contract->getName()) // don't add the constructor here
if (!function->isConstructor())
m_context.addFunction(*function);
appendFunctionSelector(_contract);
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
if (function->getName() != contract->getName()) // don't add the constructor here
if (!function->isConstructor())
function->accept(*this);
// Swap the runtime context with the creation-time context
@ -93,10 +93,29 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
}
}
//@TODO add virtual functions
neededFunctions = getFunctionsCalled(nodesUsedInConstructors);
auto overrideResolver = [&](string const& _name) -> FunctionDefinition const*
{
for (ContractDefinition const* contract: bases)
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
if (!function->isConstructor() && function->getName() == _name)
return function.get();
return nullptr;
};
neededFunctions = getFunctionsCalled(nodesUsedInConstructors, overrideResolver);
// First add all overrides (or the functions themselves if there is no override)
for (FunctionDefinition const* fun: neededFunctions)
{
FunctionDefinition const* override = nullptr;
if (!fun->isConstructor())
override = overrideResolver(fun->getName());
if (!!override && neededFunctions.count(override))
m_context.addFunction(*override);
}
// now add the rest
for (FunctionDefinition const* fun: neededFunctions)
if (fun->isConstructor() || overrideResolver(fun->getName()) != fun)
m_context.addFunction(*fun);
// Call constructors in base-to-derived order.
@ -159,10 +178,10 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor)
m_context << returnTag;
}
set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes)
set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes,
function<FunctionDefinition const*(string const&)> const& _resolveOverrides)
{
// TODO this does not add virtual functions
CallGraph callgraph;
CallGraph callgraph(_resolveOverrides);
for (ASTNode const* node: _nodes)
callgraph.addNode(*node);
return callgraph.getCalls();

View File

@ -21,6 +21,7 @@
*/
#include <ostream>
#include <functional>
#include <libsolidity/ASTVisitor.h>
#include <libsolidity/CompilerContext.h>
@ -49,7 +50,9 @@ private:
std::vector<ASTPointer<Expression>> const& _arguments);
void appendConstructorCall(FunctionDefinition const& _constructor);
/// Recursively searches the call graph and returns all functions referenced inside _nodes.
std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes);
/// _resolveOverride is called to resolve virtual function overrides.
std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes,
std::function<FunctionDefinition const*(std::string const&)> const& _resolveOverride);
void appendFunctionSelector(ContractDefinition const& _contract);
/// Creates code that unpacks the arguments for the given function, from memory if
/// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.

View File

@ -112,9 +112,9 @@ ASTPointer<ImportDirective> Parser::parseImportDirective()
ASTPointer<ContractDefinition> Parser::parseContractDefinition()
{
ASTNodeFactory nodeFactory(*this);
ASTPointer<ASTString> docstring;
ASTPointer<ASTString> docString;
if (m_scanner->getCurrentCommentLiteral() != "")
docstring = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral());
docString = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral());
expectToken(Token::CONTRACT);
ASTPointer<ASTString> name = expectIdentifierToken();
vector<ASTPointer<InheritanceSpecifier>> baseContracts;
@ -142,7 +142,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition()
expectToken(Token::COLON);
}
else if (currentToken == Token::FUNCTION)
functions.push_back(parseFunctionDefinition(visibilityIsPublic));
functions.push_back(parseFunctionDefinition(visibilityIsPublic, name.get()));
else if (currentToken == Token::STRUCT)
structs.push_back(parseStructDefinition());
else if (currentToken == Token::IDENTIFIER || currentToken == Token::MAPPING ||
@ -157,7 +157,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition()
}
nodeFactory.markEndPosition();
expectToken(Token::RBRACE);
return nodeFactory.createNode<ContractDefinition>(name, docstring, baseContracts, structs,
return nodeFactory.createNode<ContractDefinition>(name, docString, baseContracts, structs,
stateVariables, functions);
}
@ -178,7 +178,7 @@ ASTPointer<InheritanceSpecifier> Parser::parseInheritanceSpecifier()
return nodeFactory.createNode<InheritanceSpecifier>(name, arguments);
}
ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic)
ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic, ASTString const* _contractName)
{
ASTNodeFactory nodeFactory(*this);
ASTPointer<ASTString> docstring;
@ -210,7 +210,8 @@ ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic)
}
ASTPointer<Block> block = parseBlock();
nodeFactory.setEndPositionFromNode(block);
return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, docstring,
bool const c_isConstructor = (_contractName && *name == *_contractName);
return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, c_isConstructor, docstring,
parameters,
isDeclaredConst, returnParameters, block);
}

View File

@ -50,7 +50,7 @@ private:
ASTPointer<ImportDirective> parseImportDirective();
ASTPointer<ContractDefinition> parseContractDefinition();
ASTPointer<InheritanceSpecifier> parseInheritanceSpecifier();
ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic);
ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic, ASTString const* _contractName);
ASTPointer<StructDefinition> parseStructDefinition();
ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar);
ASTPointer<TypeName> parseTypeName(bool _allowVar);

View File

@ -716,7 +716,7 @@ MemberList const& TypeType::getMembers() const
// We are accessing the type of a base contract, so add all public and private
// functions. Note that this does not add inherited functions on purpose.
for (ASTPointer<FunctionDefinition> const& f: contract.getDefinedFunctions())
if (f->getName() != contract.getName())
if (!f->isConstructor())
members[f->getName()] = make_shared<FunctionType>(*f);
}
m_members.reset(new MemberList(members));