package notify import ( "context" "encoding/binary" "errors" "net" "sync" "time" ) var ( errBulkFastPayloadInvalid = errors.New("invalid bulk fast payload") ) var bulkFastFrameScratchPool sync.Pool const ( bulkFastPayloadMagic = "NBF1" bulkFastPayloadVersion = 1 bulkFastPayloadTypeData = 1 bulkFastPayloadTypeClose = 2 bulkFastPayloadTypeReset = 3 bulkFastPayloadTypeRelease = 4 bulkFastPayloadHeaderLen = 28 bulkFastPayloadFlagFullClose = 1 << 0 ) type bulkFastFrame struct { Type uint8 Flags uint8 DataID uint64 Seq uint64 Payload []byte } type bulkFastDataFrame = bulkFastFrame func encodeBulkFastFrameHeader(dst []byte, frameType uint8, flags uint8, dataID uint64, seq uint64, payloadLen int) error { if dataID == 0 { return errBulkDataIDEmpty } if len(dst) < bulkFastPayloadHeaderLen { return errBulkFastPayloadInvalid } copy(dst[:4], bulkFastPayloadMagic) dst[4] = bulkFastPayloadVersion dst[5] = frameType dst[6] = flags dst[7] = 0 binary.BigEndian.PutUint64(dst[8:16], dataID) binary.BigEndian.PutUint64(dst[16:24], seq) binary.BigEndian.PutUint32(dst[24:28], uint32(payloadLen)) return nil } func encodeBulkFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error { return encodeBulkFastFrameHeader(dst, bulkFastPayloadTypeData, 0, dataID, seq, payloadLen) } func encodeBulkFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byte, error) { frame := make([]byte, bulkFastPayloadHeaderLen+len(payload)) if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { return nil, err } copy(frame[bulkFastPayloadHeaderLen:], payload) return frame, nil } func encodeBulkFastControlFrame(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { frame := make([]byte, bulkFastPayloadHeaderLen+len(payload)) if err := encodeBulkFastFrameHeader(frame, frameType, flags, dataID, seq, len(payload)); err != nil { return nil, err } copy(frame[bulkFastPayloadHeaderLen:], payload) return frame, nil } func decodeBulkFastFrame(payload []byte) (bulkFastFrame, bool, error) { if len(payload) < 4 || string(payload[:4]) != bulkFastPayloadMagic { return bulkFastFrame{}, false, nil } if len(payload) < bulkFastPayloadHeaderLen { return bulkFastFrame{}, true, errBulkFastPayloadInvalid } if payload[4] != bulkFastPayloadVersion { return bulkFastFrame{}, true, errBulkFastPayloadInvalid } switch payload[5] { case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease: default: return bulkFastFrame{}, true, errBulkFastPayloadInvalid } dataLen := int(binary.BigEndian.Uint32(payload[24:28])) if dataLen < 0 || len(payload) != bulkFastPayloadHeaderLen+dataLen { return bulkFastFrame{}, true, errBulkFastPayloadInvalid } dataID := binary.BigEndian.Uint64(payload[8:16]) if dataID == 0 { return bulkFastFrame{}, true, errBulkFastPayloadInvalid } return bulkFastFrame{ Type: payload[5], Flags: payload[6], DataID: dataID, Seq: binary.BigEndian.Uint64(payload[16:24]), Payload: payload[bulkFastPayloadHeaderLen:], }, true, nil } func decodeBulkFastDataFrame(payload []byte) (bulkFastDataFrame, bool, error) { frame, matched, err := decodeBulkFastFrame(payload) if !matched || err != nil { return frame, matched, err } if frame.Type != bulkFastPayloadTypeData { return bulkFastDataFrame{}, false, nil } return frame, true, nil } func (c *ClientCommon) encodeFastBulkDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) { return c.encodeBulkFastPayload(bulkFastFrame{ Type: bulkFastPayloadTypeData, DataID: dataID, Seq: seq, Payload: chunk, }) } func (c *ClientCommon) encodeBulkFastPayload(frame bulkFastFrame) ([]byte, error) { if c == nil { return nil, errBulkClientNil } if c.fastPlainEncode != nil { return encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame) } plain, err := encodeBulkFastFramePayload(frame) if err != nil { return nil, err } return c.encryptTransportPayload(plain) } func (c *ClientCommon) encodeBulkFastBatchPayload(frames []bulkFastFrame) ([]byte, error) { if c == nil { return nil, errBulkClientNil } if c.fastPlainEncode != nil { return encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames) } plain, err := encodeBulkFastBatchPlain(frames) if err != nil { return nil, err } return c.encryptTransportPayload(plain) } func (c *ClientCommon) encodeBulkFastPayloadPooled(frame bulkFastFrame) ([]byte, func(), error) { if c == nil { return nil, nil, errBulkClientNil } if runtime := c.modernPSKRuntime; runtime != nil { return encodeBulkFastFramePayloadPooled(runtime, frame) } if c.fastPlainEncode != nil { payload, err := encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame) return payload, nil, err } plain, err := encodeBulkFastFramePayload(frame) if err != nil { return nil, nil, err } payload, err := c.encryptTransportPayload(plain) return payload, nil, err } func (c *ClientCommon) encodeBulkFastBatchPayloadPooled(frames []bulkFastFrame) ([]byte, func(), error) { if c == nil { return nil, nil, errBulkClientNil } if runtime := c.modernPSKRuntime; runtime != nil { return encodeBulkFastBatchPayloadPooled(runtime, frames) } if c.fastPlainEncode != nil { payload, err := encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames) return payload, nil, err } plain, err := encodeBulkFastBatchPlain(frames) if err != nil { return nil, nil, err } payload, err := c.encryptTransportPayload(plain) return payload, nil, err } func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte, fastPathVersion uint8) error { binding := c.clientTransportBindingSnapshot() if binding == nil { return net.ErrClosed } if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil { return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk) } payload, err := c.encodeFastBulkDataPayload(dataID, seq, chunk) if err != nil { return err } return c.writePayloadToTransport(payload) } func (c *ClientCommon) sendFastBulkWrite(ctx context.Context, dataID uint64, startSeq uint64, chunkSize int, fastPathVersion uint8, payload []byte, payloadOwned bool) (int, error) { if len(payload) == 0 { return 0, nil } binding := c.clientTransportBindingSnapshot() if binding == nil { return 0, net.ErrClosed } if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil { return sender.submitWrite(ctx, dataID, startSeq, fastPathVersion, payload, chunkSize, payloadOwned) } if chunkSize <= 0 { chunkSize = defaultBulkChunkSize } written := 0 seq := startSeq for written < len(payload) { end := written + chunkSize if end > len(payload) { end = len(payload) } if err := c.sendFastBulkData(ctx, dataID, seq, payload[written:end], fastPathVersion); err != nil { return written, err } seq++ written = end } return written, nil } func (c *ClientCommon) sendFastBulkControl(ctx context.Context, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error { frame := bulkFastFrame{ Type: frameType, Flags: flags, DataID: dataID, Seq: seq, Payload: payload, } binding := c.clientTransportBindingSnapshot() if binding == nil { return net.ErrClosed } if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil { return sender.submitControl(ctx, frameType, flags, dataID, seq, fastPathVersion, payload) } encoded, err := c.encodeBulkFastPayload(frame) if err != nil { return err } return c.writePayloadToTransport(encoded) } func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload) if err != nil { return nil, err } return c.encryptTransportPayload(plain) } func (s *ServerCommon) encodeFastBulkDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) { return s.encodeBulkFastPayloadLogical(logical, bulkFastFrame{ Type: bulkFastPayloadTypeData, DataID: dataID, Seq: seq, Payload: chunk, }) } func (s *ServerCommon) encodeBulkFastPayloadLogical(logical *LogicalConn, frame bulkFastFrame) ([]byte, error) { if logical == nil { return nil, errTransportDetached } if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { return encodeBulkFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame) } plain, err := encodeBulkFastFramePayload(frame) if err != nil { return nil, err } return s.encryptTransportPayloadLogical(logical, plain) } func (s *ServerCommon) encodeBulkFastBatchPayloadLogical(logical *LogicalConn, frames []bulkFastFrame) ([]byte, error) { if logical == nil { return nil, errTransportDetached } if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { return encodeBulkFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames) } plain, err := encodeBulkFastBatchPlain(frames) if err != nil { return nil, err } return s.encryptTransportPayloadLogical(logical, plain) } func (s *ServerCommon) encodeBulkFastPayloadLogicalPooled(logical *LogicalConn, frame bulkFastFrame) ([]byte, func(), error) { if logical == nil { return nil, nil, errTransportDetached } if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil { return encodeBulkFastFramePayloadPooled(runtime, frame) } if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { payload, err := encodeBulkFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame) return payload, nil, err } plain, err := encodeBulkFastFramePayload(frame) if err != nil { return nil, nil, err } payload, err := s.encryptTransportPayloadLogical(logical, plain) return payload, nil, err } func (s *ServerCommon) encodeBulkFastBatchPayloadLogicalPooled(logical *LogicalConn, frames []bulkFastFrame) ([]byte, func(), error) { if logical == nil { return nil, nil, errTransportDetached } if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil { return encodeBulkFastBatchPayloadPooled(runtime, frames) } if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { payload, err := encodeBulkFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames) return payload, nil, err } plain, err := encodeBulkFastBatchPlain(frames) if err != nil { return nil, nil, err } payload, err := s.encryptTransportPayloadLogical(logical, plain) return payload, nil, err } func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte, fastPathVersion uint8) error { if err := s.ensureServerTransportSendReady(transport); err != nil { return err } if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return errTransportDetached } if binding := logical.transportBindingSnapshot(); binding != nil { if binding.queueSnapshot() != nil { if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil { return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk) } } } payload, err := s.encodeFastBulkDataPayloadLogical(logical, dataID, seq, chunk) if err != nil { return err } return s.writeEnvelopePayload(logical, transport, nil, payload) } func (s *ServerCommon) sendFastBulkWriteTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, startSeq uint64, chunkSize int, fastPathVersion uint8, payload []byte, payloadOwned bool) (int, error) { if len(payload) == 0 { return 0, nil } if err := s.ensureServerTransportSendReady(transport); err != nil { return 0, err } if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return 0, errTransportDetached } if binding := logical.transportBindingSnapshot(); binding != nil { if binding.queueSnapshot() != nil { if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil { return sender.submitWrite(ctx, dataID, startSeq, fastPathVersion, payload, chunkSize, payloadOwned) } } } if chunkSize <= 0 { chunkSize = defaultBulkChunkSize } written := 0 seq := startSeq for written < len(payload) { end := written + chunkSize if end > len(payload) { end = len(payload) } if err := s.sendFastBulkDataTransport(ctx, logical, transport, dataID, seq, payload[written:end], fastPathVersion); err != nil { return written, err } seq++ written = end } return written, nil } func (s *ServerCommon) sendFastBulkControlTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error { if err := s.ensureServerTransportSendReady(transport); err != nil { return err } if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return errTransportDetached } if binding := logical.transportBindingSnapshot(); binding != nil { if binding.queueSnapshot() != nil { if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil { return sender.submitControl(ctx, frameType, flags, dataID, seq, fastPathVersion, payload) } } } encoded, err := s.encodeBulkFastPayloadLogical(logical, bulkFastFrame{ Type: frameType, Flags: flags, DataID: dataID, Seq: seq, Payload: payload, }) if err != nil { return err } return s.writeEnvelopePayload(logical, transport, nil, encoded) } func (s *ServerCommon) encodeBulkFastControlPayloadLogical(logical *LogicalConn, frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload) if err != nil { return nil, err } return s.encryptTransportPayloadLogical(logical, plain) } func getBulkFastFrameScratch(payloadLen int) []byte { need := bulkFastPayloadHeaderLen + payloadLen if buf, ok := bulkFastFrameScratchPool.Get().([]byte); ok && cap(buf) >= need { return buf[:need] } return make([]byte, need) } func putBulkFastFrameScratch(buf []byte) { if cap(buf) == 0 || cap(buf) > 4*1024*1024 { return } bulkFastFrameScratchPool.Put(buf[:0]) } func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error { plain, err := c.decryptTransportPayload(payload) if err != nil { return err } if frames, matched, err := decodeBulkFastFrames(plain); matched { if err != nil { return err } for _, frame := range frames { c.dispatchFastBulkFrame(frame) } return nil } if frames, matched, err := decodeStreamFastDataFrames(plain); matched { if err != nil { return err } for _, frame := range frames { c.dispatchFastStreamData(frame) } return nil } env, err := c.decodeEnvelopePlain(plain) if err != nil { return err } c.dispatchEnvelope(env, now) return nil } func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte, now time.Time) error { if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return errTransportDetached } plain, err := s.decryptTransportPayloadLogical(logical, payload) if err != nil { return err } if frames, matched, err := decodeBulkFastFrames(plain); matched { if err != nil { return err } for _, frame := range frames { s.dispatchFastBulkFrame(logical, transport, conn, frame) } return nil } if frames, matched, err := decodeStreamFastDataFrames(plain); matched { if err != nil { return err } for _, frame := range frames { s.dispatchFastStreamData(logical, transport, conn, frame) } return nil } env, err := s.decodeEnvelopePlain(plain) if err != nil { return err } s.dispatchEnvelope(logical, transport, conn, env, now) return nil }