starnet/trace.go

341 lines
7.2 KiB
Go
Raw Permalink Normal View History

package starnet
import (
"context"
"crypto/tls"
"net"
"net/http/httptrace"
"sync/atomic"
"time"
)
type traceContextKey struct{}
// TraceHooks defines optional callbacks for network lifecycle events.
// Hooks may be called concurrently.
type TraceHooks struct {
GetConn func(TraceGetConnInfo)
GotConn func(TraceGotConnInfo)
PutIdleConn func(TracePutIdleConnInfo)
DNSStart func(TraceDNSStartInfo)
DNSDone func(TraceDNSDoneInfo)
ConnectStart func(TraceConnectStartInfo)
ConnectDone func(TraceConnectDoneInfo)
TLSHandshakeStart func(TraceTLSHandshakeStartInfo)
TLSHandshakeDone func(TraceTLSHandshakeDoneInfo)
WroteHeaderField func(TraceWroteHeaderFieldInfo)
WroteHeaders func()
WroteRequest func(TraceWroteRequestInfo)
GotFirstResponseByte func()
RetryAttemptStart func(TraceRetryAttemptStartInfo)
RetryAttemptDone func(TraceRetryAttemptDoneInfo)
RetryBackoff func(TraceRetryBackoffInfo)
}
type TraceGetConnInfo struct {
Addr string
}
type TraceGotConnInfo struct {
Conn net.Conn
Reused bool
WasIdle bool
IdleTime time.Duration
}
type TracePutIdleConnInfo struct {
Err error
}
type TraceDNSStartInfo struct {
Host string
}
type TraceDNSDoneInfo struct {
Addrs []net.IPAddr
Coalesced bool
Err error
}
type TraceConnectStartInfo struct {
Network string
Addr string
}
type TraceConnectDoneInfo struct {
Network string
Addr string
Err error
}
type TraceTLSHandshakeStartInfo struct {
Network string
Addr string
ServerName string
}
type TraceTLSHandshakeDoneInfo struct {
Network string
Addr string
ServerName string
ConnectionState tls.ConnectionState
Err error
}
type TraceWroteHeaderFieldInfo struct {
Key string
Values []string
}
type TraceWroteRequestInfo struct {
Err error
}
type TraceRetryAttemptStartInfo struct {
Attempt int
MaxAttempts int
}
type TraceRetryAttemptDoneInfo struct {
Attempt int
MaxAttempts int
StatusCode int
Err error
WillRetry bool
}
type TraceRetryBackoffInfo struct {
Attempt int
Delay time.Duration
}
type traceState struct {
hooks *TraceHooks
customTLS atomic.Uint32
manualDNSRefs atomic.Int32
}
func newTraceState(hooks *TraceHooks) *traceState {
if hooks == nil {
return nil
}
return &traceState{hooks: hooks}
}
func withTraceState(ctx context.Context, state *traceState) context.Context {
if state == nil {
return ctx
}
return context.WithValue(ctx, traceContextKey{}, state)
}
func getTraceState(ctx context.Context) *traceState {
if ctx == nil {
return nil
}
state, _ := ctx.Value(traceContextKey{}).(*traceState)
return state
}
func (t *traceState) needsHTTPTrace() bool {
if t == nil || t.hooks == nil {
return false
}
h := t.hooks
return h.GetConn != nil ||
h.GotConn != nil ||
h.PutIdleConn != nil ||
h.DNSStart != nil ||
h.DNSDone != nil ||
h.ConnectStart != nil ||
h.ConnectDone != nil ||
h.TLSHandshakeStart != nil ||
h.TLSHandshakeDone != nil ||
h.WroteHeaderField != nil ||
h.WroteHeaders != nil ||
h.WroteRequest != nil ||
h.GotFirstResponseByte != nil
}
func (t *traceState) clientTrace() *httptrace.ClientTrace {
if !t.needsHTTPTrace() {
return nil
}
h := t.hooks
trace := &httptrace.ClientTrace{}
if h.GetConn != nil {
trace.GetConn = func(hostPort string) {
h.GetConn(TraceGetConnInfo{Addr: hostPort})
}
}
if h.GotConn != nil {
trace.GotConn = func(info httptrace.GotConnInfo) {
h.GotConn(TraceGotConnInfo{
Conn: info.Conn,
Reused: info.Reused,
WasIdle: info.WasIdle,
IdleTime: info.IdleTime,
})
}
}
if h.PutIdleConn != nil {
trace.PutIdleConn = func(err error) {
h.PutIdleConn(TracePutIdleConnInfo{Err: err})
}
}
if h.DNSStart != nil {
trace.DNSStart = func(info httptrace.DNSStartInfo) {
if t.usesManualDNS() {
return
}
h.DNSStart(TraceDNSStartInfo{Host: info.Host})
}
}
if h.DNSDone != nil {
trace.DNSDone = func(info httptrace.DNSDoneInfo) {
if t.usesManualDNS() {
return
}
h.DNSDone(TraceDNSDoneInfo{
Addrs: append([]net.IPAddr(nil), info.Addrs...),
Coalesced: info.Coalesced,
Err: info.Err,
})
}
}
if h.ConnectStart != nil {
trace.ConnectStart = func(network, addr string) {
h.ConnectStart(TraceConnectStartInfo{Network: network, Addr: addr})
}
}
if h.ConnectDone != nil {
trace.ConnectDone = func(network, addr string, err error) {
h.ConnectDone(TraceConnectDoneInfo{Network: network, Addr: addr, Err: err})
}
}
if h.TLSHandshakeStart != nil {
trace.TLSHandshakeStart = func() {
if t.usesCustomTLS() {
return
}
h.TLSHandshakeStart(TraceTLSHandshakeStartInfo{})
}
}
if h.TLSHandshakeDone != nil {
trace.TLSHandshakeDone = func(state tls.ConnectionState, err error) {
if t.usesCustomTLS() {
return
}
h.TLSHandshakeDone(TraceTLSHandshakeDoneInfo{
ConnectionState: state,
Err: err,
})
}
}
if h.WroteHeaderField != nil {
trace.WroteHeaderField = func(key string, value []string) {
h.WroteHeaderField(TraceWroteHeaderFieldInfo{
Key: key,
Values: value,
})
}
}
if h.WroteHeaders != nil {
trace.WroteHeaders = h.WroteHeaders
}
if h.WroteRequest != nil {
trace.WroteRequest = func(info httptrace.WroteRequestInfo) {
h.WroteRequest(TraceWroteRequestInfo{Err: info.Err})
}
}
if h.GotFirstResponseByte != nil {
trace.GotFirstResponseByte = h.GotFirstResponseByte
}
return trace
}
func (t *traceState) markCustomTLS() {
if t == nil {
return
}
t.customTLS.Store(1)
}
func (t *traceState) usesCustomTLS() bool {
if t == nil {
return false
}
return t.customTLS.Load() != 0
}
func (t *traceState) beginManualDNS() {
if t == nil {
return
}
t.manualDNSRefs.Add(1)
}
func (t *traceState) endManualDNS() {
if t == nil {
return
}
t.manualDNSRefs.Add(-1)
}
func (t *traceState) usesManualDNS() bool {
if t == nil {
return false
}
return t.manualDNSRefs.Load() > 0
}
func (t *traceState) tlsHandshakeStart(info TraceTLSHandshakeStartInfo) {
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeStart == nil {
return
}
t.hooks.TLSHandshakeStart(info)
}
func (t *traceState) tlsHandshakeDone(info TraceTLSHandshakeDoneInfo) {
if t == nil || t.hooks == nil || t.hooks.TLSHandshakeDone == nil {
return
}
t.hooks.TLSHandshakeDone(info)
}
func (t *traceState) dnsStart(info TraceDNSStartInfo) {
if t == nil || t.hooks == nil || t.hooks.DNSStart == nil {
return
}
t.hooks.DNSStart(info)
}
func (t *traceState) dnsDone(info TraceDNSDoneInfo) {
if t == nil || t.hooks == nil || t.hooks.DNSDone == nil {
return
}
t.hooks.DNSDone(info)
}
func emitRetryAttemptStart(hooks *TraceHooks, info TraceRetryAttemptStartInfo) {
if hooks == nil || hooks.RetryAttemptStart == nil {
return
}
hooks.RetryAttemptStart(info)
}
func emitRetryAttemptDone(hooks *TraceHooks, info TraceRetryAttemptDoneInfo) {
if hooks == nil || hooks.RetryAttemptDone == nil {
return
}
hooks.RetryAttemptDone(info)
}
func emitRetryBackoff(hooks *TraceHooks, info TraceRetryBackoffInfo) {
if hooks == nil || hooks.RetryBackoff == nil {
return
}
hooks.RetryBackoff(info)
}