package probing import ( "bytes" "net" "sync" "sync/atomic" "time" "trial/ping/probing/icmp" "github.com/google/uuid" ) var connm sync.Mutex var conns = map[string]*mpingconn{} func MPConn(ipv4 bool, protocol string) *mpingconn { key := protocol if ipv4 { key += "4" } else { key += "6" } connm.Lock() mp := conns[key] if mp == nil { mp = newMPConn(ipv4, protocol) conns[key] = mp } connm.Unlock() return mp } type mpingconn struct { *icmp.MPacketConn mutex sync.Mutex uuid uuid.UUID pingidinfo map[int]*mpinfo } type mpinfo struct { host string ipaddr *net.IPAddr id int lastseq int size int timeout time.Duration seqpkt map[int]*icmp.Packet OnSend func(outpkt *icmp.Packet) OnRecv func(inpkt *icmp.Packet) OnRecvDup func(inpkt *icmp.Packet) OnRecvTimeout func(outpkt *icmp.Packet) } var pingid int32 var pinginfomutex sync.Mutex var pinginfo = map[string]*mpinfo{} func getPingInfo(ipaddr *net.IPAddr) *mpinfo { pinginfomutex.Lock() defer pinginfomutex.Unlock() pinfo := pinginfo[ipaddr.String()] if pinfo == nil { pinfo = &mpinfo{ ipaddr: ipaddr, id: int(atomic.AddInt32(&pingid, 1)), seqpkt: make(map[int]*icmp.Packet), } pinginfo[ipaddr.String()] = pinfo } return pinfo } func newMPConn(ipv4 bool, protocol string) *mpingconn { mpconn := &mpingconn{ MPacketConn: &icmp.MPacketConn{ IPV4: ipv4, Protocol: protocol, Source: "", Backlog: receive_buffer_count, TTL: ping_ttl, }, uuid: uuid.Must(uuid.NewUUID()), pingidinfo: make(map[int]*mpinfo), } mpconn.MPacketConn.OnRecvPacket = mpconn.OnRecvPacket mpconn.MPacketConn.OnError = mpconn.OnError return mpconn } func (p *mpingconn) Listen() error { err := p.MPacketConn.Listen() if err != nil { return err } return nil } func (p *mpingconn) Close() error { err := p.MPacketConn.Close() if err != nil { return err } return nil } func (p *mpingconn) OnRecvPacket(recvpkt *icmp.Packet) { // fmt.Println("recv", recvpkt) p.mutex.Lock() defer p.mutex.Unlock() pinfo := p.pingidinfo[recvpkt.ID] if pinfo == nil { return } if !bytes.Equal(p.uuid[:], recvpkt.UUID[:]) { return } outpkt, inflight := pinfo.seqpkt[recvpkt.Seq] if inflight { // remove it from the list of sequences we're waiting for so we don't get duplicates. if outpkt.TimeoutTimer != nil { outpkt.TimeoutTimer.Stop() } delete(pinfo.seqpkt, recvpkt.Seq) if pinfo.OnRecv != nil { go pinfo.OnRecv(recvpkt) } return } if pinfo.OnRecvDup != nil { go pinfo.OnRecvDup(recvpkt) } } func (p *mpingconn) OnError(err error) { logger.Error(err) } func (p *mpingconn) Ping(pinfo *mpinfo) error { p.mutex.Lock() if _, has := p.pingidinfo[pinfo.id]; !has { p.pingidinfo[pinfo.id] = pinfo } seq := pinfo.lastseq pinfo.lastseq++ outpkt := &icmp.Packet{ IPAddr: pinfo.ipaddr, ID: pinfo.id, Seq: seq, UUID: p.uuid, Nbytes: pinfo.size, } pinfo.seqpkt[seq] = outpkt p.mutex.Unlock() err := p.SendPacket(outpkt) if err != nil { return err } if pinfo.OnSend != nil { pinfo.OnSend(outpkt) } if pinfo.OnRecvTimeout != nil { outpkt.TimeoutTimer = time.AfterFunc(pinfo.timeout, func() { p.mutex.Lock() outpkt := pinfo.seqpkt[seq] p.mutex.Unlock() if outpkt != nil { pinfo.OnRecvTimeout(outpkt) } }) } // fmt.Println("sent", outpkt) return nil }