cosmos-sdk/store/v2/db/memdb.go
2024-05-13 11:14:01 +00:00

466 lines
10 KiB
Go

package db
import (
"bytes"
"context"
"fmt"
"sync"
"github.com/google/btree"
corestore "cosmossdk.io/core/store"
"cosmossdk.io/store/v2/errors"
)
const (
// The approximate number of items and children per B-tree node. Tuned with benchmarks.
bTreeDegree = 32
)
// item is a btree.Item with byte slices as keys and values
type item struct {
key []byte
value []byte
}
// Less implements btree.Item.
func (i item) Less(other btree.Item) bool {
// this considers nil == []byte{}, but that's ok since we handle nil endpoints
// in iterators specially anyway
return bytes.Compare(i.key, other.(item).key) == -1
}
// newKey creates a new key item.
func newKey(key []byte) item {
return item{key: key}
}
// newPair creates a new pair item.
func newPair(key, value []byte) item {
return item{key: key, value: value}
}
var _ corestore.KVStoreWithBatch = (*MemDB)(nil)
// MemDB is an in-memory database backend using a B-tree for storage.
//
// For performance reasons, all given and returned keys and values are pointers to the in-memory
// database, so modifying them will cause the stored values to be modified as well. All DB methods
// already specify that keys and values should be considered read-only, but this is especially
// important with MemDB.
type MemDB struct {
mtx sync.RWMutex
btree *btree.BTree
}
// NewMemDB creates a new in-memory database.
func NewMemDB() *MemDB {
database := &MemDB{
btree: btree.New(bTreeDegree),
}
return database
}
// Get implements DB.
func (db *MemDB) Get(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, errors.ErrKeyEmpty
}
db.mtx.RLock()
defer db.mtx.RUnlock()
i := db.btree.Get(newKey(key))
if i != nil {
return i.(item).value, nil
}
return nil, nil
}
// Has implements DB.
func (db *MemDB) Has(key []byte) (bool, error) {
if len(key) == 0 {
return false, errors.ErrKeyEmpty
}
db.mtx.RLock()
defer db.mtx.RUnlock()
return db.btree.Has(newKey(key)), nil
}
// Set implements DB.
func (db *MemDB) Set(key, value []byte) error {
if len(key) == 0 {
return errors.ErrKeyEmpty
}
if value == nil {
return errors.ErrValueNil
}
db.mtx.Lock()
defer db.mtx.Unlock()
db.set(key, value)
return nil
}
// set sets a value without locking the mutex.
func (db *MemDB) set(key, value []byte) {
db.btree.ReplaceOrInsert(newPair(key, value))
}
// SetSync implements DB.
func (db *MemDB) SetSync(key, value []byte) error {
return db.Set(key, value)
}
// Delete implements DB.
func (db *MemDB) Delete(key []byte) error {
if len(key) == 0 {
return errors.ErrKeyEmpty
}
db.mtx.Lock()
defer db.mtx.Unlock()
db.delete(key)
return nil
}
// delete deletes a key without locking the mutex.
func (db *MemDB) delete(key []byte) {
db.btree.Delete(newKey(key))
}
// DeleteSync implements DB.
func (db *MemDB) DeleteSync(key []byte) error {
return db.Delete(key)
}
// Close implements DB.
func (db *MemDB) Close() error {
// Close is a noop since for an in-memory database, we don't have a destination to flush
// contents to nor do we want any data loss on invoking Close().
// See the discussion in https://github.com/tendermint/tendermint/libs/pull/56
return nil
}
// Print implements DB.
func (db *MemDB) Print() error {
db.mtx.RLock()
defer db.mtx.RUnlock()
db.btree.Ascend(func(i btree.Item) bool {
item := i.(item)
fmt.Printf("[%X]:\t[%X]\n", item.key, item.value)
return true
})
return nil
}
// Stats implements DB.
func (db *MemDB) Stats() map[string]string {
db.mtx.RLock()
defer db.mtx.RUnlock()
stats := make(map[string]string)
stats["database.type"] = "memDB"
stats["database.size"] = fmt.Sprintf("%d", db.btree.Len())
return stats
}
// NewBatch implements DB.
func (db *MemDB) NewBatch() corestore.Batch {
return newMemDBBatch(db)
}
// NewBatchWithSize implements DB.
// It does the same thing as NewBatch because we can't pre-allocate memDBBatch
func (db *MemDB) NewBatchWithSize(size int) corestore.Batch {
return newMemDBBatch(db)
}
// Iterator implements DB.
// Takes out a read-lock on the database until the iterator is closed.
func (db *MemDB) Iterator(start, end []byte) (corestore.Iterator, error) {
if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) {
return nil, errors.ErrKeyEmpty
}
return newMemDBIterator(db, start, end, false), nil
}
// ReverseIterator implements DB.
// Takes out a read-lock on the database until the iterator is closed.
func (db *MemDB) ReverseIterator(start, end []byte) (corestore.Iterator, error) {
if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) {
return nil, errors.ErrKeyEmpty
}
return newMemDBIterator(db, start, end, true), nil
}
// IteratorNoMtx makes an iterator with no mutex.
func (db *MemDB) IteratorNoMtx(start, end []byte) (corestore.Iterator, error) {
if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) {
return nil, errors.ErrKeyEmpty
}
return newMemDBIteratorMtxChoice(db, start, end, false, false), nil
}
// ReverseIteratorNoMtx makes an iterator with no mutex.
func (db *MemDB) ReverseIteratorNoMtx(start, end []byte) (corestore.Iterator, error) {
if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) {
return nil, errors.ErrKeyEmpty
}
return newMemDBIteratorMtxChoice(db, start, end, true, false), nil
}
const (
// Size of the channel buffer between traversal goroutine and iterator. Using an unbuffered
// channel causes two context switches per item sent, while buffering allows more work per
// context switch. Tuned with benchmarks.
chBufferSize = 64
)
// memDBIterator is a memDB iterator.
type memDBIterator struct {
ch <-chan *item
cancel context.CancelFunc
item *item
start []byte
end []byte
useMtx bool
}
var _ corestore.Iterator = (*memDBIterator)(nil)
// newMemDBIterator creates a new memDBIterator.
func newMemDBIterator(db *MemDB, start, end []byte, reverse bool) *memDBIterator {
return newMemDBIteratorMtxChoice(db, start, end, reverse, true)
}
func newMemDBIteratorMtxChoice(db *MemDB, start, end []byte, reverse, useMtx bool) *memDBIterator {
ctx, cancel := context.WithCancel(context.Background())
ch := make(chan *item, chBufferSize)
iter := &memDBIterator{
ch: ch,
cancel: cancel,
start: start,
end: end,
useMtx: useMtx,
}
if useMtx {
db.mtx.RLock()
}
go func() {
if useMtx {
defer db.mtx.RUnlock()
}
// Because we use [start, end) for reverse ranges, while btree uses (start, end], we need
// the following variables to handle some reverse iteration conditions ourselves.
var (
skipEqual []byte
abortLessThan []byte
)
visitor := func(i btree.Item) bool {
item := i.(item)
if skipEqual != nil && bytes.Equal(item.key, skipEqual) {
skipEqual = nil
return true
}
if abortLessThan != nil && bytes.Compare(item.key, abortLessThan) == -1 {
return false
}
select {
case <-ctx.Done():
return false
case ch <- &item:
return true
}
}
switch {
case start == nil && end == nil && !reverse:
db.btree.Ascend(visitor)
case start == nil && end == nil && reverse:
db.btree.Descend(visitor)
case end == nil && !reverse:
// must handle this specially, since nil is considered less than anything else
db.btree.AscendGreaterOrEqual(newKey(start), visitor)
case !reverse:
db.btree.AscendRange(newKey(start), newKey(end), visitor)
case end == nil:
// abort after start, since we use [start, end) while btree uses (start, end]
abortLessThan = start
db.btree.Descend(visitor)
default:
// skip end and abort after start, since we use [start, end) while btree uses (start, end]
skipEqual = end
abortLessThan = start
db.btree.DescendLessOrEqual(newKey(end), visitor)
}
close(ch)
}()
// prime the iterator with the first value, if any
if item, ok := <-ch; ok {
iter.item = item
}
return iter
}
// Close implements Iterator.
func (i *memDBIterator) Close() error {
i.cancel()
for range i.ch { // drain channel
}
i.item = nil
return nil
}
// Domain implements Iterator.
func (i *memDBIterator) Domain() ([]byte, []byte) {
return i.start, i.end
}
// Valid implements Iterator.
func (i *memDBIterator) Valid() bool {
return i.item != nil
}
// Next implements Iterator.
func (i *memDBIterator) Next() {
i.assertIsValid()
item, ok := <-i.ch
switch {
case ok:
i.item = item
default:
i.item = nil
}
}
// Error implements Iterator.
func (i *memDBIterator) Error() error {
return nil // famous last words
}
// Key implements Iterator.
func (i *memDBIterator) Key() []byte {
i.assertIsValid()
return i.item.key
}
// Value implements Iterator.
func (i *memDBIterator) Value() []byte {
i.assertIsValid()
return i.item.value
}
func (i *memDBIterator) assertIsValid() {
if !i.Valid() {
panic("iterator is invalid")
}
}
// memDBBatch operations
type opType int
const (
opTypeSet opType = iota + 1
opTypeDelete
)
type operation struct {
opType
key []byte
value []byte
}
// memDBBatch handles in-memory batching.
type memDBBatch struct {
db *MemDB
ops []operation
size int
}
var _ corestore.Batch = (*memDBBatch)(nil)
// newMemDBBatch creates a new memDBBatch
func newMemDBBatch(db *MemDB) *memDBBatch {
return &memDBBatch{
db: db,
ops: []operation{},
size: 0,
}
}
// Set implements Batch.
func (b *memDBBatch) Set(key, value []byte) error {
if len(key) == 0 {
return errors.ErrKeyEmpty
}
if value == nil {
return errors.ErrValueNil
}
if b.ops == nil {
return errors.ErrBatchClosed
}
b.size += len(key) + len(value)
b.ops = append(b.ops, operation{opTypeSet, key, value})
return nil
}
// Delete implements Batch.
func (b *memDBBatch) Delete(key []byte) error {
if len(key) == 0 {
return errors.ErrKeyEmpty
}
if b.ops == nil {
return errors.ErrBatchClosed
}
b.size += len(key)
b.ops = append(b.ops, operation{opTypeDelete, key, nil})
return nil
}
// Write implements Batch.
func (b *memDBBatch) Write() error {
if b.ops == nil {
return errors.ErrBatchClosed
}
b.db.mtx.Lock()
defer b.db.mtx.Unlock()
for _, op := range b.ops {
switch op.opType {
case opTypeSet:
b.db.set(op.key, op.value)
case opTypeDelete:
b.db.delete(op.key)
default:
return fmt.Errorf("unknown operation type %v (%v)", op.opType, op)
}
}
// Make sure batch cannot be used afterwards. Callers should still call Close(), for errors.
return b.Close()
}
// WriteSync implements Batch.
func (b *memDBBatch) WriteSync() error {
return b.Write()
}
// Close implements Batch.
func (b *memDBBatch) Close() error {
b.ops = nil
b.size = 0
return nil
}
// GetByteSize implements Batch
func (b *memDBBatch) GetByteSize() (int, error) {
if b.ops == nil {
return 0, errors.ErrBatchClosed
}
return b.size, nil
}