core/state: track all accounts in canon state

This change introduces a global, per-state cache that keeps account data
in the canon state. Thanks to @karalabe for lots of fixes.
This commit is contained in:
Felix Lange 2016-09-22 21:04:58 +02:00
parent e859f36967
commit a59a93f476
17 changed files with 417 additions and 339 deletions

View File

@ -135,12 +135,9 @@ func (b *SimulatedBackend) StorageAt(ctx context.Context, contract common.Addres
return nil, errBlockNumberUnsupported return nil, errBlockNumberUnsupported
} }
statedb, _ := b.blockchain.State() statedb, _ := b.blockchain.State()
if obj := statedb.GetStateObject(contract); obj != nil { val := statedb.GetState(contract, key)
val := obj.GetState(key)
return val[:], nil return val[:], nil
} }
return nil, nil
}
// TransactionReceipt returns the receipt of a transaction. // TransactionReceipt returns the receipt of a transaction.
func (b *SimulatedBackend) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) { func (b *SimulatedBackend) TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) {

View File

@ -93,6 +93,7 @@ type BlockChain struct {
currentBlock *types.Block // Current head of the block chain currentBlock *types.Block // Current head of the block chain
currentFastBlock *types.Block // Current head of the fast-sync chain (may be above the block chain!) currentFastBlock *types.Block // Current head of the fast-sync chain (may be above the block chain!)
stateCache *state.StateDB // State database to reuse between imports (contains state cache)
bodyCache *lru.Cache // Cache for the most recent block bodies bodyCache *lru.Cache // Cache for the most recent block bodies
bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format
blockCache *lru.Cache // Cache for the most recent entire blocks blockCache *lru.Cache // Cache for the most recent entire blocks
@ -196,7 +197,15 @@ func (self *BlockChain) loadLastState() error {
self.currentFastBlock = block self.currentFastBlock = block
} }
} }
// Issue a status log and return // Initialize a statedb cache to ensure singleton account bloom filter generation
statedb, err := state.New(self.currentBlock.Root(), self.chainDb)
if err != nil {
return err
}
self.stateCache = statedb
self.stateCache.GetAccount(common.Address{})
// Issue a status log for the user
headerTd := self.GetTd(currentHeader.Hash(), currentHeader.Number.Uint64()) headerTd := self.GetTd(currentHeader.Hash(), currentHeader.Number.Uint64())
blockTd := self.GetTd(self.currentBlock.Hash(), self.currentBlock.NumberU64()) blockTd := self.GetTd(self.currentBlock.Hash(), self.currentBlock.NumberU64())
fastTd := self.GetTd(self.currentFastBlock.Hash(), self.currentFastBlock.NumberU64()) fastTd := self.GetTd(self.currentFastBlock.Hash(), self.currentFastBlock.NumberU64())
@ -826,7 +835,6 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) {
tstart = time.Now() tstart = time.Now()
nonceChecked = make([]bool, len(chain)) nonceChecked = make([]bool, len(chain))
statedb *state.StateDB
) )
// Start the parallel nonce verifier. // Start the parallel nonce verifier.
@ -893,29 +901,30 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) {
// Create a new statedb using the parent block and report an // Create a new statedb using the parent block and report an
// error if it fails. // error if it fails.
if statedb == nil { switch {
statedb, err = state.New(self.GetBlock(block.ParentHash(), block.NumberU64()-1).Root(), self.chainDb) case i == 0:
} else { err = self.stateCache.Reset(self.GetBlock(block.ParentHash(), block.NumberU64()-1).Root())
err = statedb.Reset(chain[i-1].Root()) default:
err = self.stateCache.Reset(chain[i-1].Root())
} }
if err != nil { if err != nil {
reportBlock(block, err) reportBlock(block, err)
return i, err return i, err
} }
// Process block using the parent state as reference point. // Process block using the parent state as reference point.
receipts, logs, usedGas, err := self.processor.Process(block, statedb, self.config.VmConfig) receipts, logs, usedGas, err := self.processor.Process(block, self.stateCache, self.config.VmConfig)
if err != nil { if err != nil {
reportBlock(block, err) reportBlock(block, err)
return i, err return i, err
} }
// Validate the state using the default validator // Validate the state using the default validator
err = self.Validator().ValidateState(block, self.GetBlock(block.ParentHash(), block.NumberU64()-1), statedb, receipts, usedGas) err = self.Validator().ValidateState(block, self.GetBlock(block.ParentHash(), block.NumberU64()-1), self.stateCache, receipts, usedGas)
if err != nil { if err != nil {
reportBlock(block, err) reportBlock(block, err)
return i, err return i, err
} }
// Write state changes to database // Write state changes to database
_, err = statedb.Commit() _, err = self.stateCache.Commit()
if err != nil { if err != nil {
return i, err return i, err
} }

View File

@ -79,7 +79,7 @@ func ExampleGenerateChain() {
evmux := &event.TypeMux{} evmux := &event.TypeMux{}
blockchain, _ := NewBlockChain(db, MakeChainConfig(), FakePow{}, evmux) blockchain, _ := NewBlockChain(db, MakeChainConfig(), FakePow{}, evmux)
if i, err := blockchain.InsertChain(chain); err != nil { if i, err := blockchain.InsertChain(chain); err != nil {
fmt.Printf("insert error (block %d): %v\n", i, err) fmt.Printf("insert error (block %d): %v\n", chain[i].NumberU64(), err)
return return
} }

View File

@ -21,9 +21,10 @@ import (
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
) )
type Account struct { type DumpAccount struct {
Balance string `json:"balance"` Balance string `json:"balance"`
Nonce uint64 `json:"nonce"` Nonce uint64 `json:"nonce"`
Root string `json:"root"` Root string `json:"root"`
@ -32,40 +33,41 @@ type Account struct {
Storage map[string]string `json:"storage"` Storage map[string]string `json:"storage"`
} }
type World struct { type Dump struct {
Root string `json:"root"` Root string `json:"root"`
Accounts map[string]Account `json:"accounts"` Accounts map[string]DumpAccount `json:"accounts"`
} }
func (self *StateDB) RawDump() World { func (self *StateDB) RawDump() Dump {
world := World{ dump := Dump{
Root: common.Bytes2Hex(self.trie.Root()), Root: common.Bytes2Hex(self.trie.Root()),
Accounts: make(map[string]Account), Accounts: make(map[string]DumpAccount),
} }
it := self.trie.Iterator() it := self.trie.Iterator()
for it.Next() { for it.Next() {
addr := self.trie.GetKey(it.Key) addr := self.trie.GetKey(it.Key)
stateObject, err := DecodeObject(common.BytesToAddress(addr), self.db, it.Value) var data Account
if err != nil { if err := rlp.DecodeBytes(it.Value, &data); err != nil {
panic(err) panic(err)
} }
account := Account{ obj := NewObject(common.BytesToAddress(addr), data, nil)
Balance: stateObject.balance.String(), account := DumpAccount{
Nonce: stateObject.nonce, Balance: data.Balance.String(),
Root: common.Bytes2Hex(stateObject.Root()), Nonce: data.Nonce,
CodeHash: common.Bytes2Hex(stateObject.codeHash), Root: common.Bytes2Hex(data.Root[:]),
Code: common.Bytes2Hex(stateObject.Code()), CodeHash: common.Bytes2Hex(data.CodeHash),
Code: common.Bytes2Hex(obj.Code(self.db)),
Storage: make(map[string]string), Storage: make(map[string]string),
} }
storageIt := stateObject.trie.Iterator() storageIt := obj.getTrie(self.db).Iterator()
for storageIt.Next() { for storageIt.Next() {
account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value) account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value)
} }
world.Accounts[common.Bytes2Hex(addr)] = account dump.Accounts[common.Bytes2Hex(addr)] = account
} }
return world return dump
} }
func (self *StateDB) Dump() []byte { func (self *StateDB) Dump() []byte {
@ -76,12 +78,3 @@ func (self *StateDB) Dump() []byte {
return json return json
} }
// Debug stuff
func (self *StateObject) CreateOutputForDiff() {
fmt.Printf("%x %x %x %x\n", self.Address(), self.Root(), self.balance.Bytes(), self.nonce)
it := self.trie.Iterator()
for it.Next() {
fmt.Printf("%x %x\n", it.Key, it.Value)
}
}

View File

@ -33,14 +33,14 @@ type ManagedState struct {
mu sync.RWMutex mu sync.RWMutex
accounts map[string]*account accounts map[common.Address]*account
} }
// ManagedState returns a new managed state with the statedb as it's backing layer // ManagedState returns a new managed state with the statedb as it's backing layer
func ManageState(statedb *StateDB) *ManagedState { func ManageState(statedb *StateDB) *ManagedState {
return &ManagedState{ return &ManagedState{
StateDB: statedb.Copy(), StateDB: statedb.Copy(),
accounts: make(map[string]*account), accounts: make(map[common.Address]*account),
} }
} }
@ -103,7 +103,7 @@ func (ms *ManagedState) SetNonce(addr common.Address, nonce uint64) {
so := ms.GetOrNewStateObject(addr) so := ms.GetOrNewStateObject(addr)
so.SetNonce(nonce) so.SetNonce(nonce)
ms.accounts[addr.Str()] = newAccount(so) ms.accounts[addr] = newAccount(so)
} }
// HasAccount returns whether the given address is managed or not // HasAccount returns whether the given address is managed or not
@ -114,29 +114,28 @@ func (ms *ManagedState) HasAccount(addr common.Address) bool {
} }
func (ms *ManagedState) hasAccount(addr common.Address) bool { func (ms *ManagedState) hasAccount(addr common.Address) bool {
_, ok := ms.accounts[addr.Str()] _, ok := ms.accounts[addr]
return ok return ok
} }
// populate the managed state // populate the managed state
func (ms *ManagedState) getAccount(addr common.Address) *account { func (ms *ManagedState) getAccount(addr common.Address) *account {
straddr := addr.Str() if account, ok := ms.accounts[addr]; !ok {
if account, ok := ms.accounts[straddr]; !ok {
so := ms.GetOrNewStateObject(addr) so := ms.GetOrNewStateObject(addr)
ms.accounts[straddr] = newAccount(so) ms.accounts[addr] = newAccount(so)
} else { } else {
// Always make sure the state account nonce isn't actually higher // Always make sure the state account nonce isn't actually higher
// than the tracked one. // than the tracked one.
so := ms.StateDB.GetStateObject(addr) so := ms.StateDB.GetStateObject(addr)
if so != nil && uint64(len(account.nonces))+account.nstart < so.nonce { if so != nil && uint64(len(account.nonces))+account.nstart < so.Nonce() {
ms.accounts[straddr] = newAccount(so) ms.accounts[addr] = newAccount(so)
} }
} }
return ms.accounts[straddr] return ms.accounts[addr]
} }
func newAccount(so *StateObject) *account { func newAccount(so *StateObject) *account {
return &account{so, so.nonce, nil} return &account{so, so.Nonce(), nil}
} }

