packetconn.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. package icmp
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "net"
  6. "sync"
  7. "syscall"
  8. "time"
  9. "golang.org/x/net/icmp"
  10. "golang.org/x/net/ipv4"
  11. "golang.org/x/net/ipv6"
  12. )
  13. const (
  14. protocolICMP = 1
  15. protocolIPv6ICMP = 58
  16. )
  17. var ENOLISTENER = fmt.Errorf("no listener")
  18. type Type icmp.Type
  19. type PacketConn interface {
  20. Close() error
  21. ICMPRequestType() Type
  22. ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error)
  23. SetFlagTTL() error
  24. SetReadDeadline(t time.Time) error
  25. WriteTo(b []byte, dst net.Addr) (int, error)
  26. SetTTL(ttl int)
  27. }
  28. var (
  29. ipv4Proto = map[string]string{"icmp": "ip4:icmp", "udp": "udp4"}
  30. ipv6Proto = map[string]string{"icmp": "ip6:ipv6-icmp", "udp": "udp6"}
  31. )
  32. func Listen(ipv4 bool, protocol string, source string) (conn PacketConn, err error) {
  33. if ipv4 {
  34. var c icmpv4Conn
  35. c.c, err = icmp.ListenPacket(ipv4Proto[protocol], source)
  36. conn = &c
  37. } else {
  38. var c icmpV6Conn
  39. c.c, err = icmp.ListenPacket(ipv6Proto[protocol], source)
  40. conn = &c
  41. }
  42. return
  43. }
  44. // Packet represents a received and processed ICMP echo packet.
  45. type Packet struct {
  46. // IPAddr is the address of the host being pinged.
  47. IPAddr *net.IPAddr
  48. // ID is the ICMP identifier.
  49. ID int
  50. // Seq is the ICMP sequence number.
  51. Seq int
  52. // TTL is the Time To Live on the packet.
  53. TTL int
  54. // NBytes is the number of bytes in the message.
  55. Nbytes int
  56. // Rtt is the round-trip time it took to ping.
  57. Rtt time.Duration
  58. }
  59. type recvPkt struct {
  60. recvtime time.Time
  61. addr net.Addr
  62. bytes []byte
  63. nbytes int
  64. ttl int
  65. }
  66. type MPacketConn struct {
  67. IPV4 bool
  68. Protocol string
  69. Source string
  70. Backlog int
  71. OnRecvPacket func(addr net.Addr, pkt *Packet)
  72. OnRecvError func(error)
  73. mutex sync.Mutex
  74. conn PacketConn
  75. done chan interface{}
  76. recvbuf chan *recvPkt
  77. }
  78. func (mp *MPacketConn) Listen() error {
  79. conn, err := Listen(mp.IPV4, mp.Protocol, mp.Source)
  80. if err != nil {
  81. return err
  82. }
  83. conn.SetTTL(64)
  84. if err := conn.SetFlagTTL(); err != nil {
  85. return err
  86. }
  87. mp.done = make(chan interface{})
  88. mp.recvbuf = make(chan *recvPkt, mp.Backlog)
  89. mp.conn = conn
  90. go mp.recvICMP()
  91. go mp.processRecvPacket()
  92. return nil
  93. }
  94. func (mp *MPacketConn) Close() error {
  95. mp.mutex.Lock()
  96. defer mp.mutex.Unlock()
  97. open := true
  98. select {
  99. case _, open = <-mp.done:
  100. default:
  101. }
  102. if open {
  103. close(mp.done)
  104. }
  105. return mp.conn.Close()
  106. }
  107. func (mp *MPacketConn) recvICMP() {
  108. bytes := make([]byte, 65536)
  109. for {
  110. select {
  111. case <-mp.done:
  112. return
  113. default:
  114. var n, ttl int
  115. var addr net.Addr
  116. var err error
  117. n, ttl, addr, err = mp.conn.ReadFrom(bytes)
  118. if err != nil {
  119. if neterr, ok := err.(*net.OpError); ok {
  120. if neterr.Timeout() {
  121. // Read timeout
  122. continue
  123. }
  124. }
  125. if mp.OnRecvError != nil {
  126. mp.OnRecvError(err)
  127. } else {
  128. fmt.Println(err)
  129. }
  130. }
  131. bs := make([]byte, n)
  132. copy(bs, bytes[:n])
  133. select {
  134. case <-mp.done:
  135. return
  136. case mp.recvbuf <- &recvPkt{recvtime: time.Now(), addr: addr, bytes: bs, nbytes: n, ttl: ttl}:
  137. }
  138. }
  139. }
  140. }
  141. func (mp *MPacketConn) SendPacket(pkt *Packet, addr *net.IPAddr) error {
  142. if mp.conn == nil {
  143. return ENOLISTENER
  144. }
  145. msgBytes, err := BuildEchoRequestMessage(pkt.ID, pkt.Seq, pkt.Nbytes, mp.conn.ICMPRequestType())
  146. if err != nil {
  147. return err
  148. }
  149. var dst net.Addr = addr
  150. if mp.Protocol == "udp" {
  151. dst = &net.UDPAddr{IP: addr.IP, Zone: addr.Zone}
  152. }
  153. for {
  154. select {
  155. case <-mp.done:
  156. return nil
  157. default:
  158. }
  159. if _, err := mp.conn.WriteTo(msgBytes, dst); err != nil {
  160. if neterr, ok := err.(*net.OpError); ok {
  161. if neterr.Err == syscall.ENOBUFS {
  162. if mp.OnRecvError != nil {
  163. mp.OnRecvError(neterr.Err)
  164. } else {
  165. fmt.Println("缓存不够,发送失败,重发")
  166. }
  167. continue
  168. }
  169. }
  170. return err
  171. } else {
  172. return nil
  173. }
  174. }
  175. }
  176. func (mp *MPacketConn) processRecvPacket() {
  177. for pkt := range mp.recvbuf {
  178. err := mp.processPacket(pkt)
  179. if err != nil {
  180. if mp.OnRecvError != nil {
  181. mp.OnRecvError(err)
  182. } else {
  183. fmt.Println(err)
  184. }
  185. }
  186. }
  187. }
  188. var count = 0
  189. func (mp *MPacketConn) processPacket(recv *recvPkt) error {
  190. var proto int
  191. if mp.IPV4 {
  192. proto = protocolICMP
  193. } else {
  194. proto = protocolIPv6ICMP
  195. }
  196. // fmt.Println(count, "from", recv.addr.String(), "bytes", recv.bytes)
  197. var m *icmp.Message
  198. var err error
  199. if m, err = icmp.ParseMessage(proto, recv.bytes); err != nil {
  200. return fmt.Errorf("error parsing icmp message: %w", err)
  201. }
  202. if m.Type != ipv4.ICMPTypeEchoReply && m.Type != ipv6.ICMPTypeEchoReply {
  203. // Not an echo reply, ignore it
  204. return nil
  205. }
  206. switch pkt := m.Body.(type) {
  207. case *icmp.Echo:
  208. return mp.processEchoReply(pkt, recv)
  209. default:
  210. // Very bad, not sure how this can happen
  211. return fmt.Errorf("invalid ICMP echo reply; type: '%T', '%v'", pkt, pkt)
  212. }
  213. }
  214. func (mp *MPacketConn) processEchoReply(pkt *icmp.Echo, recv *recvPkt) error {
  215. if len(pkt.Data) < 24 {
  216. return nil
  217. }
  218. sendtime := int64(binary.BigEndian.Uint64(pkt.Data[:8]))
  219. fullseq := int(binary.BigEndian.Uint64(pkt.Data[8:16]))
  220. fullid := int(binary.BigEndian.Uint64(pkt.Data[16:24]))
  221. if fullid%65536 != pkt.ID || fullseq%65536 != pkt.Seq {
  222. return nil
  223. }
  224. inPkt := &Packet{
  225. Addr: recv.addr.String(),
  226. ID: fullid,
  227. Seq: fullseq,
  228. Nbytes: recv.nbytes,
  229. TTL: recv.ttl,
  230. Rtt: recv.recvtime.Sub(time.Unix(0, sendtime)),
  231. }
  232. // fmt.Printf("%s %d bytes from %s: icmp_seq=%d time=%v\n",
  233. // time.Now().Format("15:04:05.000"), inPkt.Nbytes, inPkt.IPAddr, inPkt.Seq, inPkt.Rtt)
  234. p.mutex.Lock()
  235. chpkt, inflight := pinfo.seqpkt[fullseq]
  236. if inflight {
  237. // remove it from the list of sequences we're waiting for so we don't get duplicates.
  238. delete(pinfo.seqpkt, fullseq)
  239. }
  240. p.mutex.Unlock()
  241. if chpkt != nil {
  242. chpkt <- inPkt
  243. }
  244. return nil
  245. }