notify/client_runtime.go

524 lines
12 KiB
Go
Raw Permalink Normal View History

package notify
import (
"b612.me/stario"
"context"
"errors"
"fmt"
"math"
"net"
"sync/atomic"
"time"
)
func (c *ClientCommon) closeClientTransport() {
c.closeClientTransportBinding(c.clientTransportBindingSnapshot())
}
func (c *ClientCommon) closeClientTransportConn(conn net.Conn) {
if c == nil || conn == nil {
return
}
_ = conn.Close()
}
func (c *ClientCommon) closeClientTransportBinding(binding *transportBinding) {
if binding == nil {
return
}
c.closeClientTransportConn(binding.connSnapshot())
binding.stopBackgroundWorkers()
}
func (c *ClientCommon) beginClientSessionEpoch() uint64 {
if c == nil {
return 0
}
return atomic.AddUint64(&c.sessionEpoch, 1)
}
func (c *ClientCommon) currentClientSessionEpoch() uint64 {
if c == nil {
return 0
}
return atomic.LoadUint64(&c.sessionEpoch)
}
func (c *ClientCommon) isClientSessionEpochCurrent(epoch uint64) bool {
if c == nil || epoch == 0 {
return false
}
return c.currentClientSessionEpoch() == epoch
}
func (c *ClientCommon) stopClientSessionIfCurrent(epoch uint64, reason string, err error) bool {
if !c.isClientSessionEpochCurrent(epoch) {
return false
}
c.stopClientSession(reason, err)
return true
}
func (c *ClientCommon) setByeFromServer(val bool) {
if c == nil {
return
}
c.mu.Lock()
c.byeFromServer = val
c.mu.Unlock()
}
func (c *ClientCommon) resetClientStopState() {
c.setByeFromServer(false)
}
func (c *ClientCommon) shouldSayGoodByeOnStop() bool {
if c == nil {
return false
}
c.mu.Lock()
defer c.mu.Unlock()
return !c.byeFromServer
}
func (c *ClientCommon) stopClientSession(reason string, err error) {
if c == nil {
return
}
c.markSessionStopped(reason, err)
}
func (c *ClientCommon) stopClientSessionFromServer(reason string, err error) {
if c == nil {
return
}
c.setByeFromServer(true)
c.markSessionStopped(reason, err)
}
func (c *ClientCommon) beginClientConnectAttempt() (func(success bool), error) {
if !c.beginClientSessionStart() {
return nil, errors.New("client already run")
}
return func(success bool) {
if success {
return
}
c.cleanupFailedClientStart()
}, nil
}
func (c *ClientCommon) clientCanAttachTransport() bool {
if c == nil {
return false
}
if !sessionIsAlive(&c.alive) {
return false
}
if c.clientTransportAttachedSnapshot() {
return false
}
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return false
}
return rt.stopCtx != nil && rt.queue != nil
}
func (c *ClientCommon) attachClientWithConnSource(conn net.Conn, source *clientConnectSource) error {
if c == nil {
return errors.New("client is nil")
}
if conn == nil {
return errors.New("conn is nil")
}
if err := c.attachClientSessionTransport(conn); err != nil {
_ = conn.Close()
return err
}
if err := c.bootstrapClientTransportRuntime(c.clientSessionRuntimeSnapshot(), true, false); err != nil {
return err
}
c.setClientConnectSource(source)
return nil
}
func (c *ClientCommon) Connect(network string, addr string) error {
if err := c.validateSecurityConfiguration(); err != nil {
return err
}
source := newClientNetworkConnectSource(network, addr)
c.applySignalReliabilityTransportDefault(source.isUDP())
if c.clientCanAttachTransport() {
conn, err := source.dial(nil)
if err != nil {
return err
}
return c.attachClientWithConnSource(conn, source)
}
finish, err := c.beginClientConnectAttempt()
if err != nil {
return err
}
started := false
defer func() {
finish(started)
}()
conn, err := source.dial(nil)
if err != nil {
return err
}
if err := c.startClientWithConnSource(conn, source); err != nil {
return err
}
started = true
return nil
}
func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error {
if err := c.validateSecurityConfiguration(); err != nil {
return err
}
source := newClientTimeoutConnectSource(network, addr, timeout)
c.applySignalReliabilityTransportDefault(source.isUDP())
if c.clientCanAttachTransport() {
conn, err := source.dial(nil)
if err != nil {
return err
}
return c.attachClientWithConnSource(conn, source)
}
finish, err := c.beginClientConnectAttempt()
if err != nil {
return err
}
started := false
defer func() {
finish(started)
}()
conn, err := source.dial(nil)
if err != nil {
return err
}
if err := c.startClientWithConnSource(conn, source); err != nil {
return err
}
started = true
return nil
}
func (c *ClientCommon) ConnectByConn(conn net.Conn) error {
if err := c.validateSecurityConfiguration(); err != nil {
return err
}
if conn == nil {
return errors.New("conn is nil")
}
source := newClientConnConnectSource(conn)
c.applySignalReliabilityTransportDefault(false)
if c.clientCanAttachTransport() {
return c.attachClientWithConnSource(conn, source)
}
finish, err := c.beginClientConnectAttempt()
if err != nil {
return err
}
started := false
defer func() {
finish(started)
}()
if err := c.startClientWithConnSource(conn, source); err != nil {
return err
}
started = true
return nil
}
func (c *ClientCommon) ConnectByFactory(ctx context.Context, dialFn func(context.Context) (net.Conn, error)) error {
if err := c.validateSecurityConfiguration(); err != nil {
return err
}
if dialFn == nil {
return errors.New("dialFn is nil")
}
if ctx == nil {
ctx = context.Background()
}
source := newClientFactoryConnectSource(dialFn)
if c.clientCanAttachTransport() {
c.applySignalReliabilityTransportDefault(false)
conn, err := dialFn(ctx)
if err != nil {
return err
}
if conn == nil {
return errors.New("conn is nil")
}
return c.attachClientWithConnSource(conn, source)
}
finish, err := c.beginClientConnectAttempt()
if err != nil {
return err
}
started := false
defer func() {
finish(started)
}()
conn, err := dialFn(ctx)
if err != nil {
return err
}
if conn == nil {
return errors.New("conn is nil")
}
c.applySignalReliabilityTransportDefault(false)
if err := c.startClientWithConnSource(conn, source); err != nil {
return err
}
started = true
return nil
}
func (c *ClientCommon) startClientWithConn(conn net.Conn) error {
return c.startClientWithConnSource(conn, newClientConnConnectSource(conn))
}
func (c *ClientCommon) startClientWithConnSource(conn net.Conn, source *clientConnectSource) error {
stopCtx, stopFn := context.WithCancel(context.Background())
epoch := c.beginClientSessionEpoch()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
c.setClientConnectSource(source)
rt := newClientSessionRuntime(conn, stopCtx, stopFn, queue, epoch)
c.setClientSessionRuntime(rt)
c.resetClientStopState()
c.markSessionStarted()
return c.clientPostInit(rt)
}
func (c *ClientCommon) monitorPool() {
c.monitorPoolLoop(c.clientStopContextSnapshot())
}
func (c *ClientCommon) monitorPoolLoop(stopCtx context.Context) {
if stopCtx == nil {
return
}
for {
select {
case <-stopCtx.Done():
if c.clientStopContextSnapshot() == stopCtx {
c.getPendingWaitPool().closeAll()
c.getFileAckPool().closeAll()
c.getSignalAckPool().closeAll()
}
return
case <-time.After(time.Second * 30):
}
now := time.Now()
c.getPendingWaitPool().cleanupExpired(int64(c.noFinSyncMsgMaxKeepSeconds), now)
}
}
func (c *ClientCommon) clientPostInit(rt *clientSessionRuntime) error {
if rt == nil {
return nil
}
go c.monitorPoolLoop(rt.stopCtx)
if err := c.startClientTransportRuntime(rt); err != nil {
return err
}
return c.bootstrapClientTransportRuntime(rt, true, true)
}
func (c *ClientCommon) startClientTransportRuntime(rt *clientSessionRuntime) error {
if rt == nil {
return nil
}
transportStopCtx := rt.transportStopCtx
if transportStopCtx == nil {
transportStopCtx = rt.stopCtx
}
if c.useHeartBeat {
go c.heartbeatLoop(transportStopCtx, rt.epoch)
}
go c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, rt.epoch)
go c.loadMessageLoop(rt)
return nil
}
func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime, runKeyExchange bool, stopSessionOnFailure bool) error {
if rt == nil {
return nil
}
if runKeyExchange && !c.skipKeyExchange {
if err := c.keyExchangeFn(c); err != nil {
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err)
}
}
if err := c.announceClientPeerIdentity(); err != nil {
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err)
}
return nil
}
func (c *ClientCommon) failClientTransportBootstrap(rt *clientSessionRuntime, stopSessionOnFailure bool, reason string, err error) error {
if c == nil || rt == nil {
return err
}
c.retireClientSessionRuntime(rt, true)
c.closeClientTransportConn(rt.conn)
if stopSessionOnFailure {
c.stopClientSessionIfCurrent(rt.epoch, reason, err)
return err
}
c.clearClientSessionRuntimeTransport()
return err
}
func (c *ClientCommon) Heartbeat() {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return
}
epoch := rt.epoch
if epoch == 0 {
epoch = c.currentClientSessionEpoch()
}
transportStopCtx := rt.transportStopCtx
if transportStopCtx == nil {
transportStopCtx = rt.stopCtx
}
c.heartbeatLoop(transportStopCtx, epoch)
}
func (c *ClientCommon) heartbeatLoop(stopCtx context.Context, epoch uint64) {
if stopCtx == nil {
return
}
failedCount := 0
for {
select {
case <-stopCtx.Done():
return
case <-time.After(c.heartbeatPeriod):
}
err := c.sendHeartbeat()
var stop bool
failedCount, stop = c.handleHeartbeatResultWithSession(epoch, err, failedCount)
if stop {
return
}
}
}
func (c *ClientCommon) readMessage() {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return
}
epoch := rt.epoch
if epoch == 0 {
epoch = c.currentClientSessionEpoch()
}
transportStopCtx := rt.transportStopCtx
if transportStopCtx == nil {
transportStopCtx = rt.stopCtx
}
c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, epoch)
}
func (c *ClientCommon) readMessageLoop(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, epoch uint64) {
if stopCtx == nil {
return
}
binding := newTransportBinding(conn, queue)
dispatcher := c.clientInboundDispatcherSnapshot()
buf := streamReadBuffer()
for {
select {
case <-stopCtx.Done():
c.closeClientTransportBinding(binding)
return
default:
}
readNum, data, err := c.readFromTransportBindingWithBuffer(binding, buf)
if !c.handleTransportReadResultWithSessionDispatcher(stopCtx, conn, queue, readNum, data, err, epoch, dispatcher) {
return
}
}
}
func (c *ClientCommon) sayGoodBye() error {
_, err := c.sendWait(TransferMsg{
ID: 10010,
Key: "bye",
Value: nil,
Type: MSG_SYS_WAIT,
}, time.Second*3)
return err
}
func (c *ClientCommon) loadMessage() {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return
}
c.loadMessageLoop(rt)
}
func (c *ClientCommon) loadMessageLoop(rt *clientSessionRuntime) {
if rt == nil {
return
}
stopCtx := rt.transportStopCtx
if stopCtx == nil {
stopCtx = rt.stopCtx
}
if stopCtx == nil {
return
}
queue := rt.queue
if rt.transport != nil {
queue = rt.transport.queueSnapshot()
}
if queue == nil {
return
}
dispatcher := rt.inboundDispatcher
if dispatcher == nil {
dispatcher = newInboundDispatcher()
defer dispatcher.CloseAndWait()
}
for {
select {
case <-stopCtx.Done():
sessionStopping := rt.stopCtx != nil && rt.stopCtx.Err() != nil
if sessionStopping && rt.inboundDispatcher != nil {
rt.inboundDispatcher.CloseAndWait()
}
if sessionStopping && !rt.runtimeShouldSuppressGoodByeOnStop() && c.shouldSayGoodByeOnStop() {
c.sayGoodBye()
}
c.closeClientTransportBinding(rt.transport)
return
case data, ok := <-queue.RestoreChan():
if !ok {
continue
}
msg := data
c.wg.Add(1)
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
defer c.wg.Done()
now := time.Now()
if err := c.dispatchInboundTransportPayload(msg.Msg, now); err != nil {
if c.showError || c.debugMode {
fmt.Println("client decode envelope error", err)
}
}
}) {
c.wg.Done()
}
}
}
}