From 7d5e9f1c232caaae5f5a2135a82e15be79b3817d Mon Sep 17 00:00:00 2001 From: ocnc <97242826+ocnc@users.noreply.github.com> Date: Tue, 28 Nov 2023 02:54:04 -0500 Subject: [PATCH] feat(math): add safe arithmetic (#18552) --- math/CHANGELOG.md | 1 + math/int.go | 102 ++++++++++++++++++++++++++++++++++++---------- math/int_test.go | 42 ++++++++++++++++++- 3 files changed, 122 insertions(+), 23 deletions(-) diff --git a/math/CHANGELOG.md b/math/CHANGELOG.md index e385c915a0..2f512456c3 100644 --- a/math/CHANGELOG.md +++ b/math/CHANGELOG.md @@ -37,6 +37,7 @@ Ref: https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.j ## [Unreleased] ### Features +* [#18552](https://github.com/cosmos/cosmos-sdk/pull/18552) Add safe arithmetic operations for `math.Int` that return an error in case of an overflow or any mishap. * [#18421](https://github.com/cosmos/cosmos-sdk/pull/18421) Add mutative api for `LegacyDec.BigInt()`. diff --git a/math/int.go b/math/int.go index 9df290d8ee..bf6eb9b1fd 100644 --- a/math/int.go +++ b/math/int.go @@ -3,6 +3,7 @@ package math import ( "encoding" "encoding/json" + "errors" "fmt" "math/big" "strings" @@ -13,6 +14,14 @@ import ( // MaxBitLen defines the maximum bit length supported bit Int and Uint types. const MaxBitLen = 256 +// Integer errors +var ( + // ErrIntOverflow is the error returned when an integer overflow occurs + ErrIntOverflow = errors.New("Integer overflow") + // ErrDivideByZero is the error returned when a divide by zero occurs + ErrDivideByZero = errors.New("Divide by zero") +) + func newIntegerFromString(s string) (*big.Int, bool) { return new(big.Int).SetString(s, 0) } @@ -259,12 +268,12 @@ func (i Int) LTE(i2 Int) bool { // Add adds Int from another func (i Int) Add(i2 Int) (res Int) { - res = Int{add(i.i, i2.i)} // Check overflow - if res.i.BitLen() > MaxBitLen { - panic("Int overflow") + x, err := i.SafeAdd(i2) + if err != nil { + panic(err) } - return + return x } // AddRaw adds int64 to Int @@ -272,14 +281,24 @@ func (i Int) AddRaw(i2 int64) Int { return i.Add(NewInt(i2)) } -// Sub subtracts Int from another -func (i Int) Sub(i2 Int) (res Int) { - res = Int{sub(i.i, i2.i)} +// SafeAdd adds Int from another and returns an error if overflow +func (i Int) SafeAdd(i2 Int) (res Int, err error) { + res = Int{add(i.i, i2.i)} // Check overflow if res.i.BitLen() > MaxBitLen { - panic("Int overflow") + return Int{}, ErrIntOverflow } - return + return res, nil +} + +// Sub subtracts Int from another +func (i Int) Sub(i2 Int) (res Int) { + // Check overflow + x, err := i.SafeSub(i2) + if err != nil { + panic(err) + } + return x } // SubRaw subtracts int64 from Int @@ -287,18 +306,24 @@ func (i Int) SubRaw(i2 int64) Int { return i.Sub(NewInt(i2)) } +// SafeSub subtracts Int from another and returns an error if overflow or underflow +func (i Int) SafeSub(i2 Int) (res Int, err error) { + res = Int{sub(i.i, i2.i)} + // Check overflow/underflow + if res.i.BitLen() > MaxBitLen { + return Int{}, ErrIntOverflow + } + return res, nil +} + // Mul multiples two Ints func (i Int) Mul(i2 Int) (res Int) { // Check overflow - if i.i.BitLen()+i2.i.BitLen()-1 > MaxBitLen { - panic("Int overflow") + x, err := i.SafeMul(i2) + if err != nil { + panic(err) } - res = Int{mul(i.i, i2.i)} - // Check overflow if sign of both are same - if res.i.BitLen() > MaxBitLen { - panic("Int overflow") - } - return + return x } // MulRaw multipies Int and int64 @@ -306,13 +331,28 @@ func (i Int) MulRaw(i2 int64) Int { return i.Mul(NewInt(i2)) } +// SafeMul multiples Int from another and returns an error if overflow +func (i Int) SafeMul(i2 Int) (res Int, err error) { + // Check overflow + if i.i.BitLen()+i2.i.BitLen()-1 > MaxBitLen { + return Int{}, ErrIntOverflow + } + res = Int{mul(i.i, i2.i)} + // Check overflow if sign of both are same + if res.i.BitLen() > MaxBitLen { + return Int{}, ErrIntOverflow + } + return res, nil +} + // Quo divides Int with Int func (i Int) Quo(i2 Int) (res Int) { // Check division-by-zero - if i2.i.Sign() == 0 { + x, err := i.SafeQuo(i2) + if err != nil { panic("Division by zero") } - return Int{div(i.i, i2.i)} + return x } // QuoRaw divides Int with int64 @@ -320,12 +360,22 @@ func (i Int) QuoRaw(i2 int64) Int { return i.Quo(NewInt(i2)) } +// SafeQuo divides Int with Int and returns an error if division by zero +func (i Int) SafeQuo(i2 Int) (res Int, err error) { + // Check division-by-zero + if i2.i.Sign() == 0 { + return Int{}, ErrDivideByZero + } + return Int{div(i.i, i2.i)}, nil +} + // Mod returns remainder after dividing with Int func (i Int) Mod(i2 Int) Int { - if i2.Sign() == 0 { - panic("division-by-zero") + x, err := i.SafeMod(i2) + if err != nil { + panic(err) } - return Int{mod(i.i, i2.i)} + return x } // ModRaw returns remainder after dividing with int64 @@ -333,6 +383,14 @@ func (i Int) ModRaw(i2 int64) Int { return i.Mod(NewInt(i2)) } +// SafeMod returns remainder after dividing with Int and returns an error if division by zero +func (i Int) SafeMod(i2 Int) (res Int, err error) { + if i2.Sign() == 0 { + return Int{}, ErrDivideByZero + } + return Int{mod(i.i, i2.i)}, nil +} + // Neg negates Int func (i Int) Neg() (res Int) { return Int{neg(i.i)} diff --git a/math/int_test.go b/math/int_test.go index cbcf29d632..1895f4bc1d 100644 --- a/math/int_test.go +++ b/math/int_test.go @@ -111,32 +111,66 @@ func (s *intTestSuite) TestIntPanic() { s.Require().NotPanics(func() { i1.Add(i1) }) s.Require().NotPanics(func() { i2.Add(i2) }) s.Require().Panics(func() { i3.Add(i3) }) + _, err := i1.SafeAdd(i1) + s.Require().Nil(err) + _, err = i2.SafeAdd(i2) + s.Require().Nil(err) + _, err = i3.SafeAdd(i3) + s.Require().Error(err) s.Require().NotPanics(func() { i1.Sub(i1.Neg()) }) s.Require().NotPanics(func() { i2.Sub(i2.Neg()) }) s.Require().Panics(func() { i3.Sub(i3.Neg()) }) + _, err = i1.SafeSub(i1.Neg()) + s.Require().Nil(err) + _, err = i2.SafeSub(i2.Neg()) + s.Require().Nil(err) + _, err = i3.SafeSub(i3.Neg()) + s.Require().Error(err) s.Require().Panics(func() { i1.Mul(i1) }) s.Require().Panics(func() { i2.Mul(i2) }) s.Require().Panics(func() { i3.Mul(i3) }) + _, err = i1.SafeMul(i1) + s.Require().Error(err) + _, err = i2.SafeMul(i2) + s.Require().Error(err) + _, err = i3.SafeMul(i3) + s.Require().Error(err) s.Require().Panics(func() { i1.Neg().Mul(i1.Neg()) }) s.Require().Panics(func() { i2.Neg().Mul(i2.Neg()) }) s.Require().Panics(func() { i3.Neg().Mul(i3.Neg()) }) + _, err = i1.Neg().SafeMul(i1.Neg()) + s.Require().Error(err) + _, err = i2.Neg().SafeMul(i2.Neg()) + s.Require().Error(err) + _, err = i3.Neg().SafeMul(i3.Neg()) + s.Require().Error(err) - // // Underflow check + // Underflow check i3n := i3.Neg() s.Require().NotPanics(func() { i3n.Sub(i1) }) s.Require().NotPanics(func() { i3n.Sub(i2) }) s.Require().Panics(func() { i3n.Sub(i3) }) + _, err = i3n.SafeSub(i3) + s.Require().Error(err) s.Require().NotPanics(func() { i3n.Add(i1.Neg()) }) s.Require().NotPanics(func() { i3n.Add(i2.Neg()) }) s.Require().Panics(func() { i3n.Add(i3.Neg()) }) + _, err = i3n.SafeAdd(i3.Neg()) + s.Require().Error(err) s.Require().Panics(func() { i1.Mul(i1.Neg()) }) s.Require().Panics(func() { i2.Mul(i2.Neg()) }) s.Require().Panics(func() { i3.Mul(i3.Neg()) }) + _, err = i1.SafeMul(i1.Neg()) + s.Require().Error(err) + _, err = i2.SafeMul(i2.Neg()) + s.Require().Error(err) + _, err = i3.SafeMul(i3.Neg()) + s.Require().Error(err) // Bound check intmax := math.NewIntFromBigInt(new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1))) @@ -145,12 +179,18 @@ func (s *intTestSuite) TestIntPanic() { s.Require().NotPanics(func() { intmin.Sub(math.ZeroInt()) }) s.Require().Panics(func() { intmax.Add(math.OneInt()) }) s.Require().Panics(func() { intmin.Sub(math.OneInt()) }) + _, err = intmax.SafeAdd(math.OneInt()) + s.Require().Error(err) + _, err = intmin.SafeSub(math.OneInt()) + s.Require().Error(err) s.Require().NotPanics(func() { math.NewIntFromBigInt(nil) }) s.Require().True(math.NewIntFromBigInt(nil).IsNil()) // Division-by-zero check s.Require().Panics(func() { i1.Quo(math.NewInt(0)) }) + _, err = i1.SafeQuo(math.NewInt(0)) + s.Require().Error(err) s.Require().NotPanics(func() { math.Int{}.BigInt() }) }