package net import ( "fmt" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "net" "sync" "sync/atomic" "time" ) const ( readBufferSize = 1500 maxSeq = 0xffff ) type Pinger struct { icmpID int conn4 *icmp.PacketConn conn6 *icmp.PacketConn connOnce4 sync.Once connOnce6 sync.Once connErr4 error connErr6 error sequenceNum4 uint32 sequenceNum6 uint32 pending sync.Map closed chan struct{} } type pendingKey struct { proto string seq int } func NewPinger(icmpID int) *Pinger { return &Pinger{ icmpID: icmpID & 0xffff, closed: make(chan struct{}), } } func (p *Pinger) Close() error { close(p.closed) var err error if p.conn4 != nil { if e := p.conn4.Close(); e != nil { err = e } } if p.conn6 != nil { if e := p.conn6.Close(); e != nil { err = e } } return err } func (p *Pinger) Ping(ip string, timeout time.Duration) error { var conn *icmp.PacketConn var proto string var sequence *uint32 dest, err := net.ResolveIPAddr("ip", ip) if err != nil { return fmt.Errorf("resolve IP address error: %w", err) } if dest.IP.To4() != nil { // IPv4处理 p.connOnce4.Do(func() { p.conn4, p.connErr4 = icmp.ListenPacket("ip4:icmp", "0.0.0.0") if p.connErr4 == nil { go p.receiveLoop(p.conn4, "ip4") } }) if p.connErr4 != nil { return fmt.Errorf("ICMPv4 connection error: %w", p.connErr4) } conn = p.conn4 proto = "ip4" sequence = &p.sequenceNum4 } else { // IPv6处理 p.connOnce6.Do(func() { p.conn6, p.connErr6 = icmp.ListenPacket("ip6:ipv6-icmp", "::") if p.connErr6 == nil { go p.receiveLoop(p.conn6, "ip6") } }) if p.connErr6 != nil { return fmt.Errorf("ICMPv6 connection error: %w", p.connErr6) } conn = p.conn6 proto = "ip6" sequence = &p.sequenceNum6 } seq := int(atomic.AddUint32(sequence, 1) & maxSeq) key := pendingKey{proto, seq} resultChan := make(chan struct{}) p.pending.Store(key, resultChan) defer p.pending.Delete(key) var msgType icmp.Type if proto == "ip4" { msgType = ipv4.ICMPTypeEcho } else { msgType = ipv6.ICMPTypeEchoRequest } msg := &icmp.Message{ Type: msgType, Code: 0, Body: &icmp.Echo{ ID: p.icmpID, Seq: seq, Data: []byte("HELLO-PING"), }, } packet, err := msg.Marshal(nil) if err != nil { return fmt.Errorf("marshal error: %w", err) } if _, err := conn.WriteTo(packet, dest); err != nil { return fmt.Errorf("write error: %w", err) } select { case <-resultChan: return nil case <-time.After(timeout): return fmt.Errorf("timeout") case <-p.closed: return fmt.Errorf("pinger closed") } } func (p *Pinger) receiveLoop(conn *icmp.PacketConn, proto string) { buffer := make([]byte, readBufferSize) for { select { case <-p.closed: return default: n, _, err := conn.ReadFrom(buffer) if err != nil { continue } var expectedType icmp.Type var protocol int if proto == "ip4" { expectedType = ipv4.ICMPTypeEchoReply protocol = 1 // ICMPv4协议号 } else { expectedType = ipv6.ICMPTypeEchoReply protocol = 58 // ICMPv6协议号 } msg, err := icmp.ParseMessage(protocol, buffer[:n]) if err != nil { continue } if msg.Type != expectedType { continue } echo, ok := msg.Body.(*icmp.Echo) if !ok || echo.ID != p.icmpID { continue } key := pendingKey{proto, echo.Seq} if ch, exists := p.pending.LoadAndDelete(key); exists { close(ch.(chan struct{})) } } } }