package notify import ( "b612.me/stario" "context" "errors" "fmt" "math" "net" "sync/atomic" "time" ) func (c *ClientCommon) closeClientTransport() { c.closeClientTransportBinding(c.clientTransportBindingSnapshot()) } func (c *ClientCommon) closeClientTransportConn(conn net.Conn) { if c == nil || conn == nil { return } _ = conn.Close() } func (c *ClientCommon) closeClientTransportBinding(binding *transportBinding) { if binding == nil { return } c.closeClientTransportConn(binding.connSnapshot()) binding.stopBackgroundWorkers() } func (c *ClientCommon) beginClientSessionEpoch() uint64 { if c == nil { return 0 } return atomic.AddUint64(&c.sessionEpoch, 1) } func (c *ClientCommon) currentClientSessionEpoch() uint64 { if c == nil { return 0 } return atomic.LoadUint64(&c.sessionEpoch) } func (c *ClientCommon) isClientSessionEpochCurrent(epoch uint64) bool { if c == nil || epoch == 0 { return false } return c.currentClientSessionEpoch() == epoch } func (c *ClientCommon) stopClientSessionIfCurrent(epoch uint64, reason string, err error) bool { if !c.isClientSessionEpochCurrent(epoch) { return false } c.stopClientSession(reason, err) return true } func (c *ClientCommon) setByeFromServer(val bool) { if c == nil { return } c.mu.Lock() c.byeFromServer = val c.mu.Unlock() } func (c *ClientCommon) resetClientStopState() { c.setByeFromServer(false) } func (c *ClientCommon) shouldSayGoodByeOnStop() bool { if c == nil { return false } c.mu.Lock() defer c.mu.Unlock() return !c.byeFromServer } func (c *ClientCommon) stopClientSession(reason string, err error) { if c == nil { return } c.markSessionStopped(reason, err) } func (c *ClientCommon) stopClientSessionFromServer(reason string, err error) { if c == nil { return } c.setByeFromServer(true) c.markSessionStopped(reason, err) } func (c *ClientCommon) beginClientConnectAttempt() (func(success bool), error) { if !c.beginClientSessionStart() { return nil, errors.New("client already run") } return func(success bool) { if success { return } c.cleanupFailedClientStart() }, nil } func (c *ClientCommon) clientCanAttachTransport() bool { if c == nil { return false } if !sessionIsAlive(&c.alive) { return false } if c.clientTransportAttachedSnapshot() { return false } rt := c.clientSessionRuntimeSnapshot() if rt == nil { return false } return rt.stopCtx != nil && rt.queue != nil } func (c *ClientCommon) attachClientWithConnSource(conn net.Conn, source *clientConnectSource) error { if c == nil { return errors.New("client is nil") } if conn == nil { return errors.New("conn is nil") } if err := c.attachClientSessionTransport(conn); err != nil { _ = conn.Close() return err } if err := c.bootstrapClientTransportRuntime(c.clientSessionRuntimeSnapshot(), true, false); err != nil { return err } c.setClientConnectSource(source) return nil } func (c *ClientCommon) Connect(network string, addr string) error { if err := c.validateSecurityConfiguration(); err != nil { return err } source := newClientNetworkConnectSource(network, addr) c.applySignalReliabilityTransportDefault(source.isUDP()) if c.clientCanAttachTransport() { conn, err := source.dial(nil) if err != nil { return err } return c.attachClientWithConnSource(conn, source) } finish, err := c.beginClientConnectAttempt() if err != nil { return err } started := false defer func() { finish(started) }() conn, err := source.dial(nil) if err != nil { return err } if err := c.startClientWithConnSource(conn, source); err != nil { return err } started = true return nil } func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error { if err := c.validateSecurityConfiguration(); err != nil { return err } source := newClientTimeoutConnectSource(network, addr, timeout) c.applySignalReliabilityTransportDefault(source.isUDP()) if c.clientCanAttachTransport() { conn, err := source.dial(nil) if err != nil { return err } return c.attachClientWithConnSource(conn, source) } finish, err := c.beginClientConnectAttempt() if err != nil { return err } started := false defer func() { finish(started) }() conn, err := source.dial(nil) if err != nil { return err } if err := c.startClientWithConnSource(conn, source); err != nil { return err } started = true return nil } func (c *ClientCommon) ConnectByConn(conn net.Conn) error { if err := c.validateSecurityConfiguration(); err != nil { return err } if conn == nil { return errors.New("conn is nil") } source := newClientConnConnectSource(conn) c.applySignalReliabilityTransportDefault(false) if c.clientCanAttachTransport() { return c.attachClientWithConnSource(conn, source) } finish, err := c.beginClientConnectAttempt() if err != nil { return err } started := false defer func() { finish(started) }() if err := c.startClientWithConnSource(conn, source); err != nil { return err } started = true return nil } func (c *ClientCommon) ConnectByFactory(ctx context.Context, dialFn func(context.Context) (net.Conn, error)) error { if err := c.validateSecurityConfiguration(); err != nil { return err } if dialFn == nil { return errors.New("dialFn is nil") } if ctx == nil { ctx = context.Background() } source := newClientFactoryConnectSource(dialFn) if c.clientCanAttachTransport() { c.applySignalReliabilityTransportDefault(false) conn, err := dialFn(ctx) if err != nil { return err } if conn == nil { return errors.New("conn is nil") } return c.attachClientWithConnSource(conn, source) } finish, err := c.beginClientConnectAttempt() if err != nil { return err } started := false defer func() { finish(started) }() conn, err := dialFn(ctx) if err != nil { return err } if conn == nil { return errors.New("conn is nil") } c.applySignalReliabilityTransportDefault(false) if err := c.startClientWithConnSource(conn, source); err != nil { return err } started = true return nil } func (c *ClientCommon) startClientWithConn(conn net.Conn) error { return c.startClientWithConnSource(conn, newClientConnConnectSource(conn)) } func (c *ClientCommon) startClientWithConnSource(conn net.Conn, source *clientConnectSource) error { stopCtx, stopFn := context.WithCancel(context.Background()) epoch := c.beginClientSessionEpoch() queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) c.setClientConnectSource(source) rt := newClientSessionRuntime(conn, stopCtx, stopFn, queue, epoch) c.setClientSessionRuntime(rt) c.resetClientStopState() c.markSessionStarted() return c.clientPostInit(rt) } func (c *ClientCommon) monitorPool() { c.monitorPoolLoop(c.clientStopContextSnapshot()) } func (c *ClientCommon) monitorPoolLoop(stopCtx context.Context) { if stopCtx == nil { return } for { select { case <-stopCtx.Done(): if c.clientStopContextSnapshot() == stopCtx { c.getPendingWaitPool().closeAll() c.getFileAckPool().closeAll() c.getSignalAckPool().closeAll() } return case <-time.After(time.Second * 30): } now := time.Now() c.getPendingWaitPool().cleanupExpired(int64(c.noFinSyncMsgMaxKeepSeconds), now) } } func (c *ClientCommon) clientPostInit(rt *clientSessionRuntime) error { if rt == nil { return nil } go c.monitorPoolLoop(rt.stopCtx) if err := c.startClientTransportRuntime(rt); err != nil { return err } return c.bootstrapClientTransportRuntime(rt, true, true) } func (c *ClientCommon) startClientTransportRuntime(rt *clientSessionRuntime) error { if rt == nil { return nil } transportStopCtx := rt.transportStopCtx if transportStopCtx == nil { transportStopCtx = rt.stopCtx } if c.useHeartBeat { go c.heartbeatLoop(transportStopCtx, rt.epoch) } go c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, rt.epoch) go c.loadMessageLoop(rt) return nil } func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime, runKeyExchange bool, stopSessionOnFailure bool) error { if rt == nil { return nil } c.resetClientPeerAttachAuth() c.activateClientBootstrapTransportProtection() if runKeyExchange && !c.skipKeyExchange { if err := c.keyExchangeFn(c); err != nil { return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err) } } if err := c.announceClientPeerIdentity(); err != nil { return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err) } c.activateClientSteadyTransportProtection() return nil } func (c *ClientCommon) failClientTransportBootstrap(rt *clientSessionRuntime, stopSessionOnFailure bool, reason string, err error) error { if c == nil || rt == nil { return err } c.retireClientSessionRuntime(rt, true) c.closeClientTransportConn(rt.conn) if stopSessionOnFailure { c.stopClientSessionIfCurrent(rt.epoch, reason, err) return err } c.clearClientSessionRuntimeTransport() return err } func (c *ClientCommon) Heartbeat() { rt := c.clientSessionRuntimeSnapshot() if rt == nil { return } epoch := rt.epoch if epoch == 0 { epoch = c.currentClientSessionEpoch() } transportStopCtx := rt.transportStopCtx if transportStopCtx == nil { transportStopCtx = rt.stopCtx } c.heartbeatLoop(transportStopCtx, epoch) } func (c *ClientCommon) heartbeatLoop(stopCtx context.Context, epoch uint64) { if stopCtx == nil { return } failedCount := 0 for { select { case <-stopCtx.Done(): return case <-time.After(c.heartbeatPeriod): } err := c.sendHeartbeat() var stop bool failedCount, stop = c.handleHeartbeatResultWithSession(epoch, err, failedCount) if stop { return } } } func (c *ClientCommon) readMessage() { rt := c.clientSessionRuntimeSnapshot() if rt == nil { return } epoch := rt.epoch if epoch == 0 { epoch = c.currentClientSessionEpoch() } transportStopCtx := rt.transportStopCtx if transportStopCtx == nil { transportStopCtx = rt.stopCtx } c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, epoch) } func (c *ClientCommon) readMessageLoop(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, epoch uint64) { if stopCtx == nil { return } binding := newTransportBinding(conn, queue) dispatcher := c.clientInboundDispatcherSnapshot() if conn != nil && queue != nil && !isPacketTransportConn(conn) { reader := newTransportFrameReader(conn, queue) for { select { case <-stopCtx.Done(): c.closeClientTransportBinding(binding) return default: } payload, release, err := c.readTransportPayloadPooled(conn, reader) if !c.handleTransportPayloadReadResultWithSession(stopCtx, binding, payload, release, err, epoch, dispatcher) { return } } } buf := streamReadBuffer() for { select { case <-stopCtx.Done(): c.closeClientTransportBinding(binding) return default: } readNum, data, err := c.readFromTransportBindingWithBuffer(binding, buf) if !c.handleTransportReadResultWithSessionDispatcher(stopCtx, conn, queue, readNum, data, err, epoch, dispatcher) { return } } } func (c *ClientCommon) sayGoodBye() error { _, err := c.sendWait(TransferMsg{ ID: 10010, Key: "bye", Value: nil, Type: MSG_SYS_WAIT, }, time.Second*3) return err } func (c *ClientCommon) loadMessage() { rt := c.clientSessionRuntimeSnapshot() if rt == nil { return } c.loadMessageLoop(rt) } func (c *ClientCommon) loadMessageLoop(rt *clientSessionRuntime) { if rt == nil { return } stopCtx := rt.transportStopCtx if stopCtx == nil { stopCtx = rt.stopCtx } if stopCtx == nil { return } queue := rt.queue if rt.transport != nil { queue = rt.transport.queueSnapshot() } if queue == nil { return } dispatcher := rt.inboundDispatcher if dispatcher == nil { dispatcher = newInboundDispatcher() defer dispatcher.CloseAndWait() } for { select { case <-stopCtx.Done(): sessionStopping := rt.stopCtx != nil && rt.stopCtx.Err() != nil if sessionStopping && rt.inboundDispatcher != nil { rt.inboundDispatcher.CloseAndWait() } if sessionStopping && !rt.runtimeShouldSuppressGoodByeOnStop() && c.shouldSayGoodByeOnStop() { c.sayGoodBye() } c.closeClientTransportBinding(rt.transport) return case data, ok := <-queue.RestoreChan(): if !ok { continue } msg := data c.wg.Add(1) if !dispatcher.Dispatch(clientInboundDispatchSource(), func() { defer c.wg.Done() now := time.Now() if err := c.dispatchInboundTransportPayload(msg.Msg, now); err != nil { if c.showError || c.debugMode { fmt.Println("client decode envelope error", err) } } }) { c.wg.Done() } } } }