diff --git a/cmd/lotus/daemon.go b/cmd/lotus/daemon.go index 5a59ec816..ec4a638b4 100644 --- a/cmd/lotus/daemon.go +++ b/cmd/lotus/daemon.go @@ -351,8 +351,21 @@ var DaemonCmd = &cli.Command{ return xerrors.Errorf("getting api endpoint: %w", err) } + // Start the RPC server. + rpcStopper, err := node.ServeRPC(api, endpoint, int64(cctx.Int("api-max-req-size"))) + if err != nil { + return fmt.Errorf("failed to start JSON-RPC API: %s", err) + } + + // Monitor for shutdown. + finishCh := node.MonitorShutdown(shutdownChan, + node.ShutdownHandler{Component: "rpc server", StopFunc: rpcStopper}, + node.ShutdownHandler{Component: "node", StopFunc: stop}, + ) + <-finishCh // fires when shutdown is complete. + // TODO: properly parse api endpoint (or make it a URL) - return serveRPC(api, stop, endpoint, shutdownChan, int64(cctx.Int("api-max-req-size"))) + return nil }, Subcommands: []*cli.Command{ daemonStopCmd, diff --git a/cmd/lotus/main.go b/cmd/lotus/main.go index c1dab8e94..63d01f891 100644 --- a/cmd/lotus/main.go +++ b/cmd/lotus/main.go @@ -4,6 +4,7 @@ import ( "context" "os" + logging "github.com/ipfs/go-log/v2" "github.com/mattn/go-isatty" "github.com/urfave/cli/v2" "go.opencensus.io/trace" @@ -16,6 +17,8 @@ import ( "github.com/filecoin-project/lotus/node/repo" ) +var log = logging.Logger("main") + var AdvanceBlockCmd *cli.Command func main() { diff --git a/cmd/lotus/pprof.go b/cmd/lotus/pprof.go deleted file mode 100644 index ea6823e48..000000000 --- a/cmd/lotus/pprof.go +++ /dev/null @@ -1,33 +0,0 @@ -package main - -import ( - "net/http" - "strconv" -) - -func handleFractionOpt(name string, setter func(int)) http.HandlerFunc { - return func(rw http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(rw, "only POST allowed", http.StatusMethodNotAllowed) - return - } - if err := r.ParseForm(); err != nil { - http.Error(rw, err.Error(), http.StatusBadRequest) - return - } - - asfr := r.Form.Get("x") - if len(asfr) == 0 { - http.Error(rw, "parameter 'x' must be set", http.StatusBadRequest) - return - } - - fr, err := strconv.Atoi(asfr) - if err != nil { - http.Error(rw, err.Error(), http.StatusBadRequest) - return - } - log.Infof("setting %s to %d", name, fr) - setter(fr) - } -} diff --git a/cmd/lotus/rpc.go b/node/rpc.go similarity index 68% rename from cmd/lotus/rpc.go rename to node/rpc.go index 95050d639..1b1192f71 100644 --- a/cmd/lotus/rpc.go +++ b/node/rpc.go @@ -1,4 +1,4 @@ -package main +package node import ( "context" @@ -6,10 +6,8 @@ import ( "net" "net/http" _ "net/http/pprof" - "os" - "os/signal" "runtime" - "syscall" + "strconv" "github.com/ipfs/go-cid" logging "github.com/ipfs/go-log/v2" @@ -25,13 +23,12 @@ import ( "github.com/filecoin-project/lotus/api/v0api" "github.com/filecoin-project/lotus/api/v1api" "github.com/filecoin-project/lotus/metrics" - "github.com/filecoin-project/lotus/node" "github.com/filecoin-project/lotus/node/impl" ) -var log = logging.Logger("main") +var rpclog = logging.Logger("rpc") -func serveRPC(a v1api.FullNode, stop node.StopFunc, addr multiaddr.Multiaddr, shutdownCh <-chan struct{}, maxRequestSize int64) error { +func ServeRPC(a v1api.FullNode, addr multiaddr.Multiaddr, maxRequestSize int64) (StopFunc, error) { serverOptions := make([]jsonrpc.ServerOption, 0) if maxRequestSize != 0 { // config set serverOptions = append(serverOptions, jsonrpc.WithMaxRequestSize(maxRequestSize)) @@ -62,15 +59,17 @@ func serveRPC(a v1api.FullNode, stop node.StopFunc, addr multiaddr.Multiaddr, sh http.Handle("/debug/metrics", metrics.Exporter()) http.Handle("/debug/pprof-set/block", handleFractionOpt("BlockProfileRate", runtime.SetBlockProfileRate)) - http.Handle("/debug/pprof-set/mutex", handleFractionOpt("MutexProfileFraction", - func(x int) { runtime.SetMutexProfileFraction(x) }, - )) + http.Handle("/debug/pprof-set/mutex", handleFractionOpt("MutexProfileFraction", func(x int) { + runtime.SetMutexProfileFraction(x) + })) + // Start listening to the addr; if invalid or occupied, we will fail early. lst, err := manet.Listen(addr) if err != nil { - return xerrors.Errorf("could not listen: %w", err) + return nil, xerrors.Errorf("could not listen: %w", err) } + // Instantiate the server and start listening. srv := &http.Server{ Handler: http.DefaultServeMux, BaseContext: func(listener net.Listener) context.Context { @@ -79,35 +78,14 @@ func serveRPC(a v1api.FullNode, stop node.StopFunc, addr multiaddr.Multiaddr, sh }, } - sigCh := make(chan os.Signal, 2) - shutdownDone := make(chan struct{}) go func() { - select { - case sig := <-sigCh: - log.Warnw("received shutdown", "signal", sig) - case <-shutdownCh: - log.Warn("received shutdown") + err = srv.Serve(manet.NetListener(lst)) + if err != http.ErrServerClosed { + log.Warnf("rpc server failed: %s", err) } - - log.Warn("Shutting down...") - if err := srv.Shutdown(context.TODO()); err != nil { - log.Errorf("shutting down RPC server failed: %s", err) - } - if err := stop(context.TODO()); err != nil { - log.Errorf("graceful shutting down failed: %s", err) - } - log.Warn("Graceful shutdown successful") - _ = log.Sync() //nolint:errcheck - close(shutdownDone) }() - signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) - err = srv.Serve(manet.NetListener(lst)) - if err == http.ErrServerClosed { - <-shutdownDone - return nil - } - return err + return srv.Shutdown, err } func handleImport(a *impl.FullNodeAPI) func(w http.ResponseWriter, r *http.Request) { @@ -136,3 +114,30 @@ func handleImport(a *impl.FullNodeAPI) func(w http.ResponseWriter, r *http.Reque } } } + +func handleFractionOpt(name string, setter func(int)) http.HandlerFunc { + return func(rw http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(rw, "only POST allowed", http.StatusMethodNotAllowed) + return + } + if err := r.ParseForm(); err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + return + } + + asfr := r.Form.Get("x") + if len(asfr) == 0 { + http.Error(rw, "parameter 'x' must be set", http.StatusBadRequest) + return + } + + fr, err := strconv.Atoi(asfr) + if err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + return + } + log.Infof("setting %s to %d", name, fr) + setter(fr) + } +} diff --git a/node/shutdown.go b/node/shutdown.go new file mode 100644 index 000000000..e630031da --- /dev/null +++ b/node/shutdown.go @@ -0,0 +1,56 @@ +package node + +import ( + "context" + "os" + "os/signal" + "syscall" +) + +type ShutdownHandler struct { + Component string + StopFunc StopFunc +} + +// MonitorShutdown manages shutdown requests, by watching signals and invoking +// the supplied handlers in order. +// +// It watches SIGTERM and SIGINT OS signals, as well as the trigger channel. +// When any of them fire, it calls the supplied handlers in order. If any of +// them errors, it merely logs the error. +// +// Once the shutdown has completed, it closes the returned channel. The caller +// can watch this channel +func MonitorShutdown(triggerCh <-chan struct{}, handlers ...ShutdownHandler) <-chan struct{} { + sigCh := make(chan os.Signal, 2) + out := make(chan struct{}) + + go func() { + select { + case sig := <-sigCh: + log.Warnw("received shutdown", "signal", sig) + case <-triggerCh: + log.Warn("received shutdown") + } + + log.Warn("Shutting down...") + + // Call all the handlers, logging on failure and success. + for _, h := range handlers { + if err := h.StopFunc(context.TODO()); err != nil { + log.Errorf("shutting down %s failed: %s", h.Component, err) + continue + } + log.Infof("%s shut down successfully ", h.Component) + } + + log.Warn("Graceful shutdown successful") + + // Sync all loggers. + _ = log.Sync() //nolint:errcheck + close(out) + }() + + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) + return out +} diff --git a/node/shutdown_test.go b/node/shutdown_test.go new file mode 100644 index 000000000..15e2af93e --- /dev/null +++ b/node/shutdown_test.go @@ -0,0 +1,36 @@ +package node + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestMonitorShutdown(t *testing.T) { + signalCh := make(chan struct{}) + + // Three shutdown handlers. + var wg sync.WaitGroup + wg.Add(3) + h := ShutdownHandler{ + Component: "handler", + StopFunc: func(_ context.Context) error { + wg.Done() + return nil + }, + } + + finishCh := MonitorShutdown(signalCh, h, h, h) + + // Nothing here after 10ms. + time.Sleep(10 * time.Millisecond) + require.Len(t, finishCh, 0) + + // Now trigger the shutdown. + close(signalCh) + wg.Wait() + <-finishCh +}