package notify import ( "b612.me/notify/internal/transport" "b612.me/stario" "context" cryptorand "crypto/rand" "encoding/binary" "encoding/hex" "errors" "io" "net" "sync/atomic" "time" ) const ( systemBulkAttachKey = "_notify_bulk_attach" bulkDedicatedRecordMagic = "NBR1" bulkDedicatedRecordHeaderLen = 8 bulkDedicatedAttachTimeout = 5 * time.Second ) type bulkAttachRequest struct { PeerID string BulkID string AttachToken string } type bulkAttachResponse struct { Accepted bool Error string } func newBulkAttachToken() string { var buf [16]byte if _, err := cryptorand.Read(buf[:]); err == nil { return hex.EncodeToString(buf[:]) } return "" } func decodeBulkAttachRequest(decodeFn func([]byte) (interface{}, error), data MsgVal) (bulkAttachRequest, error) { var req bulkAttachRequest if decodeFn == nil { decodeFn = Decode } raw := []byte(data) value, err := decodeFn(raw) if err != nil { return req, err } switch typed := value.(type) { case bulkAttachRequest: return typed, nil case *bulkAttachRequest: if typed == nil { return req, errors.New("bulk attach request is nil") } return *typed, nil default: return req, errors.New("invalid bulk attach payload") } } func decodeBulkAttachResponse(decodeFn func([]byte) (interface{}, error), data MsgVal) (bulkAttachResponse, error) { var resp bulkAttachResponse if decodeFn == nil { decodeFn = Decode } raw := []byte(data) value, err := decodeFn(raw) if err != nil { return resp, err } switch typed := value.(type) { case bulkAttachResponse: return typed, nil case *bulkAttachResponse: if typed == nil { return resp, errors.New("bulk attach response is nil") } return *typed, nil default: return resp, errors.New("invalid bulk attach response") } } func encodeDirectSignalFrame(queue *stario.StarQueue, sequenceEn func(interface{}) ([]byte, error), msgEn func([]byte, []byte) []byte, secretKey []byte, msg TransferMsg) ([]byte, error) { if queue == nil { queue = stario.NewQueue() } env, err := wrapTransferMsgEnvelope(msg, sequenceEn) if err != nil { return nil, err } plain, err := sequenceEn(env) if err != nil { return nil, err } payload := msgEn(secretKey, plain) if payload == nil && len(plain) != 0 { return nil, errTransportPayloadEncryptFailed } return queue.BuildMessage(payload), nil } func decodeDirectSignalPayload(sequenceDe func([]byte) (interface{}, error), msgDe func([]byte, []byte) []byte, secretKey []byte, payload []byte) (TransferMsg, error) { plain := msgDe(secretKey, payload) if plain == nil && len(payload) != 0 { return TransferMsg{}, errTransportPayloadDecryptFailed } value, err := sequenceDe(plain) if err != nil { return TransferMsg{}, err } env, ok := value.(Envelope) if !ok { return TransferMsg{}, errors.New("invalid signal envelope") } return unwrapTransferMsgEnvelope(env, sequenceDe) } func writeBulkDedicatedRecord(conn net.Conn, payload []byte) error { return writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{}) } func writeBulkDedicatedRecordWithDeadline(conn net.Conn, payload []byte, deadline time.Time) error { if conn == nil { return net.ErrClosed } return withRawConnWriteLockDeadline(conn, deadline, func(conn net.Conn) error { var header [bulkDedicatedRecordHeaderLen]byte copy(header[:4], bulkDedicatedRecordMagic) binary.BigEndian.PutUint32(header[4:8], uint32(len(payload))) buffers := net.Buffers{header[:], payload} _, err := buffers.WriteTo(conn) return err }) } func readBulkDedicatedRecord(conn net.Conn) ([]byte, error) { if conn == nil { return nil, net.ErrClosed } var header [bulkDedicatedRecordHeaderLen]byte if _, err := io.ReadFull(conn, header[:]); err != nil { return nil, err } if string(header[:4]) != bulkDedicatedRecordMagic { return nil, errBulkFastPayloadInvalid } size := int(binary.BigEndian.Uint32(header[4:8])) if size < 0 { return nil, errBulkFastPayloadInvalid } payload := make([]byte, size) if _, err := io.ReadFull(conn, payload); err != nil { return nil, err } return payload, nil } func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context) (net.Conn, error) { source := c.clientConnectSourceSnapshot() if source != nil && source.canReconnect() { return source.dial(ctx) } conn := c.clientTransportConnSnapshot() if conn == nil || conn.RemoteAddr() == nil { return nil, errClientReconnectSourceUnavailable } return transport.Dial(conn.RemoteAddr().Network(), conn.RemoteAddr().String()) } func (c *ClientCommon) attachDedicatedBulkSidecar(ctx context.Context, bulk *bulkHandle) error { if c == nil || bulk == nil || !bulk.Dedicated() || bulk.dedicatedAttachedSnapshot() { return nil } if ctx == nil { ctx = context.Background() } ctx, cancel := context.WithTimeout(ctx, bulkDedicatedAttachTimeout) defer cancel() conn, err := c.dialDedicatedBulkConn(ctx) if err != nil { return err } resp, err := c.sendDedicatedBulkAttachRequest(ctx, conn, bulk) if err != nil { _ = conn.Close() return err } if !resp.Accepted { _ = conn.Close() if resp.Error != "" { return errors.New(resp.Error) } return errors.New("bulk attach rejected") } if err := bulk.attachDedicatedConn(conn); err != nil { _ = conn.Close() return err } go c.readDedicatedBulkLoop(bulk, conn) return nil } func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn net.Conn, bulk *bulkHandle) (bulkAttachResponse, error) { if c == nil { return bulkAttachResponse{}, errBulkClientNil } if bulk == nil { return bulkAttachResponse{}, errBulkIDEmpty } defer func() { _ = conn.SetReadDeadline(time.Time{}) }() reqPayload, err := c.sequenceEn(bulkAttachRequest{ PeerID: c.ensureClientPeerIdentity(), BulkID: bulk.ID(), AttachToken: bulk.dedicatedAttachTokenSnapshot(), }) if err != nil { return bulkAttachResponse{}, err } queue := stario.NewQueue() msg := TransferMsg{ ID: atomic.AddUint64(&c.msgID, 1), Key: systemBulkAttachKey, Value: reqPayload, Type: MSG_SYS_WAIT, } frame, err := encodeDirectSignalFrame(queue, c.sequenceEn, c.msgEn, c.SecretKey, msg) if err != nil { return bulkAttachResponse{}, err } if err := writeFullToConn(conn, frame); err != nil { return bulkAttachResponse{}, err } replyCh := make(chan Message, 1) readBuf := streamReadBuffer() for { if deadline, ok := ctx.Deadline(); ok { _ = conn.SetReadDeadline(deadline) } n, err := conn.Read(readBuf) if err != nil { return bulkAttachResponse{}, err } parseErr := queue.ParseMessageOwned(readBuf[:n], "bulk-attach", func(msgq stario.MsgQueue) error { transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, msgq.Msg) if err != nil { return err } replyCh <- Message{ ServerConn: c, TransferMsg: transfer, NetType: NET_CLIENT, } return nil }) if parseErr != nil { return bulkAttachResponse{}, parseErr } select { case reply := <-replyCh: return decodeBulkAttachResponse(c.sequenceDe, reply.Value) default: } } } func (c *ClientCommon) readDedicatedBulkLoop(bulk *bulkHandle, conn net.Conn) { for { payload, err := readBulkDedicatedRecord(conn) if err != nil { handleDedicatedBulkReadError(bulk, err) return } plain, err := c.decryptTransportPayload(payload) if err != nil { _ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error()) bulk.markReset(err) return } items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain) if err != nil { _ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error()) bulk.markReset(err) return } for _, item := range items { if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil { if !errors.Is(err, io.EOF) { _ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error()) bulk.markReset(err) } return } if bulk.Context().Err() != nil { return } } } } func (s *ServerCommon) handleBulkAttachSystemMessage(message Message) bool { if message.Key != systemBulkAttachKey { return false } current := messageLogicalConnSnapshot(&message) resp := bulkAttachResponse{} var ( req bulkAttachRequest logical *LogicalConn bulk *bulkHandle err error ) req, err = decodeBulkAttachRequest(s.sequenceDe, message.Value) if err == nil { logical, bulk, err = s.resolveInboundDedicatedBulk(current, req) } if err != nil { resp.Error = err.Error() } else { resp.Accepted = true } if current != nil { _ = s.replyDedicatedBulkAttach(current, message, resp) } if err == nil { if attachErr := s.finishInboundDedicatedBulkAttach(current, logical, bulk); attachErr != nil { bulk.markReset(attachErr) } } return true } func (s *ServerCommon) resolveInboundDedicatedBulk(current *LogicalConn, req bulkAttachRequest) (*LogicalConn, *bulkHandle, error) { if s == nil { return nil, nil, errBulkServerNil } if current == nil { return nil, nil, errBulkLogicalConnNil } if req.PeerID == "" || req.BulkID == "" || req.AttachToken == "" { return nil, nil, errBulkIDEmpty } logical := s.GetLogicalConn(req.PeerID) if logical == nil { return nil, nil, errBulkLogicalConnNil } runtime := s.getBulkRuntime() if runtime == nil { return nil, nil, errBulkRuntimeNil } bulk, ok := runtime.lookup(serverFileScope(logical), req.BulkID) if !ok { return nil, nil, errBulkNotFound } if !bulk.Dedicated() { return nil, nil, errors.New("bulk is not dedicated") } if bulk.dedicatedAttachTokenSnapshot() != req.AttachToken { return nil, nil, errors.New("bulk attach token mismatch") } return logical, bulk, nil } func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle) error { if current == nil || logical == nil || bulk == nil { return errBulkLogicalConnNil } conn, err := current.detachTransportForTransfer() if err != nil { return err } if err := bulk.attachDedicatedConn(conn); err != nil { if conn != nil { _ = conn.Close() } return err } go s.readDedicatedBulkLoop(logical, bulk, conn) current.markSessionStopped("bulk dedicated attach", nil) s.removeLogical(current) return nil } func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Message, resp bulkAttachResponse) error { if s == nil || client == nil { return errBulkServerNil } encoded, err := s.sequenceEn(resp) if err != nil { return err } reply := TransferMsg{ ID: message.ID, Key: systemBulkAttachKey, Value: encoded, Type: MSG_SYS_REPLY, } if message.inboundConn != nil { return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply) } _, err = s.sendLogical(client, reply) return err } func (s *ServerCommon) readDedicatedBulkLoop(logical *LogicalConn, bulk *bulkHandle, conn net.Conn) { for { payload, err := readBulkDedicatedRecord(conn) if err != nil { handleDedicatedBulkReadError(bulk, err) return } plain, err := s.decryptTransportPayloadLogical(logical, payload) if err != nil { _ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error()) bulk.markReset(err) return } items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain) if err != nil { _ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error()) bulk.markReset(err) return } for _, item := range items { if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil { if !errors.Is(err, io.EOF) { _ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error()) bulk.markReset(err) } return } if bulk.Context().Err() != nil { return } } } } func handleDedicatedBulkReadError(bulk *bulkHandle, err error) { if bulk == nil { return } if bulk.Context().Err() != nil || bulk.remoteClosedSnapshot() { return } if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { if bulk.Dedicated() || bulk.localClosedSnapshot() { bulk.markRemoteClosed() return } } bulk.markReset(transportDetachedError("dedicated bulk read error", err)) } func (c *ClientCommon) dedicatedBulkSender(bulk *bulkHandle) (*bulkDedicatedSender, error) { if c == nil || bulk == nil { return nil, errBulkClientNil } if sender := bulk.dedicatedSenderSnapshot(); sender != nil { return sender, nil } conn := bulk.dedicatedConnSnapshot() if conn == nil { return nil, transportDetachedError("dedicated bulk sidecar not attached", nil) } sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), c.encryptTransportPayload, func(items []bulkDedicatedSendRequest) ([]byte, error) { return c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), items) }, func(err error) { bulk.markReset(err) }) actual := bulk.installDedicatedSender(sender) if actual != sender { sender.stop() } return actual, nil } func (c *ClientCommon) sendDedicatedBulkData(ctx context.Context, bulk *bulkHandle, chunk []byte) error { if c == nil || bulk == nil { return errBulkClientNil } sender, err := c.dedicatedBulkSender(bulk) if err != nil { return err } return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk) } func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) { if c == nil || bulk == nil { return 0, errBulkClientNil } sender, err := c.dedicatedBulkSender(bulk) if err != nil { return 0, err } return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize) } func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error { if c == nil || bulk == nil { return errBulkClientNil } sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout) if err != nil { return err } defer cancel() flags := uint8(0) if full { flags = bulkFastPayloadFlagFullClose } sender, err := c.dedicatedBulkSender(bulk) if err != nil { return err } return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil) } func (c *ClientCommon) sendDedicatedBulkReset(ctx context.Context, bulk *bulkHandle, message string) error { if c == nil || bulk == nil { return errBulkClientNil } sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout) if err != nil { return err } defer cancel() sender, err := c.dedicatedBulkSender(bulk) if err != nil { return err } return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message)) } func (c *ClientCommon) sendDedicatedBulkRelease(ctx context.Context, bulk *bulkHandle, bytes int64, chunks int) error { if c == nil || bulk == nil { return errBulkClientNil } payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks) if err != nil { return err } if err := bulk.waitDedicatedReady(ctx); err != nil { return err } sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout) if err != nil { return err } defer cancel() frame, err := c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{ Type: bulkFastPayloadTypeRelease, Payload: payload, }}) if err != nil { return err } conn := bulk.dedicatedConnSnapshot() if conn == nil { return transportDetachedError("dedicated bulk sidecar not attached", nil) } deadline, _ := sendCtx.Deadline() return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline) } func (s *ServerCommon) dedicatedBulkSender(logical *LogicalConn, bulk *bulkHandle) (*bulkDedicatedSender, error) { if s == nil || bulk == nil { return nil, errBulkServerNil } if logical == nil { logical = bulk.LogicalConn() } if logical == nil { return nil, errBulkLogicalConnNil } if sender := bulk.dedicatedSenderSnapshot(); sender != nil { return sender, nil } conn := bulk.dedicatedConnSnapshot() if conn == nil { return nil, transportDetachedError("dedicated bulk sidecar not attached", nil) } sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), func(plain []byte) ([]byte, error) { return s.encryptTransportPayloadLogical(logical, plain) }, func(items []bulkDedicatedSendRequest) ([]byte, error) { return s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), items) }, func(err error) { bulk.markReset(err) }) actual := bulk.installDedicatedSender(sender) if actual != sender { sender.stop() } return actual, nil } func (s *ServerCommon) sendDedicatedBulkData(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, chunk []byte) error { if s == nil || bulk == nil { return errBulkServerNil } sender, err := s.dedicatedBulkSender(logical, bulk) if err != nil { return err } return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk) } func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, payload []byte) (int, error) { if s == nil || bulk == nil { return 0, errBulkServerNil } sender, err := s.dedicatedBulkSender(logical, bulk) if err != nil { return 0, err } return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize) } func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error { if s == nil || bulk == nil { return errBulkServerNil } sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout) if err != nil { return err } defer cancel() flags := uint8(0) if full { flags = bulkFastPayloadFlagFullClose } sender, err := s.dedicatedBulkSender(logical, bulk) if err != nil { return err } return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil) } func (s *ServerCommon) sendDedicatedBulkReset(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, message string) error { if s == nil || bulk == nil { return errBulkServerNil } sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout) if err != nil { return err } defer cancel() sender, err := s.dedicatedBulkSender(logical, bulk) if err != nil { return err } return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message)) } func (s *ServerCommon) sendDedicatedBulkRelease(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, bytes int64, chunks int) error { if s == nil || bulk == nil { return errBulkServerNil } payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks) if err != nil { return err } if err := bulk.waitDedicatedReady(ctx); err != nil { return err } sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout) if err != nil { return err } defer cancel() frame, err := s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{ Type: bulkFastPayloadTypeRelease, Payload: payload, }}) if err != nil { return err } conn := bulk.dedicatedConnSnapshot() if conn == nil { return transportDetachedError("dedicated bulk sidecar not attached", nil) } deadline, _ := sendCtx.Deadline() return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline) } func (c *ClientCommon) encodeDedicatedBulkBatchPayload(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { if c == nil { return nil, errBulkClientNil } if c.fastPlainEncode != nil { return encodeBulkDedicatedBatchPayloadFast(c.fastPlainEncode, c.SecretKey, dataID, items) } plain, err := encodeBulkDedicatedBatchPlain(dataID, items) if err != nil { return nil, err } return c.encryptTransportPayload(plain) } func (s *ServerCommon) encodeDedicatedBulkBatchPayload(logical *LogicalConn, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { if s == nil { return nil, errBulkServerNil } if logical == nil { return nil, errBulkLogicalConnNil } if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { return encodeBulkDedicatedBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), dataID, items) } plain, err := encodeBulkDedicatedBatchPlain(dataID, items) if err != nil { return nil, err } return s.encryptTransportPayloadLogical(logical, plain) }