core/state: implement reverts by journaling all changes

This commit replaces the deep-copy based state revert mechanism with a
linear complexity journal. This commit also hides several internal
StateDB methods to limit the number of ways in which calling code can
use the journal incorrectly.

As usual consultation and bug fixes to the initial implementation were
provided by @karalabe, @obscuren and @Arachnid. Thank you!
This commit is contained in:
Felix Lange 2016-10-04 12:36:02 +02:00
parent ab7adb0027
commit 1f1ea18b54
24 changed files with 670 additions and 253 deletions

View File

@ -172,8 +172,9 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call ethereum.CallM
func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) { func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot())
rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy()) rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
return rval, err return rval, err
} }
@ -197,8 +198,9 @@ func (b *SimulatedBackend) SuggestGasPrice(ctx context.Context) (*big.Int, error
func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMsg) (*big.Int, error) { func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMsg) (*big.Int, error) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot())
_, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy()) _, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
return gas, err return gas, err
} }

View File

@ -227,22 +227,22 @@ type ruleSet struct{}
func (ruleSet) IsHomestead(*big.Int) bool { return true } func (ruleSet) IsHomestead(*big.Int) bool { return true }
func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} } func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} }
func (self *VMEnv) Vm() vm.Vm { return self.evm } func (self *VMEnv) Vm() vm.Vm { return self.evm }
func (self *VMEnv) Db() vm.Database { return self.state } func (self *VMEnv) Db() vm.Database { return self.state }
func (self *VMEnv) MakeSnapshot() vm.Database { return self.state.Copy() } func (self *VMEnv) SnapshotDatabase() int { return self.state.Snapshot() }
func (self *VMEnv) SetSnapshot(db vm.Database) { self.state.Set(db.(*state.StateDB)) } func (self *VMEnv) RevertToSnapshot(snap int) { self.state.RevertToSnapshot(snap) }
func (self *VMEnv) Origin() common.Address { return *self.transactor } func (self *VMEnv) Origin() common.Address { return *self.transactor }
func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 } func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 }
func (self *VMEnv) Coinbase() common.Address { return *self.transactor } func (self *VMEnv) Coinbase() common.Address { return *self.transactor }
func (self *VMEnv) Time() *big.Int { return self.time } func (self *VMEnv) Time() *big.Int { return self.time }
func (self *VMEnv) Difficulty() *big.Int { return common.Big1 } func (self *VMEnv) Difficulty() *big.Int { return common.Big1 }
func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) } func (self *VMEnv) BlockHash() []byte { return make([]byte, 32) }
func (self *VMEnv) Value() *big.Int { return self.value } func (self *VMEnv) Value() *big.Int { return self.value }
func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) } func (self *VMEnv) GasLimit() *big.Int { return big.NewInt(1000000000) }
func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy } func (self *VMEnv) VmType() vm.Type { return vm.StdVmTy }
func (self *VMEnv) Depth() int { return 0 } func (self *VMEnv) Depth() int { return 0 }
func (self *VMEnv) SetDepth(i int) { self.depth = i } func (self *VMEnv) SetDepth(i int) { self.depth = i }
func (self *VMEnv) GetHash(n uint64) common.Hash { func (self *VMEnv) GetHash(n uint64) common.Hash {
if self.block.Number().Cmp(big.NewInt(int64(n))) == 0 { if self.block.Number().Cmp(big.NewInt(int64(n))) == 0 {
return self.block.Hash() return self.block.Hash()

View File

@ -131,7 +131,7 @@ func (b *BlockGen) AddUncheckedReceipt(receipt *types.Receipt) {
// TxNonce returns the next valid transaction nonce for the // TxNonce returns the next valid transaction nonce for the
// account at addr. It panics if the account does not exist. // account at addr. It panics if the account does not exist.
func (b *BlockGen) TxNonce(addr common.Address) uint64 { func (b *BlockGen) TxNonce(addr common.Address) uint64 {
if !b.statedb.HasAccount(addr) { if !b.statedb.Exist(addr) {
panic("account does not exist") panic("account does not exist")
} }
return b.statedb.GetNonce(addr) return b.statedb.GetNonce(addr)

View File

@ -85,7 +85,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
createAccount = true createAccount = true
} }
snapshotPreTransfer := env.MakeSnapshot() snapshotPreTransfer := env.SnapshotDatabase()
var ( var (
from = env.Db().GetAccount(caller.Address()) from = env.Db().GetAccount(caller.Address())
to vm.Account to vm.Account
@ -129,7 +129,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) { if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) {
contract.UseGas(contract.Gas) contract.UseGas(contract.Gas)
env.SetSnapshot(snapshotPreTransfer) env.RevertToSnapshot(snapshotPreTransfer)
} }
return ret, addr, err return ret, addr, err
@ -144,7 +144,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
return nil, common.Address{}, vm.DepthError return nil, common.Address{}, vm.DepthError
} }
snapshot := env.MakeSnapshot() snapshot := env.SnapshotDatabase()
var to vm.Account var to vm.Account
if !env.Db().Exist(*toAddr) { if !env.Db().Exist(*toAddr) {
@ -162,7 +162,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
if err != nil { if err != nil {
contract.UseGas(contract.Gas) contract.UseGas(contract.Gas)
env.SetSnapshot(snapshot) env.RevertToSnapshot(snapshot)
} }
return ret, addr, err return ret, addr, err

View File

@ -52,7 +52,7 @@ func (self *StateDB) RawDump() Dump {
panic(err) panic(err)
} }
obj := NewObject(common.BytesToAddress(addr), data, nil) obj := newObject(nil, common.BytesToAddress(addr), data, nil)
account := DumpAccount{ account := DumpAccount{
Balance: data.Balance.String(), Balance: data.Balance.String(),
Nonce: data.Nonce, Nonce: data.Nonce,

117
core/state/journal.go Normal file
View File

@ -0,0 +1,117 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"math/big"
"github.com/ethereum/go-ethereum/common"
)
type journalEntry interface {
undo(*StateDB)
}
type journal []journalEntry
type (
// Changes to the account trie.
createObjectChange struct {
account *common.Address
}
resetObjectChange struct {
prev *StateObject
}
deleteAccountChange struct {
account *common.Address
prev bool // whether account had already suicided
prevbalance *big.Int
}
// Changes to individual accounts.
balanceChange struct {
account *common.Address
prev *big.Int
}
nonceChange struct {
account *common.Address
prev uint64
}
storageChange struct {
account *common.Address
key, prevalue common.Hash
}
codeChange struct {
account *common.Address
prevcode, prevhash []byte
}
// Changes to other state values.
refundChange struct {
prev *big.Int
}
addLogChange struct {
txhash common.Hash
}
)
func (ch createObjectChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).deleted = true
delete(s.stateObjects, *ch.account)
delete(s.stateObjectsDirty, *ch.account)
}
func (ch resetObjectChange) undo(s *StateDB) {
s.setStateObject(ch.prev)
}
func (ch deleteAccountChange) undo(s *StateDB) {
obj := s.GetStateObject(*ch.account)
if obj != nil {
obj.remove = ch.prev
obj.setBalance(ch.prevbalance)
}
}
func (ch balanceChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setBalance(ch.prev)
}
func (ch nonceChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setNonce(ch.prev)
}
func (ch codeChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
}
func (ch storageChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setState(ch.key, ch.prevalue)
}
func (ch refundChange) undo(s *StateDB) {
s.refund = ch.prev
}
func (ch addLogChange) undo(s *StateDB) {
logs := s.logs[ch.txhash]
if len(logs) == 1 {
delete(s.logs, ch.txhash)
} else {
s.logs[ch.txhash] = logs[:len(logs)-1]
}
}

View File

@ -29,11 +29,8 @@ func create() (*ManagedState, *account) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := New(common.Hash{}, db) statedb, _ := New(common.Hash{}, db)
ms := ManageState(statedb) ms := ManageState(statedb)
so := &StateObject{address: addr} ms.StateDB.SetNonce(addr, 100)
so.SetNonce(100) ms.accounts[addr] = newAccount(ms.StateDB.GetStateObject(addr))
ms.StateDB.stateObjects[addr] = so
ms.accounts[addr] = newAccount(so)
return ms, ms.accounts[addr] return ms, ms.accounts[addr]
} }

View File

@ -66,6 +66,7 @@ func (self Storage) Copy() Storage {
type StateObject struct { type StateObject struct {
address common.Address // Ethereum address of this account address common.Address // Ethereum address of this account
data Account data Account
db *StateDB
// DB error. // DB error.
// State objects are used by the consensus core and VM which are // State objects are used by the consensus core and VM which are
@ -99,15 +100,15 @@ type Account struct {
CodeHash []byte CodeHash []byte
} }
// NewObject creates a state object. // newObject creates a state object.
func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject { func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *StateObject {
if data.Balance == nil { if data.Balance == nil {
data.Balance = new(big.Int) data.Balance = new(big.Int)
} }
if data.CodeHash == nil { if data.CodeHash == nil {
data.CodeHash = emptyCodeHash data.CodeHash = emptyCodeHash
} }
return &StateObject{address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} return &StateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty}
} }
// EncodeRLP implements rlp.Encoder. // EncodeRLP implements rlp.Encoder.
@ -122,7 +123,7 @@ func (self *StateObject) setError(err error) {
} }
} }
func (self *StateObject) MarkForDeletion() { func (self *StateObject) markForDeletion() {
self.remove = true self.remove = true
if self.onDirty != nil { if self.onDirty != nil {
self.onDirty(self.Address()) self.onDirty(self.Address())
@ -163,7 +164,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash
} }
// SetState updates a value in account storage. // SetState updates a value in account storage.
func (self *StateObject) SetState(key, value common.Hash) { func (self *StateObject) SetState(db trie.Database, key, value common.Hash) {
self.db.journal = append(self.db.journal, storageChange{
account: &self.address,
key: key,
prevalue: self.GetState(db, key),
})
self.setState(key, value)
}
func (self *StateObject) setState(key, value common.Hash) {
self.cachedStorage[key] = value self.cachedStorage[key] = value
self.dirtyStorage[key] = value self.dirtyStorage[key] = value
@ -189,7 +199,7 @@ func (self *StateObject) updateTrie(db trie.Database) {
} }
// UpdateRoot sets the trie root to the current root hash of // UpdateRoot sets the trie root to the current root hash of
func (self *StateObject) UpdateRoot(db trie.Database) { func (self *StateObject) updateRoot(db trie.Database) {
self.updateTrie(db) self.updateTrie(db)
self.data.Root = self.trie.Hash() self.data.Root = self.trie.Hash()
} }
@ -232,6 +242,14 @@ func (c *StateObject) SubBalance(amount *big.Int) {
} }
func (self *StateObject) SetBalance(amount *big.Int) { func (self *StateObject) SetBalance(amount *big.Int) {
self.db.journal = append(self.db.journal, balanceChange{
account: &self.address,
prev: new(big.Int).Set(self.data.Balance),
})
self.setBalance(amount)
}
func (self *StateObject) setBalance(amount *big.Int) {
self.data.Balance = amount self.data.Balance = amount
if self.onDirty != nil { if self.onDirty != nil {
self.onDirty(self.Address()) self.onDirty(self.Address())
@ -242,8 +260,8 @@ func (self *StateObject) SetBalance(amount *big.Int) {
// Return the gas back to the origin. Used by the Virtual machine or Closures // Return the gas back to the origin. Used by the Virtual machine or Closures
func (c *StateObject) ReturnGas(gas, price *big.Int) {} func (c *StateObject) ReturnGas(gas, price *big.Int) {}
func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject { func (self *StateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *StateObject {
stateObject := NewObject(self.address, self.data, onDirty) stateObject := newObject(db, self.address, self.data, onDirty)
stateObject.trie = self.trie stateObject.trie = self.trie
stateObject.code = self.code stateObject.code = self.code
stateObject.dirtyStorage = self.dirtyStorage.Copy() stateObject.dirtyStorage = self.dirtyStorage.Copy()
@ -280,6 +298,16 @@ func (self *StateObject) Code(db trie.Database) []byte {
} }
func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
prevcode := self.Code(self.db.db)
self.db.journal = append(self.db.journal, codeChange{
account: &self.address,
prevhash: self.CodeHash(),
prevcode: prevcode,
})
self.setCode(codeHash, code)
}
func (self *StateObject) setCode(codeHash common.Hash, code []byte) {
self.code = code self.code = code
self.data.CodeHash = codeHash[:] self.data.CodeHash = codeHash[:]
self.dirtyCode = true self.dirtyCode = true
@ -290,6 +318,14 @@ func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
} }
func (self *StateObject) SetNonce(nonce uint64) { func (self *StateObject) SetNonce(nonce uint64) {
self.db.journal = append(self.db.journal, nonceChange{
account: &self.address,
prev: self.data.Nonce,
})
self.setNonce(nonce)
}
func (self *StateObject) setNonce(nonce uint64) {
self.data.Nonce = nonce self.data.Nonce = nonce
if self.onDirty != nil { if self.onDirty != nil {
self.onDirty(self.Address()) self.onDirty(self.Address())
@ -322,7 +358,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) {
cb(h, value) cb(h, value)
} }
it := self.trie.Iterator() it := self.getTrie(self.db.db).Iterator()
for it.Next() { for it.Next() {
// ignore cached values // ignore cached values
key := common.BytesToHash(self.trie.GetKey(it.Key)) key := common.BytesToHash(self.trie.GetKey(it.Key))

View File

@ -46,8 +46,8 @@ func (s *StateSuite) TestDump(c *checker.C) {
obj3.SetBalance(big.NewInt(44)) obj3.SetBalance(big.NewInt(44))
// write some of them to the trie // write some of them to the trie
s.state.UpdateStateObject(obj1) s.state.updateStateObject(obj1)
s.state.UpdateStateObject(obj2) s.state.updateStateObject(obj2)
s.state.Commit() s.state.Commit()
// check that dump contains the state objects that are in trie // check that dump contains the state objects that are in trie
@ -116,12 +116,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
// set initial state object value // set initial state object value
s.state.SetState(stateobjaddr, storageaddr, data1) s.state.SetState(stateobjaddr, storageaddr, data1)
// get snapshot of current state // get snapshot of current state
snapshot := s.state.Copy() snapshot := s.state.Snapshot()
// set new state object value // set new state object value
s.state.SetState(stateobjaddr, storageaddr, data2) s.state.SetState(stateobjaddr, storageaddr, data2)
// restore snapshot // restore snapshot
s.state.Set(snapshot) s.state.RevertToSnapshot(snapshot)
// get state storage value // get state storage value
res := s.state.GetState(stateobjaddr, storageaddr) res := s.state.GetState(stateobjaddr, storageaddr)
@ -129,6 +129,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
c.Assert(data1, checker.DeepEquals, res) c.Assert(data1, checker.DeepEquals, res)
} }
func TestSnapshotEmpty(t *testing.T) {
db, _ := ethdb.NewMemDatabase()
state, _ := New(common.Hash{}, db)
state.RevertToSnapshot(state.Snapshot())
}
// use testing instead of checker because checker does not support // use testing instead of checker because checker does not support
// printing/logging in tests (-check.vv does not work) // printing/logging in tests (-check.vv does not work)
func TestSnapshot2(t *testing.T) { func TestSnapshot2(t *testing.T) {
@ -152,7 +158,7 @@ func TestSnapshot2(t *testing.T) {
so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'})
so0.remove = false so0.remove = false
so0.deleted = false so0.deleted = false
state.SetStateObject(so0) state.setStateObject(so0)
root, _ := state.Commit() root, _ := state.Commit()
state.Reset(root) state.Reset(root)
@ -164,15 +170,15 @@ func TestSnapshot2(t *testing.T) {
so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'}) so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'})
so1.remove = true so1.remove = true
so1.deleted = true so1.deleted = true
state.SetStateObject(so1) state.setStateObject(so1)
so1 = state.GetStateObject(stateobjaddr1) so1 = state.GetStateObject(stateobjaddr1)
if so1 != nil { if so1 != nil {
t.Fatalf("deleted object not nil when getting") t.Fatalf("deleted object not nil when getting")
} }
snapshot := state.Copy() snapshot := state.Snapshot()
state.Set(snapshot) state.RevertToSnapshot(snapshot)
so0Restored := state.GetStateObject(stateobjaddr0) so0Restored := state.GetStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing. // Update lazily-loaded values before comparing.

View File

@ -20,6 +20,7 @@ package state
import ( import (
"fmt" "fmt"
"math/big" "math/big"
"sort"
"sync" "sync"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -40,12 +41,17 @@ var StartingNonce uint64
const ( const (
// Number of past tries to keep. The arbitrarily chosen value here // Number of past tries to keep. The arbitrarily chosen value here
// is max uncle depth + 1. // is max uncle depth + 1.
maxJournalLength = 8 maxTrieCacheLength = 8
// Number of codehash->size associations to keep. // Number of codehash->size associations to keep.
codeSizeCacheSize = 100000 codeSizeCacheSize = 100000
) )
type revision struct {
id int
journalIndex int
}
// StateDBs within the ethereum protocol are used to store anything // StateDBs within the ethereum protocol are used to store anything
// within the merkle trie. StateDBs take care of caching and storing // within the merkle trie. StateDBs take care of caching and storing
// nested states. It's the general query interface to retrieve: // nested states. It's the general query interface to retrieve:
@ -69,6 +75,12 @@ type StateDB struct {
logs map[common.Hash]vm.Logs logs map[common.Hash]vm.Logs
logSize uint logSize uint
// Journal of state modifications. This is the backbone of
// Snapshot and RevertToSnapshot.
journal journal
validRevisions []revision
nextRevisionId int
lock sync.Mutex lock sync.Mutex
} }
@ -124,12 +136,12 @@ func (self *StateDB) Reset(root common.Hash) error {
self.trie = tr self.trie = tr
self.stateObjects = make(map[common.Address]*StateObject) self.stateObjects = make(map[common.Address]*StateObject)
self.stateObjectsDirty = make(map[common.Address]struct{}) self.stateObjectsDirty = make(map[common.Address]struct{})
self.refund = new(big.Int)
self.thash = common.Hash{} self.thash = common.Hash{}
self.bhash = common.Hash{} self.bhash = common.Hash{}
self.txIndex = 0 self.txIndex = 0
self.logs = make(map[common.Hash]vm.Logs) self.logs = make(map[common.Hash]vm.Logs)
self.logSize = 0 self.logSize = 0
self.clearJournalAndRefund()
return nil return nil
} }
@ -150,7 +162,7 @@ func (self *StateDB) pushTrie(t *trie.SecureTrie) {
self.lock.Lock() self.lock.Lock()
defer self.lock.Unlock() defer self.lock.Unlock()
if len(self.pastTries) >= maxJournalLength { if len(self.pastTries) >= maxTrieCacheLength {
copy(self.pastTries, self.pastTries[1:]) copy(self.pastTries, self.pastTries[1:])
self.pastTries[len(self.pastTries)-1] = t self.pastTries[len(self.pastTries)-1] = t
} else { } else {
@ -165,6 +177,8 @@ func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) {
} }
func (self *StateDB) AddLog(log *vm.Log) { func (self *StateDB) AddLog(log *vm.Log) {
self.journal = append(self.journal, addLogChange{txhash: self.thash})
log.TxHash = self.thash log.TxHash = self.thash
log.BlockHash = self.bhash log.BlockHash = self.bhash
log.TxIndex = uint(self.txIndex) log.TxIndex = uint(self.txIndex)
@ -186,13 +200,12 @@ func (self *StateDB) Logs() vm.Logs {
} }
func (self *StateDB) AddRefund(gas *big.Int) { func (self *StateDB) AddRefund(gas *big.Int) {
self.journal = append(self.journal, refundChange{prev: new(big.Int).Set(self.refund)})
self.refund.Add(self.refund, gas) self.refund.Add(self.refund, gas)
} }
func (self *StateDB) HasAccount(addr common.Address) bool { // Exist reports whether the given account address exists in the state.
return self.GetStateObject(addr) != nil // Notably this also returns true for suicided accounts.
}
func (self *StateDB) Exist(addr common.Address) bool { func (self *StateDB) Exist(addr common.Address) bool {
return self.GetStateObject(addr) != nil return self.GetStateObject(addr) != nil
} }
@ -207,7 +220,6 @@ func (self *StateDB) GetBalance(addr common.Address) *big.Int {
if stateObject != nil { if stateObject != nil {
return stateObject.Balance() return stateObject.Balance()
} }
return common.Big0 return common.Big0
} }
@ -282,6 +294,13 @@ func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) {
} }
} }
func (self *StateDB) SetBalance(addr common.Address, amount *big.Int) {
stateObject := self.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetBalance(amount)
}
}
func (self *StateDB) SetNonce(addr common.Address, nonce uint64) { func (self *StateDB) SetNonce(addr common.Address, nonce uint64) {
stateObject := self.GetOrNewStateObject(addr) stateObject := self.GetOrNewStateObject(addr)
if stateObject != nil { if stateObject != nil {
@ -299,27 +318,36 @@ func (self *StateDB) SetCode(addr common.Address, code []byte) {
func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) { func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) {
stateObject := self.GetOrNewStateObject(addr) stateObject := self.GetOrNewStateObject(addr)
if stateObject != nil { if stateObject != nil {
stateObject.SetState(key, value) stateObject.SetState(self.db, key, value)
} }
} }
// Delete marks the given account as suicided.
// This clears the account balance.
//
// The account's state object is still available until the state is committed,
// GetStateObject will return a non-nil account after Delete.
func (self *StateDB) Delete(addr common.Address) bool { func (self *StateDB) Delete(addr common.Address) bool {
stateObject := self.GetStateObject(addr) stateObject := self.GetStateObject(addr)
if stateObject != nil { if stateObject == nil {
stateObject.MarkForDeletion() return false
stateObject.data.Balance = new(big.Int)
return true
} }
self.journal = append(self.journal, deleteAccountChange{
return false account: &addr,
prev: stateObject.remove,
prevbalance: new(big.Int).Set(stateObject.Balance()),
})
stateObject.markForDeletion()
stateObject.data.Balance = new(big.Int)
return true
} }
// //
// Setting, updating & deleting state object methods // Setting, updating & deleting state object methods
// //
// Update the given state object and apply it to state trie // updateStateObject writes the given object to the trie.
func (self *StateDB) UpdateStateObject(stateObject *StateObject) { func (self *StateDB) updateStateObject(stateObject *StateObject) {
addr := stateObject.Address() addr := stateObject.Address()
data, err := rlp.EncodeToBytes(stateObject) data, err := rlp.EncodeToBytes(stateObject)
if err != nil { if err != nil {
@ -328,10 +356,9 @@ func (self *StateDB) UpdateStateObject(stateObject *StateObject) {
self.trie.Update(addr[:], data) self.trie.Update(addr[:], data)
} }
// Delete the given state object and delete it from the state trie // deleteStateObject removes the given object from the state trie.
func (self *StateDB) DeleteStateObject(stateObject *StateObject) { func (self *StateDB) deleteStateObject(stateObject *StateObject) {
stateObject.deleted = true stateObject.deleted = true
addr := stateObject.Address() addr := stateObject.Address()
self.trie.Delete(addr[:]) self.trie.Delete(addr[:])
} }
@ -357,12 +384,12 @@ func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObje
return nil return nil
} }
// Insert into the live set. // Insert into the live set.
obj := NewObject(addr, data, self.MarkStateObjectDirty) obj := newObject(self, addr, data, self.MarkStateObjectDirty)
self.SetStateObject(obj) self.setStateObject(obj)
return obj return obj
} }
func (self *StateDB) SetStateObject(object *StateObject) { func (self *StateDB) setStateObject(object *StateObject) {
self.stateObjects[object.Address()] = object self.stateObjects[object.Address()] = object
} }
@ -370,52 +397,55 @@ func (self *StateDB) SetStateObject(object *StateObject) {
func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject { func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject {
stateObject := self.GetStateObject(addr) stateObject := self.GetStateObject(addr)
if stateObject == nil || stateObject.deleted { if stateObject == nil || stateObject.deleted {
stateObject = self.CreateStateObject(addr) stateObject, _ = self.createObject(addr)
} }
return stateObject return stateObject
} }
// NewStateObject create a state object whether it exist in the trie or not
func (self *StateDB) newStateObject(addr common.Address) *StateObject {
if glog.V(logger.Core) {
glog.Infof("(+) %x\n", addr)
}
obj := NewObject(addr, Account{}, self.MarkStateObjectDirty)
obj.SetNonce(StartingNonce) // sets the object to dirty
self.stateObjects[addr] = obj
return obj
}
// MarkStateObjectDirty adds the specified object to the dirty map to avoid costly // MarkStateObjectDirty adds the specified object to the dirty map to avoid costly
// state object cache iteration to find a handful of modified ones. // state object cache iteration to find a handful of modified ones.
func (self *StateDB) MarkStateObjectDirty(addr common.Address) { func (self *StateDB) MarkStateObjectDirty(addr common.Address) {
self.stateObjectsDirty[addr] = struct{}{} self.stateObjectsDirty[addr] = struct{}{}
} }
// Creates creates a new state object and takes ownership. // createObject creates a new state object. If there is an existing account with
func (self *StateDB) CreateStateObject(addr common.Address) *StateObject { // the given address, it is overwritten and returned as the second return value.
// Get previous (if any) func (self *StateDB) createObject(addr common.Address) (newobj, prev *StateObject) {
so := self.GetStateObject(addr) prev = self.GetStateObject(addr)
// Create a new one newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty)
newSo := self.newStateObject(addr) newobj.setNonce(StartingNonce) // sets the object to dirty
if prev == nil {
// If it existed set the balance to the new account if glog.V(logger.Core) {
if so != nil { glog.Infof("(+) %x\n", addr)
newSo.data.Balance = so.data.Balance }
self.journal = append(self.journal, createObjectChange{account: &addr})
} else {
self.journal = append(self.journal, resetObjectChange{prev: prev})
} }
self.setStateObject(newobj)
return newSo return newobj, prev
} }
// CreateAccount explicitly creates a state object. If a state object with the address
// already exists the balance is carried over to the new account.
//
// CreateAccount is called during the EVM CREATE operation. The situation might arise that
// a contract does the following:
//
// 1. sends funds to sha(account ++ (nonce + 1))
// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1)
//
// Carrying over the balance ensures that Ether doesn't disappear.
func (self *StateDB) CreateAccount(addr common.Address) vm.Account { func (self *StateDB) CreateAccount(addr common.Address) vm.Account {
return self.CreateStateObject(addr) new, prev := self.createObject(addr)
if prev != nil {
new.setBalance(prev.data.Balance)
}
return new
} }
// // Copy creates a deep, independent copy of the state.
// Setting, copying of the state methods // Snapshots of the copied state cannot be applied to the copy.
//
func (self *StateDB) Copy() *StateDB { func (self *StateDB) Copy() *StateDB {
self.lock.Lock() self.lock.Lock()
defer self.lock.Unlock() defer self.lock.Unlock()
@ -434,7 +464,7 @@ func (self *StateDB) Copy() *StateDB {
} }
// Copy the dirty states and logs // Copy the dirty states and logs
for addr, _ := range self.stateObjectsDirty { for addr, _ := range self.stateObjectsDirty {
state.stateObjects[addr] = self.stateObjects[addr].Copy(self.db, state.MarkStateObjectDirty) state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty)
state.stateObjectsDirty[addr] = struct{}{} state.stateObjectsDirty[addr] = struct{}{}
} }
for hash, logs := range self.logs { for hash, logs := range self.logs {
@ -444,21 +474,38 @@ func (self *StateDB) Copy() *StateDB {
return state return state
} }
func (self *StateDB) Set(state *StateDB) { // Snapshot returns an identifier for the current revision of the state.
self.lock.Lock() func (self *StateDB) Snapshot() int {
defer self.lock.Unlock() id := self.nextRevisionId
self.nextRevisionId++
self.db = state.db self.validRevisions = append(self.validRevisions, revision{id, len(self.journal)})
self.trie = state.trie return id
self.pastTries = state.pastTries
self.stateObjects = state.stateObjects
self.stateObjectsDirty = state.stateObjectsDirty
self.codeSizeCache = state.codeSizeCache
self.refund = state.refund
self.logs = state.logs
self.logSize = state.logSize
} }
// RevertToSnapshot reverts all state changes made since the given revision.
func (self *StateDB) RevertToSnapshot(revid int) {
// Find the snapshot in the stack of valid snapshots.
idx := sort.Search(len(self.validRevisions), func(i int) bool {
return self.validRevisions[i].id >= revid
})
if idx == len(self.validRevisions) || self.validRevisions[idx].id != revid {
panic(fmt.Errorf("revision id %v cannot be reverted", revid))
}
snapshot := self.validRevisions[idx].journalIndex
// Replay the journal to undo changes.
for i := len(self.journal) - 1; i >= snapshot; i-- {
self.journal[i].undo(self)
}
self.journal = self.journal[:snapshot]
// Remove invalidated snapshots from the stack.
self.validRevisions = self.validRevisions[:idx]
}
// GetRefund returns the current value of the refund counter.
// The return value must not be modified by the caller and will become
// invalid at the next call to AddRefund.
func (self *StateDB) GetRefund() *big.Int { func (self *StateDB) GetRefund() *big.Int {
return self.refund return self.refund
} }
@ -467,16 +514,17 @@ func (self *StateDB) GetRefund() *big.Int {
// It is called in between transactions to get the root hash that // It is called in between transactions to get the root hash that
// goes into transaction receipts. // goes into transaction receipts.
func (s *StateDB) IntermediateRoot() common.Hash { func (s *StateDB) IntermediateRoot() common.Hash {
s.refund = new(big.Int)
for addr, _ := range s.stateObjectsDirty { for addr, _ := range s.stateObjectsDirty {
stateObject := s.stateObjects[addr] stateObject := s.stateObjects[addr]
if stateObject.remove { if stateObject.remove {
s.DeleteStateObject(stateObject) s.deleteStateObject(stateObject)
} else { } else {
stateObject.UpdateRoot(s.db) stateObject.updateRoot(s.db)
s.UpdateStateObject(stateObject) s.updateStateObject(stateObject)
} }
} }
// Invalidate journal because reverting across transactions is not allowed.
s.clearJournalAndRefund()
return s.trie.Hash() return s.trie.Hash()
} }
@ -486,9 +534,9 @@ func (s *StateDB) IntermediateRoot() common.Hash {
// DeleteSuicides should not be used for consensus related updates // DeleteSuicides should not be used for consensus related updates
// under any circumstances. // under any circumstances.
func (s *StateDB) DeleteSuicides() { func (s *StateDB) DeleteSuicides() {
// Reset refund so that any used-gas calculations can use // Reset refund so that any used-gas calculations can use this method.
// this method. s.clearJournalAndRefund()
s.refund = new(big.Int)
for addr, _ := range s.stateObjectsDirty { for addr, _ := range s.stateObjectsDirty {
stateObject := s.stateObjects[addr] stateObject := s.stateObjects[addr]
@ -516,15 +564,21 @@ func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) {
return root, batch return root, batch
} }
func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { func (s *StateDB) clearJournalAndRefund() {
s.journal = nil
s.validRevisions = s.validRevisions[:0]
s.refund = new(big.Int) s.refund = new(big.Int)
}
func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) {
defer s.clearJournalAndRefund()
// Commit objects to the trie. // Commit objects to the trie.
for addr, stateObject := range s.stateObjects { for addr, stateObject := range s.stateObjects {
if stateObject.remove { if stateObject.remove {
// If the object has been removed, don't bother syncing it // If the object has been removed, don't bother syncing it
// and just mark it for deletion in the trie. // and just mark it for deletion in the trie.
s.DeleteStateObject(stateObject) s.deleteStateObject(stateObject)
} else if _, ok := s.stateObjectsDirty[addr]; ok { } else if _, ok := s.stateObjectsDirty[addr]; ok {
// Write any contract code associated with the state object // Write any contract code associated with the state object
if stateObject.code != nil && stateObject.dirtyCode { if stateObject.code != nil && stateObject.dirtyCode {
@ -538,7 +592,7 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error)
return common.Hash{}, err return common.Hash{}, err
} }
// Update the object in the main account trie. // Update the object in the main account trie.
s.UpdateStateObject(stateObject) s.updateStateObject(stateObject)
} }
delete(s.stateObjectsDirty, addr) delete(s.stateObjectsDirty, addr)
} }
@ -549,7 +603,3 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error)
} }
return root, err return root, err
} }
func (self *StateDB) Refunds() *big.Int {
return self.refund
}

View File

@ -17,11 +17,19 @@
package state package state
import ( import (
"bytes"
"encoding/binary"
"fmt"
"math"
"math/big" "math/big"
"math/rand"
"reflect"
"strings"
"testing" "testing"
"testing/quick"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
) )
@ -34,16 +42,16 @@ func TestUpdateLeaks(t *testing.T) {
// Update it with some accounts // Update it with some accounts
for i := byte(0); i < 255; i++ { for i := byte(0); i < 255; i++ {
obj := state.GetOrNewStateObject(common.BytesToAddress([]byte{i})) addr := common.BytesToAddress([]byte{i})
obj.AddBalance(big.NewInt(int64(11 * i))) state.AddBalance(addr, big.NewInt(int64(11*i)))
obj.SetNonce(uint64(42 * i)) state.SetNonce(addr, uint64(42*i))
if i%2 == 0 { if i%2 == 0 {
obj.SetState(common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i})) state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
} }
if i%3 == 0 { if i%3 == 0 {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) state.SetCode(addr, []byte{i, i, i, i, i})
} }
state.UpdateStateObject(obj) state.IntermediateRoot()
} }
// Ensure that no data was leaked into the database // Ensure that no data was leaked into the database
for _, key := range db.Keys() { for _, key := range db.Keys() {
@ -61,51 +69,38 @@ func TestIntermediateLeaks(t *testing.T) {
transState, _ := New(common.Hash{}, transDb) transState, _ := New(common.Hash{}, transDb)
finalState, _ := New(common.Hash{}, finalDb) finalState, _ := New(common.Hash{}, finalDb)
// Update the states with some objects modify := func(state *StateDB, addr common.Address, i, tweak byte) {
for i := byte(0); i < 255; i++ { state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
// Create a new state object with some data into the transition database state.SetNonce(addr, uint64(42*i+tweak))
obj := transState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
obj.SetBalance(big.NewInt(int64(11 * i)))
obj.SetNonce(uint64(42 * i))
if i%2 == 0 { if i%2 == 0 {
obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.BytesToHash([]byte{i, i, i, i, 0})) state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
} }
if i%3 == 0 { if i%3 == 0 {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 0}), []byte{i, i, i, i, i, 0}) state.SetCode(addr, []byte{i, i, i, i, i, tweak})
} }
transState.UpdateStateObject(obj)
// Overwrite all the data with new values in the transition database
obj.SetBalance(big.NewInt(int64(11*i + 1)))
obj.SetNonce(uint64(42*i + 1))
if i%2 == 0 {
obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.Hash{})
obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
}
if i%3 == 0 {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
}
transState.UpdateStateObject(obj)
// Create the final state object directly in the final database
obj = finalState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
obj.SetBalance(big.NewInt(int64(11*i + 1)))
obj.SetNonce(uint64(42*i + 1))
if i%2 == 0 {
obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
}
if i%3 == 0 {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
}
finalState.UpdateStateObject(obj)
} }
// Modify the transient state.
for i := byte(0); i < 255; i++ {
modify(transState, common.Address{byte(i)}, i, 0)
}
// Write modifications to trie.
transState.IntermediateRoot()
// Overwrite all the data with new values in the transient database.
for i := byte(0); i < 255; i++ {
modify(transState, common.Address{byte(i)}, i, 99)
modify(finalState, common.Address{byte(i)}, i, 99)
}
// Commit and cross check the databases.
if _, err := transState.Commit(); err != nil { if _, err := transState.Commit(); err != nil {
t.Fatalf("failed to commit transition state: %v", err) t.Fatalf("failed to commit transition state: %v", err)
} }
if _, err := finalState.Commit(); err != nil { if _, err := finalState.Commit(); err != nil {
t.Fatalf("failed to commit final state: %v", err) t.Fatalf("failed to commit final state: %v", err)
} }
// Cross check the databases to ensure they are the same
for _, key := range finalDb.Keys() { for _, key := range finalDb.Keys() {
if _, err := transDb.Get(key); err != nil { if _, err := transDb.Get(key); err != nil {
val, _ := finalDb.Get(key) val, _ := finalDb.Get(key)
@ -119,3 +114,243 @@ func TestIntermediateLeaks(t *testing.T) {
} }
} }
} }
func TestSnapshotRandom(t *testing.T) {
config := &quick.Config{MaxCount: 1000}
err := quick.Check((*snapshotTest).run, config)
if cerr, ok := err.(*quick.CheckError); ok {
test := cerr.In[0].(*snapshotTest)
t.Errorf("%v:\n%s", test.err, test)
} else if err != nil {
t.Error(err)
}
}
// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
// captured by the snapshot. Instances of this test with pseudorandom content are created
// by Generate.
//
// The test works as follows:
//
// A new state is created and all actions are applied to it. Several snapshots are taken
// in between actions. The test then reverts each snapshot. For each snapshot the actions
// leading up to it are replayed on a fresh, empty state. The behaviour of all public
// accessor methods on the reverted state must match the return value of the equivalent
// methods on the replayed state.
type snapshotTest struct {
addrs []common.Address // all account addresses
actions []testAction // modifications to the state
snapshots []int // actions indexes at which snapshot is taken
err error // failure details are reported through this field
}
type testAction struct {
name string
fn func(testAction, *StateDB)
args []int64
noAddr bool
}
// newTestAction creates a random action that changes state.
func newTestAction(addr common.Address, r *rand.Rand) testAction {
actions := []testAction{
{
name: "SetBalance",
fn: func(a testAction, s *StateDB) {
s.SetBalance(addr, big.NewInt(a.args[0]))
},
args: make([]int64, 1),
},
{
name: "AddBalance",
fn: func(a testAction, s *StateDB) {
s.AddBalance(addr, big.NewInt(a.args[0]))
},
args: make([]int64, 1),
},
{
name: "SetNonce",
fn: func(a testAction, s *StateDB) {
s.SetNonce(addr, uint64(a.args[0]))
},
args: make([]int64, 1),
},
{
name: "SetState",
fn: func(a testAction, s *StateDB) {
var key, val common.Hash
binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
s.SetState(addr, key, val)
},
args: make([]int64, 2),
},
{
name: "SetCode",
fn: func(a testAction, s *StateDB) {
code := make([]byte, 16)
binary.BigEndian.PutUint64(code, uint64(a.args[0]))
binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
s.SetCode(addr, code)
},
args: make([]int64, 2),
},
{
name: "CreateAccount",
fn: func(a testAction, s *StateDB) {
s.CreateAccount(addr)
},
},
{
name: "Delete",
fn: func(a testAction, s *StateDB) {
s.Delete(addr)
},
},
{
name: "AddRefund",
fn: func(a testAction, s *StateDB) {
s.AddRefund(big.NewInt(a.args[0]))
},
args: make([]int64, 1),
noAddr: true,
},
{
name: "AddLog",
fn: func(a testAction, s *StateDB) {
data := make([]byte, 2)
binary.BigEndian.PutUint16(data, uint16(a.args[0]))
s.AddLog(&vm.Log{Address: addr, Data: data})
},
args: make([]int64, 1),
},
}
action := actions[r.Intn(len(actions))]
var nameargs []string
if !action.noAddr {
nameargs = append(nameargs, addr.Hex())
}
for _, i := range action.args {
action.args[i] = rand.Int63n(100)
nameargs = append(nameargs, fmt.Sprint(action.args[i]))
}
action.name += strings.Join(nameargs, ", ")
return action
}
// Generate returns a new snapshot test of the given size. All randomness is
// derived from r.
func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
// Generate random actions.
addrs := make([]common.Address, 50)
for i := range addrs {
addrs[i][0] = byte(i)
}
actions := make([]testAction, size)
for i := range actions {
addr := addrs[r.Intn(len(addrs))]
actions[i] = newTestAction(addr, r)
}
// Generate snapshot indexes.
nsnapshots := int(math.Sqrt(float64(size)))
if size > 0 && nsnapshots == 0 {
nsnapshots = 1
}
snapshots := make([]int, nsnapshots)
snaplen := len(actions) / nsnapshots
for i := range snapshots {
// Try to place the snapshots some number of actions apart from each other.
snapshots[i] = (i * snaplen) + r.Intn(snaplen)
}
return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
}
func (test *snapshotTest) String() string {
out := new(bytes.Buffer)
sindex := 0
for i, action := range test.actions {
if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
sindex++
}
fmt.Fprintf(out, "%4d: %s\n", i, action.name)
}
return out.String()
}
func (test *snapshotTest) run() bool {
// Run all actions and create snapshots.
var (
db, _ = ethdb.NewMemDatabase()
state, _ = New(common.Hash{}, db)
snapshotRevs = make([]int, len(test.snapshots))
sindex = 0
)
for i, action := range test.actions {
if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
snapshotRevs[sindex] = state.Snapshot()
sindex++
}
action.fn(action, state)
}
// Revert all snapshots in reverse order. Each revert must yield a state
// that is equivalent to fresh state with all actions up the snapshot applied.
for sindex--; sindex >= 0; sindex-- {
checkstate, _ := New(common.Hash{}, db)
for _, action := range test.actions[:test.snapshots[sindex]] {
action.fn(action, checkstate)
}
state.RevertToSnapshot(snapshotRevs[sindex])
if err := test.checkEqual(state, checkstate); err != nil {
test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
return false
}
}
return true
}
// checkEqual checks that methods of state and checkstate return the same values.
func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
for _, addr := range test.addrs {
var err error
checkeq := func(op string, a, b interface{}) bool {
if err == nil && !reflect.DeepEqual(a, b) {
err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
return false
}
return true
}
// Check basic accessor methods.
checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
checkeq("IsDeleted", state.IsDeleted(addr), checkstate.IsDeleted(addr))
checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
// Check storage.
if obj := state.GetStateObject(addr); obj != nil {
obj.ForEachStorage(func(key, val common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
})
checkobj := checkstate.GetStateObject(addr)
checkobj.ForEachStorage(func(key, checkval common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
})
}
if err != nil {
return err
}
}
if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 {
return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
state.GetRefund(), checkstate.GetRefund())
}
if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
}
return nil
}

