From 527c073bb90ce019f6eb605f165c137237f1f684 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 22 Jul 2020 10:28:04 +0200 Subject: [PATCH] Checked arithmetic by default. --- Changelog.md | 3 + liblangutil/Token.h | 2 +- libsolidity/analysis/SyntaxChecker.cpp | 31 ++- libsolidity/analysis/SyntaxChecker.h | 6 + libsolidity/ast/AST.h | 8 +- libsolidity/ast/ASTEnums.h | 2 + libsolidity/ast/ASTJsonConverter.cpp | 2 +- libsolidity/ast/ASTJsonImporter.cpp | 11 +- libsolidity/ast/ASTJsonImporter.h | 2 +- libsolidity/codegen/CompilerContext.h | 5 + libsolidity/codegen/ContractCompiler.cpp | 14 ++ libsolidity/codegen/ExpressionCompiler.cpp | 139 +++++++++---- libsolidity/codegen/ExpressionCompiler.h | 2 +- libsolidity/codegen/YulUtilFunctions.cpp | 184 +++++++++++++++--- libsolidity/codegen/YulUtilFunctions.h | 20 +- libsolidity/codegen/ir/IRGenerationContext.h | 5 + .../codegen/ir/IRGeneratorForStatements.cpp | 41 ++-- .../codegen/ir/IRGeneratorForStatements.h | 2 + libsolidity/parsing/Parser.cpp | 19 +- libsolidity/parsing/Parser.h | 4 +- .../SolidityExpressionCompiler.cpp | 1 + 21 files changed, 405 insertions(+), 98 deletions(-) diff --git a/Changelog.md b/Changelog.md index 831c121c3..88ae1d1e5 100644 --- a/Changelog.md +++ b/Changelog.md @@ -15,6 +15,9 @@ Language Features: * New AST Node ``IdentifierPath`` replacing in many places the ``UserDefinedTypeName`` +AST Changes: + * New node type: unchecked block - used for ``unchecked { ... }``. + ### 0.7.4 (unreleased) Important Bugfixes: diff --git a/liblangutil/Token.h b/liblangutil/Token.h index 183d23887..2eb97054e 100644 --- a/liblangutil/Token.h +++ b/liblangutil/Token.h @@ -191,6 +191,7 @@ namespace solidity::langutil K(Throw, "throw", 0) \ K(Try, "try", 0) \ K(Type, "type", 0) \ + K(Unchecked, "unchecked", 0) \ K(Unicode, "unicode", 0) \ K(Using, "using", 0) \ K(View, "view", 0) \ @@ -266,7 +267,6 @@ namespace solidity::langutil K(Switch, "switch", 0) \ K(Typedef, "typedef", 0) \ K(TypeOf, "typeof", 0) \ - K(Unchecked, "unchecked", 0) \ K(Var, "var", 0) \ \ /* Yul-specific tokens, but not keywords. */ \ diff --git a/libsolidity/analysis/SyntaxChecker.cpp b/libsolidity/analysis/SyntaxChecker.cpp index 076592053..0fc1fa4ab 100644 --- a/libsolidity/analysis/SyntaxChecker.cpp +++ b/libsolidity/analysis/SyntaxChecker.cpp @@ -190,6 +190,28 @@ void SyntaxChecker::endVisit(ForStatement const&) m_inLoopDepth--; } +bool SyntaxChecker::visit(Block const& _block) +{ + if (_block.unchecked()) + { + if (m_uncheckedArithmetic) + m_errorReporter.syntaxError( + 1941_error, + _block.location(), + "\"unchecked\" blocks cannot be nested." + ); + + m_uncheckedArithmetic = true; + } + return true; +} + +void SyntaxChecker::endVisit(Block const& _block) +{ + if (_block.unchecked()) + m_uncheckedArithmetic = false; +} + bool SyntaxChecker::visit(Continue const& _continueStatement) { if (m_inLoopDepth <= 0) @@ -288,8 +310,15 @@ bool SyntaxChecker::visit(InlineAssembly const& _inlineAssembly) return false; } -bool SyntaxChecker::visit(PlaceholderStatement const&) +bool SyntaxChecker::visit(PlaceholderStatement const& _placeholder) { + if (m_uncheckedArithmetic) + m_errorReporter.syntaxError( + 2573_error, + _placeholder.location(), + "The placeholder statement \"_\" cannot be used inside an \"unchecked\" block." + ); + m_placeholderFound = true; return true; } diff --git a/libsolidity/analysis/SyntaxChecker.h b/libsolidity/analysis/SyntaxChecker.h index 52c82ff6e..46f6c486e 100644 --- a/libsolidity/analysis/SyntaxChecker.h +++ b/libsolidity/analysis/SyntaxChecker.h @@ -71,6 +71,9 @@ private: bool visit(ForStatement const& _forStatement) override; void endVisit(ForStatement const& _forStatement) override; + bool visit(Block const& _block) override; + void endVisit(Block const& _block) override; + bool visit(Continue const& _continueStatement) override; bool visit(Break const& _breakStatement) override; @@ -100,6 +103,9 @@ private: /// Flag that indicates whether some version pragma was present. bool m_versionPragmaFound = false; + /// Flag that indicates whether we are inside an unchecked block. + bool m_uncheckedArithmetic = false; + int m_inLoopDepth = 0; std::optional m_currentContractKind; diff --git a/libsolidity/ast/AST.h b/libsolidity/ast/AST.h index 505028a9f..a4d1f1654 100644 --- a/libsolidity/ast/AST.h +++ b/libsolidity/ast/AST.h @@ -1384,18 +1384,24 @@ public: int64_t _id, SourceLocation const& _location, ASTPointer const& _docString, + bool _unchecked, std::vector> _statements ): - Statement(_id, _location, _docString), m_statements(std::move(_statements)) {} + Statement(_id, _location, _docString), + m_statements(std::move(_statements)), + m_unchecked(_unchecked) + {} void accept(ASTVisitor& _visitor) override; void accept(ASTConstVisitor& _visitor) const override; std::vector> const& statements() const { return m_statements; } + bool unchecked() const { return m_unchecked; } BlockAnnotation& annotation() const override; private: std::vector> m_statements; + bool m_unchecked; }; /** diff --git a/libsolidity/ast/ASTEnums.h b/libsolidity/ast/ASTEnums.h index 09dff7337..87f9817c3 100644 --- a/libsolidity/ast/ASTEnums.h +++ b/libsolidity/ast/ASTEnums.h @@ -39,6 +39,8 @@ enum class StateMutability { Pure, View, NonPayable, Payable }; /// Visibility ordered from restricted to unrestricted. enum class Visibility { Default, Private, Internal, Public, External }; +enum class Arithmetic { Checked, Wrapping }; + inline std::string stateMutabilityToString(StateMutability const& _stateMutability) { switch (_stateMutability) diff --git a/libsolidity/ast/ASTJsonConverter.cpp b/libsolidity/ast/ASTJsonConverter.cpp index c1137f2b9..0d24726db 100644 --- a/libsolidity/ast/ASTJsonConverter.cpp +++ b/libsolidity/ast/ASTJsonConverter.cpp @@ -596,7 +596,7 @@ bool ASTJsonConverter::visit(InlineAssembly const& _node) bool ASTJsonConverter::visit(Block const& _node) { - setJsonNode(_node, "Block", { + setJsonNode(_node, _node.unchecked() ? "UncheckedBlock" : "Block", { make_pair("statements", toJson(_node.statements())) }); return false; diff --git a/libsolidity/ast/ASTJsonImporter.cpp b/libsolidity/ast/ASTJsonImporter.cpp index 1dec4b4cb..a3d88b435 100644 --- a/libsolidity/ast/ASTJsonImporter.cpp +++ b/libsolidity/ast/ASTJsonImporter.cpp @@ -154,7 +154,9 @@ ASTPointer ASTJsonImporter::convertJsonToASTNode(Json::Value const& _js if (nodeType == "InlineAssembly") return createInlineAssembly(_json); if (nodeType == "Block") - return createBlock(_json); + return createBlock(_json, false); + if (nodeType == "UncheckedBlock") + return createBlock(_json, true); if (nodeType == "PlaceholderStatement") return createPlaceholderStatement(_json); if (nodeType == "IfStatement") @@ -439,7 +441,7 @@ ASTPointer ASTJsonImporter::createFunctionDefinition(Json::V createParameterList(member(_node, "parameters")), modifiers, createParameterList(member(_node, "returnParameters")), - memberAsBool(_node, "implemented") ? createBlock(member(_node, "body")) : nullptr + memberAsBool(_node, "implemented") ? createBlock(member(_node, "body"), false) : nullptr ); } @@ -489,7 +491,7 @@ ASTPointer ASTJsonImporter::createModifierDefinition(Json::V createParameterList(member(_node, "parameters")), memberAsBool(_node, "virtual"), _node["overrides"].isNull() ? nullptr : createOverrideSpecifier(member(_node, "overrides")), - _node["body"].isNull() ? nullptr: createBlock(member(_node, "body")) + _node["body"].isNull() ? nullptr: createBlock(member(_node, "body"), false) ); } @@ -589,7 +591,7 @@ ASTPointer ASTJsonImporter::createInlineAssembly(Json::Value con ); } -ASTPointer ASTJsonImporter::createBlock(Json::Value const& _node) +ASTPointer ASTJsonImporter::createBlock(Json::Value const& _node, bool _unchecked) { std::vector> statements; for (auto& stat: member(_node, "statements")) @@ -597,6 +599,7 @@ ASTPointer ASTJsonImporter::createBlock(Json::Value const& _node) return createASTNode( _node, nullOrASTString(_node, "documentation"), + _unchecked, statements ); } diff --git a/libsolidity/ast/ASTJsonImporter.h b/libsolidity/ast/ASTJsonImporter.h index 824249e59..c64790d4d 100644 --- a/libsolidity/ast/ASTJsonImporter.h +++ b/libsolidity/ast/ASTJsonImporter.h @@ -93,7 +93,7 @@ private: ASTPointer createMapping(Json::Value const& _node); ASTPointer createArrayTypeName(Json::Value const& _node); ASTPointer createInlineAssembly(Json::Value const& _node); - ASTPointer createBlock(Json::Value const& _node); + ASTPointer createBlock(Json::Value const& _node, bool _unchecked); ASTPointer createPlaceholderStatement(Json::Value const& _node); ASTPointer createIfStatement(Json::Value const& _node); ASTPointer createTryCatchClause(Json::Value const& _node); diff --git a/libsolidity/codegen/CompilerContext.h b/libsolidity/codegen/CompilerContext.h index bd37c3686..b552bf6b1 100644 --- a/libsolidity/codegen/CompilerContext.h +++ b/libsolidity/codegen/CompilerContext.h @@ -123,6 +123,9 @@ public: void setMostDerivedContract(ContractDefinition const& _contract) { m_mostDerivedContract = &_contract; } ContractDefinition const& mostDerivedContract() const; + void setArithmetic(Arithmetic _value) { m_arithmetic = _value; } + Arithmetic arithmetic() const { return m_arithmetic; } + /// @returns the next function in the queue of functions that are still to be compiled /// (i.e. that were referenced during compilation but where we did not yet generate code for). /// Returns nullptr if the queue is empty. Does not remove the function from the queue, @@ -380,6 +383,8 @@ private: std::map> m_localVariables; /// The contract currently being compiled. Virtual function lookup starts from this contarct. ContractDefinition const* m_mostDerivedContract = nullptr; + /// Whether to use checked arithmetic. + Arithmetic m_arithmetic = Arithmetic::Checked; /// Stack of current visited AST nodes, used for location attachment std::stack m_visitedNodes; /// The runtime context if in Creation mode, this is used for generating tags that would be stored into the storage and then used at runtime. diff --git a/libsolidity/codegen/ContractCompiler.cpp b/libsolidity/codegen/ContractCompiler.cpp index 35780fdc2..eb3081e13 100644 --- a/libsolidity/codegen/ContractCompiler.cpp +++ b/libsolidity/codegen/ContractCompiler.cpp @@ -1247,19 +1247,31 @@ bool ContractCompiler::visit(PlaceholderStatement const& _placeholderStatement) { StackHeightChecker checker(m_context); CompilerContext::LocationSetter locationSetter(m_context, _placeholderStatement); + solAssert(m_context.arithmetic() == Arithmetic::Checked, "Placeholder cannot be used inside checked block."); appendModifierOrFunctionCode(); + solAssert(m_context.arithmetic() == Arithmetic::Checked, "Arithmetic not reset to 'checked'."); checker.check(); return true; } bool ContractCompiler::visit(Block const& _block) { + if (_block.unchecked()) + { + solAssert(m_context.arithmetic() == Arithmetic::Checked, ""); + m_context.setArithmetic(Arithmetic::Wrapping); + } storeStackHeight(&_block); return true; } void ContractCompiler::endVisit(Block const& _block) { + if (_block.unchecked()) + { + solAssert(m_context.arithmetic() == Arithmetic::Wrapping, ""); + m_context.setArithmetic(Arithmetic::Checked); + } // Frees local variables declared in the scope of this block. popScopedVariables(&_block); } @@ -1327,6 +1339,8 @@ void ContractCompiler::appendModifierOrFunctionCode() if (codeBlock) { + m_context.setArithmetic(Arithmetic::Checked); + std::set experimentalFeaturesOutside = m_context.experimentalFeaturesActive(); m_context.setExperimentalFeatures(codeBlock->sourceUnit().annotation().experimentalFeatures); diff --git a/libsolidity/codegen/ExpressionCompiler.cpp b/libsolidity/codegen/ExpressionCompiler.cpp index 9dff56afc..7303f1429 100644 --- a/libsolidity/codegen/ExpressionCompiler.cpp +++ b/libsolidity/codegen/ExpressionCompiler.cpp @@ -275,7 +275,7 @@ bool ExpressionCompiler::visit(Assignment const& _assignment) solAssert(*_assignment.annotation().type == leftType, ""); bool cleanupNeeded = false; if (op != Token::Assign) - cleanupNeeded = cleanupNeededForOp(leftType.category(), binOp); + cleanupNeeded = cleanupNeededForOp(leftType.category(), binOp, m_context.arithmetic()); _assignment.rightHandSide().accept(*this); // Perform some conversion already. This will convert storage types to memory and literals // to their actual type, but will not convert e.g. memory to storage. @@ -381,9 +381,10 @@ bool ExpressionCompiler::visit(TupleExpression const& _tuple) bool ExpressionCompiler::visit(UnaryOperation const& _unaryOperation) { CompilerContext::LocationSetter locationSetter(m_context, _unaryOperation); - if (_unaryOperation.annotation().type->category() == Type::Category::RationalNumber) + Type const& type = *_unaryOperation.annotation().type; + if (type.category() == Type::Category::RationalNumber) { - m_context << _unaryOperation.annotation().type->literalValue(nullptr); + m_context << type.literalValue(nullptr); return false; } @@ -406,24 +407,39 @@ bool ExpressionCompiler::visit(UnaryOperation const& _unaryOperation) case Token::Dec: // -- (pre- or postfix) solAssert(!!m_currentLValue, "LValue not retrieved."); solUnimplementedAssert( - _unaryOperation.annotation().type->category() != Type::Category::FixedPoint, + type.category() != Type::Category::FixedPoint, "Not yet implemented - FixedPointType." ); m_currentLValue->retrieveValue(_unaryOperation.location()); if (!_unaryOperation.isPrefixOperation()) { // store value for later - solUnimplementedAssert(_unaryOperation.annotation().type->sizeOnStack() == 1, "Stack size != 1 not implemented."); + solUnimplementedAssert(type.sizeOnStack() == 1, "Stack size != 1 not implemented."); m_context << Instruction::DUP1; if (m_currentLValue->sizeOnStack() > 0) for (unsigned i = 1 + m_currentLValue->sizeOnStack(); i > 0; --i) m_context << swapInstruction(i); } - m_context << u256(1); if (_unaryOperation.getOperator() == Token::Inc) - m_context << Instruction::ADD; + { + if (m_context.arithmetic() == Arithmetic::Checked) + m_context.callYulFunction(m_context.utilFunctions().incrementCheckedFunction(type), 1, 1); + else + { + m_context << u256(1); + m_context << Instruction::ADD; + } + } else - m_context << Instruction::SWAP1 << Instruction::SUB; + { + if (m_context.arithmetic() == Arithmetic::Checked) + m_context.callYulFunction(m_context.utilFunctions().decrementCheckedFunction(type), 1, 1); + else + { + m_context << u256(1); + m_context << Instruction::SWAP1 << Instruction::SUB; + } + } // Stack for prefix: [ref...] (*ref)+-1 // Stack for postfix: *ref [ref...] (*ref)+-1 for (unsigned i = m_currentLValue->sizeOnStack(); i > 0; --i) @@ -437,7 +453,10 @@ bool ExpressionCompiler::visit(UnaryOperation const& _unaryOperation) // unary add, so basically no-op break; case Token::Sub: // - - m_context << u256(0) << Instruction::SUB; + if (m_context.arithmetic() == Arithmetic::Checked) + m_context.callYulFunction(m_context.utilFunctions().negateNumberCheckedFunction(type), 1, 1); + else + m_context << u256(0) << Instruction::SUB; break; default: solAssert(false, "Invalid unary operator: " + string(TokenTraits::toString(_unaryOperation.getOperator()))); @@ -460,7 +479,7 @@ bool ExpressionCompiler::visit(BinaryOperation const& _binaryOperation) m_context << commonType->literalValue(nullptr); else { - bool cleanupNeeded = cleanupNeededForOp(commonType->category(), c_op); + bool cleanupNeeded = cleanupNeededForOp(commonType->category(), c_op, m_context.arithmetic()); TypePointer leftTargetType = commonType; TypePointer rightTargetType = @@ -2112,34 +2131,65 @@ void ExpressionCompiler::appendArithmeticOperatorCode(Token _operator, Type cons solUnimplemented("Not yet implemented - FixedPointType."); IntegerType const& type = dynamic_cast(_type); - bool const c_isSigned = type.isSigned(); - - switch (_operator) + if (m_context.arithmetic() == Arithmetic::Checked) { - case Token::Add: - m_context << Instruction::ADD; - break; - case Token::Sub: - m_context << Instruction::SUB; - break; - case Token::Mul: - m_context << Instruction::MUL; - break; - case Token::Div: - case Token::Mod: - { - // Test for division by zero - m_context << Instruction::DUP2 << Instruction::ISZERO; - m_context.appendConditionalInvalid(); - - if (_operator == Token::Div) - m_context << (c_isSigned ? Instruction::SDIV : Instruction::DIV); - else - m_context << (c_isSigned ? Instruction::SMOD : Instruction::MOD); - break; + string functionName; + switch (_operator) + { + case Token::Add: + functionName = m_context.utilFunctions().overflowCheckedIntAddFunction(type); + break; + case Token::Sub: + functionName = m_context.utilFunctions().overflowCheckedIntSubFunction(type); + break; + case Token::Mul: + functionName = m_context.utilFunctions().overflowCheckedIntMulFunction(type); + break; + case Token::Div: + functionName = m_context.utilFunctions().overflowCheckedIntDivFunction(type); + break; + case Token::Mod: + functionName = m_context.utilFunctions().intModFunction(type); + break; + case Token::Exp: + // EXP is handled in a different function. + default: + solAssert(false, "Unknown arithmetic operator."); + } + // TODO Maybe we want to force-inline this? + m_context.callYulFunction(functionName, 2, 1); } - default: - solAssert(false, "Unknown arithmetic operator."); + else + { + bool const c_isSigned = type.isSigned(); + + switch (_operator) + { + case Token::Add: + m_context << Instruction::ADD; + break; + case Token::Sub: + m_context << Instruction::SUB; + break; + case Token::Mul: + m_context << Instruction::MUL; + break; + case Token::Div: + case Token::Mod: + { + // Test for division by zero + m_context << Instruction::DUP2 << Instruction::ISZERO; + m_context.appendConditionalInvalid(); + + if (_operator == Token::Div) + m_context << (c_isSigned ? Instruction::SDIV : Instruction::DIV); + else + m_context << (c_isSigned ? Instruction::SMOD : Instruction::MOD); + break; + } + default: + solAssert(false, "Unknown arithmetic operator."); + } } } @@ -2237,7 +2287,14 @@ void ExpressionCompiler::appendExpOperatorCode(Type const& _valueType, Type cons solAssert(_valueType.category() == Type::Category::Integer, ""); solAssert(!dynamic_cast(_exponentType).isSigned(), ""); - m_context << Instruction::EXP; + + if (m_context.arithmetic() == Arithmetic::Checked) + m_context.callYulFunction(m_context.utilFunctions().overflowCheckedIntExpFunction( + dynamic_cast(_valueType), + dynamic_cast(_exponentType) + ), 2, 1); + else + m_context << Instruction::EXP; } void ExpressionCompiler::appendExternalFunctionCall( @@ -2561,11 +2618,15 @@ void ExpressionCompiler::setLValueToStorageItem(Expression const& _expression) setLValue(_expression, *_expression.annotation().type); } -bool ExpressionCompiler::cleanupNeededForOp(Type::Category _type, Token _op) +bool ExpressionCompiler::cleanupNeededForOp(Type::Category _type, Token _op, Arithmetic _arithmetic) { if (TokenTraits::isCompareOp(_op) || TokenTraits::isShiftOp(_op)) return true; - else if (_type == Type::Category::Integer && (_op == Token::Div || _op == Token::Mod || _op == Token::Exp)) + else if ( + _arithmetic == Arithmetic::Wrapping && + _type == Type::Category::Integer && + (_op == Token::Div || _op == Token::Mod || _op == Token::Exp) + ) // We need cleanup for EXP because 0**0 == 1, but 0**0x100 == 0 // It would suffice to clean the exponent, though. return true; diff --git a/libsolidity/codegen/ExpressionCompiler.h b/libsolidity/codegen/ExpressionCompiler.h index 7b574b59c..66399d070 100644 --- a/libsolidity/codegen/ExpressionCompiler.h +++ b/libsolidity/codegen/ExpressionCompiler.h @@ -132,7 +132,7 @@ private: /// @returns true if the operator applied to the given type requires a cleanup prior to the /// operation. - static bool cleanupNeededForOp(Type::Category _type, Token _op); + static bool cleanupNeededForOp(Type::Category _type, Token _op, Arithmetic _arithmetic); void acceptAndConvert(Expression const& _expression, Type const& _type, bool _cleanupNeeded = false); diff --git a/libsolidity/codegen/YulUtilFunctions.cpp b/libsolidity/codegen/YulUtilFunctions.cpp index 14c907f49..95f11da7d 100644 --- a/libsolidity/codegen/YulUtilFunctions.cpp +++ b/libsolidity/codegen/YulUtilFunctions.cpp @@ -483,6 +483,22 @@ string YulUtilFunctions::overflowCheckedIntAddFunction(IntegerType const& _type) }); } +string YulUtilFunctions::wrappingIntAddFunction(IntegerType const& _type) +{ + string functionName = "wrapping_add_" + _type.identifier(); + return m_functionCollector.createFunction(functionName, [&]() { + return + Whiskers(R"( + function (x, y) -> sum { + sum := (add(x, y)) + } + )") + ("functionName", functionName) + ("cleanupFunction", cleanupFunction(_type)) + .render(); + }); +} + string YulUtilFunctions::overflowCheckedIntMulFunction(IntegerType const& _type) { string functionName = "checked_mul_" + _type.identifier(); @@ -519,6 +535,22 @@ string YulUtilFunctions::overflowCheckedIntMulFunction(IntegerType const& _type) }); } +string YulUtilFunctions::wrappingIntMulFunction(IntegerType const& _type) +{ + string functionName = "wrapping_mul_" + _type.identifier(); + return m_functionCollector.createFunction(functionName, [&]() { + return + Whiskers(R"( + function (x, y) -> product { + product := (mul(x, y)) + } + )") + ("functionName", functionName) + ("cleanupFunction", cleanupFunction(_type)) + .render(); + }); +} + string YulUtilFunctions::overflowCheckedIntDivFunction(IntegerType const& _type) { string functionName = "checked_div_" + _type.identifier(); @@ -548,9 +580,30 @@ string YulUtilFunctions::overflowCheckedIntDivFunction(IntegerType const& _type) }); } -string YulUtilFunctions::checkedIntModFunction(IntegerType const& _type) +string YulUtilFunctions::wrappingIntDivFunction(IntegerType const& _type) { - string functionName = "checked_mod_" + _type.identifier(); + string functionName = "wrapping_div_" + _type.identifier(); + return m_functionCollector.createFunction(functionName, [&]() { + return + Whiskers(R"( + function (x, y) -> r { + x := (x) + y := (y) + if iszero(y) { () } + r := sdiv(x, y) + } + )") + ("functionName", functionName) + ("cleanupFunction", cleanupFunction(_type)) + ("signed", _type.isSigned()) + ("error", panicFunction()) + .render(); + }); +} + +string YulUtilFunctions::intModFunction(IntegerType const& _type) +{ + string functionName = "mod_" + _type.identifier(); return m_functionCollector.createFunction(functionName, [&]() { return Whiskers(R"( @@ -599,6 +652,22 @@ string YulUtilFunctions::overflowCheckedIntSubFunction(IntegerType const& _type) }); } +string YulUtilFunctions::wrappingIntSubFunction(IntegerType const& _type) +{ + string functionName = "wrapping_sub_" + _type.identifier(); + return m_functionCollector.createFunction(functionName, [&] { + return + Whiskers(R"( + function (x, y) -> diff { + diff := (sub(x, y)) + } + )") + ("functionName", functionName) + ("cleanupFunction", cleanupFunction(_type)) + .render(); + }); +} + string YulUtilFunctions::overflowCheckedIntExpFunction( IntegerType const& _type, IntegerType const& _exponentType @@ -894,6 +963,30 @@ string YulUtilFunctions::overflowCheckedExpLoopFunction() }); } +string YulUtilFunctions::wrappingIntExpFunction( + IntegerType const& _type, + IntegerType const& _exponentType +) +{ + solAssert(!_exponentType.isSigned(), ""); + + string functionName = "wrapping_exp_" + _type.identifier() + "_" + _exponentType.identifier(); + return m_functionCollector.createFunction(functionName, [&]() { + return + Whiskers(R"( + function (base, exponent) -> power { + base := (base) + exponent := (exponent) + power := (exp(base, exponent)) + } + )") + ("functionName", functionName) + ("baseCleanupFunction", cleanupFunction(_type)) + ("exponentCleanupFunction", cleanupFunction(_exponentType)) + .render(); + }); +} + string YulUtilFunctions::extractByteArrayLengthFunction() { string functionName = "extract_byte_array_length"; @@ -2951,30 +3044,39 @@ std::string YulUtilFunctions::decrementCheckedFunction(Type const& _type) string const functionName = "decrement_" + _type.identifier(); return m_functionCollector.createFunction(functionName, [&]() { - u256 minintval; - - // Smallest admissible value to decrement - if (type.isSigned()) - minintval = 0 - (u256(1) << (type.numBits() - 1)) + 1; - else - minintval = 1; - return Whiskers(R"( function (value) -> ret { value := (value) - if (value, ) { () } + if eq(value, ) { () } ret := sub(value, 1) } )") ("functionName", functionName) ("panic", panicFunction()) - ("minval", toCompactHexWithPrefix(minintval)) - ("lt", type.isSigned() ? "slt" : "lt") + ("minval", toCompactHexWithPrefix(type.min())) ("cleanupFunction", cleanupFunction(_type)) .render(); }); } +std::string YulUtilFunctions::decrementWrappingFunction(Type const& _type) +{ + IntegerType const& type = dynamic_cast(_type); + + string const functionName = "decrement_wrapping_" + _type.identifier(); + + return m_functionCollector.createFunction(functionName, [&]() { + return Whiskers(R"( + function (value) -> ret { + ret := (sub(value, 1)) + } + )") + ("functionName", functionName) + ("cleanupFunction", cleanupFunction(type)) + .render(); + }); +} + std::string YulUtilFunctions::incrementCheckedFunction(Type const& _type) { IntegerType const& type = dynamic_cast(_type); @@ -2982,55 +3084,79 @@ std::string YulUtilFunctions::incrementCheckedFunction(Type const& _type) string const functionName = "increment_" + _type.identifier(); return m_functionCollector.createFunction(functionName, [&]() { - u256 maxintval; - - // Biggest admissible value to increment - if (type.isSigned()) - maxintval = (u256(1) << (type.numBits() - 1)) - 2; - else - maxintval = (u256(1) << type.numBits()) - 2; - return Whiskers(R"( function (value) -> ret { value := (value) - if (value, ) { () } + if eq(value, ) { () } ret := add(value, 1) } )") ("functionName", functionName) - ("maxval", toCompactHexWithPrefix(maxintval)) - ("gt", type.isSigned() ? "sgt" : "gt") + ("maxval", toCompactHexWithPrefix(type.max())) ("panic", panicFunction()) ("cleanupFunction", cleanupFunction(_type)) .render(); }); } +std::string YulUtilFunctions::incrementWrappingFunction(Type const& _type) +{ + IntegerType const& type = dynamic_cast(_type); + + string const functionName = "increment_wrapping_" + _type.identifier(); + + return m_functionCollector.createFunction(functionName, [&]() { + return Whiskers(R"( + function (value) -> ret { + ret := (add(value, 1)) + } + )") + ("functionName", functionName) + ("cleanupFunction", cleanupFunction(type)) + .render(); + }); +} + string YulUtilFunctions::negateNumberCheckedFunction(Type const& _type) { IntegerType const& type = dynamic_cast(_type); solAssert(type.isSigned(), "Expected signed type!"); string const functionName = "negate_" + _type.identifier(); - - u256 const minintval = 0 - (u256(1) << (type.numBits() - 1)) + 1; - return m_functionCollector.createFunction(functionName, [&]() { return Whiskers(R"( function (value) -> ret { value := (value) - if slt(value, ) { () } + if eq(value, ) { () } ret := sub(0, value) } )") ("functionName", functionName) - ("minval", toCompactHexWithPrefix(minintval)) + ("minval", toCompactHexWithPrefix(type.min())) ("cleanupFunction", cleanupFunction(_type)) ("panic", panicFunction()) .render(); }); } +string YulUtilFunctions::negateNumberWrappingFunction(Type const& _type) +{ + IntegerType const& type = dynamic_cast(_type); + solAssert(type.isSigned(), "Expected signed type!"); + + string const functionName = "negate_" + _type.identifier(); + return m_functionCollector.createFunction(functionName, [&]() { + return Whiskers(R"( + function (value) -> ret { + value := (sub(0, value))) + } + )") + ("functionName", functionName) + ("cleanupFunction", cleanupFunction(type)) + .render(); + }); +} + string YulUtilFunctions::zeroValueFunction(Type const& _type, bool _splitFunctionTypes) { solAssert(_type.category() != Type::Category::Mapping, ""); diff --git a/libsolidity/codegen/YulUtilFunctions.h b/libsolidity/codegen/YulUtilFunctions.h index 7b6d66d96..731a7f939 100644 --- a/libsolidity/codegen/YulUtilFunctions.h +++ b/libsolidity/codegen/YulUtilFunctions.h @@ -106,24 +106,35 @@ public: /// signature: (x, y) -> sum std::string overflowCheckedIntAddFunction(IntegerType const& _type); + /// signature: (x, y) -> sum + std::string wrappingIntAddFunction(IntegerType const& _type); /// signature: (x, y) -> product std::string overflowCheckedIntMulFunction(IntegerType const& _type); + /// signature: (x, y) -> product + std::string wrappingIntMulFunction(IntegerType const& _type); /// @returns name of function to perform division on integers. /// Checks for division by zero and the special case of /// signed division of the smallest number by -1. std::string overflowCheckedIntDivFunction(IntegerType const& _type); + /// @returns name of function to perform division on integers. + /// Checks for division by zero. + std::string wrappingIntDivFunction(IntegerType const& _type); /// @returns name of function to perform modulo on integers. /// Reverts for modulo by zero. - std::string checkedIntModFunction(IntegerType const& _type); + std::string intModFunction(IntegerType const& _type); /// @returns computes the difference between two values. /// Assumes the input to be in range for the type. /// signature: (x, y) -> diff std::string overflowCheckedIntSubFunction(IntegerType const& _type); + /// @returns computes the difference between two values. + /// signature: (x, y) -> diff + std::string wrappingIntSubFunction(IntegerType const& _type); + /// @returns the name of the exponentiation function. /// signature: (base, exponent) -> power std::string overflowCheckedIntExpFunction(IntegerType const& _type, IntegerType const& _exponentType); @@ -151,6 +162,10 @@ public: /// signature: (power, base, exponent, max) -> power std::string overflowCheckedExpLoopFunction(); + /// @returns the name of the exponentiation function. + /// signature: (base, exponent) -> power + std::string wrappingIntExpFunction(IntegerType const& _type, IntegerType const& _exponentType); + /// @returns the name of a function that fetches the length of the given /// array /// signature: (array) -> length @@ -367,9 +382,12 @@ public: std::string forwardingRevertFunction(); std::string incrementCheckedFunction(Type const& _type); + std::string incrementWrappingFunction(Type const& _type); std::string decrementCheckedFunction(Type const& _type); + std::string decrementWrappingFunction(Type const& _type); std::string negateNumberCheckedFunction(Type const& _type); + std::string negateNumberWrappingFunction(Type const& _type); /// @returns the name of a function that returns the zero value for the /// provided type. diff --git a/libsolidity/codegen/ir/IRGenerationContext.h b/libsolidity/codegen/ir/IRGenerationContext.h index 8f102dfda..d044a6d4a 100644 --- a/libsolidity/codegen/ir/IRGenerationContext.h +++ b/libsolidity/codegen/ir/IRGenerationContext.h @@ -132,6 +132,9 @@ public: langutil::EVMVersion evmVersion() const { return m_evmVersion; }; + void setArithmetic(Arithmetic _value) { m_arithmetic = _value; } + Arithmetic arithmetic() const { return m_arithmetic; } + ABIFunctions abiFunctions(); /// @returns code that stores @param _message for revert reason @@ -161,6 +164,8 @@ private: std::map> m_stateVariables; MultiUseYulFunctionCollector m_functions; size_t m_varCounter = 0; + /// Whether to use checked or wrapping arithmetic. + Arithmetic m_arithmetic = Arithmetic::Checked; /// Flag indicating whether any inline assembly block was seen. bool m_inlineAssemblySeen = false; diff --git a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp index c26012e17..6102b8603 100644 --- a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp +++ b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp @@ -474,6 +474,25 @@ bool IRGeneratorForStatements::visit(TupleExpression const& _tuple) return false; } +bool IRGeneratorForStatements::visit(Block const& _block) +{ + if (_block.unchecked()) + { + solAssert(m_context.arithmetic() == Arithmetic::Checked, ""); + m_context.setArithmetic(Arithmetic::Wrapping); + } + return true; +} + +void IRGeneratorForStatements::endVisit(Block const& _block) +{ + if (_block.unchecked()) + { + solAssert(m_context.arithmetic() == Arithmetic::Wrapping, ""); + m_context.setArithmetic(Arithmetic::Checked); + } +} + bool IRGeneratorForStatements::visit(IfStatement const& _ifStatement) { _ifStatement.condition().accept(*this); @@ -618,11 +637,11 @@ void IRGeneratorForStatements::endVisit(UnaryOperation const& _unaryOperation) else if (op == Token::Sub) { IntegerType const& intType = *dynamic_cast(&resultType); - define(_unaryOperation) << - m_utils.negateNumberCheckedFunction(intType) << - "(" << - IRVariable(_unaryOperation.subExpression()).name() << - ")\n"; + define(_unaryOperation) << ( + m_context.arithmetic() == Arithmetic::Checked ? + m_utils.negateNumberCheckedFunction(intType) : + m_utils.negateNumberWrappingFunction(intType) + ) << "(" << IRVariable(_unaryOperation.subExpression()).name() << ")\n"; } else solUnimplementedAssert(false, "Unary operator not yet implemented"); @@ -2560,23 +2579,23 @@ string IRGeneratorForStatements::binaryOperation( if (IntegerType const* type = dynamic_cast(&_type)) { string fun; - // TODO: Implement all operations for signed and unsigned types. + bool checked = m_context.arithmetic() == Arithmetic::Checked; switch (_operator) { case Token::Add: - fun = m_utils.overflowCheckedIntAddFunction(*type); + fun = checked ? m_utils.overflowCheckedIntAddFunction(*type) : m_utils.wrappingIntAddFunction(*type); break; case Token::Sub: - fun = m_utils.overflowCheckedIntSubFunction(*type); + fun = checked ? m_utils.overflowCheckedIntSubFunction(*type) : m_utils.wrappingIntSubFunction(*type); break; case Token::Mul: - fun = m_utils.overflowCheckedIntMulFunction(*type); + fun = checked ? m_utils.overflowCheckedIntMulFunction(*type) : m_utils.wrappingIntMulFunction(*type); break; case Token::Div: - fun = m_utils.overflowCheckedIntDivFunction(*type); + fun = checked ? m_utils.overflowCheckedIntDivFunction(*type) : m_utils.wrappingIntDivFunction(*type); break; case Token::Mod: - fun = m_utils.checkedIntModFunction(*type); + fun = m_utils.intModFunction(*type); break; case Token::BitOr: fun = "or"; diff --git a/libsolidity/codegen/ir/IRGeneratorForStatements.h b/libsolidity/codegen/ir/IRGeneratorForStatements.h index 3d87eae57..2f17436c6 100644 --- a/libsolidity/codegen/ir/IRGeneratorForStatements.h +++ b/libsolidity/codegen/ir/IRGeneratorForStatements.h @@ -66,6 +66,8 @@ public: bool visit(Conditional const& _conditional) override; bool visit(Assignment const& _assignment) override; bool visit(TupleExpression const& _tuple) override; + bool visit(Block const& _block) override; + void endVisit(Block const& _block) override; bool visit(IfStatement const& _ifStatement) override; bool visit(ForStatement const& _forStatement) override; bool visit(WhileStatement const& _whileStatement) override; diff --git a/libsolidity/parsing/Parser.cpp b/libsolidity/parsing/Parser.cpp index a870d9064..79eeb015a 100644 --- a/libsolidity/parsing/Parser.cpp +++ b/libsolidity/parsing/Parser.cpp @@ -1096,16 +1096,23 @@ ASTPointer Parser::parseParameterList( return nodeFactory.createNode(parameters); } -ASTPointer Parser::parseBlock(ASTPointer const& _docString) +ASTPointer Parser::parseBlock(bool _allowUnchecked, ASTPointer const& _docString) { RecursionGuard recursionGuard(*this); ASTNodeFactory nodeFactory(*this); + bool const unchecked = m_scanner->currentToken() == Token::Unchecked; + if (unchecked) + { + if (!_allowUnchecked) + parserError(5296_error, "\"unchecked\" blocks can only be used inside regular blocks."); + m_scanner->next(); + } expectToken(Token::LBrace); vector> statements; try { while (m_scanner->currentToken() != Token::RBrace) - statements.push_back(parseStatement()); + statements.push_back(parseStatement(true)); nodeFactory.markEndPosition(); } catch (FatalError const&) @@ -1122,10 +1129,10 @@ ASTPointer Parser::parseBlock(ASTPointer const& _docString) expectTokenOrConsumeUntil(Token::RBrace, "Block"); else expectToken(Token::RBrace); - return nodeFactory.createNode(_docString, statements); + return nodeFactory.createNode(_docString, unchecked, statements); } -ASTPointer Parser::parseStatement() +ASTPointer Parser::parseStatement(bool _allowUnchecked) { RecursionGuard recursionGuard(*this); ASTPointer docString; @@ -1144,9 +1151,9 @@ ASTPointer Parser::parseStatement() return parseDoWhileStatement(docString); case Token::For: return parseForStatement(docString); + case Token::Unchecked: case Token::LBrace: - return parseBlock(docString); - // starting from here, all statements must be terminated by a semicolon + return parseBlock(_allowUnchecked, docString); case Token::Continue: statement = ASTNodeFactory(*this).createNode(docString); m_scanner->next(); diff --git a/libsolidity/parsing/Parser.h b/libsolidity/parsing/Parser.h index 08426287c..c5c85d76c 100644 --- a/libsolidity/parsing/Parser.h +++ b/libsolidity/parsing/Parser.h @@ -115,8 +115,8 @@ private: VarDeclParserOptions const& _options = {}, bool _allowEmpty = true ); - ASTPointer parseBlock(ASTPointer const& _docString = {}); - ASTPointer parseStatement(); + ASTPointer parseBlock(bool _allowUncheckedBlock = false, ASTPointer const& _docString = {}); + ASTPointer parseStatement(bool _allowUncheckedBlock = false); ASTPointer parseInlineAssembly(ASTPointer const& _docString = {}); ASTPointer parseIfStatement(ASTPointer const& _docString); ASTPointer parseTryStatement(ASTPointer const& _docString); diff --git a/test/libsolidity/SolidityExpressionCompiler.cpp b/test/libsolidity/SolidityExpressionCompiler.cpp index 8c0ce780d..54e851514 100644 --- a/test/libsolidity/SolidityExpressionCompiler.cpp +++ b/test/libsolidity/SolidityExpressionCompiler.cpp @@ -139,6 +139,7 @@ bytes compileFirstExpression( ); context.resetVisitedNodes(contract); context.setMostDerivedContract(*contract); + context.setArithmetic(Arithmetic::Wrapping); size_t parametersSize = _localVariables.size(); // assume they are all one slot on the stack context.adjustStackOffset(static_cast(parametersSize)); for (vector const& variable: _localVariables)