package notify import ( "context" "net" "sync" "sync/atomic" "time" ) const ( bulkBatchMaxPayloads = 64 bulkBatchMaxPayloadBytes = bulkFastBatchMaxPlainBytes bulkBatchMaxFlushDelay = 50 * time.Microsecond ) const ( bulkBatchRequestQueued int32 = iota bulkBatchRequestStarted bulkBatchRequestCanceled ) type bulkBatchRequestState struct { value atomic.Int32 } type bulkBatchCodec struct { encodeSingle func(bulkFastFrame) ([]byte, func(), error) encodeBatch func([]bulkFastFrame) ([]byte, func(), error) } type bulkBatchRequest struct { ctx context.Context frames []bulkFastFrame fastPathVersion uint8 payloadOwned bool deadline time.Time done chan error state *bulkBatchRequestState release func() } type bulkBatchEncodedPayload struct { payload []byte release func() } func (p *bulkBatchEncodedPayload) done() { if p == nil || p.release == nil { return } p.release() p.release = nil } type bulkBatchSender struct { binding *transportBinding codec bulkBatchCodec writeTimeoutProvider func() time.Duration reqCh chan bulkBatchRequest stopCh chan struct{} doneCh chan struct{} stopOnce sync.Once flushMu sync.Mutex queued atomic.Int64 errMu sync.Mutex err error } func newBulkBatchSender(binding *transportBinding, codec bulkBatchCodec, writeTimeoutProvider func() time.Duration) *bulkBatchSender { sender := &bulkBatchSender{ binding: binding, codec: codec, writeTimeoutProvider: writeTimeoutProvider, reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4), stopCh: make(chan struct{}), doneCh: make(chan struct{}), } go sender.run() return sender } func (s *bulkBatchSender) submitData(ctx context.Context, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error { return s.submitFramesOwned(ctx, []bulkFastFrame{{ Type: bulkFastPayloadTypeData, DataID: dataID, Seq: seq, Payload: payload, }}, fastPathVersion, false) } func (s *bulkBatchSender) submitControl(ctx context.Context, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error { return s.submitFramesOwned(ctx, []bulkFastFrame{{ Type: frameType, Flags: flags, DataID: dataID, Seq: seq, Payload: payload, }}, fastPathVersion, false) } func (s *bulkBatchSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, fastPathVersion uint8, payload []byte, chunkSize int, payloadOwned bool) (int, error) { if s == nil { return 0, errTransportDetached } if len(payload) == 0 { return 0, nil } if chunkSize <= 0 { chunkSize = defaultBulkChunkSize } written := 0 seq := startSeq for written < len(payload) { var batch [bulkFastBatchMaxItems]bulkFastFrame frames := batch[:0] batchBytes := bulkFastBatchHeaderLen start := written for written < len(payload) && len(frames) < bulkFastBatchMaxItems { end := written + chunkSize if end > len(payload) { end = len(payload) } frame := bulkFastFrame{ Type: bulkFastPayloadTypeData, DataID: dataID, Seq: seq, Payload: payload[written:end], } frameLen := bulkFastBatchFrameLen(frame) if len(frames) > 0 && batchBytes+frameLen > bulkFastBatchMaxPlainBytes { break } frames = append(frames, frame) batchBytes += frameLen seq++ written = end } if len(frames) == 0 { end := written + chunkSize if end > len(payload) { end = len(payload) } frames = append(frames, bulkFastFrame{ Type: bulkFastPayloadTypeData, DataID: dataID, Seq: seq, Payload: payload[written:end], }) seq++ written = end } if err := s.submitFramesOwned(ctx, frames, fastPathVersion, payloadOwned); err != nil { return start, err } } return written, nil } func (s *bulkBatchSender) submitFrames(ctx context.Context, frames []bulkFastFrame, fastPathVersion uint8) error { return s.submitFramesOwned(ctx, frames, fastPathVersion, false) } func (s *bulkBatchSender) submitFramesOwned(ctx context.Context, frames []bulkFastFrame, fastPathVersion uint8, payloadOwned bool) error { if s == nil { return errTransportDetached } if ctx == nil { ctx = context.Background() } if len(frames) == 0 { return nil } req := bulkBatchRequest{ ctx: ctx, frames: frames, fastPathVersion: normalizeBulkFastPathVersion(fastPathVersion), payloadOwned: payloadOwned, done: make(chan error, 1), state: &bulkBatchRequestState{}, } if deadline, ok := ctx.Deadline(); ok { req.deadline = deadline } if err := s.errSnapshot(); err != nil { return err } if s.shouldDirectSubmit(req) { if submitted, err := s.tryDirectSubmit(req); submitted { return err } } req = cloneQueuedBulkBatchRequest(req) s.queued.Add(1) select { case <-ctx.Done(): s.queued.Add(-1) if req.release != nil { req.release() } return normalizeStreamDeadlineError(ctx.Err()) case <-s.stopCh: s.queued.Add(-1) if req.release != nil { req.release() } return s.stoppedErr() case s.reqCh <- req: } select { case err := <-req.done: return err case <-ctx.Done(): if req.tryCancel() { return normalizeStreamDeadlineError(ctx.Err()) } return <-req.done } } func (s *bulkBatchSender) shouldDirectSubmit(req bulkBatchRequest) bool { if len(req.frames) == 0 { return false } return !bulkBatchRequestSupportsSharedSuperBatch(req) } func (s *bulkBatchSender) tryDirectSubmit(req bulkBatchRequest) (bool, error) { if s == nil { return true, errTransportDetached } if err := s.errSnapshot(); err != nil { return true, err } select { case <-req.ctx.Done(): return true, normalizeStreamDeadlineError(req.ctx.Err()) case <-s.stopCh: return true, s.stoppedErr() default: } if s.queued.Load() != 0 { return false, nil } if !s.flushMu.TryLock() { return false, nil } defer s.flushMu.Unlock() if s.queued.Load() != 0 { return false, nil } if err := s.errSnapshot(); err != nil { return true, err } if !req.tryStart() { return true, req.canceledErr() } if err := req.contextErr(); err != nil { return true, err } err := s.flush([]bulkBatchRequest{req}) if err != nil { s.setErr(err) s.failPending(err) return true, err } return true, nil } func (s *bulkBatchSender) run() { defer close(s.doneCh) for { req, ok := s.nextRequest() if !ok { return } batch := []bulkBatchRequest{req} batchBytes := bulkBatchRequestApproxBytes(req) timer := (*time.Timer)(nil) timerCh := (<-chan time.Time)(nil) if bulkBatchShouldWaitForMore(batch, batchBytes) { timer = time.NewTimer(bulkBatchMaxFlushDelay) timerCh = timer.C } drain: for len(batch) < bulkBatchMaxPayloads && batchBytes < bulkBatchMaxPayloadBytes { if timerCh == nil { select { case <-s.stopCh: s.failPending(s.stoppedErr()) return case next := <-s.reqCh: batch = append(batch, next) batchBytes += bulkBatchRequestApproxBytes(next) default: break drain } continue } select { case <-s.stopCh: if timer != nil { timer.Stop() } s.failPending(s.stoppedErr()) return case next := <-s.reqCh: batch = append(batch, next) batchBytes += bulkBatchRequestApproxBytes(next) case <-timerCh: timerCh = nil break drain } } if timer != nil { if !timer.Stop() && timerCh != nil { select { case <-timer.C: default: } } } s.flushMu.Lock() err := s.errSnapshot() active := make([]bulkBatchRequest, 0, len(batch)) for _, item := range batch { if !item.tryStart() { s.finishRequest(item, item.canceledErr()) continue } if itemErr := item.contextErr(); itemErr != nil { s.finishRequest(item, itemErr) continue } active = append(active, item) } if len(active) == 0 { s.flushMu.Unlock() continue } if err == nil { err = s.flush(active) } s.flushMu.Unlock() if err != nil { s.setErr(err) for _, item := range active { s.finishRequest(item, err) } s.failPending(err) return } for _, item := range active { s.finishRequest(item, nil) } } } func (s *bulkBatchSender) nextRequest() (bulkBatchRequest, bool) { select { case <-s.stopCh: s.failPending(s.stoppedErr()) return bulkBatchRequest{}, false case req := <-s.reqCh: return req, true } } func (r bulkBatchRequest) contextErr() error { if r.ctx == nil { return nil } select { case <-r.ctx.Done(): return normalizeStreamDeadlineError(r.ctx.Err()) default: return nil } } func (r bulkBatchRequest) tryStart() bool { if r.state == nil { return true } return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestStarted) } func (r bulkBatchRequest) tryCancel() bool { if r.state == nil { return false } return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestCanceled) } func (r bulkBatchRequest) canceledErr() error { if err := r.contextErr(); err != nil { return err } return context.Canceled } func (s *bulkBatchSender) flush(requests []bulkBatchRequest) error { if s == nil || s.binding == nil { return errTransportDetached } queue := s.binding.queueSnapshot() if queue == nil { return errTransportFrameQueueUnavailable } payloads, err := s.encodeRequests(requests) if err != nil { return err } defer func() { for index := range payloads { payloads[index].done() } }() writeTimeout := s.transportWriteTimeout() frames := make([][]byte, 0, len(payloads)) payloadBytes := 0 for _, payload := range payloads { frames = append(frames, payload.payload) payloadBytes += len(payload.payload) } started := time.Now() err = s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error { return writeFramedPayloadBatchUnlocked(conn, queue, frames) }) s.binding.observeBulkAdaptivePayloadWrite(payloadBytes, time.Since(started), writeTimeout, err) return err } func (s *bulkBatchSender) encodeRequests(requests []bulkBatchRequest) ([]bulkBatchEncodedPayload, error) { if len(requests) == 0 { return nil, nil } payloads := make([]bulkBatchEncodedPayload, 0, len(requests)) batch := make([]bulkFastFrame, 0, minInt(len(requests), bulkFastBatchMaxItems)) mixedBatchLimit := s.sharedMixedPayloadLimit() batchRequestIndex := -1 batchDataID := uint64(0) batchMixed := false flushBatch := func() error { if len(batch) == 0 { return nil } payload, release, err := s.encodeBatch(batch) if err != nil { return err } payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release}) batch = batch[:0] batchRequestIndex = -1 batchDataID = 0 batchMixed = false return nil } batchBytes := bulkFastBatchHeaderLen for reqIndex, req := range requests { for _, frame := range req.frames { if !bulkFastPathSupportsSharedBatch(req.fastPathVersion) { if err := flushBatch(); err != nil { return nil, err } batchBytes = bulkFastBatchHeaderLen payload, release, err := s.encodeSingle(frame) if err != nil { return nil, err } payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release}) continue } frameLen := bulkFastBatchFrameLen(frame) if frameLen+bulkFastBatchHeaderLen > bulkFastBatchMaxPlainBytes { if err := flushBatch(); err != nil { return nil, err } batchBytes = bulkFastBatchHeaderLen payload, release, err := s.encodeSingle(frame) if err != nil { return nil, err } payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release}) continue } nextMixed := batchMixed if len(batch) > 0 && (batchRequestIndex != reqIndex || (batchDataID != 0 && batchDataID != frame.DataID)) { nextMixed = true } batchLimit := bulkFastBatchMaxPlainBytes if nextMixed && mixedBatchLimit > 0 && mixedBatchLimit < batchLimit { batchLimit = mixedBatchLimit } if len(batch) > 0 && (len(batch) >= bulkFastBatchMaxItems || batchBytes+frameLen > batchLimit) { if err := flushBatch(); err != nil { return nil, err } batchBytes = bulkFastBatchHeaderLen nextMixed = false } if len(batch) == 0 { batchRequestIndex = reqIndex batchDataID = frame.DataID batchMixed = false } else if batchRequestIndex != reqIndex || (batchDataID != 0 && batchDataID != frame.DataID) { batchMixed = true } batch = append(batch, frame) batchBytes += frameLen } } if err := flushBatch(); err != nil { return nil, err } return payloads, nil } func bulkBatchRequestApproxBytes(req bulkBatchRequest) int { total := 0 for _, frame := range req.frames { total += bulkFastBatchFrameLen(frame) } return total } func bulkBatchRequestSupportsSharedSuperBatch(req bulkBatchRequest) bool { if len(req.frames) == 0 || !bulkFastPathSupportsSharedBatch(req.fastPathVersion) { return false } for _, frame := range req.frames { switch frame.Type { case bulkFastPayloadTypeData: default: return false } } return true } func bulkBatchShouldWaitForMore(batch []bulkBatchRequest, batchBytes int) bool { if bulkBatchMaxFlushDelay <= 0 || len(batch) == 0 { return false } if len(batch) >= bulkBatchMaxPayloads || batchBytes >= bulkBatchMaxPayloadBytes { return false } for _, req := range batch { if !bulkBatchRequestSupportsSharedSuperBatch(req) { return false } } return true } func cloneQueuedBulkBatchRequest(req bulkBatchRequest) bulkBatchRequest { if len(req.frames) == 0 || req.payloadOwned { return req } clonedFrames := make([]bulkFastFrame, len(req.frames)) totalPayload := 0 for _, frame := range req.frames { totalPayload += len(frame.Payload) } var payloadBuf []byte if totalPayload > 0 { payloadBuf = getBulkAsyncWritePayload(totalPayload) req.release = func() { putBulkAsyncWritePayload(payloadBuf) } } offset := 0 for index, frame := range req.frames { clonedFrames[index] = frame if len(frame.Payload) == 0 { clonedFrames[index].Payload = nil continue } next := offset + len(frame.Payload) clonedFrames[index].Payload = payloadBuf[offset:next] copy(clonedFrames[index].Payload, frame.Payload) offset = next } req.frames = clonedFrames return req } func (s *bulkBatchSender) encodeSingle(frame bulkFastFrame) ([]byte, func(), error) { if s == nil || s.codec.encodeSingle == nil { return nil, nil, errTransportDetached } return s.codec.encodeSingle(frame) } func (s *bulkBatchSender) encodeBatch(frames []bulkFastFrame) ([]byte, func(), error) { if len(frames) == 1 || s.codec.encodeBatch == nil { return s.encodeSingle(frames[0]) } return s.codec.encodeBatch(frames) } func (s *bulkBatchSender) stop() { if s == nil { return } s.stopOnce.Do(func() { s.setErr(errTransportDetached) close(s.stopCh) }) <-s.doneCh } func (s *bulkBatchSender) failPending(err error) { for { select { case item := <-s.reqCh: s.finishRequest(item, err) default: return } } } func (s *bulkBatchSender) finishRequest(req bulkBatchRequest, err error) { if s != nil { s.queued.Add(-1) } if req.release != nil { req.release() } req.done <- err } func (s *bulkBatchSender) setErr(err error) { if s == nil || err == nil { return } s.errMu.Lock() if s.err == nil { s.err = err } s.errMu.Unlock() } func (s *bulkBatchSender) errSnapshot() error { if s == nil { return errTransportDetached } s.errMu.Lock() defer s.errMu.Unlock() return s.err } func (s *bulkBatchSender) stoppedErr() error { if err := s.errSnapshot(); err != nil { return err } return errTransportDetached } func (s *bulkBatchSender) transportWriteDeadline() time.Time { if s == nil || s.writeTimeoutProvider == nil { return time.Time{} } return writeDeadlineFromTimeout(s.writeTimeoutProvider()) } func (s *bulkBatchSender) transportWriteTimeout() time.Duration { if s == nil || s.writeTimeoutProvider == nil { return 0 } return s.writeTimeoutProvider() } func (s *bulkBatchSender) sharedMixedPayloadLimit() int { if s == nil || s.binding == nil { return bulkAdaptiveSoftPayloadFallbackBytes } return s.binding.bulkAdaptiveSoftPayloadBytesSnapshot() } func minInt(a int, b int) int { if a < b { return a } return b }