View File

@ -57,7 +57,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i})
acc.code = []byte{i, i, i, i, i} acc.code = []byte{i, i, i, i, i}
} }
state.UpdateStateObject(obj) state.updateStateObject(obj)
accounts = append(accounts, acc) accounts = append(accounts, acc)
} }
root, _ := state.Commit() root, _ := state.Commit()

View File

@ -257,7 +257,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction) error {
// Make sure the account exist. Non existent accounts // Make sure the account exist. Non existent accounts
// haven't got funds and well therefor never pass. // haven't got funds and well therefor never pass.
if !currentState.HasAccount(from) { if !currentState.Exist(from) {
return ErrNonExistentAccount return ErrNonExistentAccount
} }

View File

@ -36,9 +36,9 @@ type Environment interface {
// The state database // The state database
Db() Database Db() Database
// Creates a restorable snapshot // Creates a restorable snapshot
MakeSnapshot() Database SnapshotDatabase() int
// Set database to previous snapshot // Set database to previous snapshot
SetSnapshot(Database) RevertToSnapshot(int)
// Address of the original invoker (first occurrence of the VM invoker) // Address of the original invoker (first occurrence of the VM invoker)
Origin() common.Address Origin() common.Address
// The block number this VM is invoked on // The block number this VM is invoked on

View File

@ -179,8 +179,8 @@ func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) }
//func (self *Env) PrevHash() []byte { return self.parent } //func (self *Env) PrevHash() []byte { return self.parent }
func (self *Env) Coinbase() common.Address { return common.Address{} } func (self *Env) Coinbase() common.Address { return common.Address{} }
func (self *Env) MakeSnapshot() Database { return nil } func (self *Env) SnapshotDatabase() int { return 0 }
func (self *Env) SetSnapshot(Database) {} func (self *Env) RevertToSnapshot(int) {}
func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
func (self *Env) Db() Database { return nil } func (self *Env) Db() Database { return nil }

