mirror of
https://github.com/ethereum/solidity
synced 2023-10-03 13:03:40 +00:00
Merge pull request #830 from chriseth/sol_overridesInConstructorContext
Include virtual function overrides in constructor context.
This commit is contained in:
commit
30b455e4d6
7
AST.cpp
7
AST.cpp
@ -82,7 +82,7 @@ map<FixedHash<4>, FunctionDefinition const*> ContractDefinition::getInterfaceFun
|
||||
FunctionDefinition const* ContractDefinition::getConstructor() const
|
||||
{
|
||||
for (ASTPointer<FunctionDefinition> const& f: m_definedFunctions)
|
||||
if (f->getName() == getName())
|
||||
if (f->isConstructor())
|
||||
return f.get();
|
||||
return nullptr;
|
||||
}
|
||||
@ -95,7 +95,7 @@ void ContractDefinition::checkIllegalOverrides() const
|
||||
for (ContractDefinition const* contract: getLinearizedBaseContracts())
|
||||
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
||||
{
|
||||
if (function->getName() == contract->getName())
|
||||
if (function->isConstructor())
|
||||
continue; // constructors can neither be overriden nor override anything
|
||||
FunctionDefinition const*& override = functions[function->getName()];
|
||||
if (!override)
|
||||
@ -115,8 +115,7 @@ vector<pair<FixedHash<4>, FunctionDefinition const*>> const& ContractDefinition:
|
||||
m_interfaceFunctionList.reset(new vector<pair<FixedHash<4>, FunctionDefinition const*>>());
|
||||
for (ContractDefinition const* contract: getLinearizedBaseContracts())
|
||||
for (ASTPointer<FunctionDefinition> const& f: contract->getDefinedFunctions())
|
||||
if (f->isPublic() && f->getName() != contract->getName() &&
|
||||
functionsSeen.count(f->getName()) == 0)
|
||||
if (f->isPublic() && !f->isConstructor() && functionsSeen.count(f->getName()) == 0)
|
||||
{
|
||||
functionsSeen.insert(f->getName());
|
||||
FixedHash<4> hash(dev::sha3(f->getCanonicalSignature()));
|
||||
|
5
AST.h
5
AST.h
@ -281,12 +281,13 @@ class FunctionDefinition: public Declaration
|
||||
public:
|
||||
FunctionDefinition(Location const& _location, ASTPointer<ASTString> const& _name,
|
||||
bool _isPublic,
|
||||
bool _isConstructor,
|
||||
ASTPointer<ASTString> const& _documentation,
|
||||
ASTPointer<ParameterList> const& _parameters,
|
||||
bool _isDeclaredConst,
|
||||
ASTPointer<ParameterList> const& _returnParameters,
|
||||
ASTPointer<Block> const& _body):
|
||||
Declaration(_location, _name), m_isPublic(_isPublic),
|
||||
Declaration(_location, _name), m_isPublic(_isPublic), m_isConstructor(_isConstructor),
|
||||
m_parameters(_parameters),
|
||||
m_isDeclaredConst(_isDeclaredConst),
|
||||
m_returnParameters(_returnParameters),
|
||||
@ -298,6 +299,7 @@ public:
|
||||
virtual void accept(ASTConstVisitor& _visitor) const override;
|
||||
|
||||
bool isPublic() const { return m_isPublic; }
|
||||
bool isConstructor() const { return m_isConstructor; }
|
||||
bool isDeclaredConst() const { return m_isDeclaredConst; }
|
||||
std::vector<ASTPointer<VariableDeclaration>> const& getParameters() const { return m_parameters->getParameters(); }
|
||||
ParameterList const& getParameterList() const { return *m_parameters; }
|
||||
@ -321,6 +323,7 @@ public:
|
||||
|
||||
private:
|
||||
bool m_isPublic;
|
||||
bool m_isConstructor;
|
||||
ASTPointer<ParameterList> m_parameters;
|
||||
bool m_isDeclaredConst;
|
||||
ASTPointer<ParameterList> m_returnParameters;
|
||||
|
@ -38,6 +38,7 @@ void CallGraph::addNode(ASTNode const& _node)
|
||||
|
||||
set<FunctionDefinition const*> const& CallGraph::getCalls()
|
||||
{
|
||||
computeCallGraph();
|
||||
return m_functionsSeen;
|
||||
}
|
||||
|
||||
@ -45,8 +46,7 @@ void CallGraph::computeCallGraph()
|
||||
{
|
||||
while (!m_workQueue.empty())
|
||||
{
|
||||
FunctionDefinition const* fun = m_workQueue.front();
|
||||
fun->accept(*this);
|
||||
m_workQueue.front()->accept(*this);
|
||||
m_workQueue.pop();
|
||||
}
|
||||
}
|
||||
@ -55,7 +55,12 @@ bool CallGraph::visit(Identifier const& _identifier)
|
||||
{
|
||||
FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration());
|
||||
if (fun)
|
||||
{
|
||||
if (m_overrideResolver)
|
||||
fun = (*m_overrideResolver)(fun->getName());
|
||||
solAssert(fun, "Error finding override for function " + fun->getName());
|
||||
addFunction(*fun);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -22,6 +22,7 @@
|
||||
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include <functional>
|
||||
#include <boost/range/iterator_range.hpp>
|
||||
#include <libsolidity/ASTVisitor.h>
|
||||
|
||||
@ -38,8 +39,11 @@ namespace solidity
|
||||
class CallGraph: private ASTConstVisitor
|
||||
{
|
||||
public:
|
||||
using OverrideResolver = std::function<FunctionDefinition const*(std::string const&)>;
|
||||
|
||||
CallGraph(OverrideResolver const& _overrideResolver): m_overrideResolver(&_overrideResolver) {}
|
||||
|
||||
void addNode(ASTNode const& _node);
|
||||
void computeCallGraph();
|
||||
|
||||
std::set<FunctionDefinition const*> const& getCalls();
|
||||
|
||||
@ -48,8 +52,10 @@ private:
|
||||
virtual bool visit(Identifier const& _identifier) override;
|
||||
virtual bool visit(MemberAccess const& _memberAccess) override;
|
||||
|
||||
void computeCallGraph();
|
||||
void addFunction(FunctionDefinition const& _function);
|
||||
|
||||
OverrideResolver const* m_overrideResolver;
|
||||
std::set<FunctionDefinition const*> m_functionsSeen;
|
||||
std::queue<FunctionDefinition const*> m_workQueue;
|
||||
};
|
||||
|
33
Compiler.cpp
33
Compiler.cpp
@ -43,13 +43,13 @@ void Compiler::compileContract(ContractDefinition const& _contract,
|
||||
|
||||
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
|
||||
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
||||
if (function->getName() != contract->getName()) // don't add the constructor here
|
||||
if (!function->isConstructor())
|
||||
m_context.addFunction(*function);
|
||||
|
||||
appendFunctionSelector(_contract);
|
||||
for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
|
||||
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
||||
if (function->getName() != contract->getName()) // don't add the constructor here
|
||||
if (!function->isConstructor())
|
||||
function->accept(*this);
|
||||
|
||||
// Swap the runtime context with the creation-time context
|
||||
@ -93,10 +93,29 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
|
||||
}
|
||||
}
|
||||
|
||||
//@TODO add virtual functions
|
||||
neededFunctions = getFunctionsCalled(nodesUsedInConstructors);
|
||||
auto overrideResolver = [&](string const& _name) -> FunctionDefinition const*
|
||||
{
|
||||
for (ContractDefinition const* contract: bases)
|
||||
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
|
||||
if (!function->isConstructor() && function->getName() == _name)
|
||||
return function.get();
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
neededFunctions = getFunctionsCalled(nodesUsedInConstructors, overrideResolver);
|
||||
|
||||
// First add all overrides (or the functions themselves if there is no override)
|
||||
for (FunctionDefinition const* fun: neededFunctions)
|
||||
{
|
||||
FunctionDefinition const* override = nullptr;
|
||||
if (!fun->isConstructor())
|
||||
override = overrideResolver(fun->getName());
|
||||
if (!!override && neededFunctions.count(override))
|
||||
m_context.addFunction(*override);
|
||||
}
|
||||
// now add the rest
|
||||
for (FunctionDefinition const* fun: neededFunctions)
|
||||
if (fun->isConstructor() || overrideResolver(fun->getName()) != fun)
|
||||
m_context.addFunction(*fun);
|
||||
|
||||
// Call constructors in base-to-derived order.
|
||||
@ -159,10 +178,10 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor)
|
||||
m_context << returnTag;
|
||||
}
|
||||
|
||||
set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes)
|
||||
set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes,
|
||||
function<FunctionDefinition const*(string const&)> const& _resolveOverrides)
|
||||
{
|
||||
// TODO this does not add virtual functions
|
||||
CallGraph callgraph;
|
||||
CallGraph callgraph(_resolveOverrides);
|
||||
for (ASTNode const* node: _nodes)
|
||||
callgraph.addNode(*node);
|
||||
return callgraph.getCalls();
|
||||
|
@ -21,6 +21,7 @@
|
||||
*/
|
||||
|
||||
#include <ostream>
|
||||
#include <functional>
|
||||
#include <libsolidity/ASTVisitor.h>
|
||||
#include <libsolidity/CompilerContext.h>
|
||||
|
||||
@ -49,7 +50,9 @@ private:
|
||||
std::vector<ASTPointer<Expression>> const& _arguments);
|
||||
void appendConstructorCall(FunctionDefinition const& _constructor);
|
||||
/// Recursively searches the call graph and returns all functions referenced inside _nodes.
|
||||
std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes);
|
||||
/// _resolveOverride is called to resolve virtual function overrides.
|
||||
std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes,
|
||||
std::function<FunctionDefinition const*(std::string const&)> const& _resolveOverride);
|
||||
void appendFunctionSelector(ContractDefinition const& _contract);
|
||||
/// Creates code that unpacks the arguments for the given function, from memory if
|
||||
/// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.
|
||||
|
13
Parser.cpp
13
Parser.cpp
@ -112,9 +112,9 @@ ASTPointer<ImportDirective> Parser::parseImportDirective()
|
||||
ASTPointer<ContractDefinition> Parser::parseContractDefinition()
|
||||
{
|
||||
ASTNodeFactory nodeFactory(*this);
|
||||
ASTPointer<ASTString> docstring;
|
||||
ASTPointer<ASTString> docString;
|
||||
if (m_scanner->getCurrentCommentLiteral() != "")
|
||||
docstring = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral());
|
||||
docString = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral());
|
||||
expectToken(Token::CONTRACT);
|
||||
ASTPointer<ASTString> name = expectIdentifierToken();
|
||||
vector<ASTPointer<InheritanceSpecifier>> baseContracts;
|
||||
@ -142,7 +142,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition()
|
||||
expectToken(Token::COLON);
|
||||
}
|
||||
else if (currentToken == Token::FUNCTION)
|
||||
functions.push_back(parseFunctionDefinition(visibilityIsPublic));
|
||||
functions.push_back(parseFunctionDefinition(visibilityIsPublic, name.get()));
|
||||
else if (currentToken == Token::STRUCT)
|
||||
structs.push_back(parseStructDefinition());
|
||||
else if (currentToken == Token::IDENTIFIER || currentToken == Token::MAPPING ||
|
||||
@ -157,7 +157,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition()
|
||||
}
|
||||
nodeFactory.markEndPosition();
|
||||
expectToken(Token::RBRACE);
|
||||
return nodeFactory.createNode<ContractDefinition>(name, docstring, baseContracts, structs,
|
||||
return nodeFactory.createNode<ContractDefinition>(name, docString, baseContracts, structs,
|
||||
stateVariables, functions);
|
||||
}
|
||||
|
||||
@ -178,7 +178,7 @@ ASTPointer<InheritanceSpecifier> Parser::parseInheritanceSpecifier()
|
||||
return nodeFactory.createNode<InheritanceSpecifier>(name, arguments);
|
||||
}
|
||||
|
||||
ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic)
|
||||
ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic, ASTString const* _contractName)
|
||||
{
|
||||
ASTNodeFactory nodeFactory(*this);
|
||||
ASTPointer<ASTString> docstring;
|
||||
@ -210,7 +210,8 @@ ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic)
|
||||
}
|
||||
ASTPointer<Block> block = parseBlock();
|
||||
nodeFactory.setEndPositionFromNode(block);
|
||||
return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, docstring,
|
||||
bool const c_isConstructor = (_contractName && *name == *_contractName);
|
||||
return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, c_isConstructor, docstring,
|
||||
parameters,
|
||||
isDeclaredConst, returnParameters, block);
|
||||
}
|
||||
|
2
Parser.h
2
Parser.h
@ -50,7 +50,7 @@ private:
|
||||
ASTPointer<ImportDirective> parseImportDirective();
|
||||
ASTPointer<ContractDefinition> parseContractDefinition();
|
||||
ASTPointer<InheritanceSpecifier> parseInheritanceSpecifier();
|
||||
ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic);
|
||||
ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic, ASTString const* _contractName);
|
||||
ASTPointer<StructDefinition> parseStructDefinition();
|
||||
ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar);
|
||||
ASTPointer<TypeName> parseTypeName(bool _allowVar);
|
||||
|
@ -716,7 +716,7 @@ MemberList const& TypeType::getMembers() const
|
||||
// We are accessing the type of a base contract, so add all public and private
|
||||
// functions. Note that this does not add inherited functions on purpose.
|
||||
for (ASTPointer<FunctionDefinition> const& f: contract.getDefinedFunctions())
|
||||
if (f->getName() != contract.getName())
|
||||
if (!f->isConstructor())
|
||||
members[f->getName()] = make_shared<FunctionType>(*f);
|
||||
}
|
||||
m_members.reset(new MemberList(members));
|
||||
|
Loading…
Reference in New Issue
Block a user