package main import ( "bufio" "fmt" "io" "io/ioutil" "net" "os" "os/user" "path" "regexp" "strconv" "strings" "syscall" "time" "git.wecise.com/wecise/common/logger" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" ) var ( DefaultCiphers = []string{ "aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-gcm@openssh.com", "chacha20-poly1305@openssh.com", "arcfour256", "arcfour128", "arcfour", "aes128-cbc", "3des-cbc", "blowfish-cbc", "cast128-cbc", "aes192-cbc", "aes256-cbc", } ) type Client interface { Login() } type defaultClient struct { clientConfig *ssh.ClientConfig node *Node } func genSSHConfig(node *Node) *defaultClient { u, err := user.Current() if err != nil { logger.Error(err) return nil } var authMethods []ssh.AuthMethod var pemBytes []byte if node.KeyPath == "" { pemBytes, err = ioutil.ReadFile(path.Join(u.HomeDir, ".ssh/id_rsa")) } else { pemBytes, err = ioutil.ReadFile(node.KeyPath) } if err != nil && !os.IsNotExist(err) { logger.Error(err) } else if len(pemBytes) > 0 { var signer ssh.Signer if node.Passphrase != "" { signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(node.Passphrase)) } else { signer, err = ssh.ParsePrivateKey(pemBytes) } if err != nil { logger.Error(err) } else { authMethods = append(authMethods, ssh.PublicKeys(signer)) } } password := node.password() if password != nil { authMethods = append(authMethods, password) } authMethods = append(authMethods, ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) { answers := make([]string, 0, len(questions)) for i, q := range questions { fmt.Print(q) if echos[i] { scan := bufio.NewScanner(os.Stdin) if scan.Scan() { answers = append(answers, scan.Text()) } err := scan.Err() if err != nil { return nil, err } } else { b, err := terminal.ReadPassword(int(syscall.Stdin)) if err != nil { return nil, err } fmt.Println() answers = append(answers, string(b)) } } return answers, nil })) config := &ssh.ClientConfig{ User: node.user(), Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: time.Second * 10, } config.SetDefaults() config.Ciphers = append(config.Ciphers, DefaultCiphers...) return &defaultClient{ clientConfig: config, node: node, } } func NewClient(node *Node) Client { return genSSHConfig(node) } func (c *defaultClient) Login() { host := c.node.Host port := strconv.Itoa(c.node.port()) jNodes := c.node.Jump var client *ssh.Client if len(jNodes) > 0 { jNode := jNodes[0] jc := genSSHConfig(jNode) proxyClient, err := ssh.Dial("tcp", net.JoinHostPort(jNode.Host, strconv.Itoa(jNode.port())), jc.clientConfig) if err != nil { logger.Error(err) return } conn, err := proxyClient.Dial("tcp", net.JoinHostPort(host, port)) if err != nil { logger.Error(err) return } ncc, chans, reqs, err := ssh.NewClientConn(conn, net.JoinHostPort(host, port), c.clientConfig) if err != nil { logger.Error(err) return } client = ssh.NewClient(ncc, chans, reqs) } else { client1, err := ssh.Dial("tcp", net.JoinHostPort(host, port), c.clientConfig) client = client1 if err != nil { msg := err.Error() // use terminal password retry if strings.Contains(msg, "no supported methods remain") && !strings.Contains(msg, "password") { fmt.Printf("%s@%s's password:", c.clientConfig.User, host) var b []byte b, err = terminal.ReadPassword(int(syscall.Stdin)) if err == nil { p := string(b) if p != "" { c.clientConfig.Auth = append(c.clientConfig.Auth, ssh.Password(p)) } fmt.Println() client, err = ssh.Dial("tcp", net.JoinHostPort(host, port), c.clientConfig) } } } if err != nil { logger.Error(err) return } } defer client.Close() logger.Warnf("connect server ssh -p %d %s@%s version: %s\n", c.node.port(), c.node.user(), host, string(client.ServerVersion())) session, err := client.NewSession() if err != nil { logger.Error(err) return } defer session.Close() fd := int(os.Stdin.Fd()) w, h, err := terminal.GetSize(fd) if err != nil { logger.Error(err) return } modes := ssh.TerminalModes{ ssh.ECHO: 0, ssh.TTY_OP_ISPEED: 14400, ssh.TTY_OP_OSPEED: 14400, } err = session.RequestPty("xterm", h, w, modes) if err != nil { logger.Error(err) return } stdinPipe, err := session.StdinPipe() if err != nil { logger.Error(err) session.Stdin = os.Stdin } stdoutPipe, err := session.StdoutPipe() if err != nil { logger.Error(err) session.Stdout = os.Stdout } stderrPipe, err := session.StderrPipe() if err != nil { logger.Error(err) session.Stderr = os.Stderr } err = session.Shell() if err != nil { logger.Error(err) return } cmdidx := 0 regxpassword := regexp.MustCompile(`.*[Pp]assword: *$`) regxprompt := regexp.MustCompile(`.*[%#\$\>]\s*$`) outputproc := func(stdoutbtr *BTReader, localstdout io.Writer, stdinPipe io.Writer) { as := "" change := false for { bs, err := stdoutbtr.ReadTimeout(10 * time.Millisecond) if err != nil { if err == io.EOF { return } logger.Error(err) return } if len(bs) > 0 { change = true s := string(bs) localstdout.Write([]byte(s)) as += s } else if change { change = false if regxpassword.MatchString(as) { p := c.node.Commands[cmdidx].Password if p == "" { p = c.node.Commands[0].Password } // don't echo password, os.Stdout.Write([]byte(p + "\n")) stdinPipe.Write([]byte(p + "\n")) } if regxprompt.MatchString(as) { cmdidx++ if cmdidx >= len(c.node.Commands) { localstdout.Write([]byte("exit" + "\n")) stdinPipe.Write([]byte("exit" + "\n")) } else { localstdout.Write([]byte(c.node.Commands[cmdidx].Cmd + "\n")) stdinPipe.Write([]byte(c.node.Commands[cmdidx].Cmd + "\n")) } } if len(as) > 1024 { as = as[len(as)-1024:] } } } } go outputproc(&BTReader{Reader: bufio.NewReader(stdoutPipe)}, os.Stdout, stdinPipe) go outputproc(&BTReader{Reader: bufio.NewReader(stderrPipe)}, os.Stderr, stdinPipe) go func() { for { bs := make([]byte, 1024) n, err := os.Stdin.Read(bs) if err != nil { if err == io.EOF { return } logger.Error(err) return } s := string(bs[:n]) stdinPipe.Write([]byte(s)) } }() // interval get terminal size // fix resize issue go func() { var ( ow = w oh = h ) for { cw, ch, err := terminal.GetSize(fd) if err != nil { break } if cw != ow || ch != oh { err = session.WindowChange(ch, cw) if err != nil { break } ow = cw oh = ch } time.Sleep(time.Second) } }() // send keepalive go func() { for { time.Sleep(time.Second * 10) client.SendRequest("nop", false, nil) } }() session.Wait() logger.Warnf("disconnected") } type BTReader struct { *bufio.Reader bufop int32 chbs chan []byte } func (me *BTReader) ReadTimeout(d time.Duration) (rbs []byte, err error) { if me.chbs == nil { me.chbs = make(chan []byte) go func() { n := 0 bs := make([]byte, me.Size()) for { _, err = me.ReadByte() if err != nil { return } err = me.UnreadByte() if err != nil { return } n, err = me.Read(bs[0:me.Buffered()]) if err != nil { return } me.chbs <- bs[0:n] } }() } t := time.NewTimer(d) select { case rbs = <-me.chbs: return case <-t.C: return } }