package notify import ( "context" "encoding/binary" "errors" "io" "net" "sync" "sync/atomic" "time" ) const ( bulkDedicatedBatchMagic = "NBD2" bulkDedicatedBatchVersion = 1 bulkDedicatedBatchHeaderLen = 20 bulkDedicatedBatchItemHeaderLen = 16 bulkDedicatedBatchMaxItems = 32 bulkDedicatedBatchMaxPlainBytes = 8 * 1024 * 1024 bulkDedicatedSendQueueSize = bulkDedicatedBatchMaxItems bulkDedicatedReleasePayloadLen = 12 ) const ( bulkDedicatedRequestQueued int32 = iota bulkDedicatedRequestStarted bulkDedicatedRequestCanceled ) type bulkDedicatedRequestState struct { value atomic.Int32 } type bulkDedicatedBatchItem struct { Type uint8 Flags uint8 Seq uint64 Payload []byte } type bulkDedicatedSendRequest struct { Type uint8 Flags uint8 Seq uint64 Payload []byte } type bulkDedicatedBatchRequest struct { Ctx context.Context Items []bulkDedicatedSendRequest Deadline time.Time Ack chan error State *bulkDedicatedRequestState } type bulkDedicatedSender struct { conn net.Conn dataID uint64 encrypt func([]byte) ([]byte, error) encodeBatch func([]bulkDedicatedSendRequest) ([]byte, error) fail func(error) reqCh chan bulkDedicatedBatchRequest stopCh chan struct{} doneCh chan struct{} stopOnce sync.Once flushMu sync.Mutex queued atomic.Int64 errMu sync.Mutex err error } func newBulkDedicatedSender(conn net.Conn, dataID uint64, encrypt func([]byte) ([]byte, error), encodeBatch func([]bulkDedicatedSendRequest) ([]byte, error), fail func(error)) *bulkDedicatedSender { sender := &bulkDedicatedSender{ conn: conn, dataID: dataID, encrypt: encrypt, encodeBatch: encodeBatch, fail: fail, reqCh: make(chan bulkDedicatedBatchRequest, bulkDedicatedSendQueueSize), stopCh: make(chan struct{}), doneCh: make(chan struct{}), } go sender.run() return sender } func (s *bulkDedicatedSender) submitData(ctx context.Context, seq uint64, payload []byte) error { if s == nil { return errTransportDetached } items := []bulkDedicatedSendRequest{{ Type: bulkFastPayloadTypeData, Seq: seq, Payload: append([]byte(nil), payload...), }} return s.submitBatch(ctx, items, false) } func (s *bulkDedicatedSender) submitWrite(ctx context.Context, startSeq uint64, payload []byte, chunkSize int) (int, error) { if s == nil { return 0, errTransportDetached } if len(payload) == 0 { return 0, nil } if chunkSize <= 0 { chunkSize = defaultBulkChunkSize } written := 0 seq := startSeq for written < len(payload) { var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest items := itemBuf[:0] batchBytes := bulkDedicatedBatchHeaderLen start := written for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems { end := written + chunkSize if end > len(payload) { end = len(payload) } itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written) if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes { break } items = append(items, bulkDedicatedSendRequest{ Type: bulkFastPayloadTypeData, Seq: seq, Payload: payload[written:end], }) batchBytes += itemLen seq++ written = end } if len(items) == 0 { end := written + chunkSize if end > len(payload) { end = len(payload) } items = append(items, bulkDedicatedSendRequest{ Type: bulkFastPayloadTypeData, Seq: seq, Payload: payload[written:end], }) seq++ written = end } if err := s.submitWriteBatch(ctx, items); err != nil { return start, err } start = written } return written, nil } func (s *bulkDedicatedSender) submitWriteBatch(ctx context.Context, items []bulkDedicatedSendRequest) error { if s == nil { return errTransportDetached } if len(items) == 0 { return nil } if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted { return err } queuedItems := make([]bulkDedicatedSendRequest, len(items)) copy(queuedItems, items) return s.submitBatch(ctx, queuedItems, true) } func (s *bulkDedicatedSender) submitControl(ctx context.Context, frameType uint8, flags uint8, seq uint64, payload []byte) error { if s == nil { return errTransportDetached } items := []bulkDedicatedSendRequest{{ Type: frameType, Flags: flags, Seq: seq, }} if len(payload) > 0 { items[0].Payload = append([]byte(nil), payload...) } return s.submitBatch(ctx, items, true) } func (s *bulkDedicatedSender) submitBatch(ctx context.Context, items []bulkDedicatedSendRequest, wait bool) error { if s == nil { return errTransportDetached } if ctx == nil { ctx = context.Background() } if err := s.errSnapshot(); err != nil { return err } req := bulkDedicatedBatchRequest{ Ctx: ctx, Items: items, State: &bulkDedicatedRequestState{}, } if deadline, ok := ctx.Deadline(); ok { req.Deadline = deadline } if wait { req.Ack = make(chan error, 1) } s.queued.Add(1) select { case <-ctx.Done(): s.queued.Add(-1) return normalizeStreamDeadlineError(ctx.Err()) case <-s.stopCh: s.queued.Add(-1) return s.stoppedErr() case s.reqCh <- req: if !wait { return nil } return s.waitAck(req) } } func (s *bulkDedicatedSender) tryDirectSubmitBatch(ctx context.Context, items []bulkDedicatedSendRequest) (bool, error) { if s == nil { return true, errTransportDetached } if ctx == nil { ctx = context.Background() } if len(items) == 0 { return true, nil } if err := s.errSnapshot(); err != nil { return true, err } select { case <-ctx.Done(): return true, normalizeStreamDeadlineError(ctx.Err()) case <-s.stopCh: return true, s.stoppedErr() default: } if s.queued.Load() != 0 { return false, nil } if !s.flushMu.TryLock() { return false, nil } defer s.flushMu.Unlock() if s.queued.Load() != 0 { return false, nil } if err := s.errSnapshot(); err != nil { return true, err } select { case <-ctx.Done(): return true, normalizeStreamDeadlineError(ctx.Err()) case <-s.stopCh: return true, s.stoppedErr() default: } deadline, _ := ctx.Deadline() if err := s.flush(items, deadline); err != nil { err = normalizeDedicatedBulkSendError(err) s.setErr(err) s.failPending(err) if s.fail != nil { go s.fail(err) } return true, err } return true, nil } func (s *bulkDedicatedSender) waitAck(req bulkDedicatedBatchRequest) error { if s == nil { return errTransportDetached } ctx := req.Ctx if ctx == nil { ctx = context.Background() } select { case err := <-req.Ack: return normalizeDedicatedBulkSendError(err) case <-ctx.Done(): if req.tryCancel() { return normalizeStreamDeadlineError(ctx.Err()) } return normalizeDedicatedBulkSendError(<-req.Ack) } } func (s *bulkDedicatedSender) stop() { if s == nil { return } s.stopOnce.Do(func() { s.setErr(errTransportDetached) close(s.stopCh) }) <-s.doneCh } func (s *bulkDedicatedSender) run() { defer close(s.doneCh) for { req, ok := s.nextRequest() if !ok { return } if !req.tryStart() { s.finishRequest(req, req.canceledErr()) continue } if err := req.contextErr(); err != nil { s.finishRequest(req, err) continue } s.flushMu.Lock() err := s.errSnapshot() if err == nil { err = s.flush(req.Items, req.Deadline) } s.flushMu.Unlock() if err != nil { err = normalizeDedicatedBulkSendError(err) s.setErr(err) s.finishRequest(req, err) s.failPending(err) if s.fail != nil { go s.fail(err) } return } s.finishRequest(req, nil) } } func (r bulkDedicatedBatchRequest) contextErr() error { if r.Ctx == nil { return nil } select { case <-r.Ctx.Done(): return normalizeStreamDeadlineError(r.Ctx.Err()) default: return nil } } func (r bulkDedicatedBatchRequest) tryStart() bool { if r.State == nil { return true } return r.State.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestStarted) } func (r bulkDedicatedBatchRequest) tryCancel() bool { if r.State == nil { return false } return r.State.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestCanceled) } func (r bulkDedicatedBatchRequest) canceledErr() error { if err := r.contextErr(); err != nil { return err } return context.Canceled } func (s *bulkDedicatedSender) nextRequest() (bulkDedicatedBatchRequest, bool) { select { case <-s.stopCh: s.failPending(s.stoppedErr()) return bulkDedicatedBatchRequest{}, false case req := <-s.reqCh: return req, true } } func (s *bulkDedicatedSender) flush(batch []bulkDedicatedSendRequest, deadline time.Time) error { if s == nil || s.conn == nil { return errTransportDetached } var ( payload []byte err error ) if s.encodeBatch != nil { payload, err = s.encodeBatch(batch) } else { plain, plainErr := encodeBulkDedicatedBatchPlain(s.dataID, batch) if plainErr != nil { return plainErr } payload, err = s.encrypt(plain) } if err != nil { return err } return writeBulkDedicatedRecordWithDeadline(s.conn, payload, deadline) } func (s *bulkDedicatedSender) ack(req bulkDedicatedBatchRequest, err error) { if req.Ack != nil { req.Ack <- err } } func (s *bulkDedicatedSender) finishRequest(req bulkDedicatedBatchRequest, err error) { if s != nil { s.queued.Add(-1) } s.ack(req, err) } func (s *bulkDedicatedSender) failPending(err error) { for { select { case item := <-s.reqCh: s.finishRequest(item, err) default: return } } } func (s *bulkDedicatedSender) setErr(err error) { if s == nil || err == nil { return } s.errMu.Lock() if s.err == nil { s.err = err } s.errMu.Unlock() } func (s *bulkDedicatedSender) errSnapshot() error { if s == nil { return errTransportDetached } s.errMu.Lock() defer s.errMu.Unlock() return s.err } func (s *bulkDedicatedSender) stoppedErr() error { if err := s.errSnapshot(); err != nil { return err } return errTransportDetached } func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int { return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload)) } func bulkDedicatedSendRequestLenFromPayloadLen(payloadLen int) int { return bulkDedicatedBatchItemHeaderLen + payloadLen } func encodeBulkDedicatedReleasePayload(bytes int64, chunks int) ([]byte, error) { if bytes <= 0 && chunks <= 0 { return nil, errBulkFastPayloadInvalid } if chunks < 0 { return nil, errBulkFastPayloadInvalid } payload := make([]byte, bulkDedicatedReleasePayloadLen) binary.BigEndian.PutUint64(payload[:8], uint64(bytes)) binary.BigEndian.PutUint32(payload[8:12], uint32(chunks)) return payload, nil } func decodeBulkDedicatedReleasePayload(payload []byte) (int64, int, error) { if len(payload) != bulkDedicatedReleasePayloadLen { return 0, 0, errBulkFastPayloadInvalid } bytes := int64(binary.BigEndian.Uint64(payload[:8])) chunks := int(binary.BigEndian.Uint32(payload[8:12])) if bytes <= 0 && chunks <= 0 { return 0, 0, errBulkFastPayloadInvalid } return bytes, chunks, nil } func encodeBulkDedicatedBatchPlain(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { if dataID == 0 || len(items) == 0 { return nil, errBulkFastPayloadInvalid } total := bulkDedicatedBatchPlainLen(items) buf := make([]byte, total) if err := writeBulkDedicatedBatchPlain(buf, dataID, items); err != nil { return nil, err } return buf, nil } func encodeBulkDedicatedBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { if encode == nil { return nil, errTransportPayloadEncryptFailed } plainLen := bulkDedicatedBatchPlainLen(items) return encode(secretKey, plainLen, func(dst []byte) error { return writeBulkDedicatedBatchPlain(dst, dataID, items) }) } func bulkDedicatedBatchPlainLen(items []bulkDedicatedSendRequest) int { total := bulkDedicatedBatchHeaderLen for _, item := range items { total += bulkDedicatedSendRequestLen(item) } return total } func writeBulkDedicatedBatchPlain(buf []byte, dataID uint64, items []bulkDedicatedSendRequest) error { if dataID == 0 || len(items) == 0 { return errBulkFastPayloadInvalid } if len(buf) != bulkDedicatedBatchPlainLen(items) { return errBulkFastPayloadInvalid } copy(buf[:4], bulkDedicatedBatchMagic) buf[4] = bulkDedicatedBatchVersion binary.BigEndian.PutUint64(buf[8:16], dataID) binary.BigEndian.PutUint32(buf[16:20], uint32(len(items))) offset := bulkDedicatedBatchHeaderLen for _, item := range items { buf[offset] = item.Type buf[offset+1] = item.Flags binary.BigEndian.PutUint64(buf[offset+4:offset+12], item.Seq) binary.BigEndian.PutUint32(buf[offset+12:offset+16], uint32(len(item.Payload))) offset += bulkDedicatedBatchItemHeaderLen copy(buf[offset:offset+len(item.Payload)], item.Payload) offset += len(item.Payload) } return nil } func decodeBulkDedicatedBatchPlain(payload []byte) (uint64, []bulkDedicatedBatchItem, bool, error) { if len(payload) < 4 || string(payload[:4]) != bulkDedicatedBatchMagic { return 0, nil, false, nil } if len(payload) < bulkDedicatedBatchHeaderLen { return 0, nil, true, errBulkFastPayloadInvalid } if payload[4] != bulkDedicatedBatchVersion { return 0, nil, true, errBulkFastPayloadInvalid } dataID := binary.BigEndian.Uint64(payload[8:16]) count := int(binary.BigEndian.Uint32(payload[16:20])) if dataID == 0 || count <= 0 { return 0, nil, true, errBulkFastPayloadInvalid } items := make([]bulkDedicatedBatchItem, 0, count) offset := bulkDedicatedBatchHeaderLen for i := 0; i < count; i++ { if len(payload)-offset < bulkDedicatedBatchItemHeaderLen { return 0, nil, true, errBulkFastPayloadInvalid } itemType := payload[offset] switch itemType { case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease: default: return 0, nil, true, errBulkFastPayloadInvalid } flags := payload[offset+1] seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12]) dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16])) offset += bulkDedicatedBatchItemHeaderLen if dataLen < 0 || len(payload)-offset < dataLen { return 0, nil, true, errBulkFastPayloadInvalid } items = append(items, bulkDedicatedBatchItem{ Type: itemType, Flags: flags, Seq: seq, Payload: payload[offset : offset+dataLen], }) offset += dataLen } if offset != len(payload) { return 0, nil, true, errBulkFastPayloadInvalid } return dataID, items, true, nil } func decodeDedicatedBulkInboundItems(expectedDataID uint64, plain []byte) ([]bulkDedicatedBatchItem, error) { if dataID, items, matched, err := decodeBulkDedicatedBatchPlain(plain); matched { if err != nil { return nil, err } if expectedDataID == 0 || dataID != expectedDataID { return nil, errBulkFastPayloadInvalid } return items, nil } frame, matched, err := decodeBulkFastFrame(plain) if err != nil { return nil, err } if !matched || expectedDataID == 0 || frame.DataID != expectedDataID { return nil, errBulkFastPayloadInvalid } return []bulkDedicatedBatchItem{{ Type: frame.Type, Flags: frame.Flags, Seq: frame.Seq, Payload: frame.Payload, }}, nil } func normalizeDedicatedBulkSendError(err error) error { switch { case err == nil: return nil case errors.Is(err, net.ErrClosed): return errTransportDetached default: return normalizeStreamDeadlineError(err) } } func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchItem) error { if bulk == nil { return io.ErrClosedPipe } switch item.Type { case bulkFastPayloadTypeData: return bulk.pushOwnedChunkNoReset(item.Payload) case bulkFastPayloadTypeClose: if item.Flags&bulkFastPayloadFlagFullClose != 0 { bulk.markPeerClosed() return nil } bulk.markRemoteClosed() return nil case bulkFastPayloadTypeReset: resetErr := errBulkReset if len(item.Payload) > 0 { resetErr = bulkRemoteResetError(string(item.Payload)) } bulk.markReset(bulkResetError(resetErr)) return nil case bulkFastPayloadTypeRelease: bytes, chunks, err := decodeBulkDedicatedReleasePayload(item.Payload) if err != nil { return err } bulk.releaseOutboundWindow(bytes, chunks) return nil default: return errBulkFastPayloadInvalid } }