package notify import ( itransfer "b612.me/notify/internal/transfer" "context" "crypto/sha256" "encoding/binary" "encoding/hex" "errors" "fmt" "io" "strings" "sync" "time" ) const ( transferStreamMetadataKindKey = "_notify.transfer_stream_kind" transferStreamMetadataKindValue = "segment" transferFrameHeaderSize = 4 transferFrameAggregateLimit = 128 * 1024 transferFrameAggregateCount = 8 transferCommitWaitTimeout = 30 * time.Second transferChecksumChunkSize = 64 * 1024 ) type transferSendTarget struct { runtime *transferRuntime runtimeScope string publicScope string transportGeneration uint64 logical *LogicalConn transport *TransportConn sequenceEn func(interface{}) ([]byte, error) sequenceDe func([]byte) (interface{}, error) openStream func(context.Context, StreamOpenOptions) (Stream, error) sendBegin func(context.Context, TransferBeginRequest) (TransferBeginResponse, error) sendResume func(context.Context, TransferResumeRequest) (TransferResumeResponse, error) sendCommit func(context.Context, TransferCommitRequest) (TransferCommitResponse, error) sendAbort func(context.Context, TransferAbortRequest) (TransferAbortResponse, error) } type transferSendHooks struct { onNegotiated func(nextOffset int64, resumed bool) onSegmentSent func(offset int64, sentBytes int64) onCommitted func() onAbort func(stage string, offset int64, err error) } type transferSendHandle struct { id string runtime *transferRuntime scope string logical *LogicalConn transport *TransportConn abortFn func(context.Context, TransferAbortRequest) error cancel context.CancelFunc mu sync.Mutex stream Stream result error done chan struct{} once sync.Once } func newTransferSendHandle(id string, runtime *transferRuntime, scope string, logical *LogicalConn, transport *TransportConn, cancel context.CancelFunc, abortFn func(context.Context, TransferAbortRequest) error) *transferSendHandle { return &transferSendHandle{ id: id, runtime: runtime, scope: normalizeFileScope(scope), logical: logical, transport: transport, abortFn: abortFn, cancel: cancel, done: make(chan struct{}), } } func (h *transferSendHandle) ID() string { if h == nil { return "" } return h.id } func (h *transferSendHandle) LogicalConn() *LogicalConn { if h == nil { return nil } return h.logical } func (h *transferSendHandle) TransportConn() *TransportConn { if h == nil { return nil } return h.transport } func (h *transferSendHandle) Snapshot() (TransferSnapshot, bool) { if h == nil || h.runtime == nil || h.id == "" { return TransferSnapshot{}, false } snapshot, ok := h.runtime.snapshot(fileTransferDirectionSend, h.scope, h.id) if !ok { return TransferSnapshot{}, false } return convertTransferSnapshot(snapshot), true } func (h *transferSendHandle) Wait(ctx context.Context) error { if h == nil { return nil } if ctx == nil { ctx = context.Background() } select { case <-h.done: h.mu.Lock() err := h.result h.mu.Unlock() return err case <-ctx.Done(): return ctx.Err() } } func (h *transferSendHandle) Abort(ctx context.Context, cause error) error { if h == nil { return nil } if cause == nil { cause = context.Canceled } if h.cancel != nil { h.cancel() } stream := h.streamSnapshot() if stream != nil { _ = stream.Reset(cause) } var err error if h.abortFn != nil { req := TransferAbortRequest{ TransferID: h.id, Stage: "abort", Error: cause.Error(), } err = h.abortFn(ctx, req) } h.finish(cause) return err } func (h *transferSendHandle) setStream(stream Stream) { if h == nil { return } h.mu.Lock() h.stream = stream h.mu.Unlock() } func (h *transferSendHandle) clearStream(stream Stream) { if h == nil { return } h.mu.Lock() if h.stream == stream { h.stream = nil } h.mu.Unlock() } func (h *transferSendHandle) streamSnapshot() Stream { if h == nil { return nil } h.mu.Lock() defer h.mu.Unlock() return h.stream } func (h *transferSendHandle) finish(err error) { if h == nil { return } h.once.Do(func() { h.mu.Lock() h.result = err h.mu.Unlock() close(h.done) }) } func (c *ClientCommon) SetTransferHandler(fn func(TransferAcceptInfo) (TransferReceiveOptions, error)) { state := c.getTransferState() if state == nil { return } state.setHandler(fn) } func (s *ServerCommon) SetTransferHandler(fn func(TransferAcceptInfo) (TransferReceiveOptions, error)) { state := s.getTransferState() if state == nil { return } state.setHandler(fn) } func (c *ClientCommon) SendTransfer(ctx context.Context, opt TransferSendOptions) (TransferHandle, error) { if c == nil { return nil, errTransferControlClientNil } target := transferSendTarget{ runtime: c.getTransferRuntime(), runtimeScope: clientFileScope(), publicScope: clientFileScope(), transportGeneration: 0, sequenceEn: c.sequenceEn, sequenceDe: c.sequenceDe, openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) { return c.OpenStream(ctx, opt) }, sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) { return SendTransferBeginClient(ctx, c, req) }, sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) { return SendTransferResumeClient(ctx, c, req) }, sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) { return SendTransferCommitClient(ctx, c, req) }, sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) { return SendTransferAbortClient(ctx, c, req) }, } return startTransferSendWithHooks(ctx, opt, target, transferSendHooks{}) } func (s *ServerCommon) SendTransferLogical(ctx context.Context, logical *LogicalConn, opt TransferSendOptions) (TransferHandle, error) { if s == nil { return nil, errTransferControlServerNil } if logical == nil { return nil, errTransferControlLogicalConnNil } target := transferSendTarget{ runtime: s.getTransferRuntime(), runtimeScope: serverTransportScope(logical), publicScope: serverFileScope(logical), transportGeneration: logical.transportGenerationSnapshot(), logical: logical, transport: logical.CurrentTransportConn(), sequenceEn: s.sequenceEn, sequenceDe: s.sequenceDe, openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) { return s.OpenStreamLogical(ctx, logical, opt) }, sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) { return SendTransferBeginLogical(ctx, s, logical, req) }, sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) { return SendTransferResumeLogical(ctx, s, logical, req) }, sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) { return SendTransferCommitLogical(ctx, s, logical, req) }, sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) { return SendTransferAbortLogical(ctx, s, logical, req) }, } return startTransferSendWithHooks(ctx, opt, target, transferSendHooks{}) } func (s *ServerCommon) SendTransferTransport(ctx context.Context, transport *TransportConn, opt TransferSendOptions) (TransferHandle, error) { if s == nil { return nil, errTransferControlServerNil } if transport == nil { return nil, errTransferControlTransportNil } logical := transport.LogicalConn() if logical == nil { return nil, errTransferControlLogicalConnNil } target := transferSendTarget{ runtime: s.getTransferRuntime(), runtimeScope: serverTransportScopeForTransport(transport), publicScope: serverFileScope(logical), transportGeneration: transport.TransportGeneration(), logical: logical, transport: transport, sequenceEn: s.sequenceEn, sequenceDe: s.sequenceDe, openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) { return s.OpenStreamTransport(ctx, transport, opt) }, sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) { return SendTransferBeginTransport(ctx, s, transport, req) }, sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) { return SendTransferResumeTransport(ctx, s, transport, req) }, sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) { return SendTransferCommitTransport(ctx, s, transport, req) }, sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) { return SendTransferAbortTransport(ctx, s, transport, req) }, } return startTransferSendWithHooks(ctx, opt, target, transferSendHooks{}) } func startTransferSend(ctx context.Context, opt TransferSendOptions, target transferSendTarget) (TransferHandle, error) { return startTransferSendWithHooks(ctx, opt, target, transferSendHooks{}) } func startTransferSendWithHooks(ctx context.Context, opt TransferSendOptions, target transferSendTarget, hooks transferSendHooks) (TransferHandle, error) { opt, err := normalizeTransferSendOptions(opt) if err != nil { if hooks.onAbort != nil { hooks.onAbort("prepare", 0, err) } return nil, err } desc := opt.Descriptor if target.sequenceEn == nil { recordTransferDataFailure(target.runtime, target.runtimeScope, desc.ID, errTransferSequenceEncodeNil) if hooks.onAbort != nil { hooks.onAbort("prepare", 0, errTransferSequenceEncodeNil) } return nil, errTransferSequenceEncodeNil } useResume := false if snapshot, ok := target.runtime.snapshot(fileTransferDirectionSend, target.runtimeScope, desc.ID); ok { switch snapshot.State { case itransfer.StateDone: return nil, errTransferAlreadyCompleted case itransfer.StateAborted: return nil, errTransferAlreadyAborted default: useResume = true } } var nextOffset int64 if useResume { resp, err := target.sendResume(ctx, TransferResumeRequest{TransferID: desc.ID}) if err != nil { if hooks.onAbort != nil { hooks.onAbort("resume", 0, err) } return nil, err } nextOffset = transferControlOffset(resp.NextOffset) } else { resp, err := target.sendBegin(ctx, TransferBeginRequest{ TransferID: desc.ID, Channel: desc.Channel, Size: desc.Size, Checksum: desc.Checksum, Metadata: cloneTransferMetadata(desc.Metadata), }) if err != nil { if hooks.onAbort != nil { hooks.onAbort("begin", 0, err) } return nil, err } nextOffset = transferControlOffset(resp.NextOffset) } if hooks.onNegotiated != nil { hooks.onNegotiated(nextOffset, useResume) } recordTransferSendRuntimeOptions(target.runtime, target.runtimeScope, opt) stream, err := target.openStream(ctx, StreamOpenOptions{ Channel: transferChannelToStreamChannel(desc.Channel), Metadata: buildTransferStreamMetadata(desc.ID), }) if err != nil { recordTransferDataFailure(target.runtime, target.runtimeScope, desc.ID, err) if hooks.onAbort != nil { hooks.onAbort("stream.open", nextOffset, err) } return nil, err } runCtx, cancel := context.WithCancel(transferRunContext(ctx)) handle := newTransferSendHandle(desc.ID, target.runtime, target.runtimeScope, target.logical, target.transport, cancel, func(ctx context.Context, req TransferAbortRequest) error { _, err := target.sendAbort(ctx, req) return err }) handle.setStream(stream) go func() { err := runTransferSend(runCtx, stream, opt, nextOffset, target, hooks) handle.clearStream(stream) handle.finish(err) }() return handle, nil } func runTransferSend(ctx context.Context, stream Stream, opt TransferSendOptions, nextOffset int64, target transferSendTarget, hooks transferSendHooks) error { desc := opt.Descriptor sendErr := sendTransferSegments(ctx, stream, target, opt, nextOffset, hooks) closeErr := stream.Close() if sendErr != nil { recordTransferDataFailure(target.runtime, target.runtimeScope, desc.ID, sendErr) if hooks.onAbort != nil { hooks.onAbort("data", nextOffset, sendErr) } return sendErr } if closeErr != nil && !errors.Is(closeErr, errStreamNotFound) { recordTransferDataFailure(target.runtime, target.runtimeScope, desc.ID, closeErr) if hooks.onAbort != nil { hooks.onAbort("stream.close", desc.Size, closeErr) } return closeErr } commitStartedAt := time.Now() _, err := target.sendCommit(ctx, TransferCommitRequest{ TransferID: desc.ID, Size: desc.Size, Checksum: desc.Checksum, }) if target.runtime != nil { target.runtime.recordCommitWaitDuration(fileTransferDirectionSend, target.runtimeScope, desc.ID, time.Since(commitStartedAt)) } if err != nil { if hooks.onAbort != nil { hooks.onAbort("commit", desc.Size, err) } return err } if target.runtime != nil { target.runtime.setAckedBytes(fileTransferDirectionSend, target.runtimeScope, desc.ID, desc.Size) } if hooks.onCommitted != nil { hooks.onCommitted() } return nil } func sendTransferSegments(ctx context.Context, stream Stream, target transferSendTarget, opt TransferSendOptions, nextOffset int64, hooks transferSendHooks) error { if ctx == nil { ctx = context.Background() } var err error opt, err = normalizeTransferSendOptions(opt) if err != nil { return err } recordTransferSendRuntimeOptions(target.runtime, target.runtimeScope, opt) desc := opt.Descriptor if nextOffset < 0 { nextOffset = 0 } if nextOffset > desc.Size { return errTransferSizeMismatch } if opt.Parallelism > 1 { return sendTransferSegmentsWindowed(ctx, stream, target, opt, nextOffset, hooks) } return sendTransferSegmentsSerial(ctx, stream, target, opt, nextOffset, hooks) } func recordTransferSendRuntimeOptions(runtime *transferRuntime, runtimeScope string, opt TransferSendOptions) { if runtime == nil || opt.Descriptor.ID == "" { return } runtime.recordSendOptions(fileTransferDirectionSend, runtimeScope, opt.Descriptor.ID, opt.ChunkSize, opt.Parallelism, opt.MaxInflightBytes) } func (c *ClientCommon) dispatchInternalTransferControl(message Message) bool { state := c.getTransferState() if !state.controlEnabledSnapshot() || !isInternalTransferControlKey(message.Key) { return false } dispatchInternalTransferControlMessage(&message, state, c.getTransferRuntime(), clientFileScope(), clientFileScope(), 0, nil, nil) return true } func (s *ServerCommon) dispatchInternalTransferControl(message Message) bool { state := s.getTransferState() if !state.controlEnabledSnapshot() || !isInternalTransferControlKey(message.Key) { return false } logical := messageLogicalConnSnapshot(&message) transport := messageTransportConnSnapshot(&message) dispatchInternalTransferControlMessage( &message, state, s.getTransferRuntime(), transferRuntimeScopeForPeer(logical, transport), transferPublicScopeForPeer(logical), transferGenerationForPeer(logical, transport), logical, transport, ) return true } func dispatchInternalTransferControlMessage(msg *Message, state *transferState, runtime *transferRuntime, runtimeScope string, publicScope string, transportGeneration uint64, logical *LogicalConn, transport *TransportConn) { if msg == nil || state == nil { return } switch msg.Key { case TransferBeginSignalKey: var req TransferBeginRequest resp := TransferBeginResponse{} if err := msg.Value.Orm(&req); err != nil { resp.Error = err.Error() replyTransferControlIfNeeded(msg, resp) return } transferControlPrepareBegin(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req) resp, err := state.acceptBegin(runtime, req, runtimeScope, publicScope, transportGeneration, logical, transport) if resp.TransferID == "" { resp.TransferID = req.TransferID } if err != nil && resp.Error == "" { resp.Error = err.Error() } transferControlFinishBegin(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, resp, transferControlResultError("begin", resp.Accepted, resp.Error, err)) replyTransferControlIfNeeded(msg, resp) case TransferResumeSignalKey: var req TransferResumeRequest resp := TransferResumeResponse{} if err := msg.Value.Orm(&req); err != nil { resp.Error = err.Error() replyTransferControlIfNeeded(msg, resp) return } transferControlPrepareResume(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req) resp, err := state.acceptResume(runtime, req, publicScope, runtimeScope, logical, transport, transportGeneration) if resp.TransferID == "" { resp.TransferID = req.TransferID } if err != nil && resp.Error == "" { resp.Error = err.Error() } transferControlFinishResume(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, resp, transferControlResultError("resume", resp.Accepted, resp.Error, err)) replyTransferControlIfNeeded(msg, resp) case TransferCommitSignalKey: var req TransferCommitRequest resp := TransferCommitResponse{} if err := msg.Value.Orm(&req); err != nil { resp.Error = err.Error() replyTransferControlIfNeeded(msg, resp) return } transferControlPrepareCommit(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req) resp, err := state.acceptCommit(runtime, req, publicScope) if resp.TransferID == "" { resp.TransferID = req.TransferID } if err != nil && resp.Error == "" { resp.Error = err.Error() } transferControlFinishCommit(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, resp, transferControlResultError("commit", resp.Accepted, resp.Error, err)) replyTransferControlIfNeeded(msg, resp) case TransferAbortSignalKey: var req TransferAbortRequest resp := TransferAbortResponse{} if err := msg.Value.Orm(&req); err != nil { resp.Error = err.Error() replyTransferControlIfNeeded(msg, resp) return } transferControlPrepareAbort(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req) resp, err := state.acceptAbort(req, publicScope) if resp.TransferID == "" { resp.TransferID = req.TransferID } if err != nil && resp.Error == "" { resp.Error = err.Error() } transferControlFinishAbort(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, req, resp, transferControlResultError("abort", resp.Accepted, resp.Error, err)) replyTransferControlIfNeeded(msg, resp) } } func (state *transferState) acceptBegin(runtime *transferRuntime, req TransferBeginRequest, runtimeScope string, publicScope string, transportGeneration uint64, logical *LogicalConn, transport *TransportConn) (TransferBeginResponse, error) { if req.TransferID == "" { return TransferBeginResponse{}, errTransferIDEmpty } desc := transferDescriptorFromBegin(req) if session, ok := state.load(publicScope, req.TransferID); ok { existing := session.descriptorSnapshot() if !transferDescriptorsCompatible(existing, desc) { return TransferBeginResponse{TransferID: req.TransferID}, fmt.Errorf("transfer descriptor mismatch") } session.updateBinding(runtimeScope, logical, transport, transportGeneration) return TransferBeginResponse{ TransferID: req.TransferID, Accepted: true, NextOffset: session.nextOffsetSnapshot(), }, nil } if restored, ok, err := state.restoreReceiveSession(runtime, publicScope, runtimeScope, logical, transport, transportGeneration, desc); ok { if err != nil { return TransferBeginResponse{TransferID: req.TransferID}, err } return TransferBeginResponse{ TransferID: req.TransferID, Accepted: true, NextOffset: restored.nextOffsetSnapshot(), }, nil } info := TransferAcceptInfo{ Descriptor: cloneTransferDescriptor(desc), LogicalConn: logical, TransportConn: transport, TransportGeneration: transportGeneration, } opt, err := state.acceptOptions(info) if err != nil { return TransferBeginResponse{TransferID: req.TransferID}, err } opt, err = normalizeTransferReceiveOptions(desc, opt) if err != nil { return TransferBeginResponse{TransferID: req.TransferID}, err } session := newTransferReceiveSession(publicScope, runtimeScope, logical, transport, transportGeneration, opt) if err := state.store(publicScope, req.TransferID, session); err != nil { if existing, ok := state.load(publicScope, req.TransferID); ok { current := existing.descriptorSnapshot() if transferDescriptorsCompatible(current, desc) { existing.updateBinding(runtimeScope, logical, transport, transportGeneration) return TransferBeginResponse{ TransferID: req.TransferID, Accepted: true, NextOffset: existing.nextOffsetSnapshot(), }, nil } } return TransferBeginResponse{TransferID: req.TransferID}, err } return TransferBeginResponse{ TransferID: req.TransferID, Accepted: true, NextOffset: session.nextOffsetSnapshot(), }, nil } func (state *transferState) acceptResume(runtime *transferRuntime, req TransferResumeRequest, publicScope string, runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64) (TransferResumeResponse, error) { if req.TransferID == "" { return TransferResumeResponse{}, errTransferIDEmpty } session, ok := state.load(publicScope, req.TransferID) if !ok { snapshot, found := runtime.resumableSnapshot(fileTransferDirectionReceive, publicScope, req.TransferID) if !found { return TransferResumeResponse{TransferID: req.TransferID}, errTransferSessionNotFound } var err error session, ok, err = state.restoreReceiveSession(runtime, publicScope, runtimeScope, logical, transport, transportGeneration, transferDescriptorFromSnapshot(snapshot)) if err != nil { return TransferResumeResponse{TransferID: req.TransferID}, err } if !ok || session == nil { return TransferResumeResponse{TransferID: req.TransferID}, errTransferSessionNotFound } } session.updateBinding(runtimeScope, logical, transport, transportGeneration) return TransferResumeResponse{ TransferID: req.TransferID, Accepted: true, NextOffset: session.nextOffsetSnapshot(), Missing: nil, Error: "", }, nil } func (state *transferState) acceptCommit(runtime *transferRuntime, req TransferCommitRequest, publicScope string) (TransferCommitResponse, error) { if req.TransferID == "" { return TransferCommitResponse{}, errTransferIDEmpty } session, ok := state.load(publicScope, req.TransferID) if !ok { return TransferCommitResponse{TransferID: req.TransferID}, errTransferSessionNotFound } ctx, cancel := context.WithTimeout(context.Background(), transferCommitWaitTimeout) defer cancel() if err := session.commit(ctx, runtime, req.TransferID); err != nil { return TransferCommitResponse{TransferID: req.TransferID}, err } state.remove(publicScope, req.TransferID) return TransferCommitResponse{ TransferID: req.TransferID, Accepted: true, }, nil } func (state *transferState) acceptAbort(req TransferAbortRequest, publicScope string) (TransferAbortResponse, error) { if req.TransferID == "" { return TransferAbortResponse{}, errTransferIDEmpty } if session := state.remove(publicScope, req.TransferID); session != nil { session.close(transferControlAbortError(req.Error)) } return TransferAbortResponse{ TransferID: req.TransferID, Accepted: true, }, nil } func replyTransferControlIfNeeded(msg *Message, value interface{}) { if msg == nil || !requiresSignalReplyWait(msg.TransferMsg) { return } _ = msg.ReplyObj(value) } func isInternalTransferControlKey(key string) bool { switch key { case TransferBeginSignalKey, TransferResumeSignalKey, TransferCommitSignalKey, TransferAbortSignalKey: return true default: return false } } func (c *ClientCommon) claimInboundTransferStream(stream *streamHandle) (bool, error) { return claimInboundTransferStream(stream, c.getTransferState(), c.getTransferRuntime(), clientFileScope(), clientFileScope(), nil, nil, c.sequenceDe) } func (s *ServerCommon) claimInboundTransferStream(logical *LogicalConn, transport *TransportConn, stream *streamHandle) (bool, error) { if logical == nil { return false, nil } return claimInboundTransferStream( stream, s.getTransferState(), s.getTransferRuntime(), serverFileScope(logical), transferRuntimeScopeForPeer(logical, transport), logical, transport, s.sequenceDe, ) } func claimInboundTransferStream(stream *streamHandle, state *transferState, runtime *transferRuntime, publicScope string, runtimeScope string, logical *LogicalConn, transport *TransportConn, sequenceDe func([]byte) (interface{}, error)) (bool, error) { if stream == nil || state == nil || !state.controlEnabledSnapshot() { return false, nil } transferID, ok := parseTransferStreamMetadata(stream.Metadata()) if !ok { return false, nil } session, ok := state.load(publicScope, transferID) if !ok { return true, errTransferSessionNotFound } if err := session.beginStream(stream.ID(), runtimeScope, logical, transport, stream.TransportGeneration()); err != nil { return true, err } go func() { err := receiveTransferStream(stream, session, transferID, sequenceDe, runtime) if err != nil { if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) { _ = stream.Reset(err) } if runtime != nil { runtime.recordFailureStage(fileTransferDirectionReceive, session.runtimeScopeSnapshot(), transferID, "data") } } session.finishStream(stream.ID(), err) }() return true, nil } func receiveTransferStream(stream Stream, session *transferReceiveSession, transferID string, sequenceDe func([]byte) (interface{}, error), runtime *transferRuntime) error { if sequenceDe == nil { return errTransferSequenceDecodeNil } for { payload, err := readTransferFrame(stream) if err != nil { if errors.Is(err, io.EOF) { return nil } return err } segment, err := decodeTransferSegment(payload, sequenceDe) if err != nil { return err } if segment.TransferID != transferID { return fmt.Errorf("transfer id mismatch") } if err := session.writeSegment(runtime, transferID, segment.Offset, segment.Payload); err != nil { return err } } } func decodeTransferSegment(payload []byte, sequenceDe func([]byte) (interface{}, error)) (itransfer.Segment, error) { value, err := sequenceDe(payload) if err != nil { return itransfer.Segment{}, err } segment, ok := value.(itransfer.Segment) if !ok { return itransfer.Segment{}, errors.New("invalid transfer segment payload") } return segment, nil } func buildTransferFrame(payload []byte) []byte { frame := make([]byte, transferFrameHeaderSize+len(payload)) binary.BigEndian.PutUint32(frame[:transferFrameHeaderSize], uint32(len(payload))) copy(frame[transferFrameHeaderSize:], payload) return frame } func writeTransferFrames(stream Stream, data []byte) error { for len(data) > 0 { n, err := stream.Write(data) if n > 0 { data = data[n:] } if err != nil { return err } if n == 0 { return io.ErrShortWrite } } return nil } func readTransferFrame(stream Stream) ([]byte, error) { header := make([]byte, transferFrameHeaderSize) if _, err := io.ReadFull(stream, header); err != nil { return nil, err } length := binary.BigEndian.Uint32(header) payload := make([]byte, int(length)) if _, err := io.ReadFull(stream, payload); err != nil { return nil, err } return payload, nil } func transferConfirmProgress(ctx context.Context, target transferSendTarget, transferID string) (int64, error) { if target.sendResume == nil { return 0, nil } resp, err := target.sendResume(ctx, TransferResumeRequest{TransferID: transferID}) if err != nil { return 0, err } return transferControlOffset(resp.NextOffset), nil } func normalizeTransferSendOptions(opt TransferSendOptions) (TransferSendOptions, error) { if opt.Source == nil { return TransferSendOptions{}, errTransferSourceNil } desc := normalizeTransferDescriptor(opt.Descriptor) if desc.ID == "" { return TransferSendOptions{}, errTransferIDEmpty } if desc.Size < 0 { return TransferSendOptions{}, errTransferSizeInvalid } sourceSize := opt.Source.Size() if sourceSize < 0 { return TransferSendOptions{}, errTransferSizeInvalid } if desc.Size == 0 { desc.Size = sourceSize } if opt.ChunkSize <= 0 { opt.ChunkSize = defaultFileChunkSize } if opt.Parallelism <= 0 { opt.Parallelism = 1 } if opt.MaxInflightBytes <= 0 { opt.MaxInflightBytes = int64(opt.ChunkSize * opt.Parallelism) } if opt.MaxInflightBytes < int64(opt.ChunkSize) { opt.MaxInflightBytes = int64(opt.ChunkSize) } if _, ok := opt.Source.(transferSequentialReaderSource); ok { if opt.Parallelism > 1 { return TransferSendOptions{}, errTransferSequentialSourceParallelism } if opt.VerifyChecksum && desc.Checksum == "" { return TransferSendOptions{}, errTransferSequentialSourceImplicitChecksum } } if opt.VerifyChecksum && desc.Checksum == "" { sum, err := computeTransferChecksum(opt.Source, desc.Size) if err != nil { return TransferSendOptions{}, err } desc.Checksum = sum } opt.Descriptor = desc return opt, nil } func normalizeTransferReceiveOptions(desc TransferDescriptor, opt TransferReceiveOptions) (TransferReceiveOptions, error) { if opt.Sink == nil { return TransferReceiveOptions{}, errTransferSinkNil } if opt.Descriptor.ID != "" && opt.Descriptor.ID != desc.ID { return TransferReceiveOptions{}, fmt.Errorf("transfer id mismatch") } if opt.Descriptor.Channel != "" && normalizeTransferChannel(opt.Descriptor.Channel) != normalizeTransferChannel(desc.Channel) { return TransferReceiveOptions{}, fmt.Errorf("transfer channel mismatch") } if opt.Descriptor.Size > 0 && opt.Descriptor.Size != desc.Size { return TransferReceiveOptions{}, fmt.Errorf("transfer size mismatch") } if opt.Descriptor.Checksum != "" && !equalChecksum(opt.Descriptor.Checksum, desc.Checksum) { return TransferReceiveOptions{}, fmt.Errorf("transfer checksum mismatch") } merged := cloneTransferDescriptor(desc) if len(opt.Descriptor.Metadata) != 0 { if merged.Metadata == nil { merged.Metadata = make(map[string]string, len(opt.Descriptor.Metadata)) } for key, value := range opt.Descriptor.Metadata { merged.Metadata[key] = value } } opt.Descriptor = merged return opt, nil } func normalizeTransferDescriptor(desc TransferDescriptor) TransferDescriptor { desc.Channel = normalizeTransferChannel(desc.Channel) desc.Metadata = cloneTransferMetadata(desc.Metadata) return desc } func cloneTransferDescriptor(desc TransferDescriptor) TransferDescriptor { return normalizeTransferDescriptor(desc) } func normalizeTransferChannel(channel TransferChannel) TransferChannel { switch channel { case "", TransferChannelData: return TransferChannelData case TransferChannelControl: return TransferChannelControl default: return channel } } func transferChannelToStreamChannel(channel TransferChannel) StreamChannel { switch normalizeTransferChannel(channel) { case TransferChannelControl: return StreamControlChannel default: return StreamDataChannel } } func transferChannelToKernel(channel TransferChannel) itransfer.Channel { switch normalizeTransferChannel(channel) { case TransferChannelControl: return itransfer.ControlChannel default: return itransfer.DataChannel } } func transferDescriptorFromBegin(req TransferBeginRequest) TransferDescriptor { return normalizeTransferDescriptor(TransferDescriptor{ ID: req.TransferID, Channel: req.Channel, Size: req.Size, Checksum: req.Checksum, Metadata: cloneTransferMetadata(req.Metadata), }) } func transferDescriptorsCompatible(left TransferDescriptor, right TransferDescriptor) bool { return left.ID == right.ID && normalizeTransferChannel(left.Channel) == normalizeTransferChannel(right.Channel) && left.Size == right.Size && equalChecksum(left.Checksum, right.Checksum) } func buildTransferStreamMetadata(transferID string) StreamMetadata { return StreamMetadata{ transferStreamMetadataKindKey: transferStreamMetadataKindValue, transferMetadataIDKey: transferID, } } func parseTransferStreamMetadata(metadata StreamMetadata) (string, bool) { if len(metadata) == 0 { return "", false } if metadata[transferStreamMetadataKindKey] != transferStreamMetadataKindValue { return "", false } transferID := strings.TrimSpace(metadata[transferMetadataIDKey]) if transferID == "" { return "", false } return transferID, true } func recordTransferDataFailure(runtime *transferRuntime, runtimeScope string, transferID string, err error) { if runtime == nil || transferID == "" || err == nil { return } runtime.recordStage(fileTransferDirectionSend, runtimeScope, transferID, "data") runtime.recordFailureStage(fileTransferDirectionSend, runtimeScope, transferID, "data") runtime.fail(fileTransferDirectionSend, runtimeScope, transferID, err) } func transferRunContext(ctx context.Context) context.Context { if ctx == nil { return context.Background() } return ctx } func transferRuntimeScopeForPeer(logical *LogicalConn, transport *TransportConn) string { if transport != nil { return serverTransportScopeForTransport(transport) } if logical != nil { return serverTransportScope(logical) } return defaultFileScope } func transferPublicScopeForPeer(logical *LogicalConn) string { if logical == nil { return defaultFileScope } return serverFileScope(logical) } func transferGenerationForPeer(logical *LogicalConn, transport *TransportConn) uint64 { if transport != nil { return transport.TransportGeneration() } if logical != nil { return logical.transportGenerationSnapshot() } return 0 } func computeTransferChecksum(reader io.ReaderAt, size int64) (string, error) { if size < 0 { return "", errTransferSizeInvalid } sum := sha256.New() buf := make([]byte, transferChecksumChunkSize) for offset := int64(0); offset < size; { want := len(buf) remaining := size - offset if remaining < int64(want) { want = int(remaining) } n, err := reader.ReadAt(buf[:want], offset) if n > 0 { _, _ = sum.Write(buf[:n]) offset += int64(n) } if err != nil { if errors.Is(err, io.EOF) && offset == size { break } return "", err } if n == 0 && offset < size { return "", io.ErrNoProgress } } return hex.EncodeToString(sum.Sum(nil)), nil } func equalChecksum(got string, want string) bool { got = normalizeChecksum(got) want = normalizeChecksum(want) if got == "" || want == "" { return got == want } return strings.EqualFold(got, want) } func normalizeChecksum(value string) string { value = strings.TrimSpace(value) value = strings.TrimPrefix(value, "sha256:") value = strings.TrimPrefix(value, "SHA256:") return value }