From c3a77d626831b4ffe37ed4f8640e67e70ad5b220 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Thu, 29 Sep 2016 16:51:32 +0200 Subject: [PATCH] trie: fix delete bug for values contained in fullNode Delete crashed if a fullNode contained a valueNode directly. This bug is very unlikely to occur with SecureTrie, but can happen with regular tries. This commit also introduces a randomised test which triggers all trie operations, which should prevent such bugs in the future. Credit for finding this bug goes to Github user @rjl493456442. --- trie/trie.go | 3 + trie/trie_test.go | 159 ++++++++++++++++++++++++++++++++++++---------- 2 files changed, 127 insertions(+), 35 deletions(-) diff --git a/trie/trie.go b/trie/trie.go index 55481f4f7..55598af98 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -377,6 +377,9 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { // n still contains at least two values and cannot be reduced. return true, n, nil + case valueNode: + return true, nil, nil + case nil: return false, nil, nil diff --git a/trie/trie_test.go b/trie/trie_test.go index 5a3ea1be9..87a7ec258 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -21,8 +21,11 @@ import ( "encoding/binary" "fmt" "io/ioutil" + "math/rand" "os" + "reflect" "testing" + "testing/quick" "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/common" @@ -297,41 +300,6 @@ func TestReplication(t *testing.T) { } } -func paranoiaCheck(t1 *Trie) (bool, *Trie) { - t2 := new(Trie) - it := NewIterator(t1) - for it.Next() { - t2.Update(it.Key, it.Value) - } - return t2.Hash() == t1.Hash(), t2 -} - -func TestParanoia(t *testing.T) { - t.Skip() - trie := newEmpty() - - vals := []struct{ k, v string }{ - {"do", "verb"}, - {"ether", "wookiedoo"}, - {"horse", "stallion"}, - {"shaman", "horse"}, - {"doge", "coin"}, - {"ether", ""}, - {"dog", "puppy"}, - {"shaman", ""}, - {"somethingveryoddindeedthis is", "myothernodedata"}, - } - for _, val := range vals { - updateString(trie, val.k, val.v) - } - trie.Commit() - - ok, t2 := paranoiaCheck(trie) - if !ok { - t.Errorf("trie paranoia check failed %x %x", trie.Hash(), t2.Hash()) - } -} - // Not an actual test func TestOutput(t *testing.T) { t.Skip() @@ -356,7 +324,128 @@ func TestLargeValue(t *testing.T) { trie.Update([]byte("key1"), []byte{99, 99, 99, 99}) trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32)) trie.Hash() +} +type randTestStep struct { + op int + key []byte // for opUpdate, opDelete, opGet + value []byte // for opUpdate +} + +type randTest []randTestStep + +const ( + opUpdate = iota + opDelete + opGet + opCommit + opHash + opReset + opItercheckhash + opMax // boundary value, not an actual op +) + +func (randTest) Generate(r *rand.Rand, size int) reflect.Value { + var allKeys [][]byte + genKey := func() []byte { + if len(allKeys) < 2 || r.Intn(100) < 10 { + // new key + key := make([]byte, r.Intn(50)) + randRead(r, key) + allKeys = append(allKeys, key) + return key + } + // use existing key + return allKeys[r.Intn(len(allKeys))] + } + + var steps randTest + for i := 0; i < size; i++ { + step := randTestStep{op: r.Intn(opMax)} + switch step.op { + case opUpdate: + step.key = genKey() + step.value = make([]byte, 8) + binary.BigEndian.PutUint64(step.value, uint64(i)) + case opGet, opDelete: + step.key = genKey() + } + steps = append(steps, step) + } + return reflect.ValueOf(steps) +} + +// rand.Rand provides a Read method in Go 1.7 and later, but +// we can't use it yet. +func randRead(r *rand.Rand, b []byte) { + pos := 0 + val := 0 + for n := 0; n < len(b); n++ { + if pos == 0 { + val = r.Int() + pos = 7 + } + b[n] = byte(val) + val >>= 8 + pos-- + } +} + +func runRandTest(rt randTest) bool { + db, _ := ethdb.NewMemDatabase() + tr, _ := New(common.Hash{}, db) + values := make(map[string]string) // tracks content of the trie + + for _, step := range rt { + switch step.op { + case opUpdate: + tr.Update(step.key, step.value) + values[string(step.key)] = string(step.value) + case opDelete: + tr.Delete(step.key) + delete(values, string(step.key)) + case opGet: + v := tr.Get(step.key) + want := values[string(step.key)] + if string(v) != want { + fmt.Printf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want) + return false + } + case opCommit: + if _, err := tr.Commit(); err != nil { + panic(err) + } + case opHash: + tr.Hash() + case opReset: + hash, err := tr.Commit() + if err != nil { + panic(err) + } + newtr, err := New(hash, db) + if err != nil { + panic(err) + } + tr = newtr + case opItercheckhash: + checktr, _ := New(common.Hash{}, nil) + it := tr.Iterator() + for it.Next() { + checktr.Update(it.Key, it.Value) + } + if tr.Hash() != checktr.Hash() { + fmt.Println("hashes not equal") + return false + } + } + } + return true +} + +func TestRandom(t *testing.T) { + if err := quick.Check(runRandTest, nil); err != nil { + t.Fatal(err) + } } func BenchmarkGet(b *testing.B) { benchGet(b, false) }