diff --git a/les/client.go b/les/client.go index 9f0afc96c..d08c9feba 100644 --- a/les/client.go +++ b/les/client.go @@ -81,7 +81,7 @@ func New(stack *node.Node, config *ethconfig.Config) (*LightEthereum, error) { if err != nil { return nil, err } - lespayDb, err := stack.OpenDatabase("lespay", 0, 0, "eth/db/lespay") + lesDb, err := stack.OpenDatabase("les.client", 0, 0, "eth/db/les.client") if err != nil { return nil, err } @@ -99,6 +99,7 @@ func New(stack *node.Node, config *ethconfig.Config) (*LightEthereum, error) { chainConfig: chainConfig, iConfig: light.DefaultClientIndexerConfig, chainDb: chainDb, + lesDb: lesDb, closeCh: make(chan struct{}), }, peers: peers, @@ -108,13 +109,13 @@ func New(stack *node.Node, config *ethconfig.Config) (*LightEthereum, error) { engine: ethconfig.CreateConsensusEngine(stack, chainConfig, &config.Ethash, nil, false, chainDb), bloomRequests: make(chan chan *bloombits.Retrieval), bloomIndexer: core.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations), - valueTracker: vfc.NewValueTracker(lespayDb, &mclock.System{}, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)), + valueTracker: vfc.NewValueTracker(lesDb, &mclock.System{}, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)), p2pServer: stack.Server(), p2pConfig: &stack.Config().P2P, } peers.subscribe((*vtSubscription)(leth.valueTracker)) - leth.serverPool = newServerPool(lespayDb, []byte("serverpool:"), leth.valueTracker, time.Second, nil, &mclock.System{}, config.UltraLightServers) + leth.serverPool = newServerPool(lesDb, []byte("serverpool:"), leth.valueTracker, time.Second, nil, &mclock.System{}, config.UltraLightServers) peers.subscribe(leth.serverPool) leth.dialCandidates = leth.serverPool.dialIterator @@ -331,6 +332,7 @@ func (s *LightEthereum) Stop() error { s.eventMux.Stop() rawdb.PopUncleanShutdownMarker(s.chainDb) s.chainDb.Close() + s.lesDb.Close() s.wg.Wait() log.Info("Light ethereum stopped") return nil diff --git a/les/clientpool.go b/les/clientpool.go index 96c0f0f99..4e1499bf5 100644 --- a/les/clientpool.go +++ b/les/clientpool.go @@ -105,7 +105,7 @@ type clientInfo struct { } // newClientPool creates a new client pool -func newClientPool(ns *nodestate.NodeStateMachine, lespayDb ethdb.Database, minCap uint64, connectedBias time.Duration, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { +func newClientPool(ns *nodestate.NodeStateMachine, lesDb ethdb.Database, minCap uint64, connectedBias time.Duration, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { pool := &clientPool{ ns: ns, BalanceTrackerSetup: balanceTrackerSetup, @@ -115,7 +115,7 @@ func newClientPool(ns *nodestate.NodeStateMachine, lespayDb ethdb.Database, minC connectedBias: connectedBias, removePeer: removePeer, } - pool.bt = vfs.NewBalanceTracker(ns, balanceTrackerSetup, lespayDb, clock, &utils.Expirer{}, &utils.Expirer{}) + pool.bt = vfs.NewBalanceTracker(ns, balanceTrackerSetup, lesDb, clock, &utils.Expirer{}, &utils.Expirer{}) pool.pp = vfs.NewPriorityPool(ns, priorityPoolSetup, clock, minCap, connectedBias, 4) // set default expiration constants used by tests diff --git a/les/commons.go b/les/commons.go index a2fce1dc9..d090fc21f 100644 --- a/les/commons.go +++ b/les/commons.go @@ -51,7 +51,7 @@ type lesCommons struct { config *ethconfig.Config chainConfig *params.ChainConfig iConfig *light.IndexerConfig - chainDb ethdb.Database + chainDb, lesDb ethdb.Database chainReader chainReader chtIndexer, bloomTrieIndexer *core.ChainIndexer oracle *checkpointoracle.CheckpointOracle diff --git a/les/server.go b/les/server.go index 351a53f69..e34647f29 100644 --- a/les/server.go +++ b/les/server.go @@ -85,6 +85,10 @@ type LesServer struct { } func NewLesServer(node *node.Node, e ethBackend, config *ethconfig.Config) (*LesServer, error) { + lesDb, err := node.OpenDatabase("les.server", 0, 0, "eth/db/les.server") + if err != nil { + return nil, err + } ns := nodestate.NewNodeStateMachine(nil, nil, mclock.System{}, serverSetup) // Calculate the number of threads used to service the light client // requests based on the user-specified value. @@ -99,6 +103,7 @@ func NewLesServer(node *node.Node, e ethBackend, config *ethconfig.Config) (*Les chainConfig: e.BlockChain().Config(), iConfig: light.DefaultServerIndexerConfig, chainDb: e.ChainDb(), + lesDb: lesDb, chainReader: e.BlockChain(), chtIndexer: light.NewChtIndexer(e.ChainDb(), nil, params.CHTFrequency, params.HelperTrieProcessConfirmations, true), bloomTrieIndexer: light.NewBloomTrieIndexer(e.ChainDb(), nil, params.BloomBitsBlocks, params.BloomTrieFrequency, true), @@ -136,7 +141,7 @@ func NewLesServer(node *node.Node, e ethBackend, config *ethconfig.Config) (*Les srv.maxCapacity = totalRecharge } srv.fcManager.SetCapacityLimits(srv.minCapacity, srv.maxCapacity, srv.minCapacity*2) - srv.clientPool = newClientPool(ns, srv.chainDb, srv.minCapacity, defaultConnectedBias, mclock.System{}, srv.dropClient) + srv.clientPool = newClientPool(ns, lesDb, srv.minCapacity, defaultConnectedBias, mclock.System{}, srv.dropClient) srv.clientPool.setDefaultFactors(vfs.PriceFactors{TimeFactor: 0, CapacityFactor: 1, RequestFactor: 1}, vfs.PriceFactors{TimeFactor: 0, CapacityFactor: 1, RequestFactor: 1}) checkpoint := srv.latestLocalCheckpoint() @@ -222,6 +227,7 @@ func (s *LesServer) Stop() error { // Note, bloom trie indexer is closed by parent bloombits indexer. s.chtIndexer.Close() + s.lesDb.Close() s.wg.Wait() log.Info("Les server stopped") diff --git a/les/vflux/server/balance_test.go b/les/vflux/server/balance_test.go index 67e194437..6c817aa26 100644 --- a/les/vflux/server/balance_test.go +++ b/les/vflux/server/balance_test.go @@ -17,6 +17,7 @@ package server import ( + "math" "math/rand" "reflect" "testing" @@ -69,7 +70,9 @@ func (b *balanceTestSetup) newNode(capacity uint64) *NodeBalance { node := enode.SignNull(&enr.Record{}, enode.ID{}) b.ns.SetState(node, testFlag, nodestate.Flags{}, 0) b.ns.SetField(node, btTestSetup.connAddressField, "") - b.ns.SetField(node, ppTestSetup.CapacityField, capacity) + if capacity != 0 { + b.ns.SetField(node, ppTestSetup.CapacityField, capacity) + } n, _ := b.ns.GetField(node, btTestSetup.BalanceField).(*NodeBalance) return n } @@ -398,3 +401,71 @@ func TestCallback(t *testing.T) { case <-time.NewTimer(time.Millisecond * 100).C: } } + +func TestBalancePersistence(t *testing.T) { + clock := &mclock.Simulated{} + ns := nodestate.NewNodeStateMachine(nil, nil, clock, testSetup) + db := memorydb.New() + posExp := &utils.Expirer{} + negExp := &utils.Expirer{} + posExp.SetRate(clock.Now(), math.Log(2)/float64(time.Hour*2)) // halves every two hours + negExp.SetRate(clock.Now(), math.Log(2)/float64(time.Hour)) // halves every hour + bt := NewBalanceTracker(ns, btTestSetup, db, clock, posExp, negExp) + ns.Start() + bts := &balanceTestSetup{ + clock: clock, + ns: ns, + bt: bt, + } + var nb *NodeBalance + exp := func(expPos, expNeg uint64) { + pos, neg := nb.GetBalance() + if pos != expPos { + t.Fatalf("Positive balance incorrect, want %v, got %v", expPos, pos) + } + if neg != expNeg { + t.Fatalf("Positive balance incorrect, want %v, got %v", expPos, pos) + } + } + expTotal := func(expTotal uint64) { + total := bt.TotalTokenAmount() + if total != expTotal { + t.Fatalf("Total token amount incorrect, want %v, got %v", expTotal, total) + } + } + + expTotal(0) + nb = bts.newNode(0) + expTotal(0) + nb.SetBalance(16000000000, 16000000000) + exp(16000000000, 16000000000) + expTotal(16000000000) + clock.Run(time.Hour * 2) + exp(8000000000, 4000000000) + expTotal(8000000000) + bt.Stop() + ns.Stop() + + clock = &mclock.Simulated{} + ns = nodestate.NewNodeStateMachine(nil, nil, clock, testSetup) + posExp = &utils.Expirer{} + negExp = &utils.Expirer{} + posExp.SetRate(clock.Now(), math.Log(2)/float64(time.Hour*2)) // halves every two hours + negExp.SetRate(clock.Now(), math.Log(2)/float64(time.Hour)) // halves every hour + bt = NewBalanceTracker(ns, btTestSetup, db, clock, posExp, negExp) + ns.Start() + bts = &balanceTestSetup{ + clock: clock, + ns: ns, + bt: bt, + } + expTotal(8000000000) + nb = bts.newNode(0) + exp(8000000000, 4000000000) + expTotal(8000000000) + clock.Run(time.Hour * 2) + exp(4000000000, 1000000000) + expTotal(4000000000) + bt.Stop() + ns.Stop() +} diff --git a/les/vflux/server/balance_tracker.go b/les/vflux/server/balance_tracker.go index edd4b288d..1708019de 100644 --- a/les/vflux/server/balance_tracker.go +++ b/les/vflux/server/balance_tracker.go @@ -99,6 +99,10 @@ func NewBalanceTracker(ns *nodestate.NodeStateMachine, setup BalanceTrackerSetup balanceTimer: utils.NewUpdateTimer(clock, time.Second*10), quit: make(chan struct{}), } + posOffset, negOffset := bt.ndb.getExpiration() + posExp.SetLogOffset(clock.Now(), posOffset) + negExp.SetLogOffset(clock.Now(), negOffset) + bt.ndb.forEachBalance(false, func(id enode.ID, balance utils.ExpiredValue) bool { bt.inactive.AddExp(balance) return true @@ -177,7 +181,7 @@ func (bt *BalanceTracker) TotalTokenAmount() uint64 { bt.balanceTimer.Update(func(_ time.Duration) bool { bt.active = utils.ExpiredValue{} bt.ns.ForEach(nodestate.Flags{}, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { - if n, ok := bt.ns.GetField(node, bt.BalanceField).(*NodeBalance); ok { + if n, ok := bt.ns.GetField(node, bt.BalanceField).(*NodeBalance); ok && n.active { pos, _ := n.GetRawBalance() bt.active.AddExp(pos) }