From edf3529db60d7d63388a059d088bad0d48c269b3 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 2 Feb 2021 16:14:44 +0100 Subject: [PATCH] Syntax changes for catching custom errors. --- libsolidity/analysis/TypeChecker.cpp | 127 +++++++++++------- libsolidity/ast/AST.cpp | 7 +- libsolidity/ast/AST.h | 27 +++- libsolidity/ast/ASTJsonConverter.cpp | 18 ++- libsolidity/ast/ASTJsonConverter.h | 1 + libsolidity/ast/ASTJsonImporter.cpp | 21 ++- libsolidity/ast/ASTJsonImporter.h | 1 + libsolidity/ast/AST_accept.h | 4 + libsolidity/codegen/ContractCompiler.cpp | 8 +- libsolidity/parsing/Parser.cpp | 32 ++++- .../tryCatch/invalid_error_name.sol | 3 +- 11 files changed, 179 insertions(+), 70 deletions(-) diff --git a/libsolidity/analysis/TypeChecker.cpp b/libsolidity/analysis/TypeChecker.cpp index 2730b3014..097e9d2ba 100644 --- a/libsolidity/analysis/TypeChecker.cpp +++ b/libsolidity/analysis/TypeChecker.cpp @@ -994,9 +994,10 @@ void TypeChecker::endVisit(TryStatement const& _tryStatement) TryCatchClause const* panicClause = nullptr; TryCatchClause const* errorClause = nullptr; TryCatchClause const* lowLevelClause = nullptr; + map seenErrors; for (auto const& clause: _tryStatement.clauses() | ranges::views::drop_exactly(1) | views::dereferenceChecked) { - if (clause.errorName() == "") + if (clause.kind() == TryCatchClause::Kind::Fallback) { if (lowLevelClause) m_errorReporter.typeError( @@ -1022,59 +1023,85 @@ void TypeChecker::endVisit(TryStatement const& _tryStatement) "). You need at least a Byzantium-compatible EVM or use `catch { ... }`." ); } + continue; } - else if (clause.errorName() == "Error" || clause.errorName() == "Panic") - { - if (!m_evmVersion.supportsReturndata()) - m_errorReporter.typeError( - 1812_error, - clause.location(), - "This catch clause type cannot be used on the selected EVM version (" + - m_evmVersion.name() + - "). You need at least a Byzantium-compatible EVM or use `catch { ... }`." - ); - if (clause.errorName() == "Error") - { - if (errorClause) - m_errorReporter.typeError( - 1036_error, - clause.location(), - SecondarySourceLocation{}.append("The first clause is here:", errorClause->location()), - "This try statement already has an \"Error\" catch clause." - ); - errorClause = &clause; - if ( - !clause.parameters() || - clause.parameters()->parameters().size() != 1 || - *clause.parameters()->parameters().front()->type() != *TypeProvider::stringMemory() - ) - m_errorReporter.typeError(2943_error, clause.location(), "Expected `catch Error(string memory ...) { ... }`."); - } - else - { - if (panicClause) - m_errorReporter.typeError( - 6732_error, - clause.location(), - SecondarySourceLocation{}.append("The first clause is here:", panicClause->location()), - "This try statement already has a \"Panic\" catch clause." - ); - panicClause = &clause; - if ( - !clause.parameters() || - clause.parameters()->parameters().size() != 1 || - *clause.parameters()->parameters().front()->type() != *TypeProvider::uint256() - ) - m_errorReporter.typeError(1271_error, clause.location(), "Expected `catch Panic(uint ...) { ... }`."); - } + if (!m_evmVersion.supportsReturndata()) + m_errorReporter.typeError( + 1812_error, + clause.location(), + "This catch clause type cannot be used on the selected EVM version (" + + m_evmVersion.name() + + "). You need at least a Byzantium-compatible EVM or use `catch { ... }`." + ); + + if (clause.kind() == TryCatchClause::Kind::Error) + { + if (errorClause) + m_errorReporter.typeError( + 1036_error, + clause.location(), + SecondarySourceLocation{}.append("The first clause is here:", errorClause->location()), + "This try statement already has an \"Error\" catch clause." + ); + errorClause = &clause; + if ( + !clause.parameters() || + clause.parameters()->parameters().size() != 1 || + *clause.parameters()->parameters().front()->type() != *TypeProvider::stringMemory() + ) + m_errorReporter.typeError(2943_error, clause.location(), "Expected `catch Error(string memory ...) { ... }`."); + } + else if (clause.kind() == TryCatchClause::Kind::Panic) + { + if (panicClause) + m_errorReporter.typeError( + 6732_error, + clause.location(), + SecondarySourceLocation{}.append("The first clause is here:", panicClause->location()), + "This try statement already has a \"Panic\" catch clause." + ); + panicClause = &clause; + if ( + !clause.parameters() || + clause.parameters()->parameters().size() != 1 || + *clause.parameters()->parameters().front()->type() != *TypeProvider::uint256() + ) + m_errorReporter.typeError(1271_error, clause.location(), "Expected `catch Panic(uint ...) { ... }`."); } else - m_errorReporter.typeError( - 3542_error, - clause.location(), - "Invalid catch clause name. Expected either `catch (...)`, `catch Error(...)`, or `catch Panic(...)`." - ); + { + solAssert(clause.kind() == TryCatchClause::Kind::UserDefined, ""); + solAssert(*clause.errorName().annotation().requiredLookup == VirtualLookup::Static, ""); + ErrorDefinition const* error = dynamic_cast(clause.errorName().annotation().referencedDeclaration); + if (!error) + { + m_errorReporter.typeError(1178_error, clause.location(), "Expected the name of an error."); + continue; + } + + if (!seenErrors.emplace(error, &clause).second) + m_errorReporter.typeError( + 6853_error, + clause.location(), + SecondarySourceLocation{}.append("The first clause is here:", seenErrors[error]->location()), + "This try statement already has a \"" + error->name() + "\" catch clause." + ); + if ( + !clause.parameters() || + clause.parameters()->parameters().size() != error->parameters().size() + ) + m_errorReporter.typeError(1271_error, clause.location(), "Expected `catch Panic(uint ...) { ... }`."); + else + for (auto&& [varDecl, parameter]: ranges::views::zip(clause.parameters()->parameters(), error->parameters())) + if (*varDecl->type() != *parameter->type()) + m_errorReporter.typeError( + 63958_error, + varDecl->location(), + ("Expected a parameter of type \"" + parameter->type()->toString(true) + "\" ") + + ("but got \"" + varDecl->type()->toString(true) + "\"") + ); + } } } diff --git a/libsolidity/ast/AST.cpp b/libsolidity/ast/AST.cpp index 106214d36..9e67f7de5 100644 --- a/libsolidity/ast/AST.cpp +++ b/libsolidity/ast/AST.cpp @@ -856,13 +856,14 @@ string Literal::getChecksummedAddress() const TryCatchClause const* TryStatement::successClause() const { solAssert(m_clauses.size() > 0, ""); + solAssert(m_clauses.front()->kind() == TryCatchClause::Kind::Success, ""); return m_clauses[0].get(); } TryCatchClause const* TryStatement::panicClause() const { for (size_t i = 1; i < m_clauses.size(); ++i) - if (m_clauses[i]->errorName() == "Panic") + if (m_clauses[i]->kind() == TryCatchClause::Kind::Panic) return m_clauses[i].get(); return nullptr; } @@ -870,7 +871,7 @@ TryCatchClause const* TryStatement::panicClause() const TryCatchClause const* TryStatement::errorClause() const { for (size_t i = 1; i < m_clauses.size(); ++i) - if (m_clauses[i]->errorName() == "Error") + if (m_clauses[i]->kind() == TryCatchClause::Kind::Error) return m_clauses[i].get(); return nullptr; } @@ -878,7 +879,7 @@ TryCatchClause const* TryStatement::errorClause() const TryCatchClause const* TryStatement::fallbackClause() const { for (size_t i = 1; i < m_clauses.size(); ++i) - if (m_clauses[i]->errorName().empty()) + if (m_clauses[i]->kind() == TryCatchClause::Kind::Fallback) return m_clauses[i].get(); return nullptr; } diff --git a/libsolidity/ast/AST.h b/libsolidity/ast/AST.h index b409c84cd..ef40142ec 100644 --- a/libsolidity/ast/AST.h +++ b/libsolidity/ast/AST.h @@ -1503,29 +1503,46 @@ private: class TryCatchClause: public ASTNode, public Scopable, public ScopeOpener { public: + enum Kind + { + Success, + Panic, + Error, + Fallback, + UserDefined + }; + TryCatchClause( int64_t _id, SourceLocation const& _location, - ASTPointer _errorName, + Kind _kind, + ASTPointer _errorName, ASTPointer _parameters, ASTPointer _block ): ASTNode(_id, _location), + m_kind(_kind), m_errorName(std::move(_errorName)), m_parameters(std::move(_parameters)), m_block(std::move(_block)) - {} + { + solAssert(!!m_errorName == (m_kind == Kind::UserDefined), ""); + } void accept(ASTVisitor& _visitor) override; void accept(ASTConstVisitor& _visitor) const override; - ASTString const& errorName() const { return *m_errorName; } + + Kind kind() const { return m_kind; } + /// @returns the name of the error. Should only be called if catch kind is UserDefined. + IdentifierPath const& errorName() const { return *m_errorName; } ParameterList const* parameters() const { return m_parameters.get(); } Block const& block() const { return *m_block; } TryCatchClauseAnnotation& annotation() const override; private: - ASTPointer m_errorName; + Kind m_kind; + ASTPointer m_errorName; ASTPointer m_parameters; ASTPointer m_block; }; @@ -1535,6 +1552,8 @@ private: * Syntax: * try returns (uint x, uint y) { * // success code + * } catch Custom(uint data) { + * // custom error handler * } catch Panic(uint errorCode) { * // panic * } catch Error(string memory cause) { diff --git a/libsolidity/ast/ASTJsonConverter.cpp b/libsolidity/ast/ASTJsonConverter.cpp index d8139714e..cbd772c12 100644 --- a/libsolidity/ast/ASTJsonConverter.cpp +++ b/libsolidity/ast/ASTJsonConverter.cpp @@ -578,7 +578,10 @@ bool ASTJsonConverter::visit(IfStatement const& _node) bool ASTJsonConverter::visit(TryCatchClause const& _node) { setJsonNode(_node, "TryCatchClause", { - make_pair("errorName", _node.errorName()), + make_pair("kind", catchClauseKind(_node.kind())), + make_pair("errorName", toJsonOrNull( + _node.kind() == TryCatchClause::Kind::UserDefined ? &_node.errorName() : nullptr + )), make_pair("parameters", toJsonOrNull(_node.parameters())), make_pair("block", toJson(_node.block())) }); @@ -931,6 +934,19 @@ string ASTJsonConverter::functionCallKind(FunctionCallKind _kind) } } +string ASTJsonConverter::catchClauseKind(TryCatchClause::Kind _kind) +{ + switch (_kind) + { + case TryCatchClause::Kind::Success: return "success"; + case TryCatchClause::Kind::Panic: return "panic"; + case TryCatchClause::Kind::Error: return "error"; + case TryCatchClause::Kind::Fallback: return "fallback"; + case TryCatchClause::Kind::UserDefined: return "userDefined"; + } + solAssert(false, "Unknown kind of catch clause."); +} + string ASTJsonConverter::literalTokenKind(Token _token) { switch (_token) diff --git a/libsolidity/ast/ASTJsonConverter.h b/libsolidity/ast/ASTJsonConverter.h index 551c972a6..113c5f7e1 100644 --- a/libsolidity/ast/ASTJsonConverter.h +++ b/libsolidity/ast/ASTJsonConverter.h @@ -153,6 +153,7 @@ private: static std::string location(VariableDeclaration::Location _location); static std::string contractKind(ContractKind _kind); static std::string functionCallKind(FunctionCallKind _kind); + static std::string catchClauseKind(TryCatchClause::Kind _kind); static std::string literalTokenKind(Token _token); static std::string type(Expression const& _expression); static std::string type(VariableDeclaration const& _varDecl); diff --git a/libsolidity/ast/ASTJsonImporter.cpp b/libsolidity/ast/ASTJsonImporter.cpp index 41647c55a..8f2e99c0f 100644 --- a/libsolidity/ast/ASTJsonImporter.cpp +++ b/libsolidity/ast/ASTJsonImporter.cpp @@ -642,7 +642,8 @@ ASTPointer ASTJsonImporter::createTryCatchClause(Json::Value con { return createASTNode( _node, - memberAsASTString(_node, "errorName"), + tryCatchClauseKind(_node), + nullOrCast(member(_node, "errorName")), nullOrCast(member(_node, "parameters")), convertJsonToASTNode(member(_node, "block")) ); @@ -959,6 +960,24 @@ bool ASTJsonImporter::memberAsBool(Json::Value const& _node, string const& _name // =========== JSON to definition helpers ======================= +TryCatchClause::Kind ASTJsonImporter::tryCatchClauseKind(Json::Value const& _node) +{ + astAssert(!member(_node, "kind").isNull(), ""); + if (_node["kind"].asString() == "success") + return TryCatchClause::Kind::Success; + else if (_node["kind"].asString() == "error") + return TryCatchClause::Kind::Error; + else if (_node["kind"].asString() == "panic") + return TryCatchClause::Kind::Panic; + else if (_node["kind"].asString() == "fallback") + return TryCatchClause::Kind::Fallback; + else if (_node["kind"].asString() == "userDefined") + return TryCatchClause::Kind::UserDefined; + else + astAssert(false, "Unknown try catch clause kind"); + return {}; +} + ContractKind ASTJsonImporter::contractKind(Json::Value const& _node) { ContractKind kind; diff --git a/libsolidity/ast/ASTJsonImporter.h b/libsolidity/ast/ASTJsonImporter.h index a2e8b6410..767ef6d50 100644 --- a/libsolidity/ast/ASTJsonImporter.h +++ b/libsolidity/ast/ASTJsonImporter.h @@ -144,6 +144,7 @@ private: Visibility visibility(Json::Value const& _node); StateMutability stateMutability(Json::Value const& _node); VariableDeclaration::Location location(Json::Value const& _node); + TryCatchClause::Kind tryCatchClauseKind(Json::Value const& _node); ContractKind contractKind(Json::Value const& _node); Token literalTokenKind(Json::Value const& _node); Literal::SubDenomination subdenomination(Json::Value const& _node); diff --git a/libsolidity/ast/AST_accept.h b/libsolidity/ast/AST_accept.h index 678a9ef7e..d6b20cb80 100644 --- a/libsolidity/ast/AST_accept.h +++ b/libsolidity/ast/AST_accept.h @@ -542,6 +542,8 @@ void TryCatchClause::accept(ASTVisitor& _visitor) { if (_visitor.visit(*this)) { + if (m_errorName) + m_errorName->accept(_visitor); if (m_parameters) m_parameters->accept(_visitor); m_block->accept(_visitor); @@ -553,6 +555,8 @@ void TryCatchClause::accept(ASTConstVisitor& _visitor) const { if (_visitor.visit(*this)) { + if (m_errorName) + m_errorName->accept(_visitor); if (m_parameters) m_parameters->accept(_visitor); m_block->accept(_visitor); diff --git a/libsolidity/codegen/ContractCompiler.cpp b/libsolidity/codegen/ContractCompiler.cpp index 9ce2b743d..410eae4c4 100644 --- a/libsolidity/codegen/ContractCompiler.cpp +++ b/libsolidity/codegen/ContractCompiler.cpp @@ -975,14 +975,14 @@ void ContractCompiler::handleCatch(vector> const& _ca ASTPointer panic{}; ASTPointer fallback{}; for (size_t i = 1; i < _catchClauses.size(); ++i) - if (_catchClauses[i]->errorName() == "Error") + if (_catchClauses[i]->kind() == TryCatchClause::Kind::Error) error = _catchClauses[i]; - else if (_catchClauses[i]->errorName() == "Panic") + else if (_catchClauses[i]->kind() == TryCatchClause::Kind::Panic) panic = _catchClauses[i]; - else if (_catchClauses[i]->errorName().empty()) + else if (_catchClauses[i]->kind() == TryCatchClause::Kind::Fallback) fallback = _catchClauses[i]; else - solAssert(false, ""); + solUnimplementedAssert(false, ""); solAssert(_catchClauses.size() == 1ul + (error ? 1 : 0) + (panic ? 1 : 0) + (fallback ? 1 : 0), ""); diff --git a/libsolidity/parsing/Parser.cpp b/libsolidity/parsing/Parser.cpp index a648cca43..49958d328 100644 --- a/libsolidity/parsing/Parser.cpp +++ b/libsolidity/parsing/Parser.cpp @@ -1323,7 +1323,7 @@ ASTPointer Parser::parseTryStatement(ASTPointer const& ASTPointer successBlock = parseBlock(); successClauseFactory.setEndPositionFromNode(successBlock); clauses.emplace_back(successClauseFactory.createNode( - make_shared(), returnsParameters, successBlock + TryCatchClause::Kind::Success, ASTPointer{}, returnsParameters, successBlock )); do @@ -1342,20 +1342,42 @@ ASTPointer Parser::parseCatchClause() RecursionGuard recursionGuard(*this); ASTNodeFactory nodeFactory(*this); expectToken(Token::Catch); - ASTPointer errorName = make_shared(); + TryCatchClause::Kind kind; + ASTPointer errorName; ASTPointer errorParameters; if (m_scanner->currentToken() != Token::LBrace) { if (m_scanner->currentToken() == Token::Identifier) - errorName = expectIdentifierToken(); + { + string name = m_scanner->currentLiteral(); + if (name == "Error") + { + kind = TryCatchClause::Kind::Error; + m_scanner->next(); + } + else if (name == "Panic") + { + kind = TryCatchClause::Kind::Panic; + m_scanner->next(); + } + else + { + errorName = parseIdentifierPath(); + kind = TryCatchClause::Kind::UserDefined; + } + } + else + kind = TryCatchClause::Kind::Fallback; VarDeclParserOptions options; options.allowEmptyName = true; options.allowLocationSpecifier = true; - errorParameters = parseParameterList(options, !errorName->empty()); + errorParameters = parseParameterList(options); } + else + kind = TryCatchClause::Kind::Fallback; ASTPointer block = parseBlock(); nodeFactory.setEndPositionFromNode(block); - return nodeFactory.createNode(errorName, errorParameters, block); + return nodeFactory.createNode(kind, errorName, errorParameters, block); } ASTPointer Parser::parseWhileStatement(ASTPointer const& _docString) diff --git a/test/libsolidity/syntaxTests/tryCatch/invalid_error_name.sol b/test/libsolidity/syntaxTests/tryCatch/invalid_error_name.sol index d777065e6..27c67b505 100644 --- a/test/libsolidity/syntaxTests/tryCatch/invalid_error_name.sol +++ b/test/libsolidity/syntaxTests/tryCatch/invalid_error_name.sol @@ -7,5 +7,4 @@ contract C { } } // ---- -// TypeError 3542: (93-119): Invalid catch clause name. Expected either `catch (...)`, `catch Error(...)`, or `catch Panic(...)`. -// TypeError 3542: (120-143): Invalid catch clause name. Expected either `catch (...)`, `catch Error(...)`, or `catch Panic(...)`. +// DeclarationError 7920: (99-105): Identifier not found or not unique.