View File

@ -86,11 +86,11 @@ func (self *Env) SetDepth(i int) { self.depth = i }
func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0 return self.state.GetBalance(from).Cmp(balance) >= 0
} }
func (self *Env) MakeSnapshot() vm.Database { func (self *Env) SnapshotDatabase() int {
return self.state.Copy() return self.state.Snapshot()
} }
func (self *Env) SetSnapshot(copy vm.Database) { func (self *Env) RevertToSnapshot(snapshot int) {
self.state.Set(copy.(*state.StateDB)) self.state.RevertToSnapshot(snapshot)
} }
func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { func (self *Env) Transfer(from, to vm.Account, amount *big.Int) {

View File

@ -89,12 +89,12 @@ func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0 return self.state.GetBalance(from).Cmp(balance) >= 0
} }
func (self *VMEnv) MakeSnapshot() vm.Database { func (self *VMEnv) SnapshotDatabase() int {
return self.state.Copy() return self.state.Snapshot()
} }
func (self *VMEnv) SetSnapshot(copy vm.Database) { func (self *VMEnv) RevertToSnapshot(snapshot int) {
self.state.Set(copy.(*state.StateDB)) self.state.RevertToSnapshot(snapshot)
} }
func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) { func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) {

View File

@ -98,12 +98,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int {
} }
func (b *EthApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) { func (b *EthApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) {
stateDb := state.(EthApiState).state.Copy() statedb := state.(EthApiState).state
addr, _ := msg.From() addr, _ := msg.From()
from := stateDb.GetOrNewStateObject(addr) from := statedb.GetOrNewStateObject(addr)
from.SetBalance(common.MaxBig) from.SetBalance(common.MaxBig)
vmError := func() error { return nil } vmError := func() error { return nil }
return core.NewEnv(stateDb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil return core.NewEnv(statedb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil
} }
func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error {

View File

@ -50,14 +50,14 @@ func (self *Env) Origin() common.Address { return common.Address{} }
func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) } func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) }
//func (self *Env) PrevHash() []byte { return self.parent } //func (self *Env) PrevHash() []byte { return self.parent }
func (self *Env) Coinbase() common.Address { return common.Address{} } func (self *Env) Coinbase() common.Address { return common.Address{} }
func (self *Env) MakeSnapshot() vm.Database { return nil } func (self *Env) SnapshotDatabase() int { return 0 }
func (self *Env) SetSnapshot(vm.Database) {} func (self *Env) RevertToSnapshot(int) {}
func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
func (self *Env) Db() vm.Database { return nil } func (self *Env) Db() vm.Database { return nil }
func (self *Env) GasLimit() *big.Int { return self.gasLimit } func (self *Env) GasLimit() *big.Int { return self.gasLimit }
func (self *Env) VmType() vm.Type { return vm.StdVmTy } func (self *Env) VmType() vm.Type { return vm.StdVmTy }
func (self *Env) GetHash(n uint64) common.Hash { func (self *Env) GetHash(n uint64) common.Hash {
return common.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String()))) return common.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String())))
} }

