mirror of
				https://github.com/ethereum/solidity
				synced 2023-10-03 13:03:40 +00:00 
			
		
		
		
	Merge pull request #830 from chriseth/sol_overridesInConstructorContext
Include virtual function overrides in constructor context.
This commit is contained in:
		
						commit
						30b455e4d6
					
				
							
								
								
									
										7
									
								
								AST.cpp
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								AST.cpp
									
									
									
									
									
								
							| @ -82,7 +82,7 @@ map<FixedHash<4>, FunctionDefinition const*> ContractDefinition::getInterfaceFun | ||||
| FunctionDefinition const* ContractDefinition::getConstructor() const | ||||
| { | ||||
| 	for (ASTPointer<FunctionDefinition> const& f: m_definedFunctions) | ||||
| 		if (f->getName() == getName()) | ||||
| 		if (f->isConstructor()) | ||||
| 			return f.get(); | ||||
| 	return nullptr; | ||||
| } | ||||
| @ -95,7 +95,7 @@ void ContractDefinition::checkIllegalOverrides() const | ||||
| 	for (ContractDefinition const* contract: getLinearizedBaseContracts()) | ||||
| 		for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) | ||||
| 		{ | ||||
| 			if (function->getName() == contract->getName()) | ||||
| 			if (function->isConstructor()) | ||||
| 				continue; // constructors can neither be overriden nor override anything
 | ||||
| 			FunctionDefinition const*& override = functions[function->getName()]; | ||||
| 			if (!override) | ||||
| @ -115,8 +115,7 @@ vector<pair<FixedHash<4>, FunctionDefinition const*>> const& ContractDefinition: | ||||
| 		m_interfaceFunctionList.reset(new vector<pair<FixedHash<4>, FunctionDefinition const*>>()); | ||||
| 		for (ContractDefinition const* contract: getLinearizedBaseContracts()) | ||||
| 			for (ASTPointer<FunctionDefinition> const& f: contract->getDefinedFunctions()) | ||||
| 				if (f->isPublic() && f->getName() != contract->getName() && | ||||
| 						functionsSeen.count(f->getName()) == 0) | ||||
| 				if (f->isPublic() && !f->isConstructor() && functionsSeen.count(f->getName()) == 0) | ||||
| 				{ | ||||
| 					functionsSeen.insert(f->getName()); | ||||
| 					FixedHash<4> hash(dev::sha3(f->getCanonicalSignature())); | ||||
|  | ||||
							
								
								
									
										5
									
								
								AST.h
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								AST.h
									
									
									
									
									
								
							| @ -281,12 +281,13 @@ class FunctionDefinition: public Declaration | ||||
| public: | ||||
| 	FunctionDefinition(Location const& _location, ASTPointer<ASTString> const& _name, | ||||
| 					bool _isPublic, | ||||
| 					bool _isConstructor, | ||||
| 					ASTPointer<ASTString> const& _documentation, | ||||
| 					ASTPointer<ParameterList> const& _parameters, | ||||
| 					bool _isDeclaredConst, | ||||
| 					ASTPointer<ParameterList> const& _returnParameters, | ||||
| 					ASTPointer<Block> const& _body): | ||||
| 	Declaration(_location, _name), m_isPublic(_isPublic), | ||||
| 	Declaration(_location, _name), m_isPublic(_isPublic), m_isConstructor(_isConstructor), | ||||
| 	m_parameters(_parameters), | ||||
| 	m_isDeclaredConst(_isDeclaredConst), | ||||
| 	m_returnParameters(_returnParameters), | ||||
| @ -298,6 +299,7 @@ public: | ||||
| 	virtual void accept(ASTConstVisitor& _visitor) const override; | ||||
| 
 | ||||
| 	bool isPublic() const { return m_isPublic; } | ||||
| 	bool isConstructor() const { return m_isConstructor; } | ||||
| 	bool isDeclaredConst() const { return m_isDeclaredConst; } | ||||
| 	std::vector<ASTPointer<VariableDeclaration>> const& getParameters() const { return m_parameters->getParameters(); } | ||||
| 	ParameterList const& getParameterList() const { return *m_parameters; } | ||||
| @ -321,6 +323,7 @@ public: | ||||
| 
 | ||||
| private: | ||||
| 	bool m_isPublic; | ||||
| 	bool m_isConstructor; | ||||
| 	ASTPointer<ParameterList> m_parameters; | ||||
| 	bool m_isDeclaredConst; | ||||
| 	ASTPointer<ParameterList> m_returnParameters; | ||||
|  | ||||
| @ -38,6 +38,7 @@ void CallGraph::addNode(ASTNode const& _node) | ||||
| 
 | ||||
| set<FunctionDefinition const*> const& CallGraph::getCalls() | ||||
| { | ||||
| 	computeCallGraph(); | ||||
| 	return m_functionsSeen; | ||||
| } | ||||
| 
 | ||||
| @ -45,8 +46,7 @@ void CallGraph::computeCallGraph() | ||||
| { | ||||
| 	while (!m_workQueue.empty()) | ||||
| 	{ | ||||
| 		FunctionDefinition const* fun = m_workQueue.front(); | ||||
| 		fun->accept(*this); | ||||
| 		m_workQueue.front()->accept(*this); | ||||
| 		m_workQueue.pop(); | ||||
| 	} | ||||
| } | ||||
| @ -55,7 +55,12 @@ bool CallGraph::visit(Identifier const& _identifier) | ||||
| { | ||||
| 	FunctionDefinition const* fun = dynamic_cast<FunctionDefinition const*>(_identifier.getReferencedDeclaration()); | ||||
| 	if (fun) | ||||
| 	{ | ||||
| 		if (m_overrideResolver) | ||||
| 			fun = (*m_overrideResolver)(fun->getName()); | ||||
| 		solAssert(fun, "Error finding override for function " + fun->getName()); | ||||
| 		addFunction(*fun); | ||||
| 	} | ||||
| 	return true; | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -22,6 +22,7 @@ | ||||
| 
 | ||||
| #include <set> | ||||
| #include <queue> | ||||
| #include <functional> | ||||
| #include <boost/range/iterator_range.hpp> | ||||
| #include <libsolidity/ASTVisitor.h> | ||||
| 
 | ||||
| @ -38,8 +39,11 @@ namespace solidity | ||||
| class CallGraph: private ASTConstVisitor | ||||
| { | ||||
| public: | ||||
| 	using OverrideResolver = std::function<FunctionDefinition const*(std::string const&)>; | ||||
| 
 | ||||
| 	CallGraph(OverrideResolver const& _overrideResolver): m_overrideResolver(&_overrideResolver) {} | ||||
| 
 | ||||
| 	void addNode(ASTNode const& _node); | ||||
| 	void computeCallGraph(); | ||||
| 
 | ||||
| 	std::set<FunctionDefinition const*> const& getCalls(); | ||||
| 
 | ||||
| @ -48,8 +52,10 @@ private: | ||||
| 	virtual bool visit(Identifier const& _identifier) override; | ||||
| 	virtual bool visit(MemberAccess const& _memberAccess) override; | ||||
| 
 | ||||
| 	void computeCallGraph(); | ||||
| 	void addFunction(FunctionDefinition const& _function); | ||||
| 
 | ||||
| 	OverrideResolver const* m_overrideResolver; | ||||
| 	std::set<FunctionDefinition const*> m_functionsSeen; | ||||
| 	std::queue<FunctionDefinition const*> m_workQueue; | ||||
| }; | ||||
|  | ||||
							
								
								
									
										35
									
								
								Compiler.cpp
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								Compiler.cpp
									
									
									
									
									
								
							| @ -43,13 +43,13 @@ void Compiler::compileContract(ContractDefinition const& _contract, | ||||
| 
 | ||||
| 	for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) | ||||
| 		for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) | ||||
| 			if (function->getName() != contract->getName()) // don't add the constructor here
 | ||||
| 			if (!function->isConstructor()) | ||||
| 				m_context.addFunction(*function); | ||||
| 
 | ||||
| 	appendFunctionSelector(_contract); | ||||
| 	for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) | ||||
| 		for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) | ||||
| 			if (function->getName() != contract->getName()) // don't add the constructor here
 | ||||
