package probing import ( "encoding/binary" "fmt" "net" "sync" "sync/atomic" "syscall" "time" "trial/ping/probing/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) type recvPkt struct { recvtime time.Time addr net.Addr bytes []byte nbytes int ttl int } type mpingconn struct { mutex sync.Mutex ipv4 bool protocol string Source string Conn icmp.PacketConn done chan interface{} recvchan chan *recvPkt pinghostinfo map[string]*mpinfo pingidinfo map[int]*mpinfo OnError func(error) } type mpinfo struct { host string ipaddr *net.IPAddr id int lastseq int size int timeout time.Duration seqpkt map[int]chan *Packet OnSend func(*Packet) OnRecv func(*Packet) OnRecvDup func(*Packet) OnRecvError func(error) } var mpconn = newMPConn() func newMPConn() *mpingconn { mpconn := &mpingconn{ ipv4: true, protocol: "udp", Source: "", done: make(chan interface{}), recvchan: make(chan *recvPkt, receive_buffer_count), pinghostinfo: make(map[string]*mpinfo), pingidinfo: make(map[int]*mpinfo), } return mpconn } func (p *mpingconn) listen() (icmp.PacketConn, error) { p.mutex.Lock() defer p.mutex.Unlock() if p.Conn != nil { return p.Conn, nil } conn, err := icmp.Listen(p.ipv4, p.protocol, p.Source) if err != nil { return nil, err } p.Conn = conn conn.SetTTL(64) if err := conn.SetFlagTTL(); err != nil { return nil, err } go p.recvICMP() go p.processRecvPacket() return p, nil } func (p *mpingconn) Close() error { p.mutex.Lock() defer p.mutex.Unlock() open := true select { case _, open = <-p.done: default: } if open { close(p.done) } return p.Conn.Close() } func (p *mpingconn) ICMPRequestType() icmp.Type { return p.Conn.ICMPRequestType() } func (p *mpingconn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { return p.Conn.ReadFrom(b) } func (p *mpingconn) WriteTo(b []byte, dst net.Addr) (int, error) { return p.Conn.WriteTo(b, dst) } func (p *mpingconn) SetReadDeadline(t time.Time) error { return p.Conn.SetReadDeadline(t) } func (p *mpingconn) SetFlagTTL() error { return p.Conn.SetFlagTTL() } func (p *mpingconn) SetTTL(ttl int) { p.Conn.SetTTL(ttl) } var pingid int32 var ETIMEDOUT error = fmt.Errorf("timeout") func newPingInfo(host string, ipaddr *net.IPAddr, size int, timeout time.Duration) *mpinfo { return &mpinfo{ host: host, ipaddr: ipaddr, id: int(atomic.AddInt32(&pingid, 1)), seqpkt: make(map[int]chan *Packet), size: size, timeout: timeout, } } func (p *mpingconn) ping(host string, ipaddr *net.IPAddr, size int, timeout time.Duration, onSend func(*Packet), onRecv func(*Packet)) error { p.mutex.Lock() var pinfo *mpinfo var has bool if pinfo, has = p.pinghostinfo[ipaddr.String()]; !has { pinfo = newPingInfo(host, ipaddr, size, timeout) p.pinghostinfo[ipaddr.String()] = pinfo p.pingidinfo[pinfo.id] = pinfo } seq := pinfo.lastseq pinfo.lastseq++ recvpkt := make(chan *Packet, 1) pinfo.seqpkt[seq] = recvpkt p.mutex.Unlock() msgBytes, err := icmp.BuildEchoRequestMessage(pinfo.id, seq, size, p.ICMPRequestType()) if err != nil { return err } err = p.sendICMP(msgBytes, ipaddr) if err != nil { return err } outpkt := &Packet{ Nbytes: len(msgBytes), Host: host, IPAddr: ipaddr, Seq: seq, ID: pinfo.id, } if onSend != nil { onSend(outpkt) } go func(onRecv func(*Packet), recvpkt chan *Packet) { t := time.NewTimer(timeout) select { case <-t.C: case inpkt := <-recvpkt: if onRecv != nil { onRecv(inpkt) } case <-p.done: } // clear(); }(onRecv, recvpkt) return nil } func (p *mpingconn) sendICMP(msgBytes []byte, addr *net.IPAddr) error { var dst net.Addr = addr if p.protocol == "udp" { dst = &net.UDPAddr{IP: addr.IP, Zone: addr.Zone} } for { select { case <-p.done: return nil default: } if _, err := p.Conn.WriteTo(msgBytes, dst); err != nil { if neterr, ok := err.(*net.OpError); ok { if neterr.Err == syscall.ENOBUFS { if p.OnError != nil { p.OnError(neterr.Err) } else { fmt.Println("缓存不够,发送失败,重发") } continue } } return err } else { return nil } } } func (p *mpingconn) recvICMP() error { bytes := make([]byte, 65536) for { select { case <-p.done: return nil default: var n, ttl int var addr net.Addr var err error n, ttl, addr, err = p.Conn.ReadFrom(bytes) if err != nil { if neterr, ok := err.(*net.OpError); ok { if neterr.Timeout() { // Read timeout continue } } return err } bs := make([]byte, n) copy(bs, bytes[:n]) select { case <-p.done: return nil case p.recvchan <- &recvPkt{recvtime: time.Now(), addr: addr, bytes: bs, nbytes: n, ttl: ttl}: } } } } func (p *mpingconn) processRecvPacket() { for pkt := range p.recvchan { if len(p.recvchan) > cap(p.recvchan)*9/10 { fmt.Printf("receive buffer full") } err := p.processPacket(pkt) if err != nil { if p.OnError != nil { p.OnError(err) } else { fmt.Println(err) } } } } var count = 0 func (p *mpingconn) processPacket(recv *recvPkt) error { var proto int if p.ipv4 { proto = protocolICMP } else { proto = protocolIPv6ICMP } // fmt.Println(count, "from", recv.addr.String(), "bytes", recv.bytes) var m *icmp.Message var err error if m, err = icmp.ParseMessage(proto, recv.bytes); err != nil { return fmt.Errorf("error parsing icmp message: %w", err) } if m.Type != ipv4.ICMPTypeEchoReply && m.Type != ipv6.ICMPTypeEchoReply { // Not an echo reply, ignore it return nil } switch pkt := m.Body.(type) { case *icmp.Echo: return p.processEchoReply(pkt, recv) default: // Very bad, not sure how this can happen return fmt.Errorf("invalid ICMP echo reply; type: '%T', '%v'", pkt, pkt) } } func (p *mpingconn) processEchoReply(pkt *icmp.Echo, recv *recvPkt) error { if len(pkt.Data) < 24 { return nil } sendtime := int64(binary.BigEndian.Uint64(pkt.Data[:8])) fullseq := int(binary.BigEndian.Uint64(pkt.Data[8:16])) fullid := int(binary.BigEndian.Uint64(pkt.Data[16:24])) if fullid%65536 != pkt.ID || fullseq%65536 != pkt.Seq { return nil } p.mutex.Lock() pinfo := p.pingidinfo[fullid] p.mutex.Unlock() if pinfo == nil { return nil } inPkt := &Packet{ Host: pinfo.host, IPAddr: pinfo.ipaddr, ID: pinfo.id, Seq: fullseq, Nbytes: recv.nbytes, TTL: recv.ttl, Rtt: recv.recvtime.Sub(time.Unix(0, sendtime)), } // fmt.Printf("%s %d bytes from %s: icmp_seq=%d time=%v\n", // time.Now().Format("15:04:05.000"), inPkt.Nbytes, inPkt.IPAddr, inPkt.Seq, inPkt.Rtt) p.mutex.Lock() chpkt, inflight := pinfo.seqpkt[fullseq] if inflight { // remove it from the list of sequences we're waiting for so we don't get duplicates. delete(pinfo.seqpkt, fullseq) } p.mutex.Unlock() if chpkt != nil { chpkt <- inPkt } return nil }