package starnet import ( "context" "encoding/binary" "errors" "fmt" "net" "os" "strings" "sync/atomic" "time" ) const ( icmpTypeEchoReplyV4 = 0 icmpTypeEchoRequestV4 = 8 icmpTypeEchoRequestV6 = 128 icmpTypeEchoReplyV6 = 129 icmpHeaderLen = 8 icmpReadBufSz = 1500 defaultPingAttemptTimeout = 2 * time.Second defaultPingableCount = 3 maxPingPayloadSize = 65499 // 65507 - ICMP header(8) ) type ICMP struct { Type uint8 Code uint8 CheckSum uint16 Identifier uint16 SequenceNum uint16 } type pingSocketSpec struct { network string family int requestType uint8 replyType uint8 } // PingOptions controls ping probing behavior. type PingOptions struct { Count int // ping attempts for Pingable, default 3 Timeout time.Duration // per-attempt timeout, default 2s Interval time.Duration // delay between attempts, default 0 Deadline time.Time // overall deadline for Pingable/PingWithContext PreferIPv4 bool // prefer IPv4 targets PreferIPv6 bool // prefer IPv6 targets SourceIP net.IP // optional source IP for raw socket bind PayloadSize int // ICMP payload bytes, default 0 } type PingResult struct { Duration time.Duration RecvCount int RemoteIP string } var pingIdentifierSeed uint32 func nextPingIdentifier() uint16 { pid := uint32(os.Getpid() & 0xffff) n := atomic.AddUint32(&pingIdentifierSeed, 1) return uint16((pid + n) & 0xffff) } func pingPayload(size int) []byte { if size <= 0 { return nil } payload := make([]byte, size) for i := 0; i < len(payload); i++ { payload[i] = byte(i) } return payload } func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP { icmp := ICMP{ Type: typ, Code: 0, CheckSum: 0, Identifier: identifier, SequenceNum: seq, } buf := marshalICMPPacket(icmp, payload) icmp.CheckSum = checkSum(buf) return icmp } func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) { var res PingResult if ctx == nil { ctx = context.Background() } if err := ctx.Err(); err != nil { return res, wrapError(err, "ping context done") } if destAddr == nil || destAddr.IP == nil { return res, fmt.Errorf("destination ip is nil") } res.RemoteIP = destAddr.String() localAddr, err := localIPAddrForFamily(sourceIP, spec.family) if err != nil { return res, err } conn, err := net.DialIP(spec.network, localAddr, destAddr) if err != nil { return res, normalizePingDialError(err) } defer conn.Close() packet := marshalICMPPacket(icmp, payload) if _, err := conn.Write(packet); err != nil { return res, wrapError(err, "ping write request") } tStart := time.Now() deadline := tStart.Add(timeout) if d, ok := ctx.Deadline(); ok && d.Before(deadline) { deadline = d } if err := conn.SetReadDeadline(deadline); err != nil { return res, wrapError(err, "ping set read deadline") } doneCh := make(chan struct{}) go func() { select { case <-ctx.Done(): _ = conn.SetReadDeadline(time.Now()) case <-doneCh: } }() defer close(doneCh) recv := make([]byte, icmpReadBufSz) for { n, err := conn.Read(recv) if err != nil { if ctx.Err() != nil { return res, wrapError(ctx.Err(), "ping context done") } return res, wrapError(err, "ping read reply") } if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) { res.RecvCount = n res.Duration = time.Since(tStart) return res, nil } } } func checkSum(data []byte) uint16 { var ( sum uint32 length int = len(data) index int ) for length > 1 { sum += uint32(data[index])<<8 + uint32(data[index+1]) index += 2 length -= 2 } if length > 0 { sum += uint32(data[index]) << 8 } for sum>>16 != 0 { sum = (sum & 0xffff) + (sum >> 16) } return uint16(^sum) } func marshalICMP(icmp ICMP) []byte { return marshalICMPPacket(icmp, nil) } func marshalICMPPacket(icmp ICMP, payload []byte) []byte { buf := make([]byte, icmpHeaderLen+len(payload)) buf[0] = icmp.Type buf[1] = icmp.Code binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum) binary.BigEndian.PutUint16(buf[4:], icmp.Identifier) binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum) copy(buf[icmpHeaderLen:], payload) return buf } func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool { for _, off := range candidateICMPOffsets(packet, family) { if off < 0 || off+icmpHeaderLen > len(packet) { continue } if packet[off] != expectedType || packet[off+1] != 0 { continue } if binary.BigEndian.Uint16(packet[off+4:off+6]) != identifier { continue } if binary.BigEndian.Uint16(packet[off+6:off+8]) != seq { continue } return true } return false } func candidateICMPOffsets(packet []byte, family int) []int { offsets := []int{0} if len(packet) == 0 { return offsets } ver := packet[0] >> 4 if ver == 4 && len(packet) >= 20 { ihl := int(packet[0]&0x0f) * 4 if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen { offsets = append(offsets, ihl) } } else if ver == 6 && len(packet) >= 40+icmpHeaderLen { offsets = append(offsets, 40) } // 某些平台/内核可能回包含链路层头部,保守再尝试常见偏移。 if family == 4 && len(packet) >= 20+icmpHeaderLen { offsets = append(offsets, 20) } if family == 6 && len(packet) >= 40+icmpHeaderLen { offsets = append(offsets, 40) } return dedupOffsets(offsets) } func dedupOffsets(offsets []int) []int { if len(offsets) <= 1 { return offsets } m := make(map[int]struct{}, len(offsets)) out := make([]int, 0, len(offsets)) for _, off := range offsets { if _, ok := m[off]; ok { continue } m[off] = struct{}{} out = append(out, off) } return out } func socketSpecForIP(ip net.IP) (pingSocketSpec, error) { if ip == nil { return pingSocketSpec{}, wrapError(ErrInvalidIP, "ip is nil") } if ip4 := ip.To4(); ip4 != nil { return pingSocketSpec{ network: "ip4:icmp", family: 4, requestType: icmpTypeEchoRequestV4, replyType: icmpTypeEchoReplyV4, }, nil } if ip16 := ip.To16(); ip16 != nil { return pingSocketSpec{ network: "ip6:ipv6-icmp", family: 6, requestType: icmpTypeEchoRequestV6, replyType: icmpTypeEchoReplyV6, }, nil } return pingSocketSpec{}, wrapError(ErrInvalidIP, "invalid ip: %q", ip.String()) } func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) { if sourceIP == nil { return nil, nil } if sourceIP.To16() == nil { return nil, wrapError(ErrInvalidIP, "invalid source ip: %q", sourceIP.String()) } if family == 4 && sourceIP.To4() == nil { return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv4 target") } if family == 6 && sourceIP.To4() != nil { return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv6 target") } return &net.IPAddr{IP: sourceIP}, nil } func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) { if parsed := net.ParseIP(host); parsed != nil { return []*net.IPAddr{{IP: parsed}}, nil } var targets []*net.IPAddr var err4 error var err6 error if ip4, e := net.ResolveIPAddr("ip4", host); e == nil && ip4 != nil && ip4.IP != nil { targets = append(targets, ip4) } else { err4 = e } if ip6, e := net.ResolveIPAddr("ip6", host); e == nil && ip6 != nil && ip6.IP != nil { targets = append(targets, ip6) } else { err6 = e } if len(targets) > 0 { return orderPingTargets(targets, preferIPv4, preferIPv6), nil } if err4 != nil { return nil, err4 } if err6 != nil { return nil, err6 } return nil, ErrPingNoResolvedTarget } func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr { if len(targets) <= 1 || preferIPv4 == preferIPv6 { return targets } ordered := make([]*net.IPAddr, 0, len(targets)) if preferIPv4 { for _, t := range targets { if t != nil && t.IP != nil && t.IP.To4() != nil { ordered = append(ordered, t) } } for _, t := range targets { if t != nil && t.IP != nil && t.IP.To4() == nil { ordered = append(ordered, t) } } return ordered } for _, t := range targets { if t != nil && t.IP != nil && t.IP.To4() == nil { ordered = append(ordered, t) } } for _, t := range targets { if t != nil && t.IP != nil && t.IP.To4() != nil { ordered = append(ordered, t) } } return ordered } func normalizePingDialError(err error) error { if err == nil { return nil } msg := strings.ToLower(err.Error()) if errors.Is(err, os.ErrPermission) || strings.Contains(msg, "operation not permitted") || strings.Contains(msg, "permission denied") { return fmt.Errorf("%w: %v", ErrPingPermissionDenied, err) } if strings.Contains(msg, "unknown network") || strings.Contains(msg, "protocol not available") || strings.Contains(msg, "address family not supported by protocol") || strings.Contains(msg, "socket type not supported") { return fmt.Errorf("%w: %v", ErrPingProtocolUnsupported, err) } return wrapError(err, "ping dial") } func normalizePingOptions(opts *PingOptions, defaultCount int, defaultTimeout time.Duration) (PingOptions, error) { out := PingOptions{ Count: defaultCount, Timeout: defaultTimeout, Interval: 0, PayloadSize: 0, } if opts != nil { out = *opts if out.Count == 0 { out.Count = defaultCount } if out.Timeout == 0 { out.Timeout = defaultTimeout } } if out.Count < 0 { return out, fmt.Errorf("ping count must be >= 0") } if out.Timeout <= 0 { return out, wrapError(ErrPingInvalidTimeout, "timeout must be > 0") } if out.Interval < 0 { return out, fmt.Errorf("ping interval must be >= 0") } if out.PayloadSize < 0 || out.PayloadSize > maxPingPayloadSize { return out, fmt.Errorf("ping payload size must be in [0,%d]", maxPingPayloadSize) } if out.SourceIP != nil && out.SourceIP.To16() == nil { return out, wrapError(ErrInvalidIP, "invalid source ip: %q", out.SourceIP.String()) } return out, nil } func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOptions) (PingResult, error) { var res PingResult if ctx == nil { ctx = context.Background() } if err := ctx.Err(); err != nil { return res, wrapError(err, "ping context done") } targets, err := resolvePingTargets(host, opts.PreferIPv4, opts.PreferIPv6) if err != nil { return res, wrapError(err, "resolve ping target") } payload := pingPayload(opts.PayloadSize) var lastErr error for _, target := range targets { spec, err := socketSpecForIP(target.IP) if err != nil { lastErr = err continue } icmp := getICMP(uint16(seq), nextPingIdentifier(), spec.requestType, payload) resp, err := sendICMPRequest(ctx, icmp, payload, target, opts.SourceIP, spec, opts.Timeout) if err == nil { return resp, nil } // 权限问题通常与地址族无关,继续重试意义不大。 if errors.Is(err, ErrPingPermissionDenied) { return res, err } lastErr = err } if lastErr != nil { return res, wrapError(lastErr, "ping all resolved targets failed") } return res, ErrPingNoResolvedTarget } // PingWithContext sends one ICMP echo request with context cancel support. func PingWithContext(ctx context.Context, host string, seq int, timeout time.Duration) (PingResult, error) { opts, err := normalizePingOptions(&PingOptions{ Count: 1, Timeout: timeout, }, 1, timeout) if err != nil { return PingResult{}, err } if !opts.Deadline.IsZero() { var cancel context.CancelFunc ctx, cancel = context.WithDeadline(ctx, opts.Deadline) defer cancel() } return pingOnceWithOptions(ctx, host, seq, opts) } // Ping sends one ICMP echo request. func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) { return PingWithContext(context.Background(), ip, seq, timeout) } // Pingable checks host reachability with retry options. func Pingable(host string, opts *PingOptions) (bool, error) { cfg, err := normalizePingOptions(opts, defaultPingableCount, defaultPingAttemptTimeout) if err != nil { return false, err } ctx := context.Background() if !cfg.Deadline.IsZero() { var cancel context.CancelFunc ctx, cancel = context.WithDeadline(ctx, cfg.Deadline) defer cancel() } var lastErr error for i := 0; i < cfg.Count; i++ { _, err := pingOnceWithOptions(ctx, host, 29+i, cfg) if err == nil { return true, nil } lastErr = err if errors.Is(err, ErrPingPermissionDenied) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { break } if i < cfg.Count-1 && cfg.Interval > 0 { timer := time.NewTimer(cfg.Interval) select { case <-ctx.Done(): timer.Stop() return false, wrapError(ctx.Err(), "pingable context done") case <-timer.C: } } } if lastErr == nil { lastErr = ErrPingNoResolvedTarget } return false, lastErr } // IsIpPingable keeps backward-compatible bool-only behavior. func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool { if retryLimit <= 0 { return false } ok, _ := Pingable(ip, &PingOptions{ Count: retryLimit, Timeout: timeout, }) return ok }