package notify import ( "b612.me/stario" "context" "errors" "math" "net" "os" "testing" "time" ) func TestServerLogicalAndTransportLookupAPIs(t *testing.T) { server := NewServer().(*ServerCommon) left, right := net.Pipe() defer left.Close() defer right.Close() logical := server.bootstrapAcceptedLogical("logical-lookup", nil, left) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } if got := server.GetLogicalConn(logical.ClientID); got != logical { t.Fatalf("GetLogicalConn mismatch: got %+v want %+v", got, logical) } transportByID := server.GetCurrentTransportConn(logical.ClientID) if transportByID == nil { t.Fatal("GetCurrentTransportConn should expose current transport") } transportByLogical := server.GetCurrentTransportConnByLogical(logical) if transportByLogical == nil { t.Fatal("GetCurrentTransportConnByLogical should expose current transport") } if got, want := transportByID.ClientID(), logical.ClientID; got != want { t.Fatalf("transport client id mismatch: got %q want %q", got, want) } if got, want := transportByID.TransportGeneration(), transportByLogical.TransportGeneration(); got != want { t.Fatalf("transport generation mismatch: got %d want %d", got, want) } if !transportByID.IsCurrent() || !transportByLogical.IsCurrent() { t.Fatal("lookup transports should be current") } list := server.GetCurrentTransportConnList() if len(list) != 1 { t.Fatalf("GetCurrentTransportConnList len = %d, want 1", len(list)) } if got, want := list[0].ClientID(), logical.ClientID; got != want { t.Fatalf("transport list client id mismatch: got %q want %q", got, want) } } func TestServerSendLogicalAndTransportAPIs(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() server.setServerSessionRuntime(&serverSessionRuntime{ stopCtx: stopCtx, stopFn: stopFn, queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32), }) server.markSessionStarted() defer server.markSessionStopped("test done", nil) left, right := net.Pipe() defer left.Close() defer right.Close() logical, _, _ := newRegisteredServerLogicalForTest(t, server, "api-send", left, stopCtx, stopFn) logical.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) type readResult struct { msg TransferMsg err error } readOneAsync := func() <-chan readResult { t.Helper() ch := make(chan readResult, 1) go func() { _ = right.SetReadDeadline(time.Now().Add(time.Second)) reader := stario.NewFrameReader(right, nil) payload, err := reader.Next() if err != nil { ch <- readResult{err: err} return } env, err := server.decodeEnvelopeLogical(logical, payload) if err != nil { ch <- readResult{err: err} return } msg, err := unwrapTransferMsgEnvelope(env, server.sequenceDe) ch <- readResult{msg: msg, err: err} }() return ch } logicalRead := readOneAsync() if err := server.SendLogical(logical, "logical", MsgVal("payload")); err != nil { t.Fatalf("SendLogical failed: %v", err) } if got := <-logicalRead; got.err != nil { t.Fatalf("SendLogical decode failed: %v", got.err) } else if got.msg.Key != "logical" || got.msg.Type != MSG_ASYNC || string(got.msg.Value) != "payload" { t.Fatalf("SendLogical decoded message mismatch: %+v", got.msg) } transport := server.GetCurrentTransportConn(logical.ClientID) if transport == nil { t.Fatal("GetCurrentTransportConn should expose current transport") } transportRead := readOneAsync() if err := server.SendTransport(transport, "transport", MsgVal("payload")); err != nil { t.Fatalf("SendTransport failed: %v", err) } if got := <-transportRead; got.err != nil { t.Fatalf("SendTransport decode failed: %v", got.err) } else if got.msg.Key != "transport" || got.msg.Type != MSG_ASYNC || string(got.msg.Value) != "payload" { t.Fatalf("SendTransport decoded message mismatch: %+v", got.msg) } } func TestServerSendFileTransportRejectsStaleTransport(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) server.markSessionStarted() defer server.markSessionStopped("test done", nil) firstLeft, firstRight := net.Pipe() defer firstRight.Close() stopCtx, stopFn := context.WithCancel(context.Background()) defer stopFn() logical, _, _ := newRegisteredServerLogicalForTest(t, server, "api-file-stale", firstLeft, stopCtx, stopFn) logical.applyClientConnAttachmentProfile(0, 100*time.Millisecond, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) staleTransport := server.GetCurrentTransportConn(logical.ClientID) if staleTransport == nil { t.Fatal("initial transport should exist") } secondLeft, secondRight := net.Pipe() defer secondLeft.Close() defer secondRight.Close() if err := logical.attachClientConnSessionTransport(secondLeft); err != nil { t.Fatalf("attachClientConnSessionTransport failed: %v", err) } file, err := os.CreateTemp(t.TempDir(), "notify-send-file-*") if err != nil { t.Fatalf("CreateTemp failed: %v", err) } if _, err := file.WriteString("payload"); err != nil { t.Fatalf("WriteString failed: %v", err) } if err := file.Close(); err != nil { t.Fatalf("Close temp file failed: %v", err) } err = server.SendFileTransport(context.Background(), staleTransport, file.Name()) if !errors.Is(err, errTransportDetached) { t.Fatalf("SendFileTransport stale error = %v, want errors.Is(..., %v)", err, errTransportDetached) } }