From df1809f8dabcbad040d77cbf965584eaedfdb5da Mon Sep 17 00:00:00 2001 From: Daniel Kirchner Date: Tue, 14 Apr 2020 16:36:37 +0200 Subject: [PATCH] Annotate struct definitions with a recursive flag. --- .../analysis/DeclarationTypeChecker.cpp | 120 +++++++++---- libsolidity/analysis/DeclarationTypeChecker.h | 6 +- libsolidity/analysis/TypeChecker.cpp | 8 + libsolidity/ast/AST.cpp | 5 +- libsolidity/ast/AST.h | 2 +- libsolidity/ast/ASTAnnotations.h | 6 + libsolidity/ast/Types.cpp | 168 +++++++++--------- libsolidity/ast/Types.h | 11 +- libsolutil/Algorithms.h | 4 + test/libsolidity/SolidityTypes.cpp | 1 + .../recursive_struct_function_pointer.sol | 10 ++ .../function_type_argument_external.sol | 1 - .../mapping/function_type_return_external.sol | 1 - 13 files changed, 210 insertions(+), 133 deletions(-) create mode 100644 test/libsolidity/syntaxTests/structs/recursion/recursive_struct_function_pointer.sol diff --git a/libsolidity/analysis/DeclarationTypeChecker.cpp b/libsolidity/analysis/DeclarationTypeChecker.cpp index 12e84d53d..71fd52f14 100644 --- a/libsolidity/analysis/DeclarationTypeChecker.cpp +++ b/libsolidity/analysis/DeclarationTypeChecker.cpp @@ -31,40 +31,86 @@ using namespace solidity::frontend; bool DeclarationTypeChecker::visit(ElementaryTypeName const& _typeName) { - if (!_typeName.annotation().type) + if (_typeName.annotation().type) + return false; + + _typeName.annotation().type = TypeProvider::fromElementaryTypeName(_typeName.typeName()); + if (_typeName.stateMutability().has_value()) { - _typeName.annotation().type = TypeProvider::fromElementaryTypeName(_typeName.typeName()); - if (_typeName.stateMutability().has_value()) + // for non-address types this was already caught by the parser + solAssert(_typeName.annotation().type->category() == Type::Category::Address, ""); + switch (*_typeName.stateMutability()) { - // for non-address types this was already caught by the parser - solAssert(_typeName.annotation().type->category() == Type::Category::Address, ""); - switch (*_typeName.stateMutability()) - { - case StateMutability::Payable: - _typeName.annotation().type = TypeProvider::payableAddress(); - break; - case StateMutability::NonPayable: - _typeName.annotation().type = TypeProvider::address(); - break; - default: - typeError( - _typeName.location(), - "Address types can only be payable or non-payable." - ); - break; - } + case StateMutability::Payable: + _typeName.annotation().type = TypeProvider::payableAddress(); + break; + case StateMutability::NonPayable: + _typeName.annotation().type = TypeProvider::address(); + break; + default: + typeError( + _typeName.location(), + "Address types can only be payable or non-payable." + ); + break; } } return true; } +bool DeclarationTypeChecker::visit(StructDefinition const& _struct) +{ + if (_struct.annotation().recursive.has_value()) + { + if (!m_currentStructsSeen.empty() && *_struct.annotation().recursive) + m_recursiveStructSeen = true; + return false; + } + + if (m_currentStructsSeen.count(&_struct)) + { + _struct.annotation().recursive = true; + m_recursiveStructSeen = true; + return false; + } + + bool previousRecursiveStructSeen = m_recursiveStructSeen; + bool hasRecursiveChild = false; + + m_currentStructsSeen.insert(&_struct); + + for (auto const& _member: _struct.members()) + { + m_recursiveStructSeen = false; + _member->accept(*this); + if (m_recursiveStructSeen) + hasRecursiveChild = true; + } + + if (!_struct.annotation().recursive.has_value()) + _struct.annotation().recursive = hasRecursiveChild; + m_recursiveStructSeen = previousRecursiveStructSeen || *_struct.annotation().recursive; + m_currentStructsSeen.erase(&_struct); + if (m_currentStructsSeen.empty()) + m_recursiveStructSeen = false; + + return false; +} + void DeclarationTypeChecker::endVisit(UserDefinedTypeName const& _typeName) { + if (_typeName.annotation().type) + return; + Declaration const* declaration = _typeName.annotation().referencedDeclaration; solAssert(declaration, ""); if (StructDefinition const* structDef = dynamic_cast(declaration)) + { + if (!m_insideFunctionType && !m_currentStructsSeen.empty()) + structDef->accept(*this); _typeName.annotation().type = TypeProvider::structType(*structDef, DataLocation::Storage); + } else if (EnumDefinition const* enumDef = dynamic_cast(declaration)) _typeName.annotation().type = TypeProvider::enumType(*enumDef); else if (ContractDefinition const* contract = dynamic_cast(declaration)) @@ -75,8 +121,17 @@ void DeclarationTypeChecker::endVisit(UserDefinedTypeName const& _typeName) fatalTypeError(_typeName.location(), "Name has to refer to a struct, enum or contract."); } } -void DeclarationTypeChecker::endVisit(FunctionTypeName const& _typeName) +bool DeclarationTypeChecker::visit(FunctionTypeName const& _typeName) { + if (_typeName.annotation().type) + return false; + + bool previousInsideFunctionType = m_insideFunctionType; + m_insideFunctionType = true; + _typeName.parameterTypeList()->accept(*this); + _typeName.returnParameterTypeList()->accept(*this); + m_insideFunctionType = previousInsideFunctionType; + switch (_typeName.visibility()) { case Visibility::Internal: @@ -84,30 +139,22 @@ void DeclarationTypeChecker::endVisit(FunctionTypeName const& _typeName) break; default: fatalTypeError(_typeName.location(), "Invalid visibility, can only be \"external\" or \"internal\"."); - return; + return false; } if (_typeName.isPayable() && _typeName.visibility() != Visibility::External) { fatalTypeError(_typeName.location(), "Only external function types can be payable."); - return; + return false; } - - if (_typeName.visibility() == Visibility::External) - for (auto const& t: _typeName.parameterTypes() + _typeName.returnParameterTypes()) - { - solAssert(t->annotation().type, "Type not set for parameter."); - if (!t->annotation().type->interfaceType(false).get()) - { - fatalTypeError(t->location(), "Internal type cannot be used for external function type."); - return; - } - } - _typeName.annotation().type = TypeProvider::function(_typeName); + return false; } void DeclarationTypeChecker::endVisit(Mapping const& _mapping) { + if (_mapping.annotation().type) + return; + if (auto const* typeName = dynamic_cast(&_mapping.keyType())) { if (auto const* contractType = dynamic_cast(typeName->annotation().type)) @@ -140,6 +187,9 @@ void DeclarationTypeChecker::endVisit(Mapping const& _mapping) void DeclarationTypeChecker::endVisit(ArrayTypeName const& _typeName) { + if (_typeName.annotation().type) + return; + TypePointer baseType = _typeName.baseType().annotation().type; if (!baseType) { diff --git a/libsolidity/analysis/DeclarationTypeChecker.h b/libsolidity/analysis/DeclarationTypeChecker.h index 6878359ab..a3f2c0c8f 100644 --- a/libsolidity/analysis/DeclarationTypeChecker.h +++ b/libsolidity/analysis/DeclarationTypeChecker.h @@ -53,10 +53,11 @@ private: bool visit(ElementaryTypeName const& _typeName) override; void endVisit(UserDefinedTypeName const& _typeName) override; - void endVisit(FunctionTypeName const& _typeName) override; + bool visit(FunctionTypeName const& _typeName) override; void endVisit(Mapping const& _mapping) override; void endVisit(ArrayTypeName const& _typeName) override; void endVisit(VariableDeclaration const& _variable) override; + bool visit(StructDefinition const& _struct) override; /// Adds a new error to the list of errors. void typeError(langutil::SourceLocation const& _location, std::string const& _description); @@ -67,6 +68,9 @@ private: langutil::ErrorReporter& m_errorReporter; bool m_errorOccurred = false; langutil::EVMVersion m_evmVersion; + bool m_insideFunctionType = false; + bool m_recursiveStructSeen = false; + std::set m_currentStructsSeen; }; } diff --git a/libsolidity/analysis/TypeChecker.cpp b/libsolidity/analysis/TypeChecker.cpp index 76e805c27..dba8aaab9 100644 --- a/libsolidity/analysis/TypeChecker.cpp +++ b/libsolidity/analysis/TypeChecker.cpp @@ -631,7 +631,15 @@ void TypeChecker::endVisit(FunctionTypeName const& _funType) { FunctionType const& fun = dynamic_cast(*_funType.annotation().type); if (fun.kind() == FunctionType::Kind::External) + { + for (auto const& t: _funType.parameterTypes() + _funType.returnParameterTypes()) + { + solAssert(t->annotation().type, "Type not set for parameter."); + if (!t->annotation().type->interfaceType(false).get()) + m_errorReporter.typeError(t->location(), "Internal type cannot be used for external function type."); + } solAssert(fun.interfaceType(false), "External function type uses internal types."); + } } bool TypeChecker::visit(InlineAssembly const& _inlineAssembly) diff --git a/libsolidity/ast/AST.cpp b/libsolidity/ast/AST.cpp index d615babf5..10825217c 100644 --- a/libsolidity/ast/AST.cpp +++ b/libsolidity/ast/AST.cpp @@ -257,12 +257,13 @@ TypeNameAnnotation& TypeName::annotation() const TypePointer StructDefinition::type() const { + solAssert(annotation().recursive.has_value(), "Requested struct type before DeclarationTypeChecker."); return TypeProvider::typeType(TypeProvider::structType(*this, DataLocation::Storage)); } -TypeDeclarationAnnotation& StructDefinition::annotation() const +StructDeclarationAnnotation& StructDefinition::annotation() const { - return initAnnotation(); + return initAnnotation(); } TypePointer EnumValue::type() const diff --git a/libsolidity/ast/AST.h b/libsolidity/ast/AST.h index 04067329d..bc4fed110 100644 --- a/libsolidity/ast/AST.h +++ b/libsolidity/ast/AST.h @@ -607,7 +607,7 @@ public: bool isVisibleInDerivedContracts() const override { return true; } bool isVisibleViaContractTypeAccess() const override { return true; } - TypeDeclarationAnnotation& annotation() const override; + StructDeclarationAnnotation& annotation() const override; private: std::vector> m_members; diff --git a/libsolidity/ast/ASTAnnotations.h b/libsolidity/ast/ASTAnnotations.h index 9f7edf84e..126a40b43 100644 --- a/libsolidity/ast/ASTAnnotations.h +++ b/libsolidity/ast/ASTAnnotations.h @@ -128,6 +128,12 @@ struct TypeDeclarationAnnotation: DeclarationAnnotation std::string canonicalName; }; +struct StructDeclarationAnnotation: TypeDeclarationAnnotation +{ + /// Whether the struct is recursive. Will be filled in by the DeclarationTypeChecker. + std::optional recursive; +}; + struct ContractDefinitionAnnotation: TypeDeclarationAnnotation, StructurallyDocumentedAnnotation { /// List of functions without a body. Can also contain functions from base classes. diff --git a/libsolidity/ast/Types.cpp b/libsolidity/ast/Types.cpp index 59eeb18ee..554ff5015 100644 --- a/libsolidity/ast/Types.cpp +++ b/libsolidity/ast/Types.cpp @@ -2175,93 +2175,107 @@ MemberList::MemberMap StructType::nativeMembers(ContractDefinition const*) const TypeResult StructType::interfaceType(bool _inLibrary) const { - if (_inLibrary && m_interfaceType_library.has_value()) - return *m_interfaceType_library; - - if (!_inLibrary && m_interfaceType.has_value()) + if (!_inLibrary) + { + if (!m_interfaceType.has_value()) + { + if (recursive()) + m_interfaceType = TypeResult::err("Recursive type not allowed for public or external contract functions."); + else + { + TypeResult result{TypePointer{}}; + for (ASTPointer const& member: m_struct.members()) + { + if (!member->annotation().type) + { + result = TypeResult::err("Invalid type!"); + break; + } + auto interfaceType = member->annotation().type->interfaceType(false); + if (!interfaceType.get()) + { + solAssert(!interfaceType.message().empty(), "Expected detailed error message!"); + result = interfaceType; + break; + } + } + if (result.message().empty()) + m_interfaceType = TypeProvider::withLocation(this, DataLocation::Memory, true); + else + m_interfaceType = result; + } + } return *m_interfaceType; + } + else if (m_interfaceType_library.has_value()) + return *m_interfaceType_library; TypeResult result{TypePointer{}}; - m_recursive = false; - - auto visitor = [&]( - StructDefinition const& _struct, - util::CycleDetector& _cycleDetector, - size_t /*_depth*/ - ) - { - // Check that all members have interface types. - // Return an error if at least one struct member does not have a type. - // This might happen, for example, if the type of the member does not exist. - for (ASTPointer const& variable: _struct.members()) - { - // If the struct member does not have a type return false. - // A TypeError is expected in this case. - if (!variable->annotation().type) - { - result = TypeResult::err("Invalid type!"); - return; - } - - Type const* memberType = variable->annotation().type; - - while (dynamic_cast(memberType)) - memberType = dynamic_cast(memberType)->baseType(); - - if (StructType const* innerStruct = dynamic_cast(memberType)) - if ( - innerStruct->m_recursive == true || - _cycleDetector.run(innerStruct->structDefinition()) - ) + util::BreadthFirstSearch breadthFirstSearch{{&m_struct}}; + breadthFirstSearch.run( + [&](StructDefinition const* _struct, auto&& _addChild) { + // Check that all members have interface types. + // Return an error if at least one struct member does not have a type. + // This might happen, for example, if the type of the member does not exist. + for (ASTPointer const& variable: _struct->members()) { - m_recursive = true; - if (_inLibrary && location() == DataLocation::Storage) - continue; - else + // If the struct member does not have a type return false. + // A TypeError is expected in this case. + if (!variable->annotation().type) { - result = TypeResult::err("Recursive structs can only be passed as storage pointers to libraries, not as memory objects to contract functions."); + result = TypeResult::err("Invalid type!"); + breadthFirstSearch.abort(); return; } - } - auto iType = memberType->interfaceType(_inLibrary); - if (!iType.get()) - { - solAssert(!iType.message().empty(), "Expected detailed error message!"); - result = iType; - return; + Type const* memberType = variable->annotation().type; + + while (dynamic_cast(memberType)) + memberType = dynamic_cast(memberType)->baseType(); + + if (StructType const* innerStruct = dynamic_cast(memberType)) + { + if (innerStruct->recursive() && !(_inLibrary && location() == DataLocation::Storage)) + { + result = TypeResult::err( + "Recursive structs can only be passed as storage pointers to libraries, not as memory objects to contract functions." + ); + breadthFirstSearch.abort(); + return; + } + else + _addChild(&innerStruct->structDefinition()); + } + else + { + auto iType = memberType->interfaceType(_inLibrary); + if (!iType.get()) + { + solAssert(!iType.message().empty(), "Expected detailed error message!"); + result = iType; + breadthFirstSearch.abort(); + return; + } + } } } - }; + ); - m_recursive = m_recursive.value() || (util::CycleDetector(visitor).run(structDefinition()) != nullptr); + if (!result.message().empty()) + return result; - std::string const recursiveErrMsg = "Recursive type not allowed for public or external contract functions."; - - if (_inLibrary) - { - if (!result.message().empty()) - m_interfaceType_library = result; - else if (location() == DataLocation::Storage) - m_interfaceType_library = this; - else - m_interfaceType_library = TypeProvider::withLocation(this, DataLocation::Memory, true); - - if (m_recursive.value()) - m_interfaceType = TypeResult::err(recursiveErrMsg); - - return *m_interfaceType_library; - } - - if (m_recursive.value()) - m_interfaceType = TypeResult::err(recursiveErrMsg); - else if (!result.message().empty()) - m_interfaceType = result; + if (location() == DataLocation::Storage) + m_interfaceType_library = this; else - m_interfaceType = TypeProvider::withLocation(this, DataLocation::Memory, true); + m_interfaceType_library = TypeProvider::withLocation(this, DataLocation::Memory, true); + return *m_interfaceType_library; +} - return *m_interfaceType; +bool StructType::recursive() const +{ + solAssert(m_struct.annotation().recursive.has_value(), "Called StructType::recursive() before DeclarationTypeChecker."); + return *m_struct.annotation().recursive; } std::unique_ptr StructType::copyForLocation(DataLocation _location, bool _isPointer) const @@ -2644,21 +2658,11 @@ FunctionType::FunctionType(FunctionTypeName const& _typeName): for (auto const& t: _typeName.parameterTypes()) { solAssert(t->annotation().type, "Type not set for parameter."); - if (m_kind == Kind::External) - solAssert( - t->annotation().type->interfaceType(false).get(), - "Internal type used as parameter for external function." - ); m_parameterTypes.push_back(t->annotation().type); } for (auto const& t: _typeName.returnParameterTypes()) { solAssert(t->annotation().type, "Type not set for return parameter."); - if (m_kind == Kind::External) - solAssert( - t->annotation().type->interfaceType(false).get(), - "Internal type used as return parameter for external function." - ); m_returnParameterTypes.push_back(t->annotation().type); } diff --git a/libsolidity/ast/Types.h b/libsolidity/ast/Types.h index d07c47dde..7ab9f7d91 100644 --- a/libsolidity/ast/Types.h +++ b/libsolidity/ast/Types.h @@ -934,15 +934,7 @@ public: Type const* encodingType() const override; TypeResult interfaceType(bool _inLibrary) const override; - bool recursive() const - { - if (m_recursive.has_value()) - return m_recursive.value(); - - interfaceType(false); - - return m_recursive.value(); - } + bool recursive() const; std::unique_ptr copyForLocation(DataLocation _location, bool _isPointer) const override; @@ -971,7 +963,6 @@ private: // Caches for interfaceType(bool) mutable std::optional m_interfaceType; mutable std::optional m_interfaceType_library; - mutable std::optional m_recursive; }; /** diff --git a/libsolutil/Algorithms.h b/libsolutil/Algorithms.h index b9028f19b..3897d65d2 100644 --- a/libsolutil/Algorithms.h +++ b/libsolutil/Algorithms.h @@ -114,6 +114,10 @@ struct BreadthFirstSearch } return *this; } + void abort() + { + verticesToTraverse.clear(); + } std::set verticesToTraverse; std::set visited{}; diff --git a/test/libsolidity/SolidityTypes.cpp b/test/libsolidity/SolidityTypes.cpp index c6cd25c56..b21c8f2e1 100644 --- a/test/libsolidity/SolidityTypes.cpp +++ b/test/libsolidity/SolidityTypes.cpp @@ -184,6 +184,7 @@ BOOST_AUTO_TEST_CASE(type_identifiers) BOOST_CHECK_EQUAL(ContractType(c, true).identifier(), "t_super$_MyContract$$$_$2"); StructDefinition s(++id, {}, make_shared("Struct"), {}); + s.annotation().recursive = false; BOOST_CHECK_EQUAL(s.type()->identifier(), "t_type$_t_struct$_Struct_$3_storage_ptr_$"); EnumDefinition e(++id, {}, make_shared("Enum"), {}); diff --git a/test/libsolidity/syntaxTests/structs/recursion/recursive_struct_function_pointer.sol b/test/libsolidity/syntaxTests/structs/recursion/recursive_struct_function_pointer.sol new file mode 100644 index 000000000..dc40ae3b2 --- /dev/null +++ b/test/libsolidity/syntaxTests/structs/recursion/recursive_struct_function_pointer.sol @@ -0,0 +1,10 @@ +pragma experimental ABIEncoderV2; +contract C { + struct S { + uint a; + function() external returns (S memory) sub; + } + function f() public pure returns (S memory) { + } +} +// ---- diff --git a/test/libsolidity/syntaxTests/types/mapping/function_type_argument_external.sol b/test/libsolidity/syntaxTests/types/mapping/function_type_argument_external.sol index 34f957019..8638baf85 100644 --- a/test/libsolidity/syntaxTests/types/mapping/function_type_argument_external.sol +++ b/test/libsolidity/syntaxTests/types/mapping/function_type_argument_external.sol @@ -4,4 +4,3 @@ contract C { } // ---- // TypeError: (37-64): Data location must be "memory" for parameter in function, but "storage" was given. -// TypeError: (37-64): Internal type cannot be used for external function type. diff --git a/test/libsolidity/syntaxTests/types/mapping/function_type_return_external.sol b/test/libsolidity/syntaxTests/types/mapping/function_type_return_external.sol index aed9b3878..b9bd5bc3f 100644 --- a/test/libsolidity/syntaxTests/types/mapping/function_type_return_external.sol +++ b/test/libsolidity/syntaxTests/types/mapping/function_type_return_external.sol @@ -4,4 +4,3 @@ contract C { } // ---- // TypeError: (57-84): Data location must be "memory" for return parameter in function, but "storage" was given. -// TypeError: (57-84): Internal type cannot be used for external function type.