Merge pull request #3092 from fjl/state-journal

core/state: implement reverts by journaling all changes
This commit is contained in:
Jeffrey Wilcke 2016-10-06 16:14:22 +02:00 committed by GitHub
commit 7335a70a02
27 changed files with 700 additions and 278 deletions

View File

@ -172,8 +172,9 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call ethereum.CallM
func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) { func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call ethereum.CallMsg) ([]byte, error) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot())
rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy()) rval, _, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
return rval, err return rval, err
} }
@ -197,8 +198,9 @@ func (b *SimulatedBackend) SuggestGasPrice(ctx context.Context) (*big.Int, error
func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMsg) (*big.Int, error) { func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMsg) (*big.Int, error) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
defer b.pendingState.RevertToSnapshot(b.pendingState.Snapshot())
_, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState.Copy()) _, gas, err := b.callContract(ctx, call, b.pendingBlock, b.pendingState)
return gas, err return gas, err
} }

View File

@ -230,8 +230,8 @@ func (ruleSet) IsHomestead(*big.Int) bool { return true }
func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} } func (self *VMEnv) RuleSet() vm.RuleSet { return ruleSet{} }
func (self *VMEnv) Vm() vm.Vm { return self.evm } func (self *VMEnv) Vm() vm.Vm { return self.evm }
func (self *VMEnv) Db() vm.Database { return self.state } func (self *VMEnv) Db() vm.Database { return self.state }
func (self *VMEnv) MakeSnapshot() vm.Database { return self.state.Copy() } func (self *VMEnv) SnapshotDatabase() int { return self.state.Snapshot() }
func (self *VMEnv) SetSnapshot(db vm.Database) { self.state.Set(db.(*state.StateDB)) } func (self *VMEnv) RevertToSnapshot(snap int) { self.state.RevertToSnapshot(snap) }
func (self *VMEnv) Origin() common.Address { return *self.transactor } func (self *VMEnv) Origin() common.Address { return *self.transactor }
func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 } func (self *VMEnv) BlockNumber() *big.Int { return common.Big0 }
func (self *VMEnv) Coinbase() common.Address { return *self.transactor } func (self *VMEnv) Coinbase() common.Address { return *self.transactor }

View File

@ -131,7 +131,7 @@ func (b *BlockGen) AddUncheckedReceipt(receipt *types.Receipt) {
// TxNonce returns the next valid transaction nonce for the // TxNonce returns the next valid transaction nonce for the
// account at addr. It panics if the account does not exist. // account at addr. It panics if the account does not exist.
func (b *BlockGen) TxNonce(addr common.Address) uint64 { func (b *BlockGen) TxNonce(addr common.Address) uint64 {
if !b.statedb.HasAccount(addr) { if !b.statedb.Exist(addr) {
panic("account does not exist") panic("account does not exist")
} }
return b.statedb.GetNonce(addr) return b.statedb.GetNonce(addr)

View File

@ -85,7 +85,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
createAccount = true createAccount = true
} }
snapshotPreTransfer := env.MakeSnapshot() snapshotPreTransfer := env.SnapshotDatabase()
var ( var (
from = env.Db().GetAccount(caller.Address()) from = env.Db().GetAccount(caller.Address())
to vm.Account to vm.Account
@ -129,7 +129,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A
if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) { if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) {
contract.UseGas(contract.Gas) contract.UseGas(contract.Gas)
env.SetSnapshot(snapshotPreTransfer) env.RevertToSnapshot(snapshotPreTransfer)
} }
return ret, addr, err return ret, addr, err
@ -144,7 +144,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
return nil, common.Address{}, vm.DepthError return nil, common.Address{}, vm.DepthError
} }
snapshot := env.MakeSnapshot() snapshot := env.SnapshotDatabase()
var to vm.Account var to vm.Account
if !env.Db().Exist(*toAddr) { if !env.Db().Exist(*toAddr) {
@ -162,7 +162,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA
if err != nil { if err != nil {
contract.UseGas(contract.Gas) contract.UseGas(contract.Gas)
env.SetSnapshot(snapshot) env.RevertToSnapshot(snapshot)
} }
return ret, addr, err return ret, addr, err

View File

@ -52,7 +52,7 @@ func (self *StateDB) RawDump() Dump {
panic(err) panic(err)
} }
obj := NewObject(common.BytesToAddress(addr), data, nil) obj := newObject(nil, common.BytesToAddress(addr), data, nil)
account := DumpAccount{ account := DumpAccount{
Balance: data.Balance.String(), Balance: data.Balance.String(),
Nonce: data.Nonce, Nonce: data.Nonce,

117
core/state/journal.go Normal file
View File

@ -0,0 +1,117 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"math/big"
"github.com/ethereum/go-ethereum/common"
)
type journalEntry interface {
undo(*StateDB)
}
type journal []journalEntry
type (
// Changes to the account trie.
createObjectChange struct {
account *common.Address
}
resetObjectChange struct {
prev *StateObject
}
suicideChange struct {
account *common.Address
prev bool // whether account had already suicided
prevbalance *big.Int
}
// Changes to individual accounts.
balanceChange struct {
account *common.Address
prev *big.Int
}
nonceChange struct {
account *common.Address
prev uint64
}
storageChange struct {
account *common.Address
key, prevalue common.Hash
}
codeChange struct {
account *common.Address
prevcode, prevhash []byte
}
// Changes to other state values.
refundChange struct {
prev *big.Int
}
addLogChange struct {
txhash common.Hash
}
)
func (ch createObjectChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).deleted = true
delete(s.stateObjects, *ch.account)
delete(s.stateObjectsDirty, *ch.account)
}
func (ch resetObjectChange) undo(s *StateDB) {
s.setStateObject(ch.prev)
}
func (ch suicideChange) undo(s *StateDB) {
obj := s.GetStateObject(*ch.account)
if obj != nil {
obj.suicided = ch.prev
obj.setBalance(ch.prevbalance)
}
}
func (ch balanceChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setBalance(ch.prev)
}
func (ch nonceChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setNonce(ch.prev)
}
func (ch codeChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
}
func (ch storageChange) undo(s *StateDB) {
s.GetStateObject(*ch.account).setState(ch.key, ch.prevalue)
}
func (ch refundChange) undo(s *StateDB) {
s.refund = ch.prev
}
func (ch addLogChange) undo(s *StateDB) {
logs := s.logs[ch.txhash]
if len(logs) == 1 {
delete(s.logs, ch.txhash)
} else {
s.logs[ch.txhash] = logs[:len(logs)-1]
}
}

View File

@ -29,11 +29,8 @@ func create() (*ManagedState, *account) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := New(common.Hash{}, db) statedb, _ := New(common.Hash{}, db)
ms := ManageState(statedb) ms := ManageState(statedb)
so := &StateObject{address: addr} ms.StateDB.SetNonce(addr, 100)
so.SetNonce(100) ms.accounts[addr] = newAccount(ms.StateDB.GetStateObject(addr))
ms.StateDB.stateObjects[addr] = so
ms.accounts[addr] = newAccount(so)
return ms, ms.accounts[addr] return ms, ms.accounts[addr]
} }

View File

