169 lines
5.5 KiB
Go
169 lines
5.5 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"b612.me/stario"
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"math"
|
||
|
|
"net"
|
||
|
|
"os"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestServerLogicalAndTransportLookupAPIs(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
left, right := net.Pipe()
|
||
|
|
defer left.Close()
|
||
|
|
defer right.Close()
|
||
|
|
|
||
|
|
logical := server.bootstrapAcceptedLogical("logical-lookup", nil, left)
|
||
|
|
if logical == nil {
|
||
|
|
t.Fatal("bootstrapAcceptedLogical should return logical")
|
||
|
|
}
|
||
|
|
|
||
|
|
if got := server.GetLogicalConn(logical.ClientID); got != logical {
|
||
|
|
t.Fatalf("GetLogicalConn mismatch: got %+v want %+v", got, logical)
|
||
|
|
}
|
||
|
|
|
||
|
|
transportByID := server.GetCurrentTransportConn(logical.ClientID)
|
||
|
|
if transportByID == nil {
|
||
|
|
t.Fatal("GetCurrentTransportConn should expose current transport")
|
||
|
|
}
|
||
|
|
transportByLogical := server.GetCurrentTransportConnByLogical(logical)
|
||
|
|
if transportByLogical == nil {
|
||
|
|
t.Fatal("GetCurrentTransportConnByLogical should expose current transport")
|
||
|
|
}
|
||
|
|
if got, want := transportByID.ClientID(), logical.ClientID; got != want {
|
||
|
|
t.Fatalf("transport client id mismatch: got %q want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := transportByID.TransportGeneration(), transportByLogical.TransportGeneration(); got != want {
|
||
|
|
t.Fatalf("transport generation mismatch: got %d want %d", got, want)
|
||
|
|
}
|
||
|
|
if !transportByID.IsCurrent() || !transportByLogical.IsCurrent() {
|
||
|
|
t.Fatal("lookup transports should be current")
|
||
|
|
}
|
||
|
|
|
||
|
|
list := server.GetCurrentTransportConnList()
|
||
|
|
if len(list) != 1 {
|
||
|
|
t.Fatalf("GetCurrentTransportConnList len = %d, want 1", len(list))
|
||
|
|
}
|
||
|
|
if got, want := list[0].ClientID(), logical.ClientID; got != want {
|
||
|
|
t.Fatalf("transport list client id mismatch: got %q want %q", got, want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServerSendLogicalAndTransportAPIs(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
UseLegacySecurityServer(server)
|
||
|
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||
|
|
defer stopFn()
|
||
|
|
server.setServerSessionRuntime(&serverSessionRuntime{
|
||
|
|
stopCtx: stopCtx,
|
||
|
|
stopFn: stopFn,
|
||
|
|
queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32),
|
||
|
|
})
|
||
|
|
server.markSessionStarted()
|
||
|
|
defer server.markSessionStopped("test done", nil)
|
||
|
|
|
||
|
|
left, right := net.Pipe()
|
||
|
|
defer left.Close()
|
||
|
|
defer right.Close()
|
||
|
|
|
||
|
|
logical, _, _ := newRegisteredServerLogicalForTest(t, server, "api-send", left, stopCtx, stopFn)
|
||
|
|
logical.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey)
|
||
|
|
|
||
|
|
type readResult struct {
|
||
|
|
msg TransferMsg
|
||
|
|
err error
|
||
|
|
}
|
||
|
|
readOneAsync := func() <-chan readResult {
|
||
|
|
t.Helper()
|
||
|
|
ch := make(chan readResult, 1)
|
||
|
|
go func() {
|
||
|
|
_ = right.SetReadDeadline(time.Now().Add(time.Second))
|
||
|
|
reader := stario.NewFrameReader(right, nil)
|
||
|
|
payload, err := reader.Next()
|
||
|
|
if err != nil {
|
||
|
|
ch <- readResult{err: err}
|
||
|
|
return
|
||
|
|
}
|
||
|
|
env, err := server.decodeEnvelopeLogical(logical, payload)
|
||
|
|
if err != nil {
|
||
|
|
ch <- readResult{err: err}
|
||
|
|
return
|
||
|
|
}
|
||
|
|
msg, err := unwrapTransferMsgEnvelope(env, server.sequenceDe)
|
||
|
|
ch <- readResult{msg: msg, err: err}
|
||
|
|
}()
|
||
|
|
return ch
|
||
|
|
}
|
||
|
|
|
||
|
|
logicalRead := readOneAsync()
|
||
|
|
if err := server.SendLogical(logical, "logical", MsgVal("payload")); err != nil {
|
||
|
|
t.Fatalf("SendLogical failed: %v", err)
|
||
|
|
}
|
||
|
|
if got := <-logicalRead; got.err != nil {
|
||
|
|
t.Fatalf("SendLogical decode failed: %v", got.err)
|
||
|
|
} else if got.msg.Key != "logical" || got.msg.Type != MSG_ASYNC || string(got.msg.Value) != "payload" {
|
||
|
|
t.Fatalf("SendLogical decoded message mismatch: %+v", got.msg)
|
||
|
|
}
|
||
|
|
|
||
|
|
transport := server.GetCurrentTransportConn(logical.ClientID)
|
||
|
|
if transport == nil {
|
||
|
|
t.Fatal("GetCurrentTransportConn should expose current transport")
|
||
|
|
}
|
||
|
|
transportRead := readOneAsync()
|
||
|
|
if err := server.SendTransport(transport, "transport", MsgVal("payload")); err != nil {
|
||
|
|
t.Fatalf("SendTransport failed: %v", err)
|
||
|
|
}
|
||
|
|
if got := <-transportRead; got.err != nil {
|
||
|
|
t.Fatalf("SendTransport decode failed: %v", got.err)
|
||
|
|
} else if got.msg.Key != "transport" || got.msg.Type != MSG_ASYNC || string(got.msg.Value) != "payload" {
|
||
|
|
t.Fatalf("SendTransport decoded message mismatch: %+v", got.msg)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestServerSendFileTransportRejectsStaleTransport(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
UseLegacySecurityServer(server)
|
||
|
|
server.markSessionStarted()
|
||
|
|
defer server.markSessionStopped("test done", nil)
|
||
|
|
|
||
|
|
firstLeft, firstRight := net.Pipe()
|
||
|
|
defer firstRight.Close()
|
||
|
|
|
||
|
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||
|
|
defer stopFn()
|
||
|
|
logical, _, _ := newRegisteredServerLogicalForTest(t, server, "api-file-stale", firstLeft, stopCtx, stopFn)
|
||
|
|
logical.applyClientConnAttachmentProfile(0, 100*time.Millisecond, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey)
|
||
|
|
|
||
|
|
staleTransport := server.GetCurrentTransportConn(logical.ClientID)
|
||
|
|
if staleTransport == nil {
|
||
|
|
t.Fatal("initial transport should exist")
|
||
|
|
}
|
||
|
|
|
||
|
|
secondLeft, secondRight := net.Pipe()
|
||
|
|
defer secondLeft.Close()
|
||
|
|
defer secondRight.Close()
|
||
|
|
if err := logical.attachClientConnSessionTransport(secondLeft); err != nil {
|
||
|
|
t.Fatalf("attachClientConnSessionTransport failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
file, err := os.CreateTemp(t.TempDir(), "notify-send-file-*")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("CreateTemp failed: %v", err)
|
||
|
|
}
|
||
|
|
if _, err := file.WriteString("payload"); err != nil {
|
||
|
|
t.Fatalf("WriteString failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := file.Close(); err != nil {
|
||
|
|
t.Fatalf("Close temp file failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
err = server.SendFileTransport(context.Background(), staleTransport, file.Name())
|
||
|
|
if !errors.Is(err, errTransportDetached) {
|
||
|
|
t.Fatalf("SendFileTransport stale error = %v, want errors.Is(..., %v)", err, errTransportDetached)
|
||
|
|
}
|
||
|
|
}
|