fix: 修复 dedicated bulk attach 竞态并优化 short write 补写路径

- 客户端 dedicated attach 回复改为精确读取单帧,避免 attach reply 与后续 NBR1 数据粘连后被误解析
  - 服务端 accepted attach 改为先 detach transport,再直接回 attach reply,随后立即切入 dedicated bulk read loop
  - transport 读循环在 stop 或 transport ownership 失效后不再继续上推已读数据,避免 handoff 后首包被旧 reader 吃掉
  - dedicated bulk record 写路径改为 full-write,消除 short write 导致的 invalid bulk fast payload
  - 优化 vectored write 补写策略:先尝试一次 writev,未写完时直接顺序补完剩余 buffers,减少重复 WriteTo 开销
  - 放宽 vectored write 能力识别,支持通过 UnwrapConn/WriteBuffers 命中 fast path
  - 修复 dedicated batch 排队路径 payload 复用问题,改为深拷贝 queued items
  - 补齐 dedicated attach、short write、payload clone、transport stop/handoff 等回归测试
This commit is contained in:
兔子 2026-04-16 17:27:48 +08:00
parent 7ed3dd5b37
commit 4f760f2807
Signed by: b612
GPG Key ID: 99DD2222B612B612
9 changed files with 690 additions and 47 deletions

View File