@ -66,6 +66,7 @@ func (self Storage) Copy() Storage {
type StateObject struct { type StateObject struct {
address common.Address // Ethereum address of this account address common.Address // Ethereum address of this account
data Account data Account
db *StateDB
// DB error. // DB error.
// State objects are used by the consensus core and VM which are // State objects are used by the consensus core and VM which are
@ -82,10 +83,10 @@ type StateObject struct {
dirtyStorage Storage // Storage entries that need to be flushed to disk dirtyStorage Storage // Storage entries that need to be flushed to disk
// Cache flags. // Cache flags.
// When an object is marked for deletion it will be delete from the trie // When an object is marked suicided it will be delete from the trie
// during the "update" phase of the state transition // during the "update" phase of the state transition.
dirtyCode bool // true if the code was updated dirtyCode bool // true if the code was updated
remove bool suicided bool
deleted bool deleted bool
onDirty func(addr common.Address) // Callback method to mark a state object newly dirty onDirty func(addr common.Address) // Callback method to mark a state object newly dirty
} }
@ -99,15 +100,15 @@ type Account struct {
CodeHash []byte CodeHash []byte
} }
// NewObject creates a state object. // newObject creates a state object.
func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject { func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *StateObject {
if data.Balance == nil { if data.Balance == nil {
data.Balance = new(big.Int) data.Balance = new(big.Int)
} }
if data.CodeHash == nil { if data.CodeHash == nil {
data.CodeHash = emptyCodeHash data.CodeHash = emptyCodeHash
} }
return &StateObject{address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} return &StateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty}
} }
// EncodeRLP implements rlp.Encoder. // EncodeRLP implements rlp.Encoder.
@ -122,8 +123,8 @@ func (self *StateObject) setError(err error) {
} }
} }
func (self *StateObject) MarkForDeletion() { func (self *StateObject) markSuicided() {
self.remove = true self.suicided = true
if self.onDirty != nil { if self.onDirty != nil {
self.onDirty(self.Address()) self.onDirty(self.Address())
self.onDirty = nil self.onDirty = nil
@ -152,10 +153,13 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash
return value return value
} }
// Load from DB in case it is missing. // Load from DB in case it is missing.
tr := self.getTrie(db) if enc := self.getTrie(db).Get(key[:]); len(enc) > 0 {
var ret []byte _, content, _, err := rlp.Split(enc)
rlp.DecodeBytes(tr.Get(key[:]), &ret) if err != nil {
value = common.BytesToHash(ret) self.setError(err)
}
value.SetBytes(content)
}
if (value != common.Hash{}) { if (value != common.Hash{}) {
self.cachedStorage[key] = value self.cachedStorage[key] = value
} }
@ -163,7 +167,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash
} }
// SetState updates a value in account storage. // SetState updates a value in account storage.
func (self *StateObject) SetState(key, value common.Hash) { func (self *StateObject) SetState(db trie.Database, key, value common.Hash) {
self.db.journal = append(self.db.journal, storageChange{
account: &self.address,
key: key,
prevalue: self.GetState(db, key),
})
self.setState(key, value)
}
func (self *StateObject) setState(key, value common.Hash) {
self.cachedStorage[key] = value self.cachedStorage[key] = value
self.dirtyStorage[key] = value self.dirtyStorage[key] = value
@ -189,7 +202,7 @@ func (self *StateObject) updateTrie(db trie.Database) {
} }
// UpdateRoot sets the trie root to the current root hash of // UpdateRoot sets the trie root to the current root hash of
func (self *StateObject) UpdateRoot(db trie.Database) { func (self *StateObject) updateRoot(db trie.Database) {
self.updateTrie(db) self.updateTrie(db)
self.data.Root = self.trie.Hash() self.data.Root = self.trie.Hash()
} }
@ -199,7 +212,6 @@ func (self *StateObject) UpdateRoot(db trie.Database) {
func (self *StateObject) CommitTrie(db trie.Database, dbw trie.DatabaseWriter) error { func (self *StateObject) CommitTrie(db trie.Database, dbw trie.DatabaseWriter) error {
self.updateTrie(db) self.updateTrie(db)
if self.dbErr != nil { if self.dbErr != nil {
fmt.Println("dbErr:", self.dbErr)
return self.dbErr return self.dbErr
} }
root, err := self.trie.CommitTo(dbw) root, err := self.trie.CommitTo(dbw)
@ -232,6 +244,14 @@ func (c *StateObject) SubBalance(amount *big.Int) {
} }
func (self *StateObject) SetBalance(amount *big.Int) { func (self *StateObject) SetBalance(amount *big.Int) {
self.db.journal = append(self.db.journal, balanceChange{
account: &self.address,
prev: new(big.Int).Set(self.data.Balance),
})
self.setBalance(amount)
}
func (self *StateObject) setBalance(amount *big.Int) {
self.data.Balance = amount self.data.Balance = amount
if self.onDirty != nil { if self.onDirty != nil {
self.onDirty(self.Address()) self.onDirty(self.Address())
@ -242,13 +262,13 @@ func (self *StateObject) SetBalance(amount *big.Int) {
// Return the gas back to the origin. Used by the Virtual machine or Closures // Return the gas back to the origin. Used by the Virtual machine or Closures
func (c *StateObject) ReturnGas(gas, price *big.Int) {} func (c *StateObject) ReturnGas(gas, price *big.Int) {}
func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject { func (self *StateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *StateObject {
stateObject := NewObject(self.address, self.data, onDirty) stateObject := newObject(db, self.address, self.data, onDirty)
stateObject.trie = self.trie stateObject.trie = self.trie
stateObject.code = self.code stateObject.code = self.code
stateObject.dirtyStorage = self.dirtyStorage.Copy() stateObject.dirtyStorage = self.dirtyStorage.Copy()
stateObject.cachedStorage = self.dirtyStorage.Copy() stateObject.cachedStorage = self.dirtyStorage.Copy()
stateObject.remove = self.remove stateObject.suicided = self.suicided
stateObject.dirtyCode = self.dirtyCode stateObject.dirtyCode = self.dirtyCode
stateObject.deleted = self.deleted stateObject.deleted = self.deleted
return stateObject return stateObject
@ -280,6 +300,16 @@ func (self *StateObject) Code(db trie.Database) []byte {
} }
func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
prevcode := self.Code(self.db.db)
self.db.journal = append(self.db.journal, codeChange{
account: &self.address,
prevhash: self.CodeHash(),
prevcode: prevcode,
})
self.setCode(codeHash, code)
}
func (self *StateObject) setCode(codeHash common.Hash, code []byte) {
self.code = code self.code = code
self.data.CodeHash = codeHash[:] self.data.CodeHash = codeHash[:]
self.dirtyCode = true self.dirtyCode = true
@ -290,6 +320,14 @@ func (self *StateObject) SetCode(codeHash common.Hash, code []byte) {
} }
func (self *StateObject) SetNonce(nonce uint64) { func (self *StateObject) SetNonce(nonce uint64) {
self.db.journal = append(self.db.journal, nonceChange{
account: &self.address,
prev: self.data.Nonce,
})
self.setNonce(nonce)
}
func (self *StateObject) setNonce(nonce uint64) {
self.data.Nonce = nonce self.data.Nonce = nonce
if self.onDirty != nil { if self.onDirty != nil {
self.onDirty(self.Address()) self.onDirty(self.Address())
@ -322,7 +360,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) {
cb(h, value) cb(h, value)
} }
it := self.trie.Iterator() it := self.getTrie(self.db.db).Iterator()
for it.Next() { for it.Next() {
// ignore cached values // ignore cached values
key := common.BytesToHash(self.trie.GetKey(it.Key)) key := common.BytesToHash(self.trie.GetKey(it.Key))

View File

@ -46,8 +46,8 @@ func (s *StateSuite) TestDump(c *checker.C) {
obj3.SetBalance(big.NewInt(44)) obj3.SetBalance(big.NewInt(44))
// write some of them to the trie // write some of them to the trie
s.state.UpdateStateObject(obj1) s.state.updateStateObject(obj1)
s.state.UpdateStateObject(obj2) s.state.updateStateObject(obj2)
s.state.Commit() s.state.Commit()
// check that dump contains the state objects that are in trie // check that dump contains the state objects that are in trie
@ -116,12 +116,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
// set initial state object value // set initial state object value
s.state.SetState(stateobjaddr, storageaddr, data1) s.state.SetState(stateobjaddr, storageaddr, data1)
// get snapshot of current state // get snapshot of current state
snapshot := s.state.Copy() snapshot := s.state.Snapshot()
// set new state object value // set new state object value
s.state.SetState(stateobjaddr, storageaddr, data2) s.state.SetState(stateobjaddr, storageaddr, data2)
// restore snapshot // restore snapshot
s.state.Set(snapshot) s.state.RevertToSnapshot(snapshot)
// get state storage value // get state storage value
res := s.state.GetState(stateobjaddr, storageaddr) res := s.state.GetState(stateobjaddr, storageaddr)
@ -129,6 +129,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
c.Assert(data1, checker.DeepEquals, res) c.Assert(data1, checker.DeepEquals, res)
} }
func TestSnapshotEmpty(t *testing.T) {
db, _ := ethdb.NewMemDatabase()
state, _ := New(common.Hash{}, db)
state.RevertToSnapshot(state.Snapshot())
}
// use testing instead of checker because checker does not support // use testing instead of checker because checker does not support
// printing/logging in tests (-check.vv does not work) // printing/logging in tests (-check.vv does not work)
func TestSnapshot2(t *testing.T) { func TestSnapshot2(t *testing.T) {
@ -150,9 +156,9 @@ func TestSnapshot2(t *testing.T) {
so0.SetBalance(big.NewInt(42)) so0.SetBalance(big.NewInt(42))
so0.SetNonce(43) so0.SetNonce(43)
so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'})
so0.remove = false so0.suicided = false
so0.deleted = false so0.deleted = false
state.SetStateObject(so0) state.setStateObject(so0)
root, _ := state.Commit() root, _ := state.Commit()
state.Reset(root) state.Reset(root)
@ -162,17 +168,17 @@ func TestSnapshot2(t *testing.T) {
so1.SetBalance(big.NewInt(52)) so1.SetBalance(big.NewInt(52))
so1.SetNonce(53) so1.SetNonce(53)
so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'}) so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'})
so1.remove = true so1.suicided = true
so1.deleted = true so1.deleted = true
state.SetStateObject(so1) state.setStateObject(so1)
so1 = state.GetStateObject(stateobjaddr1) so1 = state.GetStateObject(stateobjaddr1)
if so1 != nil { if so1 != nil {
t.Fatalf("deleted object not nil when getting") t.Fatalf("deleted object not nil when getting")
} }
snapshot := state.Copy() snapshot := state.Snapshot()
state.Set(snapshot) state.RevertToSnapshot(snapshot)
so0Restored := state.GetStateObject(stateobjaddr0) so0Restored := state.GetStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing. // Update lazily-loaded values before comparing.
@ -222,8 +228,8 @@ func compareStateObjects(so0, so1 *StateObject, t *testing.T) {
} }
} }
if so0.remove != so1.remove { if so0.suicided != so1.suicided {
t.Fatalf("Remove mismatch: have %v, want %v", so0.remove, so1.remove) t.Fatalf("suicided mismatch: have %v, want %v", so0.suicided, so1.suicided)
} }
if so0.deleted != so1.deleted { if so0.deleted != so1.deleted {
t.Fatalf("Deleted mismatch: have %v, want %v", so0.deleted, so1.deleted) t.Fatalf("Deleted mismatch: have %v, want %v", so0.deleted, so1.deleted)

View File

@ -20,6 +20,7 @@ package state
import ( import (
"fmt" "fmt"
"math/big" "math/big"
"sort"
"sync" "sync"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -40,12 +41,17 @@ var StartingNonce uint64
const ( const (
// Number of past tries to keep. The arbitrarily chosen value here // Number of past tries to keep. The arbitrarily chosen value here
// is max uncle depth + 1. // is max uncle depth + 1.
maxJournalLength = 8 maxTrieCacheLength = 8
// Number of codehash->size associations to keep. // Number of codehash->size associations to keep.
codeSizeCacheSize = 100000 codeSizeCacheSize = 100000
) )
type revision struct {
id int
journalIndex int
}
// StateDBs within the ethereum protocol are used to store anything // StateDBs within the ethereum protocol are used to store anything
// within the merkle trie. StateDBs take care of caching and storing // within the merkle trie. StateDBs take care of caching and storing
// nested states. It's the general query interface to retrieve: // nested states. It's the general query interface to retrieve:
@ -69,6 +75,12 @@ type StateDB struct {
logs map[common.Hash]vm.Logs logs map[common.Hash]vm.Logs
logSize uint logSize uint
// Journal of state modifications. This is the backbone of
// Snapshot and RevertToSnapshot.
journal journal
validRevisions []revision
nextRevisionId int
lock sync.Mutex lock sync.Mutex
} }
@ -124,12 +136,12 @@ func (self *StateDB) Reset(root common.Hash) error {
self.trie = tr self.trie = tr
self.stateObjects = make(map[common.Address]*StateObject) self.stateObjects = make(map[common.Address]*StateObject)
self.stateObjectsDirty = make(map[common.Address]struct{}) self.stateObjectsDirty = make(map[common.Address]struct{})
self.refund = new(big.Int)
self.thash = common.Hash{} self.thash = common.Hash{}
self.bhash = common.Hash{} self.bhash = common.Hash{}
self.txIndex = 0 self.txIndex = 0
self.logs = make(map[common.Hash]vm.Logs) self.logs = make(map[common.Hash]vm.Logs)
self.logSize = 0 self.logSize = 0
self.clearJournalAndRefund()
return nil return nil
} }
@ -150,7 +162,7 @@ func (self *StateDB) pushTrie(t *trie.SecureTrie) {
self.lock.Lock() self.lock.Lock()
defer self.lock.Unlock() defer self.lock.Unlock()
if len(self.pastTries) >= maxJournalLength { if len(self.pastTries) >= maxTrieCacheLength {
copy(self.pastTries, self.pastTries[1:]) copy(self.pastTries, self.pastTries[1:])
self.pastTries[len(self.pastTries)-1] = t self.pastTries[len(self.pastTries)-1] = t
} else { } else {
@ -165,6 +177,8 @@ func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) {
} }
func (self *StateDB) AddLog(log *vm.Log) { func (self *StateDB) AddLog(log *vm.Log) {
self.journal = append(self.journal, addLogChange{txhash: self.thash})
log.TxHash = self.thash log.TxHash = self.thash
log.BlockHash = self.bhash log.BlockHash = self.bhash
log.TxIndex = uint(self.txIndex) log.TxIndex = uint(self.txIndex)
@ -186,13 +200,12 @@ func (self *StateDB) Logs() vm.Logs {
} }
func (self *StateDB) AddRefund(gas *big.Int) { func (self *StateDB) AddRefund(gas *big.Int) {
self.journal = append(self.journal, refundChange{prev: new(big.Int).Set(self.refund)})
self.refund.Add(self.refund, gas) self.refund.Add(self.refund, gas)
} }
func (self *StateDB) HasAccount(addr common.Address) bool { // Exist reports whether the given account address exists in the state.
return self.GetStateObject(addr) != nil // Notably this also returns true for suicided accounts.
}
func (self *StateDB) Exist(addr common.Address) bool { func (self *StateDB) Exist(addr common.Address) bool {
return self.GetStateObject(addr) != nil return self.GetStateObject(addr) != nil
} }
@ -207,7 +220,6 @@ func (self *StateDB) GetBalance(addr common.Address) *big.Int {
if stateObject != nil { if stateObject != nil {
return stateObject.Balance() return stateObject.Balance()
} }
return common.Big0 return common.Big0
} }
@ -263,10 +275,10 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash {
return common.Hash{} return common.Hash{}
} }
func (self *StateDB) IsDeleted(addr common.Address) bool { func (self *StateDB) HasSuicided(addr common.Address) bool {
stateObject := self.GetStateObject(addr) stateObject := self.GetStateObject(addr)
if stateObject != nil { if stateObject != nil {
return stateObject.remove return stateObject.suicided
} }
return false return false
} }
@ -282,6 +294,13 @@ func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) {
} }
} }
func (self *StateDB) SetBalance(addr common.Address, amount *big.Int) {
stateObject := self.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetBalance(amount)
}
}
func (self *StateDB) SetNonce(addr common.Address, nonce uint64) { func (self *StateDB) SetNonce(addr common.Address, nonce uint64) {
stateObject := self.GetOrNewStateObject(addr) stateObject := self.GetOrNewStateObject(addr)
if stateObject != nil { if stateObject != nil {
@ -299,27 +318,36 @@ func (self *StateDB) SetCode(addr common.Address, code []byte) {
func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) { func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) {
stateObject := self.GetOrNewStateObject(addr) stateObject := self.GetOrNewStateObject(addr)
if stateObject != nil { if stateObject != nil {
stateObject.SetState(key, value) stateObject.SetState(self.db, key, value)
} }
} }
func (self *StateDB) Delete(addr common.Address) bool { // Suicide marks the given account as suicided.
// This clears the account balance.
//
// The account's state object is still available until the state is committed,
// GetStateObject will return a non-nil account after Suicide.
func (self *StateDB) Suicide(addr common.Address) bool {
stateObject := self.GetStateObject(addr) stateObject := self.GetStateObject(addr)
if stateObject != nil { if stateObject == nil {
stateObject.MarkForDeletion() return false
}
self.journal = append(self.journal, suicideChange{
account: &addr,
prev: stateObject.suicided,
prevbalance: new(big.Int).Set(stateObject.Balance()),
})
stateObject.markSuicided()
stateObject.data.Balance = new(big.Int) stateObject.data.Balance = new(big.Int)
return true return true
} }
return false
}
// //
// Setting, updating & deleting state object methods // Setting, updating & deleting state object methods
// //
// Update the given state object and apply it to state trie // updateStateObject writes the given object to the trie.
func (self *StateDB) UpdateStateObject(stateObject *StateObject) { func (self *StateDB) updateStateObject(stateObject *StateObject) {
addr := stateObject.Address() addr := stateObject.Address()
data, err := rlp.EncodeToBytes(stateObject) data, err := rlp.EncodeToBytes(stateObject)
if err != nil { if err != nil {
@ -328,10 +356,9 @@ func (self *StateDB) UpdateStateObject(stateObject *StateObject) {
self.trie.Update(addr[:], data) self.trie.Update(addr[:], data)
} }
// Delete the given state object and delete it from the state trie // deleteStateObject removes the given object from the state trie.
func (self *StateDB) DeleteStateObject(stateObject *StateObject) { func (self *StateDB) deleteStateObject(stateObject *StateObject) {
stateObject.deleted = true stateObject.deleted = true
addr := stateObject.Address() addr := stateObject.Address()
self.trie.Delete(addr[:]) self.trie.Delete(addr[:])
} }
@ -357,12 +384,12 @@ func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObje
return nil return nil
} }
// Insert into the live set. // Insert into the live set.
obj := NewObject(addr, data, self.MarkStateObjectDirty) obj := newObject(self, addr, data, self.MarkStateObjectDirty)
self.SetStateObject(obj) self.setStateObject(obj)
return obj return obj
} }
func (self *StateDB) SetStateObject(object *StateObject) { func (self *StateDB) setStateObject(object *StateObject) {
self.stateObjects[object.Address()] = object self.stateObjects[object.Address()] = object
} }
@ -370,52 +397,55 @@ func (self *StateDB) SetStateObject(object *StateObject) {
func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject { func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject {
stateObject := self.GetStateObject(addr) stateObject := self.GetStateObject(addr)
if stateObject == nil || stateObject.deleted { if stateObject == nil || stateObject.deleted {
stateObject = self.CreateStateObject(addr) stateObject, _ = self.createObject(addr)
} }
return stateObject return stateObject
} }
// NewStateObject create a state object whether it exist in the trie or not
func (self *StateDB) newStateObject(addr common.Address) *StateObject {
if glog.V(logger.Core) {
glog.Infof("(+) %x\n", addr)
}
obj := NewObject(addr, Account{}, self.MarkStateObjectDirty)
obj.SetNonce(StartingNonce) // sets the object to dirty
self.stateObjects[addr] = obj
return obj
}
// MarkStateObjectDirty adds the specified object to the dirty map to avoid costly // MarkStateObjectDirty adds the specified object to the dirty map to avoid costly
// state object cache iteration to find a handful of modified ones. // state object cache iteration to find a handful of modified ones.
func (self *StateDB) MarkStateObjectDirty(addr common.Address) { func (self *StateDB) MarkStateObjectDirty(addr common.Address) {
self.stateObjectsDirty[addr] = struct{}{} self.stateObjectsDirty[addr] = struct{}{}
} }
// Creates creates a new state object and takes ownership. // createObject creates a new state object. If there is an existing account with
func (self *StateDB) CreateStateObject(addr common.Address) *StateObject { // the given address, it is overwritten and returned as the second return value.
// Get previous (if any) func (self *StateDB) createObject(addr common.Address) (newobj, prev *StateObject) {
so := self.GetStateObject(addr) prev = self.GetStateObject(addr)
// Create a new one newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty)
newSo := self.newStateObject(addr) newobj.setNonce(StartingNonce) // sets the object to dirty
if prev == nil {
// If it existed set the balance to the new account if glog.V(logger.Core) {
if so != nil { glog.Infof("(+) %x\n", addr)
newSo.data.Balance = so.data.Balance }
} self.journal = append(self.journal, createObjectChange{account: &addr})
} else {
return newSo self.journal = append(self.journal, resetObjectChange{prev: prev})
}
self.setStateObject(newobj)
return newobj, prev
} }
// CreateAccount explicitly creates a state object. If a state object with the address
// already exists the balance is carried over to the new account.
//
// CreateAccount is called during the EVM CREATE operation. The situation might arise that
// a contract does the following:
//
// 1. sends funds to sha(account ++ (nonce + 1))
// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1)
//
// Carrying over the balance ensures that Ether doesn't disappear.
func (self *StateDB) CreateAccount(addr common.Address) vm.Account { func (self *StateDB) CreateAccount(addr common.Address) vm.Account {
return self.CreateStateObject(addr) new, prev := self.createObject(addr)
if prev != nil {
new.setBalance(prev.data.Balance)
}
return new
} }
// // Copy creates a deep, independent copy of the state.
// Setting, copying of the state methods // Snapshots of the copied state cannot be applied to the copy.
//
func (self *StateDB) Copy() *StateDB { func (self *StateDB) Copy() *StateDB {
self.lock.Lock() self.lock.Lock()
defer self.lock.Unlock() defer self.lock.Unlock()
@ -434,7 +464,7 @@ func (self *StateDB) Copy() *StateDB {
} }
// Copy the dirty states and logs // Copy the dirty states and logs
for addr, _ := range self.stateObjectsDirty { for addr, _ := range self.stateObjectsDirty {
state.stateObjects[addr] = self.stateObjects[addr].Copy(self.db, state.MarkStateObjectDirty) state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty)
state.stateObjectsDirty[addr] = struct{}{} state.stateObjectsDirty[addr] = struct{}{}
} }
for hash, logs := range self.logs { for hash, logs := range self.logs {
@ -444,21 +474,38 @@ func (self *StateDB) Copy() *StateDB {
return state return state
} }
func (self *StateDB) Set(state *StateDB) { // Snapshot returns an identifier for the current revision of the state.
self.lock.Lock() func (self *StateDB) Snapshot() int {
defer self.lock.Unlock() id := self.nextRevisionId
self.nextRevisionId++
self.db = state.db self.validRevisions = append(self.validRevisions, revision{id, len(self.journal)})
self.trie = state.trie return id
self.pastTries = state.pastTries
self.stateObjects = state.stateObjects
self.stateObjectsDirty = state.stateObjectsDirty
self.codeSizeCache = state.codeSizeCache
self.refund = state.refund
self.logs = state.logs
self.logSize = state.logSize
} }
// RevertToSnapshot reverts all state changes made since the given revision.
func (self *StateDB) RevertToSnapshot(revid int) {
// Find the snapshot in the stack of valid snapshots.
idx := sort.Search(len(self.validRevisions), func(i int) bool {
return self.validRevisions[i].id >= revid
})
if idx == len(self.validRevisions) || self.validRevisions[idx].id != revid {
panic(fmt.Errorf("revision id %v cannot be reverted", revid))
}
snapshot := self.validRevisions[idx].journalIndex
// Replay the journal to undo changes.
for i := len(self.journal) - 1; i >= snapshot; i-- {
self.journal[i].undo(self)
}
self.journal = self.journal[:snapshot]
// Remove invalidated snapshots from the stack.
self.validRevisions = self.validRevisions[:idx]
}
// GetRefund returns the current value of the refund counter.
// The return value must not be modified by the caller and will become
// invalid at the next call to AddRefund.
func (self *StateDB) GetRefund() *big.Int { func (self *StateDB) GetRefund() *big.Int {
return self.refund return self.refund
} }
@ -467,16 +514,17 @@ func (self *StateDB) GetRefund() *big.Int {
// It is called in between transactions to get the root hash that // It is called in between transactions to get the root hash that
// goes into transaction receipts. // goes into transaction receipts.
func (s *StateDB) IntermediateRoot() common.Hash { func (s *StateDB) IntermediateRoot() common.Hash {
s.refund = new(big.Int)
for addr, _ := range s.stateObjectsDirty { for addr, _ := range s.stateObjectsDirty {
stateObject := s.stateObjects[addr] stateObject := s.stateObjects[addr]
if stateObject.remove { if stateObject.suicided {
s.DeleteStateObject(stateObject) s.deleteStateObject(stateObject)
} else { } else {
stateObject.UpdateRoot(s.db) stateObject.updateRoot(s.db)
s.UpdateStateObject(stateObject) s.updateStateObject(stateObject)
} }
} }
// Invalidate journal because reverting across transactions is not allowed.
s.clearJournalAndRefund()
return s.trie.Hash() return s.trie.Hash()
} }
@ -486,15 +534,15 @@ func (s *StateDB) IntermediateRoot() common.Hash {
// DeleteSuicides should not be used for consensus related updates // DeleteSuicides should not be used for consensus related updates
// under any circumstances. // under any circumstances.
func (s *StateDB) DeleteSuicides() { func (s *StateDB) DeleteSuicides() {
// Reset refund so that any used-gas calculations can use // Reset refund so that any used-gas calculations can use this method.
// this method. s.clearJournalAndRefund()
s.refund = new(big.Int)
for addr, _ := range s.stateObjectsDirty { for addr, _ := range s.stateObjectsDirty {
stateObject := s.stateObjects[addr] stateObject := s.stateObjects[addr]
// If the object has been removed by a suicide // If the object has been removed by a suicide
// flag the object as deleted. // flag the object as deleted.
if stateObject.remove { if stateObject.suicided {
stateObject.deleted = true stateObject.deleted = true
} }
delete(s.stateObjectsDirty, addr) delete(s.stateObjectsDirty, addr)
@ -516,15 +564,21 @@ func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) {
return root, batch return root, batch
} }
func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { func (s *StateDB) clearJournalAndRefund() {
s.journal = nil
s.validRevisions = s.validRevisions[:0]
s.refund = new(big.Int) s.refund = new(big.Int)
}
func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) {
defer s.clearJournalAndRefund()
// Commit objects to the trie. // Commit objects to the trie.
for addr, stateObject := range s.stateObjects { for addr, stateObject := range s.stateObjects {
if stateObject.remove { if stateObject.suicided {
// If the object has been removed, don't bother syncing it // If the object has been removed, don't bother syncing it
// and just mark it for deletion in the trie. // and just mark it for deletion in the trie.
s.DeleteStateObject(stateObject) s.deleteStateObject(stateObject)
} else if _, ok := s.stateObjectsDirty[addr]; ok { } else if _, ok := s.stateObjectsDirty[addr]; ok {
// Write any contract code associated with the state object // Write any contract code associated with the state object
if stateObject.code != nil && stateObject.dirtyCode { if stateObject.code != nil && stateObject.dirtyCode {
@ -538,7 +592,7 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error)
return common.Hash{}, err return common.Hash{}, err
} }
// Update the object in the main account trie. // Update the object in the main account trie.
s.UpdateStateObject(stateObject) s.updateStateObject(stateObject)
} }
delete(s.stateObjectsDirty, addr) delete(s.stateObjectsDirty, addr)
} }
@ -549,7 +603,3 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error)
} }
return root, err return root, err
} }
func (self *StateDB) Refunds() *big.Int {
return self.refund
}

View File

@ -17,11 +17,19 @@
package state package state
import ( import (
"bytes"
"encoding/binary"
"fmt"
"math"
"math/big" "math/big"
"math/rand"
"reflect"
"strings"
"testing" "testing"
"testing/quick"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
) )
@ -34,16 +42,16 @@ func TestUpdateLeaks(t *testing.T) {
// Update it with some accounts // Update it with some accounts
for i := byte(0); i < 255; i++ { for i := byte(0); i < 255; i++ {
obj := state.GetOrNewStateObject(common.BytesToAddress([]byte{i})) addr := common.BytesToAddress([]byte{i})
obj.AddBalance(big.NewInt(int64(11 * i))) state.AddBalance(addr, big.NewInt(int64(11*i)))
obj.SetNonce(uint64(42 * i)) state.SetNonce(addr, uint64(42*i))
if i%2 == 0 { if i%2 == 0 {
obj.SetState(common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i})) state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
} }
if i%3 == 0 { if i%3 == 0 {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) state.SetCode(addr, []byte{i, i, i, i, i})
} }
state.UpdateStateObject(obj) state.IntermediateRoot()
} }
// Ensure that no data was leaked into the database // Ensure that no data was leaked into the database
for _, key := range db.Keys() { for _, key := range db.Keys() {
@ -61,51 +69,38 @@ func TestIntermediateLeaks(t *testing.T) {
transState, _ := New(common.Hash{}, transDb) transState, _ := New(common.Hash{}, transDb)
finalState, _ := New(common.Hash{}, finalDb) finalState, _ := New(common.Hash{}, finalDb)
// Update the states with some objects modify := func(state *StateDB, addr common.Address, i, tweak byte) {
state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
state.SetNonce(addr, uint64(42*i+tweak))
if i%2 == 0 {
state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
}
if i%3 == 0 {
state.SetCode(addr, []byte{i, i, i, i, i, tweak})
}
}
// Modify the transient state.
for i := byte(0); i < 255; i++ { for i := byte(0); i < 255; i++ {
// Create a new state object with some data into the transition database modify(transState, common.Address{byte(i)}, i, 0)
obj := transState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
obj.SetBalance(big.NewInt(int64(11 * i)))
obj.SetNonce(uint64(42 * i))
if i%2 == 0 {
obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.BytesToHash([]byte{i, i, i, i, 0}))
} }
if i%3 == 0 { // Write modifications to trie.
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 0}), []byte{i, i, i, i, i, 0}) transState.IntermediateRoot()
}
transState.UpdateStateObject(obj)
// Overwrite all the data with new values in the transition database // Overwrite all the data with new values in the transient database.
obj.SetBalance(big.NewInt(int64(11*i + 1))) for i := byte(0); i < 255; i++ {
obj.SetNonce(uint64(42*i + 1)) modify(transState, common.Address{byte(i)}, i, 99)
if i%2 == 0 { modify(finalState, common.Address{byte(i)}, i, 99)
obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.Hash{})
obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
} }
if i%3 == 0 {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
}
transState.UpdateStateObject(obj)
// Create the final state object directly in the final database // Commit and cross check the databases.
obj = finalState.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
obj.SetBalance(big.NewInt(int64(11*i + 1)))
obj.SetNonce(uint64(42*i + 1))
if i%2 == 0 {
obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1}))
}
if i%3 == 0 {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1})
}
finalState.UpdateStateObject(obj)
}
if _, err := transState.Commit(); err != nil { if _, err := transState.Commit(); err != nil {
t.Fatalf("failed to commit transition state: %v", err) t.Fatalf("failed to commit transition state: %v", err)
} }
if _, err := finalState.Commit(); err != nil { if _, err := finalState.Commit(); err != nil {
t.Fatalf("failed to commit final state: %v", err) t.Fatalf("failed to commit final state: %v", err)
} }
// Cross check the databases to ensure they are the same
for _, key := range finalDb.Keys() { for _, key := range finalDb.Keys() {
if _, err := transDb.Get(key); err != nil { if _, err := transDb.Get(key); err != nil {
val, _ := finalDb.Get(key) val, _ := finalDb.Get(key)
@ -119,3 +114,243 @@ func TestIntermediateLeaks(t *testing.T) {
} }
} }
} }
func TestSnapshotRandom(t *testing.T) {
config := &quick.Config{MaxCount: 1000}
err := quick.Check((*snapshotTest).run, config)
if cerr, ok := err.(*quick.CheckError); ok {
test := cerr.In[0].(*snapshotTest)
t.Errorf("%v:\n%s", test.err, test)
} else if err != nil {
t.Error(err)
}
}
// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
// captured by the snapshot. Instances of this test with pseudorandom content are created
// by Generate.
//
// The test works as follows:
//
// A new state is created and all actions are applied to it. Several snapshots are taken
// in between actions. The test then reverts each snapshot. For each snapshot the actions
// leading up to it are replayed on a fresh, empty state. The behaviour of all public
// accessor methods on the reverted state must match the return value of the equivalent
// methods on the replayed state.
type snapshotTest struct {
addrs []common.Address // all account addresses
actions []testAction // modifications to the state
snapshots []int // actions indexes at which snapshot is taken
err error // failure details are reported through this field
}
type testAction struct {
name string
fn func(testAction, *StateDB)
args []int64
noAddr bool
}
// newTestAction creates a random action that changes state.
func newTestAction(addr common.Address, r *rand.Rand) testAction {
actions := []testAction{
{
name: "SetBalance",
fn: func(a testAction, s *StateDB) {
s.SetBalance(addr, big.NewInt(a.args[0]))
},
args: make([]int64, 1),
},
{
name: "AddBalance",
fn: func(a testAction, s *StateDB) {
s.AddBalance(addr, big.NewInt(a.args[0]))
},
args: make([]int64, 1),
},
{
name: "SetNonce",
fn: func(a testAction, s *StateDB) {
s.SetNonce(addr, uint64(a.args[0]))
},
args: make([]int64, 1),
},
{
name: "SetState",
fn: func(a testAction, s *StateDB) {
var key, val common.Hash
binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
s.SetState(addr, key, val)
},
args: make([]int64, 2),
},
{
name: "SetCode",
fn: func(a testAction, s *StateDB) {
code := make([]byte, 16)
binary.BigEndian.PutUint64(code, uint64(a.args[0]))
binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
s.SetCode(addr, code)
},
args: make([]int64, 2),
},
{
name: "CreateAccount",
fn: func(a testAction, s *StateDB) {
s.CreateAccount(addr)
},
},
{
name: "Suicide",
fn: func(a testAction, s *StateDB) {
s.Suicide(addr)
},
},
{
name: "AddRefund",
fn: func(a testAction, s *StateDB) {
s.AddRefund(big.NewInt(a.args[0]))
},
args: make([]int64, 1),
noAddr: true,
},
{
name: "AddLog",
fn: func(a testAction, s *StateDB) {
data := make([]byte, 2)
binary.BigEndian.PutUint16(data, uint16(a.args[0]))
s.AddLog(&vm.Log{Address: addr, Data: data})
},
args: make([]int64, 1),
},
}
action := actions[r.Intn(len(actions))]
var nameargs []string
if !action.noAddr {
nameargs = append(nameargs, addr.Hex())
}
for _, i := range action.args {
action.args[i] = rand.Int63n(100)
nameargs = append(nameargs, fmt.Sprint(action.args[i]))
}
action.name += strings.Join(nameargs, ", ")
return action
}
// Generate returns a new snapshot test of the given size. All randomness is
// derived from r.
func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
// Generate random actions.
addrs := make([]common.Address, 50)
for i := range addrs {
addrs[i][0] = byte(i)
}
actions := make([]testAction, size)
for i := range actions {
addr := addrs[r.Intn(len(addrs))]
actions[i] = newTestAction(addr, r)
}
// Generate snapshot indexes.
nsnapshots := int(math.Sqrt(float64(size)))
if size > 0 && nsnapshots == 0 {
nsnapshots = 1
}
snapshots := make([]int, nsnapshots)
snaplen := len(actions) / nsnapshots
for i := range snapshots {
// Try to place the snapshots some number of actions apart from each other.
snapshots[i] = (i * snaplen) + r.Intn(snaplen)
}
return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
}
func (test *snapshotTest) String() string {
out := new(bytes.Buffer)
sindex := 0
for i, action := range test.actions {
if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
sindex++
}
fmt.Fprintf(out, "%4d: %s\n", i, action.name)
}
return out.String()
}
func (test *snapshotTest) run() bool {
// Run all actions and create snapshots.
var (
db, _ = ethdb.NewMemDatabase()
state, _ = New(common.Hash{}, db)
snapshotRevs = make([]int, len(test.snapshots))
sindex = 0
)
for i, action := range test.actions {
if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
snapshotRevs[sindex] = state.Snapshot()
sindex++
}
action.fn(action, state)
}
// Revert all snapshots in reverse order. Each revert must yield a state
// that is equivalent to fresh state with all actions up the snapshot applied.
for sindex--; sindex >= 0; sindex-- {
checkstate, _ := New(common.Hash{}, db)
for _, action := range test.actions[:test.snapshots[sindex]] {
action.fn(action, checkstate)
}
state.RevertToSnapshot(snapshotRevs[sindex])
if err := test.checkEqual(state, checkstate); err != nil {
test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
return false
}
}
return true
}
// checkEqual checks that methods of state and checkstate return the same values.
func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
for _, addr := range test.addrs {
var err error
checkeq := func(op string, a, b interface{}) bool {
if err == nil && !reflect.DeepEqual(a, b) {
err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
return false
}
return true
}
// Check basic accessor methods.
checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
// Check storage.
if obj := state.GetStateObject(addr); obj != nil {
obj.ForEachStorage(func(key, val common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
})
checkobj := checkstate.GetStateObject(addr)
checkobj.ForEachStorage(func(key, checkval common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
})
}
if err != nil {
return err
}
}
if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 {
return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
state.GetRefund(), checkstate.GetRefund())
}
if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
}
return nil
}

