ConstantEvaluator: respect integer type hints in arithmetic operations.

This commit is contained in:
Christian Parpart 2020-11-11 13:58:11 +01:00
parent 0ae504f883
commit 74912d65cb
6 changed files with 110 additions and 51 deletions

View File

@ -30,25 +30,48 @@
#include <libsolutil/Common.h>
using namespace std;
using namespace solidity;
using namespace solidity::frontend;
using namespace solidity::langutil;
using std::optional;
using std::nullopt;
using std::string;
void ConstantEvaluator::endVisit(UnaryOperation const& _operation)
{
auto sub = evaluatedValue(_operation.subExpression());
if (sub)
setValue(_operation, sub->unaryOperatorResult(_operation.getOperator()));
if (auto const sub = result(_operation.subExpression()); sub.has_value())
{
auto const res = sub.value().type->unaryOperatorResult(_operation.getOperator());
if (auto const rationalType = dynamic_cast<RationalNumberType const*>(res.get()))
{
auto const subType = sub.value().type;
if (subType && subType->category() == Type::Category::Integer)
{
rational const frac = rationalType->value();
bigint const num = frac.numerator() / frac.denominator();
setValue(_operation, rational(num, 1));
}
else
setValue(_operation, rationalType->value());
}
}
}
void ConstantEvaluator::endVisit(BinaryOperation const& _operation)
{
auto left = evaluatedValue(_operation.leftExpression());
auto right = evaluatedValue(_operation.rightExpression());
auto left = value(_operation.leftExpression());
auto right = value(_operation.rightExpression());
if (left && right)
{
TypePointer commonType = left->binaryOperatorResult(_operation.getOperator(), right);
TypePointer const commonType = TypeProvider::rationalNumber(*left)->binaryOperatorResult(
_operation.getOperator(),
TypeProvider::rationalNumber(*right)
);
auto const leftType = result(_operation.leftExpression()).value().type;
auto const rightType = result(_operation.rightExpression()).value().type;
if (!commonType)
m_errorReporter.fatalTypeError(
6020_error,
@ -56,22 +79,37 @@ void ConstantEvaluator::endVisit(BinaryOperation const& _operation)
"Operator " +
string(TokenTraits::toString(_operation.getOperator())) +
" not compatible with types " +
left->toString() +
leftType->toString() +
" and " +
right->toString()
rightType->toString()
);
setValue(
_operation,
TokenTraits::isCompareOp(_operation.getOperator()) ?
TypeProvider::boolean() :
commonType
);
if (auto const rationalCommonType = dynamic_cast<RationalNumberType const*>(commonType))
{
if (leftType && leftType->category() == Type::Category::Integer &&
rightType && rightType->category() == Type::Category::Integer)
{
rational const frac = rationalCommonType->value();
bigint const num = frac.numerator() / frac.denominator();
setValue(_operation, rational(num, 1));
}
else
setValue(_operation, rationalCommonType->value());
}
// other types, such as BoolType are currently impossible to get, and in the old
// code, have been ignored, too.
// When we want to widen the constexpr support in Solidity, then we
// need to touch here, too.
}
}
void ConstantEvaluator::endVisit(Literal const& _literal)
{
setValue(_literal, TypeProvider::forLiteral(_literal));
auto const literalType = TypeProvider::forLiteral(_literal);
if (auto const p = dynamic_cast<RationalNumberType const*>(literalType))
setResult(_literal, TypedValue{literalType, p->value()});
}
bool ConstantEvaluator::evaluated(ASTNode const& _node) const noexcept
@ -99,28 +137,24 @@ void ConstantEvaluator::endVisit(Identifier const& _identifier)
}
// Link LHS's identifier to the evaluation result of the RHS expression.
setResult(_identifier, result(*value));
if (auto const resultOpt = result(*value); resultOpt.has_value())
setResult(_identifier, TypedValue{variableDeclaration->annotation().type, resultOpt.value().value});
}
void ConstantEvaluator::endVisit(TupleExpression const& _tuple) // TODO: do we actually ever need this code path here?
{
if (!_tuple.isInlineArray() && _tuple.components().size() == 1)
setValue(_tuple, evaluatedValue(*_tuple.components().front()));
}
void ConstantEvaluator::setValue(ASTNode const& _node, TypePointer const& _value)
{
setResult(_node, TypedValue{_value, _value});
if (auto v = value(*_tuple.components().front()); v.has_value())
setValue(_tuple, v.value());
}
void ConstantEvaluator::setResult(ASTNode const& _node, optional<ConstantEvaluator::TypedValue> _result)
{
if (_result.has_value())
{
auto const sourceType = _result.value().sourceType;
auto const value = _result.value().evaluatedValue;
if (value && value->category() == Type::Category::RationalNumber)
m_evaluations[&_node] = {sourceType, value};
auto const type = _result.value().type;
auto const value = _result.value().value;
m_evaluations[&_node] = {type, value};
}
}
@ -132,35 +166,35 @@ optional<ConstantEvaluator::TypedValue> ConstantEvaluator::result(ASTNode const&
return nullopt;
}
TypePointer ConstantEvaluator::sourceType(ASTNode const& _node)
TypePointer ConstantEvaluator::type(ASTNode const& _node)
{
if (auto p = m_evaluations.find(&_node); p != m_evaluations.end())
return p->second.sourceType;
return p->second.type;
return nullptr;
}
TypePointer ConstantEvaluator::evaluatedValue(ASTNode const& _node)
optional<rational> ConstantEvaluator::value(ASTNode const& _node)
{
if (auto p = m_evaluations.find(&_node); p != m_evaluations.end())
return p->second.evaluatedValue;
return p->second.value;
return nullptr;
return nullopt;
}
TypePointer ConstantEvaluator::evaluate(langutil::ErrorReporter& _errorReporter, Expression const& _expr)
std::optional<rational> ConstantEvaluator::evaluate(langutil::ErrorReporter& _errorReporter, Expression const& _expr)
{
EvaluationMap evaluations;
ConstantEvaluator evaluator(_errorReporter, evaluations);
return evaluator.evaluate(_expr);
}
TypePointer ConstantEvaluator::evaluate(Expression const& _expr)
std::optional<rational> ConstantEvaluator::evaluate(Expression const& _expr)
{
m_depth++;
ScopeGuard _([&]() { m_depth--; });
_expr.accept(*this);
return evaluatedValue(_expr);
return value(_expr);
}

