543 lines
13 KiB
Go
543 lines
13 KiB
Go
package starnet
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
icmpTypeEchoReplyV4 = 0
|
|
icmpTypeEchoRequestV4 = 8
|
|
icmpTypeEchoRequestV6 = 128
|
|
icmpTypeEchoReplyV6 = 129
|
|
|
|
icmpHeaderLen = 8
|
|
icmpReadBufSz = 1500
|
|
|
|
defaultPingAttemptTimeout = 2 * time.Second
|
|
defaultPingableCount = 3
|
|
maxPingPayloadSize = 65499 // 65507 - ICMP header(8)
|
|
)
|
|
|
|
type ICMP struct {
|
|
Type uint8
|
|
Code uint8
|
|
CheckSum uint16
|
|
Identifier uint16
|
|
SequenceNum uint16
|
|
}
|
|
|
|
type pingSocketSpec struct {
|
|
network string
|
|
family int
|
|
requestType uint8
|
|
replyType uint8
|
|
}
|
|
|
|
// PingOptions controls ping probing behavior.
|
|
type PingOptions struct {
|
|
Count int // ping attempts for Pingable, default 3
|
|
Timeout time.Duration // per-attempt timeout, default 2s
|
|
Interval time.Duration // delay between attempts, default 0
|
|
Deadline time.Time // overall deadline for Pingable/PingWithContext
|
|
PreferIPv4 bool // prefer IPv4 targets
|
|
PreferIPv6 bool // prefer IPv6 targets
|
|
SourceIP net.IP // optional source IP for raw socket bind
|
|
PayloadSize int // ICMP payload bytes, default 0
|
|
}
|
|
|
|
type PingResult struct {
|
|
Duration time.Duration
|
|
RecvCount int
|
|
RemoteIP string
|
|
}
|
|
|
|
var pingIdentifierSeed uint32
|
|
|
|
func nextPingIdentifier() uint16 {
|
|
pid := uint32(os.Getpid() & 0xffff)
|
|
n := atomic.AddUint32(&pingIdentifierSeed, 1)
|
|
return uint16((pid + n) & 0xffff)
|
|
}
|
|
|
|
func pingPayload(size int) []byte {
|
|
if size <= 0 {
|
|
return nil
|
|
}
|
|
payload := make([]byte, size)
|
|
for i := 0; i < len(payload); i++ {
|
|
payload[i] = byte(i)
|
|
}
|
|
return payload
|
|
}
|
|
|
|
func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
|
|
icmp := ICMP{
|
|
Type: typ,
|
|
Code: 0,
|
|
CheckSum: 0,
|
|
Identifier: identifier,
|
|
SequenceNum: seq,
|
|
}
|
|
buf := marshalICMPPacket(icmp, payload)
|
|
icmp.CheckSum = checkSum(buf)
|
|
return icmp
|
|
}
|
|
|
|
func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) {
|
|
var res PingResult
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
if err := ctx.Err(); err != nil {
|
|
return res, wrapError(err, "ping context done")
|
|
}
|
|
if destAddr == nil || destAddr.IP == nil {
|
|
return res, fmt.Errorf("destination ip is nil")
|
|
}
|
|
res.RemoteIP = destAddr.String()
|
|
|
|
localAddr, err := localIPAddrForFamily(sourceIP, spec.family)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
|
|
conn, err := net.DialIP(spec.network, localAddr, destAddr)
|
|
if err != nil {
|
|
return res, normalizePingDialError(err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
packet := marshalICMPPacket(icmp, payload)
|
|
if _, err := conn.Write(packet); err != nil {
|
|
return res, wrapError(err, "ping write request")
|
|
}
|
|
|
|
tStart := time.Now()
|
|
deadline := tStart.Add(timeout)
|
|
if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
|
|
deadline = d
|
|
}
|
|
if err := conn.SetReadDeadline(deadline); err != nil {
|
|
return res, wrapError(err, "ping set read deadline")
|
|
}
|
|
|
|
doneCh := make(chan struct{})
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = conn.SetReadDeadline(time.Now())
|
|
case <-doneCh:
|
|
}
|
|
}()
|
|
defer close(doneCh)
|
|
|
|
recv := make([]byte, icmpReadBufSz)
|
|
for {
|
|
n, err := conn.Read(recv)
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return res, wrapError(ctx.Err(), "ping context done")
|
|
}
|
|
return res, wrapError(err, "ping read reply")
|
|
}
|
|
if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) {
|
|
res.RecvCount = n
|
|
res.Duration = time.Since(tStart)
|
|
return res, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func checkSum(data []byte) uint16 {
|
|
var (
|
|
sum uint32
|
|
length int = len(data)
|
|
index int
|
|
)
|
|
for length > 1 {
|
|
sum += uint32(data[index])<<8 + uint32(data[index+1])
|
|
index += 2
|
|
length -= 2
|
|
}
|
|
if length > 0 {
|
|
sum += uint32(data[index]) << 8
|
|
}
|
|
for sum>>16 != 0 {
|
|
sum = (sum & 0xffff) + (sum >> 16)
|
|
}
|
|
|
|
return uint16(^sum)
|
|
}
|
|
|
|
func marshalICMP(icmp ICMP) []byte {
|
|
return marshalICMPPacket(icmp, nil)
|
|
}
|
|
|
|
func marshalICMPPacket(icmp ICMP, payload []byte) []byte {
|
|
buf := make([]byte, icmpHeaderLen+len(payload))
|
|
buf[0] = icmp.Type
|
|
buf[1] = icmp.Code
|
|
binary.BigEndian.PutUint16(buf[2:], icmp.CheckSum)
|
|
binary.BigEndian.PutUint16(buf[4:], icmp.Identifier)
|
|
binary.BigEndian.PutUint16(buf[6:], icmp.SequenceNum)
|
|
copy(buf[icmpHeaderLen:], payload)
|
|
return buf
|
|
}
|
|
|
|
func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
|
|
for _, off := range candidateICMPOffsets(packet, family) {
|
|
if off < 0 || off+icmpHeaderLen > len(packet) {
|
|
continue
|
|
}
|
|
if packet[off] != expectedType || packet[off+1] != 0 {
|
|
continue
|
|
}
|
|
if binary.BigEndian.Uint16(packet[off+4:off+6]) != identifier {
|
|
continue
|
|
}
|
|
if binary.BigEndian.Uint16(packet[off+6:off+8]) != seq {
|
|
continue
|
|
}
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func candidateICMPOffsets(packet []byte, family int) []int {
|
|
offsets := []int{0}
|
|
if len(packet) == 0 {
|
|
return offsets
|
|
}
|
|
|
|
ver := packet[0] >> 4
|
|
if ver == 4 && len(packet) >= 20 {
|
|
ihl := int(packet[0]&0x0f) * 4
|
|
if ihl >= 20 && ihl <= len(packet)-icmpHeaderLen {
|
|
offsets = append(offsets, ihl)
|
|
}
|
|
} else if ver == 6 && len(packet) >= 40+icmpHeaderLen {
|
|
offsets = append(offsets, 40)
|
|
}
|
|
|
|
// 某些平台/内核可能回包含链路层头部,保守再尝试常见偏移。
|
|
if family == 4 && len(packet) >= 20+icmpHeaderLen {
|
|
offsets = append(offsets, 20)
|
|
}
|
|
if family == 6 && len(packet) >= 40+icmpHeaderLen {
|
|
offsets = append(offsets, 40)
|
|
}
|
|
|
|
return dedupOffsets(offsets)
|
|
}
|
|
|
|
func dedupOffsets(offsets []int) []int {
|
|
if len(offsets) <= 1 {
|
|
return offsets
|
|
}
|
|
m := make(map[int]struct{}, len(offsets))
|
|
out := make([]int, 0, len(offsets))
|
|
for _, off := range offsets {
|
|
if _, ok := m[off]; ok {
|
|
continue
|
|
}
|
|
m[off] = struct{}{}
|
|
out = append(out, off)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func socketSpecForIP(ip net.IP) (pingSocketSpec, error) {
|
|
if ip == nil {
|
|
return pingSocketSpec{}, wrapError(ErrInvalidIP, "ip is nil")
|
|
}
|
|
|
|
if ip4 := ip.To4(); ip4 != nil {
|
|
return pingSocketSpec{
|
|
network: "ip4:icmp",
|
|
family: 4,
|
|
requestType: icmpTypeEchoRequestV4,
|
|
replyType: icmpTypeEchoReplyV4,
|
|
}, nil
|
|
}
|
|
|
|
if ip16 := ip.To16(); ip16 != nil {
|
|
return pingSocketSpec{
|
|
network: "ip6:ipv6-icmp",
|
|
family: 6,
|
|
requestType: icmpTypeEchoRequestV6,
|
|
replyType: icmpTypeEchoReplyV6,
|
|
}, nil
|
|
}
|
|
|
|
return pingSocketSpec{}, wrapError(ErrInvalidIP, "invalid ip: %q", ip.String())
|
|
}
|
|
|
|
func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) {
|
|
if sourceIP == nil {
|
|
return nil, nil
|
|
}
|
|
if sourceIP.To16() == nil {
|
|
return nil, wrapError(ErrInvalidIP, "invalid source ip: %q", sourceIP.String())
|
|
}
|
|
if family == 4 && sourceIP.To4() == nil {
|
|
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv4 target")
|
|
}
|
|
if family == 6 && sourceIP.To4() != nil {
|
|
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv6 target")
|
|
}
|
|
return &net.IPAddr{IP: sourceIP}, nil
|
|
}
|
|
|
|
func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
|
|
if parsed := net.ParseIP(host); parsed != nil {
|
|
return []*net.IPAddr{{IP: parsed}}, nil
|
|
}
|
|
|
|
var targets []*net.IPAddr
|
|
var err4 error
|
|
var err6 error
|
|
|
|
if ip4, e := net.ResolveIPAddr("ip4", host); e == nil && ip4 != nil && ip4.IP != nil {
|
|
targets = append(targets, ip4)
|
|
} else {
|
|
err4 = e
|
|
}
|
|
|
|
if ip6, e := net.ResolveIPAddr("ip6", host); e == nil && ip6 != nil && ip6.IP != nil {
|
|
targets = append(targets, ip6)
|
|
} else {
|
|
err6 = e
|
|
}
|
|
|
|
if len(targets) > 0 {
|
|
return orderPingTargets(targets, preferIPv4, preferIPv6), nil
|
|
}
|
|
|
|
if err4 != nil {
|
|
return nil, err4
|
|
}
|
|
if err6 != nil {
|
|
return nil, err6
|
|
}
|
|
return nil, ErrPingNoResolvedTarget
|
|
}
|
|
|
|
func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
|
|
if len(targets) <= 1 || preferIPv4 == preferIPv6 {
|
|
return targets
|
|
}
|
|
|
|
ordered := make([]*net.IPAddr, 0, len(targets))
|
|
if preferIPv4 {
|
|
for _, t := range targets {
|
|
if t != nil && t.IP != nil && t.IP.To4() != nil {
|
|
ordered = append(ordered, t)
|
|
}
|
|
}
|
|
for _, t := range targets {
|
|
if t != nil && t.IP != nil && t.IP.To4() == nil {
|
|
ordered = append(ordered, t)
|
|
}
|
|
}
|
|
return ordered
|
|
}
|
|
|
|
for _, t := range targets {
|
|
if t != nil && t.IP != nil && t.IP.To4() == nil {
|
|
ordered = append(ordered, t)
|
|
}
|
|
}
|
|
for _, t := range targets {
|
|
if t != nil && t.IP != nil && t.IP.To4() != nil {
|
|
ordered = append(ordered, t)
|
|
}
|
|
}
|
|
return ordered
|
|
}
|
|
|
|
func normalizePingDialError(err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
msg := strings.ToLower(err.Error())
|
|
if errors.Is(err, os.ErrPermission) ||
|
|
strings.Contains(msg, "operation not permitted") ||
|
|
strings.Contains(msg, "permission denied") {
|
|
return fmt.Errorf("%w: %v", ErrPingPermissionDenied, err)
|
|
}
|
|
|
|
if strings.Contains(msg, "unknown network") ||
|
|
strings.Contains(msg, "protocol not available") ||
|
|
strings.Contains(msg, "address family not supported by protocol") ||
|
|
strings.Contains(msg, "socket type not supported") {
|
|
return fmt.Errorf("%w: %v", ErrPingProtocolUnsupported, err)
|
|
}
|
|
|
|
return wrapError(err, "ping dial")
|
|
}
|
|
|
|
func normalizePingOptions(opts *PingOptions, defaultCount int, defaultTimeout time.Duration) (PingOptions, error) {
|
|
out := PingOptions{
|
|
Count: defaultCount,
|
|
Timeout: defaultTimeout,
|
|
Interval: 0,
|
|
PayloadSize: 0,
|
|
}
|
|
if opts != nil {
|
|
out = *opts
|
|
if out.Count == 0 {
|
|
out.Count = defaultCount
|
|
}
|
|
if out.Timeout == 0 {
|
|
out.Timeout = defaultTimeout
|
|
}
|
|
}
|
|
|
|
if out.Count < 0 {
|
|
return out, fmt.Errorf("ping count must be >= 0")
|
|
}
|
|
if out.Timeout <= 0 {
|
|
return out, wrapError(ErrPingInvalidTimeout, "timeout must be > 0")
|
|
}
|
|
if out.Interval < 0 {
|
|
return out, fmt.Errorf("ping interval must be >= 0")
|
|
}
|
|
if out.PayloadSize < 0 || out.PayloadSize > maxPingPayloadSize {
|
|
return out, fmt.Errorf("ping payload size must be in [0,%d]", maxPingPayloadSize)
|
|
}
|
|
if out.SourceIP != nil && out.SourceIP.To16() == nil {
|
|
return out, wrapError(ErrInvalidIP, "invalid source ip: %q", out.SourceIP.String())
|
|
}
|
|
|
|
return out, nil
|
|
}
|
|
|
|
func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOptions) (PingResult, error) {
|
|
var res PingResult
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
if err := ctx.Err(); err != nil {
|
|
return res, wrapError(err, "ping context done")
|
|
}
|
|
|
|
targets, err := resolvePingTargets(host, opts.PreferIPv4, opts.PreferIPv6)
|
|
if err != nil {
|
|
return res, wrapError(err, "resolve ping target")
|
|
}
|
|
|
|
payload := pingPayload(opts.PayloadSize)
|
|
var lastErr error
|
|
for _, target := range targets {
|
|
spec, err := socketSpecForIP(target.IP)
|
|
if err != nil {
|
|
lastErr = err
|
|
continue
|
|
}
|
|
|
|
icmp := getICMP(uint16(seq), nextPingIdentifier(), spec.requestType, payload)
|
|
resp, err := sendICMPRequest(ctx, icmp, payload, target, opts.SourceIP, spec, opts.Timeout)
|
|
if err == nil {
|
|
return resp, nil
|
|
}
|
|
|
|
// 权限问题通常与地址族无关,继续重试意义不大。
|
|
if errors.Is(err, ErrPingPermissionDenied) {
|
|
return res, err
|
|
}
|
|
lastErr = err
|
|
}
|
|
|
|
if lastErr != nil {
|
|
return res, wrapError(lastErr, "ping all resolved targets failed")
|
|
}
|
|
return res, ErrPingNoResolvedTarget
|
|
}
|
|
|
|
// PingWithContext sends one ICMP echo request with context cancel support.
|
|
func PingWithContext(ctx context.Context, host string, seq int, timeout time.Duration) (PingResult, error) {
|
|
opts, err := normalizePingOptions(&PingOptions{
|
|
Count: 1,
|
|
Timeout: timeout,
|
|
}, 1, timeout)
|
|
if err != nil {
|
|
return PingResult{}, err
|
|
}
|
|
|
|
if !opts.Deadline.IsZero() {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithDeadline(ctx, opts.Deadline)
|
|
defer cancel()
|
|
}
|
|
return pingOnceWithOptions(ctx, host, seq, opts)
|
|
}
|
|
|
|
// Ping sends one ICMP echo request.
|
|
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
|
|
return PingWithContext(context.Background(), ip, seq, timeout)
|
|
}
|
|
|
|
// Pingable checks host reachability with retry options.
|
|
func Pingable(host string, opts *PingOptions) (bool, error) {
|
|
cfg, err := normalizePingOptions(opts, defaultPingableCount, defaultPingAttemptTimeout)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
ctx := context.Background()
|
|
if !cfg.Deadline.IsZero() {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithDeadline(ctx, cfg.Deadline)
|
|
defer cancel()
|
|
}
|
|
|
|
var lastErr error
|
|
for i := 0; i < cfg.Count; i++ {
|
|
_, err := pingOnceWithOptions(ctx, host, 29+i, cfg)
|
|
if err == nil {
|
|
return true, nil
|
|
}
|
|
lastErr = err
|
|
|
|
if errors.Is(err, ErrPingPermissionDenied) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
|
break
|
|
}
|
|
|
|
if i < cfg.Count-1 && cfg.Interval > 0 {
|
|
timer := time.NewTimer(cfg.Interval)
|
|
select {
|
|
case <-ctx.Done():
|
|
timer.Stop()
|
|
return false, wrapError(ctx.Err(), "pingable context done")
|
|
case <-timer.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
if lastErr == nil {
|
|
lastErr = ErrPingNoResolvedTarget
|
|
}
|
|
return false, lastErr
|
|
}
|
|
|
|
// IsIpPingable keeps backward-compatible bool-only behavior.
|
|
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
|
|
if retryLimit <= 0 {
|
|
return false
|
|
}
|
|
ok, _ := Pingable(ip, &PingOptions{
|
|
Count: retryLimit,
|
|
Timeout: timeout,
|
|
})
|
|
return ok
|
|
}
|