Unary operators with using for directive fix

This commit is contained in:
wechman 2022-07-07 14:07:59 +02:00
parent 6482f5bb17
commit 56bcb525bc
6 changed files with 136 additions and 34 deletions

View File

@ -1733,7 +1733,8 @@ bool TypeChecker::visit(UnaryOperation const& _operation)
// Check if the operator is built-in or user-defined. // Check if the operator is built-in or user-defined.
FunctionDefinition const* userDefinedOperator = subExprType->userDefinedOperator( FunctionDefinition const* userDefinedOperator = subExprType->userDefinedOperator(
_operation.getOperator(), _operation.getOperator(),
*currentDefinitionScope() *currentDefinitionScope(),
true // _unaryOperation
); );
_operation.annotation().userDefinedFunction = userDefinedOperator; _operation.annotation().userDefinedFunction = userDefinedOperator;
FunctionType const* userDefinedFunctionType = nullptr; FunctionType const* userDefinedFunctionType = nullptr;
@ -1791,7 +1792,8 @@ void TypeChecker::endVisit(BinaryOperation const& _operation)
// Check if the operator is built-in or user-defined. // Check if the operator is built-in or user-defined.
FunctionDefinition const* userDefinedOperator = leftType->userDefinedOperator( FunctionDefinition const* userDefinedOperator = leftType->userDefinedOperator(
_operation.getOperator(), _operation.getOperator(),
*currentDefinitionScope() *currentDefinitionScope(),
false // _unaryOperation
); );
_operation.annotation().userDefinedFunction = userDefinedOperator; _operation.annotation().userDefinedFunction = userDefinedOperator;
FunctionType const* userDefinedFunctionType = nullptr; FunctionType const* userDefinedFunctionType = nullptr;
@ -3899,15 +3901,10 @@ void TypeChecker::endVisit(UsingForDirective const& _usingFor)
); );
continue; continue;
} }
// "-" can be used as unary and binary operator.
bool isUnaryNegation = (
operator_ == Token::Sub &&
functionType->parameterTypesIncludingSelf().size() == 1
);
if ( if (
( (
(TokenTraits::isBinaryOp(*operator_) && !isUnaryNegation) || (TokenTraits::isBinaryOp(*operator_) && !TokenTraits::isUnaryOp(*operator_)) || TokenTraits::isCompareOp(*operator_)
TokenTraits::isCompareOp(*operator_)
) && ) &&
( (
functionType->parameterTypesIncludingSelf().size() != 2 || functionType->parameterTypesIncludingSelf().size() != 2 ||

View File

@ -384,7 +384,7 @@ vector<UsingForDirective const*> usingForDirectivesForType(Type const& _type, AS
} }
FunctionDefinition const* Type::userDefinedOperator(Token _token, ASTNode const& _scope) const FunctionDefinition const* Type::userDefinedOperator(Token _token, ASTNode const& _scope, bool _unaryOperation) const
{ {
// Check if it is a user-defined type. // Check if it is a user-defined type.
if (!typeDefinition()) if (!typeDefinition())
@ -405,8 +405,11 @@ FunctionDefinition const* Type::userDefinedOperator(Token _token, ASTNode const&
solAssert(functionType && !functionType->parameterTypes().empty()); solAssert(functionType && !functionType->parameterTypes().empty());
// TODO does this work (data location)? // TODO does this work (data location)?
solAssert(isImplicitlyConvertibleTo(*functionType->parameterTypes().front())); solAssert(isImplicitlyConvertibleTo(*functionType->parameterTypes().front()));
seenFunctions.insert(&function); if ((_unaryOperation && function.parameterList().parameters().size() == 1) ||
(!_unaryOperation && function.parameterList().parameters().size() == 2))
seenFunctions.insert(&function);
} }
// TODO proper error handling. // TODO proper error handling.
if (seenFunctions.size() == 1) if (seenFunctions.size() == 1)
return *seenFunctions.begin(); return *seenFunctions.begin();

View File

@ -377,7 +377,7 @@ public:
/// Clears all internally cached values (if any). /// Clears all internally cached values (if any).
virtual void clearCache() const; virtual void clearCache() const;
FunctionDefinition const* userDefinedOperator(Token _token, ASTNode const& _scope) const; FunctionDefinition const* userDefinedOperator(Token _token, ASTNode const& _scope, bool _unaryOperation) const;
private: private:
/// @returns a member list containing all members added to this type by `using for` directives. /// @returns a member list containing all members added to this type by `using for` directives.

View File

@ -410,6 +410,47 @@ bool ExpressionCompiler::visit(TupleExpression const& _tuple)
bool ExpressionCompiler::visit(UnaryOperation const& _unaryOperation) bool ExpressionCompiler::visit(UnaryOperation const& _unaryOperation)
{ {
CompilerContext::LocationSetter locationSetter(m_context, _unaryOperation); CompilerContext::LocationSetter locationSetter(m_context, _unaryOperation);
if (_unaryOperation.annotation().userDefinedFunction)
{
FunctionDefinition const& function = *_unaryOperation.annotation().userDefinedFunction;
FunctionType const* functionType = dynamic_cast<FunctionType const*>(
function.libraryFunction() ? function.typeViaContractName() : function.type());
solAssert(functionType);
functionType = dynamic_cast<FunctionType const&>(*functionType).asBoundFunction();
solAssert(functionType);
evmasm::AssemblyItem returnLabel = m_context.pushNewTag();
_unaryOperation.subExpression().accept(*this);
utils().pushCombinedFunctionEntryLabel(
function.resolveVirtual(m_context.mostDerivedContract()),
false
);
unsigned parameterSize =
CompilerUtils::sizeOnStack(functionType->parameterTypes()) +
functionType->selfType()->sizeOnStack();
if (m_context.runtimeContext())
// We have a runtime context, so we need the creation part.
utils().rightShiftNumberOnStack(32);
else
// Extract the runtime part.
m_context << ((u256(1) << 32) - 1) << Instruction::AND;
m_context.appendJump(evmasm::AssemblyItem::JumpType::IntoFunction);
m_context << returnLabel;
unsigned returnParametersSize = CompilerUtils::sizeOnStack(functionType->returnParameterTypes());
// callee adds return parameters, but removes arguments and return label
m_context.adjustStackOffset(static_cast<int>(returnParametersSize - parameterSize) - 1);
return false;
}
Type const& type = *_unaryOperation.annotation().type; Type const& type = *_unaryOperation.annotation().type;
if (type.category() == Type::Category::RationalNumber) if (type.category() == Type::Category::RationalNumber)
{ {

View File

@ -672,6 +672,34 @@ void IRGeneratorForStatements::endVisit(Return const& _return)
bool IRGeneratorForStatements::visit(UnaryOperation const& _unaryOperation) bool IRGeneratorForStatements::visit(UnaryOperation const& _unaryOperation)
{ {
setLocation(_unaryOperation); setLocation(_unaryOperation);
if (_unaryOperation.annotation().userDefinedFunction)
{
_unaryOperation.subExpression().accept(*this);
setLocation(_unaryOperation);
// TODO extract from function call
FunctionDefinition const& function = *_unaryOperation.annotation().userDefinedFunction;
FunctionType const* functionType = dynamic_cast<FunctionType const*>(
function.libraryFunction() ? function.typeViaContractName() : function.type()
);
solAssert(functionType);
functionType = dynamic_cast<FunctionType const&>(*functionType).asBoundFunction();
solAssert(functionType);
// TODO virtual?
string parameter = expressionAsType(_unaryOperation.subExpression(), *functionType->selfType());
solAssert(!parameter.empty());
solAssert(function.isImplemented(), "");
define(_unaryOperation) <<
m_context.enqueueFunctionForCodeGeneration(function) <<
("(" + parameter + ")\n");
return false;
}
Type const& resultType = type(_unaryOperation); Type const& resultType = type(_unaryOperation);
Token const op = _unaryOperation.getOperator(); Token const op = _unaryOperation.getOperator();

View File

@ -12,34 +12,34 @@ function w(int128 x) pure returns (Int) {
return Int.wrap(x); return Int.wrap(x);
} }
function bitor(Int, Int) pure returns (Int) { function bitor(Int, Int) pure returns (Int) {
return w(1); return w(10);
} }
function bitand(Int, Int) pure returns (Int) { function bitand(Int, Int) pure returns (Int) {
return w(2); return w(11);
} }
function bitxor(Int, Int) pure returns (Int) { function bitxor(Int, Int) pure returns (Int) {
return w(3); return w(12);
} }
function bitnot(Int) pure returns (Int) { function bitnot(Int) pure returns (Int) {
return w(4); return w(13);
} }
function add(Int, Int) pure returns (Int) { function add(Int x, Int) pure returns (int128) {
return w(5); return uw(x) + 10;
} }
function sub(Int, Int) pure returns (Int) { function sub(Int, Int) pure returns (Int) {
return w(6); return w(15);
} }
function unsub(Int) pure returns (Int) { function unsub(Int) pure returns (Int) {
return w(7); return w(16);
} }
function mul(Int, Int) pure returns (Int) { function mul(Int, Int) pure returns (Int) {
return w(8); return w(17);
} }
function div(Int, Int) pure returns (Int) { function div(Int, Int) pure returns (Int) {
return w(9); return w(18);
} }
function mod(Int, Int) pure returns (Int) { function mod(Int, Int) pure returns (Int) {
return w(10); return w(19);
} }
function eq(Int x, Int) pure returns (bool) { function eq(Int x, Int) pure returns (bool) {
return uw(x) == 1; return uw(x) == 1;
@ -48,28 +48,61 @@ function noteq(Int x, Int) pure returns (bool) {
return uw(x) == 2; return uw(x) == 2;
} }
function lt(Int x, Int) pure returns (bool) { function lt(Int x, Int) pure returns (bool) {
return uw(x) == 3; return uw(x) < 10;
} }
function gt(Int x, Int) pure returns (bool) { function gt(Int x, Int) pure returns (bool) {
return uw(x) == 4; return uw(x) > 10;
} }
function leq(Int x, Int) pure returns (bool) { function leq(Int x, Int) pure returns (bool) {
return uw(x) == 5; return uw(x) <= 10;
} }
function geq(Int x, Int) pure returns (bool) { function geq(Int x, Int) pure returns (bool) {
return uw(x) == 6; return uw(x) >= 10;
} }
// TODO test that side-effects are executed properly. // TODO test that side-effects are executed properly.
contract C { contract C {
function f1() public pure returns (Int) { function test_bitor() public pure returns (Int) { return w(1) | w(2); }
require(w(1) | w(2) == w(1)); function test_bitand() public pure returns (Int) { return w(1) | w(2); }
require(!(w(1) | w(2) == w(2))); function test_bitxor() public pure returns (Int) { return w(1) ^ w(2); }
return w(1) | w(2); function test_bitnot() public pure returns (Int) { return ~w(1); }
} function test_add(int128 x) public pure returns (int128) { return w(x) + w(2); }
// TODO all the other operators function test_sub() public pure returns (Int) { return w(1) - w(2); }
function test_unsub() public pure returns (Int) { return -w(1); }
function test_mul() public pure returns (Int) { return w(1) * w(2); }
function test_div() public pure returns (Int) { return w(1) / w(2); }
function test_mod() public pure returns (Int) { return w(1) % w(2); }
function test_eq(int128 x) public pure returns (bool) { return w(x) == w(2); }
function test_neq(int128 x) public pure returns (bool) { return w(x) != w(2); }
function test_lt(int128 x) public pure returns (bool) { return w(x) < w(2); }
function test_gt(int128 x) public pure returns (bool) { return w(x) > w(2); }
function test_leq(int128 x) public pure returns (bool) { return w(x) <= w(2); }
function test_geq(int128 x) public pure returns (bool) { return w(x) >= w(2); }
} }
// ==== // ====
// compileViaYul: also // compileViaYul: also
// ---- // ----
// f1() // test_bitor() -> 10
// test_bitand() -> 10
// test_bitxor() -> 12
// test_bitnot() -> 13
// test_add(int128): 4 -> 14
// test_add(int128): 104 -> 114
// test_sub() -> 15
// test_unsub() -> 16
// test_mul() -> 17
// test_div() -> 18
// test_mod() -> 19
// test_eq(int128): 1 -> true
// test_eq(int128): 2 -> false
// test_neq(int128): 2 -> true
// test_neq(int128): 1 -> false
// test_lt(int128): 9 -> true
// test_lt(int128): 10 -> false
// test_gt(int128): 11 -> true
// test_gt(int128): 10 -> false
// test_leq(int128): 10 -> true
// test_leq(int128): 11 -> false
// test_geq(int128): 10 -> true
// test_geq(int128): 9 -> false