mirror of
https://github.com/ethereum/solidity
synced 2023-10-03 13:03:40 +00:00
Merge pull request #830 from chriseth/sol_overridesInConstructorContext
Include virtual function overrides in constructor context.
This commit is contained in:
commit
30b455e4d6
7
AST.cpp
7
AST.cpp
@ -82,7 +82,7 @@ map<FixedHash<4>, FunctionDefinition const*> ContractDefinition::getInterfaceFun
|
|||||||
FunctionDefinition const* ContractDefinition::getConstructor() const
|
FunctionDefinition const* ContractDefinition::getConstructor() const
|
||||||
{
|
{
|
||||||
for (ASTPointer<FunctionDefinition> const& f: m_definedFunctions)
|
for (ASTPointer<FunctionDefinition> const& f: m_definedFunctions)
|
||||||
if (f->getName() == getName())
|
if (f->isConstructor())
|
||||||
return f.get();
|
return f.get();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -95,7 +95,7 @@ void ContractDefinition::checkIllegalOverrides() const
|
|||||||
for (ContractDefinition const* contract: getLinearizedBaseContracts())
|
for (ContractDefinition const* contract: getLinearizedBaseContracts())
|
||||||
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
||||||
{
|
{
|
||||||
if (function->getName() == contract->getName())
|
if (function->isConstructor())
|
||||||
continue; // constructors can neither be overriden nor override anything
|
continue; // constructors can neither be overriden nor override anything
|
||||||
FunctionDefinition const*& override = functions[function->getName()];
|
FunctionDefinition const*& override = functions[function->getName()];
|
||||||
if (!override)
|
if (!override)
|
||||||
@ -115,8 +115,7 @@ vector<pair<FixedHash<4>, FunctionDefinition const*>> const& ContractDefinition:
|
|||||||
m_interfaceFunctionList.reset(new vector<pair<FixedHash<4>, FunctionDefinition const*>>());
|
m_interfaceFunctionList.reset(new vector<pair<FixedHash<4>, FunctionDefinition const*>>());
|
||||||
for (ContractDefinition const* contract: getLinearizedBaseContracts())
|
for (ContractDefinition const* contract: getLinearizedBaseContracts())
|
||||||
for (ASTPointer<FunctionDefinition> const& f: contract->getDefinedFunctions())
|
for (ASTPointer<FunctionDefinition> const& f: contract->getDefinedFunctions())
|
||||||
if (f->isPublic() && f->getName() != contract->getName() &&
|
if (f->isPublic() && !f->isConstructor() && functionsSeen.count(f->getName()) == 0)
|
||||||
functionsSeen.count(f->getName()) == 0)
|
|
||||||
{
|
{
|
||||||
functionsSeen.insert(f->getName());
|
functionsSeen.insert(f->getName());
|
||||||
FixedHash<4> hash(dev::sha3(f->getCanonicalSignature()));
|
FixedHash<4> hash(dev::sha3(f->getCanonicalSignature()));
|
||||||
|
5
AST.h
5
AST.h
@ -281,12 +281,13 @@ class FunctionDefinition: public Declaration
|
|||||||
public:
|
public:
|
||||||
FunctionDefinition(Location const& _location, ASTPointer<ASTString> const& _name,
|
FunctionDefinition(Location const& _location, ASTPointer<ASTString> const& _name,
|
||||||
bool _isPublic,
|
bool _isPublic,
|
||||||
|
bool _isConstructor,
|
||||||
ASTPointer<ASTString> const& _documentation,
|
ASTPointer<ASTString> const& _documentation,
|
||||||
ASTPointer<ParameterList> const& _parameters,
|
ASTPointer<ParameterList> const& _parameters,
|
||||||
bool _isDeclaredConst,
|
bool _isDeclaredConst,
|
||||||
ASTPointer<ParameterList> const& _returnParameters,
|
ASTPointer<ParameterList> const& _returnParameters,
|
||||||
ASTPointer<Block> const& _body):
|
ASTPointer<Block> const& _body):
|
||||||
Declaration(_location, _name), m_isPublic(_isPublic),
|
Declaration(_location, _name), m_isPublic(_isPublic), m_isConstructor(_isConstructor),
|
||||||
m_parameters(_parameters),
|
m_parameters(_parameters),
|
||||||
m_isDeclaredConst(_isDeclaredConst),
|
m_isDeclaredConst(_isDeclaredConst),
|
||||||
m_returnParameters(_returnParameters),
|
m_returnParameters(_returnParameters),
|
||||||
@ -298,6 +299,7 @@ public:
|
|||||||
virtual void accept(ASTConstVisitor& _visitor) const override;
|
virtual void accept(ASTConstVisitor& _visitor) const override;
|
||||||
|
|
||||||
bool isPublic() const { return m_isPublic; }
|
bool isPublic() const { return m_isPublic; }
|
||||||
|
bool isConstructor() const { return m_isConstructor; }
|
||||||
bool isDeclaredConst() const { return m_isDeclaredConst; }
|
bool isDeclaredConst() const { return m_isDeclaredConst; }
|
||||||
std::vector<ASTPointer<VariableDeclaration>> const& getParameters() const { return m_parameters->getParameters(); }
|
std::vector<ASTPointer<VariableDeclaration>> const& getParameters() const { return m_parameters->getParameters(); }
|
||||||
ParameterList const& getParameterList() const { return *m_parameters; }
|
ParameterList const& getParameterList() const { return *m_parameters; }
|
||||||
@ -321,6 +323,7 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
bool m_isPublic;
|
bool m_isPublic;
|
||||||
|
bool m_isConstructor;
|
||||||
ASTPointer<ParameterList> m_parameters;
|
ASTPointer<ParameterList> m_parameters;
|
||||||
bool m_isDeclaredConst;
|
bool m_isDeclaredConst;
|
||||||
ASTPointer<ParameterList> m_returnParameters;
|
ASTPointer<ParameterList> m_returnParameters;
|
||||||
|
@ -38,6 +38,7 @@ void CallGraph::addNode(ASTNode const& _node)
|
|||||||
|
|
||||||
set<FunctionDefinition const*> const& CallGraph::getCalls()
|
set<FunctionDefinition const*> const& CallGraph::getCalls()
|
||||||
{
|
{
|
||||||
|
computeCallGraph();
|
||||||
return m_functionsSeen;
|
return m_functionsSeen;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,8 +46,7 @@ void CallGraph::computeCallGraph()
|
|||||||
{
|
{
|
||||||
while (!m_workQueue.empty())
|
while (!m_workQueue.empty())
|
||||||
{
|
{
|
||||||
FunctionDefinition const* fun = m_workQueue.front();
|
m_workQueue.front()->accept(*this);
|
||||||
fun->accept(*this);
|
|
||||||
m_workQueue.pop();
|
m_workQueue.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -55,7 +55,12 @@ bool CallGraph::visit(Identifier const& _identifier)
|
|||||||
{
|
{
|
||||||
FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration());
|
FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration());
|
||||||
if (fun)
|
if (fun)
|
||||||
|
{
|
||||||
|
if (m_overrideResolver)
|
||||||
|
fun = (*m_overrideResolver)(fun->getName());
|
||||||
|
solAssert(fun, "Error finding override for function " + fun->getName());
|
||||||
addFunction(*fun);
|
addFunction(*fun);
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@
|
|||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
#include <functional>
|
||||||
#include <boost/range/iterator_range.hpp>
|
#include <boost/range/iterator_range.hpp>
|
||||||
#include <libsolidity/ASTVisitor.h>
|
#include <libsolidity/ASTVisitor.h>
|
||||||
|
|
||||||
@ -38,8 +39,11 @@ namespace solidity
|
|||||||
class CallGraph: private ASTConstVisitor
|
class CallGraph: private ASTConstVisitor
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
using OverrideResolver = std::function<FunctionDefinition const*(std::string const&)>;
|
||||||
|
|
||||||
|
CallGraph(OverrideResolver const& _overrideResolver): m_overrideResolver(&_overrideResolver) {}
|
||||||
|
|
||||||
void addNode(ASTNode const& _node);
|
void addNode(ASTNode const& _node);
|
||||||
void computeCallGraph();
|
|
||||||
|
|
||||||
std::set<FunctionDefinition const*> const& getCalls();
|
std::set<FunctionDefinition const*> const& getCalls();
|
||||||
|
|
||||||
@ -48,8 +52,10 @@ private:
|
|||||||
virtual bool visit(Identifier const& _identifier) override;
|
virtual bool visit(Identifier const& _identifier) override;
|
||||||
virtual bool visit(MemberAccess const& _memberAccess) override;
|
virtual bool visit(MemberAccess const& _memberAccess) override;
|
||||||
|
|
||||||
|
void computeCallGraph();
|
||||||
void addFunction(FunctionDefinition const& _function);
|
void addFunction(FunctionDefinition const& _function);
|
||||||
|
|
||||||
|
OverrideResolver const* m_overrideResolver;
|
||||||
std::set<FunctionDefinition const*> m_functionsSeen;
|
std::set<FunctionDefinition const*> m_functionsSeen;
|
||||||
std::queue<FunctionDefinition const*> m_workQueue;
|
std::queue<FunctionDefinition const*> m_workQueue;
|
||||||
};
|
};
|
||||||
|
33
Compiler.cpp
33
Compiler.cpp
@ -43,13 +43,13 @@ void Compiler::compileContract(ContractDefinition const& _contract,
|
|||||||
|
|
||||||
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
|
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
|
||||||
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
||||||
if (function->getName() != contract->getName()) // don't add the constructor here
|
if (!function->isConstructor())
|
||||||
m_context.addFunction(*function);
|
m_context.addFunction(*function);
|
||||||
|
|
||||||
appendFunctionSelector(_contract);
|
appendFunctionSelector(_contract);
|
||||||
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
|
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
|
||||||
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
||||||
if (function->getName() != contract->getName()) // don't add the constructor here
|
if (!function->isConstructor())
|
||||||
function->accept(*this);
|
function->accept(*this);
|
||||||
|
|
||||||
// Swap the runtime context with the creation-time context
|
// Swap the runtime context with the creation-time context
|
||||||
@ -93,10 +93,29 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//@TODO add virtual functions
|
auto overrideResolver = [&](string const& _name) -> FunctionDefinition const*
|
||||||
neededFunctions = getFunctionsCalled(nodesUsedInConstructors);
|
{
|
||||||
|
for (ContractDefinition const* contract: bases)
|
||||||
|
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
||||||
|
if (!function->isConstructor() && function->getName() == _name)
|
||||||
|
return function.get();
|
||||||
|
return nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
neededFunctions = getFunctionsCalled(nodesUsedInConstructors, overrideResolver);
|
||||||
|
|
||||||
|
// First add all overrides (or the functions themselves if there is no override)
|
||||||
for (FunctionDefinition const* fun: neededFunctions)
|
for (FunctionDefinition const* fun: neededFunctions)
|
||||||
|
{
|
||||||
|
FunctionDefinition const* override = nullptr;
|
||||||
|
if (!fun->isConstructor())
|
||||||
|
override = overrideResolver(fun->getName());
|
||||||
|
if (!!override && neededFunctions.count(override))
|
||||||
|
m_context.addFunction(*override);
|
||||||
|
}
|
||||||
|
// now add the rest
|
||||||
|
for (FunctionDefinition const* fun: neededFunctions)
|
||||||
|
if (fun->isConstructor() || overrideResolver(fun->getName()) != fun)
|
||||||
m_context.addFunction(*fun);
|
m_context.addFunction(*fun);
|
||||||
|
|
||||||
// Call constructors in base-to-derived order.
|
// Call constructors in base-to-derived order.
|
||||||
@ -159,10 +178,10 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor)
|
|||||||
m_context << returnTag;
|
m_context << returnTag;
|
||||||
}
|
}
|
||||||
|
|
||||||
set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes)
|
set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes,
|
||||||
|
function<FunctionDefinition const*(string const&)> const& _resolveOverrides)
|
||||||
{
|
{
|
||||||
// TODO this does not add virtual functions
|
CallGraph callgraph(_resolveOverrides);
|
||||||
CallGraph callgraph;
|
|
||||||
for (ASTNode const* node: _nodes)
|
for (ASTNode const* node: _nodes)
|
||||||
callgraph.addNode(*node);
|
callgraph.addNode(*node);
|
||||||
return callgraph.getCalls();
|
return callgraph.getCalls();
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
|
#include <functional>
|
||||||
#include <libsolidity/ASTVisitor.h>
|
#include <libsolidity/ASTVisitor.h>
|
||||||
#include <libsolidity/CompilerContext.h>
|
#include <libsolidity/CompilerContext.h>
|
||||||
|
|
||||||
@ -49,7 +50,9 @@ private:
|
|||||||
std::vector<ASTPointer<Expression>> const& _arguments);
|
std::vector<ASTPointer<Expression>> const& _arguments);
|
||||||
void appendConstructorCall(FunctionDefinition const& _constructor);
|
void appendConstructorCall(FunctionDefinition const& _constructor);
|
||||||
/// Recursively searches the call graph and returns all functions referenced inside _nodes.
|
/// Recursively searches the call graph and returns all functions referenced inside _nodes.
|
||||||
std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes);
|
/// _resolveOverride is called to resolve virtual function overrides.
|
||||||
|
std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes,
|
||||||
|
std::function<FunctionDefinition const*(std::string const&)> const& _resolveOverride);
|
||||||
void appendFunctionSelector(ContractDefinition const& _contract);
|
void appendFunctionSelector(ContractDefinition const& _contract);
|
||||||
/// Creates code that unpacks the arguments for the given function, from memory if
|
/// Creates code that unpacks the arguments for the given function, from memory if
|
||||||
/// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.
|
/// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.
|
||||||
|
13
Parser.cpp
13
Parser.cpp
@ -112,9 +112,9 @@ ASTPointer<ImportDirective> Parser::parseImportDirective()
|
|||||||
ASTPointer<ContractDefinition> Parser::parseContractDefinition()
|
ASTPointer<ContractDefinition> Parser::parseContractDefinition()
|
||||||
{
|
{
|
||||||
ASTNodeFactory nodeFactory(*this);
|
ASTNodeFactory nodeFactory(*this);
|
||||||
ASTPointer<ASTString> docstring;
|
ASTPointer<ASTString> docString;
|
||||||
if (m_scanner->getCurrentCommentLiteral() != "")
|
if (m_scanner->getCurrentCommentLiteral() != "")
|
||||||
docstring = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral());
|
docString = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral());
|
||||||
expectToken(Token::CONTRACT);
|
expectToken(Token::CONTRACT);
|
||||||
ASTPointer<ASTString> name = expectIdentifierToken();
|
ASTPointer<ASTString> name = expectIdentifierToken();
|
||||||
vector<ASTPointer<InheritanceSpecifier>> baseContracts;
|
vector<ASTPointer<InheritanceSpecifier>> baseContracts;
|
||||||
@ -142,7 +142,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition()
|
|||||||
expectToken(Token::COLON);
|
expectToken(Token::COLON);
|
||||||
}
|
}
|
||||||
else if (currentToken == Token::FUNCTION)
|
else if (currentToken == Token::FUNCTION)
|
||||||
functions.push_back(parseFunctionDefinition(visibilityIsPublic));
|
functions.push_back(parseFunctionDefinition(visibilityIsPublic, name.get()));
|
||||||
else if (currentToken == Token::STRUCT)
|
else if (currentToken == Token::STRUCT)
|
||||||
structs.push_back(parseStructDefinition());
|
structs.push_back(parseStructDefinition());
|
||||||
else if (currentToken == Token::IDENTIFIER || currentToken == Token::MAPPING ||
|
else if (currentToken == Token::IDENTIFIER || currentToken == Token::MAPPING ||
|
||||||
@ -157,7 +157,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition()
|
|||||||
}
|
}
|
||||||
nodeFactory.markEndPosition();
|
nodeFactory.markEndPosition();
|
||||||
expectToken(Token::RBRACE);
|
expectToken(Token::RBRACE);
|
||||||
return nodeFactory.createNode<ContractDefinition>(name, docstring, baseContracts, structs,
|
return nodeFactory.createNode<ContractDefinition>(name, docString, baseContracts, structs,
|
||||||
stateVariables, functions);
|
stateVariables, functions);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ ASTPointer<InheritanceSpecifier> Parser::parseInheritanceSpecifier()
|
|||||||
return nodeFactory.createNode<InheritanceSpecifier>(name, arguments);
|
return nodeFactory.createNode<InheritanceSpecifier>(name, arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic)
|
ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic, ASTString const* _contractName)
|
||||||
{
|
{
|
||||||
ASTNodeFactory nodeFactory(*this);
|
ASTNodeFactory nodeFactory(*this);
|
||||||
ASTPointer<ASTString> docstring;
|
ASTPointer<ASTString> docstring;
|
||||||
@ -210,7 +210,8 @@ ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic)
|
|||||||
}
|
}
|
||||||
ASTPointer<Block> block = parseBlock();
|
ASTPointer<Block> block = parseBlock();
|
||||||
nodeFactory.setEndPositionFromNode(block);
|
nodeFactory.setEndPositionFromNode(block);
|
||||||
return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, docstring,
|
bool const c_isConstructor = (_contractName && *name == *_contractName);
|
||||||
|
return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, c_isConstructor, docstring,
|
||||||
parameters,
|
parameters,
|
||||||
isDeclaredConst, returnParameters, block);
|
isDeclaredConst, returnParameters, block);
|
||||||
}
|
}
|
||||||
|
2
Parser.h
2
Parser.h
@ -50,7 +50,7 @@ private:
|
|||||||
ASTPointer<ImportDirective> parseImportDirective();
|
ASTPointer<ImportDirective> parseImportDirective();
|
||||||
ASTPointer<ContractDefinition> parseContractDefinition();
|
ASTPointer<ContractDefinition> parseContractDefinition();
|
||||||
ASTPointer<InheritanceSpecifier> parseInheritanceSpecifier();
|
ASTPointer<InheritanceSpecifier> parseInheritanceSpecifier();
|
||||||
ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic);
|
ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic, ASTString const* _contractName);
|
||||||
ASTPointer<StructDefinition> parseStructDefinition();
|
ASTPointer<StructDefinition> parseStructDefinition();
|
||||||
ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar);
|
ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar);
|
||||||
ASTPointer<TypeName> parseTypeName(bool _allowVar);
|
ASTPointer<TypeName> parseTypeName(bool _allowVar);
|
||||||
|
@ -716,7 +716,7 @@ MemberList const& TypeType::getMembers() const
|
|||||||
// We are accessing the type of a base contract, so add all public and private
|
// We are accessing the type of a base contract, so add all public and private
|
||||||
// functions. Note that this does not add inherited functions on purpose.
|
// functions. Note that this does not add inherited functions on purpose.
|
||||||
for (ASTPointer<FunctionDefinition> const& f: contract.getDefinedFunctions())
|
for (ASTPointer<FunctionDefinition> const& f: contract.getDefinedFunctions())
|
||||||
if (f->getName() != contract.getName())
|
if (!f->isConstructor())
|
||||||
members[f->getName()] = make_shared<FunctionType>(*f);
|
members[f->getName()] = make_shared<FunctionType>(*f);
|
||||||
}
|
}
|
||||||
m_members.reset(new MemberList(members));
|
m_members.reset(new MemberList(members));
|
||||||
|
Loading…
Reference in New Issue
Block a user