package notify import ( "context" "encoding/binary" "errors" "fmt" "net" "sync" "time" ) var ( errBulkFastPayloadInvalid = errors.New("invalid bulk fast payload") ) var bulkFastFrameScratchPool sync.Pool const ( bulkFastPayloadMagic = "NBF1" bulkFastPayloadVersion = 1 bulkFastPayloadTypeData = 1 bulkFastPayloadTypeClose = 2 bulkFastPayloadTypeReset = 3 bulkFastPayloadTypeRelease = 4 bulkFastPayloadHeaderLen = 28 bulkFastPayloadFlagFullClose = 1 << 0 ) type bulkFastFrame struct { Type uint8 Flags uint8 DataID uint64 Seq uint64 Payload []byte } type bulkFastDataFrame = bulkFastFrame func encodeBulkFastFrameHeader(dst []byte, frameType uint8, flags uint8, dataID uint64, seq uint64, payloadLen int) error { if dataID == 0 { return errBulkDataIDEmpty } if len(dst) < bulkFastPayloadHeaderLen { return errBulkFastPayloadInvalid } copy(dst[:4], bulkFastPayloadMagic) dst[4] = bulkFastPayloadVersion dst[5] = frameType dst[6] = flags dst[7] = 0 binary.BigEndian.PutUint64(dst[8:16], dataID) binary.BigEndian.PutUint64(dst[16:24], seq) binary.BigEndian.PutUint32(dst[24:28], uint32(payloadLen)) return nil } func encodeBulkFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error { return encodeBulkFastFrameHeader(dst, bulkFastPayloadTypeData, 0, dataID, seq, payloadLen) } func encodeBulkFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byte, error) { frame := make([]byte, bulkFastPayloadHeaderLen+len(payload)) if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { return nil, err } copy(frame[bulkFastPayloadHeaderLen:], payload) return frame, nil } func encodeBulkFastControlFrame(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { frame := make([]byte, bulkFastPayloadHeaderLen+len(payload)) if err := encodeBulkFastFrameHeader(frame, frameType, flags, dataID, seq, len(payload)); err != nil { return nil, err } copy(frame[bulkFastPayloadHeaderLen:], payload) return frame, nil } func decodeBulkFastFrame(payload []byte) (bulkFastFrame, bool, error) { if len(payload) < 4 || string(payload[:4]) != bulkFastPayloadMagic { return bulkFastFrame{}, false, nil } if len(payload) < bulkFastPayloadHeaderLen { return bulkFastFrame{}, true, errBulkFastPayloadInvalid } if payload[4] != bulkFastPayloadVersion { return bulkFastFrame{}, true, errBulkFastPayloadInvalid } switch payload[5] { case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease: default: return bulkFastFrame{}, true, errBulkFastPayloadInvalid } dataLen := int(binary.BigEndian.Uint32(payload[24:28])) if dataLen < 0 || len(payload) != bulkFastPayloadHeaderLen+dataLen { return bulkFastFrame{}, true, errBulkFastPayloadInvalid } dataID := binary.BigEndian.Uint64(payload[8:16]) if dataID == 0 { return bulkFastFrame{}, true, errBulkFastPayloadInvalid } return bulkFastFrame{ Type: payload[5], Flags: payload[6], DataID: dataID, Seq: binary.BigEndian.Uint64(payload[16:24]), Payload: payload[bulkFastPayloadHeaderLen:], }, true, nil } func decodeBulkFastDataFrame(payload []byte) (bulkFastDataFrame, bool, error) { frame, matched, err := decodeBulkFastFrame(payload) if !matched || err != nil { return frame, matched, err } if frame.Type != bulkFastPayloadTypeData { return bulkFastDataFrame{}, false, nil } return frame, true, nil } func (c *ClientCommon) encodeFastBulkDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) { return c.encodeBulkFastPayload(bulkFastFrame{ Type: bulkFastPayloadTypeData, DataID: dataID, Seq: seq, Payload: chunk, }) } func (c *ClientCommon) encodeBulkFastPayload(frame bulkFastFrame) ([]byte, error) { if c == nil { return nil, errBulkClientNil } profile := c.clientTransportProtectionSnapshot() if profile.fastPlainEncode != nil { return encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame) } plain, err := encodeBulkFastFramePayload(frame) if err != nil { return nil, err } return c.encryptTransportPayload(plain) } func (c *ClientCommon) encodeBulkFastBatchPayload(frames []bulkFastFrame) ([]byte, error) { if c == nil { return nil, errBulkClientNil } profile := c.clientTransportProtectionSnapshot() if profile.fastPlainEncode != nil { return encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames) } plain, err := encodeBulkFastBatchPlain(frames) if err != nil { return nil, err } return c.encryptTransportPayload(plain) } func (c *ClientCommon) encodeBulkFastPayloadPooled(frame bulkFastFrame) ([]byte, func(), error) { if c == nil { return nil, nil, errBulkClientNil } profile := c.clientTransportProtectionSnapshot() if runtime := profile.runtime; runtime != nil { return encodeBulkFastFramePayloadPooled(runtime, frame) } if profile.fastPlainEncode != nil { payload, err := encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame) return payload, nil, err } plain, err := encodeBulkFastFramePayload(frame) if err != nil { return nil, nil, err } payload, err := c.encryptTransportPayload(plain) return payload, nil, err } func (c *ClientCommon) encodeBulkFastBatchPayloadPooled(frames []bulkFastFrame) ([]byte, func(), error) { if c == nil { return nil, nil, errBulkClientNil } profile := c.clientTransportProtectionSnapshot() if runtime := profile.runtime; runtime != nil { return encodeBulkFastBatchPayloadPooled(runtime, frames) } if profile.fastPlainEncode != nil { payload, err := encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames) return payload, nil, err } plain, err := encodeBulkFastBatchPlain(frames) if err != nil { return nil, nil, err } payload, err := c.encryptTransportPayload(plain) return payload, nil, err } func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte, fastPathVersion uint8) error { binding := c.clientTransportBindingSnapshot() if binding == nil { return net.ErrClosed } if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil { return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk) } payload, err := c.encodeFastBulkDataPayload(dataID, seq, chunk) if err != nil { return err } return c.writePayloadToTransport(payload) } func (c *ClientCommon) sendFastBulkWrite(ctx context.Context, dataID uint64, startSeq uint64, chunkSize int, fastPathVersion uint8, payload []byte, payloadOwned bool) (int, error) { if len(payload) == 0 { return 0, nil } binding := c.clientTransportBindingSnapshot() if binding == nil { return 0, net.ErrClosed } if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil { return sender.submitWrite(ctx, dataID, startSeq, fastPathVersion, payload, chunkSize, payloadOwned) } if chunkSize <= 0 { chunkSize = defaultBulkChunkSize } written := 0 seq := startSeq for written < len(payload) { end := written + chunkSize if end > len(payload) { end = len(payload) } if err := c.sendFastBulkData(ctx, dataID, seq, payload[written:end], fastPathVersion); err != nil { return written, err } seq++ written = end } return written, nil } func (c *ClientCommon) sendFastBulkControl(ctx context.Context, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error { frame := bulkFastFrame{ Type: frameType, Flags: flags, DataID: dataID, Seq: seq, Payload: payload, } binding := c.clientTransportBindingSnapshot() if binding == nil { return net.ErrClosed } if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil { return sender.submitControl(ctx, frameType, flags, dataID, seq, fastPathVersion, payload) } encoded, err := c.encodeBulkFastPayload(frame) if err != nil { return err } return c.writePayloadToTransport(encoded) } func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload) if err != nil { return nil, err } return c.encryptTransportPayload(plain) } func (s *ServerCommon) encodeFastBulkDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) { return s.encodeBulkFastPayloadLogical(logical, bulkFastFrame{ Type: bulkFastPayloadTypeData, DataID: dataID, Seq: seq, Payload: chunk, }) } func (s *ServerCommon) encodeBulkFastPayloadLogical(logical *LogicalConn, frame bulkFastFrame) ([]byte, error) { if logical == nil { return nil, errTransportDetached } if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { return encodeBulkFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame) } plain, err := encodeBulkFastFramePayload(frame) if err != nil { return nil, err } return s.encryptTransportPayloadLogical(logical, plain) } func (s *ServerCommon) encodeBulkFastBatchPayloadLogical(logical *LogicalConn, frames []bulkFastFrame) ([]byte, error) { if logical == nil { return nil, errTransportDetached } if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { return encodeBulkFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames) } plain, err := encodeBulkFastBatchPlain(frames) if err != nil { return nil, err } return s.encryptTransportPayloadLogical(logical, plain) } func (s *ServerCommon) encodeBulkFastPayloadLogicalPooled(logical *LogicalConn, frame bulkFastFrame) ([]byte, func(), error) { if logical == nil { return nil, nil, errTransportDetached } if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil { return encodeBulkFastFramePayloadPooled(runtime, frame) } if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { payload, err := encodeBulkFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame) return payload, nil, err } plain, err := encodeBulkFastFramePayload(frame) if err != nil { return nil, nil, err } payload, err := s.encryptTransportPayloadLogical(logical, plain) return payload, nil, err } func (s *ServerCommon) encodeBulkFastBatchPayloadLogicalPooled(logical *LogicalConn, frames []bulkFastFrame) ([]byte, func(), error) { if logical == nil { return nil, nil, errTransportDetached } if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil { return encodeBulkFastBatchPayloadPooled(runtime, frames) } if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { payload, err := encodeBulkFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames) return payload, nil, err } plain, err := encodeBulkFastBatchPlain(frames) if err != nil { return nil, nil, err } payload, err := s.encryptTransportPayloadLogical(logical, plain) return payload, nil, err } func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte, fastPathVersion uint8) error { if err := s.ensureServerTransportSendReady(transport); err != nil { return err } if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return errTransportDetached } if binding := logical.transportBindingSnapshot(); binding != nil { if binding.queueSnapshot() != nil { if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil { return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk) } } } payload, err := s.encodeFastBulkDataPayloadLogical(logical, dataID, seq, chunk) if err != nil { return err } return s.writeEnvelopePayload(logical, transport, nil, payload) } func (s *ServerCommon) sendFastBulkWriteTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, startSeq uint64, chunkSize int, fastPathVersion uint8, payload []byte, payloadOwned bool) (int, error) { if len(payload) == 0 { return 0, nil } if err := s.ensureServerTransportSendReady(transport); err != nil { return 0, err } if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return 0, errTransportDetached } if binding := logical.transportBindingSnapshot(); binding != nil { if binding.queueSnapshot() != nil { if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil { return sender.submitWrite(ctx, dataID, startSeq, fastPathVersion, payload, chunkSize, payloadOwned) } } } if chunkSize <= 0 { chunkSize = defaultBulkChunkSize } written := 0 seq := startSeq for written < len(payload) { end := written + chunkSize if end > len(payload) { end = len(payload) } if err := s.sendFastBulkDataTransport(ctx, logical, transport, dataID, seq, payload[written:end], fastPathVersion); err != nil { return written, err } seq++ written = end } return written, nil } func (s *ServerCommon) sendFastBulkControlTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error { if err := s.ensureServerTransportSendReady(transport); err != nil { return err } if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return errTransportDetached } if binding := logical.transportBindingSnapshot(); binding != nil { if binding.queueSnapshot() != nil { if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil { return sender.submitControl(ctx, frameType, flags, dataID, seq, fastPathVersion, payload) } } } encoded, err := s.encodeBulkFastPayloadLogical(logical, bulkFastFrame{ Type: frameType, Flags: flags, DataID: dataID, Seq: seq, Payload: payload, }) if err != nil { return err } return s.writeEnvelopePayload(logical, transport, nil, encoded) } func (s *ServerCommon) encodeBulkFastControlPayloadLogical(logical *LogicalConn, frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload) if err != nil { return nil, err } return s.encryptTransportPayloadLogical(logical, plain) } func getBulkFastFrameScratch(payloadLen int) []byte { need := bulkFastPayloadHeaderLen + payloadLen if buf, ok := bulkFastFrameScratchPool.Get().([]byte); ok && cap(buf) >= need { return buf[:need] } return make([]byte, need) } func putBulkFastFrameScratch(buf []byte) { if cap(buf) == 0 || cap(buf) > 4*1024*1024 { return } bulkFastFrameScratchPool.Put(buf[:0]) } func transportFastPayloadMagic(payload []byte) string { if len(payload) < 4 { return "" } return string(payload[:4]) } func (c *ClientCommon) decryptTransportPayloadPooled(payload []byte, release func()) ([]byte, func(), error) { profile := c.clientTransportProtectionSnapshot() return decryptTransportPayloadCodecPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload, release) } func (s *ServerCommon) decryptTransportPayloadLogicalPooled(logical *LogicalConn, payload []byte, release func()) ([]byte, func(), error) { if logical == nil { if release != nil { release() } return nil, nil, errTransportDetached } return decryptTransportPayloadCodecPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, release) } func (c *ClientCommon) tryDispatchBorrowedTransportPlain(plain []byte, release func()) bool { switch transportFastPayloadMagic(plain) { case bulkFastPayloadMagic, bulkFastBatchMagic: owner := newBulkReadPayloadOwner(release) matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error { c.dispatchFastBulkFrameWithOwner(frame, owner) return nil }) if owner != nil { owner.done() } if !matched { walkErr = errBulkFastPayloadInvalid } if walkErr != nil && (c.showError || c.debugMode) { fmt.Println("client decode bulk fast payload error", walkErr) } return true case streamFastPayloadMagic, streamFastBatchMagic: owner := newStreamReadPayloadOwner(release) matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error { c.dispatchFastStreamDataWithOwner(frame, owner) return nil }) if owner != nil { owner.done() } if !matched { walkErr = errStreamFastPayloadInvalid } if walkErr != nil && (c.showError || c.debugMode) { fmt.Println("client decode stream fast payload error", walkErr) } return true default: return false } } func (s *ServerCommon) tryDispatchBorrowedTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, release func()) bool { switch transportFastPayloadMagic(plain) { case bulkFastPayloadMagic, bulkFastBatchMagic: owner := newBulkReadPayloadOwner(release) matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error { s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, owner) return nil }) if owner != nil { owner.done() } if !matched { walkErr = errBulkFastPayloadInvalid } if walkErr != nil && (s.showError || s.debugMode) { fmt.Println("server decode bulk fast payload error", walkErr) } return true case streamFastPayloadMagic, streamFastBatchMagic: owner := newStreamReadPayloadOwner(release) matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error { s.dispatchFastStreamDataWithOwner(logical, transport, conn, frame, owner) return nil }) if owner != nil { owner.done() } if !matched { walkErr = errStreamFastPayloadInvalid } if walkErr != nil && (s.showError || s.debugMode) { fmt.Println("server decode stream fast payload error", walkErr) } return true default: return false } } func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error { plain, err := c.decryptTransportPayload(payload) if err != nil { return err } return c.dispatchInboundTransportPlain(plain, now) } func (c *ClientCommon) dispatchInboundTransportPlain(plain []byte, now time.Time) error { if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error { c.dispatchFastBulkFrame(frame) return nil }); matched { return err } if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error { c.dispatchFastStreamData(frame) return nil }); matched { return err } env, err := c.decodeEnvelopePlain(plain) if err != nil { return err } c.dispatchEnvelope(env, now) return nil } func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte, now time.Time) error { if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return errTransportDetached } plain, err := s.decryptTransportPayloadLogical(logical, payload) if err != nil { return err } return s.dispatchInboundTransportPlain(logical, transport, conn, plain, now) } func (s *ServerCommon) dispatchInboundTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, now time.Time) error { if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error { s.dispatchFastBulkFrame(logical, transport, conn, frame) return nil }); matched { return err } if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error { s.dispatchFastStreamData(logical, transport, conn, frame) return nil }); matched { return err } env, err := s.decodeEnvelopePlain(plain) if err != nil { return err } s.dispatchEnvelope(logical, transport, conn, env, now) return nil }