diff --git a/Changelog.md b/Changelog.md index 84aa53023..3ac18f47c 100644 --- a/Changelog.md +++ b/Changelog.md @@ -5,6 +5,7 @@ Language Features: Compiler Features: * Yul Optimizer: Allow replacing the previously hard-coded cleanup sequence by specifying custom steps after a colon delimiter (``:``) in the sequence string. + * Allow user-defined operators via ``using {f as +} for Typename;``. Bugfixes: diff --git a/libsolidity/analysis/ControlFlowBuilder.cpp b/libsolidity/analysis/ControlFlowBuilder.cpp index 1d1e33f50..25451c89d 100644 --- a/libsolidity/analysis/ControlFlowBuilder.cpp +++ b/libsolidity/analysis/ControlFlowBuilder.cpp @@ -63,6 +63,7 @@ bool ControlFlowBuilder::visit(BinaryOperation const& _operation) case Token::Or: case Token::And: { + solAssert(!_operation.annotation().userDefinedFunction); visitNode(_operation); appendControlFlow(_operation.leftExpression()); @@ -73,10 +74,41 @@ bool ControlFlowBuilder::visit(BinaryOperation const& _operation) return false; } default: - return ASTConstVisitor::visit(_operation); + { + ASTConstVisitor::visit(_operation); + if (_operation.annotation().userDefinedFunction) + { + solAssert(!m_currentNode->resolveFunctionCall(nullptr)); + m_currentNode->functionCall = _operation.annotation().userDefinedFunction; + + auto nextNode = newLabel(); + + connect(m_currentNode, nextNode); + m_currentNode = nextNode; + } + return false; + } } } +bool ControlFlowBuilder::visit(UnaryOperation const& _operation) +{ + solAssert(!!m_currentNode, ""); + + ASTConstVisitor::visit(_operation); + if (_operation.annotation().userDefinedFunction) + { + solAssert(!m_currentNode->resolveFunctionCall(nullptr)); + m_currentNode->functionCall = _operation.annotation().userDefinedFunction; + + auto nextNode = newLabel(); + + connect(m_currentNode, nextNode); + m_currentNode = nextNode; + } + return false; +} + bool ControlFlowBuilder::visit(Conditional const& _conditional) { solAssert(!!m_currentNode, ""); @@ -300,7 +332,7 @@ bool ControlFlowBuilder::visit(FunctionCall const& _functionCall) _functionCall.expression().accept(*this); ASTNode::listAccept(_functionCall.arguments(), *this); - solAssert(!m_currentNode->functionCall); + solAssert(!m_currentNode->resolveFunctionCall(nullptr)); m_currentNode->functionCall = &_functionCall; auto nextNode = newLabel(); diff --git a/libsolidity/analysis/ControlFlowBuilder.h b/libsolidity/analysis/ControlFlowBuilder.h index a150262f1..7782e8edf 100644 --- a/libsolidity/analysis/ControlFlowBuilder.h +++ b/libsolidity/analysis/ControlFlowBuilder.h @@ -50,6 +50,7 @@ private: // Visits for constructing the control flow. bool visit(BinaryOperation const& _operation) override; + bool visit(UnaryOperation const& _operation) override; bool visit(Conditional const& _conditional) override; bool visit(TryStatement const& _tryStatement) override; bool visit(IfStatement const& _ifStatement) override; diff --git a/libsolidity/analysis/ControlFlowGraph.cpp b/libsolidity/analysis/ControlFlowGraph.cpp index ca36b421c..14882743c 100644 --- a/libsolidity/analysis/ControlFlowGraph.cpp +++ b/libsolidity/analysis/ControlFlowGraph.cpp @@ -19,11 +19,21 @@ #include #include +#include using namespace std; +using namespace solidity::util; using namespace solidity::langutil; using namespace solidity::frontend; +FunctionDefinition const* CFGNode::resolveFunctionCall(ContractDefinition const* _mostDerivedContract) const +{ + return std::visit(GenericVisitor{ + [=](FunctionCall const* _funCall) { return _funCall ? ASTNode::resolveFunctionCall(*_funCall, _mostDerivedContract) : nullptr; }, + [](FunctionDefinition const* _funDef) { return _funDef; } + }, functionCall); +} + bool CFG::constructFlow(ASTNode const& _astRoot) { _astRoot.accept(*this); diff --git a/libsolidity/analysis/ControlFlowGraph.h b/libsolidity/analysis/ControlFlowGraph.h index 7383783fd..68f15a30d 100644 --- a/libsolidity/analysis/ControlFlowGraph.h +++ b/libsolidity/analysis/ControlFlowGraph.h @@ -29,6 +29,7 @@ #include #include #include +#include namespace solidity::frontend { @@ -98,8 +99,13 @@ struct CFGNode std::vector entries; /// Exit nodes. All CFG nodes to which control flow may continue after this node. std::vector exits; - /// Function call done by this node - FunctionCall const* functionCall = nullptr; + /// Function call done by this node, either a proper function call (allows virtual lookup) + /// or a direct function definition reference (in case of an operator), + /// or nullptr. + std::variant functionCall = static_cast(nullptr); + /// @returns the actual function called given a most derived contract. If no function is called + /// in this node, returns nullptr. + FunctionDefinition const* resolveFunctionCall(ContractDefinition const* _mostDerivedContract) const; /// Variable occurrences in the node. std::vector variableOccurrences; diff --git a/libsolidity/analysis/ControlFlowRevertPruner.cpp b/libsolidity/analysis/ControlFlowRevertPruner.cpp index ae9c18a13..b42bb3b8a 100644 --- a/libsolidity/analysis/ControlFlowRevertPruner.cpp +++ b/libsolidity/analysis/ControlFlowRevertPruner.cpp @@ -81,27 +81,23 @@ void ControlFlowRevertPruner::findRevertStates() if (_node == functionFlow.exit) foundExit = true; - if (auto const* functionCall = _node->functionCall) + auto const* resolvedFunction = _node->resolveFunctionCall(item.contract); + if (resolvedFunction && resolvedFunction->isImplemented()) { - auto const* resolvedFunction = ASTNode::resolveFunctionCall(*functionCall, item.contract); - - if (resolvedFunction && resolvedFunction->isImplemented()) + CFG::FunctionContractTuple calledFunctionTuple{ + findScopeContract(*resolvedFunction, item.contract), + resolvedFunction + }; + switch (m_functions.at(calledFunctionTuple)) { - CFG::FunctionContractTuple calledFunctionTuple{ - findScopeContract(*resolvedFunction, item.contract), - resolvedFunction - }; - switch (m_functions.at(calledFunctionTuple)) - { - case RevertState::Unknown: - wakeUp[calledFunctionTuple].insert(item); - foundUnknown = true; - return; - case RevertState::AllPathsRevert: - return; - case RevertState::HasNonRevertingPath: - break; - } + case RevertState::Unknown: + wakeUp[calledFunctionTuple].insert(item); + foundUnknown = true; + return; + case RevertState::AllPathsRevert: + return; + case RevertState::HasNonRevertingPath: + break; } } @@ -135,30 +131,26 @@ void ControlFlowRevertPruner::modifyFunctionFlows() FunctionFlow const& functionFlow = m_cfg.functionFlow(*item.first.function, item.first.contract); solidity::util::BreadthFirstSearch{{functionFlow.entry}}.run( [&](CFGNode* _node, auto&& _addChild) { - if (auto const* functionCall = _node->functionCall) - { - auto const* resolvedFunction = ASTNode::resolveFunctionCall(*functionCall, item.first.contract); + auto const* resolvedFunction = _node->resolveFunctionCall(item.first.contract); + if (resolvedFunction && resolvedFunction->isImplemented()) + switch (m_functions.at({findScopeContract(*resolvedFunction, item.first.contract), resolvedFunction})) + { + case RevertState::Unknown: + [[fallthrough]]; + case RevertState::AllPathsRevert: + // If the revert states of the functions do not + // change anymore, we treat all "unknown" states as + // "reverting", since they can only be caused by + // recursion. + for (CFGNode * node: _node->exits) + ranges::remove(node->entries, _node); - if (resolvedFunction && resolvedFunction->isImplemented()) - switch (m_functions.at({findScopeContract(*resolvedFunction, item.first.contract), resolvedFunction})) - { - case RevertState::Unknown: - [[fallthrough]]; - case RevertState::AllPathsRevert: - // If the revert states of the functions do not - // change anymore, we treat all "unknown" states as - // "reverting", since they can only be caused by - // recursion. - for (CFGNode * node: _node->exits) - ranges::remove(node->entries, _node); - - _node->exits = {functionFlow.revert}; - functionFlow.revert->entries.push_back(_node); - return; - default: - break; - } - } + _node->exits = {functionFlow.revert}; + functionFlow.revert->entries.push_back(_node); + return; + default: + break; + } for (CFGNode* exit: _node->exits) _addChild(exit); diff --git a/libsolidity/analysis/FunctionCallGraph.cpp b/libsolidity/analysis/FunctionCallGraph.cpp index e941c99c5..8f7a816d3 100644 --- a/libsolidity/analysis/FunctionCallGraph.cpp +++ b/libsolidity/analysis/FunctionCallGraph.cpp @@ -204,6 +204,20 @@ bool FunctionCallGraphBuilder::visit(MemberAccess const& _memberAccess) return true; } +bool FunctionCallGraphBuilder::visit(BinaryOperation const& _binaryOperation) +{ + if (FunctionDefinition const* function = _binaryOperation.annotation().userDefinedFunction) + functionReferenced(*function, true /* called directly */); + return true; +} + +bool FunctionCallGraphBuilder::visit(UnaryOperation const& _unaryOperation) +{ + if (FunctionDefinition const* function = _unaryOperation.annotation().userDefinedFunction) + functionReferenced(*function, true /* called directly */); + return true; +} + bool FunctionCallGraphBuilder::visit(ModifierInvocation const& _modifierInvocation) { if (auto const* modifier = dynamic_cast(_modifierInvocation.name().annotation().referencedDeclaration)) diff --git a/libsolidity/analysis/FunctionCallGraph.h b/libsolidity/analysis/FunctionCallGraph.h index 3ea1f5b9e..a4e5787a2 100644 --- a/libsolidity/analysis/FunctionCallGraph.h +++ b/libsolidity/analysis/FunctionCallGraph.h @@ -72,6 +72,8 @@ private: bool visit(EmitStatement const& _emitStatement) override; bool visit(Identifier const& _identifier) override; bool visit(MemberAccess const& _memberAccess) override; + bool visit(BinaryOperation const& _binaryOperation) override; + bool visit(UnaryOperation const& _unaryOperation) override; bool visit(ModifierInvocation const& _modifierInvocation) override; bool visit(NewExpression const& _newExpression) override; diff --git a/libsolidity/analysis/PostTypeChecker.cpp b/libsolidity/analysis/PostTypeChecker.cpp index c8b3b5f16..fce50914a 100644 --- a/libsolidity/analysis/PostTypeChecker.cpp +++ b/libsolidity/analysis/PostTypeChecker.cpp @@ -178,6 +178,7 @@ struct ConstStateVarCircularReferenceChecker: public PostTypeChecker::Checker bool visit(Identifier const& _identifier) override { + // TODO add user defined operators? if (m_currentConstVariable) if (auto var = dynamic_cast(_identifier.annotation().referencedDeclaration)) if (var->isConstant()) diff --git a/libsolidity/analysis/SyntaxChecker.cpp b/libsolidity/analysis/SyntaxChecker.cpp index bdf138e20..d9f707ec9 100644 --- a/libsolidity/analysis/SyntaxChecker.cpp +++ b/libsolidity/analysis/SyntaxChecker.cpp @@ -405,6 +405,12 @@ void SyntaxChecker::endVisit(ContractDefinition const&) bool SyntaxChecker::visit(UsingForDirective const& _usingFor) { + if (!_usingFor.usesBraces()) + solAssert( + _usingFor.functionsAndOperators().size() == 1 && + !std::get<1>(_usingFor.functionsAndOperators().front()) + ); + if (!m_currentContractKind && !_usingFor.typeName()) m_errorReporter.syntaxError( 8118_error, diff --git a/libsolidity/analysis/TypeChecker.cpp b/libsolidity/analysis/TypeChecker.cpp index 1025e6576..21d391b80 100644 --- a/libsolidity/analysis/TypeChecker.cpp +++ b/libsolidity/analysis/TypeChecker.cpp @@ -1728,10 +1728,40 @@ bool TypeChecker::visit(UnaryOperation const& _operation) else _operation.subExpression().accept(*this); Type const* subExprType = type(_operation.subExpression()); - TypeResult result = type(_operation.subExpression())->unaryOperatorResult(op); - if (!result) + + + // Check if the operator is built-in or user-defined. + FunctionDefinition const* userDefinedOperator = subExprType->userDefinedOperator( + _operation.getOperator(), + *currentDefinitionScope() + ); + _operation.annotation().userDefinedFunction = userDefinedOperator; + FunctionType const* userDefinedFunctionType = nullptr; + if (userDefinedOperator) + userDefinedFunctionType = &dynamic_cast( + userDefinedOperator->libraryFunction() ? + *userDefinedOperator->typeViaContractName() : + *userDefinedOperator->type() + ); + + TypeResult builtinResult = subExprType->unaryOperatorResult(op); + + solAssert(!builtinResult || !userDefinedOperator); + if (userDefinedOperator) { - string description = "Unary operator " + string(TokenTraits::toString(op)) + " cannot be applied to type " + subExprType->humanReadableName() + "." + (!result.message().empty() ? " " + result.message() : ""); + solAssert(userDefinedFunctionType->parameterTypes().size() == 1); + solAssert(userDefinedFunctionType->returnParameterTypes().size() == 1); + solAssert( + *userDefinedFunctionType->parameterTypes().at(0) == + *userDefinedFunctionType->returnParameterTypes().at(0) + ); + _operation.annotation().type = userDefinedFunctionType->returnParameterTypes().at(0); + } + else if (builtinResult) + _operation.annotation().type = builtinResult; + else + { + string description = "Unary operator " + string(TokenTraits::toString(op)) + " cannot be applied to type " + subExprType->humanReadableName() + "." + (!builtinResult.message().empty() ? " " + builtinResult.message() : ""); if (modifying) // Cannot just report the error, ignore the unary operator, and continue, // because the sub-expression was already processed with requireLValue() @@ -1740,10 +1770,12 @@ bool TypeChecker::visit(UnaryOperation const& _operation) m_errorReporter.typeError(4907_error, _operation.location(), description); _operation.annotation().type = subExprType; } - else - _operation.annotation().type = result.get(); + _operation.annotation().isConstant = false; - _operation.annotation().isPure = !modifying && *_operation.subExpression().annotation().isPure; + _operation.annotation().isPure = + !modifying && + *_operation.subExpression().annotation().isPure && + (!userDefinedFunctionType || userDefinedFunctionType->isPure()); _operation.annotation().isLValue = false; return false; @@ -1753,10 +1785,35 @@ void TypeChecker::endVisit(BinaryOperation const& _operation) { Type const* leftType = type(_operation.leftExpression()); Type const* rightType = type(_operation.rightExpression()); - TypeResult result = leftType->binaryOperatorResult(_operation.getOperator(), rightType); - Type const* commonType = result.get(); - if (!commonType) - { + _operation.annotation().isLValue = false; + _operation.annotation().isConstant = false; + + // Check if the operator is built-in or user-defined. + FunctionDefinition const* userDefinedOperator = leftType->userDefinedOperator( + _operation.getOperator(), + *currentDefinitionScope() + ); + _operation.annotation().userDefinedFunction = userDefinedOperator; + FunctionType const* userDefinedFunctionType = nullptr; + if (userDefinedOperator) + userDefinedFunctionType = &dynamic_cast( + userDefinedOperator->libraryFunction() ? + *userDefinedOperator->typeViaContractName() : + *userDefinedOperator->type() + ); + _operation.annotation().isPure = + *_operation.leftExpression().annotation().isPure && + *_operation.rightExpression().annotation().isPure && + (!userDefinedFunctionType || userDefinedFunctionType->isPure()); + + TypeResult builtinResult = leftType->binaryOperatorResult(_operation.getOperator(), rightType); + Type const* commonType = leftType; + + // Either the operator is user-defined or built-in. + // TODO For enums, we have compare operators. Should we disallow overriding them? + solAssert(!userDefinedOperator || !builtinResult); + + if (!builtinResult && !userDefinedOperator) m_errorReporter.typeError( 2271_error, _operation.location(), @@ -1766,22 +1823,33 @@ void TypeChecker::endVisit(BinaryOperation const& _operation) leftType->humanReadableName() + " and " + rightType->humanReadableName() + "." + - (!result.message().empty() ? " " + result.message() : "") + (!builtinResult.message().empty() ? " " + builtinResult.message() : "") ); - commonType = leftType; + + if (builtinResult) + commonType = builtinResult.get(); + else if (userDefinedOperator) + { + solAssert( + userDefinedFunctionType->parameterTypes().size() == 2 && + *userDefinedFunctionType->parameterTypes().at(0) == + *userDefinedFunctionType->parameterTypes().at(1) + ); + commonType = userDefinedFunctionType->parameterTypes().at(0); } + _operation.annotation().commonType = commonType; _operation.annotation().type = TokenTraits::isCompareOp(_operation.getOperator()) ? TypeProvider::boolean() : commonType; - _operation.annotation().isPure = - *_operation.leftExpression().annotation().isPure && - *_operation.rightExpression().annotation().isPure; - _operation.annotation().isLValue = false; - _operation.annotation().isConstant = false; - if (_operation.getOperator() == Token::Exp || _operation.getOperator() == Token::SHL) + if (userDefinedOperator) + solAssert( + userDefinedFunctionType->returnParameterTypes().size() == 1 && + *userDefinedFunctionType->returnParameterTypes().front() == *_operation.annotation().type + ); + else if (builtinResult && (_operation.getOperator() == Token::Exp || _operation.getOperator() == Token::SHL)) { string operation = _operation.getOperator() == Token::Exp ? "exponentiation" : "shift"; if ( @@ -3784,7 +3852,7 @@ void TypeChecker::endVisit(UsingForDirective const& _usingFor) ); solAssert(normalizedType); - for (ASTPointer const& path: _usingFor.functionsOrLibrary()) + for (auto const& [path, operator_]: _usingFor.functionsAndOperators()) { solAssert(path->annotation().referencedDeclaration); FunctionDefinition const& functionDefinition = @@ -3820,6 +3888,73 @@ void TypeChecker::endVisit(UsingForDirective const& _usingFor) ": " + result.message() ) ); + else if (operator_) + { + if (!_usingFor.typeName()->annotation().type->typeDefinition()) + { + m_errorReporter.typeError( + 5332_error, + path->location(), + "Operators can only be implemented for user-defined types and not for contracts." + ); + continue; + } + // "-" can be used as unary and binary operator. + bool isUnaryNegation = ( + operator_ == Token::Sub && + functionType->parameterTypesIncludingSelf().size() == 1 + ); + if ( + ( + (TokenTraits::isBinaryOp(*operator_) && !isUnaryNegation) || + TokenTraits::isCompareOp(*operator_) + ) && + ( + functionType->parameterTypesIncludingSelf().size() != 2 || + *functionType->parameterTypesIncludingSelf().at(0) != + *functionType->parameterTypesIncludingSelf().at(1) + ) + ) + m_errorReporter.typeError( + 1884_error, + path->location(), + "The function \"" + joinHumanReadable(path->path(), ".") + "\" "+ + "needs to have two parameters of equal type to be used for the operator " + + TokenTraits::friendlyName(*operator_) + + "." + ); + if ( + (isUnaryNegation || (TokenTraits::isUnaryOp(*operator_) && *operator_ != Token::Add)) && + functionType->parameterTypesIncludingSelf().size() != 1 + ) + m_errorReporter.typeError( + 8112_error, + path->location(), + "The function \"" + joinHumanReadable(path->path(), ".") + "\" "+ + "needs to have exactly one parameter to be used for the operator " + + TokenTraits::friendlyName(*operator_) + + "." + ); + Type const* expectedType = + TokenTraits::isCompareOp(*operator_) ? + dynamic_cast(TypeProvider::boolean()) : + functionType->parameterTypesIncludingSelf().at(0); + + if ( + functionType->returnParameterTypes().size() != 1 || + *functionType->returnParameterTypes().front() != *expectedType + ) + m_errorReporter.typeError( + 7743_error, + path->location(), + "The function \"" + joinHumanReadable(path->path(), ".") + "\" "+ + "needs to return exactly one value of type " + + expectedType->toString(true) + + " to be used for the operator " + + TokenTraits::friendlyName(*operator_) + + "." + ); + } } } diff --git a/libsolidity/analysis/ViewPureChecker.cpp b/libsolidity/analysis/ViewPureChecker.cpp index 7483af892..2086d6f00 100644 --- a/libsolidity/analysis/ViewPureChecker.cpp +++ b/libsolidity/analysis/ViewPureChecker.cpp @@ -323,6 +323,8 @@ ViewPureChecker::MutabilityAndLocation const& ViewPureChecker::modifierMutabilit return m_inferredMutability.at(&_modifier); } +// TODO needs to visit binaryoperation as well + void ViewPureChecker::endVisit(FunctionCall const& _functionCall) { if (*_functionCall.annotation().kind != FunctionCallKind::FunctionCall) diff --git a/libsolidity/ast/AST.cpp b/libsolidity/ast/AST.cpp index 8619ee6ae..bc8657126 100644 --- a/libsolidity/ast/AST.cpp +++ b/libsolidity/ast/AST.cpp @@ -895,6 +895,11 @@ MemberAccessAnnotation& MemberAccess::annotation() const return initAnnotation(); } +OperationAnnotation& UnaryOperation::annotation() const +{ + return initAnnotation(); +} + BinaryOperationAnnotation& BinaryOperation::annotation() const { return initAnnotation(); diff --git a/libsolidity/ast/AST.h b/libsolidity/ast/AST.h index d6e41bb5f..c6d4de92a 100644 --- a/libsolidity/ast/AST.h +++ b/libsolidity/ast/AST.h @@ -38,6 +38,7 @@ #include #include +#include #include #include @@ -664,16 +665,19 @@ public: int64_t _id, SourceLocation const& _location, std::vector> _functions, + std::vector> _operators, bool _usesBraces, ASTPointer _typeName, bool _global ): ASTNode(_id, _location), - m_functions(_functions), + m_functions(std::move(_functions)), + m_operators(std::move(_operators)), m_usesBraces(_usesBraces), m_typeName(std::move(_typeName)), m_global{_global} { + solAssert(m_functions.size() == m_operators.size()); } void accept(ASTVisitor& _visitor) override; @@ -684,12 +688,15 @@ public: /// @returns a list of functions or the single library. std::vector> const& functionsOrLibrary() const { return m_functions; } + auto functionsAndOperators() const { return ranges::zip_view(m_functions, m_operators); } bool usesBraces() const { return m_usesBraces; } bool global() const { return m_global; } private: /// Either the single library or a list of functions. std::vector> m_functions; + /// Operators, the functions are applied to. + std::vector> m_operators; bool m_usesBraces; ASTPointer m_typeName; bool m_global = false; @@ -2055,6 +2062,8 @@ public: bool isPrefixOperation() const { return m_isPrefix; } Expression const& subExpression() const { return *m_subExpression; } + OperationAnnotation& annotation() const override; + private: Token m_operator; ASTPointer m_subExpression; diff --git a/libsolidity/ast/ASTAnnotations.h b/libsolidity/ast/ASTAnnotations.h index 2615e6af7..b15598df0 100644 --- a/libsolidity/ast/ASTAnnotations.h +++ b/libsolidity/ast/ASTAnnotations.h @@ -312,7 +312,13 @@ struct MemberAccessAnnotation: ExpressionAnnotation util::SetOnce requiredLookup; }; -struct BinaryOperationAnnotation: ExpressionAnnotation +struct OperationAnnotation: ExpressionAnnotation +{ + // TODO should this be more like "referencedDeclaration"? + FunctionDefinition const* userDefinedFunction = nullptr; +}; + +struct BinaryOperationAnnotation: OperationAnnotation { /// The common type that is used for the operation, not necessarily the result type (which /// e.g. for comparisons is bool). diff --git a/libsolidity/ast/ASTJsonExporter.cpp b/libsolidity/ast/ASTJsonExporter.cpp index d308cd90c..f2ac33325 100644 --- a/libsolidity/ast/ASTJsonExporter.cpp +++ b/libsolidity/ast/ASTJsonExporter.cpp @@ -329,14 +329,17 @@ bool ASTJsonExporter::visit(UsingForDirective const& _node) vector> attributes = { make_pair("typeName", _node.typeName() ? toJson(*_node.typeName()) : Json::nullValue) }; + if (_node.usesBraces()) { Json::Value functionList; - for (auto const& function: _node.functionsOrLibrary()) + for (auto&& [function, op]: _node.functionsAndOperators()) { Json::Value functionNode; functionNode["function"] = toJson(*function); - functionList.append(std::move(functionNode)); + if (op) + functionNode["operator"] = string(TokenTraits::toString(*op)); + functionList.append(move(functionNode)); } attributes.emplace_back("functionList", std::move(functionList)); } @@ -825,6 +828,8 @@ bool ASTJsonExporter::visit(UnaryOperation const& _node) make_pair("operator", TokenTraits::toString(_node.getOperator())), make_pair("subExpression", toJson(_node.subExpression())) }; + if (FunctionDefinition const* function = _node.annotation().userDefinedFunction) + attributes.emplace_back("function", nodeId(*function)); appendExpressionAttributes(attributes, _node.annotation()); setJsonNode(_node, "UnaryOperation", std::move(attributes)); return false; @@ -838,6 +843,8 @@ bool ASTJsonExporter::visit(BinaryOperation const& _node) make_pair("rightExpression", toJson(_node.rightExpression())), make_pair("commonType", typePointerToJson(_node.annotation().commonType)), }; + if (FunctionDefinition const* function = _node.annotation().userDefinedFunction) + attributes.emplace_back("function", nodeId(*function)); appendExpressionAttributes(attributes, _node.annotation()); setJsonNode(_node, "BinaryOperation", std::move(attributes)); return false; diff --git a/libsolidity/ast/ASTJsonImporter.cpp b/libsolidity/ast/ASTJsonImporter.cpp index 96c37003e..61c444eeb 100644 --- a/libsolidity/ast/ASTJsonImporter.cpp +++ b/libsolidity/ast/ASTJsonImporter.cpp @@ -383,15 +383,29 @@ ASTPointer ASTJsonImporter::createInheritanceSpecifier(Jso ASTPointer ASTJsonImporter::createUsingForDirective(Json::Value const& _node) { vector> functions; + vector> operators; if (_node.isMember("libraryName")) + { + solAssert(!_node["libraryName"].isArray()); + solAssert(!_node["libraryName"]["operator"]); functions.emplace_back(createIdentifierPath(_node["libraryName"])); + operators.emplace_back(); + } else if (_node.isMember("functionList")) for (Json::Value const& function: _node["functionList"]) + { functions.emplace_back(createIdentifierPath(function["function"])); + operators.emplace_back( + function.isMember("operator") ? + optional{scanSingleToken(function["operator"])} : + nullopt + ); + } return createASTNode( _node, std::move(functions), + move(operators), !_node.isMember("libraryName"), _node["typeName"].isNull() ? nullptr : convertJsonToASTNode(_node["typeName"]), memberAsBool(_node, "global") diff --git a/libsolidity/ast/AST_accept.h b/libsolidity/ast/AST_accept.h index 6dab08f00..64d785e05 100644 --- a/libsolidity/ast/AST_accept.h +++ b/libsolidity/ast/AST_accept.h @@ -194,7 +194,7 @@ void UsingForDirective::accept(ASTVisitor& _visitor) { if (_visitor.visit(*this)) { - listAccept(functionsOrLibrary(), _visitor); + listAccept(m_functions, _visitor); if (m_typeName) m_typeName->accept(_visitor); } @@ -205,7 +205,7 @@ void UsingForDirective::accept(ASTConstVisitor& _visitor) const { if (_visitor.visit(*this)) { - listAccept(functionsOrLibrary(), _visitor); + listAccept(m_functions, _visitor); if (m_typeName) m_typeName->accept(_visitor); } diff --git a/libsolidity/ast/Types.cpp b/libsolidity/ast/Types.cpp index 9c0179eda..b3bb6847c 100644 --- a/libsolidity/ast/Types.cpp +++ b/libsolidity/ast/Types.cpp @@ -48,6 +48,7 @@ #include #include #include +#include #include #include @@ -337,7 +338,10 @@ Type const* Type::fullEncodingType(bool _inLibraryCall, bool _encoderV2, bool) c return encodingType; } -MemberList::MemberMap Type::boundFunctions(Type const& _type, ASTNode const& _scope) +namespace +{ + +vector usingForDirectivesForType(Type const& _type, ASTNode const& _scope) { vector usingForDirectives; SourceUnit const* sourceUnit = dynamic_cast(&_scope); @@ -362,6 +366,57 @@ MemberList::MemberMap Type::boundFunctions(Type const& _type, ASTNode const& _sc if (auto refType = dynamic_cast(&_type)) typeLocation = refType->location(); + return usingForDirectives | ranges::views::filter([&](UsingForDirective const* _directive) -> bool { + // Convert both types to pointers for comparison to see if the `using for` + // directive applies. + // Further down, we check more detailed for each function if `_type` is + // convertible to the function parameter type. + return + !_directive->typeName() || + *TypeProvider::withLocationIfReference(typeLocation, &_type, true) == + *TypeProvider::withLocationIfReference( + typeLocation, + _directive->typeName()->annotation().type, + true + ); + }) | ranges::to>; +} + +} + +FunctionDefinition const* Type::userDefinedOperator(Token _token, ASTNode const& _scope) const +{ + // Check if it is a user-defined type. + if (!typeDefinition()) + return nullptr; + + set seenFunctions; + for (UsingForDirective const* ufd: usingForDirectivesForType(*this, _scope)) + for (auto const& [pathPointer, operator_]: ufd->functionsAndOperators()) + { + if (operator_ != _token) + continue; + FunctionDefinition const& function = dynamic_cast( + *pathPointer->annotation().referencedDeclaration + ); + FunctionType const* functionType = dynamic_cast( + function.libraryFunction() ? function.typeViaContractName() : function.type() + ); + solAssert(functionType && !functionType->parameterTypes().empty()); + // TODO does this work (data location)? + solAssert(isImplicitlyConvertibleTo(*functionType->parameterTypes().front())); + seenFunctions.insert(&function); + } + // TODO proper error handling. + if (seenFunctions.size() == 1) + return *seenFunctions.begin(); + else + return nullptr; +} + + +MemberList::MemberMap Type::boundFunctions(Type const& _type, ASTNode const& _scope) +{ MemberList::MemberMap members; set> seenFunctions; @@ -381,25 +436,12 @@ MemberList::MemberMap Type::boundFunctions(Type const& _type, ASTNode const& _sc members.emplace_back(&_function, asBoundFunction, *_name); }; - for (UsingForDirective const* ufd: usingForDirectives) - { - // Convert both types to pointers for comparison to see if the `using for` - // directive applies. - // Further down, we check more detailed for each function if `_type` is - // convertible to the function parameter type. - if ( - ufd->typeName() && - *TypeProvider::withLocationIfReference(typeLocation, &_type, true) != - *TypeProvider::withLocationIfReference( - typeLocation, - ufd->typeName()->annotation().type, - true - ) - ) - continue; - - for (auto const& pathPointer: ufd->functionsOrLibrary()) + for (UsingForDirective const* ufd: usingForDirectivesForType(_type, _scope)) + for (auto const& [pathPointer, operator_]: ufd->functionsAndOperators()) { + if (operator_) + continue; + solAssert(pathPointer); Declaration const* declaration = pathPointer->annotation().referencedDeclaration; solAssert(declaration); @@ -420,7 +462,6 @@ MemberList::MemberMap Type::boundFunctions(Type const& _type, ASTNode const& _sc pathPointer->path().back() ); } - } return members; } diff --git a/libsolidity/ast/Types.h b/libsolidity/ast/Types.h index 735bd6902..d41ae3033 100644 --- a/libsolidity/ast/Types.h +++ b/libsolidity/ast/Types.h @@ -377,6 +377,8 @@ public: /// Clears all internally cached values (if any). virtual void clearCache() const; + FunctionDefinition const* userDefinedOperator(Token _token, ASTNode const& _scope) const; + private: /// @returns a member list containing all members added to this type by `using for` directives. static MemberList::MemberMap boundFunctions(Type const& _type, ASTNode const& _scope); diff --git a/libsolidity/codegen/ExpressionCompiler.cpp b/libsolidity/codegen/ExpressionCompiler.cpp index 461de5648..635c279ee 100644 --- a/libsolidity/codegen/ExpressionCompiler.cpp +++ b/libsolidity/codegen/ExpressionCompiler.cpp @@ -502,6 +502,46 @@ bool ExpressionCompiler::visit(BinaryOperation const& _binaryOperation) CompilerContext::LocationSetter locationSetter(m_context, _binaryOperation); Expression const& leftExpression = _binaryOperation.leftExpression(); Expression const& rightExpression = _binaryOperation.rightExpression(); + if (_binaryOperation.annotation().userDefinedFunction) + { + // TODO extract from function call + FunctionDefinition const& function = *_binaryOperation.annotation().userDefinedFunction; + FunctionType const* functionType = dynamic_cast( + function.libraryFunction() ? function.typeViaContractName() : function.type() + ); + solAssert(functionType); + functionType = dynamic_cast(*functionType).asBoundFunction(); + solAssert(functionType); + + evmasm::AssemblyItem returnLabel = m_context.pushNewTag(); + acceptAndConvert(leftExpression, *functionType->selfType()); + acceptAndConvert(rightExpression, *functionType->parameterTypes().at(0)); + + utils().pushCombinedFunctionEntryLabel( + function.resolveVirtual(m_context.mostDerivedContract()), + false + ); + + unsigned parameterSize = + CompilerUtils::sizeOnStack(functionType->parameterTypes()) + + functionType->selfType()->sizeOnStack(); + + if (m_context.runtimeContext()) + // We have a runtime context, so we need the creation part. + utils().rightShiftNumberOnStack(32); + else + // Extract the runtime part. + m_context << ((u256(1) << 32) - 1) << Instruction::AND; + + m_context.appendJump(evmasm::AssemblyItem::JumpType::IntoFunction); + m_context << returnLabel; + + unsigned returnParametersSize = CompilerUtils::sizeOnStack(functionType->returnParameterTypes()); + // callee adds return parameters, but removes arguments and return label + m_context.adjustStackOffset(static_cast(returnParametersSize - parameterSize) - 1); + return false; + } + solAssert(!!_binaryOperation.annotation().commonType, ""); Type const* commonType = _binaryOperation.annotation().commonType; Token const c_op = _binaryOperation.getOperator(); diff --git a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp index 30575930c..6bcd6ab9c 100644 --- a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp +++ b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp @@ -775,10 +775,42 @@ bool IRGeneratorForStatements::visit(BinaryOperation const& _binOp) { setLocation(_binOp); - solAssert(!!_binOp.annotation().commonType); + // TOOD make this nicer + if (_binOp.annotation().userDefinedFunction) + { + _binOp.leftExpression().accept(*this); + _binOp.rightExpression().accept(*this); + setLocation(_binOp); + + // TODO extract from function call + FunctionDefinition const& function = *_binOp.annotation().userDefinedFunction; + FunctionType const* functionType = dynamic_cast( + function.libraryFunction() ? function.typeViaContractName() : function.type() + ); + solAssert(functionType); + functionType = dynamic_cast(*functionType).asBoundFunction(); + solAssert(functionType); + + // TODO virtual? + + string left = expressionAsType(_binOp.leftExpression(), *functionType->selfType()); + string right = expressionAsType(_binOp.rightExpression(), *functionType->parameterTypes().at(0)); + solAssert(!left.empty() && !right.empty()); + + solAssert(function.isImplemented(), ""); + + define(_binOp) << + m_context.enqueueFunctionForCodeGeneration(function) << + ("(" + left + ", " + right + ")\n"); + + return false; + } + + solAssert(!!_binOp.annotation().commonType, ""); Type const* commonType = _binOp.annotation().commonType; langutil::Token op = _binOp.getOperator(); + if (op == Token::And || op == Token::Or) { // This can short-circuit! diff --git a/libsolidity/parsing/Parser.cpp b/libsolidity/parsing/Parser.cpp index afa70609d..96a305235 100644 --- a/libsolidity/parsing/Parser.cpp +++ b/libsolidity/parsing/Parser.cpp @@ -968,6 +968,7 @@ ASTPointer Parser::parseUsingDirective() expectToken(Token::Using); vector> functions; + vector> operators; bool const usesBraces = m_scanner->currentToken() == Token::LBrace; if (usesBraces) { @@ -975,12 +976,38 @@ ASTPointer Parser::parseUsingDirective() { advance(); functions.emplace_back(parseIdentifierPath()); + if (m_scanner->currentToken() == Token::As) + { + advance(); + Token operator_ = m_scanner->currentToken(); + vector overridable = { + // Potential future additions: <<, >>, **, ! + Token::BitOr, Token::BitAnd, Token::BitXor, + Token::Add, Token::Sub, Token::Mul, Token::Div, Token::Mod, + Token::Equal, Token::NotEqual, + Token::LessThan, Token::GreaterThan, Token::LessThanOrEqual, Token::GreaterThanOrEqual, + Token::BitNot + }; + if (!util::contains(overridable, operator_)) + parserError( + 1885_error, + ("The operator " + string{TokenTraits::toString(operator_)} + " cannot be user-implemented. This is only possible for the folloing operators: ") + + util::joinHumanReadable(overridable | ranges::views::transform([](Token _t) { return string{TokenTraits::toString(_t)}; })) + ); + operators.emplace_back(operator_); + advance(); + } + else + operators.emplace_back(); } while (m_scanner->currentToken() == Token::Comma); expectToken(Token::RBrace); } else + { functions.emplace_back(parseIdentifierPath()); + operators.emplace_back(); + } ASTPointer typeName; expectToken(Token::For); @@ -996,7 +1023,7 @@ ASTPointer Parser::parseUsingDirective() } nodeFactory.markEndPosition(); expectToken(Token::Semicolon); - return nodeFactory.createNode(std::move(functions), usesBraces, typeName, global); + return nodeFactory.createNode(std::move(functions), std::move(operators), usesBraces, typeName, global); } ASTPointer Parser::parseModifierInvocation() diff --git a/test/libsolidity/semanticTests/operators/custom/addition_returning_bool.sol b/test/libsolidity/semanticTests/operators/custom/addition_returning_bool.sol new file mode 100644 index 000000000..dcfdbac3d --- /dev/null +++ b/test/libsolidity/semanticTests/operators/custom/addition_returning_bool.sol @@ -0,0 +1,16 @@ +type MyInt is int; +using {add as +} for MyInt; + +function add(MyInt, MyInt) pure returns (bool) { + return true; +} + +contract C { + function f() public pure returns (bool t) { + t = MyInt.wrap(2) + MyInt.wrap(7); + } +} +// ==== +// compileViaYul: also +// ---- +// f() -> true diff --git a/test/libsolidity/semanticTests/operators/custom/all_operators.sol b/test/libsolidity/semanticTests/operators/custom/all_operators.sol new file mode 100644 index 000000000..75ee7d885 --- /dev/null +++ b/test/libsolidity/semanticTests/operators/custom/all_operators.sol @@ -0,0 +1,75 @@ +type Int is int128; +using { + bitor as |, bitand as &, bitxor as ^, bitnot as ~, + add as +, sub as -, unsub as -, mul as *, div as /, mod as %, + eq as ==, noteq as !=, lt as <, gt as >, leq as <=, geq as >= +} for Int; + +function uw(Int x) pure returns (int128) { + return Int.unwrap(x); +} +function w(int128 x) pure returns (Int) { + return Int.wrap(x); +} +function bitor(Int, Int) pure returns (Int) { + return w(1); +} +function bitand(Int, Int) pure returns (Int) { + return w(2); +} +function bitxor(Int, Int) pure returns (Int) { + return w(3); +} +function bitnot(Int) pure returns (Int) { + return w(4); +} +function add(Int, Int) pure returns (Int) { + return w(5); +} +function sub(Int, Int) pure returns (Int) { + return w(6); +} +function unsub(Int) pure returns (Int) { + return w(7); +} +function mul(Int, Int) pure returns (Int) { + return w(8); +} +function div(Int, Int) pure returns (Int) { + return w(9); +} +function mod(Int, Int) pure returns (Int) { + return w(10); +} +function eq(Int x, Int) pure returns (bool) { + return uw(x) == 1; +} +function noteq(Int x, Int) pure returns (bool) { + return uw(x) == 2; +} +function lt(Int x, Int) pure returns (bool) { + return uw(x) == 3; +} +function gt(Int x, Int) pure returns (bool) { + return uw(x) == 4; +} +function leq(Int x, Int) pure returns (bool) { + return uw(x) == 5; +} +function geq(Int x, Int) pure returns (bool) { + return uw(x) == 6; +} + +// TODO test that side-effects are executed properly. +contract C { + function f1() public pure returns (Int) { + require(w(1) | w(2) == w(1)); + require(!(w(1) | w(2) == w(2))); + return w(1) | w(2); + } + // TODO all the other operators +} +// ==== +// compileViaYul: also +// ---- +// f1() diff --git a/test/libsolidity/semanticTests/operators/custom/fixedpoint.sol b/test/libsolidity/semanticTests/operators/custom/fixedpoint.sol new file mode 100644 index 000000000..fa119abf2 --- /dev/null +++ b/test/libsolidity/semanticTests/operators/custom/fixedpoint.sol @@ -0,0 +1,24 @@ +type Fixed is int128; +using {add as +, mul as *} for Fixed; + +int constant MULTIPLIER = 10**18; + +function add(Fixed a, Fixed b) pure returns (Fixed) { + return Fixed.wrap(Fixed.unwrap(a) + Fixed.unwrap(b)); +} + +function mul(Fixed a, Fixed b) pure returns (Fixed) { + int intermediate = (int(Fixed.unwrap(a)) * int(Fixed.unwrap(b))) / MULTIPLIER; + if (int128(intermediate) != intermediate) { revert("Overflow"); } + return Fixed.wrap(int128(intermediate)); +} + +contract C { + function applyInterest(Fixed value, Fixed percentage) public pure returns (Fixed result) { + return value + value * percentage; + } +} +// ==== +// compileViaYul: also +// ---- +// applyInterest(int128,int128): 500000000000000000000, 100000000000000000 -> 550000000000000000000 diff --git a/test/libsolidity/syntaxTests/constants/constant_cyclic_via_user_operators.sol b/test/libsolidity/syntaxTests/constants/constant_cyclic_via_user_operators.sol new file mode 100644 index 000000000..e3f75ce44 --- /dev/null +++ b/test/libsolidity/syntaxTests/constants/constant_cyclic_via_user_operators.sol @@ -0,0 +1,10 @@ +type Type is uint; +using {f as +} for Type; +function f(Type, Type) pure returns (Type) {} + +Type constant t = Type.wrap(1); +Type constant u = v + t; +Type constant v = u + t; +// ---- +// TypeError 8349: (141-146): Initial value for constant variable has to be compile-time constant. +// TypeError 8349: (166-171): Initial value for constant variable has to be compile-time constant. diff --git a/test/libsolidity/syntaxTests/using/operator_for_builtin.sol b/test/libsolidity/syntaxTests/using/operator_for_builtin.sol new file mode 100644 index 000000000..e2b9ee5d9 --- /dev/null +++ b/test/libsolidity/syntaxTests/using/operator_for_builtin.sol @@ -0,0 +1,4 @@ +using {f as +} for uint; +function f(uint, uint) pure returns (uint) {} +// ---- +// TypeError 5332: (7-8): Operators can only be implemented for user-defined types and not for contracts.