Reimplement constant evaluator.

This commit is contained in:
Christian Parpart 2020-11-16 12:19:52 +01:00 committed by Leonardo Alt
parent 49bde69afa
commit c5d172c058
15 changed files with 245 additions and 96 deletions

View File

@ -32,6 +32,8 @@ using namespace solidity;
using namespace solidity::frontend; using namespace solidity::frontend;
using namespace solidity::langutil; using namespace solidity::langutil;
using TypedRational = ConstantEvaluator::TypedRational;
namespace namespace
{ {
@ -227,85 +229,162 @@ optional<rational> ConstantEvaluator::evaluateUnaryOperator(Token _operator, rat
} }
} }
optional<TypedRational> convertType(rational const& _value, Type const& _type)
{
if (_type.category() == Type::Category::RationalNumber)
return TypedRational{TypeProvider::rationalNumber(_value), _value};
else if (auto const* integerType = dynamic_cast<IntegerType const*>(&_type))
{
if (_value > integerType->maxValue() || _value < integerType->minValue())
return nullopt;
else
return TypedRational{&_type, _value.numerator() / _value.denominator()};
}
else
return nullopt;
}
optional<TypedRational> convertType(optional<TypedRational> const& _value, Type const& _type)
{
return _value ? convertType(_value->value, _type) : nullopt;
}
optional<TypedRational> constantToTypedValue(Type const& _type)
{
if (_type.category() == Type::Category::RationalNumber)
return TypedRational{&_type, dynamic_cast<RationalNumberType const&>(_type).value()};
else
return nullopt;
}
optional<TypedRational> ConstantEvaluator::evaluate(
langutil::ErrorReporter& _errorReporter,
Expression const& _expr
)
{
return ConstantEvaluator{_errorReporter}.evaluate(_expr);
}
optional<TypedRational> ConstantEvaluator::evaluate(ASTNode const& _node)
{
if (!m_values.count(&_node))
{
if (auto const* varDecl = dynamic_cast<VariableDeclaration const*>(&_node))
{
solAssert(varDecl->isConstant(), "");
if (!varDecl->value())
m_values[&_node] = nullopt;
else
{
m_depth++;
if (m_depth > 32)
m_errorReporter.fatalTypeError(
5210_error,
varDecl->location(),
"Cyclic constant definition (or maximum recursion depth exhausted)."
);
m_values[&_node] = convertType(evaluate(*varDecl->value()), *varDecl->type());
m_depth--;
}
}
else if (auto const* expression = dynamic_cast<Expression const*>(&_node))
{
expression->accept(*this);
if (!m_values.count(&_node))
m_values[&_node] = nullopt;
}
}
return m_values.at(&_node);
}
void ConstantEvaluator::endVisit(UnaryOperation const& _operation) void ConstantEvaluator::endVisit(UnaryOperation const& _operation)
{ {
auto sub = type(_operation.subExpression()); optional<TypedRational> value = evaluate(_operation.subExpression());
if (sub) if (!value)
setType(_operation, sub->unaryOperatorResult(_operation.getOperator())); return;
TypePointer resultType = value->type->unaryOperatorResult(_operation.getOperator());
if (!resultType)
return;
value = convertType(value, *resultType);
if (!value)
return;
if (optional<rational> result = evaluateUnaryOperator(_operation.getOperator(), value->value))
{
optional<TypedRational> convertedValue = convertType(*result, *resultType);
if (!convertedValue)
m_errorReporter.fatalTypeError(
3667_error,
_operation.location(),
"Arithmetic error when computing constant value."
);
m_values[&_operation] = convertedValue;
}
} }
void ConstantEvaluator::endVisit(BinaryOperation const& _operation) void ConstantEvaluator::endVisit(BinaryOperation const& _operation)
{ {
auto left = type(_operation.leftExpression()); optional<TypedRational> left = evaluate(_operation.leftExpression());
auto right = type(_operation.rightExpression()); optional<TypedRational> right = evaluate(_operation.rightExpression());
if (left && right) if (!left || !right)
return;
// If this is implemented in the future: Comparison operators have a "binaryOperatorResult"
// that is non-bool, but the result has to be bool.
if (TokenTraits::isCompareOp(_operation.getOperator()))
return;
TypePointer resultType = left->type->binaryOperatorResult(_operation.getOperator(), right->type);
if (!resultType)
{ {
TypePointer commonType = left->binaryOperatorResult(_operation.getOperator(), right);
if (!commonType)
m_errorReporter.fatalTypeError( m_errorReporter.fatalTypeError(
6020_error, 6020_error,
_operation.location(), _operation.location(),
"Operator " + "Operator " +
string(TokenTraits::toString(_operation.getOperator())) + string(TokenTraits::toString(_operation.getOperator())) +
" not compatible with types " + " not compatible with types " +
left->toString() + left->type->toString() +
" and " + " and " +
right->toString() right->type->toString()
); );
setType( return;
_operation, }
TokenTraits::isCompareOp(_operation.getOperator()) ?
TypeProvider::boolean() : left = convertType(left, *resultType);
commonType right = convertType(right, *resultType);
if (!left || !right)
return;
if (optional<rational> value = evaluateBinaryOperator(_operation.getOperator(), left->value, right->value))
{
optional<TypedRational> convertedValue = convertType(*value, *resultType);
if (!convertedValue)
m_errorReporter.fatalTypeError(
2643_error,
_operation.location(),
"Arithmetic error when computing constant value."
); );
m_values[&_operation] = convertedValue;
} }
} }
void ConstantEvaluator::endVisit(Literal const& _literal) void ConstantEvaluator::endVisit(Literal const& _literal)
{ {
setType(_literal, TypeProvider::forLiteral(_literal)); if (Type const* literalType = TypeProvider::forLiteral(_literal))
m_values[&_literal] = constantToTypedValue(*literalType);
} }
void ConstantEvaluator::endVisit(Identifier const& _identifier) void ConstantEvaluator::endVisit(Identifier const& _identifier)
{ {
VariableDeclaration const* variableDeclaration = dynamic_cast<VariableDeclaration const*>(_identifier.annotation().referencedDeclaration); VariableDeclaration const* variableDeclaration = dynamic_cast<VariableDeclaration const*>(_identifier.annotation().referencedDeclaration);
if (!variableDeclaration) if (variableDeclaration && variableDeclaration->isConstant())
return; m_values[&_identifier] = evaluate(*variableDeclaration);
if (!variableDeclaration->isConstant())
return;
ASTPointer<Expression> const& value = variableDeclaration->value();
if (!value)
return;
else if (!m_types->count(value.get()))
{
if (m_depth > 32)
m_errorReporter.fatalTypeError(5210_error, _identifier.location(), "Cyclic constant definition (or maximum recursion depth exhausted).");
ConstantEvaluator(m_errorReporter, m_depth + 1, m_types).evaluate(*value);
}
setType(_identifier, type(*value));
} }
void ConstantEvaluator::endVisit(TupleExpression const& _tuple) void ConstantEvaluator::endVisit(TupleExpression const& _tuple)
{ {
if (!_tuple.isInlineArray() && _tuple.components().size() == 1) if (!_tuple.isInlineArray() && _tuple.components().size() == 1)
setType(_tuple, type(*_tuple.components().front())); m_values[&_tuple] = evaluate(*_tuple.components().front());
}
void ConstantEvaluator::setType(ASTNode const& _node, TypePointer const& _type)
{
if (_type && _type->category() == Type::Category::RationalNumber)
(*m_types)[&_node] = _type;
}
TypePointer ConstantEvaluator::type(ASTNode const& _node)
{
return (*m_types)[&_node];
}
TypePointer ConstantEvaluator::evaluate(Expression const& _expr)
{
_expr.accept(*this);
return type(_expr);
} }

