Merge pull request #830 from chriseth/sol_overridesInConstructorContext

Include virtual function overrides in constructor context.
This commit is contained in:
Gav Wood 2015-01-20 10:02:18 -08:00
commit 30b455e4d6
9 changed files with 61 additions and 25 deletions

View File

@ -82,7 +82,7 @@ map<FixedHash<4>, FunctionDefinition const*> ContractDefinition::getInterfaceFun
FunctionDefinition const* ContractDefinition::getConstructor() const FunctionDefinition const* ContractDefinition::getConstructor() const
{ {
for (ASTPointer<FunctionDefinition> const& f: m_definedFunctions) for (ASTPointer<FunctionDefinition> const& f: m_definedFunctions)
if (f->getName() == getName()) if (f->isConstructor())
return f.get(); return f.get();
return nullptr; return nullptr;
} }
@ -95,7 +95,7 @@ void ContractDefinition::checkIllegalOverrides() const
for (ContractDefinition const* contract: getLinearizedBaseContracts()) for (ContractDefinition const* contract: getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
{ {
if (function->getName() == contract->getName()) if (function->isConstructor())
continue; // constructors can neither be overriden nor override anything continue; // constructors can neither be overriden nor override anything
FunctionDefinition const*& override = functions[function->getName()]; FunctionDefinition const*& override = functions[function->getName()];
if (!override) if (!override)
@ -115,8 +115,7 @@ vector<pair<FixedHash<4>, FunctionDefinition const*>> const& ContractDefinition:
m_interfaceFunctionList.reset(new vector<pair<FixedHash<4>, FunctionDefinition const*>>()); m_interfaceFunctionList.reset(new vector<pair<FixedHash<4>, FunctionDefinition const*>>());
for (ContractDefinition const* contract: getLinearizedBaseContracts()) for (ContractDefinition const* contract: getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& f: contract->getDefinedFunctions()) for (ASTPointer<FunctionDefinition> const& f: contract->getDefinedFunctions())
if (f->isPublic() && f->getName() != contract->getName() && if (f->isPublic() && !f->isConstructor() && functionsSeen.count(f->getName()) == 0)
functionsSeen.count(f->getName()) == 0)
{ {
functionsSeen.insert(f->getName()); functionsSeen.insert(f->getName());
FixedHash<4> hash(dev::sha3(f->getCanonicalSignature())); FixedHash<4> hash(dev::sha3(f->getCanonicalSignature()));

5
AST.h
View File

@ -281,12 +281,13 @@ class FunctionDefinition: public Declaration
public: public:
FunctionDefinition(Location const& _location, ASTPointer<ASTString> const& _name, FunctionDefinition(Location const& _location, ASTPointer<ASTString> const& _name,
bool _isPublic, bool _isPublic,
bool _isConstructor,
ASTPointer<ASTString> const& _documentation, ASTPointer<ASTString> const& _documentation,
ASTPointer<ParameterList> const& _parameters, ASTPointer<ParameterList> const& _parameters,
bool _isDeclaredConst, bool _isDeclaredConst,
ASTPointer<ParameterList> const& _returnParameters, ASTPointer<ParameterList> const& _returnParameters,
ASTPointer<Block> const& _body): ASTPointer<Block> const& _body):
Declaration(_location, _name), m_isPublic(_isPublic), Declaration(_location, _name), m_isPublic(_isPublic), m_isConstructor(_isConstructor),
m_parameters(_parameters), m_parameters(_parameters),
m_isDeclaredConst(_isDeclaredConst), m_isDeclaredConst(_isDeclaredConst),
m_returnParameters(_returnParameters), m_returnParameters(_returnParameters),
@ -298,6 +299,7 @@ public:
virtual void accept(ASTConstVisitor& _visitor) const override; virtual void accept(ASTConstVisitor& _visitor) const override;
bool isPublic() const { return m_isPublic; } bool isPublic() const { return m_isPublic; }
bool isConstructor() const { return m_isConstructor; }
bool isDeclaredConst() const { return m_isDeclaredConst; } bool isDeclaredConst() const { return m_isDeclaredConst; }
std::vector<ASTPointer<VariableDeclaration>> const& getParameters() const { return m_parameters->getParameters(); } std::vector<ASTPointer<VariableDeclaration>> const& getParameters() const { return m_parameters->getParameters(); }
ParameterList const& getParameterList() const { return *m_parameters; } ParameterList const& getParameterList() const { return *m_parameters; }
@ -321,6 +323,7 @@ public:
private: private:
bool m_isPublic; bool m_isPublic;
bool m_isConstructor;
ASTPointer<ParameterList> m_parameters; ASTPointer<ParameterList> m_parameters;
bool m_isDeclaredConst; bool m_isDeclaredConst;
ASTPointer<ParameterList> m_returnParameters; ASTPointer<ParameterList> m_returnParameters;

View File

@ -38,6 +38,7 @@ void CallGraph::addNode(ASTNode const& _node)
set<FunctionDefinition const*> const& CallGraph::getCalls() set<FunctionDefinition const*> const& CallGraph::getCalls()
{ {
computeCallGraph();
return m_functionsSeen; return m_functionsSeen;
} }
@ -45,8 +46,7 @@ void CallGraph::computeCallGraph()
{ {
while (!m_workQueue.empty()) while (!m_workQueue.empty())
{ {
FunctionDefinition const* fun = m_workQueue.front(); m_workQueue.front()->accept(*this);
fun->accept(*this);
m_workQueue.pop(); m_workQueue.pop();
} }
} }
@ -55,7 +55,12 @@ bool CallGraph::visit(Identifier const& _identifier)
{ {
FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration()); FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration());
if (fun) if (fun)
{
if (m_overrideResolver)
fun = (*m_overrideResolver)(fun->getName());
solAssert(fun, "Error finding override for function " + fun->getName());
addFunction(*fun); addFunction(*fun);
}
return true; return true;
} }

View File

@ -22,6 +22,7 @@
#include <set> #include <set>
#include <queue> #include <queue>
#include <functional>
#include <boost/range/iterator_range.hpp> #include <boost/range/iterator_range.hpp>
#include <libsolidity/ASTVisitor.h> #include <libsolidity/ASTVisitor.h>
@ -38,8 +39,11 @@ namespace solidity
class CallGraph: private ASTConstVisitor class CallGraph: private ASTConstVisitor
{ {
public: public:
using OverrideResolver = std::function<FunctionDefinition const*(std::string const&)>;
CallGraph(OverrideResolver const& _overrideResolver): m_overrideResolver(&_overrideResolver) {}
void addNode(ASTNode const& _node); void addNode(ASTNode const& _node);
void computeCallGraph();
std::set<FunctionDefinition const*> const& getCalls(); std::set<FunctionDefinition const*> const& getCalls();
@ -48,8 +52,10 @@ private:
virtual bool visit(Identifier const& _identifier) override; virtual bool visit(Identifier const& _identifier) override;
virtual bool visit(MemberAccess const& _memberAccess) override; virtual bool visit(MemberAccess const& _memberAccess) override;
void computeCallGraph();
void addFunction(FunctionDefinition const& _function); void addFunction(FunctionDefinition const& _function);
OverrideResolver const* m_overrideResolver;
std::set<FunctionDefinition const*> m_functionsSeen; std::set<FunctionDefinition const*> m_functionsSeen;
std::queue<FunctionDefinition const*> m_workQueue; std::queue<FunctionDefinition const*> m_workQueue;
}; };

View File

@ -43,13 +43,13 @@ void Compiler::compileContract(ContractDefinition const& _contract,
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
if (function->getName() != contract->getName()) // don't add the constructor here if (!function->isConstructor())
m_context.addFunction(*function); m_context.addFunction(*function);
appendFunctionSelector(_contract); appendFunctionSelector(_contract);
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
if (function->getName() != contract->getName()) // don't add the constructor here if (!function->isConstructor())
function->accept(*this); function->accept(*this);
// Swap the runtime context with the creation-time context // Swap the runtime context with the creation-time context
@ -93,10 +93,29 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
} }
} }
//@TODO add virtual functions auto overrideResolver = [&](string const& _name) -> FunctionDefinition const*
neededFunctions = getFunctionsCalled(nodesUsedInConstructors); {
for (ContractDefinition const* contract: bases)
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
if (!function->isConstructor() && function->getName() == _name)
return function.get();
return nullptr;
};
neededFunctions = getFunctionsCalled(nodesUsedInConstructors, overrideResolver);
// First add all overrides (or the functions themselves if there is no override)
for (FunctionDefinition const* fun: neededFunctions) for (FunctionDefinition const* fun: neededFunctions)
{
FunctionDefinition const* override = nullptr;
if (!fun->isConstructor())
override = overrideResolver(fun->getName());
if (!!override && neededFunctions.count(override))
m_context.addFunction(*override);
}
// now add the rest
for (FunctionDefinition const* fun: neededFunctions)
if (fun->isConstructor() || overrideResolver(fun->getName()) != fun)
m_context.addFunction(*fun); m_context.addFunction(*fun);
// Call constructors in base-to-derived order. // Call constructors in base-to-derived order.
@ -159,10 +178,10 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor)
m_context << returnTag; m_context << returnTag;
} }
set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes) set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes,
function<FunctionDefinition const*(string const&)> const& _resolveOverrides)
{ {
// TODO this does not add virtual functions CallGraph callgraph(_resolveOverrides);
CallGraph callgraph;
for (ASTNode const* node: _nodes) for (ASTNode const* node: _nodes)
callgraph.addNode(*node); callgraph.addNode(*node);
return callgraph.getCalls(); return callgraph.getCalls();

View File

@ -21,6 +21,7 @@
*/ */
#include <ostream> #include <ostream>
#include <functional>
#include <libsolidity/ASTVisitor.h> #include <libsolidity/ASTVisitor.h>
#include <libsolidity/CompilerContext.h> #include <libsolidity/CompilerContext.h>
@ -49,7 +50,9 @@ private:
std::vector<ASTPointer<Expression>> const& _arguments); std::vector<ASTPointer<Expression>> const& _arguments);
void appendConstructorCall(FunctionDefinition const& _constructor); void appendConstructorCall(FunctionDefinition const& _constructor);
/// Recursively searches the call graph and returns all functions referenced inside _nodes. /// Recursively searches the call graph and returns all functions referenced inside _nodes.
std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes); /// _resolveOverride is called to resolve virtual function overrides.
std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes,
std::function<FunctionDefinition const*(std::string const&)> const& _resolveOverride);
void appendFunctionSelector(ContractDefinition const& _contract); void appendFunctionSelector(ContractDefinition const& _contract);
/// Creates code that unpacks the arguments for the given function, from memory if /// Creates code that unpacks the arguments for the given function, from memory if
/// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes. /// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.

