2019-09-17 14:20:08 +00:00
|
|
|
package types
|
|
|
|
|
|
|
|
import (
|
2020-01-06 19:53:03 +00:00
|
|
|
"errors"
|
2019-09-17 14:20:08 +00:00
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
|
2019-12-07 14:47:36 +00:00
|
|
|
rlepluslazy "github.com/filecoin-project/lotus/lib/rlepluslazy"
|
2019-09-17 14:20:08 +00:00
|
|
|
cbg "github.com/whyrusleeping/cbor-gen"
|
|
|
|
"golang.org/x/xerrors"
|
|
|
|
)
|
|
|
|
|
2020-01-06 19:53:03 +00:00
|
|
|
var ErrBitFieldTooMany = errors.New("to many items in RLE")
|
|
|
|
|
2019-09-17 14:20:08 +00:00
|
|
|
type BitField struct {
|
2019-12-07 14:47:36 +00:00
|
|
|
rle rlepluslazy.RLE
|
|
|
|
|
2019-09-17 14:20:08 +00:00
|
|
|
bits map[uint64]struct{}
|
|
|
|
}
|
|
|
|
|
2019-09-18 15:10:03 +00:00
|
|
|
func NewBitField() BitField {
|
2019-12-07 15:19:54 +00:00
|
|
|
rle, err := rlepluslazy.FromBuf([]byte{})
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
2019-12-07 14:47:36 +00:00
|
|
|
return BitField{
|
|
|
|
rle: rle,
|
|
|
|
bits: make(map[uint64]struct{}),
|
|
|
|
}
|
2019-09-18 15:10:03 +00:00
|
|
|
}
|
|
|
|
|
2019-09-19 16:17:49 +00:00
|
|
|
func BitFieldFromSet(setBits []uint64) BitField {
|
|
|
|
res := BitField{bits: make(map[uint64]struct{})}
|
|
|
|
for _, b := range setBits {
|
|
|
|
res.bits[b] = struct{}{}
|
|
|
|
}
|
|
|
|
return res
|
|
|
|
}
|
|
|
|
|
2019-12-06 14:06:42 +00:00
|
|
|
func MergeBitFields(a, b BitField) (BitField, error) {
|
|
|
|
ra, err := a.rle.RunIterator()
|
|
|
|
if err != nil {
|
|
|
|
return BitField{}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
rb, err := b.rle.RunIterator()
|
|
|
|
if err != nil {
|
|
|
|
return BitField{}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
merge, err := rlepluslazy.Sum(ra, rb)
|
|
|
|
if err != nil {
|
|
|
|
return BitField{}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
mergebytes, err := rlepluslazy.EncodeRuns(merge, nil)
|
|
|
|
if err != nil {
|
|
|
|
return BitField{}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
rle, err := rlepluslazy.FromBuf(mergebytes)
|
|
|
|
if err != nil {
|
|
|
|
return BitField{}, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return BitField{
|
|
|
|
rle: rle,
|
|
|
|
bits: make(map[uint64]struct{}),
|
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
2019-12-07 14:47:36 +00:00
|
|
|
func (bf BitField) sum() (rlepluslazy.RunIterator, error) {
|
|
|
|
if len(bf.bits) == 0 {
|
|
|
|
return bf.rle.RunIterator()
|
|
|
|
}
|
|
|
|
|
|
|
|
a, err := bf.rle.RunIterator()
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
slc := make([]uint64, 0, len(bf.bits))
|
|
|
|
for b := range bf.bits {
|
|
|
|
slc = append(slc, b)
|
|
|
|
}
|
|
|
|
|
|
|
|
b, err := rlepluslazy.RunsFromSlice(slc)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
res, err := rlepluslazy.Sum(a, b)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return res, nil
|
|
|
|
}
|
|
|
|
|
2019-09-17 14:20:08 +00:00
|
|
|
// Set ...s bit in the BitField
|
|
|
|
func (bf BitField) Set(bit uint64) {
|
|
|
|
bf.bits[bit] = struct{}{}
|
|
|
|
}
|
|
|
|
|
2019-12-07 14:47:36 +00:00
|
|
|
func (bf BitField) Count() (uint64, error) {
|
|
|
|
s, err := bf.sum()
|
|
|
|
if err != nil {
|
|
|
|
return 0, err
|
|
|
|
}
|
|
|
|
return rlepluslazy.Count(s)
|
2019-09-17 14:20:08 +00:00
|
|
|
}
|
|
|
|
|
2019-12-07 15:19:54 +00:00
|
|
|
// All returns all set bits
|
2020-01-06 19:53:03 +00:00
|
|
|
func (bf BitField) All(max uint64) ([]uint64, error) {
|
|
|
|
c, err := bf.Count()
|
|
|
|
if err != nil {
|
|
|
|
return nil, xerrors.Errorf("count errror: %w", err)
|
|
|
|
}
|
|
|
|
if c > max {
|
|
|
|
return nil, xerrors.Errorf("expected %d, got %d: %w", max, c, ErrBitFieldTooMany)
|
|
|
|
}
|
2019-12-07 14:47:36 +00:00
|
|
|
|
|
|
|
runs, err := bf.sum()
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
2019-09-18 15:10:03 +00:00
|
|
|
}
|
2019-10-04 18:18:11 +00:00
|
|
|
|
2019-12-07 14:47:36 +00:00
|
|
|
res, err := rlepluslazy.SliceFromRuns(runs)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2019-12-06 14:06:42 +00:00
|
|
|
return res, nil
|
|
|
|
}
|
|
|
|
|
2020-01-06 19:53:03 +00:00
|
|
|
func (bf BitField) AllMap(max uint64) (map[uint64]bool, error) {
|
|
|
|
c, err := bf.Count()
|
|
|
|
if err != nil {
|
|
|
|
return nil, xerrors.Errorf("count errror: %w", err)
|
|
|
|
}
|
|
|
|
if c > max {
|
|
|
|
return nil, xerrors.Errorf("expected %d, got %d: %w", max, c, ErrBitFieldTooMany)
|
|
|
|
}
|
2019-12-06 14:06:42 +00:00
|
|
|
|
|
|
|
runs, err := bf.sum()
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
res, err := rlepluslazy.SliceFromRuns(runs)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
out := make(map[uint64]bool)
|
|
|
|
for _, i := range res {
|
|
|
|
out[i] = true
|
|
|
|
}
|
|
|
|
return out, nil
|
2019-09-18 15:10:03 +00:00
|
|
|
}
|
|
|
|
|
2019-09-17 14:20:08 +00:00
|
|
|
func (bf BitField) MarshalCBOR(w io.Writer) error {
|
|
|
|
ints := make([]uint64, 0, len(bf.bits))
|
|
|
|
for i := range bf.bits {
|
|
|
|
ints = append(ints, i)
|
|
|
|
}
|
|
|
|
|
2019-12-07 14:47:36 +00:00
|
|
|
s, err := bf.sum()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
rle, err := rlepluslazy.EncodeRuns(s, []byte{})
|
2019-09-17 14:20:08 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2019-12-16 23:44:14 +00:00
|
|
|
if len(rle) > 8192 {
|
|
|
|
return xerrors.Errorf("encoded bitfield was too large (%d)", len(rle))
|
|
|
|
}
|
|
|
|
|
2019-09-17 14:20:08 +00:00
|
|
|
if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(rle)))); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if _, err = w.Write(rle); err != nil {
|
|
|
|
return xerrors.Errorf("writing rle: %w", err)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (bf *BitField) UnmarshalCBOR(r io.Reader) error {
|
|
|
|
br := cbg.GetPeeker(r)
|
|
|
|
|
|
|
|
maj, extra, err := cbg.CborReadHeader(br)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if extra > 8192 {
|
|
|
|
return fmt.Errorf("array too large")
|
|
|
|
}
|
|
|
|
|
|
|
|
if maj != cbg.MajByteString {
|
|
|
|
return fmt.Errorf("expected byte array")
|
|
|
|
}
|
|
|
|
|
2019-12-07 14:47:36 +00:00
|
|
|
buf := make([]byte, extra)
|
|
|
|
if _, err := io.ReadFull(br, buf); err != nil {
|
2019-09-17 14:20:08 +00:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2019-12-07 14:47:36 +00:00
|
|
|
rle, err := rlepluslazy.FromBuf(buf)
|
2019-09-17 14:20:08 +00:00
|
|
|
if err != nil {
|
|
|
|
return xerrors.Errorf("could not decode rle+: %w", err)
|
|
|
|
}
|
2019-12-07 14:47:36 +00:00
|
|
|
bf.rle = rle
|
2019-09-17 14:20:08 +00:00
|
|
|
bf.bits = make(map[uint64]struct{})
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|