feat(iavl): add Node, MemNode, and NodePointer (#25633)

This commit is contained in:
Aaron Craelius 2025-12-04 16:54:31 -05:00 committed by GitHub
parent 0e1f43c8f1
commit 18e85dec5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 559 additions and 1 deletions

View File

@ -50,7 +50,7 @@ type BranchLayout struct {
// and an additional byte of padding is already reserved below for this purpose.
KeyOffset uint32
// Height is the height of this branch node in the tree.
// Height is the height of the subtree rooted at this branch node.
Height uint8
// NOTE: there are two bytes of padding here that could be used for something else in the future if needed

View File

@ -0,0 +1,11 @@
package internal
import "fmt"
// NOTE: This is a placeholder implementation. We will add the implementation in a future PR.
type Changeset struct{}
func (cs *Changeset) Resolve(id NodeID, fileIdx uint32) (Node, error) {
return nil, fmt.Errorf("not implemented")
}

126
iavl/internal/mem_node.go Normal file
View File

@ -0,0 +1,126 @@
package internal
import (
"bytes"
"fmt"
)
// MemNode represents an in-memory node that has recently been created and may or may not have
// been serialized to disk yet.
type MemNode struct {
height uint8
version uint32
size int64
key []byte
value []byte
left *NodePointer
right *NodePointer
hash []byte
nodeId NodeID // ID of this node, 0 if not yet assigned
keyOffset uint32
}
var _ Node = (*MemNode)(nil)
// ID implements the Node interface.
func (node *MemNode) ID() NodeID {
return node.nodeId
}
// Height implements the Node interface.
func (node *MemNode) Height() uint8 {
return node.height
}
// Size implements the Node interface.
func (node *MemNode) Size() int64 {
return node.size
}
// Version implements the Node interface.
func (node *MemNode) Version() uint32 {
return node.version
}
// Key implements the Node interface.
func (node *MemNode) Key() ([]byte, error) {
return node.key, nil
}
// Value implements the Node interface.
func (node *MemNode) Value() ([]byte, error) {
return node.value, nil
}
// Left implements the Node interface.
func (node *MemNode) Left() *NodePointer {
return node.left
}
// Right implements the Node interface.
func (node *MemNode) Right() *NodePointer {
return node.right
}
// Hash implements the Node interface.
func (node *MemNode) Hash() []byte {
return node.hash
}
// MutateBranch implements the Node interface.
func (node *MemNode) MutateBranch(version uint32) (*MemNode, error) {
n := *node
n.version = version
n.hash = nil
return &n, nil
}
// Get implements the Node interface.
func (node *MemNode) Get(key []byte) (value []byte, index int64, err error) {
if node.IsLeaf() {
switch bytes.Compare(node.key, key) {
case -1:
return nil, 1, nil
case 1:
return nil, 0, nil
default:
return node.value, 0, nil
}
}
if bytes.Compare(key, node.key) < 0 {
leftNode, err := node.left.Resolve()
if err != nil {
return nil, 0, err
}
return leftNode.Get(key)
}
rightNode, err := node.right.Resolve()
if err != nil {
return nil, 0, err
}
value, index, err = rightNode.Get(key)
if err != nil {
return nil, 0, err
}
index += node.size - rightNode.Size()
return value, index, nil
}
// IsLeaf implements the Node interface.
func (node *MemNode) IsLeaf() bool {
return node.height == 0
}
// String implements the fmt.Stringer interface.
func (node *MemNode) String() string {
if node.IsLeaf() {
return fmt.Sprintf("MemNode{key:%x, version:%d, size:%d, value:%x}", node.key, node.version, node.size, node.value)
} else {
return fmt.Sprintf("MemNode{key:%x, version:%d, size:%d, height:%d, left:%s, right:%s}", node.key, node.version, node.size, node.height, node.left, node.right)
}
}

View File

@ -0,0 +1,319 @@
package internal
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMemNode_Getters(t *testing.T) {
left := NewNodePointer(&MemNode{})
right := NewNodePointer(&MemNode{})
nodeId := NewNodeID(true, 5, 10)
node := &MemNode{
height: 3,
version: 7,
size: 42,
key: []byte("testkey"),
value: []byte("testvalue"),
hash: []byte("testhash"),
left: left,
right: right,
nodeId: nodeId,
keyOffset: 100,
}
require.Equal(t, uint8(3), node.Height())
require.Equal(t, uint32(7), node.Version())
require.Equal(t, int64(42), node.Size())
require.Equal(t, left, node.Left())
require.Equal(t, right, node.Right())
require.Equal(t, []byte("testhash"), node.Hash())
require.Equal(t, nodeId, node.ID())
key, err := node.Key()
require.NoError(t, err)
require.Equal(t, []byte("testkey"), key)
value, err := node.Value()
require.NoError(t, err)
require.Equal(t, []byte("testvalue"), value)
}
func TestMemNode_IsLeaf(t *testing.T) {
tests := []struct {
name string
height uint8
want bool
}{
{name: "leaf", height: 0, want: true},
{name: "branch height 1", height: 1, want: false},
{name: "branch height 5", height: 5, want: false},
{name: "branch max height", height: 255, want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
node := &MemNode{height: tt.height}
require.Equal(t, tt.want, node.IsLeaf())
})
}
}
func TestMemNode_String(t *testing.T) {
tests := []struct {
name string
node *MemNode
want string
}{
{
name: "leaf node",
node: &MemNode{
height: 0,
version: 1,
size: 1,
key: []byte{0xab, 0xcd},
value: []byte{0x12, 0x34},
},
want: "MemNode{key:abcd, version:1, size:1, value:1234}",
},
{
name: "branch node",
node: &MemNode{
height: 2,
version: 5,
size: 10,
key: []byte{0xff},
left: &NodePointer{id: NewNodeID(true, 1, 1)},
right: &NodePointer{id: NewNodeID(true, 1, 2)},
},
want: "MemNode{key:ff, version:5, size:10, height:2, left:NodePointer{id: NodeID{leaf:true, version:1, index:1}, fileIdx: 0}, right:NodePointer{id: NodeID{leaf:true, version:1, index:2}, fileIdx: 0}}",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, tt.node.String())
})
}
}
func TestMemNode_MutateBranch(t *testing.T) {
original := &MemNode{
height: 2,
version: 5,
size: 10,
key: []byte("key"),
hash: []byte("oldhash"),
left: NewNodePointer(&MemNode{}),
right: NewNodePointer(&MemNode{}),
}
mutated, err := original.MutateBranch(12)
require.NoError(t, err)
// Version updated, hash cleared
require.Equal(t, uint32(12), mutated.Version())
require.Nil(t, mutated.Hash())
// Other fields preserved
require.Equal(t, original.Height(), mutated.Height())
require.Equal(t, original.Size(), mutated.Size())
key, _ := mutated.Key()
require.Equal(t, []byte("key"), key)
require.Equal(t, original.Left(), mutated.Left())
require.Equal(t, original.Right(), mutated.Right())
// Is a copy, not same pointer
require.NotSame(t, original, mutated)
// Original unchanged
require.Equal(t, uint32(5), original.Version())
require.Equal(t, []byte("oldhash"), original.Hash())
}
func TestMemNode_Get_Leaf(t *testing.T) {
// When Get is called on a leaf node:
// - If key matches: returns (value, 0, nil)
// - If key not found: returns (nil, index, nil) where index is the insertion point
// - key < nodeKey: index=0 (would insert before this leaf)
// - key > nodeKey: index=1 (would insert after this leaf)
tests := []struct {
name string
nodeKey string
nodeValue string
searchKey string
wantValue []byte
wantIndex int64
}{
{
name: "exact match",
nodeKey: "b",
nodeValue: "val_b",
searchKey: "b",
wantValue: []byte("val_b"),
wantIndex: 0,
},
{
name: "search key less than node key",
nodeKey: "b",
nodeValue: "val_b",
searchKey: "a",
wantValue: nil,
wantIndex: 0, // "a" would be inserted before "b"
},
{
name: "search key greater than node key",
nodeKey: "b",
nodeValue: "val_b",
searchKey: "c",
wantValue: nil,
wantIndex: 1, // "c" would be inserted after "b"
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
node := &MemNode{
height: 0,
size: 1,
key: []byte(tt.nodeKey),
value: []byte(tt.nodeValue),
}
val, idx, err := node.Get([]byte(tt.searchKey))
require.NoError(t, err)
require.Equal(t, tt.wantValue, val)
require.Equal(t, tt.wantIndex, idx)
})
}
}
func TestMemNode_Get_Branch(t *testing.T) {
// Hand-construct a simple tree:
//
// [b] <- branch, key="b", size=2
// / \
// [a] [b] <- leaves (index 0, index 1)
//
// In IAVL, branch key = smallest key in right subtree
//
// Index is the 0-based position in sorted leaf order:
// - "a" is at index 0, "b" is at index 1
// - Keys not found return the insertion point
leftLeaf := &MemNode{
height: 0,
size: 1,
key: []byte("a"),
value: []byte("val_a"),
}
rightLeaf := &MemNode{
height: 0,
size: 1,
key: []byte("b"),
value: []byte("val_b"),
}
root := &MemNode{
height: 1,
size: 2,
key: []byte("b"), // smallest key in right subtree
left: NewNodePointer(leftLeaf),
right: NewNodePointer(rightLeaf),
}
tests := []struct {
name string
searchKey string
wantValue []byte
wantIndex int64
}{
{
name: "find in left subtree",
searchKey: "a",
wantValue: []byte("val_a"),
wantIndex: 0,
},
{
name: "find in right subtree",
searchKey: "b",
wantValue: []byte("val_b"),
wantIndex: 1,
},
{
name: "key not found - less than all",
searchKey: "0",
wantValue: nil,
wantIndex: 0, // "0" would be inserted at position 0
},
{
name: "key not found - greater than all",
searchKey: "z",
wantValue: nil,
wantIndex: 2, // "z" would be inserted at position 2 (after both leaves)
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, idx, err := root.Get([]byte(tt.searchKey))
require.NoError(t, err)
require.Equal(t, tt.wantValue, val)
require.Equal(t, tt.wantIndex, idx)
})
}
}
func TestMemNode_Get_DeeperTree(t *testing.T) {
// Hand-construct a 3-level tree:
//
// [c] <- root, size=4
// / \
// [b] [d] <- branches, size=2 each
// / \ / \
// [a] [b] [c] [d] <- leaves
//
// Sorted keys: a=0, b=1, c=2, d=3
leafA := &MemNode{height: 0, size: 1, key: []byte("a"), value: []byte("val_a")}
leafB := &MemNode{height: 0, size: 1, key: []byte("b"), value: []byte("val_b")}
leafC := &MemNode{height: 0, size: 1, key: []byte("c"), value: []byte("val_c")}
leafD := &MemNode{height: 0, size: 1, key: []byte("d"), value: []byte("val_d")}
branchLeft := &MemNode{
height: 1,
size: 2,
key: []byte("b"),
left: NewNodePointer(leafA),
right: NewNodePointer(leafB),
}
branchRight := &MemNode{
height: 1,
size: 2,
key: []byte("d"),
left: NewNodePointer(leafC),
right: NewNodePointer(leafD),
}
root := &MemNode{
height: 2,
size: 4,
key: []byte("c"), // smallest key in right subtree
left: NewNodePointer(branchLeft),
right: NewNodePointer(branchRight),
}
tests := []struct {
searchKey string
wantValue []byte
wantIndex int64
}{
{"a", []byte("val_a"), 0},
{"b", []byte("val_b"), 1},
{"c", []byte("val_c"), 2},
{"d", []byte("val_d"), 3},
}
for _, tt := range tests {
t.Run(tt.searchKey, func(t *testing.T) {
val, idx, err := root.Get([]byte(tt.searchKey))
require.NoError(t, err)
require.Equal(t, tt.wantValue, val)
require.Equal(t, tt.wantIndex, idx)
})
}
}