View File

@ -57,7 +57,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) {
obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i})
acc.code = []byte{i, i, i, i, i} acc.code = []byte{i, i, i, i, i}
} }
state.UpdateStateObject(obj) state.updateStateObject(obj)
accounts = append(accounts, acc) accounts = append(accounts, acc)
} }
root, _ := state.Commit() root, _ := state.Commit()

View File

@ -257,7 +257,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction) error {
// Make sure the account exist. Non existent accounts // Make sure the account exist. Non existent accounts
// haven't got funds and well therefor never pass. // haven't got funds and well therefor never pass.
if !currentState.HasAccount(from) { if !currentState.Exist(from) {
return ErrNonExistentAccount return ErrNonExistentAccount
} }

View File

@ -36,9 +36,9 @@ type Environment interface {
// The state database // The state database
Db() Database Db() Database
// Creates a restorable snapshot // Creates a restorable snapshot
MakeSnapshot() Database SnapshotDatabase() int
// Set database to previous snapshot // Set database to previous snapshot
SetSnapshot(Database) RevertToSnapshot(int)
// Address of the original invoker (first occurrence of the VM invoker) // Address of the original invoker (first occurrence of the VM invoker)
Origin() common.Address Origin() common.Address
// The block number this VM is invoked on // The block number this VM is invoked on
@ -105,9 +105,12 @@ type Database interface {
GetState(common.Address, common.Hash) common.Hash GetState(common.Address, common.Hash) common.Hash
SetState(common.Address, common.Hash, common.Hash) SetState(common.Address, common.Hash, common.Hash)
Delete(common.Address) bool Suicide(common.Address) bool
HasSuicided(common.Address) bool
// Exist reports whether the given account exists in state.
// Notably this should also return true for suicided accounts.
Exist(common.Address) bool Exist(common.Address) bool
IsDeleted(common.Address) bool
} }
// Account represents a contract or basic ethereum account. // Account represents a contract or basic ethereum account.

View File

@ -614,7 +614,7 @@ func opSuicide(instr instruction, pc *uint64, env Environment, contract *Contrac
balance := env.Db().GetBalance(contract.Address()) balance := env.Db().GetBalance(contract.Address())
env.Db().AddBalance(common.BigToAddress(stack.pop()), balance) env.Db().AddBalance(common.BigToAddress(stack.pop()), balance)
env.Db().Delete(contract.Address()) env.Db().Suicide(contract.Address())
} }
// following functions are used by the instruction jump table // following functions are used by the instruction jump table