View File

@ -39,22 +39,23 @@ class TypeChecker;
/** /**
* Small drop-in replacement for TypeChecker to evaluate simple expressions of integer constants. * Small drop-in replacement for TypeChecker to evaluate simple expressions of integer constants.
*
* Note: This always use "checked arithmetic" in the sense that any over- or underflow
* results in "unknown" value.
*/ */
class ConstantEvaluator: private ASTConstVisitor class ConstantEvaluator: private ASTConstVisitor
{ {
public: public:
ConstantEvaluator( struct TypedRational
langutil::ErrorReporter& _errorReporter,
size_t _newDepth = 0,
std::shared_ptr<std::map<ASTNode const*, TypePointer>> _types = std::make_shared<std::map<ASTNode const*, TypePointer>>()
):
m_errorReporter(_errorReporter),
m_depth(_newDepth),
m_types(std::move(_types))
{ {
} TypePointer type;
rational value;
};
TypePointer evaluate(Expression const& _expr); static std::optional<TypedRational> evaluate(
langutil::ErrorReporter& _errorReporter,
Expression const& _expr
);
/// Performs arbitrary-precision evaluation of a binary operator. Returns nullopt on cases like /// Performs arbitrary-precision evaluation of a binary operator. Returns nullopt on cases like
/// division by zero or e.g. bit operators applied to fractional values. /// division by zero or e.g. bit operators applied to fractional values.
@ -65,19 +66,21 @@ public:
static std::optional<rational> evaluateUnaryOperator(Token _operator, rational const& _input); static std::optional<rational> evaluateUnaryOperator(Token _operator, rational const& _input);
private: private:
explicit ConstantEvaluator(langutil::ErrorReporter& _errorReporter): m_errorReporter(_errorReporter) {}
std::optional<TypedRational> evaluate(ASTNode const& _node);
void endVisit(BinaryOperation const& _operation) override; void endVisit(BinaryOperation const& _operation) override;
void endVisit(UnaryOperation const& _operation) override; void endVisit(UnaryOperation const& _operation) override;
void endVisit(Literal const& _literal) override; void endVisit(Literal const& _literal) override;
void endVisit(Identifier const& _identifier) override; void endVisit(Identifier const& _identifier) override;
void endVisit(TupleExpression const& _tuple) override; void endVisit(TupleExpression const& _tuple) override;
void setType(ASTNode const& _node, TypePointer const& _type);
TypePointer type(ASTNode const& _node);
langutil::ErrorReporter& m_errorReporter; langutil::ErrorReporter& m_errorReporter;
/// Current recursion depth. /// Current recursion depth.
size_t m_depth = 0; size_t m_depth = 0;
std::shared_ptr<std::map<ASTNode const*, TypePointer>> m_types; /// Values of sub-expressions and variable declarations.
std::map<ASTNode const*, std::optional<TypedRational>> m_values;
}; };
} }

