notify/transfer_state.go

562 lines
14 KiB
Go
Raw Permalink Normal View History

package notify
import (
"context"
"fmt"
"io"
"sync"
"time"
)
type transferReceiveHandler func(TransferAcceptInfo) (TransferReceiveOptions, error)
type transferBuiltinReceiveHandler func(TransferAcceptInfo) (TransferReceiveOptions, bool, error)
type transferState struct {
mu sync.RWMutex
controlEnabled bool
handler transferReceiveHandler
builtinHandler transferBuiltinReceiveHandler
receives map[string]*transferReceiveSession
}
type transferReceiveSession struct {
descriptor TransferDescriptor
sink TransferWriterAt
syncOnCheckpoint bool
verifyChecksum bool
publicScope string
runtimeScope string
logical *LogicalConn
transport *TransportConn
transportGen uint64
nextOffset int64
closed bool
streamID string
streamActive bool
streamDone chan struct{}
streamErr error
mu sync.Mutex
}
func newTransferState() *transferState {
return &transferState{
receives: make(map[string]*transferReceiveSession),
}
}
func transferSessionKey(scope string, transferID string) string {
return normalizeFileScope(scope) + "|" + transferID
}
func (s *transferState) setHandler(fn transferReceiveHandler) {
if s == nil {
return
}
s.mu.Lock()
s.controlEnabled = true
s.handler = fn
s.mu.Unlock()
}
func (s *transferState) setBuiltinHandler(fn transferBuiltinReceiveHandler) {
if s == nil {
return
}
s.mu.Lock()
if fn != nil {
s.controlEnabled = true
}
s.builtinHandler = fn
s.mu.Unlock()
}
func (s *transferState) controlEnabledSnapshot() bool {
if s == nil {
return false
}
s.mu.RLock()
defer s.mu.RUnlock()
return s.controlEnabled
}
func (s *transferState) handlerSnapshot() transferReceiveHandler {
if s == nil {
return nil
}
s.mu.RLock()
defer s.mu.RUnlock()
return s.handler
}
func (s *transferState) builtinHandlerSnapshot() transferBuiltinReceiveHandler {
if s == nil {
return nil
}
s.mu.RLock()
defer s.mu.RUnlock()
return s.builtinHandler
}
func (s *transferState) acceptOptions(info TransferAcceptInfo) (TransferReceiveOptions, error) {
if builtin := s.builtinHandlerSnapshot(); builtin != nil {
opt, handled, err := builtin(info)
if handled || err != nil {
return opt, err
}
}
handler := s.handlerSnapshot()
if handler == nil {
return TransferReceiveOptions{}, errTransferHandlerNotConfigured
}
return handler(info)
}
func (s *transferState) load(scope string, transferID string) (*transferReceiveSession, bool) {
if s == nil || transferID == "" {
return nil, false
}
s.mu.RLock()
session, ok := s.receives[transferSessionKey(scope, transferID)]
s.mu.RUnlock()
return session, ok
}
func (s *transferState) store(scope string, transferID string, session *transferReceiveSession) error {
if s == nil || session == nil || transferID == "" {
return errTransferSessionNotFound
}
key := transferSessionKey(scope, transferID)
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.receives[key]; exists {
return errTransferSessionExists
}
s.receives[key] = session
return nil
}
func (s *transferState) remove(scope string, transferID string) *transferReceiveSession {
if s == nil || transferID == "" {
return nil
}
key := transferSessionKey(scope, transferID)
s.mu.Lock()
session := s.receives[key]
delete(s.receives, key)
s.mu.Unlock()
return session
}
func (s *transferState) closeAll(err error) {
s.closeMatching(func(string) bool { return true }, err)
}
func (s *transferState) closeScope(scope string, err error) {
scope = normalizeFileScope(scope)
s.closeMatching(func(key string) bool {
return len(key) > len(scope) && key[:len(scope)] == scope && key[len(scope)] == '|'
}, err)
}
func (s *transferState) closeMatching(match func(string) bool, err error) {
if s == nil || match == nil {
return
}
s.mu.Lock()
sessions := make([]*transferReceiveSession, 0, len(s.receives))
for key, session := range s.receives {
if !match(key) {
continue
}
sessions = append(sessions, session)
delete(s.receives, key)
}
s.mu.Unlock()
for _, session := range sessions {
session.close(err)
}
}
func newTransferReceiveSession(scope string, runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64, opt TransferReceiveOptions) *transferReceiveSession {
if transportGeneration == 0 && transport != nil {
transportGeneration = transport.TransportGeneration()
}
if transportGeneration == 0 && logical != nil {
transportGeneration = logical.transportGenerationSnapshot()
}
return &transferReceiveSession{
descriptor: normalizeTransferDescriptor(opt.Descriptor),
sink: opt.Sink,
syncOnCheckpoint: opt.SyncOnCheckpoint,
verifyChecksum: opt.VerifyChecksum,
publicScope: normalizeFileScope(scope),
runtimeScope: normalizeFileScope(runtimeScope),
logical: logical,
transport: transport,
transportGen: transportGeneration,
nextOffset: transferReceiveInitialOffset(opt.Sink),
}
}
func transferReceiveInitialOffset(sink TransferWriterAt) int64 {
if sink == nil {
return 0
}
provider, ok := sink.(transferReceiveOffsetProvider)
if !ok {
return 0
}
offset := provider.NextOffset()
if offset < 0 {
return 0
}
return offset
}
func (s *transferReceiveSession) descriptorSnapshot() TransferDescriptor {
if s == nil {
return TransferDescriptor{}
}
s.mu.Lock()
defer s.mu.Unlock()
return cloneTransferDescriptor(s.descriptor)
}
func (s *transferReceiveSession) nextOffsetSnapshot() int64 {
if s == nil {
return 0
}
s.mu.Lock()
defer s.mu.Unlock()
return s.nextOffset
}
func (s *transferReceiveSession) setNextOffset(nextOffset int64) {
if s == nil {
return
}
if nextOffset < 0 {
nextOffset = 0
}
s.mu.Lock()
s.nextOffset = nextOffset
s.mu.Unlock()
}
func (s *transferReceiveSession) updateBinding(runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64) {
if s == nil {
return
}
if transportGeneration == 0 && transport != nil {
transportGeneration = transport.TransportGeneration()
}
if transportGeneration == 0 && logical != nil {
transportGeneration = logical.transportGenerationSnapshot()
}
s.mu.Lock()
s.runtimeScope = normalizeFileScope(runtimeScope)
s.logical = logical
s.transport = transport
s.transportGen = transportGeneration
s.mu.Unlock()
}
func (s *transferReceiveSession) beginStream(streamID string, runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64) error {
if s == nil {
return errTransferSessionNotFound
}
if streamID == "" {
return errStreamIDEmpty
}
if transportGeneration == 0 && transport != nil {
transportGeneration = transport.TransportGeneration()
}
if transportGeneration == 0 && logical != nil {
transportGeneration = logical.transportGenerationSnapshot()
}
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return io.ErrClosedPipe
}
if s.streamDone != nil {
select {
case <-s.streamDone:
default:
return errTransferStreamAlreadyActive
}
}
s.runtimeScope = normalizeFileScope(runtimeScope)
s.logical = logical
s.transport = transport
s.transportGen = transportGeneration
s.streamID = streamID
s.streamDone = make(chan struct{})
s.streamErr = nil
s.streamActive = true
return nil
}
func (s *transferReceiveSession) finishStream(streamID string, err error) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
if s.streamID != streamID || s.streamDone == nil {
return
}
s.streamErr = err
if s.streamActive {
close(s.streamDone)
s.streamActive = false
}
}
func (s *transferReceiveSession) waitStream(ctx context.Context) error {
if s == nil {
return errTransferSessionNotFound
}
if ctx == nil {
ctx = context.Background()
}
s.mu.Lock()
done := s.streamDone
err := s.streamErr
s.mu.Unlock()
if done == nil {
return err
}
select {
case <-done:
s.mu.Lock()
err = s.streamErr
s.mu.Unlock()
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (s *transferReceiveSession) runtimeScopeSnapshot() string {
if s == nil {
return defaultFileScope
}
s.mu.Lock()
defer s.mu.Unlock()
return normalizeFileScope(s.runtimeScope)
}
func (s *transferReceiveSession) writeSegment(runtime *transferRuntime, transferID string, segOffset int64, payload []byte) error {
if s == nil {
return errTransferSessionNotFound
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return io.ErrClosedPipe
}
offset := segOffset
nextOffset := s.nextOffset
if offset < nextOffset {
trim := nextOffset - offset
if trim >= int64(len(payload)) {
s.mu.Unlock()
return nil
}
payload = payload[trim:]
offset = nextOffset
}
if offset != nextOffset {
s.mu.Unlock()
return fmt.Errorf("%w: got %d want %d", errTransferSegmentOffset, offset, nextOffset)
}
sink := s.sink
syncOnCheckpoint := s.syncOnCheckpoint
runtimeScope := s.runtimeScope
s.mu.Unlock()
writeStartedAt := time.Now()
n, err := sink.WriteAt(payload, offset)
writeDuration := time.Since(writeStartedAt)
if runtime != nil {
runtime.recordSinkWrite(fileTransferDirectionReceive, runtimeScope, transferID, writeDuration)
}
if err != nil {
return err
}
if n != len(payload) {
return io.ErrShortWrite
}
if syncOnCheckpoint {
if syncer, ok := sink.(TransferSyncer); ok {
syncStartedAt := time.Now()
err := syncer.Sync(context.Background())
if runtime != nil {
runtime.recordSyncDuration(fileTransferDirectionReceive, runtimeScope, transferID, time.Since(syncStartedAt))
}
if err != nil {
return err
}
}
}
s.mu.Lock()
if s.nextOffset < offset+int64(n) {
s.nextOffset = offset + int64(n)
}
s.mu.Unlock()
if runtime != nil {
runtime.activate(fileTransferDirectionReceive, runtimeScope, transferID)
runtime.recordStage(fileTransferDirectionReceive, runtimeScope, transferID, "data")
runtime.recordReceive(fileTransferDirectionReceive, runtimeScope, transferID, int64(n))
}
return nil
}
func (s *transferReceiveSession) commit(ctx context.Context, runtime *transferRuntime, transferID string) error {
if s == nil {
return errTransferSessionNotFound
}
if ctx == nil {
ctx = context.Background()
}
if err := s.waitStream(ctx); err != nil {
return err
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return io.ErrClosedPipe
}
desc := cloneTransferDescriptor(s.descriptor)
sink := s.sink
received := s.nextOffset
verifyChecksum := s.verifyChecksum
runtimeScope := s.runtimeScope
s.mu.Unlock()
if desc.Size >= 0 && received != desc.Size {
return fmt.Errorf("%w: got %d want %d", errTransferSizeMismatch, received, desc.Size)
}
if syncer, ok := sink.(TransferSyncer); ok {
syncStartedAt := time.Now()
err := syncer.Sync(ctx)
if runtime != nil {
runtime.recordSyncDuration(fileTransferDirectionReceive, runtimeScope, transferID, time.Since(syncStartedAt))
}
if err != nil {
return err
}
}
if verifyChecksum && desc.Checksum != "" {
reader, ok := sink.(io.ReaderAt)
if !ok {
return errTransferChecksumUnsupported
}
verifyStartedAt := time.Now()
sum, err := computeTransferChecksum(reader, received)
if runtime != nil {
runtime.recordVerifyDuration(fileTransferDirectionReceive, runtimeScope, transferID, time.Since(verifyStartedAt))
}
if err != nil {
return err
}
if sum != "" && !equalChecksum(sum, desc.Checksum) {
return errTransferChecksumMismatch
}
}
if committer, ok := sink.(TransferCommitter); ok {
commitStartedAt := time.Now()
err := committer.Commit(ctx)
if runtime != nil {
runtime.recordCommitDuration(fileTransferDirectionReceive, runtimeScope, transferID, time.Since(commitStartedAt))
}
if err != nil {
return err
}
}
s.close(nil)
return nil
}
func (s *transferReceiveSession) close(err error) {
if s == nil {
return
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return
}
s.closed = true
if err != nil {
s.streamErr = err
}
if s.streamActive && s.streamDone != nil {
close(s.streamDone)
s.streamActive = false
}
sink := s.sink
s.mu.Unlock()
if closer, ok := sink.(transferCloseWithError); ok {
_ = closer.CloseWithError(err)
return
}
if closer, ok := sink.(TransferCloser); ok {
_ = closer.Close()
}
}
func (s *transferState) restoreReceiveSession(runtime *transferRuntime, publicScope string, runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64, desc TransferDescriptor) (*transferReceiveSession, bool, error) {
if runtime == nil || desc.ID == "" {
return nil, false, nil
}
snapshot, ok := runtime.resumableSnapshot(fileTransferDirectionReceive, publicScope, desc.ID)
if !ok {
return nil, false, nil
}
if !transferDescriptorsCompatible(desc, transferDescriptorFromSnapshot(snapshot)) {
return nil, false, nil
}
info := TransferAcceptInfo{
Descriptor: cloneTransferDescriptor(desc),
LogicalConn: logical,
TransportConn: transport,
TransportGeneration: transportGeneration,
}
opt, err := s.acceptOptions(info)
if err != nil {
return nil, true, err
}
opt, err = normalizeTransferReceiveOptions(desc, opt)
if err != nil {
return nil, true, err
}
session := newTransferReceiveSession(publicScope, runtimeScope, logical, transport, transportGeneration, opt)
if sinkOffset := transferReceiveInitialOffset(opt.Sink); sinkOffset > snapshot.ReceivedBytes {
session.setNextOffset(sinkOffset)
} else {
session.setNextOffset(snapshot.ReceivedBytes)
}
if err := s.store(publicScope, desc.ID, session); err != nil {
if existing, ok := s.load(publicScope, desc.ID); ok {
return existing, true, nil
}
return nil, true, err
}
return session, true, nil
}
func transferDescriptorFromSnapshot(snapshot TransferSnapshot) TransferDescriptor {
return normalizeTransferDescriptor(TransferDescriptor{
ID: snapshot.ID,
Channel: snapshot.Channel,
Size: snapshot.Size,
Checksum: snapshot.Checksum,
Metadata: cloneTransferMetadata(snapshot.Metadata),
})
}