notify/bulk_dedicated.go

724 lines
20 KiB
Go
Raw Permalink Normal View History

package notify
import (
"b612.me/notify/internal/transport"
"b612.me/stario"
"context"
cryptorand "crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
"io"
"net"
"sync/atomic"
"time"
)
const (
systemBulkAttachKey = "_notify_bulk_attach"
bulkDedicatedRecordMagic = "NBR1"
bulkDedicatedRecordHeaderLen = 8
bulkDedicatedAttachTimeout = 5 * time.Second
)
type bulkAttachRequest struct {
PeerID string
BulkID string
AttachToken string
}
type bulkAttachResponse struct {
Accepted bool
Error string
}
func newBulkAttachToken() string {
var buf [16]byte
if _, err := cryptorand.Read(buf[:]); err == nil {
return hex.EncodeToString(buf[:])
}
return ""
}
func decodeBulkAttachRequest(decodeFn func([]byte) (interface{}, error), data MsgVal) (bulkAttachRequest, error) {
var req bulkAttachRequest
if decodeFn == nil {
decodeFn = Decode
}
raw := []byte(data)
value, err := decodeFn(raw)
if err != nil {
return req, err
}
switch typed := value.(type) {
case bulkAttachRequest:
return typed, nil
case *bulkAttachRequest:
if typed == nil {
return req, errors.New("bulk attach request is nil")
}
return *typed, nil
default:
return req, errors.New("invalid bulk attach payload")
}
}
func decodeBulkAttachResponse(decodeFn func([]byte) (interface{}, error), data MsgVal) (bulkAttachResponse, error) {
var resp bulkAttachResponse
if decodeFn == nil {
decodeFn = Decode
}
raw := []byte(data)
value, err := decodeFn(raw)
if err != nil {
return resp, err
}
switch typed := value.(type) {
case bulkAttachResponse:
return typed, nil
case *bulkAttachResponse:
if typed == nil {
return resp, errors.New("bulk attach response is nil")
}
return *typed, nil
default:
return resp, errors.New("invalid bulk attach response")
}
}
func encodeDirectSignalFrame(queue *stario.StarQueue, sequenceEn func(interface{}) ([]byte, error), msgEn func([]byte, []byte) []byte, secretKey []byte, msg TransferMsg) ([]byte, error) {
if queue == nil {
queue = stario.NewQueue()
}
env, err := wrapTransferMsgEnvelope(msg, sequenceEn)
if err != nil {
return nil, err
}
plain, err := sequenceEn(env)
if err != nil {
return nil, err
}
payload := msgEn(secretKey, plain)
if payload == nil && len(plain) != 0 {
return nil, errTransportPayloadEncryptFailed
}
return queue.BuildMessage(payload), nil
}
func decodeDirectSignalPayload(sequenceDe func([]byte) (interface{}, error), msgDe func([]byte, []byte) []byte, secretKey []byte, payload []byte) (TransferMsg, error) {
plain := msgDe(secretKey, payload)
if plain == nil && len(payload) != 0 {
return TransferMsg{}, errTransportPayloadDecryptFailed
}
value, err := sequenceDe(plain)
if err != nil {
return TransferMsg{}, err
}
env, ok := value.(Envelope)
if !ok {
return TransferMsg{}, errors.New("invalid signal envelope")
}
return unwrapTransferMsgEnvelope(env, sequenceDe)
}
func writeBulkDedicatedRecord(conn net.Conn, payload []byte) error {
return writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{})
}
func writeBulkDedicatedRecordWithDeadline(conn net.Conn, payload []byte, deadline time.Time) error {
if conn == nil {
return net.ErrClosed
}
return withRawConnWriteLockDeadline(conn, deadline, func(conn net.Conn) error {
var header [bulkDedicatedRecordHeaderLen]byte
copy(header[:4], bulkDedicatedRecordMagic)
binary.BigEndian.PutUint32(header[4:8], uint32(len(payload)))
buffers := net.Buffers{header[:], payload}
_, err := buffers.WriteTo(conn)
return err
})
}
func readBulkDedicatedRecord(conn net.Conn) ([]byte, error) {
if conn == nil {
return nil, net.ErrClosed
}
var header [bulkDedicatedRecordHeaderLen]byte
if _, err := io.ReadFull(conn, header[:]); err != nil {
return nil, err
}
if string(header[:4]) != bulkDedicatedRecordMagic {
return nil, errBulkFastPayloadInvalid
}
size := int(binary.BigEndian.Uint32(header[4:8]))
if size < 0 {
return nil, errBulkFastPayloadInvalid
}
payload := make([]byte, size)
if _, err := io.ReadFull(conn, payload); err != nil {
return nil, err
}
return payload, nil
}
func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context) (net.Conn, error) {
source := c.clientConnectSourceSnapshot()
if source != nil && source.canReconnect() {
return source.dial(ctx)
}
conn := c.clientTransportConnSnapshot()
if conn == nil || conn.RemoteAddr() == nil {
return nil, errClientReconnectSourceUnavailable
}
return transport.Dial(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
}
func (c *ClientCommon) attachDedicatedBulkSidecar(ctx context.Context, bulk *bulkHandle) error {
if c == nil || bulk == nil || !bulk.Dedicated() || bulk.dedicatedAttachedSnapshot() {
return nil
}
if ctx == nil {
ctx = context.Background()
}
ctx, cancel := context.WithTimeout(ctx, bulkDedicatedAttachTimeout)
defer cancel()
conn, err := c.dialDedicatedBulkConn(ctx)
if err != nil {
return err
}
resp, err := c.sendDedicatedBulkAttachRequest(ctx, conn, bulk)
if err != nil {
_ = conn.Close()
return err
}
if !resp.Accepted {
_ = conn.Close()
if resp.Error != "" {
return errors.New(resp.Error)
}
return errors.New("bulk attach rejected")
}
if err := bulk.attachDedicatedConn(conn); err != nil {
_ = conn.Close()
return err
}
go c.readDedicatedBulkLoop(bulk, conn)
return nil
}
func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn net.Conn, bulk *bulkHandle) (bulkAttachResponse, error) {
if c == nil {
return bulkAttachResponse{}, errBulkClientNil
}
if bulk == nil {
return bulkAttachResponse{}, errBulkIDEmpty
}
defer func() {
_ = conn.SetReadDeadline(time.Time{})
}()
reqPayload, err := c.sequenceEn(bulkAttachRequest{
PeerID: c.ensureClientPeerIdentity(),
BulkID: bulk.ID(),
AttachToken: bulk.dedicatedAttachTokenSnapshot(),
})
if err != nil {
return bulkAttachResponse{}, err
}
queue := stario.NewQueue()
msg := TransferMsg{
ID: atomic.AddUint64(&c.msgID, 1),
Key: systemBulkAttachKey,
Value: reqPayload,
Type: MSG_SYS_WAIT,
}
frame, err := encodeDirectSignalFrame(queue, c.sequenceEn, c.msgEn, c.SecretKey, msg)
if err != nil {
return bulkAttachResponse{}, err
}
if err := writeFullToConn(conn, frame); err != nil {
return bulkAttachResponse{}, err
}
replyCh := make(chan Message, 1)
readBuf := streamReadBuffer()
for {
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetReadDeadline(deadline)
}
n, err := conn.Read(readBuf)
if err != nil {
return bulkAttachResponse{}, err
}
parseErr := queue.ParseMessageOwned(readBuf[:n], "bulk-attach", func(msgq stario.MsgQueue) error {
transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, msgq.Msg)
if err != nil {
return err
}
replyCh <- Message{
ServerConn: c,
TransferMsg: transfer,
NetType: NET_CLIENT,
}
return nil
})
if parseErr != nil {
return bulkAttachResponse{}, parseErr
}
select {
case reply := <-replyCh:
return decodeBulkAttachResponse(c.sequenceDe, reply.Value)
default:
}
}
}
func (c *ClientCommon) readDedicatedBulkLoop(bulk *bulkHandle, conn net.Conn) {
for {
payload, err := readBulkDedicatedRecord(conn)
if err != nil {
handleDedicatedBulkReadError(bulk, err)
return
}
plain, err := c.decryptTransportPayload(payload)
if err != nil {
_ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error())
bulk.markReset(err)
return
}
items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain)
if err != nil {
_ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error())
bulk.markReset(err)
return
}
for _, item := range items {
if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil {
if !errors.Is(err, io.EOF) {
_ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error())
bulk.markReset(err)
}
return
}
if bulk.Context().Err() != nil {
return
}
}
}
}
func (s *ServerCommon) handleBulkAttachSystemMessage(message Message) bool {
if message.Key != systemBulkAttachKey {
return false
}
current := messageLogicalConnSnapshot(&message)
resp := bulkAttachResponse{}
var (
req bulkAttachRequest
logical *LogicalConn
bulk *bulkHandle
err error
)
req, err = decodeBulkAttachRequest(s.sequenceDe, message.Value)
if err == nil {
logical, bulk, err = s.resolveInboundDedicatedBulk(current, req)
}
if err != nil {
resp.Error = err.Error()
} else {
resp.Accepted = true
}
if current != nil {
_ = s.replyDedicatedBulkAttach(current, message, resp)
}
if err == nil {
if attachErr := s.finishInboundDedicatedBulkAttach(current, logical, bulk); attachErr != nil {
bulk.markReset(attachErr)
}
}
return true
}
func (s *ServerCommon) resolveInboundDedicatedBulk(current *LogicalConn, req bulkAttachRequest) (*LogicalConn, *bulkHandle, error) {
if s == nil {
return nil, nil, errBulkServerNil
}
if current == nil {
return nil, nil, errBulkLogicalConnNil
}
if req.PeerID == "" || req.BulkID == "" || req.AttachToken == "" {
return nil, nil, errBulkIDEmpty
}
logical := s.GetLogicalConn(req.PeerID)
if logical == nil {
return nil, nil, errBulkLogicalConnNil
}
runtime := s.getBulkRuntime()
if runtime == nil {
return nil, nil, errBulkRuntimeNil
}
bulk, ok := runtime.lookup(serverFileScope(logical), req.BulkID)
if !ok {
return nil, nil, errBulkNotFound
}
if !bulk.Dedicated() {
return nil, nil, errors.New("bulk is not dedicated")
}
if bulk.dedicatedAttachTokenSnapshot() != req.AttachToken {
return nil, nil, errors.New("bulk attach token mismatch")
}
return logical, bulk, nil
}
func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle) error {
if current == nil || logical == nil || bulk == nil {
return errBulkLogicalConnNil
}
conn, err := current.detachTransportForTransfer()
if err != nil {
return err
}
if err := bulk.attachDedicatedConn(conn); err != nil {
if conn != nil {
_ = conn.Close()
}
return err
}
go s.readDedicatedBulkLoop(logical, bulk, conn)
current.markSessionStopped("bulk dedicated attach", nil)
s.removeLogical(current)
return nil
}
func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Message, resp bulkAttachResponse) error {
if s == nil || client == nil {
return errBulkServerNil
}
encoded, err := s.sequenceEn(resp)
if err != nil {
return err
}
reply := TransferMsg{
ID: message.ID,
Key: systemBulkAttachKey,
Value: encoded,
Type: MSG_SYS_REPLY,
}
if message.inboundConn != nil {
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply)
}
_, err = s.sendLogical(client, reply)
return err
}
func (s *ServerCommon) readDedicatedBulkLoop(logical *LogicalConn, bulk *bulkHandle, conn net.Conn) {
for {
payload, err := readBulkDedicatedRecord(conn)
if err != nil {
handleDedicatedBulkReadError(bulk, err)
return
}
plain, err := s.decryptTransportPayloadLogical(logical, payload)
if err != nil {
_ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error())
bulk.markReset(err)
return
}
items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain)
if err != nil {
_ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error())
bulk.markReset(err)
return
}
for _, item := range items {
if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil {
if !errors.Is(err, io.EOF) {
_ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error())
bulk.markReset(err)
}
return
}
if bulk.Context().Err() != nil {
return
}
}
}
}
func handleDedicatedBulkReadError(bulk *bulkHandle, err error) {
if bulk == nil {
return
}
if bulk.Context().Err() != nil || bulk.remoteClosedSnapshot() {
return
}
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
if bulk.Dedicated() || bulk.localClosedSnapshot() {
bulk.markRemoteClosed()
return
}
}
bulk.markReset(transportDetachedError("dedicated bulk read error", err))
}
func (c *ClientCommon) dedicatedBulkSender(bulk *bulkHandle) (*bulkDedicatedSender, error) {
if c == nil || bulk == nil {
return nil, errBulkClientNil
}
if sender := bulk.dedicatedSenderSnapshot(); sender != nil {
return sender, nil
}
conn := bulk.dedicatedConnSnapshot()
if conn == nil {
return nil, transportDetachedError("dedicated bulk sidecar not attached", nil)
}
sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), c.encryptTransportPayload, func(items []bulkDedicatedSendRequest) ([]byte, error) {
return c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), items)
}, func(err error) {
bulk.markReset(err)
})
actual := bulk.installDedicatedSender(sender)
if actual != sender {
sender.stop()
}
return actual, nil
}
func (c *ClientCommon) sendDedicatedBulkData(ctx context.Context, bulk *bulkHandle, chunk []byte) error {
if c == nil || bulk == nil {
return errBulkClientNil
}
sender, err := c.dedicatedBulkSender(bulk)
if err != nil {
return err
}
return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk)
}
func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) {
if c == nil || bulk == nil {
return 0, errBulkClientNil
}
sender, err := c.dedicatedBulkSender(bulk)
if err != nil {
return 0, err
}
return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize)
}
func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error {
if c == nil || bulk == nil {
return errBulkClientNil
}
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
flags := uint8(0)
if full {
flags = bulkFastPayloadFlagFullClose
}
sender, err := c.dedicatedBulkSender(bulk)
if err != nil {
return err
}
return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil)
}
func (c *ClientCommon) sendDedicatedBulkReset(ctx context.Context, bulk *bulkHandle, message string) error {
if c == nil || bulk == nil {
return errBulkClientNil
}
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
sender, err := c.dedicatedBulkSender(bulk)
if err != nil {
return err
}
return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message))
}
func (c *ClientCommon) sendDedicatedBulkRelease(ctx context.Context, bulk *bulkHandle, bytes int64, chunks int) error {
if c == nil || bulk == nil {
return errBulkClientNil
}
payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks)
if err != nil {
return err
}
if err := bulk.waitDedicatedReady(ctx); err != nil {
return err
}
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
frame, err := c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeRelease,
Payload: payload,
}})
if err != nil {
return err
}
conn := bulk.dedicatedConnSnapshot()
if conn == nil {
return transportDetachedError("dedicated bulk sidecar not attached", nil)
}
deadline, _ := sendCtx.Deadline()
return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline)
}
func (s *ServerCommon) dedicatedBulkSender(logical *LogicalConn, bulk *bulkHandle) (*bulkDedicatedSender, error) {
if s == nil || bulk == nil {
return nil, errBulkServerNil
}
if logical == nil {
logical = bulk.LogicalConn()
}
if logical == nil {
return nil, errBulkLogicalConnNil
}
if sender := bulk.dedicatedSenderSnapshot(); sender != nil {
return sender, nil
}
conn := bulk.dedicatedConnSnapshot()
if conn == nil {
return nil, transportDetachedError("dedicated bulk sidecar not attached", nil)
}
sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), func(plain []byte) ([]byte, error) {
return s.encryptTransportPayloadLogical(logical, plain)
}, func(items []bulkDedicatedSendRequest) ([]byte, error) {
return s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), items)
}, func(err error) {
bulk.markReset(err)
})
actual := bulk.installDedicatedSender(sender)
if actual != sender {
sender.stop()
}
return actual, nil
}
func (s *ServerCommon) sendDedicatedBulkData(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, chunk []byte) error {
if s == nil || bulk == nil {
return errBulkServerNil
}
sender, err := s.dedicatedBulkSender(logical, bulk)
if err != nil {
return err
}
return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk)
}
func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, payload []byte) (int, error) {
if s == nil || bulk == nil {
return 0, errBulkServerNil
}
sender, err := s.dedicatedBulkSender(logical, bulk)
if err != nil {
return 0, err
}
return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize)
}
func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error {
if s == nil || bulk == nil {
return errBulkServerNil
}
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
flags := uint8(0)
if full {
flags = bulkFastPayloadFlagFullClose
}
sender, err := s.dedicatedBulkSender(logical, bulk)
if err != nil {
return err
}
return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil)
}
func (s *ServerCommon) sendDedicatedBulkReset(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, message string) error {
if s == nil || bulk == nil {
return errBulkServerNil
}
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
sender, err := s.dedicatedBulkSender(logical, bulk)
if err != nil {
return err
}
return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message))
}
func (s *ServerCommon) sendDedicatedBulkRelease(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, bytes int64, chunks int) error {
if s == nil || bulk == nil {
return errBulkServerNil
}
payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks)
if err != nil {
return err
}
if err := bulk.waitDedicatedReady(ctx); err != nil {
return err
}
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
frame, err := s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeRelease,
Payload: payload,
}})
if err != nil {
return err
}
conn := bulk.dedicatedConnSnapshot()
if conn == nil {
return transportDetachedError("dedicated bulk sidecar not attached", nil)
}
deadline, _ := sendCtx.Deadline()
return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline)
}
func (c *ClientCommon) encodeDedicatedBulkBatchPayload(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
if c == nil {
return nil, errBulkClientNil
}
if c.fastPlainEncode != nil {
return encodeBulkDedicatedBatchPayloadFast(c.fastPlainEncode, c.SecretKey, dataID, items)
}
plain, err := encodeBulkDedicatedBatchPlain(dataID, items)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (s *ServerCommon) encodeDedicatedBulkBatchPayload(logical *LogicalConn, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
if s == nil {
return nil, errBulkServerNil
}
if logical == nil {
return nil, errBulkLogicalConnNil
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeBulkDedicatedBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), dataID, items)
}
plain, err := encodeBulkDedicatedBatchPlain(dataID, items)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}