| 			if (!function->isConstructor()) | ||||
| 				function->accept(*this); | ||||
| 
 | ||||
| 	// Swap the runtime context with the creation-time context
 | ||||
| @ -93,11 +93,30 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	//@TODO add virtual functions
 | ||||
| 	neededFunctions = getFunctionsCalled(nodesUsedInConstructors); | ||||
| 	auto overrideResolver = [&](string const& _name) -> FunctionDefinition const* | ||||
| 	{ | ||||
| 		for (ContractDefinition const* contract: bases) | ||||
| 			for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions()) | ||||
| 				if (!function->isConstructor() && function->getName() == _name) | ||||
| 					return function.get(); | ||||
| 		return nullptr; | ||||
| 	}; | ||||
| 
 | ||||
| 	neededFunctions = getFunctionsCalled(nodesUsedInConstructors, overrideResolver); | ||||
| 
 | ||||
| 	// First add all overrides (or the functions themselves if there is no override)
 | ||||
| 	for (FunctionDefinition const* fun: neededFunctions) | ||||
| 		m_context.addFunction(*fun); | ||||
| 	{ | ||||
| 		FunctionDefinition const* override = nullptr; | ||||
| 		if (!fun->isConstructor()) | ||||
| 			override = overrideResolver(fun->getName()); | ||||
| 		if (!!override && neededFunctions.count(override)) | ||||
| 			m_context.addFunction(*override); | ||||
| 	} | ||||
| 	// now add the rest
 | ||||
