notify/file_transfer_adapter_test.go

132 lines
3.9 KiB
Go
Raw Normal View History

package notify
import (
"bytes"
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
)
func TestSendFileUsesTransferKernelAndBuiltinFileReceiver(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
receiveDir := t.TempDir()
if err := server.SetFileReceiveDir(receiveDir); err != nil {
t.Fatalf("SetFileReceiveDir failed: %v", err)
}
var serverMu sync.Mutex
var serverEvents []FileEvent
server.SetFileHandler(func(event FileEvent) {
serverMu.Lock()
serverEvents = append(serverEvents, event)
serverMu.Unlock()
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() { _ = server.Stop() }()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
var clientMu sync.Mutex
var clientEvents []FileEvent
client.setFileEventObserver(func(event FileEvent) {
clientMu.Lock()
clientEvents = append(clientEvents, event)
clientMu.Unlock()
})
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() { _ = client.Stop() }()
payload := bytes.Repeat([]byte("send-file-transfer-kernel-"), 1024)
sendPath := filepath.Join(t.TempDir(), "payload.bin")
if err := os.WriteFile(sendPath, payload, 0o600); err != nil {
t.Fatalf("WriteFile failed: %v", err)
}
if err := client.SendFile(context.Background(), sendPath); err != nil {
t.Fatalf("SendFile failed: %v", err)
}
receivedPath := waitForSingleFileInDir(t, receiveDir, 2*time.Second)
received, err := os.ReadFile(receivedPath)
if err != nil {
t.Fatalf("ReadFile failed: %v", err)
}
if !bytes.Equal(received, payload) {
t.Fatalf("received payload mismatch: got %d want %d", len(received), len(payload))
}
clientSnapshots, err := GetClientTransferSnapshots(client)
if err != nil {
t.Fatalf("GetClientTransferSnapshots failed: %v", err)
}
serverSnapshots, err := GetServerTransferSnapshots(server)
if err != nil {
t.Fatalf("GetServerTransferSnapshots failed: %v", err)
}
if !containsFileTransferSnapshot(clientSnapshots) {
t.Fatalf("client snapshots do not contain file transfer metadata: %+v", clientSnapshots)
}
if !containsFileTransferSnapshot(serverSnapshots) {
t.Fatalf("server snapshots do not contain file transfer metadata: %+v", serverSnapshots)
}
clientMu.Lock()
serverMu.Lock()
defer clientMu.Unlock()
defer serverMu.Unlock()
if !containsFileEventKind(clientEvents, EnvelopeFileMeta) || !containsFileEventKind(clientEvents, EnvelopeFileEnd) {
t.Fatalf("client file events missing meta/end: %+v", clientEvents)
}
if !containsFileEventKind(serverEvents, EnvelopeFileMeta) || !containsFileEventKind(serverEvents, EnvelopeFileEnd) {
t.Fatalf("server file events missing meta/end: %+v", serverEvents)
}
}
func waitForSingleFileInDir(t *testing.T, dir string, timeout time.Duration) string {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
entries, err := os.ReadDir(dir)
if err == nil {
for _, entry := range entries {
if entry.IsDir() {
continue
}
return filepath.Join(dir, entry.Name())
}
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("timed out waiting for received file in %s", dir)
return ""
}
func containsFileTransferSnapshot(list []TransferSnapshot) bool {
for _, snapshot := range list {
if snapshot.Metadata[fileTransferMetadataKindKey] == fileTransferMetadataKindValue && snapshot.State == TransferStateDone {
return true
}
}
return false
}
func containsFileEventKind(list []FileEvent, kind EnvelopeKind) bool {
for _, event := range list {
if event.Kind == kind {
return true
}
}
return false
}