cmd/puppeth: use geth's prompt to read input (#23718)

* cmd/puppeth: use geth's prompt to read input

* remove wizard.in

* cmd/puppeth: fix compilation errors

* reset prompt (don't exit) on receiving ctrl-c

* make promptInput spin until the user enters a value or interrupts (ctrl-d)

* make promptInput use parameter

Co-authored-by: Martin Holst Swende <martin@swende.se>
This commit is contained in:
jwasinger 2021-10-18 20:59:01 +02:00 committed by GitHub
parent c36f8fefc3
commit 60d3cc8b77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 90 deletions

View File

@ -17,7 +17,6 @@
package main package main
import ( import (
"bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -32,8 +31,10 @@ import (
"sync" "sync"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/console/prompt"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/peterh/liner"
"golang.org/x/crypto/ssh/terminal" "golang.org/x/crypto/ssh/terminal"
) )
@ -76,17 +77,27 @@ type wizard struct {
servers map[string]*sshClient // SSH connections to servers to administer servers map[string]*sshClient // SSH connections to servers to administer
services map[string][]string // Ethereum services known to be running on servers services map[string][]string // Ethereum services known to be running on servers
in *bufio.Reader // Wrapper around stdin to allow reading user input lock sync.Mutex // Lock to protect configs during concurrent service discovery
lock sync.Mutex // Lock to protect configs during concurrent service discovery }
// prompts the user for input with the given prompt string. Returns when a value is entered.
// Causes the wizard to exit if ctrl-d is pressed
func promptInput(p string) string {
for {
text, err := prompt.Stdin.PromptInput(p)
if err != nil {
if err != liner.ErrPromptAborted {
log.Crit("Failed to read user input", "err", err)
}
} else {
return text
}
}
} }
// read reads a single line from stdin, trimming if from spaces. // read reads a single line from stdin, trimming if from spaces.
func (w *wizard) read() string { func (w *wizard) read() string {
fmt.Printf("> ") text := promptInput("> ")
text, err := w.in.ReadString('\n')
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
return strings.TrimSpace(text) return strings.TrimSpace(text)
} }
@ -94,11 +105,7 @@ func (w *wizard) read() string {
// non-emptyness. // non-emptyness.
func (w *wizard) readString() string { func (w *wizard) readString() string {
for { for {
fmt.Printf("> ") text := promptInput("> ")
text, err := w.in.ReadString('\n')
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
if text = strings.TrimSpace(text); text != "" { if text = strings.TrimSpace(text); text != "" {
return text return text
} }
@ -108,11 +115,7 @@ func (w *wizard) readString() string {
// readDefaultString reads a single line from stdin, trimming if from spaces. If // readDefaultString reads a single line from stdin, trimming if from spaces. If
// an empty line is entered, the default value is returned. // an empty line is entered, the default value is returned.
func (w *wizard) readDefaultString(def string) string { func (w *wizard) readDefaultString(def string) string {
fmt.Printf("> ") text := promptInput("> ")
text, err := w.in.ReadString('\n')
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
if text = strings.TrimSpace(text); text != "" { if text = strings.TrimSpace(text); text != "" {
return text return text
} }
@ -124,11 +127,7 @@ func (w *wizard) readDefaultString(def string) string {
// value is returned. // value is returned.
func (w *wizard) readDefaultYesNo(def bool) bool { func (w *wizard) readDefaultYesNo(def bool) bool {
for { for {
fmt.Printf("> ") text := promptInput("> ")
text, err := w.in.ReadString('\n')
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
if text = strings.ToLower(strings.TrimSpace(text)); text == "" { if text = strings.ToLower(strings.TrimSpace(text)); text == "" {
return def return def
} }
@ -146,11 +145,7 @@ func (w *wizard) readDefaultYesNo(def bool) bool {
// interpret it as a URL (http, https or file). // interpret it as a URL (http, https or file).
func (w *wizard) readURL() *url.URL { func (w *wizard) readURL() *url.URL {
for { for {
fmt.Printf("> ") text := promptInput("> ")
text, err := w.in.ReadString('\n')
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
uri, err := url.Parse(strings.TrimSpace(text)) uri, err := url.Parse(strings.TrimSpace(text))
if err != nil { if err != nil {
log.Error("Invalid input, expected URL", "err", err) log.Error("Invalid input, expected URL", "err", err)
@ -164,11 +159,7 @@ func (w *wizard) readURL() *url.URL {
// to parse into an integer. // to parse into an integer.
func (w *wizard) readInt() int { func (w *wizard) readInt() int {
for { for {
fmt.Printf("> ") text := promptInput("> ")
text, err := w.in.ReadString('\n')
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
if text = strings.TrimSpace(text); text == "" { if text = strings.TrimSpace(text); text == "" {
continue continue
} }
@ -186,11 +177,7 @@ func (w *wizard) readInt() int {
// returned. // returned.
func (w *wizard) readDefaultInt(def int) int { func (w *wizard) readDefaultInt(def int) int {
for { for {
fmt.Printf("> ") text := promptInput("> ")
text, err := w.in.ReadString('\n')
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
if text = strings.TrimSpace(text); text == "" { if text = strings.TrimSpace(text); text == "" {
return def return def
} }
@ -208,11 +195,7 @@ func (w *wizard) readDefaultInt(def int) int {
// default value is returned. // default value is returned.
func (w *wizard) readDefaultBigInt(def *big.Int) *big.Int { func (w *wizard) readDefaultBigInt(def *big.Int) *big.Int {
for { for {
fmt.Printf("> ") text := promptInput("> ")
text, err := w.in.ReadString('\n')
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
if text = strings.TrimSpace(text); text == "" { if text = strings.TrimSpace(text); text == "" {
return def return def
} }
@ -225,38 +208,11 @@ func (w *wizard) readDefaultBigInt(def *big.Int) *big.Int {
} }
} }
/*
// readFloat reads a single line from stdin, trimming if from spaces, enforcing it
// to parse into a float.
func (w *wizard) readFloat() float64 {
for {
fmt.Printf("> ")
text, err := w.in.ReadString('\n')
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
if text = strings.TrimSpace(text); text == "" {
continue
}
val, err := strconv.ParseFloat(strings.TrimSpace(text), 64)
if err != nil {
log.Error("Invalid input, expected float", "err", err)
continue
}
return val
}
}
*/
// readDefaultFloat reads a single line from stdin, trimming if from spaces, enforcing // readDefaultFloat reads a single line from stdin, trimming if from spaces, enforcing
// it to parse into a float. If an empty line is entered, the default value is returned. // it to parse into a float. If an empty line is entered, the default value is returned.
func (w *wizard) readDefaultFloat(def float64) float64 { func (w *wizard) readDefaultFloat(def float64) float64 {
for { for {
fmt.Printf("> ") text := promptInput("> ")
text, err := w.in.ReadString('\n')
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
if text = strings.TrimSpace(text); text == "" { if text = strings.TrimSpace(text); text == "" {
return def return def
} }
@ -285,12 +241,7 @@ func (w *wizard) readPassword() string {
// it to an Ethereum address. // it to an Ethereum address.
func (w *wizard) readAddress() *common.Address { func (w *wizard) readAddress() *common.Address {
for { for {
// Read the address from the user text := promptInput("> 0x")
fmt.Printf("> 0x")
text, err := w.in.ReadString('\n')
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
if text = strings.TrimSpace(text); text == "" { if text = strings.TrimSpace(text); text == "" {
return nil return nil
} }
@ -311,11 +262,7 @@ func (w *wizard) readAddress() *common.Address {
func (w *wizard) readDefaultAddress(def common.Address) common.Address { func (w *wizard) readDefaultAddress(def common.Address) common.Address {
for { for {
// Read the address from the user // Read the address from the user
fmt.Printf("> 0x") text := promptInput("> 0x")
text, err := w.in.ReadString('\n')
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
if text = strings.TrimSpace(text); text == "" { if text = strings.TrimSpace(text); text == "" {
return def return def
} }
@ -334,8 +281,9 @@ func (w *wizard) readJSON() string {
var blob json.RawMessage var blob json.RawMessage
for { for {
fmt.Printf("> ") text := promptInput("> ")
if err := json.NewDecoder(w.in).Decode(&blob); err != nil { reader := strings.NewReader(text)
if err := json.NewDecoder(reader).Decode(&blob); err != nil {
log.Error("Invalid JSON, please try again", "err", err) log.Error("Invalid JSON, please try again", "err", err)
continue continue
} }
@ -351,10 +299,7 @@ func (w *wizard) readIPAddress() string {
for { for {
// Read the IP address from the user // Read the IP address from the user
fmt.Printf("> ") fmt.Printf("> ")
text, err := w.in.ReadString('\n') text := promptInput("> ")
if err != nil {
log.Crit("Failed to read user input", "err", err)
}
if text = strings.TrimSpace(text); text == "" { if text = strings.TrimSpace(text); text == "" {
return "" return ""
} }

View File

@ -17,7 +17,6 @@
package main package main
import ( import (
"bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -38,7 +37,6 @@ func makeWizard(network string) *wizard {
}, },
servers: make(map[string]*sshClient), servers: make(map[string]*sshClient),
services: make(map[string][]string), services: make(map[string][]string),
in: bufio.NewReader(os.Stdin),
} }
} }