57
iavl/internal/node.go Normal file
View File

@ -0,0 +1,57 @@
package internal
import "fmt"
// Node represents a traversable node in the IAVL tree.
type Node interface {
// ID returns the unique identifier of the node.
// If the node has not been assigned an ID yet, it returns the zero value of NodeID.
ID() NodeID
// IsLeaf indicates whether this node is a leaf node.
IsLeaf() bool
// Key returns the key of this node.
Key() ([]byte, error)
// Value returns the value of this node.
// Calling this on a non-leaf node will return nil and possibly an error.
Value() ([]byte, error)
// Left returns a pointer to the left child node.
// If this is called on a leaf node, it returns nil.
Left() *NodePointer
// Right returns a pointer to the right child node.
// If this is called on a leaf node, it returns nil.
Right() *NodePointer
// Hash returns the hash of this node.
// Hash may or may not have been computed yet.
Hash() []byte
// Height returns the height of the subtree rooted at this node.
Height() uint8
// Size returns the number of leaf nodes in the subtree rooted at this node.
Size() int64
// Version returns the version at which this node was created.
Version() uint32
// Get traverses this subtree to find the value associated with the given key.
// If the key is found, value contains the associated value.
// If the key is not found, value is nil (not an error).
// The index is the 0-based position where the key exists or would be inserted
// in sorted order among all leaf keys in this subtree. This is useful for
// range queries and determining a key's position even when it doesn't exist.
Get(key []byte) (value []byte, index int64, err error)
// MutateBranch creates a mutable copy of this branch node created at the specified version.
// Since this is an immutable tree, whenever we need to modify a branch node, we should call this method
// to create a mutable copy of it with its version updated.
// This method should only be called on branch nodes; calling it on leaf nodes will result in an error.
MutateBranch(version uint32) (*MemNode, error)
fmt.Stringer
}

