lotus/chain/types/bitfield.go

146 lines
2.5 KiB
Go
Raw Normal View History

package types
import (
"fmt"
"io"
rlepluslazy "github.com/filecoin-project/lotus/lib/rlepluslazy"
cbg "github.com/whyrusleeping/cbor-gen"
"golang.org/x/xerrors"
)
type BitField struct {
rle rlepluslazy.RLE
bits map[uint64]struct{}
}
func NewBitField() BitField {
rle, err := rlepluslazy.FromBuf([]byte{})
if err != nil {
panic(err)
}
return BitField{
rle: rle,
bits: make(map[uint64]struct{}),
}
}
2019-09-19 16:17:49 +00:00
func BitFieldFromSet(setBits []uint64) BitField {
res := BitField{bits: make(map[uint64]struct{})}
for _, b := range setBits {
res.bits[b] = struct{}{}
}
return res
}
func (bf BitField) sum() (rlepluslazy.RunIterator, error) {
if len(bf.bits) == 0 {
return bf.rle.RunIterator()
}
a, err := bf.rle.RunIterator()
if err != nil {
return nil, err
}
slc := make([]uint64, 0, len(bf.bits))
for b := range bf.bits {
slc = append(slc, b)
}
b, err := rlepluslazy.RunsFromSlice(slc)
if err != nil {
return nil, err
}
res, err := rlepluslazy.Sum(a, b)
if err != nil {
return nil, err
}
return res, nil
}
// Set ...s bit in the BitField
func (bf BitField) Set(bit uint64) {
bf.bits[bit] = struct{}{}
}
func (bf BitField) Count() (uint64, error) {
s, err := bf.sum()
if err != nil {
return 0, err
}
return rlepluslazy.Count(s)
}
// All returns all set bits
func (bf BitField) All() ([]uint64, error) {
runs, err := bf.sum()
if err != nil {
return nil, err
}
res, err := rlepluslazy.SliceFromRuns(runs)
if err != nil {
return nil, err
}
return res, err
}
func (bf BitField) MarshalCBOR(w io.Writer) error {
ints := make([]uint64, 0, len(bf.bits))
for i := range bf.bits {
ints = append(ints, i)
}
s, err := bf.sum()
if err != nil {
return err
}
rle, err := rlepluslazy.EncodeRuns(s, []byte{})
if err != nil {
return err
}
if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(rle)))); err != nil {
return err
}
if _, err = w.Write(rle); err != nil {
return xerrors.Errorf("writing rle: %w", err)
}
return nil
}
func (bf *BitField) UnmarshalCBOR(r io.Reader) error {
br := cbg.GetPeeker(r)
maj, extra, err := cbg.CborReadHeader(br)
if err != nil {
return err
}
if extra > 8192 {
return fmt.Errorf("array too large")
}
if maj != cbg.MajByteString {
return fmt.Errorf("expected byte array")
}
buf := make([]byte, extra)
if _, err := io.ReadFull(br, buf); err != nil {
return err
}
rle, err := rlepluslazy.FromBuf(buf)
if err != nil {
return xerrors.Errorf("could not decode rle+: %w", err)
}
bf.rle = rle
bf.bits = make(map[uint64]struct{})
return nil
}