diff --git a/ethchain/state.go b/ethchain/state.go index 6ec6916f4..e209e0e2f 100644 --- a/ethchain/state.go +++ b/ethchain/state.go @@ -99,7 +99,21 @@ func (s *State) Cmp(other *State) bool { } func (s *State) Copy() *State { - return NewState(s.trie.Copy()) + state := NewState(s.trie.Copy()) + for k, subState := range s.states { + state.states[k] = subState.Copy() + } + + return state +} + +func (s *State) Snapshot() *State { + return s.Copy() +} + +func (s *State) Revert(snapshot *State) { + s.trie = snapshot.trie + s.states = snapshot.states } func (s *State) Put(key, object []byte) { diff --git a/ethchain/state_object.go b/ethchain/state_object.go index 4d615e2fe..3e9c6df40 100644 --- a/ethchain/state_object.go +++ b/ethchain/state_object.go @@ -81,12 +81,17 @@ func (c *StateObject) SetStorage(num *big.Int, val *ethutil.Value) { c.SetAddr(addr, val) } -func (c *StateObject) GetMem(num *big.Int) *ethutil.Value { +func (c *StateObject) GetStorage(num *big.Int) *ethutil.Value { nb := ethutil.BigToBytes(num, 256) return c.Addr(nb) } +/* DEPRECATED */ +func (c *StateObject) GetMem(num *big.Int) *ethutil.Value { + return c.GetStorage(num) +} + func (c *StateObject) GetInstr(pc *big.Int) *ethutil.Value { if int64(len(c.script)-1) < pc.Int64() { return ethutil.NewValue(0) diff --git a/ethchain/state_test.go b/ethchain/state_test.go new file mode 100644 index 000000000..4cc3fdf75 --- /dev/null +++ b/ethchain/state_test.go @@ -0,0 +1,31 @@ +package ethchain + +import ( + "fmt" + "github.com/ethereum/eth-go/ethdb" + "github.com/ethereum/eth-go/ethutil" + "testing" +) + +func TestSnapshot(t *testing.T) { + ethutil.ReadConfig("", ethutil.LogStd, "") + + db, _ := ethdb.NewMemDatabase() + state := NewState(ethutil.NewTrie(db, "")) + + stateObject := NewContract([]byte("aa"), ethutil.Big1, ZeroHash256) + state.UpdateStateObject(stateObject) + stateObject.SetStorage(ethutil.Big("0"), ethutil.NewValue(42)) + + snapshot := state.Snapshot() + + stateObject = state.GetStateObject([]byte("aa")) + stateObject.SetStorage(ethutil.Big("0"), ethutil.NewValue(43)) + + state.Revert(snapshot) + + stateObject = state.GetStateObject([]byte("aa")) + if !stateObject.GetStorage(ethutil.Big("0")).Cmp(ethutil.NewValue(42)) { + t.Error("Expected storage 0 to be 42") + } +} diff --git a/ethchain/vm.go b/ethchain/vm.go index e067a9c96..e025920f3 100644 --- a/ethchain/vm.go +++ b/ethchain/vm.go @@ -426,6 +426,10 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro value := stack.Pop() size, offset := stack.Popn() + // Snapshot the current stack so we are able to + // revert back to it later. + snapshot := vm.state.Snapshot() + // Generate a new address addr := ethutil.CreateAddress(closure.callee.Address(), closure.callee.N()) // Create a new contract @@ -448,6 +452,9 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro closure.Script, err = closure.Call(vm, nil, hook) if err != nil { stack.Push(ethutil.BigFalse) + + // Revert the state as it was before. + vm.state.Revert(snapshot) } else { stack.Push(ethutil.BigD(addr)) @@ -473,6 +480,8 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro // Get the arguments from the memory args := mem.Get(inOffset.Int64(), inSize.Int64()) + snapshot := vm.state.Snapshot() + // Fetch the contract which will serve as the closure body contract := vm.state.GetStateObject(addr.Bytes()) @@ -495,14 +504,14 @@ func (vm *Vm) RunClosure(closure *Closure, hook DebugHook) (ret []byte, err erro if err != nil { stack.Push(ethutil.BigFalse) // Reset the changes applied this object - //contract.State().Reset() + vm.state.Revert(snapshot) } else { stack.Push(ethutil.BigTrue) + + vm.state.UpdateStateObject(contract) + + mem.Set(retOffset.Int64(), retSize.Int64(), ret) } - - vm.state.UpdateStateObject(contract) - - mem.Set(retOffset.Int64(), retSize.Int64(), ret) } else { ethutil.Config.Log.Debugf("Contract %x not found\n", addr.Bytes()) stack.Push(ethutil.BigFalse) diff --git a/ethutil/trie_test.go b/ethutil/trie_test.go index 0be512d9f..d74d129ac 100644 --- a/ethutil/trie_test.go +++ b/ethutil/trie_test.go @@ -1,7 +1,7 @@ package ethutil import ( - _ "fmt" + "fmt" "reflect" "testing" ) @@ -26,7 +26,6 @@ func (db *MemDatabase) Delete(key []byte) error { delete(db.db, string(key)) return nil } -func (db *MemDatabase) GetKeys() []*Key { return nil } func (db *MemDatabase) Print() {} func (db *MemDatabase) Close() {} func (db *MemDatabase) LastKnownTD() []byte { return nil } @@ -171,3 +170,17 @@ func TestTrieIterator(t *testing.T) { t.Errorf("Expected cached nodes to be deleted") } } + +func TestHashes(t *testing.T) { + _, trie := New() + trie.Update("cat", "dog") + trie.Update("ca", "dude") + trie.Update("doge", "1234567890abcdefghijklmnopqrstuvwxxzABCEFGHIJKLMNOPQRSTUVWXYZ") + trie.Update("dog", "test") + trie.Update("test", "1234567890abcdefghijklmnopqrstuvwxxzABCEFGHIJKLMNOPQRSTUVWXYZ") + fmt.Printf("%x\n", trie.Root) + trie.Delete("dog") + fmt.Printf("%x\n", trie.Root) + trie.Delete("test") + fmt.Printf("%x\n", trie.Root) +}