@ -3,11 +3,13 @@ package notify
import ( import (
"b612.me/notify/internal/transport" "b612.me/notify/internal/transport"
"b612.me/stario" "b612.me/stario"
"bytes"
"context" "context"
cryptorand "crypto/rand" cryptorand "crypto/rand"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"sync/atomic" "sync/atomic"
@ -19,8 +21,17 @@ const (
bulkDedicatedRecordMagic = "NBR1" bulkDedicatedRecordMagic = "NBR1"
bulkDedicatedRecordHeaderLen = 8 bulkDedicatedRecordHeaderLen = 8
bulkDedicatedAttachTimeout = 5 * time.Second bulkDedicatedAttachTimeout = 5 * time.Second
bulkDedicatedAttachFrameMagicSize = 8
bulkDedicatedAttachFrameHeaderLen = 14
bulkDedicatedAttachFrameVersionOffset = 12
bulkDedicatedAttachFrameFlagsOffset = 13
bulkDedicatedAttachFrameVersionV1 = 1
bulkDedicatedAttachFrameFlagsNone = 0
) )
var bulkDedicatedAttachFrameMagic = [bulkDedicatedAttachFrameMagicSize]byte{11, 27, 19, 96, 12, 25, 2, 20}
type bulkAttachRequest struct { type bulkAttachRequest struct {
PeerID string PeerID string
BulkID string BulkID string
@ -121,6 +132,35 @@ func decodeDirectSignalPayload(sequenceDe func([]byte) (interface{}, error), msg
return unwrapTransferMsgEnvelope(env, sequenceDe) return unwrapTransferMsgEnvelope(env, sequenceDe)
} }
func readDirectSignalFramePayload(conn net.Conn) ([]byte, error) {
if conn == nil {
return nil, net.ErrClosed
}
var header [bulkDedicatedAttachFrameHeaderLen]byte
if _, err := io.ReadFull(conn, header[:]); err != nil {
return nil, err
}
if !bytes.Equal(header[:bulkDedicatedAttachFrameMagicSize], bulkDedicatedAttachFrameMagic[:]) {
return nil, stario.ErrQueueDataFormat
}
if got := header[bulkDedicatedAttachFrameVersionOffset]; got != bulkDedicatedAttachFrameVersionV1 {
return nil, stario.ErrQueueUnsupportedVersion
}
if got := header[bulkDedicatedAttachFrameFlagsOffset]; got != bulkDedicatedAttachFrameFlagsNone {
return nil, stario.ErrQueueUnsupportedFlags
}
length := binary.BigEndian.Uint32(header[bulkDedicatedAttachFrameMagicSize : bulkDedicatedAttachFrameMagicSize+4])
maxInt := int(^uint(0) >> 1)
if uint64(length) > uint64(maxInt) {
return nil, stario.ErrQueueMessageTooLarge
}
payload := make([]byte, int(length))
if _, err := io.ReadFull(conn, payload); err != nil {
return nil, err
}
return payload, nil
}
func writeBulkDedicatedRecord(conn net.Conn, payload []byte) error { func writeBulkDedicatedRecord(conn net.Conn, payload []byte) error {
return writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{}) return writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{})
} }
@ -133,9 +173,7 @@ func writeBulkDedicatedRecordWithDeadline(conn net.Conn, payload []byte, deadlin
var header [bulkDedicatedRecordHeaderLen]byte var header [bulkDedicatedRecordHeaderLen]byte
copy(header[:4], bulkDedicatedRecordMagic) copy(header[:4], bulkDedicatedRecordMagic)
binary.BigEndian.PutUint32(header[4:8], uint32(len(payload))) binary.BigEndian.PutUint32(header[4:8], uint32(len(payload)))
buffers := net.Buffers{header[:], payload} return writeNetBuffersFullUnlocked(conn, net.Buffers{header[:], payload})
_, err := buffers.WriteTo(conn)
return err
}) })
} }
@ -148,7 +186,7 @@ func readBulkDedicatedRecord(conn net.Conn) ([]byte, error) {
return nil, err return nil, err
} }
if string(header[:4]) != bulkDedicatedRecordMagic { if string(header[:4]) != bulkDedicatedRecordMagic {
return nil, errBulkFastPayloadInvalid return nil, fmt.Errorf("%w: record magic=%x", errBulkFastPayloadInvalid, header[:4])
} }
size := int(binary.BigEndian.Uint32(header[4:8])) size := int(binary.BigEndian.Uint32(header[4:8]))
if size < 0 { if size < 0 {
@ -224,51 +262,34 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn
if err != nil { if err != nil {
return bulkAttachResponse{}, err return bulkAttachResponse{}, err
} }
queue := stario.NewQueue()
msg := TransferMsg{ msg := TransferMsg{
ID: atomic.AddUint64(&c.msgID, 1), ID: atomic.AddUint64(&c.msgID, 1),
Key: systemBulkAttachKey, Key: systemBulkAttachKey,
Value: reqPayload, Value: reqPayload,
Type: MSG_SYS_WAIT, Type: MSG_SYS_WAIT,
} }
frame, err := encodeDirectSignalFrame(queue, c.sequenceEn, c.msgEn, c.SecretKey, msg) frame, err := encodeDirectSignalFrame(stario.NewQueue(), c.sequenceEn, c.msgEn, c.SecretKey, msg)
if err != nil { if err != nil {
return bulkAttachResponse{}, err return bulkAttachResponse{}, err
} }
if err := writeFullToConn(conn, frame); err != nil { if err := writeFullToConn(conn, frame); err != nil {
return bulkAttachResponse{}, err return bulkAttachResponse{}, err
} }
replyCh := make(chan Message, 1)
readBuf := streamReadBuffer()
for {
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetReadDeadline(deadline) _ = conn.SetReadDeadline(deadline)
} }
n, err := conn.Read(readBuf) replyPayload, err := readDirectSignalFramePayload(conn)
if err != nil { if err != nil {
return bulkAttachResponse{}, err return bulkAttachResponse{}, err
} }
parseErr := queue.ParseMessageOwned(readBuf[:n], "bulk-attach", func(msgq stario.MsgQueue) error { transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, replyPayload)
transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, msgq.Msg)
if err != nil { if err != nil {
return err return bulkAttachResponse{}, err
}
replyCh <- Message{
ServerConn: c,
TransferMsg: transfer,
NetType: NET_CLIENT,
}
return nil
})
if parseErr != nil {
return bulkAttachResponse{}, parseErr
}
select {
case reply := <-replyCh:
return decodeBulkAttachResponse(c.sequenceDe, reply.Value)
default:
} }
if transfer.Key != systemBulkAttachKey || transfer.Type != MSG_SYS_REPLY || transfer.ID != msg.ID {
return bulkAttachResponse{}, errors.New("invalid bulk attach reply")
} }
return decodeBulkAttachResponse(c.sequenceDe, transfer.Value)
} }
func (c *ClientCommon) readDedicatedBulkLoop(bulk *bulkHandle, conn net.Conn) { func (c *ClientCommon) readDedicatedBulkLoop(bulk *bulkHandle, conn net.Conn) {
@ -323,14 +344,13 @@ func (s *ServerCommon) handleBulkAttachSystemMessage(message Message) bool {
} }
if err != nil { if err != nil {
resp.Error = err.Error() resp.Error = err.Error()
} else {
resp.Accepted = true
}
if current != nil { if current != nil {
_ = s.replyDedicatedBulkAttach(current, message, resp) _ = s.replyDedicatedBulkAttach(current, message, resp)
} }
if err == nil { return true
if attachErr := s.finishInboundDedicatedBulkAttach(current, logical, bulk); attachErr != nil { }
if current != nil {
if attachErr := s.finishInboundDedicatedBulkAttach(current, logical, bulk, message); attachErr != nil {
bulk.markReset(attachErr) bulk.markReset(attachErr)
} }
} }
@ -368,7 +388,7 @@ func (s *ServerCommon) resolveInboundDedicatedBulk(current *LogicalConn, req bul
return logical, bulk, nil return logical, bulk, nil
} }
func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle) error { func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle, message Message) error {
if current == nil || logical == nil || bulk == nil { if current == nil || logical == nil || bulk == nil {
return errBulkLogicalConnNil return errBulkLogicalConnNil
} }
@ -376,18 +396,56 @@ func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, lo
if err != nil { if err != nil {
return err return err
} }
if err := bulk.attachDedicatedConn(conn); err != nil { fail := func(reason string, err error) error {
if conn != nil { if conn != nil {
_ = conn.Close() _ = conn.Close()
} }
current.markSessionStopped(reason, err)
s.removeLogical(current)
return err return err
} }
if err := s.replyDedicatedBulkAttachDetached(current, conn, message, bulkAttachResponse{Accepted: true}); err != nil {
return fail("bulk dedicated attach reply failed", err)
}
if err := bulk.attachDedicatedConn(conn); err != nil {
return fail("bulk dedicated attach failed", err)
}
go s.readDedicatedBulkLoop(logical, bulk, conn) go s.readDedicatedBulkLoop(logical, bulk, conn)
current.markSessionStopped("bulk dedicated attach", nil) current.markSessionStopped("bulk dedicated attach", nil)
s.removeLogical(current) s.removeLogical(current)
return nil return nil
} }
func (s *ServerCommon) replyDedicatedBulkAttachDetached(client *LogicalConn, conn net.Conn, message Message, resp bulkAttachResponse) error {
if s == nil || client == nil {
return errBulkServerNil
}
if conn == nil {
return net.ErrClosed
}
msgEn := client.msgEnSnapshot()
if msgEn == nil {
return errTransportPayloadEncryptFailed
}
encoded, err := s.sequenceEn(resp)
if err != nil {
return err
}
reply := TransferMsg{
ID: message.ID,
Key: systemBulkAttachKey,
Value: encoded,
Type: MSG_SYS_REPLY,
}
frame, err := encodeDirectSignalFrame(stario.NewQueue(), s.sequenceEn, msgEn, client.secretKeySnapshot(), reply)
if err != nil {
return err
}
return withRawConnWriteLockDeadline(conn, writeDeadlineFromTimeout(client.maxWriteTimeoutSnapshot()), func(conn net.Conn) error {
return writeFullToConnUnlocked(conn, frame)
})
}
func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Message, resp bulkAttachResponse) error { func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Message, resp bulkAttachResponse) error {
if s == nil || client == nil { if s == nil || client == nil {
return errBulkServerNil return errBulkServerNil

View File

@ -0,0 +1,217 @@
package notify
import (
"bytes"
"context"
"encoding/binary"
"net"
"testing"
"time"
"b612.me/stario"
)
type bulkAttachScriptConn struct {
readBuf *bytes.Reader
writeBuf bytes.Buffer
}
func newBulkAttachScriptConn(inbound []byte) *bulkAttachScriptConn {
return &bulkAttachScriptConn{
readBuf: bytes.NewReader(append([]byte(nil), inbound...)),
}
}
func (c *bulkAttachScriptConn) Read(p []byte) (int, error) { return c.readBuf.Read(p) }
func (c *bulkAttachScriptConn) Write(p []byte) (int, error) { return c.writeBuf.Write(p) }
func (c *bulkAttachScriptConn) Close() error { return nil }
func (c *bulkAttachScriptConn) LocalAddr() net.Addr { return bulkAttachTestAddr("local") }
func (c *bulkAttachScriptConn) RemoteAddr() net.Addr { return bulkAttachTestAddr("remote") }
func (c *bulkAttachScriptConn) SetDeadline(time.Time) error { return nil }
func (c *bulkAttachScriptConn) SetReadDeadline(time.Time) error {
return nil
}
func (c *bulkAttachScriptConn) SetWriteDeadline(time.Time) error {
return nil
}
func (c *bulkAttachScriptConn) writtenBytes() []byte {
return append([]byte(nil), c.writeBuf.Bytes()...)
}
type bulkAttachTestAddr string
func (a bulkAttachTestAddr) Network() string { return "tcp" }
func (a bulkAttachTestAddr) String() string { return string(a) }
func encodeDedicatedRecordForAttachTest(payload []byte) []byte {
out := make([]byte, bulkDedicatedRecordHeaderLen+len(payload))
copy(out[:4], bulkDedicatedRecordMagic)
binary.BigEndian.PutUint32(out[4:8], uint32(len(payload)))
copy(out[bulkDedicatedRecordHeaderLen:], payload)
return out
}
func TestSendDedicatedBulkAttachRequestKeepsCoalescedDedicatedPayloadUnread(t *testing.T) {
client := NewClient().(*ClientCommon)
UseLegacySecurityClient(client)
client.msgID = 100
bulk := newBulkHandle(context.Background(), newBulkRuntime("dedicated-attach-test"), clientFileScope(), BulkOpenRequest{
BulkID: "bulk-attach-test",
DataID: 1,
Dedicated: true,
AttachToken: "attach-token",
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
encodedResp, err := client.sequenceEn(bulkAttachResponse{Accepted: true})
if err != nil {
t.Fatalf("encode bulkAttachResponse failed: %v", err)
}
replyFrame, err := encodeDirectSignalFrame(stario.NewQueue(), client.sequenceEn, client.msgEn, client.SecretKey, TransferMsg{
ID: 101,
Key: systemBulkAttachKey,
Value: encodedResp,
Type: MSG_SYS_REPLY,
})
if err != nil {
t.Fatalf("encode attach reply frame failed: %v", err)
}
dedicatedPayload := []byte("dedicated-tail-bytes")
conn := newBulkAttachScriptConn(append(replyFrame, encodeDedicatedRecordForAttachTest(dedicatedPayload)...))
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
resp, err := client.sendDedicatedBulkAttachRequest(ctx, conn, bulk)
if err != nil {
t.Fatalf("sendDedicatedBulkAttachRequest failed: %v", err)
}
if !resp.Accepted {
t.Fatalf("bulk attach response = %+v, want accepted", resp)
}
parsedReq := stario.NewQueue()
var reqMsg TransferMsg
if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request", func(msgq stario.MsgQueue) error {
transfer, err := decodeDirectSignalPayload(client.sequenceDe, client.msgDe, client.SecretKey, msgq.Msg)
if err != nil {
return err
}
reqMsg = transfer
return nil
}); err != nil {
t.Fatalf("parse written attach request failed: %v", err)
}
if reqMsg.Key != systemBulkAttachKey || reqMsg.Type != MSG_SYS_WAIT {
t.Fatalf("attach request message mismatch: %+v", reqMsg)
}
readPayload, err := readBulkDedicatedRecord(conn)
if err != nil {
t.Fatalf("readBulkDedicatedRecord after attach failed: %v", err)
}
if !bytes.Equal(readPayload, dedicatedPayload) {
t.Fatalf("dedicated payload mismatch: got %q want %q", string(readPayload), string(dedicatedPayload))
}
}
func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHandoff(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
sidecarLeft, sidecarRight := net.Pipe()
defer sidecarRight.Close()
current := server.bootstrapAcceptedLogical("dedicated-attach-current", nil, sidecarLeft)
if current == nil {
t.Fatal("bootstrapAcceptedLogical(current) should return logical")
}
target := server.bootstrapAcceptedLogical("dedicated-attach-target", nil, nil)
if target == nil {
t.Fatal("bootstrapAcceptedLogical(target) should return logical")
}
bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{
BulkID: "server-dedicated-attach-test",
DataID: 7,
Dedicated: true,
AttachToken: "attach-token",
}, 0, target, nil, 0, nil, nil, nil, nil, nil)
if err := server.getBulkRuntime().register(serverFileScope(target), bulk); err != nil {
t.Fatalf("register bulk runtime failed: %v", err)
}
reqPayload, err := server.sequenceEn(bulkAttachRequest{
PeerID: target.ID(),
BulkID: bulk.ID(),
AttachToken: "attach-token",
})
if err != nil {
t.Fatalf("encode bulkAttachRequest failed: %v", err)
}
msg := Message{
NetType: NET_SERVER,
LogicalConn: current,
ClientConn: current.compatClientConn(),
TransferMsg: TransferMsg{
ID: 42,
Key: systemBulkAttachKey,
Value: reqPayload,
Type: MSG_SYS_WAIT,
},
inboundConn: sidecarLeft,
Time: time.Now(),
}
type attachReplyResult struct {
transfer TransferMsg
resp bulkAttachResponse
err error
}
replyCh := make(chan attachReplyResult, 1)
go func() {
_ = sidecarRight.SetReadDeadline(time.Now().Add(time.Second))
replyPayload, err := readDirectSignalFramePayload(sidecarRight)
if err != nil {
replyCh <- attachReplyResult{err: err}
return
}
transfer, err := decodeDirectSignalPayload(server.sequenceDe, current.msgDeSnapshot(), current.secretKeySnapshot(), replyPayload)
if err != nil {
replyCh <- attachReplyResult{err: err}
return
}
resp, err := decodeBulkAttachResponse(server.sequenceDe, transfer.Value)
replyCh <- attachReplyResult{transfer: transfer, resp: resp, err: err}
}()
if !server.handleBulkAttachSystemMessage(msg) {
t.Fatal("handleBulkAttachSystemMessage should accept dedicated attach message")
}
var result attachReplyResult
select {
case result = <-replyCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for direct attach reply")
}
if result.err != nil {
t.Fatalf("read direct attach reply failed: %v", result.err)
}
transfer := result.transfer
if transfer.ID != msg.ID || transfer.Key != systemBulkAttachKey || transfer.Type != MSG_SYS_REPLY {
t.Fatalf("attach reply mismatch: %+v", transfer)
}
resp := result.resp
if !resp.Accepted || resp.Error != "" {
t.Fatalf("bulk attach response = %+v, want accepted", resp)
}
if got := bulk.dedicatedConnSnapshot(); got != sidecarLeft {
t.Fatalf("dedicated conn mismatch: got %v want %v", got, sidecarLeft)
}
if current.transportAttachedSnapshot() {
t.Fatal("attach sidecar logical transport should be detached after handoff")
}
if got := server.GetLogicalConn(current.ID()); got != nil {
t.Fatalf("attach sidecar logical should be removed after handoff, got %+v", got)
}
}

View File

@ -165,8 +165,7 @@ func (s *bulkDedicatedSender) submitWriteBatch(ctx context.Context, items []bulk
if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted { if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted {
return err return err
} }
queuedItems := make([]bulkDedicatedSendRequest, len(items)) queuedItems := cloneBulkDedicatedSendRequests(items)
copy(queuedItems, items)
return s.submitBatch(ctx, queuedItems, true) return s.submitBatch(ctx, queuedItems, true)
} }
@ -458,6 +457,20 @@ func (s *bulkDedicatedSender) stoppedErr() error {
return errTransportDetached return errTransportDetached
} }
func cloneBulkDedicatedSendRequests(items []bulkDedicatedSendRequest) []bulkDedicatedSendRequest {
if len(items) == 0 {
return nil
}
cloned := make([]bulkDedicatedSendRequest, len(items))
for i, item := range items {
cloned[i] = item
if len(item.Payload) > 0 {
cloned[i].Payload = append([]byte(nil), item.Payload...)
}
}
return cloned
}
func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int { func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int {
return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload)) return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload))
} }

