package probing import ( "bytes" "net" "sync" "sync/atomic" "time" "trial/ping/probing/icmp" "git.wecise.com/wecise/common/logger" "github.com/google/uuid" ) var connm sync.Mutex var conns = map[string]*mpingconn{} func MPConn(ipv4 bool, protocol string) *mpingconn { key := protocol if ipv4 { key += "4" } else { key += "6" } connm.Lock() mp := conns[key] if mp == nil { mp = newMPConn(ipv4, protocol) conns[key] = mp } connm.Unlock() return mp } type mpingconn struct { *icmp.MPacketConn mutex sync.Mutex uuid uuid.UUID pingidinfo map[int]*mpinfo } type mpinfo struct { host string ipaddr *net.IPAddr id int lastseq int size int timeout time.Duration seqpkt map[int]*icmp.Packet OnSend func(*icmp.Packet) OnRecv func(*icmp.Packet) OnRecvDup func(*icmp.Packet) OnRecvError func(error) } var pingid int32 var pinginfomutex sync.Mutex var pinginfo = map[string]*mpinfo{} func getPingInfo(ipaddr *net.IPAddr) *mpinfo { pinginfomutex.Lock() defer pinginfomutex.Unlock() pinfo := pinginfo[ipaddr.String()] if pinfo == nil { pinfo = &mpinfo{ ipaddr: ipaddr, id: int(atomic.AddInt32(&pingid, 1)), seqpkt: make(map[int]*icmp.Packet), } pinginfo[ipaddr.String()] = pinfo } return pinfo } func newMPConn(ipv4 bool, protocol string) *mpingconn { mpconn := &mpingconn{ MPacketConn: &icmp.MPacketConn{ IPV4: ipv4, Protocol: protocol, Source: "", Backlog: receive_buffer_count, TTL: ping_ttl, }, uuid: uuid.Must(uuid.NewUUID()), pingidinfo: make(map[int]*mpinfo), } mpconn.MPacketConn.OnRecvPacket = mpconn.OnRecvPacket mpconn.MPacketConn.OnRecvError = mpconn.OnRecvError return mpconn } func (p *mpingconn) Listen() error { err := p.MPacketConn.Listen() if err != nil { return err } return nil } func (p *mpingconn) Close() error { err := p.MPacketConn.Close() if err != nil { return err } return nil } func (p *mpingconn) OnRecvPacket(recvpkt *icmp.Packet) { // fmt.Println("recv", recvpkt) p.mutex.Lock() defer p.mutex.Unlock() pinfo := p.pingidinfo[recvpkt.ID] if pinfo == nil { return } if !bytes.Equal(p.uuid[:], recvpkt.UUID[:]) { return } _, inflight := pinfo.seqpkt[recvpkt.Seq] if inflight { // remove it from the list of sequences we're waiting for so we don't get duplicates. delete(pinfo.seqpkt, recvpkt.Seq) if pinfo.OnRecv != nil { go pinfo.OnRecv(recvpkt) } return } if pinfo.OnRecvDup != nil { go pinfo.OnRecvDup(recvpkt) } } func (p *mpingconn) OnRecvError(err error) { logger.Error(err) } func (p *mpingconn) Ping(pinfo *mpinfo) error { p.mutex.Lock() if _, has := p.pingidinfo[pinfo.id]; !has { p.pingidinfo[pinfo.id] = pinfo } seq := pinfo.lastseq pinfo.lastseq++ outpkt := &icmp.Packet{ IPAddr: pinfo.ipaddr, ID: pinfo.id, Seq: seq, UUID: p.uuid, Nbytes: pinfo.size, } pinfo.seqpkt[seq] = outpkt p.mutex.Unlock() err := p.SendPacket(outpkt) if err != nil { return err } if pinfo.OnSend != nil { pinfo.OnSend(outpkt) } // fmt.Println("sent", outpkt) return nil }