package notify import ( "b612.me/stario" "context" "errors" "math" "net" "testing" "time" ) func readServerEnvelopeFromConnWithProfile(t *testing.T, server *ServerCommon, profile transportProtectionProfile, conn net.Conn, timeout time.Duration) Envelope { t.Helper() queue := stario.NewQueue() deadline := time.Now().Add(timeout) buf := make([]byte, 4096) for time.Now().Before(deadline) { if err := conn.SetReadDeadline(deadline); err != nil { t.Fatalf("SetReadDeadline failed: %v", err) } n, err := conn.Read(buf) if n > 0 { if parseErr := queue.ParseMessage(buf[:n], "server-inbound-profile"); parseErr != nil { t.Fatalf("ParseMessage failed: %v", parseErr) } select { case msg := <-queue.RestoreChan(): plain, decErr := decryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, msg.Msg) if decErr != nil { t.Fatalf("decryptTransportPayloadCodec failed: %v", decErr) } env, decErr := server.decodeEnvelopePlain(plain) if decErr != nil { t.Fatalf("decodeEnvelopePlain failed: %v", decErr) } return env default: } } if err == nil { continue } if netErr, ok := err.(net.Error); ok && netErr.Timeout() { break } t.Fatalf("conn Read failed: %v", err) } t.Fatal("timed out waiting for server envelope") return Envelope{} } func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() server.setServerSessionRuntime(&serverSessionRuntime{ stopCtx: stopCtx, stopFn: stopFn, queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32), }) server.markSessionStarted() defer server.markSessionStopped("test done", nil) firstLeft, firstRight := net.Pipe() defer firstLeft.Close() defer firstRight.Close() logical, _, _ := newRegisteredServerLogicalForTest(t, server, "reply-inbound-stale", firstLeft, stopCtx, stopFn) client := clientConnFromLogical(logical) client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) staleTransport := logical.transportConnSnapshotForInbound(firstLeft, nil, logical.transportGenerationSnapshot()+1, true) if staleTransport == nil { t.Fatal("stale transport snapshot should exist") } if staleTransport.IsCurrent() { t.Fatal("stale transport should not be current for mismatched generation") } message := Message{ NetType: NET_SERVER, LogicalConn: logical, TransportConn: staleTransport, TransferMsg: TransferMsg{ ID: 11, Key: "reply-inbound", Type: MSG_SYNC_ASK, }, Time: time.Now(), inboundConn: firstLeft, } done := make(chan error, 1) go func() { done <- message.Reply(MsgVal("ok")) }() env := readServerEnvelopeFromConn(t, server, logical, firstRight, time.Second) select { case err := <-done: if err != nil { t.Fatalf("Message.Reply failed: %v", err) } case <-time.After(time.Second): t.Fatal("timed out waiting for Message.Reply to finish") } transfer, err := unwrapTransferMsgEnvelope(env, server.sequenceDe) if err != nil { t.Fatalf("unwrapTransferMsgEnvelope failed: %v", err) } if transfer.Type != MSG_SYNC_REPLY { t.Fatalf("reply type = %v, want %v", transfer.Type, MSG_SYNC_REPLY) } if transfer.ID != 11 { t.Fatalf("reply id = %d, want %d", transfer.ID, 11) } if transfer.Key != "reply-inbound" { t.Fatalf("reply key = %q, want %q", transfer.Key, "reply-inbound") } if string(transfer.Value) != "ok" { t.Fatalf("reply value = %q, want %q", string(transfer.Value), "ok") } } func TestMessageReplyUsesInboundProtectionSnapshotAfterLogicalSwitch(t *testing.T) { secret := []byte("correct horse battery staple") alternate, err := deriveModernPSKProtectionProfile([]byte("notify-reply-snapshot-alternate"), testModernPSKOptions(), ProtectionManaged) if err != nil { t.Fatalf("deriveModernPSKProtectionProfile failed: %v", err) } handlerErr := make(chan error, 1) server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } server.SetLink("reply-snapshot", func(msg *Message) { if msg == nil || msg.LogicalConn == nil { select { case handlerErr <- errors.New("reply-snapshot logical is nil"): default: } return } msg.LogicalConn.applyTransportProtectionProfile(alternate) if err := msg.Reply([]byte("ack")); err != nil { select { case handlerErr <- err: default: } } }) }) client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } left, right := net.Pipe() defer right.Close() bootstrapPeerAttachConnForTest(t, server, right) if err := client.ConnectByConn(left); err != nil { t.Fatalf("ConnectByConn failed: %v", err) } defer func() { client.setByeFromServer(true) _ = client.Stop() }() reply, err := client.SendWait("reply-snapshot", []byte("ping"), time.Second) if err != nil { t.Fatalf("SendWait failed: %v", err) } if got, want := string(reply.Value), "ack"; got != want { t.Fatalf("reply value = %q, want %q", got, want) } select { case err := <-handlerErr: t.Fatalf("reply-snapshot handler failed: %v", err) default: } } func TestServerHandleReceivedSignalReliabilityUsesInboundConnForStaleTransport(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) if err := UseSignalReliabilityServer(server, &SignalReliabilityOptions{Enabled: true}); err != nil { t.Fatalf("UseSignalReliabilityServer failed: %v", err) } stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() server.setServerSessionRuntime(&serverSessionRuntime{ stopCtx: stopCtx, stopFn: stopFn, queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32), }) firstLeft, firstRight := net.Pipe() defer firstLeft.Close() defer firstRight.Close() logical, _, _ := newRegisteredServerLogicalForTest(t, server, "signal-ack-inbound-stale", firstLeft, stopCtx, stopFn) client := clientConnFromLogical(logical) client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) staleTransport := logical.transportConnSnapshotForInbound(firstLeft, nil, logical.transportGenerationSnapshot()+1, true) if staleTransport == nil { t.Fatal("stale transport snapshot should exist") } if staleTransport.IsCurrent() { t.Fatal("stale transport should not be current for mismatched generation") } done := make(chan bool, 1) go func() { done <- server.handleReceivedSignalReliabilityTransport(logical, staleTransport, firstLeft, TransferMsg{ ID: 22, Key: "signal-reliable", Type: MSG_ASYNC, }) }() env := readServerEnvelopeFromConn(t, server, logical, firstRight, time.Second) select { case duplicate := <-done: if duplicate { t.Fatal("first reliable signal receive should not be duplicate") } case <-time.After(time.Second): t.Fatal("timed out waiting for signal reliability handler to finish") } if env.Kind != EnvelopeSignalAck { t.Fatalf("ack envelope kind = %v, want %v", env.Kind, EnvelopeSignalAck) } if env.ID != 22 { t.Fatalf("ack signal id = %d, want %d", env.ID, 22) } } func TestServerDispatchFileEnvelopeUsesInboundConnForStaleTransportAck(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() server.setServerSessionRuntime(&serverSessionRuntime{ stopCtx: stopCtx, stopFn: stopFn, queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32), }) firstLeft, firstRight := net.Pipe() defer firstLeft.Close() defer firstRight.Close() logical, _, _ := newRegisteredServerLogicalForTest(t, server, "file-ack-inbound-stale", firstLeft, stopCtx, stopFn) client := clientConnFromLogical(logical) client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) staleTransport := logical.transportConnSnapshotForInbound(firstLeft, nil, logical.transportGenerationSnapshot()+1, true) if staleTransport == nil { t.Fatal("stale transport snapshot should exist") } if staleTransport.IsCurrent() { t.Fatal("stale transport should not be current for mismatched generation") } done := make(chan struct{}) go func() { server.dispatchFileEnvelope(logical, staleTransport, firstLeft, newFileMetaEnvelope("file-1", "demo.bin", 4, "", 0, 0), time.Now()) close(done) }() env := readServerEnvelopeFromConn(t, server, logical, firstRight, time.Second) select { case <-done: case <-time.After(time.Second): t.Fatal("timed out waiting for file dispatch to finish") } if env.Kind != EnvelopeAck { t.Fatalf("file ack envelope kind = %v, want %v", env.Kind, EnvelopeAck) } if env.File.FileID != "file-1" { t.Fatalf("file ack file id = %q, want %q", env.File.FileID, "file-1") } if env.File.Stage != "meta" { t.Fatalf("file ack stage = %q, want %q", env.File.Stage, "meta") } if env.File.Offset != 0 { t.Fatalf("file ack offset = %d, want %d", env.File.Offset, 0) } }