From 56bcb525bca4ea25989d47372e1bc0eda616235f Mon Sep 17 00:00:00 2001 From: wechman Date: Thu, 7 Jul 2022 14:07:59 +0200 Subject: [PATCH] Unary operators with using for directive fix --- libsolidity/analysis/TypeChecker.cpp | 15 ++-- libsolidity/ast/Types.cpp | 7 +- libsolidity/ast/Types.h | 2 +- libsolidity/codegen/ExpressionCompiler.cpp | 41 ++++++++++ .../codegen/ir/IRGeneratorForStatements.cpp | 28 +++++++ .../operators/custom/all_operators.sol | 77 +++++++++++++------ 6 files changed, 136 insertions(+), 34 deletions(-) diff --git a/libsolidity/analysis/TypeChecker.cpp b/libsolidity/analysis/TypeChecker.cpp index 987ad1177..3d3ad0244 100644 --- a/libsolidity/analysis/TypeChecker.cpp +++ b/libsolidity/analysis/TypeChecker.cpp @@ -1733,7 +1733,8 @@ bool TypeChecker::visit(UnaryOperation const& _operation) // Check if the operator is built-in or user-defined. FunctionDefinition const* userDefinedOperator = subExprType->userDefinedOperator( _operation.getOperator(), - *currentDefinitionScope() + *currentDefinitionScope(), + true // _unaryOperation ); _operation.annotation().userDefinedFunction = userDefinedOperator; FunctionType const* userDefinedFunctionType = nullptr; @@ -1791,7 +1792,8 @@ void TypeChecker::endVisit(BinaryOperation const& _operation) // Check if the operator is built-in or user-defined. FunctionDefinition const* userDefinedOperator = leftType->userDefinedOperator( _operation.getOperator(), - *currentDefinitionScope() + *currentDefinitionScope(), + false // _unaryOperation ); _operation.annotation().userDefinedFunction = userDefinedOperator; FunctionType const* userDefinedFunctionType = nullptr; @@ -3899,15 +3901,10 @@ void TypeChecker::endVisit(UsingForDirective const& _usingFor) ); continue; } - // "-" can be used as unary and binary operator. - bool isUnaryNegation = ( - operator_ == Token::Sub && - functionType->parameterTypesIncludingSelf().size() == 1 - ); + if ( ( - (TokenTraits::isBinaryOp(*operator_) && !isUnaryNegation) || - TokenTraits::isCompareOp(*operator_) + (TokenTraits::isBinaryOp(*operator_) && !TokenTraits::isUnaryOp(*operator_)) || TokenTraits::isCompareOp(*operator_) ) && ( functionType->parameterTypesIncludingSelf().size() != 2 || diff --git a/libsolidity/ast/Types.cpp b/libsolidity/ast/Types.cpp index b3bb6847c..7d8bef9cd 100644 --- a/libsolidity/ast/Types.cpp +++ b/libsolidity/ast/Types.cpp @@ -384,7 +384,7 @@ vector usingForDirectivesForType(Type const& _type, AS } -FunctionDefinition const* Type::userDefinedOperator(Token _token, ASTNode const& _scope) const +FunctionDefinition const* Type::userDefinedOperator(Token _token, ASTNode const& _scope, bool _unaryOperation) const { // Check if it is a user-defined type. if (!typeDefinition()) @@ -405,8 +405,11 @@ FunctionDefinition const* Type::userDefinedOperator(Token _token, ASTNode const& solAssert(functionType && !functionType->parameterTypes().empty()); // TODO does this work (data location)? solAssert(isImplicitlyConvertibleTo(*functionType->parameterTypes().front())); - seenFunctions.insert(&function); + if ((_unaryOperation && function.parameterList().parameters().size() == 1) || + (!_unaryOperation && function.parameterList().parameters().size() == 2)) + seenFunctions.insert(&function); } + // TODO proper error handling. if (seenFunctions.size() == 1) return *seenFunctions.begin(); diff --git a/libsolidity/ast/Types.h b/libsolidity/ast/Types.h index d41ae3033..eec9974e0 100644 --- a/libsolidity/ast/Types.h +++ b/libsolidity/ast/Types.h @@ -377,7 +377,7 @@ public: /// Clears all internally cached values (if any). virtual void clearCache() const; - FunctionDefinition const* userDefinedOperator(Token _token, ASTNode const& _scope) const; + FunctionDefinition const* userDefinedOperator(Token _token, ASTNode const& _scope, bool _unaryOperation) const; private: /// @returns a member list containing all members added to this type by `using for` directives. diff --git a/libsolidity/codegen/ExpressionCompiler.cpp b/libsolidity/codegen/ExpressionCompiler.cpp index 635c279ee..68d2547a2 100644 --- a/libsolidity/codegen/ExpressionCompiler.cpp +++ b/libsolidity/codegen/ExpressionCompiler.cpp @@ -410,6 +410,47 @@ bool ExpressionCompiler::visit(TupleExpression const& _tuple) bool ExpressionCompiler::visit(UnaryOperation const& _unaryOperation) { CompilerContext::LocationSetter locationSetter(m_context, _unaryOperation); + + if (_unaryOperation.annotation().userDefinedFunction) + { + FunctionDefinition const& function = *_unaryOperation.annotation().userDefinedFunction; + FunctionType const* functionType = dynamic_cast( + function.libraryFunction() ? function.typeViaContractName() : function.type()); + + solAssert(functionType); + functionType = dynamic_cast(*functionType).asBoundFunction(); + solAssert(functionType); + evmasm::AssemblyItem returnLabel = m_context.pushNewTag(); + _unaryOperation.subExpression().accept(*this); + utils().pushCombinedFunctionEntryLabel( + function.resolveVirtual(m_context.mostDerivedContract()), + false + ); + + unsigned parameterSize = + CompilerUtils::sizeOnStack(functionType->parameterTypes()) + + functionType->selfType()->sizeOnStack(); + + if (m_context.runtimeContext()) + // We have a runtime context, so we need the creation part. + utils().rightShiftNumberOnStack(32); + else + // Extract the runtime part. + m_context << ((u256(1) << 32) - 1) << Instruction::AND; + + m_context.appendJump(evmasm::AssemblyItem::JumpType::IntoFunction); + m_context << returnLabel; + + unsigned returnParametersSize = CompilerUtils::sizeOnStack(functionType->returnParameterTypes()); + + // callee adds return parameters, but removes arguments and return label + m_context.adjustStackOffset(static_cast(returnParametersSize - parameterSize) - 1); + + return false; + } + + + Type const& type = *_unaryOperation.annotation().type; if (type.category() == Type::Category::RationalNumber) { diff --git a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp index 61e70f257..f23307faf 100644 --- a/libsolidity/codegen/ir/IRGeneratorForStatements.cpp +++ b/libsolidity/codegen/ir/IRGeneratorForStatements.cpp @@ -672,6 +672,34 @@ void IRGeneratorForStatements::endVisit(Return const& _return) bool IRGeneratorForStatements::visit(UnaryOperation const& _unaryOperation) { setLocation(_unaryOperation); + + if (_unaryOperation.annotation().userDefinedFunction) + { + _unaryOperation.subExpression().accept(*this); + setLocation(_unaryOperation); + + // TODO extract from function call + FunctionDefinition const& function = *_unaryOperation.annotation().userDefinedFunction; + FunctionType const* functionType = dynamic_cast( + function.libraryFunction() ? function.typeViaContractName() : function.type() + ); + solAssert(functionType); + functionType = dynamic_cast(*functionType).asBoundFunction(); + solAssert(functionType); + + // TODO virtual? + + string parameter = expressionAsType(_unaryOperation.subExpression(), *functionType->selfType()); + solAssert(!parameter.empty()); + solAssert(function.isImplemented(), ""); + + define(_unaryOperation) << + m_context.enqueueFunctionForCodeGeneration(function) << + ("(" + parameter + ")\n"); + + return false; + } + Type const& resultType = type(_unaryOperation); Token const op = _unaryOperation.getOperator(); diff --git a/test/libsolidity/semanticTests/operators/custom/all_operators.sol b/test/libsolidity/semanticTests/operators/custom/all_operators.sol index 75ee7d885..33df2b851 100644 --- a/test/libsolidity/semanticTests/operators/custom/all_operators.sol +++ b/test/libsolidity/semanticTests/operators/custom/all_operators.sol @@ -12,34 +12,34 @@ function w(int128 x) pure returns (Int) { return Int.wrap(x); } function bitor(Int, Int) pure returns (Int) { - return w(1); + return w(10); } function bitand(Int, Int) pure returns (Int) { - return w(2); + return w(11); } function bitxor(Int, Int) pure returns (Int) { - return w(3); + return w(12); } function bitnot(Int) pure returns (Int) { - return w(4); + return w(13); } -function add(Int, Int) pure returns (Int) { - return w(5); +function add(Int x, Int) pure returns (int128) { + return uw(x) + 10; } function sub(Int, Int) pure returns (Int) { - return w(6); + return w(15); } function unsub(Int) pure returns (Int) { - return w(7); + return w(16); } function mul(Int, Int) pure returns (Int) { - return w(8); + return w(17); } function div(Int, Int) pure returns (Int) { - return w(9); + return w(18); } function mod(Int, Int) pure returns (Int) { - return w(10); + return w(19); } function eq(Int x, Int) pure returns (bool) { return uw(x) == 1; @@ -48,28 +48,61 @@ function noteq(Int x, Int) pure returns (bool) { return uw(x) == 2; } function lt(Int x, Int) pure returns (bool) { - return uw(x) == 3; + return uw(x) < 10; } function gt(Int x, Int) pure returns (bool) { - return uw(x) == 4; + return uw(x) > 10; } function leq(Int x, Int) pure returns (bool) { - return uw(x) == 5; + return uw(x) <= 10; } function geq(Int x, Int) pure returns (bool) { - return uw(x) == 6; + return uw(x) >= 10; } // TODO test that side-effects are executed properly. contract C { - function f1() public pure returns (Int) { - require(w(1) | w(2) == w(1)); - require(!(w(1) | w(2) == w(2))); - return w(1) | w(2); - } - // TODO all the other operators + function test_bitor() public pure returns (Int) { return w(1) | w(2); } + function test_bitand() public pure returns (Int) { return w(1) | w(2); } + function test_bitxor() public pure returns (Int) { return w(1) ^ w(2); } + function test_bitnot() public pure returns (Int) { return ~w(1); } + function test_add(int128 x) public pure returns (int128) { return w(x) + w(2); } + function test_sub() public pure returns (Int) { return w(1) - w(2); } + function test_unsub() public pure returns (Int) { return -w(1); } + function test_mul() public pure returns (Int) { return w(1) * w(2); } + function test_div() public pure returns (Int) { return w(1) / w(2); } + function test_mod() public pure returns (Int) { return w(1) % w(2); } + function test_eq(int128 x) public pure returns (bool) { return w(x) == w(2); } + function test_neq(int128 x) public pure returns (bool) { return w(x) != w(2); } + function test_lt(int128 x) public pure returns (bool) { return w(x) < w(2); } + function test_gt(int128 x) public pure returns (bool) { return w(x) > w(2); } + function test_leq(int128 x) public pure returns (bool) { return w(x) <= w(2); } + function test_geq(int128 x) public pure returns (bool) { return w(x) >= w(2); } } + // ==== // compileViaYul: also // ---- -// f1() +// test_bitor() -> 10 +// test_bitand() -> 10 +// test_bitxor() -> 12 +// test_bitnot() -> 13 +// test_add(int128): 4 -> 14 +// test_add(int128): 104 -> 114 +// test_sub() -> 15 +// test_unsub() -> 16 +// test_mul() -> 17 +// test_div() -> 18 +// test_mod() -> 19 +// test_eq(int128): 1 -> true +// test_eq(int128): 2 -> false +// test_neq(int128): 2 -> true +// test_neq(int128): 1 -> false +// test_lt(int128): 9 -> true +// test_lt(int128): 10 -> false +// test_gt(int128): 11 -> true +// test_gt(int128): 10 -> false +// test_leq(int128): 10 -> true +// test_leq(int128): 11 -> false +// test_geq(int128): 10 -> true +// test_geq(int128): 9 -> false