package notify import ( "context" "encoding/binary" "errors" "net" "sync" "time" ) var ( errBulkFastPayloadInvalid = errors.New("invalid bulk fast payload") ) var bulkFastFrameScratchPool sync.Pool const ( bulkFastPayloadMagic = "NBF1" bulkFastPayloadVersion = 1 bulkFastPayloadTypeData = 1 bulkFastPayloadTypeClose = 2 bulkFastPayloadTypeReset = 3 bulkFastPayloadTypeRelease = 4 bulkFastPayloadHeaderLen = 28 bulkFastPayloadFlagFullClose = 1 << 0 ) type bulkFastFrame struct { Type uint8 Flags uint8 DataID uint64 Seq uint64 Payload []byte } type bulkFastDataFrame = bulkFastFrame func encodeBulkFastFrameHeader(dst []byte, frameType uint8, flags uint8, dataID uint64, seq uint64, payloadLen int) error { if dataID == 0 { return errBulkDataIDEmpty } if len(dst) < bulkFastPayloadHeaderLen { return errBulkFastPayloadInvalid } copy(dst[:4], bulkFastPayloadMagic) dst[4] = bulkFastPayloadVersion dst[5] = frameType dst[6] = flags dst[7] = 0 binary.BigEndian.PutUint64(dst[8:16], dataID) binary.BigEndian.PutUint64(dst[16:24], seq) binary.BigEndian.PutUint32(dst[24:28], uint32(payloadLen)) return nil } func encodeBulkFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error { return encodeBulkFastFrameHeader(dst, bulkFastPayloadTypeData, 0, dataID, seq, payloadLen) } func encodeBulkFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byte, error) { frame := make([]byte, bulkFastPayloadHeaderLen+len(payload)) if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { return nil, err } copy(frame[bulkFastPayloadHeaderLen:], payload) return frame, nil } func encodeBulkFastControlFrame(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { frame := make([]byte, bulkFastPayloadHeaderLen+len(payload)) if err := encodeBulkFastFrameHeader(frame, frameType, flags, dataID, seq, len(payload)); err != nil { return nil, err } copy(frame[bulkFastPayloadHeaderLen:], payload) return frame, nil } func decodeBulkFastFrame(payload []byte) (bulkFastFrame, bool, error) { if len(payload) < 4 || string(payload[:4]) != bulkFastPayloadMagic { return bulkFastFrame{}, false, nil } if len(payload) < bulkFastPayloadHeaderLen { return bulkFastFrame{}, true, errBulkFastPayloadInvalid } if payload[4] != bulkFastPayloadVersion { return bulkFastFrame{}, true, errBulkFastPayloadInvalid } switch payload[5] { case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease: default: return bulkFastFrame{}, true, errBulkFastPayloadInvalid } dataLen := int(binary.BigEndian.Uint32(payload[24:28])) if dataLen < 0 || len(payload) != bulkFastPayloadHeaderLen+dataLen { return bulkFastFrame{}, true, errBulkFastPayloadInvalid } dataID := binary.BigEndian.Uint64(payload[8:16]) if dataID == 0 { return bulkFastFrame{}, true, errBulkFastPayloadInvalid } return bulkFastFrame{ Type: payload[5], Flags: payload[6], DataID: dataID, Seq: binary.BigEndian.Uint64(payload[16:24]), Payload: payload[bulkFastPayloadHeaderLen:], }, true, nil } func decodeBulkFastDataFrame(payload []byte) (bulkFastDataFrame, bool, error) { frame, matched, err := decodeBulkFastFrame(payload) if !matched || err != nil { return frame, matched, err } if frame.Type != bulkFastPayloadTypeData { return bulkFastDataFrame{}, false, nil } return frame, true, nil } func (c *ClientCommon) encodeFastBulkDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) { if c != nil && c.fastBulkEncode != nil { return c.fastBulkEncode(c.SecretKey, dataID, seq, chunk) } scratch := getBulkFastFrameScratch(len(chunk)) defer putBulkFastFrameScratch(scratch) frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)] if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil { return nil, err } copy(frame[bulkFastPayloadHeaderLen:], chunk) return c.encryptTransportPayload(frame) } func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte) error { payload, err := c.encodeFastBulkDataPayload(dataID, seq, chunk) if err != nil { return err } binding := c.clientTransportBindingSnapshot() if binding == nil { return net.ErrClosed } if sender := binding.bulkBatchSenderSnapshot(); sender != nil { return sender.submit(ctx, payload) } return c.writePayloadToTransport(payload) } func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload) if err != nil { return nil, err } return c.encryptTransportPayload(plain) } func (s *ServerCommon) encodeFastBulkDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) { if logical != nil { if fastBulkEncode := logical.fastBulkEncodeSnapshot(); fastBulkEncode != nil { return fastBulkEncode(logical.secretKeySnapshot(), dataID, seq, chunk) } } scratch := getBulkFastFrameScratch(len(chunk)) defer putBulkFastFrameScratch(scratch) frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)] if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil { return nil, err } copy(frame[bulkFastPayloadHeaderLen:], chunk) return s.encryptTransportPayloadLogical(logical, frame) } func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte) error { if err := s.ensureServerTransportSendReady(transport); err != nil { return err } if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return errTransportDetached } payload, err := s.encodeFastBulkDataPayloadLogical(logical, dataID, seq, chunk) if err != nil { return err } if binding := logical.transportBindingSnapshot(); binding != nil { if binding.queueSnapshot() != nil { if sender := binding.bulkBatchSenderSnapshot(); sender != nil { return sender.submit(ctx, payload) } } } return s.writeEnvelopePayload(logical, transport, nil, payload) } func (s *ServerCommon) encodeBulkFastControlPayloadLogical(logical *LogicalConn, frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload) if err != nil { return nil, err } return s.encryptTransportPayloadLogical(logical, plain) } func getBulkFastFrameScratch(payloadLen int) []byte { need := bulkFastPayloadHeaderLen + payloadLen if buf, ok := bulkFastFrameScratchPool.Get().([]byte); ok && cap(buf) >= need { return buf[:need] } return make([]byte, need) } func putBulkFastFrameScratch(buf []byte) { if cap(buf) == 0 || cap(buf) > 4*1024*1024 { return } bulkFastFrameScratchPool.Put(buf[:0]) } func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error { plain, err := c.decryptTransportPayload(payload) if err != nil { return err } if frame, matched, err := decodeBulkFastFrame(plain); matched { if err != nil { return err } c.dispatchFastBulkFrame(frame) return nil } if frame, matched, err := decodeStreamFastDataFrame(plain); matched { if err != nil { return err } c.dispatchFastStreamData(frame) return nil } env, err := c.decodeEnvelopePlain(plain) if err != nil { return err } c.dispatchEnvelope(env, now) return nil } func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte, now time.Time) error { if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return errTransportDetached } plain, err := s.decryptTransportPayloadLogical(logical, payload) if err != nil { return err } if frame, matched, err := decodeBulkFastFrame(plain); matched { if err != nil { return err } s.dispatchFastBulkFrame(logical, transport, conn, frame) return nil } if frame, matched, err := decodeStreamFastDataFrame(plain); matched { if err != nil { return err } s.dispatchFastStreamData(logical, transport, conn, frame) return nil } env, err := s.decodeEnvelopePlain(plain) if err != nil { return err } s.dispatchEnvelope(logical, transport, conn, env, now) return nil }