State snapshotting

This commit is contained in:
obscuren 2014-05-26 00:09:38 +02:00
parent 81ef40010f
commit 3ebd7f1166
5 changed files with 81 additions and 9 deletions

View File

@ -99,7 +99,21 @@ func (s *State) Cmp(other *State) bool {
}
func (s *State) Copy() *State {
return NewState(s.trie.Copy())
state := NewState(s.trie.Copy())
for k, subState := range s.states {
state.states[k] = subState.Copy()
}
return state
}
func (s *State) Snapshot() *State {
return s.Copy()
}
func (s *State) Revert(snapshot *State) {
s.trie = snapshot.trie
s.states = snapshot.states
}
func (s *State) Put(key, object []byte) {

View File

@ -81,12 +81,17 @@ func (c *StateObject) SetStorage(num *big.Int, val *ethutil.Value) {
c.SetAddr(addr, val)
}
func (c *StateObject) GetMem(num *big.Int) *ethutil.Value {
func (c *StateObject) GetStorage(num *big.Int) *ethutil.Value {
nb := ethutil.BigToBytes(num, 256)
return c.Addr(nb)
}
/* DEPRECATED */
func (c *StateObject) GetMem(num *big.Int) *ethutil.Value {
return c.GetStorage(num)
}
func (c *StateObject) GetInstr(pc *big.Int) *ethutil.Value {
if int64(len(c.script)-1) < pc.Int64() {
return ethutil.NewValue(0)

31
ethchain/state_test.go Normal file
View File

@ -0,0 +1,31 @@
package ethchain
import (
"fmt"
"github.com/ethereum/eth-go/ethdb"
"github.com/ethereum/eth-go/ethutil"
"testing"
)
func TestSnapshot(t *testing.T) {
ethutil.ReadConfig("", ethutil.LogStd, "")
db, _ := ethdb.NewMemDatabase()
state := NewState(ethutil.NewTrie(db, ""))
stateObject := NewContract([]byte("aa"), ethutil.Big1, ZeroHash256)
state.UpdateStateObject(stateObject)
stateObject.SetStorage(ethutil.Big("0"), ethutil.NewValue(42))
snapshot := state.Snapshot()
stateObject = state.GetStateObject([]byte("aa"))
stateObject.SetStorage(ethutil.Big("0"), ethutil.NewValue(43))
state.Revert(snapshot)
stateObject = state.GetStateObject([]byte("aa"))
if !stateObject.GetStorage(ethutil.Big("0")).Cmp(ethutil.NewValue(42)) {
t.Error("Expected storage 0 to be 42")
}
}

View File

@ -426,6 +426,10 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro
value := stack.Pop()
size, offset := stack.Popn()
// Snapshot the current stack so we are able to
// revert back to it later.
snapshot := vm.state.Snapshot()
// Generate a new address
addr := ethutil.CreateAddress(closure.callee.Address(), closure.callee.N())
// Create a new contract
@ -448,6 +452,9 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro
closure.Script, err = closure.Call(vm, nil, hook)
if err != nil {
stack.Push(ethutil.BigFalse)
// Revert the state as it was before.
vm.state.Revert(snapshot)
} else {
stack.Push(ethutil.BigD(addr))
@ -473,6 +480,8 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro
// Get the arguments from the memory
args := mem.Get(inOffset.Int64(), inSize.Int64())
snapshot := vm.state.Snapshot()
// Fetch the contract which will serve as the closure body
contract := vm.state.GetStateObject(addr.Bytes())
@ -495,14 +504,14 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro
if err != nil {
stack.Push(ethutil.BigFalse)
// Reset the changes applied this object
//contract.State().Reset()
vm.state.Revert(snapshot)
} else {
stack.Push(ethutil.BigTrue)
}
vm.state.UpdateStateObject(contract)
mem.Set(retOffset.Int64(), retSize.Int64(), ret)
}
} else {
ethutil.Config.Log.Debugf("Contract %x not found\n", addr.Bytes())
stack.Push(ethutil.BigFalse)

View File

@ -1,7 +1,7 @@
package ethutil
import (
_ "fmt"
"fmt"
"reflect"
"testing"
)
@ -26,7 +26,6 @@ func (db *MemDatabase) Delete(key []byte) error {
delete(db.db, string(key))
return nil
}
func (db *MemDatabase) GetKeys() []*Key { return nil }
func (db *MemDatabase) Print() {}
func (db *MemDatabase) Close() {}
func (db *MemDatabase) LastKnownTD() []byte { return nil }
@ -171,3 +170,17 @@ func TestTrieIterator(t *testing.T) {
t.Errorf("Expected cached nodes to be deleted")
}
}
func TestHashes(t *testing.T) {
_, trie := New()
trie.Update("cat", "dog")
trie.Update("ca", "dude")
trie.Update("doge", "1234567890abcdefghijklmnopqrstuvwxxzABCEFGHIJKLMNOPQRSTUVWXYZ")
trie.Update("dog", "test")
trie.Update("test", "1234567890abcdefghijklmnopqrstuvwxxzABCEFGHIJKLMNOPQRSTUVWXYZ")
fmt.Printf("%x\n", trie.Root)
trie.Delete("dog")
fmt.Printf("%x\n", trie.Root)
trie.Delete("test")
fmt.Printf("%x\n", trie.Root)
}