starnet/ping.go

376 lines
9.3 KiB
Go
Raw Normal View History

2021-06-04 10:49:23 +08:00
package starnet
import (
"context"
"errors"
"fmt"
2021-06-04 10:49:23 +08:00
"net"
"os"
"strings"
2021-06-04 10:49:23 +08:00
"time"
"b612.me/starnet/internal/pingcore"
2021-06-04 10:49:23 +08:00
)
const (
icmpTypeEchoReplyV4 = 0
icmpTypeEchoRequestV4 = 8
icmpTypeEchoRequestV6 = 128
icmpTypeEchoReplyV6 = 129
icmpReadBufSz = 1500
defaultPingAttemptTimeout = 2 * time.Second
defaultPingableCount = 3
maxPingPayloadSize = 65499 // 65507 - ICMP header(8)
)
type ICMP = pingcore.ICMP
2021-06-04 10:49:23 +08:00
type pingSocketSpec struct {
network string
family int
requestType uint8
replyType uint8
}
// PingOptions controls ping probing behavior.
type PingOptions = pingcore.Options
type PingResult = pingcore.Result
func nextPingIdentifier() uint16 {
return pingcore.NextIdentifier()
}
func pingPayload(size int) []byte {
return pingcore.Payload(size)
}
func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP {
return pingcore.BuildICMP(seq, identifier, typ, payload)
2021-06-04 10:49:23 +08:00
}
func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) {
2021-06-04 10:49:23 +08:00
var res PingResult
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return res, wrapError(err, "ping context done")
}
if destAddr == nil || destAddr.IP == nil {
return res, fmt.Errorf("destination ip is nil")
}
2023-02-11 17:15:01 +08:00
res.RemoteIP = destAddr.String()
localAddr, err := localIPAddrForFamily(sourceIP, spec.family)
2021-06-04 10:49:23 +08:00
if err != nil {
return res, err
}
conn, err := net.DialIP(spec.network, localAddr, destAddr)
if err != nil {
return res, normalizePingDialError(err)
}
2021-06-04 10:49:23 +08:00
defer conn.Close()
packet := marshalICMPPacket(icmp, payload)
if _, err := conn.Write(packet); err != nil {
return res, wrapError(err, "ping write request")
2021-06-04 10:49:23 +08:00
}
startedAt := time.Now()
deadline := startedAt.Add(timeout)
if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
deadline = d
}
if err := conn.SetReadDeadline(deadline); err != nil {
return res, wrapError(err, "ping set read deadline")
2021-06-04 10:49:23 +08:00
}
doneCh := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = conn.SetReadDeadline(time.Now())
case <-doneCh:
}
}()
defer close(doneCh)
2021-06-04 10:49:23 +08:00
recv := make([]byte, icmpReadBufSz)
for {
n, err := conn.Read(recv)
if err != nil {
if ctx.Err() != nil {
return res, wrapError(ctx.Err(), "ping context done")
}
return res, wrapError(err, "ping read reply")
}
if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) {
res.RecvCount = n
res.Duration = time.Since(startedAt)
return res, nil
}
}
2021-06-04 10:49:23 +08:00
}
func checkSum(data []byte) uint16 {
return pingcore.Checksum(data)
2021-06-04 10:49:23 +08:00
}
func marshalICMP(icmp ICMP) []byte {
return pingcore.Marshal(icmp)
2021-06-04 10:49:23 +08:00
}
func marshalICMPPacket(icmp ICMP, payload []byte) []byte {
return pingcore.MarshalPacket(icmp, payload)
}
func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool {
return pingcore.IsExpectedEchoReply(packet, family, expectedType, identifier, seq)
}
func candidateICMPOffsets(packet []byte, family int) []int {
return pingcore.CandidateICMPOffsets(packet, family)
}
func dedupOffsets(offsets []int) []int {
return pingcore.DedupOffsets(offsets)
}
func socketSpecForIP(ip net.IP) (pingSocketSpec, error) {
if ip == nil {
return pingSocketSpec{}, wrapError(ErrInvalidIP, "ip is nil")
}
if ip4 := ip.To4(); ip4 != nil {
return pingSocketSpec{
network: "ip4:icmp",
family: 4,
requestType: icmpTypeEchoRequestV4,
replyType: icmpTypeEchoReplyV4,
}, nil
}
if ip16 := ip.To16(); ip16 != nil {
return pingSocketSpec{
network: "ip6:ipv6-icmp",
family: 6,
requestType: icmpTypeEchoRequestV6,
replyType: icmpTypeEchoReplyV6,
}, nil
}
return pingSocketSpec{}, wrapError(ErrInvalidIP, "invalid ip: %q", ip.String())
}
func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) {
if sourceIP == nil {
return nil, nil
}
if sourceIP.To16() == nil {
return nil, wrapError(ErrInvalidIP, "invalid source ip: %q", sourceIP.String())
}
if family == 4 && sourceIP.To4() == nil {
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv4 target")
}
if family == 6 && sourceIP.To4() != nil {
return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv6 target")
}
return &net.IPAddr{IP: sourceIP}, nil
}
func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) {
targets, err := pingcore.ResolveTargets(host, preferIPv4, preferIPv6)
if err != nil {
return nil, err
}
if len(targets) == 0 {
return nil, ErrPingNoResolvedTarget
}
return targets, nil
}
func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr {
return pingcore.OrderTargets(targets, preferIPv4, preferIPv6)
}
func normalizePingDialError(err error) error {
if err == nil {
return nil
}
msg := strings.ToLower(err.Error())
if errors.Is(err, os.ErrPermission) ||
strings.Contains(msg, "operation not permitted") ||
strings.Contains(msg, "permission denied") {
return fmt.Errorf("%w: %v", ErrPingPermissionDenied, err)
}
if strings.Contains(msg, "unknown network") ||
strings.Contains(msg, "protocol not available") ||
strings.Contains(msg, "address family not supported by protocol") ||
strings.Contains(msg, "socket type not supported") {
return fmt.Errorf("%w: %v", ErrPingProtocolUnsupported, err)
}
return wrapError(err, "ping dial")
}
func normalizePingOptions(opts *PingOptions, defaultCount int, defaultTimeout time.Duration) (PingOptions, error) {
out := PingOptions{
Count: defaultCount,
Timeout: defaultTimeout,
Interval: 0,
PayloadSize: 0,
}
if opts != nil {
out = *opts
if out.Count == 0 {
out.Count = defaultCount
}
if out.Timeout == 0 {
out.Timeout = defaultTimeout
}
}
if out.Count < 0 {
return out, fmt.Errorf("ping count must be >= 0")
}
if out.Timeout <= 0 {
return out, wrapError(ErrPingInvalidTimeout, "timeout must be > 0")
}
if out.Interval < 0 {
return out, fmt.Errorf("ping interval must be >= 0")
}
if out.PayloadSize < 0 || out.PayloadSize > maxPingPayloadSize {
return out, fmt.Errorf("ping payload size must be in [0,%d]", maxPingPayloadSize)
}
if out.SourceIP != nil && out.SourceIP.To16() == nil {
return out, wrapError(ErrInvalidIP, "invalid source ip: %q", out.SourceIP.String())
}
return out, nil
}
func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOptions) (PingResult, error) {
2021-06-04 10:49:23 +08:00
var res PingResult
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return res, wrapError(err, "ping context done")
}
targets, err := resolvePingTargets(host, opts.PreferIPv4, opts.PreferIPv6)
2021-06-04 10:49:23 +08:00
if err != nil {
return res, wrapError(err, "resolve ping target")
2021-06-04 10:49:23 +08:00
}
2023-02-11 17:15:01 +08:00
payload := pingPayload(opts.PayloadSize)
var lastErr error
for _, target := range targets {
spec, err := socketSpecForIP(target.IP)
2023-02-11 17:15:01 +08:00
if err != nil {
lastErr = err
2023-02-11 17:15:01 +08:00
continue
}
icmp := getICMP(uint16(seq), nextPingIdentifier(), spec.requestType, payload)
resp, err := sendICMPRequest(ctx, icmp, payload, target, opts.SourceIP, spec, opts.Timeout)
if err == nil {
return resp, nil
}
if errors.Is(err, ErrPingPermissionDenied) {
return res, err
}
lastErr = err
2023-02-11 17:15:01 +08:00
}
if lastErr != nil {
return res, wrapError(lastErr, "ping all resolved targets failed")
}
return res, ErrPingNoResolvedTarget
}
// PingWithContext sends one ICMP echo request with context cancel support.
func PingWithContext(ctx context.Context, host string, seq int, timeout time.Duration) (PingResult, error) {
opts, err := normalizePingOptions(&PingOptions{
Count: 1,
Timeout: timeout,
}, 1, timeout)
if err != nil {
return PingResult{}, err
}
if !opts.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, opts.Deadline)
defer cancel()
}
return pingOnceWithOptions(ctx, host, seq, opts)
}
// Ping sends one ICMP echo request.
func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) {
return PingWithContext(context.Background(), ip, seq, timeout)
}
// Pingable checks host reachability with retry options.
func Pingable(host string, opts *PingOptions) (bool, error) {
cfg, err := normalizePingOptions(opts, defaultPingableCount, defaultPingAttemptTimeout)
if err != nil {
return false, err
}
ctx := context.Background()
if !cfg.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, cfg.Deadline)
defer cancel()
}
var lastErr error
for index := 0; index < cfg.Count; index++ {
_, err := pingOnceWithOptions(ctx, host, 29+index, cfg)
if err == nil {
return true, nil
}
lastErr = err
if errors.Is(err, ErrPingPermissionDenied) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
break
}
if index < cfg.Count-1 && cfg.Interval > 0 {
timer := time.NewTimer(cfg.Interval)
select {
case <-ctx.Done():
timer.Stop()
return false, wrapError(ctx.Err(), "pingable context done")
case <-timer.C:
}
}
}
if lastErr == nil {
lastErr = ErrPingNoResolvedTarget
}
return false, lastErr
}
// IsIpPingable keeps backward-compatible bool-only behavior.
func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool {
if retryLimit <= 0 {
return false
}
ok, _ := Pingable(ip, &PingOptions{
Count: retryLimit,
Timeout: timeout,
})
return ok
2023-02-11 17:15:01 +08:00
}