View File

@ -23,7 +23,6 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -54,16 +53,13 @@ func makeTestState() (common.Hash, ethdb.Database) {
sdb, _ := ethdb.NewMemDatabase() sdb, _ := ethdb.NewMemDatabase()
st, _ := state.New(common.Hash{}, sdb) st, _ := state.New(common.Hash{}, sdb)
for i := byte(0); i < 100; i++ { for i := byte(0); i < 100; i++ {
so := st.GetOrNewStateObject(common.Address{i}) addr := common.Address{i}
for j := byte(0); j < 100; j++ { for j := byte(0); j < 100; j++ {
val := common.Hash{i, j} st.SetState(addr, common.Hash{j}, common.Hash{i, j})
so.SetState(common.Hash{j}, val)
so.SetNonce(100)
} }
so.AddBalance(big.NewInt(int64(i))) st.SetNonce(addr, 100)
so.SetCode(crypto.Keccak256Hash([]byte{i, i, i}), []byte{i, i, i}) st.AddBalance(addr, big.NewInt(int64(i)))
so.UpdateRoot(sdb) st.SetCode(addr, []byte{i, i, i})
st.UpdateStateObject(so)
} }
root, _ := st.Commit() root, _ := st.Commit()
return root, sdb return root, sdb

View File

@ -171,7 +171,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) {
self.current.receipts, self.current.receipts,
), self.current.state ), self.current.state
} }
return self.current.Block, self.current.state return self.current.Block, self.current.state.Copy()
} }
func (self *worker) start() { func (self *worker) start() {
@ -618,7 +618,7 @@ func (env *Work) commitTransactions(mux *event.TypeMux, txs *types.TransactionsB
} }
func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) { func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) {
snap := env.state.Copy() snap := env.state.Snapshot()
// this is a bit of a hack to force jit for the miners // this is a bit of a hack to force jit for the miners
config := env.config.VmConfig config := env.config.VmConfig
@ -629,7 +629,7 @@ func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, g
receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config) receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config)
if err != nil { if err != nil {
env.state.Set(snap) env.state.RevertToSnapshot(snap)
return err, nil return err, nil
} }
env.txs = append(env.txs, tx) env.txs = append(env.txs, tx)

