package main import ( "bufio" "encoding/base64" "encoding/json" "fmt" "io" "io/fs" "os" "os/exec" "os/signal" "os/user" "reflect" "regexp" "strings" "syscall" "time" "git.wecise.com/wecise/common/matrix/logger" "git.wecise.com/wecise/common/matrix/util" "github.com/creack/pty" "golang.org/x/term" "gopkg.in/yaml.v2" ) func init() { logger.SetConsole(true) logger.SetLevel(logger.TRACE) logger.SetFormat("yyyy-MM-dd HH:mm:ss.SSS [level] msg", "\n") } type KV struct{ Key, Val string } func parseArgs(args []string) (kvs []*KV) { argk := "" argv := "" for _, arg := range args { if argk != "" { argv = arg } else if regexp.MustCompile(`^\-\w+$`).MatchString(arg) { argk = arg[1:] continue } else { kv := strings.SplitN(arg, "=", 2) if len(kv) == 2 { argk = kv[0] argv = kv[1] } else { argk = "" argv = arg } } kvs = append(kvs, &KV{argk, argv}) argk, argv = "", "" } if argk != "" { kvs = append(kvs, &KV{argk, argv}) } return } func usage() { fmt.Println("usage:") fmt.Println(" msh p=password|a=passcode [[c=]command [p=password|a=passcode] [x=cmd-end-regexp] [[r=regexp] [o=output|n=outputline]]...]...") fmt.Println(" a=passcode should be base64 encoded, or use p=password") fmt.Println(" debug info include: p(progress) a(argments) m(match) 1(all)") } func main() { // get user u := func() string { user, err := user.Current() if err != nil { logger.Error(util.ErrorWithSourceLine(err)) return "" } return user.Username }() // get passcode,取默认密码设置 c := "" kvs := parseArgs(os.Args) for _, kv := range kvs { if kv.Key == "a" { c = kv.Val if c == "" { c = "=" } break } if kv.Key == "p" { c = "=" + kv.Val break } } // get password p := c if p == "" { usage() return } else if p[0:1] == "=" { p = p[1:] } else { x, e := base64.RawStdEncoding.DecodeString(p) if e == nil { p = string(x) } // else 不是Base64编码,保持原值 } // explainArgs key, val cmds := []*Command{{Cmd: "sh", Password: c, Regexps: []*Matcher{{Regexp: nil, Output: ""}}, Endregx: regxprompt}} i := 0 // 掠过第一个参数,当前执行程序 kvs = kvs[1:] for _, kv := range kvs { key, val := kv.Key, kv.Val switch key { case "", "cmd", "command", "c": cmds = append(cmds, &Command{Cmd: val, Password: c, Regexps: []*Matcher{{Regexp: nil, Output: ""}}, Endregx: regxprompt}) case "ry", "regex_yes_no": re, err := regexp.Compile(val) if err != nil { logger.Error("arg", i, util.ErrorWithSourceLine(err)) return } else { regxyesno = &Regexp{re} } case "rc", "regex_passcode", "regex_password": re, err := regexp.Compile(val) if err != nil { logger.Error("arg", i, util.ErrorWithSourceLine(err)) return } else { regxpassword = &Regexp{re} } case "rp", "regex_prompt": re, err := regexp.Compile(val) if err != nil { logger.Error("arg", i, util.ErrorWithSourceLine(err)) return } else { regxprompt = &Regexp{re} } case "password", "code", "pass", "p": cmds[len(cmds)-1].Password = "=" + val case "passcode", "b64code", "a": cmds[len(cmds)-1].Password = val case "re", "r", "regex": if val == "" { cmds[len(cmds)-1].Regexps = append(cmds[len(cmds)-1].Regexps, &Matcher{Regexp: nil}) } else { re, err := regexp.Compile(val) if err != nil { logger.Error("arg", i, util.ErrorWithSourceLine(err)) return } else { cmds[len(cmds)-1].Regexps = append(cmds[len(cmds)-1].Regexps, &Matcher{Regexp: &Regexp{re}}) } } case "x", "end": if val == "" { cmds[len(cmds)-1].Endregx = regxprompt } else { re, err := regexp.Compile(val) if err != nil { logger.Error("arg", i, util.ErrorWithSourceLine(err)) return } else { cmds[len(cmds)-1].Endregx = &Regexp{re} } } case "out", "o", "output": cmds[len(cmds)-1].Regexps[len(cmds[len(cmds)-1].Regexps)-1].Output += val case "outln", "n", "outputline": cmds[len(cmds)-1].Regexps[len(cmds[len(cmds)-1].Regexps)-1].Output += val + "\n" case "debug", "d": cmds[len(cmds)-1].Regexps[len(cmds[len(cmds)-1].Regexps)-1].Debug += val } } if strings.Index(cmds[0].Regexps[0].Debug, "a") >= 0 || strings.Index(cmds[0].Regexps[0].Debug, "1") >= 0 { //bs, _ := json.MarshalIndent(cmds, "", " ") bs, _ := yaml.Marshal(cmds) logger.Debug("arguments:\n" + string(bs)) } client := &Client{ User: u, Password: p, Commands: cmds, } client.Run() } type Regexp struct { *regexp.Regexp } func (r *Regexp) MarshalJSON() ([]byte, error) { if r == nil || r.Regexp == nil { return json.Marshal(nil) } return json.Marshal(r.String()) } func (r *Regexp) MarshalYAML() (interface{}, error) { if r == nil || r.Regexp == nil { return nil, nil } return r.String(), nil } type Matcher struct { Regexp *Regexp Output string `yaml:"output"` Debug string `yaml:"debug"` } type Command struct { Cmd string `yaml:"cmd"` Password string `yaml:"password"` Regexps []*Matcher Endregx *Regexp } var regxyesno = &Regexp{regexp.MustCompile(`.*(\(yes\/no\)\?)\s*$`)} var regxpassword = &Regexp{regexp.MustCompile(`.*([Pp]assword:)\s*$`)} var regxprompt = &Regexp{regexp.MustCompile(`.*([%#\$\>])\s*$`)} type Client struct { User string Password string Commands []*Command } func Pipe(reader io.Reader, writer io.Writer, pfs ...func(lastbs []byte, sin string) (sout string)) { btr := NewBTReader("", reader, 30*time.Millisecond, 1024) lastbs := []byte{'\n'} for { bs, err := btr.Read() if err != nil { if err == io.EOF { return } logger.Error(util.ErrorWithSourceLine(err)) return } if len(bs) > 0 { xbs := bs for _, pf := range pfs { if pf != nil { s := pf(lastbs, string(xbs)) xbs = []byte(s) } } if writer != nil { writer.Write(xbs) } lastbs = bs } } } func (c *Client) Run() { logger.Info("msh ready") // Create arbitrary command. cmd := exec.Command("sh", "-c", "export PS1='\\h:> ';sh -i") // Start the command with a pty. ptmx, err := pty.Start(cmd) if err != nil { logger.Error(err) return } // Make sure to close the pty at the end. defer func() { _ = ptmx.Close() }() // Best effort. // Handle pty size. ch := make(chan os.Signal, 1) signal.Notify(ch, syscall.SIGWINCH) go func() { for range ch { if err := pty.InheritSize(os.Stdin, ptmx); err != nil { logger.Error("error resizing pty:", err) } } }() ch <- syscall.SIGWINCH // Initial resize. defer func() { signal.Stop(ch); close(ch) }() // Cleanup signals when done. // Set stdin in raw mode. oldState, err := term.MakeRaw(int(os.Stdin.Fd())) if err != nil { logger.Error(err) } go Pipe(os.Stdin, ptmx) e := c.IOProc(ptmx, os.Stdout, ptmx) if oldState != nil { term.Restore(int(os.Stdin.Fd()), oldState) fmt.Println() } if cmd.ProcessState.ExitCode() == -1 && e != nil { logger.Error(e) return } logger.Info("msh exit") return } func (c *Client) IOProc(cmdout io.Reader, stdout io.Writer, cmdin io.Writer) error { cmdidx := 0 as := "" change := false stdoutbtr := NewBTReader("", cmdout, 50*time.Millisecond, 1024) for { bs, err := stdoutbtr.Read() if err != nil { if err == io.EOF { return nil } if _, y := err.(*fs.PathError); y { return nil } return util.ErrorWithSourceLine(err, reflect.TypeOf(err)) } if len(bs) > 0 { change = true s := string(bs) stdout.Write([]byte(s)) as += s } else if change { change = false if cmdidx < len(c.Commands) { if msi := c.Commands[cmdidx].Endregx.FindStringSubmatchIndex(as); msi != nil { // logger.Error("EndregexMatch", as) match := as[msi[0]:msi[1]] if strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "m") >= 0 || strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "1") >= 0 { logger.Trace("match end:", "'"+match+"'") } as = "" // 全清,开始新的命令 cmdidx++ if cmdidx >= len(c.Commands) { return nil } if strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "p") >= 0 || strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "1") >= 0 { logger.Trace("command:", c.Commands[cmdidx].Cmd) } // stdout.Write([]byte(c.Commands[cmdidx].Cmd + "\n")) cmdin.Write([]byte(c.Commands[cmdidx].Cmd + "\n")) continue } for _, regex := range c.Commands[cmdidx].Regexps { if regex == nil { continue } if regex.Regexp == nil { // like match all if regex.Output != "" { // stdout.Write([]byte(regex.Output)) cmdin.Write([]byte(regex.Output)) } } else if msi := regex.Regexp.FindStringSubmatchIndex(as); msi != nil { match := as[msi[0]:msi[1]] if len(msi) >= 4 { match = as[msi[2]:msi[3]] as = as[msi[3]:] // 清除已处理完的内容 } else { as = as[msi[1]:] // 清除已处理完的内容 } if strings.Index(regex.Debug, "m") >= 0 || strings.Index(regex.Debug, "1") >= 0 { logger.Trace("match regex:", "'"+match+"'") } if regex.Output != "" { // stdout.Write([]byte(regex.Output)) cmdin.Write([]byte(regex.Output)) } } } } if msi := regxyesno.FindStringSubmatchIndex(as); msi != nil { match := as[msi[0]:msi[1]] if len(msi) >= 4 { as = as[msi[3]:] // 清除已处理完的内容 } else { as = as[msi[1]:] // 清除已处理完的内容 } if strings.Index(c.Commands[0].Regexps[0].Debug, "m") >= 0 || strings.Index(c.Commands[0].Regexps[0].Debug, "1") >= 0 || cmdidx < len(c.Commands) && (strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "m") >= 0 || strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "1") >= 0) { logger.Trace("match yesno:", "'"+match+"'") } // stdout.Write([]byte("yes\n")) cmdin.Write([]byte("yes\n")) } if msi := regxpassword.FindStringSubmatchIndex(as); msi != nil { // logger.Error(as) match := as[msi[0]:msi[1]] if len(msi) >= 4 { as = as[msi[3]:] // 清除已处理完的内容 } else { as = as[msi[1]:] // 清除已处理完的内容 } p := c.Commands[0].Password if cmdidx < len(c.Commands) { p = c.Commands[cmdidx].Password } if strings.Index(c.Commands[0].Regexps[0].Debug, "m") >= 0 || strings.Index(c.Commands[0].Regexps[0].Debug, "1") >= 0 || cmdidx < len(c.Commands) && (strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "m") >= 0 || strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "1") >= 0) { logger.Trace("match password:", "'"+match+"'") } if p != "" { if p[0:1] == "=" { p = p[1:] } else { x, e := base64.RawStdEncoding.DecodeString(p) if e == nil { p = string(x) } // else 不是Base64编码,保持原值 } // don't echo password if c.Commands[0].Regexps[0].Debug != "" || cmdidx < len(c.Commands) && c.Commands[cmdidx].Regexps[0].Debug != "" { stdout.Write([]byte(p + "\n")) } cmdin.Write([]byte(p + "\n")) } } if len(as) > 1024 { as = as[len(as)-1024:] } } } } type BTReader struct { *bufio.Reader flag string chbs chan []byte cher chan error timeout time.Duration sizeout int err error } func NewBTReader(flag string, reader io.Reader, timeout time.Duration, sizeout int) (me *BTReader) { me = &BTReader{Reader: bufio.NewReader(reader), flag: flag, chbs: make(chan []byte), cher: make(chan error), timeout: timeout, sizeout: sizeout, } go func() { bs := make([]byte, me.Size()) for { n, err := me.Reader.Read(bs[:1]) if err != nil { me.cher <- err return } x, err := me.Reader.Read(bs[1 : me.Reader.Buffered()+1]) if err != nil { me.cher <- err return } n += x abs := make([]byte, n) copy(abs, bs[:n]) me.chbs <- abs } }() return } // 指定时间内没有新数据进入,且有积累数据,或积累数据超过指定数量,即返回 func (me *BTReader) Read() (rbs []byte, err error) { if me.err != nil { return nil, me.err } for { t := time.NewTimer(me.timeout) select { case me.err = <-me.cher: if len(rbs) > 0 { // 返回最后的数据,下次读时才返回错误 return rbs, nil } return nil, me.err case abs := <-me.chbs: rbs = append(rbs, abs...) if len(rbs) > me.sizeout { return } t.Stop() t.Reset(me.timeout) case <-t.C: if len(rbs) == 0 { t.Stop() t.Reset(me.timeout) } return } } }