Browse Source

增加了密码登录和自定义密钥路径的配置

master
小红帽 3 weeks ago
parent
commit
29264e47c4
  1. 58
      cmd/stup/main.go
  2. 50
      executor.go
  3. 22
      localclient.go
  4. 142
      ssh_executor.go
  5. 149
      sshclient.go
  6. 114
      stup.go
  7. 17
      stupfile.go
  8. 3
      version.go
  9. 15
      vssh/client.go
  10. 104
      vssh/client_test.go
  11. 7
      vssh/example_stream_test.go
  12. 174
      vssh/query.go
  13. 87
      vssh/query_test.go
  14. 56
      vssh/vssh.go
  15. 18
      vssh/vssh_test.go
  16. 1
      vsshclient.go

58
cmd/stup/main.go

@ -5,8 +5,6 @@ import (
"fmt"
"io/ioutil"
"os"
"os/user"
"path/filepath"
"regexp"
"strings"
"text/tabwriter"
@ -79,7 +77,7 @@ func networkUsage(conf *stup.Stupfile) {
fmt.Fprintf(w, "- %v\n", name)
network, _ := conf.Networks.Get(name)
for _, host := range network.Hosts {
fmt.Fprintf(w, "\t- %v\n", host)
fmt.Fprintf(w, " - %v\n", host.Address)
}
}
fmt.Fprintln(w)
@ -94,13 +92,13 @@ func cmdUsage(conf *stup.Stupfile) {
fmt.Fprintln(w, "Targets:\t")
for _, name := range conf.Targets.Names {
cmds, _ := conf.Targets.Get(name)
fmt.Fprintf(w, "- %v\t%v\n", name, strings.Join(cmds, " "))
fmt.Fprintf(w, " - %v\t%v\n", name, strings.Join(cmds, " "))
}
fmt.Fprintln(w, "\t")
fmt.Fprintln(w, "Commands:\t")
for _, name := range conf.Commands.Names {
cmd, _ := conf.Commands.Get(name)
fmt.Fprintf(w, "- %v\t%v\n", name, cmd.Desc)
fmt.Fprintf(w, " - %v\t%v\n", name, cmd.Desc)
}
fmt.Fprintln(w)
}
@ -162,19 +160,19 @@ func parseArgs(conf *stup.Stupfile) (*stup.Network, []*stup.Command, error) {
}
// Add default env variable with current network
network.Env.Set("SUP_NETWORK", args[0])
network.Env.Set("STUP_NETWORK", args[0])
// Add default nonce
network.Env.Set("SUP_TIME", time.Now().UTC().Format(time.RFC3339))
if os.Getenv("SUP_TIME") != "" {
network.Env.Set("SUP_TIME", os.Getenv("SUP_TIME"))
network.Env.Set("STUP_TIME", time.Now().UTC().Format(time.RFC3339))
if os.Getenv("STUP_TIME") != "" {
network.Env.Set("STUP_TIME", os.Getenv("STUP_TIME"))
}
// Add user
if os.Getenv("SUP_USER") != "" {
network.Env.Set("SUP_USER", os.Getenv("SUP_USER"))
if os.Getenv("STUP_USER") != "" {
network.Env.Set("STUP_USER", os.Getenv("STUP_USER"))
} else {
network.Env.Set("SUP_USER", os.Getenv("USER"))
network.Env.Set("STUP_USER", os.Getenv("USER"))
}
for _, cmd := range args[1:] {
@ -209,19 +207,6 @@ func parseArgs(conf *stup.Stupfile) (*stup.Network, []*stup.Command, error) {
return &network, commands, nil
}
func resolvePath(path string) string {
if path == "" {
return ""
}
if path[:2] == "~/" {
usr, err := user.Current()
if err == nil {
path = filepath.Join(usr.HomeDir, path[2:])
}
}
return path
}
func main() {
flag.Parse()
@ -239,7 +224,7 @@ func main() {
if stupfile == "" {
stupfile = "./Stupfile"
}
data, err := ioutil.ReadFile(resolvePath(stupfile))
data, err := ioutil.ReadFile(stup.ResolvePath(stupfile))
if err != nil {
firstErr := err
data, err = ioutil.ReadFile("./Stupfile.yml") // Alternative to ./Stupfile.
@ -306,7 +291,7 @@ func main() {
// --sshconfig flag location for ssh_config file
if sshConfig != "" {
confHosts, err := sshconfig.ParseSSHConfig(resolvePath(sshConfig))
confHosts, err := sshconfig.ParseSSHConfig(stup.ResolvePath(sshConfig))
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
@ -326,15 +311,20 @@ func main() {
for n, host := range network.Hosts {
conf, found := confMap[host.Address]
if found {
ident, err = ioutil.ReadFile(resolvePath(conf.IdentityFile))
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
if conf.IdentityFile != "" {
ident, err = ioutil.ReadFile(stup.ResolvePath(conf.IdentityFile))
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
network.Hosts[n].Identity = string(ident)
}
if conf.User != "" {
network.Hosts[n].User = conf.User
}
network.Hosts[n].User = conf.User
network.Hosts[n].Identity = string(ident)
network.Hosts[n].Address = fmt.Sprintf("%s:%d", conf.HostName, conf.Port)
network.Hosts[n].Address = fmt.Sprintf("%s:%d", host.Address, conf.Port)
}
}
}

50
executor.go

@ -0,0 +1,50 @@
package stup
import (
"bytes"
"fmt"
"io"
"os"
"strings"
"sync"
"github.com/pkg/errors"
)
var (
mtx = &sync.Mutex{}
sudoFunc = func(sudoerPassword string, in io.Writer, output *bytes.Buffer, endChan chan struct{}) {
for {
select {
case <-endChan:
default:
//TODO: Refactor it
mtx.Lock()
if output.Len() > 0 {
msg := output.String()
if strings.Contains(msg, "[sudo] ") {
_, err := in.Write([]byte(sudoerPassword + "\n"))
if err != nil && err != io.EOF {
fmt.Fprintf(os.Stderr, "%v", errors.Wrap(err, "some went wrong when trying remote sudo"))
}
}
}
mtx.Unlock()
}
}
}
execFunc = func(session Session, command string) {
err := session.Run(command)
if err != nil {
fmt.Fprintf(os.Stderr, "%v", errors.Wrap(err, "some went wrong when trying execute remote command"))
}
}
)
type Session interface {
Run(cmd string) error
Wait() error
Close() error
}

22
localclient.go

@ -1,6 +1,7 @@
package stup
import (
"bytes"
"fmt"
"io"
"os"
@ -10,10 +11,11 @@ import (
"github.com/pkg/errors"
)
// Client is a wrapper over the SSH connection/sessions.
type LocalhostClient struct {
cmd *exec.Cmd
user string
pass string
sudo bool
stdin io.WriteCloser
stdout io.Reader
stderr io.Reader
@ -27,7 +29,10 @@ func (c *LocalhostClient) Connect(_ string) error {
return err
}
c.user = u.Username
if c.user == "" {
c.user = u.Username
}
return nil
}
@ -56,10 +61,23 @@ func (c *LocalhostClient) Run(task *Task) error {
return err
}
var rsOut bytes.Buffer
endChan := make(chan struct{})
if c.sudo {
go sudoFunc(c.pass, c.stdin, &rsOut, endChan)
}
if err := c.cmd.Start(); err != nil {
fmt.Fprintf(os.Stderr, "%v", errors.Wrap(err, "some went wrong when trying execute remote command"))
return ErrTask{task, err.Error()}
}
if err := c.Wait(); err != nil {
return err
}
c.running = true
return nil
}

142
ssh_executor.go

@ -1,142 +0,0 @@
package stup
import (
"bytes"
"fmt"
"io"
"os"
"strings"
"sync"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
)
type (
SSHExecutor struct {
session *ssh.Session
sudopass string
}
)
var (
mtx = &sync.Mutex{}
sudoFunc = func(sudoerPassword string, in io.Writer, output *bytes.Buffer, endChan chan struct{}) {
for {
select {
case <-endChan:
default:
//TODO: Refactor it
mtx.Lock()
if output.Len() > 0 {
msg := output.String()
if strings.Contains(msg, "[sudo] ") {
_, err := in.Write([]byte(sudoerPassword + "\n"))
if err != nil && err != io.EOF {
fmt.Fprintf(os.Stderr, "%v", errors.Wrap(err, "some went wrong when trying remote sudo"))
}
}
}
mtx.Unlock()
}
}
}
execFunc = func(session *ssh.Session, command string) {
err := session.Run(command)
if err != nil {
fmt.Fprintf(os.Stderr, "%v", errors.Wrap(err, "some went wrong when trying execute remote command"))
}
}
)
func NewSSHExecutor(session *ssh.Session, password string) *SSHExecutor {
return &SSHExecutor{session: session, sudopass: password}
}
func (s *SSHExecutor) PrepareStreams(rsOut, rsErr io.Writer) (stdout io.Reader, stderr io.Reader, err error) {
if rsOut != nil {
stdout, err = s.session.StdoutPipe()
if err != nil {
return nil, nil, fmt.Errorf("unable to setup stdout for session: %v", err)
}
go io.Copy(rsOut, stdout)
}
if rsErr != nil {
stderr, err = s.session.StderrPipe()
if err != nil {
return nil, nil, fmt.Errorf("unable to setup stderr for session: %v", err)
}
go io.Copy(rsErr, stderr)
}
return
}
func (s *SSHExecutor) Execute(command string) error {
var err error
modes := ssh.TerminalModes{
ssh.ECHO: 0, // disable echoing
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
err = s.session.RequestPty("xterm", 80, 40, modes)
if err != nil {
return err
}
// Capture stdout and stderr from remote server
var rsOut bytes.Buffer
var rsErr bytes.Buffer
stdOut, stdErr, err := s.PrepareStreams(&rsOut, &rsErr)
if err != nil {
return err
}
// Set stdin to provide sudo passwd if it's necessary
stdIn, _ := s.session.StdinPipe()
endChan := make(chan struct{}) // sudo func ending chan
go sudoFunc(s.sudopass, stdIn, &rsOut, endChan)
// Execute the remote command
execFunc(s.session, command)
err = waitForReaderEmpty(stdOut, &rsOut)
if err != nil {
return err
}
err = waitForReaderEmpty(stdErr, &rsErr)
if err != nil {
return err
}
close(endChan)
return nil
}
func waitForReaderEmpty(reader io.Reader, buff *bytes.Buffer) error {
var (
b = make([]byte, 1024)
)
for {
mtx.Lock()
n, err := reader.Read(b)
mtx.Unlock()
if n == 0 || err == io.EOF {
break
}
if err != nil {
return err
}
buff.Write(b[:n])
}
return nil
}

149
sshclient.go

@ -1,6 +1,7 @@
package stup
import (
"bytes"
"fmt"
"io"
"io/ioutil"
@ -9,7 +10,6 @@ import (
"os/user"
"path/filepath"
"strings"
"sync"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
@ -22,6 +22,8 @@ type SSHClient struct {
user string
host string
pass string
ident string
sudo bool
remoteStdin io.WriteCloser
remoteStdout io.Reader
remoteStderr io.Reader
@ -83,11 +85,8 @@ func (c *SSHClient) parseHost(host string) error {
return nil
}
var initAuthMethodOnce sync.Once
var authMethod ssh.AuthMethod
// initAuthMethod initiates SSH authentication method.
func initAuthMethod() {
func (c *SSHClient) setPubKeyAuth() (ssh.AuthMethod, error) {
var signers []ssh.Signer
// If there's a running SSH Agent, try to use its Private keys.
@ -97,24 +96,38 @@ func initAuthMethod() {
signers, _ = agent.Signers()
}
// Try to read user's SSH private keys form the standard paths.
files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*")
for _, file := range files {
if strings.HasSuffix(file, ".pub") {
continue // Skip public keys.
}
data, err := ioutil.ReadFile(file)
if c.ident != "" {
data, err := ioutil.ReadFile(ResolvePath(c.ident))
if err != nil {
continue
return nil, err
}
signer, err := ssh.ParsePrivateKey(data)
if err != nil {
continue
return nil, err
}
signers = append(signers, signer)
} else {
// Try to read user's SSH private keys form the standard paths.
files, _ := filepath.Glob(os.Getenv("HOME") + "/.ssh/id_*")
for _, file := range files {
if strings.HasSuffix(file, ".pub") {
continue // Skip public keys.
}
data, err := ioutil.ReadFile(file)
if err != nil {
continue
}
signer, err := ssh.ParsePrivateKey(data)
if err != nil {
continue
}
signers = append(signers, signer)
}
}
authMethod = ssh.PublicKeys(signers...)
return ssh.PublicKeys(signers...), err
}
// SSHDialFunc can dial an ssh server and return a client
@ -130,19 +143,20 @@ func (c *SSHClient) Connect(host string) error {
// connection.
// TODO: Split Signers to its own method.
func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error {
var authMethod ssh.AuthMethod
if c.connOpened {
return fmt.Errorf("already connected")
}
initAuthMethodOnce.Do(initAuthMethod)
err := c.parseHost(host)
if err != nil {
return err
}
if c.pass != "" {
authMethod = ssh.Password(c.pass)
authMethod, err = c.setPubKeyAuth()
if err != nil {
return err
}
config := &ssh.ClientConfig{
@ -153,6 +167,11 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error {
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
if c.pass != "" {
authMethod = ssh.Password(c.pass)
config.Auth = append(config.Auth, authMethod)
}
c.conn, err = dialer("tcp", c.host, config)
if err != nil {
return ErrConnect{c.user, c.host, err.Error()}
@ -204,7 +223,7 @@ func (c *SSHClient) Run(task *Task) error {
}
}
err = NewSSHExecutor(sess, c.pass).Execute(c.env + task.Run)
err = c.execute(c.env + task.Run)
if err != nil {
return ErrTask{task, err.Error()}
}
@ -299,3 +318,93 @@ func (c *SSHClient) Signal(sig os.Signal) error {
return fmt.Errorf("%v not supported", sig)
}
}
func (c *SSHClient) prepareStreams(rsOut, rsErr io.Writer) (stdout io.Reader, stderr io.Reader, err error) {
if rsOut != nil {
stdout, err = c.sess.StdoutPipe()
if err != nil {
return nil, nil, fmt.Errorf("unable to setup stdout for session: %v", err)
}
go io.Copy(rsOut, stdout)
}
if rsErr != nil {
stderr, err = c.sess.StderrPipe()
if err != nil {
return nil, nil, fmt.Errorf("unable to setup stderr for session: %v", err)
}
go io.Copy(rsErr, stderr)
}
return
}
func (c *SSHClient) execute(command string) error {
var err error
modes := ssh.TerminalModes{
ssh.ECHO: 0, // disable echoing
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
err = c.sess.RequestPty("xterm", 80, 40, modes)
if err != nil {
return err
}
// Capture stdout and stderr from remote server
var rsOut bytes.Buffer
var rsErr bytes.Buffer
stdOut, stdErr, err := c.prepareStreams(&rsOut, &rsErr)
if err != nil {
return err
}
// Set stdin to provide sudo passwd if it's necessary
stdIn, _ := c.sess.StdinPipe()
endChan := make(chan struct{}) // sudo func ending chan
// Execute the remote command
execFunc(c.sess, command)
if c.sudo {
go sudoFunc(c.pass, stdIn, &rsOut, endChan)
}
err = waitForReaderEmpty(stdOut, &rsOut)
if err != nil {
return err
}
err = waitForReaderEmpty(stdErr, &rsErr)
if err != nil {
return err
}
close(endChan)
return nil
}
func waitForReaderEmpty(reader io.Reader, buff *bytes.Buffer) error {
var (
b = make([]byte, 1024)
)
for {
mtx.Lock()
n, err := reader.Read(b)
mtx.Unlock()
if n == 0 || err == io.EOF {
break
}
if err != nil {
return err
}
buff.Write(b[:n])
}
return nil
}

114
stup.go

@ -5,6 +5,8 @@ import (
"io"
"os"
"os/signal"
"os/user"
"path/filepath"
"strings"
"sync"
@ -13,8 +15,6 @@ import (
"golang.org/x/crypto/ssh"
)
const VERSION = "1.0"
type Startup struct {
conf *Stupfile
debug bool
@ -27,10 +27,23 @@ func New(conf *Stupfile) (*Startup, error) {
}, nil
}
func ResolvePath(path string) string {
if path == "" {
return ""
}
if path[:2] == "~/" {
usr, err := user.Current()
if err == nil {
path = filepath.Join(usr.HomeDir, path[2:])
}
}
return path
}
// Run runs set of commands on multiple hosts defined by network sequentially.
// TODO: This megamoth method needs a big refactor and should be split
// to multiple smaller methods.
func (stup *Startup) Run(network *Network, envVars EnvList, commands ...*Command) error {
func (s *Startup) Run(network *Network, envVars EnvList, commands ...*Command) error {
if len(commands) == 0 {
return errors.New("no commands to be run")
}
@ -40,7 +53,18 @@ func (stup *Startup) Run(network *Network, envVars EnvList, commands ...*Command
// Create clients for every host (either SSH or Localhost).
var bastion *SSHClient
if network.Bastion != "" {
bastion = &SSHClient{}
bastion = &SSHClient{
user: network.User,
}
if network.Password != "" {
bastion.pass = network.Password
}
if network.Identity != "" {
bastion.ident = network.Identity
}
if err := bastion.Connect(network.Bastion); err != nil {
return errors.Wrap(err, "connecting to bastion failed")
}
@ -50,19 +74,59 @@ func (stup *Startup) Run(network *Network, envVars EnvList, commands ...*Command
clientCh := make(chan Client, len(network.Hosts))
errCh := make(chan error, len(network.Hosts))
var password string
for i, host := range network.Hosts {
wg.Add(1)
go func(i int, host Instance) {
defer wg.Done()
var user string
var password string
var ident string
var sudo bool
addr := host.Address
if network.User != "" {
user = network.User
}
if host.User != "" {
user = host.User
}
if network.User == "" && host.User == "" {
user = "root"
}
if user == "root" {
sudo = false
}
if network.Sudo || host.Sudo {
sudo = true
}
if network.Password != "" {
password = network.Password
}
if host.Password != "" {
password = host.Password
}
// Localhost client.
if host.Address == "localhost" {
if addr == "localhost" {
if host.Sudo && host.User == "" {
user = host.User
}
local := &LocalhostClient{
env: env + `export SUP_HOST="` + host.Address + `";`,
env: env + `export SUP_HOST="localhost";`,
user: user,
pass: password,
sudo: sudo,
}
if err := local.Connect(host.Address); err != nil {
if err := local.Connect(addr); err != nil {
errCh <- errors.Wrap(err, "connecting to localhost failed")
return
}
@ -70,29 +134,31 @@ func (stup *Startup) Run(network *Network, envVars EnvList, commands ...*Command
return
}
if network.Password != "" {
password = network.Password
if network.Identity != "" {
ident = network.Inventory
}
if host.Password != "" {
password = host.Password
if host.Identity != "" {
ident = host.Identity
}
// SSH client.
remote := &SSHClient{
env: env + `export SUP_HOST="` + host.Address + `";`,
user: network.User,
env: env + `export SUP_HOST="` + addr + `";`,
user: user,
pass: password,
ident: ident,
sudo: sudo,
color: Colors[i%len(Colors)],
}
if bastion != nil {
if err := remote.ConnectWith(host.Address, bastion.DialThrough); err != nil {
if err := remote.ConnectWith(addr, bastion.DialThrough); err != nil {
errCh <- errors.Wrap(err, "connecting to remote host through bastion failed")
return
}
} else {
if err := remote.Connect(host.Address); err != nil {
if err := remote.Connect(addr); err != nil {
errCh <- errors.Wrap(err, "connecting to remote host failed")
return
}
@ -120,7 +186,7 @@ func (stup *Startup) Run(network *Network, envVars EnvList, commands ...*Command
// Run command or run multiple commands defined by target sequentially.
for _, cmd := range commands {
// Translate command into task(s).
tasks, err := stup.createTasks(cmd, clients, env)
tasks, err := s.createTasks(cmd, clients, env)
if err != nil {
return errors.Wrap(err, "creating task failed")
}
@ -134,7 +200,7 @@ func (stup *Startup) Run(network *Network, envVars EnvList, commands ...*Command
for _, c := range task.Clients {
var prefix string
var prefixLen int
if stup.prefix {
if s.prefix {
prefix, prefixLen = c.Prefix()
if len(prefix) < maxLen { // Left padding.
prefix = strings.Repeat(" ", maxLen-prefixLen) + prefix
@ -217,7 +283,7 @@ func (stup *Startup) Run(network *Network, envVars EnvList, commands ...*Command
defer wg.Done()
if err := c.Wait(); err != nil {
var prefix string
if stup.prefix {
if s.prefix {
var prefixLen int
prefix, prefixLen = c.Prefix()
if len(prefix) < maxLen { // Left padding.
@ -251,10 +317,10 @@ func (stup *Startup) Run(network *Network, envVars EnvList, commands ...*Command
return nil
}
func (sup *Startup) Debug(value bool) {
sup.debug = value
func (s *Startup) Debug(value bool) {
s.debug = value
}
func (sup *Startup) Prefix(value bool) {
sup.prefix = value
func (s *Startup) Prefix(value bool) {
s.prefix = value
}

17
stupfile.go

@ -19,11 +19,16 @@ type Stupfile struct {
Version string `yaml:"version"`
}
type Instance struct {
Address string `yaml:"address"`
type Accunt struct {
User string `yaml:"user"`
Password string `yaml:"password"`
Identity string `yaml:"identity"`
Sudo bool `yaml:"sudo"`
}
type Instance struct {
Address string `yaml:"address"`
Accunt
}
// Network is group of hosts with extra custom env vars.
@ -32,9 +37,7 @@ type Network struct {
Inventory string `yaml:"inventory"`
Hosts []Instance `yaml:"hosts"`
Bastion string `yaml:"bastion"`
User string `yaml:"user"`
Password string `yaml:"password"`
Identity string `yaml:"identity"`
Accunt
}
// Networks is a list of user-defined networks
@ -255,11 +258,11 @@ type ErrUnsupportedSupfileVersion struct {
}
func (e ErrMustUpdate) Error() string {
return fmt.Sprintf("%v\n\nPlease update sup by `go get -u github.com/pressly/sup/cmd/sup`", e.Msg)
return fmt.Sprintf("%v\n\nPlease update stup", e.Msg)
}
func (e ErrUnsupportedSupfileVersion) Error() string {
return fmt.Sprintf("%v\n\nCheck your Supfile version (available latest version: v0.5)", e.Msg)
return fmt.Sprintf("%v\n\nCheck your Stupfile version", e.Msg)
}
// NewStupfile parses configuration file and returns Supfile or error.

3
version.go

@ -0,0 +1,3 @@
package stup
const VERSION = "1.0"

15
vssh/client.go

@ -49,7 +49,6 @@ type clientStats struct {
// clientAttr represents client attributes
type clientAttr struct {
addr string
labels map[string]string
config *ssh.ClientConfig
client *ssh.Client
logger *log.Logger
@ -210,7 +209,6 @@ LOOP:
case <-done:
break LOOP
}
}
if err = session.Wait(); err != nil {
@ -326,19 +324,6 @@ func (c *clientAttr) getClient() *ssh.Client {
return c.client
}
func (c *clientAttr) labelMatch(v *visitor) bool {
if len(c.labels) < 1 {
return false
}
ok, err := exprEval(v, c.labels)
if err != nil {
return false
}
return ok
}
func (c *clientAttr) connect() {
c.Lock()
defer c.Unlock()

104
vssh/client_test.go

@ -23,44 +23,15 @@ import (
var (
rsaPrivate = []byte(`-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn
NhAAAAAwEAAQAAAYEA0GQrT5+y8RIn+6si0BMCrQ5IwnoLbmHo3bwqoVOZm908olK7npvh
m5P9LGOLYnElvgn83S2LV4H+zQeBci2r3N82C2L/c8E2DMYY3/eRD0zWTIkqgR8w3iXz9i
vIsN9TC2fHJe2VUC/fBD68aJRdbR0T9od/qxY2WBCbtkxlFJIK6mm3OPpadhghn3JbmiOq
MOREwZGyiw9XnHLJayDBFY0pGKaSpzh8kujBtf0nehRLy3WRLKrb6/OJgH0kVRP3JYomeG
xNWzXtE8QQnN+ROGYbx9qM0E1Tu3qhdWJfvpy/y1rRUIEjmd6BNCZKW86u9Y2cU/iwYmNC
dO4zsY5K3v5iceEyZjCsPibEbsQwpCkzwd/mIg9hZoJxF7MUSYlz8TFNpIy8VRMm3RfY2W
KnuXHRG0Tiabj2Nv8U+CDBjvVtTBVw7YgEDHOYW3jdKqYJyJoWyqUM8e5k86UUCLqTsT9l
sW6QYJP3cDu7J1QC05xl3M+5v3YaW2Bogw+nDu8zAAAFkHMouHlzKLh5AAAAB3NzaC1yc2
EAAAGBANBkK0+fsvESJ/urItATAq0OSMJ6C25h6N28KqFTmZvdPKJSu56b4ZuT/Sxji2Jx
Jb4J/N0ti1eB/s0HgXItq9zfNgti/3PBNgzGGN/3kQ9M1kyJKoEfMN4l8/YryLDfUwtnxy
XtlVAv3wQ+vGiUXW0dE/aHf6sWNlgQm7ZMZRSSCupptzj6WnYYIZ9yW5ojqjDkRMGRsosP
V5xyyWsgwRWNKRimkqc4fJLowbX9J3oUS8t1kSyq2+vziYB9JFUT9yWKJnhsTVs17RPEEJ
zfkThmG8fajNBNU7t6oXViX76cv8ta0VCBI5negTQmSlvOrvWNnFP4sGJjQnTuM7GOSt7+
YnHhMmYwrD4mxG7EMKQpM8Hf5iIPYWaCcRezFEmJc/ExTaSMvFUTJt0X2Nlip7lx0RtE4m
m49jb/FPggwY71bUwVcO2IBAxzmFt43SqmCciaFsqlDPHuZPOlFAi6k7E/ZbFukGCT93A7
uydUAtOcZdzPub92GltgaIMPpw7vMwAAAAMBAAEAAAGABy7cu1Li5SJeFHOysH9nQTXT1j
hEuppPX41D3um1ysSWeXXml7IB1c4FFQmdXVhPF7zaZXlTa0HE2aZflOL0IJnlEAFqkr/f
MBOH+fhbnK5mWJ8FwwujMJUYUqzxrv8Tqrn6CFmnIutzgX70GZq7ma496OqEwQ3z85cm9u
KtPUdHbwsT0Lf4dEeiqQ9VDvwZurOzlwSBpf9yYqcmQDYR0b9a4kmjlnYA/UNeofpG6RNY
BXxY87Qz/m8Xl0E5BmG4vDOwdpEjR7a6nQ+iM0MJ5cD03Y14jVXMEwE6MLq2bRVT+MC12m
J3Bi0r247MxrLlTr3Yt6690mnn7P/liKJr9YWI943sUYd5DmMA8s4ibmK059ApdC6ymMbK
0SfKBrH3tpo2jOvLzJ/sZrQ20XRM88C6mMdsz7EGk5jTpETd4QqC+4qfClGhnFCNpUEPzG
YejsnydiWdAdkNpsUxjFL7XunCFy90eaYogZRs8wBwSAp2MAt9HN5ZYa7qcBif0W4xAAAA
wADrMYUt40vGt2jA4QZxy4zB9Vsi/xBDcoRi9TVJ3/KcFl2scUr1c7cZuDE45C04XaZ3zm
c8c9LEWG0QOQuFpQ9UkCH7Uj6QsH8BrjUeE2NUFMGeLJOvFKzWlhtHWx4OaEYbWiAk6khe
7gj8Rw7D4G+iddk5w18TeyHwYmp/dOLsX5Lc/Czn0L8Bl89wijs9F33Yg0vrQoWtAmGtYk
1OJJHvgglRgBwrK65hWQH6bZNGF2vtYPM9EqYjsSZYNR2JQAAAAMEA9XL6tJBm/MIN6Z+g
loXtWWHZ6o4HGts/A/17WF/8Z5lbBp5eIkrAKFEwC4zAmCIdjQEOxaVcVzps+kSeZWoS7t
ijrlDCeJjOqvFnzb7YkUNGWhhLbKJK/vGsPa3T20XgtApS2NgREQrq8jQ/FPNnZZJ15rUV
Z4eg6i5lHKRTXFguBh+D3FF3P4ECs5jX5cHQPFrhmsE+jpsSTTqXBxTcb6ATCv6zamTzc3
16sfaMSRnU0Fg/D9dbx+OeHmb56b4pAAAAwQDZWWPx6fgtftO+HbjKfN2wrveUR6Mx8xxx
/J3m9uy2WNWEZ2NN6EL1x4/bk/KIcUxvVL7Kyev+f30YxSyGgjXnl5S13Uker7XtaG7lWJ
xZe9KaFo+tXOg6ThEf/IFPjcGjJxNfNwYaszzdyXoS9HmM6S0GUqbrF84IjFNCqsNtnK2I
L+Ha2sPh5OB4w+j/xdvWwdevCA11HE3MDqjN6Uq0EMKfAlEbgkqePQB+uiFhSf3laAybgm
KNj5a3Q/DLNfsAAAAbbWVocmRhZEBNLU1hY0Jvb2stUHJvLmxvY2Fs
-----END OPENSSH PRIVATE KEY-----`)
testSrvAddr = "127.0.0.1:5522"
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAaAAAABNlY2RzYS
1zaGEyLW5pc3RwMjU2AAAACG5pc3RwMjU2AAAAQQSX6mZs22fsQAv6Nd2kg7NA0eapBEkK
3+irgEu+QPypnNXl0xGEY5leAk3/zpIa0IqQL1pAdJF9Jx/SdLbi8u2LAAAAqGyonJ1sqJ
ydAAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJfqZmzbZ+xAC/o1
3aSDs0DR5qkESQrf6KuAS75A/Kmc1eXTEYRjmV4CTf/OkhrQipAvWkB0kX0nH9J0tuLy7Y
sAAAAhAO3EhXxaQHMkqPuI4XXfJFOnrxlhH0FfCtdWG+a6ByFJAAAACXJvb3RAZmluZQEC
AwQFBg==
-----END OPENSSH PRIVATE KEY-----`)
testSrvAddr = "127.0.0.1:3222"
testSrvStdin = ""
testSrvSig ssh.Signal
)
@ -155,7 +126,6 @@ func TestSessionsMaxOutRaceQueries(t *testing.T) {
if maxOutCount != 1 || nilErr != 1 {
t.Error("sessions maxout race failed")
}
}
func TestTimeout(t *testing.T) {
@ -209,62 +179,6 @@ func TestGoroutineLeak(t *testing.T) {
}
}
func TestQueryWithLabel(t *testing.T) {
vs := New()
vs.Start()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
timeout, _ := time.ParseDuration("2s")
config := GetConfigUserPass("vssh", "vssh")
labels := map[string]string{"POP": "ORD"}
vs.AddClient(testSrvAddr, config, SetLabels(labels), DisableRequestPty())
vs.Wait(100)
respChan, err := vs.RunWithLabel(ctx, "ping", "POP==LAX", timeout)
if err != nil {
t.Fatal(err)
}
// should be empty result
_, ok := <-respChan
if ok {
t.Fatal("expect to get false but got", ok)
}
respChan, err = vs.RunWithLabel(ctx, "ping", "POP==ORD", timeout)
if err != nil {
t.Fatal(err)
}
// should be return result
_, ok = <-respChan
if !ok {
t.Fatal("expect to get true but got", ok)
}
client, ok := vs.clients.get(testSrvAddr)
if !ok {
t.Error("expect test-client but not exist")
}
// persistent connection test
err = client.client.Close()
if err != nil {
t.Error("unexpected error as client should be open")
}
// wrong query
_, err = vs.RunWithLabel(ctx, "ping", "POP=ORD", timeout)
if err == nil {
t.Fatal(err)
}
}
func TestOnDemand(t *testing.T) {
vs := New().Start().OnDemand()

7
vssh/example_stream_test.go

@ -12,12 +12,9 @@ func Example_stream() {
var wg sync.WaitGroup
vs := New().Start()
config, err := GetConfigPEM("ubuntu", "myaws.pem")
if err != nil {
log.Fatal(err)
}
config := GetConfigUserPass("root", "gnu$2021")
for _, addr := range []string{"3.101.78.17:22", "13.57.12.15:22"} {
for _, addr := range []string{"104.233.130.113"} {
vs.AddClient(addr, config, SetMaxSessions(4))
}

174
vssh/query.go

@ -5,47 +5,19 @@ package vssh
import (
"context"
"errors"
"fmt"
"go/ast"
"go/parser"
"strconv"
"sync"
"time"
)
const (
idLogic = iota
idOp
idName
idValue
idEvald
opAnd = 34
opOr = 35
)
var errNotSupportOperator = errors.New("operator doesn't support")
var errQuery = errors.New("query error")
type query struct {
cmd string
ctx context.Context
respChan chan *Response
respTimeout time.Duration
stmt string
compiledQuery *visitor
limitReadOut int64
limitReadErr int64
}
type visitor struct {
idents []ident
}
type ident struct {
value string
vType int
cmd string
ctx context.Context
respChan chan *Response
respTimeout time.Duration
stmt string
limitReadOut int64
limitReadErr int64
}
func (q *query) errResp(id string, err error) {
@ -63,7 +35,7 @@ func (q *query) run(v *VSSH) {
go func() {
defer wg.Done()
if len(q.stmt) > 0 && !client.labelMatch(q.compiledQuery) {
if len(q.stmt) > 0 {
return
}
@ -89,133 +61,3 @@ func (q *query) run(v *VSSH) {
wg.Wait()
close(q.respChan)
}
func (f *visitor) Visit(n ast.Node) ast.Visitor {
switch d := n.(type) {
case *ast.Ident:
f.idents = append(f.idents, ident{d.Name, idName})
case *ast.BasicLit:
f.idents = append(f.idents, ident{d.Value, idValue})
case *ast.BinaryExpr:
if d.Op == opAnd || d.Op == opOr {
f.idents = append(f.idents, ident{d.Op.String(), idLogic})
} else {
f.idents = append(f.idents, ident{d.Op.String(), idOp})
}
}
return f
}
func parseExpr(expr string) (*visitor, error) {
e, err := parser.ParseExpr(expr)
if err != nil {
return nil, err
}
var v visitor
ast.Walk(&v, e)
return &v, nil
}
func exprEval(v *visitor, labels map[string]string) (bool, error) {
results := []ident{}
i := 0
for {
if len(v.idents[i:]) < 3 {
break
}
if v.idents[i].vType == 0 {
results = append(results, v.idents[i])
i++
continue
}
if relOpEval(v.idents[i:i+3], labels) {
results = append(results, ident{"true", idEvald})
} else {
results = append(results, ident{"false", idEvald})
}
if i+3 < len(v.idents) {
i = i + 3
} else {
break
}
}
if len(results) < 1 {
return false, errQuery
}
ok, err := binOpEval(&results)
if err != nil {
return ok, err
}
return ok, nil
}
func binOpEval(idents *[]ident) (bool, error) {
i := 0
for {
if len((*idents)) == 1 {
r, _ := strconv.ParseBool((*idents)[0].value)
return r, nil
}
if (*idents)[i].vType == 0 && (*idents)[i+1].vType == 0 {
i++
continue
}
if len((*idents)) > 2 && (*idents)[i].vType == idLogic &&
(*idents)[i+1].vType == idEvald && (*idents)[i+2].vType == idEvald {
r1, _ := strconv.ParseBool((*idents)[i+1].value)
r2, _ := strconv.ParseBool((*idents)[i+2].value)
if (*idents)[i].value == "&&" {
(*idents)[i].value = strconv.FormatBool(r1 && r2)
} else if (*idents)[i+0].value == "||" {
(*idents)[i].value = strconv.FormatBool(r1 || r2)
} else {
return false, errNotSupportOperator
}
(*idents)[i].vType = idEvald
if len((*idents)) > 2 {
(*idents) = append((*idents)[:i+1], (*idents)[i+3:]...)
i = 0
continue
} else {
tmp := (*idents)[0:1]
*idents = tmp
}
}
if len((*idents)) > 4 {
i += 2
} else {
return false, errNotSupportOperator
}
}
}
func relOpEval(idents []ident, labels map[string]string) bool {
if idents[0].value == "==" {
if _, ok := labels[idents[1].value]; ok {
return labels[idents[1].value] == idents[2].value
}
} else {
if _, ok := labels[idents[1].value]; ok {
return labels[idents[1].value] != idents[2].value
}
}
return false
}

87
vssh/query_test.go

@ -1,87 +0,0 @@
//: Copyright Verizon Media
//: Licensed under the terms of the Apache 2.0 License. See LICENSE file in the project root for terms.
package vssh
import (
"testing"
)
func TestQueryExprEval(t *testing.T) {
labels := map[string]string{"POP": "LAX", "OS": "JUNOS"}
exprTests := []struct {
expr string
expected bool
}{
{"POP==LAX", true},
{"POP!=LAX", false},
{"POP==LAX && OS==JUNOS", true},
{"POP==LAX && OS!=JUNOS", false},
{"(POP==LAX || POP==BUR) && OS==JUNOS", true},
{"OS==JUNOS && (POP==LAX || POP==BUR)", true},
{"OS!=JUNOS && (POP==LAX || POP==BUR)", false},
{"(OS==JUNOS) && (POP==LAX || POP==BUR)", true},
{"((OS==JUNOS) && (POP==LAX || POP==BUR))", true},
}
for _, x := range exprTests {
v, err := parseExpr(x.expr)
if err != nil {
t.Fatal(err)
}
ok, err := exprEval(v, labels)
if err != nil {
t.Fatal(err)
}
if ok != x.expected {
t.Fatalf("expect %t, got %t", x.expected, ok)
}
}
_, err := parseExpr("OS=JUNOS")
if err == nil {
t.Fatal("expect error but got nil")
}
v, err := parseExpr("OS")
if err != nil {
t.Fatal("expect error but got nil")
}
_, err = exprEval(v, labels)
if err == nil {
t.Fatal("expect error but got nil")
}
// not support operator
ops := []string{"&", "+", "<=", "<"}
for _, op := range ops {
v, _ := parseExpr("OS == JUNOS " + op + " POP == LAX")
_, err = exprEval(v, labels)
if err == nil {
t.Fatal("expect error but got nil")
}
}
}
func BenchmarkQueryExprEval(b *testing.B) {
labels := map[string]string{"POP": "LAX", "OS": "JUNOS"}
expr := "POP==LAX"
for i := 0; i < b.N; i++ {
v, err := parseExpr(expr)
if err != nil {
b.Fatal(err)
}
_, err = exprEval(v, labels)
if err != nil {
b.Fatal(err)
}
}
}

56
vssh/vssh.go

@ -56,7 +56,6 @@ type VSSH struct {
}
type stats struct {
queries uint64
connects uint64
processes uint64
}
@ -164,13 +163,6 @@ func DisableRequestPty() ClientOption {
}
}
// SetLabels sets labels for a client.
func SetLabels(labels map[string]string) ClientOption {
return func(c *clientAttr) {
c.labels = labels
}
}
func clientValidation(c *clientAttr) error {
if c.config == nil {
return errSSHConfig
@ -216,15 +208,9 @@ func (v *VSSH) process(ctx context.Context) {
go func() {
for {
select {
case a := <-v.actionQ:
switch b := a.(type) {
case *connect:
atomic.AddUint64(&v.stats.connects, 1)
b.run(v)
case *query:
atomic.AddUint64(&v.stats.queries, 1)
b.run(v)
}
case action := <-v.actionQ:
atomic.AddUint64(&v.stats.connects, 1)
action.run(v)
case <-v.procSig:
atomic.AddUint64(&v.stats.processes, ^uint64(0))
return
@ -305,42 +291,6 @@ func (v *VSSH) Run(ctx context.Context, cmd string, timeout time.Duration, opts
return respChan
}
// RunWithLabel runs the command on the specific clients which
// they matched with given query statement.
// labels := map[string]string {
// "POP" : "LAX",
// "OS" : "JUNOS",
// }
// // sets labels to a client
// vs.AddClient(addr, config, vssh.SetLabels(labels))
// // run the command with label
// vs.RunWithLabel(ctx, cmd, timeout, "POP == LAX || POP == DCA) && OS == JUNOS")
func (v *VSSH) RunWithLabel(ctx context.Context, cmd, queryStmt string, timeout time.Duration, opts ...RunOption) (chan *Response, error) {
vis, err := parseExpr(queryStmt)
if err != nil {
return nil, err
}
respChan := make(chan *Response, 100)
q := &query{
ctx: ctx,
cmd: cmd,
stmt: queryStmt,
compiledQuery: vis,
respChan: respChan,
respTimeout: timeout,
}
for _, opt := range opts {
opt(q)
}
v.actionQ <- q
return respChan, nil
}
// SetLimitReaderStdout sets limit for stdout reader.
// respChan := vs.Run(ctx, cmd, timeout, vssh.SetLimitReaderStdout(1024))
func SetLimitReaderStdout(n int64) RunOption {

18
vssh/vssh_test.go

@ -111,24 +111,6 @@ func TestRequestPty(t *testing.T) {
}
}
func TestSetLimitReaderStdout(t *testing.T) {
f := SetLimitReaderStdout(1024)
q := &query{}
f(q)
if q.limitReadOut != 1024 {
t.Error("expect limitReadOut 1024 but got", q.limitReadOut)
}
}
func TestSetLimitReaderStderr(t *testing.T) {
f := SetLimitReaderStderr(1024)
q := &query{}
f(q)
if q.limitReadErr != 1024 {
t.Error("expect limitReadErr 1024 but got", q.limitReadErr)
}
}
func TestGetConfigPEM(t *testing.T) {
_, err := GetConfigPEM("vssh", "notexitfile")
if err == nil {

1
vsshclient.go

@ -1 +0,0 @@
package stup
Loading…
Cancel
Save