mpconn.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package probing
  2. import (
  3. "fmt"
  4. "net"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. "trial/ping/probing/icmp"
  9. "git.wecise.com/wecise/common/logger"
  10. )
  11. var connm sync.Mutex
  12. var conns = map[string]*mpingconn{}
  13. func MPConn(ipv4 bool, protocol string) *mpingconn {
  14. key := protocol
  15. if ipv4 {
  16. key += "4"
  17. } else {
  18. key += "6"
  19. }
  20. connm.Lock()
  21. mp := conns[key]
  22. if mp == nil {
  23. mp = newMPConn(ipv4, protocol)
  24. conns[key] = mp
  25. }
  26. connm.Unlock()
  27. return mp
  28. }
  29. type mpingconn struct {
  30. *icmp.MPacketConn
  31. mutex sync.Mutex
  32. pingidinfo map[int]*mpinfo
  33. }
  34. type mpinfo struct {
  35. host string
  36. ipaddr *net.IPAddr
  37. id int
  38. lastseq int
  39. size int
  40. timeout time.Duration
  41. seqpkt map[int]*icmp.Packet
  42. OnSend func(*icmp.Packet)
  43. OnRecv func(*icmp.Packet)
  44. OnRecvDup func(*icmp.Packet)
  45. OnRecvError func(error)
  46. }
  47. var pingid int32
  48. var ETIMEDOUT error = fmt.Errorf("timeout")
  49. func newPingInfo(host string, ipaddr *net.IPAddr, size int, timeout time.Duration) *mpinfo {
  50. return &mpinfo{
  51. host: host,
  52. ipaddr: ipaddr,
  53. id: int(atomic.AddInt32(&pingid, 1)),
  54. seqpkt: make(map[int]*icmp.Packet),
  55. size: size,
  56. timeout: timeout,
  57. }
  58. }
  59. func newMPConn(ipv4 bool, protocol string) *mpingconn {
  60. mpconn := &mpingconn{
  61. MPacketConn: &icmp.MPacketConn{
  62. IPV4: ipv4,
  63. Protocol: protocol,
  64. Source: "",
  65. Backlog: 10,
  66. TTL: 64,
  67. },
  68. pingidinfo: make(map[int]*mpinfo),
  69. }
  70. mpconn.MPacketConn.OnRecvPacket = mpconn.OnRecvPacket
  71. mpconn.MPacketConn.OnRecvError = mpconn.OnRecvError
  72. return mpconn
  73. }
  74. func (p *mpingconn) Listen() error {
  75. err := p.MPacketConn.Listen()
  76. if err != nil {
  77. return err
  78. }
  79. return nil
  80. }
  81. func (p *mpingconn) Close() error {
  82. err := p.MPacketConn.Close()
  83. if err != nil {
  84. return err
  85. }
  86. return nil
  87. }
  88. func (p *mpingconn) OnRecvPacket(recvpkt *icmp.Packet) {
  89. // fmt.Println("recv", recvpkt)
  90. p.mutex.Lock()
  91. defer p.mutex.Unlock()
  92. pinfo := p.pingidinfo[recvpkt.ID]
  93. if pinfo == nil {
  94. return
  95. }
  96. _, inflight := pinfo.seqpkt[recvpkt.Seq]
  97. if inflight {
  98. // remove it from the list of sequences we're waiting for so we don't get duplicates.
  99. delete(pinfo.seqpkt, recvpkt.Seq)
  100. if pinfo.OnRecv != nil {
  101. go pinfo.OnRecv(recvpkt)
  102. }
  103. return
  104. }
  105. if pinfo.OnRecvDup != nil {
  106. go pinfo.OnRecvDup(recvpkt)
  107. }
  108. }
  109. func (p *mpingconn) OnRecvError(err error) {
  110. logger.Error(err)
  111. }
  112. func (p *mpingconn) Ping(pinfo *mpinfo) error {
  113. p.mutex.Lock()
  114. if _, has := p.pingidinfo[pinfo.id]; !has {
  115. p.pingidinfo[pinfo.id] = pinfo
  116. }
  117. seq := pinfo.lastseq
  118. pinfo.lastseq++
  119. outpkt := &icmp.Packet{
  120. IPAddr: pinfo.ipaddr,
  121. ID: pinfo.id,
  122. Seq: seq,
  123. Nbytes: pinfo.size,
  124. }
  125. pinfo.seqpkt[seq] = outpkt
  126. p.mutex.Unlock()
  127. err := p.SendPacket(outpkt)
  128. if err != nil {
  129. return err
  130. }
  131. if pinfo.OnSend != nil {
  132. pinfo.OnSend(outpkt)
  133. }
  134. // fmt.Println("sent", outpkt)
  135. return nil
  136. }