Merge pull request #1369 from obscuren/statedb-update-cleanup

core, core/state: throw out intermediate state
This commit is contained in:
Jeffrey Wilcke 2015-07-04 03:42:13 -07:00
commit 9c3db1be1d
16 changed files with 106 additions and 49 deletions

View File

@ -77,7 +77,7 @@ func (self *BlockProcessor) ApplyTransaction(coinbase *state.StateObject, stated
} }
// Update the state with pending changes // Update the state with pending changes
statedb.Update() statedb.SyncIntermediate()
usedGas.Add(usedGas, gas) usedGas.Add(usedGas, gas)
receipt := types.NewReceipt(statedb.Root().Bytes(), usedGas) receipt := types.NewReceipt(statedb.Root().Bytes(), usedGas)
@ -243,7 +243,7 @@ func (sm *BlockProcessor) processWithParent(block, parent *types.Block) (logs st
// Commit state objects/accounts to a temporary trie (does not save) // Commit state objects/accounts to a temporary trie (does not save)
// used to calculate the state root. // used to calculate the state root.
state.Update() state.SyncObjects()
if header.Root != state.Root() { if header.Root != state.Root() {
err = fmt.Errorf("invalid merkle root. received=%x got=%x", header.Root, state.Root()) err = fmt.Errorf("invalid merkle root. received=%x got=%x", header.Root, state.Root())
return return

View File

@ -77,7 +77,7 @@ func (b *BlockGen) AddTx(tx *types.Transaction) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
b.statedb.Update() b.statedb.SyncIntermediate()
b.header.GasUsed.Add(b.header.GasUsed, gas) b.header.GasUsed.Add(b.header.GasUsed, gas)
receipt := types.NewReceipt(b.statedb.Root().Bytes(), b.header.GasUsed) receipt := types.NewReceipt(b.statedb.Root().Bytes(), b.header.GasUsed)
logs := b.statedb.GetLogs(tx.Hash()) logs := b.statedb.GetLogs(tx.Hash())
@ -135,7 +135,7 @@ func GenerateChain(parent *types.Block, db common.Database, n int, gen func(int,
gen(i, b) gen(i, b)
} }
AccumulateRewards(statedb, h, b.uncles) AccumulateRewards(statedb, h, b.uncles)
statedb.Update() statedb.SyncIntermediate()
h.Root = statedb.Root() h.Root = statedb.Root()
return types.NewBlock(h, b.txs, b.uncles, b.receipts) return types.NewBlock(h, b.txs, b.uncles, b.receipts)
} }

View File

@ -64,7 +64,7 @@ func GenesisBlockForTesting(db common.Database, addr common.Address, balance *bi
statedb := state.New(common.Hash{}, db) statedb := state.New(common.Hash{}, db)
obj := statedb.GetOrNewStateObject(addr) obj := statedb.GetOrNewStateObject(addr)
obj.SetBalance(balance) obj.SetBalance(balance)
statedb.Update() statedb.SyncObjects()
statedb.Sync() statedb.Sync()
block := types.NewBlock(&types.Header{ block := types.NewBlock(&types.Header{
Difficulty: params.GenesisDifficulty, Difficulty: params.GenesisDifficulty,

View File

@ -57,8 +57,6 @@ type StateObject struct {
initCode Code initCode Code
// Cached storage (flushed when updated) // Cached storage (flushed when updated)
storage Storage storage Storage
// Temporary prepaid gas, reward after transition
prepaid *big.Int
// Total gas pool is the total amount of gas currently // Total gas pool is the total amount of gas currently
// left if this object is the coinbase. Gas is directly // left if this object is the coinbase. Gas is directly
@ -77,14 +75,10 @@ func (self *StateObject) Reset() {
} }
func NewStateObject(address common.Address, db common.Database) *StateObject { func NewStateObject(address common.Address, db common.Database) *StateObject {
// This to ensure that it has 20 bytes (and not 0 bytes), thus left or right pad doesn't matter.
//address := common.ToAddress(addr)
object := &StateObject{db: db, address: address, balance: new(big.Int), gasPool: new(big.Int), dirty: true} object := &StateObject{db: db, address: address, balance: new(big.Int), gasPool: new(big.Int), dirty: true}
object.trie = trie.NewSecure((common.Hash{}).Bytes(), db) object.trie = trie.NewSecure((common.Hash{}).Bytes(), db)
object.storage = make(Storage) object.storage = make(Storage)
object.gasPool = new(big.Int) object.gasPool = new(big.Int)
object.prepaid = new(big.Int)
return object return object
} }
@ -110,7 +104,6 @@ func NewStateObjectFromBytes(address common.Address, data []byte, db common.Data
object.trie = trie.NewSecure(extobject.Root[:], db) object.trie = trie.NewSecure(extobject.Root[:], db)
object.storage = make(map[string]common.Hash) object.storage = make(map[string]common.Hash)
object.gasPool = new(big.Int) object.gasPool = new(big.Int)
object.prepaid = new(big.Int)
object.code, _ = db.Get(extobject.CodeHash) object.code, _ = db.Get(extobject.CodeHash)
return object return object
@ -172,7 +165,6 @@ func (self *StateObject) Update() {
self.setAddr([]byte(key), value) self.setAddr([]byte(key), value)
} }
self.storage = make(Storage)
} }
func (c *StateObject) GetInstr(pc *big.Int) *common.Value { func (c *StateObject) GetInstr(pc *big.Int) *common.Value {

View File

@ -72,7 +72,7 @@ func TestNull(t *testing.T) {
//value := common.FromHex("0x823140710bf13990e4500136726d8b55") //value := common.FromHex("0x823140710bf13990e4500136726d8b55")
var value common.Hash var value common.Hash
state.SetState(address, common.Hash{}, value) state.SetState(address, common.Hash{}, value)
state.Update() state.SyncIntermediate()
state.Sync() state.Sync()
value = state.GetState(address, common.Hash{}) value = state.GetState(address, common.Hash{})
if !common.EmptyHash(value) { if !common.EmptyHash(value) {

View File

@ -18,6 +18,7 @@ import (
type StateDB struct { type StateDB struct {
db common.Database db common.Database
trie *trie.SecureTrie trie *trie.SecureTrie
root common.Hash
stateObjects map[string]*StateObject stateObjects map[string]*StateObject
@ -31,7 +32,7 @@ type StateDB struct {
// Create a new state from a given trie // Create a new state from a given trie
func New(root common.Hash, db common.Database) *StateDB { func New(root common.Hash, db common.Database) *StateDB {
trie := trie.NewSecure(root[:], db) trie := trie.NewSecure(root[:], db)
return &StateDB{db: db, trie: trie, stateObjects: make(map[string]*StateObject), refund: new(big.Int), logs: make(map[common.Hash]Logs)} return &StateDB{root: root, db: db, trie: trie, stateObjects: make(map[string]*StateObject), refund: new(big.Int), logs: make(map[common.Hash]Logs)}
} }
func (self *StateDB) PrintRoot() { func (self *StateDB) PrintRoot() {
@ -185,7 +186,7 @@ func (self *StateDB) DeleteStateObject(stateObject *StateObject) {
addr := stateObject.Address() addr := stateObject.Address()
self.trie.Delete(addr[:]) self.trie.Delete(addr[:])
delete(self.stateObjects, addr.Str()) //delete(self.stateObjects, addr.Str())
} }
// Retrieve a state object given my the address. Nil if not found // Retrieve a state object given my the address. Nil if not found
@ -323,7 +324,8 @@ func (self *StateDB) Refunds() *big.Int {
return self.refund return self.refund
} }
func (self *StateDB) Update() { // SyncIntermediate updates the intermediate state and all mid steps
func (self *StateDB) SyncIntermediate() {
self.refund = new(big.Int) self.refund = new(big.Int)
for _, stateObject := range self.stateObjects { for _, stateObject := range self.stateObjects {
@ -340,6 +342,24 @@ func (self *StateDB) Update() {
} }
} }
// SyncObjects syncs the changed objects to the trie
func (self *StateDB) SyncObjects() {
self.trie = trie.NewSecure(self.root[:], self.db)
self.refund = new(big.Int)
for _, stateObject := range self.stateObjects {
if stateObject.remove {
self.DeleteStateObject(stateObject)
} else {
stateObject.Update()
self.UpdateStateObject(stateObject)
}
stateObject.dirty = false
}
}
// Debug stuff // Debug stuff
func (self *StateDB) CreateOutputForDiff() { func (self *StateDB) CreateOutputForDiff() {
for _, stateObject := range self.stateObjects { for _, stateObject := range self.stateObjects {

View File

@ -453,7 +453,7 @@ func (self *worker) commitNewWork() {
if atomic.LoadInt32(&self.mining) == 1 { if atomic.LoadInt32(&self.mining) == 1 {
// commit state root after all state transitions. // commit state root after all state transitions.
core.AccumulateRewards(self.current.state, header, uncles) core.AccumulateRewards(self.current.state, header, uncles)
current.state.Update() current.state.SyncObjects()
self.current.state.Sync() self.current.state.Sync()
header.Root = current.state.Root() header.Root = current.state.Root()
} }

View File

@ -215,7 +215,7 @@ func (t *BlockTest) InsertPreState(ethereum *eth.Ethereum) (*state.StateDB, erro
} }
} }
// sync objects to trie // sync objects to trie
statedb.Update() statedb.SyncObjects()
// sync trie to disk // sync trie to disk
statedb.Sync() statedb.Sync()

View File

@ -175,7 +175,7 @@ func RunState(statedb *state.StateDB, env, tx map[string]string) ([]byte, state.
if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || state.IsGasLimitErr(err) { if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || state.IsGasLimitErr(err) {
statedb.Set(snapshot) statedb.Set(snapshot)
} }
statedb.Update() statedb.SyncObjects()
return ret, vmenv.state.Logs(), vmenv.Gas, err return ret, vmenv.state.Logs(), vmenv.Gas, err
} }

View File

@ -1,17 +1,16 @@
package trie package trie
import "fmt"
type FullNode struct { type FullNode struct {
trie *Trie trie *Trie
nodes [17]Node nodes [17]Node
dirty bool
} }
func NewFullNode(t *Trie) *FullNode { func NewFullNode(t *Trie) *FullNode {
return &FullNode{trie: t} return &FullNode{trie: t}
} }
func (self *FullNode) Dirty() bool { return true } func (self *FullNode) Dirty() bool { return self.dirty }
func (self *FullNode) Value() Node { func (self *FullNode) Value() Node {
self.nodes[16] = self.trie.trans(self.nodes[16]) self.nodes[16] = self.trie.trans(self.nodes[16])
return self.nodes[16] return self.nodes[16]
@ -24,9 +23,10 @@ func (self *FullNode) Copy(t *Trie) Node {
nnode := NewFullNode(t) nnode := NewFullNode(t)
for i, node := range self.nodes { for i, node := range self.nodes {
if node != nil { if node != nil {
nnode.nodes[i] = node.Copy(t) nnode.nodes[i] = node
} }
} }
nnode.dirty = true
return nnode return nnode
} }
@ -60,11 +60,8 @@ func (self *FullNode) RlpData() interface{} {
} }
func (self *FullNode) set(k byte, value Node) { func (self *FullNode) set(k byte, value Node) {
if _, ok := value.(*ValueNode); ok && k != 16 {
fmt.Println(value, k)
}
self.nodes[int(k)] = value self.nodes[int(k)] = value
self.dirty = true
} }
func (self *FullNode) branch(i byte) Node { func (self *FullNode) branch(i byte) Node {
@ -75,3 +72,7 @@ func (self *FullNode) branch(i byte) Node {
} }
return nil return nil
} }
func (self *FullNode) setDirty(dirty bool) {
self.dirty = dirty
}

View File

@ -3,12 +3,13 @@ package trie
import "github.com/ethereum/go-ethereum/common" import "github.com/ethereum/go-ethereum/common"
type HashNode struct { type HashNode struct {
key []byte key []byte
trie *Trie trie *Trie
dirty bool
} }
func NewHash(key []byte, trie *Trie) *HashNode { func NewHash(key []byte, trie *Trie) *HashNode {
return &HashNode{key, trie} return &HashNode{key, trie, false}
} }
func (self *HashNode) RlpData() interface{} { func (self *HashNode) RlpData() interface{} {
@ -19,6 +20,10 @@ func (self *HashNode) Hash() interface{} {
return self.key return self.key
} }
func (self *HashNode) setDirty(dirty bool) {
self.dirty = dirty
}
// These methods will never be called but we have to satisfy Node interface // These methods will never be called but we have to satisfy Node interface
func (self *HashNode) Value() Node { return nil } func (self *HashNode) Value() Node { return nil }
func (self *HashNode) Dirty() bool { return true } func (self *HashNode) Dirty() bool { return true }

View File

@ -11,6 +11,7 @@ type Node interface {
fstring(string) string fstring(string) string
Hash() interface{} Hash() interface{}
RlpData() interface{} RlpData() interface{}
setDirty(dirty bool)
} }
// Value node // Value node

View File

@ -6,20 +6,22 @@ type ShortNode struct {
trie *Trie trie *Trie
key []byte key []byte
value Node value Node
dirty bool
} }
func NewShortNode(t *Trie, key []byte, value Node) *ShortNode { func NewShortNode(t *Trie, key []byte, value Node) *ShortNode {
return &ShortNode{t, []byte(CompactEncode(key)), value} return &ShortNode{t, []byte(CompactEncode(key)), value, false}
} }
func (self *ShortNode) Value() Node { func (self *ShortNode) Value() Node {
self.value = self.trie.trans(self.value) self.value = self.trie.trans(self.value)
return self.value return self.value
} }
func (self *ShortNode) Dirty() bool { return true } func (self *ShortNode) Dirty() bool { return self.dirty }
func (self *ShortNode) Copy(t *Trie) Node { func (self *ShortNode) Copy(t *Trie) Node {
node := &ShortNode{t, nil, self.value.Copy(t)} node := &ShortNode{t, nil, self.value.Copy(t), self.dirty}
node.key = common.CopyBytes(self.key) node.key = common.CopyBytes(self.key)
node.dirty = true
return node return node
} }
@ -33,3 +35,7 @@ func (self *ShortNode) Hash() interface{} {
func (self *ShortNode) Key() []byte { func (self *ShortNode) Key() []byte {
return CompactDecode(string(self.key)) return CompactDecode(string(self.key))
} }
func (self *ShortNode) setDirty(dirty bool) {
self.dirty = dirty
}

View File

@ -117,7 +117,9 @@ func (self *Trie) Update(key, value []byte) Node {
k := CompactHexDecode(string(key)) k := CompactHexDecode(string(key))
if len(value) != 0 { if len(value) != 0 {
self.root = self.insert(self.root, k, &ValueNode{self, value}) node := NewValueNode(self, value)
node.dirty = true
self.root = self.insert(self.root, k, node)
} else { } else {
self.root = self.delete(self.root, k) self.root = self.delete(self.root, k)
} }
@ -157,7 +159,9 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node {
} }
if node == nil { if node == nil {
return NewShortNode(self, key, value) node := NewShortNode(self, key, value)
node.dirty = true
return node
} }
switch node := node.(type) { switch node := node.(type) {
@ -165,7 +169,10 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node {
k := node.Key() k := node.Key()
cnode := node.Value() cnode := node.Value()
if bytes.Equal(k, key) { if bytes.Equal(k, key) {
return NewShortNode(self, key, value) node := NewShortNode(self, key, value)
node.dirty = true
return node
} }
var n Node var n Node
@ -176,6 +183,7 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node {
pnode := self.insert(nil, k[matchlength+1:], cnode) pnode := self.insert(nil, k[matchlength+1:], cnode)
nnode := self.insert(nil, key[matchlength+1:], value) nnode := self.insert(nil, key[matchlength+1:], value)
fulln := NewFullNode(self) fulln := NewFullNode(self)
fulln.dirty = true
fulln.set(k[matchlength], pnode) fulln.set(k[matchlength], pnode)
fulln.set(key[matchlength], nnode) fulln.set(key[matchlength], nnode)
n = fulln n = fulln
@ -184,11 +192,14 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node {
return n return n
} }
return NewShortNode(self, key[:matchlength], n) snode := NewShortNode(self, key[:matchlength], n)
snode.dirty = true
return snode
case *FullNode: case *FullNode:
cpy := node.Copy(self).(*FullNode) cpy := node.Copy(self).(*FullNode)
cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value)) cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value))
cpy.dirty = true
return cpy return cpy
@ -242,8 +253,10 @@ func (self *Trie) delete(node Node, key []byte) Node {
case *ShortNode: case *ShortNode:
nkey := append(k, child.Key()...) nkey := append(k, child.Key()...)
n = NewShortNode(self, nkey, child.Value()) n = NewShortNode(self, nkey, child.Value())
n.(*ShortNode).dirty = true
case *FullNode: case *FullNode:
sn := NewShortNode(self, node.Key(), child) sn := NewShortNode(self, node.Key(), child)
sn.dirty = true
sn.key = node.key sn.key = node.key
n = sn n = sn
} }
@ -256,6 +269,7 @@ func (self *Trie) delete(node Node, key []byte) Node {
case *FullNode: case *FullNode:
n := node.Copy(self).(*FullNode) n := node.Copy(self).(*FullNode)
n.set(key[0], self.delete(n.branch(key[0]), key[1:])) n.set(key[0], self.delete(n.branch(key[0]), key[1:]))
n.dirty = true
pos := -1 pos := -1
for i := 0; i < 17; i++ { for i := 0; i < 17; i++ {
@ -271,6 +285,7 @@ func (self *Trie) delete(node Node, key []byte) Node {
var nnode Node var nnode Node
if pos == 16 { if pos == 16 {
nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos))) nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos)))
nnode.(*ShortNode).dirty = true
} else if pos >= 0 { } else if pos >= 0 {
cnode := n.branch(byte(pos)) cnode := n.branch(byte(pos))
switch cnode := cnode.(type) { switch cnode := cnode.(type) {
@ -278,8 +293,10 @@ func (self *Trie) delete(node Node, key []byte) Node {
// Stitch keys // Stitch keys
k := append([]byte{byte(pos)}, cnode.Key()...) k := append([]byte{byte(pos)}, cnode.Key()...)
nnode = NewShortNode(self, k, cnode.Value()) nnode = NewShortNode(self, k, cnode.Value())
nnode.(*ShortNode).dirty = true
case *FullNode: case *FullNode:
nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos))) nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos)))
nnode.(*ShortNode).dirty = true
} }
} else { } else {
nnode = n nnode = n
@ -304,7 +321,7 @@ func (self *Trie) mknode(value *common.Value) Node {
if value.Get(0).Len() != 0 { if value.Get(0).Len() != 0 {
key := CompactDecode(string(value.Get(0).Bytes())) key := CompactDecode(string(value.Get(0).Bytes()))
if key[len(key)-1] == 16 { if key[len(key)-1] == 16 {
return NewShortNode(self, key, &ValueNode{self, value.Get(1).Bytes()}) return NewShortNode(self, key, NewValueNode(self, value.Get(1).Bytes()))
} else { } else {
return NewShortNode(self, key, self.mknode(value.Get(1))) return NewShortNode(self, key, self.mknode(value.Get(1)))
} }
@ -318,10 +335,10 @@ func (self *Trie) mknode(value *common.Value) Node {
return fnode return fnode
} }
case 32: case 32:
return &HashNode{value.Bytes(), self} return NewHash(value.Bytes(), self)
} }
return &ValueNode{self, value.Bytes()} return NewValueNode(self, value.Bytes())
} }
func (self *Trie) trans(node Node) Node { func (self *Trie) trans(node Node) Node {
@ -338,7 +355,11 @@ func (self *Trie) store(node Node) interface{} {
data := common.Encode(node) data := common.Encode(node)
if len(data) >= 32 { if len(data) >= 32 {
key := crypto.Sha3(data) key := crypto.Sha3(data)
self.cache.Put(key, data) if node.Dirty() {
//fmt.Println("save", node)
//fmt.Println()
self.cache.Put(key, data)
}
return key return key
} }

View File

@ -152,7 +152,7 @@ func TestReplication(t *testing.T) {
} }
trie.Commit() trie.Commit()
trie2 := New(trie.roothash, trie.cache.backend) trie2 := New(trie.Root(), trie.cache.backend)
if string(trie2.GetString("horse")) != "stallion" { if string(trie2.GetString("horse")) != "stallion" {
t.Error("expected to have horse => stallion") t.Error("expected to have horse => stallion")
} }

View File

@ -3,13 +3,24 @@ package trie
import "github.com/ethereum/go-ethereum/common" import "github.com/ethereum/go-ethereum/common"
type ValueNode struct { type ValueNode struct {
trie *Trie trie *Trie
data []byte data []byte
dirty bool
} }
func (self *ValueNode) Value() Node { return self } // Best not to call :-) func NewValueNode(trie *Trie, data []byte) *ValueNode {
func (self *ValueNode) Val() []byte { return self.data } return &ValueNode{trie, data, false}
func (self *ValueNode) Dirty() bool { return true } }
func (self *ValueNode) Copy(t *Trie) Node { return &ValueNode{t, common.CopyBytes(self.data)} }
func (self *ValueNode) Value() Node { return self } // Best not to call :-)
func (self *ValueNode) Val() []byte { return self.data }
func (self *ValueNode) Dirty() bool { return self.dirty }
func (self *ValueNode) Copy(t *Trie) Node {
return &ValueNode{t, common.CopyBytes(self.data), self.dirty}
}
func (self *ValueNode) RlpData() interface{} { return self.data } func (self *ValueNode) RlpData() interface{} { return self.data }
func (self *ValueNode) Hash() interface{} { return self.data } func (self *ValueNode) Hash() interface{} { return self.data }
func (self *ValueNode) setDirty(dirty bool) {
self.dirty = dirty
}