notify/internal/transfer/manager.go

367 lines
9.4 KiB
Go
Raw Permalink Normal View History

package transfer
import (
"errors"
"sort"
"sync"
"time"
)
var (
ErrTransferIDEmpty = errors.New("transfer id is empty")
ErrTransferExists = errors.New("transfer already exists")
ErrTransferNotFound = errors.New("transfer not found")
ErrTransferBytesInvalid = errors.New("transfer bytes must be non-negative")
)
type Manager struct {
mu sync.Mutex
now func() time.Time
transfers map[string]*managedTransfer
}
type managedTransfer struct {
snapshot Snapshot
}
func NewManager() *Manager {
return NewManagerWithClock(time.Now)
}
func NewManagerWithClock(now func() time.Time) *Manager {
if now == nil {
now = time.Now
}
return &Manager{
now: now,
transfers: make(map[string]*managedTransfer),
}
}
func (m *Manager) StartOutgoing(desc Descriptor) (Snapshot, error) {
return m.start(desc, DirectionSend, StateNegotiating)
}
func (m *Manager) StartIncoming(desc Descriptor) (Snapshot, error) {
return m.start(desc, DirectionReceive, StatePrepared)
}
func (m *Manager) Activate(id string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.State = StateActive
return nil
})
}
func (m *Manager) Pause(id string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
if snapshot.State.Terminal() {
return nil
}
snapshot.State = StatePaused
return nil
})
}
func (m *Manager) Resume(id string, confirmedBytes int64) (Snapshot, error) {
if confirmedBytes < 0 {
return Snapshot{}, ErrTransferBytesInvalid
}
return m.update(id, func(snapshot *Snapshot) error {
switch snapshot.Direction {
case DirectionSend:
if confirmedBytes > snapshot.SentBytes {
snapshot.SentBytes = confirmedBytes
}
snapshot.AckedBytes = confirmedBytes
if snapshot.Size > 0 && snapshot.AckedBytes > snapshot.Size {
snapshot.AckedBytes = snapshot.Size
}
case DirectionReceive:
if confirmedBytes > snapshot.ReceivedBytes {
snapshot.ReceivedBytes = confirmedBytes
}
if snapshot.Size > 0 && snapshot.ReceivedBytes > snapshot.Size {
snapshot.ReceivedBytes = snapshot.Size
}
}
snapshot.State = StateActive
snapshot.InflightBytes = inflightBytes(*snapshot)
return nil
})
}
func (m *Manager) RecordSend(id string, sentBytes int64) (Snapshot, error) {
if sentBytes < 0 {
return Snapshot{}, ErrTransferBytesInvalid
}
return m.update(id, func(snapshot *Snapshot) error {
snapshot.SentBytes += sentBytes
if snapshot.Size > 0 && snapshot.SentBytes > snapshot.Size {
snapshot.SentBytes = snapshot.Size
}
snapshot.InflightBytes = inflightBytes(*snapshot)
if !snapshot.State.Terminal() {
snapshot.State = StateActive
}
return nil
})
}
func (m *Manager) RecordReceive(id string, recvBytes int64) (Snapshot, error) {
if recvBytes < 0 {
return Snapshot{}, ErrTransferBytesInvalid
}
return m.update(id, func(snapshot *Snapshot) error {
snapshot.ReceivedBytes += recvBytes
if snapshot.Size > 0 && snapshot.ReceivedBytes > snapshot.Size {
snapshot.ReceivedBytes = snapshot.Size
}
if !snapshot.State.Terminal() {
snapshot.State = StateActive
}
return nil
})
}
func (m *Manager) SetAckedBytes(id string, ackedBytes int64) (Snapshot, error) {
if ackedBytes < 0 {
return Snapshot{}, ErrTransferBytesInvalid
}
return m.update(id, func(snapshot *Snapshot) error {
snapshot.AckedBytes = ackedBytes
if snapshot.Size > 0 && snapshot.AckedBytes > snapshot.Size {
snapshot.AckedBytes = snapshot.Size
}
if snapshot.AckedBytes > snapshot.SentBytes {
snapshot.SentBytes = snapshot.AckedBytes
}
snapshot.InflightBytes = inflightBytes(*snapshot)
return nil
})
}
func (m *Manager) BeginCommit(id string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.State = StateCommitting
return nil
})
}
func (m *Manager) BeginVerify(id string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.State = StateVerifying
return nil
})
}
func (m *Manager) Complete(id string) (Snapshot, error) {
now := m.currentTime()
return m.updateWithTime(id, now, func(snapshot *Snapshot, now time.Time) error {
snapshot.State = StateDone
snapshot.CompletedAt = now.UnixNano()
snapshot.InflightBytes = inflightBytes(*snapshot)
return nil
})
}
func (m *Manager) Abort(id string, err error) (Snapshot, error) {
return m.finishWithError(id, StateAborted, err)
}
func (m *Manager) Fail(id string, err error) (Snapshot, error) {
return m.finishWithError(id, StateFailed, err)
}
func (m *Manager) RecordRetry(id string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.RetryCount++
return nil
})
}
func (m *Manager) RecordTimeout(id string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.TimeoutCount++
return nil
})
}
func (m *Manager) SetStage(id string, stage string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.Stage = stage
return nil
})
}
func (m *Manager) SetFailureStage(id string, stage string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.LastFailureStage = stage
if stage != "" {
snapshot.Stage = stage
}
return nil
})
}
func (m *Manager) MergeMetadata(id string, metadata Metadata) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
if len(metadata) == 0 {
return nil
}
if snapshot.Metadata == nil {
snapshot.Metadata = make(Metadata, len(metadata))
}
for key, value := range metadata {
if value == "" {
delete(snapshot.Metadata, key)
continue
}
snapshot.Metadata[key] = value
}
return nil
})
}
func (m *Manager) RecordTelemetry(id string, delta TelemetryDelta) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
if delta.SourceReadDuration > 0 {
snapshot.SourceReadDuration += delta.SourceReadDuration
}
if delta.StreamWriteDuration > 0 {
snapshot.StreamWriteDuration += delta.StreamWriteDuration
}
if delta.SinkWriteDuration > 0 {
snapshot.SinkWriteDuration += delta.SinkWriteDuration
}
if delta.SyncDuration > 0 {
snapshot.SyncDuration += delta.SyncDuration
}
if delta.VerifyDuration > 0 {
snapshot.VerifyDuration += delta.VerifyDuration
}
if delta.CommitDuration > 0 {
snapshot.CommitDuration += delta.CommitDuration
}
if delta.CommitWaitDuration > 0 {
snapshot.CommitWaitDuration += delta.CommitWaitDuration
}
if delta.SourceReadCount > 0 {
snapshot.SourceReadCount += delta.SourceReadCount
}
if delta.StreamWriteCount > 0 {
snapshot.StreamWriteCount += delta.StreamWriteCount
}
if delta.SinkWriteCount > 0 {
snapshot.SinkWriteCount += delta.SinkWriteCount
}
return nil
})
}
func (m *Manager) Snapshot(id string) (Snapshot, bool) {
m.mu.Lock()
defer m.mu.Unlock()
transfer, ok := m.transfers[id]
if !ok {
return Snapshot{}, false
}
return cloneSnapshot(transfer.snapshot), true
}
func (m *Manager) Snapshots() []Snapshot {
m.mu.Lock()
defer m.mu.Unlock()
out := make([]Snapshot, 0, len(m.transfers))
for _, transfer := range m.transfers {
out = append(out, cloneSnapshot(transfer.snapshot))
}
sort.Slice(out, func(i int, j int) bool {
return out[i].ID < out[j].ID
})
return out
}
func (m *Manager) Restore(snapshot Snapshot) (Snapshot, error) {
if snapshot.ID == "" {
return Snapshot{}, ErrTransferIDEmpty
}
m.mu.Lock()
defer m.mu.Unlock()
m.transfers[snapshot.ID] = &managedTransfer{snapshot: cloneSnapshot(snapshot)}
return cloneSnapshot(snapshot), nil
}
func (m *Manager) start(desc Descriptor, direction Direction, state State) (Snapshot, error) {
if desc.ID == "" {
return Snapshot{}, ErrTransferIDEmpty
}
now := m.currentTime()
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.transfers[desc.ID]; exists {
return Snapshot{}, ErrTransferExists
}
snapshot := Snapshot{
ID: desc.ID,
Direction: direction,
Channel: normalizeChannel(desc.Channel),
State: state,
Size: desc.Size,
Checksum: desc.Checksum,
Metadata: cloneMetadata(desc.Metadata),
StartedAt: now.UnixNano(),
UpdatedAt: now.UnixNano(),
}
m.transfers[desc.ID] = &managedTransfer{snapshot: snapshot}
return cloneSnapshot(snapshot), nil
}
func (m *Manager) finishWithError(id string, state State, err error) (Snapshot, error) {
now := m.currentTime()
return m.updateWithTime(id, now, func(snapshot *Snapshot, now time.Time) error {
snapshot.State = state
snapshot.CompletedAt = now.UnixNano()
if err != nil {
snapshot.LastError = err.Error()
}
snapshot.InflightBytes = inflightBytes(*snapshot)
return nil
})
}
func (m *Manager) update(id string, fn func(*Snapshot) error) (Snapshot, error) {
return m.updateWithTime(id, m.currentTime(), func(snapshot *Snapshot, _ time.Time) error {
return fn(snapshot)
})
}
func (m *Manager) updateWithTime(id string, now time.Time, fn func(*Snapshot, time.Time) error) (Snapshot, error) {
m.mu.Lock()
defer m.mu.Unlock()
transfer, ok := m.transfers[id]
if !ok {
return Snapshot{}, ErrTransferNotFound
}
snapshot := &transfer.snapshot
if err := fn(snapshot, now); err != nil {
return Snapshot{}, err
}
snapshot.UpdatedAt = now.UnixNano()
return cloneSnapshot(*snapshot), nil
}
func (m *Manager) currentTime() time.Time {
return m.now()
}
func inflightBytes(snapshot Snapshot) int64 {
if snapshot.Direction != DirectionSend {
return 0
}
if snapshot.SentBytes <= snapshot.AckedBytes {
return 0
}
return snapshot.SentBytes - snapshot.AckedBytes
}