diff --git a/localworker.go b/localworker.go index f3f12e8c1..a92f01a89 100644 --- a/localworker.go +++ b/localworker.go @@ -211,6 +211,10 @@ func (l *LocalWorker) Info(context.Context) (storiface.WorkerInfo, error) { }, nil } +func (l *LocalWorker) Closing(ctx context.Context) (<-chan struct{}, error) { + return make(chan struct{}), nil +} + func (l *LocalWorker) Close() error { return nil } diff --git a/manager.go b/manager.go index f4408b6e0..065370ed6 100644 --- a/manager.go +++ b/manager.go @@ -37,6 +37,9 @@ type Worker interface { Info(context.Context) (storiface.WorkerInfo, error) + // returns channel signalling worker shutdown + Closing(context.Context) (<-chan struct{}, error) + Close() error } diff --git a/sched.go b/sched.go index aaf8b9e26..3822a8683 100644 --- a/sched.go +++ b/sched.go @@ -33,6 +33,10 @@ type scheduler struct { workers map[WorkerID]*workerHandle newWorkers chan *workerHandle + + watchClosing chan WorkerID + workerClosing chan WorkerID + schedule chan *workerRequest workerFree chan WorkerID closing chan struct{} @@ -47,10 +51,14 @@ func newScheduler(spt abi.RegisteredProof) *scheduler { nextWorker: 0, workers: map[WorkerID]*workerHandle{}, - newWorkers: make(chan *workerHandle), - schedule: make(chan *workerRequest), - workerFree: make(chan WorkerID), - closing: make(chan struct{}), + newWorkers: make(chan *workerHandle), + + watchClosing: make(chan WorkerID), + workerClosing: make(chan WorkerID), + + schedule: make(chan *workerRequest), + workerFree: make(chan WorkerID), + closing: make(chan struct{}), schedQueue: list.New(), } @@ -128,12 +136,14 @@ type workerHandle struct { } func (sh *scheduler) runSched() { + go sh.runWorkerWatcher() + for { select { case w := <-sh.newWorkers: - wid := sh.schedNewWorker(w) - - sh.onWorkerFreed(wid) + sh.schedNewWorker(w) + case wid := <-sh.workerClosing: + sh.schedDropWorker(wid) case req := <-sh.schedule: scheduled, err := sh.maybeSchedRequest(req) if err != nil { @@ -155,10 +165,18 @@ func (sh *scheduler) runSched() { } func (sh *scheduler) onWorkerFreed(wid WorkerID) { + sh.workersLk.Lock() + w, ok := sh.workers[wid] + sh.workersLk.Unlock() + if !ok { + log.Warnf("onWorkerFreed on invalid worker %d", wid) + return + } + for e := sh.schedQueue.Front(); e != nil; e = e.Next() { req := e.Value.(*workerRequest) - ok, err := req.sel.Ok(req.ctx, req.taskType, sh.workers[wid]) + ok, err := req.sel.Ok(req.ctx, req.taskType, w) if err != nil { log.Errorf("onWorkerFreed req.sel.Ok error: %+v", err) continue @@ -411,15 +429,36 @@ func (a *activeResources) utilization(wr storiface.WorkerResources) float64 { return max } -func (sh *scheduler) schedNewWorker(w *workerHandle) WorkerID { +func (sh *scheduler) schedNewWorker(w *workerHandle) { sh.workersLk.Lock() - defer sh.workersLk.Unlock() id := sh.nextWorker sh.workers[id] = w sh.nextWorker++ - return id + sh.workersLk.Unlock() + + select { + case sh.watchClosing <- id: + case <-sh.closing: + return + } + + sh.onWorkerFreed(id) +} + +func (sh *scheduler) schedDropWorker(wid WorkerID) { + sh.workersLk.Lock() + defer sh.workersLk.Unlock() + + w := sh.workers[wid] + delete(sh.workers, wid) + + go func() { + if err := w.w.Close(); err != nil { + log.Warnf("closing worker %d: %+v", err) + } + }() } func (sh *scheduler) schedClose() { diff --git a/sched_watch.go b/sched_watch.go new file mode 100644 index 000000000..c2716aae9 --- /dev/null +++ b/sched_watch.go @@ -0,0 +1,93 @@ +package sectorstorage + +import ( + "context" + "reflect" +) + +func (sh *scheduler) runWorkerWatcher() { + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + nilch := reflect.ValueOf(new(chan struct{})).Elem() + + cases := []reflect.SelectCase{ + { + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(sh.closing), + }, + { + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(sh.watchClosing), + }, + } + + caseToWorker := map[int]WorkerID{} + + for { + n, rv, ok := reflect.Select(cases) + + switch { + case n == 0: // sh.closing + return + case n == 1: // sh.watchClosing + if !ok { + log.Errorf("watchClosing channel closed") + return + } + + wid, ok := rv.Interface().(WorkerID) + if !ok { + panic("got a non-WorkerID message") + } + + sh.workersLk.Lock() + workerClosing, err := sh.workers[wid].w.Closing(ctx) + sh.workersLk.Unlock() + if err != nil { + log.Errorf("getting worker closing channel: %+v", err) + select { + case sh.workerClosing <- wid: + case <-sh.closing: + return + } + + continue + } + + toSet := -1 + for i, sc := range cases { + if sc.Chan == nilch { + toSet = i + break + } + } + if toSet == -1 { + toSet = len(cases) + cases = append(cases, reflect.SelectCase{}) + } + + cases[toSet] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(workerClosing), + } + + caseToWorker[toSet] = wid + default: + wid := caseToWorker[n] + + delete(caseToWorker, n) + cases[n] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: nilch, + } + + log.Warnf("worker %d dropped", wid) + select { + case sh.workerClosing <- wid: + case <-sh.closing: + return + } + } + } +}