cosmos-sdk/math/dec_rapid_test.go
Alexander Peters 4ed1087cf5
feat(math): Upstream GDA based decimal type (#21982)
Co-authored-by: samricotta <samanthalricotta@gmail.com>
Co-authored-by: samricotta <37125168+samricotta@users.noreply.github.com>
2024-11-22 10:29:01 +00:00

528 lines
12 KiB
Go

package math
import (
"fmt"
"regexp"
"strconv"
"testing"
"github.com/stretchr/testify/require"
"pgregory.net/rapid"
)
// Rapid is a Go library for property-based testing.
func TestDecWithRapid(t *testing.T) {
// Property tests
t.Run("TestNewDecFromInt64", rapid.MakeCheck(testDecInt64))
// Properties about *FromString functions
t.Run("TestInvalidNewDecFromString", rapid.MakeCheck(testInvalidNewDecFromString))
// Properties about addition
t.Run("TestAddLeftIdentity", rapid.MakeCheck(testAddLeftIdentity))
t.Run("TestAddRightIdentity", rapid.MakeCheck(testAddRightIdentity))
t.Run("TestAddCommutative", rapid.MakeCheck(testAddCommutative))
t.Run("TestAddAssociative", rapid.MakeCheck(testAddAssociative))
// Properties about subtraction
t.Run("TestSubRightIdentity", rapid.MakeCheck(testSubRightIdentity))
t.Run("TestSubZero", rapid.MakeCheck(testSubZero))
// Properties about multiplication
t.Run("TestMulLeftIdentity", rapid.MakeCheck(testMulLeftIdentity))
t.Run("TestMulRightIdentity", rapid.MakeCheck(testMulRightIdentity))
t.Run("TestMulCommutative", rapid.MakeCheck(testMulCommutative))
t.Run("TestMulAssociative", rapid.MakeCheck(testMulAssociative))
t.Run("TestZeroIdentity", rapid.MakeCheck(testMulZero))
// Properties about division
t.Run("TestDivisionBySelf", rapid.MakeCheck(testSelfQuo))
t.Run("TestDivisionByOne", rapid.MakeCheck(testQuoByOne))
// Properties combining operations
t.Run("TestSubAdd", rapid.MakeCheck(testSubAdd))
t.Run("TestAddSub", rapid.MakeCheck(testAddSub))
t.Run("TestMulQuoA", rapid.MakeCheck(testMulQuoA))
t.Run("TestMulQuoB", rapid.MakeCheck(testMulQuoB))
t.Run("TestMulQuoExact", rapid.MakeCheck(testMulQuoExact))
t.Run("TestQuoMulExact", rapid.MakeCheck(testQuoMulExact))
// Properties about comparison and equality
t.Run("TestCmpInverse", rapid.MakeCheck(testCmpInverse))
t.Run("TestEqualCommutative", rapid.MakeCheck(testEqualCommutative))
// Properties about tests on a single Dec
t.Run("TestIsZero", rapid.MakeCheck(testIsZero))
t.Run("TestIsNegative", rapid.MakeCheck(testIsNegative))
t.Run("TestIsPositive", rapid.MakeCheck(testIsPositive))
t.Run("TestNumDecimalPlaces", rapid.MakeCheck(testNumDecimalPlaces))
// Unit tests
zero := Dec{}
one := NewDecFromInt64(1)
two := NewDecFromInt64(2)
three := NewDecFromInt64(3)
four := NewDecFromInt64(4)
five := NewDecFromInt64(5)
minusOne := NewDecFromInt64(-1)
onePointOneFive, err := NewDecFromString("1.15")
require.NoError(t, err)
twoPointThreeFour, err := NewDecFromString("2.34")
require.NoError(t, err)
threePointFourNine, err := NewDecFromString("3.49")
require.NoError(t, err)
onePointFourNine, err := NewDecFromString("1.49")
require.NoError(t, err)
minusFivePointZero, err := NewDecFromString("-5.0")
require.NoError(t, err)
twoThousand := NewDecWithExp(2, 3)
require.True(t, twoThousand.Equal(NewDecFromInt64(2000)))
res, err := two.Add(zero)
require.NoError(t, err)
require.True(t, res.Equal(two))
res, err = five.Sub(two)
require.NoError(t, err)
require.True(t, res.Equal(three))
res, err = four.Quo(two)
require.NoError(t, err)
require.True(t, res.Equal(two))
res, err = five.QuoInteger(two)
require.NoError(t, err)
require.True(t, res.Equal(two))
res, err = five.Modulo(two)
require.NoError(t, err)
require.True(t, res.Equal(one))
x, err := four.Int64()
require.NoError(t, err)
require.Equal(t, int64(4), x)
require.Equal(t, "5", five.String())
res, err = onePointOneFive.Add(twoPointThreeFour)
require.NoError(t, err)
require.True(t, res.Equal(threePointFourNine))
res, err = threePointFourNine.Sub(two)
require.NoError(t, err)
require.True(t, res.Equal(onePointFourNine))
res, err = minusOne.Sub(four)
require.NoError(t, err)
require.True(t, res.Equal(minusFivePointZero))
require.True(t, zero.IsZero())
require.False(t, zero.IsPositive())
require.False(t, zero.IsNegative())
require.False(t, one.IsZero())
require.True(t, one.IsPositive())
require.False(t, one.IsNegative())
require.False(t, minusOne.IsZero())
require.False(t, minusOne.IsPositive())
require.True(t, minusOne.IsNegative())
res, err = one.MulExact(two)
require.NoError(t, err)
require.True(t, res.Equal(two))
}
var genDec *rapid.Generator[Dec] = rapid.Custom(func(t *rapid.T) Dec {
f := rapid.Float64().Draw(t, "f")
dec, err := NewDecFromString(fmt.Sprintf("%g", f))
require.NoError(t, err)
return dec
})
// A Dec value and the float used to create it
type floatAndDec struct {
float float64
dec Dec
}
// Generate a Dec value along with the float used to create it
var genFloatAndDec *rapid.Generator[floatAndDec] = rapid.Custom(func(t *rapid.T) floatAndDec {
f := rapid.Float64().Draw(t, "f")
dec, err := NewDecFromString(fmt.Sprintf("%g", f))
require.NoError(t, err)
return floatAndDec{f, dec}
})
// Property: n == NewDecFromInt64(n).Int64()
func testDecInt64(t *rapid.T) {
nIn := rapid.Int64().Draw(t, "n")
nOut, err := NewDecFromInt64(nIn).Int64()
require.NoError(t, err)
require.Equal(t, nIn, nOut)
}
// Property: invalid_number_string(s) => NewDecFromString(s) == err
func testInvalidNewDecFromString(t *rapid.T) {
s := rapid.StringMatching("[[:alpha:]]+").Draw(t, "s")
_, err := NewDecFromString(s)
require.Error(t, err)
}
// Property: 0 + a == a
func testAddLeftIdentity(t *rapid.T) {
a := genDec.Draw(t, "a")
zero := NewDecFromInt64(0)
b, err := zero.Add(a)
require.NoError(t, err)
require.True(t, a.Equal(b))
}
// Property: a + 0 == a
func testAddRightIdentity(t *rapid.T) {
a := genDec.Draw(t, "a")
zero := NewDecFromInt64(0)
b, err := a.Add(zero)
require.NoError(t, err)
require.True(t, a.Equal(b))
}
// Property: a + b == b + a
func testAddCommutative(t *rapid.T) {
a := genDec.Draw(t, "a")
b := genDec.Draw(t, "b")
c, err := a.Add(b)
require.NoError(t, err)
d, err := b.Add(a)
require.NoError(t, err)
require.True(t, c.Equal(d))
}
// Property: (a + b) + c == a + (b + c)
func testAddAssociative(t *rapid.T) {
a := genDec.Draw(t, "a")
b := genDec.Draw(t, "b")
c := genDec.Draw(t, "c")
// (a + b) + c
d, err := a.Add(b)
require.NoError(t, err)
e, err := d.Add(c)
require.NoError(t, err)
// a + (b + c)
f, err := b.Add(c)
require.NoError(t, err)
g, err := a.Add(f)
require.NoError(t, err)
require.True(t, e.Equal(g))
}
// Property: a - 0 == a
func testSubRightIdentity(t *rapid.T) {
a := genDec.Draw(t, "a")
zero := NewDecFromInt64(0)
b, err := a.Sub(zero)
require.NoError(t, err)
require.True(t, a.Equal(b))
}
// Property: a - a == 0
func testSubZero(t *rapid.T) {
a := genDec.Draw(t, "a")
zero := NewDecFromInt64(0)
b, err := a.Sub(a)
require.NoError(t, err)
require.True(t, b.Equal(zero))
}
// Property: 1 * a == a
func testMulLeftIdentity(t *rapid.T) {
a := genDec.Draw(t, "a")
one := NewDecFromInt64(1)
b, err := one.Mul(a)
require.NoError(t, err)
require.True(t, a.Equal(b))
}
// Property: a * 1 == a
func testMulRightIdentity(t *rapid.T) {
a := genDec.Draw(t, "a")
one := NewDecFromInt64(1)
b, err := a.Mul(one)
require.NoError(t, err)
require.True(t, a.Equal(b))
}
// Property: a * b == b * a
func testMulCommutative(t *rapid.T) {
a := genDec.Draw(t, "a")
b := genDec.Draw(t, "b")
c, err := a.Mul(b)
require.NoError(t, err)
d, err := b.Mul(a)
require.NoError(t, err)
require.True(t, c.Equal(d))
}
// Property: (a * b) * c == a * (b * c)
func testMulAssociative(t *rapid.T) {
a := genDec.Draw(t, "a")
b := genDec.Draw(t, "b")
c := genDec.Draw(t, "c")
// (a * b) * c
d, err := a.Mul(b)
require.NoError(t, err)
e, err := d.Mul(c)
require.NoError(t, err)
// a * (b * c)
f, err := b.Mul(c)
require.NoError(t, err)
g, err := a.Mul(f)
require.NoError(t, err)
require.True(t, e.Equal(g))
}
// Property: (a - b) + b == a
func testSubAdd(t *rapid.T) {
a := genDec.Draw(t, "a")
b := genDec.Draw(t, "b")
c, err := a.Sub(b)
require.NoError(t, err)
d, err := c.Add(b)
require.NoError(t, err)
require.True(t, a.Equal(d))
}
// Property: (a + b) - b == a
func testAddSub(t *rapid.T) {
a := genDec.Draw(t, "a")
b := genDec.Draw(t, "b")
c, err := a.Add(b)
require.NoError(t, err)
d, err := c.Sub(b)
require.NoError(t, err)
require.True(t, a.Equal(d))
}
// Property: a * 0 = 0
func testMulZero(t *rapid.T) {
a := genDec.Draw(t, "a")
zero := Dec{}
c, err := a.Mul(zero)
require.NoError(t, err)
require.True(t, c.IsZero())
}
// Property: a/a = 1
func testSelfQuo(t *rapid.T) {
decNotZero := func(d Dec) bool { return !d.IsZero() }
a := genDec.Filter(decNotZero).Draw(t, "a")
one := NewDecFromInt64(1)
b, err := a.Quo(a)
require.NoError(t, err)
require.True(t, one.Equal(b))
}
// Property: a/1 = a
func testQuoByOne(t *rapid.T) {
a := genDec.Draw(t, "a")
one := NewDecFromInt64(1)
b, err := a.Quo(one)
require.NoError(t, err)
require.True(t, a.Equal(b))
}
// Property: (a * b) / a == b
func testMulQuoA(t *rapid.T) {
decNotZero := func(d Dec) bool { return !d.IsZero() }
a := genDec.Filter(decNotZero).Draw(t, "a")
b := genDec.Draw(t, "b")
c, err := a.Mul(b)
require.NoError(t, err)
d, err := c.Quo(a)
require.NoError(t, err)
require.True(t, b.Equal(d))
}
// Property: (a * b) / b == a
func testMulQuoB(t *rapid.T) {
decNotZero := func(d Dec) bool { return !d.IsZero() }
a := genDec.Draw(t, "a")
b := genDec.Filter(decNotZero).Draw(t, "b")
c, err := a.Mul(b)
require.NoError(t, err)
d, err := c.Quo(b)
require.NoError(t, err)
require.True(t, a.Equal(d))
}
// Property: (a * 10^b) / 10^b == a using MulExact and QuoExact
// and a with no more than b decimal places (b <= 32).
func testMulQuoExact(t *rapid.T) {
b := rapid.Uint32Range(0, 32).Draw(t, "b")
decPrec := func(d Dec) bool { return d.NumDecimalPlaces() <= b }
a := genDec.Filter(decPrec).Draw(t, "a")
c := NewDecWithExp(1, int32(b))
d, err := a.MulExact(c)
require.NoError(t, err)
e, err := d.QuoExact(c)
require.NoError(t, err)
require.True(t, a.Equal(e))
}
// Property: (a / b) * b == a using QuoExact and MulExact and
// a as an integer.
func testQuoMulExact(t *rapid.T) {
a := rapid.Uint64().Draw(t, "a")
aDec, err := NewDecFromString(fmt.Sprintf("%d", a))
require.NoError(t, err)
b := rapid.Uint32Range(0, 32).Draw(t, "b")
c := NewDecWithExp(1, int32(b))
require.NoError(t, err)
d, err := aDec.QuoExact(c)
require.NoError(t, err)
e, err := d.MulExact(c)
require.NoError(t, err)
require.True(t, aDec.Equal(e))
}
// Property: Cmp(a, b) == -Cmp(b, a)
func testCmpInverse(t *rapid.T) {
a := genDec.Draw(t, "a")
b := genDec.Draw(t, "b")
require.Equal(t, a.Cmp(b), -b.Cmp(a))
}
// Property: Equal(a, b) == Equal(b, a)
func testEqualCommutative(t *rapid.T) {
a := genDec.Draw(t, "a")
b := genDec.Draw(t, "b")
require.Equal(t, a.Equal(b), b.Equal(a))
}
// Property: isZero(f) == isZero(NewDecFromString(f.String()))
func testIsZero(t *rapid.T) {
floatAndDec := genFloatAndDec.Draw(t, "floatAndDec")
f, dec := floatAndDec.float, floatAndDec.dec
require.Equal(t, f == 0, dec.IsZero())
}
// Property: isNegative(f) == isNegative(NewDecFromString(f.String()))
func testIsNegative(t *rapid.T) {
floatAndDec := genFloatAndDec.Draw(t, "floatAndDec")
f, dec := floatAndDec.float, floatAndDec.dec
require.Equal(t, f < 0, dec.IsNegative())
}
// Property: isPositive(f) == isPositive(NewDecFromString(f.String()))
func testIsPositive(t *rapid.T) {
floatAndDec := genFloatAndDec.Draw(t, "floatAndDec")
f, dec := floatAndDec.float, floatAndDec.dec
require.Equal(t, f > 0, dec.IsPositive())
}
// Property: floatDecimalPlaces(f) == NumDecimalPlaces(NewDecFromString(f.String()))
func testNumDecimalPlaces(t *rapid.T) {
floatAndDec := genFloatAndDec.Draw(t, "floatAndDec")
f, dec := floatAndDec.float, floatAndDec.dec
require.Equal(t, floatDecimalPlaces(t, f), dec.NumDecimalPlaces())
}
func floatDecimalPlaces(t *rapid.T, f float64) uint32 {
reScientific := regexp.MustCompile(`^\-?(?:[[:digit:]]+(?:\.([[:digit:]]+))?|\.([[:digit:]]+))(?:e?(?:\+?([[:digit:]]+)|(-[[:digit:]]+)))?$`)
fStr := fmt.Sprintf("%g", f)
matches := reScientific.FindAllStringSubmatch(fStr, 1)
if len(matches) != 1 {
t.Fatalf("Didn't match float: %g", f)
}
// basePlaces is the number of decimal places in the decimal part of the
// string
basePlaces := 0
if matches[0][1] != "" {
basePlaces = len(matches[0][1])
} else if matches[0][2] != "" {
basePlaces = len(matches[0][2])
}
t.Logf("Base places: %d", basePlaces)
// exp is the exponent
exp := 0
if matches[0][3] != "" {
var err error
exp, err = strconv.Atoi(matches[0][3])
require.NoError(t, err)
} else if matches[0][4] != "" {
var err error
exp, err = strconv.Atoi(matches[0][4])
require.NoError(t, err)
}
// Subtract exponent from base and check if negative
res := basePlaces - exp
if res <= 0 {
return 0
}
return uint32(res)
}