122 lines
2.7 KiB
Go
122 lines
2.7 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"b612.me/notify/internal/transport"
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"net"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
const (
|
||
|
|
clientConnectSourceConn = "conn"
|
||
|
|
clientConnectSourceNetwork = "network"
|
||
|
|
clientConnectSourceTimeout = "timeout"
|
||
|
|
clientConnectSourceFactory = "factory"
|
||
|
|
)
|
||
|
|
|
||
|
|
var errClientReconnectSourceUnavailable = errors.New("client reconnect source is unavailable")
|
||
|
|
|
||
|
|
type clientConnectSource struct {
|
||
|
|
kind string
|
||
|
|
network string
|
||
|
|
addr string
|
||
|
|
dialFn func(context.Context) (net.Conn, error)
|
||
|
|
}
|
||
|
|
|
||
|
|
func newClientConnConnectSource(conn net.Conn) *clientConnectSource {
|
||
|
|
source := &clientConnectSource{kind: clientConnectSourceConn}
|
||
|
|
if conn == nil {
|
||
|
|
return source
|
||
|
|
}
|
||
|
|
if remoteAddr := conn.RemoteAddr(); remoteAddr != nil {
|
||
|
|
source.network = remoteAddr.Network()
|
||
|
|
source.addr = remoteAddr.String()
|
||
|
|
}
|
||
|
|
if source.network == "" {
|
||
|
|
if localAddr := conn.LocalAddr(); localAddr != nil {
|
||
|
|
source.network = localAddr.Network()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return source
|
||
|
|
}
|
||
|
|
|
||
|
|
func newClientNetworkConnectSource(network string, addr string) *clientConnectSource {
|
||
|
|
return &clientConnectSource{
|
||
|
|
kind: clientConnectSourceNetwork,
|
||
|
|
network: network,
|
||
|
|
addr: addr,
|
||
|
|
dialFn: func(context.Context) (net.Conn, error) {
|
||
|
|
return transport.Dial(network, addr)
|
||
|
|
},
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func newClientTimeoutConnectSource(network string, addr string, timeout time.Duration) *clientConnectSource {
|
||
|
|
return &clientConnectSource{
|
||
|
|
kind: clientConnectSourceTimeout,
|
||
|
|
network: network,
|
||
|
|
addr: addr,
|
||
|
|
dialFn: func(context.Context) (net.Conn, error) {
|
||
|
|
return transport.DialTimeout(network, addr, timeout)
|
||
|
|
},
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func newClientFactoryConnectSource(dialFn func(context.Context) (net.Conn, error)) *clientConnectSource {
|
||
|
|
return &clientConnectSource{
|
||
|
|
kind: clientConnectSourceFactory,
|
||
|
|
dialFn: dialFn,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *clientConnectSource) clone() *clientConnectSource {
|
||
|
|
if s == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
out := *s
|
||
|
|
return &out
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *clientConnectSource) canReconnect() bool {
|
||
|
|
return s != nil && s.dialFn != nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *clientConnectSource) isUDP() bool {
|
||
|
|
if s == nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
return transport.IsUDPNetwork(s.network)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *clientConnectSource) dial(ctx context.Context) (net.Conn, error) {
|
||
|
|
if s == nil || s.dialFn == nil {
|
||
|
|
return nil, errClientReconnectSourceUnavailable
|
||
|
|
}
|
||
|
|
if ctx == nil {
|
||
|
|
ctx = context.Background()
|
||
|
|
}
|
||
|
|
return s.dialFn(ctx)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *ClientCommon) setClientConnectSource(source *clientConnectSource) {
|
||
|
|
if c == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if source == nil {
|
||
|
|
c.connectSource.Store(nil)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
c.connectSource.Store(source.clone())
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *ClientCommon) clientConnectSourceSnapshot() *clientConnectSource {
|
||
|
|
if c == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if source := c.connectSource.Load(); source != nil {
|
||
|
|
return source.clone()
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|