package notify import ( "bytes" "context" "io" "os" "strconv" "sync" "testing" "time" itransfer "b612.me/notify/internal/transfer" ) type memoryTransferResumeStore struct { mu sync.Mutex snapshots map[string]TransferSnapshot } func newMemoryTransferResumeStore() *memoryTransferResumeStore { return &memoryTransferResumeStore{ snapshots: make(map[string]TransferSnapshot), } } func (s *memoryTransferResumeStore) SaveTransferSnapshot(_ context.Context, snapshot TransferSnapshot) error { s.mu.Lock() s.snapshots[memoryTransferResumeStoreKey(snapshot)] = snapshot s.mu.Unlock() return nil } func (s *memoryTransferResumeStore) DeleteTransferSnapshot(_ context.Context, snapshot TransferSnapshot) error { s.mu.Lock() delete(s.snapshots, memoryTransferResumeStoreKey(snapshot)) s.mu.Unlock() return nil } func (s *memoryTransferResumeStore) LoadTransferSnapshots(_ context.Context) ([]TransferSnapshot, error) { s.mu.Lock() defer s.mu.Unlock() out := make([]TransferSnapshot, 0, len(s.snapshots)) for _, snapshot := range s.snapshots { out = append(out, snapshot) } return out, nil } func memoryTransferResumeStoreKey(snapshot TransferSnapshot) string { return string(rune(snapshot.Direction)) + "|" + snapshot.RuntimeScope + "|" + snapshot.ID } type transferOffsetSink struct { nextOffset int64 } func (s *transferOffsetSink) WriteAt(p []byte, off int64) (int, error) { if off < 0 { return 0, io.ErrShortWrite } return len(p), nil } func (s *transferOffsetSink) NextOffset() int64 { if s == nil { return 0 } return s.nextOffset } func TestTransferResumeStorePersistsAndRecoversSnapshots(t *testing.T) { store := newMemoryTransferResumeStore() client := NewClient().(*ClientCommon) client.SetTransferResumeStore(store) runtime := client.getTransferRuntime() runtime.ensureTransferDescriptor(fileTransferDirectionSend, clientFileScope(), clientFileScope(), 0, itransfer.Descriptor{ ID: "persist-send", Channel: itransfer.DataChannel, Size: 12, }) runtime.recordSendOptions(fileTransferDirectionSend, clientFileScope(), "persist-send", 4096, 3, 12288) runtime.recordSend(fileTransferDirectionSend, clientFileScope(), "persist-send", 5) snapshots, err := store.LoadTransferSnapshots(context.Background()) if err != nil { t.Fatalf("LoadTransferSnapshots failed: %v", err) } if len(snapshots) != 1 { t.Fatalf("store snapshot len = %d, want 1", len(snapshots)) } if got, want := snapshots[0].SentBytes, int64(5); got != want { t.Fatalf("stored sent bytes = %d, want %d", got, want) } if got, want := snapshots[0].ChunkSize, 4096; got != want { t.Fatalf("stored chunk size = %d, want %d", got, want) } if got, want := snapshots[0].Parallelism, 3; got != want { t.Fatalf("stored parallelism = %d, want %d", got, want) } if got, want := snapshots[0].MaxInflightBytes, int64(12288); got != want { t.Fatalf("stored max inflight bytes = %d, want %d", got, want) } recovered := NewClient().(*ClientCommon) recovered.SetTransferResumeStore(store) if err := recovered.RecoverTransferSnapshots(context.Background()); err != nil { t.Fatalf("RecoverTransferSnapshots failed: %v", err) } snapshot, ok, err := GetClientTransferSnapshotByIDScope(recovered, "persist-send", clientFileScope()) if err != nil { t.Fatalf("GetClientTransferSnapshotByIDScope failed: %v", err) } if !ok { t.Fatal("recovered snapshot missing") } if got, want := snapshot.SentBytes, int64(5); got != want { t.Fatalf("recovered sent bytes = %d, want %d", got, want) } if got, want := snapshot.ChunkSize, 4096; got != want { t.Fatalf("recovered chunk size = %d, want %d", got, want) } if got, want := snapshot.Parallelism, 3; got != want { t.Fatalf("recovered parallelism = %d, want %d", got, want) } if got, want := snapshot.MaxInflightBytes, int64(12288); got != want { t.Fatalf("recovered max inflight bytes = %d, want %d", got, want) } } func TestTransferStateRestoreReceiveSessionFromRuntimeSnapshot(t *testing.T) { state := newTransferState() state.setHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, error) { return TransferReceiveOptions{ Sink: &transferMemorySink{}, }, nil }) runtime := newTransferRuntime() runtime.ensureTransferDescriptor(fileTransferDirectionReceive, "runtime-scope", "public-scope", 0, itransfer.Descriptor{ ID: "restore-rx", Channel: itransfer.DataChannel, Size: 10, }) runtime.recordReceive(fileTransferDirectionReceive, "runtime-scope", "restore-rx", 4) session, restored, err := state.restoreReceiveSession(runtime, "public-scope", "runtime-scope", nil, nil, 0, TransferDescriptor{ ID: "restore-rx", Channel: TransferChannelData, Size: 10, }) if err != nil { t.Fatalf("restoreReceiveSession failed: %v", err) } if !restored { t.Fatal("restoreReceiveSession should restore session") } if got, want := session.nextOffsetSnapshot(), int64(4); got != want { t.Fatalf("restored next offset = %d, want %d", got, want) } } func TestTransferAcceptBeginUsesSinkInitialOffset(t *testing.T) { state := newTransferState() state.setHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, error) { return TransferReceiveOptions{ Sink: &transferOffsetSink{nextOffset: 12}, }, nil }) resp, err := state.acceptBegin(nil, TransferBeginRequest{ TransferID: "offset-begin", Channel: TransferChannelData, Size: 32, }, clientFileScope(), clientFileScope(), 0, nil, nil) if err != nil { t.Fatalf("acceptBegin failed: %v", err) } if got, want := resp.NextOffset, int64(12); got != want { t.Fatalf("acceptBegin next offset = %d, want %d", got, want) } } func TestTransferStateRestoreFileReceiveSessionUsesCheckpointOffset(t *testing.T) { receiveDir := t.TempDir() data := bytes.Repeat([]byte("restore-file-transfer-"), 2048) checksum := transferTestChecksum(data) packet := FilePacket{ FileID: "restore-file", Name: "payload.bin", Size: int64(len(data)), Mode: 0o600, ModTime: time.Now().UnixNano(), Checksum: checksum, } originalPool := newFileReceivePool() if err := originalPool.setDir(receiveDir); err != nil { t.Fatalf("setDir failed: %v", err) } now := time.Now() if _, err := originalPool.onMeta(clientFileScope(), packet, now); err != nil { t.Fatalf("onMeta failed: %v", err) } partial := len(data) / 3 if _, err := originalPool.onChunk(clientFileScope(), FilePacket{ FileID: packet.FileID, Offset: 0, Chunk: append([]byte(nil), data[:partial]...), }, now.Add(10*time.Millisecond)); err != nil { t.Fatalf("onChunk failed: %v", err) } originalPool.mu.Lock() checkpointPath := originalPool.checkpointPathLocked(clientFileScope(), packet.FileID) originalPool.mu.Unlock() if _, err := os.Stat(checkpointPath); err != nil { t.Fatalf("checkpoint missing before restore: %v", err) } runtime := newTransferRuntime() runtime.ensureTransferDescriptor(fileTransferDirectionReceive, clientFileScope(), clientFileScope(), 0, itransfer.Descriptor{ ID: packet.FileID, Channel: itransfer.DataChannel, Size: int64(len(data)), Checksum: checksum, Metadata: itransfer.Metadata{ fileTransferMetadataKindKey: fileTransferMetadataKindValue, fileTransferMetadataNameKey: packet.Name, fileTransferMetadataModeKey: strconv.FormatUint(uint64(packet.Mode), 10), fileTransferMetadataModTimeKey: strconv.FormatInt(packet.ModTime, 10), }, }) runtime.recordReceive(fileTransferDirectionReceive, clientFileScope(), packet.FileID, int64(partial/2)) restoredPool := newFileReceivePool() if err := restoredPool.setDir(receiveDir); err != nil { t.Fatalf("restored setDir failed: %v", err) } state := newTransferState() state.setBuiltinHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) { sink, err := newFileTransferReceiveSink(restoredPool, clientFileScope(), packet, nil) if err != nil { return TransferReceiveOptions{}, true, err } return TransferReceiveOptions{ Descriptor: cloneTransferDescriptor(info.Descriptor), Sink: sink, }, true, nil }) desc := TransferDescriptor{ ID: packet.FileID, Channel: TransferChannelData, Size: int64(len(data)), Checksum: checksum, Metadata: map[string]string{ fileTransferMetadataKindKey: fileTransferMetadataKindValue, fileTransferMetadataNameKey: packet.Name, fileTransferMetadataModeKey: strconv.FormatUint(uint64(packet.Mode), 10), fileTransferMetadataModTimeKey: strconv.FormatInt(packet.ModTime, 10), }, } session, restored, err := state.restoreReceiveSession(runtime, clientFileScope(), clientFileScope(), nil, nil, 0, desc) if err != nil { t.Fatalf("restoreReceiveSession failed: %v", err) } if !restored { t.Fatal("restoreReceiveSession should restore file session") } if got, want := session.nextOffsetSnapshot(), int64(partial); got != want { t.Fatalf("restored file next offset = %d, want %d", got, want) } if err := session.writeSegment(runtime, packet.FileID, int64(partial), data[partial:]); err != nil { t.Fatalf("writeSegment failed: %v", err) } if err := session.commit(context.Background(), runtime, packet.FileID); err != nil { t.Fatalf("commit failed: %v", err) } restoredPool.mu.Lock() completed := restoredPool.completed[fileReceiveKey(clientFileScope(), packet.FileID)] restoredPool.mu.Unlock() if completed == nil { t.Fatal("completed file session missing after commit") } received, err := os.ReadFile(completed.finalPath) if err != nil { t.Fatalf("ReadFile failed: %v", err) } if !bytes.Equal(received, data) { t.Fatalf("restored file content mismatch: got %d want %d", len(received), len(data)) } if _, err := os.Stat(checkpointPath); !os.IsNotExist(err) { t.Fatalf("checkpoint should be removed after commit, stat err = %v", err) } }