notify/logical_transport_peer_fields_test.go

331 lines
10 KiB
Go
Raw Normal View History

package notify
import (
"context"
"errors"
"math"
"net"
"testing"
"time"
"b612.me/stario"
)
func TestHydrateServerMessagePeerFieldsFromLogicalConn(t *testing.T) {
server := NewServer().(*ServerCommon)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
logical := server.bootstrapAcceptedLogical("peer-fields-message", nil, left)
if logical == nil {
t.Fatal("bootstrapAcceptedLogical should return logical")
}
client := clientConnFromLogical(logical)
message := hydrateServerMessagePeerFields(Message{
NetType: NET_SERVER,
LogicalConn: logical,
})
if message.ClientConn != client {
t.Fatal("ClientConn should alias LogicalConn after hydration")
}
if got := messageLogicalConnSnapshot(&message); got != logical {
t.Fatal("messageLogicalConnSnapshot mismatch")
}
transport := messageTransportConnSnapshot(&message)
if transport == nil {
t.Fatal("message transport should be hydrated")
}
if got := transport.LogicalConn(); got != logical {
t.Fatal("message transport logical conn mismatch")
}
}
func TestServerPublishSendFileEventHydratesLogicalConnAlias(t *testing.T) {
server := NewServer().(*ServerCommon)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
logical := server.bootstrapAcceptedLogical("peer-fields-file", nil, left)
if logical == nil {
t.Fatal("bootstrapAcceptedLogical should return logical")
}
client := clientConnFromLogical(logical)
var observed []FileEvent
server.setFileEventObserver(func(event FileEvent) {
observed = append(observed, event)
})
server.publishSendFileEvent(FileEvent{
NetType: NET_SERVER,
ClientConn: client,
Kind: EnvelopeFileMeta,
Packet: FilePacket{FileID: "file-1"},
})
if got, want := len(observed), 1; got != want {
t.Fatalf("observed count = %d, want %d", got, want)
}
if observed[0].LogicalConn != logical {
t.Fatal("LogicalConn should be hydrated from compatibility alias")
}
if observed[0].ClientConn != client {
t.Fatal("ClientConn compatibility alias mismatch")
}
if observed[0].TransportConn == nil {
t.Fatal("TransportConn should be hydrated for server file event")
}
}
func TestMessageReplyUsesLogicalConnOnly(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
server.setServerSessionRuntime(&serverSessionRuntime{
stopCtx: stopCtx,
stopFn: stopFn,
queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32),
})
server.markSessionStarted()
defer server.markSessionStopped("test done", nil)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
client, _, _ := newRegisteredServerClientForTest(t, server, "reply-logical", left, stopCtx, stopFn)
client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey)
logical := logicalConnFromClient(client)
expectedReply := TransferMsg{
ID: 1,
Key: "reply",
Value: MsgVal("ok"),
Type: MSG_SYNC_REPLY,
}
env, err := wrapTransferMsgEnvelope(expectedReply, server.sequenceEn)
if err != nil {
t.Fatalf("wrapTransferMsgEnvelope failed: %v", err)
}
payload, err := server.encodeEnvelopePayloadLogical(logical, env)
if err != nil {
t.Fatalf("encodeEnvelopePayloadLogical failed: %v", err)
}
want := server.serverQueueSnapshot().BuildMessage(payload)
if len(want) == 0 {
t.Fatal("expected framed reply payload")
}
errCh := make(chan error, 1)
recvCh := make(chan []byte, 1)
go func() {
_ = right.SetReadDeadline(time.Now().Add(time.Second))
buf := make([]byte, len(want))
read := 0
for read < len(want) {
n, err := right.Read(buf[read:])
if n > 0 {
read += n
}
if err != nil {
errCh <- err
return
}
}
recvCh <- buf[:read]
}()
message := Message{
NetType: NET_SERVER,
LogicalConn: logical,
TransferMsg: TransferMsg{
ID: 1,
Key: "reply",
Type: MSG_SYNC_ASK,
},
Time: time.Now(),
}
if err := message.Reply(MsgVal("ok")); err != nil {
t.Fatalf("Message.Reply failed: %v", err)
}
select {
case err := <-errCh:
t.Fatalf("reply read failed: %v", err)
case got := <-recvCh:
var decoded TransferMsg
frames := 0
if err := server.serverQueueSnapshot().ParseMessageView(got, "reply-test", func(view stario.FrameView) error {
frames++
env, err := server.decodeEnvelopeLogical(logical, view.Payload)
if err != nil {
return err
}
decoded, err = unwrapTransferMsgEnvelope(env, server.sequenceDe)
return err
}); err != nil {
t.Fatalf("failed to decode framed reply payload: %v", err)
}
if frames != 1 {
t.Fatalf("decoded frame count = %d, want 1", frames)
}
if decoded.ID != expectedReply.ID || decoded.Key != expectedReply.Key || decoded.Type != expectedReply.Type || string(decoded.Value) != string(expectedReply.Value) {
t.Fatalf("decoded reply mismatch: got %+v want %+v", decoded, expectedReply)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for reply payload")
}
}
func TestTransferControlServerLogicalAndTransportValidation(t *testing.T) {
server := NewServer()
req := TransferBeginRequest{TransferID: "tx-validate"}
if _, err := SendTransferBeginLogical(context.Background(), server, nil, req); !errors.Is(err, errTransferControlLogicalConnNil) {
t.Fatalf("SendTransferBeginLogical nil logical error = %v, want %v", err, errTransferControlLogicalConnNil)
}
if _, err := SendTransferBeginTransport(context.Background(), server, nil, req); !errors.Is(err, errTransferControlTransportNil) {
t.Fatalf("SendTransferBeginTransport nil transport error = %v, want %v", err, errTransferControlTransportNil)
}
}
func TestTransferControlServerLogicalAndTransportBeginAPIs(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
beginReqCh := make(chan TransferBeginRequest, 2)
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
// This test validates the explicit transfer-control link bindings rather
// than the builtin transfer plane receiver, so disable the builtin control
// interception on the client side first.
clientTransferState := client.getTransferState()
clientTransferState.mu.Lock()
clientTransferState.controlEnabled = false
clientTransferState.handler = nil
clientTransferState.builtinHandler = nil
clientTransferState.mu.Unlock()
if err := BindTransferControlClient(client, TransferControlHandler{
Begin: func(_ *Message, req TransferBeginRequest) (TransferBeginResponse, error) {
beginReqCh <- req
resp := TransferBeginResponse{
TransferID: req.TransferID,
Accepted: true,
}
switch req.TransferID {
case "tx-logical":
resp.NextOffset = 111
case "tx-transport":
resp.NextOffset = 222
}
return resp, nil
},
}); err != nil {
t.Fatalf("BindTransferControlClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
transport := logical.CurrentTransportConn()
if transport == nil {
t.Fatal("CurrentTransportConn should expose active transport")
}
logicalResp, err := SendTransferBeginLogical(context.Background(), server, logical, TransferBeginRequest{
TransferID: "tx-logical",
Channel: TransferChannelControl,
Size: 256,
})
if err != nil {
t.Fatalf("SendTransferBeginLogical failed: %v", err)
}
if !logicalResp.Accepted || logicalResp.TransferID != "tx-logical" || logicalResp.NextOffset != 111 {
t.Fatalf("logical begin response mismatch: %+v", logicalResp)
}
select {
case got := <-beginReqCh:
if got.TransferID != "tx-logical" || got.Channel != TransferChannelControl || got.Size != 256 {
t.Fatalf("logical begin request mismatch: %+v", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for logical begin request")
}
transportResp, err := SendTransferBeginTransport(context.Background(), server, transport, TransferBeginRequest{
TransferID: "tx-transport",
Channel: TransferChannelData,
Size: 512,
})
if err != nil {
t.Fatalf("SendTransferBeginTransport failed: %v", err)
}
if !transportResp.Accepted || transportResp.TransferID != "tx-transport" || transportResp.NextOffset != 222 {
t.Fatalf("transport begin response mismatch: %+v", transportResp)
}
select {
case got := <-beginReqCh:
if got.TransferID != "tx-transport" || got.Channel != TransferChannelData || got.Size != 512 {
t.Fatalf("transport begin request mismatch: %+v", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for transport begin request")
}
logicalSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-logical")
if err != nil {
t.Fatalf("GetServerTransferSnapshotByID logical failed: %v", err)
}
if !ok {
t.Fatal("logical transfer snapshot should exist")
}
if got, want := logicalSnapshot.Scope, serverFileScope(logical); got != want {
t.Fatalf("logical transfer scope = %q, want %q", got, want)
}
if got, want := logicalSnapshot.RuntimeScope, serverTransportScope(logical); got != want {
t.Fatalf("logical transfer runtime scope = %q, want %q", got, want)
}
if got, want := logicalSnapshot.AckedBytes, int64(111); got != want {
t.Fatalf("logical transfer acked bytes = %d, want %d", got, want)
}
transportSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-transport")
if err != nil {
t.Fatalf("GetServerTransferSnapshotByID transport failed: %v", err)
}
if !ok {
t.Fatal("transport transfer snapshot should exist")
}
if got, want := transportSnapshot.Scope, serverFileScope(logical); got != want {
t.Fatalf("transport transfer scope = %q, want %q", got, want)
}
if got, want := transportSnapshot.RuntimeScope, serverTransportScopeForTransport(transport); got != want {
t.Fatalf("transport transfer runtime scope = %q, want %q", got, want)
}
if got, want := transportSnapshot.AckedBytes, int64(222); got != want {
t.Fatalf("transport transfer acked bytes = %d, want %d", got, want)
}
}