diff --git a/trie/fullnode.go b/trie/fullnode.go index 522fdb373..95c8eb4be 100644 --- a/trie/fullnode.go +++ b/trie/fullnode.go @@ -1,17 +1,16 @@ package trie -import "fmt" - type FullNode struct { trie *Trie nodes [17]Node + dirty bool } func NewFullNode(t *Trie) *FullNode { return &FullNode{trie: t} } -func (self *FullNode) Dirty() bool { return true } +func (self *FullNode) Dirty() bool { return self.dirty } func (self *FullNode) Value() Node { self.nodes[16] = self.trie.trans(self.nodes[16]) return self.nodes[16] @@ -27,6 +26,7 @@ func (self *FullNode) Copy(t *Trie) Node { nnode.nodes[i] = node.Copy(t) } } + nnode.dirty = true return nnode } @@ -60,11 +60,8 @@ func (self *FullNode) RlpData() interface{} { } func (self *FullNode) set(k byte, value Node) { - if _, ok := value.(*ValueNode); ok && k != 16 { - fmt.Println(value, k) - } - self.nodes[int(k)] = value + self.dirty = true } func (self *FullNode) branch(i byte) Node { @@ -75,3 +72,7 @@ func (self *FullNode) branch(i byte) Node { } return nil } + +func (self *FullNode) setDirty(dirty bool) { + self.dirty = dirty +} diff --git a/trie/hashnode.go b/trie/hashnode.go index 8125cc3c9..e82ab8069 100644 --- a/trie/hashnode.go +++ b/trie/hashnode.go @@ -3,12 +3,13 @@ package trie import "github.com/ethereum/go-ethereum/common" type HashNode struct { - key []byte - trie *Trie + key []byte + trie *Trie + dirty bool } func NewHash(key []byte, trie *Trie) *HashNode { - return &HashNode{key, trie} + return &HashNode{key, trie, false} } func (self *HashNode) RlpData() interface{} { @@ -19,6 +20,10 @@ func (self *HashNode) Hash() interface{} { return self.key } +func (self *HashNode) setDirty(dirty bool) { + self.dirty = dirty +} + // These methods will never be called but we have to satisfy Node interface func (self *HashNode) Value() Node { return nil } func (self *HashNode) Dirty() bool { return true } diff --git a/trie/node.go b/trie/node.go index 0d8a7cff9..dccbc64a3 100644 --- a/trie/node.go +++ b/trie/node.go @@ -11,6 +11,7 @@ type Node interface { fstring(string) string Hash() interface{} RlpData() interface{} + setDirty(dirty bool) } // Value node diff --git a/trie/shortnode.go b/trie/shortnode.go index edd490b4d..c86e50096 100644 --- a/trie/shortnode.go +++ b/trie/shortnode.go @@ -6,20 +6,22 @@ type ShortNode struct { trie *Trie key []byte value Node + dirty bool } func NewShortNode(t *Trie, key []byte, value Node) *ShortNode { - return &ShortNode{t, []byte(CompactEncode(key)), value} + return &ShortNode{t, []byte(CompactEncode(key)), value, false} } func (self *ShortNode) Value() Node { self.value = self.trie.trans(self.value) return self.value } -func (self *ShortNode) Dirty() bool { return true } +func (self *ShortNode) Dirty() bool { return self.dirty } func (self *ShortNode) Copy(t *Trie) Node { - node := &ShortNode{t, nil, self.value.Copy(t)} + node := &ShortNode{t, nil, self.value.Copy(t), self.dirty} node.key = common.CopyBytes(self.key) + node.dirty = true return node } @@ -33,3 +35,7 @@ func (self *ShortNode) Hash() interface{} { func (self *ShortNode) Key() []byte { return CompactDecode(string(self.key)) } + +func (self *ShortNode) setDirty(dirty bool) { + self.dirty = dirty +} diff --git a/trie/trie.go b/trie/trie.go index d990338ee..7e17baa2f 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -117,7 +117,9 @@ func (self *Trie) Update(key, value []byte) Node { k := CompactHexDecode(string(key)) if len(value) != 0 { - self.root = self.insert(self.root, k, &ValueNode{self, value}) + node := NewValueNode(self, value) + node.dirty = true + self.root = self.insert(self.root, k, node) } else { self.root = self.delete(self.root, k) } @@ -157,7 +159,9 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { } if node == nil { - return NewShortNode(self, key, value) + node := NewShortNode(self, key, value) + node.dirty = true + return node } switch node := node.(type) { @@ -165,7 +169,10 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { k := node.Key() cnode := node.Value() if bytes.Equal(k, key) { - return NewShortNode(self, key, value) + node := NewShortNode(self, key, value) + node.dirty = true + return node + } var n Node @@ -176,6 +183,7 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { pnode := self.insert(nil, k[matchlength+1:], cnode) nnode := self.insert(nil, key[matchlength+1:], value) fulln := NewFullNode(self) + fulln.dirty = true fulln.set(k[matchlength], pnode) fulln.set(key[matchlength], nnode) n = fulln @@ -184,11 +192,14 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node { return n } - return NewShortNode(self, key[:matchlength], n) + snode := NewShortNode(self, key[:matchlength], n) + snode.dirty = true + return snode case *FullNode: cpy := node.Copy(self).(*FullNode) cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value)) + cpy.dirty = true return cpy @@ -242,8 +253,10 @@ func (self *Trie) delete(node Node, key []byte) Node { case *ShortNode: nkey := append(k, child.Key()...) n = NewShortNode(self, nkey, child.Value()) + n.(*ShortNode).dirty = true case *FullNode: sn := NewShortNode(self, node.Key(), child) + sn.dirty = true sn.key = node.key n = sn } @@ -256,6 +269,7 @@ func (self *Trie) delete(node Node, key []byte) Node { case *FullNode: n := node.Copy(self).(*FullNode) n.set(key[0], self.delete(n.branch(key[0]), key[1:])) + n.dirty = true pos := -1 for i := 0; i < 17; i++ { @@ -271,6 +285,7 @@ func (self *Trie) delete(node Node, key []byte) Node { var nnode Node if pos == 16 { nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos))) + nnode.(*ShortNode).dirty = true } else if pos >= 0 { cnode := n.branch(byte(pos)) switch cnode := cnode.(type) { @@ -278,8 +293,10 @@ func (self *Trie) delete(node Node, key []byte) Node { // Stitch keys k := append([]byte{byte(pos)}, cnode.Key()...) nnode = NewShortNode(self, k, cnode.Value()) + nnode.(*ShortNode).dirty = true case *FullNode: nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos))) + nnode.(*ShortNode).dirty = true } } else { nnode = n @@ -304,7 +321,7 @@ func (self *Trie) mknode(value *common.Value) Node { if value.Get(0).Len() != 0 { key := CompactDecode(string(value.Get(0).Bytes())) if key[len(key)-1] == 16 { - return NewShortNode(self, key, &ValueNode{self, value.Get(1).Bytes()}) + return NewShortNode(self, key, NewValueNode(self, value.Get(1).Bytes())) } else { return NewShortNode(self, key, self.mknode(value.Get(1))) } @@ -318,10 +335,10 @@ func (self *Trie) mknode(value *common.Value) Node { return fnode } case 32: - return &HashNode{value.Bytes(), self} + return NewHash(value.Bytes(), self) } - return &ValueNode{self, value.Bytes()} + return NewValueNode(self, value.Bytes()) } func (self *Trie) trans(node Node) Node { @@ -338,7 +355,11 @@ func (self *Trie) store(node Node) interface{} { data := common.Encode(node) if len(data) >= 32 { key := crypto.Sha3(data) - self.cache.Put(key, data) + if node.Dirty() { + //fmt.Println("save", node) + //fmt.Println() + self.cache.Put(key, data) + } return key } diff --git a/trie/trie_test.go b/trie/trie_test.go index 9a58958d8..60f0873a8 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -152,7 +152,7 @@ func TestReplication(t *testing.T) { } trie.Commit() - trie2 := New(trie.roothash, trie.cache.backend) + trie2 := New(trie.Root(), trie.cache.backend) if string(trie2.GetString("horse")) != "stallion" { t.Error("expected to have horse => stallion") } diff --git a/trie/valuenode.go b/trie/valuenode.go index 7bf8ff06e..6adb59652 100644 --- a/trie/valuenode.go +++ b/trie/valuenode.go @@ -3,13 +3,24 @@ package trie import "github.com/ethereum/go-ethereum/common" type ValueNode struct { - trie *Trie - data []byte + trie *Trie + data []byte + dirty bool } -func (self *ValueNode) Value() Node { return self } // Best not to call :-) -func (self *ValueNode) Val() []byte { return self.data } -func (self *ValueNode) Dirty() bool { return true } -func (self *ValueNode) Copy(t *Trie) Node { return &ValueNode{t, common.CopyBytes(self.data)} } +func NewValueNode(trie *Trie, data []byte) *ValueNode { + return &ValueNode{trie, data, false} +} + +func (self *ValueNode) Value() Node { return self } // Best not to call :-) +func (self *ValueNode) Val() []byte { return self.data } +func (self *ValueNode) Dirty() bool { return self.dirty } +func (self *ValueNode) Copy(t *Trie) Node { + return &ValueNode{t, common.CopyBytes(self.data), self.dirty} +} func (self *ValueNode) RlpData() interface{} { return self.data } func (self *ValueNode) Hash() interface{} { return self.data } + +func (self *ValueNode) setDirty(dirty bool) { + self.dirty = dirty +}