diff --git a/les/serverpool.go b/les/serverpool.go index aff774324..9bfa0bd72 100644 --- a/les/serverpool.go +++ b/les/serverpool.go @@ -166,7 +166,7 @@ func newServerPool(db ethdb.KeyValueStore, dbKey []byte, vt *lpc.ValueTracker, d if oldState.Equals(sfWaitDialTimeout) && newState.IsEmpty() { // dial timeout, no connection s.setRedialWait(n, dialCost, dialWaitStep) - s.ns.SetState(n, nodestate.Flags{}, sfDialing, 0) + s.ns.SetStateSub(n, nodestate.Flags{}, sfDialing, 0) } }) @@ -193,10 +193,10 @@ func (s *serverPool) addPreNegFilter(input enode.Iterator, query queryFunc) enod if rand.Intn(maxQueryFails*2) < int(fails) { // skip pre-negotiation with increasing chance, max 50% // this ensures that the client can operate even if UDP is not working at all - s.ns.SetState(n, sfCanDial, nodestate.Flags{}, time.Second*10) + s.ns.SetStateSub(n, sfCanDial, nodestate.Flags{}, time.Second*10) // set canDial before resetting queried so that FillSet will not read more // candidates unnecessarily - s.ns.SetState(n, nodestate.Flags{}, sfQueried, 0) + s.ns.SetStateSub(n, nodestate.Flags{}, sfQueried, 0) return } go func() { @@ -206,12 +206,15 @@ func (s *serverPool) addPreNegFilter(input enode.Iterator, query queryFunc) enod } else { atomic.StoreUint32(&s.queryFails, 0) } - if q == 1 { - s.ns.SetState(n, sfCanDial, nodestate.Flags{}, time.Second*10) - } else { - s.setRedialWait(n, queryCost, queryWaitStep) - } - s.ns.SetState(n, nodestate.Flags{}, sfQueried, 0) + s.ns.Operation(func() { + // we are no longer running in the operation that the callback belongs to, start a new one because of setRedialWait + if q == 1 { + s.ns.SetStateSub(n, sfCanDial, nodestate.Flags{}, time.Second*10) + } else { + s.setRedialWait(n, queryCost, queryWaitStep) + } + s.ns.SetStateSub(n, nodestate.Flags{}, sfQueried, 0) + }) }() } }) @@ -240,18 +243,20 @@ func (s *serverPool) start() { } } unixTime := s.unixTime() - s.ns.ForEach(sfHasValue, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { - s.calculateWeight(node) - if n, ok := s.ns.GetField(node, sfiNodeHistory).(nodeHistory); ok && n.redialWaitEnd > unixTime { - wait := n.redialWaitEnd - unixTime - lastWait := n.redialWaitEnd - n.redialWaitStart - if wait > lastWait { - // if the time until expiration is larger than the last suggested - // waiting time then the system clock was probably adjusted - wait = lastWait + s.ns.Operation(func() { + s.ns.ForEach(sfHasValue, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { + s.calculateWeight(node) + if n, ok := s.ns.GetField(node, sfiNodeHistory).(nodeHistory); ok && n.redialWaitEnd > unixTime { + wait := n.redialWaitEnd - unixTime + lastWait := n.redialWaitEnd - n.redialWaitStart + if wait > lastWait { + // if the time until expiration is larger than the last suggested + // waiting time then the system clock was probably adjusted + wait = lastWait + } + s.ns.SetStateSub(node, sfRedialWait, nodestate.Flags{}, time.Duration(wait)*time.Second) } - s.ns.SetState(node, sfRedialWait, nodestate.Flags{}, time.Duration(wait)*time.Second) - } + }) }) } @@ -261,9 +266,11 @@ func (s *serverPool) stop() { if s.fillSet != nil { s.fillSet.Close() } - s.ns.ForEach(sfConnected, nodestate.Flags{}, func(n *enode.Node, state nodestate.Flags) { - // recalculate weight of connected nodes in order to update hasValue flag if necessary - s.calculateWeight(n) + s.ns.Operation(func() { + s.ns.ForEach(sfConnected, nodestate.Flags{}, func(n *enode.Node, state nodestate.Flags) { + // recalculate weight of connected nodes in order to update hasValue flag if necessary + s.calculateWeight(n) + }) }) s.ns.Stop() } @@ -279,9 +286,11 @@ func (s *serverPool) registerPeer(p *serverPeer) { // unregisterPeer implements serverPeerSubscriber func (s *serverPool) unregisterPeer(p *serverPeer) { - s.setRedialWait(p.Node(), dialCost, dialWaitStep) - s.ns.SetState(p.Node(), nodestate.Flags{}, sfConnected, 0) - s.ns.SetField(p.Node(), sfiConnectedStats, nil) + s.ns.Operation(func() { + s.setRedialWait(p.Node(), dialCost, dialWaitStep) + s.ns.SetStateSub(p.Node(), nodestate.Flags{}, sfConnected, 0) + s.ns.SetFieldSub(p.Node(), sfiConnectedStats, nil) + }) s.vt.Unregister(p.ID()) p.setValueTracker(nil, nil) } @@ -380,14 +389,16 @@ func (s *serverPool) serviceValue(node *enode.Node) (sessionValue, totalValue fl // updateWeight calculates the node weight and updates the nodeWeight field and the // hasValue flag. It also saves the node state if necessary. +// Note: this function should run inside a NodeStateMachine operation func (s *serverPool) updateWeight(node *enode.Node, totalValue float64, totalDialCost uint64) { weight := uint64(totalValue * nodeWeightMul / float64(totalDialCost)) if weight >= nodeWeightThreshold { - s.ns.SetState(node, sfHasValue, nodestate.Flags{}, 0) - s.ns.SetField(node, sfiNodeWeight, weight) + s.ns.SetStateSub(node, sfHasValue, nodestate.Flags{}, 0) + s.ns.SetFieldSub(node, sfiNodeWeight, weight) } else { - s.ns.SetState(node, nodestate.Flags{}, sfHasValue, 0) - s.ns.SetField(node, sfiNodeWeight, nil) + s.ns.SetStateSub(node, nodestate.Flags{}, sfHasValue, 0) + s.ns.SetFieldSub(node, sfiNodeWeight, nil) + s.ns.SetFieldSub(node, sfiNodeHistory, nil) } s.ns.Persist(node) // saved if node history or hasValue changed } @@ -400,6 +411,7 @@ func (s *serverPool) updateWeight(node *enode.Node, totalValue float64, totalDia // a significant amount of service value again its waiting time is quickly reduced or reset // to the minimum. // Note: node weight is also recalculated and updated by this function. +// Note 2: this function should run inside a NodeStateMachine operation func (s *serverPool) setRedialWait(node *enode.Node, addDialCost int64, waitStep float64) { n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory) sessionValue, totalValue := s.serviceValue(node) @@ -450,21 +462,22 @@ func (s *serverPool) setRedialWait(node *enode.Node, addDialCost int64, waitStep if wait < waitThreshold { n.redialWaitStart = unixTime n.redialWaitEnd = unixTime + int64(nextTimeout) - s.ns.SetField(node, sfiNodeHistory, n) - s.ns.SetState(node, sfRedialWait, nodestate.Flags{}, wait) + s.ns.SetFieldSub(node, sfiNodeHistory, n) + s.ns.SetStateSub(node, sfRedialWait, nodestate.Flags{}, wait) s.updateWeight(node, totalValue, totalDialCost) } else { // discard known node statistics if waiting time is very long because the node // hasn't been responsive for a very long time - s.ns.SetField(node, sfiNodeHistory, nil) - s.ns.SetField(node, sfiNodeWeight, nil) - s.ns.SetState(node, nodestate.Flags{}, sfHasValue, 0) + s.ns.SetFieldSub(node, sfiNodeHistory, nil) + s.ns.SetFieldSub(node, sfiNodeWeight, nil) + s.ns.SetStateSub(node, nodestate.Flags{}, sfHasValue, 0) } } // calculateWeight calculates and sets the node weight without altering the node history. // This function should be called during startup and shutdown only, otherwise setRedialWait // will keep the weights updated as the underlying statistics are adjusted. +// Note: this function should run inside a NodeStateMachine operation func (s *serverPool) calculateWeight(node *enode.Node) { n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory) _, totalValue := s.serviceValue(node) diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index 7091281ae..ab28b47a1 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -32,34 +32,46 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) +var ( + ErrInvalidField = errors.New("invalid field type") + ErrClosed = errors.New("already closed") +) + type ( - // NodeStateMachine connects different system components operating on subsets of - // network nodes. Node states are represented by 64 bit vectors with each bit assigned - // to a state flag. Each state flag has a descriptor structure and the mapping is - // created automatically. It is possible to subscribe to subsets of state flags and - // receive a callback if one of the nodes has a relevant state flag changed. - // Callbacks can also modify further flags of the same node or other nodes. State - // updates only return after all immediate effects throughout the system have happened - // (deadlocks should be avoided by design of the implemented state logic). The caller - // can also add timeouts assigned to a certain node and a subset of state flags. - // If the timeout elapses, the flags are reset. If all relevant flags are reset then - // the timer is dropped. State flags with no timeout are persisted in the database - // if the flag descriptor enables saving. If a node has no state flags set at any - // moment then it is discarded. - // - // Extra node fields can also be registered so system components can also store more - // complex state for each node that is relevant to them, without creating a custom - // peer set. Fields can be shared across multiple components if they all know the - // field ID. Subscription to fields is also possible. Persistent fields should have - // an encoder and a decoder function. + // NodeStateMachine implements a network node-related event subscription system. + // It can assign binary state flags and fields of arbitrary type to each node and allows + // subscriptions to flag/field changes which can also modify further flags and fields, + // potentially triggering further subscriptions. An operation includes an initial change + // and all resulting subsequent changes and always ends in a consistent global state. + // It is initiated by a "top level" SetState/SetField call that blocks (also blocking other + // top-level functions) until the operation is finished. Callbacks making further changes + // should use the non-blocking SetStateSub/SetFieldSub functions. The tree of events + // resulting from the initial changes is traversed in a breadth-first order, ensuring for + // each subscription callback that all other callbacks caused by the same change triggering + // the current callback are processed before anything is triggered by the changes made in the + // current callback. In practice this logic ensures that all subscriptions "see" events in + // the logical order, callbacks are never called concurrently and "back and forth" effects + // are also possible. The state machine design should ensure that infinite event cycles + // cannot happen. + // The caller can also add timeouts assigned to a certain node and a subset of state flags. + // If the timeout elapses, the flags are reset. If all relevant flags are reset then the timer + // is dropped. State flags with no timeout are persisted in the database if the flag + // descriptor enables saving. If a node has no state flags set at any moment then it is discarded. + // Note: in order to avoid mutex deadlocks the callbacks should never lock a mutex that + // might be locked when the top level SetState/SetField functions are called. If a function + // potentially performs state/field changes then it is recommended to mention this fact in the + // function description, along with whether it should run inside an operation callback. NodeStateMachine struct { - started, stopped bool + started, closed bool lock sync.Mutex clock mclock.Clock db ethdb.KeyValueStore dbNodeKey []byte nodes map[enode.ID]*nodeInfo offlineCallbackList []offlineCallback + opFlag bool // an operation has started + opWait *sync.Cond // signaled when the operation ends + opPending []func() // pending callback list of the current operation // Registered state flags or fields. Modifications are allowed // only when the node state machine has not been started. @@ -128,11 +140,12 @@ type ( // nodeInfo contains node state, fields and state timeouts nodeInfo struct { - node *enode.Node - state bitMask - timeouts []*nodeStateTimeout - fields []interface{} - db, dirty bool + node *enode.Node + state bitMask + timeouts []*nodeStateTimeout + fields []interface{} + fieldCount int + db, dirty bool } nodeInfoEnc struct { @@ -158,7 +171,7 @@ type ( } offlineCallback struct { - node *enode.Node + node *nodeInfo state bitMask fields []interface{} } @@ -319,10 +332,11 @@ func NewNodeStateMachine(db ethdb.KeyValueStore, dbKey []byte, clock mclock.Cloc nodes: make(map[enode.ID]*nodeInfo), fields: make([]*fieldInfo, len(setup.fields)), } + ns.opWait = sync.NewCond(&ns.lock) stateNameMap := make(map[string]int) for index, flag := range setup.flags { if _, ok := stateNameMap[flag.name]; ok { - panic("Node state flag name collision") + panic("Node state flag name collision: " + flag.name) } stateNameMap[flag.name] = index if flag.persistent { @@ -332,7 +346,7 @@ func NewNodeStateMachine(db ethdb.KeyValueStore, dbKey []byte, clock mclock.Cloc fieldNameMap := make(map[string]int) for index, field := range setup.fields { if _, ok := fieldNameMap[field.name]; ok { - panic("Node field name collision") + panic("Node field name collision: " + field.name) } ns.fields[index] = &fieldInfo{fieldDefinition: field} fieldNameMap[field.name] = index @@ -357,10 +371,12 @@ func (ns *NodeStateMachine) fieldIndex(field Field) int { } // SubscribeState adds a node state subscription. The callback is called while the state -// machine mutex is not held and it is allowed to make further state updates. All immediate -// changes throughout the system are processed in the same thread/goroutine. It is the -// responsibility of the implemented state logic to avoid deadlocks caused by the callbacks, -// infinite toggling of flags or hazardous/non-deterministic state changes. +// machine mutex is not held and it is allowed to make further state updates using the +// non-blocking SetStateSub/SetFieldSub functions. All callbacks of an operation are running +// from the thread/goroutine of the initial caller and parallel operations are not permitted. +// Therefore the callback is never called concurrently. It is the responsibility of the +// implemented state logic to avoid deadlocks and to reach a stable state in a finite amount +// of steps. // State subscriptions should be installed before loading the node database or making the // first state update. func (ns *NodeStateMachine) SubscribeState(flags Flags, callback StateCallback) { @@ -408,26 +424,33 @@ func (ns *NodeStateMachine) Start() { if ns.db != nil { ns.loadFromDb() } - ns.lock.Unlock() + + ns.opStart() ns.offlineCallbacks(true) + ns.opFinish() + ns.lock.Unlock() } // Stop stops the state machine and saves its state if a database was supplied func (ns *NodeStateMachine) Stop() { ns.lock.Lock() + defer ns.lock.Unlock() + + ns.checkStarted() + if !ns.opStart() { + panic("already closed") + } for _, node := range ns.nodes { fields := make([]interface{}, len(node.fields)) copy(fields, node.fields) - ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields}) + ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node, node.state, fields}) } - ns.stopped = true if ns.db != nil { ns.saveToDb() - ns.lock.Unlock() - } else { - ns.lock.Unlock() } ns.offlineCallbacks(false) + ns.closed = true + ns.opFinish() } // loadFromDb loads persisted node states from the database @@ -477,6 +500,7 @@ func (ns *NodeStateMachine) decodeNode(id enode.ID, data []byte) { if decode := ns.fields[i].decode; decode != nil { if field, err := decode(encField); err == nil { node.fields[i] = field + node.fieldCount++ } else { log.Error("Failed to decode node field", "id", id, "field name", ns.fields[i].name, "error", err) return @@ -491,7 +515,7 @@ func (ns *NodeStateMachine) decodeNode(id enode.ID, data []byte) { node.state = enc.State fields := make([]interface{}, len(node.fields)) copy(fields, node.fields) - ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields}) + ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node, node.state, fields}) log.Debug("Loaded node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup}) } @@ -505,15 +529,6 @@ func (ns *NodeStateMachine) saveNode(id enode.ID, node *nodeInfo) error { for _, t := range node.timeouts { storedState &= ^t.mask } - if storedState == 0 { - if node.db { - node.db = false - ns.deleteNode(id) - } - node.dirty = false - return nil - } - enc := nodeInfoEnc{ Enr: *node.node.Record(), Version: ns.setup.Version, @@ -537,6 +552,14 @@ func (ns *NodeStateMachine) saveNode(id enode.ID, node *nodeInfo) error { enc.Fields[i] = blob lastIndex = i } + if storedState == 0 && lastIndex == -1 { + if node.db { + node.db = false + ns.deleteNode(id) + } + node.dirty = false + return nil + } enc.Fields = enc.Fields[:lastIndex+1] data, err := rlp.EncodeToBytes(&enc) if err != nil { @@ -596,23 +619,36 @@ func (ns *NodeStateMachine) Persist(n *enode.Node) error { return nil } -// SetState updates the given node state flags and processes all resulting callbacks. -// It only returns after all subsequent immediate changes (including those changed by the -// callbacks) have been processed. If a flag with a timeout is set again, the operation -// removes or replaces the existing timeout. -func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { +// SetState updates the given node state flags and blocks until the operation is finished. +// If a flag with a timeout is set again, the operation removes or replaces the existing timeout. +func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) error { ns.lock.Lock() - ns.checkStarted() - if ns.stopped { - ns.lock.Unlock() - return - } + defer ns.lock.Unlock() + if !ns.opStart() { + return ErrClosed + } + ns.setState(n, setFlags, resetFlags, timeout) + ns.opFinish() + return nil +} + +// SetStateSub updates the given node state flags without blocking (should be called +// from a subscription/operation callback). +func (ns *NodeStateMachine) SetStateSub(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { + ns.lock.Lock() + defer ns.lock.Unlock() + + ns.opCheck() + ns.setState(n, setFlags, resetFlags, timeout) +} + +func (ns *NodeStateMachine) setState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { + ns.checkStarted() set, reset := ns.stateMask(setFlags), ns.stateMask(resetFlags) id, node := ns.updateEnode(n) if node == nil { if set == 0 { - ns.lock.Unlock() return } node = ns.newNode(n) @@ -627,16 +663,14 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, // even they are not existent(it's noop). ns.removeTimeouts(node, set|reset) - // Register the timeout callback if the new state is not empty - // and timeout itself is required. - if timeout != 0 && newState != 0 { + // Register the timeout callback if required + if timeout != 0 && set != 0 { ns.addTimeout(n, set, timeout) } if newState == oldState { - ns.lock.Unlock() return } - if newState == 0 { + if newState == 0 && node.fieldCount == 0 { delete(ns.nodes, id) if node.db { ns.deleteNode(id) @@ -646,68 +680,118 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, node.dirty = true } } - ns.lock.Unlock() - // call state update subscription callbacks without holding the mutex - for _, sub := range ns.stateSubs { - if changed&sub.mask != 0 { - sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}) - } - } - if newState == 0 { - // call field subscriptions for discarded fields - for i, v := range node.fields { - if v != nil { - f := ns.fields[i] - if len(f.subs) > 0 { - for _, cb := range f.subs { - cb(n, Flags{setup: ns.setup}, v, nil) - } - } + callback := func() { + for _, sub := range ns.stateSubs { + if changed&sub.mask != 0 { + sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}) } } } + ns.opPending = append(ns.opPending, callback) +} + +// opCheck checks whether an operation is active +func (ns *NodeStateMachine) opCheck() { + if !ns.opFlag { + panic("Operation has not started") + } +} + +// opStart waits until other operations are finished and starts a new one +func (ns *NodeStateMachine) opStart() bool { + for ns.opFlag { + ns.opWait.Wait() + } + if ns.closed { + return false + } + ns.opFlag = true + return true +} + +// opFinish finishes the current operation by running all pending callbacks. +// Callbacks resulting from a state/field change performed in a previous callback are always +// put at the end of the pending list and therefore processed after all callbacks resulting +// from the previous state/field change. +func (ns *NodeStateMachine) opFinish() { + for len(ns.opPending) != 0 { + list := ns.opPending + ns.lock.Unlock() + for _, cb := range list { + cb() + } + ns.lock.Lock() + ns.opPending = ns.opPending[len(list):] + } + ns.opPending = nil + ns.opFlag = false + ns.opWait.Signal() +} + +// Operation calls the given function as an operation callback. This allows the caller +// to start an operation with multiple initial changes. The same rules apply as for +// subscription callbacks. +func (ns *NodeStateMachine) Operation(fn func()) error { + ns.lock.Lock() + started := ns.opStart() + ns.lock.Unlock() + if !started { + return ErrClosed + } + fn() + ns.lock.Lock() + ns.opFinish() + ns.lock.Unlock() + return nil } // offlineCallbacks calls state update callbacks at startup or shutdown func (ns *NodeStateMachine) offlineCallbacks(start bool) { for _, cb := range ns.offlineCallbackList { - for _, sub := range ns.stateSubs { - offState := offlineState & sub.mask - onState := cb.state & sub.mask - if offState != onState { + cb := cb + callback := func() { + for _, sub := range ns.stateSubs { + offState := offlineState & sub.mask + onState := cb.state & sub.mask + if offState == onState { + continue + } if start { - sub.callback(cb.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}) + sub.callback(cb.node.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}) } else { - sub.callback(cb.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}) + sub.callback(cb.node.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}) } } - } - for i, f := range cb.fields { - if f != nil && ns.fields[i].subs != nil { + for i, f := range cb.fields { + if f == nil || ns.fields[i].subs == nil { + continue + } for _, fsub := range ns.fields[i].subs { if start { - fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, nil, f) + fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, nil, f) } else { - fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, f, nil) + fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, f, nil) } } } } + ns.opPending = append(ns.opPending, callback) } ns.offlineCallbackList = nil } // AddTimeout adds a node state timeout associated to the given state flag(s). // After the specified time interval, the relevant states will be reset. -func (ns *NodeStateMachine) AddTimeout(n *enode.Node, flags Flags, timeout time.Duration) { +func (ns *NodeStateMachine) AddTimeout(n *enode.Node, flags Flags, timeout time.Duration) error { ns.lock.Lock() defer ns.lock.Unlock() ns.checkStarted() - if ns.stopped { - return + if ns.closed { + return ErrClosed } ns.addTimeout(n, ns.stateMask(flags), timeout) + return nil } // addTimeout adds a node state timeout associated to the given state flag(s). @@ -756,13 +840,15 @@ func (ns *NodeStateMachine) removeTimeouts(node *nodeInfo, mask bitMask) { } } -// GetField retrieves the given field of the given node +// GetField retrieves the given field of the given node. Note that when used in a +// subscription callback the result can be out of sync with the state change represented +// by the callback parameters so extra safety checks might be necessary. func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} { ns.lock.Lock() defer ns.lock.Unlock() ns.checkStarted() - if ns.stopped { + if ns.closed { return nil } if _, node := ns.updateEnode(n); node != nil { @@ -771,48 +857,80 @@ func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} { return nil } -// SetField sets the given field of the given node +// SetField sets the given field of the given node and blocks until the operation is finished func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}) error { ns.lock.Lock() - ns.checkStarted() - if ns.stopped { - ns.lock.Unlock() - return nil + defer ns.lock.Unlock() + + if !ns.opStart() { + return ErrClosed } - _, node := ns.updateEnode(n) + err := ns.setField(n, field, value) + ns.opFinish() + return err +} + +// SetFieldSub sets the given field of the given node without blocking (should be called +// from a subscription/operation callback). +func (ns *NodeStateMachine) SetFieldSub(n *enode.Node, field Field, value interface{}) error { + ns.lock.Lock() + defer ns.lock.Unlock() + + ns.opCheck() + return ns.setField(n, field, value) +} + +func (ns *NodeStateMachine) setField(n *enode.Node, field Field, value interface{}) error { + ns.checkStarted() + id, node := ns.updateEnode(n) if node == nil { - ns.lock.Unlock() - return nil + if value == nil { + return nil + } + node = ns.newNode(n) + ns.nodes[id] = node } fieldIndex := ns.fieldIndex(field) f := ns.fields[fieldIndex] if value != nil && reflect.TypeOf(value) != f.ftype { log.Error("Invalid field type", "type", reflect.TypeOf(value), "required", f.ftype) - ns.lock.Unlock() - return errors.New("invalid field type") + return ErrInvalidField } oldValue := node.fields[fieldIndex] if value == oldValue { - ns.lock.Unlock() return nil } - node.fields[fieldIndex] = value - if f.encode != nil { - node.dirty = true + if oldValue != nil { + node.fieldCount-- + } + if value != nil { + node.fieldCount++ + } + node.fields[fieldIndex] = value + if node.state == 0 && node.fieldCount == 0 { + delete(ns.nodes, id) + if node.db { + ns.deleteNode(id) + } + } else { + if f.encode != nil { + node.dirty = true + } } - state := node.state - ns.lock.Unlock() - if len(f.subs) > 0 { + callback := func() { for _, cb := range f.subs { cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value) } } + ns.opPending = append(ns.opPending, callback) return nil } // ForEach calls the callback for each node having all of the required and none of the -// disabled flags set +// disabled flags set. +// Note that this callback is not an operation callback but ForEach can be called from an +// Operation callback or Operation can also be called from a ForEach callback if necessary. func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n *enode.Node, state Flags)) { ns.lock.Lock() ns.checkStarted() diff --git a/p2p/nodestate/nodestate_test.go b/p2p/nodestate/nodestate_test.go index f6ff3ffc0..5f99a3da7 100644 --- a/p2p/nodestate/nodestate_test.go +++ b/p2p/nodestate/nodestate_test.go @@ -147,8 +147,13 @@ func TestSetField(t *testing.T) { // Set field before setting state ns.SetField(testNode(1), fields[0], "hello world") field := ns.GetField(testNode(1), fields[0]) + if field == nil { + t.Fatalf("Field should be set before setting states") + } + ns.SetField(testNode(1), fields[0], nil) + field = ns.GetField(testNode(1), fields[0]) if field != nil { - t.Fatalf("Field shouldn't be set before setting states") + t.Fatalf("Field should be unset") } // Set field after setting state ns.SetState(testNode(1), flags[0], Flags{}, 0) @@ -169,23 +174,6 @@ func TestSetField(t *testing.T) { } } -func TestUnsetField(t *testing.T) { - mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} - - s, flags, fields := testSetup([]bool{false}, []reflect.Type{reflect.TypeOf("")}) - ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - - ns.Start() - - ns.SetState(testNode(1), flags[0], Flags{}, time.Second) - ns.SetField(testNode(1), fields[0], "hello world") - - ns.SetState(testNode(1), Flags{}, flags[0], 0) - if field := ns.GetField(testNode(1), fields[0]); field != nil { - t.Fatalf("Field should be unset") - } -} - func TestSetState(t *testing.T) { mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} @@ -339,6 +327,7 @@ func TestFieldSub(t *testing.T) { ns2.Start() check(s.OfflineFlag(), nil, uint64(100)) ns2.SetState(testNode(1), Flags{}, flags[0], 0) + ns2.SetField(testNode(1), fields[0], nil) check(Flags{}, uint64(100), nil) ns2.Stop() } @@ -387,3 +376,34 @@ func TestDuplicatedFlags(t *testing.T) { clock.Run(2 * time.Second) check(flags[0], Flags{}, true) } + +func TestCallbackOrder(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, _ := testSetup([]bool{false, false, false, false}, nil) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { + if newState.Equals(flags[0]) { + ns.SetStateSub(n, flags[1], Flags{}, 0) + ns.SetStateSub(n, flags[2], Flags{}, 0) + } + }) + ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { + if newState.Equals(flags[1]) { + ns.SetStateSub(n, flags[3], Flags{}, 0) + } + }) + lastState := Flags{} + ns.SubscribeState(MergeFlags(flags[1], flags[2], flags[3]), func(n *enode.Node, oldState, newState Flags) { + if !oldState.Equals(lastState) { + t.Fatalf("Wrong callback order") + } + lastState = newState + }) + + ns.Start() + defer ns.Stop() + + ns.SetState(testNode(1), flags[0], Flags{}, 0) +}