notify/bulk_batch_sender.go

267 lines
5.1 KiB
Go
Raw Normal View History

package notify
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
)
const (
bulkBatchMaxPayloads = 16
)
const (
bulkBatchRequestQueued int32 = iota
bulkBatchRequestStarted
bulkBatchRequestCanceled
)
type bulkBatchRequestState struct {
value atomic.Int32
}
type bulkBatchRequest struct {
ctx context.Context
payload []byte
deadline time.Time
done chan error
state *bulkBatchRequestState
}
type bulkBatchSender struct {
binding *transportBinding
reqCh chan bulkBatchRequest
stopCh chan struct{}
doneCh chan struct{}
stopOnce sync.Once
errMu sync.Mutex
err error
}
func newBulkBatchSender(binding *transportBinding) *bulkBatchSender {
sender := &bulkBatchSender{
binding: binding,
reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
go sender.run()
return sender
}
func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error {
if s == nil {
return errTransportDetached
}
if ctx == nil {
ctx = context.Background()
}
req := bulkBatchRequest{
ctx: ctx,
payload: payload,
done: make(chan error, 1),
state: &bulkBatchRequestState{},
}
if deadline, ok := ctx.Deadline(); ok {
req.deadline = deadline
}
if err := s.errSnapshot(); err != nil {
return err
}
select {
case <-ctx.Done():
return normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
return s.stoppedErr()
case s.reqCh <- req:
}
select {
case err := <-req.done:
return err
case <-ctx.Done():
if req.tryCancel() {
return normalizeStreamDeadlineError(ctx.Err())
}
return <-req.done
}
}
func (s *bulkBatchSender) run() {
defer close(s.doneCh)
for {
req, ok := s.nextRequest()
if !ok {
return
}
batch := []bulkBatchRequest{req}
drain:
for len(batch) < bulkBatchMaxPayloads {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
default:
break drain
}
}
active, payloads := activeBulkBatchRequests(batch)
if len(active) == 0 {
continue
}
deadline := bulkBatchRequestsEarliestDeadline(active)
err := s.flush(payloads, deadline)
if err != nil {
s.setErr(err)
for _, item := range active {
item.done <- err
}
s.failPending(err)
return
}
for _, item := range active {
item.done <- err
}
}
}
func (s *bulkBatchSender) nextRequest() (bulkBatchRequest, bool) {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return bulkBatchRequest{}, false
case req := <-s.reqCh:
return req, true
}
}
func activeBulkBatchRequests(batch []bulkBatchRequest) ([]bulkBatchRequest, [][]byte) {
active := make([]bulkBatchRequest, 0, len(batch))
payloads := make([][]byte, 0, len(batch))
for _, item := range batch {
if !item.tryStart() {
item.done <- item.canceledErr()
continue
}
if err := item.contextErr(); err != nil {
item.done <- err
continue
}
active = append(active, item)
payloads = append(payloads, item.payload)
}
return active, payloads
}
func bulkBatchRequestsEarliestDeadline(batch []bulkBatchRequest) time.Time {
var deadline time.Time
for _, item := range batch {
if item.deadline.IsZero() {
continue
}
if deadline.IsZero() || item.deadline.Before(deadline) {
deadline = item.deadline
}
}
return deadline
}
func (r bulkBatchRequest) contextErr() error {
if r.ctx == nil {
return nil
}
select {
case <-r.ctx.Done():
return normalizeStreamDeadlineError(r.ctx.Err())
default:
return nil
}
}
func (r bulkBatchRequest) tryStart() bool {
if r.state == nil {
return true
}
return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestStarted)
}
func (r bulkBatchRequest) tryCancel() bool {
if r.state == nil {
return false
}
return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestCanceled)
}
func (r bulkBatchRequest) canceledErr() error {
if err := r.contextErr(); err != nil {
return err
}
return context.Canceled
}
func (s *bulkBatchSender) flush(payloads [][]byte, deadline time.Time) error {
if s == nil || s.binding == nil {
return errTransportDetached
}
queue := s.binding.queueSnapshot()
if queue == nil {
return errTransportFrameQueueUnavailable
}
return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error {
return writeFramedPayloadBatchUnlocked(conn, queue, payloads)
})
}
func (s *bulkBatchSender) stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
s.setErr(errTransportDetached)
close(s.stopCh)
})
<-s.doneCh
}
func (s *bulkBatchSender) failPending(err error) {
for {
select {
case item := <-s.reqCh:
item.done <- err
default:
return
}
}
}
func (s *bulkBatchSender) setErr(err error) {
if s == nil || err == nil {
return
}
s.errMu.Lock()
if s.err == nil {
s.err = err
}
s.errMu.Unlock()
}
func (s *bulkBatchSender) errSnapshot() error {
if s == nil {
return errTransportDetached
}
s.errMu.Lock()
defer s.errMu.Unlock()
return s.err
}
func (s *bulkBatchSender) stoppedErr() error {
if err := s.errSnapshot(); err != nil {
return err
}
return errTransportDetached
}