fix: 修复 dedicated bulk attach 竞态并优化 short write 补写路径
- 客户端 dedicated attach 回复改为精确读取单帧,避免 attach reply 与后续 NBR1 数据粘连后被误解析 - 服务端 accepted attach 改为先 detach transport,再直接回 attach reply,随后立即切入 dedicated bulk read loop - transport 读循环在 stop 或 transport ownership 失效后不再继续上推已读数据,避免 handoff 后首包被旧 reader 吃掉 - dedicated bulk record 写路径改为 full-write,消除 short write 导致的 invalid bulk fast payload - 优化 vectored write 补写策略:先尝试一次 writev,未写完时直接顺序补完剩余 buffers,减少重复 WriteTo 开销 - 放宽 vectored write 能力识别,支持通过 UnwrapConn/WriteBuffers 命中 fast path - 修复 dedicated batch 排队路径 payload 复用问题,改为深拷贝 queued items - 补齐 dedicated attach、short write、payload clone、transport stop/handoff 等回归测试
This commit is contained in:
parent
7ed3dd5b37
commit
4f760f2807
@ -3,11 +3,13 @@ package notify
|
||||
import (
|
||||
"b612.me/notify/internal/transport"
|
||||
"b612.me/stario"
|
||||
"bytes"
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
@ -19,8 +21,17 @@ const (
|
||||
bulkDedicatedRecordMagic = "NBR1"
|
||||
bulkDedicatedRecordHeaderLen = 8
|
||||
bulkDedicatedAttachTimeout = 5 * time.Second
|
||||
|
||||
bulkDedicatedAttachFrameMagicSize = 8
|
||||
bulkDedicatedAttachFrameHeaderLen = 14
|
||||
bulkDedicatedAttachFrameVersionOffset = 12
|
||||
bulkDedicatedAttachFrameFlagsOffset = 13
|
||||
bulkDedicatedAttachFrameVersionV1 = 1
|
||||
bulkDedicatedAttachFrameFlagsNone = 0
|
||||
)
|
||||
|
||||
var bulkDedicatedAttachFrameMagic = [bulkDedicatedAttachFrameMagicSize]byte{11, 27, 19, 96, 12, 25, 2, 20}
|
||||
|
||||
type bulkAttachRequest struct {
|
||||
PeerID string
|
||||
BulkID string
|
||||
@ -121,6 +132,35 @@ func decodeDirectSignalPayload(sequenceDe func([]byte) (interface{}, error), msg
|
||||
return unwrapTransferMsgEnvelope(env, sequenceDe)
|
||||
}
|
||||
|
||||
func readDirectSignalFramePayload(conn net.Conn) ([]byte, error) {
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
var header [bulkDedicatedAttachFrameHeaderLen]byte
|
||||
if _, err := io.ReadFull(conn, header[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !bytes.Equal(header[:bulkDedicatedAttachFrameMagicSize], bulkDedicatedAttachFrameMagic[:]) {
|
||||
return nil, stario.ErrQueueDataFormat
|
||||
}
|
||||
if got := header[bulkDedicatedAttachFrameVersionOffset]; got != bulkDedicatedAttachFrameVersionV1 {
|
||||
return nil, stario.ErrQueueUnsupportedVersion
|
||||
}
|
||||
if got := header[bulkDedicatedAttachFrameFlagsOffset]; got != bulkDedicatedAttachFrameFlagsNone {
|
||||
return nil, stario.ErrQueueUnsupportedFlags
|
||||
}
|
||||
length := binary.BigEndian.Uint32(header[bulkDedicatedAttachFrameMagicSize : bulkDedicatedAttachFrameMagicSize+4])
|
||||
maxInt := int(^uint(0) >> 1)
|
||||
if uint64(length) > uint64(maxInt) {
|
||||
return nil, stario.ErrQueueMessageTooLarge
|
||||
}
|
||||
payload := make([]byte, int(length))
|
||||
if _, err := io.ReadFull(conn, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func writeBulkDedicatedRecord(conn net.Conn, payload []byte) error {
|
||||
return writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{})
|
||||
}
|
||||
@ -133,9 +173,7 @@ func writeBulkDedicatedRecordWithDeadline(conn net.Conn, payload []byte, deadlin
|
||||
var header [bulkDedicatedRecordHeaderLen]byte
|
||||
copy(header[:4], bulkDedicatedRecordMagic)
|
||||
binary.BigEndian.PutUint32(header[4:8], uint32(len(payload)))
|
||||
buffers := net.Buffers{header[:], payload}
|
||||
_, err := buffers.WriteTo(conn)
|
||||
return err
|
||||
return writeNetBuffersFullUnlocked(conn, net.Buffers{header[:], payload})
|
||||
})
|
||||
}
|
||||
|
||||
@ -148,7 +186,7 @@ func readBulkDedicatedRecord(conn net.Conn) ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
if string(header[:4]) != bulkDedicatedRecordMagic {
|
||||
return nil, errBulkFastPayloadInvalid
|
||||
return nil, fmt.Errorf("%w: record magic=%x", errBulkFastPayloadInvalid, header[:4])
|
||||
}
|
||||
size := int(binary.BigEndian.Uint32(header[4:8]))
|
||||
if size < 0 {
|
||||
@ -224,51 +262,34 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn
|
||||
if err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
queue := stario.NewQueue()
|
||||
msg := TransferMsg{
|
||||
ID: atomic.AddUint64(&c.msgID, 1),
|
||||
Key: systemBulkAttachKey,
|
||||
Value: reqPayload,
|
||||
Type: MSG_SYS_WAIT,
|
||||
}
|
||||
frame, err := encodeDirectSignalFrame(queue, c.sequenceEn, c.msgEn, c.SecretKey, msg)
|
||||
frame, err := encodeDirectSignalFrame(stario.NewQueue(), c.sequenceEn, c.msgEn, c.SecretKey, msg)
|
||||
if err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
if err := writeFullToConn(conn, frame); err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
replyCh := make(chan Message, 1)
|
||||
readBuf := streamReadBuffer()
|
||||
for {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
_ = conn.SetReadDeadline(deadline)
|
||||
}
|
||||
n, err := conn.Read(readBuf)
|
||||
if err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
parseErr := queue.ParseMessageOwned(readBuf[:n], "bulk-attach", func(msgq stario.MsgQueue) error {
|
||||
transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, msgq.Msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
replyCh <- Message{
|
||||
ServerConn: c,
|
||||
TransferMsg: transfer,
|
||||
NetType: NET_CLIENT,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if parseErr != nil {
|
||||
return bulkAttachResponse{}, parseErr
|
||||
}
|
||||
select {
|
||||
case reply := <-replyCh:
|
||||
return decodeBulkAttachResponse(c.sequenceDe, reply.Value)
|
||||
default:
|
||||
}
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
_ = conn.SetReadDeadline(deadline)
|
||||
}
|
||||
replyPayload, err := readDirectSignalFramePayload(conn)
|
||||
if err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, replyPayload)
|
||||
if err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
if transfer.Key != systemBulkAttachKey || transfer.Type != MSG_SYS_REPLY || transfer.ID != msg.ID {
|
||||
return bulkAttachResponse{}, errors.New("invalid bulk attach reply")
|
||||
}
|
||||
return decodeBulkAttachResponse(c.sequenceDe, transfer.Value)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) readDedicatedBulkLoop(bulk *bulkHandle, conn net.Conn) {
|
||||
@ -323,14 +344,13 @@ func (s *ServerCommon) handleBulkAttachSystemMessage(message Message) bool {
|
||||
}
|
||||
if err != nil {
|
||||
resp.Error = err.Error()
|
||||
} else {
|
||||
resp.Accepted = true
|
||||
if current != nil {
|
||||
_ = s.replyDedicatedBulkAttach(current, message, resp)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if current != nil {
|
||||
_ = s.replyDedicatedBulkAttach(current, message, resp)
|
||||
}
|
||||
if err == nil {
|
||||
if attachErr := s.finishInboundDedicatedBulkAttach(current, logical, bulk); attachErr != nil {
|
||||
if attachErr := s.finishInboundDedicatedBulkAttach(current, logical, bulk, message); attachErr != nil {
|
||||
bulk.markReset(attachErr)
|
||||
}
|
||||
}
|
||||
@ -368,7 +388,7 @@ func (s *ServerCommon) resolveInboundDedicatedBulk(current *LogicalConn, req bul
|
||||
return logical, bulk, nil
|
||||
}
|
||||
|
||||
func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle) error {
|
||||
func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle, message Message) error {
|
||||
if current == nil || logical == nil || bulk == nil {
|
||||
return errBulkLogicalConnNil
|
||||
}
|
||||
@ -376,18 +396,56 @@ func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, lo
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bulk.attachDedicatedConn(conn); err != nil {
|
||||
fail := func(reason string, err error) error {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
current.markSessionStopped(reason, err)
|
||||
s.removeLogical(current)
|
||||
return err
|
||||
}
|
||||
if err := s.replyDedicatedBulkAttachDetached(current, conn, message, bulkAttachResponse{Accepted: true}); err != nil {
|
||||
return fail("bulk dedicated attach reply failed", err)
|
||||
}
|
||||
if err := bulk.attachDedicatedConn(conn); err != nil {
|
||||
return fail("bulk dedicated attach failed", err)
|
||||
}
|
||||
go s.readDedicatedBulkLoop(logical, bulk, conn)
|
||||
current.markSessionStopped("bulk dedicated attach", nil)
|
||||
s.removeLogical(current)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerCommon) replyDedicatedBulkAttachDetached(client *LogicalConn, conn net.Conn, message Message, resp bulkAttachResponse) error {
|
||||
if s == nil || client == nil {
|
||||
return errBulkServerNil
|
||||
}
|
||||
if conn == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
msgEn := client.msgEnSnapshot()
|
||||
if msgEn == nil {
|
||||
return errTransportPayloadEncryptFailed
|
||||
}
|
||||
encoded, err := s.sequenceEn(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reply := TransferMsg{
|
||||
ID: message.ID,
|
||||
Key: systemBulkAttachKey,
|
||||
Value: encoded,
|
||||
Type: MSG_SYS_REPLY,
|
||||
}
|
||||
frame, err := encodeDirectSignalFrame(stario.NewQueue(), s.sequenceEn, msgEn, client.secretKeySnapshot(), reply)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return withRawConnWriteLockDeadline(conn, writeDeadlineFromTimeout(client.maxWriteTimeoutSnapshot()), func(conn net.Conn) error {
|
||||
return writeFullToConnUnlocked(conn, frame)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Message, resp bulkAttachResponse) error {
|
||||
if s == nil || client == nil {
|
||||
return errBulkServerNil
|
||||
|
||||
217
bulk_dedicated_attach_test.go
Normal file
217
bulk_dedicated_attach_test.go
Normal file
@ -0,0 +1,217 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"b612.me/stario"
|
||||
)
|
||||
|
||||
type bulkAttachScriptConn struct {
|
||||
readBuf *bytes.Reader
|
||||
writeBuf bytes.Buffer
|
||||
}
|
||||
|
||||
func newBulkAttachScriptConn(inbound []byte) *bulkAttachScriptConn {
|
||||
return &bulkAttachScriptConn{
|
||||
readBuf: bytes.NewReader(append([]byte(nil), inbound...)),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *bulkAttachScriptConn) Read(p []byte) (int, error) { return c.readBuf.Read(p) }
|
||||
func (c *bulkAttachScriptConn) Write(p []byte) (int, error) { return c.writeBuf.Write(p) }
|
||||
func (c *bulkAttachScriptConn) Close() error { return nil }
|
||||
func (c *bulkAttachScriptConn) LocalAddr() net.Addr { return bulkAttachTestAddr("local") }
|
||||
func (c *bulkAttachScriptConn) RemoteAddr() net.Addr { return bulkAttachTestAddr("remote") }
|
||||
func (c *bulkAttachScriptConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *bulkAttachScriptConn) SetReadDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (c *bulkAttachScriptConn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *bulkAttachScriptConn) writtenBytes() []byte {
|
||||
return append([]byte(nil), c.writeBuf.Bytes()...)
|
||||
}
|
||||
|
||||
type bulkAttachTestAddr string
|
||||
|
||||
func (a bulkAttachTestAddr) Network() string { return "tcp" }
|
||||
func (a bulkAttachTestAddr) String() string { return string(a) }
|
||||
|
||||
func encodeDedicatedRecordForAttachTest(payload []byte) []byte {
|
||||
out := make([]byte, bulkDedicatedRecordHeaderLen+len(payload))
|
||||
copy(out[:4], bulkDedicatedRecordMagic)
|
||||
binary.BigEndian.PutUint32(out[4:8], uint32(len(payload)))
|
||||
copy(out[bulkDedicatedRecordHeaderLen:], payload)
|
||||
return out
|
||||
}
|
||||
|
||||
func TestSendDedicatedBulkAttachRequestKeepsCoalescedDedicatedPayloadUnread(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
UseLegacySecurityClient(client)
|
||||
client.msgID = 100
|
||||
|
||||
bulk := newBulkHandle(context.Background(), newBulkRuntime("dedicated-attach-test"), clientFileScope(), BulkOpenRequest{
|
||||
BulkID: "bulk-attach-test",
|
||||
DataID: 1,
|
||||
Dedicated: true,
|
||||
AttachToken: "attach-token",
|
||||
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
|
||||
|
||||
encodedResp, err := client.sequenceEn(bulkAttachResponse{Accepted: true})
|
||||
if err != nil {
|
||||
t.Fatalf("encode bulkAttachResponse failed: %v", err)
|
||||
}
|
||||
replyFrame, err := encodeDirectSignalFrame(stario.NewQueue(), client.sequenceEn, client.msgEn, client.SecretKey, TransferMsg{
|
||||
ID: 101,
|
||||
Key: systemBulkAttachKey,
|
||||
Value: encodedResp,
|
||||
Type: MSG_SYS_REPLY,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("encode attach reply frame failed: %v", err)
|
||||
}
|
||||
dedicatedPayload := []byte("dedicated-tail-bytes")
|
||||
conn := newBulkAttachScriptConn(append(replyFrame, encodeDedicatedRecordForAttachTest(dedicatedPayload)...))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
resp, err := client.sendDedicatedBulkAttachRequest(ctx, conn, bulk)
|
||||
if err != nil {
|
||||
t.Fatalf("sendDedicatedBulkAttachRequest failed: %v", err)
|
||||
}
|
||||
if !resp.Accepted {
|
||||
t.Fatalf("bulk attach response = %+v, want accepted", resp)
|
||||
}
|
||||
|
||||
parsedReq := stario.NewQueue()
|
||||
var reqMsg TransferMsg
|
||||
if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request", func(msgq stario.MsgQueue) error {
|
||||
transfer, err := decodeDirectSignalPayload(client.sequenceDe, client.msgDe, client.SecretKey, msgq.Msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reqMsg = transfer
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("parse written attach request failed: %v", err)
|
||||
}
|
||||
if reqMsg.Key != systemBulkAttachKey || reqMsg.Type != MSG_SYS_WAIT {
|
||||
t.Fatalf("attach request message mismatch: %+v", reqMsg)
|
||||
}
|
||||
readPayload, err := readBulkDedicatedRecord(conn)
|
||||
if err != nil {
|
||||
t.Fatalf("readBulkDedicatedRecord after attach failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(readPayload, dedicatedPayload) {
|
||||
t.Fatalf("dedicated payload mismatch: got %q want %q", string(readPayload), string(dedicatedPayload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHandoff(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
UseLegacySecurityServer(server)
|
||||
|
||||
sidecarLeft, sidecarRight := net.Pipe()
|
||||
defer sidecarRight.Close()
|
||||
|
||||
current := server.bootstrapAcceptedLogical("dedicated-attach-current", nil, sidecarLeft)
|
||||
if current == nil {
|
||||
t.Fatal("bootstrapAcceptedLogical(current) should return logical")
|
||||
}
|
||||
target := server.bootstrapAcceptedLogical("dedicated-attach-target", nil, nil)
|
||||
if target == nil {
|
||||
t.Fatal("bootstrapAcceptedLogical(target) should return logical")
|
||||
}
|
||||
|
||||
bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{
|
||||
BulkID: "server-dedicated-attach-test",
|
||||
DataID: 7,
|
||||
Dedicated: true,
|
||||
AttachToken: "attach-token",
|
||||
}, 0, target, nil, 0, nil, nil, nil, nil, nil)
|
||||
if err := server.getBulkRuntime().register(serverFileScope(target), bulk); err != nil {
|
||||
t.Fatalf("register bulk runtime failed: %v", err)
|
||||
}
|
||||
|
||||
reqPayload, err := server.sequenceEn(bulkAttachRequest{
|
||||
PeerID: target.ID(),
|
||||
BulkID: bulk.ID(),
|
||||
AttachToken: "attach-token",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("encode bulkAttachRequest failed: %v", err)
|
||||
}
|
||||
msg := Message{
|
||||
NetType: NET_SERVER,
|
||||
LogicalConn: current,
|
||||
ClientConn: current.compatClientConn(),
|
||||
TransferMsg: TransferMsg{
|
||||
ID: 42,
|
||||
Key: systemBulkAttachKey,
|
||||
Value: reqPayload,
|
||||
Type: MSG_SYS_WAIT,
|
||||
},
|
||||
inboundConn: sidecarLeft,
|
||||
Time: time.Now(),
|
||||
}
|
||||
|
||||
type attachReplyResult struct {
|
||||
transfer TransferMsg
|
||||
resp bulkAttachResponse
|
||||
err error
|
||||
}
|
||||
replyCh := make(chan attachReplyResult, 1)
|
||||
go func() {
|
||||
_ = sidecarRight.SetReadDeadline(time.Now().Add(time.Second))
|
||||
replyPayload, err := readDirectSignalFramePayload(sidecarRight)
|
||||
if err != nil {
|
||||
replyCh <- attachReplyResult{err: err}
|
||||
return
|
||||
}
|
||||
transfer, err := decodeDirectSignalPayload(server.sequenceDe, current.msgDeSnapshot(), current.secretKeySnapshot(), replyPayload)
|
||||
if err != nil {
|
||||
replyCh <- attachReplyResult{err: err}
|
||||
return
|
||||
}
|
||||
resp, err := decodeBulkAttachResponse(server.sequenceDe, transfer.Value)
|
||||
replyCh <- attachReplyResult{transfer: transfer, resp: resp, err: err}
|
||||
}()
|
||||
|
||||
if !server.handleBulkAttachSystemMessage(msg) {
|
||||
t.Fatal("handleBulkAttachSystemMessage should accept dedicated attach message")
|
||||
}
|
||||
|
||||
var result attachReplyResult
|
||||
select {
|
||||
case result = <-replyCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for direct attach reply")
|
||||
}
|
||||
if result.err != nil {
|
||||
t.Fatalf("read direct attach reply failed: %v", result.err)
|
||||
}
|
||||
transfer := result.transfer
|
||||
if transfer.ID != msg.ID || transfer.Key != systemBulkAttachKey || transfer.Type != MSG_SYS_REPLY {
|
||||
t.Fatalf("attach reply mismatch: %+v", transfer)
|
||||
}
|
||||
resp := result.resp
|
||||
if !resp.Accepted || resp.Error != "" {
|
||||
t.Fatalf("bulk attach response = %+v, want accepted", resp)
|
||||
}
|
||||
if got := bulk.dedicatedConnSnapshot(); got != sidecarLeft {
|
||||
t.Fatalf("dedicated conn mismatch: got %v want %v", got, sidecarLeft)
|
||||
}
|
||||
if current.transportAttachedSnapshot() {
|
||||
t.Fatal("attach sidecar logical transport should be detached after handoff")
|
||||
}
|
||||
if got := server.GetLogicalConn(current.ID()); got != nil {
|
||||
t.Fatalf("attach sidecar logical should be removed after handoff, got %+v", got)
|
||||
}
|
||||
}
|
||||
@ -165,8 +165,7 @@ func (s *bulkDedicatedSender) submitWriteBatch(ctx context.Context, items []bulk
|
||||
if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted {
|
||||
return err
|
||||
}
|
||||
queuedItems := make([]bulkDedicatedSendRequest, len(items))
|
||||
copy(queuedItems, items)
|
||||
queuedItems := cloneBulkDedicatedSendRequests(items)
|
||||
return s.submitBatch(ctx, queuedItems, true)
|
||||
}
|
||||
|
||||
@ -458,6 +457,20 @@ func (s *bulkDedicatedSender) stoppedErr() error {
|
||||
return errTransportDetached
|
||||
}
|
||||
|
||||
func cloneBulkDedicatedSendRequests(items []bulkDedicatedSendRequest) []bulkDedicatedSendRequest {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
cloned := make([]bulkDedicatedSendRequest, len(items))
|
||||
for i, item := range items {
|
||||
cloned[i] = item
|
||||
if len(item.Payload) > 0 {
|
||||
cloned[i].Payload = append([]byte(nil), item.Payload...)
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int {
|
||||
return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload))
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
@ -8,6 +9,38 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCloneBulkDedicatedSendRequestsDeepCopiesPayload(t *testing.T) {
|
||||
src := []bulkDedicatedSendRequest{
|
||||
{
|
||||
Type: bulkFastPayloadTypeData,
|
||||
Seq: 1,
|
||||
Payload: []byte("payload-a"),
|
||||
},
|
||||
{
|
||||
Type: bulkFastPayloadTypeReset,
|
||||
Seq: 2,
|
||||
Payload: []byte("payload-b"),
|
||||
},
|
||||
}
|
||||
cloned := cloneBulkDedicatedSendRequests(src)
|
||||
if len(cloned) != len(src) {
|
||||
t.Fatalf("clone length = %d, want %d", len(cloned), len(src))
|
||||
}
|
||||
if &cloned[0] == &src[0] {
|
||||
t.Fatal("request clone should not alias source slice elements")
|
||||
}
|
||||
if len(cloned[0].Payload) == 0 || len(src[0].Payload) == 0 {
|
||||
t.Fatal("payload should not be empty")
|
||||
}
|
||||
if &cloned[0].Payload[0] == &src[0].Payload[0] {
|
||||
t.Fatal("payload clone should not alias source bytes")
|
||||
}
|
||||
src[0].Payload[0] = 'X'
|
||||
if bytes.Equal(cloned[0].Payload, src[0].Payload) {
|
||||
t.Fatal("mutating source payload should not affect cloned payload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkDedicatedBatchPlainRoundTrip(t *testing.T) {
|
||||
releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2)
|
||||
if err != nil {
|
||||
|
||||
66
bulk_dedicated_record_test.go
Normal file
66
bulk_dedicated_record_test.go
Normal file
@ -0,0 +1,66 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type shortWriteBulkRecordConn struct {
|
||||
maxPerWrite int
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func (c *shortWriteBulkRecordConn) Read([]byte) (int, error) { return 0, io.EOF }
|
||||
|
||||
func (c *shortWriteBulkRecordConn) Write(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
n := c.maxPerWrite
|
||||
if n <= 0 || n > len(p) {
|
||||
n = len(p)
|
||||
}
|
||||
_, _ = c.buf.Write(p[:n])
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *shortWriteBulkRecordConn) Close() error { return nil }
|
||||
func (c *shortWriteBulkRecordConn) LocalAddr() net.Addr { return shortWriteBulkRecordAddr("local") }
|
||||
func (c *shortWriteBulkRecordConn) RemoteAddr() net.Addr { return shortWriteBulkRecordAddr("remote") }
|
||||
func (c *shortWriteBulkRecordConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *shortWriteBulkRecordConn) SetReadDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (c *shortWriteBulkRecordConn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type shortWriteBulkRecordAddr string
|
||||
|
||||
func (a shortWriteBulkRecordAddr) Network() string { return "tcp" }
|
||||
func (a shortWriteBulkRecordAddr) String() string { return string(a) }
|
||||
|
||||
func TestWriteBulkDedicatedRecordWithDeadlineHandlesShortWrite(t *testing.T) {
|
||||
conn := &shortWriteBulkRecordConn{maxPerWrite: 3}
|
||||
payload := []byte("abcdefghijklmnopqrstuvwxyz")
|
||||
if err := writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{}); err != nil {
|
||||
t.Fatalf("writeBulkDedicatedRecordWithDeadline failed: %v", err)
|
||||
}
|
||||
raw := conn.buf.Bytes()
|
||||
if got, want := len(raw), bulkDedicatedRecordHeaderLen+len(payload); got != want {
|
||||
t.Fatalf("record length = %d, want %d", got, want)
|
||||
}
|
||||
if got := string(raw[:4]); got != bulkDedicatedRecordMagic {
|
||||
t.Fatalf("record magic = %q, want %q", got, bulkDedicatedRecordMagic)
|
||||
}
|
||||
if got, want := int(binary.BigEndian.Uint32(raw[4:8])), len(payload); got != want {
|
||||
t.Fatalf("record payload length = %d, want %d", got, want)
|
||||
}
|
||||
if got := raw[bulkDedicatedRecordHeaderLen:]; !bytes.Equal(got, payload) {
|
||||
t.Fatalf("record payload mismatch")
|
||||
}
|
||||
}
|
||||
@ -332,6 +332,24 @@ func TestLogicalDetachTransportForTransferKeepsHandoffConnAlive(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogicalHandleTUTransportReadResultWithSessionDropsDataAfterTransportStop(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
logical, _, _ := newRegisteredServerLogicalForTest(t, server, "logical-stop-read-drop", left, stopCtx, stopFn)
|
||||
if logical == nil {
|
||||
t.Fatal("logical should not be nil")
|
||||
}
|
||||
generation := logical.transportGenerationSnapshot()
|
||||
stopFn()
|
||||
|
||||
if logical.handleTUTransportReadResultWithSession(stopCtx, left, generation, len([]byte("late-data")), []byte("late-data"), nil) {
|
||||
t.Fatal("handleTUTransportReadResultWithSession should stop after transport stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientConnTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) {
|
||||
client := &ClientConn{}
|
||||
left, right := net.Pipe()
|
||||
|
||||
@ -69,6 +69,12 @@ func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []by
|
||||
}
|
||||
|
||||
func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool {
|
||||
if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) {
|
||||
if c.shouldCloseTransportOnStop(conn) {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return false
|
||||
}
|
||||
if err == os.ErrDeadlineExceeded {
|
||||
if num != 0 {
|
||||
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
|
||||
@ -95,6 +101,30 @@ func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Con
|
||||
return true
|
||||
}
|
||||
|
||||
func transportReadShouldStop(stopCtx context.Context) bool {
|
||||
select {
|
||||
case <-sessionStopChan(stopCtx):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LogicalConn) ownsTransportRead(conn net.Conn, generation uint64) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil || !rt.transportAttached || rt.transportGeneration != generation {
|
||||
return false
|
||||
}
|
||||
current := rt.tuConn
|
||||
if rt.transport != nil && rt.transport.connSnapshot() != nil {
|
||||
current = rt.transport.connSnapshot()
|
||||
}
|
||||
return current == conn
|
||||
}
|
||||
|
||||
func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) {
|
||||
if c == nil || len(data) == 0 {
|
||||
return
|
||||
@ -163,6 +193,12 @@ func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Cont
|
||||
if logical := c.LogicalConn(); logical != nil {
|
||||
return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err)
|
||||
}
|
||||
if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) {
|
||||
if c.shouldCloseClientConnTransportOnStop(conn) {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return false
|
||||
}
|
||||
if err == os.ErrDeadlineExceeded {
|
||||
if num != 0 {
|
||||
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
|
||||
@ -189,6 +225,21 @@ func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Cont
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *ClientConn) ownsTransportRead(conn net.Conn, generation uint64) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil || !rt.transportAttached || rt.transportGeneration != generation {
|
||||
return false
|
||||
}
|
||||
current := rt.tuConn
|
||||
if rt.transport != nil && rt.transport.connSnapshot() != nil {
|
||||
current = rt.transport.connSnapshot()
|
||||
}
|
||||
return current == conn
|
||||
}
|
||||
|
||||
func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) {
|
||||
if logical := c.LogicalConn(); logical != nil {
|
||||
logical.pushServerOwnedTransportMessage(data, conn, generation)
|
||||
|
||||
@ -13,6 +13,14 @@ import (
|
||||
var transportConnWriteLocks sync.Map
|
||||
var errTransportFrameQueueUnavailable = errors.New("transport frame queue is unavailable")
|
||||
|
||||
type vectoredBuffersWriter interface {
|
||||
WriteBuffers(*net.Buffers) (int64, error)
|
||||
}
|
||||
|
||||
type vectoredConnUnwrapper interface {
|
||||
UnwrapConn() net.Conn
|
||||
}
|
||||
|
||||
func writeFullToConn(conn net.Conn, data []byte) error {
|
||||
if conn == nil {
|
||||
return net.ErrClosed
|
||||
@ -26,8 +34,15 @@ func writeFullToConnUnlocked(conn net.Conn, data []byte) error {
|
||||
if conn == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return writeFullToWriterUnlocked(conn, data)
|
||||
}
|
||||
|
||||
func writeFullToWriterUnlocked(writer io.Writer, data []byte) error {
|
||||
if writer == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
for len(data) > 0 {
|
||||
n, err := conn.Write(data)
|
||||
n, err := writer.Write(data)
|
||||
if n > 0 {
|
||||
data = data[n:]
|
||||
}
|
||||
@ -41,6 +56,69 @@ func writeFullToConnUnlocked(conn net.Conn, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeNetBuffersFullUnlocked(conn net.Conn, buffers net.Buffers) error {
|
||||
if conn == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
writer, writeFn := vectoredWriteStrategy(conn)
|
||||
if writeFn == nil {
|
||||
return writeRemainingBuffersUnlocked(conn, buffers)
|
||||
}
|
||||
n, err := writeFn(&buffers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(buffers) == 0 {
|
||||
return nil
|
||||
}
|
||||
if n == 0 {
|
||||
return io.ErrNoProgress
|
||||
}
|
||||
return writeRemainingBuffersUnlocked(writer, buffers)
|
||||
}
|
||||
|
||||
func vectoredWriteStrategy(conn net.Conn) (io.Writer, func(*net.Buffers) (int64, error)) {
|
||||
current := conn
|
||||
for depth := 0; depth < 8 && current != nil; depth++ {
|
||||
if writer, ok := current.(vectoredBuffersWriter); ok {
|
||||
target := current
|
||||
return target, writer.WriteBuffers
|
||||
}
|
||||
switch target := current.(type) {
|
||||
case *net.TCPConn:
|
||||
return target, func(bufs *net.Buffers) (int64, error) {
|
||||
return bufs.WriteTo(target)
|
||||
}
|
||||
case *net.UnixConn:
|
||||
return target, func(bufs *net.Buffers) (int64, error) {
|
||||
return bufs.WriteTo(target)
|
||||
}
|
||||
}
|
||||
unwrapper, ok := current.(vectoredConnUnwrapper)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
next := unwrapper.UnwrapConn()
|
||||
if next == nil || next == current {
|
||||
break
|
||||
}
|
||||
current = next
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func writeRemainingBuffersUnlocked(writer io.Writer, buffers net.Buffers) error {
|
||||
for _, part := range buffers {
|
||||
if len(part) == 0 {
|
||||
continue
|
||||
}
|
||||
if err := writeFullToWriterUnlocked(writer, part); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func withRawConnWriteLock(conn net.Conn, fn func(net.Conn) error) error {
|
||||
return withRawConnWriteLockDeadline(conn, time.Time{}, fn)
|
||||
}
|
||||
|
||||
@ -2,8 +2,10 @@ package notify
|
||||
|
||||
import (
|
||||
"b612.me/stario"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -146,6 +148,113 @@ func (c *blockingPacketWriteConn) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
type vectoredShortWriteConn struct {
|
||||
steps []int64
|
||||
idx int
|
||||
buf bytes.Buffer
|
||||
writes int
|
||||
writev int
|
||||
}
|
||||
|
||||
func (c *vectoredShortWriteConn) Read([]byte) (int, error) { return 0, io.EOF }
|
||||
func (c *vectoredShortWriteConn) Write(p []byte) (int, error) {
|
||||
c.writes++
|
||||
return c.buf.Write(p)
|
||||
}
|
||||
func (c *vectoredShortWriteConn) Close() error { return nil }
|
||||
func (c *vectoredShortWriteConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *vectoredShortWriteConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *vectoredShortWriteConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *vectoredShortWriteConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *vectoredShortWriteConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
func (c *vectoredShortWriteConn) WriteBuffers(bufs *net.Buffers) (int64, error) {
|
||||
c.writev++
|
||||
if c.idx >= len(c.steps) {
|
||||
return 0, io.ErrNoProgress
|
||||
}
|
||||
remaining := c.steps[c.idx]
|
||||
c.idx++
|
||||
written := int64(0)
|
||||
for len(*bufs) > 0 && remaining > 0 {
|
||||
part := (*bufs)[0]
|
||||
if len(part) == 0 {
|
||||
(*bufs)[0] = nil
|
||||
*bufs = (*bufs)[1:]
|
||||
continue
|
||||
}
|
||||
n := int64(len(part))
|
||||
if n > remaining {
|
||||
n = remaining
|
||||
}
|
||||
_, _ = c.buf.Write(part[:n])
|
||||
written += n
|
||||
remaining -= n
|
||||
if n == int64(len(part)) {
|
||||
(*bufs)[0] = nil
|
||||
*bufs = (*bufs)[1:]
|
||||
continue
|
||||
}
|
||||
(*bufs)[0] = part[n:]
|
||||
break
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
type unwrapVectoredConn struct {
|
||||
inner net.Conn
|
||||
}
|
||||
|
||||
func (c *unwrapVectoredConn) Read(p []byte) (int, error) { return c.inner.Read(p) }
|
||||
func (c *unwrapVectoredConn) Write(p []byte) (int, error) { return c.inner.Write(p) }
|
||||
func (c *unwrapVectoredConn) Close() error { return c.inner.Close() }
|
||||
func (c *unwrapVectoredConn) LocalAddr() net.Addr { return c.inner.LocalAddr() }
|
||||
func (c *unwrapVectoredConn) RemoteAddr() net.Addr { return c.inner.RemoteAddr() }
|
||||
func (c *unwrapVectoredConn) SetDeadline(t time.Time) error { return c.inner.SetDeadline(t) }
|
||||
func (c *unwrapVectoredConn) SetReadDeadline(t time.Time) error { return c.inner.SetReadDeadline(t) }
|
||||
func (c *unwrapVectoredConn) SetWriteDeadline(t time.Time) error { return c.inner.SetWriteDeadline(t) }
|
||||
func (c *unwrapVectoredConn) UnwrapConn() net.Conn { return c.inner }
|
||||
|
||||
func TestWriteNetBuffersFullUnlockedFallsBackToDirectWritesAfterFirstPartialVectoredWrite(t *testing.T) {
|
||||
conn := &vectoredShortWriteConn{steps: []int64{3}}
|
||||
header := []byte("head")
|
||||
payload := []byte("payload")
|
||||
if err := writeNetBuffersFullUnlocked(conn, net.Buffers{header, payload}); err != nil {
|
||||
t.Fatalf("writeNetBuffersFullUnlocked failed: %v", err)
|
||||
}
|
||||
if got, want := conn.writev, 1; got != want {
|
||||
t.Fatalf("vectored write calls = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := conn.writes, 2; got != want {
|
||||
t.Fatalf("fallback direct writes = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := conn.buf.String(), "headpayload"; got != want {
|
||||
t.Fatalf("written bytes = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteNetBuffersFullUnlockedReturnsNoProgressWhenVectoredWriteDoesNotAdvance(t *testing.T) {
|
||||
conn := &vectoredShortWriteConn{steps: []int64{0}}
|
||||
err := writeNetBuffersFullUnlocked(conn, net.Buffers{[]byte("head"), []byte("payload")})
|
||||
if !errors.Is(err, io.ErrNoProgress) {
|
||||
t.Fatalf("writeNetBuffersFullUnlocked error = %v, want %v", err, io.ErrNoProgress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteNetBuffersFullUnlockedUsesUnwrappedVectoredConn(t *testing.T) {
|
||||
inner := &vectoredShortWriteConn{steps: []int64{100}}
|
||||
conn := &unwrapVectoredConn{inner: inner}
|
||||
if err := writeNetBuffersFullUnlocked(conn, net.Buffers{[]byte("head"), []byte("payload")}); err != nil {
|
||||
t.Fatalf("writeNetBuffersFullUnlocked failed: %v", err)
|
||||
}
|
||||
if got, want := inner.writev, 1; got != want {
|
||||
t.Fatalf("unwrapped vectored write calls = %d, want %d", got, want)
|
||||
}
|
||||
if got := inner.writes; got != 0 {
|
||||
t.Fatalf("unexpected fallback direct writes = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) {
|
||||
conn := newBlockingPacketWriteConn()
|
||||
binding := newTransportBinding(conn, stario.NewQueue())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user