diff --git a/libsmtutil/SolverInterface.h b/libsmtutil/SolverInterface.h index c9cf8f700..4a21f1841 100644 --- a/libsmtutil/SolverInterface.h +++ b/libsmtutil/SolverInterface.h @@ -123,7 +123,7 @@ public: explicit Expression(std::shared_ptr _sort, std::string _name = ""): Expression(std::move(_name), {}, _sort) {} explicit Expression(std::string _name, std::vector _arguments, SortPointer _sort): name(std::move(_name)), arguments(std::move(_arguments)), sort(std::move(_sort)) {} - Expression(size_t _number): Expression(std::to_string(_number), {}, SortProvider::sintSort) {} + Expression(size_t _number): Expression(std::to_string(_number), {}, SortProvider::uintSort) {} Expression(u256 const& _number): Expression(_number.str(), {}, SortProvider::sintSort) {} Expression(s256 const& _number): Expression( _number >= 0 ? _number.str() : "-", @@ -189,7 +189,10 @@ public: static Expression ite(Expression _condition, Expression _trueValue, Expression _falseValue) { - smtAssert(*_trueValue.sort == *_falseValue.sort, ""); + if (_trueValue.sort->kind == Kind::Int) + smtAssert(_trueValue.sort->kind == _falseValue.sort->kind, ""); + else + smtAssert(*_trueValue.sort == *_falseValue.sort, ""); SortPointer sort = _trueValue.sort; return Expression("ite", std::vector{ std::move(_condition), std::move(_trueValue), std::move(_falseValue) @@ -213,7 +216,10 @@ public: std::shared_ptr arraySort = std::dynamic_pointer_cast(_array.sort); smtAssert(arraySort, ""); smtAssert(_index.sort, ""); - smtAssert(*arraySort->domain == *_index.sort, ""); + if (arraySort->domain->kind == Kind::Int) + smtAssert(arraySort->domain->kind == _index.sort->kind, ""); + else + smtAssert(*arraySort->domain == *_index.sort, ""); return Expression( "select", std::vector{std::move(_array), std::move(_index)}, @@ -230,7 +236,10 @@ public: smtAssert(_index.sort, ""); smtAssert(_element.sort, ""); smtAssert(*arraySort->domain == *_index.sort, ""); - smtAssert(*arraySort->range == *_element.sort, ""); + if (arraySort->domain->kind == Kind::Int) + smtAssert(arraySort->range->kind == _element.sort->kind, ""); + else + smtAssert(*arraySort->range == *_element.sort, ""); return Expression( "store", std::vector{std::move(_array), std::move(_index), std::move(_element)}, @@ -245,7 +254,10 @@ public: auto arraySort = std::dynamic_pointer_cast(sortSort->inner); smtAssert(sortSort && arraySort, ""); smtAssert(_value.sort, ""); - smtAssert(*arraySort->range == *_value.sort, ""); + if (arraySort->domain->kind == Kind::Int) + smtAssert(arraySort->range->kind == _value.sort->kind, ""); + else + smtAssert(*arraySort->range == *_value.sort, ""); return Expression( "const_array", std::vector{std::move(_sort), std::move(_value)}, diff --git a/libsmtutil/Sorts.h b/libsmtutil/Sorts.h index 107dc1f2c..419ee9dc5 100644 --- a/libsmtutil/Sorts.h +++ b/libsmtutil/Sorts.h @@ -57,9 +57,14 @@ struct IntSort: public Sort isSigned(_signed) {} - bool operator==(IntSort const& _other) const + bool operator==(Sort const& _other) const override { - return Sort::operator==(_other) && isSigned == _other.isSigned; + if (!Sort::operator==(_other)) + return false; + + auto otherIntSort = dynamic_cast(&_other); + smtAssert(otherIntSort); + return isSigned == otherIntSort->isSigned; } bool isSigned; @@ -72,9 +77,14 @@ struct BitVectorSort: public Sort size(_size) {} - bool operator==(BitVectorSort const& _other) const + bool operator==(Sort const& _other) const override { - return Sort::operator==(_other) && size == _other.size; + if (!Sort::operator==(_other)) + return false; + + auto otherBitVectorSort = dynamic_cast(&_other); + smtAssert(otherBitVectorSort); + return size == otherBitVectorSort->size; } unsigned size;