View File

@ -112,9 +112,9 @@ ASTPointer<ImportDirective> Parser::parseImportDirective()
ASTPointer<ContractDefinition> Parser::parseContractDefinition() ASTPointer<ContractDefinition> Parser::parseContractDefinition()
{ {
ASTNodeFactory nodeFactory(*this); ASTNodeFactory nodeFactory(*this);
ASTPointer<ASTString> docstring; ASTPointer<ASTString> docString;
if (m_scanner->getCurrentCommentLiteral() != "") if (m_scanner->getCurrentCommentLiteral() != "")
docstring = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral()); docString = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral());
expectToken(Token::CONTRACT); expectToken(Token::CONTRACT);
ASTPointer<ASTString> name = expectIdentifierToken(); ASTPointer<ASTString> name = expectIdentifierToken();
vector<ASTPointer<InheritanceSpecifier>> baseContracts; vector<ASTPointer<InheritanceSpecifier>> baseContracts;
@ -142,7 +142,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition()
expectToken(Token::COLON); expectToken(Token::COLON);
} }
else if (currentToken == Token::FUNCTION) else if (currentToken == Token::FUNCTION)
functions.push_back(parseFunctionDefinition(visibilityIsPublic)); functions.push_back(parseFunctionDefinition(visibilityIsPublic, name.get()));
else if (currentToken == Token::STRUCT) else if (currentToken == Token::STRUCT)
structs.push_back(parseStructDefinition()); structs.push_back(parseStructDefinition());
else if (currentToken == Token::IDENTIFIER || currentToken == Token::MAPPING || else if (currentToken == Token::IDENTIFIER || currentToken == Token::MAPPING ||
@ -157,7 +157,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition()
} }
nodeFactory.markEndPosition(); nodeFactory.markEndPosition();
expectToken(Token::RBRACE); expectToken(Token::RBRACE);
return nodeFactory.createNode<ContractDefinition>(name, docstring, baseContracts, structs, return nodeFactory.createNode<ContractDefinition>(name, docString, baseContracts, structs,
stateVariables, functions); stateVariables, functions);
} }
@ -178,7 +178,7 @@ ASTPointer<InheritanceSpecifier> Parser::parseInheritanceSpecifier()
return nodeFactory.createNode<InheritanceSpecifier>(name, arguments); return nodeFactory.createNode<InheritanceSpecifier>(name, arguments);
} }
ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic) ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic, ASTString const* _contractName)
{ {
ASTNodeFactory nodeFactory(*this); ASTNodeFactory nodeFactory(*this);
ASTPointer<ASTString> docstring; ASTPointer<ASTString> docstring;
@ -210,7 +210,8 @@ ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic)
} }
ASTPointer<Block> block = parseBlock(); ASTPointer<Block> block = parseBlock();
nodeFactory.setEndPositionFromNode(block); nodeFactory.setEndPositionFromNode(block);
return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, docstring, bool const c_isConstructor = (_contractName && *name == *_contractName);
return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, c_isConstructor, docstring,
parameters, parameters,
isDeclaredConst, returnParameters, block); isDeclaredConst, returnParameters, block);
} }