View File

@ -1,6 +1,7 @@
package notify package notify
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"net" "net"
@ -8,6 +9,38 @@ import (
"time" "time"
) )
func TestCloneBulkDedicatedSendRequestsDeepCopiesPayload(t *testing.T) {
src := []bulkDedicatedSendRequest{
{
Type: bulkFastPayloadTypeData,
Seq: 1,
Payload: []byte("payload-a"),
},
{
Type: bulkFastPayloadTypeReset,
Seq: 2,
Payload: []byte("payload-b"),
},
}
cloned := cloneBulkDedicatedSendRequests(src)
if len(cloned) != len(src) {
t.Fatalf("clone length = %d, want %d", len(cloned), len(src))
}
if &cloned[0] == &src[0] {
t.Fatal("request clone should not alias source slice elements")
}
if len(cloned[0].Payload) == 0 || len(src[0].Payload) == 0 {
t.Fatal("payload should not be empty")
}
if &cloned[0].Payload[0] == &src[0].Payload[0] {
t.Fatal("payload clone should not alias source bytes")
}
src[0].Payload[0] = 'X'
if bytes.Equal(cloned[0].Payload, src[0].Payload) {
t.Fatal("mutating source payload should not affect cloned payload")
}
}
func TestBulkDedicatedBatchPlainRoundTrip(t *testing.T) { func TestBulkDedicatedBatchPlainRoundTrip(t *testing.T) {
releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2) releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2)
if err != nil { if err != nil {

View File

@ -0,0 +1,66 @@
package notify
import (
"bytes"
"encoding/binary"
"io"
"net"
"testing"
"time"
)
type shortWriteBulkRecordConn struct {
maxPerWrite int
buf bytes.Buffer
}
func (c *shortWriteBulkRecordConn) Read([]byte) (int, error) { return 0, io.EOF }
func (c *shortWriteBulkRecordConn) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
n := c.maxPerWrite
if n <= 0 || n > len(p) {
n = len(p)
}
_, _ = c.buf.Write(p[:n])
return n, nil
}
func (c *shortWriteBulkRecordConn) Close() error { return nil }
func (c *shortWriteBulkRecordConn) LocalAddr() net.Addr { return shortWriteBulkRecordAddr("local") }
func (c *shortWriteBulkRecordConn) RemoteAddr() net.Addr { return shortWriteBulkRecordAddr("remote") }
func (c *shortWriteBulkRecordConn) SetDeadline(time.Time) error { return nil }
func (c *shortWriteBulkRecordConn) SetReadDeadline(time.Time) error {
return nil
}
func (c *shortWriteBulkRecordConn) SetWriteDeadline(time.Time) error {
return nil
}
type shortWriteBulkRecordAddr string
func (a shortWriteBulkRecordAddr) Network() string { return "tcp" }
func (a shortWriteBulkRecordAddr) String() string { return string(a) }
func TestWriteBulkDedicatedRecordWithDeadlineHandlesShortWrite(t *testing.T) {
conn := &shortWriteBulkRecordConn{maxPerWrite: 3}
payload := []byte("abcdefghijklmnopqrstuvwxyz")
if err := writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{}); err != nil {
t.Fatalf("writeBulkDedicatedRecordWithDeadline failed: %v", err)
}
raw := conn.buf.Bytes()
if got, want := len(raw), bulkDedicatedRecordHeaderLen+len(payload); got != want {
t.Fatalf("record length = %d, want %d", got, want)
}
if got := string(raw[:4]); got != bulkDedicatedRecordMagic {
t.Fatalf("record magic = %q, want %q", got, bulkDedicatedRecordMagic)
}
if got, want := int(binary.BigEndian.Uint32(raw[4:8])), len(payload); got != want {
t.Fatalf("record payload length = %d, want %d", got, want)
}
if got := raw[bulkDedicatedRecordHeaderLen:]; !bytes.Equal(got, payload) {
t.Fatalf("record payload mismatch")
}
}