View File

@ -29,11 +29,12 @@ func create() (*ManagedState, *account) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := New(common.Hash{}, db) statedb, _ := New(common.Hash{}, db)
ms := ManageState(statedb) ms := ManageState(statedb)
so := &StateObject{address: addr, nonce: 100} so := &StateObject{address: addr}
ms.StateDB.stateObjects[addr.Str()] = so so.SetNonce(100)
ms.accounts[addr.Str()] = newAccount(so) ms.StateDB.stateObjects[addr] = so
ms.accounts[addr] = newAccount(so)
return ms, ms.accounts[addr.Str()] return ms, ms.accounts[addr]
} }
func TestNewNonce(t *testing.T) { func TestNewNonce(t *testing.T) {
@ -92,7 +93,7 @@ func TestRemoteNonceChange(t *testing.T) {
account.nonces = append(account.nonces, nn...) account.nonces = append(account.nonces, nn...)
nonce := ms.NewNonce(addr) nonce := ms.NewNonce(addr)
ms.StateDB.stateObjects[addr.Str()].nonce = 200 ms.StateDB.stateObjects[addr].data.Nonce = 200
nonce = ms.NewNonce(addr) nonce = ms.NewNonce(addr)
if nonce != 200 { if nonce != 200 {
t.Error("expected nonce after remote update to be", 201, "got", nonce) t.Error("expected nonce after remote update to be", 201, "got", nonce)
@ -100,7 +101,7 @@ func TestRemoteNonceChange(t *testing.T) {
ms.NewNonce(addr) ms.NewNonce(addr)
ms.NewNonce(addr) ms.NewNonce(addr)
ms.NewNonce(addr) ms.NewNonce(addr)
ms.StateDB.stateObjects[addr.Str()].nonce = 200 ms.StateDB.stateObjects[addr].data.Nonce = 200
nonce = ms.NewNonce(addr) nonce = ms.NewNonce(addr)
if nonce != 204 { if nonce != 204 {
t.Error("expected nonce after remote update to be", 201, "got", nonce) t.Error("expected nonce after remote update to be", 201, "got", nonce)

View File

@ -57,108 +57,163 @@ func (self Storage) Copy() Storage {
return cpy return cpy
} }
// StateObject represents an Ethereum account which is being modified.
//
// The usage pattern is as follows:
// First you need to obtain a state object.
// Account values can be accessed and modified through the object.
// Finally, call CommitTrie to write the modified storage trie into a database.
type StateObject struct { type StateObject struct {
db trie.Database // State database for storing state changes address common.Address // Ethereum address of this account
trie *trie.SecureTrie data Account
// Address belonging to this account // DB error.
address common.Address // State objects are used by the consensus core and VM which are
// The balance of the account // unable to deal with database-level errors. Any error that occurs
balance *big.Int // during a database read is memoized here and will eventually be returned
// The nonce of the account // by StateDB.Commit.
nonce uint64 dbErr error
// The code hash if code is present (i.e. a contract)
codeHash []byte
// The code for this account
code Code
// Cached storage (flushed when updated)
storage Storage
// Mark for deletion // Write caches.
trie *trie.SecureTrie // storage trie, which becomes non-nil on first access
code Code // contract bytecode, which gets set when code is loaded
storage Storage // Cached storage (flushed when updated)
// Cache flags.
// When an object is marked for deletion it will be delete from the trie // When an object is marked for deletion it will be delete from the trie
// during the "update" phase of the state transition // during the "update" phase of the state transition
dirtyCode bool // true if the code was updated
remove bool remove bool
deleted bool deleted bool
dirty bool onDirty func(addr common.Address) // Callback method to mark a state object newly dirty
} }
func NewStateObject(address common.Address, db trie.Database) *StateObject { // Account is the Ethereum consensus representation of accounts.
object := &StateObject{ // These objects are stored in the main account trie.
db: db, type Account struct {
address: address, Nonce uint64
balance: new(big.Int), Balance *big.Int
dirty: true, Root common.Hash // merkle root of the storage trie
codeHash: emptyCodeHash, CodeHash []byte
storage: make(Storage),
codeSize *int
}
// NewObject creates a state object.
func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject {
if data.Balance == nil {
data.Balance = new(big.Int)
}
if data.CodeHash == nil {
data.CodeHash = emptyCodeHash
}
return &StateObject{address: address, data: data, storage: make(Storage), onDirty: onDirty}
}
// EncodeRLP implements rlp.Encoder.
func (c *StateObject) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, c.data)
}
// setError remembers the first non-nil error it is called with.
func (self *StateObject) setError(err error) {
if self.dbErr == nil {
self.dbErr = err
} }
object.trie, _ = trie.NewSecure(common.Hash{}, db)
return object
} }
func (self *StateObject) MarkForDeletion() { func (self *StateObject) MarkForDeletion() {
self.remove = true self.remove = true
self.dirty = true if self.onDirty != nil {
self.onDirty(self.Address())
self.onDirty = nil
}
if glog.V(logger.Core) { if glog.V(logger.Core) {
glog.Infof("%x: #%d %v X\n", self.Address(), self.nonce, self.balance) glog.Infof("%x: #%d %v X\n", self.Address(), self.Nonce(), self.Balance())
} }
} }
func (c *StateObject) getAddr(addr common.Hash) common.Hash { func (c *StateObject) getTrie(db trie.Database) *trie.SecureTrie {
var ret []byte if c.trie == nil {
rlp.DecodeBytes(c.trie.Get(addr[:]), &ret) var err error
return common.BytesToHash(ret) c.trie, err = trie.NewSecure(c.data.Root, db)
}
func (c *StateObject) setAddr(addr, value common.Hash) {
v, err := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00"))
if err != nil { if err != nil {
// if RLPing failed we better panic and not fail silently. This would be considered a consensus issue c.trie, _ = trie.NewSecure(common.Hash{}, db)
panic(err) c.setError(fmt.Errorf("can't create storage trie: %v", err))
} }
c.trie.Update(addr[:], v) }
return c.trie
} }
func (self *StateObject) Storage() Storage { // GetState returns a value in account storage.
return self.storage func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash {
}
func (self *StateObject) GetState(key common.Hash) common.Hash {
value, exists := self.storage[key] value, exists := self.storage[key]
if !exists { if exists {
value = self.getAddr(key) return value
}
// Load from DB in case it is missing.
tr := self.getTrie(db)
var ret []byte
rlp.DecodeBytes(tr.Get(key[:]), &ret)
value = common.BytesToHash(ret)
if (value != common.Hash{}) { if (value != common.Hash{}) {
self.storage[key] = value self.storage[key] = value
} }
}
return value return value
} }
// SetState updates a value in account storage.
func (self *StateObject) SetState(key, value common.Hash) { func (self *StateObject) SetState(key, value common.Hash) {
self.storage[key] = value self.storage[key] = value
self.dirty = true if self.onDirty != nil {
self.onDirty(self.Address())
self.onDirty = nil
}
} }
// Update updates the current cached storage to the trie // updateTrie writes cached storage modifications into the object's storage trie.
func (self *StateObject) Update() { func (self *StateObject) updateTrie(db trie.Database) {
tr := self.getTrie(db)
for key, value := range self.storage { for key, value := range self.storage {
if (value == common.Hash{}) { if (value == common.Hash{}) {
self.trie.Delete(key[:]) tr.Delete(key[:])
continue continue
} }
self.setAddr(key, value) // Encoding []byte cannot fail, ok to ignore the error.
v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00"))
tr.Update(key[:], v)
} }
} }
// UpdateRoot sets the trie root to the current root hash of
func (self *StateObject) UpdateRoot(db trie.Database) {
self.updateTrie(db)
self.data.Root = self.trie.Hash()
}
// CommitTrie the storage trie of the object to dwb.
// This updates the trie root.
func (self *StateObject) CommitTrie(db trie.Database, dbw trie.DatabaseWriter) error {
self.updateTrie(db)
if self.dbErr != nil {
fmt.Println("dbErr:", self.dbErr)
return self.dbErr
}
root, err := self.trie.CommitTo(dbw)
if err == nil {
self.data.Root = root
}
return err
}
func (c *StateObject) AddBalance(amount *big.Int) { func (c *StateObject) AddBalance(amount *big.Int) {
if amount.Cmp(common.Big0) == 0 { if amount.Cmp(common.Big0) == 0 {
return return
} }
c.SetBalance(new(big.Int).Add(c.balance, amount)) c.SetBalance(new(big.Int).Add(c.Balance(), amount))
if glog.V(logger.Core) { if glog.V(logger.Core) {
glog.Infof("%x: #%d %v (+ %v)\n", c.Address(), c.nonce, c.balance, amount) glog.Infof("%x: #%d %v (+ %v)\n", c.Address(), c.Nonce(), c.Balance(), amount)
} }
} }
@ -166,37 +221,32 @@ func (c *StateObject) SubBalance(amount *big.Int) {
if amount.Cmp(common.Big0) == 0 { if amount.Cmp(common.Big0) == 0 {
return return
} }
c.SetBalance(new(big.Int).Sub(c.balance, amount)) c.SetBalance(new(big.Int).Sub(c.Balance(), amount))
if glog.V(logger.Core) { if glog.V(logger.Core) {
glog.Infof("%x: #%d %v (- %v)\n", c.Address(), c.nonce, c.balance, amount) glog.Infof("%x: #%d %v (- %v)\n", c.Address(), c.Nonce(), c.Balance(), amount)
} }
} }
func (c *StateObject) SetBalance(amount *big.Int) { func (self *StateObject) SetBalance(amount *big.Int) {
c.balance = amount self.data.Balance = amount
c.dirty = true if self.onDirty != nil {
self.onDirty(self.Address())
self.onDirty = nil
} }
func (c *StateObject) St() Storage {
return c.storage
} }
// Return the gas back to the origin. Used by the Virtual machine or Closures // Return the gas back to the origin. Used by the Virtual machine or Closures
func (c *StateObject) ReturnGas(gas, price *big.Int) {} func (c *StateObject) ReturnGas(gas, price *big.Int) {}
func (self *StateObject) Copy() *StateObject { func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject {
stateObject := NewStateObject(self.Address(), self.db) stateObject := NewObject(self.address, self.data, onDirty)
stateObject.balance.Set(self.balance)
stateObject.codeHash = common.CopyBytes(self.codeHash)
stateObject.nonce = self.nonce
stateObject.trie = self.trie stateObject.trie = self.trie
stateObject.code = self.code stateObject.code = self.code
stateObject.storage = self.storage.Copy() stateObject.storage = self.storage.Copy()
stateObject.remove = self.remove stateObject.remove = self.remove
stateObject.dirty = self.dirty stateObject.dirtyCode = self.dirtyCode
stateObject.deleted = self.deleted stateObject.deleted = self.deleted
return stateObject return stateObject
} }
@ -204,40 +254,66 @@ func (self *StateObject) Copy() *StateObject {
// Attribute accessors // Attribute accessors
// //
func (self *StateObject) Balance() *big.Int {
return self.balance
}
// Returns the address of the contract/account // Returns the address of the contract/account
func (c *StateObject) Address() common.Address { func (c *StateObject) Address() common.Address {
return c.address return c.address
} }
func (self *StateObject) Trie() *trie.SecureTrie { // Code returns the contract code associated with this object, if any.
return self.trie func (self *StateObject) Code(db trie.Database) []byte {
} if self.code != nil {
func (self *StateObject) Root() []byte {
return self.trie.Root()
}
func (self *StateObject) Code() []byte {
return self.code return self.code
} }
if bytes.Equal(self.CodeHash(), emptyCodeHash) {
return nil
}
code, err := db.Get(self.CodeHash())
if err != nil {
self.setError(fmt.Errorf("can't load code hash %x: %v", self.CodeHash(), err))
}
self.code = code
return code
}
// CodeSize returns the size of the contract code associated with this object.
func (self *StateObject) CodeSize(db trie.Database) int {
if self.data.codeSize == nil {
self.data.codeSize = new(int)
*self.data.codeSize = len(self.Code(db))
}
return *self.data.codeSize
}
func (self *StateObject) SetCode(code []byte) { func (self *StateObject) SetCode(code []byte) {
self.code = code self.code = code
self.codeHash = crypto.Keccak256(code) self.data.CodeHash = crypto.Keccak256(code)
self.dirty = true self.data.codeSize = new(int)
*self.data.codeSize = len(code)
self.dirtyCode = true
if self.onDirty != nil {
self.onDirty(self.Address())
self.onDirty = nil
}
} }
func (self *StateObject) SetNonce(nonce uint64) { func (self *StateObject) SetNonce(nonce uint64) {
self.nonce = nonce self.data.Nonce = nonce
self.dirty = true if self.onDirty != nil {
self.onDirty(self.Address())
self.onDirty = nil
}
}
func (self *StateObject) CodeHash() []byte {
return self.data.CodeHash
}
func (self *StateObject) Balance() *big.Int {
return self.data.Balance
} }
func (self *StateObject) Nonce() uint64 { func (self *StateObject) Nonce() uint64 {
return self.nonce return self.data.Nonce
} }
// Never called, but must be present to allow StateObject to be used // Never called, but must be present to allow StateObject to be used
@ -262,39 +338,3 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) {
} }
} }
} }
type extStateObject struct {
Nonce uint64
Balance *big.Int
Root common.Hash
CodeHash []byte
}
// EncodeRLP implements rlp.Encoder.
func (c *StateObject) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, []interface{}{c.nonce, c.balance, c.Root(), c.codeHash})
}
// DecodeObject decodes an RLP-encoded state object.
func DecodeObject(address common.Address, db trie.Database, data []byte) (*StateObject, error) {
var (
obj = &StateObject{address: address, db: db, storage: make(Storage)}
ext extStateObject
err error
)
if err = rlp.DecodeBytes(data, &ext); err != nil {
return nil, err
}
if obj.trie, err = trie.NewSecure(ext.Root, db); err != nil {
return nil, err
}
if !bytes.Equal(ext.CodeHash, emptyCodeHash) {
if obj.code, err = db.Get(ext.CodeHash); err != nil {
return nil, fmt.Errorf("can't get code for hash %x: %v", ext.CodeHash, err)
}
}
obj.nonce = ext.Nonce
obj.balance = ext.Balance
obj.codeHash = ext.CodeHash
return obj, nil
}

