starnet/trace.go

605 lines
14 KiB
Go
Raw Permalink Normal View History

package starnet
import (
"context"
"crypto/tls"
"net"
"net/http/httptrace"
"sync/atomic"
"time"
)
type traceContextKey struct{}
// TraceHooks defines optional callbacks for network lifecycle events.
// Hooks may be called concurrently.
type TraceHooks struct {
GetConn func(TraceGetConnInfo)
GotConn func(TraceGotConnInfo)
PutIdleConn func(TracePutIdleConnInfo)
DNSStart func(TraceDNSStartInfo)
DNSDone func(TraceDNSDoneInfo)
ConnectStart func(TraceConnectStartInfo)
ConnectDone func(TraceConnectDoneInfo)
TLSHandshakeStart func(TraceTLSHandshakeStartInfo)
TLSHandshakeDone func(TraceTLSHandshakeDoneInfo)
WroteHeaderField func(TraceWroteHeaderFieldInfo)
WroteHeaders func()
WroteRequest func(TraceWroteRequestInfo)
GotFirstResponseByte func()
RetryAttemptStart func(TraceRetryAttemptStartInfo)
RetryAttemptDone func(TraceRetryAttemptDoneInfo)
RetryBackoff func(TraceRetryBackoffInfo)
}
type TraceGetConnInfo struct {
Addr string
}
type TraceGotConnInfo struct {
Conn net.Conn
Reused bool
WasIdle bool
IdleTime time.Duration
}
type TracePutIdleConnInfo struct {
Err error
}
type TraceDNSStartInfo struct {
Host string
}
type TraceDNSDoneInfo struct {
Addrs []net.IPAddr
Coalesced bool
Err error
}
type TraceConnectStartInfo struct {
Network string
Addr string
}
type TraceConnectDoneInfo struct {
Network string
Addr string
Err error
}
type TraceTLSHandshakeStartInfo struct {
Network string
Addr string
ServerName string
}
type TraceTLSHandshakeDoneInfo struct {
Network string
Addr string
ServerName string
ConnectionState tls.ConnectionState
Err error
}
type TraceWroteHeaderFieldInfo struct {
Key string
Values []string
}
type TraceWroteRequestInfo struct {
Err error
}
type TraceRetryAttemptStartInfo struct {
Attempt int
MaxAttempts int
}
type TraceRetryAttemptDoneInfo struct {
Attempt int
MaxAttempts int
StatusCode int
Err error
WillRetry bool
}
type TraceRetryBackoffInfo struct {
Attempt int
Delay time.Duration
}
type traceState struct {
hooks *TraceHooks
customTLS atomic.Uint32
manualDNSRefs atomic.Int32
defaultTLSHandshakeInfo TraceTLSHandshakeStartInfo
}
func newTraceState(hooks *TraceHooks) *traceState {
if hooks == nil {
return nil
}
return &traceState{hooks: hooks}
}
func (t *traceState) setDefaultTLSHandshakeInfo(info TraceTLSHandshakeStartInfo) {
if t == nil {
return
}
t.defaultTLSHandshakeInfo = info
}
func (t *traceState) getDefaultTLSHandshakeInfo() TraceTLSHandshakeStartInfo {
if t == nil {
return TraceTLSHandshakeStartInfo{}
}
return t.defaultTLSHandshakeInfo
}
func withTraceState(ctx context.Context, state *traceState) context.Context {
if state == nil {
return ctx
}
return context.WithValue(ctx, traceContextKey{}, state)
}
func getTraceState(ctx context.Context) *traceState {
if ctx == nil {
return nil
}
state, _ := ctx.Value(traceContextKey{}).(*traceState)
return state
}
func (t *traceState) needsHTTPTrace() bool {
if t == nil || t.hooks == nil {
return false
}
h := t.hooks
return h.GetConn != nil ||
h.GotConn != nil ||
h.PutIdleConn != nil ||
h.DNSStart != nil ||
h.DNSDone != nil ||
h.ConnectStart != nil ||
h.ConnectDone != nil ||
h.TLSHandshakeStart != nil ||
h.TLSHandshakeDone != nil ||
h.WroteHeaderField != nil ||
h.WroteHeaders != nil ||
h.WroteRequest != nil ||
h.GotFirstResponseByte != nil
}
func (t *traceState) clientTrace() *httptrace.ClientTrace {
if !t.needsHTTPTrace() {
return nil
}
h := t.hooks
trace := &httptrace.ClientTrace{}
if h.GetConn != nil {
trace.GetConn = func(hostPort string) {
h.GetConn(TraceGetConnInfo{Addr: hostPort})
}
}
if h.GotConn != nil {
trace.GotConn = func(info httptrace.GotConnInfo) {
h.GotConn(TraceGotConnInfo{
Conn: info.Conn,
Reused: info.Reused,
WasIdle: info.WasIdle,
IdleTime: info.IdleTime,
})
}
}
if h.PutIdleConn != nil {
trace.PutIdleConn = func(err error) {
h.PutIdleConn(TracePutIdleConnInfo{Err: err})
}
}
if h.DNSStart != nil {
trace.DNSStart = func(info httptrace.DNSStartInfo) {
if t.usesManualDNS() {
return
}
h.DNSStart(TraceDNSStartInfo{Host: info.Host})
}
}
if h.DNSDone != nil {
trace.DNSDone = func(info httptrace.DNSDoneInfo) {
if t.usesManualDNS() {
return
}
h.DNSDone(TraceDNSDoneInfo{
Addrs: append([]net.IPAddr(nil), info.Addrs...),
Coalesced: info.Coalesced,
Err: info.Err,
})
}
}
if h.ConnectStart != nil {
trace.ConnectStart = func(network, addr string) {
h.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
}
}
if h.ConnectDone != nil {
trace.ConnectDone = func(network, addr string, err error) {
h.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
}
}
if h.TLSHandshakeStart != nil {
trace.TLSHandshakeStart = func() {
if t.usesCustomTLS() {
return
}
h.TLSHandshakeStart(t.getDefaultTLSHandshakeInfo())
}
}
if h.TLSHandshakeDone != nil {
trace.TLSHandshakeDone = func(state tls.ConnectionState, err error) {
if t.usesCustomTLS() {
return
}
info := t.getDefaultTLSHandshakeInfo()
h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{
Network: info.Network,
Addr: info.Addr,
ServerName: info.ServerName,
ConnectionState: state,
Err: err,
})
}
}
if h.WroteHeaderField != nil {
trace.WroteHeaderField = func(key string, value []string) {
h.WroteHeaderField(TraceWroteHeaderFieldInfo{
Key: key,
Values: append([]string(nil), value...),
})
}
}
if h.WroteHeaders != nil {
trace.WroteHeaders = h.WroteHeaders
}
if h.WroteRequest != nil {
trace.WroteRequest = func(info httptrace.WroteRequestInfo) {
h.WroteRequest(TraceWroteRequestInfo{Err: info.Err})
}
}
if h.GotFirstResponseByte != nil {
trace.GotFirstResponseByte = h.GotFirstResponseByte
}
return trace
}
func (t *traceState) markCustomTLS() {
if t == nil {
return
}
t.customTLS.Store(1)
}
func (t *traceState) usesCustomTLS() bool {
if t == nil {
return false
}
return t.customTLS.Load() != 0
}
func (t *traceState) beginManualDNS() {
if t == nil {
return
}
t.manualDNSRefs.Add(1)
}
func (t *traceState) endManualDNS() {
if t == nil {
return
}
t.manualDNSRefs.Add(-1)
}
func (t *traceState) usesManualDNS() bool {
if t == nil {
return false
}
return t.manualDNSRefs.Load() > 0
}
func (t *traceState) tlsHandshakeStart(info TraceTLSHandshakeStartInfo) {
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeStart == nil {
return
}
t.hooks.TLSHandshakeStart(info)
}
func (t *traceState) tlsHandshakeDone(info TraceTLSHandshakeDoneInfo) {
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeDone == nil {
return
}
t.hooks.TLSHandshakeDone(info)
}
func (t *traceState) dnsStart(info TraceDNSStartInfo) {
if t == nil || t.hooks == nil || t.hooks.DNSStart == nil {
return
}
t.hooks.DNSStart(info)
}
func (t *traceState) dnsDone(info TraceDNSDoneInfo) {
if t == nil || t.hooks == nil || t.hooks.DNSDone == nil {
return
}
t.hooks.DNSDone(info)
}
func emitRetryAttemptStart(hooks *TraceHooks, info TraceRetryAttemptStartInfo) {
if hooks == nil || hooks.RetryAttemptStart == nil {
return
}
hooks.RetryAttemptStart(info)
}
func emitRetryAttemptDone(hooks *TraceHooks, info TraceRetryAttemptDoneInfo) {
if hooks == nil || hooks.RetryAttemptDone == nil {
return
}
hooks.RetryAttemptDone(info)
}
func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) {
if hooks == nil || hooks.RetryBackoff == nil {
return
}
hooks.RetryBackoff(info)
}
func traceRecorderHooks(recorder *TraceRecorder) *TraceHooks {
if recorder == nil {
return nil
}
return recorder.Hooks()
}
func composeTraceHooks(first, second *TraceHooks) *TraceHooks {
switch {
case first == nil:
return second
case second == nil:
return first
}
return &TraceHooks{
GetConn: composeTraceGetConnHook(first.GetConn, second.GetConn),
GotConn: composeTraceGotConnHook(first.GotConn, second.GotConn),
PutIdleConn: composeTracePutIdleConnHook(first.PutIdleConn, second.PutIdleConn),
DNSStart: composeTraceDNSStartHook(first.DNSStart, second.DNSStart),
DNSDone: composeTraceDNSDoneHook(first.DNSDone, second.DNSDone),
ConnectStart: composeTraceConnectStartHook(first.ConnectStart, second.ConnectStart),
ConnectDone: composeTraceConnectDoneHook(first.ConnectDone, second.ConnectDone),
TLSHandshakeStart: composeTraceTLSHandshakeStartHook(first.TLSHandshakeStart, second.TLSHandshakeStart),
TLSHandshakeDone: composeTraceTLSHandshakeDoneHook(first.TLSHandshakeDone, second.TLSHandshakeDone),
WroteHeaderField: composeTraceWroteHeaderFieldHook(first.WroteHeaderField, second.WroteHeaderField),
WroteHeaders: composeTraceSimpleHook(first.WroteHeaders, second.WroteHeaders),
WroteRequest: composeTraceWroteRequestHook(first.WroteRequest, second.WroteRequest),
GotFirstResponseByte: composeTraceSimpleHook(first.GotFirstResponseByte, second.GotFirstResponseByte),
RetryAttemptStart: composeTraceRetryAttemptStartHook(first.RetryAttemptStart, second.RetryAttemptStart),
RetryAttemptDone: composeTraceRetryAttemptDoneHook(first.RetryAttemptDone, second.RetryAttemptDone),
RetryBackoff: composeTraceRetryBackoffHook(first.RetryBackoff, second.RetryBackoff),
}
}
func composeTraceGetConnHook(first, second func(TraceGetConnInfo)) func(TraceGetConnInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceGetConnInfo) {
first(info)
second(info)
}
}
}
func composeTraceGotConnHook(first, second func(TraceGotConnInfo)) func(TraceGotConnInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceGotConnInfo) {
first(info)
second(info)
}
}
}
func composeTracePutIdleConnHook(first, second func(TracePutIdleConnInfo)) func(TracePutIdleConnInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TracePutIdleConnInfo) {
first(info)
second(info)
}
}
}
func composeTraceDNSStartHook(first, second func(TraceDNSStartInfo)) func(TraceDNSStartInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceDNSStartInfo) {
first(info)
second(info)
}
}
}
func composeTraceDNSDoneHook(first, second func(TraceDNSDoneInfo)) func(TraceDNSDoneInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceDNSDoneInfo) {
first(info)
second(info)
}
}
}
func composeTraceConnectStartHook(first, second func(TraceConnectStartInfo)) func(TraceConnectStartInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceConnectStartInfo) {
first(info)
second(info)
}
}
}
func composeTraceConnectDoneHook(first, second func(TraceConnectDoneInfo)) func(TraceConnectDoneInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceConnectDoneInfo) {
first(info)
second(info)
}
}
}
func composeTraceTLSHandshakeStartHook(first, second func(TraceTLSHandshakeStartInfo)) func(TraceTLSHandshakeStartInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceTLSHandshakeStartInfo) {
first(info)
second(info)
}
}
}
func composeTraceTLSHandshakeDoneHook(first, second func(TraceTLSHandshakeDoneInfo)) func(TraceTLSHandshakeDoneInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceTLSHandshakeDoneInfo) {
first(info)
second(info)
}
}
}
func composeTraceWroteHeaderFieldHook(first, second func(TraceWroteHeaderFieldInfo)) func(TraceWroteHeaderFieldInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceWroteHeaderFieldInfo) {
first(info)
second(info)
}
}
}
func composeTraceWroteRequestHook(first, second func(TraceWroteRequestInfo)) func(TraceWroteRequestInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceWroteRequestInfo) {
first(info)
second(info)
}
}
}
func composeTraceRetryAttemptStartHook(first, second func(TraceRetryAttemptStartInfo)) func(TraceRetryAttemptStartInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceRetryAttemptStartInfo) {
first(info)
second(info)
}
}
}
func composeTraceRetryAttemptDoneHook(first, second func(TraceRetryAttemptDoneInfo)) func(TraceRetryAttemptDoneInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceRetryAttemptDoneInfo) {
first(info)
second(info)
}
}
}
func composeTraceRetryBackoffHook(first, second func(TraceRetryBackoffInfo)) func(TraceRetryBackoffInfo) {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func(info TraceRetryBackoffInfo) {
first(info)
second(info)
}
}
}
func composeTraceSimpleHook(first, second func()) func() {
switch {
case first == nil:
return second
case second == nil:
return first
default:
return func() {
first()
second()
}
}
}