diff --git a/cmd/puppeth/wizard.go b/cmd/puppeth/wizard.go index 83536506c..c0edc5401 100644 --- a/cmd/puppeth/wizard.go +++ b/cmd/puppeth/wizard.go @@ -17,7 +17,6 @@ package main import ( - "bufio" "encoding/json" "fmt" "io/ioutil" @@ -32,8 +31,10 @@ import ( "sync" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/console/prompt" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/log" + "github.com/peterh/liner" "golang.org/x/crypto/ssh/terminal" ) @@ -76,17 +77,27 @@ type wizard struct { servers map[string]*sshClient // SSH connections to servers to administer services map[string][]string // Ethereum services known to be running on servers - in *bufio.Reader // Wrapper around stdin to allow reading user input - lock sync.Mutex // Lock to protect configs during concurrent service discovery + lock sync.Mutex // Lock to protect configs during concurrent service discovery +} + +// prompts the user for input with the given prompt string. Returns when a value is entered. +// Causes the wizard to exit if ctrl-d is pressed +func promptInput(p string) string { + for { + text, err := prompt.Stdin.PromptInput(p) + if err != nil { + if err != liner.ErrPromptAborted { + log.Crit("Failed to read user input", "err", err) + } + } else { + return text + } + } } // read reads a single line from stdin, trimming if from spaces. func (w *wizard) read() string { - fmt.Printf("> ") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } + text := promptInput("> ") return strings.TrimSpace(text) } @@ -94,11 +105,7 @@ func (w *wizard) read() string { // non-emptyness. func (w *wizard) readString() string { for { - fmt.Printf("> ") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } + text := promptInput("> ") if text = strings.TrimSpace(text); text != "" { return text } @@ -108,11 +115,7 @@ func (w *wizard) readString() string { // readDefaultString reads a single line from stdin, trimming if from spaces. If // an empty line is entered, the default value is returned. func (w *wizard) readDefaultString(def string) string { - fmt.Printf("> ") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } + text := promptInput("> ") if text = strings.TrimSpace(text); text != "" { return text } @@ -124,11 +127,7 @@ func (w *wizard) readDefaultString(def string) string { // value is returned. func (w *wizard) readDefaultYesNo(def bool) bool { for { - fmt.Printf("> ") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } + text := promptInput("> ") if text = strings.ToLower(strings.TrimSpace(text)); text == "" { return def } @@ -146,11 +145,7 @@ func (w *wizard) readDefaultYesNo(def bool) bool { // interpret it as a URL (http, https or file). func (w *wizard) readURL() *url.URL { for { - fmt.Printf("> ") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } + text := promptInput("> ") uri, err := url.Parse(strings.TrimSpace(text)) if err != nil { log.Error("Invalid input, expected URL", "err", err) @@ -164,11 +159,7 @@ func (w *wizard) readURL() *url.URL { // to parse into an integer. func (w *wizard) readInt() int { for { - fmt.Printf("> ") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } + text := promptInput("> ") if text = strings.TrimSpace(text); text == "" { continue } @@ -186,11 +177,7 @@ func (w *wizard) readInt() int { // returned. func (w *wizard) readDefaultInt(def int) int { for { - fmt.Printf("> ") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } + text := promptInput("> ") if text = strings.TrimSpace(text); text == "" { return def } @@ -208,11 +195,7 @@ func (w *wizard) readDefaultInt(def int) int { // default value is returned. func (w *wizard) readDefaultBigInt(def *big.Int) *big.Int { for { - fmt.Printf("> ") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } + text := promptInput("> ") if text = strings.TrimSpace(text); text == "" { return def } @@ -225,38 +208,11 @@ func (w *wizard) readDefaultBigInt(def *big.Int) *big.Int { } } -/* -// readFloat reads a single line from stdin, trimming if from spaces, enforcing it -// to parse into a float. -func (w *wizard) readFloat() float64 { - for { - fmt.Printf("> ") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } - if text = strings.TrimSpace(text); text == "" { - continue - } - val, err := strconv.ParseFloat(strings.TrimSpace(text), 64) - if err != nil { - log.Error("Invalid input, expected float", "err", err) - continue - } - return val - } -} -*/ - // readDefaultFloat reads a single line from stdin, trimming if from spaces, enforcing // it to parse into a float. If an empty line is entered, the default value is returned. func (w *wizard) readDefaultFloat(def float64) float64 { for { - fmt.Printf("> ") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } + text := promptInput("> ") if text = strings.TrimSpace(text); text == "" { return def } @@ -285,12 +241,7 @@ func (w *wizard) readPassword() string { // it to an Ethereum address. func (w *wizard) readAddress() *common.Address { for { - // Read the address from the user - fmt.Printf("> 0x") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } + text := promptInput("> 0x") if text = strings.TrimSpace(text); text == "" { return nil } @@ -311,11 +262,7 @@ func (w *wizard) readAddress() *common.Address { func (w *wizard) readDefaultAddress(def common.Address) common.Address { for { // Read the address from the user - fmt.Printf("> 0x") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } + text := promptInput("> 0x") if text = strings.TrimSpace(text); text == "" { return def } @@ -334,8 +281,9 @@ func (w *wizard) readJSON() string { var blob json.RawMessage for { - fmt.Printf("> ") - if err := json.NewDecoder(w.in).Decode(&blob); err != nil { + text := promptInput("> ") + reader := strings.NewReader(text) + if err := json.NewDecoder(reader).Decode(&blob); err != nil { log.Error("Invalid JSON, please try again", "err", err) continue } @@ -351,10 +299,7 @@ func (w *wizard) readIPAddress() string { for { // Read the IP address from the user fmt.Printf("> ") - text, err := w.in.ReadString('\n') - if err != nil { - log.Crit("Failed to read user input", "err", err) - } + text := promptInput("> ") if text = strings.TrimSpace(text); text == "" { return "" } diff --git a/cmd/puppeth/wizard_intro.go b/cmd/puppeth/wizard_intro.go index 75fb04b76..8610b908d 100644 --- a/cmd/puppeth/wizard_intro.go +++ b/cmd/puppeth/wizard_intro.go @@ -17,7 +17,6 @@ package main import ( - "bufio" "encoding/json" "fmt" "io/ioutil" @@ -38,7 +37,6 @@ func makeWizard(network string) *wizard { }, servers: make(map[string]*sshClient), services: make(map[string][]string), - in: bufio.NewReader(os.Stdin), } }