notify/transfer_send_pipeline.go

277 lines
7.1 KiB
Go
Raw Permalink Normal View History

package notify
import (
"context"
"errors"
"io"
"time"
itransfer "b612.me/notify/internal/transfer"
)
type transferFrameBatchWriter struct {
stream Stream
runtime *transferRuntime
runtimeScope string
transferID string
batch []byte
frameCount int
}
func newTransferFrameBatchWriter(stream Stream, runtime *transferRuntime, runtimeScope string, transferID string) *transferFrameBatchWriter {
return &transferFrameBatchWriter{
stream: stream,
runtime: runtime,
runtimeScope: runtimeScope,
transferID: transferID,
batch: make([]byte, 0, transferFrameAggregateLimit),
}
}
func (w *transferFrameBatchWriter) writeEncodedFrame(payload []byte) error {
if w == nil {
return nil
}
frame := buildTransferFrame(payload)
if len(w.batch) > 0 && len(w.batch)+len(frame) > transferFrameAggregateLimit {
if err := w.flush(); err != nil {
return err
}
}
if len(frame) >= transferFrameAggregateLimit {
if err := w.flush(); err != nil {
return err
}
return w.writeBatch(frame)
}
w.batch = append(w.batch, frame...)
w.frameCount++
if len(w.batch) >= transferFrameAggregateLimit || w.frameCount >= transferFrameAggregateCount {
return w.flush()
}
return nil
}
func (w *transferFrameBatchWriter) flush() error {
if w == nil || len(w.batch) == 0 {
return nil
}
if err := w.writeBatch(w.batch); err != nil {
return err
}
w.batch = w.batch[:0]
w.frameCount = 0
return nil
}
func (w *transferFrameBatchWriter) writeBatch(data []byte) error {
if w == nil || len(data) == 0 {
return nil
}
start := time.Now()
err := writeTransferFrames(w.stream, data)
if err == nil && w.runtime != nil && w.transferID != "" {
w.runtime.recordStreamWrite(fileTransferDirectionSend, w.runtimeScope, w.transferID, time.Since(start))
}
return err
}
type transferSegmentReadResult struct {
offset int64
want int
n int
readDuration time.Duration
payload []byte
err error
}
func sendTransferSegmentFrame(writer *transferFrameBatchWriter, target transferSendTarget, desc TransferDescriptor, chunk []byte, offset int64, runtimeScope string, hooks transferSendHooks) error {
if len(chunk) == 0 {
return io.ErrNoProgress
}
segment := itransfer.Segment{
TransferID: desc.ID,
Channel: transferChannelToKernel(desc.Channel),
Offset: offset,
Payload: append([]byte(nil), chunk...),
}
payload, err := target.sequenceEn(segment)
if err != nil {
return err
}
if err := writer.writeEncodedFrame(payload); err != nil {
return err
}
if target.runtime != nil {
target.runtime.activate(fileTransferDirectionSend, runtimeScope, desc.ID)
target.runtime.recordStage(fileTransferDirectionSend, runtimeScope, desc.ID, "data")
target.runtime.recordSend(fileTransferDirectionSend, runtimeScope, desc.ID, int64(len(chunk)))
}
if hooks.onSegmentSent != nil {
hooks.onSegmentSent(offset, int64(len(chunk)))
}
return nil
}
func sendTransferSegmentsSerial(ctx context.Context, stream Stream, target transferSendTarget, opt TransferSendOptions, nextOffset int64, hooks transferSendHooks) error {
desc := opt.Descriptor
chunkSize := opt.ChunkSize
buf := make([]byte, chunkSize)
writer := newTransferFrameBatchWriter(stream, target.runtime, target.runtimeScope, desc.ID)
for offset := nextOffset; offset < desc.Size; {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
want := chunkSize
remaining := desc.Size - offset
if remaining < int64(want) {
want = int(remaining)
}
readStartedAt := time.Now()
n, err := opt.Source.ReadAt(buf[:want], offset)
if target.runtime != nil {
target.runtime.recordSourceRead(fileTransferDirectionSend, target.runtimeScope, desc.ID, time.Since(readStartedAt))
}
if n > 0 {
if sendErr := sendTransferSegmentFrame(writer, target, desc, buf[:n], offset, target.runtimeScope, hooks); sendErr != nil {
return sendErr
}
offset += int64(n)
}
if err != nil {
if errors.Is(err, io.EOF) && offset == desc.Size {
return writer.flush()
}
return err
}
if n == 0 {
return io.ErrNoProgress
}
}
return writer.flush()
}
func sendTransferSegmentsWindowed(ctx context.Context, stream Stream, target transferSendTarget, opt TransferSendOptions, nextOffset int64, hooks transferSendHooks) error {
desc := opt.Descriptor
chunkSize := opt.ChunkSize
parallelism := opt.Parallelism
if parallelism <= 1 {
return sendTransferSegmentsSerial(ctx, stream, target, opt, nextOffset, hooks)
}
windowBytes := opt.MaxInflightBytes
if windowBytes <= 0 {
windowBytes = int64(chunkSize * parallelism)
}
if windowBytes < int64(chunkSize) {
windowBytes = int64(chunkSize)
}
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
results := make(chan transferSegmentReadResult, parallelism)
pending := make(map[int64]transferSegmentReadResult)
writer := newTransferFrameBatchWriter(stream, target.runtime, target.runtimeScope, desc.ID)
var nextDispatch int64 = nextOffset
var nextWrite int64 = nextOffset
var activeReads int
var reservedBytes int64
dispatchRead := func(offset int64, want int) {
activeReads++
reservedBytes += int64(want)
go func() {
buf := make([]byte, want)
readStartedAt := time.Now()
n, err := opt.Source.ReadAt(buf, offset)
readDuration := time.Since(readStartedAt)
if n > 0 {
buf = buf[:n]
} else {
buf = nil
}
result := transferSegmentReadResult{
offset: offset,
want: want,
n: n,
readDuration: readDuration,
payload: buf,
err: err,
}
select {
case results <- result:
case <-runCtx.Done():
}
}()
}
tryDispatch := func() {
for nextDispatch < desc.Size && activeReads < parallelism {
want := chunkSize
remaining := desc.Size - nextDispatch
if remaining < int64(want) {
want = int(remaining)
}
if reservedBytes > 0 && reservedBytes+int64(want) > windowBytes {
return
}
dispatchRead(nextDispatch, want)
nextDispatch += int64(want)
}
}
consumeResult := func(result transferSegmentReadResult) error {
if result.want > 0 {
reservedBytes -= int64(result.want)
if reservedBytes < 0 {
reservedBytes = 0
}
}
if target.runtime != nil {
target.runtime.recordSourceRead(fileTransferDirectionSend, target.runtimeScope, desc.ID, result.readDuration)
}
if result.n > 0 {
if err := sendTransferSegmentFrame(writer, target, desc, result.payload, result.offset, target.runtimeScope, hooks); err != nil {
return err
}
nextWrite = result.offset + int64(result.n)
}
if result.err != nil {
if errors.Is(result.err, io.EOF) && nextWrite == desc.Size {
return nil
}
return result.err
}
if result.n == 0 {
return io.ErrNoProgress
}
return nil
}
tryDispatch()
for nextWrite < desc.Size || activeReads > 0 || len(pending) > 0 {
if ready, ok := pending[nextWrite]; ok {
delete(pending, nextWrite)
if err := consumeResult(ready); err != nil {
return err
}
tryDispatch()
continue
}
select {
case <-runCtx.Done():
return runCtx.Err()
case result := <-results:
activeReads--
pending[result.offset] = result
tryDispatch()
}
}
return writer.flush()
}