package transfer import ( "errors" "sort" "sync" "time" ) var ( ErrTransferIDEmpty = errors.New("transfer id is empty") ErrTransferExists = errors.New("transfer already exists") ErrTransferNotFound = errors.New("transfer not found") ErrTransferBytesInvalid = errors.New("transfer bytes must be non-negative") ) type Manager struct { mu sync.Mutex now func() time.Time transfers map[string]*managedTransfer } type managedTransfer struct { snapshot Snapshot } func NewManager() *Manager { return NewManagerWithClock(time.Now) } func NewManagerWithClock(now func() time.Time) *Manager { if now == nil { now = time.Now } return &Manager{ now: now, transfers: make(map[string]*managedTransfer), } } func (m *Manager) StartOutgoing(desc Descriptor) (Snapshot, error) { return m.start(desc, DirectionSend, StateNegotiating) } func (m *Manager) StartIncoming(desc Descriptor) (Snapshot, error) { return m.start(desc, DirectionReceive, StatePrepared) } func (m *Manager) Activate(id string) (Snapshot, error) { return m.update(id, func(snapshot *Snapshot) error { snapshot.State = StateActive return nil }) } func (m *Manager) Pause(id string) (Snapshot, error) { return m.update(id, func(snapshot *Snapshot) error { if snapshot.State.Terminal() { return nil } snapshot.State = StatePaused return nil }) } func (m *Manager) Resume(id string, confirmedBytes int64) (Snapshot, error) { if confirmedBytes < 0 { return Snapshot{}, ErrTransferBytesInvalid } return m.update(id, func(snapshot *Snapshot) error { switch snapshot.Direction { case DirectionSend: if confirmedBytes > snapshot.SentBytes { snapshot.SentBytes = confirmedBytes } snapshot.AckedBytes = confirmedBytes if snapshot.Size > 0 && snapshot.AckedBytes > snapshot.Size { snapshot.AckedBytes = snapshot.Size } case DirectionReceive: if confirmedBytes > snapshot.ReceivedBytes { snapshot.ReceivedBytes = confirmedBytes } if snapshot.Size > 0 && snapshot.ReceivedBytes > snapshot.Size { snapshot.ReceivedBytes = snapshot.Size } } snapshot.State = StateActive snapshot.InflightBytes = inflightBytes(*snapshot) return nil }) } func (m *Manager) RecordSend(id string, sentBytes int64) (Snapshot, error) { if sentBytes < 0 { return Snapshot{}, ErrTransferBytesInvalid } return m.update(id, func(snapshot *Snapshot) error { snapshot.SentBytes += sentBytes if snapshot.Size > 0 && snapshot.SentBytes > snapshot.Size { snapshot.SentBytes = snapshot.Size } snapshot.InflightBytes = inflightBytes(*snapshot) if !snapshot.State.Terminal() { snapshot.State = StateActive } return nil }) } func (m *Manager) RecordReceive(id string, recvBytes int64) (Snapshot, error) { if recvBytes < 0 { return Snapshot{}, ErrTransferBytesInvalid } return m.update(id, func(snapshot *Snapshot) error { snapshot.ReceivedBytes += recvBytes if snapshot.Size > 0 && snapshot.ReceivedBytes > snapshot.Size { snapshot.ReceivedBytes = snapshot.Size } if !snapshot.State.Terminal() { snapshot.State = StateActive } return nil }) } func (m *Manager) SetAckedBytes(id string, ackedBytes int64) (Snapshot, error) { if ackedBytes < 0 { return Snapshot{}, ErrTransferBytesInvalid } return m.update(id, func(snapshot *Snapshot) error { snapshot.AckedBytes = ackedBytes if snapshot.Size > 0 && snapshot.AckedBytes > snapshot.Size { snapshot.AckedBytes = snapshot.Size } if snapshot.AckedBytes > snapshot.SentBytes { snapshot.SentBytes = snapshot.AckedBytes } snapshot.InflightBytes = inflightBytes(*snapshot) return nil }) } func (m *Manager) BeginCommit(id string) (Snapshot, error) { return m.update(id, func(snapshot *Snapshot) error { snapshot.State = StateCommitting return nil }) } func (m *Manager) BeginVerify(id string) (Snapshot, error) { return m.update(id, func(snapshot *Snapshot) error { snapshot.State = StateVerifying return nil }) } func (m *Manager) Complete(id string) (Snapshot, error) { now := m.currentTime() return m.updateWithTime(id, now, func(snapshot *Snapshot, now time.Time) error { snapshot.State = StateDone snapshot.CompletedAt = now.UnixNano() snapshot.InflightBytes = inflightBytes(*snapshot) return nil }) } func (m *Manager) Abort(id string, err error) (Snapshot, error) { return m.finishWithError(id, StateAborted, err) } func (m *Manager) Fail(id string, err error) (Snapshot, error) { return m.finishWithError(id, StateFailed, err) } func (m *Manager) RecordRetry(id string) (Snapshot, error) { return m.update(id, func(snapshot *Snapshot) error { snapshot.RetryCount++ return nil }) } func (m *Manager) RecordTimeout(id string) (Snapshot, error) { return m.update(id, func(snapshot *Snapshot) error { snapshot.TimeoutCount++ return nil }) } func (m *Manager) SetStage(id string, stage string) (Snapshot, error) { return m.update(id, func(snapshot *Snapshot) error { snapshot.Stage = stage return nil }) } func (m *Manager) SetFailureStage(id string, stage string) (Snapshot, error) { return m.update(id, func(snapshot *Snapshot) error { snapshot.LastFailureStage = stage if stage != "" { snapshot.Stage = stage } return nil }) } func (m *Manager) MergeMetadata(id string, metadata Metadata) (Snapshot, error) { return m.update(id, func(snapshot *Snapshot) error { if len(metadata) == 0 { return nil } if snapshot.Metadata == nil { snapshot.Metadata = make(Metadata, len(metadata)) } for key, value := range metadata { if value == "" { delete(snapshot.Metadata, key) continue } snapshot.Metadata[key] = value } return nil }) } func (m *Manager) RecordTelemetry(id string, delta TelemetryDelta) (Snapshot, error) { return m.update(id, func(snapshot *Snapshot) error { if delta.SourceReadDuration > 0 { snapshot.SourceReadDuration += delta.SourceReadDuration } if delta.StreamWriteDuration > 0 { snapshot.StreamWriteDuration += delta.StreamWriteDuration } if delta.SinkWriteDuration > 0 { snapshot.SinkWriteDuration += delta.SinkWriteDuration } if delta.SyncDuration > 0 { snapshot.SyncDuration += delta.SyncDuration } if delta.VerifyDuration > 0 { snapshot.VerifyDuration += delta.VerifyDuration } if delta.CommitDuration > 0 { snapshot.CommitDuration += delta.CommitDuration } if delta.CommitWaitDuration > 0 { snapshot.CommitWaitDuration += delta.CommitWaitDuration } if delta.SourceReadCount > 0 { snapshot.SourceReadCount += delta.SourceReadCount } if delta.StreamWriteCount > 0 { snapshot.StreamWriteCount += delta.StreamWriteCount } if delta.SinkWriteCount > 0 { snapshot.SinkWriteCount += delta.SinkWriteCount } return nil }) } func (m *Manager) Snapshot(id string) (Snapshot, bool) { m.mu.Lock() defer m.mu.Unlock() transfer, ok := m.transfers[id] if !ok { return Snapshot{}, false } return cloneSnapshot(transfer.snapshot), true } func (m *Manager) Snapshots() []Snapshot { m.mu.Lock() defer m.mu.Unlock() out := make([]Snapshot, 0, len(m.transfers)) for _, transfer := range m.transfers { out = append(out, cloneSnapshot(transfer.snapshot)) } sort.Slice(out, func(i int, j int) bool { return out[i].ID < out[j].ID }) return out } func (m *Manager) Restore(snapshot Snapshot) (Snapshot, error) { if snapshot.ID == "" { return Snapshot{}, ErrTransferIDEmpty } m.mu.Lock() defer m.mu.Unlock() m.transfers[snapshot.ID] = &managedTransfer{snapshot: cloneSnapshot(snapshot)} return cloneSnapshot(snapshot), nil } func (m *Manager) start(desc Descriptor, direction Direction, state State) (Snapshot, error) { if desc.ID == "" { return Snapshot{}, ErrTransferIDEmpty } now := m.currentTime() m.mu.Lock() defer m.mu.Unlock() if _, exists := m.transfers[desc.ID]; exists { return Snapshot{}, ErrTransferExists } snapshot := Snapshot{ ID: desc.ID, Direction: direction, Channel: normalizeChannel(desc.Channel), State: state, Size: desc.Size, Checksum: desc.Checksum, Metadata: cloneMetadata(desc.Metadata), StartedAt: now.UnixNano(), UpdatedAt: now.UnixNano(), } m.transfers[desc.ID] = &managedTransfer{snapshot: snapshot} return cloneSnapshot(snapshot), nil } func (m *Manager) finishWithError(id string, state State, err error) (Snapshot, error) { now := m.currentTime() return m.updateWithTime(id, now, func(snapshot *Snapshot, now time.Time) error { snapshot.State = state snapshot.CompletedAt = now.UnixNano() if err != nil { snapshot.LastError = err.Error() } snapshot.InflightBytes = inflightBytes(*snapshot) return nil }) } func (m *Manager) update(id string, fn func(*Snapshot) error) (Snapshot, error) { return m.updateWithTime(id, m.currentTime(), func(snapshot *Snapshot, _ time.Time) error { return fn(snapshot) }) } func (m *Manager) updateWithTime(id string, now time.Time, fn func(*Snapshot, time.Time) error) (Snapshot, error) { m.mu.Lock() defer m.mu.Unlock() transfer, ok := m.transfers[id] if !ok { return Snapshot{}, ErrTransferNotFound } snapshot := &transfer.snapshot if err := fn(snapshot, now); err != nil { return Snapshot{}, err } snapshot.UpdatedAt = now.UnixNano() return cloneSnapshot(*snapshot), nil } func (m *Manager) currentTime() time.Time { return m.now() } func inflightBytes(snapshot Snapshot) int64 { if snapshot.Direction != DirectionSend { return 0 } if snapshot.SentBytes <= snapshot.AckedBytes { return 0 } return snapshot.SentBytes - snapshot.AckedBytes }