main.go 12 KB


  1. package main
  2. import (
  3. "bufio"
  4. "encoding/base64"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "io/fs"
  9. "os"
  10. "os/exec"
  11. "os/signal"
  12. "os/user"
  13. "regexp"
  14. "strings"
  15. "time"
  16. "github.com/creack/pty"
  17. "github.com/wecisecode/util/logger"
  18. "github.com/wecisecode/util/merrs"
  19. "golang.org/x/term"
  20. "gopkg.in/yaml.v2"
  21. )
  22. func init() {
  23. logger.SetConsole(true)
  24. logger.SetLevel(logger.TRACE)
  25. logger.SetFormat("yyyy-MM-dd HH:mm:ss.SSS [level] msg", "\n")
  26. }
  27. type KV struct{ Key, Val string }
  28. func parseArgs(args []string) (kvs []*KV) {
  29. argk := ""
  30. argv := ""
  31. for _, arg := range args {
  32. if argk != "" {
  33. argv = arg
  34. } else if regexp.MustCompile(`^\-\w+$`).MatchString(arg) {
  35. argk = arg[1:]
  36. continue
  37. } else {
  38. kv := strings.SplitN(arg, "=", 2)
  39. if len(kv) == 2 {
  40. argk = kv[0]
  41. argv = kv[1]
  42. } else {
  43. argk = ""
  44. argv = arg
  45. }
  46. }
  47. kvs = append(kvs, &KV{argk, argv})
  48. argk, argv = "", ""
  49. }
  50. if argk != "" {
  51. kvs = append(kvs, &KV{argk, argv})
  52. }
  53. return
  54. }
  55. func usage() {
  56. fmt.Println("usage:")
  57. fmt.Println(" msh p=password|a=passcode [[c=]command [p=password|a=passcode] [x=cmd-end-regexp] [[r=regexp] [o=output|n=outputline]]...]...")
  58. fmt.Println(" a=passcode should be base64 encoded, or use p=password")
  59. fmt.Println(" debug info include: p(progress) a(argments) m(match) 1(all)")
  60. }
  61. func main() {
  62. // get user
  63. u := func() string {
  64. user, err := user.Current()
  65. if err != nil {
  66. logger.Error(merrs.NewError(err))
  67. return ""
  68. }
  69. return user.Username
  70. }()
  71. // get passcode,取默认密码设置
  72. c := ""
  73. kvs := parseArgs(os.Args)
  74. for _, kv := range kvs {
  75. if kv.Key == "a" {
  76. c = kv.Val
  77. if c == "" {
  78. c = "="
  79. }
  80. break
  81. }
  82. if kv.Key == "p" {
  83. c = "=" + kv.Val
  84. break
  85. }
  86. }
  87. // get password
  88. p := c
  89. if p == "" {
  90. usage()
  91. return
  92. } else if p[0:1] == "=" {
  93. p = p[1:]
  94. } else {
  95. x, e := base64.RawStdEncoding.DecodeString(p)
  96. if e == nil {
  97. p = string(x)
  98. }
  99. // else 不是Base64编码,保持原值
  100. }
  101. // explainArgs key, val
  102. cmds := []*Command{{Cmd: "sh", Password: c, Regexps: []*Matcher{{Regexp: nil, Output: ""}}, Endregx: regxprompt}}
  103. i := 0
  104. // 掠过第一个参数,当前执行程序
  105. kvs = kvs[1:]
  106. for _, kv := range kvs {
  107. key, val := kv.Key, kv.Val
  108. switch key {
  109. case "", "cmd", "command", "c":
  110. cmds = append(cmds, &Command{Cmd: val, Password: c, Regexps: []*Matcher{{Regexp: nil, Output: ""}}, Endregx: regxprompt})
  111. case "ry", "regex_yes_no":
  112. re, err := regexp.Compile(val)
  113. if err != nil {
  114. logger.Error("arg", i, merrs.NewError(err))
  115. return
  116. } else {
  117. regxyesno = &Regexp{re}
  118. }
  119. case "rc", "regex_passcode", "regex_password":
  120. re, err := regexp.Compile(val)
  121. if err != nil {
  122. logger.Error("arg", i, merrs.NewError(err))
  123. return
  124. } else {
  125. regxpassword = &Regexp{re}
  126. }
  127. case "rp", "regex_prompt":
  128. re, err := regexp.Compile(val)
  129. if err != nil {
  130. logger.Error("arg", i, merrs.NewError(err))
  131. return
  132. } else {
  133. regxprompt = &Regexp{re}
  134. }
  135. case "password", "code", "pass", "p":
  136. cmds[len(cmds)-1].Password = "=" + val
  137. case "passcode", "b64code", "a":
  138. cmds[len(cmds)-1].Password = val
  139. case "re", "r", "regex":
  140. if val == "" {
  141. cmds[len(cmds)-1].Regexps = append(cmds[len(cmds)-1].Regexps, &Matcher{Regexp: nil})
  142. } else {
  143. re, err := regexp.Compile(val)
  144. if err != nil {
  145. logger.Error("arg", i, merrs.NewError(err))
  146. return
  147. } else {
  148. cmds[len(cmds)-1].Regexps = append(cmds[len(cmds)-1].Regexps, &Matcher{Regexp: &Regexp{re}})
  149. }
  150. }
  151. case "x", "end":
  152. if val == "" {
  153. cmds[len(cmds)-1].Endregx = regxprompt
  154. } else {
  155. re, err := regexp.Compile(val)
  156. if err != nil {
  157. logger.Error("arg", i, merrs.NewError(err))
  158. return
  159. } else {
  160. cmds[len(cmds)-1].Endregx = &Regexp{re}
  161. }
  162. }
  163. case "out", "o", "output":
  164. cmds[len(cmds)-1].Regexps[len(cmds[len(cmds)-1].Regexps)-1].Output += val
  165. case "outln", "n", "outputline":
  166. cmds[len(cmds)-1].Regexps[len(cmds[len(cmds)-1].Regexps)-1].Output += val + "\n"
  167. case "debug", "d":
  168. cmds[len(cmds)-1].Regexps[len(cmds[len(cmds)-1].Regexps)-1].Debug += val
  169. }
  170. }
  171. if strings.Index(cmds[0].Regexps[0].Debug, "a") >= 0 || strings.Index(cmds[0].Regexps[0].Debug, "1") >= 0 {
  172. //bs, _ := json.MarshalIndent(cmds, "", " ")
  173. bs, _ := yaml.Marshal(cmds)
  174. logger.Debug("arguments:\n" + string(bs))
  175. }
  176. client := &Client{
  177. User: u,
  178. Password: p,
  179. Commands: cmds,
  180. }
  181. client.Run()
  182. }
  183. type Regexp struct {
  184. *regexp.Regexp
  185. }
  186. func (r *Regexp) MarshalJSON() ([]byte, error) {
  187. if r == nil || r.Regexp == nil {
  188. return json.Marshal(nil)
  189. }
  190. return json.Marshal(r.String())
  191. }
  192. func (r *Regexp) MarshalYAML() (interface{}, error) {
  193. if r == nil || r.Regexp == nil {
  194. return nil, nil
  195. }
  196. return r.String(), nil
  197. }
  198. type Matcher struct {
  199. Regexp *Regexp
  200. Output string `yaml:"output"`
  201. Debug string `yaml:"debug"`
  202. }
  203. type Command struct {
  204. Cmd string `yaml:"cmd"`
  205. Password string `yaml:"password"`
  206. Regexps []*Matcher
  207. Endregx *Regexp
  208. }
  209. var regxyesno = &Regexp{regexp.MustCompile(`.*(\(yes\/no\)\?)\s*$`)}
  210. var regxpassword = &Regexp{regexp.MustCompile(`.*([Pp]assword:)\s*$`)}
  211. var regxprompt = &Regexp{regexp.MustCompile(`.*([%#\$\>])\s*$`)}
  212. type Client struct {
  213. User string
  214. Password string
  215. Commands []*Command
  216. }
  217. func Pipe(reader io.Reader, writer io.Writer, pfs ...func(lastbs []byte, sin string) (sout string)) {
  218. btr := NewBTReader("", reader, 30*time.Millisecond, 1024)
  219. lastbs := []byte{'\n'}
  220. for {
  221. bs, err := btr.Read()
  222. if err != nil {
  223. if err == io.EOF {
  224. return
  225. }
  226. logger.Error(merrs.NewError(err))
  227. return
  228. }
  229. if len(bs) > 0 {
  230. xbs := bs
  231. for _, pf := range pfs {
  232. if pf != nil {
  233. s := pf(lastbs, string(xbs))
  234. xbs = []byte(s)
  235. }
  236. }
  237. if writer != nil {
  238. writer.Write(xbs)
  239. }
  240. lastbs = bs
  241. }
  242. }
  243. }
  244. func (c *Client) Run() {
  245. logger.Info("msh ready")
  246. // Create arbitrary command.
  247. cmd := exec.Command("sh", "-c", "export PS1='\\h:> ';sh -i")
  248. // Start the command with a pty.
  249. ptmx, err := pty.Start(cmd)
  250. if err != nil {
  251. logger.Error(err)
  252. return
  253. }
  254. // Make sure to close the pty at the end.
  255. defer func() { _ = ptmx.Close() }() // Best effort.
  256. // Handle pty size.
  257. ch := HandlePTYSize(ptmx)
  258. if ch != nil {
  259. defer func() { signal.Stop(ch); close(ch) }() // Cleanup signals when done.
  260. }
  261. // Set stdin in raw mode.
  262. oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
  263. if err != nil {
  264. logger.Error(err)
  265. }
  266. go Pipe(os.Stdin, ptmx)
  267. e := c.IOProc(ptmx, os.Stdout, ptmx)
  268. if oldState != nil {
  269. term.Restore(int(os.Stdin.Fd()), oldState)
  270. fmt.Println()
  271. }
  272. if cmd.ProcessState.ExitCode() == -1 && e != nil {
  273. logger.Error(e)
  274. return
  275. }
  276. logger.Info("msh exit")
  277. return
  278. }
  279. func (c *Client) IOProc(cmdout io.Reader, stdout io.Writer, cmdin io.Writer) error {
  280. cmdidx := 0
  281. as := ""
  282. change := false
  283. stdoutbtr := NewBTReader("", cmdout, 50*time.Millisecond, 1024)
  284. for {
  285. bs, err := stdoutbtr.Read()
  286. if err != nil {
  287. if err == io.EOF {
  288. return nil
  289. }
  290. if _, y := err.(*fs.PathError); y {
  291. return nil
  292. }
  293. return merrs.NewError(err)
  294. }
  295. if len(bs) > 0 {
  296. change = true
  297. s := string(bs)
  298. stdout.Write([]byte(s))
  299. as += s
  300. } else if change {
  301. change = false
  302. if cmdidx < len(c.Commands) {
  303. if msi := c.Commands[cmdidx].Endregx.FindStringSubmatchIndex(as); msi != nil {
  304. // logger.Error("EndregexMatch", as)
  305. match := as[msi[0]:msi[1]]
  306. if strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "m") >= 0 || strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "1") >= 0 {
  307. logger.Trace("match end:", "'"+match+"'")
  308. }
  309. as = "" // 全清,开始新的命令
  310. cmdidx++
  311. if cmdidx >= len(c.Commands) {
  312. return nil
  313. }
  314. if strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "p") >= 0 || strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "1") >= 0 {
  315. logger.Trace("command:", c.Commands[cmdidx].Cmd)
  316. }
  317. // stdout.Write([]byte(c.Commands[cmdidx].Cmd + "\n"))
  318. cmdin.Write([]byte(c.Commands[cmdidx].Cmd + "\n"))
  319. continue
  320. }
  321. for _, regex := range c.Commands[cmdidx].Regexps {
  322. if regex == nil {
  323. continue
  324. }
  325. if regex.Regexp == nil {
  326. // like match all
  327. if regex.Output != "" {
  328. // stdout.Write([]byte(regex.Output))
  329. cmdin.Write([]byte(regex.Output))
  330. }
  331. } else if msi := regex.Regexp.FindStringSubmatchIndex(as); msi != nil {
  332. match := as[msi[0]:msi[1]]
  333. if len(msi) >= 4 {
  334. match = as[msi[2]:msi[3]]
  335. as = as[msi[3]:] // 清除已处理完的内容
  336. } else {
  337. as = as[msi[1]:] // 清除已处理完的内容
  338. }
  339. if strings.Index(regex.Debug, "m") >= 0 || strings.Index(regex.Debug, "1") >= 0 {
  340. logger.Trace("match regex:", "'"+match+"'")
  341. }
  342. if regex.Output != "" {
  343. // stdout.Write([]byte(regex.Output))
  344. cmdin.Write([]byte(regex.Output))
  345. }
  346. }
  347. }
  348. }
  349. if msi := regxyesno.FindStringSubmatchIndex(as); msi != nil {
  350. match := as[msi[0]:msi[1]]
  351. if len(msi) >= 4 {
  352. as = as[msi[3]:] // 清除已处理完的内容
  353. } else {
  354. as = as[msi[1]:] // 清除已处理完的内容
  355. }
  356. if strings.Index(c.Commands[0].Regexps[0].Debug, "m") >= 0 || strings.Index(c.Commands[0].Regexps[0].Debug, "1") >= 0 ||
  357. cmdidx < len(c.Commands) &&
  358. (strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "m") >= 0 || strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "1") >= 0) {
  359. logger.Trace("match yesno:", "'"+match+"'")
  360. }
  361. // stdout.Write([]byte("yes\n"))
  362. cmdin.Write([]byte("yes\n"))
  363. }
  364. if msi := regxpassword.FindStringSubmatchIndex(as); msi != nil {
  365. // logger.Error(as)
  366. match := as[msi[0]:msi[1]]
  367. if len(msi) >= 4 {
  368. as = as[msi[3]:] // 清除已处理完的内容
  369. } else {
  370. as = as[msi[1]:] // 清除已处理完的内容
  371. }
  372. p := c.Commands[0].Password
  373. if cmdidx < len(c.Commands) {
  374. p = c.Commands[cmdidx].Password
  375. }
  376. if strings.Index(c.Commands[0].Regexps[0].Debug, "m") >= 0 || strings.Index(c.Commands[0].Regexps[0].Debug, "1") >= 0 ||
  377. cmdidx < len(c.Commands) &&
  378. (strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "m") >= 0 || strings.Index(c.Commands[cmdidx].Regexps[0].Debug, "1") >= 0) {
  379. logger.Trace("match password:", "'"+match+"'")
  380. }
  381. if p != "" {
  382. if p[0:1] == "=" {
  383. p = p[1:]
  384. } else {
  385. x, e := base64.RawStdEncoding.DecodeString(p)
  386. if e == nil {
  387. p = string(x)
  388. }
  389. // else 不是Base64编码,保持原值
  390. }
  391. // don't echo password
  392. if c.Commands[0].Regexps[0].Debug != "" || cmdidx < len(c.Commands) && c.Commands[cmdidx].Regexps[0].Debug != "" {
  393. stdout.Write([]byte(p + "\n"))
  394. }
  395. cmdin.Write([]byte(p + "\n"))
  396. }
  397. }
  398. if len(as) > 1024 {
  399. as = as[len(as)-1024:]
  400. }
  401. }
  402. }
  403. }
  404. type BTReader struct {
  405. *bufio.Reader
  406. flag string
  407. chbs chan []byte
  408. cher chan error
  409. timeout time.Duration
  410. sizeout int
  411. err error
  412. }
  413. func NewBTReader(flag string, reader io.Reader, timeout time.Duration, sizeout int) (me *BTReader) {
  414. me = &BTReader{Reader: bufio.NewReader(reader),
  415. flag: flag,
  416. chbs: make(chan []byte),
  417. cher: make(chan error),
  418. timeout: timeout,
  419. sizeout: sizeout,
  420. }
  421. go func() {
  422. bs := make([]byte, me.Size())
  423. for {
  424. n, err := me.Reader.Read(bs[:1])
  425. if err != nil {
  426. me.cher <- err
  427. return
  428. }
  429. x, err := me.Reader.Read(bs[1 : me.Reader.Buffered()+1])
  430. if err != nil {
  431. me.cher <- err
  432. return
  433. }
  434. n += x
  435. abs := make([]byte, n)
  436. copy(abs, bs[:n])
  437. me.chbs <- abs
  438. }
  439. }()
  440. return
  441. }
  442. // 指定时间内没有新数据进入,且有积累数据,或积累数据超过指定数量,即返回
  443. func (me *BTReader) Read() (rbs []byte, err error) {
  444. if me.err != nil {
  445. return nil, me.err
  446. }
  447. for {
  448. t := time.NewTimer(me.timeout)
  449. select {
  450. case me.err = <-me.cher:
  451. if len(rbs) > 0 {
  452. // 返回最后的数据,下次读时才返回错误
  453. return rbs, nil
  454. }
  455. return nil, me.err
  456. case abs := <-me.chbs:
  457. rbs = append(rbs, abs...)
  458. if len(rbs) > me.sizeout {
  459. return
  460. }
  461. t.Stop()
  462. t.Reset(me.timeout)
  463. case <-t.C:
  464. if len(rbs) == 0 {
  465. t.Stop()
  466. t.Reset(me.timeout)
  467. }
  468. return
  469. }
  470. }
  471. }