diff --git a/core/state/state_object.go b/core/state/state_object.go index 121a2ec5c..cbd50e2a3 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -75,9 +75,11 @@ type StateObject struct { dbErr error // Write caches. - trie *trie.SecureTrie // storage trie, which becomes non-nil on first access - code Code // contract bytecode, which gets set when code is loaded - storage Storage // Cached storage (flushed when updated) + trie *trie.SecureTrie // storage trie, which becomes non-nil on first access + code Code // contract bytecode, which gets set when code is loaded + + cachedStorage Storage // Storage entry cache to avoid duplicate reads + dirtyStorage Storage // Storage entries that need to be flushed to disk // Cache flags. // When an object is marked for deletion it will be delete from the trie @@ -105,7 +107,7 @@ func NewObject(address common.Address, data Account, onDirty func(addr common.Ad if data.CodeHash == nil { data.CodeHash = emptyCodeHash } - return &StateObject{address: address, data: data, storage: make(Storage), onDirty: onDirty} + return &StateObject{address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} } // EncodeRLP implements rlp.Encoder. @@ -145,7 +147,7 @@ func (c *StateObject) getTrie(db trie.Database) *trie.SecureTrie { // GetState returns a value in account storage. func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash { - value, exists := self.storage[key] + value, exists := self.cachedStorage[key] if exists { return value } @@ -155,14 +157,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash rlp.DecodeBytes(tr.Get(key[:]), &ret) value = common.BytesToHash(ret) if (value != common.Hash{}) { - self.storage[key] = value + self.cachedStorage[key] = value } return value } // SetState updates a value in account storage. func (self *StateObject) SetState(key, value common.Hash) { - self.storage[key] = value + self.cachedStorage[key] = value + self.dirtyStorage[key] = value + if self.onDirty != nil { self.onDirty(self.Address()) self.onDirty = nil @@ -172,7 +176,8 @@ func (self *StateObject) SetState(key, value common.Hash) { // updateTrie writes cached storage modifications into the object's storage trie. func (self *StateObject) updateTrie(db trie.Database) { tr := self.getTrie(db) - for key, value := range self.storage { + for key, value := range self.dirtyStorage { + delete(self.dirtyStorage, key) if (value == common.Hash{}) { tr.Delete(key[:]) continue @@ -241,7 +246,8 @@ func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address stateObject := NewObject(self.address, self.data, onDirty) stateObject.trie = self.trie stateObject.code = self.code - stateObject.storage = self.storage.Copy() + stateObject.dirtyStorage = self.dirtyStorage.Copy() + stateObject.cachedStorage = self.dirtyStorage.Copy() stateObject.remove = self.remove stateObject.dirtyCode = self.dirtyCode stateObject.deleted = self.deleted @@ -312,7 +318,7 @@ func (self *StateObject) Value() *big.Int { func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) { // When iterating over the storage check the cache first - for h, value := range self.storage { + for h, value := range self.cachedStorage { cb(h, value) } @@ -320,7 +326,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) { for it.Next() { // ignore cached values key := common.BytesToHash(self.trie.GetKey(it.Key)) - if _, ok := self.storage[key]; !ok { + if _, ok := self.cachedStorage[key]; !ok { cb(key, common.BytesToHash(it.Value)) } } diff --git a/core/state/state_test.go b/core/state/state_test.go index 5fe98939b..7b9b39e06 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -208,16 +208,16 @@ func compareStateObjects(so0, so1 *StateObject, t *testing.T) { t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code) } - if len(so1.storage) != len(so0.storage) { - t.Errorf("Storage size mismatch: have %d, want %d", len(so1.storage), len(so0.storage)) + if len(so1.cachedStorage) != len(so0.cachedStorage) { + t.Errorf("Storage size mismatch: have %d, want %d", len(so1.cachedStorage), len(so0.cachedStorage)) } - for k, v := range so1.storage { - if so0.storage[k] != v { - t.Errorf("Storage key %x mismatch: have %v, want %v", k, so0.storage[k], v) + for k, v := range so1.cachedStorage { + if so0.cachedStorage[k] != v { + t.Errorf("Storage key %x mismatch: have %v, want %v", k, so0.cachedStorage[k], v) } } - for k, v := range so0.storage { - if so1.storage[k] != v { + for k, v := range so0.cachedStorage { + if so1.cachedStorage[k] != v { t.Errorf("Storage key %x mismatch: have %v, want none.", k, v) } }