package notify import ( "context" "errors" "math" "net" "testing" "time" "b612.me/stario" ) func TestHydrateServerMessagePeerFieldsFromLogicalConn(t *testing.T) { server := NewServer().(*ServerCommon) left, right := net.Pipe() defer left.Close() defer right.Close() logical := server.bootstrapAcceptedLogical("peer-fields-message", nil, left) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } client := clientConnFromLogical(logical) message := hydrateServerMessagePeerFields(Message{ NetType: NET_SERVER, LogicalConn: logical, }) if message.ClientConn != client { t.Fatal("ClientConn should alias LogicalConn after hydration") } if got := messageLogicalConnSnapshot(&message); got != logical { t.Fatal("messageLogicalConnSnapshot mismatch") } transport := messageTransportConnSnapshot(&message) if transport == nil { t.Fatal("message transport should be hydrated") } if got := transport.LogicalConn(); got != logical { t.Fatal("message transport logical conn mismatch") } } func TestServerPublishSendFileEventHydratesLogicalConnAlias(t *testing.T) { server := NewServer().(*ServerCommon) left, right := net.Pipe() defer left.Close() defer right.Close() logical := server.bootstrapAcceptedLogical("peer-fields-file", nil, left) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } client := clientConnFromLogical(logical) var observed []FileEvent server.setFileEventObserver(func(event FileEvent) { observed = append(observed, event) }) server.publishSendFileEvent(FileEvent{ NetType: NET_SERVER, ClientConn: client, Kind: EnvelopeFileMeta, Packet: FilePacket{FileID: "file-1"}, }) if got, want := len(observed), 1; got != want { t.Fatalf("observed count = %d, want %d", got, want) } if observed[0].LogicalConn != logical { t.Fatal("LogicalConn should be hydrated from compatibility alias") } if observed[0].ClientConn != client { t.Fatal("ClientConn compatibility alias mismatch") } if observed[0].TransportConn == nil { t.Fatal("TransportConn should be hydrated for server file event") } } func TestMessageReplyUsesLogicalConnOnly(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() server.setServerSessionRuntime(&serverSessionRuntime{ stopCtx: stopCtx, stopFn: stopFn, queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32), }) server.markSessionStarted() defer server.markSessionStopped("test done", nil) left, right := net.Pipe() defer left.Close() defer right.Close() client, _, _ := newRegisteredServerClientForTest(t, server, "reply-logical", left, stopCtx, stopFn) client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) logical := logicalConnFromClient(client) expectedReply := TransferMsg{ ID: 1, Key: "reply", Value: MsgVal("ok"), Type: MSG_SYNC_REPLY, } env, err := wrapTransferMsgEnvelope(expectedReply, server.sequenceEn) if err != nil { t.Fatalf("wrapTransferMsgEnvelope failed: %v", err) } payload, err := server.encodeEnvelopePayloadLogical(logical, env) if err != nil { t.Fatalf("encodeEnvelopePayloadLogical failed: %v", err) } want := server.serverQueueSnapshot().BuildMessage(payload) if len(want) == 0 { t.Fatal("expected framed reply payload") } errCh := make(chan error, 1) recvCh := make(chan []byte, 1) go func() { _ = right.SetReadDeadline(time.Now().Add(time.Second)) buf := make([]byte, len(want)) read := 0 for read < len(want) { n, err := right.Read(buf[read:]) if n > 0 { read += n } if err != nil { errCh <- err return } } recvCh <- buf[:read] }() message := Message{ NetType: NET_SERVER, LogicalConn: logical, TransferMsg: TransferMsg{ ID: 1, Key: "reply", Type: MSG_SYNC_ASK, }, Time: time.Now(), } if err := message.Reply(MsgVal("ok")); err != nil { t.Fatalf("Message.Reply failed: %v", err) } select { case err := <-errCh: t.Fatalf("reply read failed: %v", err) case got := <-recvCh: var decoded TransferMsg frames := 0 if err := server.serverQueueSnapshot().ParseMessageView(got, "reply-test", func(view stario.FrameView) error { frames++ env, err := server.decodeEnvelopeLogical(logical, view.Payload) if err != nil { return err } decoded, err = unwrapTransferMsgEnvelope(env, server.sequenceDe) return err }); err != nil { t.Fatalf("failed to decode framed reply payload: %v", err) } if frames != 1 { t.Fatalf("decoded frame count = %d, want 1", frames) } if decoded.ID != expectedReply.ID || decoded.Key != expectedReply.Key || decoded.Type != expectedReply.Type || string(decoded.Value) != string(expectedReply.Value) { t.Fatalf("decoded reply mismatch: got %+v want %+v", decoded, expectedReply) } case <-time.After(time.Second): t.Fatal("timed out waiting for reply payload") } } func TestTransferControlServerLogicalAndTransportValidation(t *testing.T) { server := NewServer() req := TransferBeginRequest{TransferID: "tx-validate"} if _, err := SendTransferBeginLogical(context.Background(), server, nil, req); !errors.Is(err, errTransferControlLogicalConnNil) { t.Fatalf("SendTransferBeginLogical nil logical error = %v, want %v", err, errTransferControlLogicalConnNil) } if _, err := SendTransferBeginTransport(context.Background(), server, nil, req); !errors.Is(err, errTransferControlTransportNil) { t.Fatalf("SendTransferBeginTransport nil transport error = %v, want %v", err, errTransferControlTransportNil) } } func TestTransferControlServerLogicalAndTransportBeginAPIs(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { t.Fatalf("server Listen failed: %v", err) } defer func() { _ = server.Stop() }() beginReqCh := make(chan TransferBeginRequest, 2) client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } // This test validates the explicit transfer-control link bindings rather // than the builtin transfer plane receiver, so disable the builtin control // interception on the client side first. clientTransferState := client.getTransferState() clientTransferState.mu.Lock() clientTransferState.controlEnabled = false clientTransferState.handler = nil clientTransferState.builtinHandler = nil clientTransferState.mu.Unlock() if err := BindTransferControlClient(client, TransferControlHandler{ Begin: func(_ *Message, req TransferBeginRequest) (TransferBeginResponse, error) { beginReqCh <- req resp := TransferBeginResponse{ TransferID: req.TransferID, Accepted: true, } switch req.TransferID { case "tx-logical": resp.NextOffset = 111 case "tx-transport": resp.NextOffset = 222 } return resp, nil }, }); err != nil { t.Fatalf("BindTransferControlClient failed: %v", err) } if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { t.Fatalf("client Connect failed: %v", err) } defer func() { _ = client.Stop() }() logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) transport := logical.CurrentTransportConn() if transport == nil { t.Fatal("CurrentTransportConn should expose active transport") } logicalResp, err := SendTransferBeginLogical(context.Background(), server, logical, TransferBeginRequest{ TransferID: "tx-logical", Channel: TransferChannelControl, Size: 256, }) if err != nil { t.Fatalf("SendTransferBeginLogical failed: %v", err) } if !logicalResp.Accepted || logicalResp.TransferID != "tx-logical" || logicalResp.NextOffset != 111 { t.Fatalf("logical begin response mismatch: %+v", logicalResp) } select { case got := <-beginReqCh: if got.TransferID != "tx-logical" || got.Channel != TransferChannelControl || got.Size != 256 { t.Fatalf("logical begin request mismatch: %+v", got) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for logical begin request") } transportResp, err := SendTransferBeginTransport(context.Background(), server, transport, TransferBeginRequest{ TransferID: "tx-transport", Channel: TransferChannelData, Size: 512, }) if err != nil { t.Fatalf("SendTransferBeginTransport failed: %v", err) } if !transportResp.Accepted || transportResp.TransferID != "tx-transport" || transportResp.NextOffset != 222 { t.Fatalf("transport begin response mismatch: %+v", transportResp) } select { case got := <-beginReqCh: if got.TransferID != "tx-transport" || got.Channel != TransferChannelData || got.Size != 512 { t.Fatalf("transport begin request mismatch: %+v", got) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for transport begin request") } logicalSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-logical") if err != nil { t.Fatalf("GetServerTransferSnapshotByID logical failed: %v", err) } if !ok { t.Fatal("logical transfer snapshot should exist") } if got, want := logicalSnapshot.Scope, serverFileScope(logical); got != want { t.Fatalf("logical transfer scope = %q, want %q", got, want) } if got, want := logicalSnapshot.RuntimeScope, serverTransportScope(logical); got != want { t.Fatalf("logical transfer runtime scope = %q, want %q", got, want) } if got, want := logicalSnapshot.AckedBytes, int64(111); got != want { t.Fatalf("logical transfer acked bytes = %d, want %d", got, want) } transportSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-transport") if err != nil { t.Fatalf("GetServerTransferSnapshotByID transport failed: %v", err) } if !ok { t.Fatal("transport transfer snapshot should exist") } if got, want := transportSnapshot.Scope, serverFileScope(logical); got != want { t.Fatalf("transport transfer scope = %q, want %q", got, want) } if got, want := transportSnapshot.RuntimeScope, serverTransportScopeForTransport(transport); got != want { t.Fatalf("transport transfer runtime scope = %q, want %q", got, want) } if got, want := transportSnapshot.AckedBytes, int64(222); got != want { t.Fatalf("transport transfer acked bytes = %d, want %d", got, want) } }