// Copyright 2017 The go-ethereum Authors
// This file is part of go-ethereum.
//
// go-ethereum is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// go-ethereum is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.

package main

import (
	"bufio"
	"bytes"
	"errors"
	"fmt"
	"io/ioutil"
	"net"
	"os"
	"os/user"
	"path/filepath"
	"strings"

	"github.com/ethereum/go-ethereum/log"
	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"
	"golang.org/x/crypto/ssh/terminal"
)

// sshClient is a small wrapper around Go's SSH client with a few utility methods
// implemented on top.
type sshClient struct {
	server  string // Server name or IP without port number
	address string // IP address of the remote server
	pubkey  []byte // RSA public key to authenticate the server
	client  *ssh.Client
	logger  log.Logger
}

const EnvSSHAuthSock = "SSH_AUTH_SOCK"

// dial establishes an SSH connection to a remote node using the current user and
// the user's configured private RSA key. If that fails, password authentication
// is fallen back to. server can be a string like user:identity@server:port.
func dial(server string, pubkey []byte) (*sshClient, error) {
	// Figure out username, identity, hostname and port
	hostname := ""
	hostport := server
	username := ""
	identity := "id_rsa" // default

	if strings.Contains(server, "@") {
		prefix := server[:strings.Index(server, "@")]
		if strings.Contains(prefix, ":") {
			username = prefix[:strings.Index(prefix, ":")]
			identity = prefix[strings.Index(prefix, ":")+1:]
		} else {
			username = prefix
		}
		hostport = server[strings.Index(server, "@")+1:]
	}
	if strings.Contains(hostport, ":") {
		hostname = hostport[:strings.Index(hostport, ":")]
	} else {
		hostname = hostport
		hostport += ":22"
	}
	logger := log.New("server", server)
	logger.Debug("Attempting to establish SSH connection")

	user, err := user.Current()
	if err != nil {
		return nil, err
	}
	if username == "" {
		username = user.Username
	}

	// Configure the supported authentication methods (ssh agent, private key and password)
	var (
		auths []ssh.AuthMethod
		conn  net.Conn
	)
	if conn, err = net.Dial("unix", os.Getenv(EnvSSHAuthSock)); err != nil {
		log.Warn("Unable to dial SSH agent, falling back to private keys", "err", err)
	} else {
		client := agent.NewClient(conn)
		auths = append(auths, ssh.PublicKeysCallback(client.Signers))
	}
	if err != nil {
		path := filepath.Join(user.HomeDir, ".ssh", identity)
		if buf, err := ioutil.ReadFile(path); err != nil {
			log.Warn("No SSH key, falling back to passwords", "path", path, "err", err)
		} else {
			key, err := ssh.ParsePrivateKey(buf)
			if err != nil {
				fmt.Printf("What's the decryption password for %s? (won't be echoed)\n>", path)
				blob, err := terminal.ReadPassword(int(os.Stdin.Fd()))
				fmt.Println()
				if err != nil {
					log.Warn("Couldn't read password", "err", err)
				}
				key, err := ssh.ParsePrivateKeyWithPassphrase(buf, blob)
				if err != nil {
					log.Warn("Failed to decrypt SSH key, falling back to passwords", "path", path, "err", err)
				} else {
					auths = append(auths, ssh.PublicKeys(key))
				}
			} else {
				auths = append(auths, ssh.PublicKeys(key))
			}
		}
		auths = append(auths, ssh.PasswordCallback(func() (string, error) {
			fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", username, server)
			blob, err := terminal.ReadPassword(int(os.Stdin.Fd()))

			fmt.Println()
			return string(blob), err
		}))
	}
	// Resolve the IP address of the remote server
	addr, err := net.LookupHost(hostname)
	if err != nil {
		return nil, err
	}
	if len(addr) == 0 {
		return nil, errors.New("no IPs associated with domain")
	}
	// Try to dial in to the remote server
	logger.Trace("Dialing remote SSH server", "user", username)
	keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
		// If no public key is known for SSH, ask the user to confirm
		if pubkey == nil {
			fmt.Println()
			fmt.Printf("The authenticity of host '%s (%s)' can't be established.\n", hostname, remote)
			fmt.Printf("SSH key fingerprint is %s [MD5]\n", ssh.FingerprintLegacyMD5(key))
			fmt.Printf("Are you sure you want to continue connecting (yes/no)? ")

			for {
				text, err := bufio.NewReader(os.Stdin).ReadString('\n')
				switch {
				case err != nil:
					return err
				case strings.TrimSpace(text) == "yes":
					pubkey = key.Marshal()
					return nil
				case strings.TrimSpace(text) == "no":
					return errors.New("users says no")
				default:
					fmt.Println("Please answer 'yes' or 'no'")
					continue
				}
			}
		}
		// If a public key exists for this SSH server, check that it matches
		if bytes.Equal(pubkey, key.Marshal()) {
			return nil
		}
		// We have a mismatch, forbid connecting
		return errors.New("ssh key mismatch, readd the machine to update")
	}
	client, err := ssh.Dial("tcp", hostport, &ssh.ClientConfig{User: username, Auth: auths, HostKeyCallback: keycheck})
	if err != nil {
		return nil, err
	}
	// Connection established, return our utility wrapper
	c := &sshClient{
		server:  hostname,
		address: addr[0],
		pubkey:  pubkey,
		client:  client,
		logger:  logger,
	}
	if err := c.init(); err != nil {
		client.Close()
		return nil, err
	}
	return c, nil
}

