diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go index 204584c95..adeaa1daa 100644 --- a/core/state/snapshot/generate.go +++ b/core/state/snapshot/generate.go @@ -230,7 +230,9 @@ func (dl *diskLayer) proveRange(ctx *generatorContext, trieId *trie.ID, prefix [ if origin == nil && !diskMore { stackTr := trie.NewStackTrie(nil) for i, key := range keys { - stackTr.Update(key, vals[i]) + if err := stackTr.Update(key, vals[i]); err != nil { + return nil, err + } } if gotRoot := stackTr.Hash(); gotRoot != root { return &proofResult{ diff --git a/trie/stacktrie.go b/trie/stacktrie.go index 423afdec8..f2f5355c4 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -18,6 +18,7 @@ package trie import ( "bytes" + "errors" "sync" "github.com/ethereum/go-ethereum/common" @@ -92,12 +93,14 @@ func NewStackTrie(options *StackTrieOptions) *StackTrie { // Update inserts a (key, value) pair into the stack trie. func (t *StackTrie) Update(key, value []byte) error { - k := keybytesToHex(key) if len(value) == 0 { - panic("deletion not supported") + return errors.New("trying to insert empty (deletion)") } + k := keybytesToHex(key) k = k[:len(k)-1] // chop the termination flag - + if bytes.Compare(t.last, k) >= 0 { + return errors.New("non-ascending key order") + } // track the first and last inserted entries. if t.first == nil { t.first = append([]byte{}, k...) diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go index 629586e2b..909a77062 100644 --- a/trie/stacktrie_test.go +++ b/trie/stacktrie_test.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/trie/testutil" + "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" ) @@ -463,3 +464,24 @@ func TestPartialStackTrie(t *testing.T) { } } } + +func TestStackTrieErrors(t *testing.T) { + s := NewStackTrie(nil) + // Deletion + if err := s.Update(nil, nil); err == nil { + t.Fatal("expected error") + } + if err := s.Update(nil, []byte{}); err == nil { + t.Fatal("expected error") + } + if err := s.Update([]byte{0xa}, []byte{}); err == nil { + t.Fatal("expected error") + } + // Non-ascending keys (going backwards or repeating) + assert.Nil(t, s.Update([]byte{0xaa}, []byte{0xa})) + assert.NotNil(t, s.Update([]byte{0xaa}, []byte{0xa}), "repeat insert same key") + assert.NotNil(t, s.Update([]byte{0xaa}, []byte{0xb}), "repeat insert same key") + assert.Nil(t, s.Update([]byte{0xab}, []byte{0xa})) + assert.NotNil(t, s.Update([]byte{0x10}, []byte{0xb}), "out of order insert") + assert.NotNil(t, s.Update([]byte{0xaa}, []byte{0xb}), "repeat insert same key") +}