|
|
@ -1,6 +1,7 @@ |
|
|
|
package stup |
|
|
|
|
|
|
|
import ( |
|
|
|
"bytes" |
|
|
|
"fmt" |
|
|
|
"io" |
|
|
|
"io/ioutil" |
|
|
@ -9,7 +10,6 @@ import ( |
|
|
|
"os/user" |
|
|
|
"path/filepath" |
|
|
|
"strings" |
|
|
|
"sync" |
|
|
|
|
|
|
|
"golang.org/x/crypto/ssh" |
|
|
|
"golang.org/x/crypto/ssh/agent" |
|
|
@ -22,6 +22,8 @@ type SSHClient struct { |
|
|
|
user string |
|
|
|
host string |
|
|
|
pass string |
|
|
|
ident string |
|
|
|
sudo bool |
|
|
|
remoteStdin io.WriteCloser |
|
|
|
remoteStdout io.Reader |
|
|
|
remoteStderr io.Reader |
|
|
@ -83,11 +85,8 @@ func (c *SSHClient) parseHost(host string) error { |
|
|
|
return nil |
|
|
|
} |
|
|
|
|
|
|
|
var initAuthMethodOnce sync.Once |
|
|
|
var authMethod ssh.AuthMethod |
|
|
|
|
|
|
|
// initAuthMethod initiates SSH authentication method.
|
|
|
|
func initAuthMethod() { |
|
|
|
func (c *SSHClient) setPubKeyAuth() (ssh.AuthMethod, error) { |
|
|
|
var signers []ssh.Signer |
|
|
|
|
|
|
|
// If there's a running SSH Agent, try to use its Private keys.
|
|
|
@ -97,24 +96,38 @@ func initAuthMethod() { |
|
|
|
signers, _ = agent.Signers() |
|
|
|
} |
|
|
|
|
|
|
|
// Try to read user's SSH private keys form the standard paths.
|
|
|
|
files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*") |
|
|
|
for _, file := range files { |
|
|
|
if strings.HasSuffix(file, ".pub") { |
|
|
|
continue // Skip public keys.
|
|
|
|
} |
|
|
|
data, err := ioutil.ReadFile(file) |
|
|
|
if c.ident != "" { |
|
|
|
data, err := ioutil.ReadFile(ResolvePath(c.ident)) |
|
|
|
if err != nil { |
|
|
|
continue |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
signer, err := ssh.ParsePrivateKey(data) |
|
|
|
if err != nil { |
|
|
|
continue |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
signers = append(signers, signer) |
|
|
|
} else { |
|
|
|
// Try to read user's SSH private keys form the standard paths.
|
|
|
|
files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*") |
|
|
|
for _, file := range files { |
|
|
|
if strings.HasSuffix(file, ".pub") { |
|
|
|
continue // Skip public keys.
|
|
|
|
} |
|
|
|
data, err := ioutil.ReadFile(file) |
|
|
|
if err != nil { |
|
|
|
continue |
|
|
|
} |
|
|
|
signer, err := ssh.ParsePrivateKey(data) |
|
|
|
if err != nil { |
|
|
|
continue |
|
|
|
} |
|
|
|
signers = append(signers, signer) |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
authMethod = ssh.PublicKeys(signers...) |
|
|
|
|
|
|
|
return ssh.PublicKeys(signers...), err |
|
|
|
} |
|
|
|
|
|
|
|
// SSHDialFunc can dial an ssh server and return a client
|
|
|
@ -130,19 +143,20 @@ func (c *SSHClient) Connect(host string) error { |
|
|
|
// connection.
|
|
|
|
// TODO: Split Signers to its own method.
|
|
|
|
func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error { |
|
|
|
var authMethod ssh.AuthMethod |
|
|
|
|
|
|
|
if c.connOpened { |
|
|
|
return fmt.Errorf("already connected") |
|
|
|
} |
|
|
|
|
|
|
|
initAuthMethodOnce.Do(initAuthMethod) |
|
|
|
|
|
|
|
err := c.parseHost(host) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
if c.pass != "" { |
|
|
|
authMethod = ssh.Password(c.pass) |
|
|
|
authMethod, err = c.setPubKeyAuth() |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
config := &ssh.ClientConfig{ |
|
|
@ -153,6 +167,11 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error { |
|
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(), |
|
|
|
} |
|
|
|
|
|
|
|
if c.pass != "" { |
|
|
|
authMethod = ssh.Password(c.pass) |
|
|
|
config.Auth = append(config.Auth, authMethod) |
|
|
|
} |
|
|
|
|
|
|
|
c.conn, err = dialer("tcp", c.host, config) |
|
|
|
if err != nil { |
|
|
|
return ErrConnect{c.user, c.host, err.Error()} |
|
|
@ -204,7 +223,7 @@ func (c *SSHClient) Run(task *Task) error { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
err = NewSSHExecutor(sess, c.pass).Execute(c.env + task.Run) |
|
|
|
err = c.execute(c.env + task.Run) |
|
|
|
if err != nil { |
|
|
|
return ErrTask{task, err.Error()} |
|
|
|
} |
|
|
@ -299,3 +318,93 @@ func (c *SSHClient) Signal(sig os.Signal) error { |
|
|
|
return fmt.Errorf("%v not supported", sig) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func (c *SSHClient) prepareStreams(rsOut, rsErr io.Writer) (stdout io.Reader, stderr io.Reader, err error) { |
|
|
|
if rsOut != nil { |
|
|
|
stdout, err = c.sess.StdoutPipe() |
|
|
|
if err != nil { |
|
|
|
return nil, nil, fmt.Errorf("unable to setup stdout for session: %v", err) |
|
|
|
} |
|
|
|
go io.Copy(rsOut, stdout) |
|
|
|
} |
|
|
|
|
|
|
|
if rsErr != nil { |
|
|
|
stderr, err = c.sess.StderrPipe() |
|
|
|
if err != nil { |
|
|
|
return nil, nil, fmt.Errorf("unable to setup stderr for session: %v", err) |
|
|
|
} |
|
|
|
go io.Copy(rsErr, stderr) |
|
|
|
} |
|
|
|
|
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
func (c *SSHClient) execute(command string) error { |
|
|
|
var err error |
|
|
|
|
|
|
|
modes := ssh.TerminalModes{ |
|
|
|
ssh.ECHO: 0, // disable echoing
|
|
|
|
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
|
|
|
|
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
|
|
|
|
} |
|
|
|
|
|
|
|
err = c.sess.RequestPty("xterm", 80, 40, modes) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
// Capture stdout and stderr from remote server
|
|
|
|
var rsOut bytes.Buffer |
|
|
|
var rsErr bytes.Buffer |
|
|
|
|
|
|
|
stdOut, stdErr, err := c.prepareStreams(&rsOut, &rsErr) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
// Set stdin to provide sudo passwd if it's necessary
|
|
|
|
stdIn, _ := c.sess.StdinPipe() |
|
|
|
|
|
|
|
endChan := make(chan struct{}) // sudo func ending chan
|
|
|
|
|
|
|
|
// Execute the remote command
|
|
|
|
execFunc(c.sess, command) |
|
|
|
|
|
|
|
if c.sudo { |
|
|
|
go sudoFunc(c.pass, stdIn, &rsOut, endChan) |
|
|
|
} |
|
|
|
|
|
|
|
err = waitForReaderEmpty(stdOut, &rsOut) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
err = waitForReaderEmpty(stdErr, &rsErr) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
close(endChan) |
|
|
|
|
|
|
|
return nil |
|
|
|
} |
|
|
|
|
|
|
|
func waitForReaderEmpty(reader io.Reader, buff *bytes.Buffer) error { |
|
|
|
var ( |
|
|
|
b = make([]byte, 1024) |
|
|
|
) |
|
|
|
|
|
|
|
for { |
|
|
|
mtx.Lock() |
|
|
|
n, err := reader.Read(b) |
|
|
|
mtx.Unlock() |
|
|
|
if n == 0 || err == io.EOF { |
|
|
|
break |
|
|
|
} |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
buff.Write(b[:n]) |
|
|
|
} |
|
|
|
return nil |
|
|
|
} |
|
|
|