diff --git a/chain/state/statetree.go b/chain/state/statetree.go index 656f9749e..812d626ad 100644 --- a/chain/state/statetree.go +++ b/chain/state/statetree.go @@ -26,16 +26,71 @@ type StateTree struct { root *hamt.Node Store cbor.IpldStore - // Maps ID addresses to actors. - actorcache map[address.Address]*types.Actor - snapshots []cid.Cid + snaps *stateSnaps +} + +type stateSnaps struct { + layers []map[address.Address]streeOp +} + +type streeOp struct { + Act types.Actor + Delete bool +} + +func newStateSnaps() *stateSnaps { + return &stateSnaps{ + layers: []map[address.Address]streeOp{make(map[address.Address]streeOp)}, + } +} + +func (ss *stateSnaps) addLayer() { + ss.layers = append(ss.layers, make(map[address.Address]streeOp)) +} + +func (ss *stateSnaps) dropLayer() { + ss.layers[len(ss.layers)-1] = nil // allow it to be GCed + ss.layers = ss.layers[:len(ss.layers)-1] +} + +func (ss *stateSnaps) mergeLastLayer() { + last := ss.layers[len(ss.layers)-1] + nextLast := ss.layers[len(ss.layers)-2] + + for k, v := range last { + nextLast[k] = v + } + + ss.dropLayer() +} + +func (ss *stateSnaps) getActor(addr address.Address) (*types.Actor, error) { + for i := len(ss.layers) - 1; i >= 0; i-- { + act, ok := ss.layers[i][addr] + if ok { + if act.Delete { + return nil, types.ErrActorNotFound + } + + return &act.Act, nil + } + } + return nil, nil +} + +func (ss *stateSnaps) setActor(addr address.Address, act *types.Actor) { + ss.layers[len(ss.layers)-1][addr] = streeOp{Act: *act} +} + +func (ss *stateSnaps) deleteActor(addr address.Address) { + ss.layers[len(ss.layers)-1][addr] = streeOp{Delete: true} } func NewStateTree(cst cbor.IpldStore) (*StateTree, error) { return &StateTree{ - root: hamt.NewNode(cst, hamt.UseTreeBitWidth(5)), - Store: cst, - actorcache: make(map[address.Address]*types.Actor), + root: hamt.NewNode(cst, hamt.UseTreeBitWidth(5)), + Store: cst, + snaps: newStateSnaps(), }, nil } @@ -47,9 +102,9 @@ func LoadStateTree(cst cbor.IpldStore, c cid.Cid) (*StateTree, error) { } return &StateTree{ - root: nd, - Store: cst, - actorcache: make(map[address.Address]*types.Actor), + root: nd, + Store: cst, + snaps: newStateSnaps(), }, nil } @@ -60,16 +115,8 @@ func (st *StateTree) SetActor(addr address.Address, act *types.Actor) error { } addr = iaddr - cact, ok := st.actorcache[addr] - if ok { - if act == cact { - return nil - } - } - - st.actorcache[addr] = act - - return st.root.Set(context.TODO(), string(addr.Bytes()), act) + st.snaps.setActor(addr, act) + return nil } // `LookupID` gets the ID address of this actor's `addr` stored in the `InitActor`. @@ -111,9 +158,13 @@ func (st *StateTree) GetActor(addr address.Address) (*types.Actor, error) { } addr = iaddr - cact, ok := st.actorcache[addr] - if ok { - return cact, nil + snapAct, err := st.snaps.getActor(addr) + if err != nil { + return nil, err + } + + if snapAct != nil { + return snapAct, nil } var act types.Actor @@ -125,7 +176,7 @@ func (st *StateTree) GetActor(addr address.Address) (*types.Actor, error) { return nil, xerrors.Errorf("hamt find failed: %w", err) } - st.actorcache[addr] = &act + st.snaps.setActor(addr, &act) return &act, nil } @@ -145,22 +196,32 @@ func (st *StateTree) DeleteActor(addr address.Address) error { addr = iaddr - delete(st.actorcache, addr) - - if err := st.root.Delete(context.TODO(), string(addr.Bytes())); err != nil { - return xerrors.Errorf("failed to delete actor: %w", err) + _, err = st.GetActor(addr) + if err != nil { + return err } + st.snaps.deleteActor(addr) + return nil } func (st *StateTree) Flush(ctx context.Context) (cid.Cid, error) { ctx, span := trace.StartSpan(ctx, "stateTree.Flush") defer span.End() + if len(st.snaps.layers) != 1 { + return cid.Undef, xerrors.Errorf("tried to flush state tree with snapshots on the stack") + } - for addr, act := range st.actorcache { - if err := st.root.Set(ctx, string(addr.Bytes()), act); err != nil { - return cid.Undef, err + for addr, sto := range st.snaps.layers[0] { + if sto.Delete { + if err := st.root.Delete(ctx, string(addr.Bytes())); err != nil { + return cid.Undef, err + } + } else { + if err := st.root.Set(ctx, string(addr.Bytes()), &sto.Act); err != nil { + return cid.Undef, err + } } } @@ -175,17 +236,13 @@ func (st *StateTree) Snapshot(ctx context.Context) error { ctx, span := trace.StartSpan(ctx, "stateTree.SnapShot") defer span.End() - ss, err := st.Flush(ctx) - if err != nil { - return err - } + st.snaps.addLayer() - st.snapshots = append(st.snapshots, ss) return nil } func (st *StateTree) ClearSnapshot() { - st.snapshots = st.snapshots[:len(st.snapshots)-1] + st.snaps.mergeLastLayer() } func (st *StateTree) RegisterNewAddress(addr address.Address) (address.Address, error) { @@ -226,14 +283,9 @@ func (a *AdtStore) Context() context.Context { var _ adt.Store = (*AdtStore)(nil) func (st *StateTree) Revert() error { - revTo := st.snapshots[len(st.snapshots)-1] - nd, err := hamt.LoadNode(context.Background(), st.Store, revTo, hamt.UseTreeBitWidth(5)) - if err != nil { - return err - } - st.actorcache = make(map[address.Address]*types.Actor) + st.snaps.dropLayer() + st.snaps.addLayer() - st.root = nd return nil } diff --git a/chain/state/statetree_test.go b/chain/state/statetree_test.go index a316cd063..83d11305a 100644 --- a/chain/state/statetree_test.go +++ b/chain/state/statetree_test.go @@ -143,8 +143,8 @@ func TestSetCache(t *testing.T) { t.Fatal(err) } - if outact.Nonce != act.Nonce { - t.Error("nonce didn't match") + if outact.Nonce == 1 { + t.Error("nonce should not have updated") } } @@ -206,6 +206,8 @@ func TestSnapshots(t *testing.T) { st.ClearSnapshot() } + st.ClearSnapshot() + if _, err := st.Flush(ctx); err != nil { t.Fatal(err) } diff --git a/chain/vm/invoker.go b/chain/vm/invoker.go index 2f3dd3472..a66eb5833 100644 --- a/chain/vm/invoker.go +++ b/chain/vm/invoker.go @@ -29,7 +29,6 @@ import ( "github.com/filecoin-project/specs-actors/actors/util/adt" "github.com/filecoin-project/lotus/chain/actors/aerrors" - "github.com/filecoin-project/lotus/chain/types" ) type invoker struct { @@ -37,7 +36,7 @@ type invoker struct { builtInState map[cid.Cid]reflect.Type } -type invokeFunc func(act *types.Actor, rt runtime.Runtime, params []byte) ([]byte, aerrors.ActorError) +type invokeFunc func(rt runtime.Runtime, params []byte) ([]byte, aerrors.ActorError) type nativeCode []invokeFunc func NewInvoker() *invoker { @@ -61,17 +60,17 @@ func NewInvoker() *invoker { return inv } -func (inv *invoker) Invoke(act *types.Actor, rt runtime.Runtime, method abi.MethodNum, params []byte) ([]byte, aerrors.ActorError) { +func (inv *invoker) Invoke(codeCid cid.Cid, rt runtime.Runtime, method abi.MethodNum, params []byte) ([]byte, aerrors.ActorError) { - code, ok := inv.builtInCode[act.Code] + code, ok := inv.builtInCode[codeCid] if !ok { - log.Errorf("no code for actor %s (Addr: %s)", act.Code, rt.Message().Receiver()) - return nil, aerrors.Newf(exitcode.SysErrorIllegalActor, "no code for actor %s(%d)(%s)", act.Code, method, hex.EncodeToString(params)) + log.Errorf("no code for actor %s (Addr: %s)", codeCid, rt.Message().Receiver()) + return nil, aerrors.Newf(exitcode.SysErrorIllegalActor, "no code for actor %s(%d)(%s)", codeCid, method, hex.EncodeToString(params)) } if method >= abi.MethodNum(len(code)) || code[method] == nil { return nil, aerrors.Newf(exitcode.SysErrInvalidMethod, "no method %d on actor", method) } - return code[method](act, rt, params) + return code[method](rt, params) } @@ -137,7 +136,7 @@ func (*invoker) transform(instance Invokee) (nativeCode, error) { paramT := meth.Type().In(1).Elem() param := reflect.New(paramT) - inBytes := in[2].Interface().([]byte) + inBytes := in[1].Interface().([]byte) if len(inBytes) > 0 { if err := DecodeParams(inBytes, param.Interface()); err != nil { aerr := aerrors.Absorb(err, 1, "failed to decode parameters") @@ -149,7 +148,7 @@ func (*invoker) transform(instance Invokee) (nativeCode, error) { } } } - rt := in[1].Interface().(*Runtime) + rt := in[0].Interface().(*Runtime) rval, aerror := rt.shimCall(func() interface{} { ret := meth.Call([]reflect.Value{ reflect.ValueOf(rt), diff --git a/chain/vm/invoker_test.go b/chain/vm/invoker_test.go index c94b7216f..b46b445a2 100644 --- a/chain/vm/invoker_test.go +++ b/chain/vm/invoker_test.go @@ -84,7 +84,7 @@ func TestInvokerBasic(t *testing.T) { bParam, err := actors.SerializeParams(&basicParams{B: 1}) assert.NoError(t, err) - _, aerr := code[0](nil, &Runtime{}, bParam) + _, aerr := code[0](&Runtime{}, bParam) assert.Equal(t, exitcode.ExitCode(1), aerrors.RetCode(aerr), "return code should be 1") if aerrors.IsFatal(aerr) { @@ -96,14 +96,14 @@ func TestInvokerBasic(t *testing.T) { bParam, err := actors.SerializeParams(&basicParams{B: 2}) assert.NoError(t, err) - _, aerr := code[10](nil, &Runtime{}, bParam) + _, aerr := code[10](&Runtime{}, bParam) assert.Equal(t, exitcode.ExitCode(12), aerrors.RetCode(aerr), "return code should be 12") if aerrors.IsFatal(aerr) { t.Fatal("err should not be fatal") } } - _, aerr := code[1](nil, &Runtime{}, []byte{99}) + _, aerr := code[1](&Runtime{}, []byte{99}) if aerrors.IsFatal(aerr) { t.Fatal("err should not be fatal") } diff --git a/chain/vm/vm.go b/chain/vm/vm.go index aa7a4d6da..5e4034d1e 100644 --- a/chain/vm/vm.go +++ b/chain/vm/vm.go @@ -599,7 +599,7 @@ func (vm *VM) Invoke(act *types.Actor, rt *Runtime, method abi.MethodNum, params defer func() { rt.ctx = oldCtx }() - ret, err := vm.inv.Invoke(act, rt, method, params) + ret, err := vm.inv.Invoke(act.Code, rt, method, params) if err != nil { return nil, err } @@ -611,13 +611,10 @@ func (vm *VM) SetInvoker(i *invoker) { } func (vm *VM) incrementNonce(addr address.Address) error { - a, err := vm.cstate.GetActor(addr) - if err != nil { - return xerrors.Errorf("nonce increment of sender failed") - } - - a.Nonce++ - return nil + return vm.cstate.MutateActor(addr, func(a *types.Actor) error { + a.Nonce++ + return nil + }) } func (vm *VM) transfer(from, to address.Address, amt types.BigInt) error { @@ -643,6 +640,15 @@ func (vm *VM) transfer(from, to address.Address, amt types.BigInt) error { return err } depositFunds(t, amt) + + if err := vm.cstate.SetActor(from, f); err != nil { + return err + } + + if err := vm.cstate.SetActor(to, t); err != nil { + return err + } + return nil } @@ -651,16 +657,13 @@ func (vm *VM) transferToGasHolder(addr address.Address, gasHolder *types.Actor, return xerrors.Errorf("attempted to transfer negative value to gas holder") } - a, err := vm.cstate.GetActor(addr) - if err != nil { - return xerrors.Errorf("transfer to gas holder failed when retrieving sender actor") - } - - if err := deductFunds(a, amt); err != nil { - return err - } - depositFunds(gasHolder, amt) - return nil + return vm.cstate.MutateActor(addr, func(a *types.Actor) error { + if err := deductFunds(a, amt); err != nil { + return err + } + depositFunds(gasHolder, amt) + return nil + }) } func (vm *VM) transferFromGasHolder(addr address.Address, gasHolder *types.Actor, amt types.BigInt) error { @@ -668,16 +671,13 @@ func (vm *VM) transferFromGasHolder(addr address.Address, gasHolder *types.Actor return xerrors.Errorf("attempted to transfer negative value from gas holder") } - a, err := vm.cstate.GetActor(addr) - if err != nil { - return xerrors.Errorf("transfer from gas holder failed when retrieving receiver actor") - } - - if err := deductFunds(gasHolder, amt); err != nil { - return err - } - depositFunds(a, amt) - return nil + return vm.cstate.MutateActor(addr, func(a *types.Actor) error { + if err := deductFunds(gasHolder, amt); err != nil { + return err + } + depositFunds(a, amt) + return nil + }) } func deductFunds(act *types.Actor, amt types.BigInt) error {