package notify import ( "context" "encoding/binary" "errors" "io" ) var ( errStreamFastPayloadInvalid = errors.New("invalid stream fast payload") errStreamFastDataIDEmpty = errors.New("stream data id is empty") ) const ( streamFastPayloadMagic = "NSF1" streamFastPayloadVersion = 1 streamFastPayloadTypeData = 1 streamFastPayloadHeaderLen = 28 streamFastBatchDirectLimit = 512 * 1024 ) type streamFastDataFrame struct { Flags uint8 DataID uint64 Seq uint64 Payload []byte } func streamAdaptiveFramePayloadLimit(binding *transportBinding) int { if binding == nil { return 0 } limit := binding.streamAdaptiveSoftPayloadBytesSnapshot() - streamFastPayloadHeaderLen if limit <= 0 { return 1 } maxPayload := streamFastBatchMaxPlainBytes - streamFastPayloadHeaderLen if limit > maxPayload { return maxPayload } return limit } func streamFastSplitFrameCount(size int, maxPayload int) int { if size <= 0 || maxPayload <= 0 { return 1 } return (size + maxPayload - 1) / maxPayload } func buildStreamFastSplitFrames(dataID uint64, startSeq uint64, chunk []byte, maxPayload int) []streamFastDataFrame { if len(chunk) == 0 { return nil } if maxPayload <= 0 || len(chunk) <= maxPayload { return []streamFastDataFrame{{ DataID: dataID, Seq: startSeq, Payload: chunk, }} } frames := make([]streamFastDataFrame, 0, streamFastSplitFrameCount(len(chunk), maxPayload)) seq := startSeq for offset := 0; offset < len(chunk); offset += maxPayload { end := offset + maxPayload if end > len(chunk) { end = len(chunk) } frames = append(frames, streamFastDataFrame{ DataID: dataID, Seq: seq, Payload: chunk[offset:end], }) seq++ } return frames } func encodeStreamFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error { if dataID == 0 { return errStreamFastDataIDEmpty } if len(dst) < streamFastPayloadHeaderLen { return errStreamFastPayloadInvalid } copy(dst[:4], streamFastPayloadMagic) dst[4] = streamFastPayloadVersion dst[5] = streamFastPayloadTypeData dst[6] = 0 dst[7] = 0 binary.BigEndian.PutUint64(dst[8:16], dataID) binary.BigEndian.PutUint64(dst[16:24], seq) binary.BigEndian.PutUint32(dst[24:28], uint32(payloadLen)) return nil } func encodeStreamFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byte, error) { frame := make([]byte, streamFastPayloadHeaderLen+len(payload)) if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { return nil, err } copy(frame[streamFastPayloadHeaderLen:], payload) return frame, nil } func encodeStreamFastFramePayload(frame streamFastDataFrame) ([]byte, error) { framePayload := make([]byte, streamFastPayloadHeaderLen+len(frame.Payload)) if err := encodeStreamFastDataFrameHeader(framePayload, frame.DataID, frame.Seq, len(frame.Payload)); err != nil { return nil, err } framePayload[6] = frame.Flags copy(framePayload[streamFastPayloadHeaderLen:], frame.Payload) return framePayload, nil } func encodeStreamFastFramePayloadFast(encode transportFastPlainEncoder, secretKey []byte, frame streamFastDataFrame) ([]byte, error) { if encode == nil { return nil, errTransportPayloadEncryptFailed } plainLen := streamFastPayloadHeaderLen + len(frame.Payload) return encode(secretKey, plainLen, func(dst []byte) error { if err := encodeStreamFastDataFrameHeader(dst, frame.DataID, frame.Seq, len(frame.Payload)); err != nil { return err } dst[6] = frame.Flags copy(dst[streamFastPayloadHeaderLen:], frame.Payload) return nil }) } func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error) { if len(payload) < 4 || string(payload[:4]) != streamFastPayloadMagic { return streamFastDataFrame{}, false, nil } if len(payload) < streamFastPayloadHeaderLen { return streamFastDataFrame{}, true, errStreamFastPayloadInvalid } if payload[4] != streamFastPayloadVersion || payload[5] != streamFastPayloadTypeData { return streamFastDataFrame{}, true, errStreamFastPayloadInvalid } dataLen := int(binary.BigEndian.Uint32(payload[24:28])) if dataLen < 0 || len(payload) != streamFastPayloadHeaderLen+dataLen { return streamFastDataFrame{}, true, errStreamFastPayloadInvalid } dataID := binary.BigEndian.Uint64(payload[8:16]) if dataID == 0 { return streamFastDataFrame{}, true, errStreamFastPayloadInvalid } return streamFastDataFrame{ Flags: payload[6], DataID: dataID, Seq: binary.BigEndian.Uint64(payload[16:24]), Payload: payload[streamFastPayloadHeaderLen:], }, true, nil } func (c *ClientCommon) encodeFastStreamPayload(frame streamFastDataFrame) ([]byte, error) { profile := c.clientTransportProtectionSnapshot() if c != nil && profile.fastStreamEncode != nil && frame.Flags == 0 { return profile.fastStreamEncode(profile.secretKey, frame.DataID, frame.Seq, frame.Payload) } if c != nil && profile.fastPlainEncode != nil { return encodeStreamFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame) } plain, err := encodeStreamFastFramePayload(frame) if err != nil { return nil, err } return c.encryptTransportPayload(plain) } func (c *ClientCommon) encodeFastStreamDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) { return c.encodeFastStreamPayload(streamFastDataFrame{ DataID: dataID, Seq: seq, Payload: chunk, }) } func (c *ClientCommon) encodeFastStreamBatchPayload(frames []streamFastDataFrame) ([]byte, error) { if c == nil { return nil, errStreamClientNil } profile := c.clientTransportProtectionSnapshot() if profile.fastPlainEncode != nil { return encodeStreamFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames) } plain, err := encodeStreamFastBatchPlain(frames) if err != nil { return nil, err } return c.encryptTransportPayload(plain) } func (c *ClientCommon) sendFastStreamData(ctx context.Context, stream *streamHandle, chunk []byte) error { if stream == nil { return io.ErrClosedPipe } dataID := stream.dataIDSnapshot() fastPathVersion := stream.fastPathVersionSnapshot() if binding := c.clientTransportBindingSnapshot(); binding != nil && streamFastPathSupportsBatch(fastPathVersion) { if sender := binding.clientStreamBatchSenderSnapshot(c); sender != nil { if maxPayload := streamAdaptiveFramePayloadLimit(binding); maxPayload > 0 && len(chunk) > maxPayload { startSeq := stream.reserveOutboundDataSeqs(streamFastSplitFrameCount(len(chunk), maxPayload)) return sender.submitFrames(ctx, fastPathVersion, buildStreamFastSplitFrames(dataID, startSeq, chunk, maxPayload)) } seq := stream.reserveOutboundDataSeqs(1) if len(chunk) < streamFastBatchDirectLimit { return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk) } payload, err := c.encodeFastStreamDataPayload(dataID, seq, chunk) if err != nil { return err } return sender.submitEncoded(ctx, fastPathVersion, payload) } } seq := stream.reserveOutboundDataSeqs(1) payload, err := c.encodeFastStreamDataPayload(dataID, seq, chunk) if err != nil { return err } return c.writePayloadToTransport(payload) } func (s *ServerCommon) encodeFastStreamPayloadLogical(logical *LogicalConn, frame streamFastDataFrame) ([]byte, error) { if logical == nil { return nil, errTransportDetached } if fastStreamEncode := logical.fastStreamEncodeSnapshot(); fastStreamEncode != nil && frame.Flags == 0 { return fastStreamEncode(logical.secretKeySnapshot(), frame.DataID, frame.Seq, frame.Payload) } if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { return encodeStreamFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame) } plain, err := encodeStreamFastFramePayload(frame) if err != nil { return nil, err } return s.encryptTransportPayloadLogical(logical, plain) } func (s *ServerCommon) encodeFastStreamDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) { return s.encodeFastStreamPayloadLogical(logical, streamFastDataFrame{ DataID: dataID, Seq: seq, Payload: chunk, }) } func (s *ServerCommon) encodeFastStreamBatchPayloadLogical(logical *LogicalConn, frames []streamFastDataFrame) ([]byte, error) { if logical == nil { return nil, errTransportDetached } if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { return encodeStreamFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames) } plain, err := encodeStreamFastBatchPlain(frames) if err != nil { return nil, err } return s.encryptTransportPayloadLogical(logical, plain) } func (s *ServerCommon) sendFastStreamDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, stream *streamHandle, chunk []byte) error { if err := s.ensureServerTransportSendReady(transport); err != nil { return err } if stream == nil { return io.ErrClosedPipe } if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return errTransportDetached } dataID := stream.dataIDSnapshot() fastPathVersion := stream.fastPathVersionSnapshot() if binding := logical.transportBindingSnapshot(); binding != nil && binding.queueSnapshot() != nil && streamFastPathSupportsBatch(fastPathVersion) { if sender := binding.serverStreamBatchSenderSnapshot(logical); sender != nil { if maxPayload := streamAdaptiveFramePayloadLimit(binding); maxPayload > 0 && len(chunk) > maxPayload { startSeq := stream.reserveOutboundDataSeqs(streamFastSplitFrameCount(len(chunk), maxPayload)) return sender.submitFrames(ctx, fastPathVersion, buildStreamFastSplitFrames(dataID, startSeq, chunk, maxPayload)) } seq := stream.reserveOutboundDataSeqs(1) if len(chunk) < streamFastBatchDirectLimit { return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk) } payload, err := s.encodeFastStreamDataPayloadLogical(logical, dataID, seq, chunk) if err != nil { return err } return sender.submitEncoded(ctx, fastPathVersion, payload) } } seq := stream.reserveOutboundDataSeqs(1) payload, err := s.encodeFastStreamDataPayloadLogical(logical, dataID, seq, chunk) if err != nil { return err } return s.writeEnvelopePayload(logical, transport, nil, payload) }