package notify import ( "b612.me/notify/internal/transport" "context" "errors" "net" "time" ) const ( clientConnectSourceConn = "conn" clientConnectSourceNetwork = "network" clientConnectSourceTimeout = "timeout" clientConnectSourceFactory = "factory" ) var errClientReconnectSourceUnavailable = errors.New("client reconnect source is unavailable") type clientConnectSource struct { kind string network string addr string dialFn func(context.Context) (net.Conn, error) } func newClientConnConnectSource(conn net.Conn) *clientConnectSource { source := &clientConnectSource{kind: clientConnectSourceConn} if conn == nil { return source } if remoteAddr := conn.RemoteAddr(); remoteAddr != nil { source.network = remoteAddr.Network() source.addr = remoteAddr.String() } if source.network == "" { if localAddr := conn.LocalAddr(); localAddr != nil { source.network = localAddr.Network() } } return source } func newClientNetworkConnectSource(network string, addr string) *clientConnectSource { return &clientConnectSource{ kind: clientConnectSourceNetwork, network: network, addr: addr, dialFn: func(context.Context) (net.Conn, error) { return transport.Dial(network, addr) }, } } func newClientTimeoutConnectSource(network string, addr string, timeout time.Duration) *clientConnectSource { return &clientConnectSource{ kind: clientConnectSourceTimeout, network: network, addr: addr, dialFn: func(context.Context) (net.Conn, error) { return transport.DialTimeout(network, addr, timeout) }, } } func newClientFactoryConnectSource(dialFn func(context.Context) (net.Conn, error)) *clientConnectSource { return &clientConnectSource{ kind: clientConnectSourceFactory, dialFn: dialFn, } } func (s *clientConnectSource) clone() *clientConnectSource { if s == nil { return nil } out := *s return &out } func (s *clientConnectSource) canReconnect() bool { return s != nil && s.dialFn != nil } func (s *clientConnectSource) isUDP() bool { if s == nil { return false } return transport.IsUDPNetwork(s.network) } func (s *clientConnectSource) dial(ctx context.Context) (net.Conn, error) { if s == nil || s.dialFn == nil { return nil, errClientReconnectSourceUnavailable } if ctx == nil { ctx = context.Background() } return s.dialFn(ctx) } func (c *ClientCommon) setClientConnectSource(source *clientConnectSource) { if c == nil { return } if source == nil { c.connectSource.Store(nil) return } c.connectSource.Store(source.clone()) } func (c *ClientCommon) clientConnectSourceSnapshot() *clientConnectSource { if c == nil { return nil } if source := c.connectSource.Load(); source != nil { return source.clone() } return nil }