cmd/puppeth: accept ssh identity in the server string (#17407)

* cmd/puppeth: Accept identityfile in the server string with fallback to id_rsa

* cmd/puppeth: code polishes + fix heath check double ports
This commit is contained in:
Nilesh Trivedi 2018-08-20 19:24:38 +05:30 committed by Péter Szilágyi
parent 1de9ada401
commit 7d38d53ae4
2 changed files with 34 additions and 26 deletions

View File

@ -45,33 +45,44 @@ type sshClient struct {
// dial establishes an SSH connection to a remote node using the current user and // dial establishes an SSH connection to a remote node using the current user and
// the user's configured private RSA key. If that fails, password authentication // the user's configured private RSA key. If that fails, password authentication
// is fallen back to. The caller may override the login user via user@server:port. // is fallen back to. server can be a string like user:identity@server:port.
func dial(server string, pubkey []byte) (*sshClient, error) { func dial(server string, pubkey []byte) (*sshClient, error) {
// Figure out a label for the server and a logger // Figure out username, identity, hostname and port
label := server hostname := ""
if strings.Contains(label, ":") { hostport := server
label = label[:strings.Index(label, ":")] username := ""
} identity := "id_rsa" // default
login := ""
if strings.Contains(server, "@") { if strings.Contains(server, "@") {
login = label[:strings.Index(label, "@")] prefix := server[:strings.Index(server, "@")]
label = label[strings.Index(label, "@")+1:] if strings.Contains(prefix, ":") {
server = server[strings.Index(server, "@")+1:] username = prefix[:strings.Index(prefix, ":")]
identity = prefix[strings.Index(prefix, ":")+1:]
} else {
username = prefix
}
hostport = server[strings.Index(server, "@")+1:]
} }
logger := log.New("server", label) if strings.Contains(hostport, ":") {
hostname = hostport[:strings.Index(hostport, ":")]
} else {
hostname = hostport
hostport += ":22"
}
logger := log.New("server", server)
logger.Debug("Attempting to establish SSH connection") logger.Debug("Attempting to establish SSH connection")
user, err := user.Current() user, err := user.Current()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if login == "" { if username == "" {
login = user.Username username = user.Username
} }
// Configure the supported authentication methods (private key and password) // Configure the supported authentication methods (private key and password)
var auths []ssh.AuthMethod var auths []ssh.AuthMethod
path := filepath.Join(user.HomeDir, ".ssh", "id_rsa") path := filepath.Join(user.HomeDir, ".ssh", identity)
if buf, err := ioutil.ReadFile(path); err != nil { if buf, err := ioutil.ReadFile(path); err != nil {
log.Warn("No SSH key, falling back to passwords", "path", path, "err", err) log.Warn("No SSH key, falling back to passwords", "path", path, "err", err)
} else { } else {
@ -94,14 +105,14 @@ func dial(server string, pubkey []byte) (*sshClient, error) {
} }
} }
auths = append(auths, ssh.PasswordCallback(func() (string, error) { auths = append(auths, ssh.PasswordCallback(func() (string, error) {
fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", login, server) fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", username, server)
blob, err := terminal.ReadPassword(int(os.Stdin.Fd())) blob, err := terminal.ReadPassword(int(os.Stdin.Fd()))
fmt.Println() fmt.Println()
return string(blob), err return string(blob), err
})) }))
// Resolve the IP address of the remote server // Resolve the IP address of the remote server
addr, err := net.LookupHost(label) addr, err := net.LookupHost(hostname)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -109,10 +120,7 @@ func dial(server string, pubkey []byte) (*sshClient, error) {
return nil, errors.New("no IPs associated with domain") return nil, errors.New("no IPs associated with domain")
} }
// Try to dial in to the remote server // Try to dial in to the remote server
logger.Trace("Dialing remote SSH server", "user", login) logger.Trace("Dialing remote SSH server", "user", username)
if !strings.Contains(server, ":") {
server += ":22"
}
keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error { keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// If no public key is known for SSH, ask the user to confirm // If no public key is known for SSH, ask the user to confirm
if pubkey == nil { if pubkey == nil {
@ -139,13 +147,13 @@ func dial(server string, pubkey []byte) (*sshClient, error) {
// We have a mismatch, forbid connecting // We have a mismatch, forbid connecting
return errors.New("ssh key mismatch, readd the machine to update") return errors.New("ssh key mismatch, readd the machine to update")
} }
client, err := ssh.Dial("tcp", server, &ssh.ClientConfig{User: login, Auth: auths, HostKeyCallback: keycheck}) client, err := ssh.Dial("tcp", hostport, &ssh.ClientConfig{User: username, Auth: auths, HostKeyCallback: keycheck})
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Connection established, return our utility wrapper // Connection established, return our utility wrapper
c := &sshClient{ c := &sshClient{
server: label, server: hostname,
address: addr[0], address: addr[0],
pubkey: pubkey, pubkey: pubkey,
client: client, client: client,

View File

@ -62,14 +62,14 @@ func (w *wizard) manageServers() {
} }
} }
// makeServer reads a single line from stdin and interprets it as a hostname to // makeServer reads a single line from stdin and interprets it as
// connect to. It tries to establish a new SSH session and also executing some // username:identity@hostname to connect to. It tries to establish a
// baseline validations. // new SSH session and also executing some baseline validations.
// //
// If connection succeeds, the server is added to the wizards configs! // If connection succeeds, the server is added to the wizards configs!
func (w *wizard) makeServer() string { func (w *wizard) makeServer() string {
fmt.Println() fmt.Println()
fmt.Println("Please enter remote server's address:") fmt.Println("What is the remote server's address ([username[:identity]@]hostname[:port])?")
// Read and dial the server to ensure docker is present // Read and dial the server to ensure docker is present
input := w.readString() input := w.readString()