From 4f760f28074b97500a26420bbfabb9d11cd6ac6b Mon Sep 17 00:00:00 2001 From: starainrt Date: Thu, 16 Apr 2026 17:27:48 +0800 Subject: [PATCH] =?UTF-8?q?=20fix:=20=E4=BF=AE=E5=A4=8D=20dedicated=20bulk?= =?UTF-8?q?=20attach=20=E7=AB=9E=E6=80=81=E5=B9=B6=E4=BC=98=E5=8C=96=20sho?= =?UTF-8?q?rt=20write=20=E8=A1=A5=E5=86=99=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 客户端 dedicated attach 回复改为精确读取单帧,避免 attach reply 与后续 NBR1 数据粘连后被误解析 - 服务端 accepted attach 改为先 detach transport,再直接回 attach reply,随后立即切入 dedicated bulk read loop - transport 读循环在 stop 或 transport ownership 失效后不再继续上推已读数据,避免 handoff 后首包被旧 reader 吃掉 - dedicated bulk record 写路径改为 full-write,消除 short write 导致的 invalid bulk fast payload - 优化 vectored write 补写策略:先尝试一次 writev,未写完时直接顺序补完剩余 buffers,减少重复 WriteTo 开销 - 放宽 vectored write 能力识别,支持通过 UnwrapConn/WriteBuffers 命中 fast path - 修复 dedicated batch 排队路径 payload 复用问题,改为深拷贝 queued items - 补齐 dedicated attach、short write、payload clone、transport stop/handoff 等回归测试 --- bulk_dedicated.go | 146 ++++++++++++++++------- bulk_dedicated_attach_test.go | 217 ++++++++++++++++++++++++++++++++++ bulk_dedicated_batch.go | 17 ++- bulk_dedicated_batch_test.go | 33 ++++++ bulk_dedicated_record_test.go | 66 +++++++++++ client_conn_session_test.go | 18 +++ client_conn_transport.go | 51 ++++++++ transport_write.go | 80 ++++++++++++- transport_write_test.go | 109 +++++++++++++++++ 9 files changed, 690 insertions(+), 47 deletions(-) create mode 100644 bulk_dedicated_attach_test.go create mode 100644 bulk_dedicated_record_test.go diff --git a/bulk_dedicated.go b/bulk_dedicated.go index da4118d..d3bf44d 100644 --- a/bulk_dedicated.go +++ b/bulk_dedicated.go @@ -3,11 +3,13 @@ package notify import ( "b612.me/notify/internal/transport" "b612.me/stario" + "bytes" "context" cryptorand "crypto/rand" "encoding/binary" "encoding/hex" "errors" + "fmt" "io" "net" "sync/atomic" @@ -19,8 +21,17 @@ const ( bulkDedicatedRecordMagic = "NBR1" bulkDedicatedRecordHeaderLen = 8 bulkDedicatedAttachTimeout = 5 * time.Second + + bulkDedicatedAttachFrameMagicSize = 8 + bulkDedicatedAttachFrameHeaderLen = 14 + bulkDedicatedAttachFrameVersionOffset = 12 + bulkDedicatedAttachFrameFlagsOffset = 13 + bulkDedicatedAttachFrameVersionV1 = 1 + bulkDedicatedAttachFrameFlagsNone = 0 ) +var bulkDedicatedAttachFrameMagic = [bulkDedicatedAttachFrameMagicSize]byte{11, 27, 19, 96, 12, 25, 2, 20} + type bulkAttachRequest struct { PeerID string BulkID string @@ -121,6 +132,35 @@ func decodeDirectSignalPayload(sequenceDe func([]byte) (interface{}, error), msg return unwrapTransferMsgEnvelope(env, sequenceDe) } +func readDirectSignalFramePayload(conn net.Conn) ([]byte, error) { + if conn == nil { + return nil, net.ErrClosed + } + var header [bulkDedicatedAttachFrameHeaderLen]byte + if _, err := io.ReadFull(conn, header[:]); err != nil { + return nil, err + } + if !bytes.Equal(header[:bulkDedicatedAttachFrameMagicSize], bulkDedicatedAttachFrameMagic[:]) { + return nil, stario.ErrQueueDataFormat + } + if got := header[bulkDedicatedAttachFrameVersionOffset]; got != bulkDedicatedAttachFrameVersionV1 { + return nil, stario.ErrQueueUnsupportedVersion + } + if got := header[bulkDedicatedAttachFrameFlagsOffset]; got != bulkDedicatedAttachFrameFlagsNone { + return nil, stario.ErrQueueUnsupportedFlags + } + length := binary.BigEndian.Uint32(header[bulkDedicatedAttachFrameMagicSize : bulkDedicatedAttachFrameMagicSize+4]) + maxInt := int(^uint(0) >> 1) + if uint64(length) > uint64(maxInt) { + return nil, stario.ErrQueueMessageTooLarge + } + payload := make([]byte, int(length)) + if _, err := io.ReadFull(conn, payload); err != nil { + return nil, err + } + return payload, nil +} + func writeBulkDedicatedRecord(conn net.Conn, payload []byte) error { return writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{}) } @@ -133,9 +173,7 @@ func writeBulkDedicatedRecordWithDeadline(conn net.Conn, payload []byte, deadlin var header [bulkDedicatedRecordHeaderLen]byte copy(header[:4], bulkDedicatedRecordMagic) binary.BigEndian.PutUint32(header[4:8], uint32(len(payload))) - buffers := net.Buffers{header[:], payload} - _, err := buffers.WriteTo(conn) - return err + return writeNetBuffersFullUnlocked(conn, net.Buffers{header[:], payload}) }) } @@ -148,7 +186,7 @@ func readBulkDedicatedRecord(conn net.Conn) ([]byte, error) { return nil, err } if string(header[:4]) != bulkDedicatedRecordMagic { - return nil, errBulkFastPayloadInvalid + return nil, fmt.Errorf("%w: record magic=%x", errBulkFastPayloadInvalid, header[:4]) } size := int(binary.BigEndian.Uint32(header[4:8])) if size < 0 { @@ -224,51 +262,34 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn if err != nil { return bulkAttachResponse{}, err } - queue := stario.NewQueue() msg := TransferMsg{ ID: atomic.AddUint64(&c.msgID, 1), Key: systemBulkAttachKey, Value: reqPayload, Type: MSG_SYS_WAIT, } - frame, err := encodeDirectSignalFrame(queue, c.sequenceEn, c.msgEn, c.SecretKey, msg) + frame, err := encodeDirectSignalFrame(stario.NewQueue(), c.sequenceEn, c.msgEn, c.SecretKey, msg) if err != nil { return bulkAttachResponse{}, err } if err := writeFullToConn(conn, frame); err != nil { return bulkAttachResponse{}, err } - replyCh := make(chan Message, 1) - readBuf := streamReadBuffer() - for { - if deadline, ok := ctx.Deadline(); ok { - _ = conn.SetReadDeadline(deadline) - } - n, err := conn.Read(readBuf) - if err != nil { - return bulkAttachResponse{}, err - } - parseErr := queue.ParseMessageOwned(readBuf[:n], "bulk-attach", func(msgq stario.MsgQueue) error { - transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, msgq.Msg) - if err != nil { - return err - } - replyCh <- Message{ - ServerConn: c, - TransferMsg: transfer, - NetType: NET_CLIENT, - } - return nil - }) - if parseErr != nil { - return bulkAttachResponse{}, parseErr - } - select { - case reply := <-replyCh: - return decodeBulkAttachResponse(c.sequenceDe, reply.Value) - default: - } + if deadline, ok := ctx.Deadline(); ok { + _ = conn.SetReadDeadline(deadline) } + replyPayload, err := readDirectSignalFramePayload(conn) + if err != nil { + return bulkAttachResponse{}, err + } + transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, replyPayload) + if err != nil { + return bulkAttachResponse{}, err + } + if transfer.Key != systemBulkAttachKey || transfer.Type != MSG_SYS_REPLY || transfer.ID != msg.ID { + return bulkAttachResponse{}, errors.New("invalid bulk attach reply") + } + return decodeBulkAttachResponse(c.sequenceDe, transfer.Value) } func (c *ClientCommon) readDedicatedBulkLoop(bulk *bulkHandle, conn net.Conn) { @@ -323,14 +344,13 @@ func (s *ServerCommon) handleBulkAttachSystemMessage(message Message) bool { } if err != nil { resp.Error = err.Error() - } else { - resp.Accepted = true + if current != nil { + _ = s.replyDedicatedBulkAttach(current, message, resp) + } + return true } if current != nil { - _ = s.replyDedicatedBulkAttach(current, message, resp) - } - if err == nil { - if attachErr := s.finishInboundDedicatedBulkAttach(current, logical, bulk); attachErr != nil { + if attachErr := s.finishInboundDedicatedBulkAttach(current, logical, bulk, message); attachErr != nil { bulk.markReset(attachErr) } } @@ -368,7 +388,7 @@ func (s *ServerCommon) resolveInboundDedicatedBulk(current *LogicalConn, req bul return logical, bulk, nil } -func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle) error { +func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle, message Message) error { if current == nil || logical == nil || bulk == nil { return errBulkLogicalConnNil } @@ -376,18 +396,56 @@ func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, lo if err != nil { return err } - if err := bulk.attachDedicatedConn(conn); err != nil { + fail := func(reason string, err error) error { if conn != nil { _ = conn.Close() } + current.markSessionStopped(reason, err) + s.removeLogical(current) return err } + if err := s.replyDedicatedBulkAttachDetached(current, conn, message, bulkAttachResponse{Accepted: true}); err != nil { + return fail("bulk dedicated attach reply failed", err) + } + if err := bulk.attachDedicatedConn(conn); err != nil { + return fail("bulk dedicated attach failed", err) + } go s.readDedicatedBulkLoop(logical, bulk, conn) current.markSessionStopped("bulk dedicated attach", nil) s.removeLogical(current) return nil } +func (s *ServerCommon) replyDedicatedBulkAttachDetached(client *LogicalConn, conn net.Conn, message Message, resp bulkAttachResponse) error { + if s == nil || client == nil { + return errBulkServerNil + } + if conn == nil { + return net.ErrClosed + } + msgEn := client.msgEnSnapshot() + if msgEn == nil { + return errTransportPayloadEncryptFailed + } + encoded, err := s.sequenceEn(resp) + if err != nil { + return err + } + reply := TransferMsg{ + ID: message.ID, + Key: systemBulkAttachKey, + Value: encoded, + Type: MSG_SYS_REPLY, + } + frame, err := encodeDirectSignalFrame(stario.NewQueue(), s.sequenceEn, msgEn, client.secretKeySnapshot(), reply) + if err != nil { + return err + } + return withRawConnWriteLockDeadline(conn, writeDeadlineFromTimeout(client.maxWriteTimeoutSnapshot()), func(conn net.Conn) error { + return writeFullToConnUnlocked(conn, frame) + }) +} + func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Message, resp bulkAttachResponse) error { if s == nil || client == nil { return errBulkServerNil diff --git a/bulk_dedicated_attach_test.go b/bulk_dedicated_attach_test.go new file mode 100644 index 0000000..80f01ad --- /dev/null +++ b/bulk_dedicated_attach_test.go @@ -0,0 +1,217 @@ +package notify + +import ( + "bytes" + "context" + "encoding/binary" + "net" + "testing" + "time" + + "b612.me/stario" +) + +type bulkAttachScriptConn struct { + readBuf *bytes.Reader + writeBuf bytes.Buffer +} + +func newBulkAttachScriptConn(inbound []byte) *bulkAttachScriptConn { + return &bulkAttachScriptConn{ + readBuf: bytes.NewReader(append([]byte(nil), inbound...)), + } +} + +func (c *bulkAttachScriptConn) Read(p []byte) (int, error) { return c.readBuf.Read(p) } +func (c *bulkAttachScriptConn) Write(p []byte) (int, error) { return c.writeBuf.Write(p) } +func (c *bulkAttachScriptConn) Close() error { return nil } +func (c *bulkAttachScriptConn) LocalAddr() net.Addr { return bulkAttachTestAddr("local") } +func (c *bulkAttachScriptConn) RemoteAddr() net.Addr { return bulkAttachTestAddr("remote") } +func (c *bulkAttachScriptConn) SetDeadline(time.Time) error { return nil } +func (c *bulkAttachScriptConn) SetReadDeadline(time.Time) error { + return nil +} +func (c *bulkAttachScriptConn) SetWriteDeadline(time.Time) error { + return nil +} + +func (c *bulkAttachScriptConn) writtenBytes() []byte { + return append([]byte(nil), c.writeBuf.Bytes()...) +} + +type bulkAttachTestAddr string + +func (a bulkAttachTestAddr) Network() string { return "tcp" } +func (a bulkAttachTestAddr) String() string { return string(a) } + +func encodeDedicatedRecordForAttachTest(payload []byte) []byte { + out := make([]byte, bulkDedicatedRecordHeaderLen+len(payload)) + copy(out[:4], bulkDedicatedRecordMagic) + binary.BigEndian.PutUint32(out[4:8], uint32(len(payload))) + copy(out[bulkDedicatedRecordHeaderLen:], payload) + return out +} + +func TestSendDedicatedBulkAttachRequestKeepsCoalescedDedicatedPayloadUnread(t *testing.T) { + client := NewClient().(*ClientCommon) + UseLegacySecurityClient(client) + client.msgID = 100 + + bulk := newBulkHandle(context.Background(), newBulkRuntime("dedicated-attach-test"), clientFileScope(), BulkOpenRequest{ + BulkID: "bulk-attach-test", + DataID: 1, + Dedicated: true, + AttachToken: "attach-token", + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + + encodedResp, err := client.sequenceEn(bulkAttachResponse{Accepted: true}) + if err != nil { + t.Fatalf("encode bulkAttachResponse failed: %v", err) + } + replyFrame, err := encodeDirectSignalFrame(stario.NewQueue(), client.sequenceEn, client.msgEn, client.SecretKey, TransferMsg{ + ID: 101, + Key: systemBulkAttachKey, + Value: encodedResp, + Type: MSG_SYS_REPLY, + }) + if err != nil { + t.Fatalf("encode attach reply frame failed: %v", err) + } + dedicatedPayload := []byte("dedicated-tail-bytes") + conn := newBulkAttachScriptConn(append(replyFrame, encodeDedicatedRecordForAttachTest(dedicatedPayload)...)) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + resp, err := client.sendDedicatedBulkAttachRequest(ctx, conn, bulk) + if err != nil { + t.Fatalf("sendDedicatedBulkAttachRequest failed: %v", err) + } + if !resp.Accepted { + t.Fatalf("bulk attach response = %+v, want accepted", resp) + } + + parsedReq := stario.NewQueue() + var reqMsg TransferMsg + if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request", func(msgq stario.MsgQueue) error { + transfer, err := decodeDirectSignalPayload(client.sequenceDe, client.msgDe, client.SecretKey, msgq.Msg) + if err != nil { + return err + } + reqMsg = transfer + return nil + }); err != nil { + t.Fatalf("parse written attach request failed: %v", err) + } + if reqMsg.Key != systemBulkAttachKey || reqMsg.Type != MSG_SYS_WAIT { + t.Fatalf("attach request message mismatch: %+v", reqMsg) + } + readPayload, err := readBulkDedicatedRecord(conn) + if err != nil { + t.Fatalf("readBulkDedicatedRecord after attach failed: %v", err) + } + if !bytes.Equal(readPayload, dedicatedPayload) { + t.Fatalf("dedicated payload mismatch: got %q want %q", string(readPayload), string(dedicatedPayload)) + } +} + +func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHandoff(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + sidecarLeft, sidecarRight := net.Pipe() + defer sidecarRight.Close() + + current := server.bootstrapAcceptedLogical("dedicated-attach-current", nil, sidecarLeft) + if current == nil { + t.Fatal("bootstrapAcceptedLogical(current) should return logical") + } + target := server.bootstrapAcceptedLogical("dedicated-attach-target", nil, nil) + if target == nil { + t.Fatal("bootstrapAcceptedLogical(target) should return logical") + } + + bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{ + BulkID: "server-dedicated-attach-test", + DataID: 7, + Dedicated: true, + AttachToken: "attach-token", + }, 0, target, nil, 0, nil, nil, nil, nil, nil) + if err := server.getBulkRuntime().register(serverFileScope(target), bulk); err != nil { + t.Fatalf("register bulk runtime failed: %v", err) + } + + reqPayload, err := server.sequenceEn(bulkAttachRequest{ + PeerID: target.ID(), + BulkID: bulk.ID(), + AttachToken: "attach-token", + }) + if err != nil { + t.Fatalf("encode bulkAttachRequest failed: %v", err) + } + msg := Message{ + NetType: NET_SERVER, + LogicalConn: current, + ClientConn: current.compatClientConn(), + TransferMsg: TransferMsg{ + ID: 42, + Key: systemBulkAttachKey, + Value: reqPayload, + Type: MSG_SYS_WAIT, + }, + inboundConn: sidecarLeft, + Time: time.Now(), + } + + type attachReplyResult struct { + transfer TransferMsg + resp bulkAttachResponse + err error + } + replyCh := make(chan attachReplyResult, 1) + go func() { + _ = sidecarRight.SetReadDeadline(time.Now().Add(time.Second)) + replyPayload, err := readDirectSignalFramePayload(sidecarRight) + if err != nil { + replyCh <- attachReplyResult{err: err} + return + } + transfer, err := decodeDirectSignalPayload(server.sequenceDe, current.msgDeSnapshot(), current.secretKeySnapshot(), replyPayload) + if err != nil { + replyCh <- attachReplyResult{err: err} + return + } + resp, err := decodeBulkAttachResponse(server.sequenceDe, transfer.Value) + replyCh <- attachReplyResult{transfer: transfer, resp: resp, err: err} + }() + + if !server.handleBulkAttachSystemMessage(msg) { + t.Fatal("handleBulkAttachSystemMessage should accept dedicated attach message") + } + + var result attachReplyResult + select { + case result = <-replyCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for direct attach reply") + } + if result.err != nil { + t.Fatalf("read direct attach reply failed: %v", result.err) + } + transfer := result.transfer + if transfer.ID != msg.ID || transfer.Key != systemBulkAttachKey || transfer.Type != MSG_SYS_REPLY { + t.Fatalf("attach reply mismatch: %+v", transfer) + } + resp := result.resp + if !resp.Accepted || resp.Error != "" { + t.Fatalf("bulk attach response = %+v, want accepted", resp) + } + if got := bulk.dedicatedConnSnapshot(); got != sidecarLeft { + t.Fatalf("dedicated conn mismatch: got %v want %v", got, sidecarLeft) + } + if current.transportAttachedSnapshot() { + t.Fatal("attach sidecar logical transport should be detached after handoff") + } + if got := server.GetLogicalConn(current.ID()); got != nil { + t.Fatalf("attach sidecar logical should be removed after handoff, got %+v", got) + } +} diff --git a/bulk_dedicated_batch.go b/bulk_dedicated_batch.go index 6d3b21d..50490c0 100644 --- a/bulk_dedicated_batch.go +++ b/bulk_dedicated_batch.go @@ -165,8 +165,7 @@ func (s *bulkDedicatedSender) submitWriteBatch(ctx context.Context, items []bulk if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted { return err } - queuedItems := make([]bulkDedicatedSendRequest, len(items)) - copy(queuedItems, items) + queuedItems := cloneBulkDedicatedSendRequests(items) return s.submitBatch(ctx, queuedItems, true) } @@ -458,6 +457,20 @@ func (s *bulkDedicatedSender) stoppedErr() error { return errTransportDetached } +func cloneBulkDedicatedSendRequests(items []bulkDedicatedSendRequest) []bulkDedicatedSendRequest { + if len(items) == 0 { + return nil + } + cloned := make([]bulkDedicatedSendRequest, len(items)) + for i, item := range items { + cloned[i] = item + if len(item.Payload) > 0 { + cloned[i].Payload = append([]byte(nil), item.Payload...) + } + } + return cloned +} + func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int { return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload)) } diff --git a/bulk_dedicated_batch_test.go b/bulk_dedicated_batch_test.go index a09ca36..bfc8ffb 100644 --- a/bulk_dedicated_batch_test.go +++ b/bulk_dedicated_batch_test.go @@ -1,6 +1,7 @@ package notify import ( + "bytes" "context" "errors" "net" @@ -8,6 +9,38 @@ import ( "time" ) +func TestCloneBulkDedicatedSendRequestsDeepCopiesPayload(t *testing.T) { + src := []bulkDedicatedSendRequest{ + { + Type: bulkFastPayloadTypeData, + Seq: 1, + Payload: []byte("payload-a"), + }, + { + Type: bulkFastPayloadTypeReset, + Seq: 2, + Payload: []byte("payload-b"), + }, + } + cloned := cloneBulkDedicatedSendRequests(src) + if len(cloned) != len(src) { + t.Fatalf("clone length = %d, want %d", len(cloned), len(src)) + } + if &cloned[0] == &src[0] { + t.Fatal("request clone should not alias source slice elements") + } + if len(cloned[0].Payload) == 0 || len(src[0].Payload) == 0 { + t.Fatal("payload should not be empty") + } + if &cloned[0].Payload[0] == &src[0].Payload[0] { + t.Fatal("payload clone should not alias source bytes") + } + src[0].Payload[0] = 'X' + if bytes.Equal(cloned[0].Payload, src[0].Payload) { + t.Fatal("mutating source payload should not affect cloned payload") + } +} + func TestBulkDedicatedBatchPlainRoundTrip(t *testing.T) { releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2) if err != nil { diff --git a/bulk_dedicated_record_test.go b/bulk_dedicated_record_test.go new file mode 100644 index 0000000..9ed334b --- /dev/null +++ b/bulk_dedicated_record_test.go @@ -0,0 +1,66 @@ +package notify + +import ( + "bytes" + "encoding/binary" + "io" + "net" + "testing" + "time" +) + +type shortWriteBulkRecordConn struct { + maxPerWrite int + buf bytes.Buffer +} + +func (c *shortWriteBulkRecordConn) Read([]byte) (int, error) { return 0, io.EOF } + +func (c *shortWriteBulkRecordConn) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + n := c.maxPerWrite + if n <= 0 || n > len(p) { + n = len(p) + } + _, _ = c.buf.Write(p[:n]) + return n, nil +} + +func (c *shortWriteBulkRecordConn) Close() error { return nil } +func (c *shortWriteBulkRecordConn) LocalAddr() net.Addr { return shortWriteBulkRecordAddr("local") } +func (c *shortWriteBulkRecordConn) RemoteAddr() net.Addr { return shortWriteBulkRecordAddr("remote") } +func (c *shortWriteBulkRecordConn) SetDeadline(time.Time) error { return nil } +func (c *shortWriteBulkRecordConn) SetReadDeadline(time.Time) error { + return nil +} +func (c *shortWriteBulkRecordConn) SetWriteDeadline(time.Time) error { + return nil +} + +type shortWriteBulkRecordAddr string + +func (a shortWriteBulkRecordAddr) Network() string { return "tcp" } +func (a shortWriteBulkRecordAddr) String() string { return string(a) } + +func TestWriteBulkDedicatedRecordWithDeadlineHandlesShortWrite(t *testing.T) { + conn := &shortWriteBulkRecordConn{maxPerWrite: 3} + payload := []byte("abcdefghijklmnopqrstuvwxyz") + if err := writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{}); err != nil { + t.Fatalf("writeBulkDedicatedRecordWithDeadline failed: %v", err) + } + raw := conn.buf.Bytes() + if got, want := len(raw), bulkDedicatedRecordHeaderLen+len(payload); got != want { + t.Fatalf("record length = %d, want %d", got, want) + } + if got := string(raw[:4]); got != bulkDedicatedRecordMagic { + t.Fatalf("record magic = %q, want %q", got, bulkDedicatedRecordMagic) + } + if got, want := int(binary.BigEndian.Uint32(raw[4:8])), len(payload); got != want { + t.Fatalf("record payload length = %d, want %d", got, want) + } + if got := raw[bulkDedicatedRecordHeaderLen:]; !bytes.Equal(got, payload) { + t.Fatalf("record payload mismatch") + } +} diff --git a/client_conn_session_test.go b/client_conn_session_test.go index 643f5fe..c083641 100644 --- a/client_conn_session_test.go +++ b/client_conn_session_test.go @@ -332,6 +332,24 @@ func TestLogicalDetachTransportForTransferKeepsHandoffConnAlive(t *testing.T) { } } +func TestLogicalHandleTUTransportReadResultWithSessionDropsDataAfterTransportStop(t *testing.T) { + server := NewServer().(*ServerCommon) + left, right := net.Pipe() + defer right.Close() + + stopCtx, stopFn := context.WithCancel(context.Background()) + logical, _, _ := newRegisteredServerLogicalForTest(t, server, "logical-stop-read-drop", left, stopCtx, stopFn) + if logical == nil { + t.Fatal("logical should not be nil") + } + generation := logical.transportGenerationSnapshot() + stopFn() + + if logical.handleTUTransportReadResultWithSession(stopCtx, left, generation, len([]byte("late-data")), []byte("late-data"), nil) { + t.Fatal("handleTUTransportReadResultWithSession should stop after transport stop") + } +} + func TestClientConnTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) { client := &ClientConn{} left, right := net.Pipe() diff --git a/client_conn_transport.go b/client_conn_transport.go index a33b7aa..fd967a0 100644 --- a/client_conn_transport.go +++ b/client_conn_transport.go @@ -69,6 +69,12 @@ func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []by } func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool { + if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) { + if c.shouldCloseTransportOnStop(conn) { + _ = conn.Close() + } + return false + } if err == os.ErrDeadlineExceeded { if num != 0 { c.pushServerOwnedTransportMessage(data[:num], conn, generation) @@ -95,6 +101,30 @@ func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Con return true } +func transportReadShouldStop(stopCtx context.Context) bool { + select { + case <-sessionStopChan(stopCtx): + return true + default: + return false + } +} + +func (c *LogicalConn) ownsTransportRead(conn net.Conn, generation uint64) bool { + if c == nil { + return false + } + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil || !rt.transportAttached || rt.transportGeneration != generation { + return false + } + current := rt.tuConn + if rt.transport != nil && rt.transport.connSnapshot() != nil { + current = rt.transport.connSnapshot() + } + return current == conn +} + func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) { if c == nil || len(data) == 0 { return @@ -163,6 +193,12 @@ func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Cont if logical := c.LogicalConn(); logical != nil { return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) } + if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) { + if c.shouldCloseClientConnTransportOnStop(conn) { + _ = conn.Close() + } + return false + } if err == os.ErrDeadlineExceeded { if num != 0 { c.pushServerOwnedTransportMessage(data[:num], conn, generation) @@ -189,6 +225,21 @@ func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Cont return true } +func (c *ClientConn) ownsTransportRead(conn net.Conn, generation uint64) bool { + if c == nil { + return false + } + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil || !rt.transportAttached || rt.transportGeneration != generation { + return false + } + current := rt.tuConn + if rt.transport != nil && rt.transport.connSnapshot() != nil { + current = rt.transport.connSnapshot() + } + return current == conn +} + func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) { if logical := c.LogicalConn(); logical != nil { logical.pushServerOwnedTransportMessage(data, conn, generation) diff --git a/transport_write.go b/transport_write.go index 7471647..87e80c0 100644 --- a/transport_write.go +++ b/transport_write.go @@ -13,6 +13,14 @@ import ( var transportConnWriteLocks sync.Map var errTransportFrameQueueUnavailable = errors.New("transport frame queue is unavailable") +type vectoredBuffersWriter interface { + WriteBuffers(*net.Buffers) (int64, error) +} + +type vectoredConnUnwrapper interface { + UnwrapConn() net.Conn +} + func writeFullToConn(conn net.Conn, data []byte) error { if conn == nil { return net.ErrClosed @@ -26,8 +34,15 @@ func writeFullToConnUnlocked(conn net.Conn, data []byte) error { if conn == nil { return net.ErrClosed } + return writeFullToWriterUnlocked(conn, data) +} + +func writeFullToWriterUnlocked(writer io.Writer, data []byte) error { + if writer == nil { + return io.ErrClosedPipe + } for len(data) > 0 { - n, err := conn.Write(data) + n, err := writer.Write(data) if n > 0 { data = data[n:] } @@ -41,6 +56,69 @@ func writeFullToConnUnlocked(conn net.Conn, data []byte) error { return nil } +func writeNetBuffersFullUnlocked(conn net.Conn, buffers net.Buffers) error { + if conn == nil { + return net.ErrClosed + } + writer, writeFn := vectoredWriteStrategy(conn) + if writeFn == nil { + return writeRemainingBuffersUnlocked(conn, buffers) + } + n, err := writeFn(&buffers) + if err != nil { + return err + } + if len(buffers) == 0 { + return nil + } + if n == 0 { + return io.ErrNoProgress + } + return writeRemainingBuffersUnlocked(writer, buffers) +} + +func vectoredWriteStrategy(conn net.Conn) (io.Writer, func(*net.Buffers) (int64, error)) { + current := conn + for depth := 0; depth < 8 && current != nil; depth++ { + if writer, ok := current.(vectoredBuffersWriter); ok { + target := current + return target, writer.WriteBuffers + } + switch target := current.(type) { + case *net.TCPConn: + return target, func(bufs *net.Buffers) (int64, error) { + return bufs.WriteTo(target) + } + case *net.UnixConn: + return target, func(bufs *net.Buffers) (int64, error) { + return bufs.WriteTo(target) + } + } + unwrapper, ok := current.(vectoredConnUnwrapper) + if !ok { + break + } + next := unwrapper.UnwrapConn() + if next == nil || next == current { + break + } + current = next + } + return nil, nil +} + +func writeRemainingBuffersUnlocked(writer io.Writer, buffers net.Buffers) error { + for _, part := range buffers { + if len(part) == 0 { + continue + } + if err := writeFullToWriterUnlocked(writer, part); err != nil { + return err + } + } + return nil +} + func withRawConnWriteLock(conn net.Conn, fn func(net.Conn) error) error { return withRawConnWriteLockDeadline(conn, time.Time{}, fn) } diff --git a/transport_write_test.go b/transport_write_test.go index 4b56f78..0764185 100644 --- a/transport_write_test.go +++ b/transport_write_test.go @@ -2,8 +2,10 @@ package notify import ( "b612.me/stario" + "bytes" "context" "errors" + "io" "net" "sync" "sync/atomic" @@ -146,6 +148,113 @@ func (c *blockingPacketWriteConn) Write(p []byte) (int, error) { return len(p), nil } +type vectoredShortWriteConn struct { + steps []int64 + idx int + buf bytes.Buffer + writes int + writev int +} + +func (c *vectoredShortWriteConn) Read([]byte) (int, error) { return 0, io.EOF } +func (c *vectoredShortWriteConn) Write(p []byte) (int, error) { + c.writes++ + return c.buf.Write(p) +} +func (c *vectoredShortWriteConn) Close() error { return nil } +func (c *vectoredShortWriteConn) LocalAddr() net.Addr { return nil } +func (c *vectoredShortWriteConn) RemoteAddr() net.Addr { return nil } +func (c *vectoredShortWriteConn) SetDeadline(time.Time) error { return nil } +func (c *vectoredShortWriteConn) SetReadDeadline(time.Time) error { return nil } +func (c *vectoredShortWriteConn) SetWriteDeadline(time.Time) error { return nil } + +func (c *vectoredShortWriteConn) WriteBuffers(bufs *net.Buffers) (int64, error) { + c.writev++ + if c.idx >= len(c.steps) { + return 0, io.ErrNoProgress + } + remaining := c.steps[c.idx] + c.idx++ + written := int64(0) + for len(*bufs) > 0 && remaining > 0 { + part := (*bufs)[0] + if len(part) == 0 { + (*bufs)[0] = nil + *bufs = (*bufs)[1:] + continue + } + n := int64(len(part)) + if n > remaining { + n = remaining + } + _, _ = c.buf.Write(part[:n]) + written += n + remaining -= n + if n == int64(len(part)) { + (*bufs)[0] = nil + *bufs = (*bufs)[1:] + continue + } + (*bufs)[0] = part[n:] + break + } + return written, nil +} + +type unwrapVectoredConn struct { + inner net.Conn +} + +func (c *unwrapVectoredConn) Read(p []byte) (int, error) { return c.inner.Read(p) } +func (c *unwrapVectoredConn) Write(p []byte) (int, error) { return c.inner.Write(p) } +func (c *unwrapVectoredConn) Close() error { return c.inner.Close() } +func (c *unwrapVectoredConn) LocalAddr() net.Addr { return c.inner.LocalAddr() } +func (c *unwrapVectoredConn) RemoteAddr() net.Addr { return c.inner.RemoteAddr() } +func (c *unwrapVectoredConn) SetDeadline(t time.Time) error { return c.inner.SetDeadline(t) } +func (c *unwrapVectoredConn) SetReadDeadline(t time.Time) error { return c.inner.SetReadDeadline(t) } +func (c *unwrapVectoredConn) SetWriteDeadline(t time.Time) error { return c.inner.SetWriteDeadline(t) } +func (c *unwrapVectoredConn) UnwrapConn() net.Conn { return c.inner } + +func TestWriteNetBuffersFullUnlockedFallsBackToDirectWritesAfterFirstPartialVectoredWrite(t *testing.T) { + conn := &vectoredShortWriteConn{steps: []int64{3}} + header := []byte("head") + payload := []byte("payload") + if err := writeNetBuffersFullUnlocked(conn, net.Buffers{header, payload}); err != nil { + t.Fatalf("writeNetBuffersFullUnlocked failed: %v", err) + } + if got, want := conn.writev, 1; got != want { + t.Fatalf("vectored write calls = %d, want %d", got, want) + } + if got, want := conn.writes, 2; got != want { + t.Fatalf("fallback direct writes = %d, want %d", got, want) + } + if got, want := conn.buf.String(), "headpayload"; got != want { + t.Fatalf("written bytes = %q, want %q", got, want) + } +} + +func TestWriteNetBuffersFullUnlockedReturnsNoProgressWhenVectoredWriteDoesNotAdvance(t *testing.T) { + conn := &vectoredShortWriteConn{steps: []int64{0}} + err := writeNetBuffersFullUnlocked(conn, net.Buffers{[]byte("head"), []byte("payload")}) + if !errors.Is(err, io.ErrNoProgress) { + t.Fatalf("writeNetBuffersFullUnlocked error = %v, want %v", err, io.ErrNoProgress) + } +} + +func TestWriteNetBuffersFullUnlockedUsesUnwrappedVectoredConn(t *testing.T) { + inner := &vectoredShortWriteConn{steps: []int64{100}} + conn := &unwrapVectoredConn{inner: inner} + if err := writeNetBuffersFullUnlocked(conn, net.Buffers{[]byte("head"), []byte("payload")}); err != nil { + t.Fatalf("writeNetBuffersFullUnlocked failed: %v", err) + } + if got, want := inner.writev, 1; got != want { + t.Fatalf("unwrapped vectored write calls = %d, want %d", got, want) + } + if got := inner.writes; got != 0 { + t.Fatalf("unexpected fallback direct writes = %d, want 0", got) + } +} + func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) { conn := newBlockingPacketWriteConn() binding := newTransportBinding(conn, stario.NewQueue())