mpconn.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. package probing
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "net"
  6. "sync"
  7. "sync/atomic"
  8. "syscall"
  9. "time"
  10. "trial/ping/probing/icmp"
  11. "golang.org/x/net/ipv4"
  12. "golang.org/x/net/ipv6"
  13. )
  14. type recvPkt struct {
  15. recvtime time.Time
  16. addr net.Addr
  17. bytes []byte
  18. nbytes int
  19. ttl int
  20. }
  21. type mpingconn struct {
  22. mutex sync.Mutex
  23. ipv4 bool
  24. protocol string
  25. Source string
  26. Conn icmp.PacketConn
  27. done chan interface{}
  28. recvchan chan *recvPkt
  29. pinghostinfo map[string]*mpinfo
  30. pingidinfo map[int]*mpinfo
  31. OnError func(error)
  32. }
  33. type mpinfo struct {
  34. host string
  35. ipaddr *net.IPAddr
  36. id int
  37. lastseq int
  38. size int
  39. timeout time.Duration
  40. seqpkt map[int]chan *Packet
  41. OnSend func(*Packet)
  42. OnRecv func(*Packet)
  43. OnRecvDup func(*Packet)
  44. OnRecvError func(error)
  45. }
  46. var mpconn = newMPConn()
  47. func newMPConn() *mpingconn {
  48. mpconn := &mpingconn{
  49. ipv4: true,
  50. protocol: "udp",
  51. Source: "",
  52. done: make(chan interface{}),
  53. recvchan: make(chan *recvPkt, receive_buffer_count),
  54. pinghostinfo: make(map[string]*mpinfo),
  55. pingidinfo: make(map[int]*mpinfo),
  56. }
  57. return mpconn
  58. }
  59. func (p *mpingconn) listen() (icmp.PacketConn, error) {
  60. p.mutex.Lock()
  61. defer p.mutex.Unlock()
  62. if p.Conn != nil {
  63. return p.Conn, nil
  64. }
  65. conn, err := icmp.Listen(p.ipv4, p.protocol, p.Source)
  66. if err != nil {
  67. return nil, err
  68. }
  69. p.Conn = conn
  70. conn.SetTTL(64)
  71. if err := conn.SetFlagTTL(); err != nil {
  72. return nil, err
  73. }
  74. go p.recvICMP()
  75. go p.processRecvPacket()
  76. return p, nil
  77. }
  78. func (p *mpingconn) Close() error {
  79. p.mutex.Lock()
  80. defer p.mutex.Unlock()
  81. open := true
  82. select {
  83. case _, open = <-p.done:
  84. default:
  85. }
  86. if open {
  87. close(p.done)
  88. }
  89. return p.Conn.Close()
  90. }
  91. func (p *mpingconn) ICMPRequestType() icmp.Type {
  92. return p.Conn.ICMPRequestType()
  93. }
  94. func (p *mpingconn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
  95. return p.Conn.ReadFrom(b)
  96. }
  97. func (p *mpingconn) WriteTo(b []byte, dst net.Addr) (int, error) {
  98. return p.Conn.WriteTo(b, dst)
  99. }
  100. func (p *mpingconn) SetReadDeadline(t time.Time) error {
  101. return p.Conn.SetReadDeadline(t)
  102. }
  103. func (p *mpingconn) SetFlagTTL() error {
  104. return p.Conn.SetFlagTTL()
  105. }
  106. func (p *mpingconn) SetTTL(ttl int) {
  107. p.Conn.SetTTL(ttl)
  108. }
  109. var pingid int32
  110. var ETIMEDOUT error = fmt.Errorf("timeout")
  111. func newPingInfo(host string, ipaddr *net.IPAddr, size int, timeout time.Duration) *mpinfo {
  112. return &mpinfo{
  113. host: host,
  114. ipaddr: ipaddr,
  115. id: int(atomic.AddInt32(&pingid, 1)),
  116. seqpkt: make(map[int]chan *Packet),
  117. size: size,
  118. timeout: timeout,
  119. }
  120. }
  121. func (p *mpingconn) ping(host string, ipaddr *net.IPAddr, size int, timeout time.Duration, onSend func(*Packet), onRecv func(*Packet)) error {
  122. p.mutex.Lock()
  123. var pinfo *mpinfo
  124. var has bool
  125. if pinfo, has = p.pinghostinfo[ipaddr.String()]; !has {
  126. pinfo = newPingInfo(host, ipaddr, size, timeout)
  127. p.pinghostinfo[ipaddr.String()] = pinfo
  128. p.pingidinfo[pinfo.id] = pinfo
  129. }
  130. seq := pinfo.lastseq
  131. pinfo.lastseq++
  132. recvpkt := make(chan *Packet, 1)
  133. pinfo.seqpkt[seq] = recvpkt
  134. p.mutex.Unlock()
  135. msgBytes, err := icmp.BuildEchoRequestMessage(pinfo.id, seq, size, p.ICMPRequestType())
  136. if err != nil {
  137. return err
  138. }
  139. err = p.sendICMP(msgBytes, ipaddr)
  140. if err != nil {
  141. return err
  142. }
  143. outpkt := &Packet{
  144. Nbytes: len(msgBytes),
  145. Host: host,
  146. IPAddr: ipaddr,
  147. Seq: seq,
  148. ID: pinfo.id,
  149. }
  150. if onSend != nil {
  151. onSend(outpkt)
  152. }
  153. go func(onRecv func(*Packet), recvpkt chan *Packet) {
  154. t := time.NewTimer(timeout)
  155. select {
  156. case <-t.C:
  157. case inpkt := <-recvpkt:
  158. if onRecv != nil {
  159. onRecv(inpkt)
  160. }
  161. case <-p.done:
  162. }
  163. // clear();
  164. }(onRecv, recvpkt)
  165. return nil
  166. }
  167. func (p *mpingconn) sendICMP(msgBytes []byte, addr *net.IPAddr) error {
  168. var dst net.Addr = addr
  169. if p.protocol == "udp" {
  170. dst = &net.UDPAddr{IP: addr.IP, Zone: addr.Zone}
  171. }
  172. for {
  173. select {
  174. case <-p.done:
  175. return nil
  176. default:
  177. }
  178. if _, err := p.Conn.WriteTo(msgBytes, dst); err != nil {
  179. if neterr, ok := err.(*net.OpError); ok {
  180. if neterr.Err == syscall.ENOBUFS {
  181. if p.OnError != nil {
  182. p.OnError(neterr.Err)
  183. } else {
  184. fmt.Println("缓存不够,发送失败,重发")
  185. }
  186. continue
  187. }
  188. }
  189. return err
  190. } else {
  191. return nil
  192. }
  193. }
  194. }
  195. func (p *mpingconn) recvICMP() error {
  196. bytes := make([]byte, 65536)
  197. for {
  198. select {
  199. case <-p.done:
  200. return nil
  201. default:
  202. var n, ttl int
  203. var addr net.Addr
  204. var err error
  205. n, ttl, addr, err = p.Conn.ReadFrom(bytes)
  206. if err != nil {
  207. if neterr, ok := err.(*net.OpError); ok {
  208. if neterr.Timeout() {
  209. // Read timeout
  210. continue
  211. }
  212. }
  213. return err
  214. }
  215. bs := make([]byte, n)
  216. copy(bs, bytes[:n])
  217. select {
  218. case <-p.done:
  219. return nil
  220. case p.recvchan <- &recvPkt{recvtime: time.Now(), addr: addr, bytes: bs, nbytes: n, ttl: ttl}:
  221. }
  222. }
  223. }
  224. }
  225. func (p *mpingconn) processRecvPacket() {
  226. for pkt := range p.recvchan {
  227. if len(p.recvchan) > cap(p.recvchan)*9/10 {
  228. fmt.Printf("receive buffer full")
  229. }
  230. err := p.processPacket(pkt)
  231. if err != nil {
  232. if p.OnError != nil {
  233. p.OnError(err)
  234. } else {
  235. fmt.Println(err)
  236. }
  237. }
  238. }
  239. }
  240. var count = 0
  241. func (p *mpingconn) processPacket(recv *recvPkt) error {
  242. var proto int
  243. if p.ipv4 {
  244. proto = protocolICMP
  245. } else {
  246. proto = protocolIPv6ICMP
  247. }
  248. // fmt.Println(count, "from", recv.addr.String(), "bytes", recv.bytes)
  249. var m *icmp.Message
  250. var err error
  251. if m, err = icmp.ParseMessage(proto, recv.bytes); err != nil {
  252. return fmt.Errorf("error parsing icmp message: %w", err)
  253. }
  254. if m.Type != ipv4.ICMPTypeEchoReply && m.Type != ipv6.ICMPTypeEchoReply {
  255. // Not an echo reply, ignore it
  256. return nil
  257. }
  258. switch pkt := m.Body.(type) {
  259. case *icmp.Echo:
  260. return p.processEchoReply(pkt, recv)
  261. default:
  262. // Very bad, not sure how this can happen
  263. return fmt.Errorf("invalid ICMP echo reply; type: '%T', '%v'", pkt, pkt)
  264. }
  265. }
  266. func (p *mpingconn) processEchoReply(pkt *icmp.Echo, recv *recvPkt) error {
  267. if len(pkt.Data) < 24 {
  268. return nil
  269. }
  270. sendtime := int64(binary.BigEndian.Uint64(pkt.Data[:8]))
  271. fullseq := int(binary.BigEndian.Uint64(pkt.Data[8:16]))
  272. fullid := int(binary.BigEndian.Uint64(pkt.Data[16:24]))
  273. if fullid%65536 != pkt.ID || fullseq%65536 != pkt.Seq {
  274. return nil
  275. }
  276. p.mutex.Lock()
  277. pinfo := p.pingidinfo[fullid]
  278. p.mutex.Unlock()
  279. if pinfo == nil {
  280. return nil
  281. }
  282. inPkt := &Packet{
  283. Host: pinfo.host,
  284. IPAddr: pinfo.ipaddr,
  285. ID: pinfo.id,
  286. Seq: fullseq,
  287. Nbytes: recv.nbytes,
  288. TTL: recv.ttl,
  289. Rtt: recv.recvtime.Sub(time.Unix(0, sendtime)),
  290. }
  291. // fmt.Printf("%s %d bytes from %s: icmp_seq=%d time=%v\n",
  292. // time.Now().Format("15:04:05.000"), inPkt.Nbytes, inPkt.IPAddr, inPkt.Seq, inPkt.Rtt)
  293. p.mutex.Lock()
  294. chpkt, inflight := pinfo.seqpkt[fullseq]
  295. if inflight {
  296. // remove it from the list of sequences we're waiting for so we don't get duplicates.
  297. delete(pinfo.seqpkt, fullseq)
  298. }
  299. p.mutex.Unlock()
  300. if chpkt != nil {
  301. chpkt <- inPkt
  302. }
  303. return nil
  304. }