View File

@ -425,7 +425,7 @@ func jitCalculateGasAndSize(env Environment, contract *Contract, instr instructi
} }
gas.Set(g) gas.Set(g)
case SUICIDE: case SUICIDE:
if !statedb.IsDeleted(contract.Address()) { if !statedb.HasSuicided(contract.Address()) {
statedb.AddRefund(params.SuicideRefundGas) statedb.AddRefund(params.SuicideRefundGas)
} }
case MLOAD: case MLOAD:

View File

@ -179,8 +179,8 @@ func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) }
//func (self *Env) PrevHash() []byte { return self.parent } //func (self *Env) PrevHash() []byte { return self.parent }
func (self *Env) Coinbase() common.Address { return common.Address{} } func (self *Env) Coinbase() common.Address { return common.Address{} }
func (self *Env) MakeSnapshot() Database { return nil } func (self *Env) SnapshotDatabase() int { return 0 }
func (self *Env) SetSnapshot(Database) {} func (self *Env) RevertToSnapshot(int) {}
func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
func (self *Env) Db() Database { return nil } func (self *Env) Db() Database { return nil }

View File

@ -86,11 +86,11 @@ func (self *Env) SetDepth(i int) { self.depth = i }
func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0 return self.state.GetBalance(from).Cmp(balance) >= 0
} }
func (self *Env) MakeSnapshot() vm.Database { func (self *Env) SnapshotDatabase() int {
return self.state.Copy() return self.state.Snapshot()
} }
func (self *Env) SetSnapshot(copy vm.Database) { func (self *Env) RevertToSnapshot(snapshot int) {
self.state.Set(copy.(*state.StateDB)) self.state.RevertToSnapshot(snapshot)
} }
func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { func (self *Env) Transfer(from, to vm.Account, amount *big.Int) {

View File

@ -303,7 +303,7 @@ func calculateGasAndSize(env Environment, contract *Contract, caller ContractRef
} }
gas.Set(g) gas.Set(g)
case SUICIDE: case SUICIDE:
if !statedb.IsDeleted(contract.Address()) { if !statedb.HasSuicided(contract.Address()) {
statedb.AddRefund(params.SuicideRefundGas) statedb.AddRefund(params.SuicideRefundGas)
} }
case MLOAD: case MLOAD:

View File

@ -89,12 +89,12 @@ func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0 return self.state.GetBalance(from).Cmp(balance) >= 0
} }
func (self *VMEnv) MakeSnapshot() vm.Database { func (self *VMEnv) SnapshotDatabase() int {
return self.state.Copy() return self.state.Snapshot()
} }
func (self *VMEnv) SetSnapshot(copy vm.Database) { func (self *VMEnv) RevertToSnapshot(snapshot int) {
self.state.Set(copy.(*state.StateDB)) self.state.RevertToSnapshot(snapshot)
} }
func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) { func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) {

View File

@ -98,12 +98,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int {
} }
func (b *EthApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) { func (b *EthApiBackend) GetVMEnv(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header) (vm.Environment, func() error, error) {
stateDb := state.(EthApiState).state.Copy() statedb := state.(EthApiState).state
addr, _ := msg.From() addr, _ := msg.From()
from := stateDb.GetOrNewStateObject(addr) from := statedb.GetOrNewStateObject(addr)
from.SetBalance(common.MaxBig) from.SetBalance(common.MaxBig)
vmError := func() error { return nil } vmError := func() error { return nil }
return core.NewEnv(stateDb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil return core.NewEnv(statedb, b.eth.chainConfig, b.eth.blockchain, msg, header, b.eth.chainConfig.VmConfig), vmError, nil
} }
func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error {

View File

@ -51,8 +51,8 @@ func (self *Env) BlockNumber() *big.Int { return big.NewInt(0) }
//func (self *Env) PrevHash() []byte { return self.parent } //func (self *Env) PrevHash() []byte { return self.parent }
func (self *Env) Coinbase() common.Address { return common.Address{} } func (self *Env) Coinbase() common.Address { return common.Address{} }
func (self *Env) MakeSnapshot() vm.Database { return nil } func (self *Env) SnapshotDatabase() int { return 0 }
func (self *Env) SetSnapshot(vm.Database) {} func (self *Env) RevertToSnapshot(int) {}
func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) } func (self *Env) Time() *big.Int { return big.NewInt(time.Now().Unix()) }
func (self *Env) Difficulty() *big.Int { return big.NewInt(0) } func (self *Env) Difficulty() *big.Int { return big.NewInt(0) }
func (self *Env) Db() vm.Database { return nil } func (self *Env) Db() vm.Database { return nil }

