package probing import ( "bytes" "fmt" "net" "sync" "sync/atomic" "syscall" "time" "github.com/google/uuid" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) type mpingconn struct { mutex sync.Mutex ipv4 bool protocol string Source string Conn packetConn done chan interface{} pingid map[string]int pingseq map[int]int pingfcb map[int]func(*Packet) } var mpconn = &mpingconn{ ipv4: true, protocol: "udp", Source: "", done: make(chan interface{}), pingid: make(map[string]int), pingseq: make(map[int]int), pingfcb: make(map[int]func(*Packet)), } func (p *mpingconn) listen() (packetConn, error) { p.mutex.Lock() defer p.mutex.Unlock() if p.Conn != nil { return p.Conn, nil } var ( conn packetConn err error ) if p.ipv4 { var c icmpv4Conn c.c, err = icmp.ListenPacket(ipv4Proto[p.protocol], p.Source) conn = &c } else { var c icmpV6Conn c.c, err = icmp.ListenPacket(ipv6Proto[p.protocol], p.Source) conn = &c } if err != nil { return nil, err } p.Conn = conn go func() { p.recvICMP() }() return p, nil } func (p *mpingconn) Close() error { p.mutex.Lock() defer p.mutex.Unlock() open := true select { case _, open = <-p.done: default: } if open { close(p.done) } return p.Conn.Close() } func (p *mpingconn) ICMPRequestType() icmp.Type { return p.Conn.ICMPRequestType() } func (p *mpingconn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { return p.Conn.ReadFrom(b) } func (p *mpingconn) WriteTo(b []byte, dst net.Addr) (int, error) { return p.Conn.WriteTo(b, dst) } func (p *mpingconn) SetReadDeadline(t time.Time) error { return p.Conn.SetReadDeadline(t) } func (p *mpingconn) SetFlagTTL() error { return p.Conn.SetFlagTTL() } func (p *mpingconn) SetTTL(ttl int) { p.Conn.SetTTL(ttl) } var pingid int32 func newPingID() int { return int(atomic.AddInt32(&pingid, 1)) } func (p *mpingconn) ping(addr *net.IPAddr, size int, timeout time.Duration, onSend func(*Packet), onRecv func(*Packet)) (time.Duration, error) { currentUUID, err := uuid.NewUUID() if err != nil { return -1, fmt.Errorf("NewUUID: %w", err) } uuidEncoded, err := currentUUID.MarshalBinary() if err != nil { return -1, fmt.Errorf("unable to marshal UUID binary: %w", err) } p.mutex.Lock() var pid int var has bool if pid, has = p.pingid[addr.String()]; !has { pid = newPingID() p.pingid[addr.String()] = pid } var psq = p.pingseq[pid] p.pingseq[pid]++ p.pingfcb[pid] = onRecv p.mutex.Unlock() t := append(timeToBytes(time.Now()), uuidEncoded...) if remainSize := size - len(t); remainSize > 0 { t = append(t, bytes.Repeat([]byte{1}, remainSize)...) } body := &icmp.Echo{ ID: pid, Seq: psq, Data: t, } msg := &icmp.Message{ Type: p.ICMPRequestType(), Code: 0, Body: body, } msgBytes, err := msg.Marshal(nil) if err != nil { return -1, err } err = p.sendICMP(msgBytes, addr) if err != nil { return -1, err } onSend(&Packet{ Nbytes: len(msgBytes), IPAddr: addr, Seq: psq, ID: pid, }) time.NewTimer(timeout) return nil } func (p *mpingconn) recvICMP(recv chan<- *packet) error { bytes := make([]byte, 65536) for { select { case <-p.done: return nil default: var n, ttl int var err error n, ttl, _, err = p.Conn.ReadFrom(bytes) if err != nil { if neterr, ok := err.(*net.OpError); ok { if neterr.Timeout() { // Read timeout continue } } return err } bs := make([]byte, n) copy(bs, bytes[:n]) select { case <-p.done: return nil case recv <- &packet{bytes: bs, nbytes: n, ttl: ttl}: } } } } func (p *mpingconn) sendICMP(msgBytes []byte, addr *net.IPAddr) error { var dst net.Addr = addr if p.protocol == "udp" { dst = &net.UDPAddr{IP: addr.IP, Zone: addr.Zone} } for { select { case <-p.done: return nil default: } for { if _, err := p.Conn.WriteTo(msgBytes, dst); err != nil { if neterr, ok := err.(*net.OpError); ok { if neterr.Err == syscall.ENOBUFS { fmt.Println("缓存不够,发送失败,重发") continue } } return err } } } } func (p *mpingconn) processPacket(recv *packet) error { receivedAt := time.Now() var proto int if p.ipv4 { proto = protocolICMP } else { proto = protocolIPv6ICMP } var m *icmp.Message var err error if m, err = icmp.ParseMessage(proto, recv.bytes); err != nil { return fmt.Errorf("error parsing icmp message: %w", err) } if m.Type != ipv4.ICMPTypeEchoReply && m.Type != ipv6.ICMPTypeEchoReply { // Not an echo reply, ignore it return nil } inPkt := &Packet{ Nbytes: recv.nbytes, IPAddr: p.ipaddr, Addr: p.addr, TTL: recv.ttl, ID: p.id, } switch pkt := m.Body.(type) { case *icmp.Echo: return p.processEchoReply(pkt, receivedAt, inPkt) default: // Very bad, not sure how this can happen return fmt.Errorf("invalid ICMP echo reply; type: '%T', '%v'", pkt, pkt) } } func (p *mpingconn) processEchoReply(pkt *icmp.Echo, receivedAt time.Time, inPkt *Packet) error { if !p.matchID(pkt.ID) { // lastpingtimemutex.Lock() // ap := idping[pkt.ID] // lastpingtimemutex.Unlock() // println(fmt.Sprintf("%#v%s%#v", ap, "\n", p)) return nil } if len(pkt.Data) < timeSliceLength+trackerLength { return fmt.Errorf("insufficient data received; got: %d %v", len(pkt.Data), pkt.Data) } pktUUID, err := p.getPacketUUID(pkt.Data) if err != nil || pktUUID == nil { return err } timestamp := bytesToTime(pkt.Data[:timeSliceLength]) inPkt.Rtt = receivedAt.Sub(timestamp) inPkt.Seq = pkt.Seq // If we've already received this sequence, ignore it. if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight { p.PacketsRecvDuplicates++ if p.OnDuplicateRecv != nil { p.OnDuplicateRecv(inPkt) } return nil } // remove it from the list of sequences we're waiting for so we don't get duplicates. delete(p.awaitingSequences[*pktUUID], pkt.Seq) p.updateStatistics(inPkt) handler := p.OnRecv if handler != nil { handler(inPkt) } return nil }