package starnet import ( "context" "crypto/tls" "net" "net/http/httptrace" "sync/atomic" "time" ) type traceContextKey struct{} // TraceHooks defines optional callbacks for network lifecycle events. // Hooks may be called concurrently. type TraceHooks struct { GetConn func(TraceGetConnInfo) GotConn func(TraceGotConnInfo) PutIdleConn func(TracePutIdleConnInfo) DNSStart func(TraceDNSStartInfo) DNSDone func(TraceDNSDoneInfo) ConnectStart func(TraceConnectStartInfo) ConnectDone func(TraceConnectDoneInfo) TLSHandshakeStart func(TraceTLSHandshakeStartInfo) TLSHandshakeDone func(TraceTLSHandshakeDoneInfo) WroteHeaderField func(TraceWroteHeaderFieldInfo) WroteHeaders func() WroteRequest func(TraceWroteRequestInfo) GotFirstResponseByte func() RetryAttemptStart func(TraceRetryAttemptStartInfo) RetryAttemptDone func(TraceRetryAttemptDoneInfo) RetryBackoff func(TraceRetryBackoffInfo) } type TraceGetConnInfo struct { Addr string } type TraceGotConnInfo struct { Conn net.Conn Reused bool WasIdle bool IdleTime time.Duration } type TracePutIdleConnInfo struct { Err error } type TraceDNSStartInfo struct { Host string } type TraceDNSDoneInfo struct { Addrs []net.IPAddr Coalesced bool Err error } type TraceConnectStartInfo struct { Network string Addr string } type TraceConnectDoneInfo struct { Network string Addr string Err error } type TraceTLSHandshakeStartInfo struct { Network string Addr string ServerName string } type TraceTLSHandshakeDoneInfo struct { Network string Addr string ServerName string ConnectionState tls.ConnectionState Err error } type TraceWroteHeaderFieldInfo struct { Key string Values []string } type TraceWroteRequestInfo struct { Err error } type TraceRetryAttemptStartInfo struct { Attempt int MaxAttempts int } type TraceRetryAttemptDoneInfo struct { Attempt int MaxAttempts int StatusCode int Err error WillRetry bool } type TraceRetryBackoffInfo struct { Attempt int Delay time.Duration } type traceState struct { hooks *TraceHooks customTLS atomic.Uint32 manualDNSRefs atomic.Int32 defaultTLSHandshakeInfo TraceTLSHandshakeStartInfo } func newTraceState(hooks *TraceHooks) *traceState { if hooks == nil { return nil } return &traceState{hooks: hooks} } func (t *traceState) setDefaultTLSHandshakeInfo(info TraceTLSHandshakeStartInfo) { if t == nil { return } t.defaultTLSHandshakeInfo = info } func (t *traceState) getDefaultTLSHandshakeInfo() TraceTLSHandshakeStartInfo { if t == nil { return TraceTLSHandshakeStartInfo{} } return t.defaultTLSHandshakeInfo } func withTraceState(ctx context.Context, state *traceState) context.Context { if state == nil { return ctx } return context.WithValue(ctx, traceContextKey{}, state) } func getTraceState(ctx context.Context) *traceState { if ctx == nil { return nil } state, _ := ctx.Value(traceContextKey{}).(*traceState) return state } func (t *traceState) needsHTTPTrace() bool { if t == nil || t.hooks == nil { return false } h := t.hooks return h.GetConn != nil || h.GotConn != nil || h.PutIdleConn != nil || h.DNSStart != nil || h.DNSDone != nil || h.ConnectStart != nil || h.ConnectDone != nil || h.TLSHandshakeStart != nil || h.TLSHandshakeDone != nil || h.WroteHeaderField != nil || h.WroteHeaders != nil || h.WroteRequest != nil || h.GotFirstResponseByte != nil } func (t *traceState) clientTrace() *httptrace.ClientTrace { if !t.needsHTTPTrace() { return nil } h := t.hooks trace := &httptrace.ClientTrace{} if h.GetConn != nil { trace.GetConn = func(hostPort string) { h.GetConn(TraceGetConnInfo{Addr: hostPort}) } } if h.GotConn != nil { trace.GotConn = func(info httptrace.GotConnInfo) { h.GotConn(TraceGotConnInfo{ Conn: info.Conn, Reused: info.Reused, WasIdle: info.WasIdle, IdleTime: info.IdleTime, }) } } if h.PutIdleConn != nil { trace.PutIdleConn = func(err error) { h.PutIdleConn(TracePutIdleConnInfo{Err: err}) } } if h.DNSStart != nil { trace.DNSStart = func(info httptrace.DNSStartInfo) { if t.usesManualDNS() { return } h.DNSStart(TraceDNSStartInfo{Host: info.Host}) } } if h.DNSDone != nil { trace.DNSDone = func(info httptrace.DNSDoneInfo) { if t.usesManualDNS() { return } h.DNSDone(TraceDNSDoneInfo{ Addrs: append([]net.IPAddr(nil), info.Addrs...), Coalesced: info.Coalesced, Err: info.Err, }) } } if h.ConnectStart != nil { trace.ConnectStart = func(network, addr string) { h.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr}) } } if h.ConnectDone != nil { trace.ConnectDone = func(network, addr string, err error) { h.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err}) } } if h.TLSHandshakeStart != nil { trace.TLSHandshakeStart = func() { if t.usesCustomTLS() { return } h.TLSHandshakeStart(t.getDefaultTLSHandshakeInfo()) } } if h.TLSHandshakeDone != nil { trace.TLSHandshakeDone = func(state tls.ConnectionState, err error) { if t.usesCustomTLS() { return } info := t.getDefaultTLSHandshakeInfo() h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{ Network: info.Network, Addr: info.Addr, ServerName: info.ServerName, ConnectionState: state, Err: err, }) } } if h.WroteHeaderField != nil { trace.WroteHeaderField = func(key string, value []string) { h.WroteHeaderField(TraceWroteHeaderFieldInfo{ Key: key, Values: append([]string(nil), value...), }) } } if h.WroteHeaders != nil { trace.WroteHeaders = h.WroteHeaders } if h.WroteRequest != nil { trace.WroteRequest = func(info httptrace.WroteRequestInfo) { h.WroteRequest(TraceWroteRequestInfo{Err: info.Err}) } } if h.GotFirstResponseByte != nil { trace.GotFirstResponseByte = h.GotFirstResponseByte } return trace } func (t *traceState) markCustomTLS() { if t == nil { return } t.customTLS.Store(1) } func (t *traceState) usesCustomTLS() bool { if t == nil { return false } return t.customTLS.Load() != 0 } func (t *traceState) beginManualDNS() { if t == nil { return } t.manualDNSRefs.Add(1) } func (t *traceState) endManualDNS() { if t == nil { return } t.manualDNSRefs.Add(-1) } func (t *traceState) usesManualDNS() bool { if t == nil { return false } return t.manualDNSRefs.Load() > 0 } func (t *traceState) tlsHandshakeStart(info TraceTLSHandshakeStartInfo) { if t == nil || t.hooks == nil || t.hooks.TLSHandshakeStart == nil { return } t.hooks.TLSHandshakeStart(info) } func (t *traceState) tlsHandshakeDone(info TraceTLSHandshakeDoneInfo) { if t == nil || t.hooks == nil || t.hooks.TLSHandshakeDone == nil { return } t.hooks.TLSHandshakeDone(info) } func (t *traceState) dnsStart(info TraceDNSStartInfo) { if t == nil || t.hooks == nil || t.hooks.DNSStart == nil { return } t.hooks.DNSStart(info) } func (t *traceState) dnsDone(info TraceDNSDoneInfo) { if t == nil || t.hooks == nil || t.hooks.DNSDone == nil { return } t.hooks.DNSDone(info) } func emitRetryAttemptStart(hooks *TraceHooks, info TraceRetryAttemptStartInfo) { if hooks == nil || hooks.RetryAttemptStart == nil { return } hooks.RetryAttemptStart(info) } func emitRetryAttemptDone(hooks *TraceHooks, info TraceRetryAttemptDoneInfo) { if hooks == nil || hooks.RetryAttemptDone == nil { return } hooks.RetryAttemptDone(info) } func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) { if hooks == nil || hooks.RetryBackoff == nil { return } hooks.RetryBackoff(info) } func traceRecorderHooks(recorder *TraceRecorder) *TraceHooks { if recorder == nil { return nil } return recorder.Hooks() } func composeTraceHooks(first, second *TraceHooks) *TraceHooks { switch { case first == nil: return second case second == nil: return first } return &TraceHooks{ GetConn: composeTraceGetConnHook(first.GetConn, second.GetConn), GotConn: composeTraceGotConnHook(first.GotConn, second.GotConn), PutIdleConn: composeTracePutIdleConnHook(first.PutIdleConn, second.PutIdleConn), DNSStart: composeTraceDNSStartHook(first.DNSStart, second.DNSStart), DNSDone: composeTraceDNSDoneHook(first.DNSDone, second.DNSDone), ConnectStart: composeTraceConnectStartHook(first.ConnectStart, second.ConnectStart), ConnectDone: composeTraceConnectDoneHook(first.ConnectDone, second.ConnectDone), TLSHandshakeStart: composeTraceTLSHandshakeStartHook(first.TLSHandshakeStart, second.TLSHandshakeStart), TLSHandshakeDone: composeTraceTLSHandshakeDoneHook(first.TLSHandshakeDone, second.TLSHandshakeDone), WroteHeaderField: composeTraceWroteHeaderFieldHook(first.WroteHeaderField, second.WroteHeaderField), WroteHeaders: composeTraceSimpleHook(first.WroteHeaders, second.WroteHeaders), WroteRequest: composeTraceWroteRequestHook(first.WroteRequest, second.WroteRequest), GotFirstResponseByte: composeTraceSimpleHook(first.GotFirstResponseByte, second.GotFirstResponseByte), RetryAttemptStart: composeTraceRetryAttemptStartHook(first.RetryAttemptStart, second.RetryAttemptStart), RetryAttemptDone: composeTraceRetryAttemptDoneHook(first.RetryAttemptDone, second.RetryAttemptDone), RetryBackoff: composeTraceRetryBackoffHook(first.RetryBackoff, second.RetryBackoff), } } func composeTraceGetConnHook(first, second func(TraceGetConnInfo)) func(TraceGetConnInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceGetConnInfo) { first(info) second(info) } } } func composeTraceGotConnHook(first, second func(TraceGotConnInfo)) func(TraceGotConnInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceGotConnInfo) { first(info) second(info) } } } func composeTracePutIdleConnHook(first, second func(TracePutIdleConnInfo)) func(TracePutIdleConnInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TracePutIdleConnInfo) { first(info) second(info) } } } func composeTraceDNSStartHook(first, second func(TraceDNSStartInfo)) func(TraceDNSStartInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceDNSStartInfo) { first(info) second(info) } } } func composeTraceDNSDoneHook(first, second func(TraceDNSDoneInfo)) func(TraceDNSDoneInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceDNSDoneInfo) { first(info) second(info) } } } func composeTraceConnectStartHook(first, second func(TraceConnectStartInfo)) func(TraceConnectStartInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceConnectStartInfo) { first(info) second(info) } } } func composeTraceConnectDoneHook(first, second func(TraceConnectDoneInfo)) func(TraceConnectDoneInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceConnectDoneInfo) { first(info) second(info) } } } func composeTraceTLSHandshakeStartHook(first, second func(TraceTLSHandshakeStartInfo)) func(TraceTLSHandshakeStartInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceTLSHandshakeStartInfo) { first(info) second(info) } } } func composeTraceTLSHandshakeDoneHook(first, second func(TraceTLSHandshakeDoneInfo)) func(TraceTLSHandshakeDoneInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceTLSHandshakeDoneInfo) { first(info) second(info) } } } func composeTraceWroteHeaderFieldHook(first, second func(TraceWroteHeaderFieldInfo)) func(TraceWroteHeaderFieldInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceWroteHeaderFieldInfo) { first(info) second(info) } } } func composeTraceWroteRequestHook(first, second func(TraceWroteRequestInfo)) func(TraceWroteRequestInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceWroteRequestInfo) { first(info) second(info) } } } func composeTraceRetryAttemptStartHook(first, second func(TraceRetryAttemptStartInfo)) func(TraceRetryAttemptStartInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceRetryAttemptStartInfo) { first(info) second(info) } } } func composeTraceRetryAttemptDoneHook(first, second func(TraceRetryAttemptDoneInfo)) func(TraceRetryAttemptDoneInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceRetryAttemptDoneInfo) { first(info) second(info) } } } func composeTraceRetryBackoffHook(first, second func(TraceRetryBackoffInfo)) func(TraceRetryBackoffInfo) { switch { case first == nil: return second case second == nil: return first default: return func(info TraceRetryBackoffInfo) { first(info) second(info) } } } func composeTraceSimpleHook(first, second func()) func() { switch { case first == nil: return second case second == nil: return first default: return func() { first() second() } } }