diff --git a/cmd/lotus-worker/tasks.go b/cmd/lotus-worker/tasks.go index 7b4462073..af9a9af70 100644 --- a/cmd/lotus-worker/tasks.go +++ b/cmd/lotus-worker/tasks.go @@ -45,32 +45,52 @@ var tasksEnableCmd = &cli.Command{ Name: "enable", Usage: "Enable a task type", ArgsUsage: "[" + settableStr + "]", - Action: taskAction(api.Worker.TaskEnable), + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "all", + Usage: "Enable all task types", + }, + }, + Action: taskAction(api.Worker.TaskEnable), } var tasksDisableCmd = &cli.Command{ Name: "disable", Usage: "Disable a task type", ArgsUsage: "[" + settableStr + "]", - Action: taskAction(api.Worker.TaskDisable), + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "all", + Usage: "Disable all task types", + }, + }, + Action: taskAction(api.Worker.TaskDisable), } func taskAction(tf func(a api.Worker, ctx context.Context, tt sealtasks.TaskType) error) func(cctx *cli.Context) error { return func(cctx *cli.Context) error { - if cctx.NArg() != 1 { - return xerrors.Errorf("expected 1 argument") + allFlag := cctx.Bool("all") + + if cctx.NArg() == 1 && allFlag { + return xerrors.Errorf("Cannot use --all flag with task type argument") + } + + if cctx.NArg() == 0 && !allFlag { + return xerrors.Errorf("Expected 1 argument or use --all flag") } var tt sealtasks.TaskType - for taskType := range allowSetting { - if taskType.Short() == cctx.Args().First() { - tt = taskType - break + if cctx.NArg() == 1 { + for taskType := range allowSetting { + if taskType.Short() == cctx.Args().First() { + tt = taskType + break + } } - } - if tt == "" { - return xerrors.Errorf("unknown task type '%s'", cctx.Args().First()) + if tt == "" { + return xerrors.Errorf("unknown task type '%s'", cctx.Args().First()) + } } api, closer, err := lcli.GetWorkerAPI(cctx) @@ -81,6 +101,15 @@ func taskAction(tf func(a api.Worker, ctx context.Context, tt sealtasks.TaskType ctx := lcli.ReqContext(cctx) + if allFlag { + for taskType := range allowSetting { + if err := tf(api, ctx, taskType); err != nil { + return err + } + } + return nil + } + return tf(api, ctx, tt) } }