331 lines
10 KiB
Go
331 lines
10 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"math"
|
||
|
|
"net"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"b612.me/stario"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestHydrateServerMessagePeerFieldsFromLogicalConn(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
left, right := net.Pipe()
|
||
|
|
defer left.Close()
|
||
|
|
defer right.Close()
|
||
|
|
|
||
|
|
logical := server.bootstrapAcceptedLogical("peer-fields-message", nil, left)
|
||
|
|
if logical == nil {
|
||
|
|
t.Fatal("bootstrapAcceptedLogical should return logical")
|
||
|
|
}
|
||
|
|
client := clientConnFromLogical(logical)
|
||
|
|
|
||
|
|
message := hydrateServerMessagePeerFields(Message{
|
||
|
|
NetType: NET_SERVER,
|
||
|
|
LogicalConn: logical,
|
||
|
|
})
|
||
|
|
|
||
|
|
if message.ClientConn != client {
|
||
|
|
t.Fatal("ClientConn should alias LogicalConn after hydration")
|
||
|
|
}
|
||
|
|
if got := messageLogicalConnSnapshot(&message); got != logical {
|
||
|
|
t.Fatal("messageLogicalConnSnapshot mismatch")
|
||
|
|
}
|
||
|
|
transport := messageTransportConnSnapshot(&message)
|
||
|
|
if transport == nil {
|
||
|
|
t.Fatal("message transport should be hydrated")
|
||
|
|
}
|
||
|
|
if got := transport.LogicalConn(); got != logical {
|
||
|
|
t.Fatal("message transport logical conn mismatch")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServerPublishSendFileEventHydratesLogicalConnAlias(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
left, right := net.Pipe()
|
||
|
|
defer left.Close()
|
||
|
|
defer right.Close()
|
||
|
|
|
||
|
|
logical := server.bootstrapAcceptedLogical("peer-fields-file", nil, left)
|
||
|
|
if logical == nil {
|
||
|
|
t.Fatal("bootstrapAcceptedLogical should return logical")
|
||
|
|
}
|
||
|
|
client := clientConnFromLogical(logical)
|
||
|
|
|
||
|
|
var observed []FileEvent
|
||
|
|
server.setFileEventObserver(func(event FileEvent) {
|
||
|
|
observed = append(observed, event)
|
||
|
|
})
|
||
|
|
|
||
|
|
server.publishSendFileEvent(FileEvent{
|
||
|
|
NetType: NET_SERVER,
|
||
|
|
ClientConn: client,
|
||
|
|
Kind: EnvelopeFileMeta,
|
||
|
|
Packet: FilePacket{FileID: "file-1"},
|
||
|
|
})
|
||
|
|
|
||
|
|
if got, want := len(observed), 1; got != want {
|
||
|
|
t.Fatalf("observed count = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if observed[0].LogicalConn != logical {
|
||
|
|
t.Fatal("LogicalConn should be hydrated from compatibility alias")
|
||
|
|
}
|
||
|
|
if observed[0].ClientConn != client {
|
||
|
|
t.Fatal("ClientConn compatibility alias mismatch")
|
||
|
|
}
|
||
|
|
if observed[0].TransportConn == nil {
|
||
|
|
t.Fatal("TransportConn should be hydrated for server file event")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestMessageReplyUsesLogicalConnOnly(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
UseLegacySecurityServer(server)
|
||
|
|
|
||
|
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||
|
|
defer stopFn()
|
||
|
|
server.setServerSessionRuntime(&serverSessionRuntime{
|
||
|
|
stopCtx: stopCtx,
|
||
|
|
stopFn: stopFn,
|
||
|
|
queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32),
|
||
|
|
})
|
||
|
|
server.markSessionStarted()
|
||
|
|
defer server.markSessionStopped("test done", nil)
|
||
|
|
|
||
|
|
left, right := net.Pipe()
|
||
|
|
defer left.Close()
|
||
|
|
defer right.Close()
|
||
|
|
|
||
|
|
client, _, _ := newRegisteredServerClientForTest(t, server, "reply-logical", left, stopCtx, stopFn)
|
||
|
|
client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey)
|
||
|
|
logical := logicalConnFromClient(client)
|
||
|
|
|
||
|
|
expectedReply := TransferMsg{
|
||
|
|
ID: 1,
|
||
|
|
Key: "reply",
|
||
|
|
Value: MsgVal("ok"),
|
||
|
|
Type: MSG_SYNC_REPLY,
|
||
|
|
}
|
||
|
|
env, err := wrapTransferMsgEnvelope(expectedReply, server.sequenceEn)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("wrapTransferMsgEnvelope failed: %v", err)
|
||
|
|
}
|
||
|
|
payload, err := server.encodeEnvelopePayloadLogical(logical, env)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("encodeEnvelopePayloadLogical failed: %v", err)
|
||
|
|
}
|
||
|
|
want := server.serverQueueSnapshot().BuildMessage(payload)
|
||
|
|
if len(want) == 0 {
|
||
|
|
t.Fatal("expected framed reply payload")
|
||
|
|
}
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
recvCh := make(chan []byte, 1)
|
||
|
|
go func() {
|
||
|
|
_ = right.SetReadDeadline(time.Now().Add(time.Second))
|
||
|
|
buf := make([]byte, len(want))
|
||
|
|
read := 0
|
||
|
|
for read < len(want) {
|
||
|
|
n, err := right.Read(buf[read:])
|
||
|
|
if n > 0 {
|
||
|
|
read += n
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
errCh <- err
|
||
|
|
return
|
||
|
|
}
|
||
|
|
}
|
||
|
|
recvCh <- buf[:read]
|
||
|
|
}()
|
||
|
|
|
||
|
|
message := Message{
|
||
|
|
NetType: NET_SERVER,
|
||
|
|
LogicalConn: logical,
|
||
|
|
TransferMsg: TransferMsg{
|
||
|
|
ID: 1,
|
||
|
|
Key: "reply",
|
||
|
|
Type: MSG_SYNC_ASK,
|
||
|
|
},
|
||
|
|
Time: time.Now(),
|
||
|
|
}
|
||
|
|
if err := message.Reply(MsgVal("ok")); err != nil {
|
||
|
|
t.Fatalf("Message.Reply failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
t.Fatalf("reply read failed: %v", err)
|
||
|
|
case got := <-recvCh:
|
||
|
|
var decoded TransferMsg
|
||
|
|
frames := 0
|
||
|
|
if err := server.serverQueueSnapshot().ParseMessageView(got, "reply-test", func(view stario.FrameView) error {
|
||
|
|
frames++
|
||
|
|
env, err := server.decodeEnvelopeLogical(logical, view.Payload)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
decoded, err = unwrapTransferMsgEnvelope(env, server.sequenceDe)
|
||
|
|
return err
|
||
|
|
}); err != nil {
|
||
|
|
t.Fatalf("failed to decode framed reply payload: %v", err)
|
||
|
|
}
|
||
|
|
if frames != 1 {
|
||
|
|
t.Fatalf("decoded frame count = %d, want 1", frames)
|
||
|
|
}
|
||
|
|
if decoded.ID != expectedReply.ID || decoded.Key != expectedReply.Key || decoded.Type != expectedReply.Type || string(decoded.Value) != string(expectedReply.Value) {
|
||
|
|
t.Fatalf("decoded reply mismatch: got %+v want %+v", decoded, expectedReply)
|
||
|
|
}
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("timed out waiting for reply payload")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTransferControlServerLogicalAndTransportValidation(t *testing.T) {
|
||
|
|
server := NewServer()
|
||
|
|
req := TransferBeginRequest{TransferID: "tx-validate"}
|
||
|
|
|
||
|
|
if _, err := SendTransferBeginLogical(context.Background(), server, nil, req); !errors.Is(err, errTransferControlLogicalConnNil) {
|
||
|
|
t.Fatalf("SendTransferBeginLogical nil logical error = %v, want %v", err, errTransferControlLogicalConnNil)
|
||
|
|
}
|
||
|
|
if _, err := SendTransferBeginTransport(context.Background(), server, nil, req); !errors.Is(err, errTransferControlTransportNil) {
|
||
|
|
t.Fatalf("SendTransferBeginTransport nil transport error = %v, want %v", err, errTransferControlTransportNil)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTransferControlServerLogicalAndTransportBeginAPIs(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
beginReqCh := make(chan TransferBeginRequest, 2)
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
// This test validates the explicit transfer-control link bindings rather
|
||
|
|
// than the builtin transfer plane receiver, so disable the builtin control
|
||
|
|
// interception on the client side first.
|
||
|
|
clientTransferState := client.getTransferState()
|
||
|
|
clientTransferState.mu.Lock()
|
||
|
|
clientTransferState.controlEnabled = false
|
||
|
|
clientTransferState.handler = nil
|
||
|
|
clientTransferState.builtinHandler = nil
|
||
|
|
clientTransferState.mu.Unlock()
|
||
|
|
if err := BindTransferControlClient(client, TransferControlHandler{
|
||
|
|
Begin: func(_ *Message, req TransferBeginRequest) (TransferBeginResponse, error) {
|
||
|
|
beginReqCh <- req
|
||
|
|
resp := TransferBeginResponse{
|
||
|
|
TransferID: req.TransferID,
|
||
|
|
Accepted: true,
|
||
|
|
}
|
||
|
|
switch req.TransferID {
|
||
|
|
case "tx-logical":
|
||
|
|
resp.NextOffset = 111
|
||
|
|
case "tx-transport":
|
||
|
|
resp.NextOffset = 222
|
||
|
|
}
|
||
|
|
return resp, nil
|
||
|
|
},
|
||
|
|
}); err != nil {
|
||
|
|
t.Fatalf("BindTransferControlClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = client.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
|
||
|
|
transport := logical.CurrentTransportConn()
|
||
|
|
if transport == nil {
|
||
|
|
t.Fatal("CurrentTransportConn should expose active transport")
|
||
|
|
}
|
||
|
|
|
||
|
|
logicalResp, err := SendTransferBeginLogical(context.Background(), server, logical, TransferBeginRequest{
|
||
|
|
TransferID: "tx-logical",
|
||
|
|
Channel: TransferChannelControl,
|
||
|
|
Size: 256,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("SendTransferBeginLogical failed: %v", err)
|
||
|
|
}
|
||
|
|
if !logicalResp.Accepted || logicalResp.TransferID != "tx-logical" || logicalResp.NextOffset != 111 {
|
||
|
|
t.Fatalf("logical begin response mismatch: %+v", logicalResp)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case got := <-beginReqCh:
|
||
|
|
if got.TransferID != "tx-logical" || got.Channel != TransferChannelControl || got.Size != 256 {
|
||
|
|
t.Fatalf("logical begin request mismatch: %+v", got)
|
||
|
|
}
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("timed out waiting for logical begin request")
|
||
|
|
}
|
||
|
|
|
||
|
|
transportResp, err := SendTransferBeginTransport(context.Background(), server, transport, TransferBeginRequest{
|
||
|
|
TransferID: "tx-transport",
|
||
|
|
Channel: TransferChannelData,
|
||
|
|
Size: 512,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("SendTransferBeginTransport failed: %v", err)
|
||
|
|
}
|
||
|
|
if !transportResp.Accepted || transportResp.TransferID != "tx-transport" || transportResp.NextOffset != 222 {
|
||
|
|
t.Fatalf("transport begin response mismatch: %+v", transportResp)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case got := <-beginReqCh:
|
||
|
|
if got.TransferID != "tx-transport" || got.Channel != TransferChannelData || got.Size != 512 {
|
||
|
|
t.Fatalf("transport begin request mismatch: %+v", got)
|
||
|
|
}
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("timed out waiting for transport begin request")
|
||
|
|
}
|
||
|
|
|
||
|
|
logicalSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-logical")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetServerTransferSnapshotByID logical failed: %v", err)
|
||
|
|
}
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("logical transfer snapshot should exist")
|
||
|
|
}
|
||
|
|
if got, want := logicalSnapshot.Scope, serverFileScope(logical); got != want {
|
||
|
|
t.Fatalf("logical transfer scope = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := logicalSnapshot.RuntimeScope, serverTransportScope(logical); got != want {
|
||
|
|
t.Fatalf("logical transfer runtime scope = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := logicalSnapshot.AckedBytes, int64(111); got != want {
|
||
|
|
t.Fatalf("logical transfer acked bytes = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
|
||
|
|
transportSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-transport")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetServerTransferSnapshotByID transport failed: %v", err)
|
||
|
|
}
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("transport transfer snapshot should exist")
|
||
|
|
}
|
||
|
|
if got, want := transportSnapshot.Scope, serverFileScope(logical); got != want {
|
||
|
|
t.Fatalf("transport transfer scope = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := transportSnapshot.RuntimeScope, serverTransportScopeForTransport(transport); got != want {
|
||
|
|
t.Fatalf("transport transfer runtime scope = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := transportSnapshot.AckedBytes, int64(222); got != want {
|
||
|
|
t.Fatalf("transport transfer acked bytes = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
}
|