diff --git a/chain/types/bigint.go b/chain/types/bigint.go index 6a0a875bb..5aeb22e33 100644 --- a/chain/types/bigint.go +++ b/chain/types/bigint.go @@ -7,9 +7,9 @@ import ( "math/big" "github.com/filecoin-project/go-lotus/build" - cbor "github.com/ipfs/go-ipld-cbor" "github.com/polydawn/refmt/obj/atlas" + cbg "github.com/whyrusleeping/cbor-gen" "golang.org/x/xerrors" ) @@ -17,18 +17,14 @@ import ( const BigIntMaxSerializedLen = 128 // is this big enough? or too big? func init() { - cbor.RegisterCborType(atlas.BuildEntry(BigInt{}).UseTag(2).Transform(). + cbor.RegisterCborType(atlas.BuildEntry(BigInt{}).Transform(). TransformMarshal(atlas.MakeMarshalTransformFunc( func(i BigInt) ([]byte, error) { - if i.Int == nil { - return []byte{}, nil - } - - return i.Bytes(), nil + return i.cborBytes(), nil })). TransformUnmarshal(atlas.MakeUnmarshalTransformFunc( func(x []byte) (BigInt, error) { - return BigFromBytes(x), nil + return fromCborBytes(x) })). Complete()) } @@ -117,30 +113,56 @@ func (bi *BigInt) UnmarshalJSON(b []byte) error { return nil } +func (bi *BigInt) cborBytes() []byte { + if bi.Int == nil { + return []byte{} + } + + switch { + case bi.Sign() == 0: + return []byte{} + case bi.Sign() > 0: + return append([]byte{0}, bi.Bytes()...) + case bi.Sign() < 0: + return append([]byte{1}, bi.Bytes()...) + } + + panic("unreachable") +} + +func fromCborBytes(buf []byte) (BigInt, error) { + var negative bool + switch buf[0] { + case 0: + negative = false + case 1: + negative = true + default: + return EmptyInt, fmt.Errorf("big int prefix should be either 0 or 1, got %d", buf[0]) + } + + i := big.NewInt(0).SetBytes(buf[1:]) + if negative { + i.Neg(i) + } + + return BigInt{i}, nil +} + func (bi *BigInt) MarshalCBOR(w io.Writer) error { if bi.Int == nil { zero := NewInt(0) return zero.MarshalCBOR(w) } - tag := uint64(2) - if bi.Sign() < 0 { - tag = 3 - } + enc := bi.cborBytes() - header := cbg.CborEncodeMajorType(cbg.MajTag, tag) + header := cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(enc))) if _, err := w.Write(header); err != nil { return err } - b := bi.Bytes() - - header = cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(b))) - if _, err := w.Write(header); err != nil { - return err - } - - if _, err := w.Write(b); err != nil { + if _, err := w.Write(enc); err != nil { return err } @@ -153,19 +175,13 @@ func (bi *BigInt) UnmarshalCBOR(br io.Reader) error { return err } - if maj != cbg.MajTag && extra != 2 && extra != 3 { - return fmt.Errorf("cbor input for big int was not a tagged big int") - } - - minus := extra & 1 - - maj, extra, err = cbg.CborReadHeader(br) - if err != nil { - return err - } - if maj != cbg.MajByteString { - return fmt.Errorf("cbor input for big int was not a tagged byte string") + return fmt.Errorf("cbor input for fil big int was not a byte string") + } + + if extra == 0 { + bi.Int = big.NewInt(0) + return nil } if extra > BigIntMaxSerializedLen { @@ -177,10 +193,12 @@ func (bi *BigInt) UnmarshalCBOR(br io.Reader) error { return err } - bi.Int = big.NewInt(0).SetBytes(buf) - if minus > 0 { - bi.Int.Neg(bi.Int) + i, err := fromCborBytes(buf) + if err != nil { + return err } + *bi = i + return nil }