View File

@ -146,23 +146,23 @@ func TestSnapshot2(t *testing.T) {
// db, trie are already non-empty values // db, trie are already non-empty values
so0 := state.GetStateObject(stateobjaddr0) so0 := state.GetStateObject(stateobjaddr0)
so0.balance = big.NewInt(42) so0.SetBalance(big.NewInt(42))
so0.nonce = 43 so0.SetNonce(43)
so0.SetCode([]byte{'c', 'a', 'f', 'e'}) so0.SetCode([]byte{'c', 'a', 'f', 'e'})
so0.remove = false so0.remove = false
so0.deleted = false so0.deleted = false
so0.dirty = true
state.SetStateObject(so0) state.SetStateObject(so0)
state.Commit()
root, _ := state.Commit()
state.Reset(root)
// and one with deleted == true // and one with deleted == true
so1 := state.GetStateObject(stateobjaddr1) so1 := state.GetStateObject(stateobjaddr1)
so1.balance = big.NewInt(52) so1.SetBalance(big.NewInt(52))
so1.nonce = 53 so1.SetNonce(53)
so1.SetCode([]byte{'c', 'a', 'f', 'e', '2'}) so1.SetCode([]byte{'c', 'a', 'f', 'e', '2'})
so1.remove = true so1.remove = true
so1.deleted = true so1.deleted = true
so1.dirty = true
state.SetStateObject(so1) state.SetStateObject(so1)
so1 = state.GetStateObject(stateobjaddr1) so1 = state.GetStateObject(stateobjaddr1)
@ -174,41 +174,50 @@ func TestSnapshot2(t *testing.T) {
state.Set(snapshot) state.Set(snapshot)
so0Restored := state.GetStateObject(stateobjaddr0) so0Restored := state.GetStateObject(stateobjaddr0)
so0Restored.GetState(storageaddr) // Update lazily-loaded values before comparing.
so1Restored := state.GetStateObject(stateobjaddr1) so0Restored.GetState(db, storageaddr)
so0Restored.Code(db)
// non-deleted is equal (restored) // non-deleted is equal (restored)
compareStateObjects(so0Restored, so0, t) compareStateObjects(so0Restored, so0, t)
// deleted should be nil, both before and after restore of state copy // deleted should be nil, both before and after restore of state copy
so1Restored := state.GetStateObject(stateobjaddr1)
if so1Restored != nil { if so1Restored != nil {
t.Fatalf("deleted object not nil after restoring snapshot") t.Fatalf("deleted object not nil after restoring snapshot: %+v", so1Restored)
} }
} }
func compareStateObjects(so0, so1 *StateObject, t *testing.T) { func compareStateObjects(so0, so1 *StateObject, t *testing.T) {
if so0.address != so1.address { if so0.Address() != so1.Address() {
t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address) t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address)
} }
if so0.balance.Cmp(so1.balance) != 0 { if so0.Balance().Cmp(so1.Balance()) != 0 {
t.Fatalf("Balance mismatch: have %v, want %v", so0.balance, so1.balance) t.Fatalf("Balance mismatch: have %v, want %v", so0.Balance(), so1.Balance())
} }
if so0.nonce != so1.nonce { if so0.Nonce() != so1.Nonce() {
t.Fatalf("Nonce mismatch: have %v, want %v", so0.nonce, so1.nonce) t.Fatalf("Nonce mismatch: have %v, want %v", so0.Nonce(), so1.Nonce())
} }
if !bytes.Equal(so0.codeHash, so1.codeHash) { if so0.data.Root != so1.data.Root {
t.Fatalf("CodeHash mismatch: have %v, want %v", so0.codeHash, so1.codeHash) t.Errorf("Root mismatch: have %x, want %x", so0.data.Root[:], so1.data.Root[:])
}
if !bytes.Equal(so0.CodeHash(), so1.CodeHash()) {
t.Fatalf("CodeHash mismatch: have %v, want %v", so0.CodeHash(), so1.CodeHash())
} }
if !bytes.Equal(so0.code, so1.code) { if !bytes.Equal(so0.code, so1.code) {
t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code) t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code)
} }
if len(so1.storage) != len(so0.storage) {
t.Errorf("Storage size mismatch: have %d, want %d", len(so1.storage), len(so0.storage))
}
for k, v := range so1.storage { for k, v := range so1.storage {
if so0.storage[k] != v { if so0.storage[k] != v {
t.Fatalf("Storage key %s mismatch: have %v, want %v", k, so0.storage[k], v) t.Errorf("Storage key %x mismatch: have %v, want %v", k, so0.storage[k], v)
} }
} }
for k, v := range so0.storage { for k, v := range so0.storage {
if so1.storage[k] != v { if so1.storage[k] != v {
t.Fatalf("Storage key %s mismatch: have %v, want none.", k, v) t.Errorf("Storage key %x mismatch: have %v, want none.", k, v)
} }
} }
@ -218,7 +227,4 @@ func compareStateObjects(so0, so1 *StateObject, t *testing.T) {
if so0.deleted != so1.deleted { if so0.deleted != so1.deleted {
t.Fatalf("Deleted mismatch: have %v, want %v", so0.deleted, so1.deleted) t.Fatalf("Deleted mismatch: have %v, want %v", so0.deleted, so1.deleted)
} }
if so0.dirty != so1.dirty {
t.Fatalf("Dirty mismatch: have %v, want %v", so0.dirty, so1.dirty)
}
} }

