whoissdk/client.go

733 lines
18 KiB
Go
Raw Normal View History

2026-03-19 11:53:07 +08:00
package whois
import (
"context"
"fmt"
"io"
"net"
"net/url"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/net/proxy"
)
const (
defWhoisServer = "whois.iana.org"
defWhoisPort = "43"
defTimeout = 30 * time.Second
defReferralMaxDepth = 1
asnPrefix = "AS"
)
var DefaultClient = NewClient()
var defaultWhoisMap = map[string]string{}
var whoisNegativeCache sync.Map
type Client struct {
extCache map[string]string
mu sync.Mutex
dialer proxy.Dialer
timeout time.Duration
elapsed time.Duration
negativeCacheTTL time.Duration
}
type rawQueryResult struct {
Data string
Server string
Charset string
}
type rawWhoisStep struct {
Data string
Server string
Charset string
}
type negativeCacheEntry struct {
Result *Result
Domain string
Code ErrorCode
Server string
Reason string
ExpireAt time.Time
}
func NewClient() *Client {
return &Client{
dialer: &net.Dialer{Timeout: defTimeout},
extCache: make(map[string]string),
timeout: defTimeout,
negativeCacheTTL: 0,
}
}
func (c *Client) SetDialer(d proxy.Dialer) *Client { c.dialer = d; return c }
func (c *Client) SetTimeout(t time.Duration) *Client { c.timeout = t; return c }
func (c *Client) SetNegativeCacheTTL(ttl time.Duration) *Client {
c.negativeCacheTTL = ttl
return c
}
func (c *Client) ClearNegativeCache(keys ...string) {
clearNegativeCache(keys...)
}
func (c *Client) Whois(domain string, servers ...string) (Result, error) {
return c.WhoisContext(context.Background(), domain, servers...)
}
func (c *Client) WhoisContext(ctx context.Context, domain string, servers ...string) (Result, error) {
opt := QueryOptions{Level: QueryAuto, ReferralMaxDepth: defReferralMaxDepth}
if len(servers) > 0 {
opt.OverrideServers = normalizeServerList(servers)
if len(opt.OverrideServers) > 0 {
opt.OverrideServer = opt.OverrideServers[0]
}
}
return c.WhoisWithOptionsContext(ctx, domain, opt)
}
func (c *Client) WhoisWithOptions(domain string, opt QueryOptions) (Result, error) {
return c.WhoisWithOptionsContext(context.Background(), domain, opt)
}
func (c *Client) WhoisWithOptionsContext(ctx context.Context, domain string, opt QueryOptions) (Result, error) {
normalizedDomain, err := normalizeQueryDomainInput(domain)
if err != nil {
return Result{}, err
}
domain = normalizedDomain
cacheTTL := c.effectiveNegativeCacheTTL(opt)
cacheKey := ""
if cacheTTL > 0 && !hasOverrideWhoisServers(opt) {
cacheKey = negativeCacheKey(domain, opt)
if cachedResult, cachedErr, ok := loadNegativeCache(cacheKey, time.Now()); ok {
if cachedErr != nil {
return Result{}, cachedErr
}
return cachedResult, nil
}
}
raw, err := c.queryRawWithLevelContext(ctx, domain, opt)
if err != nil {
if cacheKey != "" && IsCode(err, ErrorCodeNoWhoisServer) {
storeNegativeCache(cacheKey, cacheTTL, negativeCacheEntry{
Domain: domain,
Code: ErrorCodeNoWhoisServer,
Server: "",
Reason: "registry whois server not found",
ExpireAt: time.Now().Add(cacheTTL),
})
}
return Result{}, err
}
out, err := parse(domain, raw.Data)
if err != nil {
return Result{}, err
}
out.meta = buildResultMeta(out, "whois", raw.Server)
out.meta.Charset = raw.Charset
if cacheKey != "" && out.meta.ReasonCode == ErrorCodeNotFound {
resultCopy := cloneResult(out)
storeNegativeCache(cacheKey, cacheTTL, negativeCacheEntry{
Result: &resultCopy,
Domain: domain,
Code: ErrorCodeNotFound,
Server: out.meta.Server,
Reason: out.meta.Reason,
ExpireAt: time.Now().Add(cacheTTL),
})
}
return out, nil
}
func (c *Client) queryRawWithLevel(domain string, opt QueryOptions) (string, error) {
out, err := c.queryRawWithLevelContext(context.Background(), domain, opt)
if err != nil {
return "", err
}
return out.Data, nil
}
func (c *Client) queryRawWithLevelContext(ctx context.Context, domain string, opt QueryOptions) (rawQueryResult, error) {
if ctx == nil {
ctx = context.Background()
}
normalizedDomain, err := normalizeQueryDomainInput(domain)
if err != nil {
return rawQueryResult{}, err
}
domain = normalizedDomain
overrideServers := normalizeServerList(opt.OverrideServers)
if len(overrideServers) == 0 && opt.OverrideServer != "" {
overrideServers = []string{opt.OverrideServer}
}
if len(overrideServers) > 0 {
return c.rawQueryFallbackContext(ctx, domain, overrideServers)
}
registryServer, registryPort, err := c.resolveRegistryServerContext(ctx, domain)
if err != nil {
return rawQueryResult{}, err
}
referralDepth := effectiveReferralMaxDepth(opt.ReferralMaxDepth)
chain, err := c.queryWhoisChainContext(ctx, domain, registryServer, registryPort, referralDepth)
if err != nil {
return rawQueryResult{}, err
}
if len(chain) == 0 {
return rawQueryResult{}, newWhoisError(ErrorCodeEmptyResponse, domain, net.JoinHostPort(registryServer, registryPort), "empty whois response", ErrEmptyResponse)
}
switch opt.Level {
case QueryRegistryOnly:
first := chain[0]
return rawQueryResult{Data: first.Data, Server: first.Server, Charset: first.Charset}, nil
case QueryRegistrarOnly:
if len(chain) == 1 {
first := chain[0]
return rawQueryResult{Data: first.Data, Server: first.Server, Charset: first.Charset}, nil
}
last := chain[len(chain)-1]
return rawQueryResult{Data: last.Data, Server: last.Server, Charset: chainCombinedCharset(chain)}, nil
case QueryBoth:
if len(chain) == 1 {
first := chain[0]
return rawQueryResult{Data: first.Data, Server: first.Server, Charset: first.Charset}, nil
}
return rawQueryResult{
Data: formatWhoisChainData(chain, true),
Server: chainServerPath(chain),
Charset: chainCombinedCharset(chain),
}, nil
case QueryAuto:
fallthrough
default:
if len(chain) == 1 {
first := chain[0]
return rawQueryResult{Data: first.Data, Server: first.Server, Charset: first.Charset}, nil
}
return rawQueryResult{
Data: formatWhoisChainData(chain, false),
Server: chainServerPath(chain),
Charset: chainCombinedCharset(chain),
}, nil
}
}
func (c *Client) resolveRegistryServer(domain string) (string, string, error) {
return c.resolveRegistryServerContext(context.Background(), domain)
}
func (c *Client) resolveRegistryServerContext(ctx context.Context, domain string) (string, string, error) {
if v, ok := defaultWhoisMap[getExtension(domain)]; ok {
if host, port := normalizeWhoisServer(v); host != "" {
return host, port, nil
}
}
ext := getExtension(domain)
c.mu.Lock()
cache, ok := c.extCache[ext]
c.mu.Unlock()
if ok {
if host, port, err := net.SplitHostPort(cache); err == nil {
return host, port, nil
}
if host, port := normalizeWhoisServer(cache); host != "" {
return host, port, nil
}
}
result, _, err := c.rawQueryContext(ctx, ext, defWhoisServer, defWhoisPort)
if err != nil {
return "", "", fmt.Errorf("whois: query for whois server failed: %w", err)
}
server, port := getServer(result)
if server == "" {
return "", "", newWhoisError(ErrorCodeNoWhoisServer, domain, net.JoinHostPort(defWhoisServer, defWhoisPort), "registry whois server not found", ErrNoWhoisServer)
}
c.mu.Lock()
c.extCache[ext] = net.JoinHostPort(server, port)
c.mu.Unlock()
return server, port, nil
}
func (c *Client) rawQueryFallback(domain string, servers []string) (string, error) {
out, err := c.rawQueryFallbackContext(context.Background(), domain, servers)
if err != nil {
return "", err
}
return out.Data, nil
}
func (c *Client) rawQuery(domain, server, port string) (string, error) {
data, _, err := c.rawQueryContext(context.Background(), domain, server, port)
return data, err
}
func (c *Client) rawQueryFallbackContext(ctx context.Context, domain string, servers []string) (rawQueryResult, error) {
if ctx == nil {
ctx = context.Background()
}
var lastErr error
for _, server := range normalizeServerList(servers) {
host, port := normalizeWhoisServer(server)
if host == "" {
continue
}
data, charset, err := c.rawQueryContext(ctx, domain, host, port)
if err == nil {
return rawQueryResult{
Data: data,
Server: net.JoinHostPort(host, port),
Charset: charset,
}, nil
}
lastErr = fmt.Errorf("%s:%s: %w", host, port, err)
}
if lastErr == nil {
return rawQueryResult{}, newWhoisError(ErrorCodeNoWhoisServer, domain, "", "no valid whois server", ErrNoWhoisServer)
}
return rawQueryResult{}, fmt.Errorf("whois: all whois servers failed: %w", lastErr)
}
func (c *Client) rawQueryContext(ctx context.Context, domain, server, port string) (string, string, error) {
if ctx == nil {
ctx = context.Background()
}
start := time.Now()
queryText := domain
if server == "whois.arin.net" {
if IsASN(domain) {
queryText = "a + " + domain
} else {
queryText = "n + " + domain
}
}
serverAddr := net.JoinHostPort(server, port)
conn, err := c.dialContext(ctx, "tcp", serverAddr)
if err != nil {
if ctx.Err() != nil {
return "", "", ctx.Err()
}
return "", "", wrapNetworkError(domain, serverAddr, err, "connect failed")
}
defer conn.Close()
cancelWatchDone := make(chan struct{})
defer close(cancelWatchDone)
go func() {
select {
case <-ctx.Done():
_ = conn.SetDeadline(time.Now())
_ = conn.Close()
case <-cancelWatchDone:
}
}()
if err := conn.SetWriteDeadline(c.effectiveDeadline(ctx)); err != nil {
return "", "", wrapNetworkError(domain, serverAddr, err, "set write deadline failed")
}
if _, err = conn.Write([]byte(queryText + "\r\n")); err != nil {
if ctx.Err() != nil {
return "", "", ctx.Err()
}
return "", "", wrapNetworkError(domain, serverAddr, err, "write failed")
}
if err := conn.SetReadDeadline(c.effectiveDeadline(ctx)); err != nil {
return "", "", wrapNetworkError(domain, serverAddr, err, "set read deadline failed")
}
buf, err := io.ReadAll(conn)
if err != nil {
if ctx.Err() != nil {
return "", "", ctx.Err()
}
return "", "", wrapNetworkError(domain, serverAddr, err, "read failed")
}
decoded, charset := decodeWhoisPayload(buf)
data := strings.TrimSpace(decoded)
if data == "" {
return "", charset, newWhoisError(ErrorCodeEmptyResponse, domain, serverAddr, "empty whois response", ErrEmptyResponse)
}
data = fmt.Sprintf("%s\n\n%% Query time: %d msec\n%% WHEN: %s\n",
data, time.Since(start).Milliseconds(), start.Format("Mon Jan 02 15:04:05 MST 2006"))
return data, charset, nil
}
func (c *Client) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
d := c.dialer
if d == nil {
d = &net.Dialer{Timeout: c.timeout}
}
if cd, ok := d.(interface {
DialContext(context.Context, string, string) (net.Conn, error)
}); ok {
return cd.DialContext(ctx, network, address)
}
type dialRet struct {
conn net.Conn
err error
}
ch := make(chan dialRet)
abandon := make(chan struct{})
go func() {
conn, err := d.Dial(network, address)
select {
case <-abandon:
if conn != nil {
_ = conn.Close()
}
case ch <- dialRet{conn: conn, err: err}:
}
}()
select {
case <-ctx.Done():
close(abandon)
return nil, ctx.Err()
case ret := <-ch:
return ret.conn, ret.err
}
}
func (c *Client) effectiveDeadline(ctx context.Context) time.Time {
timeout := c.timeout
if timeout <= 0 {
timeout = defTimeout
}
deadline := time.Now().Add(timeout)
if d, ok := ctx.Deadline(); ok && d.Before(deadline) {
return d
}
return deadline
}
func combineCharset(a, b string) string {
a = strings.TrimSpace(strings.ToLower(a))
b = strings.TrimSpace(strings.ToLower(b))
switch {
case a == "" && b == "":
return ""
case a == "":
return b
case b == "":
return a
case a == b:
return a
default:
return "mixed"
}
}
func getExtension(domain string) string {
ext := domain
if net.ParseIP(domain) == nil {
parts := strings.Split(domain, ".")
ext = parts[len(parts)-1]
}
if strings.Contains(ext, "/") {
ext = strings.Split(ext, "/")[0]
}
return ext
}
func getServer(data string) (string, string) {
lines := strings.Split(data, "\n")
for _, raw := range lines {
line := strings.TrimSpace(raw)
if line == "" {
continue
}
key, val, ok := splitKV(line)
if !ok {
continue
}
switch strings.ToLower(strings.TrimSpace(key)) {
case "registrar whois server", "whois server", "whois", "referralserver", "refer":
if host, port := normalizeWhoisServer(val); host != "" {
return host, port
}
}
}
return "", ""
}
func normalizeServerList(servers []string) []string {
out := make([]string, 0, len(servers))
seen := make(map[string]struct{}, len(servers))
for _, s := range servers {
s = strings.TrimSpace(s)
if s == "" {
continue
}
k := strings.ToLower(s)
if _, ok := seen[k]; ok {
continue
}
seen[k] = struct{}{}
out = append(out, s)
}
return out
}
func normalizeWhoisServer(raw string) (string, string) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", ""
}
raw = strings.Trim(raw, "<>")
if fields := strings.Fields(raw); len(fields) > 0 {
raw = fields[0]
}
if strings.Contains(raw, "://") {
if u, err := url.Parse(raw); err == nil && u.Host != "" {
raw = u.Host
}
}
raw = strings.TrimPrefix(raw, "whois:")
raw = strings.TrimPrefix(raw, "rwhois:")
raw = strings.Trim(raw, "/")
if raw == "" {
return "", ""
}
if host, port, err := net.SplitHostPort(raw); err == nil {
host = strings.TrimSpace(strings.ToLower(host))
port = strings.TrimSpace(port)
if host != "" && port != "" {
return host, port
}
}
if strings.Count(raw, ":") == 1 {
v := strings.SplitN(raw, ":", 2)
host := strings.TrimSpace(strings.ToLower(v[0]))
port := strings.TrimSpace(v[1])
if host != "" && port != "" {
if _, err := strconv.Atoi(port); err == nil {
return host, port
}
}
}
return strings.ToLower(raw), defWhoisPort
}
func effectiveReferralMaxDepth(v int) int {
if v < 0 {
return 0
}
if v == 0 {
return defReferralMaxDepth
}
if v > 16 {
return 16
}
return v
}
func (c *Client) queryWhoisChainContext(ctx context.Context, domain, server, port string, referralMaxDepth int) ([]rawWhoisStep, error) {
steps := make([]rawWhoisStep, 0, 2)
visited := make(map[string]struct{}, referralMaxDepth+1)
host := server
p := port
for {
addr := strings.ToLower(net.JoinHostPort(host, p))
if _, seen := visited[addr]; seen {
break
}
visited[addr] = struct{}{}
data, charset, err := c.rawQueryContext(ctx, domain, host, p)
if err != nil {
if len(steps) == 0 {
return nil, err
}
return steps, nil
}
steps = append(steps, rawWhoisStep{
Data: data,
Server: net.JoinHostPort(host, p),
Charset: charset,
})
if len(steps)-1 >= referralMaxDepth {
break
}
refHost, refPort := getServer(data)
if refHost == "" {
break
}
refAddr := strings.ToLower(net.JoinHostPort(refHost, refPort))
if _, seen := visited[refAddr]; seen {
break
}
host, p = refHost, refPort
}
return steps, nil
}
func formatWhoisChainData(chain []rawWhoisStep, withRegistrarMarker bool) string {
if len(chain) == 0 {
return ""
}
if len(chain) == 1 {
return chain[0].Data
}
parts := make([]string, 0, len(chain))
parts = append(parts, chain[0].Data)
for i := 1; i < len(chain); i++ {
if withRegistrarMarker {
if i == 1 {
parts = append(parts, "----- REGISTRAR WHOIS -----")
} else {
parts = append(parts, fmt.Sprintf("----- REFERRAL WHOIS #%d -----", i+1))
}
}
parts = append(parts, chain[i].Data)
}
sep := "\n"
if withRegistrarMarker {
sep = "\n\n"
}
return strings.Join(parts, sep)
}
func chainServerPath(chain []rawWhoisStep) string {
if len(chain) == 0 {
return ""
}
parts := make([]string, 0, len(chain))
for _, step := range chain {
s := strings.TrimSpace(step.Server)
if s == "" {
continue
}
parts = append(parts, s)
}
return strings.Join(parts, " -> ")
}
func chainCombinedCharset(chain []rawWhoisStep) string {
charset := ""
for _, step := range chain {
charset = combineCharset(charset, step.Charset)
}
return charset
}
func (c *Client) effectiveNegativeCacheTTL(opt QueryOptions) time.Duration {
if opt.NegativeCacheTTL > 0 {
return opt.NegativeCacheTTL
}
if opt.NegativeCacheTTL < 0 {
return 0
}
if c == nil {
return 0
}
if c.negativeCacheTTL > 0 {
return c.negativeCacheTTL
}
return 0
}
func hasOverrideWhoisServers(opt QueryOptions) bool {
if strings.TrimSpace(opt.OverrideServer) != "" {
return true
}
for _, s := range opt.OverrideServers {
if strings.TrimSpace(s) != "" {
return true
}
}
return false
}
func negativeCacheKey(domain string, opt QueryOptions) string {
return fmt.Sprintf("whois-neg|level=%d|domain=%s", opt.Level, strings.ToLower(strings.TrimSpace(domain)))
}
func loadNegativeCache(key string, now time.Time) (Result, error, bool) {
v, ok := whoisNegativeCache.Load(key)
if !ok {
return Result{}, nil, false
}
entry, ok := v.(negativeCacheEntry)
if !ok {
whoisNegativeCache.Delete(key)
return Result{}, nil, false
}
if !entry.ExpireAt.After(now) {
whoisNegativeCache.Delete(key)
return Result{}, nil, false
}
if entry.Result != nil {
return cloneResult(*entry.Result), nil, true
}
cause := ErrNoWhoisServer
if entry.Code == ErrorCodeNotFound {
cause = ErrNotFound
}
return Result{}, newWhoisError(entry.Code, entry.Domain, entry.Server, entry.Reason, cause), true
}
func storeNegativeCache(key string, ttl time.Duration, entry negativeCacheEntry) {
if ttl <= 0 {
return
}
if entry.ExpireAt.IsZero() {
entry.ExpireAt = time.Now().Add(ttl)
}
if entry.Result != nil {
r := cloneResult(*entry.Result)
entry.Result = &r
}
whoisNegativeCache.Store(key, entry)
}
func clearNegativeCache(keys ...string) {
if len(keys) == 0 {
whoisNegativeCache.Range(func(k, _ interface{}) bool {
whoisNegativeCache.Delete(k)
return true
})
return
}
for _, key := range keys {
key = strings.TrimSpace(key)
if key == "" {
continue
}
whoisNegativeCache.Delete(key)
}
}
func cloneResult(in Result) Result {
out := in
out.statusRaw = append([]string(nil), in.statusRaw...)
out.nsServers = append([]string(nil), in.nsServers...)
out.nsIps = append([]string(nil), in.nsIps...)
return out
}
func IsASN(s string) bool {
s = strings.ToUpper(s)
s = strings.TrimPrefix(s, asnPrefix)
_, err := strconv.Atoi(s)
return err == nil
}