View File

@ -265,26 +265,30 @@ void DeclarationTypeChecker::endVisit(ArrayTypeName const& _typeName)
solAssert(baseType->storageBytes() != 0, "Illegal base type of storage size zero for array."); solAssert(baseType->storageBytes() != 0, "Illegal base type of storage size zero for array.");
if (Expression const* length = _typeName.length()) if (Expression const* length = _typeName.length())
{ {
TypePointer& lengthTypeGeneric = length->annotation().type; optional<rational> lengthValue;
if (!lengthTypeGeneric) if (length->annotation().type && length->annotation().type->category() == Type::Category::RationalNumber)
lengthTypeGeneric = ConstantEvaluator(m_errorReporter).evaluate(*length); lengthValue = dynamic_cast<RationalNumberType const&>(*length->annotation().type).value();
RationalNumberType const* lengthType = dynamic_cast<RationalNumberType const*>(lengthTypeGeneric); else if (optional<ConstantEvaluator::TypedRational> value = ConstantEvaluator::evaluate(m_errorReporter, *length))
u256 lengthValue = 0; lengthValue = value->value;
if (!lengthType || !lengthType->mobileType())
if (!lengthValue || lengthValue > TypeProvider::uint256()->max())
m_errorReporter.typeError( m_errorReporter.typeError(
5462_error, 5462_error,
length->location(), length->location(),
"Invalid array length, expected integer literal or constant expression." "Invalid array length, expected integer literal or constant expression."
); );
else if (lengthType->isZero()) else if (*lengthValue == 0)
m_errorReporter.typeError(1406_error, length->location(), "Array with zero length specified."); m_errorReporter.typeError(1406_error, length->location(), "Array with zero length specified.");
else if (lengthType->isFractional()) else if (lengthValue->denominator() != 1)
m_errorReporter.typeError(3208_error, length->location(), "Array with fractional length specified."); m_errorReporter.typeError(3208_error, length->location(), "Array with fractional length specified.");
else if (lengthType->isNegative()) else if (*lengthValue < 0)
m_errorReporter.typeError(3658_error, length->location(), "Array with negative length specified."); m_errorReporter.typeError(3658_error, length->location(), "Array with negative length specified.");
else
lengthValue = lengthType->literalValue(nullptr); _typeName.annotation().type = TypeProvider::array(
_typeName.annotation().type = TypeProvider::array(DataLocation::Storage, baseType, lengthValue); DataLocation::Storage,
baseType,
lengthValue ? u256(lengthValue->numerator()) : u256(0)
);
} }
else else
_typeName.annotation().type = TypeProvider::array(DataLocation::Storage, baseType); _typeName.annotation().type = TypeProvider::array(DataLocation::Storage, baseType);

View File

