notify/client_conn.go

438 lines
10 KiB
Go
Raw Permalink Normal View History

package notify
import (
"b612.me/starcrypto"
"fmt"
"net"
"sync/atomic"
"time"
)
type clientConnTransportDetachState struct {
Generation uint64
Reason string
Err string
At time.Time
}
const (
clientConnTransportDetachKindReadError = "read_error"
clientConnTransportDetachKindHeartbeatTimeout = "heartbeat_timeout"
clientConnTransportDetachKindOther = "other"
)
type ClientConn struct {
alive atomic.Value
status Status
logicalView atomic.Pointer[LogicalConn]
logicalState atomic.Pointer[logicalConnState]
runtimeState atomic.Pointer[logicalConnRuntimeState]
transportState atomic.Pointer[clientConnTransportState]
sessionRuntime atomic.Pointer[clientConnSessionRuntime]
attachment atomic.Pointer[clientConnAttachmentState]
identityBound atomic.Bool
ClientID string
ClientAddr net.Addr
server Server
}
type Status struct {
Alive bool
Reason string
Err error
}
func (c *ClientConn) readTUMessage() {
if logical := c.LogicalConn(); logical != nil {
logical.readTUMessage()
return
}
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil {
return
}
c.readTUMessageLoop(rt)
}
func (c *ClientConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
if logical := c.LogicalConn(); logical != nil {
logical.readTUMessageLoop(rt)
return
}
if rt == nil {
return
}
stopCtx := rt.transportStopCtx
if stopCtx == nil {
stopCtx = rt.stopCtx
}
if stopCtx == nil {
return
}
conn := rt.tuConn
generation := rt.transportGeneration
defer closeClientConnSessionRuntimeTransportDone(rt)
buf := streamReadBuffer()
for {
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseClientConnTransportOnStop(conn) {
_ = conn.Close()
}
return
default:
}
num, data, err := c.readFromTUTransportConnWithBuffer(conn, buf)
if !c.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) {
return
}
}
}
// Deprecated: rsaDecode exists only for the legacy MSG_KEY_CHANGE flow.
func (c *ClientConn) rsaDecode(message Message) {
privKey, err := starcrypto.DecodeRsaPrivateKey(c.clientConnHandshakeRsaKeySnapshot(), "")
if err != nil {
fmt.Println(err)
message.Reply([]byte("failed"))
return
}
data, err := starcrypto.RSADecrypt(privKey, message.Value)
if err != nil {
fmt.Println(err)
message.Reply([]byte("failed"))
return
}
message.Reply([]byte("success"))
c.setClientConnSecretKey(data)
}
func (c *ClientConn) sayGoodByeForTU() error {
if c == nil || c.server == nil {
return errTransportDetached
}
_, err := c.server.SendWaitLogical(c.LogicalConn(), "bye", nil, time.Second*3)
if err == nil {
return nil
}
_, err = c.server.sendWait(c, TransferMsg{
ID: 10010,
Key: "bye",
Value: nil,
Type: MSG_SYS_WAIT,
}, time.Second*3)
return err
}
func (c *ClientConn) GetSecretKey() []byte {
return c.clientConnSecretKeySnapshot()
}
// Deprecated: SetSecretKey injects a raw per-connection transport key directly.
func (c *ClientConn) SetSecretKey(key []byte) {
c.setClientConnSecretKey(key)
}
func (c *ClientConn) GetMsgEn() func([]byte, []byte) []byte {
return c.clientConnMsgEnSnapshot()
}
// Deprecated: SetMsgEn overrides the per-connection transport codec directly.
func (c *ClientConn) SetMsgEn(fn func([]byte, []byte) []byte) {
c.setClientConnMsgEn(fn)
}
func (c *ClientConn) GetMsgDe() func([]byte, []byte) []byte {
return c.clientConnMsgDeSnapshot()
}
// Deprecated: SetMsgDe overrides the per-connection transport codec directly.
func (c *ClientConn) SetMsgDe(fn func([]byte, []byte) []byte) {
c.setClientConnMsgDe(fn)
}
func (c *ClientConn) StopMonitorChan() <-chan struct{} {
return sessionStopChan(c.clientConnStopContextSnapshot())
}
func (c *ClientConn) Status() Status {
return c.clientConnStatusSnapshot()
}
func (c *ClientConn) Server() Server {
if c != nil {
if logical := c.logicalView.Load(); logical != nil {
if server := logical.Server(); server != nil {
return server
}
}
}
return c.server
}
func (c *ClientConn) GetRemoteAddr() net.Addr {
return c.clientConnRemoteAddrSnapshot()
}
func (c *ClientConn) markClientConnIdentityBound() {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.markIdentityBound()
return
}
state := c.ensureLogicalConnState()
if state == nil {
c.identityBound.Store(true)
return
}
state.updatePeer(func(peer *logicalConnPeerState) {
peer.identityBound = true
})
c.syncLegacyLogicalFieldsFromState(state)
}
func (c *ClientConn) clientConnIdentityBoundSnapshot() bool {
if c == nil {
return false
}
return c.clientConnLogicalPeerStateSnapshot().identityBound
}
func (c *ClientConn) markClientConnStreamTransport() {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.markStreamTransport()
return
}
state := c.ensureClientConnTransportState()
if state == nil {
return
}
state.streamTransport.Store(true)
}
func (c *ClientConn) clientConnUsesStreamTransportSnapshot() bool {
state := c.ensureClientConnTransportState()
if state == nil {
return false
}
return state.streamTransport.Load()
}
func (c *ClientConn) shouldPreserveLogicalPeerOnTransportLoss() bool {
if c == nil {
return false
}
return c.clientConnIdentityBoundSnapshot() && c.clientConnUsesStreamTransportSnapshot()
}
func (c *ClientConn) markClientConnTransportAttached() uint64 {
if c == nil {
return 0
}
if logical := c.logicalView.Load(); logical != nil {
return logical.markTransportAttached()
}
state := c.ensureClientConnTransportState()
if state == nil {
return 0
}
gen := state.transportGen.Add(1)
state.attachCount.Add(1)
state.lastAttachAt.Store(time.Now().UnixNano())
return gen
}
func (c *ClientConn) clientConnTransportGenerationSnapshot() uint64 {
state := c.ensureClientConnTransportState()
if state == nil {
return 0
}
return state.transportGen.Load()
}
func (c *ClientConn) clientConnTransportAttachCountSnapshot() uint64 {
state := c.ensureClientConnTransportState()
if state == nil {
return 0
}
return state.attachCount.Load()
}
func (c *ClientConn) markClientConnTransportDetached(reason string, err error) {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.markTransportDetached(reason, err)
return
}
state := c.ensureClientConnTransportState()
if state == nil {
return
}
detachState := &clientConnTransportDetachState{
Generation: c.clientConnTransportGenerationSnapshot(),
Reason: reason,
At: time.Now(),
}
if err != nil {
detachState.Err = err.Error()
}
state.detachCount.Add(1)
state.transportDetach.Store(detachState)
}
func (c *ClientConn) clientConnTransportDetachCountSnapshot() uint64 {
state := c.ensureClientConnTransportState()
if state == nil {
return 0
}
return state.detachCount.Load()
}
func (c *ClientConn) clearClientConnTransportDetachState() {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.clearTransportDetachState()
return
}
c.setClientConnTransportDetachState(nil)
}
func (c *ClientConn) clientConnTransportDetachSnapshot() *clientConnTransportDetachState {
state := c.ensureClientConnTransportState()
if state == nil {
return nil
}
return cloneClientConnTransportDetachState(state.transportDetach.Load())
}
func (c *ClientConn) clientConnLogicalTransportDetachedSnapshot() bool {
if c == nil {
return false
}
if !c.clientConnIdentityBoundSnapshot() || !c.clientConnUsesStreamTransportSnapshot() {
return false
}
if !c.clientConnAliveSnapshot() {
return false
}
return !c.clientConnTransportAttachedSnapshot()
}
func (c *ClientConn) clientConnLastTransportAttachedAtSnapshot() time.Time {
state := c.ensureClientConnTransportState()
if state == nil {
return time.Time{}
}
unixNano := state.lastAttachAt.Load()
if unixNano == 0 {
return time.Time{}
}
return time.Unix(0, unixNano)
}
func classifyClientConnTransportDetachReason(reason string) string {
switch reason {
case "":
return ""
case "read error":
return clientConnTransportDetachKindReadError
case "heartbeat timeout":
return clientConnTransportDetachKindHeartbeatTimeout
default:
return clientConnTransportDetachKindOther
}
}
func (c *ClientConn) clientConnTransportDetachKindSnapshot() string {
if c == nil {
return ""
}
detach := c.clientConnTransportDetachSnapshot()
if detach == nil {
return ""
}
return classifyClientConnTransportDetachReason(detach.Reason)
}
func (c *ClientConn) clientConnTransportDetachGenerationSnapshot() uint64 {
if c == nil {
return 0
}
detach := c.clientConnTransportDetachSnapshot()
if detach == nil {
return 0
}
if detach.Generation == 0 {
return c.clientConnTransportGenerationSnapshot()
}
return detach.Generation
}
func (c *ClientConn) clientConnTransportDetachExpirySnapshot() (time.Time, bool) {
if c == nil {
return time.Time{}, false
}
detach := c.clientConnTransportDetachSnapshot()
if detach == nil || detach.At.IsZero() {
return time.Time{}, false
}
if c.server == nil {
return time.Time{}, false
}
keepSec := c.server.DetachedClientKeepSec()
if keepSec <= 0 {
return time.Time{}, false
}
return detach.At.Add(time.Duration(keepSec) * time.Second), true
}
func (c *ClientConn) clientConnTransportDetachExpiredSnapshot(now time.Time) bool {
if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() {
return false
}
expiry, ok := c.clientConnTransportDetachExpirySnapshot()
if !ok {
return false
}
return !now.Before(expiry)
}
func (c *ClientConn) clientConnTransportDetachRemainingSnapshot(now time.Time) time.Duration {
if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() {
return 0
}
expiry, ok := c.clientConnTransportDetachExpirySnapshot()
if !ok {
return 0
}
if !now.Before(expiry) {
return 0
}
return expiry.Sub(now)
}
func (c *ClientConn) clientConnReattachEligibleSnapshot(now time.Time) bool {
if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() {
return false
}
if !c.clientConnAliveSnapshot() {
return false
}
if c.clientConnTransportAttachedSnapshot() {
return false
}
if c.clientConnTransportDetachExpiredSnapshot(now) {
return false
}
return true
}