| 	for (FunctionDefinition const* fun: neededFunctions) | ||||
| 		if (fun->isConstructor() || overrideResolver(fun->getName()) != fun) | ||||
| 			m_context.addFunction(*fun); | ||||
| 
 | ||||
| 	// Call constructors in base-to-derived order.
 | ||||
| 	// The Constructor for the most derived contract is called later.
 | ||||
| @ -159,10 +178,10 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor) | ||||
| 	m_context << returnTag; | ||||
| } | ||||
| 
 | ||||
| set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes) | ||||
| set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes, | ||||
| 						function<FunctionDefinition const*(string const&)> const& _resolveOverrides) | ||||
| { | ||||
| 	// TODO this does not add virtual functions
 | ||||
| 	CallGraph callgraph; | ||||
| 	CallGraph callgraph(_resolveOverrides); | ||||
| 	for (ASTNode const* node: _nodes) | ||||
| 		callgraph.addNode(*node); | ||||
| 	return callgraph.getCalls(); | ||||
|  | ||||
| @ -21,6 +21,7 @@ | ||||
|  */ | ||||
| 
 | ||||
| #include <ostream> | ||||
| #include <functional> | ||||
| #include <libsolidity/ASTVisitor.h> | ||||
| #include <libsolidity/CompilerContext.h> | ||||
| 
 | ||||
| @ -49,7 +50,9 @@ private: | ||||
| 								   std::vector<ASTPointer<Expression>> const& _arguments); | ||||
| 	void appendConstructorCall(FunctionDefinition const& _constructor); | ||||
| 	/// Recursively searches the call graph and returns all functions referenced inside _nodes.
 | ||||
| 	std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes); | ||||
| 	/// _resolveOverride is called to resolve virtual function overrides.
 | ||||
| 	std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes, | ||||
| 					std::function<FunctionDefinition const*(std::string const&)> const& _resolveOverride); | ||||
| 	void appendFunctionSelector(ContractDefinition const& _contract); | ||||
| 	/// Creates code that unpacks the arguments for the given function, from memory if
 | ||||
| 	/// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.
 | ||||
|  | ||||
							
								
								
									
										13
									
								
								Parser.cpp
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								Parser.cpp
									
									
									
									
									
								
							| @ -112,9 +112,9 @@ ASTPointer<ImportDirective> Parser::parseImportDirective() | ||||
