package dns

import (
	"encoding/base64"
	"errors"
	"fmt"
	"github.com/miekg/dns"
	"io"
	"net"
	"net/http"
	"time"
)

type Result struct {
	Res *dns.Msg
	Str string
}

type DnsClient interface {
	Exchange(req *dns.Msg, address string) (r *dns.Msg, rtt time.Duration, err error)
}

func QueryDns(domain string, queryType string, serverType int, dnsServer string) (Result, error) {
	var c DnsClient
	c = new(dns.Client)
	m := new(dns.Msg)
	if dnsServer == "" {
		dnsServer = "223.5.5.5:53"
	}
	switch serverType {
	case 1:
		c.(*dns.Client).Net = "tcp"
	case 2:
		c = &dns.Client{
			Net: "tcp-tls",
			Dialer: &net.Dialer{
				Resolver: net.DefaultResolver,
			},
		}
	case 3:
		c = NewDoHClient(WithTimeout(10 * time.Second))
	}
	switch queryType {
	case "A":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
	case "CNAME":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeCNAME)
	case "MX":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeMX)
	case "NS":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeNS)
	case "TXT":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeTXT)
	case "SOA":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeSOA)
	case "SRV":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeSRV)
	case "AAAA":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeAAAA)
	case "PTR":
		m.SetQuestion(dns.Fqdn(domain), dns.TypePTR)
	case "ANY":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeANY)
	case "CAA":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeCAA)
	case "TLSA":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeTLSA)
	case "DS":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeDS)
	case "DNSKEY":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeDNSKEY)
	case "NSEC":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeNSEC)
	case "NSEC3":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeNSEC3)
	case "NSEC3PARAM":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeNSEC3PARAM)
	case "RRSIG":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeRRSIG)
	case "SPF":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeSPF)
	case "SSHFP":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeSSHFP)
	case "TKEY":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeTKEY)
	case "TSIG":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeTSIG)
	case "URI":
		m.SetQuestion(dns.Fqdn(domain), dns.TypeURI)
	default:
		return Result{}, errors.New("not support query type,only support A,CNAME,MX,NS,SOA,SRV,AAAA,PTR,ANY,CAA,TLSA,DS,DNSKEY,NSEC,NSEC3,NSEC3PARAM,RRSIG,SPF,SSHFP,TKEY,TSIG,URI")
	}
	r, rtt, err := c.Exchange(m, dnsServer)
	if err != nil {
		return Result{}, err
	}
	return Result{
		Res: r,
		Str: r.String() + "\n" + ";; RTT:\n" + fmt.Sprintf("%v milliseconds", rtt.Milliseconds()),
	}, nil
}

const DoHMediaType = "application/dns-message"

type clientOptions struct {
	Timeout time.Duration // Timeout for one DNS query
}

type ClientOption func(*clientOptions) error

func WithTimeout(t time.Duration) ClientOption {
	return func(o *clientOptions) error {
		o.Timeout = t
		return nil
	}
}

type DoHClient struct {
	opt *clientOptions
	cli *http.Client
}

func NewDoHClient(opts ...ClientOption) *DoHClient {
	o := new(clientOptions)
	for _, f := range opts {
		f(o)
	}
	return &DoHClient{
		opt: o,
		cli: &http.Client{
			Timeout: o.Timeout,
		},
	}
}

func (c *DoHClient) Exchange(req *dns.Msg, address string) (r *dns.Msg, rtt time.Duration, err error) {
	var (
		buf, b64 []byte
		begin    = time.Now()
		origID   = req.Id
	)

	// Set DNS ID as zero accoreding to RFC8484 (cache friendly)
	req.Id = 0
	buf, err = req.Pack()
	b64 = make([]byte, base64.RawURLEncoding.EncodedLen(len(buf)))
	if err != nil {
		return
	}
	base64.RawURLEncoding.Encode(b64, buf)

	// No need to use hreq.URL.Query()
	hreq, _ := http.NewRequest("GET", address+"?dns="+string(b64), nil)
	hreq.Header.Set("User-Agent", "B612 DoH Client")
	hreq.Header.Add("Accept", DoHMediaType)
	resp, err := c.cli.Do(hreq)
	if err != nil {
		return
	}
	defer resp.Body.Close()

	content, err := io.ReadAll(resp.Body)
	if err != nil {
		return
	}
	if resp.StatusCode != http.StatusOK {
		err = errors.New("DoH query failed: " + string(content))
		return
	}

	r = new(dns.Msg)
	err = r.Unpack(content)
	r.Id = origID
	rtt = time.Since(begin)
	return
}