mpconn.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. package probing
  2. import (
  3. "bytes"
  4. "net"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. "trial/ping/probing/icmp"
  9. "github.com/google/uuid"
  10. )
  11. var connm sync.Mutex
  12. var conns = map[string]*mpingconn{}
  13. func MPConn(ipv4 bool, protocol string) *mpingconn {
  14. key := protocol
  15. if ipv4 {
  16. key += "4"
  17. } else {
  18. key += "6"
  19. }
  20. connm.Lock()
  21. mp := conns[key]
  22. if mp == nil {
  23. mp = newMPConn(ipv4, protocol)
  24. conns[key] = mp
  25. }
  26. connm.Unlock()
  27. return mp
  28. }
  29. type mpingconn struct {
  30. *icmp.MPacketConn
  31. mutex sync.Mutex
  32. uuid uuid.UUID
  33. pingidinfo map[int]*mpinfo
  34. }
  35. type mpinfo struct {
  36. host string
  37. ipaddr *net.IPAddr
  38. id int
  39. lastseq int
  40. size int
  41. timeout time.Duration
  42. seqpkt map[int]*icmp.Packet
  43. OnSend func(outpkt *icmp.Packet)
  44. OnRecv func(inpkt *icmp.Packet)
  45. OnRecvDup func(inpkt *icmp.Packet)
  46. OnRecvTimeout func(outpkt *icmp.Packet)
  47. }
  48. var pingid int32
  49. var pinginfomutex sync.Mutex
  50. var pinginfo = map[string]*mpinfo{}
  51. func getPingInfo(ipaddr *net.IPAddr) *mpinfo {
  52. pinginfomutex.Lock()
  53. defer pinginfomutex.Unlock()
  54. pinfo := pinginfo[ipaddr.String()]
  55. if pinfo == nil {
  56. pinfo = &mpinfo{
  57. ipaddr: ipaddr,
  58. id: int(atomic.AddInt32(&pingid, 1)),
  59. seqpkt: make(map[int]*icmp.Packet),
  60. }
  61. pinginfo[ipaddr.String()] = pinfo
  62. }
  63. return pinfo
  64. }
  65. func newMPConn(ipv4 bool, protocol string) *mpingconn {
  66. mpconn := &mpingconn{
  67. MPacketConn: &icmp.MPacketConn{
  68. IPV4: ipv4,
  69. Protocol: protocol,
  70. Source: "",
  71. Backlog: receive_buffer_count,
  72. TTL: ping_ttl,
  73. },
  74. uuid: uuid.Must(uuid.NewUUID()),
  75. pingidinfo: make(map[int]*mpinfo),
  76. }
  77. mpconn.MPacketConn.OnRecvPacket = mpconn.OnRecvPacket
  78. mpconn.MPacketConn.OnError = mpconn.OnError
  79. return mpconn
  80. }
  81. func (p *mpingconn) Listen() error {
  82. err := p.MPacketConn.Listen()
  83. if err != nil {
  84. return err
  85. }
  86. return nil
  87. }
  88. func (p *mpingconn) Close() error {
  89. err := p.MPacketConn.Close()
  90. if err != nil {
  91. return err
  92. }
  93. return nil
  94. }
  95. func (p *mpingconn) OnRecvPacket(recvpkt *icmp.Packet) {
  96. // fmt.Println("recv", recvpkt)
  97. p.mutex.Lock()
  98. defer p.mutex.Unlock()
  99. pinfo := p.pingidinfo[recvpkt.ID]
  100. if pinfo == nil {
  101. return
  102. }
  103. if !bytes.Equal(p.uuid[:], recvpkt.UUID[:]) {
  104. return
  105. }
  106. outpkt, inflight := pinfo.seqpkt[recvpkt.Seq]
  107. if inflight {
  108. // remove it from the list of sequences we're waiting for so we don't get duplicates.
  109. if outpkt.TimeoutTimer != nil {
  110. outpkt.TimeoutTimer.Stop()
  111. }
  112. delete(pinfo.seqpkt, recvpkt.Seq)
  113. if pinfo.OnRecv != nil {
  114. go pinfo.OnRecv(recvpkt)
  115. }
  116. return
  117. }
  118. if pinfo.OnRecvDup != nil {
  119. go pinfo.OnRecvDup(recvpkt)
  120. }
  121. }
  122. func (p *mpingconn) OnError(err error) {
  123. logger.Error(err)
  124. }
  125. func (p *mpingconn) Ping(pinfo *mpinfo) error {
  126. p.mutex.Lock()
  127. if _, has := p.pingidinfo[pinfo.id]; !has {
  128. p.pingidinfo[pinfo.id] = pinfo
  129. }
  130. seq := pinfo.lastseq
  131. pinfo.lastseq++
  132. outpkt := &icmp.Packet{
  133. IPAddr: pinfo.ipaddr,
  134. ID: pinfo.id,
  135. Seq: seq,
  136. UUID: p.uuid,
  137. Nbytes: pinfo.size,
  138. }
  139. pinfo.seqpkt[seq] = outpkt
  140. p.mutex.Unlock()
  141. err := p.SendPacket(outpkt)
  142. if err != nil {
  143. return err
  144. }
  145. if pinfo.OnSend != nil {
  146. pinfo.OnSend(outpkt)
  147. }
  148. if pinfo.OnRecvTimeout != nil {
  149. outpkt.TimeoutTimer = time.AfterFunc(pinfo.timeout, func() {
  150. p.mutex.Lock()
  151. outpkt := pinfo.seqpkt[seq]
  152. p.mutex.Unlock()
  153. if outpkt != nil {
  154. pinfo.OnRecvTimeout(outpkt)
  155. }
  156. })
  157. }
  158. // fmt.Println("sent", outpkt)
  159. return nil
  160. }