package notify import ( "b612.me/stario" "net" "sync" "time" ) // transportBinding models the currently attached physical transport for a // logical session. The binding can be swapped later without forcing callers to // reach into raw conn fields directly. type transportBinding struct { conn net.Conn queue *stario.StarQueue writeMu sync.Mutex adaptiveTx adaptiveTxState controlMu sync.Mutex controlSender *controlBatchSender streamMu sync.Mutex streamSender *streamBatchSender bulkMu sync.Mutex bulkSender *bulkBatchSender } func newTransportBinding(conn net.Conn, queue *stario.StarQueue) *transportBinding { if conn == nil && queue == nil { return nil } return &transportBinding{ conn: conn, queue: queue, } } func (b *transportBinding) connSnapshot() net.Conn { if b == nil { return nil } return b.conn } func (b *transportBinding) queueSnapshot() *stario.StarQueue { if b == nil { return nil } return b.queue } func (b *transportBinding) withConnWriteLock(fn func(net.Conn) error) error { return b.withConnWriteLockDeadline(time.Time{}, fn) } func (b *transportBinding) withConnWriteLockDeadline(deadline time.Time, fn func(net.Conn) error) error { if b == nil { return net.ErrClosed } b.writeMu.Lock() defer b.writeMu.Unlock() conn := b.connSnapshot() if conn == nil { return net.ErrClosed } if !deadline.IsZero() { if err := conn.SetWriteDeadline(deadline); err != nil { return err } defer func() { _ = conn.SetWriteDeadline(time.Time{}) }() } return fn(conn) } func (b *transportBinding) bulkBatchSenderSnapshotWithCodec(codec bulkBatchCodec, writeTimeout func() time.Duration) *bulkBatchSender { if b == nil { return nil } b.bulkMu.Lock() defer b.bulkMu.Unlock() if b.bulkSender != nil { return b.bulkSender } b.bulkSender = newBulkBatchSender(b, codec, writeTimeout) return b.bulkSender } func (b *transportBinding) clientBulkBatchSenderSnapshot(c *ClientCommon) *bulkBatchSender { if b == nil || c == nil { return nil } return b.bulkBatchSenderSnapshotWithCodec(bulkBatchCodec{ encodeSingle: c.encodeBulkFastPayloadPooled, encodeBatch: c.encodeBulkFastBatchPayloadPooled, }, c.maxWriteTimeoutSnapshot) } func (b *transportBinding) serverBulkBatchSenderSnapshot(logical *LogicalConn) *bulkBatchSender { if b == nil || logical == nil { return nil } server := logical.Server() common, ok := server.(*ServerCommon) if !ok || common == nil { return nil } return b.bulkBatchSenderSnapshotWithCodec(bulkBatchCodec{ encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) { return common.encodeBulkFastPayloadLogicalPooled(logical, frame) }, encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) { return common.encodeBulkFastBatchPayloadLogicalPooled(logical, frames) }, }, logical.maxWriteTimeoutSnapshot) } func (b *transportBinding) controlBatchSenderSnapshot() *controlBatchSender { if b == nil { return nil } b.controlMu.Lock() defer b.controlMu.Unlock() if b.controlSender != nil { return b.controlSender } b.controlSender = newControlBatchSender(b) return b.controlSender } func (b *transportBinding) streamBatchSenderSnapshotWithCodec(codec streamBatchCodec, writeTimeout func() time.Duration) *streamBatchSender { if b == nil { return nil } b.streamMu.Lock() defer b.streamMu.Unlock() if b.streamSender != nil { return b.streamSender } b.streamSender = newStreamBatchSender(b, codec, writeTimeout) return b.streamSender } func (b *transportBinding) clientStreamBatchSenderSnapshot(c *ClientCommon) *streamBatchSender { if b == nil || c == nil { return nil } return b.streamBatchSenderSnapshotWithCodec(streamBatchCodec{ encodeSingle: c.encodeFastStreamPayload, encodeBatch: c.encodeFastStreamBatchPayload, }, c.maxWriteTimeoutSnapshot) } func (b *transportBinding) serverStreamBatchSenderSnapshot(logical *LogicalConn) *streamBatchSender { if b == nil || logical == nil { return nil } server := logical.Server() common, ok := server.(*ServerCommon) if !ok || common == nil { return nil } return b.streamBatchSenderSnapshotWithCodec(streamBatchCodec{ encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { return common.encodeFastStreamPayloadLogical(logical, frame) }, encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { return common.encodeFastStreamBatchPayloadLogical(logical, frames) }, }, logical.maxWriteTimeoutSnapshot) } func (b *transportBinding) stopBackgroundWorkers() { if b == nil { return } b.controlMu.Lock() controlSender := b.controlSender b.controlMu.Unlock() b.streamMu.Lock() streamSender := b.streamSender b.streamMu.Unlock() b.bulkMu.Lock() bulkSender := b.bulkSender b.bulkMu.Unlock() if controlSender != nil { controlSender.stop() } if streamSender != nil { streamSender.stop() } if bulkSender != nil { bulkSender.stop() } }