View File

@ -23,7 +23,6 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -54,16 +53,13 @@ func makeTestState() (common.Hash, ethdb.Database) {
sdb, _ := ethdb.NewMemDatabase() sdb, _ := ethdb.NewMemDatabase()
st, _ := state.New(common.Hash{}, sdb) st, _ := state.New(common.Hash{}, sdb)
for i := byte(0); i < 100; i++ { for i := byte(0); i < 100; i++ {
so := st.GetOrNewStateObject(common.Address{i}) addr := common.Address{i}
for j := byte(0); j < 100; j++ { for j := byte(0); j < 100; j++ {
val := common.Hash{i, j} st.SetState(addr, common.Hash{j}, common.Hash{i, j})
so.SetState(common.Hash{j}, val)
so.SetNonce(100)
} }
so.AddBalance(big.NewInt(int64(i))) st.SetNonce(addr, 100)
so.SetCode(crypto.Keccak256Hash([]byte{i, i, i}), []byte{i, i, i}) st.AddBalance(addr, big.NewInt(int64(i)))
so.UpdateRoot(sdb) st.SetCode(addr, []byte{i, i, i})
st.UpdateStateObject(so)
} }
root, _ := st.Commit() root, _ := st.Commit()
return root, sdb return root, sdb

View File

@ -171,7 +171,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) {
self.current.receipts, self.current.receipts,
), self.current.state ), self.current.state
} }
return self.current.Block, self.current.state return self.current.Block, self.current.state.Copy()
} }
func (self *worker) start() { func (self *worker) start() {
@ -618,7 +618,7 @@ func (env *Work) commitTransactions(mux *event.TypeMux, txs *types.TransactionsB
} }
func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) { func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) {
snap := env.state.Copy() snap := env.state.Snapshot()
// this is a bit of a hack to force jit for the miners // this is a bit of a hack to force jit for the miners
config := env.config.VmConfig config := env.config.VmConfig
@ -629,7 +629,7 @@ func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, g
receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config) receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config)
if err != nil { if err != nil {
env.state.Set(snap) env.state.RevertToSnapshot(snap)
return err, nil return err, nil
} }
env.txs = append(env.txs, tx) env.txs = append(env.txs, tx)

