package starnet import ( "context" "crypto/tls" "net" "net/http/httptrace" "sync/atomic" "time" ) type traceContextKey struct{} // TraceHooks defines optional callbacks for network lifecycle events. // Hooks may be called concurrently. type TraceHooks struct { GetConn func(TraceGetConnInfo) GotConn func(TraceGotConnInfo) PutIdleConn func(TracePutIdleConnInfo) DNSStart func(TraceDNSStartInfo) DNSDone func(TraceDNSDoneInfo) ConnectStart func(TraceConnectStartInfo) ConnectDone func(TraceConnectDoneInfo) TLSHandshakeStart func(TraceTLSHandshakeStartInfo) TLSHandshakeDone func(TraceTLSHandshakeDoneInfo) WroteHeaderField func(TraceWroteHeaderFieldInfo) WroteHeaders func() WroteRequest func(TraceWroteRequestInfo) GotFirstResponseByte func() RetryAttemptStart func(TraceRetryAttemptStartInfo) RetryAttemptDone func(TraceRetryAttemptDoneInfo) RetryBackoff func(TraceRetryBackoffInfo) } type TraceGetConnInfo struct { Addr string } type TraceGotConnInfo struct { Conn net.Conn Reused bool WasIdle bool IdleTime time.Duration } type TracePutIdleConnInfo struct { Err error } type TraceDNSStartInfo struct { Host string } type TraceDNSDoneInfo struct { Addrs []net.IPAddr Coalesced bool Err error } type TraceConnectStartInfo struct { Network string Addr string } type TraceConnectDoneInfo struct { Network string Addr string Err error } type TraceTLSHandshakeStartInfo struct { Network string Addr string ServerName string } type TraceTLSHandshakeDoneInfo struct { Network string Addr string ServerName string ConnectionState tls.ConnectionState Err error } type TraceWroteHeaderFieldInfo struct { Key string Values []string } type TraceWroteRequestInfo struct { Err error } type TraceRetryAttemptStartInfo struct { Attempt int MaxAttempts int } type TraceRetryAttemptDoneInfo struct { Attempt int MaxAttempts int StatusCode int Err error WillRetry bool } type TraceRetryBackoffInfo struct { Attempt int Delay time.Duration } type traceState struct { hooks *TraceHooks customTLS atomic.Uint32 manualDNSRefs atomic.Int32 } func newTraceState(hooks *TraceHooks) *traceState { if hooks == nil { return nil } return &traceState{hooks: hooks} } func withTraceState(ctx context.Context, state *traceState) context.Context { if state == nil { return ctx } return context.WithValue(ctx, traceContextKey{}, state) } func getTraceState(ctx context.Context) *traceState { if ctx == nil { return nil } state, _ := ctx.Value(traceContextKey{}).(*traceState) return state } func (t *traceState) needsHTTPTrace() bool { if t == nil || t.hooks == nil { return false } h := t.hooks return h.GetConn != nil || h.GotConn != nil || h.PutIdleConn != nil || h.DNSStart != nil || h.DNSDone != nil || h.ConnectStart != nil || h.ConnectDone != nil || h.TLSHandshakeStart != nil || h.TLSHandshakeDone != nil || h.WroteHeaderField != nil || h.WroteHeaders != nil || h.WroteRequest != nil || h.GotFirstResponseByte != nil } func (t *traceState) clientTrace() *httptrace.ClientTrace { if !t.needsHTTPTrace() { return nil } h := t.hooks trace := &httptrace.ClientTrace{} if h.GetConn != nil { trace.GetConn = func(hostPort string) { h.GetConn(TraceGetConnInfo{Addr: hostPort}) } } if h.GotConn != nil { trace.GotConn = func(info httptrace.GotConnInfo) { h.GotConn(TraceGotConnInfo{ Conn: info.Conn, Reused: info.Reused, WasIdle: info.WasIdle, IdleTime: info.IdleTime, }) } } if h.PutIdleConn != nil { trace.PutIdleConn = func(err error) { h.PutIdleConn(TracePutIdleConnInfo{Err: err}) } } if h.DNSStart != nil { trace.DNSStart = func(info httptrace.DNSStartInfo) { if t.usesManualDNS() { return } h.DNSStart(TraceDNSStartInfo{Host: info.Host}) } } if h.DNSDone != nil { trace.DNSDone = func(info httptrace.DNSDoneInfo) { if t.usesManualDNS() { return } h.DNSDone(TraceDNSDoneInfo{ Addrs: append([]net.IPAddr(nil), info.Addrs...), Coalesced: info.Coalesced, Err: info.Err, }) } } if h.ConnectStart != nil { trace.ConnectStart = func(network, addr string) { h.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr}) } } if h.ConnectDone != nil { trace.ConnectDone = func(network, addr string, err error) { h.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err}) } } if h.TLSHandshakeStart != nil { trace.TLSHandshakeStart = func() { if t.usesCustomTLS() { return } h.TLSHandshakeStart(TraceTLSHandshakeStartInfo{}) } } if h.TLSHandshakeDone != nil { trace.TLSHandshakeDone = func(state tls.ConnectionState, err error) { if t.usesCustomTLS() { return } h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{ ConnectionState: state, Err: err, }) } } if h.WroteHeaderField != nil { trace.WroteHeaderField = func(key string, value []string) { h.WroteHeaderField(TraceWroteHeaderFieldInfo{ Key: key, Values: value, }) } } if h.WroteHeaders != nil { trace.WroteHeaders = h.WroteHeaders } if h.WroteRequest != nil { trace.WroteRequest = func(info httptrace.WroteRequestInfo) { h.WroteRequest(TraceWroteRequestInfo{Err: info.Err}) } } if h.GotFirstResponseByte != nil { trace.GotFirstResponseByte = h.GotFirstResponseByte } return trace } func (t *traceState) markCustomTLS() { if t == nil { return } t.customTLS.Store(1) } func (t *traceState) usesCustomTLS() bool { if t == nil { return false } return t.customTLS.Load() != 0 } func (t *traceState) beginManualDNS() { if t == nil { return } t.manualDNSRefs.Add(1) } func (t *traceState) endManualDNS() { if t == nil { return } t.manualDNSRefs.Add(-1) } func (t *traceState) usesManualDNS() bool { if t == nil { return false } return t.manualDNSRefs.Load() > 0 } func (t *traceState) tlsHandshakeStart(info TraceTLSHandshakeStartInfo) { if t == nil || t.hooks == nil || t.hooks.TLSHandshakeStart == nil { return } t.hooks.TLSHandshakeStart(info) } func (t *traceState) tlsHandshakeDone(info TraceTLSHandshakeDoneInfo) { if t == nil || t.hooks == nil || t.hooks.TLSHandshakeDone == nil { return } t.hooks.TLSHandshakeDone(info) } func (t *traceState) dnsStart(info TraceDNSStartInfo) { if t == nil || t.hooks == nil || t.hooks.DNSStart == nil { return } t.hooks.DNSStart(info) } func (t *traceState) dnsDone(info TraceDNSDoneInfo) { if t == nil || t.hooks == nil || t.hooks.DNSDone == nil { return } t.hooks.DNSDone(info) } func emitRetryAttemptStart(hooks *TraceHooks, info TraceRetryAttemptStartInfo) { if hooks == nil || hooks.RetryAttemptStart == nil { return } hooks.RetryAttemptStart(info) } func emitRetryAttemptDone(hooks *TraceHooks, info TraceRetryAttemptDoneInfo) { if hooks == nil || hooks.RetryAttemptDone == nil { return } hooks.RetryAttemptDone(info) } func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) { if hooks == nil || hooks.RetryBackoff == nil { return } hooks.RetryBackoff(info) }