132 lines
3.9 KiB
Go
132 lines
3.9 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"context"
|
||
|
|
"os"
|
||
|
|
"path/filepath"
|
||
|
|
"sync"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestSendFileUsesTransferKernelAndBuiltinFileReceiver(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
receiveDir := t.TempDir()
|
||
|
|
if err := server.SetFileReceiveDir(receiveDir); err != nil {
|
||
|
|
t.Fatalf("SetFileReceiveDir failed: %v", err)
|
||
|
|
}
|
||
|
|
var serverMu sync.Mutex
|
||
|
|
var serverEvents []FileEvent
|
||
|
|
server.SetFileHandler(func(event FileEvent) {
|
||
|
|
serverMu.Lock()
|
||
|
|
serverEvents = append(serverEvents, event)
|
||
|
|
serverMu.Unlock()
|
||
|
|
})
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() { _ = server.Stop() }()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
var clientMu sync.Mutex
|
||
|
|
var clientEvents []FileEvent
|
||
|
|
client.setFileEventObserver(func(event FileEvent) {
|
||
|
|
clientMu.Lock()
|
||
|
|
clientEvents = append(clientEvents, event)
|
||
|
|
clientMu.Unlock()
|
||
|
|
})
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() { _ = client.Stop() }()
|
||
|
|
|
||
|
|
payload := bytes.Repeat([]byte("send-file-transfer-kernel-"), 1024)
|
||
|
|
sendPath := filepath.Join(t.TempDir(), "payload.bin")
|
||
|
|
if err := os.WriteFile(sendPath, payload, 0o600); err != nil {
|
||
|
|
t.Fatalf("WriteFile failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := client.SendFile(context.Background(), sendPath); err != nil {
|
||
|
|
t.Fatalf("SendFile failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
receivedPath := waitForSingleFileInDir(t, receiveDir, 2*time.Second)
|
||
|
|
received, err := os.ReadFile(receivedPath)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("ReadFile failed: %v", err)
|
||
|
|
}
|
||
|
|
if !bytes.Equal(received, payload) {
|
||
|
|
t.Fatalf("received payload mismatch: got %d want %d", len(received), len(payload))
|
||
|
|
}
|
||
|
|
|
||
|
|
clientSnapshots, err := GetClientTransferSnapshots(client)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetClientTransferSnapshots failed: %v", err)
|
||
|
|
}
|
||
|
|
serverSnapshots, err := GetServerTransferSnapshots(server)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetServerTransferSnapshots failed: %v", err)
|
||
|
|
}
|
||
|
|
if !containsFileTransferSnapshot(clientSnapshots) {
|
||
|
|
t.Fatalf("client snapshots do not contain file transfer metadata: %+v", clientSnapshots)
|
||
|
|
}
|
||
|
|
if !containsFileTransferSnapshot(serverSnapshots) {
|
||
|
|
t.Fatalf("server snapshots do not contain file transfer metadata: %+v", serverSnapshots)
|
||
|
|
}
|
||
|
|
|
||
|
|
clientMu.Lock()
|
||
|
|
serverMu.Lock()
|
||
|
|
defer clientMu.Unlock()
|
||
|
|
defer serverMu.Unlock()
|
||
|
|
if !containsFileEventKind(clientEvents, EnvelopeFileMeta) || !containsFileEventKind(clientEvents, EnvelopeFileEnd) {
|
||
|
|
t.Fatalf("client file events missing meta/end: %+v", clientEvents)
|
||
|
|
}
|
||
|
|
if !containsFileEventKind(serverEvents, EnvelopeFileMeta) || !containsFileEventKind(serverEvents, EnvelopeFileEnd) {
|
||
|
|
t.Fatalf("server file events missing meta/end: %+v", serverEvents)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func waitForSingleFileInDir(t *testing.T, dir string, timeout time.Duration) string {
|
||
|
|
t.Helper()
|
||
|
|
deadline := time.Now().Add(timeout)
|
||
|
|
for time.Now().Before(deadline) {
|
||
|
|
entries, err := os.ReadDir(dir)
|
||
|
|
if err == nil {
|
||
|
|
for _, entry := range entries {
|
||
|
|
if entry.IsDir() {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
return filepath.Join(dir, entry.Name())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
time.Sleep(20 * time.Millisecond)
|
||
|
|
}
|
||
|
|
t.Fatalf("timed out waiting for received file in %s", dir)
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
|
||
|
|
func containsFileTransferSnapshot(list []TransferSnapshot) bool {
|
||
|
|
for _, snapshot := range list {
|
||
|
|
if snapshot.Metadata[fileTransferMetadataKindKey] == fileTransferMetadataKindValue && snapshot.State == TransferStateDone {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
func containsFileEventKind(list []FileEvent, kind EnvelopeKind) bool {
|
||
|
|
for _, event := range list {
|
||
|
|
if event.Kind == kind {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return false
|
||
|
|
}
|