View File

@ -95,14 +95,7 @@ func BenchStateTest(ruleSet RuleSet, p string, conf bconf, b *testing.B) error {
func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) { func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) {
b.StopTimer() b.StopTimer()
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb := makePreState(db, test.Pre)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
b.StartTimer() b.StartTimer()
RunState(ruleSet, statedb, env, test.Exec) RunState(ruleSet, statedb, env, test.Exec)
@ -134,14 +127,7 @@ func runStateTests(ruleSet RuleSet, tests map[string]VmTest, skipTests []string)
func runStateTest(ruleSet RuleSet, test VmTest) error { func runStateTest(ruleSet RuleSet, test VmTest) error {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb := makePreState(db, test.Pre)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
// XXX Yeah, yeah... // XXX Yeah, yeah...
env := make(map[string]string) env := make(map[string]string)
@ -227,7 +213,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string
} }
// Set pre compiled contracts // Set pre compiled contracts
vm.Precompiled = vm.PrecompiledContracts() vm.Precompiled = vm.PrecompiledContracts()
snapshot := statedb.Copy() snapshot := statedb.Snapshot()
gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"])) gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"]))
key, _ := hex.DecodeString(tx["secretKey"]) key, _ := hex.DecodeString(tx["secretKey"])
@ -237,7 +223,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string
vmenv.origin = addr vmenv.origin = addr
ret, _, err := core.ApplyMessage(vmenv, message, gaspool) ret, _, err := core.ApplyMessage(vmenv, message, gaspool)
if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) { if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) {
statedb.Set(snapshot) statedb.RevertToSnapshot(snapshot)
} }
statedb.Commit() statedb.Commit()

