notify/server_dual_api_test.go
starainrt 09d972c7b7
feat(notify): 重构通信内核并补齐 stream/bulk/record/transfer 能力
- 引入 LogicalConn/TransportConn 分层,ClientConn 保留兼容适配层
  - 新增 Stream、Bulk、RecordStream 三条数据面能力及对应控制路径
  - 完成 transfer/file 传输内核与状态快照、诊断能力
  - 补齐 reconnect、inbound dispatcher、modern psk 等基础模块
  - 增加大规模回归、并发与基准测试覆盖
  - 更新依赖库
2026-04-15 15:24:36 +08:00

169 lines
5.5 KiB
Go

package notify
import (
"b612.me/stario"
"context"
"errors"
"math"
"net"
"os"
"testing"
"time"
)
func TestServerLogicalAndTransportLookupAPIs(t *testing.T) {
server := NewServer().(*ServerCommon)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
logical := server.bootstrapAcceptedLogical("logical-lookup", nil, left)
if logical == nil {
t.Fatal("bootstrapAcceptedLogical should return logical")
}
if got := server.GetLogicalConn(logical.ClientID); got != logical {
t.Fatalf("GetLogicalConn mismatch: got %+v want %+v", got, logical)
}
transportByID := server.GetCurrentTransportConn(logical.ClientID)
if transportByID == nil {
t.Fatal("GetCurrentTransportConn should expose current transport")
}
transportByLogical := server.GetCurrentTransportConnByLogical(logical)
if transportByLogical == nil {
t.Fatal("GetCurrentTransportConnByLogical should expose current transport")
}
if got, want := transportByID.ClientID(), logical.ClientID; got != want {
t.Fatalf("transport client id mismatch: got %q want %q", got, want)
}
if got, want := transportByID.TransportGeneration(), transportByLogical.TransportGeneration(); got != want {
t.Fatalf("transport generation mismatch: got %d want %d", got, want)
}
if !transportByID.IsCurrent() || !transportByLogical.IsCurrent() {
t.Fatal("lookup transports should be current")
}
list := server.GetCurrentTransportConnList()
if len(list) != 1 {
t.Fatalf("GetCurrentTransportConnList len = %d, want 1", len(list))
}
if got, want := list[0].ClientID(), logical.ClientID; got != want {
t.Fatalf("transport list client id mismatch: got %q want %q", got, want)
}
}
func TestServerSendLogicalAndTransportAPIs(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
server.setServerSessionRuntime(&serverSessionRuntime{
stopCtx: stopCtx,
stopFn: stopFn,
queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32),
})
server.markSessionStarted()
defer server.markSessionStopped("test done", nil)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
logical, _, _ := newRegisteredServerLogicalForTest(t, server, "api-send", left, stopCtx, stopFn)
logical.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey)
type readResult struct {
msg TransferMsg
err error
}
readOneAsync := func() <-chan readResult {
t.Helper()
ch := make(chan readResult, 1)
go func() {
_ = right.SetReadDeadline(time.Now().Add(time.Second))
reader := stario.NewFrameReader(right, nil)
payload, err := reader.Next()
if err != nil {
ch <- readResult{err: err}
return
}
env, err := server.decodeEnvelopeLogical(logical, payload)
if err != nil {
ch <- readResult{err: err}
return
}
msg, err := unwrapTransferMsgEnvelope(env, server.sequenceDe)
ch <- readResult{msg: msg, err: err}
}()
return ch
}
logicalRead := readOneAsync()
if err := server.SendLogical(logical, "logical", MsgVal("payload")); err != nil {
t.Fatalf("SendLogical failed: %v", err)
}
if got := <-logicalRead; got.err != nil {
t.Fatalf("SendLogical decode failed: %v", got.err)
} else if got.msg.Key != "logical" || got.msg.Type != MSG_ASYNC || string(got.msg.Value) != "payload" {
t.Fatalf("SendLogical decoded message mismatch: %+v", got.msg)
}
transport := server.GetCurrentTransportConn(logical.ClientID)
if transport == nil {
t.Fatal("GetCurrentTransportConn should expose current transport")
}
transportRead := readOneAsync()
if err := server.SendTransport(transport, "transport", MsgVal("payload")); err != nil {
t.Fatalf("SendTransport failed: %v", err)
}
if got := <-transportRead; got.err != nil {
t.Fatalf("SendTransport decode failed: %v", got.err)
} else if got.msg.Key != "transport" || got.msg.Type != MSG_ASYNC || string(got.msg.Value) != "payload" {
t.Fatalf("SendTransport decoded message mismatch: %+v", got.msg)
}
}
func TestServerSendFileTransportRejectsStaleTransport(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
server.markSessionStarted()
defer server.markSessionStopped("test done", nil)
firstLeft, firstRight := net.Pipe()
defer firstRight.Close()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
logical, _, _ := newRegisteredServerLogicalForTest(t, server, "api-file-stale", firstLeft, stopCtx, stopFn)
logical.applyClientConnAttachmentProfile(0, 100*time.Millisecond, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey)
staleTransport := server.GetCurrentTransportConn(logical.ClientID)
if staleTransport == nil {
t.Fatal("initial transport should exist")
}
secondLeft, secondRight := net.Pipe()
defer secondLeft.Close()
defer secondRight.Close()
if err := logical.attachClientConnSessionTransport(secondLeft); err != nil {
t.Fatalf("attachClientConnSessionTransport failed: %v", err)
}
file, err := os.CreateTemp(t.TempDir(), "notify-send-file-*")
if err != nil {
t.Fatalf("CreateTemp failed: %v", err)
}
if _, err := file.WriteString("payload"); err != nil {
t.Fatalf("WriteString failed: %v", err)
}
if err := file.Close(); err != nil {
t.Fatalf("Close temp file failed: %v", err)
}
err = server.SendFileTransport(context.Background(), staleTransport, file.Name())
if !errors.Is(err, errTransportDetached) {
t.Fatalf("SendFileTransport stale error = %v, want errors.Is(..., %v)", err, errTransportDetached)
}
}