package notify import ( "context" "net" "sync" "sync/atomic" "time" ) const ( bulkBatchMaxPayloads = 16 ) const ( bulkBatchRequestQueued int32 = iota bulkBatchRequestStarted bulkBatchRequestCanceled ) type bulkBatchRequestState struct { value atomic.Int32 } type bulkBatchRequest struct { ctx context.Context payload []byte deadline time.Time done chan error state *bulkBatchRequestState } type bulkBatchSender struct { binding *transportBinding reqCh chan bulkBatchRequest stopCh chan struct{} doneCh chan struct{} stopOnce sync.Once errMu sync.Mutex err error } func newBulkBatchSender(binding *transportBinding) *bulkBatchSender { sender := &bulkBatchSender{ binding: binding, reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4), stopCh: make(chan struct{}), doneCh: make(chan struct{}), } go sender.run() return sender } func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error { if s == nil { return errTransportDetached } if ctx == nil { ctx = context.Background() } req := bulkBatchRequest{ ctx: ctx, payload: payload, done: make(chan error, 1), state: &bulkBatchRequestState{}, } if deadline, ok := ctx.Deadline(); ok { req.deadline = deadline } if err := s.errSnapshot(); err != nil { return err } select { case <-ctx.Done(): return normalizeStreamDeadlineError(ctx.Err()) case <-s.stopCh: return s.stoppedErr() case s.reqCh <- req: } select { case err := <-req.done: return err case <-ctx.Done(): if req.tryCancel() { return normalizeStreamDeadlineError(ctx.Err()) } return <-req.done } } func (s *bulkBatchSender) run() { defer close(s.doneCh) for { req, ok := s.nextRequest() if !ok { return } batch := []bulkBatchRequest{req} drain: for len(batch) < bulkBatchMaxPayloads { select { case <-s.stopCh: s.failPending(s.stoppedErr()) return case next := <-s.reqCh: batch = append(batch, next) default: break drain } } active, payloads := activeBulkBatchRequests(batch) if len(active) == 0 { continue } deadline := bulkBatchRequestsEarliestDeadline(active) err := s.flush(payloads, deadline) if err != nil { s.setErr(err) for _, item := range active { item.done <- err } s.failPending(err) return } for _, item := range active { item.done <- err } } } func (s *bulkBatchSender) nextRequest() (bulkBatchRequest, bool) { select { case <-s.stopCh: s.failPending(s.stoppedErr()) return bulkBatchRequest{}, false case req := <-s.reqCh: return req, true } } func activeBulkBatchRequests(batch []bulkBatchRequest) ([]bulkBatchRequest, [][]byte) { active := make([]bulkBatchRequest, 0, len(batch)) payloads := make([][]byte, 0, len(batch)) for _, item := range batch { if !item.tryStart() { item.done <- item.canceledErr() continue } if err := item.contextErr(); err != nil { item.done <- err continue } active = append(active, item) payloads = append(payloads, item.payload) } return active, payloads } func bulkBatchRequestsEarliestDeadline(batch []bulkBatchRequest) time.Time { var deadline time.Time for _, item := range batch { if item.deadline.IsZero() { continue } if deadline.IsZero() || item.deadline.Before(deadline) { deadline = item.deadline } } return deadline } func (r bulkBatchRequest) contextErr() error { if r.ctx == nil { return nil } select { case <-r.ctx.Done(): return normalizeStreamDeadlineError(r.ctx.Err()) default: return nil } } func (r bulkBatchRequest) tryStart() bool { if r.state == nil { return true } return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestStarted) } func (r bulkBatchRequest) tryCancel() bool { if r.state == nil { return false } return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestCanceled) } func (r bulkBatchRequest) canceledErr() error { if err := r.contextErr(); err != nil { return err } return context.Canceled } func (s *bulkBatchSender) flush(payloads [][]byte, deadline time.Time) error { if s == nil || s.binding == nil { return errTransportDetached } queue := s.binding.queueSnapshot() if queue == nil { return errTransportFrameQueueUnavailable } return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error { return writeFramedPayloadBatchUnlocked(conn, queue, payloads) }) } func (s *bulkBatchSender) stop() { if s == nil { return } s.stopOnce.Do(func() { s.setErr(errTransportDetached) close(s.stopCh) }) <-s.doneCh } func (s *bulkBatchSender) failPending(err error) { for { select { case item := <-s.reqCh: item.done <- err default: return } } } func (s *bulkBatchSender) setErr(err error) { if s == nil || err == nil { return } s.errMu.Lock() if s.err == nil { s.err = err } s.errMu.Unlock() } func (s *bulkBatchSender) errSnapshot() error { if s == nil { return errTransportDetached } s.errMu.Lock() defer s.errMu.Unlock() return s.err } func (s *bulkBatchSender) stoppedErr() error { if err := s.errSnapshot(); err != nil { return err } return errTransportDetached }