notify/stream_batch_sender.go

583 lines
13 KiB
Go
Raw Permalink Normal View History

package notify
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
)
const (
streamBatchMaxPayloads = 64
streamBatchMaxPayloadBytes = 2 * 1024 * 1024
streamBatchMaxFlushDelay = 50 * time.Microsecond
streamBatchWaitThreshold = 128 * 1024
)
const (
streamBatchRequestQueued int32 = iota
streamBatchRequestStarted
streamBatchRequestCanceled
)
type streamBatchRequestState struct {
value atomic.Int32
}
type streamBatchRequest struct {
ctx context.Context
frame streamFastDataFrame
hasFrame bool
encodedPayload []byte
hasEncoded bool
frames []streamFastDataFrame
fastPathVersion uint8
deadline time.Time
done chan error
state *streamBatchRequestState
}
type streamBatchSender struct {
binding *transportBinding
codec streamBatchCodec
writeTimeoutProvider func() time.Duration
reqCh chan streamBatchRequest
stopCh chan struct{}
doneCh chan struct{}
stopOnce sync.Once
flushMu sync.Mutex
queued atomic.Int64
errMu sync.Mutex
err error
}
func newStreamBatchSender(binding *transportBinding, codec streamBatchCodec, writeTimeoutProvider func() time.Duration) *streamBatchSender {
sender := &streamBatchSender{
binding: binding,
codec: codec,
writeTimeoutProvider: writeTimeoutProvider,
reqCh: make(chan streamBatchRequest, streamBatchMaxPayloads*4),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
go sender.run()
return sender
}
func (s *streamBatchSender) submitData(ctx context.Context, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
if s == nil {
return errTransportDetached
}
if len(payload) == 0 {
return nil
}
return s.submitRequest(streamBatchRequest{
ctx: ctx,
frame: streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: payload,
},
hasFrame: true,
fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion),
})
}
func (s *streamBatchSender) submitEncoded(ctx context.Context, fastPathVersion uint8, payload []byte) error {
if s == nil {
return errTransportDetached
}
if len(payload) == 0 {
return nil
}
return s.submitRequest(streamBatchRequest{
ctx: ctx,
encodedPayload: payload,
hasEncoded: true,
fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion),
})
}
func (s *streamBatchSender) submitFrames(ctx context.Context, fastPathVersion uint8, frames []streamFastDataFrame) error {
if s == nil {
return errTransportDetached
}
if len(frames) == 0 {
return nil
}
queuedFrames := append([]streamFastDataFrame(nil), frames...)
req := streamBatchRequest{
ctx: ctx,
frames: queuedFrames,
fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion),
}
if len(queuedFrames) == 1 {
req.frame = queuedFrames[0]
req.frames = nil
req.hasFrame = true
}
return s.submitRequest(req)
}
func (s *streamBatchSender) submitRequest(req streamBatchRequest) error {
if s == nil {
return errTransportDetached
}
if req.ctx == nil {
req.ctx = context.Background()
}
if !req.hasFrame && !req.hasEncoded && len(req.frames) == 0 {
return nil
}
req.fastPathVersion = normalizeStreamFastPathVersion(req.fastPathVersion)
req.done = make(chan error, 1)
req.state = &streamBatchRequestState{}
if deadline, ok := req.ctx.Deadline(); ok {
req.deadline = deadline
}
if err := s.errSnapshot(); err != nil {
return err
}
if s.shouldDirectSubmit(req) {
if submitted, err := s.tryDirectSubmit(req); submitted {
return err
}
}
s.queued.Add(1)
select {
case <-req.ctx.Done():
s.queued.Add(-1)
return normalizeStreamDeadlineError(req.ctx.Err())
case <-s.stopCh:
s.queued.Add(-1)
return s.stoppedErr()
case s.reqCh <- req:
}
select {
case err := <-req.done:
return err
case <-req.ctx.Done():
if req.tryCancel() {
return normalizeStreamDeadlineError(req.ctx.Err())
}
return <-req.done
}
}
func (s *streamBatchSender) shouldDirectSubmit(req streamBatchRequest) bool {
if req.hasEncoded {
return false
}
if !req.hasFrame && len(req.frames) == 0 {
return false
}
return !streamFastPathSupportsBatch(req.fastPathVersion)
}
func (s *streamBatchSender) tryDirectSubmit(req streamBatchRequest) (bool, error) {
if s == nil {
return true, errTransportDetached
}
if err := s.errSnapshot(); err != nil {
return true, err
}
select {
case <-req.ctx.Done():
return true, normalizeStreamDeadlineError(req.ctx.Err())
case <-s.stopCh:
return true, s.stoppedErr()
default:
}
if s.queued.Load() != 0 {
return false, nil
}
if !s.flushMu.TryLock() {
return false, nil
}
defer s.flushMu.Unlock()
if s.queued.Load() != 0 {
return false, nil
}
if err := s.errSnapshot(); err != nil {
return true, err
}
if !req.tryStart() {
return true, req.canceledErr()
}
if err := req.contextErr(); err != nil {
return true, err
}
if err := s.flush([]streamBatchRequest{req}); err != nil {
s.setErr(err)
s.failPending(err)
return true, err
}
return true, nil
}
func (s *streamBatchSender) run() {
defer close(s.doneCh)
for {
req, ok := s.nextRequest()
if !ok {
return
}
batch := []streamBatchRequest{req}
batchBytes := streamBatchRequestApproxBytes(req)
softPayloadLimit := s.batchSoftPayloadLimit()
waitThreshold := s.batchWaitThreshold()
flushDelay := s.batchFlushDelay()
timer := (*time.Timer)(nil)
timerCh := (<-chan time.Time)(nil)
if flushDelay > 0 && batchBytes < waitThreshold && batchBytes < softPayloadLimit && len(batch) < streamBatchMaxPayloads {
timer = time.NewTimer(flushDelay)
timerCh = timer.C
}
drain:
for len(batch) < streamBatchMaxPayloads && batchBytes < softPayloadLimit {
if timerCh == nil {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
batchBytes += streamBatchRequestApproxBytes(next)
default:
break drain
}
continue
}
select {
case <-s.stopCh:
if timer != nil {
timer.Stop()
}
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
batchBytes += streamBatchRequestApproxBytes(next)
case <-timerCh:
timerCh = nil
break drain
}
}
if timer != nil {
if !timer.Stop() && timerCh != nil {
select {
case <-timer.C:
default:
}
}
}
s.flushMu.Lock()
err := s.errSnapshot()
active := make([]streamBatchRequest, 0, len(batch))
for _, item := range batch {
if !item.tryStart() {
s.finishRequest(item, item.canceledErr())
continue
}
if itemErr := item.contextErr(); itemErr != nil {
s.finishRequest(item, itemErr)
continue
}
active = append(active, item)
}
if len(active) == 0 {
s.flushMu.Unlock()
continue
}
if err == nil {
err = s.flush(active)
}
s.flushMu.Unlock()
if err != nil {
s.setErr(err)
for _, item := range active {
s.finishRequest(item, err)
}
s.failPending(err)
return
}
for _, item := range active {
s.finishRequest(item, nil)
}
}
}
func (s *streamBatchSender) nextRequest() (streamBatchRequest, bool) {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return streamBatchRequest{}, false
case req := <-s.reqCh:
return req, true
}
}
func (s *streamBatchSender) flush(requests []streamBatchRequest) error {
if s == nil || s.binding == nil {
return errTransportDetached
}
queue := s.binding.queueSnapshot()
if queue == nil {
return errTransportFrameQueueUnavailable
}
payloads, err := s.encodeRequests(requests)
if err != nil {
return err
}
writeTimeout := s.transportWriteTimeout()
payloadBytes := 0
for _, payload := range payloads {
payloadBytes += len(payload)
}
started := time.Now()
err = s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error {
return writeFramedPayloadBatchUnlocked(conn, queue, payloads)
})
s.binding.observeStreamAdaptivePayloadWrite(payloadBytes, time.Since(started), writeTimeout, err)
return err
}
func (s *streamBatchSender) transportWriteTimeout() time.Duration {
if s == nil || s.writeTimeoutProvider == nil {
return 0
}
return s.writeTimeoutProvider()
}
func (s *streamBatchSender) batchSoftPayloadLimit() int {
if s == nil || s.binding == nil {
return streamAdaptiveSoftPayloadFallbackBytes
}
return s.binding.streamAdaptiveSoftPayloadBytesSnapshot()
}
func (s *streamBatchSender) batchWaitThreshold() int {
if s == nil || s.binding == nil {
return streamBatchWaitThreshold
}
return s.binding.streamAdaptiveWaitThresholdBytesSnapshot()
}
func (s *streamBatchSender) batchFlushDelay() time.Duration {
if s == nil || s.binding == nil {
return streamBatchMaxFlushDelay
}
return s.binding.streamAdaptiveFlushDelaySnapshot()
}
func (s *streamBatchSender) batchPlainPayloadLimit() int {
limit := s.batchSoftPayloadLimit()
if limit <= streamFastBatchHeaderLen {
return streamFastBatchHeaderLen + 1
}
return minInt(limit, streamFastBatchMaxPlainBytes)
}
func (s *streamBatchSender) encodeRequests(requests []streamBatchRequest) ([][]byte, error) {
if len(requests) == 0 {
return nil, nil
}
payloads := make([][]byte, 0, len(requests))
batchPlainLimit := s.batchPlainPayloadLimit()
var batch []streamFastDataFrame
flushBatch := func() error {
if len(batch) == 0 {
return nil
}
payload, err := s.encodeBatch(batch)
if err != nil {
return err
}
payloads = append(payloads, payload)
batch = batch[:0]
return nil
}
batchBytes := streamFastBatchHeaderLen
appendFrame := func(frame streamFastDataFrame, fastPathVersion uint8) error {
if !streamFastPathSupportsBatch(fastPathVersion) {
if err := flushBatch(); err != nil {
return err
}
batchBytes = streamFastBatchHeaderLen
payload, err := s.encodeSingle(frame)
if err != nil {
return err
}
payloads = append(payloads, payload)
return nil
}
frameLen := streamFastBatchFrameLen(frame)
if frameLen+streamFastBatchHeaderLen > batchPlainLimit {
if err := flushBatch(); err != nil {
return err
}
batchBytes = streamFastBatchHeaderLen
payload, err := s.encodeSingle(frame)
if err != nil {
return err
}
payloads = append(payloads, payload)
return nil
}
if len(batch) > 0 && (len(batch) >= streamFastBatchMaxItems || batchBytes+frameLen > batchPlainLimit) {
if err := flushBatch(); err != nil {
return err
}
batchBytes = streamFastBatchHeaderLen
}
if batch == nil {
batch = make([]streamFastDataFrame, 0, minInt(len(requests), streamFastBatchMaxItems))
}
batch = append(batch, frame)
batchBytes += frameLen
return nil
}
for _, req := range requests {
if req.hasFrame {
if err := appendFrame(req.frame, req.fastPathVersion); err != nil {
return nil, err
}
}
for _, frame := range req.frames {
if err := appendFrame(frame, req.fastPathVersion); err != nil {
return nil, err
}
}
if req.hasEncoded {
if err := flushBatch(); err != nil {
return nil, err
}
batchBytes = streamFastBatchHeaderLen
payloads = append(payloads, req.encodedPayload)
}
}
if err := flushBatch(); err != nil {
return nil, err
}
return payloads, nil
}
func streamBatchRequestApproxBytes(req streamBatchRequest) int {
total := 0
if req.hasFrame {
total += streamFastBatchFrameLen(req.frame)
}
for _, frame := range req.frames {
total += streamFastBatchFrameLen(frame)
}
if req.hasEncoded {
total += len(req.encodedPayload)
}
return total
}
func (s *streamBatchSender) encodeSingle(frame streamFastDataFrame) ([]byte, error) {
if s == nil || s.codec.encodeSingle == nil {
return nil, errTransportDetached
}
return s.codec.encodeSingle(frame)
}
func (s *streamBatchSender) encodeBatch(frames []streamFastDataFrame) ([]byte, error) {
if len(frames) == 1 || s.codec.encodeBatch == nil {
return s.encodeSingle(frames[0])
}
return s.codec.encodeBatch(frames)
}
func (s *streamBatchSender) finishRequest(req streamBatchRequest, err error) {
if s != nil {
s.queued.Add(-1)
}
req.done <- err
}
func (s *streamBatchSender) stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
s.setErr(errTransportDetached)
close(s.stopCh)
})
<-s.doneCh
}
func (s *streamBatchSender) failPending(err error) {
for {
select {
case item := <-s.reqCh:
s.finishRequest(item, err)
default:
return
}
}
}
func (s *streamBatchSender) setErr(err error) {
if s == nil || err == nil {
return
}
s.errMu.Lock()
if s.err == nil {
s.err = err
}
s.errMu.Unlock()
}
func (s *streamBatchSender) errSnapshot() error {
if s == nil {
return errTransportDetached
}
s.errMu.Lock()
defer s.errMu.Unlock()
return s.err
}
func (s *streamBatchSender) stoppedErr() error {
if err := s.errSnapshot(); err != nil {
return err
}
return errTransportDetached
}
func (r streamBatchRequest) contextErr() error {
if r.ctx == nil {
return nil
}
select {
case <-r.ctx.Done():
return normalizeStreamDeadlineError(r.ctx.Err())
default:
return nil
}
}
func (r streamBatchRequest) tryStart() bool {
if r.state == nil {
return true
}
return r.state.value.CompareAndSwap(streamBatchRequestQueued, streamBatchRequestStarted)
}
func (r streamBatchRequest) tryCancel() bool {
if r.state == nil {
return false
}
return r.state.value.CompareAndSwap(streamBatchRequestQueued, streamBatchRequestCanceled)
}
func (r streamBatchRequest) canceledErr() error {
if err := r.contextErr(); err != nil {
return err
}
return context.Canceled
}