@ -290,10 +290,8 @@ bool StaticAnalyzer::visit(BinaryOperation const& _operation)
*_operation.rightExpression().annotation().isPure && *_operation.rightExpression().annotation().isPure &&
(_operation.getOperator() == Token::Div || _operation.getOperator() == Token::Mod) (_operation.getOperator() == Token::Div || _operation.getOperator() == Token::Mod)
) )
if (auto rhs = dynamic_cast<RationalNumberType const*>( if (auto rhs = ConstantEvaluator::evaluate(m_errorReporter, _operation.rightExpression()))
ConstantEvaluator(m_errorReporter).evaluate(_operation.rightExpression()) if (rhs->value == 0)
))
if (rhs->isZero())
m_errorReporter.typeError( m_errorReporter.typeError(
1211_error, 1211_error,
_operation.location(), _operation.location(),
@ -313,10 +311,8 @@ bool StaticAnalyzer::visit(FunctionCall const& _functionCall)
{ {
solAssert(_functionCall.arguments().size() == 3, ""); solAssert(_functionCall.arguments().size() == 3, "");
if (*_functionCall.arguments()[2]->annotation().isPure) if (*_functionCall.arguments()[2]->annotation().isPure)
if (auto lastArg = dynamic_cast<RationalNumberType const*>( if (auto lastArg = ConstantEvaluator::evaluate(m_errorReporter, *(_functionCall.arguments())[2]))
ConstantEvaluator(m_errorReporter).evaluate(*(_functionCall.arguments())[2]) if (lastArg->value == 0)
))
if (lastArg->isZero())
m_errorReporter.typeError( m_errorReporter.typeError(
4195_error, 4195_error,
_functionCall.location(), _functionCall.location(),

View File

@ -568,6 +568,11 @@ public:
u256 literalValue(Literal const* _literal) const override; u256 literalValue(Literal const* _literal) const override;
TypePointer mobileType() const override; TypePointer mobileType() const override;
/// @returns the underlying raw literal value.
///
/// @see literalValue(Literal const*))
rational const& value() const noexcept { return m_value; }
/// @returns the smallest integer type that can hold the value or an empty pointer if not possible. /// @returns the smallest integer type that can hold the value or an empty pointer if not possible.
IntegerType const* integerType() const; IntegerType const* integerType() const;
/// @returns the smallest fixed type that can hold the value or incurs the least precision loss, /// @returns the smallest fixed type that can hold the value or incurs the least precision loss,

View File

@ -2,8 +2,8 @@ set(sources
Algorithms.h Algorithms.h
AnsiColorized.h AnsiColorized.h
Assertions.h Assertions.h
Common.h
Common.cpp Common.cpp
Common.h
CommonData.cpp CommonData.cpp
CommonData.h CommonData.h
CommonIO.cpp CommonIO.cpp

View File

@ -0,0 +1,15 @@
contract C {
int constant a = 7;
int constant b = 3;
int constant c = a / b;
int constant d = (-a) / b;
function f() public pure returns (uint, int, uint, int) {
uint[c] memory x;
uint[-d] memory y;
return (x.length, c, y.length, -d);
}
}
// ====
// compileViaYul: also
// ----
// f() -> 2, 2, 2, 2

View File

@ -0,0 +1,14 @@
contract C {
uint constant a = 12;
uint constant b = 10;
function f() public pure returns (uint, uint) {
uint[(a / b) * b] memory x;
return (x.length, (a / b) * b);
}
}
// ====
// compileViaYul: true
// ----
// constructor() ->
// f() -> 0x0a, 0x0a

View File

@ -7,4 +7,4 @@ contract C {
} }
} }
// ---- // ----
// TypeError 5210: (36-39): Cyclic constant definition (or maximum recursion depth exhausted). // TypeError 5210: (17-44): Cyclic constant definition (or maximum recursion depth exhausted).

View File

@ -3,4 +3,4 @@ contract C {
uint[L] ids; uint[L] ids;
} }
// ---- // ----
// TypeError 3208: (51-52): Array with fractional length specified. // TypeError 5462: (51-52): Invalid array length, expected integer literal or constant expression.

View File

@ -5,4 +5,4 @@ contract C {
} }
} }
// ---- // ----
// TypeError 5210: (37-40): Cyclic constant definition (or maximum recursion depth exhausted). // TypeError 5210: (17-40): Cyclic constant definition (or maximum recursion depth exhausted).

View File

@ -0,0 +1,9 @@
contract C {
uint8 constant a = 255;
uint16 constant b = a + 2;
function f() public pure {
uint[b] memory x;
}
}
// ----
// TypeError 2643: (65-70): Arithmetic error when computing constant value.

View File

@ -0,0 +1,8 @@
contract C {
int8 constant a = -7;
function f() public pure {
uint[-a] memory x;
x[0] = 2;
}
}
// ----

View File

@ -0,0 +1,8 @@
contract C {
uint8 constant a = 0;
function f() public pure {
uint[a - 1] memory x;
}
}
// ----
// TypeError 2643: (83-88): Arithmetic error when computing constant value.

View File

@ -0,0 +1,8 @@
contract C {
int8 constant a = -128;
function f() public pure {
uint[-a] memory x;
}
}
// ----
// TypeError 3667: (85-87): Arithmetic error when computing constant value.