mpconn.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. package probing
  2. import (
  3. "bytes"
  4. "net"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. "trial/ping/probing/icmp"
  9. "github.com/google/uuid"
  10. )
  11. var connm sync.Mutex
  12. var conns = map[string]*mpingconn{}
  13. func MPConn(ipv4 bool, protocol string) *mpingconn {
  14. key := protocol
  15. if ipv4 {
  16. key += "4"
  17. } else {
  18. key += "6"
  19. }
  20. connm.Lock()
  21. mp := conns[key]
  22. if mp == nil {
  23. mp = newMPConn(ipv4, protocol)
  24. conns[key] = mp
  25. }
  26. connm.Unlock()
  27. return mp
  28. }
  29. type mpingconn struct {
  30. *icmp.MPacketConn
  31. mutex sync.Mutex
  32. uuid uuid.UUID
  33. pingidinfo map[int]*mpinfo
  34. }
  35. type mpinfo struct {
  36. host string
  37. ipaddr *net.IPAddr
  38. id int
  39. lastseq int
  40. size int
  41. timeout time.Duration
  42. seqpkt map[int]*icmp.Packet
  43. OnSend func(*icmp.Packet)
  44. OnRecv func(*icmp.Packet)
  45. OnRecvDup func(*icmp.Packet)
  46. OnRecvError func(error)
  47. }
  48. var pingid int32
  49. var pinginfomutex sync.Mutex
  50. var pinginfo = map[string]*mpinfo{}
  51. func getPingInfo(ipaddr *net.IPAddr) *mpinfo {
  52. pinginfomutex.Lock()
  53. defer pinginfomutex.Unlock()
  54. pinfo := pinginfo[ipaddr.String()]
  55. if pinfo == nil {
  56. pinfo = &mpinfo{
  57. ipaddr: ipaddr,
  58. id: int(atomic.AddInt32(&pingid, 1)),
  59. seqpkt: make(map[int]*icmp.Packet),
  60. }
  61. pinginfo[ipaddr.String()] = pinfo
  62. }
  63. return pinfo
  64. }
  65. func newMPConn(ipv4 bool, protocol string) *mpingconn {
  66. mpconn := &mpingconn{
  67. MPacketConn: &icmp.MPacketConn{
  68. IPV4: ipv4,
  69. Protocol: protocol,
  70. Source: "",
  71. Backlog: receive_buffer_count,
  72. TTL: ping_ttl,
  73. },
  74. uuid: uuid.Must(uuid.NewUUID()),
  75. pingidinfo: make(map[int]*mpinfo),
  76. }
  77. mpconn.MPacketConn.OnRecvPacket = mpconn.OnRecvPacket
  78. mpconn.MPacketConn.OnRecvError = mpconn.OnRecvError
  79. return mpconn
  80. }
  81. func (p *mpingconn) Listen() error {
  82. err := p.MPacketConn.Listen()
  83. if err != nil {
  84. return err
  85. }
  86. return nil
  87. }
  88. func (p *mpingconn) Close() error {
  89. err := p.MPacketConn.Close()
  90. if err != nil {
  91. return err
  92. }
  93. return nil
  94. }
  95. func (p *mpingconn) OnRecvPacket(recvpkt *icmp.Packet) {
  96. // fmt.Println("recv", recvpkt)
  97. p.mutex.Lock()
  98. defer p.mutex.Unlock()
  99. pinfo := p.pingidinfo[recvpkt.ID]
  100. if pinfo == nil {
  101. return
  102. }
  103. if !bytes.Equal(p.uuid[:], recvpkt.UUID[:]) {
  104. return
  105. }
  106. _, inflight := pinfo.seqpkt[recvpkt.Seq]
  107. if inflight {
  108. // remove it from the list of sequences we're waiting for so we don't get duplicates.
  109. delete(pinfo.seqpkt, recvpkt.Seq)
  110. if pinfo.OnRecv != nil {
  111. go pinfo.OnRecv(recvpkt)
  112. }
  113. return
  114. }
  115. if pinfo.OnRecvDup != nil {
  116. go pinfo.OnRecvDup(recvpkt)
  117. }
  118. }
  119. func (p *mpingconn) OnRecvError(err error) {
  120. logger.Error(err)
  121. }
  122. func (p *mpingconn) Ping(pinfo *mpinfo) error {
  123. p.mutex.Lock()
  124. if _, has := p.pingidinfo[pinfo.id]; !has {
  125. p.pingidinfo[pinfo.id] = pinfo
  126. }
  127. seq := pinfo.lastseq
  128. pinfo.lastseq++
  129. outpkt := &icmp.Packet{
  130. IPAddr: pinfo.ipaddr,
  131. ID: pinfo.id,
  132. Seq: seq,
  133. UUID: p.uuid,
  134. Nbytes: pinfo.size,
  135. }
  136. pinfo.seqpkt[seq] = outpkt
  137. p.mutex.Unlock()
  138. err := p.SendPacket(outpkt)
  139. if err != nil {
  140. return err
  141. }
  142. if pinfo.OnSend != nil {
  143. pinfo.OnSend(outpkt)
  144. }
  145. // fmt.Println("sent", outpkt)
  146. return nil
  147. }