Using for for operators.

This commit is contained in:
chriseth 2021-12-01 16:22:50 +01:00 committed by wechman
parent 2201526a90
commit 3bd047f188
28 changed files with 607 additions and 93 deletions

View File

@ -5,6 +5,7 @@ Language Features:
Compiler Features:
* Yul Optimizer: Allow replacing the previously hard-coded cleanup sequence by specifying custom steps after a colon delimiter (``:``) in the sequence string.
* Allow user-defined operators via ``using {f as +} for Typename;``.
Bugfixes:

View File

@ -63,6 +63,7 @@ bool ControlFlowBuilder::visit(BinaryOperation const& _operation)
case Token::Or:
case Token::And:
{
solAssert(!_operation.annotation().userDefinedFunction);
visitNode(_operation);
appendControlFlow(_operation.leftExpression());
@ -73,8 +74,39 @@ bool ControlFlowBuilder::visit(BinaryOperation const& _operation)
return false;
}
default:
return ASTConstVisitor::visit(_operation);
{
ASTConstVisitor::visit(_operation);
if (_operation.annotation().userDefinedFunction)
{
solAssert(!m_currentNode->resolveFunctionCall(nullptr));
m_currentNode->functionCall = _operation.annotation().userDefinedFunction;
auto nextNode = newLabel();
connect(m_currentNode, nextNode);
m_currentNode = nextNode;
}
return false;
}
}
}
bool ControlFlowBuilder::visit(UnaryOperation const& _operation)
{
solAssert(!!m_currentNode, "");
ASTConstVisitor::visit(_operation);
if (_operation.annotation().userDefinedFunction)
{
solAssert(!m_currentNode->resolveFunctionCall(nullptr));
m_currentNode->functionCall = _operation.annotation().userDefinedFunction;
auto nextNode = newLabel();
connect(m_currentNode, nextNode);
m_currentNode = nextNode;
}
return false;
}
bool ControlFlowBuilder::visit(Conditional const& _conditional)
@ -300,7 +332,7 @@ bool ControlFlowBuilder::visit(FunctionCall const& _functionCall)
_functionCall.expression().accept(*this);
ASTNode::listAccept(_functionCall.arguments(), *this);
solAssert(!m_currentNode->functionCall);
solAssert(!m_currentNode->resolveFunctionCall(nullptr));
m_currentNode->functionCall = &_functionCall;
auto nextNode = newLabel();

View File

@ -50,6 +50,7 @@ private:
// Visits for constructing the control flow.
bool visit(BinaryOperation const& _operation) override;
bool visit(UnaryOperation const& _operation) override;
bool visit(Conditional const& _conditional) override;
bool visit(TryStatement const& _tryStatement) override;
bool visit(IfStatement const& _ifStatement) override;

View File

