notify/transfer_resume_store_test.go

302 lines
9.6 KiB
Go
Raw Normal View History

package notify
import (
"bytes"
"context"
"io"
"os"
"strconv"
"sync"
"testing"
"time"
itransfer "b612.me/notify/internal/transfer"
)
type memoryTransferResumeStore struct {
mu sync.Mutex
snapshots map[string]TransferSnapshot
}
func newMemoryTransferResumeStore() *memoryTransferResumeStore {
return &memoryTransferResumeStore{
snapshots: make(map[string]TransferSnapshot),
}
}
func (s *memoryTransferResumeStore) SaveTransferSnapshot(_ context.Context, snapshot TransferSnapshot) error {
s.mu.Lock()
s.snapshots[memoryTransferResumeStoreKey(snapshot)] = snapshot
s.mu.Unlock()
return nil
}
func (s *memoryTransferResumeStore) DeleteTransferSnapshot(_ context.Context, snapshot TransferSnapshot) error {
s.mu.Lock()
delete(s.snapshots, memoryTransferResumeStoreKey(snapshot))
s.mu.Unlock()
return nil
}
func (s *memoryTransferResumeStore) LoadTransferSnapshots(_ context.Context) ([]TransferSnapshot, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]TransferSnapshot, 0, len(s.snapshots))
for _, snapshot := range s.snapshots {
out = append(out, snapshot)
}
return out, nil
}
func memoryTransferResumeStoreKey(snapshot TransferSnapshot) string {
return string(rune(snapshot.Direction)) + "|" + snapshot.RuntimeScope + "|" + snapshot.ID
}
type transferOffsetSink struct {
nextOffset int64
}
func (s *transferOffsetSink) WriteAt(p []byte, off int64) (int, error) {
if off < 0 {
return 0, io.ErrShortWrite
}
return len(p), nil
}
func (s *transferOffsetSink) NextOffset() int64 {
if s == nil {
return 0
}
return s.nextOffset
}
func TestTransferResumeStorePersistsAndRecoversSnapshots(t *testing.T) {
store := newMemoryTransferResumeStore()
client := NewClient().(*ClientCommon)
client.SetTransferResumeStore(store)
runtime := client.getTransferRuntime()
runtime.ensureTransferDescriptor(fileTransferDirectionSend, clientFileScope(), clientFileScope(), 0, itransfer.Descriptor{
ID: "persist-send",
Channel: itransfer.DataChannel,
Size: 12,
})
runtime.recordSendOptions(fileTransferDirectionSend, clientFileScope(), "persist-send", 4096, 3, 12288)
runtime.recordSend(fileTransferDirectionSend, clientFileScope(), "persist-send", 5)
snapshots, err := store.LoadTransferSnapshots(context.Background())
if err != nil {
t.Fatalf("LoadTransferSnapshots failed: %v", err)
}
if len(snapshots) != 1 {
t.Fatalf("store snapshot len = %d, want 1", len(snapshots))
}
if got, want := snapshots[0].SentBytes, int64(5); got != want {
t.Fatalf("stored sent bytes = %d, want %d", got, want)
}
if got, want := snapshots[0].ChunkSize, 4096; got != want {
t.Fatalf("stored chunk size = %d, want %d", got, want)
}
if got, want := snapshots[0].Parallelism, 3; got != want {
t.Fatalf("stored parallelism = %d, want %d", got, want)
}
if got, want := snapshots[0].MaxInflightBytes, int64(12288); got != want {
t.Fatalf("stored max inflight bytes = %d, want %d", got, want)
}
recovered := NewClient().(*ClientCommon)
recovered.SetTransferResumeStore(store)
if err := recovered.RecoverTransferSnapshots(context.Background()); err != nil {
t.Fatalf("RecoverTransferSnapshots failed: %v", err)
}
snapshot, ok, err := GetClientTransferSnapshotByIDScope(recovered, "persist-send", clientFileScope())
if err != nil {
t.Fatalf("GetClientTransferSnapshotByIDScope failed: %v", err)
}
if !ok {
t.Fatal("recovered snapshot missing")
}
if got, want := snapshot.SentBytes, int64(5); got != want {
t.Fatalf("recovered sent bytes = %d, want %d", got, want)
}
if got, want := snapshot.ChunkSize, 4096; got != want {
t.Fatalf("recovered chunk size = %d, want %d", got, want)
}
if got, want := snapshot.Parallelism, 3; got != want {
t.Fatalf("recovered parallelism = %d, want %d", got, want)
}
if got, want := snapshot.MaxInflightBytes, int64(12288); got != want {
t.Fatalf("recovered max inflight bytes = %d, want %d", got, want)
}
}
func TestTransferStateRestoreReceiveSessionFromRuntimeSnapshot(t *testing.T) {
state := newTransferState()
state.setHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, error) {
return TransferReceiveOptions{
Sink: &transferMemorySink{},
}, nil
})
runtime := newTransferRuntime()
runtime.ensureTransferDescriptor(fileTransferDirectionReceive, "runtime-scope", "public-scope", 0, itransfer.Descriptor{
ID: "restore-rx",
Channel: itransfer.DataChannel,
Size: 10,
})
runtime.recordReceive(fileTransferDirectionReceive, "runtime-scope", "restore-rx", 4)
session, restored, err := state.restoreReceiveSession(runtime, "public-scope", "runtime-scope", nil, nil, 0, TransferDescriptor{
ID: "restore-rx",
Channel: TransferChannelData,
Size: 10,
})
if err != nil {
t.Fatalf("restoreReceiveSession failed: %v", err)
}
if !restored {
t.Fatal("restoreReceiveSession should restore session")
}
if got, want := session.nextOffsetSnapshot(), int64(4); got != want {
t.Fatalf("restored next offset = %d, want %d", got, want)
}
}
func TestTransferAcceptBeginUsesSinkInitialOffset(t *testing.T) {
state := newTransferState()
state.setHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, error) {
return TransferReceiveOptions{
Sink: &transferOffsetSink{nextOffset: 12},
}, nil
})
resp, err := state.acceptBegin(nil, TransferBeginRequest{
TransferID: "offset-begin",
Channel: TransferChannelData,
Size: 32,
}, clientFileScope(), clientFileScope(), 0, nil, nil)
if err != nil {
t.Fatalf("acceptBegin failed: %v", err)
}
if got, want := resp.NextOffset, int64(12); got != want {
t.Fatalf("acceptBegin next offset = %d, want %d", got, want)
}
}
func TestTransferStateRestoreFileReceiveSessionUsesCheckpointOffset(t *testing.T) {
receiveDir := t.TempDir()
data := bytes.Repeat([]byte("restore-file-transfer-"), 2048)
checksum := transferTestChecksum(data)
packet := FilePacket{
FileID: "restore-file",
Name: "payload.bin",
Size: int64(len(data)),
Mode: 0o600,
ModTime: time.Now().UnixNano(),
Checksum: checksum,
}
originalPool := newFileReceivePool()
if err := originalPool.setDir(receiveDir); err != nil {
t.Fatalf("setDir failed: %v", err)
}
now := time.Now()
if _, err := originalPool.onMeta(clientFileScope(), packet, now); err != nil {
t.Fatalf("onMeta failed: %v", err)
}
partial := len(data) / 3
if _, err := originalPool.onChunk(clientFileScope(), FilePacket{
FileID: packet.FileID,
Offset: 0,
Chunk: append([]byte(nil), data[:partial]...),
}, now.Add(10*time.Millisecond)); err != nil {
t.Fatalf("onChunk failed: %v", err)
}
originalPool.mu.Lock()
checkpointPath := originalPool.checkpointPathLocked(clientFileScope(), packet.FileID)
originalPool.mu.Unlock()
if _, err := os.Stat(checkpointPath); err != nil {
t.Fatalf("checkpoint missing before restore: %v", err)
}
runtime := newTransferRuntime()
runtime.ensureTransferDescriptor(fileTransferDirectionReceive, clientFileScope(), clientFileScope(), 0, itransfer.Descriptor{
ID: packet.FileID,
Channel: itransfer.DataChannel,
Size: int64(len(data)),
Checksum: checksum,
Metadata: itransfer.Metadata{
fileTransferMetadataKindKey: fileTransferMetadataKindValue,
fileTransferMetadataNameKey: packet.Name,
fileTransferMetadataModeKey: strconv.FormatUint(uint64(packet.Mode), 10),
fileTransferMetadataModTimeKey: strconv.FormatInt(packet.ModTime, 10),
},
})
runtime.recordReceive(fileTransferDirectionReceive, clientFileScope(), packet.FileID, int64(partial/2))
restoredPool := newFileReceivePool()
if err := restoredPool.setDir(receiveDir); err != nil {
t.Fatalf("restored setDir failed: %v", err)
}
state := newTransferState()
state.setBuiltinHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) {
sink, err := newFileTransferReceiveSink(restoredPool, clientFileScope(), packet, nil)
if err != nil {
return TransferReceiveOptions{}, true, err
}
return TransferReceiveOptions{
Descriptor: cloneTransferDescriptor(info.Descriptor),
Sink: sink,
}, true, nil
})
desc := TransferDescriptor{
ID: packet.FileID,
Channel: TransferChannelData,
Size: int64(len(data)),
Checksum: checksum,
Metadata: map[string]string{
fileTransferMetadataKindKey: fileTransferMetadataKindValue,
fileTransferMetadataNameKey: packet.Name,
fileTransferMetadataModeKey: strconv.FormatUint(uint64(packet.Mode), 10),
fileTransferMetadataModTimeKey: strconv.FormatInt(packet.ModTime, 10),
},
}
session, restored, err := state.restoreReceiveSession(runtime, clientFileScope(), clientFileScope(), nil, nil, 0, desc)
if err != nil {
t.Fatalf("restoreReceiveSession failed: %v", err)
}
if !restored {
t.Fatal("restoreReceiveSession should restore file session")
}
if got, want := session.nextOffsetSnapshot(), int64(partial); got != want {
t.Fatalf("restored file next offset = %d, want %d", got, want)
}
if err := session.writeSegment(runtime, packet.FileID, int64(partial), data[partial:]); err != nil {
t.Fatalf("writeSegment failed: %v", err)
}
if err := session.commit(context.Background(), runtime, packet.FileID); err != nil {
t.Fatalf("commit failed: %v", err)
}
restoredPool.mu.Lock()
completed := restoredPool.completed[fileReceiveKey(clientFileScope(), packet.FileID)]
restoredPool.mu.Unlock()
if completed == nil {
t.Fatal("completed file session missing after commit")
}
received, err := os.ReadFile(completed.finalPath)
if err != nil {
t.Fatalf("ReadFile failed: %v", err)
}
if !bytes.Equal(received, data) {
t.Fatalf("restored file content mismatch: got %d want %d", len(received), len(data))
}
if _, err := os.Stat(checkpointPath); !os.IsNotExist(err) {
t.Fatalf("checkpoint should be removed after commit, stat err = %v", err)
}
}