diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 7e09abb11..74203a468 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -172,8 +172,9 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call ethereum.CallM func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) { b.mu.Lock() defer b.mu.Unlock() + defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot()) - rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy()) + rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) return rval, err } @@ -197,8 +198,9 @@ func (b *SimulatedBackend) SuggestGasPrice(ctx context.Context) (*big.Int, error func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMsg) (*big.Int, error) { b.mu.Lock() defer b.mu.Unlock() + defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot()) - _, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy()) + _, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState) return gas, err } diff --git a/cmd/evm/main.go b/cmd/evm/main.go index 09ade1577..22707c1cc 100644 --- a/cmd/evm/main.go +++ b/cmd/evm/main.go @@ -227,22 +227,22 @@ type ruleSet struct{} func (ruleSet) IsHomestead(*big.Int) bool { return true } -func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} } -func (self *VMEnv) Vm() vm.Vm { return self.evm } -func (self *VMEnv) Db() vm.Database { return self.state } -func (self *VMEnv) MakeSnapshot() vm.Database { return self.state.Copy() } -func (self *VMEnv) SetSnapshot(db vm.Database) { self.state.Set(db.(*state.StateDB)) } -func (self *VMEnv) Origin() common.Address { return *self.transactor } -func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 } -func (self *VMEnv) Coinbase() common.Address { return *self.transactor } -func (self *VMEnv) Time() *big.Int { return self.time } -func (self *VMEnv) Difficulty() *big.Int { return common.Big1 } -func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) } -func (self *VMEnv) Value() *big.Int { return self.value } -func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) } -func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy } -func (self *VMEnv) Depth() int { return 0 } -func (self *VMEnv) SetDepth(i int) { self.depth = i } +func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} } +func (self *VMEnv) Vm() vm.Vm { return self.evm } +func (self *VMEnv) Db() vm.Database { return self.state } +func (self *VMEnv) SnapshotDatabase() int { return self.state.Snapshot() } +func (self *VMEnv) RevertToSnapshot(snap int) { self.state.RevertToSnapshot(snap) } +func (self *VMEnv) Origin() common.Address { return *self.transactor } +func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 } +func (self *VMEnv) Coinbase() common.Address { return *self.transactor } +func (self *VMEnv) Time() *big.Int { return self.time } +func (self *VMEnv) Difficulty() *big.Int { return common.Big1 } +func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) } +func (self *VMEnv) Value() *big.Int { return self.value } +func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) } +func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy } +func (self *VMEnv) Depth() int { return 0 } +func (self *VMEnv) SetDepth(i int) { self.depth = i } func (self *VMEnv) GetHash(n uint64) common.Hash { if self.block.Number().Cmp(big.NewInt(int64(n))) == 0 { return self.block.Hash() diff --git a/core/chain_makers.go b/core/chain_makers.go index 0b9a5f75d..e3ad9cda0 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -131,7 +131,7 @@ func (b *BlockGen) AddUncheckedReceipt(receipt *types.Receipt) { // TxNonce returns the next valid transaction nonce for the // account at addr. It panics if the account does not exist. func (b *BlockGen) TxNonce(addr common.Address) uint64 { - if !b.statedb.HasAccount(addr) { + if !b.statedb.Exist(addr) { panic("account does not exist") } return b.statedb.GetNonce(addr) diff --git a/core/execution.go b/core/execution.go index 1bc02f7fb..1cb507ee7 100644 --- a/core/execution.go +++ b/core/execution.go @@ -85,7 +85,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A createAccount = true } - snapshotPreTransfer := env.MakeSnapshot() + snapshotPreTransfer := env.SnapshotDatabase() var ( from = env.Db().GetAccount(caller.Address()) to vm.Account @@ -129,7 +129,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) { contract.UseGas(contract.Gas) - env.SetSnapshot(snapshotPreTransfer) + env.RevertToSnapshot(snapshotPreTransfer) } return ret, addr, err @@ -144,7 +144,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA return nil, common.Address{}, vm.DepthError } - snapshot := env.MakeSnapshot() + snapshot := env.SnapshotDatabase() var to vm.Account if !env.Db().Exist(*toAddr) { @@ -162,7 +162,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA if err != nil { contract.UseGas(contract.Gas) - env.SetSnapshot(snapshot) + env.RevertToSnapshot(snapshot) } return ret, addr, err diff --git a/core/state/dump.go b/core/state/dump.go index 58ecd852b..8294d61b9 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -52,7 +52,7 @@ func (self *StateDB) RawDump() Dump { panic(err) } - obj := NewObject(common.BytesToAddress(addr), data, nil) + obj := newObject(nil, common.BytesToAddress(addr), data, nil) account := DumpAccount{ Balance: data.Balance.String(), Nonce: data.Nonce, diff --git a/core/state/journal.go b/core/state/journal.go new file mode 100644 index 000000000..540ade6fb --- /dev/null +++ b/core/state/journal.go @@ -0,0 +1,117 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/common" +) + +type journalEntry interface { + undo(*StateDB) +} + +type journal []journalEntry + +type ( + // Changes to the account trie. + createObjectChange struct { + account *common.Address + } + resetObjectChange struct { + prev *StateObject + } + deleteAccountChange struct { + account *common.Address + prev bool // whether account had already suicided + prevbalance *big.Int + } + + // Changes to individual accounts. + balanceChange struct { + account *common.Address + prev *big.Int + } + nonceChange struct { + account *common.Address + prev uint64 + } + storageChange struct { + account *common.Address + key, prevalue common.Hash + } + codeChange struct { + account *common.Address + prevcode, prevhash []byte + } + + // Changes to other state values. + refundChange struct { + prev *big.Int + } + addLogChange struct { + txhash common.Hash + } +) + +func (ch createObjectChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).deleted = true + delete(s.stateObjects, *ch.account) + delete(s.stateObjectsDirty, *ch.account) +} + +func (ch resetObjectChange) undo(s *StateDB) { + s.setStateObject(ch.prev) +} + +func (ch deleteAccountChange) undo(s *StateDB) { + obj := s.GetStateObject(*ch.account) + if obj != nil { + obj.remove = ch.prev + obj.setBalance(ch.prevbalance) + } +} + +func (ch balanceChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).setBalance(ch.prev) +} + +func (ch nonceChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).setNonce(ch.prev) +} + +func (ch codeChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode) +} + +func (ch storageChange) undo(s *StateDB) { + s.GetStateObject(*ch.account).setState(ch.key, ch.prevalue) +} + +func (ch refundChange) undo(s *StateDB) { + s.refund = ch.prev +} + +func (ch addLogChange) undo(s *StateDB) { + logs := s.logs[ch.txhash] + if len(logs) == 1 { + delete(s.logs, ch.txhash) + } else { + s.logs[ch.txhash] = logs[:len(logs)-1] + } +} diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go index baa53428f..3f7bc2aa8 100644 --- a/core/state/managed_state_test.go +++ b/core/state/managed_state_test.go @@ -29,11 +29,8 @@ func create() (*ManagedState, *account) { db, _ := ethdb.NewMemDatabase() statedb, _ := New(common.Hash{}, db) ms := ManageState(statedb) - so := &StateObject{address: addr} - so.SetNonce(100) - ms.StateDB.stateObjects[addr] = so - ms.accounts[addr] = newAccount(so) - + ms.StateDB.SetNonce(addr, 100) + ms.accounts[addr] = newAccount(ms.StateDB.GetStateObject(addr)) return ms, ms.accounts[addr] } diff --git a/core/state/state_object.go b/core/state/state_object.go index cbd50e2a3..31ff9bcd8 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -66,6 +66,7 @@ func (self Storage) Copy() Storage { type StateObject struct { address common.Address // Ethereum address of this account data Account + db *StateDB // DB error. // State objects are used by the consensus core and VM which are @@ -99,15 +100,15 @@ type Account struct { CodeHash []byte } -// NewObject creates a state object. -func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject { +// newObject creates a state object. +func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *StateObject { if data.Balance == nil { data.Balance = new(big.Int) } if data.CodeHash == nil { data.CodeHash = emptyCodeHash } - return &StateObject{address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} + return &StateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} } // EncodeRLP implements rlp.Encoder. @@ -122,7 +123,7 @@ func (self *StateObject) setError(err error) { } } -func (self *StateObject) MarkForDeletion() { +func (self *StateObject) markForDeletion() { self.remove = true if self.onDirty != nil { self.onDirty(self.Address()) @@ -163,7 +164,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash } // SetState updates a value in account storage. -func (self *StateObject) SetState(key, value common.Hash) { +func (self *StateObject) SetState(db trie.Database, key, value common.Hash) { + self.db.journal = append(self.db.journal, storageChange{ + account: &self.address, + key: key, + prevalue: self.GetState(db, key), + }) + self.setState(key, value) +} + +func (self *StateObject) setState(key, value common.Hash) { self.cachedStorage[key] = value self.dirtyStorage[key] = value @@ -189,7 +199,7 @@ func (self *StateObject) updateTrie(db trie.Database) { } // UpdateRoot sets the trie root to the current root hash of -func (self *StateObject) UpdateRoot(db trie.Database) { +func (self *StateObject) updateRoot(db trie.Database) { self.updateTrie(db) self.data.Root = self.trie.Hash() } @@ -232,6 +242,14 @@ func (c *StateObject) SubBalance(amount *big.Int) { } func (self *StateObject) SetBalance(amount *big.Int) { + self.db.journal = append(self.db.journal, balanceChange{ + account: &self.address, + prev: new(big.Int).Set(self.data.Balance), + }) + self.setBalance(amount) +} + +func (self *StateObject) setBalance(amount *big.Int) { self.data.Balance = amount if self.onDirty != nil { self.onDirty(self.Address()) @@ -242,8 +260,8 @@ func (self *StateObject) SetBalance(amount *big.Int) { // Return the gas back to the origin. Used by the Virtual machine or Closures func (c *StateObject) ReturnGas(gas, price *big.Int) {} -func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject { - stateObject := NewObject(self.address, self.data, onDirty) +func (self *StateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *StateObject { + stateObject := newObject(db, self.address, self.data, onDirty) stateObject.trie = self.trie stateObject.code = self.code stateObject.dirtyStorage = self.dirtyStorage.Copy() @@ -280,6 +298,16 @@ func (self *StateObject) Code(db trie.Database) []byte { } func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { + prevcode := self.Code(self.db.db) + self.db.journal = append(self.db.journal, codeChange{ + account: &self.address, + prevhash: self.CodeHash(), + prevcode: prevcode, + }) + self.setCode(codeHash, code) +} + +func (self *StateObject) setCode(codeHash common.Hash, code []byte) { self.code = code self.data.CodeHash = codeHash[:] self.dirtyCode = true @@ -290,6 +318,14 @@ func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { } func (self *StateObject) SetNonce(nonce uint64) { + self.db.journal = append(self.db.journal, nonceChange{ + account: &self.address, + prev: self.data.Nonce, + }) + self.setNonce(nonce) +} + +func (self *StateObject) setNonce(nonce uint64) { self.data.Nonce = nonce if self.onDirty != nil { self.onDirty(self.Address()) @@ -322,7 +358,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) { cb(h, value) } - it := self.trie.Iterator() + it := self.getTrie(self.db.db).Iterator() for it.Next() { // ignore cached values key := common.BytesToHash(self.trie.GetKey(it.Key)) diff --git a/core/state/state_test.go b/core/state/state_test.go index 7b9b39e06..b86d8b140 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -46,8 +46,8 @@ func (s *StateSuite) TestDump(c *checker.C) { obj3.SetBalance(big.NewInt(44)) // write some of them to the trie - s.state.UpdateStateObject(obj1) - s.state.UpdateStateObject(obj2) + s.state.updateStateObject(obj1) + s.state.updateStateObject(obj2) s.state.Commit() // check that dump contains the state objects that are in trie @@ -116,12 +116,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) { // set initial state object value s.state.SetState(stateobjaddr, storageaddr, data1) // get snapshot of current state - snapshot := s.state.Copy() + snapshot := s.state.Snapshot() // set new state object value s.state.SetState(stateobjaddr, storageaddr, data2) // restore snapshot - s.state.Set(snapshot) + s.state.RevertToSnapshot(snapshot) // get state storage value res := s.state.GetState(stateobjaddr, storageaddr) @@ -129,6 +129,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) { c.Assert(data1, checker.DeepEquals, res) } +func TestSnapshotEmpty(t *testing.T) { + db, _ := ethdb.NewMemDatabase() + state, _ := New(common.Hash{}, db) + state.RevertToSnapshot(state.Snapshot()) +} + // use testing instead of checker because checker does not support // printing/logging in tests (-check.vv does not work) func TestSnapshot2(t *testing.T) { @@ -152,7 +158,7 @@ func TestSnapshot2(t *testing.T) { so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) so0.remove = false so0.deleted = false - state.SetStateObject(so0) + state.setStateObject(so0) root, _ := state.Commit() state.Reset(root) @@ -164,15 +170,15 @@ func TestSnapshot2(t *testing.T) { so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'}) so1.remove = true so1.deleted = true - state.SetStateObject(so1) + state.setStateObject(so1) so1 = state.GetStateObject(stateobjaddr1) if so1 != nil { t.Fatalf("deleted object not nil when getting") } - snapshot := state.Copy() - state.Set(snapshot) + snapshot := state.Snapshot() + state.RevertToSnapshot(snapshot) so0Restored := state.GetStateObject(stateobjaddr0) // Update lazily-loaded values before comparing. diff --git a/core/state/statedb.go b/core/state/statedb.go index 4204c456e..4f74302c3 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -20,6 +20,7 @@ package state import ( "fmt" "math/big" + "sort" "sync" "github.com/ethereum/go-ethereum/common" @@ -40,12 +41,17 @@ var StartingNonce uint64 const ( // Number of past tries to keep. The arbitrarily chosen value here // is max uncle depth + 1. - maxJournalLength = 8 + maxTrieCacheLength = 8 // Number of codehash->size associations to keep. codeSizeCacheSize = 100000 ) +type revision struct { + id int + journalIndex int +} + // StateDBs within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: @@ -69,6 +75,12 @@ type StateDB struct { logs map[common.Hash]vm.Logs logSize uint + // Journal of state modifications. This is the backbone of + // Snapshot and RevertToSnapshot. + journal journal + validRevisions []revision + nextRevisionId int + lock sync.Mutex } @@ -124,12 +136,12 @@ func (self *StateDB) Reset(root common.Hash) error { self.trie = tr self.stateObjects = make(map[common.Address]*StateObject) self.stateObjectsDirty = make(map[common.Address]struct{}) - self.refund = new(big.Int) self.thash = common.Hash{} self.bhash = common.Hash{} self.txIndex = 0 self.logs = make(map[common.Hash]vm.Logs) self.logSize = 0 + self.clearJournalAndRefund() return nil } @@ -150,7 +162,7 @@ func (self *StateDB) pushTrie(t *trie.SecureTrie) { self.lock.Lock() defer self.lock.Unlock() - if len(self.pastTries) >= maxJournalLength { + if len(self.pastTries) >= maxTrieCacheLength { copy(self.pastTries, self.pastTries[1:]) self.pastTries[len(self.pastTries)-1] = t } else { @@ -165,6 +177,8 @@ func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) { } func (self *StateDB) AddLog(log *vm.Log) { + self.journal = append(self.journal, addLogChange{txhash: self.thash}) + log.TxHash = self.thash log.BlockHash = self.bhash log.TxIndex = uint(self.txIndex) @@ -186,13 +200,12 @@ func (self *StateDB) Logs() vm.Logs { } func (self *StateDB) AddRefund(gas *big.Int) { + self.journal = append(self.journal, refundChange{prev: new(big.Int).Set(self.refund)}) self.refund.Add(self.refund, gas) } -func (self *StateDB) HasAccount(addr common.Address) bool { - return self.GetStateObject(addr) != nil -} - +// Exist reports whether the given account address exists in the state. +// Notably this also returns true for suicided accounts. func (self *StateDB) Exist(addr common.Address) bool { return self.GetStateObject(addr) != nil } @@ -207,7 +220,6 @@ func (self *StateDB) GetBalance(addr common.Address) *big.Int { if stateObject != nil { return stateObject.Balance() } - return common.Big0 } @@ -282,6 +294,13 @@ func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) { } } +func (self *StateDB) SetBalance(addr common.Address, amount *big.Int) { + stateObject := self.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetBalance(amount) + } +} + func (self *StateDB) SetNonce(addr common.Address, nonce uint64) { stateObject := self.GetOrNewStateObject(addr) if stateObject != nil { @@ -299,27 +318,36 @@ func (self *StateDB) SetCode(addr common.Address, code []byte) { func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) { stateObject := self.GetOrNewStateObject(addr) if stateObject != nil { - stateObject.SetState(key, value) + stateObject.SetState(self.db, key, value) } } +// Delete marks the given account as suicided. +// This clears the account balance. +// +// The account's state object is still available until the state is committed, +// GetStateObject will return a non-nil account after Delete. func (self *StateDB) Delete(addr common.Address) bool { stateObject := self.GetStateObject(addr) - if stateObject != nil { - stateObject.MarkForDeletion() - stateObject.data.Balance = new(big.Int) - return true + if stateObject == nil { + return false } - - return false + self.journal = append(self.journal, deleteAccountChange{ + account: &addr, + prev: stateObject.remove, + prevbalance: new(big.Int).Set(stateObject.Balance()), + }) + stateObject.markForDeletion() + stateObject.data.Balance = new(big.Int) + return true } // // Setting, updating & deleting state object methods // -// Update the given state object and apply it to state trie -func (self *StateDB) UpdateStateObject(stateObject *StateObject) { +// updateStateObject writes the given object to the trie. +func (self *StateDB) updateStateObject(stateObject *StateObject) { addr := stateObject.Address() data, err := rlp.EncodeToBytes(stateObject) if err != nil { @@ -328,10 +356,9 @@ func (self *StateDB) UpdateStateObject(stateObject *StateObject) { self.trie.Update(addr[:], data) } -// Delete the given state object and delete it from the state trie -func (self *StateDB) DeleteStateObject(stateObject *StateObject) { +// deleteStateObject removes the given object from the state trie. +func (self *StateDB) deleteStateObject(stateObject *StateObject) { stateObject.deleted = true - addr := stateObject.Address() self.trie.Delete(addr[:]) } @@ -357,12 +384,12 @@ func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObje return nil } // Insert into the live set. - obj := NewObject(addr, data, self.MarkStateObjectDirty) - self.SetStateObject(obj) + obj := newObject(self, addr, data, self.MarkStateObjectDirty) + self.setStateObject(obj) return obj } -func (self *StateDB) SetStateObject(object *StateObject) { +func (self *StateDB) setStateObject(object *StateObject) { self.stateObjects[object.Address()] = object } @@ -370,52 +397,55 @@ func (self *StateDB) SetStateObject(object *StateObject) { func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject { stateObject := self.GetStateObject(addr) if stateObject == nil || stateObject.deleted { - stateObject = self.CreateStateObject(addr) + stateObject, _ = self.createObject(addr) } - return stateObject } -// NewStateObject create a state object whether it exist in the trie or not -func (self *StateDB) newStateObject(addr common.Address) *StateObject { - if glog.V(logger.Core) { - glog.Infof("(+) %x\n", addr) - } - obj := NewObject(addr, Account{}, self.MarkStateObjectDirty) - obj.SetNonce(StartingNonce) // sets the object to dirty - self.stateObjects[addr] = obj - return obj -} - // MarkStateObjectDirty adds the specified object to the dirty map to avoid costly // state object cache iteration to find a handful of modified ones. func (self *StateDB) MarkStateObjectDirty(addr common.Address) { self.stateObjectsDirty[addr] = struct{}{} } -// Creates creates a new state object and takes ownership. -func (self *StateDB) CreateStateObject(addr common.Address) *StateObject { - // Get previous (if any) - so := self.GetStateObject(addr) - // Create a new one - newSo := self.newStateObject(addr) - - // If it existed set the balance to the new account - if so != nil { - newSo.data.Balance = so.data.Balance +// createObject creates a new state object. If there is an existing account with +// the given address, it is overwritten and returned as the second return value. +func (self *StateDB) createObject(addr common.Address) (newobj, prev *StateObject) { + prev = self.GetStateObject(addr) + newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty) + newobj.setNonce(StartingNonce) // sets the object to dirty + if prev == nil { + if glog.V(logger.Core) { + glog.Infof("(+) %x\n", addr) + } + self.journal = append(self.journal, createObjectChange{account: &addr}) + } else { + self.journal = append(self.journal, resetObjectChange{prev: prev}) } - - return newSo + self.setStateObject(newobj) + return newobj, prev } +// CreateAccount explicitly creates a state object. If a state object with the address +// already exists the balance is carried over to the new account. +// +// CreateAccount is called during the EVM CREATE operation. The situation might arise that +// a contract does the following: +// +// 1. sends funds to sha(account ++ (nonce + 1)) +// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) +// +// Carrying over the balance ensures that Ether doesn't disappear. func (self *StateDB) CreateAccount(addr common.Address) vm.Account { - return self.CreateStateObject(addr) + new, prev := self.createObject(addr) + if prev != nil { + new.setBalance(prev.data.Balance) + } + return new } -// -// Setting, copying of the state methods -// - +// Copy creates a deep, independent copy of the state. +// Snapshots of the copied state cannot be applied to the copy. func (self *StateDB) Copy() *StateDB { self.lock.Lock() defer self.lock.Unlock() @@ -434,7 +464,7 @@ func (self *StateDB) Copy() *StateDB { } // Copy the dirty states and logs for addr, _ := range self.stateObjectsDirty { - state.stateObjects[addr] = self.stateObjects[addr].Copy(self.db, state.MarkStateObjectDirty) + state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty) state.stateObjectsDirty[addr] = struct{}{} } for hash, logs := range self.logs { @@ -444,21 +474,38 @@ func (self *StateDB) Copy() *StateDB { return state } -func (self *StateDB) Set(state *StateDB) { - self.lock.Lock() - defer self.lock.Unlock() - - self.db = state.db - self.trie = state.trie - self.pastTries = state.pastTries - self.stateObjects = state.stateObjects - self.stateObjectsDirty = state.stateObjectsDirty - self.codeSizeCache = state.codeSizeCache - self.refund = state.refund - self.logs = state.logs - self.logSize = state.logSize +// Snapshot returns an identifier for the current revision of the state. +func (self *StateDB) Snapshot() int { + id := self.nextRevisionId + self.nextRevisionId++ + self.validRevisions = append(self.validRevisions, revision{id, len(self.journal)}) + return id } +// RevertToSnapshot reverts all state changes made since the given revision. +func (self *StateDB) RevertToSnapshot(revid int) { + // Find the snapshot in the stack of valid snapshots. + idx := sort.Search(len(self.validRevisions), func(i int) bool { + return self.validRevisions[i].id >= revid + }) + if idx == len(self.validRevisions) || self.validRevisions[idx].id != revid { + panic(fmt.Errorf("revision id %v cannot be reverted", revid)) + } + snapshot := self.validRevisions[idx].journalIndex + + // Replay the journal to undo changes. + for i := len(self.journal) - 1; i >= snapshot; i-- { + self.journal[i].undo(self) + } + self.journal = self.journal[:snapshot] + + // Remove invalidated snapshots from the stack. + self.validRevisions = self.validRevisions[:idx] +} + +// GetRefund returns the current value of the refund counter. +// The return value must not be modified by the caller and will become +// invalid at the next call to AddRefund. func (self *StateDB) GetRefund() *big.Int { return self.refund } @@ -467,16 +514,17 @@ func (self *StateDB) GetRefund() *big.Int { // It is called in between transactions to get the root hash that // goes into transaction receipts. func (s *StateDB) IntermediateRoot() common.Hash { - s.refund = new(big.Int) for addr, _ := range s.stateObjectsDirty { stateObject := s.stateObjects[addr] if stateObject.remove { - s.DeleteStateObject(stateObject) + s.deleteStateObject(stateObject) } else { - stateObject.UpdateRoot(s.db) - s.UpdateStateObject(stateObject) + stateObject.updateRoot(s.db) + s.updateStateObject(stateObject) } } + // Invalidate journal because reverting across transactions is not allowed. + s.clearJournalAndRefund() return s.trie.Hash() } @@ -486,9 +534,9 @@ func (s *StateDB) IntermediateRoot() common.Hash { // DeleteSuicides should not be used for consensus related updates // under any circumstances. func (s *StateDB) DeleteSuicides() { - // Reset refund so that any used-gas calculations can use - // this method. - s.refund = new(big.Int) + // Reset refund so that any used-gas calculations can use this method. + s.clearJournalAndRefund() + for addr, _ := range s.stateObjectsDirty { stateObject := s.stateObjects[addr] @@ -516,15 +564,21 @@ func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) { return root, batch } -func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { +func (s *StateDB) clearJournalAndRefund() { + s.journal = nil + s.validRevisions = s.validRevisions[:0] s.refund = new(big.Int) +} + +func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { + defer s.clearJournalAndRefund() // Commit objects to the trie. for addr, stateObject := range s.stateObjects { if stateObject.remove { // If the object has been removed, don't bother syncing it // and just mark it for deletion in the trie. - s.DeleteStateObject(stateObject) + s.deleteStateObject(stateObject) } else if _, ok := s.stateObjectsDirty[addr]; ok { // Write any contract code associated with the state object if stateObject.code != nil && stateObject.dirtyCode { @@ -538,7 +592,7 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) return common.Hash{}, err } // Update the object in the main account trie. - s.UpdateStateObject(stateObject) + s.updateStateObject(stateObject) } delete(s.stateObjectsDirty, addr) } @@ -549,7 +603,3 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) } return root, err } - -func (self *StateDB) Refunds() *big.Int { - return self.refund -} diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 7930b620d..e236cb8f3 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -17,11 +17,19 @@ package state import ( + "bytes" + "encoding/binary" + "fmt" + "math" "math/big" + "math/rand" + "reflect" + "strings" "testing" + "testing/quick" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/ethdb" ) @@ -34,16 +42,16 @@ func TestUpdateLeaks(t *testing.T) { // Update it with some accounts for i := byte(0); i < 255; i++ { - obj := state.GetOrNewStateObject(common.BytesToAddress([]byte{i})) - obj.AddBalance(big.NewInt(int64(11 * i))) - obj.SetNonce(uint64(42 * i)) + addr := common.BytesToAddress([]byte{i}) + state.AddBalance(addr, big.NewInt(int64(11*i))) + state.SetNonce(addr, uint64(42*i)) if i%2 == 0 { - obj.SetState(common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i})) + state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i})) } if i%3 == 0 { - obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) + state.SetCode(addr, []byte{i, i, i, i, i}) } - state.UpdateStateObject(obj) + state.IntermediateRoot() } // Ensure that no data was leaked into the database for _, key := range db.Keys() { @@ -61,51 +69,38 @@ func TestIntermediateLeaks(t *testing.T) { transState, _ := New(common.Hash{}, transDb) finalState, _ := New(common.Hash{}, finalDb) - // Update the states with some objects - for i := byte(0); i < 255; i++ { - // Create a new state object with some data into the transition database - obj := transState.GetOrNewStateObject(common.BytesToAddress([]byte{i})) - obj.SetBalance(big.NewInt(int64(11 * i))) - obj.SetNonce(uint64(42 * i)) + modify := func(state *StateDB, addr common.Address, i, tweak byte) { + state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) + state.SetNonce(addr, uint64(42*i+tweak)) if i%2 == 0 { - obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.BytesToHash([]byte{i, i, i, i, 0})) + state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{}) + state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak}) } if i%3 == 0 { - obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 0}), []byte{i, i, i, i, i, 0}) + state.SetCode(addr, []byte{i, i, i, i, i, tweak}) } - transState.UpdateStateObject(obj) - - // Overwrite all the data with new values in the transition database - obj.SetBalance(big.NewInt(int64(11*i + 1))) - obj.SetNonce(uint64(42*i + 1)) - if i%2 == 0 { - obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.Hash{}) - obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1})) - } - if i%3 == 0 { - obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1}) - } - transState.UpdateStateObject(obj) - - // Create the final state object directly in the final database - obj = finalState.GetOrNewStateObject(common.BytesToAddress([]byte{i})) - obj.SetBalance(big.NewInt(int64(11*i + 1))) - obj.SetNonce(uint64(42*i + 1)) - if i%2 == 0 { - obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1})) - } - if i%3 == 0 { - obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1}) - } - finalState.UpdateStateObject(obj) } + + // Modify the transient state. + for i := byte(0); i < 255; i++ { + modify(transState, common.Address{byte(i)}, i, 0) + } + // Write modifications to trie. + transState.IntermediateRoot() + + // Overwrite all the data with new values in the transient database. + for i := byte(0); i < 255; i++ { + modify(transState, common.Address{byte(i)}, i, 99) + modify(finalState, common.Address{byte(i)}, i, 99) + } + + // Commit and cross check the databases. if _, err := transState.Commit(); err != nil { t.Fatalf("failed to commit transition state: %v", err) } if _, err := finalState.Commit(); err != nil { t.Fatalf("failed to commit final state: %v", err) } - // Cross check the databases to ensure they are the same for _, key := range finalDb.Keys() { if _, err := transDb.Get(key); err != nil { val, _ := finalDb.Get(key) @@ -119,3 +114,243 @@ func TestIntermediateLeaks(t *testing.T) { } } } + +func TestSnapshotRandom(t *testing.T) { + config := &quick.Config{MaxCount: 1000} + err := quick.Check((*snapshotTest).run, config) + if cerr, ok := err.(*quick.CheckError); ok { + test := cerr.In[0].(*snapshotTest) + t.Errorf("%v:\n%s", test.err, test) + } else if err != nil { + t.Error(err) + } +} + +// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes +// captured by the snapshot. Instances of this test with pseudorandom content are created +// by Generate. +// +// The test works as follows: +// +// A new state is created and all actions are applied to it. Several snapshots are taken +// in between actions. The test then reverts each snapshot. For each snapshot the actions +// leading up to it are replayed on a fresh, empty state. The behaviour of all public +// accessor methods on the reverted state must match the return value of the equivalent +// methods on the replayed state. +type snapshotTest struct { + addrs []common.Address // all account addresses + actions []testAction // modifications to the state + snapshots []int // actions indexes at which snapshot is taken + err error // failure details are reported through this field +} + +type testAction struct { + name string + fn func(testAction, *StateDB) + args []int64 + noAddr bool +} + +// newTestAction creates a random action that changes state. +func newTestAction(addr common.Address, r *rand.Rand) testAction { + actions := []testAction{ + { + name: "SetBalance", + fn: func(a testAction, s *StateDB) { + s.SetBalance(addr, big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "AddBalance", + fn: func(a testAction, s *StateDB) { + s.AddBalance(addr, big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "SetNonce", + fn: func(a testAction, s *StateDB) { + s.SetNonce(addr, uint64(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "SetState", + fn: func(a testAction, s *StateDB) { + var key, val common.Hash + binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) + binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) + s.SetState(addr, key, val) + }, + args: make([]int64, 2), + }, + { + name: "SetCode", + fn: func(a testAction, s *StateDB) { + code := make([]byte, 16) + binary.BigEndian.PutUint64(code, uint64(a.args[0])) + binary.BigEndian.PutUint64(code[8:], uint64(a.args[1])) + s.SetCode(addr, code) + }, + args: make([]int64, 2), + }, + { + name: "CreateAccount", + fn: func(a testAction, s *StateDB) { + s.CreateAccount(addr) + }, + }, + { + name: "Delete", + fn: func(a testAction, s *StateDB) { + s.Delete(addr) + }, + }, + { + name: "AddRefund", + fn: func(a testAction, s *StateDB) { + s.AddRefund(big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + noAddr: true, + }, + { + name: "AddLog", + fn: func(a testAction, s *StateDB) { + data := make([]byte, 2) + binary.BigEndian.PutUint16(data, uint16(a.args[0])) + s.AddLog(&vm.Log{Address: addr, Data: data}) + }, + args: make([]int64, 1), + }, + } + action := actions[r.Intn(len(actions))] + var nameargs []string + if !action.noAddr { + nameargs = append(nameargs, addr.Hex()) + } + for _, i := range action.args { + action.args[i] = rand.Int63n(100) + nameargs = append(nameargs, fmt.Sprint(action.args[i])) + } + action.name += strings.Join(nameargs, ", ") + return action +} + +// Generate returns a new snapshot test of the given size. All randomness is +// derived from r. +func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value { + // Generate random actions. + addrs := make([]common.Address, 50) + for i := range addrs { + addrs[i][0] = byte(i) + } + actions := make([]testAction, size) + for i := range actions { + addr := addrs[r.Intn(len(addrs))] + actions[i] = newTestAction(addr, r) + } + // Generate snapshot indexes. + nsnapshots := int(math.Sqrt(float64(size))) + if size > 0 && nsnapshots == 0 { + nsnapshots = 1 + } + snapshots := make([]int, nsnapshots) + snaplen := len(actions) / nsnapshots + for i := range snapshots { + // Try to place the snapshots some number of actions apart from each other. + snapshots[i] = (i * snaplen) + r.Intn(snaplen) + } + return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil}) +} + +func (test *snapshotTest) String() string { + out := new(bytes.Buffer) + sindex := 0 + for i, action := range test.actions { + if len(test.snapshots) > sindex && i == test.snapshots[sindex] { + fmt.Fprintf(out, "---- snapshot %d ----\n", sindex) + sindex++ + } + fmt.Fprintf(out, "%4d: %s\n", i, action.name) + } + return out.String() +} + +func (test *snapshotTest) run() bool { + // Run all actions and create snapshots. + var ( + db, _ = ethdb.NewMemDatabase() + state, _ = New(common.Hash{}, db) + snapshotRevs = make([]int, len(test.snapshots)) + sindex = 0 + ) + for i, action := range test.actions { + if len(test.snapshots) > sindex && i == test.snapshots[sindex] { + snapshotRevs[sindex] = state.Snapshot() + sindex++ + } + action.fn(action, state) + } + + // Revert all snapshots in reverse order. Each revert must yield a state + // that is equivalent to fresh state with all actions up the snapshot applied. + for sindex--; sindex >= 0; sindex-- { + checkstate, _ := New(common.Hash{}, db) + for _, action := range test.actions[:test.snapshots[sindex]] { + action.fn(action, checkstate) + } + state.RevertToSnapshot(snapshotRevs[sindex]) + if err := test.checkEqual(state, checkstate); err != nil { + test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err) + return false + } + } + return true +} + +// checkEqual checks that methods of state and checkstate return the same values. +func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { + for _, addr := range test.addrs { + var err error + checkeq := func(op string, a, b interface{}) bool { + if err == nil && !reflect.DeepEqual(a, b) { + err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b) + return false + } + return true + } + // Check basic accessor methods. + checkeq("Exist", state.Exist(addr), checkstate.Exist(addr)) + checkeq("IsDeleted", state.IsDeleted(addr), checkstate.IsDeleted(addr)) + checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr)) + checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr)) + checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) + checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr)) + checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) + // Check storage. + if obj := state.GetStateObject(addr); obj != nil { + obj.ForEachStorage(func(key, val common.Hash) bool { + return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key)) + }) + checkobj := checkstate.GetStateObject(addr) + checkobj.ForEachStorage(func(key, checkval common.Hash) bool { + return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval) + }) + } + if err != nil { + return err + } + } + + if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 { + return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d", + state.GetRefund(), checkstate.GetRefund()) + } + if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) { + return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", + state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) + } + return nil +} diff --git a/core/state/sync_test.go b/core/state/sync_test.go index 670e1fb1b..949df7301 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -57,7 +57,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) acc.code = []byte{i, i, i, i, i} } - state.UpdateStateObject(obj) + state.updateStateObject(obj) accounts = append(accounts, acc) } root, _ := state.Commit() diff --git a/core/tx_pool.go b/core/tx_pool.go index f8b11a7ce..10a110e0b 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -257,7 +257,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction) error { // Make sure the account exist. Non existent accounts // haven't got funds and well therefor never pass. - if !currentState.HasAccount(from) { + if !currentState.Exist(from) { return ErrNonExistentAccount } diff --git a/core/vm/environment.go b/core/vm/environment.go index daf6fb90d..1038e69d5 100644 --- a/core/vm/environment.go +++ b/core/vm/environment.go @@ -36,9 +36,9 @@ type Environment interface { // The state database Db() Database // Creates a restorable snapshot - MakeSnapshot() Database + SnapshotDatabase() int // Set database to previous snapshot - SetSnapshot(Database) + RevertToSnapshot(int) // Address of the original invoker (first occurrence of the VM invoker) Origin() common.Address // The block number this VM is invoked on diff --git a/core/vm/jit_test.go b/core/vm/jit_test.go index e6922aeb7..a6de710e1 100644 --- a/core/vm/jit_test.go +++ b/core/vm/jit_test.go @@ -179,8 +179,8 @@ func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) } //func (self *Env) PrevHash() []byte { return self.parent } func (self *Env) Coinbase() common.Address { return common.Address{} } -func (self *Env) MakeSnapshot() Database { return nil } -func (self *Env) SetSnapshot(Database) {} +func (self *Env) SnapshotDatabase() int { return 0 } +func (self *Env) RevertToSnapshot(int) {} func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } func (self *Env) Db() Database { return nil } diff --git a/core/vm/runtime/env.go b/core/vm/runtime/env.go index a4793c98f..59fbaa792 100644 --- a/core/vm/runtime/env.go +++ b/core/vm/runtime/env.go @@ -86,11 +86,11 @@ func (self *Env) SetDepth(i int) { self.depth = i } func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { return self.state.GetBalance(from).Cmp(balance) >= 0 } -func (self *Env) MakeSnapshot() vm.Database { - return self.state.Copy() +func (self *Env) SnapshotDatabase() int { + return self.state.Snapshot() } -func (self *Env) SetSnapshot(copy vm.Database) { - self.state.Set(copy.(*state.StateDB)) +func (self *Env) RevertToSnapshot(snapshot int) { + self.state.RevertToSnapshot(snapshot) } func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { diff --git a/core/vm_env.go b/core/vm_env.go index e541eaef4..d62eebbd9 100644 --- a/core/vm_env.go +++ b/core/vm_env.go @@ -89,12 +89,12 @@ func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool { return self.state.GetBalance(from).Cmp(balance) >= 0 } -func (self *VMEnv) MakeSnapshot() vm.Database { - return self.state.Copy() +func (self *VMEnv) SnapshotDatabase() int { + return self.state.Snapshot() } -func (self *VMEnv) SetSnapshot(copy vm.Database) { - self.state.Set(copy.(*state.StateDB)) +func (self *VMEnv) RevertToSnapshot(snapshot int) { + self.state.RevertToSnapshot(snapshot) } func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) { diff --git a/eth/api_backend.go b/eth/api_backend.go index 4adeb0aa0..42b84bf9b 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -98,12 +98,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int { } func (b *EthApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) { - stateDb := state.(EthApiState).state.Copy() + statedb := state.(EthApiState).state addr, _ := msg.From() - from := stateDb.GetOrNewStateObject(addr) + from := statedb.GetOrNewStateObject(addr) from.SetBalance(common.MaxBig) vmError := func() error { return nil } - return core.NewEnv(stateDb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil + return core.NewEnv(statedb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil } func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { diff --git a/internal/ethapi/tracer_test.go b/internal/ethapi/tracer_test.go index 7c831d299..127af32a8 100644 --- a/internal/ethapi/tracer_test.go +++ b/internal/ethapi/tracer_test.go @@ -50,14 +50,14 @@ func (self *Env) Origin() common.Address { return common.Address{} } func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) } //func (self *Env) PrevHash() []byte { return self.parent } -func (self *Env) Coinbase() common.Address { return common.Address{} } -func (self *Env) MakeSnapshot() vm.Database { return nil } -func (self *Env) SetSnapshot(vm.Database) {} -func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } -func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } -func (self *Env) Db() vm.Database { return nil } -func (self *Env) GasLimit() *big.Int { return self.gasLimit } -func (self *Env) VmType() vm.Type { return vm.StdVmTy } +func (self *Env) Coinbase() common.Address { return common.Address{} } +func (self *Env) SnapshotDatabase() int { return 0 } +func (self *Env) RevertToSnapshot(int) {} +func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } +func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } +func (self *Env) Db() vm.Database { return nil } +func (self *Env) GasLimit() *big.Int { return self.gasLimit } +func (self *Env) VmType() vm.Type { return vm.StdVmTy } func (self *Env) GetHash(n uint64) common.Hash { return common.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String()))) } diff --git a/light/state_test.go b/light/state_test.go index d4fe95022..a6b115786 100644 --- a/light/state_test.go +++ b/light/state_test.go @@ -23,7 +23,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" "golang.org/x/net/context" @@ -54,16 +53,13 @@ func makeTestState() (common.Hash, ethdb.Database) { sdb, _ := ethdb.NewMemDatabase() st, _ := state.New(common.Hash{}, sdb) for i := byte(0); i < 100; i++ { - so := st.GetOrNewStateObject(common.Address{i}) + addr := common.Address{i} for j := byte(0); j < 100; j++ { - val := common.Hash{i, j} - so.SetState(common.Hash{j}, val) - so.SetNonce(100) + st.SetState(addr, common.Hash{j}, common.Hash{i, j}) } - so.AddBalance(big.NewInt(int64(i))) - so.SetCode(crypto.Keccak256Hash([]byte{i, i, i}), []byte{i, i, i}) - so.UpdateRoot(sdb) - st.UpdateStateObject(so) + st.SetNonce(addr, 100) + st.AddBalance(addr, big.NewInt(int64(i))) + st.SetCode(addr, []byte{i, i, i}) } root, _ := st.Commit() return root, sdb diff --git a/miner/worker.go b/miner/worker.go index ac1ef5ba3..e5348cef4 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -171,7 +171,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) { self.current.receipts, ), self.current.state } - return self.current.Block, self.current.state + return self.current.Block, self.current.state.Copy() } func (self *worker) start() { @@ -618,7 +618,7 @@ func (env *Work) commitTransactions(mux *event.TypeMux, txs *types.TransactionsB } func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) { - snap := env.state.Copy() + snap := env.state.Snapshot() // this is a bit of a hack to force jit for the miners config := env.config.VmConfig @@ -629,7 +629,7 @@ func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, g receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config) if err != nil { - env.state.Set(snap) + env.state.RevertToSnapshot(snap) return err, nil } env.txs = append(env.txs, tx) diff --git a/tests/state_test_util.go b/tests/state_test_util.go index 67e4bf832..3c4b42a18 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -95,14 +95,7 @@ func BenchStateTest(ruleSet RuleSet, p string, conf bconf, b *testing.B) error { func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) { b.StopTimer() db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) - for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) - statedb.SetStateObject(obj) - for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.HexToHash(v)) - } - } + statedb := makePreState(db, test.Pre) b.StartTimer() RunState(ruleSet, statedb, env, test.Exec) @@ -134,14 +127,7 @@ func runStateTests(ruleSet RuleSet, tests map[string]VmTest, skipTests []string) func runStateTest(ruleSet RuleSet, test VmTest) error { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) - for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) - statedb.SetStateObject(obj) - for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.HexToHash(v)) - } - } + statedb := makePreState(db, test.Pre) // XXX Yeah, yeah... env := make(map[string]string) @@ -227,7 +213,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string } // Set pre compiled contracts vm.Precompiled = vm.PrecompiledContracts() - snapshot := statedb.Copy() + snapshot := statedb.Snapshot() gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"])) key, _ := hex.DecodeString(tx["secretKey"]) @@ -237,7 +223,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string vmenv.origin = addr ret, _, err := core.ApplyMessage(vmenv, message, gaspool) if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) { - statedb.Set(snapshot) + statedb.RevertToSnapshot(snapshot) } statedb.Commit() diff --git a/tests/util.go b/tests/util.go index ffbcb9d56..8a9d09213 100644 --- a/tests/util.go +++ b/tests/util.go @@ -103,19 +103,25 @@ func (self Log) Topics() [][]byte { return t } -func StateObjectFromAccount(db ethdb.Database, addr string, account Account, onDirty func(common.Address)) *state.StateObject { +func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB { + statedb, _ := state.New(common.Hash{}, db) + for addr, account := range accounts { + insertAccount(statedb, addr, account) + } + return statedb +} + +func insertAccount(state *state.StateDB, saddr string, account Account) { if common.IsHex(account.Code) { account.Code = account.Code[2:] } - code := common.Hex2Bytes(account.Code) - codeHash := crypto.Keccak256Hash(code) - obj := state.NewObject(common.HexToAddress(addr), state.Account{ - Balance: common.Big(account.Balance), - CodeHash: codeHash[:], - Nonce: common.Big(account.Nonce).Uint64(), - }, onDirty) - obj.SetCode(codeHash, code) - return obj + addr := common.HexToAddress(saddr) + state.SetCode(addr, common.Hex2Bytes(account.Code)) + state.SetNonce(addr, common.Big(account.Nonce).Uint64()) + state.SetBalance(addr, common.Big(account.Balance)) + for a, v := range account.Storage { + state.SetState(addr, common.HexToHash(a), common.HexToHash(v)) + } } type VmEnv struct { @@ -229,11 +235,11 @@ func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { return self.state.GetBalance(from).Cmp(balance) >= 0 } -func (self *Env) MakeSnapshot() vm.Database { - return self.state.Copy() +func (self *Env) SnapshotDatabase() int { + return self.state.Snapshot() } -func (self *Env) SetSnapshot(copy vm.Database) { - self.state.Set(copy.(*state.StateDB)) +func (self *Env) RevertToSnapshot(snapshot int) { + self.state.RevertToSnapshot(snapshot) } func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { diff --git a/tests/vm_test_util.go b/tests/vm_test_util.go index 4ad72d91c..c269f21e0 100644 --- a/tests/vm_test_util.go +++ b/tests/vm_test_util.go @@ -101,14 +101,7 @@ func BenchVmTest(p string, conf bconf, b *testing.B) error { func benchVmTest(test VmTest, env map[string]string, b *testing.B) { b.StopTimer() db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) - for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) - statedb.SetStateObject(obj) - for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.HexToHash(v)) - } - } + statedb := makePreState(db, test.Pre) b.StartTimer() RunVm(statedb, env, test.Exec) @@ -152,14 +145,7 @@ func runVmTests(tests map[string]VmTest, skipTests []string) error { func runVmTest(test VmTest) error { db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, db) - for addr, account := range test.Pre { - obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) - statedb.SetStateObject(obj) - for a, v := range account.Storage { - obj.SetState(common.HexToHash(a), common.HexToHash(v)) - } - } + statedb := makePreState(db, test.Pre) // XXX Yeah, yeah... env := make(map[string]string)