feat(math): add safe arithmetic (#18552)
This commit is contained in:
parent
7d5c2dbae6
commit
7d5e9f1c23
@ -37,6 +37,7 @@ Ref: https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.j
|
||||
## [Unreleased]
|
||||
|
||||
### Features
|
||||
* [#18552](https://github.com/cosmos/cosmos-sdk/pull/18552) Add safe arithmetic operations for `math.Int` that return an error in case of an overflow or any mishap.
|
||||
|
||||
* [#18421](https://github.com/cosmos/cosmos-sdk/pull/18421) Add mutative api for `LegacyDec.BigInt()`.
|
||||
|
||||
|
||||
102
math/int.go
102
math/int.go
@ -3,6 +3,7 @@ package math
|
||||
import (
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
@ -13,6 +14,14 @@ import (
|
||||
// MaxBitLen defines the maximum bit length supported bit Int and Uint types.
|
||||
const MaxBitLen = 256
|
||||
|
||||
// Integer errors
|
||||
var (
|
||||
// ErrIntOverflow is the error returned when an integer overflow occurs
|
||||
ErrIntOverflow = errors.New("Integer overflow")
|
||||
// ErrDivideByZero is the error returned when a divide by zero occurs
|
||||
ErrDivideByZero = errors.New("Divide by zero")
|
||||
)
|
||||
|
||||
func newIntegerFromString(s string) (*big.Int, bool) {
|
||||
return new(big.Int).SetString(s, 0)
|
||||
}
|
||||
@ -259,12 +268,12 @@ func (i Int) LTE(i2 Int) bool {
|
||||
|
||||
// Add adds Int from another
|
||||
func (i Int) Add(i2 Int) (res Int) {
|
||||
res = Int{add(i.i, i2.i)}
|
||||
// Check overflow
|
||||
if res.i.BitLen() > MaxBitLen {
|
||||
panic("Int overflow")
|
||||
x, err := i.SafeAdd(i2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
return x
|
||||
}
|
||||
|
||||
// AddRaw adds int64 to Int
|
||||
@ -272,14 +281,24 @@ func (i Int) AddRaw(i2 int64) Int {
|
||||
return i.Add(NewInt(i2))
|
||||
}
|
||||
|
||||
// Sub subtracts Int from another
|
||||
func (i Int) Sub(i2 Int) (res Int) {
|
||||
res = Int{sub(i.i, i2.i)}
|
||||
// SafeAdd adds Int from another and returns an error if overflow
|
||||
func (i Int) SafeAdd(i2 Int) (res Int, err error) {
|
||||
res = Int{add(i.i, i2.i)}
|
||||
// Check overflow
|
||||
if res.i.BitLen() > MaxBitLen {
|
||||
panic("Int overflow")
|
||||
return Int{}, ErrIntOverflow
|
||||
}
|
||||
return
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Sub subtracts Int from another
|
||||
func (i Int) Sub(i2 Int) (res Int) {
|
||||
// Check overflow
|
||||
x, err := i.SafeSub(i2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// SubRaw subtracts int64 from Int
|
||||
@ -287,18 +306,24 @@ func (i Int) SubRaw(i2 int64) Int {
|
||||
return i.Sub(NewInt(i2))
|
||||
}
|
||||
|
||||
// SafeSub subtracts Int from another and returns an error if overflow or underflow
|
||||
func (i Int) SafeSub(i2 Int) (res Int, err error) {
|
||||
res = Int{sub(i.i, i2.i)}
|
||||
// Check overflow/underflow
|
||||
if res.i.BitLen() > MaxBitLen {
|
||||
return Int{}, ErrIntOverflow
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Mul multiples two Ints
|
||||
func (i Int) Mul(i2 Int) (res Int) {
|
||||
// Check overflow
|
||||
if i.i.BitLen()+i2.i.BitLen()-1 > MaxBitLen {
|
||||
panic("Int overflow")
|
||||
x, err := i.SafeMul(i2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
res = Int{mul(i.i, i2.i)}
|
||||
// Check overflow if sign of both are same
|
||||
if res.i.BitLen() > MaxBitLen {
|
||||
panic("Int overflow")
|
||||
}
|
||||
return
|
||||
return x
|
||||
}
|
||||
|
||||
// MulRaw multipies Int and int64
|
||||
@ -306,13 +331,28 @@ func (i Int) MulRaw(i2 int64) Int {
|
||||
return i.Mul(NewInt(i2))
|
||||
}
|
||||
|
||||
// SafeMul multiples Int from another and returns an error if overflow
|
||||
func (i Int) SafeMul(i2 Int) (res Int, err error) {
|
||||
// Check overflow
|
||||
if i.i.BitLen()+i2.i.BitLen()-1 > MaxBitLen {
|
||||
return Int{}, ErrIntOverflow
|
||||
}
|
||||
res = Int{mul(i.i, i2.i)}
|
||||
// Check overflow if sign of both are same
|
||||
if res.i.BitLen() > MaxBitLen {
|
||||
return Int{}, ErrIntOverflow
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Quo divides Int with Int
|
||||
func (i Int) Quo(i2 Int) (res Int) {
|
||||
// Check division-by-zero
|
||||
if i2.i.Sign() == 0 {
|
||||
x, err := i.SafeQuo(i2)
|
||||
if err != nil {
|
||||
panic("Division by zero")
|
||||
}
|
||||
return Int{div(i.i, i2.i)}
|
||||
return x
|
||||
}
|
||||
|
||||
// QuoRaw divides Int with int64
|
||||
@ -320,12 +360,22 @@ func (i Int) QuoRaw(i2 int64) Int {
|
||||
return i.Quo(NewInt(i2))
|
||||
}
|
||||
|
||||
// SafeQuo divides Int with Int and returns an error if division by zero
|
||||
func (i Int) SafeQuo(i2 Int) (res Int, err error) {
|
||||
// Check division-by-zero
|
||||
if i2.i.Sign() == 0 {
|
||||
return Int{}, ErrDivideByZero
|
||||
}
|
||||
return Int{div(i.i, i2.i)}, nil
|
||||
}
|
||||
|
||||
// Mod returns remainder after dividing with Int
|
||||
func (i Int) Mod(i2 Int) Int {
|
||||
if i2.Sign() == 0 {
|
||||
panic("division-by-zero")
|
||||
x, err := i.SafeMod(i2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return Int{mod(i.i, i2.i)}
|
||||
return x
|
||||
}
|
||||
|
||||
// ModRaw returns remainder after dividing with int64
|
||||
@ -333,6 +383,14 @@ func (i Int) ModRaw(i2 int64) Int {
|
||||
return i.Mod(NewInt(i2))
|
||||
}
|
||||
|
||||
// SafeMod returns remainder after dividing with Int and returns an error if division by zero
|
||||
func (i Int) SafeMod(i2 Int) (res Int, err error) {
|
||||
if i2.Sign() == 0 {
|
||||
return Int{}, ErrDivideByZero
|
||||
}
|
||||
return Int{mod(i.i, i2.i)}, nil
|
||||
}
|
||||
|
||||
// Neg negates Int
|
||||
func (i Int) Neg() (res Int) {
|
||||
return Int{neg(i.i)}
|
||||
|
||||
@ -111,32 +111,66 @@ func (s *intTestSuite) TestIntPanic() {
|
||||
s.Require().NotPanics(func() { i1.Add(i1) })
|
||||
s.Require().NotPanics(func() { i2.Add(i2) })
|
||||
s.Require().Panics(func() { i3.Add(i3) })
|
||||
_, err := i1.SafeAdd(i1)
|
||||
s.Require().Nil(err)
|
||||
_, err = i2.SafeAdd(i2)
|
||||
s.Require().Nil(err)
|
||||
_, err = i3.SafeAdd(i3)
|
||||
s.Require().Error(err)
|
||||
|
||||
s.Require().NotPanics(func() { i1.Sub(i1.Neg()) })
|
||||
s.Require().NotPanics(func() { i2.Sub(i2.Neg()) })
|
||||
s.Require().Panics(func() { i3.Sub(i3.Neg()) })
|
||||
_, err = i1.SafeSub(i1.Neg())
|
||||
s.Require().Nil(err)
|
||||
_, err = i2.SafeSub(i2.Neg())
|
||||
s.Require().Nil(err)
|
||||
_, err = i3.SafeSub(i3.Neg())
|
||||
s.Require().Error(err)
|
||||
|
||||
s.Require().Panics(func() { i1.Mul(i1) })
|
||||
s.Require().Panics(func() { i2.Mul(i2) })
|
||||
s.Require().Panics(func() { i3.Mul(i3) })
|
||||
_, err = i1.SafeMul(i1)
|
||||
s.Require().Error(err)
|
||||
_, err = i2.SafeMul(i2)
|
||||
s.Require().Error(err)
|
||||
_, err = i3.SafeMul(i3)
|
||||
s.Require().Error(err)
|
||||
|
||||
s.Require().Panics(func() { i1.Neg().Mul(i1.Neg()) })
|
||||
s.Require().Panics(func() { i2.Neg().Mul(i2.Neg()) })
|
||||
s.Require().Panics(func() { i3.Neg().Mul(i3.Neg()) })
|
||||
_, err = i1.Neg().SafeMul(i1.Neg())
|
||||
s.Require().Error(err)
|
||||
_, err = i2.Neg().SafeMul(i2.Neg())
|
||||
s.Require().Error(err)
|
||||
_, err = i3.Neg().SafeMul(i3.Neg())
|
||||
s.Require().Error(err)
|
||||
|
||||
// // Underflow check
|
||||
// Underflow check
|
||||
i3n := i3.Neg()
|
||||
s.Require().NotPanics(func() { i3n.Sub(i1) })
|
||||
s.Require().NotPanics(func() { i3n.Sub(i2) })
|
||||
s.Require().Panics(func() { i3n.Sub(i3) })
|
||||
_, err = i3n.SafeSub(i3)
|
||||
s.Require().Error(err)
|
||||
|
||||
s.Require().NotPanics(func() { i3n.Add(i1.Neg()) })
|
||||
s.Require().NotPanics(func() { i3n.Add(i2.Neg()) })
|
||||
s.Require().Panics(func() { i3n.Add(i3.Neg()) })
|
||||
_, err = i3n.SafeAdd(i3.Neg())
|
||||
s.Require().Error(err)
|
||||
|
||||
s.Require().Panics(func() { i1.Mul(i1.Neg()) })
|
||||
s.Require().Panics(func() { i2.Mul(i2.Neg()) })
|
||||
s.Require().Panics(func() { i3.Mul(i3.Neg()) })
|
||||
_, err = i1.SafeMul(i1.Neg())
|
||||
s.Require().Error(err)
|
||||
_, err = i2.SafeMul(i2.Neg())
|
||||
s.Require().Error(err)
|
||||
_, err = i3.SafeMul(i3.Neg())
|
||||
s.Require().Error(err)
|
||||
|
||||
// Bound check
|
||||
intmax := math.NewIntFromBigInt(new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1)))
|
||||
@ -145,12 +179,18 @@ func (s *intTestSuite) TestIntPanic() {
|
||||
s.Require().NotPanics(func() { intmin.Sub(math.ZeroInt()) })
|
||||
s.Require().Panics(func() { intmax.Add(math.OneInt()) })
|
||||
s.Require().Panics(func() { intmin.Sub(math.OneInt()) })
|
||||
_, err = intmax.SafeAdd(math.OneInt())
|
||||
s.Require().Error(err)
|
||||
_, err = intmin.SafeSub(math.OneInt())
|
||||
s.Require().Error(err)
|
||||
|
||||
s.Require().NotPanics(func() { math.NewIntFromBigInt(nil) })
|
||||
s.Require().True(math.NewIntFromBigInt(nil).IsNil())
|
||||
|
||||
// Division-by-zero check
|
||||
s.Require().Panics(func() { i1.Quo(math.NewInt(0)) })
|
||||
_, err = i1.SafeQuo(math.NewInt(0))
|
||||
s.Require().Error(err)
|
||||
|
||||
s.Require().NotPanics(func() { math.Int{}.BigInt() })
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user