package starnet import ( "context" "errors" "fmt" "net" "os" "strings" "time" "b612.me/starnet/internal/pingcore" ) const ( icmpTypeEchoReplyV4 = 0 icmpTypeEchoRequestV4 = 8 icmpTypeEchoRequestV6 = 128 icmpTypeEchoReplyV6 = 129 icmpReadBufSz = 1500 defaultPingAttemptTimeout = 2 * time.Second defaultPingableCount = 3 maxPingPayloadSize = 65499 // 65507 - ICMP header(8) ) type ICMP = pingcore.ICMP type pingSocketSpec struct { network string family int requestType uint8 replyType uint8 } // PingOptions controls ping probing behavior. type PingOptions = pingcore.Options type PingResult = pingcore.Result func nextPingIdentifier() uint16 { return pingcore.NextIdentifier() } func pingPayload(size int) []byte { return pingcore.Payload(size) } func getICMP(seq, identifier uint16, typ uint8, payload []byte) ICMP { return pingcore.BuildICMP(seq, identifier, typ, payload) } func sendICMPRequest(ctx context.Context, icmp ICMP, payload []byte, destAddr *net.IPAddr, sourceIP net.IP, spec pingSocketSpec, timeout time.Duration) (PingResult, error) { var res PingResult if ctx == nil { ctx = context.Background() } if err := ctx.Err(); err != nil { return res, wrapError(err, "ping context done") } if destAddr == nil || destAddr.IP == nil { return res, fmt.Errorf("destination ip is nil") } res.RemoteIP = destAddr.String() localAddr, err := localIPAddrForFamily(sourceIP, spec.family) if err != nil { return res, err } conn, err := net.DialIP(spec.network, localAddr, destAddr) if err != nil { return res, normalizePingDialError(err) } defer conn.Close() packet := marshalICMPPacket(icmp, payload) if _, err := conn.Write(packet); err != nil { return res, wrapError(err, "ping write request") } startedAt := time.Now() deadline := startedAt.Add(timeout) if d, ok := ctx.Deadline(); ok && d.Before(deadline) { deadline = d } if err := conn.SetReadDeadline(deadline); err != nil { return res, wrapError(err, "ping set read deadline") } doneCh := make(chan struct{}) go func() { select { case <-ctx.Done(): _ = conn.SetReadDeadline(time.Now()) case <-doneCh: } }() defer close(doneCh) recv := make([]byte, icmpReadBufSz) for { n, err := conn.Read(recv) if err != nil { if ctx.Err() != nil { return res, wrapError(ctx.Err(), "ping context done") } return res, wrapError(err, "ping read reply") } if isExpectedEchoReply(recv[:n], spec.family, spec.replyType, icmp.Identifier, icmp.SequenceNum) { res.RecvCount = n res.Duration = time.Since(startedAt) return res, nil } } } func checkSum(data []byte) uint16 { return pingcore.Checksum(data) } func marshalICMP(icmp ICMP) []byte { return pingcore.Marshal(icmp) } func marshalICMPPacket(icmp ICMP, payload []byte) []byte { return pingcore.MarshalPacket(icmp, payload) } func isExpectedEchoReply(packet []byte, family int, expectedType uint8, identifier, seq uint16) bool { return pingcore.IsExpectedEchoReply(packet, family, expectedType, identifier, seq) } func candidateICMPOffsets(packet []byte, family int) []int { return pingcore.CandidateICMPOffsets(packet, family) } func dedupOffsets(offsets []int) []int { return pingcore.DedupOffsets(offsets) } func socketSpecForIP(ip net.IP) (pingSocketSpec, error) { if ip == nil { return pingSocketSpec{}, wrapError(ErrInvalidIP, "ip is nil") } if ip4 := ip.To4(); ip4 != nil { return pingSocketSpec{ network: "ip4:icmp", family: 4, requestType: icmpTypeEchoRequestV4, replyType: icmpTypeEchoReplyV4, }, nil } if ip16 := ip.To16(); ip16 != nil { return pingSocketSpec{ network: "ip6:ipv6-icmp", family: 6, requestType: icmpTypeEchoRequestV6, replyType: icmpTypeEchoReplyV6, }, nil } return pingSocketSpec{}, wrapError(ErrInvalidIP, "invalid ip: %q", ip.String()) } func localIPAddrForFamily(sourceIP net.IP, family int) (*net.IPAddr, error) { if sourceIP == nil { return nil, nil } if sourceIP.To16() == nil { return nil, wrapError(ErrInvalidIP, "invalid source ip: %q", sourceIP.String()) } if family == 4 && sourceIP.To4() == nil { return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv4 target") } if family == 6 && sourceIP.To4() != nil { return nil, wrapError(ErrInvalidIP, "source ip family mismatch with IPv6 target") } return &net.IPAddr{IP: sourceIP}, nil } func resolvePingTargets(host string, preferIPv4, preferIPv6 bool) ([]*net.IPAddr, error) { targets, err := pingcore.ResolveTargets(host, preferIPv4, preferIPv6) if err != nil { return nil, err } if len(targets) == 0 { return nil, ErrPingNoResolvedTarget } return targets, nil } func orderPingTargets(targets []*net.IPAddr, preferIPv4, preferIPv6 bool) []*net.IPAddr { return pingcore.OrderTargets(targets, preferIPv4, preferIPv6) } func normalizePingDialError(err error) error { if err == nil { return nil } msg := strings.ToLower(err.Error()) if errors.Is(err, os.ErrPermission) || strings.Contains(msg, "operation not permitted") || strings.Contains(msg, "permission denied") { return fmt.Errorf("%w: %v", ErrPingPermissionDenied, err) } if strings.Contains(msg, "unknown network") || strings.Contains(msg, "protocol not available") || strings.Contains(msg, "address family not supported by protocol") || strings.Contains(msg, "socket type not supported") { return fmt.Errorf("%w: %v", ErrPingProtocolUnsupported, err) } return wrapError(err, "ping dial") } func normalizePingOptions(opts *PingOptions, defaultCount int, defaultTimeout time.Duration) (PingOptions, error) { out := PingOptions{ Count: defaultCount, Timeout: defaultTimeout, Interval: 0, PayloadSize: 0, } if opts != nil { out = *opts if out.Count == 0 { out.Count = defaultCount } if out.Timeout == 0 { out.Timeout = defaultTimeout } } if out.Count < 0 { return out, fmt.Errorf("ping count must be >= 0") } if out.Timeout <= 0 { return out, wrapError(ErrPingInvalidTimeout, "timeout must be > 0") } if out.Interval < 0 { return out, fmt.Errorf("ping interval must be >= 0") } if out.PayloadSize < 0 || out.PayloadSize > maxPingPayloadSize { return out, fmt.Errorf("ping payload size must be in [0,%d]", maxPingPayloadSize) } if out.SourceIP != nil && out.SourceIP.To16() == nil { return out, wrapError(ErrInvalidIP, "invalid source ip: %q", out.SourceIP.String()) } return out, nil } func pingOnceWithOptions(ctx context.Context, host string, seq int, opts PingOptions) (PingResult, error) { var res PingResult if ctx == nil { ctx = context.Background() } if err := ctx.Err(); err != nil { return res, wrapError(err, "ping context done") } targets, err := resolvePingTargets(host, opts.PreferIPv4, opts.PreferIPv6) if err != nil { return res, wrapError(err, "resolve ping target") } payload := pingPayload(opts.PayloadSize) var lastErr error for _, target := range targets { spec, err := socketSpecForIP(target.IP) if err != nil { lastErr = err continue } icmp := getICMP(uint16(seq), nextPingIdentifier(), spec.requestType, payload) resp, err := sendICMPRequest(ctx, icmp, payload, target, opts.SourceIP, spec, opts.Timeout) if err == nil { return resp, nil } if errors.Is(err, ErrPingPermissionDenied) { return res, err } lastErr = err } if lastErr != nil { return res, wrapError(lastErr, "ping all resolved targets failed") } return res, ErrPingNoResolvedTarget } // PingWithContext sends one ICMP echo request with context cancel support. func PingWithContext(ctx context.Context, host string, seq int, timeout time.Duration) (PingResult, error) { opts, err := normalizePingOptions(&PingOptions{ Count: 1, Timeout: timeout, }, 1, timeout) if err != nil { return PingResult{}, err } if !opts.Deadline.IsZero() { var cancel context.CancelFunc ctx, cancel = context.WithDeadline(ctx, opts.Deadline) defer cancel() } return pingOnceWithOptions(ctx, host, seq, opts) } // Ping sends one ICMP echo request. func Ping(ip string, seq int, timeout time.Duration) (PingResult, error) { return PingWithContext(context.Background(), ip, seq, timeout) } // Pingable checks host reachability with retry options. func Pingable(host string, opts *PingOptions) (bool, error) { cfg, err := normalizePingOptions(opts, defaultPingableCount, defaultPingAttemptTimeout) if err != nil { return false, err } ctx := context.Background() if !cfg.Deadline.IsZero() { var cancel context.CancelFunc ctx, cancel = context.WithDeadline(ctx, cfg.Deadline) defer cancel() } var lastErr error for index := 0; index < cfg.Count; index++ { _, err := pingOnceWithOptions(ctx, host, 29+index, cfg) if err == nil { return true, nil } lastErr = err if errors.Is(err, ErrPingPermissionDenied) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { break } if index < cfg.Count-1 && cfg.Interval > 0 { timer := time.NewTimer(cfg.Interval) select { case <-ctx.Done(): timer.Stop() return false, wrapError(ctx.Err(), "pingable context done") case <-timer.C: } } } if lastErr == nil { lastErr = ErrPingNoResolvedTarget } return false, lastErr } // IsIpPingable keeps backward-compatible bool-only behavior. func IsIpPingable(ip string, timeout time.Duration, retryLimit int) bool { if retryLimit <= 0 { return false } ok, _ := Pingable(ip, &PingOptions{ Count: retryLimit, Timeout: timeout, }) return ok }