cmd/puppeth: accept ssh identity in the server string (#17407)
* cmd/puppeth: Accept identityfile in the server string with fallback to id_rsa * cmd/puppeth: code polishes + fix heath check double ports
This commit is contained in:
parent
1de9ada401
commit
7d38d53ae4
@ -45,33 +45,44 @@ type sshClient struct {
|
|||||||
|
|
||||||
// dial establishes an SSH connection to a remote node using the current user and
|
// dial establishes an SSH connection to a remote node using the current user and
|
||||||
// the user's configured private RSA key. If that fails, password authentication
|
// the user's configured private RSA key. If that fails, password authentication
|
||||||
// is fallen back to. The caller may override the login user via user@server:port.
|
// is fallen back to. server can be a string like user:identity@server:port.
|
||||||
func dial(server string, pubkey []byte) (*sshClient, error) {
|
func dial(server string, pubkey []byte) (*sshClient, error) {
|
||||||
// Figure out a label for the server and a logger
|
// Figure out username, identity, hostname and port
|
||||||
label := server
|
hostname := ""
|
||||||
if strings.Contains(label, ":") {
|
hostport := server
|
||||||
label = label[:strings.Index(label, ":")]
|
username := ""
|
||||||
}
|
identity := "id_rsa" // default
|
||||||
login := ""
|
|
||||||
if strings.Contains(server, "@") {
|
if strings.Contains(server, "@") {
|
||||||
login = label[:strings.Index(label, "@")]
|
prefix := server[:strings.Index(server, "@")]
|
||||||
label = label[strings.Index(label, "@")+1:]
|
if strings.Contains(prefix, ":") {
|
||||||
server = server[strings.Index(server, "@")+1:]
|
username = prefix[:strings.Index(prefix, ":")]
|
||||||
|
identity = prefix[strings.Index(prefix, ":")+1:]
|
||||||
|
} else {
|
||||||
|
username = prefix
|
||||||
|
}
|
||||||
|
hostport = server[strings.Index(server, "@")+1:]
|
||||||
}
|
}
|
||||||
logger := log.New("server", label)
|
if strings.Contains(hostport, ":") {
|
||||||
|
hostname = hostport[:strings.Index(hostport, ":")]
|
||||||
|
} else {
|
||||||
|
hostname = hostport
|
||||||
|
hostport += ":22"
|
||||||
|
}
|
||||||
|
logger := log.New("server", server)
|
||||||
logger.Debug("Attempting to establish SSH connection")
|
logger.Debug("Attempting to establish SSH connection")
|
||||||
|
|
||||||
user, err := user.Current()
|
user, err := user.Current()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if login == "" {
|
if username == "" {
|
||||||
login = user.Username
|
username = user.Username
|
||||||
}
|
}
|
||||||
// Configure the supported authentication methods (private key and password)
|
// Configure the supported authentication methods (private key and password)
|
||||||
var auths []ssh.AuthMethod
|
var auths []ssh.AuthMethod
|
||||||
|
|
||||||
path := filepath.Join(user.HomeDir, ".ssh", "id_rsa")
|
path := filepath.Join(user.HomeDir, ".ssh", identity)
|
||||||
if buf, err := ioutil.ReadFile(path); err != nil {
|
if buf, err := ioutil.ReadFile(path); err != nil {
|
||||||
log.Warn("No SSH key, falling back to passwords", "path", path, "err", err)
|
log.Warn("No SSH key, falling back to passwords", "path", path, "err", err)
|
||||||
} else {
|
} else {
|
||||||
@ -94,14 +105,14 @@ func dial(server string, pubkey []byte) (*sshClient, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
auths = append(auths, ssh.PasswordCallback(func() (string, error) {
|
auths = append(auths, ssh.PasswordCallback(func() (string, error) {
|
||||||
fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", login, server)
|
fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", username, server)
|
||||||
blob, err := terminal.ReadPassword(int(os.Stdin.Fd()))
|
blob, err := terminal.ReadPassword(int(os.Stdin.Fd()))
|
||||||
|
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
return string(blob), err
|
return string(blob), err
|
||||||
}))
|
}))
|
||||||
// Resolve the IP address of the remote server
|
// Resolve the IP address of the remote server
|
||||||
addr, err := net.LookupHost(label)
|
addr, err := net.LookupHost(hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -109,10 +120,7 @@ func dial(server string, pubkey []byte) (*sshClient, error) {
|
|||||||
return nil, errors.New("no IPs associated with domain")
|
return nil, errors.New("no IPs associated with domain")
|
||||||
}
|
}
|
||||||
// Try to dial in to the remote server
|
// Try to dial in to the remote server
|
||||||
logger.Trace("Dialing remote SSH server", "user", login)
|
logger.Trace("Dialing remote SSH server", "user", username)
|
||||||
if !strings.Contains(server, ":") {
|
|
||||||
server += ":22"
|
|
||||||
}
|
|
||||||
keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||||
// If no public key is known for SSH, ask the user to confirm
|
// If no public key is known for SSH, ask the user to confirm
|
||||||
if pubkey == nil {
|
if pubkey == nil {
|
||||||
@ -139,13 +147,13 @@ func dial(server string, pubkey []byte) (*sshClient, error) {
|
|||||||
// We have a mismatch, forbid connecting
|
// We have a mismatch, forbid connecting
|
||||||
return errors.New("ssh key mismatch, readd the machine to update")
|
return errors.New("ssh key mismatch, readd the machine to update")
|
||||||
}
|
}
|
||||||
client, err := ssh.Dial("tcp", server, &ssh.ClientConfig{User: login, Auth: auths, HostKeyCallback: keycheck})
|
client, err := ssh.Dial("tcp", hostport, &ssh.ClientConfig{User: username, Auth: auths, HostKeyCallback: keycheck})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// Connection established, return our utility wrapper
|
// Connection established, return our utility wrapper
|
||||||
c := &sshClient{
|
c := &sshClient{
|
||||||
server: label,
|
server: hostname,
|
||||||
address: addr[0],
|
address: addr[0],
|
||||||
pubkey: pubkey,
|
pubkey: pubkey,
|
||||||
client: client,
|
client: client,
|
||||||
|
@ -62,14 +62,14 @@ func (w *wizard) manageServers() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeServer reads a single line from stdin and interprets it as a hostname to
|
// makeServer reads a single line from stdin and interprets it as
|
||||||
// connect to. It tries to establish a new SSH session and also executing some
|
// username:identity@hostname to connect to. It tries to establish a
|
||||||
// baseline validations.
|
// new SSH session and also executing some baseline validations.
|
||||||
//
|
//
|
||||||
// If connection succeeds, the server is added to the wizards configs!
|
// If connection succeeds, the server is added to the wizards configs!
|
||||||
func (w *wizard) makeServer() string {
|
func (w *wizard) makeServer() string {
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
fmt.Println("Please enter remote server's address:")
|
fmt.Println("What is the remote server's address ([username[:identity]@]hostname[:port])?")
|
||||||
|
|
||||||
// Read and dial the server to ensure docker is present
|
// Read and dial the server to ensure docker is present
|
||||||
input := w.readString()
|
input := w.readString()
|
||||||
|
Loading…
Reference in New Issue
Block a user