package notify import ( "context" "net" "sync" "sync/atomic" "time" ) const ( streamBatchMaxPayloads = 64 streamBatchMaxPayloadBytes = 2 * 1024 * 1024 streamBatchMaxFlushDelay = 50 * time.Microsecond streamBatchWaitThreshold = 128 * 1024 ) const ( streamBatchRequestQueued int32 = iota streamBatchRequestStarted streamBatchRequestCanceled ) type streamBatchRequestState struct { value atomic.Int32 } type streamBatchRequest struct { ctx context.Context frame streamFastDataFrame hasFrame bool encodedPayload []byte hasEncoded bool frames []streamFastDataFrame fastPathVersion uint8 deadline time.Time done chan error state *streamBatchRequestState } type streamBatchSender struct { binding *transportBinding codec streamBatchCodec writeTimeoutProvider func() time.Duration reqCh chan streamBatchRequest stopCh chan struct{} doneCh chan struct{} stopOnce sync.Once flushMu sync.Mutex queued atomic.Int64 errMu sync.Mutex err error } func newStreamBatchSender(binding *transportBinding, codec streamBatchCodec, writeTimeoutProvider func() time.Duration) *streamBatchSender { sender := &streamBatchSender{ binding: binding, codec: codec, writeTimeoutProvider: writeTimeoutProvider, reqCh: make(chan streamBatchRequest, streamBatchMaxPayloads*4), stopCh: make(chan struct{}), doneCh: make(chan struct{}), } go sender.run() return sender } func (s *streamBatchSender) submitData(ctx context.Context, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error { if s == nil { return errTransportDetached } if len(payload) == 0 { return nil } return s.submitRequest(streamBatchRequest{ ctx: ctx, frame: streamFastDataFrame{ DataID: dataID, Seq: seq, Payload: payload, }, hasFrame: true, fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion), }) } func (s *streamBatchSender) submitEncoded(ctx context.Context, fastPathVersion uint8, payload []byte) error { if s == nil { return errTransportDetached } if len(payload) == 0 { return nil } return s.submitRequest(streamBatchRequest{ ctx: ctx, encodedPayload: payload, hasEncoded: true, fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion), }) } func (s *streamBatchSender) submitFrames(ctx context.Context, fastPathVersion uint8, frames []streamFastDataFrame) error { if s == nil { return errTransportDetached } if len(frames) == 0 { return nil } queuedFrames := append([]streamFastDataFrame(nil), frames...) req := streamBatchRequest{ ctx: ctx, frames: queuedFrames, fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion), } if len(queuedFrames) == 1 { req.frame = queuedFrames[0] req.frames = nil req.hasFrame = true } return s.submitRequest(req) } func (s *streamBatchSender) submitRequest(req streamBatchRequest) error { if s == nil { return errTransportDetached } if req.ctx == nil { req.ctx = context.Background() } if !req.hasFrame && !req.hasEncoded && len(req.frames) == 0 { return nil } req.fastPathVersion = normalizeStreamFastPathVersion(req.fastPathVersion) req.done = make(chan error, 1) req.state = &streamBatchRequestState{} if deadline, ok := req.ctx.Deadline(); ok { req.deadline = deadline } if err := s.errSnapshot(); err != nil { return err } if s.shouldDirectSubmit(req) { if submitted, err := s.tryDirectSubmit(req); submitted { return err } } s.queued.Add(1) select { case <-req.ctx.Done(): s.queued.Add(-1) return normalizeStreamDeadlineError(req.ctx.Err()) case <-s.stopCh: s.queued.Add(-1) return s.stoppedErr() case s.reqCh <- req: } select { case err := <-req.done: return err case <-req.ctx.Done(): if req.tryCancel() { return normalizeStreamDeadlineError(req.ctx.Err()) } return <-req.done } } func (s *streamBatchSender) shouldDirectSubmit(req streamBatchRequest) bool { if req.hasEncoded { return false } if !req.hasFrame && len(req.frames) == 0 { return false } return !streamFastPathSupportsBatch(req.fastPathVersion) } func (s *streamBatchSender) tryDirectSubmit(req streamBatchRequest) (bool, error) { if s == nil { return true, errTransportDetached } if err := s.errSnapshot(); err != nil { return true, err } select { case <-req.ctx.Done(): return true, normalizeStreamDeadlineError(req.ctx.Err()) case <-s.stopCh: return true, s.stoppedErr() default: } if s.queued.Load() != 0 { return false, nil } if !s.flushMu.TryLock() { return false, nil } defer s.flushMu.Unlock() if s.queued.Load() != 0 { return false, nil } if err := s.errSnapshot(); err != nil { return true, err } if !req.tryStart() { return true, req.canceledErr() } if err := req.contextErr(); err != nil { return true, err } if err := s.flush([]streamBatchRequest{req}); err != nil { s.setErr(err) s.failPending(err) return true, err } return true, nil } func (s *streamBatchSender) run() { defer close(s.doneCh) for { req, ok := s.nextRequest() if !ok { return } batch := []streamBatchRequest{req} batchBytes := streamBatchRequestApproxBytes(req) softPayloadLimit := s.batchSoftPayloadLimit() waitThreshold := s.batchWaitThreshold() flushDelay := s.batchFlushDelay() timer := (*time.Timer)(nil) timerCh := (<-chan time.Time)(nil) if flushDelay > 0 && batchBytes < waitThreshold && batchBytes < softPayloadLimit && len(batch) < streamBatchMaxPayloads { timer = time.NewTimer(flushDelay) timerCh = timer.C } drain: for len(batch) < streamBatchMaxPayloads && batchBytes < softPayloadLimit { if timerCh == nil { select { case <-s.stopCh: s.failPending(s.stoppedErr()) return case next := <-s.reqCh: batch = append(batch, next) batchBytes += streamBatchRequestApproxBytes(next) default: break drain } continue } select { case <-s.stopCh: if timer != nil { timer.Stop() } s.failPending(s.stoppedErr()) return case next := <-s.reqCh: batch = append(batch, next) batchBytes += streamBatchRequestApproxBytes(next) case <-timerCh: timerCh = nil break drain } } if timer != nil { if !timer.Stop() && timerCh != nil { select { case <-timer.C: default: } } } s.flushMu.Lock() err := s.errSnapshot() active := make([]streamBatchRequest, 0, len(batch)) for _, item := range batch { if !item.tryStart() { s.finishRequest(item, item.canceledErr()) continue } if itemErr := item.contextErr(); itemErr != nil { s.finishRequest(item, itemErr) continue } active = append(active, item) } if len(active) == 0 { s.flushMu.Unlock() continue } if err == nil { err = s.flush(active) } s.flushMu.Unlock() if err != nil { s.setErr(err) for _, item := range active { s.finishRequest(item, err) } s.failPending(err) return } for _, item := range active { s.finishRequest(item, nil) } } } func (s *streamBatchSender) nextRequest() (streamBatchRequest, bool) { select { case <-s.stopCh: s.failPending(s.stoppedErr()) return streamBatchRequest{}, false case req := <-s.reqCh: return req, true } } func (s *streamBatchSender) flush(requests []streamBatchRequest) error { if s == nil || s.binding == nil { return errTransportDetached } queue := s.binding.queueSnapshot() if queue == nil { return errTransportFrameQueueUnavailable } payloads, err := s.encodeRequests(requests) if err != nil { return err } writeTimeout := s.transportWriteTimeout() payloadBytes := 0 for _, payload := range payloads { payloadBytes += len(payload) } started := time.Now() err = s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error { return writeFramedPayloadBatchUnlocked(conn, queue, payloads) }) s.binding.observeStreamAdaptivePayloadWrite(payloadBytes, time.Since(started), writeTimeout, err) return err } func (s *streamBatchSender) transportWriteTimeout() time.Duration { if s == nil || s.writeTimeoutProvider == nil { return 0 } return s.writeTimeoutProvider() } func (s *streamBatchSender) batchSoftPayloadLimit() int { if s == nil || s.binding == nil { return streamAdaptiveSoftPayloadFallbackBytes } return s.binding.streamAdaptiveSoftPayloadBytesSnapshot() } func (s *streamBatchSender) batchWaitThreshold() int { if s == nil || s.binding == nil { return streamBatchWaitThreshold } return s.binding.streamAdaptiveWaitThresholdBytesSnapshot() } func (s *streamBatchSender) batchFlushDelay() time.Duration { if s == nil || s.binding == nil { return streamBatchMaxFlushDelay } return s.binding.streamAdaptiveFlushDelaySnapshot() } func (s *streamBatchSender) batchPlainPayloadLimit() int { limit := s.batchSoftPayloadLimit() if limit <= streamFastBatchHeaderLen { return streamFastBatchHeaderLen + 1 } return minInt(limit, streamFastBatchMaxPlainBytes) } func (s *streamBatchSender) encodeRequests(requests []streamBatchRequest) ([][]byte, error) { if len(requests) == 0 { return nil, nil } payloads := make([][]byte, 0, len(requests)) batchPlainLimit := s.batchPlainPayloadLimit() var batch []streamFastDataFrame flushBatch := func() error { if len(batch) == 0 { return nil } payload, err := s.encodeBatch(batch) if err != nil { return err } payloads = append(payloads, payload) batch = batch[:0] return nil } batchBytes := streamFastBatchHeaderLen appendFrame := func(frame streamFastDataFrame, fastPathVersion uint8) error { if !streamFastPathSupportsBatch(fastPathVersion) { if err := flushBatch(); err != nil { return err } batchBytes = streamFastBatchHeaderLen payload, err := s.encodeSingle(frame) if err != nil { return err } payloads = append(payloads, payload) return nil } frameLen := streamFastBatchFrameLen(frame) if frameLen+streamFastBatchHeaderLen > batchPlainLimit { if err := flushBatch(); err != nil { return err } batchBytes = streamFastBatchHeaderLen payload, err := s.encodeSingle(frame) if err != nil { return err } payloads = append(payloads, payload) return nil } if len(batch) > 0 && (len(batch) >= streamFastBatchMaxItems || batchBytes+frameLen > batchPlainLimit) { if err := flushBatch(); err != nil { return err } batchBytes = streamFastBatchHeaderLen } if batch == nil { batch = make([]streamFastDataFrame, 0, minInt(len(requests), streamFastBatchMaxItems)) } batch = append(batch, frame) batchBytes += frameLen return nil } for _, req := range requests { if req.hasFrame { if err := appendFrame(req.frame, req.fastPathVersion); err != nil { return nil, err } } for _, frame := range req.frames { if err := appendFrame(frame, req.fastPathVersion); err != nil { return nil, err } } if req.hasEncoded { if err := flushBatch(); err != nil { return nil, err } batchBytes = streamFastBatchHeaderLen payloads = append(payloads, req.encodedPayload) } } if err := flushBatch(); err != nil { return nil, err } return payloads, nil } func streamBatchRequestApproxBytes(req streamBatchRequest) int { total := 0 if req.hasFrame { total += streamFastBatchFrameLen(req.frame) } for _, frame := range req.frames { total += streamFastBatchFrameLen(frame) } if req.hasEncoded { total += len(req.encodedPayload) } return total } func (s *streamBatchSender) encodeSingle(frame streamFastDataFrame) ([]byte, error) { if s == nil || s.codec.encodeSingle == nil { return nil, errTransportDetached } return s.codec.encodeSingle(frame) } func (s *streamBatchSender) encodeBatch(frames []streamFastDataFrame) ([]byte, error) { if len(frames) == 1 || s.codec.encodeBatch == nil { return s.encodeSingle(frames[0]) } return s.codec.encodeBatch(frames) } func (s *streamBatchSender) finishRequest(req streamBatchRequest, err error) { if s != nil { s.queued.Add(-1) } req.done <- err } func (s *streamBatchSender) stop() { if s == nil { return } s.stopOnce.Do(func() { s.setErr(errTransportDetached) close(s.stopCh) }) <-s.doneCh } func (s *streamBatchSender) failPending(err error) { for { select { case item := <-s.reqCh: s.finishRequest(item, err) default: return } } } func (s *streamBatchSender) setErr(err error) { if s == nil || err == nil { return } s.errMu.Lock() if s.err == nil { s.err = err } s.errMu.Unlock() } func (s *streamBatchSender) errSnapshot() error { if s == nil { return errTransportDetached } s.errMu.Lock() defer s.errMu.Unlock() return s.err } func (s *streamBatchSender) stoppedErr() error { if err := s.errSnapshot(); err != nil { return err } return errTransportDetached } func (r streamBatchRequest) contextErr() error { if r.ctx == nil { return nil } select { case <-r.ctx.Done(): return normalizeStreamDeadlineError(r.ctx.Err()) default: return nil } } func (r streamBatchRequest) tryStart() bool { if r.state == nil { return true } return r.state.value.CompareAndSwap(streamBatchRequestQueued, streamBatchRequestStarted) } func (r streamBatchRequest) tryCancel() bool { if r.state == nil { return false } return r.state.value.CompareAndSwap(streamBatchRequestQueued, streamBatchRequestCanceled) } func (r streamBatchRequest) canceledErr() error { if err := r.contextErr(); err != nil { return err } return context.Canceled }