feat(iavl): add Pin and UnsafeBytes design for managing mmaps (#25657)
This commit is contained in:
parent
f2d4a98039
commit
6065146fa6
@ -43,13 +43,13 @@ func (node *MemNode) Version() uint32 {
|
||||
}
|
||||
|
||||
// Key implements the Node interface.
|
||||
func (node *MemNode) Key() ([]byte, error) {
|
||||
return node.key, nil
|
||||
func (node *MemNode) Key() (UnsafeBytes, error) {
|
||||
return WrapSafeBytes(node.key), nil
|
||||
}
|
||||
|
||||
// Value implements the Node interface.
|
||||
func (node *MemNode) Value() ([]byte, error) {
|
||||
return node.value, nil
|
||||
func (node *MemNode) Value() (UnsafeBytes, error) {
|
||||
return WrapSafeBytes(node.value), nil
|
||||
}
|
||||
|
||||
// Left implements the Node interface.
|
||||
@ -63,8 +63,8 @@ func (node *MemNode) Right() *NodePointer {
|
||||
}
|
||||
|
||||
// Hash implements the Node interface.
|
||||
func (node *MemNode) Hash() []byte {
|
||||
return node.hash
|
||||
func (node *MemNode) Hash() UnsafeBytes {
|
||||
return WrapSafeBytes(node.hash)
|
||||
}
|
||||
|
||||
// MutateBranch implements the Node interface.
|
||||
@ -76,35 +76,37 @@ func (node *MemNode) MutateBranch(version uint32) (*MemNode, error) {
|
||||
}
|
||||
|
||||
// Get implements the Node interface.
|
||||
func (node *MemNode) Get(key []byte) (value []byte, index int64, err error) {
|
||||
func (node *MemNode) Get(key []byte) (value UnsafeBytes, index int64, err error) {
|
||||
if node.IsLeaf() {
|
||||
switch bytes.Compare(node.key, key) {
|
||||
case -1:
|
||||
return nil, 1, nil
|
||||
return UnsafeBytes{}, 1, nil
|
||||
case 1:
|
||||
return nil, 0, nil
|
||||
return UnsafeBytes{}, 0, nil
|
||||
default:
|
||||
return node.value, 0, nil
|
||||
return WrapSafeBytes(node.value), 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
if bytes.Compare(key, node.key) < 0 {
|
||||
leftNode, err := node.left.Resolve()
|
||||
leftNode, pin, err := node.left.Resolve()
|
||||
defer pin.Unpin()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return UnsafeBytes{}, 0, err
|
||||
}
|
||||
|
||||
return leftNode.Get(key)
|
||||
}
|
||||
|
||||
rightNode, err := node.right.Resolve()
|
||||
rightNode, pin, err := node.right.Resolve()
|
||||
defer pin.Unpin()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return UnsafeBytes{}, 0, err
|
||||
}
|
||||
|
||||
value, index, err = rightNode.Get(key)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return UnsafeBytes{}, 0, err
|
||||
}
|
||||
|
||||
index += node.size - rightNode.Size()
|
||||
|
||||
@ -11,13 +11,16 @@ func TestMemNode_Getters(t *testing.T) {
|
||||
right := NewNodePointer(&MemNode{})
|
||||
nodeId := NewNodeID(true, 5, 10)
|
||||
|
||||
testKey := []byte("testkey")
|
||||
testValue := []byte("testvalue")
|
||||
testHash := []byte("testhash")
|
||||
node := &MemNode{
|
||||
height: 3,
|
||||
version: 7,
|
||||
size: 42,
|
||||
key: []byte("testkey"),
|
||||
value: []byte("testvalue"),
|
||||
hash: []byte("testhash"),
|
||||
key: testKey,
|
||||
value: testValue,
|
||||
hash: testHash,
|
||||
left: left,
|
||||
right: right,
|
||||
nodeId: nodeId,
|
||||
@ -29,16 +32,16 @@ func TestMemNode_Getters(t *testing.T) {
|
||||
require.Equal(t, int64(42), node.Size())
|
||||
require.Equal(t, left, node.Left())
|
||||
require.Equal(t, right, node.Right())
|
||||
require.Equal(t, []byte("testhash"), node.Hash())
|
||||
require.Equal(t, testHash, node.Hash().UnsafeBytes())
|
||||
require.Equal(t, nodeId, node.ID())
|
||||
|
||||
key, err := node.Key()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("testkey"), key)
|
||||
require.Equal(t, testKey, key.UnsafeBytes())
|
||||
|
||||
value, err := node.Value()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("testvalue"), value)
|
||||
require.Equal(t, testValue, value.UnsafeBytes())
|
||||
}
|
||||
|
||||
func TestMemNode_IsLeaf(t *testing.T) {
|
||||
@ -98,12 +101,14 @@ func TestMemNode_String(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMemNode_MutateBranch(t *testing.T) {
|
||||
key := []byte("key")
|
||||
origHash := []byte("origHash")
|
||||
original := &MemNode{
|
||||
height: 2,
|
||||
version: 5,
|
||||
size: 10,
|
||||
key: []byte("key"),
|
||||
hash: []byte("oldhash"),
|
||||
key: key,
|
||||
hash: origHash,
|
||||
left: NewNodePointer(&MemNode{}),
|
||||
right: NewNodePointer(&MemNode{}),
|
||||
}
|
||||
@ -113,13 +118,13 @@ func TestMemNode_MutateBranch(t *testing.T) {
|
||||
|
||||
// Version updated, hash cleared
|
||||
require.Equal(t, uint32(12), mutated.Version())
|
||||
require.Nil(t, mutated.Hash())
|
||||
require.Nil(t, mutated.Hash().UnsafeBytes())
|
||||
|
||||
// Other fields preserved
|
||||
require.Equal(t, original.Height(), mutated.Height())
|
||||
require.Equal(t, original.Size(), mutated.Size())
|
||||
key, _ := mutated.Key()
|
||||
require.Equal(t, []byte("key"), key)
|
||||
key2, _ := mutated.Key()
|
||||
require.Equal(t, key, key2.UnsafeBytes())
|
||||
require.Equal(t, original.Left(), mutated.Left())
|
||||
require.Equal(t, original.Right(), mutated.Right())
|
||||
|
||||
@ -128,7 +133,7 @@ func TestMemNode_MutateBranch(t *testing.T) {
|
||||
|
||||
// Original unchanged
|
||||
require.Equal(t, uint32(5), original.Version())
|
||||
require.Equal(t, []byte("oldhash"), original.Hash())
|
||||
require.Equal(t, origHash, original.Hash().UnsafeBytes())
|
||||
}
|
||||
|
||||
func TestMemNode_Get_Leaf(t *testing.T) {
|
||||
@ -180,7 +185,7 @@ func TestMemNode_Get_Leaf(t *testing.T) {
|
||||
}
|
||||
val, idx, err := node.Get([]byte(tt.searchKey))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantValue, val)
|
||||
require.Equal(t, tt.wantValue, val.UnsafeBytes())
|
||||
require.Equal(t, tt.wantIndex, idx)
|
||||
})
|
||||
}
|
||||
@ -254,7 +259,7 @@ func TestMemNode_Get_Branch(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, idx, err := root.Get([]byte(tt.searchKey))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantValue, val)
|
||||
require.Equal(t, tt.wantValue, val.UnsafeBytes())
|
||||
require.Equal(t, tt.wantIndex, idx)
|
||||
})
|
||||
}
|
||||
@ -312,7 +317,7 @@ func TestMemNode_Get_DeeperTree(t *testing.T) {
|
||||
t.Run(tt.searchKey, func(t *testing.T) {
|
||||
val, idx, err := root.Get([]byte(tt.searchKey))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantValue, val)
|
||||
require.Equal(t, tt.wantValue, val.UnsafeBytes())
|
||||
require.Equal(t, tt.wantIndex, idx)
|
||||
})
|
||||
}
|
||||
|
||||
@ -12,11 +12,10 @@ type Node interface {
|
||||
IsLeaf() bool
|
||||
|
||||
// Key returns the key of this node.
|
||||
Key() ([]byte, error)
|
||||
Key() (UnsafeBytes, error)
|
||||
|
||||
// Value returns the value of this node.
|
||||
// Calling this on a non-leaf node will return nil and possibly an error.
|
||||
Value() ([]byte, error)
|
||||
// Value returns the value of this node. It is an error to call this method on non-leaf nodes.
|
||||
Value() (UnsafeBytes, error)
|
||||
|
||||
// Left returns a pointer to the left child node.
|
||||
// If this is called on a leaf node, it returns nil.
|
||||
@ -28,7 +27,7 @@ type Node interface {
|
||||
|
||||
// Hash returns the hash of this node.
|
||||
// Hash may or may not have been computed yet.
|
||||
Hash() []byte
|
||||
Hash() UnsafeBytes
|
||||
|
||||
// Height returns the height of the subtree rooted at this node.
|
||||
Height() uint8
|
||||
@ -45,7 +44,7 @@ type Node interface {
|
||||
// The index is the 0-based position where the key exists or would be inserted
|
||||
// in sorted order among all leaf keys in this subtree. This is useful for
|
||||
// range queries and determining a key's position even when it doesn't exist.
|
||||
Get(key []byte) (value []byte, index int64, err error)
|
||||
Get(key []byte) (value UnsafeBytes, index int64, err error)
|
||||
|
||||
// MutateBranch creates a mutable copy of this branch node created at the specified version.
|
||||
// Since this is an immutable tree, whenever we need to modify a branch node, we should call this method
|
||||
|
||||
@ -7,10 +7,10 @@ import (
|
||||
|
||||
// NodePointer is a pointer to a Node, which may be either in-memory, on-disk or both.
|
||||
type NodePointer struct {
|
||||
mem atomic.Pointer[MemNode]
|
||||
changeset *Changeset
|
||||
fileIdx uint32 // absolute index in file, 1-based, zero means we don't have an offset
|
||||
id NodeID
|
||||
mem atomic.Pointer[MemNode]
|
||||
// changeset *Changeset // commented to satisfy linter, will uncomment in a future PR when we wire it up
|
||||
fileIdx uint32 // absolute index in file, 1-based, zero means we don't have an offset
|
||||
id NodeID
|
||||
}
|
||||
|
||||
// NewNodePointer creates a new NodePointer pointing to the given in-memory node.
|
||||
@ -20,13 +20,23 @@ func NewNodePointer(memNode *MemNode) *NodePointer {
|
||||
return n
|
||||
}
|
||||
|
||||
// Resolve resolves the NodePointer to a Node, loading from memory or disk as necessary.
|
||||
func (p *NodePointer) Resolve() (Node, error) {
|
||||
// Resolve resolves the NodePointer to a Node, loading from memory or disk as necessary
|
||||
// as well as a Pin which MUST be unpinned after the caller is done using the node.
|
||||
// Resolve will ALWAYS return a valid Pin even if there is an error. For clarity and
|
||||
// consistency it is recommended to introduce a defer pin.Unpin() immediately after
|
||||
// calling Resolve and BEFORE checking the error return value like this:
|
||||
//
|
||||
// node, pin, err := nodePointer.Resolve()
|
||||
// defer pin.Unpin()
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
func (p *NodePointer) Resolve() (Node, Pin, error) {
|
||||
mem := p.mem.Load()
|
||||
if mem != nil {
|
||||
return mem, nil
|
||||
return mem, NoopPin{}, nil
|
||||
}
|
||||
return p.changeset.Resolve(p.id, p.fileIdx)
|
||||
return nil, NoopPin{}, fmt.Errorf("node not in memory and on-disk loading will be implemented in a future PR")
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer interface.
|
||||
|
||||
35
iavl/internal/pin.go
Normal file
35
iavl/internal/pin.go
Normal file
@ -0,0 +1,35 @@
|
||||
package internal
|
||||
|
||||
// Pin represents a handle that pins some memory-mapped file data in memory.
|
||||
// When the Pin is released via Unpin(), the data may be unmapped from memory.
|
||||
// Pin must be used to ensure that any UnsafeBytes obtained from memory-mapped
|
||||
// data remains valid while in use.
|
||||
// The caller must ensure that Unpin() is called exactly once
|
||||
// for each Pin obtained. It is recommended to use the following pattern:
|
||||
//
|
||||
// node, pin, err := nodePointer.Resolve()
|
||||
// defer pin.Unpin()
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
//
|
||||
// When we are using arrays directly addressed to memory mapped files, these arrays
|
||||
// are not part of the normal Go garbage collected memory. We must map and unmap
|
||||
// these regions of memory explicitly. Pin represents a commitment to keep the memory
|
||||
// mapped at least until Unpin() is called. During normal operation, changeset files
|
||||
// will be mapped and unmapped as needed either because the file size has grown, we have
|
||||
// compacted a changeset, or simply to manage open file descriptors.
|
||||
// Under the hood pins use a reference counting mechanism to keep track of how many
|
||||
// active users there are of a particular memory-mapped region.
|
||||
type Pin interface {
|
||||
// Unpin releases the Pin, allowing the underlying memory to be unmapped.
|
||||
// Implementors should ensure that Unpin() is idempotent and only unpins the
|
||||
// memory once even if called multiple times.
|
||||
Unpin()
|
||||
}
|
||||
|
||||
// NoopPin is a Pin that does nothing on Unpin().
|
||||
type NoopPin struct{}
|
||||
|
||||
// Unpin implements the Pin interface.
|
||||
func (NoopPin) Unpin() {}
|
||||
56
iavl/internal/unsafe_bytes.go
Normal file
56
iavl/internal/unsafe_bytes.go
Normal file
@ -0,0 +1,56 @@
|
||||
package internal
|
||||
|
||||
// UnsafeBytes wraps a byte slice that may or not be a direct reference to
|
||||
// a memory-mapped file.
|
||||
// Generally, an unsafe byte slice cannot be expected to live longer than the
|
||||
// Pin on the object it was obtained from.
|
||||
// As long as it is pinned, it is safe to use the UnsafeBytes() method to get
|
||||
// the underlying byte slice without copying.
|
||||
// If the byte slice needs to be retained beyond the Pin's lifetime, the
|
||||
// SafeCopy() method must be used to get a safe copy of the byte slice.
|
||||
type UnsafeBytes struct {
|
||||
bz []byte
|
||||
safe bool
|
||||
}
|
||||
|
||||
// WrapUnsafeBytes wraps an unsafe byte slice as UnsafeBytes, indicating that
|
||||
// it is unsafe to use without copying.
|
||||
// Use this method when you are wrapping a byte slice obtained from a memory-mapped file.
|
||||
func WrapUnsafeBytes(bz []byte) UnsafeBytes {
|
||||
return UnsafeBytes{bz: bz, safe: false}
|
||||
}
|
||||
|
||||
// WrapSafeBytes wraps a safe byte slice as UnsafeBytes, indicating that
|
||||
// it is safe to use without copying.
|
||||
// Use this method when you are wrapping a byte slice that is known to be safe,
|
||||
// e.g., a byte slice allocated in regular garbage-collected memory.
|
||||
func WrapSafeBytes(bz []byte) UnsafeBytes {
|
||||
return UnsafeBytes{bz: bz, safe: true}
|
||||
}
|
||||
|
||||
// IsNil returns true if the underlying byte slice is nil.
|
||||
func (ub UnsafeBytes) IsNil() bool {
|
||||
return ub.bz == nil
|
||||
}
|
||||
|
||||
// UnsafeBytes returns the underlying byte slice without copying.
|
||||
// The caller must ensure that the byte slice is not used beyond the lifetime
|
||||
// of the Pin on the object it was obtained from.
|
||||
func (ub UnsafeBytes) UnsafeBytes() []byte {
|
||||
return ub.bz
|
||||
}
|
||||
|
||||
// SafeCopy returns a safe copy of the underlying byte slice.
|
||||
// If the underlying byte slice is already safe or nil, it is returned as is.
|
||||
// If the underlying byte slice is unsafe, a copy is made and returned.
|
||||
func (ub UnsafeBytes) SafeCopy() []byte {
|
||||
if ub.safe {
|
||||
return ub.bz
|
||||
}
|
||||
if ub.bz == nil {
|
||||
return nil
|
||||
}
|
||||
copied := make([]byte, len(ub.bz))
|
||||
copy(copied, ub.bz)
|
||||
return copied
|
||||
}
|
||||
39
iavl/internal/unsafe_bytes_test.go
Normal file
39
iavl/internal/unsafe_bytes_test.go
Normal file
@ -0,0 +1,39 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUnsafeBytes(t *testing.T) {
|
||||
hello := []byte("hello")
|
||||
unsafe := WrapUnsafeBytes(hello)
|
||||
require.False(t, unsafe.IsNil())
|
||||
require.Equal(t, hello, unsafe.UnsafeBytes())
|
||||
safeCopy := unsafe.SafeCopy()
|
||||
require.Equal(t, hello, safeCopy)
|
||||
require.NotSame(t, &hello[0], &safeCopy[0]) // different underlying array
|
||||
|
||||
safe := WrapSafeBytes(hello)
|
||||
require.False(t, safe.IsNil())
|
||||
require.Equal(t, hello, safe.UnsafeBytes())
|
||||
safeCopy2 := safe.SafeCopy()
|
||||
require.Equal(t, hello, safeCopy2)
|
||||
require.Same(t, &hello[0], &safeCopy2[0]) // same underlying array
|
||||
|
||||
nilUnsafe := WrapUnsafeBytes(nil)
|
||||
require.True(t, nilUnsafe.IsNil())
|
||||
require.Nil(t, nilUnsafe.UnsafeBytes())
|
||||
require.Nil(t, nilUnsafe.SafeCopy())
|
||||
|
||||
nilSafe := WrapSafeBytes(nil)
|
||||
require.True(t, nilSafe.IsNil())
|
||||
require.Nil(t, nilSafe.UnsafeBytes())
|
||||
require.Nil(t, nilSafe.SafeCopy())
|
||||
|
||||
nilInit := UnsafeBytes{}
|
||||
require.True(t, nilInit.IsNil())
|
||||
require.Nil(t, nilInit.UnsafeBytes())
|
||||
require.Nil(t, nilInit.SafeCopy())
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user