Address review part 2

This commit is contained in:
Aayush Rajasekaran 2022-01-12 15:03:34 -05:00
parent 083c5b003c
commit 893998cb70
2 changed files with 41 additions and 22 deletions

View File

@ -25,12 +25,13 @@ type AutobatchBlockstore struct {
// TODO: drop if memory consumption is too high // TODO: drop if memory consumption is too high
addedCids map[cid.Cid]struct{} addedCids map[cid.Cid]struct{}
lock sync.Mutex stateLock sync.Mutex
doFlushLock sync.Mutex
bufferedBatch blockBatch bufferedBatch blockBatch
// the flush worker has sole control (including read) over the flushingBatch.blockList and flushErr until shutdown
flushingBatch blockBatch flushingBatch blockBatch
flushErr error flushErr error
flushWorkerDone bool
flushCh chan struct{} flushCh chan struct{}
@ -51,12 +52,13 @@ func NewAutobatch(ctx context.Context, backingBs Blockstore, bufferCapacity int)
bufferCapacity: bufferCapacity, bufferCapacity: bufferCapacity,
flushCtx: ctx, flushCtx: ctx,
flushCh: make(chan struct{}, 1), flushCh: make(chan struct{}, 1),
shutdownCh: make(chan struct{}),
// could be made configable // could be made configable
flushRetryDelay: time.Millisecond * 100, flushRetryDelay: time.Millisecond * 100,
flushWorkerDone: false,
} }
bs.bufferedBatch.blockMap = make(map[cid.Cid]block.Block) bs.bufferedBatch.blockMap = make(map[cid.Cid]block.Block)
bs.flushingBatch.blockMap = make(map[cid.Cid]block.Block)
go bs.flushWorker() go bs.flushWorker()
@ -64,8 +66,8 @@ func NewAutobatch(ctx context.Context, backingBs Blockstore, bufferCapacity int)
} }
func (bs *AutobatchBlockstore) Put(ctx context.Context, blk block.Block) error { func (bs *AutobatchBlockstore) Put(ctx context.Context, blk block.Block) error {
bs.lock.Lock() bs.stateLock.Lock()
defer bs.lock.Unlock() defer bs.stateLock.Unlock()
_, ok := bs.addedCids[blk.Cid()] _, ok := bs.addedCids[blk.Cid()]
if !ok { if !ok {
@ -87,6 +89,11 @@ func (bs *AutobatchBlockstore) Put(ctx context.Context, blk block.Block) error {
} }
func (bs *AutobatchBlockstore) flushWorker() { func (bs *AutobatchBlockstore) flushWorker() {
defer func() {
bs.stateLock.Lock()
bs.flushWorkerDone = true
bs.stateLock.Unlock()
}()
for { for {
select { select {
case <-bs.flushCh: case <-bs.flushCh:
@ -106,34 +113,47 @@ func (bs *AutobatchBlockstore) flushWorker() {
} }
} }
// caller must NOT hold lock // caller must NOT hold stateLock
func (bs *AutobatchBlockstore) doFlush(ctx context.Context) error { func (bs *AutobatchBlockstore) doFlush(ctx context.Context) error {
bs.doFlushLock.Lock()
defer bs.doFlushLock.Unlock()
if bs.flushErr == nil { if bs.flushErr == nil {
bs.lock.Lock() bs.stateLock.Lock()
// We do NOT clear addedCids here, because its purpose is to expedite Puts // We do NOT clear addedCids here, because its purpose is to expedite Puts
bs.flushingBatch = bs.bufferedBatch bs.flushingBatch = bs.bufferedBatch
bs.bufferedBatch.blockList = make([]block.Block, 0, len(bs.flushingBatch.blockList)) bs.bufferedBatch.blockList = make([]block.Block, 0, len(bs.flushingBatch.blockList))
bs.bufferedBatch.blockMap = make(map[cid.Cid]block.Block, len(bs.flushingBatch.blockMap)) bs.bufferedBatch.blockMap = make(map[cid.Cid]block.Block, len(bs.flushingBatch.blockMap))
bs.lock.Unlock() bs.stateLock.Unlock()
} }
bs.flushErr = bs.backingBs.PutMany(ctx, bs.flushingBatch.blockList) bs.flushErr = bs.backingBs.PutMany(ctx, bs.flushingBatch.blockList)
bs.stateLock.Lock()
bs.flushingBatch = blockBatch{}
bs.stateLock.Unlock()
return bs.flushErr return bs.flushErr
} }
// caller must NOT hold lock // caller must NOT hold stateLock
func (bs *AutobatchBlockstore) Flush(ctx context.Context) error { func (bs *AutobatchBlockstore) Flush(ctx context.Context) error {
return bs.doFlush(ctx) return bs.doFlush(ctx)
} }
func (bs *AutobatchBlockstore) Shutdown(ctx context.Context) error { func (bs *AutobatchBlockstore) Shutdown(ctx context.Context) error {
// shutdown the flush worker bs.stateLock.Lock()
flushDone := bs.flushWorkerDone
bs.stateLock.Unlock()
if !flushDone {
// may racily block forever if Shutdown is called in parallel
bs.shutdownCh <- struct{}{} bs.shutdownCh <- struct{}{}
}
return bs.flushErr return bs.flushErr
} }
func (bs *AutobatchBlockstore) Get(ctx context.Context, c cid.Cid) (block.Block, error) { func (bs *AutobatchBlockstore) Get(ctx context.Context, c cid.Cid) (block.Block, error) {
bs.stateLock.Lock()
defer bs.stateLock.Unlock()
// may seem backward to check the backingBs first, but that is the likeliest case // may seem backward to check the backingBs first, but that is the likeliest case
blk, err := bs.backingBs.Get(ctx, c) blk, err := bs.backingBs.Get(ctx, c)
if err == nil { if err == nil {
@ -144,20 +164,17 @@ func (bs *AutobatchBlockstore) Get(ctx context.Context, c cid.Cid) (block.Block,
return blk, err return blk, err
} }
bs.lock.Lock()
defer bs.lock.Unlock()
v, ok := bs.flushingBatch.blockMap[c] v, ok := bs.flushingBatch.blockMap[c]
if ok { if ok {
return v, nil return v, nil
} }
v, ok = bs.flushingBatch.blockMap[c] v, ok = bs.bufferedBatch.blockMap[c]
if ok { if ok {
return v, nil return v, nil
} }
// check the backingBs in case it just got put in the backingBs (and removed from the batch maps) while we were here return nil, ErrNotFound
return bs.backingBs.Get(ctx, c)
} }
func (bs *AutobatchBlockstore) DeleteBlock(context.Context, cid.Cid) error { func (bs *AutobatchBlockstore) DeleteBlock(context.Context, cid.Cid) error {
@ -218,9 +235,10 @@ func (bs *AutobatchBlockstore) HashOnRead(enabled bool) {
} }
func (bs *AutobatchBlockstore) View(ctx context.Context, cid cid.Cid, callback func([]byte) error) error { func (bs *AutobatchBlockstore) View(ctx context.Context, cid cid.Cid, callback func([]byte) error) error {
if err := bs.Flush(ctx); err != nil { blk, err := bs.Get(ctx, cid)
if err != nil {
return err return err
} }
return bs.backingBs.View(ctx, cid, callback) return callback(blk.RawData())
} }

View File

@ -17,8 +17,6 @@ func TestAutobatchBlockstore(t *testing.T) {
require.NoError(t, ab.Put(ctx, b1)) require.NoError(t, ab.Put(ctx, b1))
require.NoError(t, ab.Put(ctx, b2)) require.NoError(t, ab.Put(ctx, b2))
ab.Flush(ctx)
v0, err := ab.Get(ctx, b0.Cid()) v0, err := ab.Get(ctx, b0.Cid())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, b0.RawData(), v0.RawData()) require.Equal(t, b0.RawData(), v0.RawData())
@ -30,4 +28,7 @@ func TestAutobatchBlockstore(t *testing.T) {
v2, err := ab.Get(ctx, b2.Cid()) v2, err := ab.Get(ctx, b2.Cid())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, b2.RawData(), v2.RawData()) require.Equal(t, b2.RawData(), v2.RawData())
require.NoError(t, ab.Flush(ctx))
require.NoError(t, ab.Shutdown(ctx))
} }