notify/bulk_fastpath.go

631 lines
20 KiB
Go
Raw Permalink Normal View History

package notify
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"sync"
"time"
)
var (
errBulkFastPayloadInvalid = errors.New("invalid bulk fast payload")
)
var bulkFastFrameScratchPool sync.Pool
const (
bulkFastPayloadMagic = "NBF1"
bulkFastPayloadVersion = 1
bulkFastPayloadTypeData = 1
bulkFastPayloadTypeClose = 2
bulkFastPayloadTypeReset = 3
bulkFastPayloadTypeRelease = 4
bulkFastPayloadHeaderLen = 28
bulkFastPayloadFlagFullClose = 1 << 0
)
type bulkFastFrame struct {
Type uint8
Flags uint8
DataID uint64
Seq uint64
Payload []byte
}
type bulkFastDataFrame = bulkFastFrame
func encodeBulkFastFrameHeader(dst []byte, frameType uint8, flags uint8, dataID uint64, seq uint64, payloadLen int) error {
if dataID == 0 {
return errBulkDataIDEmpty
}
if len(dst) < bulkFastPayloadHeaderLen {
return errBulkFastPayloadInvalid
}
copy(dst[:4], bulkFastPayloadMagic)
dst[4] = bulkFastPayloadVersion
dst[5] = frameType
dst[6] = flags
dst[7] = 0
binary.BigEndian.PutUint64(dst[8:16], dataID)
binary.BigEndian.PutUint64(dst[16:24], seq)
binary.BigEndian.PutUint32(dst[24:28], uint32(payloadLen))
return nil
}
func encodeBulkFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error {
return encodeBulkFastFrameHeader(dst, bulkFastPayloadTypeData, 0, dataID, seq, payloadLen)
}
func encodeBulkFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byte, error) {
frame := make([]byte, bulkFastPayloadHeaderLen+len(payload))
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
return nil, err
}
copy(frame[bulkFastPayloadHeaderLen:], payload)
return frame, nil
}
func encodeBulkFastControlFrame(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
frame := make([]byte, bulkFastPayloadHeaderLen+len(payload))
if err := encodeBulkFastFrameHeader(frame, frameType, flags, dataID, seq, len(payload)); err != nil {
return nil, err
}
copy(frame[bulkFastPayloadHeaderLen:], payload)
return frame, nil
}
func decodeBulkFastFrame(payload []byte) (bulkFastFrame, bool, error) {
if len(payload) < 4 || string(payload[:4]) != bulkFastPayloadMagic {
return bulkFastFrame{}, false, nil
}
if len(payload) < bulkFastPayloadHeaderLen {
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
}
if payload[4] != bulkFastPayloadVersion {
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
}
switch payload[5] {
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
default:
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
}
dataLen := int(binary.BigEndian.Uint32(payload[24:28]))
if dataLen < 0 || len(payload) != bulkFastPayloadHeaderLen+dataLen {
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
}
dataID := binary.BigEndian.Uint64(payload[8:16])
if dataID == 0 {
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
}
return bulkFastFrame{
Type: payload[5],
Flags: payload[6],
DataID: dataID,
Seq: binary.BigEndian.Uint64(payload[16:24]),
Payload: payload[bulkFastPayloadHeaderLen:],
}, true, nil
}
func decodeBulkFastDataFrame(payload []byte) (bulkFastDataFrame, bool, error) {
frame, matched, err := decodeBulkFastFrame(payload)
if !matched || err != nil {
return frame, matched, err
}
if frame.Type != bulkFastPayloadTypeData {
return bulkFastDataFrame{}, false, nil
}
return frame, true, nil
}
func (c *ClientCommon) encodeFastBulkDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
return c.encodeBulkFastPayload(bulkFastFrame{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: chunk,
})
}
func (c *ClientCommon) encodeBulkFastPayload(frame bulkFastFrame) ([]byte, error) {
if c == nil {
return nil, errBulkClientNil
}
profile := c.clientTransportProtectionSnapshot()
if profile.fastPlainEncode != nil {
return encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
}
plain, err := encodeBulkFastFramePayload(frame)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (c *ClientCommon) encodeBulkFastBatchPayload(frames []bulkFastFrame) ([]byte, error) {
if c == nil {
return nil, errBulkClientNil
}
profile := c.clientTransportProtectionSnapshot()
if profile.fastPlainEncode != nil {
return encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
}
plain, err := encodeBulkFastBatchPlain(frames)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (c *ClientCommon) encodeBulkFastPayloadPooled(frame bulkFastFrame) ([]byte, func(), error) {
if c == nil {
return nil, nil, errBulkClientNil
}
profile := c.clientTransportProtectionSnapshot()
if runtime := profile.runtime; runtime != nil {
return encodeBulkFastFramePayloadPooled(runtime, frame)
}
if profile.fastPlainEncode != nil {
payload, err := encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
return payload, nil, err
}
plain, err := encodeBulkFastFramePayload(frame)
if err != nil {
return nil, nil, err
}
payload, err := c.encryptTransportPayload(plain)
return payload, nil, err
}
func (c *ClientCommon) encodeBulkFastBatchPayloadPooled(frames []bulkFastFrame) ([]byte, func(), error) {
if c == nil {
return nil, nil, errBulkClientNil
}
profile := c.clientTransportProtectionSnapshot()
if runtime := profile.runtime; runtime != nil {
return encodeBulkFastBatchPayloadPooled(runtime, frames)
}
if profile.fastPlainEncode != nil {
payload, err := encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
return payload, nil, err
}
plain, err := encodeBulkFastBatchPlain(frames)
if err != nil {
return nil, nil, err
}
payload, err := c.encryptTransportPayload(plain)
return payload, nil, err
}
func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte, fastPathVersion uint8) error {
binding := c.clientTransportBindingSnapshot()
if binding == nil {
return net.ErrClosed
}
if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil {
return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk)
}
payload, err := c.encodeFastBulkDataPayload(dataID, seq, chunk)
if err != nil {
return err
}
return c.writePayloadToTransport(payload)
}
func (c *ClientCommon) sendFastBulkWrite(ctx context.Context, dataID uint64, startSeq uint64, chunkSize int, fastPathVersion uint8, payload []byte, payloadOwned bool) (int, error) {
if len(payload) == 0 {
return 0, nil
}
binding := c.clientTransportBindingSnapshot()
if binding == nil {
return 0, net.ErrClosed
}
if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil {
return sender.submitWrite(ctx, dataID, startSeq, fastPathVersion, payload, chunkSize, payloadOwned)
}
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
written := 0
seq := startSeq
for written < len(payload) {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
if err := c.sendFastBulkData(ctx, dataID, seq, payload[written:end], fastPathVersion); err != nil {
return written, err
}
seq++
written = end
}
return written, nil
}
func (c *ClientCommon) sendFastBulkControl(ctx context.Context, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
frame := bulkFastFrame{
Type: frameType,
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload,
}
binding := c.clientTransportBindingSnapshot()
if binding == nil {
return net.ErrClosed
}
if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil {
return sender.submitControl(ctx, frameType, flags, dataID, seq, fastPathVersion, payload)
}
encoded, err := c.encodeBulkFastPayload(frame)
if err != nil {
return err
}
return c.writePayloadToTransport(encoded)
}
func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (s *ServerCommon) encodeFastBulkDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
return s.encodeBulkFastPayloadLogical(logical, bulkFastFrame{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: chunk,
})
}
func (s *ServerCommon) encodeBulkFastPayloadLogical(logical *LogicalConn, frame bulkFastFrame) ([]byte, error) {
if logical == nil {
return nil, errTransportDetached
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeBulkFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame)
}
plain, err := encodeBulkFastFramePayload(frame)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func (s *ServerCommon) encodeBulkFastBatchPayloadLogical(logical *LogicalConn, frames []bulkFastFrame) ([]byte, error) {
if logical == nil {
return nil, errTransportDetached
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeBulkFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames)
}
plain, err := encodeBulkFastBatchPlain(frames)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func (s *ServerCommon) encodeBulkFastPayloadLogicalPooled(logical *LogicalConn, frame bulkFastFrame) ([]byte, func(), error) {
if logical == nil {
return nil, nil, errTransportDetached
}
if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil {
return encodeBulkFastFramePayloadPooled(runtime, frame)
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
payload, err := encodeBulkFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame)
return payload, nil, err
}
plain, err := encodeBulkFastFramePayload(frame)
if err != nil {
return nil, nil, err
}
payload, err := s.encryptTransportPayloadLogical(logical, plain)
return payload, nil, err
}
func (s *ServerCommon) encodeBulkFastBatchPayloadLogicalPooled(logical *LogicalConn, frames []bulkFastFrame) ([]byte, func(), error) {
if logical == nil {
return nil, nil, errTransportDetached
}
if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil {
return encodeBulkFastBatchPayloadPooled(runtime, frames)
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
payload, err := encodeBulkFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames)
return payload, nil, err
}
plain, err := encodeBulkFastBatchPlain(frames)
if err != nil {
return nil, nil, err
}
payload, err := s.encryptTransportPayloadLogical(logical, plain)
return payload, nil, err
}
func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte, fastPathVersion uint8) error {
if err := s.ensureServerTransportSendReady(transport); err != nil {
return err
}
if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot()
}
if logical == nil {
return errTransportDetached
}
if binding := logical.transportBindingSnapshot(); binding != nil {
if binding.queueSnapshot() != nil {
if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil {
return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk)
}
}
}
payload, err := s.encodeFastBulkDataPayloadLogical(logical, dataID, seq, chunk)
if err != nil {
return err
}
return s.writeEnvelopePayload(logical, transport, nil, payload)
}
func (s *ServerCommon) sendFastBulkWriteTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, startSeq uint64, chunkSize int, fastPathVersion uint8, payload []byte, payloadOwned bool) (int, error) {
if len(payload) == 0 {
return 0, nil
}
if err := s.ensureServerTransportSendReady(transport); err != nil {
return 0, err
}
if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot()
}
if logical == nil {
return 0, errTransportDetached
}
if binding := logical.transportBindingSnapshot(); binding != nil {
if binding.queueSnapshot() != nil {
if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil {
return sender.submitWrite(ctx, dataID, startSeq, fastPathVersion, payload, chunkSize, payloadOwned)
}
}
}
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
written := 0
seq := startSeq
for written < len(payload) {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
if err := s.sendFastBulkDataTransport(ctx, logical, transport, dataID, seq, payload[written:end], fastPathVersion); err != nil {
return written, err
}
seq++
written = end
}
return written, nil
}
func (s *ServerCommon) sendFastBulkControlTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
if err := s.ensureServerTransportSendReady(transport); err != nil {
return err
}
if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot()
}
if logical == nil {
return errTransportDetached
}
if binding := logical.transportBindingSnapshot(); binding != nil {
if binding.queueSnapshot() != nil {
if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil {
return sender.submitControl(ctx, frameType, flags, dataID, seq, fastPathVersion, payload)
}
}
}
encoded, err := s.encodeBulkFastPayloadLogical(logical, bulkFastFrame{
Type: frameType,
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload,
})
if err != nil {
return err
}
return s.writeEnvelopePayload(logical, transport, nil, encoded)
}
func (s *ServerCommon) encodeBulkFastControlPayloadLogical(logical *LogicalConn, frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func getBulkFastFrameScratch(payloadLen int) []byte {
need := bulkFastPayloadHeaderLen + payloadLen
if buf, ok := bulkFastFrameScratchPool.Get().([]byte); ok && cap(buf) >= need {
return buf[:need]
}
return make([]byte, need)
}
func putBulkFastFrameScratch(buf []byte) {
if cap(buf) == 0 || cap(buf) > 4*1024*1024 {
return
}
bulkFastFrameScratchPool.Put(buf[:0])
}
func transportFastPayloadMagic(payload []byte) string {
if len(payload) < 4 {
return ""
}
return string(payload[:4])
}
func (c *ClientCommon) decryptTransportPayloadPooled(payload []byte, release func()) ([]byte, func(), error) {
profile := c.clientTransportProtectionSnapshot()
return decryptTransportPayloadCodecPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload, release)
}
func (s *ServerCommon) decryptTransportPayloadLogicalPooled(logical *LogicalConn, payload []byte, release func()) ([]byte, func(), error) {
if logical == nil {
if release != nil {
release()
}
return nil, nil, errTransportDetached
}
return decryptTransportPayloadCodecPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, release)
}
func (c *ClientCommon) tryDispatchBorrowedTransportPlain(plain []byte, release func()) bool {
switch transportFastPayloadMagic(plain) {
case bulkFastPayloadMagic, bulkFastBatchMagic:
owner := newBulkReadPayloadOwner(release)
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
c.dispatchFastBulkFrameWithOwner(frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
walkErr = errBulkFastPayloadInvalid
}
if walkErr != nil && (c.showError || c.debugMode) {
fmt.Println("client decode bulk fast payload error", walkErr)
}
return true
case streamFastPayloadMagic, streamFastBatchMagic:
owner := newStreamReadPayloadOwner(release)
matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
c.dispatchFastStreamDataWithOwner(frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
walkErr = errStreamFastPayloadInvalid
}
if walkErr != nil && (c.showError || c.debugMode) {
fmt.Println("client decode stream fast payload error", walkErr)
}
return true
default:
return false
}
}
func (s *ServerCommon) tryDispatchBorrowedTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, release func()) bool {
switch transportFastPayloadMagic(plain) {
case bulkFastPayloadMagic, bulkFastBatchMagic:
owner := newBulkReadPayloadOwner(release)
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
walkErr = errBulkFastPayloadInvalid
}
if walkErr != nil && (s.showError || s.debugMode) {
fmt.Println("server decode bulk fast payload error", walkErr)
}
return true
case streamFastPayloadMagic, streamFastBatchMagic:
owner := newStreamReadPayloadOwner(release)
matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
s.dispatchFastStreamDataWithOwner(logical, transport, conn, frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
walkErr = errStreamFastPayloadInvalid
}
if walkErr != nil && (s.showError || s.debugMode) {
fmt.Println("server decode stream fast payload error", walkErr)
}
return true
default:
return false
}
}
func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error {
plain, err := c.decryptTransportPayload(payload)
if err != nil {
return err
}
return c.dispatchInboundTransportPlain(plain, now)
}
func (c *ClientCommon) dispatchInboundTransportPlain(plain []byte, now time.Time) error {
if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
c.dispatchFastBulkFrame(frame)
return nil
}); matched {
return err
}
if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
c.dispatchFastStreamData(frame)
return nil
}); matched {
return err
}
env, err := c.decodeEnvelopePlain(plain)
if err != nil {
return err
}
c.dispatchEnvelope(env, now)
return nil
}
func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte, now time.Time) error {
if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot()
}
if logical == nil {
return errTransportDetached
}
plain, err := s.decryptTransportPayloadLogical(logical, payload)
if err != nil {
return err
}
return s.dispatchInboundTransportPlain(logical, transport, conn, plain, now)
}
func (s *ServerCommon) dispatchInboundTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, now time.Time) error {
if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
s.dispatchFastBulkFrame(logical, transport, conn, frame)
return nil
}); matched {
return err
}
if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
s.dispatchFastStreamData(logical, transport, conn, frame)
return nil
}); matched {
return err
}
env, err := s.decodeEnvelopePlain(plain)
if err != nil {
return err
}
s.dispatchEnvelope(logical, transport, conn, env, now)
return nil
}