- 新增 managed/external/nested 三种传输保护模式 - 新增 peer attach 显式认证、抗重放、channel binding 和可选前向保密协商 - 明确单连接注入与可重拨连接源的语义边界 - 禁止 ConnectByConn 场景下 dedicated bulk 走 sidecar,auto 模式自动回退 shared - 修正 dedicated attach 在 bootstrap/steady profile 切换下的处理逻辑 - 优化 shared bulk super-batch 与批量 framed write 路径 - 降低 stream/bulk fast path 的复制和分发损耗 - 补齐 benchmark、回归测试、运行时快照和 README 文档
542 lines
12 KiB
Go
542 lines
12 KiB
Go
package notify
|
|
|
|
import (
|
|
"b612.me/stario"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"net"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
func (c *ClientCommon) closeClientTransport() {
|
|
c.closeClientTransportBinding(c.clientTransportBindingSnapshot())
|
|
}
|
|
|
|
func (c *ClientCommon) closeClientTransportConn(conn net.Conn) {
|
|
if c == nil || conn == nil {
|
|
return
|
|
}
|
|
_ = conn.Close()
|
|
}
|
|
|
|
func (c *ClientCommon) closeClientTransportBinding(binding *transportBinding) {
|
|
if binding == nil {
|
|
return
|
|
}
|
|
c.closeClientTransportConn(binding.connSnapshot())
|
|
binding.stopBackgroundWorkers()
|
|
}
|
|
|
|
func (c *ClientCommon) beginClientSessionEpoch() uint64 {
|
|
if c == nil {
|
|
return 0
|
|
}
|
|
return atomic.AddUint64(&c.sessionEpoch, 1)
|
|
}
|
|
|
|
func (c *ClientCommon) currentClientSessionEpoch() uint64 {
|
|
if c == nil {
|
|
return 0
|
|
}
|
|
return atomic.LoadUint64(&c.sessionEpoch)
|
|
}
|
|
|
|
func (c *ClientCommon) isClientSessionEpochCurrent(epoch uint64) bool {
|
|
if c == nil || epoch == 0 {
|
|
return false
|
|
}
|
|
return c.currentClientSessionEpoch() == epoch
|
|
}
|
|
|
|
func (c *ClientCommon) stopClientSessionIfCurrent(epoch uint64, reason string, err error) bool {
|
|
if !c.isClientSessionEpochCurrent(epoch) {
|
|
return false
|
|
}
|
|
c.stopClientSession(reason, err)
|
|
return true
|
|
}
|
|
|
|
func (c *ClientCommon) setByeFromServer(val bool) {
|
|
if c == nil {
|
|
return
|
|
}
|
|
c.mu.Lock()
|
|
c.byeFromServer = val
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
func (c *ClientCommon) resetClientStopState() {
|
|
c.setByeFromServer(false)
|
|
}
|
|
|
|
func (c *ClientCommon) shouldSayGoodByeOnStop() bool {
|
|
if c == nil {
|
|
return false
|
|
}
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return !c.byeFromServer
|
|
}
|
|
|
|
func (c *ClientCommon) stopClientSession(reason string, err error) {
|
|
if c == nil {
|
|
return
|
|
}
|
|
c.markSessionStopped(reason, err)
|
|
}
|
|
|
|
func (c *ClientCommon) stopClientSessionFromServer(reason string, err error) {
|
|
if c == nil {
|
|
return
|
|
}
|
|
c.setByeFromServer(true)
|
|
c.markSessionStopped(reason, err)
|
|
}
|
|
|
|
func (c *ClientCommon) beginClientConnectAttempt() (func(success bool), error) {
|
|
if !c.beginClientSessionStart() {
|
|
return nil, errors.New("client already run")
|
|
}
|
|
return func(success bool) {
|
|
if success {
|
|
return
|
|
}
|
|
c.cleanupFailedClientStart()
|
|
}, nil
|
|
}
|
|
|
|
func (c *ClientCommon) clientCanAttachTransport() bool {
|
|
if c == nil {
|
|
return false
|
|
}
|
|
if !sessionIsAlive(&c.alive) {
|
|
return false
|
|
}
|
|
if c.clientTransportAttachedSnapshot() {
|
|
return false
|
|
}
|
|
rt := c.clientSessionRuntimeSnapshot()
|
|
if rt == nil {
|
|
return false
|
|
}
|
|
return rt.stopCtx != nil && rt.queue != nil
|
|
}
|
|
|
|
func (c *ClientCommon) attachClientWithConnSource(conn net.Conn, source *clientConnectSource) error {
|
|
if c == nil {
|
|
return errors.New("client is nil")
|
|
}
|
|
if conn == nil {
|
|
return errors.New("conn is nil")
|
|
}
|
|
if err := c.attachClientSessionTransport(conn); err != nil {
|
|
_ = conn.Close()
|
|
return err
|
|
}
|
|
if err := c.bootstrapClientTransportRuntime(c.clientSessionRuntimeSnapshot(), true, false); err != nil {
|
|
return err
|
|
}
|
|
c.setClientConnectSource(source)
|
|
return nil
|
|
}
|
|
|
|
func (c *ClientCommon) Connect(network string, addr string) error {
|
|
if err := c.validateSecurityConfiguration(); err != nil {
|
|
return err
|
|
}
|
|
source := newClientNetworkConnectSource(network, addr)
|
|
c.applySignalReliabilityTransportDefault(source.isUDP())
|
|
if c.clientCanAttachTransport() {
|
|
conn, err := source.dial(nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.attachClientWithConnSource(conn, source)
|
|
}
|
|
finish, err := c.beginClientConnectAttempt()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
started := false
|
|
defer func() {
|
|
finish(started)
|
|
}()
|
|
conn, err := source.dial(nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := c.startClientWithConnSource(conn, source); err != nil {
|
|
return err
|
|
}
|
|
started = true
|
|
return nil
|
|
}
|
|
|
|
func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error {
|
|
if err := c.validateSecurityConfiguration(); err != nil {
|
|
return err
|
|
}
|
|
source := newClientTimeoutConnectSource(network, addr, timeout)
|
|
c.applySignalReliabilityTransportDefault(source.isUDP())
|
|
if c.clientCanAttachTransport() {
|
|
conn, err := source.dial(nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.attachClientWithConnSource(conn, source)
|
|
}
|
|
finish, err := c.beginClientConnectAttempt()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
started := false
|
|
defer func() {
|
|
finish(started)
|
|
}()
|
|
conn, err := source.dial(nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := c.startClientWithConnSource(conn, source); err != nil {
|
|
return err
|
|
}
|
|
started = true
|
|
return nil
|
|
}
|
|
|
|
func (c *ClientCommon) ConnectByConn(conn net.Conn) error {
|
|
if err := c.validateSecurityConfiguration(); err != nil {
|
|
return err
|
|
}
|
|
if conn == nil {
|
|
return errors.New("conn is nil")
|
|
}
|
|
source := newClientConnConnectSource(conn)
|
|
c.applySignalReliabilityTransportDefault(false)
|
|
if c.clientCanAttachTransport() {
|
|
return c.attachClientWithConnSource(conn, source)
|
|
}
|
|
finish, err := c.beginClientConnectAttempt()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
started := false
|
|
defer func() {
|
|
finish(started)
|
|
}()
|
|
if err := c.startClientWithConnSource(conn, source); err != nil {
|
|
return err
|
|
}
|
|
started = true
|
|
return nil
|
|
}
|
|
|
|
func (c *ClientCommon) ConnectByFactory(ctx context.Context, dialFn func(context.Context) (net.Conn, error)) error {
|
|
if err := c.validateSecurityConfiguration(); err != nil {
|
|
return err
|
|
}
|
|
if dialFn == nil {
|
|
return errors.New("dialFn is nil")
|
|
}
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
source := newClientFactoryConnectSource(dialFn)
|
|
if c.clientCanAttachTransport() {
|
|
c.applySignalReliabilityTransportDefault(false)
|
|
conn, err := dialFn(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if conn == nil {
|
|
return errors.New("conn is nil")
|
|
}
|
|
return c.attachClientWithConnSource(conn, source)
|
|
}
|
|
finish, err := c.beginClientConnectAttempt()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
started := false
|
|
defer func() {
|
|
finish(started)
|
|
}()
|
|
conn, err := dialFn(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if conn == nil {
|
|
return errors.New("conn is nil")
|
|
}
|
|
c.applySignalReliabilityTransportDefault(false)
|
|
if err := c.startClientWithConnSource(conn, source); err != nil {
|
|
return err
|
|
}
|
|
started = true
|
|
return nil
|
|
}
|
|
|
|
func (c *ClientCommon) startClientWithConn(conn net.Conn) error {
|
|
return c.startClientWithConnSource(conn, newClientConnConnectSource(conn))
|
|
}
|
|
|
|
func (c *ClientCommon) startClientWithConnSource(conn net.Conn, source *clientConnectSource) error {
|
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
|
epoch := c.beginClientSessionEpoch()
|
|
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
|
c.setClientConnectSource(source)
|
|
rt := newClientSessionRuntime(conn, stopCtx, stopFn, queue, epoch)
|
|
c.setClientSessionRuntime(rt)
|
|
c.resetClientStopState()
|
|
c.markSessionStarted()
|
|
return c.clientPostInit(rt)
|
|
}
|
|
|
|
func (c *ClientCommon) monitorPool() {
|
|
c.monitorPoolLoop(c.clientStopContextSnapshot())
|
|
}
|
|
|
|
func (c *ClientCommon) monitorPoolLoop(stopCtx context.Context) {
|
|
if stopCtx == nil {
|
|
return
|
|
}
|
|
for {
|
|
select {
|
|
case <-stopCtx.Done():
|
|
if c.clientStopContextSnapshot() == stopCtx {
|
|
c.getPendingWaitPool().closeAll()
|
|
c.getFileAckPool().closeAll()
|
|
c.getSignalAckPool().closeAll()
|
|
}
|
|
return
|
|
case <-time.After(time.Second * 30):
|
|
}
|
|
now := time.Now()
|
|
c.getPendingWaitPool().cleanupExpired(int64(c.noFinSyncMsgMaxKeepSeconds), now)
|
|
}
|
|
}
|
|
|
|
func (c *ClientCommon) clientPostInit(rt *clientSessionRuntime) error {
|
|
if rt == nil {
|
|
return nil
|
|
}
|
|
go c.monitorPoolLoop(rt.stopCtx)
|
|
if err := c.startClientTransportRuntime(rt); err != nil {
|
|
return err
|
|
}
|
|
return c.bootstrapClientTransportRuntime(rt, true, true)
|
|
}
|
|
|
|
func (c *ClientCommon) startClientTransportRuntime(rt *clientSessionRuntime) error {
|
|
if rt == nil {
|
|
return nil
|
|
}
|
|
transportStopCtx := rt.transportStopCtx
|
|
if transportStopCtx == nil {
|
|
transportStopCtx = rt.stopCtx
|
|
}
|
|
if c.useHeartBeat {
|
|
go c.heartbeatLoop(transportStopCtx, rt.epoch)
|
|
}
|
|
go c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, rt.epoch)
|
|
go c.loadMessageLoop(rt)
|
|
return nil
|
|
}
|
|
|
|
func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime, runKeyExchange bool, stopSessionOnFailure bool) error {
|
|
if rt == nil {
|
|
return nil
|
|
}
|
|
c.resetClientPeerAttachAuth()
|
|
c.activateClientBootstrapTransportProtection()
|
|
if runKeyExchange && !c.skipKeyExchange {
|
|
if err := c.keyExchangeFn(c); err != nil {
|
|
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err)
|
|
}
|
|
}
|
|
if err := c.announceClientPeerIdentity(); err != nil {
|
|
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err)
|
|
}
|
|
c.activateClientSteadyTransportProtection()
|
|
return nil
|
|
}
|
|
|
|
func (c *ClientCommon) failClientTransportBootstrap(rt *clientSessionRuntime, stopSessionOnFailure bool, reason string, err error) error {
|
|
if c == nil || rt == nil {
|
|
return err
|
|
}
|
|
c.retireClientSessionRuntime(rt, true)
|
|
c.closeClientTransportConn(rt.conn)
|
|
if stopSessionOnFailure {
|
|
c.stopClientSessionIfCurrent(rt.epoch, reason, err)
|
|
return err
|
|
}
|
|
c.clearClientSessionRuntimeTransport()
|
|
return err
|
|
}
|
|
|
|
func (c *ClientCommon) Heartbeat() {
|
|
rt := c.clientSessionRuntimeSnapshot()
|
|
if rt == nil {
|
|
return
|
|
}
|
|
epoch := rt.epoch
|
|
if epoch == 0 {
|
|
epoch = c.currentClientSessionEpoch()
|
|
}
|
|
transportStopCtx := rt.transportStopCtx
|
|
if transportStopCtx == nil {
|
|
transportStopCtx = rt.stopCtx
|
|
}
|
|
c.heartbeatLoop(transportStopCtx, epoch)
|
|
}
|
|
|
|
func (c *ClientCommon) heartbeatLoop(stopCtx context.Context, epoch uint64) {
|
|
if stopCtx == nil {
|
|
return
|
|
}
|
|
failedCount := 0
|
|
for {
|
|
select {
|
|
case <-stopCtx.Done():
|
|
return
|
|
case <-time.After(c.heartbeatPeriod):
|
|
}
|
|
err := c.sendHeartbeat()
|
|
var stop bool
|
|
failedCount, stop = c.handleHeartbeatResultWithSession(epoch, err, failedCount)
|
|
if stop {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *ClientCommon) readMessage() {
|
|
rt := c.clientSessionRuntimeSnapshot()
|
|
if rt == nil {
|
|
return
|
|
}
|
|
epoch := rt.epoch
|
|
if epoch == 0 {
|
|
epoch = c.currentClientSessionEpoch()
|
|
}
|
|
transportStopCtx := rt.transportStopCtx
|
|
if transportStopCtx == nil {
|
|
transportStopCtx = rt.stopCtx
|
|
}
|
|
c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, epoch)
|
|
}
|
|
|
|
func (c *ClientCommon) readMessageLoop(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, epoch uint64) {
|
|
if stopCtx == nil {
|
|
return
|
|
}
|
|
binding := newTransportBinding(conn, queue)
|
|
dispatcher := c.clientInboundDispatcherSnapshot()
|
|
if conn != nil && queue != nil && !isPacketTransportConn(conn) {
|
|
reader := newTransportFrameReader(conn, queue)
|
|
for {
|
|
select {
|
|
case <-stopCtx.Done():
|
|
c.closeClientTransportBinding(binding)
|
|
return
|
|
default:
|
|
}
|
|
payload, release, err := c.readTransportPayloadPooled(conn, reader)
|
|
if !c.handleTransportPayloadReadResultWithSession(stopCtx, binding, payload, release, err, epoch, dispatcher) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
buf := streamReadBuffer()
|
|
for {
|
|
select {
|
|
case <-stopCtx.Done():
|
|
c.closeClientTransportBinding(binding)
|
|
return
|
|
default:
|
|
}
|
|
readNum, data, err := c.readFromTransportBindingWithBuffer(binding, buf)
|
|
if !c.handleTransportReadResultWithSessionDispatcher(stopCtx, conn, queue, readNum, data, err, epoch, dispatcher) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *ClientCommon) sayGoodBye() error {
|
|
_, err := c.sendWait(TransferMsg{
|
|
ID: 10010,
|
|
Key: "bye",
|
|
Value: nil,
|
|
Type: MSG_SYS_WAIT,
|
|
}, time.Second*3)
|
|
return err
|
|
}
|
|
|
|
func (c *ClientCommon) loadMessage() {
|
|
rt := c.clientSessionRuntimeSnapshot()
|
|
if rt == nil {
|
|
return
|
|
}
|
|
c.loadMessageLoop(rt)
|
|
}
|
|
|
|
func (c *ClientCommon) loadMessageLoop(rt *clientSessionRuntime) {
|
|
if rt == nil {
|
|
return
|
|
}
|
|
stopCtx := rt.transportStopCtx
|
|
if stopCtx == nil {
|
|
stopCtx = rt.stopCtx
|
|
}
|
|
if stopCtx == nil {
|
|
return
|
|
}
|
|
queue := rt.queue
|
|
if rt.transport != nil {
|
|
queue = rt.transport.queueSnapshot()
|
|
}
|
|
if queue == nil {
|
|
return
|
|
}
|
|
dispatcher := rt.inboundDispatcher
|
|
if dispatcher == nil {
|
|
dispatcher = newInboundDispatcher()
|
|
defer dispatcher.CloseAndWait()
|
|
}
|
|
for {
|
|
select {
|
|
case <-stopCtx.Done():
|
|
sessionStopping := rt.stopCtx != nil && rt.stopCtx.Err() != nil
|
|
if sessionStopping && rt.inboundDispatcher != nil {
|
|
rt.inboundDispatcher.CloseAndWait()
|
|
}
|
|
if sessionStopping && !rt.runtimeShouldSuppressGoodByeOnStop() && c.shouldSayGoodByeOnStop() {
|
|
c.sayGoodBye()
|
|
}
|
|
c.closeClientTransportBinding(rt.transport)
|
|
return
|
|
case data, ok := <-queue.RestoreChan():
|
|
if !ok {
|
|
continue
|
|
}
|
|
msg := data
|
|
c.wg.Add(1)
|
|
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
|
|
defer c.wg.Done()
|
|
now := time.Now()
|
|
if err := c.dispatchInboundTransportPayload(msg.Msg, now); err != nil {
|
|
if c.showError || c.debugMode {
|
|
fmt.Println("client decode envelope error", err)
|
|
}
|
|
}
|
|
}) {
|
|
c.wg.Done()
|
|
}
|
|
}
|
|
}
|
|
}
|