View File

@ -43,8 +43,14 @@ type StateDB struct {
db ethdb.Database db ethdb.Database
trie *trie.SecureTrie trie *trie.SecureTrie
stateObjects map[string]*StateObject // This map caches canon state accounts.
all map[common.Address]Account
// This map holds 'live' objects, which will get modified while processing a state transition.
stateObjects map[common.Address]*StateObject
stateObjectsDirty map[common.Address]struct{}
// The refund counter, also used by state transitioning.
refund *big.Int refund *big.Int
thash, bhash common.Hash thash, bhash common.Hash
@ -62,7 +68,9 @@ func New(root common.Hash, db ethdb.Database) (*StateDB, error) {
return &StateDB{ return &StateDB{
db: db, db: db,
trie: tr, trie: tr,
stateObjects: make(map[string]*StateObject), all: make(map[common.Address]Account),
stateObjects: make(map[common.Address]*StateObject),
stateObjectsDirty: make(map[common.Address]struct{}),
refund: new(big.Int), refund: new(big.Int),
logs: make(map[common.Hash]vm.Logs), logs: make(map[common.Hash]vm.Logs),
}, nil }, nil
@ -71,19 +79,21 @@ func New(root common.Hash, db ethdb.Database) (*StateDB, error) {
// Reset clears out all emphemeral state objects from the state db, but keeps // Reset clears out all emphemeral state objects from the state db, but keeps
// the underlying state trie to avoid reloading data for the next operations. // the underlying state trie to avoid reloading data for the next operations.
func (self *StateDB) Reset(root common.Hash) error { func (self *StateDB) Reset(root common.Hash) error {
var ( tr, err := trie.NewSecure(root, self.db)
err error if err != nil {
tr = self.trie
)
if self.trie.Hash() != root {
if tr, err = trie.NewSecure(root, self.db); err != nil {
return err return err
} }
all := self.all
if self.trie.Hash() != root {
// The root has changed, invalidate canon state.
all = make(map[common.Address]Account)
} }
*self = StateDB{ *self = StateDB{
db: self.db, db: self.db,
trie: tr, trie: tr,
stateObjects: make(map[string]*StateObject), all: all,
stateObjects: make(map[common.Address]*StateObject),
stateObjectsDirty: make(map[common.Address]struct{}),
refund: new(big.Int), refund: new(big.Int),
logs: make(map[common.Hash]vm.Logs), logs: make(map[common.Hash]vm.Logs),
} }
@ -137,7 +147,7 @@ func (self *StateDB) GetAccount(addr common.Address) vm.Account {
func (self *StateDB) GetBalance(addr common.Address) *big.Int { func (self *StateDB) GetBalance(addr common.Address) *big.Int {
stateObject := self.GetStateObject(addr) stateObject := self.GetStateObject(addr)
if stateObject != nil { if stateObject != nil {
return stateObject.balance return stateObject.Balance()
} }
return common.Big0 return common.Big0
@ -146,7 +156,7 @@ func (self *StateDB) GetBalance(addr common.Address) *big.Int {
func (self *StateDB) GetNonce(addr common.Address) uint64 { func (self *StateDB) GetNonce(addr common.Address) uint64 {
stateObject := self.GetStateObject(addr) stateObject := self.GetStateObject(addr)
if stateObject != nil { if stateObject != nil {
return stateObject.nonce return stateObject.Nonce()
} }
return StartingNonce return StartingNonce
@ -155,18 +165,24 @@ func (self *StateDB) GetNonce(addr common.Address) uint64 {
func (self *StateDB) GetCode(addr common.Address) []byte { func (self *StateDB) GetCode(addr common.Address) []byte {
stateObject := self.GetStateObject(addr) stateObject := self.GetStateObject(addr)
if stateObject != nil { if stateObject != nil {
return stateObject.code return stateObject.Code(self.db)
}
return nil
} }
return nil func (self *StateDB) GetCodeSize(addr common.Address) int {
stateObject := self.GetStateObject(addr)
if stateObject != nil {
return stateObject.CodeSize(self.db)
}
return 0
} }
func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash { func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash {
stateObject := self.GetStateObject(a) stateObject := self.GetStateObject(a)
if stateObject != nil { if stateObject != nil {
return stateObject.GetState(b) return stateObject.GetState(self.db, b)
} }
return common.Hash{} return common.Hash{}
} }
@ -214,8 +230,7 @@ func (self *StateDB) Delete(addr common.Address) bool {
stateObject := self.GetStateObject(addr) stateObject := self.GetStateObject(addr)
if stateObject != nil { if stateObject != nil {
stateObject.MarkForDeletion() stateObject.MarkForDeletion()
stateObject.balance = new(big.Int) stateObject.data.Balance = new(big.Int)
return true return true
} }
@ -242,35 +257,47 @@ func (self *StateDB) DeleteStateObject(stateObject *StateObject) {
addr := stateObject.Address() addr := stateObject.Address()
self.trie.Delete(addr[:]) self.trie.Delete(addr[:])
//delete(self.stateObjects, addr.Str())
} }
// Retrieve a state object given my the address. Nil if not found // Retrieve a state object given my the address. Returns nil if not found.
func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObject) { func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObject) {
stateObject = self.stateObjects[addr.Str()] // Prefer 'live' objects.
if stateObject != nil { if obj := self.stateObjects[addr]; obj != nil {
if stateObject.deleted { if obj.deleted {
stateObject = nil
}
return stateObject
}
data := self.trie.Get(addr[:])
if len(data) == 0 {
return nil return nil
} }
stateObject, err := DecodeObject(addr, self.db, data) return obj
if err != nil { }
// Use cached account data from the canon state if possible.
if data, ok := self.all[addr]; ok {
obj := NewObject(addr, data, self.MarkStateObjectDirty)
self.SetStateObject(obj)
return obj
}
// Load the object from the database.
enc := self.trie.Get(addr[:])
if len(enc) == 0 {
return nil
}
var data Account
if err := rlp.DecodeBytes(enc, &data); err != nil {
glog.Errorf("can't decode object at %x: %v", addr[:], err) glog.Errorf("can't decode object at %x: %v", addr[:], err)
return nil return nil
} }
self.SetStateObject(stateObject) // Update the all cache. Content in DB always corresponds
return stateObject // to the current head state so this is ok to do here.
// The object we just loaded has no storage trie and code yet.
self.all[addr] = data
// Insert into the live set.
obj := NewObject(addr, data, self.MarkStateObjectDirty)
self.SetStateObject(obj)
return obj
} }
func (self *StateDB) SetStateObject(object *StateObject) { func (self *StateDB) SetStateObject(object *StateObject) {
self.stateObjects[object.Address().Str()] = object self.stateObjects[object.Address()] = object
} }
// Retrieve a state object or create a new state object if nil // Retrieve a state object or create a new state object if nil
@ -288,15 +315,19 @@ func (self *StateDB) newStateObject(addr common.Address) *StateObject {
if glog.V(logger.Core) { if glog.V(logger.Core) {
glog.Infof("(+) %x\n", addr) glog.Infof("(+) %x\n", addr)
} }
obj := NewObject(addr, Account{}, self.MarkStateObjectDirty)
stateObject := NewStateObject(addr, self.db) obj.SetNonce(StartingNonce) // sets the object to dirty
stateObject.SetNonce(StartingNonce) self.stateObjects[addr] = obj
self.stateObjects[addr.Str()] = stateObject return obj
return stateObject
} }
// Creates creates a new state object and takes ownership. This is different from "NewStateObject" // MarkStateObjectDirty adds the specified object to the dirty map to avoid costly
// state object cache iteration to find a handful of modified ones.
func (self *StateDB) MarkStateObjectDirty(addr common.Address) {
self.stateObjectsDirty[addr] = struct{}{}
}
// Creates creates a new state object and takes ownership.
func (self *StateDB) CreateStateObject(addr common.Address) *StateObject { func (self *StateDB) CreateStateObject(addr common.Address) *StateObject {
// Get previous (if any) // Get previous (if any)
so := self.GetStateObject(addr) so := self.GetStateObject(addr)
@ -305,7 +336,7 @@ func (self *StateDB) CreateStateObject(addr common.Address) *StateObject {
// If it existed set the balance to the new account // If it existed set the balance to the new account
if so != nil { if so != nil {
newSo.balance = so.balance newSo.data.Balance = so.data.Balance
} }
return newSo return newSo
@ -320,29 +351,34 @@ func (self *StateDB) CreateAccount(addr common.Address) vm.Account {
// //
func (self *StateDB) Copy() *StateDB { func (self *StateDB) Copy() *StateDB {
// ignore error - we assume state-to-be-copied always exists // Copy all the basic fields, initialize the memory ones
state, _ := New(common.Hash{}, self.db) state := &StateDB{
state.trie = self.trie db: self.db,
for k, stateObject := range self.stateObjects { trie: self.trie,
if stateObject.dirty { all: self.all,
state.stateObjects[k] = stateObject.Copy() stateObjects: make(map[common.Address]*StateObject, len(self.stateObjectsDirty)),
stateObjectsDirty: make(map[common.Address]struct{}, len(self.stateObjectsDirty)),
refund: new(big.Int).Set(self.refund),
logs: make(map[common.Hash]vm.Logs, len(self.logs)),
logSize: self.logSize,
} }
// Copy the dirty states and logs
for addr, _ := range self.stateObjectsDirty {
state.stateObjects[addr] = self.stateObjects[addr].Copy(self.db, state.MarkStateObjectDirty)
state.stateObjectsDirty[addr] = struct{}{}
} }
state.refund.Set(self.refund)
for hash, logs := range self.logs { for hash, logs := range self.logs {
state.logs[hash] = make(vm.Logs, len(logs)) state.logs[hash] = make(vm.Logs, len(logs))
copy(state.logs[hash], logs) copy(state.logs[hash], logs)
} }
state.logSize = self.logSize
return state return state
} }
func (self *StateDB) Set(state *StateDB) { func (self *StateDB) Set(state *StateDB) {
self.trie = state.trie self.trie = state.trie
self.stateObjects = state.stateObjects self.stateObjects = state.stateObjects
self.stateObjectsDirty = state.stateObjectsDirty
self.all = state.all
self.refund = state.refund self.refund = state.refund
self.logs = state.logs self.logs = state.logs
@ -358,16 +394,15 @@ func (self *StateDB) GetRefund() *big.Int {
// goes into transaction receipts. // goes into transaction receipts.
func (s *StateDB) IntermediateRoot() common.Hash { func (s *StateDB) IntermediateRoot() common.Hash {
s.refund = new(big.Int) s.refund = new(big.Int)
for _, stateObject := range s.stateObjects { for addr, _ := range s.stateObjectsDirty {
if stateObject.dirty { stateObject := s.stateObjects[addr]
if stateObject.remove { if stateObject.remove {
s.DeleteStateObject(stateObject) s.DeleteStateObject(stateObject)
} else { } else {
stateObject.Update() stateObject.UpdateRoot(s.db)
s.UpdateStateObject(stateObject) s.UpdateStateObject(stateObject)
} }
} }
}
return s.trie.Hash() return s.trie.Hash()
} }
@ -380,15 +415,15 @@ func (s *StateDB) DeleteSuicides() {
// Reset refund so that any used-gas calculations can use // Reset refund so that any used-gas calculations can use
// this method. // this method.
s.refund = new(big.Int) s.refund = new(big.Int)
for _, stateObject := range s.stateObjects { for addr, _ := range s.stateObjectsDirty {
if stateObject.dirty { stateObject := s.stateObjects[addr]
// If the object has been removed by a suicide // If the object has been removed by a suicide
// flag the object as deleted. // flag the object as deleted.
if stateObject.remove { if stateObject.remove {
stateObject.deleted = true stateObject.deleted = true
} }
stateObject.dirty = false delete(s.stateObjectsDirty, addr)
}
} }
} }
@ -407,46 +442,44 @@ func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) {
return root, batch return root, batch
} }
func (s *StateDB) commit(db trie.DatabaseWriter) (common.Hash, error) { func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) {
s.refund = new(big.Int) s.refund = new(big.Int)
defer func() {
if err != nil {
// Committing failed, any updates to the canon state are invalid.
s.all = make(map[common.Address]Account)
}
}()
for _, stateObject := range s.stateObjects { // Commit objects to the trie.
for addr, stateObject := range s.stateObjects {
if stateObject.remove { if stateObject.remove {
// If the object has been removed, don't bother syncing it // If the object has been removed, don't bother syncing it
// and just mark it for deletion in the trie. // and just mark it for deletion in the trie.
s.DeleteStateObject(stateObject) s.DeleteStateObject(stateObject)
} else { delete(s.all, addr)
} else if _, ok := s.stateObjectsDirty[addr]; ok {
// Write any contract code associated with the state object // Write any contract code associated with the state object
if len(stateObject.code) > 0 { if stateObject.code != nil && stateObject.dirtyCode {
if err := db.Put(stateObject.codeHash, stateObject.code); err != nil { if err := dbw.Put(stateObject.CodeHash(), stateObject.code); err != nil {
return common.Hash{}, err return common.Hash{}, err
} }
stateObject.dirtyCode = false
} }
// Write any storage changes in the state object to its trie. // Write any storage changes in the state object to its storage trie.
stateObject.Update() if err := stateObject.CommitTrie(s.db, dbw); err != nil {
// Commit the trie of the object to the batch.
// This updates the trie root internally, so
// getting the root hash of the storage trie
// through UpdateStateObject is fast.
if _, err := stateObject.trie.CommitTo(db); err != nil {
return common.Hash{}, err return common.Hash{}, err
} }
// Update the object in the account trie. // Update the object in the main account trie.
s.UpdateStateObject(stateObject) s.UpdateStateObject(stateObject)
s.all[addr] = stateObject.data
} }
stateObject.dirty = false delete(s.stateObjectsDirty, addr)
} }
return s.trie.CommitTo(db) // Write trie changes.
return s.trie.CommitTo(dbw)
} }
func (self *StateDB) Refunds() *big.Int { func (self *StateDB) Refunds() *big.Int {
return self.refund return self.refund
} }
// Debug stuff
func (self *StateDB) CreateOutputForDiff() {
for _, stateObject := range self.stateObjects {
stateObject.CreateOutputForDiff()
}
}

View File

@ -94,6 +94,7 @@ type Database interface {
GetNonce(common.Address) uint64 GetNonce(common.Address) uint64
SetNonce(common.Address, uint64) SetNonce(common.Address, uint64)
GetCodeSize(common.Address) int
GetCode(common.Address) []byte GetCode(common.Address) []byte
SetCode(common.Address, []byte) SetCode(common.Address, []byte)

View File

@ -363,7 +363,7 @@ func opCalldataCopy(instr instruction, pc *uint64, env Environment, contract *Co
func opExtCodeSize(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *Stack) { func opExtCodeSize(instr instruction, pc *uint64, env Environment, contract *Contract, memory *Memory, stack *Stack) {
addr := common.BigToAddress(stack.pop()) addr := common.BigToAddress(stack.pop())
l := big.NewInt(int64(len(env.Db().GetCode(addr)))) l := big.NewInt(int64(env.Db().GetCodeSize(addr)))
stack.push(l) stack.push(l)
} }

View File

@ -288,14 +288,14 @@ func NewPublicDebugAPI(eth *Ethereum) *PublicDebugAPI {
} }
// DumpBlock retrieves the entire state of the database at a given block. // DumpBlock retrieves the entire state of the database at a given block.
func (api *PublicDebugAPI) DumpBlock(number uint64) (state.World, error) { func (api *PublicDebugAPI) DumpBlock(number uint64) (state.Dump, error) {
block := api.eth.BlockChain().GetBlockByNumber(number) block := api.eth.BlockChain().GetBlockByNumber(number)
if block == nil { if block == nil {
return state.World{}, fmt.Errorf("block #%d not found", number) return state.Dump{}, fmt.Errorf("block #%d not found", number)
} }
stateDb, err := state.New(block.Root(), api.eth.ChainDb()) stateDb, err := state.New(block.Root(), api.eth.ChainDb())
if err != nil { if err != nil {
return state.World{}, err return state.Dump{}, err
} }
return stateDb.RawDump(), nil return stateDb.RawDump(), nil
} }

View File

@ -1280,8 +1280,8 @@ func (api *PrivateDebugAPI) ChaindbProperty(property string) (string, error) {
} }
// SetHead rewinds the head of the blockchain to a previous block. // SetHead rewinds the head of the blockchain to a previous block.
func (api *PrivateDebugAPI) SetHead(number uint64) { func (api *PrivateDebugAPI) SetHead(number rpc.HexNumber) {
api.b.SetHead(number) api.b.SetHead(uint64(number.Int64()))
} }
// PublicNetAPI offers network related RPC methods // PublicNetAPI offers network related RPC methods

View File

@ -62,7 +62,7 @@ func makeTestState() (common.Hash, ethdb.Database) {
} }
so.AddBalance(big.NewInt(int64(i))) so.AddBalance(big.NewInt(int64(i)))
so.SetCode([]byte{i, i, i}) so.SetCode([]byte{i, i, i})
so.Update() so.UpdateRoot(sdb)
st.UpdateStateObject(so) st.UpdateStateObject(so)
} }
root, _ := st.Commit() root, _ := st.Commit()

View File

@ -97,7 +97,7 @@ func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *test
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, db)
for addr, account := range test.Pre { for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account) obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj) statedb.SetStateObject(obj)
for a, v := range account.Storage { for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v)) obj.SetState(common.HexToHash(a), common.HexToHash(v))
@ -136,7 +136,7 @@ func runStateTest(ruleSet RuleSet, test VmTest) error {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, db)
for addr, account := range test.Pre { for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account) obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj) statedb.SetStateObject(obj)
for a, v := range account.Storage { for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v)) obj.SetState(common.HexToHash(a), common.HexToHash(v))
@ -187,7 +187,7 @@ func runStateTest(ruleSet RuleSet, test VmTest) error {
} }
for addr, value := range account.Storage { for addr, value := range account.Storage {
v := obj.GetState(common.HexToHash(addr)) v := statedb.GetState(obj.Address(), common.HexToHash(addr))
vexp := common.HexToHash(value) vexp := common.HexToHash(value)
if v != vexp { if v != vexp {

View File

@ -103,16 +103,17 @@ func (self Log) Topics() [][]byte {
return t return t
} }
func StateObjectFromAccount(db ethdb.Database, addr string, account Account) *state.StateObject { func StateObjectFromAccount(db ethdb.Database, addr string, account Account, onDirty func(common.Address)) *state.StateObject {
obj := state.NewStateObject(common.HexToAddress(addr), db)
obj.SetBalance(common.Big(account.Balance))
if common.IsHex(account.Code) { if common.IsHex(account.Code) {
account.Code = account.Code[2:] account.Code = account.Code[2:]
} }
obj.SetCode(common.Hex2Bytes(account.Code)) code := common.Hex2Bytes(account.Code)
obj.SetNonce(common.Big(account.Nonce).Uint64()) obj := state.NewObject(common.HexToAddress(addr), state.Account{
Balance: common.Big(account.Balance),
CodeHash: crypto.Keccak256(code),
Nonce: common.Big(account.Nonce).Uint64(),
}, onDirty)
obj.SetCode(code)
return obj return obj
} }

View File

@ -103,7 +103,7 @@ func benchVmTest(test VmTest, env map[string]string, b *testing.B) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, db)
for addr, account := range test.Pre { for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account) obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj) statedb.SetStateObject(obj)
for a, v := range account.Storage { for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v)) obj.SetState(common.HexToHash(a), common.HexToHash(v))
@ -154,7 +154,7 @@ func runVmTest(test VmTest) error {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, db)
for addr, account := range test.Pre { for addr, account := range test.Pre {
obj := StateObjectFromAccount(db, addr, account) obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty)
statedb.SetStateObject(obj) statedb.SetStateObject(obj)
for a, v := range account.Storage { for a, v := range account.Storage {
obj.SetState(common.HexToHash(a), common.HexToHash(v)) obj.SetState(common.HexToHash(a), common.HexToHash(v))
@ -205,11 +205,9 @@ func runVmTest(test VmTest) error {
if obj == nil { if obj == nil {
continue continue
} }
for addr, value := range account.Storage { for addr, value := range account.Storage {
v := obj.GetState(common.HexToHash(addr)) v := statedb.GetState(obj.Address(), common.HexToHash(addr))
vexp := common.HexToHash(value) vexp := common.HexToHash(value)
if v != vexp { if v != vexp {
return fmt.Errorf("(%x: %s) storage failed. Expected %x, got %x (%v %v)\n", obj.Address().Bytes()[0:4], addr, vexp, v, vexp.Big(), v.Big()) return fmt.Errorf("(%x: %s) storage failed. Expected %x, got %x (%v %v)\n", obj.Address().Bytes()[0:4], addr, vexp, v, vexp.Big(), v.Big())
} }