package notify import ( "context" "errors" "io" "time" itransfer "b612.me/notify/internal/transfer" ) type transferFrameBatchWriter struct { stream Stream runtime *transferRuntime runtimeScope string transferID string batch []byte frameCount int } func newTransferFrameBatchWriter(stream Stream, runtime *transferRuntime, runtimeScope string, transferID string) *transferFrameBatchWriter { return &transferFrameBatchWriter{ stream: stream, runtime: runtime, runtimeScope: runtimeScope, transferID: transferID, batch: make([]byte, 0, transferFrameAggregateLimit), } } func (w *transferFrameBatchWriter) writeEncodedFrame(payload []byte) error { if w == nil { return nil } frame := buildTransferFrame(payload) if len(w.batch) > 0 && len(w.batch)+len(frame) > transferFrameAggregateLimit { if err := w.flush(); err != nil { return err } } if len(frame) >= transferFrameAggregateLimit { if err := w.flush(); err != nil { return err } return w.writeBatch(frame) } w.batch = append(w.batch, frame...) w.frameCount++ if len(w.batch) >= transferFrameAggregateLimit || w.frameCount >= transferFrameAggregateCount { return w.flush() } return nil } func (w *transferFrameBatchWriter) flush() error { if w == nil || len(w.batch) == 0 { return nil } if err := w.writeBatch(w.batch); err != nil { return err } w.batch = w.batch[:0] w.frameCount = 0 return nil } func (w *transferFrameBatchWriter) writeBatch(data []byte) error { if w == nil || len(data) == 0 { return nil } start := time.Now() err := writeTransferFrames(w.stream, data) if err == nil && w.runtime != nil && w.transferID != "" { w.runtime.recordStreamWrite(fileTransferDirectionSend, w.runtimeScope, w.transferID, time.Since(start)) } return err } type transferSegmentReadResult struct { offset int64 want int n int readDuration time.Duration payload []byte err error } func sendTransferSegmentFrame(writer *transferFrameBatchWriter, target transferSendTarget, desc TransferDescriptor, chunk []byte, offset int64, runtimeScope string, hooks transferSendHooks) error { if len(chunk) == 0 { return io.ErrNoProgress } segment := itransfer.Segment{ TransferID: desc.ID, Channel: transferChannelToKernel(desc.Channel), Offset: offset, Payload: append([]byte(nil), chunk...), } payload, err := target.sequenceEn(segment) if err != nil { return err } if err := writer.writeEncodedFrame(payload); err != nil { return err } if target.runtime != nil { target.runtime.activate(fileTransferDirectionSend, runtimeScope, desc.ID) target.runtime.recordStage(fileTransferDirectionSend, runtimeScope, desc.ID, "data") target.runtime.recordSend(fileTransferDirectionSend, runtimeScope, desc.ID, int64(len(chunk))) } if hooks.onSegmentSent != nil { hooks.onSegmentSent(offset, int64(len(chunk))) } return nil } func sendTransferSegmentsSerial(ctx context.Context, stream Stream, target transferSendTarget, opt TransferSendOptions, nextOffset int64, hooks transferSendHooks) error { desc := opt.Descriptor chunkSize := opt.ChunkSize buf := make([]byte, chunkSize) writer := newTransferFrameBatchWriter(stream, target.runtime, target.runtimeScope, desc.ID) for offset := nextOffset; offset < desc.Size; { select { case <-ctx.Done(): return ctx.Err() default: } want := chunkSize remaining := desc.Size - offset if remaining < int64(want) { want = int(remaining) } readStartedAt := time.Now() n, err := opt.Source.ReadAt(buf[:want], offset) if target.runtime != nil { target.runtime.recordSourceRead(fileTransferDirectionSend, target.runtimeScope, desc.ID, time.Since(readStartedAt)) } if n > 0 { if sendErr := sendTransferSegmentFrame(writer, target, desc, buf[:n], offset, target.runtimeScope, hooks); sendErr != nil { return sendErr } offset += int64(n) } if err != nil { if errors.Is(err, io.EOF) && offset == desc.Size { return writer.flush() } return err } if n == 0 { return io.ErrNoProgress } } return writer.flush() } func sendTransferSegmentsWindowed(ctx context.Context, stream Stream, target transferSendTarget, opt TransferSendOptions, nextOffset int64, hooks transferSendHooks) error { desc := opt.Descriptor chunkSize := opt.ChunkSize parallelism := opt.Parallelism if parallelism <= 1 { return sendTransferSegmentsSerial(ctx, stream, target, opt, nextOffset, hooks) } windowBytes := opt.MaxInflightBytes if windowBytes <= 0 { windowBytes = int64(chunkSize * parallelism) } if windowBytes < int64(chunkSize) { windowBytes = int64(chunkSize) } runCtx, cancel := context.WithCancel(ctx) defer cancel() results := make(chan transferSegmentReadResult, parallelism) pending := make(map[int64]transferSegmentReadResult) writer := newTransferFrameBatchWriter(stream, target.runtime, target.runtimeScope, desc.ID) var nextDispatch int64 = nextOffset var nextWrite int64 = nextOffset var activeReads int var reservedBytes int64 dispatchRead := func(offset int64, want int) { activeReads++ reservedBytes += int64(want) go func() { buf := make([]byte, want) readStartedAt := time.Now() n, err := opt.Source.ReadAt(buf, offset) readDuration := time.Since(readStartedAt) if n > 0 { buf = buf[:n] } else { buf = nil } result := transferSegmentReadResult{ offset: offset, want: want, n: n, readDuration: readDuration, payload: buf, err: err, } select { case results <- result: case <-runCtx.Done(): } }() } tryDispatch := func() { for nextDispatch < desc.Size && activeReads < parallelism { want := chunkSize remaining := desc.Size - nextDispatch if remaining < int64(want) { want = int(remaining) } if reservedBytes > 0 && reservedBytes+int64(want) > windowBytes { return } dispatchRead(nextDispatch, want) nextDispatch += int64(want) } } consumeResult := func(result transferSegmentReadResult) error { if result.want > 0 { reservedBytes -= int64(result.want) if reservedBytes < 0 { reservedBytes = 0 } } if target.runtime != nil { target.runtime.recordSourceRead(fileTransferDirectionSend, target.runtimeScope, desc.ID, result.readDuration) } if result.n > 0 { if err := sendTransferSegmentFrame(writer, target, desc, result.payload, result.offset, target.runtimeScope, hooks); err != nil { return err } nextWrite = result.offset + int64(result.n) } if result.err != nil { if errors.Is(result.err, io.EOF) && nextWrite == desc.Size { return nil } return result.err } if result.n == 0 { return io.ErrNoProgress } return nil } tryDispatch() for nextWrite < desc.Size || activeReads > 0 || len(pending) > 0 { if ready, ok := pending[nextWrite]; ok { delete(pending, nextWrite) if err := consumeResult(ready); err != nil { return err } tryDispatch() continue } select { case <-runCtx.Done(): return runCtx.Err() case result := <-results: activeReads-- pending[result.offset] = result tryDispatch() } } return writer.flush() }