packetconn.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. package icmp
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "net"
  6. "sync"
  7. "syscall"
  8. "time"
  9. "golang.org/x/net/icmp"
  10. "golang.org/x/net/ipv4"
  11. "golang.org/x/net/ipv6"
  12. )
  13. const (
  14. protocolICMP = 1
  15. protocolIPv6ICMP = 58
  16. )
  17. var ENOLISTENER = fmt.Errorf("no listener")
  18. type Type icmp.Type
  19. type PacketConn interface {
  20. Close() error
  21. ICMPRequestType() Type
  22. ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error)
  23. SetFlagTTL() error
  24. SetReadDeadline(t time.Time) error
  25. WriteTo(b []byte, dst net.Addr) (int, error)
  26. SetTTL(ttl int)
  27. }
  28. var (
  29. ipv4Proto = map[string]string{"icmp": "ip4:icmp", "udp": "udp4"}
  30. ipv6Proto = map[string]string{"icmp": "ip6:ipv6-icmp", "udp": "udp6"}
  31. )
  32. func listen(ipv4 bool, protocol string, source string) (conn PacketConn, err error) {
  33. if ipv4 {
  34. var c icmpv4Conn
  35. c.c, err = icmp.ListenPacket(ipv4Proto[protocol], source)
  36. conn = &c
  37. } else {
  38. var c icmpV6Conn
  39. c.c, err = icmp.ListenPacket(ipv6Proto[protocol], source)
  40. conn = &c
  41. }
  42. return
  43. }
  44. // Packet represents a received and processed ICMP echo packet.
  45. type Packet struct {
  46. // IPAddr is the address of the host being pinged.
  47. IPAddr *net.IPAddr
  48. // ID is the ICMP identifier.
  49. ID int
  50. // Seq is the ICMP sequence number.
  51. Seq int
  52. // TTL is the Time To Live on the packet.
  53. TTL int
  54. // NBytes is the number of bytes in the message.
  55. Nbytes int
  56. // Rtt is the round-trip time it took to ping.
  57. Rtt time.Duration
  58. }
  59. type recvPkt struct {
  60. recvtime time.Time
  61. addr net.Addr
  62. bytes []byte
  63. nbytes int
  64. ttl int
  65. }
  66. type MPacketConn struct {
  67. IPV4 bool
  68. Protocol string
  69. Source string
  70. Backlog int
  71. TTL int
  72. OnRecvPacket func(pkt *Packet)
  73. OnRecvError func(error)
  74. mutex sync.Mutex
  75. conn PacketConn
  76. done chan interface{}
  77. recvbuf chan *recvPkt
  78. }
  79. func (mp *MPacketConn) Listen() error {
  80. mp.mutex.Lock()
  81. defer mp.mutex.Unlock()
  82. if mp.conn != nil {
  83. return nil
  84. }
  85. conn, err := listen(mp.IPV4, mp.Protocol, mp.Source)
  86. if err != nil {
  87. return err
  88. }
  89. conn.SetTTL(mp.TTL)
  90. if err := conn.SetFlagTTL(); err != nil {
  91. return err
  92. }
  93. mp.done = make(chan interface{})
  94. mp.recvbuf = make(chan *recvPkt, mp.Backlog)
  95. mp.conn = conn
  96. go mp.recvICMP()
  97. go mp.processRecvPacket()
  98. return nil
  99. }
  100. func (mp *MPacketConn) Close() error {
  101. mp.mutex.Lock()
  102. defer mp.mutex.Unlock()
  103. open := true
  104. select {
  105. case _, open = <-mp.done:
  106. default:
  107. }
  108. if open {
  109. close(mp.done)
  110. }
  111. return mp.conn.Close()
  112. }
  113. func (mp *MPacketConn) recvICMP() {
  114. bytes := make([]byte, 65536)
  115. for {
  116. select {
  117. case <-mp.done:
  118. return
  119. default:
  120. var n, ttl int
  121. var addr net.Addr
  122. var err error
  123. n, ttl, addr, err = mp.conn.ReadFrom(bytes)
  124. if err != nil {
  125. if neterr, ok := err.(*net.OpError); ok {
  126. if neterr.Timeout() {
  127. // Read timeout
  128. continue
  129. }
  130. }
  131. if mp.OnRecvError != nil {
  132. mp.OnRecvError(err)
  133. } else {
  134. fmt.Println(err)
  135. }
  136. }
  137. bs := make([]byte, n)
  138. copy(bs, bytes[:n])
  139. select {
  140. case <-mp.done:
  141. return
  142. case mp.recvbuf <- &recvPkt{recvtime: time.Now(), addr: addr, bytes: bs, nbytes: n, ttl: ttl}:
  143. }
  144. }
  145. }
  146. }
  147. func (mp *MPacketConn) SendPacket(pkt *Packet) error {
  148. if mp.conn == nil {
  149. return ENOLISTENER
  150. }
  151. msgBytes, err := BuildEchoRequestMessage(pkt.ID, pkt.Seq, pkt.Nbytes, mp.conn.ICMPRequestType())
  152. if err != nil {
  153. return err
  154. }
  155. var dst net.Addr = pkt.IPAddr
  156. if mp.Protocol == "udp" {
  157. dst = &net.UDPAddr{IP: pkt.IPAddr.IP, Zone: pkt.IPAddr.Zone}
  158. }
  159. for {
  160. select {
  161. case <-mp.done:
  162. return nil
  163. default:
  164. }
  165. if _, err := mp.conn.WriteTo(msgBytes, dst); err != nil {
  166. if neterr, ok := err.(*net.OpError); ok {
  167. if neterr.Err == syscall.ENOBUFS {
  168. if mp.OnRecvError != nil {
  169. mp.OnRecvError(neterr.Err)
  170. } else {
  171. fmt.Println("缓存不够,发送失败,重发")
  172. }
  173. continue
  174. }
  175. }
  176. return err
  177. } else {
  178. return nil
  179. }
  180. }
  181. }
  182. func (mp *MPacketConn) processRecvPacket() {
  183. for pkt := range mp.recvbuf {
  184. err := mp.processPacket(pkt)
  185. if err != nil {
  186. if mp.OnRecvError != nil {
  187. mp.OnRecvError(err)
  188. } else {
  189. fmt.Println(err)
  190. }
  191. }
  192. }
  193. }
  194. var count = 0
  195. func (mp *MPacketConn) processPacket(recv *recvPkt) error {
  196. var proto int
  197. if mp.IPV4 {
  198. proto = protocolICMP
  199. } else {
  200. proto = protocolIPv6ICMP
  201. }
  202. // fmt.Println(count, "from", recv.addr.String(), "bytes", recv.bytes)
  203. var m *icmp.Message
  204. var err error
  205. if m, err = icmp.ParseMessage(proto, recv.bytes); err != nil {
  206. return fmt.Errorf("error parsing icmp message: %w", err)
  207. }
  208. if m.Type != ipv4.ICMPTypeEchoReply && m.Type != ipv6.ICMPTypeEchoReply {
  209. // Not an echo reply, ignore it
  210. return nil
  211. }
  212. switch pkt := m.Body.(type) {
  213. case *icmp.Echo:
  214. return mp.processEchoReply(pkt, recv)
  215. default:
  216. // Very bad, not sure how this can happen
  217. return fmt.Errorf("invalid ICMP echo reply; type: '%T', '%v'", pkt, pkt)
  218. }
  219. }
  220. func (mp *MPacketConn) processEchoReply(pkt *icmp.Echo, recv *recvPkt) error {
  221. if len(pkt.Data) < 24 {
  222. return nil
  223. }
  224. sendtime := int64(binary.BigEndian.Uint64(pkt.Data[:8]))
  225. fullseq := int(binary.BigEndian.Uint64(pkt.Data[8:16]))
  226. fullid := int(binary.BigEndian.Uint64(pkt.Data[16:24]))
  227. if fullid%65536 != pkt.ID || fullseq%65536 != pkt.Seq {
  228. return nil
  229. }
  230. // fmt.Printf("%s %d bytes from %s: icmp_seq=%d time=%v\n",
  231. // time.Now().Format("15:04:05.000"), recv.nbytes, recv.addr, fullseq, recv.recvtime.Sub(time.Unix(0, sendtime)))
  232. if mp.OnRecvPacket != nil {
  233. mp.OnRecvPacket(&Packet{
  234. IPAddr: netAddrToIPAddr(recv.addr),
  235. ID: fullid,
  236. Seq: fullseq,
  237. Nbytes: recv.nbytes,
  238. TTL: recv.ttl,
  239. Rtt: recv.recvtime.Sub(time.Unix(0, sendtime)),
  240. })
  241. }
  242. return nil
  243. }