Start Up is a simple deployment tool that performs given set of commands on multiple hosts in parallel.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

410 lines
8.5 KiB

package stup
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"os/user"
"path/filepath"
"strings"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
// Client is a wrapper over the SSH connection/sessions.
type SSHClient struct {
conn *ssh.Client
sess *ssh.Session
user string
host string
pass string
ident string
sudo bool
remoteStdin io.WriteCloser
remoteStdout io.Reader
remoteStderr io.Reader
connOpened bool
sessOpened bool
running bool
env string //export FOO="bar"; export BAR="baz";
color string
}
type ErrConnect struct {
User string
Host string
Reason string
}
func (e ErrConnect) Error() string {
return fmt.Sprintf(`Connect("%v@%v"): %v`, e.User, e.Host, e.Reason)
}
// parseHost parses and normalizes <user>@<host:port> from a given string.
func (c *SSHClient) parseHost(host string) error {
c.host = host
// Remove extra "ssh://" schema
if len(c.host) > 6 && c.host[:6] == "ssh://" {
c.host = c.host[6:]
}
// Split by the last "@", since there may be an "@" in the username.
if at := strings.LastIndex(c.host, "@"); at != -1 {
c.user = c.host[:at]
c.host = c.host[at+1:]
}
// Add default user, if not set
if c.user == "" {
u, err := user.Current()
if err != nil {
return err
}
c.user = u.Username
}
if strings.Contains(c.host, "/") {
return ErrConnect{c.user, c.host, "unexpected slash in the host URL"}
}
// Add default port, if not set
if !strings.Contains(c.host, ":") {
c.host += ":22"
}
_, _, err := net.SplitHostPort(c.host)
if err != nil {
return err
}
return nil
}
// initAuthMethod initiates SSH authentication method.
func (c *SSHClient) setPubKeyAuth() (ssh.AuthMethod, error) {
var signers []ssh.Signer
// If there's a running SSH Agent, try to use its Private keys.
sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err == nil {
agent := agent.NewClient(sock)
signers, _ = agent.Signers()
}
if c.ident != "" {
data, err := ioutil.ReadFile(ResolvePath(c.ident))
if err != nil {
return nil, err
}
signer, err := ssh.ParsePrivateKey(data)
if err != nil {
return nil, err
}
signers = append(signers, signer)
} else {
// Try to read user's SSH private keys form the standard paths.
files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*")
for _, file := range files {
if strings.HasSuffix(file, ".pub") {
continue // Skip public keys.
}
data, err := ioutil.ReadFile(file)
if err != nil {
continue
}
signer, err := ssh.ParsePrivateKey(data)
if err != nil {
continue
}
signers = append(signers, signer)
}
}
return ssh.PublicKeys(signers...), err
}
// SSHDialFunc can dial an ssh server and return a client
type SSHDialFunc func(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error)
// Connect creates SSH connection to a specified host.
// It expects the host of the form "[ssh://]host[:port]".
func (c *SSHClient) Connect(host string) error {
return c.ConnectWith(host, ssh.Dial)
}
// ConnectWith creates a SSH connection to a specified host. It will use dialer to establish the
// connection.
// TODO: Split Signers to its own method.
func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error {
var authMethod ssh.AuthMethod
if c.connOpened {
return fmt.Errorf("already connected")
}
err := c.parseHost(host)
if err != nil {
return err
}
authMethod, err = c.setPubKeyAuth()
if err != nil {
return err
}
config := &ssh.ClientConfig{
User: c.user,
Auth: []ssh.AuthMethod{
authMethod,
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
if c.pass != "" {
authMethod = ssh.Password(c.pass)
config.Auth = append(config.Auth, authMethod)
}
c.conn, err = dialer("tcp", c.host, config)
if err != nil {
return ErrConnect{c.user, c.host, err.Error()}
}
c.connOpened = true
return nil
}
// Run runs the task.Run command remotely on c.host.
func (c *SSHClient) Run(task *Task) error {
if c.running {
return fmt.Errorf("session already running")
}
if c.sessOpened {
return fmt.Errorf("session already connected")
}
sess, err := c.conn.NewSession()
if err != nil {
return err
}
c.remoteStdin, err = sess.StdinPipe()
if err != nil {
return err
}
c.remoteStdout, err = sess.StdoutPipe()
if err != nil {
return err
}
c.remoteStderr, err = sess.StderrPipe()
if err != nil {
return err
}
if task.TTY {
// Set up terminal modes
modes := ssh.TerminalModes{
ssh.ECHO: 0, // disable echoing
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
// Request pseudo terminal
if err := sess.RequestPty("xterm", 80, 40, modes); err != nil {
return ErrTask{task, fmt.Sprintf("request for pseudo terminal failed: %s", err)}
}
}
err = c.execute(c.env + task.Run)
if err != nil {
return ErrTask{task, err.Error()}
}
c.sess = sess
c.sessOpened = true
c.running = true
return nil
}
// Wait waits until the remote command finishes and exits.
// It closes the SSH session.
func (c *SSHClient) Wait() error {
if !c.running {
return fmt.Errorf("trying to wait on stopped session")
}
err := c.sess.Wait()
c.sess.Close()
c.running = false
c.sessOpened = false
return err
}
// DialThrough will create a new connection from the ssh server sc is connected to. DialThrough is an SSHDialer.
func (sc *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
conn, err := sc.conn.Dial(net, addr)
if err != nil {
return nil, err
}
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
return nil, err
}
return ssh.NewClient(c, chans, reqs), nil
}
// Close closes the underlying SSH connection and session.
func (c *SSHClient) Close() error {
if c.sessOpened {
c.sess.Close()
c.sessOpened = false
}
if !c.connOpened {
return fmt.Errorf("trying to close the already closed connection")
}
err := c.conn.Close()
c.connOpened = false
c.running = false
return err
}
func (c *SSHClient) Stdin() io.WriteCloser {
return c.remoteStdin
}
func (c *SSHClient) Stderr() io.Reader {
return c.remoteStderr
}
func (c *SSHClient) Stdout() io.Reader {
return c.remoteStdout
}
func (c *SSHClient) Prefix() (string, int) {
host := c.user + "@" + c.host + " | "
return c.color + host + ResetColor, len(host)
}
func (c *SSHClient) Write(p []byte) (n int, err error) {
return c.remoteStdin.Write(p)
}
func (c *SSHClient) WriteClose() error {
return c.remoteStdin.Close()
}
func (c *SSHClient) Signal(sig os.Signal) error {
if !c.sessOpened {
return fmt.Errorf("session is not open")
}
switch sig {
case os.Interrupt:
c.remoteStdin.Write([]byte("\x03"))
return c.sess.Signal(ssh.SIGINT)
default:
return fmt.Errorf("%v not supported", sig)
}
}
func (c *SSHClient) prepareStreams(rsOut, rsErr io.Writer) (stdout io.Reader, stderr io.Reader, err error) {
if rsOut != nil {
stdout, err = c.sess.StdoutPipe()
if err != nil {
return nil, nil, fmt.Errorf("unable to setup stdout for session: %v", err)
}
go io.Copy(rsOut, stdout)
}
if rsErr != nil {
stderr, err = c.sess.StderrPipe()
if err != nil {
return nil, nil, fmt.Errorf("unable to setup stderr for session: %v", err)
}
go io.Copy(rsErr, stderr)
}
return
}
func (c *SSHClient) execute(command string) error {
var err error
modes := ssh.TerminalModes{
ssh.ECHO: 0, // disable echoing
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
err = c.sess.RequestPty("xterm", 80, 40, modes)
if err != nil {
return err
}
// Capture stdout and stderr from remote server
var rsOut bytes.Buffer
var rsErr bytes.Buffer
stdOut, stdErr, err := c.prepareStreams(&rsOut, &rsErr)
if err != nil {
return err
}
// Set stdin to provide sudo passwd if it's necessary
stdIn, _ := c.sess.StdinPipe()
endChan := make(chan struct{}) // sudo func ending chan
// Execute the remote command
execFunc(c.sess, command)
if c.sudo {
go sudoFunc(c.pass, stdIn, &rsOut, endChan)
}
err = waitForReaderEmpty(stdOut, &rsOut)
if err != nil {
return err
}
err = waitForReaderEmpty(stdErr, &rsErr)
if err != nil {
return err
}
close(endChan)
return nil
}
func waitForReaderEmpty(reader io.Reader, buff *bytes.Buffer) error {
var (
b = make([]byte, 1024)
)
for {
mtx.Lock()
n, err := reader.Read(b)
mtx.Unlock()
if n == 0 || err == io.EOF {
break
}
if err != nil {
return err
}
buff.Write(b[:n])
}
return nil
}