View File

@ -95,14 +95,7 @@ func BenchStateTest(ruleSet RuleSet, p string, conf bconf, b *testing.B) error {
func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) { func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) {
b.StopTimer() b.StopTimer()
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb := makePreState(db, test.Pre)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
b.StartTimer() b.StartTimer()
RunState(ruleSet, statedb, env, test.Exec) RunState(ruleSet, statedb, env, test.Exec)
@ -134,14 +127,7 @@ func runStateTests(ruleSet RuleSet, tests map[string]VmTest, skipTests []string)
func runStateTest(ruleSet RuleSet, test VmTest) error { func runStateTest(ruleSet RuleSet, test VmTest) error {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb := makePreState(db, test.Pre)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
// XXX Yeah, yeah... // XXX Yeah, yeah...
env := make(map[string]string) env := make(map[string]string)
@ -227,7 +213,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string
} }
// Set pre compiled contracts // Set pre compiled contracts
vm.Precompiled = vm.PrecompiledContracts() vm.Precompiled = vm.PrecompiledContracts()
snapshot := statedb.Copy() snapshot := statedb.Snapshot()
gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"])) gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"]))
key, _ := hex.DecodeString(tx["secretKey"]) key, _ := hex.DecodeString(tx["secretKey"])
@ -237,7 +223,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string
vmenv.origin = addr vmenv.origin = addr
ret, _, err := core.ApplyMessage(vmenv, message, gaspool) ret, _, err := core.ApplyMessage(vmenv, message, gaspool)
if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) { if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) {
statedb.Set(snapshot) statedb.RevertToSnapshot(snapshot)
} }
statedb.Commit() statedb.Commit()