View File

@ -50,7 +50,7 @@ private:
ASTPointer<ImportDirective> parseImportDirective(); ASTPointer<ImportDirective> parseImportDirective();
ASTPointer<ContractDefinition> parseContractDefinition(); ASTPointer<ContractDefinition> parseContractDefinition();
ASTPointer<InheritanceSpecifier> parseInheritanceSpecifier(); ASTPointer<InheritanceSpecifier> parseInheritanceSpecifier();
ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic); ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic, ASTString const* _contractName);
ASTPointer<StructDefinition> parseStructDefinition(); ASTPointer<StructDefinition> parseStructDefinition();
ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar); ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar);
ASTPointer<TypeName> parseTypeName(bool _allowVar); ASTPointer<TypeName> parseTypeName(bool _allowVar);

View File

@ -716,7 +716,7 @@ MemberList const& TypeType::getMembers() const
// We are accessing the type of a base contract, so add all public and private // We are accessing the type of a base contract, so add all public and private
// functions. Note that this does not add inherited functions on purpose. // functions. Note that this does not add inherited functions on purpose.
for (ASTPointer<FunctionDefinition> const& f: contract.getDefinedFunctions()) for (ASTPointer<FunctionDefinition> const& f: contract.getDefinedFunctions())
if (f->getName() != contract.getName()) if (!f->isConstructor())
members[f->getName()] = make_shared<FunctionType>(*f); members[f->getName()] = make_shared<FunctionType>(*f);
} }
m_members.reset(new MemberList(members)); m_members.reset(new MemberList(members));