@ -19,11 +19,21 @@
#include <libsolidity/analysis/ControlFlowGraph.h>
#include <libsolidity/analysis/ControlFlowBuilder.h>
#include <libsolutil/Visitor.h>
using namespace std;
using namespace solidity::util;
using namespace solidity::langutil;
using namespace solidity::frontend;
FunctionDefinition const* CFGNode::resolveFunctionCall(ContractDefinition const* _mostDerivedContract) const
{
return std::visit(GenericVisitor{
[=](FunctionCall const* _funCall) { return _funCall ? ASTNode::resolveFunctionCall(*_funCall, _mostDerivedContract) : nullptr; },
[](FunctionDefinition const* _funDef) { return _funDef; }
}, functionCall);
}
bool CFG::constructFlow(ASTNode const& _astRoot)
{
_astRoot.accept(*this);

View File

@ -29,6 +29,7 @@
#include <stack>
#include <utility>
#include <vector>
#include <variant>
namespace solidity::frontend
{
@ -98,8 +99,13 @@ struct CFGNode
std::vector<CFGNode*> entries;
/// Exit nodes. All CFG nodes to which control flow may continue after this node.
std::vector<CFGNode*> exits;
/// Function call done by this node
FunctionCall const* functionCall = nullptr;
/// Function call done by this node, either a proper function call (allows virtual lookup)
/// or a direct function definition reference (in case of an operator),
/// or nullptr.
std::variant<FunctionCall const*, FunctionDefinition const*> functionCall = static_cast<FunctionCall const*>(nullptr);
/// @returns the actual function called given a most derived contract. If no function is called
/// in this node, returns nullptr.
FunctionDefinition const* resolveFunctionCall(ContractDefinition const* _mostDerivedContract) const;
/// Variable occurrences in the node.
std::vector<VariableOccurrence> variableOccurrences;

View File

@ -81,10 +81,7 @@ void ControlFlowRevertPruner::findRevertStates()
if (_node == functionFlow.exit)
foundExit = true;
if (auto const* functionCall = _node->functionCall)
{
auto const* resolvedFunction = ASTNode::resolveFunctionCall(*functionCall, item.contract);
auto const* resolvedFunction = _node->resolveFunctionCall(item.contract);
if (resolvedFunction && resolvedFunction->isImplemented())
{
CFG::FunctionContractTuple calledFunctionTuple{
@ -103,7 +100,6 @@ void ControlFlowRevertPruner::findRevertStates()
break;
}
}
}
for (CFGNode* exit: _node->exits)
_addChild(exit);
@ -135,10 +131,7 @@ void ControlFlowRevertPruner::modifyFunctionFlows()
FunctionFlow const& functionFlow = m_cfg.functionFlow(*item.first.function, item.first.contract);
solidity::util::BreadthFirstSearch<CFGNode*>{{functionFlow.entry}}.run(
[&](CFGNode* _node, auto&& _addChild) {
if (auto const* functionCall = _node->functionCall)
{
auto const* resolvedFunction = ASTNode::resolveFunctionCall(*functionCall, item.first.contract);
auto const* resolvedFunction = _node->resolveFunctionCall(item.first.contract);
if (resolvedFunction && resolvedFunction->isImplemented())
switch (m_functions.at({findScopeContract(*resolvedFunction, item.first.contract), resolvedFunction}))
{
@ -158,7 +151,6 @@ void ControlFlowRevertPruner::modifyFunctionFlows()
default:
break;
}
}
for (CFGNode* exit: _node->exits)
_addChild(exit);

View File

@ -204,6 +204,20 @@ bool FunctionCallGraphBuilder::visit(MemberAccess const& _memberAccess)
return true;
}
bool FunctionCallGraphBuilder::visit(BinaryOperation const& _binaryOperation)
{
if (FunctionDefinition const* function = _binaryOperation.annotation().userDefinedFunction)
functionReferenced(*function, true /* called directly */);
return true;
}
bool FunctionCallGraphBuilder::visit(UnaryOperation const& _unaryOperation)
{
if (FunctionDefinition const* function = _unaryOperation.annotation().userDefinedFunction)
functionReferenced(*function, true /* called directly */);
return true;
}
bool FunctionCallGraphBuilder::visit(ModifierInvocation const& _modifierInvocation)
{
if (auto const* modifier = dynamic_cast<ModifierDefinition const*>(_modifierInvocation.name().annotation().referencedDeclaration))

View File

@ -72,6 +72,8 @@ private:
bool visit(EmitStatement const& _emitStatement) override;
bool visit(Identifier const& _identifier) override;
bool visit(MemberAccess const& _memberAccess) override;
bool visit(BinaryOperation const& _binaryOperation) override;
bool visit(UnaryOperation const& _unaryOperation) override;
bool visit(ModifierInvocation const& _modifierInvocation) override;
bool visit(NewExpression const& _newExpression) override;

View File

@ -178,6 +178,7 @@ struct ConstStateVarCircularReferenceChecker: public PostTypeChecker::Checker
bool visit(Identifier const& _identifier) override
{
// TODO add user defined operators?
if (m_currentConstVariable)
if (auto var = dynamic_cast<VariableDeclaration const*>(_identifier.annotation().referencedDeclaration))
if (var->isConstant())

View File

@ -405,6 +405,12 @@ void SyntaxChecker::endVisit(ContractDefinition const&)
bool SyntaxChecker::visit(UsingForDirective const& _usingFor)
{
if (!_usingFor.usesBraces())
solAssert(
_usingFor.functionsAndOperators().size() == 1 &&
!std::get<1>(_usingFor.functionsAndOperators().front())
);
if (!m_currentContractKind && !_usingFor.typeName())
m_errorReporter.syntaxError(
8118_error,

View File

@ -1728,10 +1728,40 @@ bool TypeChecker::visit(UnaryOperation const& _operation)
else
_operation.subExpression().accept(*this);
Type const* subExprType = type(_operation.subExpression());
TypeResult result = type(_operation.subExpression())->unaryOperatorResult(op);
if (!result)
// Check if the operator is built-in or user-defined.
FunctionDefinition const* userDefinedOperator = subExprType->userDefinedOperator(
_operation.getOperator(),
*currentDefinitionScope()
);
_operation.annotation().userDefinedFunction = userDefinedOperator;
FunctionType const* userDefinedFunctionType = nullptr;
if (userDefinedOperator)
userDefinedFunctionType = &dynamic_cast<FunctionType const&>(
userDefinedOperator->libraryFunction() ?
*userDefinedOperator->typeViaContractName() :
*userDefinedOperator->type()
);
TypeResult builtinResult = subExprType->unaryOperatorResult(op);
solAssert(!builtinResult || !userDefinedOperator);
if (userDefinedOperator)
{
string description = "Unary operator " + string(TokenTraits::toString(op)) + " cannot be applied to type " + subExprType->humanReadableName() + "." + (!result.message().empty() ? " " + result.message() : "");
solAssert(userDefinedFunctionType->parameterTypes().size() == 1);
solAssert(userDefinedFunctionType->returnParameterTypes().size() == 1);
solAssert(
*userDefinedFunctionType->parameterTypes().at(0) ==
*userDefinedFunctionType->returnParameterTypes().at(0)
);
_operation.annotation().type = userDefinedFunctionType->returnParameterTypes().at(0);
}
else if (builtinResult)
_operation.annotation().type = builtinResult;
else
{
string description = "Unary operator " + string(TokenTraits::toString(op)) + " cannot be applied to type " + subExprType->humanReadableName() + "." + (!builtinResult.message().empty() ? " " + builtinResult.message() : "");
if (modifying)
// Cannot just report the error, ignore the unary operator, and continue,
// because the sub-expression was already processed with requireLValue()
@ -1740,10 +1770,12 @@ bool TypeChecker::visit(UnaryOperation const& _operation)
m_errorReporter.typeError(4907_error, _operation.location(), description);
_operation.annotation().type = subExprType;
}
else
_operation.annotation().type = result.get();
_operation.annotation().isConstant = false;
_operation.annotation().isPure = !modifying && *_operation.subExpression().annotation().isPure;
_operation.annotation().isPure =
!modifying &&
*_operation.subExpression().annotation().isPure &&
(!userDefinedFunctionType || userDefinedFunctionType->isPure());
_operation.annotation().isLValue = false;
return false;
@ -1753,10 +1785,35 @@ void TypeChecker::endVisit(BinaryOperation const& _operation)
{
Type const* leftType = type(_operation.leftExpression());
Type const* rightType = type(_operation.rightExpression());
TypeResult result = leftType->binaryOperatorResult(_operation.getOperator(), rightType);
Type const* commonType = result.get();
if (!commonType)
{
_operation.annotation().isLValue = false;
_operation.annotation().isConstant = false;
// Check if the operator is built-in or user-defined.
FunctionDefinition const* userDefinedOperator = leftType->userDefinedOperator(
_operation.getOperator(),
*currentDefinitionScope()
);
_operation.annotation().userDefinedFunction = userDefinedOperator;
FunctionType const* userDefinedFunctionType = nullptr;
if (userDefinedOperator)
userDefinedFunctionType = &dynamic_cast<FunctionType const&>(
userDefinedOperator->libraryFunction() ?
*userDefinedOperator->typeViaContractName() :
*userDefinedOperator->type()
);
_operation.annotation().isPure =
*_operation.leftExpression().annotation().isPure &&
*_operation.rightExpression().annotation().isPure &&
(!userDefinedFunctionType || userDefinedFunctionType->isPure());
TypeResult builtinResult = leftType->binaryOperatorResult(_operation.getOperator(), rightType);
Type const* commonType = leftType;
// Either the operator is user-defined or built-in.
// TODO For enums, we have compare operators. Should we disallow overriding them?
solAssert(!userDefinedOperator || !builtinResult);
if (!builtinResult && !userDefinedOperator)
m_errorReporter.typeError(
2271_error,
_operation.location(),
@ -1766,22 +1823,33 @@ void TypeChecker::endVisit(BinaryOperation const& _operation)
leftType->humanReadableName() +
" and " +
rightType->humanReadableName() + "." +
(!result.message().empty() ? " " + result.message() : "")
(!builtinResult.message().empty() ? " " + builtinResult.message() : "")
);
commonType = leftType;
if (builtinResult)
commonType = builtinResult.get();
else if (userDefinedOperator)
{
solAssert(
userDefinedFunctionType->parameterTypes().size() == 2 &&
*userDefinedFunctionType->parameterTypes().at(0) ==
*userDefinedFunctionType->parameterTypes().at(1)
);
commonType = userDefinedFunctionType->parameterTypes().at(0);
}
_operation.annotation().commonType = commonType;
_operation.annotation().type =
TokenTraits::isCompareOp(_operation.getOperator()) ?
TypeProvider::boolean() :
commonType;
_operation.annotation().isPure =
*_operation.leftExpression().annotation().isPure &&
*_operation.rightExpression().annotation().isPure;
_operation.annotation().isLValue = false;
_operation.annotation().isConstant = false;
if (_operation.getOperator() == Token::Exp || _operation.getOperator() == Token::SHL)
if (userDefinedOperator)
solAssert(
userDefinedFunctionType->returnParameterTypes().size() == 1 &&
*userDefinedFunctionType->returnParameterTypes().front() == *_operation.annotation().type
);
else if (builtinResult && (_operation.getOperator() == Token::Exp || _operation.getOperator() == Token::SHL))
{
string operation = _operation.getOperator() == Token::Exp ? "exponentiation" : "shift";
if (
@ -3784,7 +3852,7 @@ void TypeChecker::endVisit(UsingForDirective const& _usingFor)
);
solAssert(normalizedType);
for (ASTPointer<IdentifierPath> const& path: _usingFor.functionsOrLibrary())
for (auto const& [path, operator_]: _usingFor.functionsAndOperators())
{
solAssert(path->annotation().referencedDeclaration);
FunctionDefinition const& functionDefinition =
@ -3820,6 +3888,73 @@ void TypeChecker::endVisit(UsingForDirective const& _usingFor)
": " + result.message()
)
);
else if (operator_)
{
if (!_usingFor.typeName()->annotation().type->typeDefinition())
{
m_errorReporter.typeError(
5332_error,
path->location(),
"Operators can only be implemented for user-defined types and not for contracts."
);
continue;
}
// "-" can be used as unary and binary operator.
bool isUnaryNegation = (
operator_ == Token::Sub &&
functionType->parameterTypesIncludingSelf().size() == 1
);
if (
(
(TokenTraits::isBinaryOp(*operator_) && !isUnaryNegation) ||
TokenTraits::isCompareOp(*operator_)
) &&
(
functionType->parameterTypesIncludingSelf().size() != 2 ||
*functionType->parameterTypesIncludingSelf().at(0) !=
*functionType->parameterTypesIncludingSelf().at(1)
)
)
m_errorReporter.typeError(
1884_error,
path->location(),
"The function \"" + joinHumanReadable(path->path(), ".") + "\" "+
"needs to have two parameters of equal type to be used for the operator " +
TokenTraits::friendlyName(*operator_) +
"."
);
if (
(isUnaryNegation || (TokenTraits::isUnaryOp(*operator_) && *operator_ != Token::Add)) &&
functionType->parameterTypesIncludingSelf().size() != 1
)
m_errorReporter.typeError(
8112_error,
path->location(),
"The function \"" + joinHumanReadable(path->path(), ".") + "\" "+
"needs to have exactly one parameter to be used for the operator " +
TokenTraits::friendlyName(*operator_) +
"."
);
Type const* expectedType =
TokenTraits::isCompareOp(*operator_) ?
dynamic_cast<Type const*>(TypeProvider::boolean()) :
functionType->parameterTypesIncludingSelf().at(0);
if (
functionType->returnParameterTypes().size() != 1 ||
*functionType->returnParameterTypes().front() != *expectedType
)
m_errorReporter.typeError(
7743_error,
path->location(),
"The function \"" + joinHumanReadable(path->path(), ".") + "\" "+
"needs to return exactly one value of type " +
expectedType->toString(true) +
" to be used for the operator " +
TokenTraits::friendlyName(*operator_) +
"."
);
}
}
}

View File

@ -323,6 +323,8 @@ ViewPureChecker::MutabilityAndLocation const& ViewPureChecker::modifierMutabilit
return m_inferredMutability.at(&_modifier);
}
// TODO needs to visit binaryoperation as well
void ViewPureChecker::endVisit(FunctionCall const& _functionCall)
{
if (*_functionCall.annotation().kind != FunctionCallKind::FunctionCall)

View File

@ -895,6 +895,11 @@ MemberAccessAnnotation& MemberAccess::annotation() const
return initAnnotation<MemberAccessAnnotation>();
}
OperationAnnotation& UnaryOperation::annotation() const
{
return initAnnotation<OperationAnnotation>();
}
BinaryOperationAnnotation& BinaryOperation::annotation() const
{
return initAnnotation<BinaryOperationAnnotation>();

View File

@ -38,6 +38,7 @@
#include <json/json.h>
#include <range/v3/view/subrange.hpp>
#include <range/v3/view/zip.hpp>
#include <range/v3/view/map.hpp>
#include <memory>
@ -664,16 +665,19 @@ public:
int64_t _id,
SourceLocation const& _location,
std::vector<ASTPointer<IdentifierPath>> _functions,
std::vector<std::optional<Token>> _operators,
bool _usesBraces,
ASTPointer<TypeName> _typeName,
bool _global
):
ASTNode(_id, _location),
m_functions(_functions),
m_functions(std::move(_functions)),
m_operators(std::move(_operators)),
m_usesBraces(_usesBraces),
m_typeName(std::move(_typeName)),
m_global{_global}
{
solAssert(m_functions.size() == m_operators.size());
}
void accept(ASTVisitor& _visitor) override;
@ -684,12 +688,15 @@ public:
/// @returns a list of functions or the single library.
std::vector<ASTPointer<IdentifierPath>> const& functionsOrLibrary() const { return m_functions; }
auto functionsAndOperators() const { return ranges::zip_view(m_functions, m_operators); }
bool usesBraces() const { return m_usesBraces; }
bool global() const { return m_global; }
private:
/// Either the single library or a list of functions.
std::vector<ASTPointer<IdentifierPath>> m_functions;
/// Operators, the functions are applied to.
std::vector<std::optional<Token>> m_operators;
bool m_usesBraces;
ASTPointer<TypeName> m_typeName;
bool m_global = false;
@ -2055,6 +2062,8 @@ public:
bool isPrefixOperation() const { return m_isPrefix; }
Expression const& subExpression() const { return *m_subExpression; }
OperationAnnotation& annotation() const override;
private:
Token m_operator;
ASTPointer<Expression> m_subExpression;

View File

@ -312,7 +312,13 @@ struct MemberAccessAnnotation: ExpressionAnnotation
util::SetOnce<VirtualLookup> requiredLookup;
};
struct BinaryOperationAnnotation: ExpressionAnnotation
struct OperationAnnotation: ExpressionAnnotation
{
// TODO should this be more like "referencedDeclaration"?
FunctionDefinition const* userDefinedFunction = nullptr;
};
struct BinaryOperationAnnotation: OperationAnnotation
{
/// The common type that is used for the operation, not necessarily the result type (which
/// e.g. for comparisons is bool).

View File

@ -329,14 +329,17 @@ bool ASTJsonExporter::visit(UsingForDirective const& _node)
vector<pair<string, Json::Value>> attributes = {
make_pair("typeName", _node.typeName() ? toJson(*_node.typeName()) : Json::nullValue)
};
if (_node.usesBraces())
{
Json::Value functionList;
for (auto const& function: _node.functionsOrLibrary())
for (auto&& [function, op]: _node.functionsAndOperators())
{
Json::Value functionNode;
functionNode["function"] = toJson(*function);
functionList.append(std::move(functionNode));
if (op)
functionNode["operator"] = string(TokenTraits::toString(*op));
functionList.append(move(functionNode));
}
attributes.emplace_back("functionList", std::move(functionList));
}
@ -825,6 +828,8 @@ bool ASTJsonExporter::visit(UnaryOperation const& _node)
make_pair("operator", TokenTraits::toString(_node.getOperator())),
make_pair("subExpression", toJson(_node.subExpression()))
};
if (FunctionDefinition const* function = _node.annotation().userDefinedFunction)
attributes.emplace_back("function", nodeId(*function));
appendExpressionAttributes(attributes, _node.annotation());
setJsonNode(_node, "UnaryOperation", std::move(attributes));
return false;
@ -838,6 +843,8 @@ bool ASTJsonExporter::visit(BinaryOperation const& _node)
make_pair("rightExpression", toJson(_node.rightExpression())),
make_pair("commonType", typePointerToJson(_node.annotation().commonType)),
};
if (FunctionDefinition const* function = _node.annotation().userDefinedFunction)
attributes.emplace_back("function", nodeId(*function));
appendExpressionAttributes(attributes, _node.annotation());
setJsonNode(_node, "BinaryOperation", std::move(attributes));
return false;

View File

@ -383,15 +383,29 @@ ASTPointer<InheritanceSpecifier> ASTJsonImporter::createInheritanceSpecifier(Jso
ASTPointer<UsingForDirective> ASTJsonImporter::createUsingForDirective(Json::Value const& _node)
{
vector<ASTPointer<IdentifierPath>> functions;
vector<optional<Token>> operators;
if (_node.isMember("libraryName"))
{
solAssert(!_node["libraryName"].isArray());
solAssert(!_node["libraryName"]["operator"]);
functions.emplace_back(createIdentifierPath(_node["libraryName"]));
operators.emplace_back();
}
else if (_node.isMember("functionList"))
for (Json::Value const& function: _node["functionList"])
{
functions.emplace_back(createIdentifierPath(function["function"]));
operators.emplace_back(
function.isMember("operator") ?
optional<Token>{scanSingleToken(function["operator"])} :
nullopt
);
}
return createASTNode<UsingForDirective>(
_node,
std::move(functions),
move(operators),
!_node.isMember("libraryName"),
_node["typeName"].isNull() ? nullptr : convertJsonToASTNode<TypeName>(_node["typeName"]),
memberAsBool(_node, "global")

View File

@ -194,7 +194,7 @@ void UsingForDirective::accept(ASTVisitor& _visitor)
{
if (_visitor.visit(*this))
{
listAccept(functionsOrLibrary(), _visitor);
listAccept(m_functions, _visitor);
if (m_typeName)
m_typeName->accept(_visitor);
}
@ -205,7 +205,7 @@ void UsingForDirective::accept(ASTConstVisitor& _visitor) const
{
if (_visitor.visit(*this))
{
listAccept(functionsOrLibrary(), _visitor);
listAccept(m_functions, _visitor);
if (m_typeName)
m_typeName->accept(_visitor);
}

View File

@ -48,6 +48,7 @@
#include <range/v3/view/reverse.hpp>
#include <range/v3/view/tail.hpp>
#include <range/v3/view/transform.hpp>
#include <range/v3/view/filter.hpp>
#include <limits>
#include <unordered_set>
@ -337,7 +338,10 @@ Type const* Type::fullEncodingType(bool _inLibraryCall, bool _encoderV2, bool) c
return encodingType;
}
MemberList::MemberMap Type::boundFunctions(Type const& _type, ASTNode const& _scope)
namespace
{
vector<UsingForDirective const*> usingForDirectivesForType(Type const& _type, ASTNode const& _scope)
{
vector<UsingForDirective const*> usingForDirectives;
SourceUnit const* sourceUnit = dynamic_cast<SourceUnit const*>(&_scope);
@ -362,6 +366,57 @@ MemberList::MemberMap Type::boundFunctions(Type const& _type, ASTNode const& _sc
if (auto refType = dynamic_cast<ReferenceType const*>(&_type))
typeLocation = refType->location();
return usingForDirectives | ranges::views::filter([&](UsingForDirective const* _directive) -> bool {
// Convert both types to pointers for comparison to see if the `using for`
// directive applies.
// Further down, we check more detailed for each function if `_type` is
// convertible to the function parameter type.
return
!_directive->typeName() ||
*TypeProvider::withLocationIfReference(typeLocation, &_type, true) ==
*TypeProvider::withLocationIfReference(
typeLocation,
_directive->typeName()->annotation().type,
true
);
}) | ranges::to<vector<UsingForDirective const*>>;
}
}
FunctionDefinition const* Type::userDefinedOperator(Token _token, ASTNode const& _scope) const
{
// Check if it is a user-defined type.
if (!typeDefinition())
return nullptr;
set<FunctionDefinition const*> seenFunctions;
for (UsingForDirective const* ufd: usingForDirectivesForType(*this, _scope))
for (auto const& [pathPointer, operator_]: ufd->functionsAndOperators())
{
if (operator_ != _token)
continue;
FunctionDefinition const& function = dynamic_cast<FunctionDefinition const&>(
*pathPointer->annotation().referencedDeclaration
);
FunctionType const* functionType = dynamic_cast<FunctionType const*>(
function.libraryFunction() ? function.typeViaContractName() : function.type()
);
solAssert(functionType && !functionType->parameterTypes().empty());
// TODO does this work (data location)?
solAssert(isImplicitlyConvertibleTo(*functionType->parameterTypes().front()));
seenFunctions.insert(&function);
}
// TODO proper error handling.
if (seenFunctions.size() == 1)
return *seenFunctions.begin();
else
return nullptr;
}
MemberList::MemberMap Type::boundFunctions(Type const& _type, ASTNode const& _scope)
{
MemberList::MemberMap members;
set<pair<string, Declaration const*>> seenFunctions;
@ -381,25 +436,12 @@ MemberList::MemberMap Type::boundFunctions(Type const& _type, ASTNode const& _sc
members.emplace_back(&_function, asBoundFunction, *_name);
};
for (UsingForDirective const* ufd: usingForDirectives)
for (UsingForDirective const* ufd: usingForDirectivesForType(_type, _scope))
for (auto const& [pathPointer, operator_]: ufd->functionsAndOperators())
{
// Convert both types to pointers for comparison to see if the `using for`
// directive applies.
// Further down, we check more detailed for each function if `_type` is
// convertible to the function parameter type.
if (
ufd->typeName() &&
*TypeProvider::withLocationIfReference(typeLocation, &_type, true) !=
*TypeProvider::withLocationIfReference(
typeLocation,
ufd->typeName()->annotation().type,
true
)
)
if (operator_)
continue;
for (auto const& pathPointer: ufd->functionsOrLibrary())
{
solAssert(pathPointer);
Declaration const* declaration = pathPointer->annotation().referencedDeclaration;
solAssert(declaration);
@ -420,7 +462,6 @@ MemberList::MemberMap Type::boundFunctions(Type const& _type, ASTNode const& _sc
pathPointer->path().back()
);
}
}
return members;
}

View File

@ -377,6 +377,8 @@ public:
/// Clears all internally cached values (if any).
virtual void clearCache() const;
FunctionDefinition const* userDefinedOperator(Token _token, ASTNode const& _scope) const;
private:
/// @returns a member list containing all members added to this type by `using for` directives.
static MemberList::MemberMap boundFunctions(Type const& _type, ASTNode const& _scope);

View File

@ -502,6 +502,46 @@ bool ExpressionCompiler::visit(BinaryOperation const& _binaryOperation)
CompilerContext::LocationSetter locationSetter(m_context, _binaryOperation);
Expression const& leftExpression = _binaryOperation.leftExpression();
Expression const& rightExpression = _binaryOperation.rightExpression();
if (_binaryOperation.annotation().userDefinedFunction)
{
// TODO extract from function call
FunctionDefinition const& function = *_binaryOperation.annotation().userDefinedFunction;
FunctionType const* functionType = dynamic_cast<FunctionType const*>(
function.libraryFunction() ? function.typeViaContractName() : function.type()
);
solAssert(functionType);
functionType = dynamic_cast<FunctionType const&>(*functionType).asBoundFunction();
solAssert(functionType);
evmasm::AssemblyItem returnLabel = m_context.pushNewTag();
acceptAndConvert(leftExpression, *functionType->selfType());
acceptAndConvert(rightExpression, *functionType->parameterTypes().at(0));
utils().pushCombinedFunctionEntryLabel(
function.resolveVirtual(m_context.mostDerivedContract()),
false
);
unsigned parameterSize =
CompilerUtils::sizeOnStack(functionType->parameterTypes()) +
functionType->selfType()->sizeOnStack();
if (m_context.runtimeContext())
// We have a runtime context, so we need the creation part.
utils().rightShiftNumberOnStack(32);
else
// Extract the runtime part.
m_context << ((u256(1) << 32) - 1) << Instruction::AND;
m_context.appendJump(evmasm::AssemblyItem::JumpType::IntoFunction);
m_context << returnLabel;
unsigned returnParametersSize = CompilerUtils::sizeOnStack(functionType->returnParameterTypes());
// callee adds return parameters, but removes arguments and return label
m_context.adjustStackOffset(static_cast<int>(returnParametersSize - parameterSize) - 1);
return false;
}
solAssert(!!_binaryOperation.annotation().commonType, "");
Type const* commonType = _binaryOperation.annotation().commonType;
Token const c_op = _binaryOperation.getOperator();

View File

@ -775,10 +775,42 @@ bool IRGeneratorForStatements::visit(BinaryOperation const& _binOp)
{
setLocation(_binOp);
solAssert(!!_binOp.annotation().commonType);
// TOOD make this nicer
if (_binOp.annotation().userDefinedFunction)
{
_binOp.leftExpression().accept(*this);
_binOp.rightExpression().accept(*this);
setLocation(_binOp);
// TODO extract from function call
FunctionDefinition const& function = *_binOp.annotation().userDefinedFunction;
FunctionType const* functionType = dynamic_cast<FunctionType const*>(
function.libraryFunction() ? function.typeViaContractName() : function.type()
);
solAssert(functionType);
functionType = dynamic_cast<FunctionType const&>(*functionType).asBoundFunction();
solAssert(functionType);
// TODO virtual?
string left = expressionAsType(_binOp.leftExpression(), *functionType->selfType());
string right = expressionAsType(_binOp.rightExpression(), *functionType->parameterTypes().at(0));
solAssert(!left.empty() && !right.empty());
solAssert(function.isImplemented(), "");
define(_binOp) <<
m_context.enqueueFunctionForCodeGeneration(function) <<
("(" + left + ", " + right + ")\n");
return false;
}
solAssert(!!_binOp.annotation().commonType, "");
Type const* commonType = _binOp.annotation().commonType;
langutil::Token op = _binOp.getOperator();
if (op == Token::And || op == Token::Or)
{
// This can short-circuit!

View File

@ -968,6 +968,7 @@ ASTPointer<UsingForDirective> Parser::parseUsingDirective()
expectToken(Token::Using);
vector<ASTPointer<IdentifierPath>> functions;
vector<optional<Token>> operators;
bool const usesBraces = m_scanner->currentToken() == Token::LBrace;
if (usesBraces)
{
@ -975,12 +976,38 @@ ASTPointer<UsingForDirective> Parser::parseUsingDirective()
{
advance();
functions.emplace_back(parseIdentifierPath());
if (m_scanner->currentToken() == Token::As)
{
advance();
Token operator_ = m_scanner->currentToken();
vector<Token> overridable = {
// Potential future additions: <<, >>, **, !
Token::BitOr, Token::BitAnd, Token::BitXor,
Token::Add, Token::Sub, Token::Mul, Token::Div, Token::Mod,
Token::Equal, Token::NotEqual,
Token::LessThan, Token::GreaterThan, Token::LessThanOrEqual, Token::GreaterThanOrEqual,
Token::BitNot
};
if (!util::contains(overridable, operator_))
parserError(
1885_error,
("The operator " + string{TokenTraits::toString(operator_)} + " cannot be user-implemented. This is only possible for the folloing operators: ") +
util::joinHumanReadable(overridable | ranges::views::transform([](Token _t) { return string{TokenTraits::toString(_t)}; }))
);
operators.emplace_back(operator_);
advance();
}
else
operators.emplace_back();
}
while (m_scanner->currentToken() == Token::Comma);
expectToken(Token::RBrace);
}
else
{
functions.emplace_back(parseIdentifierPath());
operators.emplace_back();
}
ASTPointer<TypeName> typeName;
expectToken(Token::For);
@ -996,7 +1023,7 @@ ASTPointer<UsingForDirective> Parser::parseUsingDirective()
}
nodeFactory.markEndPosition();
expectToken(Token::Semicolon);
return nodeFactory.createNode<UsingForDirective>(std::move(functions), usesBraces, typeName, global);
return nodeFactory.createNode<UsingForDirective>(std::move(functions), std::move(operators), usesBraces, typeName, global);
}
ASTPointer<ModifierInvocation> Parser::parseModifierInvocation()

View File

@ -0,0 +1,16 @@
type MyInt is int;
using {add as +} for MyInt;
function add(MyInt, MyInt) pure returns (bool) {
return true;
}
contract C {
function f() public pure returns (bool t) {
t = MyInt.wrap(2) + MyInt.wrap(7);
}
}
// ====
// compileViaYul: also
// ----
// f() -> true

View File

@ -0,0 +1,75 @@
type Int is int128;
using {
bitor as |, bitand as &, bitxor as ^, bitnot as ~,
add as +, sub as -, unsub as -, mul as *, div as /, mod as %,
eq as ==, noteq as !=, lt as <, gt as >, leq as <=, geq as >=
} for Int;
function uw(Int x) pure returns (int128) {
return Int.unwrap(x);
}
function w(int128 x) pure returns (Int) {
return Int.wrap(x);
}
function bitor(Int, Int) pure returns (Int) {
return w(1);
}
function bitand(Int, Int) pure returns (Int) {
return w(2);
}
function bitxor(Int, Int) pure returns (Int) {
return w(3);
}
function bitnot(Int) pure returns (Int) {
return w(4);
}
function add(Int, Int) pure returns (Int) {
return w(5);
}
function sub(Int, Int) pure returns (Int) {
return w(6);
}
function unsub(Int) pure returns (Int) {
return w(7);
}
function mul(Int, Int) pure returns (Int) {
return w(8);
}
function div(Int, Int) pure returns (Int) {
return w(9);
}
function mod(Int, Int) pure returns (Int) {
return w(10);
}
function eq(Int x, Int) pure returns (bool) {
return uw(x) == 1;
}
function noteq(Int x, Int) pure returns (bool) {
return uw(x) == 2;
}
function lt(Int x, Int) pure returns (bool) {
return uw(x) == 3;
}
function gt(Int x, Int) pure returns (bool) {
return uw(x) == 4;
}
function leq(Int x, Int) pure returns (bool) {
return uw(x) == 5;
}
function geq(Int x, Int) pure returns (bool) {
return uw(x) == 6;
}
// TODO test that side-effects are executed properly.
contract C {
function f1() public pure returns (Int) {
require(w(1) | w(2) == w(1));
require(!(w(1) | w(2) == w(2)));
return w(1) | w(2);
}
// TODO all the other operators
}
// ====
// compileViaYul: also
// ----
// f1()

View File

@ -0,0 +1,24 @@
type Fixed is int128;
using {add as +, mul as *} for Fixed;
int constant MULTIPLIER = 10**18;
function add(Fixed a, Fixed b) pure returns (Fixed) {
return Fixed.wrap(Fixed.unwrap(a) + Fixed.unwrap(b));
}
function mul(Fixed a, Fixed b) pure returns (Fixed) {
int intermediate = (int(Fixed.unwrap(a)) * int(Fixed.unwrap(b))) / MULTIPLIER;
if (int128(intermediate) != intermediate) { revert("Overflow"); }
return Fixed.wrap(int128(intermediate));
}
contract C {
function applyInterest(Fixed value, Fixed percentage) public pure returns (Fixed result) {
return value + value * percentage;
}
}
// ====
// compileViaYul: also
// ----
// applyInterest(int128,int128): 500000000000000000000, 100000000000000000 -> 550000000000000000000

View File

@ -0,0 +1,10 @@
type Type is uint;
using {f as +} for Type;
function f(Type, Type) pure returns (Type) {}
Type constant t = Type.wrap(1);
Type constant u = v + t;
Type constant v = u + t;
// ----
// TypeError 8349: (141-146): Initial value for constant variable has to be compile-time constant.
// TypeError 8349: (166-171): Initial value for constant variable has to be compile-time constant.

View File

@ -0,0 +1,4 @@
using {f as +} for uint;
function f(uint, uint) pure returns (uint) {}
// ----
// TypeError 5332: (7-8): Operators can only be implemented for user-defined types and not for contracts.