// init runs some initialization commands on the remote server to ensure it's
// capable of acting as puppeth target.
func (client *sshClient) init() error {
	client.logger.Debug("Verifying if docker is available")
	if out, err := client.Run("docker version"); err != nil {
		if len(out) == 0 {
			return err
		}
		return fmt.Errorf("docker configured incorrectly: %s", out)
	}
	client.logger.Debug("Verifying if docker-compose is available")
	if out, err := client.Run("docker-compose version"); err != nil {
		if len(out) == 0 {
			return err
		}
		return fmt.Errorf("docker-compose configured incorrectly: %s", out)
	}
	return nil
}

// Close terminates the connection to an SSH server.
func (client *sshClient) Close() error {
	return client.client.Close()
}

// Run executes a command on the remote server and returns the combined output
// along with any error status.
func (client *sshClient) Run(cmd string) ([]byte, error) {
	// Establish a single command session
	session, err := client.client.NewSession()
	if err != nil {
		return nil, err
	}
	defer session.Close()

	// Execute the command and return any output
	client.logger.Trace("Running command on remote server", "cmd", cmd)
	return session.CombinedOutput(cmd)
}

// Stream executes a command on the remote server and streams all outputs into
// the local stdout and stderr streams.
func (client *sshClient) Stream(cmd string) error {
	// Establish a single command session
	session, err := client.client.NewSession()
	if err != nil {
		return err
	}
	defer session.Close()

	session.Stdout = os.Stdout
	session.Stderr = os.Stderr

	// Execute the command and return any output
	client.logger.Trace("Streaming command on remote server", "cmd", cmd)
	return session.Run(cmd)
}

// Upload copies the set of files to a remote server via SCP, creating any non-
// existing folders in the mean time.
func (client *sshClient) Upload(files map[string][]byte) ([]byte, error) {
	// Establish a single command session
	session, err := client.client.NewSession()
	if err != nil {
		return nil, err
	}
	defer session.Close()

	// Create a goroutine that streams the SCP content
	go func() {
		out, _ := session.StdinPipe()
		defer out.Close()

		for file, content := range files {
			client.logger.Trace("Uploading file to server", "file", file, "bytes", len(content))

			fmt.Fprintln(out, "D0755", 0, filepath.Dir(file))             // Ensure the folder exists
			fmt.Fprintln(out, "C0644", len(content), filepath.Base(file)) // Create the actual file
			out.Write(content)                                            // Stream the data content
			fmt.Fprint(out, "\x00")                                       // Transfer end with \x00
			fmt.Fprintln(out, "E")                                        // Leave directory (simpler)
		}
	}()
	return session.CombinedOutput("/usr/bin/scp -v -tr ./")
}