| ASTPointer<ContractDefinition> Parser::parseContractDefinition() | ||||
| { | ||||
| 	ASTNodeFactory nodeFactory(*this); | ||||
| 	ASTPointer<ASTString> docstring; | ||||
| 	ASTPointer<ASTString> docString; | ||||
| 	if (m_scanner->getCurrentCommentLiteral() != "") | ||||
| 		docstring = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral()); | ||||
| 		docString = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral()); | ||||
| 	expectToken(Token::CONTRACT); | ||||
| 	ASTPointer<ASTString> name = expectIdentifierToken(); | ||||
| 	vector<ASTPointer<InheritanceSpecifier>> baseContracts; | ||||
| @ -142,7 +142,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition() | ||||
| 			expectToken(Token::COLON); | ||||
| 		} | ||||
| 		else if (currentToken == Token::FUNCTION) | ||||
| 			functions.push_back(parseFunctionDefinition(visibilityIsPublic)); | ||||
| 			functions.push_back(parseFunctionDefinition(visibilityIsPublic, name.get())); | ||||
| 		else if (currentToken == Token::STRUCT) | ||||
| 			structs.push_back(parseStructDefinition()); | ||||
| 		else if (currentToken == Token::IDENTIFIER || currentToken == Token::MAPPING || | ||||
| @ -157,7 +157,7 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition() | ||||
| 	} | ||||
| 	nodeFactory.markEndPosition(); | ||||
| 	expectToken(Token::RBRACE); | ||||
| 	return nodeFactory.createNode<ContractDefinition>(name, docstring, baseContracts, structs, | ||||
| 	return nodeFactory.createNode<ContractDefinition>(name, docString, baseContracts, structs, | ||||
| 													  stateVariables, functions); | ||||
| } | ||||
| 
 | ||||
| @ -178,7 +178,7 @@ ASTPointer<InheritanceSpecifier> Parser::parseInheritanceSpecifier() | ||||
| 	return nodeFactory.createNode<InheritanceSpecifier>(name, arguments); | ||||
| } | ||||
| 
 | ||||
| ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic) | ||||
| ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic, ASTString const* _contractName) | ||||
| { | ||||
| 	ASTNodeFactory nodeFactory(*this); | ||||
| 	ASTPointer<ASTString> docstring; | ||||
| @ -210,7 +210,8 @@ ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic) | ||||
| 	} | ||||
| 	ASTPointer<Block> block = parseBlock(); | ||||
| 	nodeFactory.setEndPositionFromNode(block); | ||||
| 	return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, docstring, | ||||
| 	bool const c_isConstructor = (_contractName && *name == *_contractName); | ||||
| 	return nodeFactory.createNode<FunctionDefinition>(name, _isPublic, c_isConstructor, docstring, | ||||
| 													  parameters, | ||||
| 													  isDeclaredConst, returnParameters, block); | ||||
| } | ||||
|  | ||||
							
								
								
									
										2
									
								
								Parser.h
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Parser.h
									
									
									
									
									
								
							| @ -50,7 +50,7 @@ private: | ||||
| 	ASTPointer<ImportDirective> parseImportDirective(); | ||||
| 	ASTPointer<ContractDefinition> parseContractDefinition(); | ||||
| 	ASTPointer<InheritanceSpecifier> parseInheritanceSpecifier(); | ||||
| 	ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic); | ||||
| 	ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic, ASTString const* _contractName); | ||||
| 	ASTPointer<StructDefinition> parseStructDefinition(); | ||||
| 	ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar); | ||||
| 	ASTPointer<TypeName> parseTypeName(bool _allowVar); | ||||
|  | ||||
| @ -716,7 +716,7 @@ MemberList const& TypeType::getMembers() const | ||||
| 				// We are accessing the type of a base contract, so add all public and private
 | ||||
| 				// functions. Note that this does not add inherited functions on purpose.
 | ||||
| 				for (ASTPointer<FunctionDefinition> const& f: contract.getDefinedFunctions()) | ||||
| 					if (f->getName() != contract.getName()) | ||||
| 					if (!f->isConstructor()) | ||||
| 						members[f->getName()] = make_shared<FunctionType>(*f); | ||||
| 		} | ||||
| 		m_members.reset(new MemberList(members)); | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user