View File

@ -37,6 +37,11 @@ func (id NodeID) IsLeaf() bool {
return id.FlagIndex.IsLeaf()
}
// IsEmpty returns true if the NodeID is the zero value.
func (id NodeID) IsEmpty() bool {
return id.Version == 0 && id.FlagIndex == 0
}
// String returns a string representation of the NodeID.
func (id NodeID) String() string {
return fmt.Sprintf("NodeID{leaf:%t, version:%d, index:%d}", id.IsLeaf(), id.Version, id.FlagIndex.Index())

View File

@ -34,3 +34,8 @@ func TestNodeID(t *testing.T) {
})
}
}
func TestNodeID_IsEmpty(t *testing.T) {
require.True(t, NodeID{}.IsEmpty())
require.False(t, NewNodeID(true, 1, 1).IsEmpty())
}

View File

@ -0,0 +1,35 @@
package internal
import (
"fmt"
"sync/atomic"
)
// NodePointer is a pointer to a Node, which may be either in-memory, on-disk or both.
type NodePointer struct {
mem atomic.Pointer[MemNode]
changeset *Changeset
fileIdx uint32 // absolute index in file, 1-based, zero means we don't have an offset
id NodeID
}
// NewNodePointer creates a new NodePointer pointing to the given in-memory node.
func NewNodePointer(memNode *MemNode) *NodePointer {
n := &NodePointer{}
n.mem.Store(memNode)
return n
}
// Resolve resolves the NodePointer to a Node, loading from memory or disk as necessary.
func (p *NodePointer) Resolve() (Node, error) {
mem := p.mem.Load()
if mem != nil {
return mem, nil
}
return p.changeset.Resolve(p.id, p.fileIdx)
}
// String implements the fmt.Stringer interface.
func (p *NodePointer) String() string {
return fmt.Sprintf("NodePointer{id: %s, fileIdx: %d}", p.id.String(), p.fileIdx)
}