From 6065146fa60066dad4b1d76a36a8a57f13fc98e8 Mon Sep 17 00:00:00 2001 From: Aaron Craelius Date: Wed, 10 Dec 2025 21:42:25 -0500 Subject: [PATCH] feat(iavl): add Pin and UnsafeBytes design for managing mmaps (#25657) --- iavl/internal/mem_node.go | 32 +++++++++-------- iavl/internal/mem_node_test.go | 35 +++++++++++-------- iavl/internal/node.go | 11 +++--- iavl/internal/node_pointer.go | 26 +++++++++----- iavl/internal/pin.go | 35 +++++++++++++++++++ iavl/internal/unsafe_bytes.go | 56 ++++++++++++++++++++++++++++++ iavl/internal/unsafe_bytes_test.go | 39 +++++++++++++++++++++ 7 files changed, 190 insertions(+), 44 deletions(-) create mode 100644 iavl/internal/pin.go create mode 100644 iavl/internal/unsafe_bytes.go create mode 100644 iavl/internal/unsafe_bytes_test.go diff --git a/iavl/internal/mem_node.go b/iavl/internal/mem_node.go index b0b3f8fbed..bd14c590f7 100644 --- a/iavl/internal/mem_node.go +++ b/iavl/internal/mem_node.go @@ -43,13 +43,13 @@ func (node *MemNode) Version() uint32 { } // Key implements the Node interface. -func (node *MemNode) Key() ([]byte, error) { - return node.key, nil +func (node *MemNode) Key() (UnsafeBytes, error) { + return WrapSafeBytes(node.key), nil } // Value implements the Node interface. -func (node *MemNode) Value() ([]byte, error) { - return node.value, nil +func (node *MemNode) Value() (UnsafeBytes, error) { + return WrapSafeBytes(node.value), nil } // Left implements the Node interface. @@ -63,8 +63,8 @@ func (node *MemNode) Right() *NodePointer { } // Hash implements the Node interface. -func (node *MemNode) Hash() []byte { - return node.hash +func (node *MemNode) Hash() UnsafeBytes { + return WrapSafeBytes(node.hash) } // MutateBranch implements the Node interface. @@ -76,35 +76,37 @@ func (node *MemNode) MutateBranch(version uint32) (*MemNode, error) { } // Get implements the Node interface. -func (node *MemNode) Get(key []byte) (value []byte, index int64, err error) { +func (node *MemNode) Get(key []byte) (value UnsafeBytes, index int64, err error) { if node.IsLeaf() { switch bytes.Compare(node.key, key) { case -1: - return nil, 1, nil + return UnsafeBytes{}, 1, nil case 1: - return nil, 0, nil + return UnsafeBytes{}, 0, nil default: - return node.value, 0, nil + return WrapSafeBytes(node.value), 0, nil } } if bytes.Compare(key, node.key) < 0 { - leftNode, err := node.left.Resolve() + leftNode, pin, err := node.left.Resolve() + defer pin.Unpin() if err != nil { - return nil, 0, err + return UnsafeBytes{}, 0, err } return leftNode.Get(key) } - rightNode, err := node.right.Resolve() + rightNode, pin, err := node.right.Resolve() + defer pin.Unpin() if err != nil { - return nil, 0, err + return UnsafeBytes{}, 0, err } value, index, err = rightNode.Get(key) if err != nil { - return nil, 0, err + return UnsafeBytes{}, 0, err } index += node.size - rightNode.Size() diff --git a/iavl/internal/mem_node_test.go b/iavl/internal/mem_node_test.go index ca948f0701..8fe4e9b1f7 100644 --- a/iavl/internal/mem_node_test.go +++ b/iavl/internal/mem_node_test.go @@ -11,13 +11,16 @@ func TestMemNode_Getters(t *testing.T) { right := NewNodePointer(&MemNode{}) nodeId := NewNodeID(true, 5, 10) + testKey := []byte("testkey") + testValue := []byte("testvalue") + testHash := []byte("testhash") node := &MemNode{ height: 3, version: 7, size: 42, - key: []byte("testkey"), - value: []byte("testvalue"), - hash: []byte("testhash"), + key: testKey, + value: testValue, + hash: testHash, left: left, right: right, nodeId: nodeId, @@ -29,16 +32,16 @@ func TestMemNode_Getters(t *testing.T) { require.Equal(t, int64(42), node.Size()) require.Equal(t, left, node.Left()) require.Equal(t, right, node.Right()) - require.Equal(t, []byte("testhash"), node.Hash()) + require.Equal(t, testHash, node.Hash().UnsafeBytes()) require.Equal(t, nodeId, node.ID()) key, err := node.Key() require.NoError(t, err) - require.Equal(t, []byte("testkey"), key) + require.Equal(t, testKey, key.UnsafeBytes()) value, err := node.Value() require.NoError(t, err) - require.Equal(t, []byte("testvalue"), value) + require.Equal(t, testValue, value.UnsafeBytes()) } func TestMemNode_IsLeaf(t *testing.T) { @@ -98,12 +101,14 @@ func TestMemNode_String(t *testing.T) { } func TestMemNode_MutateBranch(t *testing.T) { + key := []byte("key") + origHash := []byte("origHash") original := &MemNode{ height: 2, version: 5, size: 10, - key: []byte("key"), - hash: []byte("oldhash"), + key: key, + hash: origHash, left: NewNodePointer(&MemNode{}), right: NewNodePointer(&MemNode{}), } @@ -113,13 +118,13 @@ func TestMemNode_MutateBranch(t *testing.T) { // Version updated, hash cleared require.Equal(t, uint32(12), mutated.Version()) - require.Nil(t, mutated.Hash()) + require.Nil(t, mutated.Hash().UnsafeBytes()) // Other fields preserved require.Equal(t, original.Height(), mutated.Height()) require.Equal(t, original.Size(), mutated.Size()) - key, _ := mutated.Key() - require.Equal(t, []byte("key"), key) + key2, _ := mutated.Key() + require.Equal(t, key, key2.UnsafeBytes()) require.Equal(t, original.Left(), mutated.Left()) require.Equal(t, original.Right(), mutated.Right()) @@ -128,7 +133,7 @@ func TestMemNode_MutateBranch(t *testing.T) { // Original unchanged require.Equal(t, uint32(5), original.Version()) - require.Equal(t, []byte("oldhash"), original.Hash()) + require.Equal(t, origHash, original.Hash().UnsafeBytes()) } func TestMemNode_Get_Leaf(t *testing.T) { @@ -180,7 +185,7 @@ func TestMemNode_Get_Leaf(t *testing.T) { } val, idx, err := node.Get([]byte(tt.searchKey)) require.NoError(t, err) - require.Equal(t, tt.wantValue, val) + require.Equal(t, tt.wantValue, val.UnsafeBytes()) require.Equal(t, tt.wantIndex, idx) }) } @@ -254,7 +259,7 @@ func TestMemNode_Get_Branch(t *testing.T) { t.Run(tt.name, func(t *testing.T) { val, idx, err := root.Get([]byte(tt.searchKey)) require.NoError(t, err) - require.Equal(t, tt.wantValue, val) + require.Equal(t, tt.wantValue, val.UnsafeBytes()) require.Equal(t, tt.wantIndex, idx) }) } @@ -312,7 +317,7 @@ func TestMemNode_Get_DeeperTree(t *testing.T) { t.Run(tt.searchKey, func(t *testing.T) { val, idx, err := root.Get([]byte(tt.searchKey)) require.NoError(t, err) - require.Equal(t, tt.wantValue, val) + require.Equal(t, tt.wantValue, val.UnsafeBytes()) require.Equal(t, tt.wantIndex, idx) }) } diff --git a/iavl/internal/node.go b/iavl/internal/node.go index f449ccdf45..ab1a87db27 100644 --- a/iavl/internal/node.go +++ b/iavl/internal/node.go @@ -12,11 +12,10 @@ type Node interface { IsLeaf() bool // Key returns the key of this node. - Key() ([]byte, error) + Key() (UnsafeBytes, error) - // Value returns the value of this node. - // Calling this on a non-leaf node will return nil and possibly an error. - Value() ([]byte, error) + // Value returns the value of this node. It is an error to call this method on non-leaf nodes. + Value() (UnsafeBytes, error) // Left returns a pointer to the left child node. // If this is called on a leaf node, it returns nil. @@ -28,7 +27,7 @@ type Node interface { // Hash returns the hash of this node. // Hash may or may not have been computed yet. - Hash() []byte + Hash() UnsafeBytes // Height returns the height of the subtree rooted at this node. Height() uint8 @@ -45,7 +44,7 @@ type Node interface { // The index is the 0-based position where the key exists or would be inserted // in sorted order among all leaf keys in this subtree. This is useful for // range queries and determining a key's position even when it doesn't exist. - Get(key []byte) (value []byte, index int64, err error) + Get(key []byte) (value UnsafeBytes, index int64, err error) // MutateBranch creates a mutable copy of this branch node created at the specified version. // Since this is an immutable tree, whenever we need to modify a branch node, we should call this method diff --git a/iavl/internal/node_pointer.go b/iavl/internal/node_pointer.go index fa0afa130a..e2a2511910 100644 --- a/iavl/internal/node_pointer.go +++ b/iavl/internal/node_pointer.go @@ -7,10 +7,10 @@ import ( // NodePointer is a pointer to a Node, which may be either in-memory, on-disk or both. type NodePointer struct { - mem atomic.Pointer[MemNode] - changeset *Changeset - fileIdx uint32 // absolute index in file, 1-based, zero means we don't have an offset - id NodeID + mem atomic.Pointer[MemNode] + // changeset *Changeset // commented to satisfy linter, will uncomment in a future PR when we wire it up + fileIdx uint32 // absolute index in file, 1-based, zero means we don't have an offset + id NodeID } // NewNodePointer creates a new NodePointer pointing to the given in-memory node. @@ -20,13 +20,23 @@ func NewNodePointer(memNode *MemNode) *NodePointer { return n } -// Resolve resolves the NodePointer to a Node, loading from memory or disk as necessary. -func (p *NodePointer) Resolve() (Node, error) { +// Resolve resolves the NodePointer to a Node, loading from memory or disk as necessary +// as well as a Pin which MUST be unpinned after the caller is done using the node. +// Resolve will ALWAYS return a valid Pin even if there is an error. For clarity and +// consistency it is recommended to introduce a defer pin.Unpin() immediately after +// calling Resolve and BEFORE checking the error return value like this: +// +// node, pin, err := nodePointer.Resolve() +// defer pin.Unpin() +// if err != nil { +// // handle error +// } +func (p *NodePointer) Resolve() (Node, Pin, error) { mem := p.mem.Load() if mem != nil { - return mem, nil + return mem, NoopPin{}, nil } - return p.changeset.Resolve(p.id, p.fileIdx) + return nil, NoopPin{}, fmt.Errorf("node not in memory and on-disk loading will be implemented in a future PR") } // String implements the fmt.Stringer interface. diff --git a/iavl/internal/pin.go b/iavl/internal/pin.go new file mode 100644 index 0000000000..858097e051 --- /dev/null +++ b/iavl/internal/pin.go @@ -0,0 +1,35 @@ +package internal + +// Pin represents a handle that pins some memory-mapped file data in memory. +// When the Pin is released via Unpin(), the data may be unmapped from memory. +// Pin must be used to ensure that any UnsafeBytes obtained from memory-mapped +// data remains valid while in use. +// The caller must ensure that Unpin() is called exactly once +// for each Pin obtained. It is recommended to use the following pattern: +// +// node, pin, err := nodePointer.Resolve() +// defer pin.Unpin() +// if err != nil { +// // handle error +// } +// +// When we are using arrays directly addressed to memory mapped files, these arrays +// are not part of the normal Go garbage collected memory. We must map and unmap +// these regions of memory explicitly. Pin represents a commitment to keep the memory +// mapped at least until Unpin() is called. During normal operation, changeset files +// will be mapped and unmapped as needed either because the file size has grown, we have +// compacted a changeset, or simply to manage open file descriptors. +// Under the hood pins use a reference counting mechanism to keep track of how many +// active users there are of a particular memory-mapped region. +type Pin interface { + // Unpin releases the Pin, allowing the underlying memory to be unmapped. + // Implementors should ensure that Unpin() is idempotent and only unpins the + // memory once even if called multiple times. + Unpin() +} + +// NoopPin is a Pin that does nothing on Unpin(). +type NoopPin struct{} + +// Unpin implements the Pin interface. +func (NoopPin) Unpin() {} diff --git a/iavl/internal/unsafe_bytes.go b/iavl/internal/unsafe_bytes.go new file mode 100644 index 0000000000..d6e87413dd --- /dev/null +++ b/iavl/internal/unsafe_bytes.go @@ -0,0 +1,56 @@ +package internal + +// UnsafeBytes wraps a byte slice that may or not be a direct reference to +// a memory-mapped file. +// Generally, an unsafe byte slice cannot be expected to live longer than the +// Pin on the object it was obtained from. +// As long as it is pinned, it is safe to use the UnsafeBytes() method to get +// the underlying byte slice without copying. +// If the byte slice needs to be retained beyond the Pin's lifetime, the +// SafeCopy() method must be used to get a safe copy of the byte slice. +type UnsafeBytes struct { + bz []byte + safe bool +} + +// WrapUnsafeBytes wraps an unsafe byte slice as UnsafeBytes, indicating that +// it is unsafe to use without copying. +// Use this method when you are wrapping a byte slice obtained from a memory-mapped file. +func WrapUnsafeBytes(bz []byte) UnsafeBytes { + return UnsafeBytes{bz: bz, safe: false} +} + +// WrapSafeBytes wraps a safe byte slice as UnsafeBytes, indicating that +// it is safe to use without copying. +// Use this method when you are wrapping a byte slice that is known to be safe, +// e.g., a byte slice allocated in regular garbage-collected memory. +func WrapSafeBytes(bz []byte) UnsafeBytes { + return UnsafeBytes{bz: bz, safe: true} +} + +// IsNil returns true if the underlying byte slice is nil. +func (ub UnsafeBytes) IsNil() bool { + return ub.bz == nil +} + +// UnsafeBytes returns the underlying byte slice without copying. +// The caller must ensure that the byte slice is not used beyond the lifetime +// of the Pin on the object it was obtained from. +func (ub UnsafeBytes) UnsafeBytes() []byte { + return ub.bz +} + +// SafeCopy returns a safe copy of the underlying byte slice. +// If the underlying byte slice is already safe or nil, it is returned as is. +// If the underlying byte slice is unsafe, a copy is made and returned. +func (ub UnsafeBytes) SafeCopy() []byte { + if ub.safe { + return ub.bz + } + if ub.bz == nil { + return nil + } + copied := make([]byte, len(ub.bz)) + copy(copied, ub.bz) + return copied +} diff --git a/iavl/internal/unsafe_bytes_test.go b/iavl/internal/unsafe_bytes_test.go new file mode 100644 index 0000000000..97562f5e4a --- /dev/null +++ b/iavl/internal/unsafe_bytes_test.go @@ -0,0 +1,39 @@ +package internal + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnsafeBytes(t *testing.T) { + hello := []byte("hello") + unsafe := WrapUnsafeBytes(hello) + require.False(t, unsafe.IsNil()) + require.Equal(t, hello, unsafe.UnsafeBytes()) + safeCopy := unsafe.SafeCopy() + require.Equal(t, hello, safeCopy) + require.NotSame(t, &hello[0], &safeCopy[0]) // different underlying array + + safe := WrapSafeBytes(hello) + require.False(t, safe.IsNil()) + require.Equal(t, hello, safe.UnsafeBytes()) + safeCopy2 := safe.SafeCopy() + require.Equal(t, hello, safeCopy2) + require.Same(t, &hello[0], &safeCopy2[0]) // same underlying array + + nilUnsafe := WrapUnsafeBytes(nil) + require.True(t, nilUnsafe.IsNil()) + require.Nil(t, nilUnsafe.UnsafeBytes()) + require.Nil(t, nilUnsafe.SafeCopy()) + + nilSafe := WrapSafeBytes(nil) + require.True(t, nilSafe.IsNil()) + require.Nil(t, nilSafe.UnsafeBytes()) + require.Nil(t, nilSafe.SafeCopy()) + + nilInit := UnsafeBytes{} + require.True(t, nilInit.IsNil()) + require.Nil(t, nilInit.UnsafeBytes()) + require.Nil(t, nilInit.SafeCopy()) +}