View File

@ -103,19 +103,25 @@ func (self Log) Topics() [][]byte {
return t return t
} }
func StateObjectFromAccount(db ethdb.Database, addr string, account Account, onDirty func(common.Address)) *state.StateObject { func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB {
statedb, _ := state.New(common.Hash{}, db)
for addr, account := range accounts {
insertAccount(statedb, addr, account)
}
return statedb
}
func insertAccount(state *state.StateDB, saddr string, account Account) {
if common.IsHex(account.Code) { if common.IsHex(account.Code) {
account.Code = account.Code[2:] account.Code = account.Code[2:]
} }
code := common.Hex2Bytes(account.Code) addr := common.HexToAddress(saddr)
codeHash := crypto.Keccak256Hash(code) state.SetCode(addr, common.Hex2Bytes(account.Code))
obj := state.NewObject(common.HexToAddress(addr), state.Account{ state.SetNonce(addr, common.Big(account.Nonce).Uint64())
Balance: common.Big(account.Balance), state.SetBalance(addr, common.Big(account.Balance))
CodeHash: codeHash[:], for a, v := range account.Storage {
Nonce: common.Big(account.Nonce).Uint64(), state.SetState(addr, common.HexToHash(a), common.HexToHash(v))
}, onDirty) }
obj.SetCode(codeHash, code)
return obj
} }
type VmEnv struct { type VmEnv struct {
@ -229,11 +235,11 @@ func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool {
return self.state.GetBalance(from).Cmp(balance) >= 0 return self.state.GetBalance(from).Cmp(balance) >= 0
} }
func (self *Env) MakeSnapshot() vm.Database { func (self *Env) SnapshotDatabase() int {
return self.state.Copy() return self.state.Snapshot()
} }
func (self *Env) SetSnapshot(copy vm.Database) { func (self *Env) RevertToSnapshot(snapshot int) {
self.state.Set(copy.(*state.StateDB)) self.state.RevertToSnapshot(snapshot)
} }
func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { func (self *Env) Transfer(from, to vm.Account, amount *big.Int) {

View File

@ -101,14 +101,7 @@ func BenchVmTest(p string, conf bconf, b *testing.B) error {
func benchVmTest(test VmTest, env map[string]string, b *testing.B) { func benchVmTest(test VmTest, env map[string]string, b *testing.B) {
b.StopTimer() b.StopTimer()
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb := makePreState(db, test.Pre)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
b.StartTimer() b.StartTimer()
RunVm(statedb, env, test.Exec) RunVm(statedb, env, test.Exec)
@ -152,14 +145,7 @@ func runVmTests(tests map[string]VmTest, skipTests []string) error {
func runVmTest(test VmTest) error { func runVmTest(test VmTest) error {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb := makePreState(db, test.Pre)
for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj)
for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v))
}
}
// XXX Yeah, yeah... // XXX Yeah, yeah...
env := make(map[string]string) env := make(map[string]string)