notify/bulk_batch_sender.go
starainrt f038a89771
fix: close stream adaptive gaps and switch notify to stario v0.1.1
- make stream fast path honor adaptive soft payload limits end-to-end
  - split oversized fast-stream payloads into sequential frames before batching
  - use adaptive soft cap when encoding stream batch payloads
  - move timeout-like error detection into production code for adaptive tx
  - tune notify FrameReader read size explicitly to avoid throughput regression
  - drop local stario replace and depend on released b612.me/stario v0.1.1
2026-04-18 16:05:57 +08:00

689 lines
16 KiB
Go

package notify
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
)
const (
bulkBatchMaxPayloads = 64
bulkBatchMaxPayloadBytes = bulkFastBatchMaxPlainBytes
bulkBatchMaxFlushDelay = 50 * time.Microsecond
)
const (
bulkBatchRequestQueued int32 = iota
bulkBatchRequestStarted
bulkBatchRequestCanceled
)
type bulkBatchRequestState struct {
value atomic.Int32
}
type bulkBatchCodec struct {
encodeSingle func(bulkFastFrame) ([]byte, func(), error)
encodeBatch func([]bulkFastFrame) ([]byte, func(), error)
}
type bulkBatchRequest struct {
ctx context.Context
frames []bulkFastFrame
fastPathVersion uint8
payloadOwned bool
deadline time.Time
done chan error
state *bulkBatchRequestState
release func()
}
type bulkBatchEncodedPayload struct {
payload []byte
release func()
}
func (p *bulkBatchEncodedPayload) done() {
if p == nil || p.release == nil {
return
}
p.release()
p.release = nil
}
type bulkBatchSender struct {
binding *transportBinding
codec bulkBatchCodec
writeTimeoutProvider func() time.Duration
reqCh chan bulkBatchRequest
stopCh chan struct{}
doneCh chan struct{}
stopOnce sync.Once
flushMu sync.Mutex
queued atomic.Int64
errMu sync.Mutex
err error
}
func newBulkBatchSender(binding *transportBinding, codec bulkBatchCodec, writeTimeoutProvider func() time.Duration) *bulkBatchSender {
sender := &bulkBatchSender{
binding: binding,
codec: codec,
writeTimeoutProvider: writeTimeoutProvider,
reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
go sender.run()
return sender
}
func (s *bulkBatchSender) submitData(ctx context.Context, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
return s.submitFramesOwned(ctx, []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: payload,
}}, fastPathVersion, false)
}
func (s *bulkBatchSender) submitControl(ctx context.Context, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
return s.submitFramesOwned(ctx, []bulkFastFrame{{
Type: frameType,
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload,
}}, fastPathVersion, false)
}
func (s *bulkBatchSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, fastPathVersion uint8, payload []byte, chunkSize int, payloadOwned bool) (int, error) {
if s == nil {
return 0, errTransportDetached
}
if len(payload) == 0 {
return 0, nil
}
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
written := 0
seq := startSeq
for written < len(payload) {
var batch [bulkFastBatchMaxItems]bulkFastFrame
frames := batch[:0]
batchBytes := bulkFastBatchHeaderLen
start := written
for written < len(payload) && len(frames) < bulkFastBatchMaxItems {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
frame := bulkFastFrame{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: payload[written:end],
}
frameLen := bulkFastBatchFrameLen(frame)
if len(frames) > 0 && batchBytes+frameLen > bulkFastBatchMaxPlainBytes {
break
}
frames = append(frames, frame)
batchBytes += frameLen
seq++
written = end
}
if len(frames) == 0 {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
frames = append(frames, bulkFastFrame{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: payload[written:end],
})
seq++
written = end
}
if err := s.submitFramesOwned(ctx, frames, fastPathVersion, payloadOwned); err != nil {
return start, err
}
}
return written, nil
}
func (s *bulkBatchSender) submitFrames(ctx context.Context, frames []bulkFastFrame, fastPathVersion uint8) error {
return s.submitFramesOwned(ctx, frames, fastPathVersion, false)
}
func (s *bulkBatchSender) submitFramesOwned(ctx context.Context, frames []bulkFastFrame, fastPathVersion uint8, payloadOwned bool) error {
if s == nil {
return errTransportDetached
}
if ctx == nil {
ctx = context.Background()
}
if len(frames) == 0 {
return nil
}
req := bulkBatchRequest{
ctx: ctx,
frames: frames,
fastPathVersion: normalizeBulkFastPathVersion(fastPathVersion),
payloadOwned: payloadOwned,
done: make(chan error, 1),
state: &bulkBatchRequestState{},
}
if deadline, ok := ctx.Deadline(); ok {
req.deadline = deadline
}
if err := s.errSnapshot(); err != nil {
return err
}
if s.shouldDirectSubmit(req) {
if submitted, err := s.tryDirectSubmit(req); submitted {
return err
}
}
req = cloneQueuedBulkBatchRequest(req)
s.queued.Add(1)
select {
case <-ctx.Done():
s.queued.Add(-1)
if req.release != nil {
req.release()
}
return normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
s.queued.Add(-1)
if req.release != nil {
req.release()
}
return s.stoppedErr()
case s.reqCh <- req:
}
select {
case err := <-req.done:
return err
case <-ctx.Done():
if req.tryCancel() {
return normalizeStreamDeadlineError(ctx.Err())
}
return <-req.done
}
}
func (s *bulkBatchSender) shouldDirectSubmit(req bulkBatchRequest) bool {
if len(req.frames) == 0 {
return false
}
return !bulkBatchRequestSupportsSharedSuperBatch(req)
}
func (s *bulkBatchSender) tryDirectSubmit(req bulkBatchRequest) (bool, error) {
if s == nil {
return true, errTransportDetached
}
if err := s.errSnapshot(); err != nil {
return true, err
}
select {
case <-req.ctx.Done():
return true, normalizeStreamDeadlineError(req.ctx.Err())
case <-s.stopCh:
return true, s.stoppedErr()
default:
}
if s.queued.Load() != 0 {
return false, nil
}
if !s.flushMu.TryLock() {
return false, nil
}
defer s.flushMu.Unlock()
if s.queued.Load() != 0 {
return false, nil
}
if err := s.errSnapshot(); err != nil {
return true, err
}
if !req.tryStart() {
return true, req.canceledErr()
}
if err := req.contextErr(); err != nil {
return true, err
}
err := s.flush([]bulkBatchRequest{req})
if err != nil {
s.setErr(err)
s.failPending(err)
return true, err
}
return true, nil
}
func (s *bulkBatchSender) run() {
defer close(s.doneCh)
for {
req, ok := s.nextRequest()
if !ok {
return
}
batch := []bulkBatchRequest{req}
batchBytes := bulkBatchRequestApproxBytes(req)
timer := (*time.Timer)(nil)
timerCh := (<-chan time.Time)(nil)
if bulkBatchShouldWaitForMore(batch, batchBytes) {
timer = time.NewTimer(bulkBatchMaxFlushDelay)
timerCh = timer.C
}
drain:
for len(batch) < bulkBatchMaxPayloads && batchBytes < bulkBatchMaxPayloadBytes {
if timerCh == nil {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
batchBytes += bulkBatchRequestApproxBytes(next)
default:
break drain
}
continue
}
select {
case <-s.stopCh:
if timer != nil {
timer.Stop()
}
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
batchBytes += bulkBatchRequestApproxBytes(next)
case <-timerCh:
timerCh = nil
break drain
}
}
if timer != nil {
if !timer.Stop() && timerCh != nil {
select {
case <-timer.C:
default:
}
}
}
s.flushMu.Lock()
err := s.errSnapshot()
active := make([]bulkBatchRequest, 0, len(batch))
for _, item := range batch {
if !item.tryStart() {
s.finishRequest(item, item.canceledErr())
continue
}
if itemErr := item.contextErr(); itemErr != nil {
s.finishRequest(item, itemErr)
continue
}
active = append(active, item)
}
if len(active) == 0 {
s.flushMu.Unlock()
continue
}
if err == nil {
err = s.flush(active)
}
s.flushMu.Unlock()
if err != nil {
s.setErr(err)
for _, item := range active {
s.finishRequest(item, err)
}
s.failPending(err)
return
}
for _, item := range active {
s.finishRequest(item, nil)
}
}
}
func (s *bulkBatchSender) nextRequest() (bulkBatchRequest, bool) {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return bulkBatchRequest{}, false
case req := <-s.reqCh:
return req, true
}
}
func (r bulkBatchRequest) contextErr() error {
if r.ctx == nil {
return nil
}
select {
case <-r.ctx.Done():
return normalizeStreamDeadlineError(r.ctx.Err())
default:
return nil
}
}
func (r bulkBatchRequest) tryStart() bool {
if r.state == nil {
return true
}
return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestStarted)
}
func (r bulkBatchRequest) tryCancel() bool {
if r.state == nil {
return false
}
return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestCanceled)
}
func (r bulkBatchRequest) canceledErr() error {
if err := r.contextErr(); err != nil {
return err
}
return context.Canceled
}
func (s *bulkBatchSender) flush(requests []bulkBatchRequest) error {
if s == nil || s.binding == nil {
return errTransportDetached
}
queue := s.binding.queueSnapshot()
if queue == nil {
return errTransportFrameQueueUnavailable
}
payloads, err := s.encodeRequests(requests)
if err != nil {
return err
}
defer func() {
for index := range payloads {
payloads[index].done()
}
}()
writeTimeout := s.transportWriteTimeout()
for _, payload := range payloads {
frame := payload.payload
started := time.Now()
err := s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error {
return writeFramedPayloadUnlocked(conn, queue, frame)
})
s.binding.observeBulkAdaptivePayloadWrite(len(frame), time.Since(started), writeTimeout, err)
if err != nil {
return err
}
}
return nil
}
func (s *bulkBatchSender) encodeRequests(requests []bulkBatchRequest) ([]bulkBatchEncodedPayload, error) {
if len(requests) == 0 {
return nil, nil
}
payloads := make([]bulkBatchEncodedPayload, 0, len(requests))
batch := make([]bulkFastFrame, 0, minInt(len(requests), bulkFastBatchMaxItems))
mixedBatchLimit := s.sharedMixedPayloadLimit()
batchRequestIndex := -1
batchDataID := uint64(0)
batchMixed := false
flushBatch := func() error {
if len(batch) == 0 {
return nil
}
payload, release, err := s.encodeBatch(batch)
if err != nil {
return err
}
payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release})
batch = batch[:0]
batchRequestIndex = -1
batchDataID = 0
batchMixed = false
return nil
}
batchBytes := bulkFastBatchHeaderLen
for reqIndex, req := range requests {
for _, frame := range req.frames {
if !bulkFastPathSupportsSharedBatch(req.fastPathVersion) {
if err := flushBatch(); err != nil {
return nil, err
}
batchBytes = bulkFastBatchHeaderLen
payload, release, err := s.encodeSingle(frame)
if err != nil {
return nil, err
}
payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release})
continue
}
frameLen := bulkFastBatchFrameLen(frame)
if frameLen+bulkFastBatchHeaderLen > bulkFastBatchMaxPlainBytes {
if err := flushBatch(); err != nil {
return nil, err
}
batchBytes = bulkFastBatchHeaderLen
payload, release, err := s.encodeSingle(frame)
if err != nil {
return nil, err
}
payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release})
continue
}
nextMixed := batchMixed
if len(batch) > 0 && (batchRequestIndex != reqIndex || (batchDataID != 0 && batchDataID != frame.DataID)) {
nextMixed = true
}
batchLimit := bulkFastBatchMaxPlainBytes
if nextMixed && mixedBatchLimit > 0 && mixedBatchLimit < batchLimit {
batchLimit = mixedBatchLimit
}
if len(batch) > 0 && (len(batch) >= bulkFastBatchMaxItems || batchBytes+frameLen > batchLimit) {
if err := flushBatch(); err != nil {
return nil, err
}
batchBytes = bulkFastBatchHeaderLen
nextMixed = false
}
if len(batch) == 0 {
batchRequestIndex = reqIndex
batchDataID = frame.DataID
batchMixed = false
} else if batchRequestIndex != reqIndex || (batchDataID != 0 && batchDataID != frame.DataID) {
batchMixed = true
}
batch = append(batch, frame)
batchBytes += frameLen
}
}
if err := flushBatch(); err != nil {
return nil, err
}
return payloads, nil
}
func bulkBatchRequestApproxBytes(req bulkBatchRequest) int {
total := 0
for _, frame := range req.frames {
total += bulkFastBatchFrameLen(frame)
}
return total
}
func bulkBatchRequestSupportsSharedSuperBatch(req bulkBatchRequest) bool {
if len(req.frames) == 0 || !bulkFastPathSupportsSharedBatch(req.fastPathVersion) {
return false
}
for _, frame := range req.frames {
switch frame.Type {
case bulkFastPayloadTypeData:
default:
return false
}
}
return true
}
func bulkBatchShouldWaitForMore(batch []bulkBatchRequest, batchBytes int) bool {
if bulkBatchMaxFlushDelay <= 0 || len(batch) == 0 {
return false
}
if len(batch) >= bulkBatchMaxPayloads || batchBytes >= bulkBatchMaxPayloadBytes {
return false
}
for _, req := range batch {
if !bulkBatchRequestSupportsSharedSuperBatch(req) {
return false
}
}
return true
}
func cloneQueuedBulkBatchRequest(req bulkBatchRequest) bulkBatchRequest {
if len(req.frames) == 0 || req.payloadOwned {
return req
}
clonedFrames := make([]bulkFastFrame, len(req.frames))
totalPayload := 0
for _, frame := range req.frames {
totalPayload += len(frame.Payload)
}
var payloadBuf []byte
if totalPayload > 0 {
payloadBuf = getBulkAsyncWritePayload(totalPayload)
req.release = func() {
putBulkAsyncWritePayload(payloadBuf)
}
}
offset := 0
for index, frame := range req.frames {
clonedFrames[index] = frame
if len(frame.Payload) == 0 {
clonedFrames[index].Payload = nil
continue
}
next := offset + len(frame.Payload)
clonedFrames[index].Payload = payloadBuf[offset:next]
copy(clonedFrames[index].Payload, frame.Payload)
offset = next
}
req.frames = clonedFrames
return req
}
func (s *bulkBatchSender) encodeSingle(frame bulkFastFrame) ([]byte, func(), error) {
if s == nil || s.codec.encodeSingle == nil {
return nil, nil, errTransportDetached
}
return s.codec.encodeSingle(frame)
}
func (s *bulkBatchSender) encodeBatch(frames []bulkFastFrame) ([]byte, func(), error) {
if len(frames) == 1 || s.codec.encodeBatch == nil {
return s.encodeSingle(frames[0])
}
return s.codec.encodeBatch(frames)
}
func (s *bulkBatchSender) stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
s.setErr(errTransportDetached)
close(s.stopCh)
})
<-s.doneCh
}
func (s *bulkBatchSender) failPending(err error) {
for {
select {
case item := <-s.reqCh:
s.finishRequest(item, err)
default:
return
}
}
}
func (s *bulkBatchSender) finishRequest(req bulkBatchRequest, err error) {
if s != nil {
s.queued.Add(-1)
}
if req.release != nil {
req.release()
}
req.done <- err
}
func (s *bulkBatchSender) setErr(err error) {
if s == nil || err == nil {
return
}
s.errMu.Lock()
if s.err == nil {
s.err = err
}
s.errMu.Unlock()
}
func (s *bulkBatchSender) errSnapshot() error {
if s == nil {
return errTransportDetached
}
s.errMu.Lock()
defer s.errMu.Unlock()
return s.err
}
func (s *bulkBatchSender) stoppedErr() error {
if err := s.errSnapshot(); err != nil {
return err
}
return errTransportDetached
}
func (s *bulkBatchSender) transportWriteDeadline() time.Time {
if s == nil || s.writeTimeoutProvider == nil {
return time.Time{}
}
return writeDeadlineFromTimeout(s.writeTimeoutProvider())
}
func (s *bulkBatchSender) transportWriteTimeout() time.Duration {
if s == nil || s.writeTimeoutProvider == nil {
return 0
}
return s.writeTimeoutProvider()
}
func (s *bulkBatchSender) sharedMixedPayloadLimit() int {
if s == nil || s.binding == nil {
return bulkAdaptiveSoftPayloadFallbackBytes
}
return s.binding.bulkAdaptiveSoftPayloadBytesSnapshot()
}
func minInt(a int, b int) int {
if a < b {
return a
}
return b
}