View File

@ -44,7 +44,7 @@ class TypeChecker;
class ConstantEvaluator: private ASTConstVisitor
{
public:
struct TypedValue { TypePointer sourceType; TypePointer evaluatedValue; };
struct TypedValue { TypePointer type; rational value; };
using EvaluationMap = std::map<ASTNode const*, TypedValue>;
ConstantEvaluator(langutil::ErrorReporter& _errorReporter, EvaluationMap& _evaluations):
@ -54,12 +54,12 @@ public:
{
}
static TypePointer evaluate(
static std::optional<rational> evaluate(
langutil::ErrorReporter& _errorReporter,
Expression const& _expr
);
TypePointer evaluate(Expression const& _expr);
std::optional<rational> evaluate(Expression const& _expr);
private:
void endVisit(BinaryOperation const& _operation) override;
@ -68,14 +68,21 @@ private:
void endVisit(Identifier const& _identifier) override;
void endVisit(TupleExpression const& _tuple) override;
void setValue(ASTNode const& _node, TypePointer const& _value);
TypePointer sourceType(ASTNode const& _node);
TypePointer evaluatedValue(ASTNode const& _node);
TypePointer type(ASTNode const& _node);
std::optional<rational> value(ASTNode const& _node);
/// @return typed evaluation result or std::nullopt if not evaluated yet.
std::optional<TypedValue> result(ASTNode const& _node);
/// Conditionally sets the evaluation result for the given ASTNode @p _node.
void setResult(ASTNode const& _node, std::optional<TypedValue> _result);
void setValue(ASTNode const& _node, rational const& _value)
{
setResult(_node, TypedValue{nullptr, _value});
}
/// @returns boolean indicating whether or not given ASTNode @p _node has been evaluated already or not.
bool evaluated(ASTNode const& _node) const noexcept;
langutil::ErrorReporter& m_errorReporter;

View File

@ -249,7 +249,8 @@ void DeclarationTypeChecker::endVisit(ArrayTypeName const& _typeName)
{
TypePointer& lengthTypeGeneric = length->annotation().type;
if (!lengthTypeGeneric)
lengthTypeGeneric = ConstantEvaluator::evaluate(m_errorReporter, *length);
if (auto const p = ConstantEvaluator::evaluate(m_errorReporter, *length); p.has_value())
lengthTypeGeneric = TypeProvider::rationalNumber(*p);
RationalNumberType const* lengthType = dynamic_cast<RationalNumberType const*>(lengthTypeGeneric);
u256 lengthValue = 0;
if (!lengthType || !lengthType->mobileType())

View File

@ -290,10 +290,8 @@ bool StaticAnalyzer::visit(BinaryOperation const& _operation)
*_operation.rightExpression().annotation().isPure &&
(_operation.getOperator() == Token::Div || _operation.getOperator() == Token::Mod)
)
if (auto rhs = dynamic_cast<RationalNumberType const*>(
ConstantEvaluator::evaluate(m_errorReporter, _operation.rightExpression())
))
if (rhs->isZero())
if (auto const rhs = ConstantEvaluator::evaluate(m_errorReporter, _operation.rightExpression()); rhs.has_value())
if (rhs.value().numerator() == 0)
m_errorReporter.typeError(
1211_error,
_operation.location(),
@ -313,10 +311,10 @@ bool StaticAnalyzer::visit(FunctionCall const& _functionCall)
{
solAssert(_functionCall.arguments().size() == 3, "");
if (*_functionCall.arguments()[2]->annotation().isPure)
if (auto lastArg = dynamic_cast<RationalNumberType const*>(
ConstantEvaluator::evaluate(m_errorReporter, *(_functionCall.arguments())[2])
))
if (lastArg->isZero())
if (auto const lastArg = ConstantEvaluator::evaluate(m_errorReporter, *(_functionCall.arguments())[2]);
lastArg.has_value()
)
if (lastArg.value().numerator() == 0)
m_errorReporter.typeError(
4195_error,
_functionCall.location(),

View File

@ -568,6 +568,11 @@ public:
u256 literalValue(Literal const* _literal) const override;
TypePointer mobileType() const override;
/// @returns the underlying raw literal value.
///
/// @see literalValue(Literal const*))
rational const& value() const noexcept { return m_value; }
/// @returns the smallest integer type that can hold the value or an empty pointer if not possible.
IntegerType const* integerType() const;
/// @returns the smallest fixed type that can hold the value or incurs the least precision loss,

View File

@ -0,0 +1,14 @@
contract C {
uint constant a = 12;
uint constant b = 10;
function f() public pure returns (uint) {
uint[(a / b) * b] memory x;
return x.length;
}
}
// ====
// compileViaYul: true
// ----
// constructor() ->
// f() -> 10