starnet/ping.go
starainrt b5bd7595a1
1. 优化ping功能
2. 新增重试机制
3. 优化错误处理逻辑
2026-03-19 16:42:45 +08:00

543 lines
13 KiB
Go

package starnet
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"os"
"strings"
"sync/atomic"
"time"
)
const (
icmpTypeEchoReplyV4 = 0
icmpTypeEchoRequestV4 = 8
icmpTypeEchoRequestV6 = 128
icmpTypeEchoReplyV6 = 129
icmpHeaderLen = 8
icmpReadBufSz = 1500
defaultPingAttemptTimeout = 2 * time.Second
defaultPingableCount = 3
maxPingPayloadSize = 65499 // 65507 - ICMP header(8)
)
type ICMP struct {
Type uint8
Code uint8
CheckSum uint16
Identifier uint16
SequenceNum uint16
}
type pingSocketSpec struct {
network string
family int
requestType uint8
replyType uint8
}
// PingOptions controls ping probing behavior.
type PingOptions struct {
Count int // ping attempts for Pingable, default 3
Timeout time.Duration // per-attempt timeout, default 2s
Interval time.Duration // delay between attempts, default 0
Deadline time.Time // overall deadline for Pingable/PingWithContext
PreferIPv4 bool // prefer IPv4 targets
PreferIPv6 bool // prefer IPv6 targets
SourceIP net.IP // optional source IP for raw socket bind
PayloadSize int // ICMP payload bytes, default 0
}
type PingResult struct {
Duration time.Duration
RecvCount int
RemoteIP string
}
var pingIdentifierSeed uint32
func nextPingIdentifier() uint16 {
pid := uint32(os.Getpid() & 0xffff)
n := atomic.AddUint32(&pingIdentifierSeed, 1)
return uint16((pid + n) & 0xffff)
}
func pingPayload(size int) []byte {
if size <= 0 {
return nil
}
payload := make([]byte, size)
for i := 0; i < len(payload); i++ {
payload[i] = byte(i)
}
return payload
}
func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
icmp := ICMP{
Type: typ,
Code: 0,
CheckSum: 0,
Identifier: identifier,
SequenceNum: seq,
}
buf := marshalICMPPacket(icmp, payload)
icmp.CheckSum = checkSum(buf)
return icmp
}
func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) {
var res PingResult
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return res, wrapError(err, "ping context done")
}
if destAddr == nil || destAddr.IP == nil {
return res, fmt.Errorf("destination ip is nil")
}
res.RemoteIP = destAddr.String()
localAddr, err := localIPAddrForFamily(sourceIP, spec.family)
if err != nil {
return res, err
}
conn, err := net.DialIP(spec.network, localAddr, destAddr)
if err != nil {
return res, normalizePingDialError(err)
}
defer conn.Close()
packet := marshalICMPPacket(icmp, payload)
if _, err := conn.Write(packet); err != nil {
return res, wrapError(err, "ping write request")
}
tStart := time.Now()
deadline := tStart.Add(timeout)
if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
deadline = d
}
if err := conn.SetReadDeadline(deadline); err != nil {
return res, wrapError(err, "ping set read deadline")
}
doneCh := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = conn.SetReadDeadline(time.Now())
case <-doneCh:
}
}()
defer close(doneCh)
recv := make([]byte, icmpReadBufSz)
for {
n, err := conn.Read(recv)
if err != nil {
if ctx.Err() != nil {
return res, wrapError(ctx.Err(), "ping context done")
}
return res, wrapError(err, "ping read reply")
}
if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) {
res.RecvCount = n
res.Duration = time.Since(tStart)
return res, nil
}
}
}
func checkSum(data []byte) uint16 {
var (
sum uint32
length int = len(data)
index int
)
for length > 1 {
sum += uint32(data[index])<<8 + uint32(data[index+1])
index += 2
length -= 2
}
if length > 0 {
sum += uint32(data[index]) << 8
}
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return uint16(^sum)
}
func marshalICMP(icmp ICMP) []byte {
return marshalICMPPacket(icmp, nil)
}
func marshalICMPPacket(icmp ICMP, payload []byte) []byte {
buf := make([]byte, icmpHeaderLen+len(payload))
buf[0] = icmp.Type
buf[1] = icmp.Code
binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
copy(buf[icmpHeaderLen:], payload)
return buf
}
func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
for _, off := range candidateICMPOffsets(packet, family) {
if off < 0 || off+icmpHeaderLen > len(packet) {
continue
}
if packet[off] != expectedType || packet[off+1] != 0 {
continue
}
if binary.BigEndian.Uint16(packet[off+4:off+6]) != identifier {
continue
}
if binary.BigEndian.Uint16(packet[off+6:off+8]) != seq {
continue
}
return true
}
return false
}
func candidateICMPOffsets(packet []byte, family int) []int {
offsets := []int{0}
if len(packet) == 0 {
return offsets
}
ver := packet[0] >> 4
if ver == 4 && len(packet) >= 20 {
ihl := int(packet[0]&0x0f) * 4
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
offsets = append(offsets, ihl)
}
} else if ver == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
// 某些平台/内核可能回包含链路层头部,保守再尝试常见偏移。
if family == 4 && len(packet) >= 20+icmpHeaderLen {
offsets = append(offsets, 20)
}
if family == 6 && len(packet) >= 40+icmpHeaderLen {
offsets = append(offsets, 40)
}
return dedupOffsets(offsets)
}
func dedupOffsets(offsets []int) []int {
if len(offsets) <= 1 {
return offsets
}
m := make(map[int]struct{}, len(offsets))
out := make([]int, 0, len(offsets))
for _, off := range offsets {
if _, ok := m[off]; ok {
continue
}
m[off] = struct{}{}
out = append(out, off)
}
return out
}
func socketSpecForIP(ip net.IP) (pingSocketSpec, error) {
if ip == nil {
return pingSocketSpec{}, wrapError(ErrInvalidIP, "ip is nil")
}
if ip4 := ip.To4(); ip4 != nil {
return pingSocketSpec{
network: "ip4:icmp",
family: 4,
requestType: icmpTypeEchoRequestV4,
replyType: icmpTypeEchoReplyV4,
}, nil
}
if ip16 := ip.To16(); ip16 != nil {
return pingSocketSpec{
network: "ip6:ipv6-icmp",
family: 6,
requestType: icmpTypeEchoRequestV6,
replyType: icmpTypeEchoReplyV6,
}, nil
}
return pingSocketSpec{}, wrapError(ErrInvalidIP, "invalid ip: %q", ip.String())
}
func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) {
if sourceIP == nil {
return nil, nil
}
if sourceIP.To16() == nil {
return nil, wrapError(ErrInvalidIP, "invalid source ip: %q", sourceIP.String())
}
if family == 4 && sourceIP.To4() == nil {
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv4 target")
}
if family == 6 && sourceIP.To4() != nil {
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv6 target")
}
return &net.IPAddr{IP: sourceIP}, nil
}
func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
if parsed := net.ParseIP(host); parsed != nil {
return []*net.IPAddr{{IP: parsed}}, nil
}
var targets []*net.IPAddr
var err4 error
var err6 error
if ip4, e := net.ResolveIPAddr("ip4", host); e == nil && ip4 != nil && ip4.IP != nil {
targets = append(targets, ip4)
} else {
err4 = e
}
if ip6, e := net.ResolveIPAddr("ip6", host); e == nil && ip6 != nil && ip6.IP != nil {
targets = append(targets, ip6)
} else {
err6 = e
}
if len(targets) > 0 {
return orderPingTargets(targets, preferIPv4, preferIPv6), nil
}
if err4 != nil {
return nil, err4
}
if err6 != nil {
return nil, err6
}
return nil, ErrPingNoResolvedTarget
}
func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
if len(targets) <= 1 || preferIPv4 == preferIPv6 {
return targets
}
ordered := make([]*net.IPAddr, 0, len(targets))
if preferIPv4 {
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() != nil {
ordered = append(ordered, t)
}
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() == nil {
ordered = append(ordered, t)
}
}
return ordered
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() == nil {
ordered = append(ordered, t)
}
}
for _, t := range targets {
if t != nil && t.IP != nil && t.IP.To4() != nil {
ordered = append(ordered, t)
}
}
return ordered
}
func normalizePingDialError(err error) error {
if err == nil {
return nil
}
msg := strings.ToLower(err.Error())
if errors.Is(err, os.ErrPermission) ||
strings.Contains(msg, "operation not permitted") ||
strings.Contains(msg, "permission denied") {
return fmt.Errorf("%w: %v", ErrPingPermissionDenied, err)
}
if strings.Contains(msg, "unknown network") ||
strings.Contains(msg, "protocol not available") ||
strings.Contains(msg, "address family not supported by protocol") ||
strings.Contains(msg, "socket type not supported") {
return fmt.Errorf("%w: %v", ErrPingProtocolUnsupported, err)
}
return wrapError(err, "ping dial")
}
func normalizePingOptions(opts *PingOptions, defaultCount int, defaultTimeout time.Duration) (PingOptions, error) {
out := PingOptions{
Count: defaultCount,
Timeout: defaultTimeout,
Interval: 0,
PayloadSize: 0,
}
if opts != nil {
out = *opts
if out.Count == 0 {
out.Count = defaultCount
}
if out.Timeout == 0 {
out.Timeout = defaultTimeout
}
}
if out.Count < 0 {
return out, fmt.Errorf("ping count must be >= 0")
}
if out.Timeout <= 0 {
return out, wrapError(ErrPingInvalidTimeout, "timeout must be > 0")
}
if out.Interval < 0 {
return out, fmt.Errorf("ping interval must be >= 0")
}
if out.PayloadSize < 0 || out.PayloadSize > maxPingPayloadSize {
return out, fmt.Errorf("ping payload size must be in [0,%d]", maxPingPayloadSize)
}
if out.SourceIP != nil && out.SourceIP.To16() == nil {
return out, wrapError(ErrInvalidIP, "invalid source ip: %q", out.SourceIP.String())
}
return out, nil
}
func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOptions) (PingResult, error) {
var res PingResult
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return res, wrapError(err, "ping context done")
}
targets, err := resolvePingTargets(host, opts.PreferIPv4, opts.PreferIPv6)
if err != nil {
return res, wrapError(err, "resolve ping target")
}
payload := pingPayload(opts.PayloadSize)
var lastErr error
for _, target := range targets {
spec, err := socketSpecForIP(target.IP)
if err != nil {
lastErr = err
continue
}
icmp := getICMP(uint16(seq), nextPingIdentifier(), spec.requestType, payload)
resp, err := sendICMPRequest(ctx, icmp, payload, target, opts.SourceIP, spec, opts.Timeout)
if err == nil {
return resp, nil
}
// 权限问题通常与地址族无关,继续重试意义不大。
if errors.Is(err, ErrPingPermissionDenied) {
return res, err
}
lastErr = err
}
if lastErr != nil {
return res, wrapError(lastErr, "ping all resolved targets failed")
}
return res, ErrPingNoResolvedTarget
}
// PingWithContext sends one ICMP echo request with context cancel support.
func PingWithContext(ctx context.Context, host string, seq int, timeout time.Duration) (PingResult, error) {
opts, err := normalizePingOptions(&PingOptions{
Count: 1,
Timeout: timeout,
}, 1, timeout)
if err != nil {
return PingResult{}, err
}
if !opts.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, opts.Deadline)
defer cancel()
}
return pingOnceWithOptions(ctx, host, seq, opts)
}
// Ping sends one ICMP echo request.
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
return PingWithContext(context.Background(), ip, seq, timeout)
}
// Pingable checks host reachability with retry options.
func Pingable(host string, opts *PingOptions) (bool, error) {
cfg, err := normalizePingOptions(opts, defaultPingableCount, defaultPingAttemptTimeout)
if err != nil {
return false, err
}
ctx := context.Background()
if !cfg.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, cfg.Deadline)
defer cancel()
}
var lastErr error
for i := 0; i < cfg.Count; i++ {
_, err := pingOnceWithOptions(ctx, host, 29+i, cfg)
if err == nil {
return true, nil
}
lastErr = err
if errors.Is(err, ErrPingPermissionDenied) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
break
}
if i < cfg.Count-1 && cfg.Interval > 0 {
timer := time.NewTimer(cfg.Interval)
select {
case <-ctx.Done():
timer.Stop()
return false, wrapError(ctx.Err(), "pingable context done")
case <-timer.C:
}
}
}
if lastErr == nil {
lastErr = ErrPingNoResolvedTarget
}
return false, lastErr
}
// IsIpPingable keeps backward-compatible bool-only behavior.
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
if retryLimit <= 0 {
return false
}
ok, _ := Pingable(ip, &PingOptions{
Count: retryLimit,
Timeout: timeout,
})
return ok
}