diff --git a/blockstore/autobatch.go b/blockstore/autobatch.go index cd3991246..e9df6a3b3 100644 --- a/blockstore/autobatch.go +++ b/blockstore/autobatch.go @@ -25,12 +25,13 @@ type AutobatchBlockstore struct { // TODO: drop if memory consumption is too high addedCids map[cid.Cid]struct{} - lock sync.Mutex + stateLock sync.Mutex + doFlushLock sync.Mutex bufferedBatch blockBatch - // the flush worker has sole control (including read) over the flushingBatch.blockList and flushErr until shutdown - flushingBatch blockBatch - flushErr error + flushingBatch blockBatch + flushErr error + flushWorkerDone bool flushCh chan struct{} @@ -51,12 +52,13 @@ func NewAutobatch(ctx context.Context, backingBs Blockstore, bufferCapacity int) bufferCapacity: bufferCapacity, flushCtx: ctx, flushCh: make(chan struct{}, 1), + shutdownCh: make(chan struct{}), // could be made configable flushRetryDelay: time.Millisecond * 100, + flushWorkerDone: false, } bs.bufferedBatch.blockMap = make(map[cid.Cid]block.Block) - bs.flushingBatch.blockMap = make(map[cid.Cid]block.Block) go bs.flushWorker() @@ -64,8 +66,8 @@ func NewAutobatch(ctx context.Context, backingBs Blockstore, bufferCapacity int) } func (bs *AutobatchBlockstore) Put(ctx context.Context, blk block.Block) error { - bs.lock.Lock() - defer bs.lock.Unlock() + bs.stateLock.Lock() + defer bs.stateLock.Unlock() _, ok := bs.addedCids[blk.Cid()] if !ok { @@ -87,6 +89,11 @@ func (bs *AutobatchBlockstore) Put(ctx context.Context, blk block.Block) error { } func (bs *AutobatchBlockstore) flushWorker() { + defer func() { + bs.stateLock.Lock() + bs.flushWorkerDone = true + bs.stateLock.Unlock() + }() for { select { case <-bs.flushCh: @@ -106,34 +113,47 @@ func (bs *AutobatchBlockstore) flushWorker() { } } -// caller must NOT hold lock +// caller must NOT hold stateLock func (bs *AutobatchBlockstore) doFlush(ctx context.Context) error { + bs.doFlushLock.Lock() + defer bs.doFlushLock.Unlock() if bs.flushErr == nil { - bs.lock.Lock() + bs.stateLock.Lock() // We do NOT clear addedCids here, because its purpose is to expedite Puts bs.flushingBatch = bs.bufferedBatch bs.bufferedBatch.blockList = make([]block.Block, 0, len(bs.flushingBatch.blockList)) bs.bufferedBatch.blockMap = make(map[cid.Cid]block.Block, len(bs.flushingBatch.blockMap)) - bs.lock.Unlock() + bs.stateLock.Unlock() } bs.flushErr = bs.backingBs.PutMany(ctx, bs.flushingBatch.blockList) + bs.stateLock.Lock() + bs.flushingBatch = blockBatch{} + bs.stateLock.Unlock() + return bs.flushErr } -// caller must NOT hold lock +// caller must NOT hold stateLock func (bs *AutobatchBlockstore) Flush(ctx context.Context) error { return bs.doFlush(ctx) } func (bs *AutobatchBlockstore) Shutdown(ctx context.Context) error { - // shutdown the flush worker - bs.shutdownCh <- struct{}{} + bs.stateLock.Lock() + flushDone := bs.flushWorkerDone + bs.stateLock.Unlock() + if !flushDone { + // may racily block forever if Shutdown is called in parallel + bs.shutdownCh <- struct{}{} + } return bs.flushErr } func (bs *AutobatchBlockstore) Get(ctx context.Context, c cid.Cid) (block.Block, error) { + bs.stateLock.Lock() + defer bs.stateLock.Unlock() // may seem backward to check the backingBs first, but that is the likeliest case blk, err := bs.backingBs.Get(ctx, c) if err == nil { @@ -144,20 +164,17 @@ func (bs *AutobatchBlockstore) Get(ctx context.Context, c cid.Cid) (block.Block, return blk, err } - bs.lock.Lock() - defer bs.lock.Unlock() v, ok := bs.flushingBatch.blockMap[c] if ok { return v, nil } - v, ok = bs.flushingBatch.blockMap[c] + v, ok = bs.bufferedBatch.blockMap[c] if ok { return v, nil } - // check the backingBs in case it just got put in the backingBs (and removed from the batch maps) while we were here - return bs.backingBs.Get(ctx, c) + return nil, ErrNotFound } func (bs *AutobatchBlockstore) DeleteBlock(context.Context, cid.Cid) error { @@ -218,9 +235,10 @@ func (bs *AutobatchBlockstore) HashOnRead(enabled bool) { } func (bs *AutobatchBlockstore) View(ctx context.Context, cid cid.Cid, callback func([]byte) error) error { - if err := bs.Flush(ctx); err != nil { + blk, err := bs.Get(ctx, cid) + if err != nil { return err } - return bs.backingBs.View(ctx, cid, callback) + return callback(blk.RawData()) } diff --git a/blockstore/autobatch_test.go b/blockstore/autobatch_test.go index fe52c55c7..57a3b7d6c 100644 --- a/blockstore/autobatch_test.go +++ b/blockstore/autobatch_test.go @@ -17,8 +17,6 @@ func TestAutobatchBlockstore(t *testing.T) { require.NoError(t, ab.Put(ctx, b1)) require.NoError(t, ab.Put(ctx, b2)) - ab.Flush(ctx) - v0, err := ab.Get(ctx, b0.Cid()) require.NoError(t, err) require.Equal(t, b0.RawData(), v0.RawData()) @@ -30,4 +28,7 @@ func TestAutobatchBlockstore(t *testing.T) { v2, err := ab.Get(ctx, b2.Cid()) require.NoError(t, err) require.Equal(t, b2.RawData(), v2.RawData()) + + require.NoError(t, ab.Flush(ctx)) + require.NoError(t, ab.Shutdown(ctx)) }