compilable trie (tests fail)

This commit is contained in:
obscuren 2015-03-16 16:28:16 +01:00
parent 20b7162a62
commit d338650089
6 changed files with 92 additions and 73 deletions

View File

@ -27,27 +27,27 @@ func StringToAddress(s string) Address { return BytesToAddress([]byte(s)) }
// Don't use the default 'String' method in case we want to overwrite // Don't use the default 'String' method in case we want to overwrite
// Get the string representation of the underlying hash // Get the string representation of the underlying hash
func (h Hash) Str() string { func (h *Hash) Str() string {
return string(h[:]) return string(h[:])
} }
// Sets the hash to the value of b. If b is larger than len(h) it will panic // Sets the hash to the value of b. If b is larger than len(h) it will panic
func (h Hash) SetBytes(b []byte) { func (h *Hash) SetBytes(b []byte) {
if len(b) > len(h) { if len(b) > len(h) {
panic("unable to set bytes. too big") panic("unable to set bytes. too big")
} }
// reverse loop // reverse loop
for i := len(b); i >= 0; i-- { for i := len(b) - 1; i >= 0; i-- {
h[i] = b[i] h[i] = b[i]
} }
} }
// Set string `s` to h. If s is larger than len(h) it will panic // Set string `s` to h. If s is larger than len(h) it will panic
func (h Hash) SetString(s string) { h.SetBytes([]byte(s)) } func (h *Hash) SetString(s string) { h.SetBytes([]byte(s)) }
// Sets h to other // Sets h to other
func (h Hash) Set(other Hash) { func (h *Hash) Set(other Hash) {
for i, v := range other { for i, v := range other {
h[i] = v h[i] = v
} }

View File

@ -2,17 +2,19 @@ package trie
import ( import (
"bytes" "bytes"
"github.com/ethereum/go-ethereum/common"
) )
type Iterator struct { type Iterator struct {
trie *Trie trie *Trie
Key []byte Key common.Hash
Value []byte Value []byte
} }
func NewIterator(trie *Trie) *Iterator { func NewIterator(trie *Trie) *Iterator {
return &Iterator{trie: trie, Key: nil} return &Iterator{trie: trie}
} }
func (self *Iterator) Next() bool { func (self *Iterator) Next() bool {
@ -20,15 +22,15 @@ func (self *Iterator) Next() bool {
defer self.trie.mu.Unlock() defer self.trie.mu.Unlock()
isIterStart := false isIterStart := false
if self.Key == nil { if (self.Key == common.Hash{}) {
isIterStart = true isIterStart = true
self.Key = make([]byte, 32) //self.Key = make([]byte, 32)
} }
key := RemTerm(CompactHexDecode(string(self.Key))) key := RemTerm(CompactHexDecode(self.Key.Str()))
k := self.next(self.trie.root, key, isIterStart) k := self.next(self.trie.root, key, isIterStart)
self.Key = []byte(DecodeCompact(k)) self.Key = common.StringToHash(DecodeCompact(k))
return len(k) > 0 return len(k) > 0
} }

View File

@ -22,7 +22,7 @@ func TestIterator(t *testing.T) {
it := trie.Iterator() it := trie.Iterator()
for it.Next() { for it.Next() {
v[string(it.Key)] = true v[it.Key.Str()] = true
} }
for k, found := range v { for k, found := range v {

View File

@ -1,34 +1,38 @@
package trie package trie
import "github.com/ethereum/go-ethereum/crypto" import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
)
type SecureTrie struct { type SecureTrie struct {
*Trie *Trie
} }
func NewSecure(root []byte, backend Backend) *SecureTrie { func NewSecure(root common.Hash, backend Backend) *SecureTrie {
return &SecureTrie{New(root, backend)} return &SecureTrie{New(root, backend)}
} }
func (self *SecureTrie) Update(key, value []byte) Node { func (self *SecureTrie) Update(key common.Hash, value []byte) Node {
return self.Trie.Update(crypto.Sha3(key), value) return self.Trie.Update(common.BytesToHash(crypto.Sha3(key[:])), value)
}
func (self *SecureTrie) UpdateString(key, value string) Node {
return self.Update([]byte(key), []byte(value))
} }
func (self *SecureTrie) Get(key []byte) []byte { func (self *SecureTrie) UpdateString(key, value string) Node {
return self.Trie.Get(crypto.Sha3(key)) return self.Update(common.StringToHash(key), []byte(value))
}
func (self *SecureTrie) Get(key common.Hash) []byte {
return self.Trie.Get(common.BytesToHash(crypto.Sha3(key[:])))
} }
func (self *SecureTrie) GetString(key string) []byte { func (self *SecureTrie) GetString(key string) []byte {
return self.Get([]byte(key)) return self.Get(common.StringToHash(key))
} }
func (self *SecureTrie) Delete(key []byte) Node { func (self *SecureTrie) Delete(key common.Hash) Node {
return self.Trie.Delete(crypto.Sha3(key)) return self.Trie.Delete(common.BytesToHash(crypto.Sha3(key[:])))
} }
func (self *SecureTrie) DeleteString(key string) Node { func (self *SecureTrie) DeleteString(key string) Node {
return self.Delete([]byte(key)) return self.Delete(common.StringToHash(key))
} }
func (self *SecureTrie) Copy() *SecureTrie { func (self *SecureTrie) Copy() *SecureTrie {

View File

@ -11,14 +11,15 @@ import (
) )
func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) { func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) {
t2 := New(nil, backend) t2 := New(common.Hash{}, backend)
it := t1.Iterator() it := t1.Iterator()
for it.Next() { for it.Next() {
t2.Update(it.Key, it.Value) t2.Update(it.Key, it.Value)
} }
return bytes.Equal(t2.Hash(), t1.Hash()), t2 a, b := t2.Hash(), t1.Hash()
return bytes.Equal(a[:], b[:]), t2
} }
type Trie struct { type Trie struct {
@ -38,8 +39,8 @@ func New(root common.Hash, backend Backend) *Trie {
trie.cache = NewCache(backend) trie.cache = NewCache(backend)
} }
if root != nil { if (root != common.Hash{}) {
value := common.NewValueFromBytes(trie.cache.Get(root)) value := common.NewValueFromBytes(trie.cache.Get(root[:]))
trie.root = trie.mknode(value) trie.root = trie.mknode(value)
} }
@ -51,12 +52,13 @@ func (self *Trie) Iterator() *Iterator {
} }
func (self *Trie) Copy() *Trie { func (self *Trie) Copy() *Trie {
//cpy := make([]byte, 32)
//copy(cpy, self.roothash)
// cheap copying method // cheap copying method
var cpy common.Hash var cpy common.Hash
cpy.Set(self.roothash[:]) cpy.Set(self.roothash)
cpy := make([]byte, 32) trie := New(common.Hash{}, nil)
copy(cpy, self.roothash)
trie := New(nil, nil)
trie.cache = self.cache.Copy() trie.cache = self.cache.Copy()
if self.root != nil { if self.root != nil {
trie.root = self.root.Copy(trie) trie.root = self.root.Copy(trie)
@ -66,21 +68,21 @@ func (self *Trie) Copy() *Trie {
} }
// Legacy support // Legacy support
func (self *Trie) Root() []byte { return self.Hash() } func (self *Trie) Root() common.Hash { return self.Hash() }
func (self *Trie) Hash() []byte { func (self *Trie) Hash() common.Hash {
var hash []byte var hash common.Hash
if self.root != nil { if self.root != nil {
t := self.root.Hash() t := self.root.Hash()
if byts, ok := t.([]byte); ok && len(byts) > 0 { if h, ok := t.(common.Hash); ok && (h != common.Hash{}) {
hash = byts hash = h
} else { } else {
hash = crypto.Sha3(common.Encode(self.root.RlpData())) hash = common.BytesToHash(crypto.Sha3(common.Encode(self.root.RlpData())))
} }
} else { } else {
hash = crypto.Sha3(common.Encode("")) hash = common.BytesToHash(crypto.Sha3(common.Encode("")))
} }
if !bytes.Equal(hash, self.roothash) { if hash != self.roothash {
self.revisions.PushBack(self.roothash) self.revisions.PushBack(self.roothash)
self.roothash = hash self.roothash = hash
} }
@ -105,19 +107,21 @@ func (self *Trie) Reset() {
self.cache.Reset() self.cache.Reset()
if self.revisions.Len() > 0 { if self.revisions.Len() > 0 {
revision := self.revisions.Remove(self.revisions.Back()).([]byte) revision := self.revisions.Remove(self.revisions.Back()).(common.Hash)
self.roothash = revision self.roothash = revision
} }
value := common.NewValueFromBytes(self.cache.Get(self.roothash)) value := common.NewValueFromBytes(self.cache.Get(self.roothash[:]))
self.root = self.mknode(value) self.root = self.mknode(value)
} }
func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } func (self *Trie) UpdateString(key, value string) Node {
func (self *Trie) Update(key, value []byte) Node { return self.Update(common.StringToHash(key), []byte(value))
}
func (self *Trie) Update(key common.Hash, value []byte) Node {
self.mu.Lock() self.mu.Lock()
defer self.mu.Unlock() defer self.mu.Unlock()
k := CompactHexDecode(string(key)) k := CompactHexDecode(key.Str())
if len(value) != 0 { if len(value) != 0 {
self.root = self.insert(self.root, k, &ValueNode{self, value}) self.root = self.insert(self.root, k, &ValueNode{self, value})
@ -128,12 +132,12 @@ func (self *Trie) Update(key, value []byte) Node {
return self.root return self.root
} }
func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) } func (self *Trie) GetString(key string) []byte { return self.Get(common.StringToHash(key)) }
func (self *Trie) Get(key []byte) []byte { func (self *Trie) Get(key common.Hash) []byte {
self.mu.Lock() self.mu.Lock()
defer self.mu.Unlock() defer self.mu.Unlock()
k := CompactHexDecode(string(key)) k := CompactHexDecode(key.Str())
n := self.get(self.root, k) n := self.get(self.root, k)
if n != nil { if n != nil {
@ -143,12 +147,12 @@ func (self *Trie) Get(key []byte) []byte {
return nil return nil
} }
func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) } func (self *Trie) DeleteString(key string) Node { return self.Delete(common.StringToHash(key)) }
func (self *Trie) Delete(key []byte) Node { func (self *Trie) Delete(key common.Hash) Node {
self.mu.Lock() self.mu.Lock()
defer self.mu.Unlock() defer self.mu.Unlock()
k := CompactHexDecode(string(key)) k := CompactHexDecode(key.Str())
self.root = self.delete(self.root, k) self.root = self.delete(self.root, k)
return self.root return self.root

View File

@ -5,8 +5,8 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
) )
type Db map[string][]byte type Db map[string][]byte
@ -16,18 +16,18 @@ func (self Db) Put(k, v []byte) { self[string(k)] = v }
// Used for testing // Used for testing
func NewEmpty() *Trie { func NewEmpty() *Trie {
return New(nil, make(Db)) return New(common.Hash{}, make(Db))
} }
func NewEmptySecure() *SecureTrie { func NewEmptySecure() *SecureTrie {
return NewSecure(nil, make(Db)) return NewSecure(common.Hash{}, make(Db))
} }
func TestEmptyTrie(t *testing.T) { func TestEmptyTrie(t *testing.T) {
trie := NewEmpty() trie := NewEmpty()
res := trie.Hash() res := trie.Hash()
exp := crypto.Sha3(common.Encode("")) exp := crypto.Sha3(common.Encode(""))
if !bytes.Equal(res, exp) { if !bytes.Equal(res[:], exp[:]) {
t.Errorf("expected %x got %x", exp, res) t.Errorf("expected %x got %x", exp, res)
} }
} }
@ -41,7 +41,7 @@ func TestInsert(t *testing.T) {
exp := common.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3") exp := common.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
root := trie.Hash() root := trie.Hash()
if !bytes.Equal(root, exp) { if !bytes.Equal(root[:], exp[:]) {
t.Errorf("exp %x got %x", exp, root) t.Errorf("exp %x got %x", exp, root)
} }
@ -50,7 +50,7 @@ func TestInsert(t *testing.T) {
exp = common.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") exp = common.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
root = trie.Hash() root = trie.Hash()
if !bytes.Equal(root, exp) { if !bytes.Equal(root[:], exp) {
t.Errorf("exp %x got %x", exp, root) t.Errorf("exp %x got %x", exp, root)
} }
} }
@ -96,7 +96,7 @@ func TestDelete(t *testing.T) {
hash := trie.Hash() hash := trie.Hash()
exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
if !bytes.Equal(hash, exp) { if !bytes.Equal(hash[:], exp) {
t.Errorf("expected %x got %x", exp, hash) t.Errorf("expected %x got %x", exp, hash)
} }
} }
@ -120,7 +120,7 @@ func TestEmptyValues(t *testing.T) {
hash := trie.Hash() hash := trie.Hash()
exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
if !bytes.Equal(hash, exp) { if !bytes.Equal(hash[:], exp) {
t.Errorf("expected %x got %x", exp, hash) t.Errorf("expected %x got %x", exp, hash)
} }
} }
@ -150,7 +150,7 @@ func TestReplication(t *testing.T) {
hash := trie2.Hash() hash := trie2.Hash()
exp := trie.Hash() exp := trie.Hash()
if !bytes.Equal(hash, exp) { if !bytes.Equal(hash[:], exp[:]) {
t.Errorf("root failure. expected %x got %x", exp, hash) t.Errorf("root failure. expected %x got %x", exp, hash)
} }
@ -168,7 +168,9 @@ func TestReset(t *testing.T) {
} }
trie.Commit() trie.Commit()
before := common.CopyBytes(trie.roothash) var before common.Hash
before.Set(trie.roothash)
trie.UpdateString("should", "revert") trie.UpdateString("should", "revert")
trie.Hash() trie.Hash()
// Should have no effect // Should have no effect
@ -177,9 +179,11 @@ func TestReset(t *testing.T) {
// ### // ###
trie.Reset() trie.Reset()
after := common.CopyBytes(trie.roothash)
if !bytes.Equal(before, after) { var after common.Hash
after.Set(trie.roothash)
if before != after {
t.Errorf("expected roots to be equal. %x - %x", before, after) t.Errorf("expected roots to be equal. %x - %x", before, after)
} }
} }
@ -248,7 +252,7 @@ func BenchmarkGets(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
trie.Get([]byte("horse")) trie.GetString("horse")
} }
} }
@ -263,7 +267,8 @@ func BenchmarkUpdate(b *testing.B) {
} }
type kv struct { type kv struct {
k, v []byte k common.Hash
v []byte
t bool t bool
} }
@ -272,17 +277,21 @@ func TestLargeData(t *testing.T) {
vals := make(map[string]*kv) vals := make(map[string]*kv)
for i := byte(0); i < 255; i++ { for i := byte(0); i < 255; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} var k1 common.Hash
value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} k1.SetBytes([]byte{i})
var k2 common.Hash
k2.SetBytes([]byte{10, i})
value := &kv{k1, []byte{i}, false}
value2 := &kv{k2, []byte{i}, false}
trie.Update(value.k, value.v) trie.Update(value.k, value.v)
trie.Update(value2.k, value2.v) trie.Update(value2.k, value2.v)
vals[string(value.k)] = value vals[value.k.Str()] = value
vals[string(value2.k)] = value2 vals[value2.k.Str()] = value2
} }
it := trie.Iterator() it := trie.Iterator()
for it.Next() { for it.Next() {
vals[string(it.Key)].t = true vals[it.Key.Str()].t = true
} }
var untouched []*kv var untouched []*kv
@ -323,7 +332,7 @@ func TestSecureDelete(t *testing.T) {
hash := trie.Hash() hash := trie.Hash()
exp := common.Hex2Bytes("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d") exp := common.Hex2Bytes("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d")
if !bytes.Equal(hash, exp) { if !bytes.Equal(hash[:], exp) {
t.Errorf("expected %x got %x", exp, hash) t.Errorf("expected %x got %x", exp, hash)
} }
} }