341 lines
7.2 KiB
Go
341 lines
7.2 KiB
Go
|
|
package starnet
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"crypto/tls"
|
||
|
|
"net"
|
||
|
|
"net/http/httptrace"
|
||
|
|
"sync/atomic"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
type traceContextKey struct{}
|
||
|
|
|
||
|
|
// TraceHooks defines optional callbacks for network lifecycle events.
|
||
|
|
// Hooks may be called concurrently.
|
||
|
|
type TraceHooks struct {
|
||
|
|
GetConn func(TraceGetConnInfo)
|
||
|
|
GotConn func(TraceGotConnInfo)
|
||
|
|
PutIdleConn func(TracePutIdleConnInfo)
|
||
|
|
DNSStart func(TraceDNSStartInfo)
|
||
|
|
DNSDone func(TraceDNSDoneInfo)
|
||
|
|
ConnectStart func(TraceConnectStartInfo)
|
||
|
|
ConnectDone func(TraceConnectDoneInfo)
|
||
|
|
TLSHandshakeStart func(TraceTLSHandshakeStartInfo)
|
||
|
|
TLSHandshakeDone func(TraceTLSHandshakeDoneInfo)
|
||
|
|
WroteHeaderField func(TraceWroteHeaderFieldInfo)
|
||
|
|
WroteHeaders func()
|
||
|
|
WroteRequest func(TraceWroteRequestInfo)
|
||
|
|
GotFirstResponseByte func()
|
||
|
|
RetryAttemptStart func(TraceRetryAttemptStartInfo)
|
||
|
|
RetryAttemptDone func(TraceRetryAttemptDoneInfo)
|
||
|
|
RetryBackoff func(TraceRetryBackoffInfo)
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceGetConnInfo struct {
|
||
|
|
Addr string
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceGotConnInfo struct {
|
||
|
|
Conn net.Conn
|
||
|
|
Reused bool
|
||
|
|
WasIdle bool
|
||
|
|
IdleTime time.Duration
|
||
|
|
}
|
||
|
|
|
||
|
|
type TracePutIdleConnInfo struct {
|
||
|
|
Err error
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceDNSStartInfo struct {
|
||
|
|
Host string
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceDNSDoneInfo struct {
|
||
|
|
Addrs []net.IPAddr
|
||
|
|
Coalesced bool
|
||
|
|
Err error
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceConnectStartInfo struct {
|
||
|
|
Network string
|
||
|
|
Addr string
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceConnectDoneInfo struct {
|
||
|
|
Network string
|
||
|
|
Addr string
|
||
|
|
Err error
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceTLSHandshakeStartInfo struct {
|
||
|
|
Network string
|
||
|
|
Addr string
|
||
|
|
ServerName string
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceTLSHandshakeDoneInfo struct {
|
||
|
|
Network string
|
||
|
|
Addr string
|
||
|
|
ServerName string
|
||
|
|
ConnectionState tls.ConnectionState
|
||
|
|
Err error
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceWroteHeaderFieldInfo struct {
|
||
|
|
Key string
|
||
|
|
Values []string
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceWroteRequestInfo struct {
|
||
|
|
Err error
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceRetryAttemptStartInfo struct {
|
||
|
|
Attempt int
|
||
|
|
MaxAttempts int
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceRetryAttemptDoneInfo struct {
|
||
|
|
Attempt int
|
||
|
|
MaxAttempts int
|
||
|
|
StatusCode int
|
||
|
|
Err error
|
||
|
|
WillRetry bool
|
||
|
|
}
|
||
|
|
|
||
|
|
type TraceRetryBackoffInfo struct {
|
||
|
|
Attempt int
|
||
|
|
Delay time.Duration
|
||
|
|
}
|
||
|
|
|
||
|
|
type traceState struct {
|
||
|
|
hooks *TraceHooks
|
||
|
|
customTLS atomic.Uint32
|
||
|
|
manualDNSRefs atomic.Int32
|
||
|
|
}
|
||
|
|
|
||
|
|
func newTraceState(hooks *TraceHooks) *traceState {
|
||
|
|
if hooks == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return &traceState{hooks: hooks}
|
||
|
|
}
|
||
|
|
|
||
|
|
func withTraceState(ctx context.Context, state *traceState) context.Context {
|
||
|
|
if state == nil {
|
||
|
|
return ctx
|
||
|
|
}
|
||
|
|
return context.WithValue(ctx, traceContextKey{}, state)
|
||
|
|
}
|
||
|
|
|
||
|
|
func getTraceState(ctx context.Context) *traceState {
|
||
|
|
if ctx == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
state, _ := ctx.Value(traceContextKey{}).(*traceState)
|
||
|
|
return state
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *traceState) needsHTTPTrace() bool {
|
||
|
|
if t == nil || t.hooks == nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
h := t.hooks
|
||
|
|
return h.GetConn != nil ||
|
||
|
|
h.GotConn != nil ||
|
||
|
|
h.PutIdleConn != nil ||
|
||
|
|
h.DNSStart != nil ||
|
||
|
|
h.DNSDone != nil ||
|
||
|
|
h.ConnectStart != nil ||
|
||
|
|
h.ConnectDone != nil ||
|
||
|
|
h.TLSHandshakeStart != nil ||
|
||
|
|
h.TLSHandshakeDone != nil ||
|
||
|
|
h.WroteHeaderField != nil ||
|
||
|
|
h.WroteHeaders != nil ||
|
||
|
|
h.WroteRequest != nil ||
|
||
|
|
h.GotFirstResponseByte != nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *traceState) clientTrace() *httptrace.ClientTrace {
|
||
|
|
if !t.needsHTTPTrace() {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
h := t.hooks
|
||
|
|
trace := &httptrace.ClientTrace{}
|
||
|
|
if h.GetConn != nil {
|
||
|
|
trace.GetConn = func(hostPort string) {
|
||
|
|
h.GetConn(TraceGetConnInfo{Addr: hostPort})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if h.GotConn != nil {
|
||
|
|
trace.GotConn = func(info httptrace.GotConnInfo) {
|
||
|
|
h.GotConn(TraceGotConnInfo{
|
||
|
|
Conn: info.Conn,
|
||
|
|
Reused: info.Reused,
|
||
|
|
WasIdle: info.WasIdle,
|
||
|
|
IdleTime: info.IdleTime,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if h.PutIdleConn != nil {
|
||
|
|
trace.PutIdleConn = func(err error) {
|
||
|
|
h.PutIdleConn(TracePutIdleConnInfo{Err: err})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if h.DNSStart != nil {
|
||
|
|
trace.DNSStart = func(info httptrace.DNSStartInfo) {
|
||
|
|
if t.usesManualDNS() {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
h.DNSStart(TraceDNSStartInfo{Host: info.Host})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if h.DNSDone != nil {
|
||
|
|
trace.DNSDone = func(info httptrace.DNSDoneInfo) {
|
||
|
|
if t.usesManualDNS() {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
h.DNSDone(TraceDNSDoneInfo{
|
||
|
|
Addrs: append([]net.IPAddr(nil), info.Addrs...),
|
||
|
|
Coalesced: info.Coalesced,
|
||
|
|
Err: info.Err,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if h.ConnectStart != nil {
|
||
|
|
trace.ConnectStart = func(network, addr string) {
|
||
|
|
h.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if h.ConnectDone != nil {
|
||
|
|
trace.ConnectDone = func(network, addr string, err error) {
|
||
|
|
h.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if h.TLSHandshakeStart != nil {
|
||
|
|
trace.TLSHandshakeStart = func() {
|
||
|
|
if t.usesCustomTLS() {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
h.TLSHandshakeStart(TraceTLSHandshakeStartInfo{})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if h.TLSHandshakeDone != nil {
|
||
|
|
trace.TLSHandshakeDone = func(state tls.ConnectionState, err error) {
|
||
|
|
if t.usesCustomTLS() {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{
|
||
|
|
ConnectionState: state,
|
||
|
|
Err: err,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if h.WroteHeaderField != nil {
|
||
|
|
trace.WroteHeaderField = func(key string, value []string) {
|
||
|
|
h.WroteHeaderField(TraceWroteHeaderFieldInfo{
|
||
|
|
Key: key,
|
||
|
|
Values: value,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if h.WroteHeaders != nil {
|
||
|
|
trace.WroteHeaders = h.WroteHeaders
|
||
|
|
}
|
||
|
|
if h.WroteRequest != nil {
|
||
|
|
trace.WroteRequest = func(info httptrace.WroteRequestInfo) {
|
||
|
|
h.WroteRequest(TraceWroteRequestInfo{Err: info.Err})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if h.GotFirstResponseByte != nil {
|
||
|
|
trace.GotFirstResponseByte = h.GotFirstResponseByte
|
||
|
|
}
|
||
|
|
return trace
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *traceState) markCustomTLS() {
|
||
|
|
if t == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
t.customTLS.Store(1)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *traceState) usesCustomTLS() bool {
|
||
|
|
if t == nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
return t.customTLS.Load() != 0
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *traceState) beginManualDNS() {
|
||
|
|
if t == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
t.manualDNSRefs.Add(1)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *traceState) endManualDNS() {
|
||
|
|
if t == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
t.manualDNSRefs.Add(-1)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *traceState) usesManualDNS() bool {
|
||
|
|
if t == nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
return t.manualDNSRefs.Load() > 0
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *traceState) tlsHandshakeStart(info TraceTLSHandshakeStartInfo) {
|
||
|
|
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeStart == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
t.hooks.TLSHandshakeStart(info)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *traceState) tlsHandshakeDone(info TraceTLSHandshakeDoneInfo) {
|
||
|
|
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeDone == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
t.hooks.TLSHandshakeDone(info)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *traceState) dnsStart(info TraceDNSStartInfo) {
|
||
|
|
if t == nil || t.hooks == nil || t.hooks.DNSStart == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
t.hooks.DNSStart(info)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *traceState) dnsDone(info TraceDNSDoneInfo) {
|
||
|
|
if t == nil || t.hooks == nil || t.hooks.DNSDone == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
t.hooks.DNSDone(info)
|
||
|
|
}
|
||
|
|
|
||
|
|
func emitRetryAttemptStart(hooks *TraceHooks, info TraceRetryAttemptStartInfo) {
|
||
|
|
if hooks == nil || hooks.RetryAttemptStart == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
hooks.RetryAttemptStart(info)
|
||
|
|
}
|
||
|
|
|
||
|
|
func emitRetryAttemptDone(hooks *TraceHooks, info TraceRetryAttemptDoneInfo) {
|
||
|
|
if hooks == nil || hooks.RetryAttemptDone == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
hooks.RetryAttemptDone(info)
|
||
|
|
}
|
||
|
|
|
||
|
|
func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) {
|
||
|
|
if hooks == nil || hooks.RetryBackoff == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
hooks.RetryBackoff(info)
|
||
|
|
}
|