feat(iavl): add Pin and UnsafeBytes design for managing mmaps (#25657)

This commit is contained in:
Aaron Craelius 2025-12-10 21:42:25 -05:00 committed by GitHub
parent f2d4a98039
commit 6065146fa6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 190 additions and 44 deletions

View File

@ -43,13 +43,13 @@ func (node *MemNode) Version() uint32 {
}
// Key implements the Node interface.
func (node *MemNode) Key() ([]byte, error) {
return node.key, nil
func (node *MemNode) Key() (UnsafeBytes, error) {
return WrapSafeBytes(node.key), nil
}
// Value implements the Node interface.
func (node *MemNode) Value() ([]byte, error) {
return node.value, nil
func (node *MemNode) Value() (UnsafeBytes, error) {
return WrapSafeBytes(node.value), nil
}
// Left implements the Node interface.
@ -63,8 +63,8 @@ func (node *MemNode) Right() *NodePointer {
}
// Hash implements the Node interface.
func (node *MemNode) Hash() []byte {
return node.hash
func (node *MemNode) Hash() UnsafeBytes {
return WrapSafeBytes(node.hash)
}
// MutateBranch implements the Node interface.
@ -76,35 +76,37 @@ func (node *MemNode) MutateBranch(version uint32) (*MemNode, error) {
}
// Get implements the Node interface.
func (node *MemNode) Get(key []byte) (value []byte, index int64, err error) {
func (node *MemNode) Get(key []byte) (value UnsafeBytes, index int64, err error) {
if node.IsLeaf() {
switch bytes.Compare(node.key, key) {
case -1:
return nil, 1, nil
return UnsafeBytes{}, 1, nil
case 1:
return nil, 0, nil
return UnsafeBytes{}, 0, nil
default:
return node.value, 0, nil
return WrapSafeBytes(node.value), 0, nil
}
}
if bytes.Compare(key, node.key) < 0 {
leftNode, err := node.left.Resolve()
leftNode, pin, err := node.left.Resolve()
defer pin.Unpin()
if err != nil {
return nil, 0, err
return UnsafeBytes{}, 0, err
}
return leftNode.Get(key)
}
rightNode, err := node.right.Resolve()
rightNode, pin, err := node.right.Resolve()
defer pin.Unpin()
if err != nil {
return nil, 0, err
return UnsafeBytes{}, 0, err
}
value, index, err = rightNode.Get(key)
if err != nil {
return nil, 0, err
return UnsafeBytes{}, 0, err
}
index += node.size - rightNode.Size()

View File

@ -11,13 +11,16 @@ func TestMemNode_Getters(t *testing.T) {
right := NewNodePointer(&MemNode{})
nodeId := NewNodeID(true, 5, 10)
testKey := []byte("testkey")
testValue := []byte("testvalue")
testHash := []byte("testhash")
node := &MemNode{
height: 3,
version: 7,
size: 42,
key: []byte("testkey"),
value: []byte("testvalue"),
hash: []byte("testhash"),
key: testKey,
value: testValue,
hash: testHash,
left: left,
right: right,
nodeId: nodeId,
@ -29,16 +32,16 @@ func TestMemNode_Getters(t *testing.T) {
require.Equal(t, int64(42), node.Size())
require.Equal(t, left, node.Left())
require.Equal(t, right, node.Right())
require.Equal(t, []byte("testhash"), node.Hash())
require.Equal(t, testHash, node.Hash().UnsafeBytes())
require.Equal(t, nodeId, node.ID())
key, err := node.Key()
require.NoError(t, err)
require.Equal(t, []byte("testkey"), key)
require.Equal(t, testKey, key.UnsafeBytes())
value, err := node.Value()
require.NoError(t, err)
require.Equal(t, []byte("testvalue"), value)
require.Equal(t, testValue, value.UnsafeBytes())
}
func TestMemNode_IsLeaf(t *testing.T) {
@ -98,12 +101,14 @@ func TestMemNode_String(t *testing.T) {
}
func TestMemNode_MutateBranch(t *testing.T) {
key := []byte("key")
origHash := []byte("origHash")
original := &MemNode{
height: 2,
version: 5,
size: 10,
key: []byte("key"),
hash: []byte("oldhash"),
key: key,
hash: origHash,
left: NewNodePointer(&MemNode{}),
right: NewNodePointer(&MemNode{}),
}
@ -113,13 +118,13 @@ func TestMemNode_MutateBranch(t *testing.T) {
// Version updated, hash cleared
require.Equal(t, uint32(12), mutated.Version())
require.Nil(t, mutated.Hash())
require.Nil(t, mutated.Hash().UnsafeBytes())
// Other fields preserved
require.Equal(t, original.Height(), mutated.Height())
require.Equal(t, original.Size(), mutated.Size())
key, _ := mutated.Key()
require.Equal(t, []byte("key"), key)
key2, _ := mutated.Key()
require.Equal(t, key, key2.UnsafeBytes())
require.Equal(t, original.Left(), mutated.Left())
require.Equal(t, original.Right(), mutated.Right())
@ -128,7 +133,7 @@ func TestMemNode_MutateBranch(t *testing.T) {
// Original unchanged
require.Equal(t, uint32(5), original.Version())
require.Equal(t, []byte("oldhash"), original.Hash())
require.Equal(t, origHash, original.Hash().UnsafeBytes())
}
func TestMemNode_Get_Leaf(t *testing.T) {
@ -180,7 +185,7 @@ func TestMemNode_Get_Leaf(t *testing.T) {
}
val, idx, err := node.Get([]byte(tt.searchKey))
require.NoError(t, err)
require.Equal(t, tt.wantValue, val)
require.Equal(t, tt.wantValue, val.UnsafeBytes())
require.Equal(t, tt.wantIndex, idx)
})
}
@ -254,7 +259,7 @@ func TestMemNode_Get_Branch(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
val, idx, err := root.Get([]byte(tt.searchKey))
require.NoError(t, err)
require.Equal(t, tt.wantValue, val)
require.Equal(t, tt.wantValue, val.UnsafeBytes())
require.Equal(t, tt.wantIndex, idx)
})
}
@ -312,7 +317,7 @@ func TestMemNode_Get_DeeperTree(t *testing.T) {
t.Run(tt.searchKey, func(t *testing.T) {
val, idx, err := root.Get([]byte(tt.searchKey))
require.NoError(t, err)
require.Equal(t, tt.wantValue, val)
require.Equal(t, tt.wantValue, val.UnsafeBytes())
require.Equal(t, tt.wantIndex, idx)
})
}

View File

@ -12,11 +12,10 @@ type Node interface {
IsLeaf() bool
// Key returns the key of this node.
Key() ([]byte, error)
Key() (UnsafeBytes, error)
// Value returns the value of this node.
// Calling this on a non-leaf node will return nil and possibly an error.
Value() ([]byte, error)
// Value returns the value of this node. It is an error to call this method on non-leaf nodes.
Value() (UnsafeBytes, error)
// Left returns a pointer to the left child node.
// If this is called on a leaf node, it returns nil.
@ -28,7 +27,7 @@ type Node interface {
// Hash returns the hash of this node.
// Hash may or may not have been computed yet.
Hash() []byte
Hash() UnsafeBytes
// Height returns the height of the subtree rooted at this node.
Height() uint8
@ -45,7 +44,7 @@ type Node interface {
// The index is the 0-based position where the key exists or would be inserted
// in sorted order among all leaf keys in this subtree. This is useful for
// range queries and determining a key's position even when it doesn't exist.
Get(key []byte) (value []byte, index int64, err error)
Get(key []byte) (value UnsafeBytes, index int64, err error)
// MutateBranch creates a mutable copy of this branch node created at the specified version.
// Since this is an immutable tree, whenever we need to modify a branch node, we should call this method

View File

@ -7,10 +7,10 @@ import (
// NodePointer is a pointer to a Node, which may be either in-memory, on-disk or both.
type NodePointer struct {
mem atomic.Pointer[MemNode]
changeset *Changeset
fileIdx uint32 // absolute index in file, 1-based, zero means we don't have an offset
id NodeID
mem atomic.Pointer[MemNode]
// changeset *Changeset // commented to satisfy linter, will uncomment in a future PR when we wire it up
fileIdx uint32 // absolute index in file, 1-based, zero means we don't have an offset
id NodeID
}
// NewNodePointer creates a new NodePointer pointing to the given in-memory node.
@ -20,13 +20,23 @@ func NewNodePointer(memNode *MemNode) *NodePointer {
return n
}
// Resolve resolves the NodePointer to a Node, loading from memory or disk as necessary.
func (p *NodePointer) Resolve() (Node, error) {
// Resolve resolves the NodePointer to a Node, loading from memory or disk as necessary
// as well as a Pin which MUST be unpinned after the caller is done using the node.
// Resolve will ALWAYS return a valid Pin even if there is an error. For clarity and
// consistency it is recommended to introduce a defer pin.Unpin() immediately after
// calling Resolve and BEFORE checking the error return value like this:
//
// node, pin, err := nodePointer.Resolve()
// defer pin.Unpin()
// if err != nil {
// // handle error
// }
func (p *NodePointer) Resolve() (Node, Pin, error) {
mem := p.mem.Load()
if mem != nil {
return mem, nil
return mem, NoopPin{}, nil
}
return p.changeset.Resolve(p.id, p.fileIdx)
return nil, NoopPin{}, fmt.Errorf("node not in memory and on-disk loading will be implemented in a future PR")
}
// String implements the fmt.Stringer interface.

35
iavl/internal/pin.go Normal file
View File

@ -0,0 +1,35 @@
package internal
// Pin represents a handle that pins some memory-mapped file data in memory.
// When the Pin is released via Unpin(), the data may be unmapped from memory.
// Pin must be used to ensure that any UnsafeBytes obtained from memory-mapped
// data remains valid while in use.
// The caller must ensure that Unpin() is called exactly once
// for each Pin obtained. It is recommended to use the following pattern:
//
// node, pin, err := nodePointer.Resolve()
// defer pin.Unpin()
// if err != nil {
// // handle error
// }
//
// When we are using arrays directly addressed to memory mapped files, these arrays
// are not part of the normal Go garbage collected memory. We must map and unmap
// these regions of memory explicitly. Pin represents a commitment to keep the memory
// mapped at least until Unpin() is called. During normal operation, changeset files
// will be mapped and unmapped as needed either because the file size has grown, we have
// compacted a changeset, or simply to manage open file descriptors.
// Under the hood pins use a reference counting mechanism to keep track of how many
// active users there are of a particular memory-mapped region.
type Pin interface {
// Unpin releases the Pin, allowing the underlying memory to be unmapped.
// Implementors should ensure that Unpin() is idempotent and only unpins the
// memory once even if called multiple times.
Unpin()
}
// NoopPin is a Pin that does nothing on Unpin().
type NoopPin struct{}
// Unpin implements the Pin interface.
func (NoopPin) Unpin() {}

View File

@ -0,0 +1,56 @@
package internal
// UnsafeBytes wraps a byte slice that may or not be a direct reference to
// a memory-mapped file.
// Generally, an unsafe byte slice cannot be expected to live longer than the
// Pin on the object it was obtained from.
// As long as it is pinned, it is safe to use the UnsafeBytes() method to get
// the underlying byte slice without copying.
// If the byte slice needs to be retained beyond the Pin's lifetime, the
// SafeCopy() method must be used to get a safe copy of the byte slice.
type UnsafeBytes struct {
bz []byte
safe bool
}
// WrapUnsafeBytes wraps an unsafe byte slice as UnsafeBytes, indicating that
// it is unsafe to use without copying.
// Use this method when you are wrapping a byte slice obtained from a memory-mapped file.
func WrapUnsafeBytes(bz []byte) UnsafeBytes {
return UnsafeBytes{bz: bz, safe: false}
}
// WrapSafeBytes wraps a safe byte slice as UnsafeBytes, indicating that
// it is safe to use without copying.
// Use this method when you are wrapping a byte slice that is known to be safe,
// e.g., a byte slice allocated in regular garbage-collected memory.
func WrapSafeBytes(bz []byte) UnsafeBytes {
return UnsafeBytes{bz: bz, safe: true}
}
// IsNil returns true if the underlying byte slice is nil.
func (ub UnsafeBytes) IsNil() bool {
return ub.bz == nil
}
// UnsafeBytes returns the underlying byte slice without copying.
// The caller must ensure that the byte slice is not used beyond the lifetime
// of the Pin on the object it was obtained from.
func (ub UnsafeBytes) UnsafeBytes() []byte {
return ub.bz
}
// SafeCopy returns a safe copy of the underlying byte slice.
// If the underlying byte slice is already safe or nil, it is returned as is.
// If the underlying byte slice is unsafe, a copy is made and returned.
func (ub UnsafeBytes) SafeCopy() []byte {
if ub.safe {
return ub.bz
}
if ub.bz == nil {
return nil
}
copied := make([]byte, len(ub.bz))
copy(copied, ub.bz)
return copied
}

View File

@ -0,0 +1,39 @@
package internal
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestUnsafeBytes(t *testing.T) {
hello := []byte("hello")
unsafe := WrapUnsafeBytes(hello)
require.False(t, unsafe.IsNil())
require.Equal(t, hello, unsafe.UnsafeBytes())
safeCopy := unsafe.SafeCopy()
require.Equal(t, hello, safeCopy)
require.NotSame(t, &hello[0], &safeCopy[0]) // different underlying array
safe := WrapSafeBytes(hello)
require.False(t, safe.IsNil())
require.Equal(t, hello, safe.UnsafeBytes())
safeCopy2 := safe.SafeCopy()
require.Equal(t, hello, safeCopy2)
require.Same(t, &hello[0], &safeCopy2[0]) // same underlying array
nilUnsafe := WrapUnsafeBytes(nil)
require.True(t, nilUnsafe.IsNil())
require.Nil(t, nilUnsafe.UnsafeBytes())
require.Nil(t, nilUnsafe.SafeCopy())
nilSafe := WrapSafeBytes(nil)
require.True(t, nilSafe.IsNil())
require.Nil(t, nilSafe.UnsafeBytes())
require.Nil(t, nilSafe.SafeCopy())
nilInit := UnsafeBytes{}
require.True(t, nilInit.IsNil())
require.Nil(t, nilInit.UnsafeBytes())
require.Nil(t, nilInit.SafeCopy())
}