Unary operators with using for directive fix

This commit is contained in:
wechman 2022-07-07 14:07:59 +02:00
parent 6482f5bb17
commit 56bcb525bc
6 changed files with 136 additions and 34 deletions

View File

@ -1733,7 +1733,8 @@ bool TypeChecker::visit(UnaryOperation const& _operation)
// Check if the operator is built-in or user-defined.
FunctionDefinition const* userDefinedOperator = subExprType->userDefinedOperator(
_operation.getOperator(),
*currentDefinitionScope()
*currentDefinitionScope(),
true // _unaryOperation
);
_operation.annotation().userDefinedFunction = userDefinedOperator;
FunctionType const* userDefinedFunctionType = nullptr;
@ -1791,7 +1792,8 @@ void TypeChecker::endVisit(BinaryOperation const& _operation)
// Check if the operator is built-in or user-defined.
FunctionDefinition const* userDefinedOperator = leftType->userDefinedOperator(
_operation.getOperator(),
*currentDefinitionScope()
*currentDefinitionScope(),
false // _unaryOperation
);
_operation.annotation().userDefinedFunction = userDefinedOperator;
FunctionType const* userDefinedFunctionType = nullptr;
@ -3899,15 +3901,10 @@ void TypeChecker::endVisit(UsingForDirective const& _usingFor)
);
continue;
}
// "-" can be used as unary and binary operator.
bool isUnaryNegation = (
operator_ == Token::Sub &&
functionType->parameterTypesIncludingSelf().size() == 1
);
if (
(
(TokenTraits::isBinaryOp(*operator_) && !isUnaryNegation) ||
TokenTraits::isCompareOp(*operator_)
(TokenTraits::isBinaryOp(*operator_) && !TokenTraits::isUnaryOp(*operator_)) || TokenTraits::isCompareOp(*operator_)
) &&
(
functionType->parameterTypesIncludingSelf().size() != 2 ||

View File

@ -384,7 +384,7 @@ vector<UsingForDirective const*> usingForDirectivesForType(Type const& _type, AS
}
FunctionDefinition const* Type::userDefinedOperator(Token _token, ASTNode const& _scope) const
FunctionDefinition const* Type::userDefinedOperator(Token _token, ASTNode const& _scope, bool _unaryOperation) const
{
// Check if it is a user-defined type.
if (!typeDefinition())
@ -405,8 +405,11 @@ FunctionDefinition const* Type::userDefinedOperator(Token _token, ASTNode const&
solAssert(functionType && !functionType->parameterTypes().empty());
// TODO does this work (data location)?
solAssert(isImplicitlyConvertibleTo(*functionType->parameterTypes().front()));
seenFunctions.insert(&function);
if ((_unaryOperation && function.parameterList().parameters().size() == 1) ||
(!_unaryOperation && function.parameterList().parameters().size() == 2))
seenFunctions.insert(&function);
}
// TODO proper error handling.
if (seenFunctions.size() == 1)
return *seenFunctions.begin();

View File

@ -377,7 +377,7 @@ public:
/// Clears all internally cached values (if any).
virtual void clearCache() const;
FunctionDefinition const* userDefinedOperator(Token _token, ASTNode const& _scope) const;
FunctionDefinition const* userDefinedOperator(Token _token, ASTNode const& _scope, bool _unaryOperation) const;
private:
/// @returns a member list containing all members added to this type by `using for` directives.

View File

@ -410,6 +410,47 @@ bool ExpressionCompiler::visit(TupleExpression const& _tuple)
bool ExpressionCompiler::visit(UnaryOperation const& _unaryOperation)
{
CompilerContext::LocationSetter locationSetter(m_context, _unaryOperation);
if (_unaryOperation.annotation().userDefinedFunction)
{
FunctionDefinition const& function = *_unaryOperation.annotation().userDefinedFunction;
FunctionType const* functionType = dynamic_cast<FunctionType const*>(
function.libraryFunction() ? function.typeViaContractName() : function.type());
solAssert(functionType);
functionType = dynamic_cast<FunctionType const&>(*functionType).asBoundFunction();
solAssert(functionType);
evmasm::AssemblyItem returnLabel = m_context.pushNewTag();
_unaryOperation.subExpression().accept(*this);
utils().pushCombinedFunctionEntryLabel(
function.resolveVirtual(m_context.mostDerivedContract()),
false
);
unsigned parameterSize =
CompilerUtils::sizeOnStack(functionType->parameterTypes()) +
functionType->selfType()->sizeOnStack();
if (m_context.runtimeContext())
// We have a runtime context, so we need the creation part.
utils().rightShiftNumberOnStack(32);
else
// Extract the runtime part.
m_context << ((u256(1) << 32) - 1) << Instruction::AND;
m_context.appendJump(evmasm::AssemblyItem::JumpType::IntoFunction);
m_context << returnLabel;
unsigned returnParametersSize = CompilerUtils::sizeOnStack(functionType->returnParameterTypes());
// callee adds return parameters, but removes arguments and return label
m_context.adjustStackOffset(static_cast<int>(returnParametersSize - parameterSize) - 1);
return false;
}
Type const& type = *_unaryOperation.annotation().type;
if (type.category() == Type::Category::RationalNumber)
{

View File

@ -672,6 +672,34 @@ void IRGeneratorForStatements::endVisit(Return const& _return)
bool IRGeneratorForStatements::visit(UnaryOperation const& _unaryOperation)
{
setLocation(_unaryOperation);
if (_unaryOperation.annotation().userDefinedFunction)
{
_unaryOperation.subExpression().accept(*this);
setLocation(_unaryOperation);
// TODO extract from function call
FunctionDefinition const& function = *_unaryOperation.annotation().userDefinedFunction;
FunctionType const* functionType = dynamic_cast<FunctionType const*>(
function.libraryFunction() ? function.typeViaContractName() : function.type()
);
solAssert(functionType);
functionType = dynamic_cast<FunctionType const&>(*functionType).asBoundFunction();
solAssert(functionType);
// TODO virtual?
string parameter = expressionAsType(_unaryOperation.subExpression(), *functionType->selfType());
solAssert(!parameter.empty());
solAssert(function.isImplemented(), "");
define(_unaryOperation) <<
m_context.enqueueFunctionForCodeGeneration(function) <<
("(" + parameter + ")\n");
return false;
}
Type const& resultType = type(_unaryOperation);
Token const op = _unaryOperation.getOperator();

View File

@ -12,34 +12,34 @@ function w(int128 x) pure returns (Int) {
return Int.wrap(x);
}
function bitor(Int, Int) pure returns (Int) {
return w(1);
return w(10);
}
function bitand(Int, Int) pure returns (Int) {
return w(2);
return w(11);
}
function bitxor(Int, Int) pure returns (Int) {
return w(3);
return w(12);
}
function bitnot(Int) pure returns (Int) {
return w(4);
return w(13);
}
function add(Int, Int) pure returns (Int) {
return w(5);
function add(Int x, Int) pure returns (int128) {
return uw(x) + 10;
}
function sub(Int, Int) pure returns (Int) {
return w(6);
return w(15);
}
function unsub(Int) pure returns (Int) {
return w(7);
return w(16);
}
function mul(Int, Int) pure returns (Int) {
return w(8);
return w(17);
}
function div(Int, Int) pure returns (Int) {
return w(9);
return w(18);
}
function mod(Int, Int) pure returns (Int) {
return w(10);
return w(19);
}
function eq(Int x, Int) pure returns (bool) {
return uw(x) == 1;
@ -48,28 +48,61 @@ function noteq(Int x, Int) pure returns (bool) {
return uw(x) == 2;
}
function lt(Int x, Int) pure returns (bool) {
return uw(x) == 3;
return uw(x) < 10;
}
function gt(Int x, Int) pure returns (bool) {
return uw(x) == 4;
return uw(x) > 10;
}
function leq(Int x, Int) pure returns (bool) {
return uw(x) == 5;
return uw(x) <= 10;
}
function geq(Int x, Int) pure returns (bool) {
return uw(x) == 6;
return uw(x) >= 10;
}
// TODO test that side-effects are executed properly.
contract C {
function f1() public pure returns (Int) {
require(w(1) | w(2) == w(1));
require(!(w(1) | w(2) == w(2)));
return w(1) | w(2);
}
// TODO all the other operators
function test_bitor() public pure returns (Int) { return w(1) | w(2); }
function test_bitand() public pure returns (Int) { return w(1) | w(2); }
function test_bitxor() public pure returns (Int) { return w(1) ^ w(2); }
function test_bitnot() public pure returns (Int) { return ~w(1); }
function test_add(int128 x) public pure returns (int128) { return w(x) + w(2); }
function test_sub() public pure returns (Int) { return w(1) - w(2); }
function test_unsub() public pure returns (Int) { return -w(1); }
function test_mul() public pure returns (Int) { return w(1) * w(2); }
function test_div() public pure returns (Int) { return w(1) / w(2); }
function test_mod() public pure returns (Int) { return w(1) % w(2); }
function test_eq(int128 x) public pure returns (bool) { return w(x) == w(2); }
function test_neq(int128 x) public pure returns (bool) { return w(x) != w(2); }
function test_lt(int128 x) public pure returns (bool) { return w(x) < w(2); }
function test_gt(int128 x) public pure returns (bool) { return w(x) > w(2); }
function test_leq(int128 x) public pure returns (bool) { return w(x) <= w(2); }
function test_geq(int128 x) public pure returns (bool) { return w(x) >= w(2); }
}
// ====
// compileViaYul: also
// ----
// f1()
// test_bitor() -> 10
// test_bitand() -> 10
// test_bitxor() -> 12
// test_bitnot() -> 13
// test_add(int128): 4 -> 14
// test_add(int128): 104 -> 114
// test_sub() -> 15
// test_unsub() -> 16
// test_mul() -> 17
// test_div() -> 18
// test_mod() -> 19
// test_eq(int128): 1 -> true
// test_eq(int128): 2 -> false
// test_neq(int128): 2 -> true
// test_neq(int128): 1 -> false
// test_lt(int128): 9 -> true
// test_lt(int128): 10 -> false
// test_gt(int128): 11 -> true
// test_gt(int128): 10 -> false
// test_leq(int128): 10 -> true
// test_leq(int128): 11 -> false
// test_geq(int128): 10 -> true
// test_geq(int128): 9 -> false