123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373 |
- package main
- import (
- "bufio"
- "fmt"
- "io"
- "io/ioutil"
- "net"
- "os"
- "os/user"
- "path"
- "regexp"
- "strconv"
- "strings"
- "syscall"
- "time"
- "git.wecise.com/wecise/common/logger"
- "golang.org/x/crypto/ssh"
- "golang.org/x/crypto/ssh/terminal"
- )
- var (
- DefaultCiphers = []string{
- "aes128-ctr",
- "aes192-ctr",
- "aes256-ctr",
- "aes128-gcm@openssh.com",
- "chacha20-poly1305@openssh.com",
- "arcfour256",
- "arcfour128",
- "arcfour",
- "aes128-cbc",
- "3des-cbc",
- "blowfish-cbc",
- "cast128-cbc",
- "aes192-cbc",
- "aes256-cbc",
- }
- )
- type Client interface {
- Login()
- }
- type defaultClient struct {
- clientConfig *ssh.ClientConfig
- node *Node
- }
- func genSSHConfig(node *Node) *defaultClient {
- u, err := user.Current()
- if err != nil {
- logger.Error(err)
- return nil
- }
- var authMethods []ssh.AuthMethod
- var pemBytes []byte
- if node.KeyPath == "" {
- pemBytes, err = ioutil.ReadFile(path.Join(u.HomeDir, ".ssh/id_rsa"))
- } else {
- pemBytes, err = ioutil.ReadFile(node.KeyPath)
- }
- if err != nil && !os.IsNotExist(err) {
- logger.Error(err)
- } else if len(pemBytes) > 0 {
- var signer ssh.Signer
- if node.Passphrase != "" {
- signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(node.Passphrase))
- } else {
- signer, err = ssh.ParsePrivateKey(pemBytes)
- }
- if err != nil {
- logger.Error(err)
- } else {
- authMethods = append(authMethods, ssh.PublicKeys(signer))
- }
- }
- password := node.password()
- if password != nil {
- authMethods = append(authMethods, password)
- }
- authMethods = append(authMethods, ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
- answers := make([]string, 0, len(questions))
- for i, q := range questions {
- fmt.Print(q)
- if echos[i] {
- scan := bufio.NewScanner(os.Stdin)
- if scan.Scan() {
- answers = append(answers, scan.Text())
- }
- err := scan.Err()
- if err != nil {
- return nil, err
- }
- } else {
- b, err := terminal.ReadPassword(int(syscall.Stdin))
- if err != nil {
- return nil, err
- }
- fmt.Println()
- answers = append(answers, string(b))
- }
- }
- return answers, nil
- }))
- config := &ssh.ClientConfig{
- User: node.user(),
- Auth: authMethods,
- HostKeyCallback: ssh.InsecureIgnoreHostKey(),
- Timeout: time.Second * 10,
- }
- config.SetDefaults()
- config.Ciphers = append(config.Ciphers, DefaultCiphers...)
- return &defaultClient{
- clientConfig: config,
- node: node,
- }
- }
- func NewClient(node *Node) Client {
- return genSSHConfig(node)
- }
- func (c *defaultClient) Login() {
- host := c.node.Host
- port := strconv.Itoa(c.node.port())
- jNodes := c.node.Jump
- var client *ssh.Client
- if len(jNodes) > 0 {
- jNode := jNodes[0]
- jc := genSSHConfig(jNode)
- proxyClient, err := ssh.Dial("tcp", net.JoinHostPort(jNode.Host, strconv.Itoa(jNode.port())), jc.clientConfig)
- if err != nil {
- logger.Error(err)
- return
- }
- conn, err := proxyClient.Dial("tcp", net.JoinHostPort(host, port))
- if err != nil {
- logger.Error(err)
- return
- }
- ncc, chans, reqs, err := ssh.NewClientConn(conn, net.JoinHostPort(host, port), c.clientConfig)
- if err != nil {
- logger.Error(err)
- return
- }
- client = ssh.NewClient(ncc, chans, reqs)
- } else {
- client1, err := ssh.Dial("tcp", net.JoinHostPort(host, port), c.clientConfig)
- client = client1
- if err != nil {
- msg := err.Error()
- // use terminal password retry
- if strings.Contains(msg, "no supported methods remain") && !strings.Contains(msg, "password") {
- fmt.Printf("%s@%s's password:", c.clientConfig.User, host)
- var b []byte
- b, err = terminal.ReadPassword(int(syscall.Stdin))
- if err == nil {
- p := string(b)
- if p != "" {
- c.clientConfig.Auth = append(c.clientConfig.Auth, ssh.Password(p))
- }
- fmt.Println()
- client, err = ssh.Dial("tcp", net.JoinHostPort(host, port), c.clientConfig)
- }
- }
- }
- if err != nil {
- logger.Error(err)
- return
- }
- }
- defer client.Close()
- logger.Warnf("connect server ssh -p %d %s@%s version: %s\n", c.node.port(), c.node.user(), host, string(client.ServerVersion()))
- session, err := client.NewSession()
- if err != nil {
- logger.Error(err)
- return
- }
- defer session.Close()
- fd := int(os.Stdin.Fd())
- w, h, err := terminal.GetSize(fd)
- if err != nil {
- logger.Error(err)
- return
- }
- modes := ssh.TerminalModes{
- ssh.ECHO: 0,
- ssh.TTY_OP_ISPEED: 14400,
- ssh.TTY_OP_OSPEED: 14400,
- }
- err = session.RequestPty("xterm", h, w, modes)
- if err != nil {
- logger.Error(err)
- return
- }
- stdinPipe, err := session.StdinPipe()
- if err != nil {
- logger.Error(err)
- session.Stdin = os.Stdin
- }
- stdoutPipe, err := session.StdoutPipe()
- if err != nil {
- logger.Error(err)
- session.Stdout = os.Stdout
- }
- stderrPipe, err := session.StderrPipe()
- if err != nil {
- logger.Error(err)
- session.Stderr = os.Stderr
- }
- err = session.Shell()
- if err != nil {
- logger.Error(err)
- return
- }
- cmdidx := 0
- regxpassword := regexp.MustCompile(`.*[Pp]assword: *$`)
- regxprompt := regexp.MustCompile(`.*[%#\$\>]\s*$`)
- outputproc := func(stdoutbtr *BTReader, localstdout io.Writer, stdinPipe io.Writer) {
- as := ""
- change := false
- for {
- bs, err := stdoutbtr.ReadTimeout(10 * time.Millisecond)
- if err != nil {
- if err == io.EOF {
- return
- }
- logger.Error(err)
- return
- }
- if len(bs) > 0 {
- change = true
- s := string(bs)
- localstdout.Write([]byte(s))
- as += s
- } else if change {
- change = false
- if regxpassword.MatchString(as) {
- p := c.node.Commands[cmdidx].Password
- if p == "" {
- p = c.node.Commands[0].Password
- }
- // don't echo password, os.Stdout.Write([]byte(p + "\n"))
- stdinPipe.Write([]byte(p + "\n"))
- }
- if regxprompt.MatchString(as) {
- cmdidx++
- if cmdidx >= len(c.node.Commands) {
- localstdout.Write([]byte("exit" + "\n"))
- stdinPipe.Write([]byte("exit" + "\n"))
- } else {
- localstdout.Write([]byte(c.node.Commands[cmdidx].Cmd + "\n"))
- stdinPipe.Write([]byte(c.node.Commands[cmdidx].Cmd + "\n"))
- }
- }
- if len(as) > 1024 {
- as = as[len(as)-1024:]
- }
- }
- }
- }
- go outputproc(&BTReader{Reader: bufio.NewReader(stdoutPipe)}, os.Stdout, stdinPipe)
- go outputproc(&BTReader{Reader: bufio.NewReader(stderrPipe)}, os.Stderr, stdinPipe)
- go func() {
- for {
- bs := make([]byte, 1024)
- n, err := os.Stdin.Read(bs)
- if err != nil {
- if err == io.EOF {
- return
- }
- logger.Error(err)
- return
- }
- s := string(bs[:n])
- stdinPipe.Write([]byte(s))
- }
- }()
- // interval get terminal size
- // fix resize issue
- go func() {
- var (
- ow = w
- oh = h
- )
- for {
- cw, ch, err := terminal.GetSize(fd)
- if err != nil {
- break
- }
- if cw != ow || ch != oh {
- err = session.WindowChange(ch, cw)
- if err != nil {
- break
- }
- ow = cw
- oh = ch
- }
- time.Sleep(time.Second)
- }
- }()
- // send keepalive
- go func() {
- for {
- time.Sleep(time.Second * 10)
- client.SendRequest("nop", false, nil)
- }
- }()
- session.Wait()
- logger.Warnf("disconnected")
- }
- type BTReader struct {
- *bufio.Reader
- bufop int32
- chbs chan []byte
- }
- func (me *BTReader) ReadTimeout(d time.Duration) (rbs []byte, err error) {
- if me.chbs == nil {
- me.chbs = make(chan []byte)
- go func() {
- n := 0
- bs := make([]byte, me.Size())
- for {
- _, err = me.ReadByte()
- if err != nil {
- return
- }
- err = me.UnreadByte()
- if err != nil {
- return
- }
- n, err = me.Read(bs[0:me.Buffered()])
- if err != nil {
- return
- }
- me.chbs <- bs[0:n]
- }
- }()
- }
- t := time.NewTimer(d)
- select {
- case rbs = <-me.chbs:
- return
- case <-t.C:
- return
- }
- }
|