mpconn.go 6.0 KB


  1. package probing
  2. import (
  3. "bytes"
  4. "fmt"
  5. "net"
  6. "sync"
  7. "sync/atomic"
  8. "syscall"
  9. "time"
  10. "github.com/google/uuid"
  11. "golang.org/x/net/icmp"
  12. "golang.org/x/net/ipv4"
  13. "golang.org/x/net/ipv6"
  14. )
  15. type mpingconn struct {
  16. mutex sync.Mutex
  17. ipv4 bool
  18. protocol string
  19. Source string
  20. Conn packetConn
  21. done chan interface{}
  22. pingid map[string]int
  23. pingseq map[int]int
  24. pingfcb map[int]func(*Packet)
  25. }
  26. var mpconn = &mpingconn{
  27. ipv4: true,
  28. protocol: "udp",
  29. Source: "",
  30. done: make(chan interface{}),
  31. pingid: make(map[string]int),
  32. pingseq: make(map[int]int),
  33. pingfcb: make(map[int]func(*Packet)),
  34. }
  35. func (p *mpingconn) listen() (packetConn, error) {
  36. p.mutex.Lock()
  37. defer p.mutex.Unlock()
  38. if p.Conn != nil {
  39. return p.Conn, nil
  40. }
  41. var (
  42. conn packetConn
  43. err error
  44. )
  45. if p.ipv4 {
  46. var c icmpv4Conn
  47. c.c, err = icmp.ListenPacket(ipv4Proto[p.protocol], p.Source)
  48. conn = &c
  49. } else {
  50. var c icmpV6Conn
  51. c.c, err = icmp.ListenPacket(ipv6Proto[p.protocol], p.Source)
  52. conn = &c
  53. }
  54. if err != nil {
  55. return nil, err
  56. }
  57. p.Conn = conn
  58. go func() {
  59. p.recvICMP()
  60. }()
  61. return p, nil
  62. }
  63. func (p *mpingconn) Close() error {
  64. p.mutex.Lock()
  65. defer p.mutex.Unlock()
  66. open := true
  67. select {
  68. case _, open = <-p.done:
  69. default:
  70. }
  71. if open {
  72. close(p.done)
  73. }
  74. return p.Conn.Close()
  75. }
  76. func (p *mpingconn) ICMPRequestType() icmp.Type {
  77. return p.Conn.ICMPRequestType()
  78. }
  79. func (p *mpingconn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
  80. return p.Conn.ReadFrom(b)
  81. }
  82. func (p *mpingconn) WriteTo(b []byte, dst net.Addr) (int, error) {
  83. return p.Conn.WriteTo(b, dst)
  84. }
  85. func (p *mpingconn) SetReadDeadline(t time.Time) error {
  86. return p.Conn.SetReadDeadline(t)
  87. }
  88. func (p *mpingconn) SetFlagTTL() error {
  89. return p.Conn.SetFlagTTL()
  90. }
  91. func (p *mpingconn) SetTTL(ttl int) {
  92. p.Conn.SetTTL(ttl)
  93. }
  94. var pingid int32
  95. func newPingID() int {
  96. return int(atomic.AddInt32(&pingid, 1))
  97. }
  98. func (p *mpingconn) ping(addr *net.IPAddr, size int, timeout time.Duration, onSend func(*Packet), onRecv func(*Packet)) (time.Duration, error) {
  99. currentUUID, err := uuid.NewUUID()
  100. if err != nil {
  101. return -1, fmt.Errorf("NewUUID: %w", err)
  102. }
  103. uuidEncoded, err := currentUUID.MarshalBinary()
  104. if err != nil {
  105. return -1, fmt.Errorf("unable to marshal UUID binary: %w", err)
  106. }
  107. p.mutex.Lock()
  108. var pid int
  109. var has bool
  110. if pid, has = p.pingid[addr.String()]; !has {
  111. pid = newPingID()
  112. p.pingid[addr.String()] = pid
  113. }
  114. var psq = p.pingseq[pid]
  115. p.pingseq[pid]++
  116. p.pingfcb[pid] = onRecv
  117. p.mutex.Unlock()
  118. t := append(timeToBytes(time.Now()), uuidEncoded...)
  119. if remainSize := size - len(t); remainSize > 0 {
  120. t = append(t, bytes.Repeat([]byte{1}, remainSize)...)
  121. }
  122. body := &icmp.Echo{
  123. ID: pid,
  124. Seq: psq,
  125. Data: t,
  126. }
  127. msg := &icmp.Message{
  128. Type: p.ICMPRequestType(),
  129. Code: 0,
  130. Body: body,
  131. }
  132. msgBytes, err := msg.Marshal(nil)
  133. if err != nil {
  134. return -1, err
  135. }
  136. err = p.sendICMP(msgBytes, addr)
  137. if err != nil {
  138. return -1, err
  139. }
  140. onSend(&Packet{
  141. Nbytes: len(msgBytes),
  142. IPAddr: addr,
  143. Seq: psq,
  144. ID: pid,
  145. })
  146. time.NewTimer(timeout)
  147. return nil
  148. }
  149. func (p *mpingconn) recvICMP(recv chan<- *packet) error {
  150. bytes := make([]byte, 65536)
  151. for {
  152. select {
  153. case <-p.done:
  154. return nil
  155. default:
  156. var n, ttl int
  157. var err error
  158. n, ttl, _, err = p.Conn.ReadFrom(bytes)
  159. if err != nil {
  160. if neterr, ok := err.(*net.OpError); ok {
  161. if neterr.Timeout() {
  162. // Read timeout
  163. continue
  164. }
  165. }
  166. return err
  167. }
  168. bs := make([]byte, n)
  169. copy(bs, bytes[:n])
  170. select {
  171. case <-p.done:
  172. return nil
  173. case recv <- &packet{bytes: bs, nbytes: n, ttl: ttl}:
  174. }
  175. }
  176. }
  177. }
  178. func (p *mpingconn) sendICMP(msgBytes []byte, addr *net.IPAddr) error {
  179. var dst net.Addr = addr
  180. if p.protocol == "udp" {
  181. dst = &net.UDPAddr{IP: addr.IP, Zone: addr.Zone}
  182. }
  183. for {
  184. select {
  185. case <-p.done:
  186. return nil
  187. default:
  188. }
  189. for {
  190. if _, err := p.Conn.WriteTo(msgBytes, dst); err != nil {
  191. if neterr, ok := err.(*net.OpError); ok {
  192. if neterr.Err == syscall.ENOBUFS {
  193. fmt.Println("缓存不够,发送失败,重发")
  194. continue
  195. }
  196. }
  197. return err
  198. }
  199. }
  200. }
  201. }
  202. func (p *mpingconn) processPacket(recv *packet) error {
  203. receivedAt := time.Now()
  204. var proto int
  205. if p.ipv4 {
  206. proto = protocolICMP
  207. } else {
  208. proto = protocolIPv6ICMP
  209. }
  210. var m *icmp.Message
  211. var err error
  212. if m, err = icmp.ParseMessage(proto, recv.bytes); err != nil {
  213. return fmt.Errorf("error parsing icmp message: %w", err)
  214. }
  215. if m.Type != ipv4.ICMPTypeEchoReply && m.Type != ipv6.ICMPTypeEchoReply {
  216. // Not an echo reply, ignore it
  217. return nil
  218. }
  219. inPkt := &Packet{
  220. Nbytes: recv.nbytes,
  221. IPAddr: p.ipaddr,
  222. Addr: p.addr,
  223. TTL: recv.ttl,
  224. ID: p.id,
  225. }
  226. switch pkt := m.Body.(type) {
  227. case *icmp.Echo:
  228. return p.processEchoReply(pkt, receivedAt, inPkt)
  229. default:
  230. // Very bad, not sure how this can happen
  231. return fmt.Errorf("invalid ICMP echo reply; type: '%T', '%v'", pkt, pkt)
  232. }
  233. }
  234. func (p *mpingconn) processEchoReply(pkt *icmp.Echo, receivedAt time.Time, inPkt *Packet) error {
  235. if !p.matchID(pkt.ID) {
  236. // lastpingtimemutex.Lock()
  237. // ap := idping[pkt.ID]
  238. // lastpingtimemutex.Unlock()
  239. // println(fmt.Sprintf("%#v%s%#v", ap, "\n", p))
  240. return nil
  241. }
  242. if len(pkt.Data) < timeSliceLength+trackerLength {
  243. return fmt.Errorf("insufficient data received; got: %d %v",
  244. len(pkt.Data), pkt.Data)
  245. }
  246. pktUUID, err := p.getPacketUUID(pkt.Data)
  247. if err != nil || pktUUID == nil {
  248. return err
  249. }
  250. timestamp := bytesToTime(pkt.Data[:timeSliceLength])
  251. inPkt.Rtt = receivedAt.Sub(timestamp)
  252. inPkt.Seq = pkt.Seq
  253. // If we've already received this sequence, ignore it.
  254. if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
  255. p.PacketsRecvDuplicates++
  256. if p.OnDuplicateRecv != nil {
  257. p.OnDuplicateRecv(inPkt)
  258. }
  259. return nil
  260. }
  261. // remove it from the list of sequences we're waiting for so we don't get duplicates.
  262. delete(p.awaitingSequences[*pktUUID], pkt.Seq)
  263. p.updateStatistics(inPkt)
  264. handler := p.OnRecv
  265. if handler != nil {
  266. handler(inPkt)
  267. }
  268. return nil
  269. }