packetconn.go 6.3 KB


  1. package icmp
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "net"
  6. "sync"
  7. "syscall"
  8. "time"
  9. "github.com/google/uuid"
  10. "golang.org/x/net/icmp"
  11. "golang.org/x/net/ipv4"
  12. "golang.org/x/net/ipv6"
  13. )
  14. const (
  15. protocolICMP = 1
  16. protocolIPv6ICMP = 58
  17. )
  18. var ENOLISTENER = fmt.Errorf("no listener")
  19. type Type icmp.Type
  20. type PacketConn interface {
  21. Close() error
  22. ICMPRequestType() Type
  23. ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error)
  24. SetFlagTTL() error
  25. SetReadDeadline(t time.Time) error
  26. WriteTo(b []byte, dst net.Addr) (int, error)
  27. SetTTL(ttl int)
  28. }
  29. var (
  30. ipv4Proto = map[string]string{"icmp": "ip4:icmp", "udp": "udp4"}
  31. ipv6Proto = map[string]string{"icmp": "ip6:ipv6-icmp", "udp": "udp6"}
  32. )
  33. // Packet represents a received and processed ICMP echo packet.
  34. type Packet struct {
  35. // IPAddr is the address of the host being pinged.
  36. IPAddr *net.IPAddr
  37. // ID is the ICMP identifier.
  38. ID int
  39. // Seq is the ICMP sequence number.
  40. Seq int
  41. // UUID
  42. UUID uuid.UUID
  43. // TTL is the Time To Live on the packet.
  44. TTL int
  45. // NBytes is the number of bytes in the message.
  46. Nbytes int
  47. // SendTime
  48. SendTime time.Time
  49. // Rtt is the round-trip time it took to ping.
  50. Rtt time.Duration
  51. //
  52. TimeoutTimer *time.Timer
  53. }
  54. type recvPkt struct {
  55. recvtime time.Time
  56. addr net.Addr
  57. bytes []byte
  58. nbytes int
  59. ttl int
  60. }
  61. type MPacketConn struct {
  62. IPV4 bool
  63. Protocol string
  64. Source string
  65. Backlog int
  66. TTL int
  67. OnRecvPacket func(pkt *Packet)
  68. OnError func(error)
  69. mutex sync.Mutex
  70. conn PacketConn
  71. done chan interface{}
  72. recvbuf chan *recvPkt
  73. }
  74. func (mp *MPacketConn) Listen() error {
  75. mp.mutex.Lock()
  76. defer mp.mutex.Unlock()
  77. if mp.conn != nil {
  78. return nil
  79. }
  80. conn, err := mp.listen()
  81. if err != nil {
  82. return err
  83. }
  84. mp.done = make(chan interface{})
  85. mp.recvbuf = make(chan *recvPkt, mp.Backlog)
  86. mp.conn = conn
  87. go mp.recvICMP()
  88. go mp.processRecvPacket()
  89. return nil
  90. }
  91. func (mp *MPacketConn) listen() (conn PacketConn, err error) {
  92. if mp.IPV4 {
  93. var c icmpv4Conn
  94. c.c, err = icmp.ListenPacket(ipv4Proto[mp.Protocol], mp.Source)
  95. conn = &c
  96. } else {
  97. var c icmpV6Conn
  98. c.c, err = icmp.ListenPacket(ipv6Proto[mp.Protocol], mp.Source)
  99. conn = &c
  100. }
  101. if err != nil {
  102. return nil, err
  103. }
  104. conn.SetTTL(mp.TTL)
  105. if err := conn.SetFlagTTL(); err != nil {
  106. conn.Close()
  107. return nil, err
  108. }
  109. return conn, nil
  110. }
  111. func (mp *MPacketConn) Close() error {
  112. mp.mutex.Lock()
  113. defer mp.mutex.Unlock()
  114. open := true
  115. select {
  116. case _, open = <-mp.done:
  117. default:
  118. }
  119. if open {
  120. close(mp.done)
  121. }
  122. if mp.conn != nil {
  123. mp.conn.Close()
  124. mp.conn = nil
  125. }
  126. return nil
  127. }
  128. func (mp *MPacketConn) recvICMP() {
  129. bytes := make([]byte, 65536)
  130. for {
  131. select {
  132. case <-mp.done:
  133. return
  134. default:
  135. conn := mp.conn
  136. if conn == nil {
  137. return
  138. }
  139. var n, ttl int
  140. var addr net.Addr
  141. var err error
  142. n, ttl, addr, err = conn.ReadFrom(bytes)
  143. if err != nil {
  144. if neterr, ok := err.(*net.OpError); ok {
  145. if neterr.Timeout() {
  146. // Read timeout
  147. continue
  148. }
  149. }
  150. if mp.OnError != nil {
  151. mp.OnError(err)
  152. } else {
  153. fmt.Println("ReadFrom Error:", err)
  154. }
  155. }
  156. bs := make([]byte, n)
  157. copy(bs, bytes[:n])
  158. select {
  159. case <-mp.done:
  160. return
  161. case mp.recvbuf <- &recvPkt{recvtime: time.Now(), addr: addr, bytes: bs, nbytes: n, ttl: ttl}:
  162. }
  163. }
  164. }
  165. }
  166. func (mp *MPacketConn) SendPacket(pkt *Packet) error {
  167. conn := mp.conn
  168. if conn == nil {
  169. return ENOLISTENER
  170. }
  171. var dst net.Addr = pkt.IPAddr
  172. if mp.Protocol == "udp" {
  173. dst = &net.UDPAddr{IP: pkt.IPAddr.IP, Zone: pkt.IPAddr.Zone}
  174. }
  175. for {
  176. select {
  177. case <-mp.done:
  178. return nil
  179. default:
  180. }
  181. msgBytes, err := pkt.BuildEchoRequestMessage(conn.ICMPRequestType())
  182. if err != nil {
  183. return err
  184. }
  185. if _, err := conn.WriteTo(msgBytes, dst); err != nil {
  186. if neterr, ok := err.(*net.OpError); ok {
  187. if neterr.Err == syscall.ENOBUFS {
  188. if mp.OnError != nil {
  189. mp.OnError(neterr.Err)
  190. } else {
  191. // 运行时默认忽略缓存不够错误
  192. // fmt.Println("缓存不够,发送失败,重发")
  193. }
  194. continue
  195. }
  196. }
  197. return err
  198. } else {
  199. return nil
  200. }
  201. }
  202. }
  203. var max_receive_buffer_used = 0
  204. func MaxReceiveBufferUsed() int {
  205. return max_receive_buffer_used
  206. }
  207. func (mp *MPacketConn) processRecvPacket() {
  208. for pkt := range mp.recvbuf {
  209. if len(mp.recvbuf) > max_receive_buffer_used {
  210. max_receive_buffer_used = len(mp.recvbuf)
  211. }
  212. err := mp.processPacket(pkt)
  213. if err != nil {
  214. if mp.OnError != nil {
  215. mp.OnError(err)
  216. } else {
  217. // 运行时默认忽略接收数据格式不符错误
  218. // fmt.Println(err)
  219. }
  220. }
  221. }
  222. }
  223. var count = 0
  224. func (mp *MPacketConn) processPacket(recv *recvPkt) error {
  225. var proto int
  226. if mp.IPV4 {
  227. proto = protocolICMP
  228. } else {
  229. proto = protocolIPv6ICMP
  230. }
  231. // fmt.Println(count, "from", recv.addr.String(), "bytes", recv.bytes)
  232. var m *icmp.Message
  233. var err error
  234. if m, err = icmp.ParseMessage(proto, recv.bytes); err != nil {
  235. return fmt.Errorf("error parsing icmp message: %w", err)
  236. }
  237. if m.Type != ipv4.ICMPTypeEchoReply && m.Type != ipv6.ICMPTypeEchoReply {
  238. // Not an echo reply, ignore it
  239. return nil
  240. }
  241. switch pkt := m.Body.(type) {
  242. case *icmp.Echo:
  243. return mp.processEchoReply(pkt, recv)
  244. default:
  245. // Very bad, not sure how this can happen
  246. return fmt.Errorf("invalid ICMP echo reply; type: '%T', '%v'", pkt, pkt)
  247. }
  248. }
  249. func (mp *MPacketConn) processEchoReply(pkt *icmp.Echo, recv *recvPkt) error {
  250. if len(pkt.Data) < 40 {
  251. return nil
  252. }
  253. sendtime := int64(binary.BigEndian.Uint64(pkt.Data[:8]))
  254. fullseq := int(binary.BigEndian.Uint64(pkt.Data[8:16]))
  255. fullid := int(binary.BigEndian.Uint64(pkt.Data[16:24]))
  256. pktuuid := uuid.Must(uuid.FromBytes(pkt.Data[24:40]))
  257. // Linux 下 UDP 方式,接收的 EchoReply.ID 与发送的 Echo.ID 是不一致的
  258. // if fullid%65536 != pkt.ID || fullseq%65536 != pkt.Seq {
  259. // return nil
  260. // }
  261. // fmt.Printf("%s %d bytes from %s: icmp_seq=%d time=%v\n",
  262. // time.Now().Format("15:04:05.000"), recv.nbytes, recv.addr, fullseq, recv.recvtime.Sub(time.Unix(0, sendtime)))
  263. if mp.OnRecvPacket != nil {
  264. mp.OnRecvPacket(&Packet{
  265. IPAddr: netAddrToIPAddr(recv.addr),
  266. ID: fullid,
  267. Seq: fullseq,
  268. UUID: pktuuid,
  269. Nbytes: recv.nbytes,
  270. TTL: recv.ttl,
  271. SendTime: time.Unix(0, sendtime),
  272. Rtt: recv.recvtime.Sub(time.Unix(0, sendtime)),
  273. })
  274. }
  275. return nil
  276. }