package notify import ( "context" "fmt" "io" "sync" "time" ) type transferReceiveHandler func(TransferAcceptInfo) (TransferReceiveOptions, error) type transferBuiltinReceiveHandler func(TransferAcceptInfo) (TransferReceiveOptions, bool, error) type transferState struct { mu sync.RWMutex controlEnabled bool handler transferReceiveHandler builtinHandler transferBuiltinReceiveHandler receives map[string]*transferReceiveSession } type transferReceiveSession struct { descriptor TransferDescriptor sink TransferWriterAt syncOnCheckpoint bool verifyChecksum bool publicScope string runtimeScope string logical *LogicalConn transport *TransportConn transportGen uint64 nextOffset int64 closed bool streamID string streamActive bool streamDone chan struct{} streamErr error mu sync.Mutex } func newTransferState() *transferState { return &transferState{ receives: make(map[string]*transferReceiveSession), } } func transferSessionKey(scope string, transferID string) string { return normalizeFileScope(scope) + "|" + transferID } func (s *transferState) setHandler(fn transferReceiveHandler) { if s == nil { return } s.mu.Lock() s.controlEnabled = true s.handler = fn s.mu.Unlock() } func (s *transferState) setBuiltinHandler(fn transferBuiltinReceiveHandler) { if s == nil { return } s.mu.Lock() if fn != nil { s.controlEnabled = true } s.builtinHandler = fn s.mu.Unlock() } func (s *transferState) controlEnabledSnapshot() bool { if s == nil { return false } s.mu.RLock() defer s.mu.RUnlock() return s.controlEnabled } func (s *transferState) handlerSnapshot() transferReceiveHandler { if s == nil { return nil } s.mu.RLock() defer s.mu.RUnlock() return s.handler } func (s *transferState) builtinHandlerSnapshot() transferBuiltinReceiveHandler { if s == nil { return nil } s.mu.RLock() defer s.mu.RUnlock() return s.builtinHandler } func (s *transferState) acceptOptions(info TransferAcceptInfo) (TransferReceiveOptions, error) { if builtin := s.builtinHandlerSnapshot(); builtin != nil { opt, handled, err := builtin(info) if handled || err != nil { return opt, err } } handler := s.handlerSnapshot() if handler == nil { return TransferReceiveOptions{}, errTransferHandlerNotConfigured } return handler(info) } func (s *transferState) load(scope string, transferID string) (*transferReceiveSession, bool) { if s == nil || transferID == "" { return nil, false } s.mu.RLock() session, ok := s.receives[transferSessionKey(scope, transferID)] s.mu.RUnlock() return session, ok } func (s *transferState) store(scope string, transferID string, session *transferReceiveSession) error { if s == nil || session == nil || transferID == "" { return errTransferSessionNotFound } key := transferSessionKey(scope, transferID) s.mu.Lock() defer s.mu.Unlock() if _, exists := s.receives[key]; exists { return errTransferSessionExists } s.receives[key] = session return nil } func (s *transferState) remove(scope string, transferID string) *transferReceiveSession { if s == nil || transferID == "" { return nil } key := transferSessionKey(scope, transferID) s.mu.Lock() session := s.receives[key] delete(s.receives, key) s.mu.Unlock() return session } func (s *transferState) closeAll(err error) { s.closeMatching(func(string) bool { return true }, err) } func (s *transferState) closeScope(scope string, err error) { scope = normalizeFileScope(scope) s.closeMatching(func(key string) bool { return len(key) > len(scope) && key[:len(scope)] == scope && key[len(scope)] == '|' }, err) } func (s *transferState) closeMatching(match func(string) bool, err error) { if s == nil || match == nil { return } s.mu.Lock() sessions := make([]*transferReceiveSession, 0, len(s.receives)) for key, session := range s.receives { if !match(key) { continue } sessions = append(sessions, session) delete(s.receives, key) } s.mu.Unlock() for _, session := range sessions { session.close(err) } } func newTransferReceiveSession(scope string, runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64, opt TransferReceiveOptions) *transferReceiveSession { if transportGeneration == 0 && transport != nil { transportGeneration = transport.TransportGeneration() } if transportGeneration == 0 && logical != nil { transportGeneration = logical.transportGenerationSnapshot() } return &transferReceiveSession{ descriptor: normalizeTransferDescriptor(opt.Descriptor), sink: opt.Sink, syncOnCheckpoint: opt.SyncOnCheckpoint, verifyChecksum: opt.VerifyChecksum, publicScope: normalizeFileScope(scope), runtimeScope: normalizeFileScope(runtimeScope), logical: logical, transport: transport, transportGen: transportGeneration, nextOffset: transferReceiveInitialOffset(opt.Sink), } } func transferReceiveInitialOffset(sink TransferWriterAt) int64 { if sink == nil { return 0 } provider, ok := sink.(transferReceiveOffsetProvider) if !ok { return 0 } offset := provider.NextOffset() if offset < 0 { return 0 } return offset } func (s *transferReceiveSession) descriptorSnapshot() TransferDescriptor { if s == nil { return TransferDescriptor{} } s.mu.Lock() defer s.mu.Unlock() return cloneTransferDescriptor(s.descriptor) } func (s *transferReceiveSession) nextOffsetSnapshot() int64 { if s == nil { return 0 } s.mu.Lock() defer s.mu.Unlock() return s.nextOffset } func (s *transferReceiveSession) setNextOffset(nextOffset int64) { if s == nil { return } if nextOffset < 0 { nextOffset = 0 } s.mu.Lock() s.nextOffset = nextOffset s.mu.Unlock() } func (s *transferReceiveSession) updateBinding(runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64) { if s == nil { return } if transportGeneration == 0 && transport != nil { transportGeneration = transport.TransportGeneration() } if transportGeneration == 0 && logical != nil { transportGeneration = logical.transportGenerationSnapshot() } s.mu.Lock() s.runtimeScope = normalizeFileScope(runtimeScope) s.logical = logical s.transport = transport s.transportGen = transportGeneration s.mu.Unlock() } func (s *transferReceiveSession) beginStream(streamID string, runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64) error { if s == nil { return errTransferSessionNotFound } if streamID == "" { return errStreamIDEmpty } if transportGeneration == 0 && transport != nil { transportGeneration = transport.TransportGeneration() } if transportGeneration == 0 && logical != nil { transportGeneration = logical.transportGenerationSnapshot() } s.mu.Lock() defer s.mu.Unlock() if s.closed { return io.ErrClosedPipe } if s.streamDone != nil { select { case <-s.streamDone: default: return errTransferStreamAlreadyActive } } s.runtimeScope = normalizeFileScope(runtimeScope) s.logical = logical s.transport = transport s.transportGen = transportGeneration s.streamID = streamID s.streamDone = make(chan struct{}) s.streamErr = nil s.streamActive = true return nil } func (s *transferReceiveSession) finishStream(streamID string, err error) { if s == nil { return } s.mu.Lock() defer s.mu.Unlock() if s.streamID != streamID || s.streamDone == nil { return } s.streamErr = err if s.streamActive { close(s.streamDone) s.streamActive = false } } func (s *transferReceiveSession) waitStream(ctx context.Context) error { if s == nil { return errTransferSessionNotFound } if ctx == nil { ctx = context.Background() } s.mu.Lock() done := s.streamDone err := s.streamErr s.mu.Unlock() if done == nil { return err } select { case <-done: s.mu.Lock() err = s.streamErr s.mu.Unlock() return err case <-ctx.Done(): return ctx.Err() } } func (s *transferReceiveSession) runtimeScopeSnapshot() string { if s == nil { return defaultFileScope } s.mu.Lock() defer s.mu.Unlock() return normalizeFileScope(s.runtimeScope) } func (s *transferReceiveSession) writeSegment(runtime *transferRuntime, transferID string, segOffset int64, payload []byte) error { if s == nil { return errTransferSessionNotFound } s.mu.Lock() if s.closed { s.mu.Unlock() return io.ErrClosedPipe } offset := segOffset nextOffset := s.nextOffset if offset < nextOffset { trim := nextOffset - offset if trim >= int64(len(payload)) { s.mu.Unlock() return nil } payload = payload[trim:] offset = nextOffset } if offset != nextOffset { s.mu.Unlock() return fmt.Errorf("%w: got %d want %d", errTransferSegmentOffset, offset, nextOffset) } sink := s.sink syncOnCheckpoint := s.syncOnCheckpoint runtimeScope := s.runtimeScope s.mu.Unlock() writeStartedAt := time.Now() n, err := sink.WriteAt(payload, offset) writeDuration := time.Since(writeStartedAt) if runtime != nil { runtime.recordSinkWrite(fileTransferDirectionReceive, runtimeScope, transferID, writeDuration) } if err != nil { return err } if n != len(payload) { return io.ErrShortWrite } if syncOnCheckpoint { if syncer, ok := sink.(TransferSyncer); ok { syncStartedAt := time.Now() err := syncer.Sync(context.Background()) if runtime != nil { runtime.recordSyncDuration(fileTransferDirectionReceive, runtimeScope, transferID, time.Since(syncStartedAt)) } if err != nil { return err } } } s.mu.Lock() if s.nextOffset < offset+int64(n) { s.nextOffset = offset + int64(n) } s.mu.Unlock() if runtime != nil { runtime.activate(fileTransferDirectionReceive, runtimeScope, transferID) runtime.recordStage(fileTransferDirectionReceive, runtimeScope, transferID, "data") runtime.recordReceive(fileTransferDirectionReceive, runtimeScope, transferID, int64(n)) } return nil } func (s *transferReceiveSession) commit(ctx context.Context, runtime *transferRuntime, transferID string) error { if s == nil { return errTransferSessionNotFound } if ctx == nil { ctx = context.Background() } if err := s.waitStream(ctx); err != nil { return err } s.mu.Lock() if s.closed { s.mu.Unlock() return io.ErrClosedPipe } desc := cloneTransferDescriptor(s.descriptor) sink := s.sink received := s.nextOffset verifyChecksum := s.verifyChecksum runtimeScope := s.runtimeScope s.mu.Unlock() if desc.Size >= 0 && received != desc.Size { return fmt.Errorf("%w: got %d want %d", errTransferSizeMismatch, received, desc.Size) } if syncer, ok := sink.(TransferSyncer); ok { syncStartedAt := time.Now() err := syncer.Sync(ctx) if runtime != nil { runtime.recordSyncDuration(fileTransferDirectionReceive, runtimeScope, transferID, time.Since(syncStartedAt)) } if err != nil { return err } } if verifyChecksum && desc.Checksum != "" { reader, ok := sink.(io.ReaderAt) if !ok { return errTransferChecksumUnsupported } verifyStartedAt := time.Now() sum, err := computeTransferChecksum(reader, received) if runtime != nil { runtime.recordVerifyDuration(fileTransferDirectionReceive, runtimeScope, transferID, time.Since(verifyStartedAt)) } if err != nil { return err } if sum != "" && !equalChecksum(sum, desc.Checksum) { return errTransferChecksumMismatch } } if committer, ok := sink.(TransferCommitter); ok { commitStartedAt := time.Now() err := committer.Commit(ctx) if runtime != nil { runtime.recordCommitDuration(fileTransferDirectionReceive, runtimeScope, transferID, time.Since(commitStartedAt)) } if err != nil { return err } } s.close(nil) return nil } func (s *transferReceiveSession) close(err error) { if s == nil { return } s.mu.Lock() if s.closed { s.mu.Unlock() return } s.closed = true if err != nil { s.streamErr = err } if s.streamActive && s.streamDone != nil { close(s.streamDone) s.streamActive = false } sink := s.sink s.mu.Unlock() if closer, ok := sink.(transferCloseWithError); ok { _ = closer.CloseWithError(err) return } if closer, ok := sink.(TransferCloser); ok { _ = closer.Close() } } func (s *transferState) restoreReceiveSession(runtime *transferRuntime, publicScope string, runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64, desc TransferDescriptor) (*transferReceiveSession, bool, error) { if runtime == nil || desc.ID == "" { return nil, false, nil } snapshot, ok := runtime.resumableSnapshot(fileTransferDirectionReceive, publicScope, desc.ID) if !ok { return nil, false, nil } if !transferDescriptorsCompatible(desc, transferDescriptorFromSnapshot(snapshot)) { return nil, false, nil } info := TransferAcceptInfo{ Descriptor: cloneTransferDescriptor(desc), LogicalConn: logical, TransportConn: transport, TransportGeneration: transportGeneration, } opt, err := s.acceptOptions(info) if err != nil { return nil, true, err } opt, err = normalizeTransferReceiveOptions(desc, opt) if err != nil { return nil, true, err } session := newTransferReceiveSession(publicScope, runtimeScope, logical, transport, transportGeneration, opt) if sinkOffset := transferReceiveInitialOffset(opt.Sink); sinkOffset > snapshot.ReceivedBytes { session.setNextOffset(sinkOffset) } else { session.setNextOffset(snapshot.ReceivedBytes) } if err := s.store(publicScope, desc.ID, session); err != nil { if existing, ok := s.load(publicScope, desc.ID); ok { return existing, true, nil } return nil, true, err } return session, true, nil } func transferDescriptorFromSnapshot(snapshot TransferSnapshot) TransferDescriptor { return normalizeTransferDescriptor(TransferDescriptor{ ID: snapshot.ID, Channel: snapshot.Channel, Size: snapshot.Size, Checksum: snapshot.Checksum, Metadata: cloneTransferMetadata(snapshot.Metadata), }) }