package whois import ( "context" "fmt" "io" "net" "net/url" "strconv" "strings" "sync" "time" "golang.org/x/net/proxy" ) const ( defWhoisServer = "whois.iana.org" defWhoisPort = "43" defTimeout = 30 * time.Second defReferralMaxDepth = 1 asnPrefix = "AS" ) var DefaultClient = NewClient() var defaultWhoisMap = map[string]string{} var whoisNegativeCache sync.Map type Client struct { extCache map[string]string mu sync.Mutex dialer proxy.Dialer timeout time.Duration elapsed time.Duration negativeCacheTTL time.Duration } type rawQueryResult struct { Data string Server string Charset string } type rawWhoisStep struct { Data string Server string Charset string } type negativeCacheEntry struct { Result *Result Domain string Code ErrorCode Server string Reason string ExpireAt time.Time } func NewClient() *Client { return &Client{ dialer: &net.Dialer{Timeout: defTimeout}, extCache: make(map[string]string), timeout: defTimeout, negativeCacheTTL: 0, } } func (c *Client) SetDialer(d proxy.Dialer) *Client { c.dialer = d; return c } func (c *Client) SetTimeout(t time.Duration) *Client { c.timeout = t; return c } func (c *Client) SetNegativeCacheTTL(ttl time.Duration) *Client { c.negativeCacheTTL = ttl return c } func (c *Client) ClearNegativeCache(keys ...string) { clearNegativeCache(keys...) } func (c *Client) Whois(domain string, servers ...string) (Result, error) { return c.WhoisContext(context.Background(), domain, servers...) } func (c *Client) WhoisContext(ctx context.Context, domain string, servers ...string) (Result, error) { opt := QueryOptions{Level: QueryAuto, ReferralMaxDepth: defReferralMaxDepth} if len(servers) > 0 { opt.OverrideServers = normalizeServerList(servers) if len(opt.OverrideServers) > 0 { opt.OverrideServer = opt.OverrideServers[0] } } return c.WhoisWithOptionsContext(ctx, domain, opt) } func (c *Client) WhoisWithOptions(domain string, opt QueryOptions) (Result, error) { return c.WhoisWithOptionsContext(context.Background(), domain, opt) } func (c *Client) WhoisWithOptionsContext(ctx context.Context, domain string, opt QueryOptions) (Result, error) { normalizedDomain, err := normalizeQueryDomainInput(domain) if err != nil { return Result{}, err } domain = normalizedDomain cacheTTL := c.effectiveNegativeCacheTTL(opt) cacheKey := "" if cacheTTL > 0 && !hasOverrideWhoisServers(opt) { cacheKey = negativeCacheKey(domain, opt) if cachedResult, cachedErr, ok := loadNegativeCache(cacheKey, time.Now()); ok { if cachedErr != nil { return Result{}, cachedErr } return cachedResult, nil } } raw, err := c.queryRawWithLevelContext(ctx, domain, opt) if err != nil { if cacheKey != "" && IsCode(err, ErrorCodeNoWhoisServer) { storeNegativeCache(cacheKey, cacheTTL, negativeCacheEntry{ Domain: domain, Code: ErrorCodeNoWhoisServer, Server: "", Reason: "registry whois server not found", ExpireAt: time.Now().Add(cacheTTL), }) } return Result{}, err } out, err := parse(domain, raw.Data) if err != nil { return Result{}, err } out.meta = buildResultMeta(out, "whois", raw.Server) out.meta.Charset = raw.Charset if cacheKey != "" && out.meta.ReasonCode == ErrorCodeNotFound { resultCopy := cloneResult(out) storeNegativeCache(cacheKey, cacheTTL, negativeCacheEntry{ Result: &resultCopy, Domain: domain, Code: ErrorCodeNotFound, Server: out.meta.Server, Reason: out.meta.Reason, ExpireAt: time.Now().Add(cacheTTL), }) } return out, nil } func (c *Client) queryRawWithLevel(domain string, opt QueryOptions) (string, error) { out, err := c.queryRawWithLevelContext(context.Background(), domain, opt) if err != nil { return "", err } return out.Data, nil } func (c *Client) queryRawWithLevelContext(ctx context.Context, domain string, opt QueryOptions) (rawQueryResult, error) { if ctx == nil { ctx = context.Background() } normalizedDomain, err := normalizeQueryDomainInput(domain) if err != nil { return rawQueryResult{}, err } domain = normalizedDomain overrideServers := normalizeServerList(opt.OverrideServers) if len(overrideServers) == 0 && opt.OverrideServer != "" { overrideServers = []string{opt.OverrideServer} } if len(overrideServers) > 0 { return c.rawQueryFallbackContext(ctx, domain, overrideServers) } registryServer, registryPort, err := c.resolveRegistryServerContext(ctx, domain) if err != nil { return rawQueryResult{}, err } referralDepth := effectiveReferralMaxDepth(opt.ReferralMaxDepth) chain, err := c.queryWhoisChainContext(ctx, domain, registryServer, registryPort, referralDepth) if err != nil { return rawQueryResult{}, err } if len(chain) == 0 { return rawQueryResult{}, newWhoisError(ErrorCodeEmptyResponse, domain, net.JoinHostPort(registryServer, registryPort), "empty whois response", ErrEmptyResponse) } switch opt.Level { case QueryRegistryOnly: first := chain[0] return rawQueryResult{Data: first.Data, Server: first.Server, Charset: first.Charset}, nil case QueryRegistrarOnly: if len(chain) == 1 { first := chain[0] return rawQueryResult{Data: first.Data, Server: first.Server, Charset: first.Charset}, nil } last := chain[len(chain)-1] return rawQueryResult{Data: last.Data, Server: last.Server, Charset: chainCombinedCharset(chain)}, nil case QueryBoth: if len(chain) == 1 { first := chain[0] return rawQueryResult{Data: first.Data, Server: first.Server, Charset: first.Charset}, nil } return rawQueryResult{ Data: formatWhoisChainData(chain, true), Server: chainServerPath(chain), Charset: chainCombinedCharset(chain), }, nil case QueryAuto: fallthrough default: if len(chain) == 1 { first := chain[0] return rawQueryResult{Data: first.Data, Server: first.Server, Charset: first.Charset}, nil } return rawQueryResult{ Data: formatWhoisChainData(chain, false), Server: chainServerPath(chain), Charset: chainCombinedCharset(chain), }, nil } } func (c *Client) resolveRegistryServer(domain string) (string, string, error) { return c.resolveRegistryServerContext(context.Background(), domain) } func (c *Client) resolveRegistryServerContext(ctx context.Context, domain string) (string, string, error) { if v, ok := defaultWhoisMap[getExtension(domain)]; ok { if host, port := normalizeWhoisServer(v); host != "" { return host, port, nil } } ext := getExtension(domain) c.mu.Lock() cache, ok := c.extCache[ext] c.mu.Unlock() if ok { if host, port, err := net.SplitHostPort(cache); err == nil { return host, port, nil } if host, port := normalizeWhoisServer(cache); host != "" { return host, port, nil } } result, _, err := c.rawQueryContext(ctx, ext, defWhoisServer, defWhoisPort) if err != nil { return "", "", fmt.Errorf("whois: query for whois server failed: %w", err) } server, port := getServer(result) if server == "" { return "", "", newWhoisError(ErrorCodeNoWhoisServer, domain, net.JoinHostPort(defWhoisServer, defWhoisPort), "registry whois server not found", ErrNoWhoisServer) } c.mu.Lock() c.extCache[ext] = net.JoinHostPort(server, port) c.mu.Unlock() return server, port, nil } func (c *Client) rawQueryFallback(domain string, servers []string) (string, error) { out, err := c.rawQueryFallbackContext(context.Background(), domain, servers) if err != nil { return "", err } return out.Data, nil } func (c *Client) rawQuery(domain, server, port string) (string, error) { data, _, err := c.rawQueryContext(context.Background(), domain, server, port) return data, err } func (c *Client) rawQueryFallbackContext(ctx context.Context, domain string, servers []string) (rawQueryResult, error) { if ctx == nil { ctx = context.Background() } var lastErr error for _, server := range normalizeServerList(servers) { host, port := normalizeWhoisServer(server) if host == "" { continue } data, charset, err := c.rawQueryContext(ctx, domain, host, port) if err == nil { return rawQueryResult{ Data: data, Server: net.JoinHostPort(host, port), Charset: charset, }, nil } lastErr = fmt.Errorf("%s:%s: %w", host, port, err) } if lastErr == nil { return rawQueryResult{}, newWhoisError(ErrorCodeNoWhoisServer, domain, "", "no valid whois server", ErrNoWhoisServer) } return rawQueryResult{}, fmt.Errorf("whois: all whois servers failed: %w", lastErr) } func (c *Client) rawQueryContext(ctx context.Context, domain, server, port string) (string, string, error) { if ctx == nil { ctx = context.Background() } start := time.Now() queryText := domain if server == "whois.arin.net" { if IsASN(domain) { queryText = "a + " + domain } else { queryText = "n + " + domain } } serverAddr := net.JoinHostPort(server, port) conn, err := c.dialContext(ctx, "tcp", serverAddr) if err != nil { if ctx.Err() != nil { return "", "", ctx.Err() } return "", "", wrapNetworkError(domain, serverAddr, err, "connect failed") } defer conn.Close() cancelWatchDone := make(chan struct{}) defer close(cancelWatchDone) go func() { select { case <-ctx.Done(): _ = conn.SetDeadline(time.Now()) _ = conn.Close() case <-cancelWatchDone: } }() if err := conn.SetWriteDeadline(c.effectiveDeadline(ctx)); err != nil { return "", "", wrapNetworkError(domain, serverAddr, err, "set write deadline failed") } if _, err = conn.Write([]byte(queryText + "\r\n")); err != nil { if ctx.Err() != nil { return "", "", ctx.Err() } return "", "", wrapNetworkError(domain, serverAddr, err, "write failed") } if err := conn.SetReadDeadline(c.effectiveDeadline(ctx)); err != nil { return "", "", wrapNetworkError(domain, serverAddr, err, "set read deadline failed") } buf, err := io.ReadAll(conn) if err != nil { if ctx.Err() != nil { return "", "", ctx.Err() } return "", "", wrapNetworkError(domain, serverAddr, err, "read failed") } decoded, charset := decodeWhoisPayload(buf) data := strings.TrimSpace(decoded) if data == "" { return "", charset, newWhoisError(ErrorCodeEmptyResponse, domain, serverAddr, "empty whois response", ErrEmptyResponse) } data = fmt.Sprintf("%s\n\n%% Query time: %d msec\n%% WHEN: %s\n", data, time.Since(start).Milliseconds(), start.Format("Mon Jan 02 15:04:05 MST 2006")) return data, charset, nil } func (c *Client) dialContext(ctx context.Context, network, address string) (net.Conn, error) { d := c.dialer if d == nil { d = &net.Dialer{Timeout: c.timeout} } if cd, ok := d.(interface { DialContext(context.Context, string, string) (net.Conn, error) }); ok { return cd.DialContext(ctx, network, address) } type dialRet struct { conn net.Conn err error } ch := make(chan dialRet) abandon := make(chan struct{}) go func() { conn, err := d.Dial(network, address) select { case <-abandon: if conn != nil { _ = conn.Close() } case ch <- dialRet{conn: conn, err: err}: } }() select { case <-ctx.Done(): close(abandon) return nil, ctx.Err() case ret := <-ch: return ret.conn, ret.err } } func (c *Client) effectiveDeadline(ctx context.Context) time.Time { timeout := c.timeout if timeout <= 0 { timeout = defTimeout } deadline := time.Now().Add(timeout) if d, ok := ctx.Deadline(); ok && d.Before(deadline) { return d } return deadline } func combineCharset(a, b string) string { a = strings.TrimSpace(strings.ToLower(a)) b = strings.TrimSpace(strings.ToLower(b)) switch { case a == "" && b == "": return "" case a == "": return b case b == "": return a case a == b: return a default: return "mixed" } } func getExtension(domain string) string { ext := domain if net.ParseIP(domain) == nil { parts := strings.Split(domain, ".") ext = parts[len(parts)-1] } if strings.Contains(ext, "/") { ext = strings.Split(ext, "/")[0] } return ext } func getServer(data string) (string, string) { lines := strings.Split(data, "\n") for _, raw := range lines { line := strings.TrimSpace(raw) if line == "" { continue } key, val, ok := splitKV(line) if !ok { continue } switch strings.ToLower(strings.TrimSpace(key)) { case "registrar whois server", "whois server", "whois", "referralserver", "refer": if host, port := normalizeWhoisServer(val); host != "" { return host, port } } } return "", "" } func normalizeServerList(servers []string) []string { out := make([]string, 0, len(servers)) seen := make(map[string]struct{}, len(servers)) for _, s := range servers { s = strings.TrimSpace(s) if s == "" { continue } k := strings.ToLower(s) if _, ok := seen[k]; ok { continue } seen[k] = struct{}{} out = append(out, s) } return out } func normalizeWhoisServer(raw string) (string, string) { raw = strings.TrimSpace(raw) if raw == "" { return "", "" } raw = strings.Trim(raw, "<>") if fields := strings.Fields(raw); len(fields) > 0 { raw = fields[0] } if strings.Contains(raw, "://") { if u, err := url.Parse(raw); err == nil && u.Host != "" { raw = u.Host } } raw = strings.TrimPrefix(raw, "whois:") raw = strings.TrimPrefix(raw, "rwhois:") raw = strings.Trim(raw, "/") if raw == "" { return "", "" } if host, port, err := net.SplitHostPort(raw); err == nil { host = strings.TrimSpace(strings.ToLower(host)) port = strings.TrimSpace(port) if host != "" && port != "" { return host, port } } if strings.Count(raw, ":") == 1 { v := strings.SplitN(raw, ":", 2) host := strings.TrimSpace(strings.ToLower(v[0])) port := strings.TrimSpace(v[1]) if host != "" && port != "" { if _, err := strconv.Atoi(port); err == nil { return host, port } } } return strings.ToLower(raw), defWhoisPort } func effectiveReferralMaxDepth(v int) int { if v < 0 { return 0 } if v == 0 { return defReferralMaxDepth } if v > 16 { return 16 } return v } func (c *Client) queryWhoisChainContext(ctx context.Context, domain, server, port string, referralMaxDepth int) ([]rawWhoisStep, error) { steps := make([]rawWhoisStep, 0, 2) visited := make(map[string]struct{}, referralMaxDepth+1) host := server p := port for { addr := strings.ToLower(net.JoinHostPort(host, p)) if _, seen := visited[addr]; seen { break } visited[addr] = struct{}{} data, charset, err := c.rawQueryContext(ctx, domain, host, p) if err != nil { if len(steps) == 0 { return nil, err } return steps, nil } steps = append(steps, rawWhoisStep{ Data: data, Server: net.JoinHostPort(host, p), Charset: charset, }) if len(steps)-1 >= referralMaxDepth { break } refHost, refPort := getServer(data) if refHost == "" { break } refAddr := strings.ToLower(net.JoinHostPort(refHost, refPort)) if _, seen := visited[refAddr]; seen { break } host, p = refHost, refPort } return steps, nil } func formatWhoisChainData(chain []rawWhoisStep, withRegistrarMarker bool) string { if len(chain) == 0 { return "" } if len(chain) == 1 { return chain[0].Data } parts := make([]string, 0, len(chain)) parts = append(parts, chain[0].Data) for i := 1; i < len(chain); i++ { if withRegistrarMarker { if i == 1 { parts = append(parts, "----- REGISTRAR WHOIS -----") } else { parts = append(parts, fmt.Sprintf("----- REFERRAL WHOIS #%d -----", i+1)) } } parts = append(parts, chain[i].Data) } sep := "\n" if withRegistrarMarker { sep = "\n\n" } return strings.Join(parts, sep) } func chainServerPath(chain []rawWhoisStep) string { if len(chain) == 0 { return "" } parts := make([]string, 0, len(chain)) for _, step := range chain { s := strings.TrimSpace(step.Server) if s == "" { continue } parts = append(parts, s) } return strings.Join(parts, " -> ") } func chainCombinedCharset(chain []rawWhoisStep) string { charset := "" for _, step := range chain { charset = combineCharset(charset, step.Charset) } return charset } func (c *Client) effectiveNegativeCacheTTL(opt QueryOptions) time.Duration { if opt.NegativeCacheTTL > 0 { return opt.NegativeCacheTTL } if opt.NegativeCacheTTL < 0 { return 0 } if c == nil { return 0 } if c.negativeCacheTTL > 0 { return c.negativeCacheTTL } return 0 } func hasOverrideWhoisServers(opt QueryOptions) bool { if strings.TrimSpace(opt.OverrideServer) != "" { return true } for _, s := range opt.OverrideServers { if strings.TrimSpace(s) != "" { return true } } return false } func negativeCacheKey(domain string, opt QueryOptions) string { return fmt.Sprintf("whois-neg|level=%d|domain=%s", opt.Level, strings.ToLower(strings.TrimSpace(domain))) } func loadNegativeCache(key string, now time.Time) (Result, error, bool) { v, ok := whoisNegativeCache.Load(key) if !ok { return Result{}, nil, false } entry, ok := v.(negativeCacheEntry) if !ok { whoisNegativeCache.Delete(key) return Result{}, nil, false } if !entry.ExpireAt.After(now) { whoisNegativeCache.Delete(key) return Result{}, nil, false } if entry.Result != nil { return cloneResult(*entry.Result), nil, true } cause := ErrNoWhoisServer if entry.Code == ErrorCodeNotFound { cause = ErrNotFound } return Result{}, newWhoisError(entry.Code, entry.Domain, entry.Server, entry.Reason, cause), true } func storeNegativeCache(key string, ttl time.Duration, entry negativeCacheEntry) { if ttl <= 0 { return } if entry.ExpireAt.IsZero() { entry.ExpireAt = time.Now().Add(ttl) } if entry.Result != nil { r := cloneResult(*entry.Result) entry.Result = &r } whoisNegativeCache.Store(key, entry) } func clearNegativeCache(keys ...string) { if len(keys) == 0 { whoisNegativeCache.Range(func(k, _ interface{}) bool { whoisNegativeCache.Delete(k) return true }) return } for _, key := range keys { key = strings.TrimSpace(key) if key == "" { continue } whoisNegativeCache.Delete(key) } } func cloneResult(in Result) Result { out := in out.statusRaw = append([]string(nil), in.statusRaw...) out.nsServers = append([]string(nil), in.nsServers...) out.nsIps = append([]string(nil), in.nsIps...) return out } func IsASN(s string) bool { s = strings.ToUpper(s) s = strings.TrimPrefix(s, asnPrefix) _, err := strconv.Atoi(s) return err == nil }