Annotate struct definitions with a recursive flag.

This commit is contained in:
Daniel Kirchner 2020-04-14 16:36:37 +02:00
parent 95349b3634
commit df1809f8da
13 changed files with 210 additions and 133 deletions

View File

@ -31,8 +31,9 @@ using namespace solidity::frontend;
bool DeclarationTypeChecker::visit(ElementaryTypeName const& _typeName) bool DeclarationTypeChecker::visit(ElementaryTypeName const& _typeName)
{ {
if (!_typeName.annotation().type) if (_typeName.annotation().type)
{ return false;
_typeName.annotation().type = TypeProvider::fromElementaryTypeName(_typeName.typeName()); _typeName.annotation().type = TypeProvider::fromElementaryTypeName(_typeName.typeName());
if (_typeName.stateMutability().has_value()) if (_typeName.stateMutability().has_value())
{ {
@ -54,17 +55,62 @@ bool DeclarationTypeChecker::visit(ElementaryTypeName const& _typeName)
break; break;
} }
} }
}
return true; return true;
} }
bool DeclarationTypeChecker::visit(StructDefinition const& _struct)
{
if (_struct.annotation().recursive.has_value())
{
if (!m_currentStructsSeen.empty() && *_struct.annotation().recursive)
m_recursiveStructSeen = true;
return false;
}
if (m_currentStructsSeen.count(&_struct))
{
_struct.annotation().recursive = true;
m_recursiveStructSeen = true;
return false;
}
bool previousRecursiveStructSeen = m_recursiveStructSeen;
bool hasRecursiveChild = false;
m_currentStructsSeen.insert(&_struct);
for (auto const& _member: _struct.members())
{
m_recursiveStructSeen = false;
_member->accept(*this);
if (m_recursiveStructSeen)
hasRecursiveChild = true;
}
if (!_struct.annotation().recursive.has_value())
_struct.annotation().recursive = hasRecursiveChild;
m_recursiveStructSeen = previousRecursiveStructSeen || *_struct.annotation().recursive;
m_currentStructsSeen.erase(&_struct);
if (m_currentStructsSeen.empty())
m_recursiveStructSeen = false;
return false;
}
void DeclarationTypeChecker::endVisit(UserDefinedTypeName const& _typeName) void DeclarationTypeChecker::endVisit(UserDefinedTypeName const& _typeName)
{ {
if (_typeName.annotation().type)
return;
Declaration const* declaration = _typeName.annotation().referencedDeclaration; Declaration const* declaration = _typeName.annotation().referencedDeclaration;
solAssert(declaration, ""); solAssert(declaration, "");
if (StructDefinition const* structDef = dynamic_cast<StructDefinition const*>(declaration)) if (StructDefinition const* structDef = dynamic_cast<StructDefinition const*>(declaration))
{
if (!m_insideFunctionType && !m_currentStructsSeen.empty())
structDef->accept(*this);
_typeName.annotation().type = TypeProvider::structType(*structDef, DataLocation::Storage); _typeName.annotation().type = TypeProvider::structType(*structDef, DataLocation::Storage);
}
else if (EnumDefinition const* enumDef = dynamic_cast<EnumDefinition const*>(declaration)) else if (EnumDefinition const* enumDef = dynamic_cast<EnumDefinition const*>(declaration))
_typeName.annotation().type = TypeProvider::enumType(*enumDef); _typeName.annotation().type = TypeProvider::enumType(*enumDef);
else if (ContractDefinition const* contract = dynamic_cast<ContractDefinition const*>(declaration)) else if (ContractDefinition const* contract = dynamic_cast<ContractDefinition const*>(declaration))
@ -75,8 +121,17 @@ void DeclarationTypeChecker::endVisit(UserDefinedTypeName const& _typeName)
fatalTypeError(_typeName.location(), "Name has to refer to a struct, enum or contract."); fatalTypeError(_typeName.location(), "Name has to refer to a struct, enum or contract.");
} }
} }
void DeclarationTypeChecker::endVisit(FunctionTypeName const& _typeName) bool DeclarationTypeChecker::visit(FunctionTypeName const& _typeName)
{ {
if (_typeName.annotation().type)
return false;
bool previousInsideFunctionType = m_insideFunctionType;
m_insideFunctionType = true;
_typeName.parameterTypeList()->accept(*this);
_typeName.returnParameterTypeList()->accept(*this);
m_insideFunctionType = previousInsideFunctionType;
switch (_typeName.visibility()) switch (_typeName.visibility())
{ {
case Visibility::Internal: case Visibility::Internal:
@ -84,30 +139,22 @@ void DeclarationTypeChecker::endVisit(FunctionTypeName const& _typeName)
break; break;
default: default:
fatalTypeError(_typeName.location(), "Invalid visibility, can only be \"external\" or \"internal\"."); fatalTypeError(_typeName.location(), "Invalid visibility, can only be \"external\" or \"internal\".");
return; return false;
} }
if (_typeName.isPayable() && _typeName.visibility() != Visibility::External) if (_typeName.isPayable() && _typeName.visibility() != Visibility::External)
{ {
fatalTypeError(_typeName.location(), "Only external function types can be payable."); fatalTypeError(_typeName.location(), "Only external function types can be payable.");
return; return false;
} }
if (_typeName.visibility() == Visibility::External)
for (auto const& t: _typeName.parameterTypes() + _typeName.returnParameterTypes())
{
solAssert(t->annotation().type, "Type not set for parameter.");
if (!t->annotation().type->interfaceType(false).get())
{
fatalTypeError(t->location(), "Internal type cannot be used for external function type.");
return;
}
}
_typeName.annotation().type = TypeProvider::function(_typeName); _typeName.annotation().type = TypeProvider::function(_typeName);
return false;
} }
void DeclarationTypeChecker::endVisit(Mapping const& _mapping) void DeclarationTypeChecker::endVisit(Mapping const& _mapping)
{ {
if (_mapping.annotation().type)
return;
if (auto const* typeName = dynamic_cast<UserDefinedTypeName const*>(&_mapping.keyType())) if (auto const* typeName = dynamic_cast<UserDefinedTypeName const*>(&_mapping.keyType()))
{ {
if (auto const* contractType = dynamic_cast<ContractType const*>(typeName->annotation().type)) if (auto const* contractType = dynamic_cast<ContractType const*>(typeName->annotation().type))
@ -140,6 +187,9 @@ void DeclarationTypeChecker::endVisit(Mapping const& _mapping)
void DeclarationTypeChecker::endVisit(ArrayTypeName const& _typeName) void DeclarationTypeChecker::endVisit(ArrayTypeName const& _typeName)
{ {
if (_typeName.annotation().type)
return;
TypePointer baseType = _typeName.baseType().annotation().type; TypePointer baseType = _typeName.baseType().annotation().type;
if (!baseType) if (!baseType)
{ {

View File

@ -53,10 +53,11 @@ private:
bool visit(ElementaryTypeName const& _typeName) override; bool visit(ElementaryTypeName const& _typeName) override;
void endVisit(UserDefinedTypeName const& _typeName) override; void endVisit(UserDefinedTypeName const& _typeName) override;
void endVisit(FunctionTypeName const& _typeName) override; bool visit(FunctionTypeName const& _typeName) override;
void endVisit(Mapping const& _mapping) override; void endVisit(Mapping const& _mapping) override;
void endVisit(ArrayTypeName const& _typeName) override; void endVisit(ArrayTypeName const& _typeName) override;
void endVisit(VariableDeclaration const& _variable) override; void endVisit(VariableDeclaration const& _variable) override;
bool visit(StructDefinition const& _struct) override;
/// Adds a new error to the list of errors. /// Adds a new error to the list of errors.
void typeError(langutil::SourceLocation const& _location, std::string const& _description); void typeError(langutil::SourceLocation const& _location, std::string const& _description);
@ -67,6 +68,9 @@ private:
langutil::ErrorReporter& m_errorReporter; langutil::ErrorReporter& m_errorReporter;
bool m_errorOccurred = false; bool m_errorOccurred = false;
langutil::EVMVersion m_evmVersion; langutil::EVMVersion m_evmVersion;
bool m_insideFunctionType = false;
bool m_recursiveStructSeen = false;
std::set<StructDefinition const*> m_currentStructsSeen;
}; };
} }

View File

@ -631,7 +631,15 @@ void TypeChecker::endVisit(FunctionTypeName const& _funType)
{ {
FunctionType const& fun = dynamic_cast<FunctionType const&>(*_funType.annotation().type); FunctionType const& fun = dynamic_cast<FunctionType const&>(*_funType.annotation().type);
if (fun.kind() == FunctionType::Kind::External) if (fun.kind() == FunctionType::Kind::External)
{
for (auto const& t: _funType.parameterTypes() + _funType.returnParameterTypes())
{
solAssert(t->annotation().type, "Type not set for parameter.");
if (!t->annotation().type->interfaceType(false).get())
m_errorReporter.typeError(t->location(), "Internal type cannot be used for external function type.");
}
solAssert(fun.interfaceType(false), "External function type uses internal types."); solAssert(fun.interfaceType(false), "External function type uses internal types.");
}
} }
bool TypeChecker::visit(InlineAssembly const& _inlineAssembly) bool TypeChecker::visit(InlineAssembly const& _inlineAssembly)

View File

@ -257,12 +257,13 @@ TypeNameAnnotation& TypeName::annotation() const
TypePointer StructDefinition::type() const TypePointer StructDefinition::type() const
{ {
solAssert(annotation().recursive.has_value(), "Requested struct type before DeclarationTypeChecker.");
return TypeProvider::typeType(TypeProvider::structType(*this, DataLocation::Storage)); return TypeProvider::typeType(TypeProvider::structType(*this, DataLocation::Storage));
} }
TypeDeclarationAnnotation& StructDefinition::annotation() const StructDeclarationAnnotation& StructDefinition::annotation() const
{ {
return initAnnotation<TypeDeclarationAnnotation>(); return initAnnotation<StructDeclarationAnnotation>();
} }
TypePointer EnumValue::type() const TypePointer EnumValue::type() const

View File

@ -607,7 +607,7 @@ public:
bool isVisibleInDerivedContracts() const override { return true; } bool isVisibleInDerivedContracts() const override { return true; }
bool isVisibleViaContractTypeAccess() const override { return true; } bool isVisibleViaContractTypeAccess() const override { return true; }
TypeDeclarationAnnotation& annotation() const override; StructDeclarationAnnotation& annotation() const override;
private: private:
std::vector<ASTPointer<VariableDeclaration>> m_members; std::vector<ASTPointer<VariableDeclaration>> m_members;

View File

@ -128,6 +128,12 @@ struct TypeDeclarationAnnotation: DeclarationAnnotation
std::string canonicalName; std::string canonicalName;
}; };
struct StructDeclarationAnnotation: TypeDeclarationAnnotation
{
/// Whether the struct is recursive. Will be filled in by the DeclarationTypeChecker.
std::optional<bool> recursive;
};
struct ContractDefinitionAnnotation: TypeDeclarationAnnotation, StructurallyDocumentedAnnotation struct ContractDefinitionAnnotation: TypeDeclarationAnnotation, StructurallyDocumentedAnnotation
{ {
/// List of functions without a body. Can also contain functions from base classes. /// List of functions without a body. Can also contain functions from base classes.

View File

@ -2175,32 +2175,57 @@ MemberList::MemberMap StructType::nativeMembers(ContractDefinition const*) const
TypeResult StructType::interfaceType(bool _inLibrary) const TypeResult StructType::interfaceType(bool _inLibrary) const
{ {
if (_inLibrary && m_interfaceType_library.has_value()) if (!_inLibrary)
return *m_interfaceType_library; {
if (!m_interfaceType.has_value())
if (!_inLibrary && m_interfaceType.has_value()) {
if (recursive())
m_interfaceType = TypeResult::err("Recursive type not allowed for public or external contract functions.");
else
{
TypeResult result{TypePointer{}};
for (ASTPointer<VariableDeclaration> const& member: m_struct.members())
{
if (!member->annotation().type)
{
result = TypeResult::err("Invalid type!");
break;
}
auto interfaceType = member->annotation().type->interfaceType(false);
if (!interfaceType.get())
{
solAssert(!interfaceType.message().empty(), "Expected detailed error message!");
result = interfaceType;
break;
}
}
if (result.message().empty())
m_interfaceType = TypeProvider::withLocation(this, DataLocation::Memory, true);
else
m_interfaceType = result;
}
}
return *m_interfaceType; return *m_interfaceType;
}
else if (m_interfaceType_library.has_value())
return *m_interfaceType_library;
TypeResult result{TypePointer{}}; TypeResult result{TypePointer{}};
m_recursive = false; util::BreadthFirstSearch<StructDefinition const*> breadthFirstSearch{{&m_struct}};
breadthFirstSearch.run(
auto visitor = [&]( [&](StructDefinition const* _struct, auto&& _addChild) {
StructDefinition const& _struct,
util::CycleDetector<StructDefinition>& _cycleDetector,
size_t /*_depth*/
)
{
// Check that all members have interface types. // Check that all members have interface types.
// Return an error if at least one struct member does not have a type. // Return an error if at least one struct member does not have a type.
// This might happen, for example, if the type of the member does not exist. // This might happen, for example, if the type of the member does not exist.
for (ASTPointer<VariableDeclaration> const& variable: _struct.members()) for (ASTPointer<VariableDeclaration> const& variable: _struct->members())
{ {
// If the struct member does not have a type return false. // If the struct member does not have a type return false.
// A TypeError is expected in this case. // A TypeError is expected in this case.
if (!variable->annotation().type) if (!variable->annotation().type)
{ {
result = TypeResult::err("Invalid type!"); result = TypeResult::err("Invalid type!");
breadthFirstSearch.abort();
return; return;
} }
@ -2210,58 +2235,47 @@ TypeResult StructType::interfaceType(bool _inLibrary) const
memberType = dynamic_cast<ArrayType const*>(memberType)->baseType(); memberType = dynamic_cast<ArrayType const*>(memberType)->baseType();
if (StructType const* innerStruct = dynamic_cast<StructType const*>(memberType)) if (StructType const* innerStruct = dynamic_cast<StructType const*>(memberType))
if (
innerStruct->m_recursive == true ||
_cycleDetector.run(innerStruct->structDefinition())
)
{ {
m_recursive = true; if (innerStruct->recursive() && !(_inLibrary && location() == DataLocation::Storage))
if (_inLibrary && location() == DataLocation::Storage)
continue;
else
{ {
result = TypeResult::err("Recursive structs can only be passed as storage pointers to libraries, not as memory objects to contract functions."); result = TypeResult::err(
"Recursive structs can only be passed as storage pointers to libraries, not as memory objects to contract functions."
);
breadthFirstSearch.abort();
return; return;
} }
else
_addChild(&innerStruct->structDefinition());
} }
else
{
auto iType = memberType->interfaceType(_inLibrary); auto iType = memberType->interfaceType(_inLibrary);
if (!iType.get()) if (!iType.get())
{ {
solAssert(!iType.message().empty(), "Expected detailed error message!"); solAssert(!iType.message().empty(), "Expected detailed error message!");
result = iType; result = iType;
breadthFirstSearch.abort();
return; return;
} }
} }
}; }
}
);
m_recursive = m_recursive.value() || (util::CycleDetector<StructDefinition>(visitor).run(structDefinition()) != nullptr);
std::string const recursiveErrMsg = "Recursive type not allowed for public or external contract functions.";
if (_inLibrary)
{
if (!result.message().empty()) if (!result.message().empty())
m_interfaceType_library = result; return result;
else if (location() == DataLocation::Storage)
if (location() == DataLocation::Storage)
m_interfaceType_library = this; m_interfaceType_library = this;
else else
m_interfaceType_library = TypeProvider::withLocation(this, DataLocation::Memory, true); m_interfaceType_library = TypeProvider::withLocation(this, DataLocation::Memory, true);
if (m_recursive.value())
m_interfaceType = TypeResult::err(recursiveErrMsg);
return *m_interfaceType_library; return *m_interfaceType_library;
} }
if (m_recursive.value()) bool StructType::recursive() const
m_interfaceType = TypeResult::err(recursiveErrMsg); {
else if (!result.message().empty()) solAssert(m_struct.annotation().recursive.has_value(), "Called StructType::recursive() before DeclarationTypeChecker.");
m_interfaceType = result; return *m_struct.annotation().recursive;
else
m_interfaceType = TypeProvider::withLocation(this, DataLocation::Memory, true);
return *m_interfaceType;
} }
std::unique_ptr<ReferenceType> StructType::copyForLocation(DataLocation _location, bool _isPointer) const std::unique_ptr<ReferenceType> StructType::copyForLocation(DataLocation _location, bool _isPointer) const
@ -2644,21 +2658,11 @@ FunctionType::FunctionType(FunctionTypeName const& _typeName):
for (auto const& t: _typeName.parameterTypes()) for (auto const& t: _typeName.parameterTypes())
{ {
solAssert(t->annotation().type, "Type not set for parameter."); solAssert(t->annotation().type, "Type not set for parameter.");
if (m_kind == Kind::External)
solAssert(
t->annotation().type->interfaceType(false).get(),
"Internal type used as parameter for external function."
);
m_parameterTypes.push_back(t->annotation().type); m_parameterTypes.push_back(t->annotation().type);
} }
for (auto const& t: _typeName.returnParameterTypes()) for (auto const& t: _typeName.returnParameterTypes())
{ {
solAssert(t->annotation().type, "Type not set for return parameter."); solAssert(t->annotation().type, "Type not set for return parameter.");
if (m_kind == Kind::External)
solAssert(
t->annotation().type->interfaceType(false).get(),
"Internal type used as return parameter for external function."
);
m_returnParameterTypes.push_back(t->annotation().type); m_returnParameterTypes.push_back(t->annotation().type);
} }

View File

@ -934,15 +934,7 @@ public:
Type const* encodingType() const override; Type const* encodingType() const override;
TypeResult interfaceType(bool _inLibrary) const override; TypeResult interfaceType(bool _inLibrary) const override;
bool recursive() const bool recursive() const;
{
if (m_recursive.has_value())
return m_recursive.value();
interfaceType(false);
return m_recursive.value();
}
std::unique_ptr<ReferenceType> copyForLocation(DataLocation _location, bool _isPointer) const override; std::unique_ptr<ReferenceType> copyForLocation(DataLocation _location, bool _isPointer) const override;
@ -971,7 +963,6 @@ private:
// Caches for interfaceType(bool) // Caches for interfaceType(bool)
mutable std::optional<TypeResult> m_interfaceType; mutable std::optional<TypeResult> m_interfaceType;
mutable std::optional<TypeResult> m_interfaceType_library; mutable std::optional<TypeResult> m_interfaceType_library;
mutable std::optional<bool> m_recursive;
}; };
/** /**

View File

@ -114,6 +114,10 @@ struct BreadthFirstSearch
} }
return *this; return *this;
} }
void abort()
{
verticesToTraverse.clear();
}
std::set<V> verticesToTraverse; std::set<V> verticesToTraverse;
std::set<V> visited{}; std::set<V> visited{};

View File

@ -184,6 +184,7 @@ BOOST_AUTO_TEST_CASE(type_identifiers)
BOOST_CHECK_EQUAL(ContractType(c, true).identifier(), "t_super$_MyContract$$$_$2"); BOOST_CHECK_EQUAL(ContractType(c, true).identifier(), "t_super$_MyContract$$$_$2");
StructDefinition s(++id, {}, make_shared<string>("Struct"), {}); StructDefinition s(++id, {}, make_shared<string>("Struct"), {});
s.annotation().recursive = false;
BOOST_CHECK_EQUAL(s.type()->identifier(), "t_type$_t_struct$_Struct_$3_storage_ptr_$"); BOOST_CHECK_EQUAL(s.type()->identifier(), "t_type$_t_struct$_Struct_$3_storage_ptr_$");
EnumDefinition e(++id, {}, make_shared<string>("Enum"), {}); EnumDefinition e(++id, {}, make_shared<string>("Enum"), {});

View File

@ -0,0 +1,10 @@
pragma experimental ABIEncoderV2;
contract C {
struct S {
uint a;
function() external returns (S memory) sub;
}
function f() public pure returns (S memory) {
}
}
// ----

View File

@ -4,4 +4,3 @@ contract C {
} }
// ---- // ----
// TypeError: (37-64): Data location must be "memory" for parameter in function, but "storage" was given. // TypeError: (37-64): Data location must be "memory" for parameter in function, but "storage" was given.
// TypeError: (37-64): Internal type cannot be used for external function type.

View File

@ -4,4 +4,3 @@ contract C {
} }
// ---- // ----
// TypeError: (57-84): Data location must be "memory" for return parameter in function, but "storage" was given. // TypeError: (57-84): Data location must be "memory" for return parameter in function, but "storage" was given.
// TypeError: (57-84): Internal type cannot be used for external function type.