notify/transfer_plane.go

1107 lines
34 KiB
Go
Raw Permalink Normal View History

package notify
import (
itransfer "b612.me/notify/internal/transfer"
"context"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"
)
const (
transferStreamMetadataKindKey = "_notify.transfer_stream_kind"
transferStreamMetadataKindValue = "segment"
transferFrameHeaderSize = 4
transferFrameAggregateLimit = 128 * 1024
transferFrameAggregateCount = 8
transferCommitWaitTimeout = 30 * time.Second
transferChecksumChunkSize = 64 * 1024
)
type transferSendTarget struct {
runtime *transferRuntime
runtimeScope string
publicScope string
transportGeneration uint64
logical *LogicalConn
transport *TransportConn
sequenceEn func(interface{}) ([]byte, error)
sequenceDe func([]byte) (interface{}, error)
openStream func(context.Context, StreamOpenOptions) (Stream, error)
sendBegin func(context.Context, TransferBeginRequest) (TransferBeginResponse, error)
sendResume func(context.Context, TransferResumeRequest) (TransferResumeResponse, error)
sendCommit func(context.Context, TransferCommitRequest) (TransferCommitResponse, error)
sendAbort func(context.Context, TransferAbortRequest) (TransferAbortResponse, error)
}
type transferSendHooks struct {
onNegotiated func(nextOffset int64, resumed bool)
onSegmentSent func(offset int64, sentBytes int64)
onCommitted func()
onAbort func(stage string, offset int64, err error)
}
type transferSendHandle struct {
id string
runtime *transferRuntime
scope string
logical *LogicalConn
transport *TransportConn
abortFn func(context.Context, TransferAbortRequest) error
cancel context.CancelFunc
mu sync.Mutex
stream Stream
result error
done chan struct{}
once sync.Once
}
func newTransferSendHandle(id string, runtime *transferRuntime, scope string, logical *LogicalConn, transport *TransportConn, cancel context.CancelFunc, abortFn func(context.Context, TransferAbortRequest) error) *transferSendHandle {
return &transferSendHandle{
id: id,
runtime: runtime,
scope: normalizeFileScope(scope),
logical: logical,
transport: transport,
abortFn: abortFn,
cancel: cancel,
done: make(chan struct{}),
}
}
func (h *transferSendHandle) ID() string {
if h == nil {
return ""
}
return h.id
}
func (h *transferSendHandle) LogicalConn() *LogicalConn {
if h == nil {
return nil
}
return h.logical
}
func (h *transferSendHandle) TransportConn() *TransportConn {
if h == nil {
return nil
}
return h.transport
}
func (h *transferSendHandle) Snapshot() (TransferSnapshot, bool) {
if h == nil || h.runtime == nil || h.id == "" {
return TransferSnapshot{}, false
}
snapshot, ok := h.runtime.snapshot(fileTransferDirectionSend, h.scope, h.id)
if !ok {
return TransferSnapshot{}, false
}
return convertTransferSnapshot(snapshot), true
}
func (h *transferSendHandle) Wait(ctx context.Context) error {
if h == nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
select {
case <-h.done:
h.mu.Lock()
err := h.result
h.mu.Unlock()
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (h *transferSendHandle) Abort(ctx context.Context, cause error) error {
if h == nil {
return nil
}
if cause == nil {
cause = context.Canceled
}
if h.cancel != nil {
h.cancel()
}
stream := h.streamSnapshot()
if stream != nil {
_ = stream.Reset(cause)
}
var err error
if h.abortFn != nil {
req := TransferAbortRequest{
TransferID: h.id,
Stage: "abort",
Error: cause.Error(),
}
err = h.abortFn(ctx, req)
}
h.finish(cause)
return err
}
func (h *transferSendHandle) setStream(stream Stream) {
if h == nil {
return
}
h.mu.Lock()
h.stream = stream
h.mu.Unlock()
}
func (h *transferSendHandle) clearStream(stream Stream) {
if h == nil {
return
}
h.mu.Lock()
if h.stream == stream {
h.stream = nil
}
h.mu.Unlock()
}
func (h *transferSendHandle) streamSnapshot() Stream {
if h == nil {
return nil
}
h.mu.Lock()
defer h.mu.Unlock()
return h.stream
}
func (h *transferSendHandle) finish(err error) {
if h == nil {
return
}
h.once.Do(func() {
h.mu.Lock()
h.result = err
h.mu.Unlock()
close(h.done)
})
}
func (c *ClientCommon) SetTransferHandler(fn func(TransferAcceptInfo) (TransferReceiveOptions, error)) {
state := c.getTransferState()
if state == nil {
return
}
state.setHandler(fn)
}
func (s *ServerCommon) SetTransferHandler(fn func(TransferAcceptInfo) (TransferReceiveOptions, error)) {
state := s.getTransferState()
if state == nil {
return
}
state.setHandler(fn)
}
func (c *ClientCommon) SendTransfer(ctx context.Context, opt TransferSendOptions) (TransferHandle, error) {
if c == nil {
return nil, errTransferControlClientNil
}
target := transferSendTarget{
runtime: c.getTransferRuntime(),
runtimeScope: clientFileScope(),
publicScope: clientFileScope(),
transportGeneration: 0,
sequenceEn: c.sequenceEn,
sequenceDe: c.sequenceDe,
openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) {
return c.OpenStream(ctx, opt)
},
sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) {
return SendTransferBeginClient(ctx, c, req)
},
sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) {
return SendTransferResumeClient(ctx, c, req)
},
sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) {
return SendTransferCommitClient(ctx, c, req)
},
sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) {
return SendTransferAbortClient(ctx, c, req)
},
}
return startTransferSendWithHooks(ctx, opt, target, transferSendHooks{})
}
func (s *ServerCommon) SendTransferLogical(ctx context.Context, logical *LogicalConn, opt TransferSendOptions) (TransferHandle, error) {
if s == nil {
return nil, errTransferControlServerNil
}
if logical == nil {
return nil, errTransferControlLogicalConnNil
}
target := transferSendTarget{
runtime: s.getTransferRuntime(),
runtimeScope: serverTransportScope(logical),
publicScope: serverFileScope(logical),
transportGeneration: logical.transportGenerationSnapshot(),
logical: logical,
transport: logical.CurrentTransportConn(),
sequenceEn: s.sequenceEn,
sequenceDe: s.sequenceDe,
openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) {
return s.OpenStreamLogical(ctx, logical, opt)
},
sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) {
return SendTransferBeginLogical(ctx, s, logical, req)
},
sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) {
return SendTransferResumeLogical(ctx, s, logical, req)
},
sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) {
return SendTransferCommitLogical(ctx, s, logical, req)
},
sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) {
return SendTransferAbortLogical(ctx, s, logical, req)
},
}
return startTransferSendWithHooks(ctx, opt, target, transferSendHooks{})
}
func (s *ServerCommon) SendTransferTransport(ctx context.Context, transport *TransportConn, opt TransferSendOptions) (TransferHandle, error) {
if s == nil {
return nil, errTransferControlServerNil
}
if transport == nil {
return nil, errTransferControlTransportNil
}
logical := transport.LogicalConn()
if logical == nil {
return nil, errTransferControlLogicalConnNil
}
target := transferSendTarget{
runtime: s.getTransferRuntime(),
runtimeScope: serverTransportScopeForTransport(transport),
publicScope: serverFileScope(logical),
transportGeneration: transport.TransportGeneration(),
logical: logical,
transport: transport,
sequenceEn: s.sequenceEn,
sequenceDe: s.sequenceDe,
openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) {
return s.OpenStreamTransport(ctx, transport, opt)
},
sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) {
return SendTransferBeginTransport(ctx, s, transport, req)
},
sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) {
return SendTransferResumeTransport(ctx, s, transport, req)
},
sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) {
return SendTransferCommitTransport(ctx, s, transport, req)
},
sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) {
return SendTransferAbortTransport(ctx, s, transport, req)
},
}
return startTransferSendWithHooks(ctx, opt, target, transferSendHooks{})
}
func startTransferSend(ctx context.Context, opt TransferSendOptions, target transferSendTarget) (TransferHandle, error) {
return startTransferSendWithHooks(ctx, opt, target, transferSendHooks{})
}
func startTransferSendWithHooks(ctx context.Context, opt TransferSendOptions, target transferSendTarget, hooks transferSendHooks) (TransferHandle, error) {
opt, err := normalizeTransferSendOptions(opt)
if err != nil {
if hooks.onAbort != nil {
hooks.onAbort("prepare", 0, err)
}
return nil, err
}
desc := opt.Descriptor
if target.sequenceEn == nil {
recordTransferDataFailure(target.runtime, target.runtimeScope, desc.ID, errTransferSequenceEncodeNil)
if hooks.onAbort != nil {
hooks.onAbort("prepare", 0, errTransferSequenceEncodeNil)
}
return nil, errTransferSequenceEncodeNil
}
useResume := false
if snapshot, ok := target.runtime.snapshot(fileTransferDirectionSend, target.runtimeScope, desc.ID); ok {
switch snapshot.State {
case itransfer.StateDone:
return nil, errTransferAlreadyCompleted
case itransfer.StateAborted:
return nil, errTransferAlreadyAborted
default:
useResume = true
}
}
var nextOffset int64
if useResume {
resp, err := target.sendResume(ctx, TransferResumeRequest{TransferID: desc.ID})
if err != nil {
if hooks.onAbort != nil {
hooks.onAbort("resume", 0, err)
}
return nil, err
}
nextOffset = transferControlOffset(resp.NextOffset)
} else {
resp, err := target.sendBegin(ctx, TransferBeginRequest{
TransferID: desc.ID,
Channel: desc.Channel,
Size: desc.Size,
Checksum: desc.Checksum,
Metadata: cloneTransferMetadata(desc.Metadata),
})
if err != nil {
if hooks.onAbort != nil {
hooks.onAbort("begin", 0, err)
}
return nil, err
}
nextOffset = transferControlOffset(resp.NextOffset)
}
if hooks.onNegotiated != nil {
hooks.onNegotiated(nextOffset, useResume)
}
recordTransferSendRuntimeOptions(target.runtime, target.runtimeScope, opt)
stream, err := target.openStream(ctx, StreamOpenOptions{
Channel: transferChannelToStreamChannel(desc.Channel),
Metadata: buildTransferStreamMetadata(desc.ID),
})
if err != nil {
recordTransferDataFailure(target.runtime, target.runtimeScope, desc.ID, err)
if hooks.onAbort != nil {
hooks.onAbort("stream.open", nextOffset, err)
}
return nil, err
}
runCtx, cancel := context.WithCancel(transferRunContext(ctx))
handle := newTransferSendHandle(desc.ID, target.runtime, target.runtimeScope, target.logical, target.transport, cancel, func(ctx context.Context, req TransferAbortRequest) error {
_, err := target.sendAbort(ctx, req)
return err
})
handle.setStream(stream)
go func() {
err := runTransferSend(runCtx, stream, opt, nextOffset, target, hooks)
handle.clearStream(stream)
handle.finish(err)
}()
return handle, nil
}
func runTransferSend(ctx context.Context, stream Stream, opt TransferSendOptions, nextOffset int64, target transferSendTarget, hooks transferSendHooks) error {
desc := opt.Descriptor
sendErr := sendTransferSegments(ctx, stream, target, opt, nextOffset, hooks)
closeErr := stream.Close()
if sendErr != nil {
recordTransferDataFailure(target.runtime, target.runtimeScope, desc.ID, sendErr)
if hooks.onAbort != nil {
hooks.onAbort("data", nextOffset, sendErr)
}
return sendErr
}
if closeErr != nil && !errors.Is(closeErr, errStreamNotFound) {
recordTransferDataFailure(target.runtime, target.runtimeScope, desc.ID, closeErr)
if hooks.onAbort != nil {
hooks.onAbort("stream.close", desc.Size, closeErr)
}
return closeErr
}
commitStartedAt := time.Now()
_, err := target.sendCommit(ctx, TransferCommitRequest{
TransferID: desc.ID,
Size: desc.Size,
Checksum: desc.Checksum,
})
if target.runtime != nil {
target.runtime.recordCommitWaitDuration(fileTransferDirectionSend, target.runtimeScope, desc.ID, time.Since(commitStartedAt))
}
if err != nil {
if hooks.onAbort != nil {
hooks.onAbort("commit", desc.Size, err)
}
return err
}
if target.runtime != nil {
target.runtime.setAckedBytes(fileTransferDirectionSend, target.runtimeScope, desc.ID, desc.Size)
}
if hooks.onCommitted != nil {
hooks.onCommitted()
}
return nil
}
func sendTransferSegments(ctx context.Context, stream Stream, target transferSendTarget, opt TransferSendOptions, nextOffset int64, hooks transferSendHooks) error {
if ctx == nil {
ctx = context.Background()
}
var err error
opt, err = normalizeTransferSendOptions(opt)
if err != nil {
return err
}
recordTransferSendRuntimeOptions(target.runtime, target.runtimeScope, opt)
desc := opt.Descriptor
if nextOffset < 0 {
nextOffset = 0
}
if nextOffset > desc.Size {
return errTransferSizeMismatch
}
if opt.Parallelism > 1 {
return sendTransferSegmentsWindowed(ctx, stream, target, opt, nextOffset, hooks)
}
return sendTransferSegmentsSerial(ctx, stream, target, opt, nextOffset, hooks)
}
func recordTransferSendRuntimeOptions(runtime *transferRuntime, runtimeScope string, opt TransferSendOptions) {
if runtime == nil || opt.Descriptor.ID == "" {
return
}
runtime.recordSendOptions(fileTransferDirectionSend, runtimeScope, opt.Descriptor.ID, opt.ChunkSize, opt.Parallelism, opt.MaxInflightBytes)
}
func (c *ClientCommon) dispatchInternalTransferControl(message Message) bool {
state := c.getTransferState()
if !state.controlEnabledSnapshot() || !isInternalTransferControlKey(message.Key) {
return false
}
dispatchInternalTransferControlMessage(&message, state, c.getTransferRuntime(), clientFileScope(), clientFileScope(), 0, nil, nil)
return true
}
func (s *ServerCommon) dispatchInternalTransferControl(message Message) bool {
state := s.getTransferState()
if !state.controlEnabledSnapshot() || !isInternalTransferControlKey(message.Key) {
return false
}
logical := messageLogicalConnSnapshot(&message)
transport := messageTransportConnSnapshot(&message)
dispatchInternalTransferControlMessage(
&message,
state,
s.getTransferRuntime(),
transferRuntimeScopeForPeer(logical, transport),
transferPublicScopeForPeer(logical),
transferGenerationForPeer(logical, transport),
logical,
transport,
)
return true
}
func dispatchInternalTransferControlMessage(msg *Message, state *transferState, runtime *transferRuntime, runtimeScope string, publicScope string, transportGeneration uint64, logical *LogicalConn, transport *TransportConn) {
if msg == nil || state == nil {
return
}
switch msg.Key {
case TransferBeginSignalKey:
var req TransferBeginRequest
resp := TransferBeginResponse{}
if err := msg.Value.Orm(&req); err != nil {
resp.Error = err.Error()
replyTransferControlIfNeeded(msg, resp)
return
}
transferControlPrepareBegin(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req)
resp, err := state.acceptBegin(runtime, req, runtimeScope, publicScope, transportGeneration, logical, transport)
if resp.TransferID == "" {
resp.TransferID = req.TransferID
}
if err != nil && resp.Error == "" {
resp.Error = err.Error()
}
transferControlFinishBegin(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, resp, transferControlResultError("begin", resp.Accepted, resp.Error, err))
replyTransferControlIfNeeded(msg, resp)
case TransferResumeSignalKey:
var req TransferResumeRequest
resp := TransferResumeResponse{}
if err := msg.Value.Orm(&req); err != nil {
resp.Error = err.Error()
replyTransferControlIfNeeded(msg, resp)
return
}
transferControlPrepareResume(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req)
resp, err := state.acceptResume(runtime, req, publicScope, runtimeScope, logical, transport, transportGeneration)
if resp.TransferID == "" {
resp.TransferID = req.TransferID
}
if err != nil && resp.Error == "" {
resp.Error = err.Error()
}
transferControlFinishResume(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, resp, transferControlResultError("resume", resp.Accepted, resp.Error, err))
replyTransferControlIfNeeded(msg, resp)
case TransferCommitSignalKey:
var req TransferCommitRequest
resp := TransferCommitResponse{}
if err := msg.Value.Orm(&req); err != nil {
resp.Error = err.Error()
replyTransferControlIfNeeded(msg, resp)
return
}
transferControlPrepareCommit(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req)
resp, err := state.acceptCommit(runtime, req, publicScope)
if resp.TransferID == "" {
resp.TransferID = req.TransferID
}
if err != nil && resp.Error == "" {
resp.Error = err.Error()
}
transferControlFinishCommit(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, resp, transferControlResultError("commit", resp.Accepted, resp.Error, err))
replyTransferControlIfNeeded(msg, resp)
case TransferAbortSignalKey:
var req TransferAbortRequest
resp := TransferAbortResponse{}
if err := msg.Value.Orm(&req); err != nil {
resp.Error = err.Error()
replyTransferControlIfNeeded(msg, resp)
return
}
transferControlPrepareAbort(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req)
resp, err := state.acceptAbort(req, publicScope)
if resp.TransferID == "" {
resp.TransferID = req.TransferID
}
if err != nil && resp.Error == "" {
resp.Error = err.Error()
}
transferControlFinishAbort(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, req, resp, transferControlResultError("abort", resp.Accepted, resp.Error, err))
replyTransferControlIfNeeded(msg, resp)
}
}
func (state *transferState) acceptBegin(runtime *transferRuntime, req TransferBeginRequest, runtimeScope string, publicScope string, transportGeneration uint64, logical *LogicalConn, transport *TransportConn) (TransferBeginResponse, error) {
if req.TransferID == "" {
return TransferBeginResponse{}, errTransferIDEmpty
}
desc := transferDescriptorFromBegin(req)
if session, ok := state.load(publicScope, req.TransferID); ok {
existing := session.descriptorSnapshot()
if !transferDescriptorsCompatible(existing, desc) {
return TransferBeginResponse{TransferID: req.TransferID}, fmt.Errorf("transfer descriptor mismatch")
}
session.updateBinding(runtimeScope, logical, transport, transportGeneration)
return TransferBeginResponse{
TransferID: req.TransferID,
Accepted: true,
NextOffset: session.nextOffsetSnapshot(),
}, nil
}
if restored, ok, err := state.restoreReceiveSession(runtime, publicScope, runtimeScope, logical, transport, transportGeneration, desc); ok {
if err != nil {
return TransferBeginResponse{TransferID: req.TransferID}, err
}
return TransferBeginResponse{
TransferID: req.TransferID,
Accepted: true,
NextOffset: restored.nextOffsetSnapshot(),
}, nil
}
info := TransferAcceptInfo{
Descriptor: cloneTransferDescriptor(desc),
LogicalConn: logical,
TransportConn: transport,
TransportGeneration: transportGeneration,
}
opt, err := state.acceptOptions(info)
if err != nil {
return TransferBeginResponse{TransferID: req.TransferID}, err
}
opt, err = normalizeTransferReceiveOptions(desc, opt)
if err != nil {
return TransferBeginResponse{TransferID: req.TransferID}, err
}
session := newTransferReceiveSession(publicScope, runtimeScope, logical, transport, transportGeneration, opt)
if err := state.store(publicScope, req.TransferID, session); err != nil {
if existing, ok := state.load(publicScope, req.TransferID); ok {
current := existing.descriptorSnapshot()
if transferDescriptorsCompatible(current, desc) {
existing.updateBinding(runtimeScope, logical, transport, transportGeneration)
return TransferBeginResponse{
TransferID: req.TransferID,
Accepted: true,
NextOffset: existing.nextOffsetSnapshot(),
}, nil
}
}
return TransferBeginResponse{TransferID: req.TransferID}, err
}
return TransferBeginResponse{
TransferID: req.TransferID,
Accepted: true,
NextOffset: session.nextOffsetSnapshot(),
}, nil
}
func (state *transferState) acceptResume(runtime *transferRuntime, req TransferResumeRequest, publicScope string, runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64) (TransferResumeResponse, error) {
if req.TransferID == "" {
return TransferResumeResponse{}, errTransferIDEmpty
}
session, ok := state.load(publicScope, req.TransferID)
if !ok {
snapshot, found := runtime.resumableSnapshot(fileTransferDirectionReceive, publicScope, req.TransferID)
if !found {
return TransferResumeResponse{TransferID: req.TransferID}, errTransferSessionNotFound
}
var err error
session, ok, err = state.restoreReceiveSession(runtime, publicScope, runtimeScope, logical, transport, transportGeneration, transferDescriptorFromSnapshot(snapshot))
if err != nil {
return TransferResumeResponse{TransferID: req.TransferID}, err
}
if !ok || session == nil {
return TransferResumeResponse{TransferID: req.TransferID}, errTransferSessionNotFound
}
}
session.updateBinding(runtimeScope, logical, transport, transportGeneration)
return TransferResumeResponse{
TransferID: req.TransferID,
Accepted: true,
NextOffset: session.nextOffsetSnapshot(),
Missing: nil,
Error: "",
}, nil
}
func (state *transferState) acceptCommit(runtime *transferRuntime, req TransferCommitRequest, publicScope string) (TransferCommitResponse, error) {
if req.TransferID == "" {
return TransferCommitResponse{}, errTransferIDEmpty
}
session, ok := state.load(publicScope, req.TransferID)
if !ok {
return TransferCommitResponse{TransferID: req.TransferID}, errTransferSessionNotFound
}
ctx, cancel := context.WithTimeout(context.Background(), transferCommitWaitTimeout)
defer cancel()
if err := session.commit(ctx, runtime, req.TransferID); err != nil {
return TransferCommitResponse{TransferID: req.TransferID}, err
}
state.remove(publicScope, req.TransferID)
return TransferCommitResponse{
TransferID: req.TransferID,
Accepted: true,
}, nil
}
func (state *transferState) acceptAbort(req TransferAbortRequest, publicScope string) (TransferAbortResponse, error) {
if req.TransferID == "" {
return TransferAbortResponse{}, errTransferIDEmpty
}
if session := state.remove(publicScope, req.TransferID); session != nil {
session.close(transferControlAbortError(req.Error))
}
return TransferAbortResponse{
TransferID: req.TransferID,
Accepted: true,
}, nil
}
func replyTransferControlIfNeeded(msg *Message, value interface{}) {
if msg == nil || !requiresSignalReplyWait(msg.TransferMsg) {
return
}
_ = msg.ReplyObj(value)
}
func isInternalTransferControlKey(key string) bool {
switch key {
case TransferBeginSignalKey, TransferResumeSignalKey, TransferCommitSignalKey, TransferAbortSignalKey:
return true
default:
return false
}
}
func (c *ClientCommon) claimInboundTransferStream(stream *streamHandle) (bool, error) {
return claimInboundTransferStream(stream, c.getTransferState(), c.getTransferRuntime(), clientFileScope(), clientFileScope(), nil, nil, c.sequenceDe)
}
func (s *ServerCommon) claimInboundTransferStream(logical *LogicalConn, transport *TransportConn, stream *streamHandle) (bool, error) {
if logical == nil {
return false, nil
}
return claimInboundTransferStream(
stream,
s.getTransferState(),
s.getTransferRuntime(),
serverFileScope(logical),
transferRuntimeScopeForPeer(logical, transport),
logical,
transport,
s.sequenceDe,
)
}
func claimInboundTransferStream(stream *streamHandle, state *transferState, runtime *transferRuntime, publicScope string, runtimeScope string, logical *LogicalConn, transport *TransportConn, sequenceDe func([]byte) (interface{}, error)) (bool, error) {
if stream == nil || state == nil || !state.controlEnabledSnapshot() {
return false, nil
}
transferID, ok := parseTransferStreamMetadata(stream.Metadata())
if !ok {
return false, nil
}
session, ok := state.load(publicScope, transferID)
if !ok {
return true, errTransferSessionNotFound
}
if err := session.beginStream(stream.ID(), runtimeScope, logical, transport, stream.TransportGeneration()); err != nil {
return true, err
}
go func() {
err := receiveTransferStream(stream, session, transferID, sequenceDe, runtime)
if err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) {
_ = stream.Reset(err)
}
if runtime != nil {
runtime.recordFailureStage(fileTransferDirectionReceive, session.runtimeScopeSnapshot(), transferID, "data")
}
}
session.finishStream(stream.ID(), err)
}()
return true, nil
}
func receiveTransferStream(stream Stream, session *transferReceiveSession, transferID string, sequenceDe func([]byte) (interface{}, error), runtime *transferRuntime) error {
if sequenceDe == nil {
return errTransferSequenceDecodeNil
}
for {
payload, err := readTransferFrame(stream)
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
segment, err := decodeTransferSegment(payload, sequenceDe)
if err != nil {
return err
}
if segment.TransferID != transferID {
return fmt.Errorf("transfer id mismatch")
}
if err := session.writeSegment(runtime, transferID, segment.Offset, segment.Payload); err != nil {
return err
}
}
}
func decodeTransferSegment(payload []byte, sequenceDe func([]byte) (interface{}, error)) (itransfer.Segment, error) {
value, err := sequenceDe(payload)
if err != nil {
return itransfer.Segment{}, err
}
segment, ok := value.(itransfer.Segment)
if !ok {
return itransfer.Segment{}, errors.New("invalid transfer segment payload")
}
return segment, nil
}
func buildTransferFrame(payload []byte) []byte {
frame := make([]byte, transferFrameHeaderSize+len(payload))
binary.BigEndian.PutUint32(frame[:transferFrameHeaderSize], uint32(len(payload)))
copy(frame[transferFrameHeaderSize:], payload)
return frame
}
func writeTransferFrames(stream Stream, data []byte) error {
for len(data) > 0 {
n, err := stream.Write(data)
if n > 0 {
data = data[n:]
}
if err != nil {
return err
}
if n == 0 {
return io.ErrShortWrite
}
}
return nil
}
func readTransferFrame(stream Stream) ([]byte, error) {
header := make([]byte, transferFrameHeaderSize)
if _, err := io.ReadFull(stream, header); err != nil {
return nil, err
}
length := binary.BigEndian.Uint32(header)
payload := make([]byte, int(length))
if _, err := io.ReadFull(stream, payload); err != nil {
return nil, err
}
return payload, nil
}
func transferConfirmProgress(ctx context.Context, target transferSendTarget, transferID string) (int64, error) {
if target.sendResume == nil {
return 0, nil
}
resp, err := target.sendResume(ctx, TransferResumeRequest{TransferID: transferID})
if err != nil {
return 0, err
}
return transferControlOffset(resp.NextOffset), nil
}
func normalizeTransferSendOptions(opt TransferSendOptions) (TransferSendOptions, error) {
if opt.Source == nil {
return TransferSendOptions{}, errTransferSourceNil
}
desc := normalizeTransferDescriptor(opt.Descriptor)
if desc.ID == "" {
return TransferSendOptions{}, errTransferIDEmpty
}
if desc.Size < 0 {
return TransferSendOptions{}, errTransferSizeInvalid
}
sourceSize := opt.Source.Size()
if sourceSize < 0 {
return TransferSendOptions{}, errTransferSizeInvalid
}
if desc.Size == 0 {
desc.Size = sourceSize
}
if opt.ChunkSize <= 0 {
opt.ChunkSize = defaultFileChunkSize
}
if opt.Parallelism <= 0 {
opt.Parallelism = 1
}
if opt.MaxInflightBytes <= 0 {
opt.MaxInflightBytes = int64(opt.ChunkSize * opt.Parallelism)
}
if opt.MaxInflightBytes < int64(opt.ChunkSize) {
opt.MaxInflightBytes = int64(opt.ChunkSize)
}
if _, ok := opt.Source.(transferSequentialReaderSource); ok {
if opt.Parallelism > 1 {
return TransferSendOptions{}, errTransferSequentialSourceParallelism
}
if opt.VerifyChecksum && desc.Checksum == "" {
return TransferSendOptions{}, errTransferSequentialSourceImplicitChecksum
}
}
if opt.VerifyChecksum && desc.Checksum == "" {
sum, err := computeTransferChecksum(opt.Source, desc.Size)
if err != nil {
return TransferSendOptions{}, err
}
desc.Checksum = sum
}
opt.Descriptor = desc
return opt, nil
}
func normalizeTransferReceiveOptions(desc TransferDescriptor, opt TransferReceiveOptions) (TransferReceiveOptions, error) {
if opt.Sink == nil {
return TransferReceiveOptions{}, errTransferSinkNil
}
if opt.Descriptor.ID != "" && opt.Descriptor.ID != desc.ID {
return TransferReceiveOptions{}, fmt.Errorf("transfer id mismatch")
}
if opt.Descriptor.Channel != "" && normalizeTransferChannel(opt.Descriptor.Channel) != normalizeTransferChannel(desc.Channel) {
return TransferReceiveOptions{}, fmt.Errorf("transfer channel mismatch")
}
if opt.Descriptor.Size > 0 && opt.Descriptor.Size != desc.Size {
return TransferReceiveOptions{}, fmt.Errorf("transfer size mismatch")
}
if opt.Descriptor.Checksum != "" && !equalChecksum(opt.Descriptor.Checksum, desc.Checksum) {
return TransferReceiveOptions{}, fmt.Errorf("transfer checksum mismatch")
}
merged := cloneTransferDescriptor(desc)
if len(opt.Descriptor.Metadata) != 0 {
if merged.Metadata == nil {
merged.Metadata = make(map[string]string, len(opt.Descriptor.Metadata))
}
for key, value := range opt.Descriptor.Metadata {
merged.Metadata[key] = value
}
}
opt.Descriptor = merged
return opt, nil
}
func normalizeTransferDescriptor(desc TransferDescriptor) TransferDescriptor {
desc.Channel = normalizeTransferChannel(desc.Channel)
desc.Metadata = cloneTransferMetadata(desc.Metadata)
return desc
}
func cloneTransferDescriptor(desc TransferDescriptor) TransferDescriptor {
return normalizeTransferDescriptor(desc)
}
func normalizeTransferChannel(channel TransferChannel) TransferChannel {
switch channel {
case "", TransferChannelData:
return TransferChannelData
case TransferChannelControl:
return TransferChannelControl
default:
return channel
}
}
func transferChannelToStreamChannel(channel TransferChannel) StreamChannel {
switch normalizeTransferChannel(channel) {
case TransferChannelControl:
return StreamControlChannel
default:
return StreamDataChannel
}
}
func transferChannelToKernel(channel TransferChannel) itransfer.Channel {
switch normalizeTransferChannel(channel) {
case TransferChannelControl:
return itransfer.ControlChannel
default:
return itransfer.DataChannel
}
}
func transferDescriptorFromBegin(req TransferBeginRequest) TransferDescriptor {
return normalizeTransferDescriptor(TransferDescriptor{
ID: req.TransferID,
Channel: req.Channel,
Size: req.Size,
Checksum: req.Checksum,
Metadata: cloneTransferMetadata(req.Metadata),
})
}
func transferDescriptorsCompatible(left TransferDescriptor, right TransferDescriptor) bool {
return left.ID == right.ID &&
normalizeTransferChannel(left.Channel) == normalizeTransferChannel(right.Channel) &&
left.Size == right.Size &&
equalChecksum(left.Checksum, right.Checksum)
}
func buildTransferStreamMetadata(transferID string) StreamMetadata {
return StreamMetadata{
transferStreamMetadataKindKey: transferStreamMetadataKindValue,
transferMetadataIDKey: transferID,
}
}
func parseTransferStreamMetadata(metadata StreamMetadata) (string, bool) {
if len(metadata) == 0 {
return "", false
}
if metadata[transferStreamMetadataKindKey] != transferStreamMetadataKindValue {
return "", false
}
transferID := strings.TrimSpace(metadata[transferMetadataIDKey])
if transferID == "" {
return "", false
}
return transferID, true
}
func recordTransferDataFailure(runtime *transferRuntime, runtimeScope string, transferID string, err error) {
if runtime == nil || transferID == "" || err == nil {
return
}
runtime.recordStage(fileTransferDirectionSend, runtimeScope, transferID, "data")
runtime.recordFailureStage(fileTransferDirectionSend, runtimeScope, transferID, "data")
runtime.fail(fileTransferDirectionSend, runtimeScope, transferID, err)
}
func transferRunContext(ctx context.Context) context.Context {
if ctx == nil {
return context.Background()
}
return ctx
}
func transferRuntimeScopeForPeer(logical *LogicalConn, transport *TransportConn) string {
if transport != nil {
return serverTransportScopeForTransport(transport)
}
if logical != nil {
return serverTransportScope(logical)
}
return defaultFileScope
}
func transferPublicScopeForPeer(logical *LogicalConn) string {
if logical == nil {
return defaultFileScope
}
return serverFileScope(logical)
}
func transferGenerationForPeer(logical *LogicalConn, transport *TransportConn) uint64 {
if transport != nil {
return transport.TransportGeneration()
}
if logical != nil {
return logical.transportGenerationSnapshot()
}
return 0
}
func computeTransferChecksum(reader io.ReaderAt, size int64) (string, error) {
if size < 0 {
return "", errTransferSizeInvalid
}
sum := sha256.New()
buf := make([]byte, transferChecksumChunkSize)
for offset := int64(0); offset < size; {
want := len(buf)
remaining := size - offset
if remaining < int64(want) {
want = int(remaining)
}
n, err := reader.ReadAt(buf[:want], offset)
if n > 0 {
_, _ = sum.Write(buf[:n])
offset += int64(n)
}
if err != nil {
if errors.Is(err, io.EOF) && offset == size {
break
}
return "", err
}
if n == 0 && offset < size {
return "", io.ErrNoProgress
}
}
return hex.EncodeToString(sum.Sum(nil)), nil
}
func equalChecksum(got string, want string) bool {
got = normalizeChecksum(got)
want = normalizeChecksum(want)
if got == "" || want == "" {
return got == want
}
return strings.EqualFold(got, want)
}
func normalizeChecksum(value string) string {
value = strings.TrimSpace(value)
value = strings.TrimPrefix(value, "sha256:")
value = strings.TrimPrefix(value, "SHA256:")
return value
}