View File

@ -332,6 +332,24 @@ func TestLogicalDetachTransportForTransferKeepsHandoffConnAlive(t *testing.T) {
} }
} }
func TestLogicalHandleTUTransportReadResultWithSessionDropsDataAfterTransportStop(t *testing.T) {
server := NewServer().(*ServerCommon)
left, right := net.Pipe()
defer right.Close()
stopCtx, stopFn := context.WithCancel(context.Background())
logical, _, _ := newRegisteredServerLogicalForTest(t, server, "logical-stop-read-drop", left, stopCtx, stopFn)
if logical == nil {
t.Fatal("logical should not be nil")
}
generation := logical.transportGenerationSnapshot()
stopFn()
if logical.handleTUTransportReadResultWithSession(stopCtx, left, generation, len([]byte("late-data")), []byte("late-data"), nil) {
t.Fatal("handleTUTransportReadResultWithSession should stop after transport stop")
}
}
func TestClientConnTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) { func TestClientConnTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) {
client := &ClientConn{} client := &ClientConn{}
left, right := net.Pipe() left, right := net.Pipe()

View File

@ -69,6 +69,12 @@ func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []by
} }
func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool { func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool {
if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) {
if c.shouldCloseTransportOnStop(conn) {
_ = conn.Close()
}
return false
}
if err == os.ErrDeadlineExceeded { if err == os.ErrDeadlineExceeded {
if num != 0 { if num != 0 {
c.pushServerOwnedTransportMessage(data[:num], conn, generation) c.pushServerOwnedTransportMessage(data[:num], conn, generation)
@ -95,6 +101,30 @@ func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Con
return true return true
} }
func transportReadShouldStop(stopCtx context.Context) bool {
select {
case <-sessionStopChan(stopCtx):
return true
default:
return false
}
}
func (c *LogicalConn) ownsTransportRead(conn net.Conn, generation uint64) bool {
if c == nil {
return false
}
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil || !rt.transportAttached || rt.transportGeneration != generation {
return false
}
current := rt.tuConn
if rt.transport != nil && rt.transport.connSnapshot() != nil {
current = rt.transport.connSnapshot()
}
return current == conn
}
func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) { func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) {
if c == nil || len(data) == 0 { if c == nil || len(data) == 0 {
return return
@ -163,6 +193,12 @@ func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Cont
if logical := c.LogicalConn(); logical != nil { if logical := c.LogicalConn(); logical != nil {
return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err)
} }
if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) {
if c.shouldCloseClientConnTransportOnStop(conn) {
_ = conn.Close()
}
return false
}
if err == os.ErrDeadlineExceeded { if err == os.ErrDeadlineExceeded {
if num != 0 { if num != 0 {
c.pushServerOwnedTransportMessage(data[:num], conn, generation) c.pushServerOwnedTransportMessage(data[:num], conn, generation)
@ -189,6 +225,21 @@ func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Cont
return true return true
} }
func (c *ClientConn) ownsTransportRead(conn net.Conn, generation uint64) bool {
if c == nil {
return false
}
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil || !rt.transportAttached || rt.transportGeneration != generation {
return false
}
current := rt.tuConn
if rt.transport != nil && rt.transport.connSnapshot() != nil {
current = rt.transport.connSnapshot()
}
return current == conn
}
func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) { func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) {
if logical := c.LogicalConn(); logical != nil { if logical := c.LogicalConn(); logical != nil {
logical.pushServerOwnedTransportMessage(data, conn, generation) logical.pushServerOwnedTransportMessage(data, conn, generation)

View File

@ -13,6 +13,14 @@ import (
var transportConnWriteLocks sync.Map var transportConnWriteLocks sync.Map
var errTransportFrameQueueUnavailable = errors.New("transport frame queue is unavailable") var errTransportFrameQueueUnavailable = errors.New("transport frame queue is unavailable")
type vectoredBuffersWriter interface {
WriteBuffers(*net.Buffers) (int64, error)
}
type vectoredConnUnwrapper interface {
UnwrapConn() net.Conn
}
func writeFullToConn(conn net.Conn, data []byte) error { func writeFullToConn(conn net.Conn, data []byte) error {
if conn == nil { if conn == nil {
return net.ErrClosed return net.ErrClosed
@ -26,8 +34,15 @@ func writeFullToConnUnlocked(conn net.Conn, data []byte) error {
if conn == nil { if conn == nil {
return net.ErrClosed return net.ErrClosed
} }
return writeFullToWriterUnlocked(conn, data)
}
func writeFullToWriterUnlocked(writer io.Writer, data []byte) error {
if writer == nil {
return io.ErrClosedPipe
}
for len(data) > 0 { for len(data) > 0 {
n, err := conn.Write(data) n, err := writer.Write(data)
if n > 0 { if n > 0 {
data = data[n:] data = data[n:]
} }
@ -41,6 +56,69 @@ func writeFullToConnUnlocked(conn net.Conn, data []byte) error {
return nil return nil
} }
func writeNetBuffersFullUnlocked(conn net.Conn, buffers net.Buffers) error {
if conn == nil {
return net.ErrClosed
}
writer, writeFn := vectoredWriteStrategy(conn)
if writeFn == nil {
return writeRemainingBuffersUnlocked(conn, buffers)
}
n, err := writeFn(&buffers)
if err != nil {
return err
}
if len(buffers) == 0 {
return nil
}
if n == 0 {
return io.ErrNoProgress
}
return writeRemainingBuffersUnlocked(writer, buffers)
}
func vectoredWriteStrategy(conn net.Conn) (io.Writer, func(*net.Buffers) (int64, error)) {
current := conn
for depth := 0; depth < 8 && current != nil; depth++ {
if writer, ok := current.(vectoredBuffersWriter); ok {
target := current
return target, writer.WriteBuffers
}
switch target := current.(type) {
case *net.TCPConn:
return target, func(bufs *net.Buffers) (int64, error) {
return bufs.WriteTo(target)
}
case *net.UnixConn:
return target, func(bufs *net.Buffers) (int64, error) {
return bufs.WriteTo(target)
}
}
unwrapper, ok := current.(vectoredConnUnwrapper)
if !ok {
break
}
next := unwrapper.UnwrapConn()
if next == nil || next == current {
break
}
current = next
}
return nil, nil
}
func writeRemainingBuffersUnlocked(writer io.Writer, buffers net.Buffers) error {
for _, part := range buffers {
if len(part) == 0 {
continue
}
if err := writeFullToWriterUnlocked(writer, part); err != nil {
return err
}
}
return nil
}
func withRawConnWriteLock(conn net.Conn, fn func(net.Conn) error) error { func withRawConnWriteLock(conn net.Conn, fn func(net.Conn) error) error {
return withRawConnWriteLockDeadline(conn, time.Time{}, fn) return withRawConnWriteLockDeadline(conn, time.Time{}, fn)
} }

View File

@ -2,8 +2,10 @@ package notify
import ( import (
"b612.me/stario" "b612.me/stario"
"bytes"
"context" "context"
"errors" "errors"
"io"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -146,6 +148,113 @@ func (c *blockingPacketWriteConn) Write(p []byte) (int, error) {
return len(p), nil return len(p), nil
} }
type vectoredShortWriteConn struct {
steps []int64
idx int
buf bytes.Buffer
writes int
writev int
}
func (c *vectoredShortWriteConn) Read([]byte) (int, error) { return 0, io.EOF }
func (c *vectoredShortWriteConn) Write(p []byte) (int, error) {
c.writes++
return c.buf.Write(p)
}
func (c *vectoredShortWriteConn) Close() error { return nil }
func (c *vectoredShortWriteConn) LocalAddr() net.Addr { return nil }
func (c *vectoredShortWriteConn) RemoteAddr() net.Addr { return nil }
func (c *vectoredShortWriteConn) SetDeadline(time.Time) error { return nil }
func (c *vectoredShortWriteConn) SetReadDeadline(time.Time) error { return nil }
func (c *vectoredShortWriteConn) SetWriteDeadline(time.Time) error { return nil }
func (c *vectoredShortWriteConn) WriteBuffers(bufs *net.Buffers) (int64, error) {
c.writev++
if c.idx >= len(c.steps) {
return 0, io.ErrNoProgress
}
remaining := c.steps[c.idx]
c.idx++
written := int64(0)
for len(*bufs) > 0 && remaining > 0 {
part := (*bufs)[0]
if len(part) == 0 {
(*bufs)[0] = nil
*bufs = (*bufs)[1:]
continue
}
n := int64(len(part))
if n > remaining {
n = remaining
}
_, _ = c.buf.Write(part[:n])
written += n
remaining -= n
if n == int64(len(part)) {
(*bufs)[0] = nil
*bufs = (*bufs)[1:]
continue
}
(*bufs)[0] = part[n:]
break
}
return written, nil
}
type unwrapVectoredConn struct {
inner net.Conn
}
func (c *unwrapVectoredConn) Read(p []byte) (int, error) { return c.inner.Read(p) }
func (c *unwrapVectoredConn) Write(p []byte) (int, error) { return c.inner.Write(p) }
func (c *unwrapVectoredConn) Close() error { return c.inner.Close() }
func (c *unwrapVectoredConn) LocalAddr() net.Addr { return c.inner.LocalAddr() }
func (c *unwrapVectoredConn) RemoteAddr() net.Addr { return c.inner.RemoteAddr() }
func (c *unwrapVectoredConn) SetDeadline(t time.Time) error { return c.inner.SetDeadline(t) }
func (c *unwrapVectoredConn) SetReadDeadline(t time.Time) error { return c.inner.SetReadDeadline(t) }
func (c *unwrapVectoredConn) SetWriteDeadline(t time.Time) error { return c.inner.SetWriteDeadline(t) }
func (c *unwrapVectoredConn) UnwrapConn() net.Conn { return c.inner }
func TestWriteNetBuffersFullUnlockedFallsBackToDirectWritesAfterFirstPartialVectoredWrite(t *testing.T) {
conn := &vectoredShortWriteConn{steps: []int64{3}}
header := []byte("head")
payload := []byte("payload")
if err := writeNetBuffersFullUnlocked(conn, net.Buffers{header, payload}); err != nil {
t.Fatalf("writeNetBuffersFullUnlocked failed: %v", err)
}
if got, want := conn.writev, 1; got != want {
t.Fatalf("vectored write calls = %d, want %d", got, want)
}
if got, want := conn.writes, 2; got != want {
t.Fatalf("fallback direct writes = %d, want %d", got, want)
}
if got, want := conn.buf.String(), "headpayload"; got != want {
t.Fatalf("written bytes = %q, want %q", got, want)
}
}
func TestWriteNetBuffersFullUnlockedReturnsNoProgressWhenVectoredWriteDoesNotAdvance(t *testing.T) {
conn := &vectoredShortWriteConn{steps: []int64{0}}
err := writeNetBuffersFullUnlocked(conn, net.Buffers{[]byte("head"), []byte("payload")})
if !errors.Is(err, io.ErrNoProgress) {
t.Fatalf("writeNetBuffersFullUnlocked error = %v, want %v", err, io.ErrNoProgress)
}
}
func TestWriteNetBuffersFullUnlockedUsesUnwrappedVectoredConn(t *testing.T) {
inner := &vectoredShortWriteConn{steps: []int64{100}}
conn := &unwrapVectoredConn{inner: inner}
if err := writeNetBuffersFullUnlocked(conn, net.Buffers{[]byte("head"), []byte("payload")}); err != nil {
t.Fatalf("writeNetBuffersFullUnlocked failed: %v", err)
}
if got, want := inner.writev, 1; got != want {
t.Fatalf("unwrapped vectored write calls = %d, want %d", got, want)
}
if got := inner.writes; got != 0 {
t.Fatalf("unexpected fallback direct writes = %d, want 0", got)
}
}
func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) { func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) {
conn := newBlockingPacketWriteConn() conn := newBlockingPacketWriteConn()
binding := newTransportBinding(conn, stario.NewQueue()) binding := newTransportBinding(conn, stario.NewQueue())