notify/stream_fastpath.go

305 lines
10 KiB
Go
Raw Permalink Normal View History

package notify
import (
"context"
"encoding/binary"
"errors"
"io"
)
var (
errStreamFastPayloadInvalid = errors.New("invalid stream fast payload")
errStreamFastDataIDEmpty = errors.New("stream data id is empty")
)
const (
streamFastPayloadMagic = "NSF1"
streamFastPayloadVersion = 1
streamFastPayloadTypeData = 1
streamFastPayloadHeaderLen = 28
streamFastBatchDirectLimit = 512 * 1024
)
type streamFastDataFrame struct {
Flags uint8
DataID uint64
Seq uint64
Payload []byte
}
func streamAdaptiveFramePayloadLimit(binding *transportBinding) int {
if binding == nil {
return 0
}
limit := binding.streamAdaptiveSoftPayloadBytesSnapshot() - streamFastPayloadHeaderLen
if limit <= 0 {
return 1
}
maxPayload := streamFastBatchMaxPlainBytes - streamFastPayloadHeaderLen
if limit > maxPayload {
return maxPayload
}
return limit
}
func streamFastSplitFrameCount(size int, maxPayload int) int {
if size <= 0 || maxPayload <= 0 {
return 1
}
return (size + maxPayload - 1) / maxPayload
}
func buildStreamFastSplitFrames(dataID uint64, startSeq uint64, chunk []byte, maxPayload int) []streamFastDataFrame {
if len(chunk) == 0 {
return nil
}
if maxPayload <= 0 || len(chunk) <= maxPayload {
return []streamFastDataFrame{{
DataID: dataID,
Seq: startSeq,
Payload: chunk,
}}
}
frames := make([]streamFastDataFrame, 0, streamFastSplitFrameCount(len(chunk), maxPayload))
seq := startSeq
for offset := 0; offset < len(chunk); offset += maxPayload {
end := offset + maxPayload
if end > len(chunk) {
end = len(chunk)
}
frames = append(frames, streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: chunk[offset:end],
})
seq++
}
return frames
}
func encodeStreamFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error {
if dataID == 0 {
return errStreamFastDataIDEmpty
}
if len(dst) < streamFastPayloadHeaderLen {
return errStreamFastPayloadInvalid
}
copy(dst[:4], streamFastPayloadMagic)
dst[4] = streamFastPayloadVersion
dst[5] = streamFastPayloadTypeData
dst[6] = 0
dst[7] = 0
binary.BigEndian.PutUint64(dst[8:16], dataID)
binary.BigEndian.PutUint64(dst[16:24], seq)
binary.BigEndian.PutUint32(dst[24:28], uint32(payloadLen))
return nil
}
func encodeStreamFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byte, error) {
frame := make([]byte, streamFastPayloadHeaderLen+len(payload))
if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
return nil, err
}
copy(frame[streamFastPayloadHeaderLen:], payload)
return frame, nil
}
func encodeStreamFastFramePayload(frame streamFastDataFrame) ([]byte, error) {
framePayload := make([]byte, streamFastPayloadHeaderLen+len(frame.Payload))
if err := encodeStreamFastDataFrameHeader(framePayload, frame.DataID, frame.Seq, len(frame.Payload)); err != nil {
return nil, err
}
framePayload[6] = frame.Flags
copy(framePayload[streamFastPayloadHeaderLen:], frame.Payload)
return framePayload, nil
}
func encodeStreamFastFramePayloadFast(encode transportFastPlainEncoder, secretKey []byte, frame streamFastDataFrame) ([]byte, error) {
if encode == nil {
return nil, errTransportPayloadEncryptFailed
}
plainLen := streamFastPayloadHeaderLen + len(frame.Payload)
return encode(secretKey, plainLen, func(dst []byte) error {
if err := encodeStreamFastDataFrameHeader(dst, frame.DataID, frame.Seq, len(frame.Payload)); err != nil {
return err
}
dst[6] = frame.Flags
copy(dst[streamFastPayloadHeaderLen:], frame.Payload)
return nil
})
}
func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error) {
if len(payload) < 4 || string(payload[:4]) != streamFastPayloadMagic {
return streamFastDataFrame{}, false, nil
}
if len(payload) < streamFastPayloadHeaderLen {
return streamFastDataFrame{}, true, errStreamFastPayloadInvalid
}
if payload[4] != streamFastPayloadVersion || payload[5] != streamFastPayloadTypeData {
return streamFastDataFrame{}, true, errStreamFastPayloadInvalid
}
dataLen := int(binary.BigEndian.Uint32(payload[24:28]))
if dataLen < 0 || len(payload) != streamFastPayloadHeaderLen+dataLen {
return streamFastDataFrame{}, true, errStreamFastPayloadInvalid
}
dataID := binary.BigEndian.Uint64(payload[8:16])
if dataID == 0 {
return streamFastDataFrame{}, true, errStreamFastPayloadInvalid
}
return streamFastDataFrame{
Flags: payload[6],
DataID: dataID,
Seq: binary.BigEndian.Uint64(payload[16:24]),
Payload: payload[streamFastPayloadHeaderLen:],
}, true, nil
}
func (c *ClientCommon) encodeFastStreamPayload(frame streamFastDataFrame) ([]byte, error) {
profile := c.clientTransportProtectionSnapshot()
if c != nil && profile.fastStreamEncode != nil && frame.Flags == 0 {
return profile.fastStreamEncode(profile.secretKey, frame.DataID, frame.Seq, frame.Payload)
}
if c != nil && profile.fastPlainEncode != nil {
return encodeStreamFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
}
plain, err := encodeStreamFastFramePayload(frame)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (c *ClientCommon) encodeFastStreamDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
return c.encodeFastStreamPayload(streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: chunk,
})
}
func (c *ClientCommon) encodeFastStreamBatchPayload(frames []streamFastDataFrame) ([]byte, error) {
if c == nil {
return nil, errStreamClientNil
}
profile := c.clientTransportProtectionSnapshot()
if profile.fastPlainEncode != nil {
return encodeStreamFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
}
plain, err := encodeStreamFastBatchPlain(frames)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (c *ClientCommon) sendFastStreamData(ctx context.Context, stream *streamHandle, chunk []byte) error {
if stream == nil {
return io.ErrClosedPipe
}
dataID := stream.dataIDSnapshot()
fastPathVersion := stream.fastPathVersionSnapshot()
if binding := c.clientTransportBindingSnapshot(); binding != nil && streamFastPathSupportsBatch(fastPathVersion) {
if sender := binding.clientStreamBatchSenderSnapshot(c); sender != nil {
if maxPayload := streamAdaptiveFramePayloadLimit(binding); maxPayload > 0 && len(chunk) > maxPayload {
startSeq := stream.reserveOutboundDataSeqs(streamFastSplitFrameCount(len(chunk), maxPayload))
return sender.submitFrames(ctx, fastPathVersion, buildStreamFastSplitFrames(dataID, startSeq, chunk, maxPayload))
}
seq := stream.reserveOutboundDataSeqs(1)
if len(chunk) < streamFastBatchDirectLimit {
return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk)
}
payload, err := c.encodeFastStreamDataPayload(dataID, seq, chunk)
if err != nil {
return err
}
return sender.submitEncoded(ctx, fastPathVersion, payload)
}
}
seq := stream.reserveOutboundDataSeqs(1)
payload, err := c.encodeFastStreamDataPayload(dataID, seq, chunk)
if err != nil {
return err
}
return c.writePayloadToTransport(payload)
}
func (s *ServerCommon) encodeFastStreamPayloadLogical(logical *LogicalConn, frame streamFastDataFrame) ([]byte, error) {
if logical == nil {
return nil, errTransportDetached
}
if fastStreamEncode := logical.fastStreamEncodeSnapshot(); fastStreamEncode != nil && frame.Flags == 0 {
return fastStreamEncode(logical.secretKeySnapshot(), frame.DataID, frame.Seq, frame.Payload)
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeStreamFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame)
}
plain, err := encodeStreamFastFramePayload(frame)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func (s *ServerCommon) encodeFastStreamDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
return s.encodeFastStreamPayloadLogical(logical, streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: chunk,
})
}
func (s *ServerCommon) encodeFastStreamBatchPayloadLogical(logical *LogicalConn, frames []streamFastDataFrame) ([]byte, error) {
if logical == nil {
return nil, errTransportDetached
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeStreamFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames)
}
plain, err := encodeStreamFastBatchPlain(frames)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func (s *ServerCommon) sendFastStreamDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, stream *streamHandle, chunk []byte) error {
if err := s.ensureServerTransportSendReady(transport); err != nil {
return err
}
if stream == nil {
return io.ErrClosedPipe
}
if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot()
}
if logical == nil {
return errTransportDetached
}
dataID := stream.dataIDSnapshot()
fastPathVersion := stream.fastPathVersionSnapshot()
if binding := logical.transportBindingSnapshot(); binding != nil && binding.queueSnapshot() != nil && streamFastPathSupportsBatch(fastPathVersion) {
if sender := binding.serverStreamBatchSenderSnapshot(logical); sender != nil {
if maxPayload := streamAdaptiveFramePayloadLimit(binding); maxPayload > 0 && len(chunk) > maxPayload {
startSeq := stream.reserveOutboundDataSeqs(streamFastSplitFrameCount(len(chunk), maxPayload))
return sender.submitFrames(ctx, fastPathVersion, buildStreamFastSplitFrames(dataID, startSeq, chunk, maxPayload))
}
seq := stream.reserveOutboundDataSeqs(1)
if len(chunk) < streamFastBatchDirectLimit {
return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk)
}
payload, err := s.encodeFastStreamDataPayloadLogical(logical, dataID, seq, chunk)
if err != nil {
return err
}
return sender.submitEncoded(ctx, fastPathVersion, payload)
}
}
seq := stream.reserveOutboundDataSeqs(1)
payload, err := s.encodeFastStreamDataPayloadLogical(logical, dataID, seq, chunk)
if err != nil {
return err
}
return s.writeEnvelopePayload(logical, transport, nil, payload)
}