package main import ( "bufio" "encoding/base64" "fmt" "io" "io/ioutil" "net" "os" "os/user" "path" "regexp" "strconv" "strings" "syscall" "time" "git.wecise.com/wecise/odbserver/matrix/logger" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" ) var ( DefaultCiphers = []string{ "aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-gcm@openssh.com", "chacha20-poly1305@openssh.com", "arcfour256", "arcfour128", "arcfour", "aes128-cbc", "3des-cbc", "blowfish-cbc", "cast128-cbc", "aes192-cbc", "aes256-cbc", } ) var regxyesno = &Regexp{regexp.MustCompile(`.*(\(yes\/no\)\?)\s*$`)} var regxpassword = &Regexp{regexp.MustCompile(`.*([Pp]assword:)\s*$`)} var regxprompt = &Regexp{regexp.MustCompile(`.*([%#\$\>])\s*$`)} type Client interface { Login() } type defaultClient struct { clientConfig *ssh.ClientConfig node *Node } func genSSHConfig(node *Node) *defaultClient { u, err := user.Current() if err != nil { logger.Error(err) return nil } var authMethods []ssh.AuthMethod var pemBytes []byte if node.KeyPath == "" { pemBytes, err = ioutil.ReadFile(path.Join(u.HomeDir, ".ssh/id_rsa")) } else { pemBytes, err = ioutil.ReadFile(node.KeyPath) } if err != nil && !os.IsNotExist(err) { logger.Error(err) } else if len(pemBytes) > 0 { var signer ssh.Signer if node.Passphrase != "" { signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(node.Passphrase)) } else { signer, err = ssh.ParsePrivateKey(pemBytes) } if err != nil { logger.Error(err) } else { authMethods = append(authMethods, ssh.PublicKeys(signer)) } } password := node.password() if password != nil { authMethods = append(authMethods, password) } authMethods = append(authMethods, ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) { answers := make([]string, 0, len(questions)) for i, q := range questions { fmt.Print(q) if echos[i] { scan := bufio.NewScanner(os.Stdin) if scan.Scan() { answers = append(answers, scan.Text()) } err := scan.Err() if err != nil { return nil, err } } else { b, err := terminal.ReadPassword(int(syscall.Stdin)) if err != nil { return nil, err } fmt.Println() answers = append(answers, string(b)) } } return answers, nil })) config := &ssh.ClientConfig{ User: node.user(), Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: time.Second * 10, } config.SetDefaults() config.Ciphers = append(config.Ciphers, DefaultCiphers...) return &defaultClient{ clientConfig: config, node: node, } } func NewClient(node *Node) Client { return genSSHConfig(node) } func (c *defaultClient) Login() { host := c.node.Host port := strconv.Itoa(c.node.port()) jNodes := c.node.Jump var client *ssh.Client if len(jNodes) > 0 { jNode := jNodes[0] jc := genSSHConfig(jNode) proxyClient, err := ssh.Dial("tcp", net.JoinHostPort(jNode.Host, strconv.Itoa(jNode.port())), jc.clientConfig) if err != nil { logger.Error(err) return } conn, err := proxyClient.Dial("tcp", net.JoinHostPort(host, port)) if err != nil { logger.Error(err) return } ncc, chans, reqs, err := ssh.NewClientConn(conn, net.JoinHostPort(host, port), c.clientConfig) if err != nil { logger.Error(err) return } client = ssh.NewClient(ncc, chans, reqs) } else { client1, err := ssh.Dial("tcp", net.JoinHostPort(host, port), c.clientConfig) client = client1 if err != nil { msg := err.Error() // use terminal password retry if strings.Contains(msg, "no supported methods remain") && !strings.Contains(msg, "password") { fmt.Printf("%s@%s's password:", c.clientConfig.User, host) var b []byte b, err = terminal.ReadPassword(int(syscall.Stdin)) if err == nil { p := string(b) if p != "" { c.clientConfig.Auth = append(c.clientConfig.Auth, ssh.Password(p)) } fmt.Println() client, err = ssh.Dial("tcp", net.JoinHostPort(host, port), c.clientConfig) } } } if err != nil { logger.Error(err) return } } defer client.Close() if strings.Index(c.node.Commands[0].Regexps[0].Debug, "p") >= 0 || strings.Index(c.node.Commands[0].Regexps[0].Debug, "1") >= 0 { logger.Trace("connect server ssh", fmt.Sprintf("-p %d %s@%s version: %s", c.node.port(), c.node.user(), host, string(client.ServerVersion()))) } session, err := client.NewSession() if err != nil { logger.Error(err) return } defer session.Close() fd := int(os.Stdin.Fd()) w, h, err := terminal.GetSize(fd) if err != nil { logger.Error(err) return } modes := ssh.TerminalModes{ ssh.ECHO: 0, ssh.TTY_OP_ISPEED: 14400, ssh.TTY_OP_OSPEED: 14400, } err = session.RequestPty("xterm", h, w, modes) if err != nil { logger.Error(err) return } stdinPipe, err := session.StdinPipe() if err != nil { logger.Error(err) session.Stdin = os.Stdin } stdoutPipe, err := session.StdoutPipe() if err != nil { logger.Error(err) session.Stdout = os.Stdout } stderrPipe, err := session.StderrPipe() if err != nil { logger.Error(err) session.Stderr = os.Stderr } err = session.Shell() if err != nil { logger.Error(err) return } cmdidx := 0 outputproc := func(stdoutbtr *BTReader, localstdout io.Writer, stdinPipe io.Writer) { as := "" change := false for { bs, err := stdoutbtr.ReadTimeout(10 * time.Millisecond) if err != nil { if err == io.EOF { return } logger.Error(err) return } if len(bs) > 0 { change = true s := string(bs) localstdout.Write([]byte(s)) as += s } else if change { change = false if cmdidx < len(c.node.Commands) { if msi := c.node.Commands[cmdidx].Endregx.FindStringSubmatchIndex(as); msi != nil { match := as[msi[0]:msi[1]] if strings.Index(c.node.Commands[cmdidx].Regexps[0].Debug, "m") >= 0 || strings.Index(c.node.Commands[cmdidx].Regexps[0].Debug, "1") >= 0 { logger.Trace("match end:", "'"+match+"'") } as = "" // 全清,开始新的命令 cmdidx++ if cmdidx >= len(c.node.Commands) { continue } if strings.Index(c.node.Commands[cmdidx].Regexps[0].Debug, "p") >= 0 || strings.Index(c.node.Commands[cmdidx].Regexps[0].Debug, "1") >= 0 { logger.Trace("command:", c.node.Commands[cmdidx].Cmd) } localstdout.Write([]byte(c.node.Commands[cmdidx].Cmd + "\n")) stdinPipe.Write([]byte(c.node.Commands[cmdidx].Cmd + "\n")) continue } for _, regex := range c.node.Commands[cmdidx].Regexps { if regex == nil { continue } if regex.Regexp == nil { // like match all if regex.Output != "" { localstdout.Write([]byte(regex.Output)) stdinPipe.Write([]byte(regex.Output)) } } else if msi := regex.Regexp.FindStringSubmatchIndex(as); msi != nil { match := as[msi[0]:msi[1]] if len(msi) >= 4 { match = as[msi[2]:msi[3]] as = as[msi[3]:] // 清除已处理完的内容 } else { as = as[msi[1]:] // 清除已处理完的内容 } if strings.Index(regex.Debug, "m") >= 0 || strings.Index(regex.Debug, "1") >= 0 { logger.Trace("match regex:", "'"+match+"'") } if regex.Output != "" { localstdout.Write([]byte(regex.Output)) stdinPipe.Write([]byte(regex.Output)) } } } } if msi := regxyesno.FindStringSubmatchIndex(as); msi != nil { match := as[msi[0]:msi[1]] if len(msi) >= 4 { as = as[msi[3]:] // 清除已处理完的内容 } else { as = as[msi[1]:] // 清除已处理完的内容 } if strings.Index(c.node.Commands[0].Regexps[0].Debug, "m") >= 0 || strings.Index(c.node.Commands[0].Regexps[0].Debug, "1") >= 0 || cmdidx < len(c.node.Commands) && (strings.Index(c.node.Commands[cmdidx].Regexps[0].Debug, "m") >= 0 || strings.Index(c.node.Commands[cmdidx].Regexps[0].Debug, "1") >= 0) { logger.Trace("match yesno:", "'"+match+"'") } os.Stdout.Write([]byte("yes\n")) stdinPipe.Write([]byte("yes\n")) } if msi := regxpassword.FindStringSubmatchIndex(as); msi != nil { match := as[msi[0]:msi[1]] if len(msi) >= 4 { as = as[msi[3]:] // 清除已处理完的内容 } else { as = as[msi[1]:] // 清除已处理完的内容 } p := c.node.Commands[0].Password if cmdidx < len(c.node.Commands) { p = c.node.Commands[cmdidx].Password } if strings.Index(c.node.Commands[0].Regexps[0].Debug, "m") >= 0 || strings.Index(c.node.Commands[0].Regexps[0].Debug, "1") >= 0 || cmdidx < len(c.node.Commands) && (strings.Index(c.node.Commands[cmdidx].Regexps[0].Debug, "m") >= 0 || strings.Index(c.node.Commands[cmdidx].Regexps[0].Debug, "1") >= 0) { logger.Trace("match password:", "'"+match+"'") } if p != "" { if p[0:1] == "=" { p = p[1:] } else { x, e := base64.RawStdEncoding.DecodeString(p) if e == nil { p = string(x) } // else 不是Base64编码,保持原值 } // don't echo password if c.node.Commands[0].Regexps[0].Debug != "" || cmdidx < len(c.node.Commands) && c.node.Commands[cmdidx].Regexps[0].Debug != "" { os.Stdout.Write([]byte(p + "\n")) } stdinPipe.Write([]byte(p + "\n")) } } if len(as) > 1024 { as = as[len(as)-1024:] } } } } go outputproc(&BTReader{Reader: bufio.NewReader(stdoutPipe)}, os.Stdout, stdinPipe) go outputproc(&BTReader{Reader: bufio.NewReader(stderrPipe)}, os.Stderr, stdinPipe) go func() { for { bs := make([]byte, 1024) n, err := os.Stdin.Read(bs) if err != nil { if err == io.EOF { return } logger.Error(err) return } s := string(bs[:n]) stdinPipe.Write([]byte(s)) } }() // interval get terminal size // fix resize issue go func() { var ( ow = w oh = h ) for { cw, ch, err := terminal.GetSize(fd) if err != nil { break } if cw != ow || ch != oh { err = session.WindowChange(ch, cw) if err != nil { break } ow = cw oh = ch } time.Sleep(time.Second) } }() // send keepalive go func() { for { time.Sleep(time.Second * 10) client.SendRequest("nop", false, nil) } }() session.Wait() if strings.Index(c.node.Commands[0].Regexps[0].Debug, "p") >= 0 || strings.Index(c.node.Commands[0].Regexps[0].Debug, "1") >= 0 { logger.Trace("disconnected") } } type BTReader struct { *bufio.Reader bufop int32 chbs chan []byte } func (me *BTReader) ReadTimeout(d time.Duration) (rbs []byte, err error) { if me.chbs == nil { me.chbs = make(chan []byte) go func() { n := 0 bs := make([]byte, me.Size()) for { _, err = me.ReadByte() if err != nil { return } err = me.UnreadByte() if err != nil { return } n, err = me.Read(bs[0:me.Buffered()]) if err != nil { return } me.chbs <- bs[0:n] } }() } t := time.NewTimer(d) select { case rbs = <-me.chbs: return case <-t.C: return } }