View File

@ -103,19 +103,25 @@ func (self Log) Topics() [][]byte {
return t return t
} }
func StateObjectFromAccount(db ethdb.Database, addr string, account Account, onDirty func(common.Address)) *state.StateObject { func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB {
statedb, _ := state.New(common.Hash{}, db)
for addr, account := range accounts {
insertAccount(statedb, addr, account)
}
return statedb
}
func insertAccount(state *state.StateDB, saddr string, account Account) {
if common.IsHex(account.Code) { if common.IsHex(account.Code) {
account.Code = account.Code[2:] account.Code = account.Code[2:]
} }
code := common.Hex2Bytes(account.Code) addr := common.HexToAddress(saddr)
codeHash := crypto.Keccak256Hash(code) state.SetCode(addr, common.Hex2Bytes(account.Code))
obj := state.NewObject(common.HexToAddress(addr), state.Account{ state.SetNonce(addr, common.Big(account.Nonce).Uint64())
Balance: common.Big(account.Balance), state.SetBalance(addr, common.Big(account.Balance))
CodeHash: codeHash[:], for a, v := range account.Storage {
Nonce: common.Big(account.Nonce).Uint64(), state.SetState(addr, common.HexToHash(a), common.HexToHash(v))
}, onDirty) }
obj.SetCode(codeHash, code)
return obj
} }
type VmEnv struct { type VmEnv struct {
@ -229,11 +235,11 @@ func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0 return self.state.GetBalance(from).Cmp(balance) >= 0
} }
func (self *Env) MakeSnapshot() vm.Database { func (self *Env) SnapshotDatabase() int {
return self.state.Copy() return self.state.Snapshot()
} }
func (self *Env) SetSnapshot(copy vm.Database) { func (self *Env) RevertToSnapshot(snapshot int) {
self.state.Set(copy.(*state.StateDB)) self.state.RevertToSnapshot(snapshot)
} }
func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { func (self *Env) Transfer(from, to vm.Account, amount *big.Int) {

View File

@ -101,14 +101,7 @@ func BenchVmTest(p string, conf bconf, b *testing.B) error {
func benchVmTest(test VmTest, env map[string]string, b *testing.B) { func benchVmTest(test VmTest, env map[string]string, b *testing.B) {
b.StopTimer() b.StopTimer()
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb := makePreState(db, test.Pre)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
b.StartTimer() b.StartTimer()
RunVm(statedb, env, test.Exec) RunVm(statedb, env, test.Exec)
@ -152,14 +145,7 @@ func runVmTests(tests map[string]VmTest, skipTests []string) error {
func runVmTest(test VmTest) error { func runVmTest(test VmTest) error {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb := makePreState(db, test.Pre)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
// XXX Yeah, yeah... // XXX Yeah, yeah...
env := make(map[string]string) env := make(map[string]string)