ipld-eth-server/vendor/github.com/pressly/sup/ssh.go
2018-09-17 13:25:19 -05:00

295 lines
6.4 KiB
Go

package sup
import (
"fmt"
"io"
"io/ioutil"
"net"
"os"
"os/user"
"path/filepath"
"strings"
"sync"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
// Client is a wrapper over the SSH connection/sessions.
type SSHClient struct {
conn *ssh.Client
sess *ssh.Session
user string
host string
remoteStdin io.WriteCloser
remoteStdout io.Reader
remoteStderr io.Reader
connOpened bool
sessOpened bool
running bool
env string //export FOO="bar"; export BAR="baz";
color string
}
type ErrConnect struct {
User string
Host string
Reason string
}
func (e ErrConnect) Error() string {
return fmt.Sprintf(`Connect("%v@%v"): %v`, e.User, e.Host, e.Reason)
}
// parseHost parses and normalizes <user>@<host:port> from a given string.
func (c *SSHClient) parseHost(host string) error {
c.host = host
// Remove extra "ssh://" schema
if len(c.host) > 6 && c.host[:6] == "ssh://" {
c.host = c.host[6:]
}
if at := strings.Index(c.host, "@"); at != -1 {
c.user = c.host[:at]
c.host = c.host[at+1:]
}
// Add default user, if not set
if c.user == "" {
u, err := user.Current()
if err != nil {
return err
}
c.user = u.Username
}
if strings.Index(c.host, "/") != -1 {
return ErrConnect{c.user, c.host, "unexpected slash in the host URL"}
}
// Add default port, if not set
if strings.Index(c.host, ":") == -1 {
c.host += ":22"
}
return nil
}
var initAuthMethodOnce sync.Once
var authMethod ssh.AuthMethod
// initAuthMethod initiates SSH authentication method.
func initAuthMethod() {
var signers []ssh.Signer
// If there's a running SSH Agent, try to use its Private keys.
sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err == nil {
agent := agent.NewClient(sock)
signers, _ = agent.Signers()
}
// Try to read user's SSH private keys form the standard paths.
files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*")
for _, file := range files {
if strings.HasSuffix(file, ".pub") {
continue // Skip public keys.
}
data, err := ioutil.ReadFile(file)
if err != nil {
continue
}
signer, err := ssh.ParsePrivateKey(data)
if err != nil {
continue
}
signers = append(signers, signer)
}
authMethod = ssh.PublicKeys(signers...)
}
// SSHDialFunc can dial an ssh server and return a client
type SSHDialFunc func(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error)
// Connect creates SSH connection to a specified host.
// It expects the host of the form "[ssh://]host[:port]".
func (c *SSHClient) Connect(host string) error {
return c.ConnectWith(host, ssh.Dial)
}
// ConnectWith creates a SSH connection to a specified host. It will use dialer to establish the
// connection.
// TODO: Split Signers to its own method.
func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error {
if c.connOpened {
return fmt.Errorf("Already connected")
}
initAuthMethodOnce.Do(initAuthMethod)
err := c.parseHost(host)
if err != nil {
return err
}
config := &ssh.ClientConfig{
User: c.user,
Auth: []ssh.AuthMethod{
authMethod,
},
}
c.conn, err = dialer("tcp", c.host, config)
if err != nil {
return ErrConnect{c.user, c.host, err.Error()}
}
c.connOpened = true
return nil
}
// Run runs the task.Run command remotely on c.host.
func (c *SSHClient) Run(task *Task) error {
if c.running {
return fmt.Errorf("Session already running")
}
if c.sessOpened {
return fmt.Errorf("Session already connected")
}
sess, err := c.conn.NewSession()
if err != nil {
return err
}
c.remoteStdin, err = sess.StdinPipe()
if err != nil {
return err
}
c.remoteStdout, err = sess.StdoutPipe()
if err != nil {
return err
}
c.remoteStderr, err = sess.StderrPipe()
if err != nil {
return err
}
if task.TTY {
// Set up terminal modes
modes := ssh.TerminalModes{
ssh.ECHO: 0, // disable echoing
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
// Request pseudo terminal
if err := sess.RequestPty("xterm", 80, 40, modes); err != nil {
return ErrTask{task, fmt.Sprintf("request for pseudo terminal failed: %s", err)}
}
}
// Start the remote command.
if err := sess.Start(c.env + task.Run); err != nil {
return ErrTask{task, err.Error()}
}
c.sess = sess
c.sessOpened = true
c.running = true
return nil
}
// Wait waits until the remote command finishes and exits.
// It closes the SSH session.
func (c *SSHClient) Wait() error {
if !c.running {
return fmt.Errorf("Trying to wait on stopped session")
}
err := c.sess.Wait()
c.sess.Close()
c.running = false
c.sessOpened = false
return err
}
// DialThrough will create a new connection from the ssh server sc is connected to. DialThrough is an SSHDialer.
func (sc *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
conn, err := sc.conn.Dial(net, addr)
if err != nil {
return nil, err
}
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
return nil, err
}
return ssh.NewClient(c, chans, reqs), nil
}
// Close closes the underlying SSH connection and session.
func (c *SSHClient) Close() error {
if c.sessOpened {
c.sess.Close()
c.sessOpened = false
}
if !c.connOpened {
return fmt.Errorf("Trying to close the already closed connection")
}
err := c.conn.Close()
c.connOpened = false
c.running = false
return err
}
func (c *SSHClient) Stdin() io.WriteCloser {
return c.remoteStdin
}
func (c *SSHClient) Stderr() io.Reader {
return c.remoteStderr
}
func (c *SSHClient) Stdout() io.Reader {
return c.remoteStdout
}
func (c *SSHClient) Prefix() (string, int) {
host := c.user + "@" + c.host + " | "
return c.color + host + ResetColor, len(host)
}
func (c *SSHClient) Write(p []byte) (n int, err error) {
return c.remoteStdin.Write(p)
}
func (c *SSHClient) WriteClose() error {
return c.remoteStdin.Close()
}
func (c *SSHClient) Signal(sig os.Signal) error {
if !c.sessOpened {
return fmt.Errorf("session is not open")
}
switch sig {
case os.Interrupt:
// TODO: Turns out that .Signal(ssh.SIGHUP) doesn't work for me.
// Instead, sending \x03 to the remote session works for me,
// which sounds like something that should be fixed/resolved
// upstream in the golang.org/x/crypto/ssh pkg.
// https://github.com/golang/go/issues/4115#issuecomment-66070418
c.remoteStdin.Write([]byte("\x03"))
return c.sess.Signal(ssh.SIGINT)
default:
return fmt.Errorf("%v not supported", sig)
}
}