fix: close stream adaptive gaps and switch notify to stario v0.1.1

- make stream fast path honor adaptive soft payload limits end-to-end
  - split oversized fast-stream payloads into sequential frames before batching
  - use adaptive soft cap when encoding stream batch payloads
  - move timeout-like error detection into production code for adaptive tx
  - tune notify FrameReader read size explicitly to avoid throughput regression
  - drop local stario replace and depend on released b612.me/stario v0.1.1
This commit is contained in:
兔子 2026-04-18 16:05:57 +08:00
parent 4f760f2807
commit f038a89771
Signed by: b612
GPG Key ID: 99DD2222B612B612
76 changed files with 12656 additions and 906 deletions

16
benchmark_listen_test.go Normal file
View File

@ -0,0 +1,16 @@
package notify
import (
"os"
"testing"
)
const benchmarkTCPListenAddrEnv = "NOTIFY_BENCH_TCP_LISTEN_ADDR"
func benchmarkTCPListenAddr(tb testing.TB) string {
tb.Helper()
if addr := os.Getenv(benchmarkTCPListenAddrEnv); addr != "" {
return addr
}
return "127.0.0.1:0"
}

378
benchmark_tcp_proxy_test.go Normal file
View File

@ -0,0 +1,378 @@
package notify
import (
"context"
"errors"
"io"
"math/rand"
"net"
"os"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
)
const (
benchmarkTCPProxyEnableEnv = "NOTIFY_BENCH_TCP_PROXY_ENABLE"
benchmarkTCPProxyListenAddrEnv = "NOTIFY_BENCH_TCP_PROXY_LISTEN_ADDR"
benchmarkTCPProxyRateMbitEnv = "NOTIFY_BENCH_TCP_PROXY_RATE_MBIT"
benchmarkTCPProxyDelayMsEnv = "NOTIFY_BENCH_TCP_PROXY_DELAY_MS"
benchmarkTCPProxyLossPctEnv = "NOTIFY_BENCH_TCP_PROXY_LOSS_PCT"
benchmarkTCPProxyLossPenaltyMsEnv = "NOTIFY_BENCH_TCP_PROXY_LOSS_PENALTY_MS"
benchmarkTCPProxySeedEnv = "NOTIFY_BENCH_TCP_PROXY_SEED"
benchmarkTCPProxyDefaultListenAddr = "127.0.0.1:0"
)
const benchmarkTCPProxyDefaultLossPenalty = 120 * time.Millisecond
type benchmarkTCPProxyConfig struct {
Enabled bool
ListenAddr string
RateBytesPS float64
Delay time.Duration
LossPct float64
LossPenalty time.Duration
Seed int64
}
type benchmarkTCPProxy struct {
listener net.Listener
target string
cfg benchmarkTCPProxyConfig
closed atomic.Bool
wg sync.WaitGroup
}
func benchmarkTCPDialAddr(tb testing.TB, targetAddr string) string {
tb.Helper()
if targetAddr == "" {
return targetAddr
}
cfg := benchmarkTCPProxyConfigFromEnv(tb)
if !cfg.Enabled {
return targetAddr
}
proxy, err := startBenchmarkTCPProxy(targetAddr, cfg)
if err != nil {
tb.Fatalf("start benchmark tcp proxy failed: %v", err)
}
tb.Cleanup(proxy.stop)
if proxy.listener == nil || proxy.listener.Addr() == nil {
tb.Fatalf("benchmark tcp proxy listener is nil")
}
return proxy.listener.Addr().String()
}
func benchmarkTCPProxyConfigFromEnv(tb testing.TB) benchmarkTCPProxyConfig {
tb.Helper()
enabled, ok, err := parseBoolEnv(benchmarkTCPProxyEnableEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyEnableEnv, err)
}
if !ok || !enabled {
return benchmarkTCPProxyConfig{}
}
cfg := benchmarkTCPProxyConfig{
Enabled: true,
ListenAddr: benchmarkTCPProxyDefaultListenAddr,
Seed: 1,
}
if listen := os.Getenv(benchmarkTCPProxyListenAddrEnv); listen != "" {
cfg.ListenAddr = listen
}
rateMbit, ok, err := parseFloatEnv(benchmarkTCPProxyRateMbitEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyRateMbitEnv, err)
}
if ok && rateMbit > 0 {
cfg.RateBytesPS = rateMbit * 1000 * 1000 / 8
}
delayMs, ok, err := parseIntEnv(benchmarkTCPProxyDelayMsEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyDelayMsEnv, err)
}
if ok && delayMs > 0 {
cfg.Delay = time.Duration(delayMs) * time.Millisecond
}
lossPct, ok, err := parseFloatEnv(benchmarkTCPProxyLossPctEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyLossPctEnv, err)
}
if ok {
if lossPct < 0 || lossPct > 100 {
tb.Fatalf("%s must be in [0, 100], got %.4f", benchmarkTCPProxyLossPctEnv, lossPct)
}
cfg.LossPct = lossPct
}
lossPenaltyMs, ok, err := parseIntEnv(benchmarkTCPProxyLossPenaltyMsEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyLossPenaltyMsEnv, err)
}
if ok && lossPenaltyMs > 0 {
cfg.LossPenalty = time.Duration(lossPenaltyMs) * time.Millisecond
}
if cfg.LossPct > 0 && cfg.LossPenalty <= 0 {
cfg.LossPenalty = benchmarkTCPProxyDefaultLossPenalty
}
seed, ok, err := parseInt64Env(benchmarkTCPProxySeedEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxySeedEnv, err)
}
if ok {
cfg.Seed = seed
}
return cfg
}
func parseBoolEnv(name string) (value bool, ok bool, err error) {
raw := os.Getenv(name)
if raw == "" {
return false, false, nil
}
value, err = strconv.ParseBool(raw)
return value, true, err
}
func parseFloatEnv(name string) (value float64, ok bool, err error) {
raw := os.Getenv(name)
if raw == "" {
return 0, false, nil
}
value, err = strconv.ParseFloat(raw, 64)
return value, true, err
}
func parseIntEnv(name string) (value int, ok bool, err error) {
raw := os.Getenv(name)
if raw == "" {
return 0, false, nil
}
value, err = strconv.Atoi(raw)
return value, true, err
}
func parseInt64Env(name string) (value int64, ok bool, err error) {
raw := os.Getenv(name)
if raw == "" {
return 0, false, nil
}
value, err = strconv.ParseInt(raw, 10, 64)
return value, true, err
}
func startBenchmarkTCPProxy(targetAddr string, cfg benchmarkTCPProxyConfig) (*benchmarkTCPProxy, error) {
if targetAddr == "" {
return nil, errors.New("benchmark tcp proxy target addr is empty")
}
listenAddr := cfg.ListenAddr
if listenAddr == "" {
listenAddr = benchmarkTCPProxyDefaultListenAddr
}
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return nil, err
}
proxy := &benchmarkTCPProxy{
listener: listener,
target: targetAddr,
cfg: cfg,
}
proxy.wg.Add(1)
go proxy.runAcceptLoop()
return proxy, nil
}
func (p *benchmarkTCPProxy) runAcceptLoop() {
defer p.wg.Done()
for {
conn, err := p.listener.Accept()
if err != nil {
if p.closed.Load() || errors.Is(err, net.ErrClosed) {
return
}
continue
}
p.wg.Add(1)
go func(clientConn net.Conn) {
defer p.wg.Done()
p.handleConn(clientConn)
}(conn)
}
}
func (p *benchmarkTCPProxy) handleConn(clientConn net.Conn) {
if p == nil || clientConn == nil {
return
}
targetConn, err := net.Dial("tcp", p.target)
if err != nil {
_ = clientConn.Close()
return
}
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
}
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 2)
leftRng := rand.New(rand.NewSource(p.cfg.Seed + 101))
rightRng := rand.New(rand.NewSource(p.cfg.Seed + 202))
go func() {
errCh <- benchmarkRelayStream(ctx, targetConn, clientConn, p.cfg, leftRng)
}()
go func() {
errCh <- benchmarkRelayStream(ctx, clientConn, targetConn, p.cfg, rightRng)
}()
_ = <-errCh
cancel()
_ = clientConn.Close()
_ = targetConn.Close()
<-errCh
}
func (p *benchmarkTCPProxy) stop() {
if p == nil {
return
}
if p.closed.CompareAndSwap(false, true) {
if p.listener != nil {
_ = p.listener.Close()
}
}
p.wg.Wait()
}
func benchmarkRelayStream(ctx context.Context, dst net.Conn, src net.Conn, cfg benchmarkTCPProxyConfig, rng *rand.Rand) error {
if dst == nil || src == nil {
return net.ErrClosed
}
chunkCh := make(chan benchmarkRelayChunk, 128)
readerErrCh := make(chan error, 1)
go benchmarkRelayReader(ctx, src, chunkCh, readerErrCh)
scheduler := benchmarkRelayScheduler{rateBytesPS: cfg.RateBytesPS}
for chunk := range chunkCh {
sendAt := scheduler.schedule(chunk.readAt, len(chunk.payload), cfg.Delay)
if cfg.LossPct > 0 && cfg.LossPenalty > 0 && rng != nil && rng.Float64()*100 < cfg.LossPct {
sendAt = sendAt.Add(cfg.LossPenalty)
}
if waitErr := benchmarkSleepUntilContext(ctx, sendAt); waitErr != nil {
return waitErr
}
if writeErr := writeAllToConn(dst, chunk.payload); writeErr != nil {
return writeErr
}
}
readerErr := <-readerErrCh
if errors.Is(readerErr, io.EOF) {
return nil
}
return readerErr
}
type benchmarkRelayChunk struct {
readAt time.Time
payload []byte
}
func benchmarkRelayReader(ctx context.Context, src net.Conn, chunkCh chan<- benchmarkRelayChunk, errCh chan<- error) {
defer close(chunkCh)
if src == nil {
errCh <- net.ErrClosed
return
}
buf := make([]byte, 64*1024)
for {
n, err := src.Read(buf)
if n > 0 {
chunk := benchmarkRelayChunk{
readAt: time.Now(),
payload: append([]byte(nil), buf[:n]...),
}
select {
case <-ctx.Done():
errCh <- ctx.Err()
return
case chunkCh <- chunk:
}
}
if err != nil {
errCh <- err
return
}
}
}
type benchmarkRelayScheduler struct {
rateBytesPS float64
nextAt time.Time
}
func (s *benchmarkRelayScheduler) schedule(readAt time.Time, bytes int, delay time.Duration) time.Time {
if s == nil {
return readAt
}
sendAt := readAt
if delay > 0 {
sendAt = sendAt.Add(delay)
}
if !s.nextAt.IsZero() && s.nextAt.After(sendAt) {
sendAt = s.nextAt
}
if s.rateBytesPS > 0 && bytes > 0 {
duration := time.Duration(float64(bytes) / s.rateBytesPS * float64(time.Second))
if duration <= 0 {
duration = time.Nanosecond
}
s.nextAt = sendAt.Add(duration)
} else {
s.nextAt = sendAt
}
return sendAt
}
func benchmarkSleepUntilContext(ctx context.Context, at time.Time) error {
wait := time.Until(at)
if wait <= 0 {
return nil
}
return benchmarkSleepContext(ctx, wait)
}
func benchmarkSleepContext(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
func writeAllToConn(conn net.Conn, payload []byte) error {
for len(payload) > 0 {
n, err := conn.Write(payload)
if n > 0 {
payload = payload[n:]
}
if err != nil {
return err
}
if n == 0 {
return io.ErrNoProgress
}
}
return nil
}

1233
bulk.go

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,9 @@ import (
)
const (
bulkBatchMaxPayloads = 16
bulkBatchMaxPayloads = 64
bulkBatchMaxPayloadBytes = bulkFastBatchMaxPlainBytes
bulkBatchMaxFlushDelay = 50 * time.Microsecond
)
const (
@ -22,48 +24,161 @@ type bulkBatchRequestState struct {
value atomic.Int32
}
type bulkBatchCodec struct {
encodeSingle func(bulkFastFrame) ([]byte, func(), error)
encodeBatch func([]bulkFastFrame) ([]byte, func(), error)
}
type bulkBatchRequest struct {
ctx context.Context
payload []byte
deadline time.Time
done chan error
state *bulkBatchRequestState
ctx context.Context
frames []bulkFastFrame
fastPathVersion uint8
payloadOwned bool
deadline time.Time
done chan error
state *bulkBatchRequestState
release func()
}
type bulkBatchEncodedPayload struct {
payload []byte
release func()
}
func (p *bulkBatchEncodedPayload) done() {
if p == nil || p.release == nil {
return
}
p.release()
p.release = nil
}
type bulkBatchSender struct {
binding *transportBinding
reqCh chan bulkBatchRequest
stopCh chan struct{}
doneCh chan struct{}
binding *transportBinding
codec bulkBatchCodec
writeTimeoutProvider func() time.Duration
reqCh chan bulkBatchRequest
stopCh chan struct{}
doneCh chan struct{}
stopOnce sync.Once
flushMu sync.Mutex
queued atomic.Int64
errMu sync.Mutex
err error
}
func newBulkBatchSender(binding *transportBinding) *bulkBatchSender {
func newBulkBatchSender(binding *transportBinding, codec bulkBatchCodec, writeTimeoutProvider func() time.Duration) *bulkBatchSender {
sender := &bulkBatchSender{
binding: binding,
reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
binding: binding,
codec: codec,
writeTimeoutProvider: writeTimeoutProvider,
reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
go sender.run()
return sender
}
func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error {
func (s *bulkBatchSender) submitData(ctx context.Context, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
return s.submitFramesOwned(ctx, []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: payload,
}}, fastPathVersion, false)
}
func (s *bulkBatchSender) submitControl(ctx context.Context, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
return s.submitFramesOwned(ctx, []bulkFastFrame{{
Type: frameType,
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload,
}}, fastPathVersion, false)
}
func (s *bulkBatchSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, fastPathVersion uint8, payload []byte, chunkSize int, payloadOwned bool) (int, error) {
if s == nil {
return 0, errTransportDetached
}
if len(payload) == 0 {
return 0, nil
}
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
written := 0
seq := startSeq
for written < len(payload) {
var batch [bulkFastBatchMaxItems]bulkFastFrame
frames := batch[:0]
batchBytes := bulkFastBatchHeaderLen
start := written
for written < len(payload) && len(frames) < bulkFastBatchMaxItems {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
frame := bulkFastFrame{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: payload[written:end],
}
frameLen := bulkFastBatchFrameLen(frame)
if len(frames) > 0 && batchBytes+frameLen > bulkFastBatchMaxPlainBytes {
break
}
frames = append(frames, frame)
batchBytes += frameLen
seq++
written = end
}
if len(frames) == 0 {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
frames = append(frames, bulkFastFrame{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: payload[written:end],
})
seq++
written = end
}
if err := s.submitFramesOwned(ctx, frames, fastPathVersion, payloadOwned); err != nil {
return start, err
}
}
return written, nil
}
func (s *bulkBatchSender) submitFrames(ctx context.Context, frames []bulkFastFrame, fastPathVersion uint8) error {
return s.submitFramesOwned(ctx, frames, fastPathVersion, false)
}
func (s *bulkBatchSender) submitFramesOwned(ctx context.Context, frames []bulkFastFrame, fastPathVersion uint8, payloadOwned bool) error {
if s == nil {
return errTransportDetached
}
if ctx == nil {
ctx = context.Background()
}
if len(frames) == 0 {
return nil
}
req := bulkBatchRequest{
ctx: ctx,
payload: payload,
done: make(chan error, 1),
state: &bulkBatchRequestState{},
ctx: ctx,
frames: frames,
fastPathVersion: normalizeBulkFastPathVersion(fastPathVersion),
payloadOwned: payloadOwned,
done: make(chan error, 1),
state: &bulkBatchRequestState{},
}
if deadline, ok := ctx.Deadline(); ok {
req.deadline = deadline
@ -71,10 +186,25 @@ func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error {
if err := s.errSnapshot(); err != nil {
return err
}
if s.shouldDirectSubmit(req) {
if submitted, err := s.tryDirectSubmit(req); submitted {
return err
}
}
req = cloneQueuedBulkBatchRequest(req)
s.queued.Add(1)
select {
case <-ctx.Done():
s.queued.Add(-1)
if req.release != nil {
req.release()
}
return normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
s.queued.Add(-1)
if req.release != nil {
req.release()
}
return s.stoppedErr()
case s.reqCh <- req:
}
@ -89,6 +219,55 @@ func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error {
}
}
func (s *bulkBatchSender) shouldDirectSubmit(req bulkBatchRequest) bool {
if len(req.frames) == 0 {
return false
}
return !bulkBatchRequestSupportsSharedSuperBatch(req)
}
func (s *bulkBatchSender) tryDirectSubmit(req bulkBatchRequest) (bool, error) {
if s == nil {
return true, errTransportDetached
}
if err := s.errSnapshot(); err != nil {
return true, err
}
select {
case <-req.ctx.Done():
return true, normalizeStreamDeadlineError(req.ctx.Err())
case <-s.stopCh:
return true, s.stoppedErr()
default:
}
if s.queued.Load() != 0 {
return false, nil
}
if !s.flushMu.TryLock() {
return false, nil
}
defer s.flushMu.Unlock()
if s.queued.Load() != 0 {
return false, nil
}
if err := s.errSnapshot(); err != nil {
return true, err
}
if !req.tryStart() {
return true, req.canceledErr()
}
if err := req.contextErr(); err != nil {
return true, err
}
err := s.flush([]bulkBatchRequest{req})
if err != nil {
s.setErr(err)
s.failPending(err)
return true, err
}
return true, nil
}
func (s *bulkBatchSender) run() {
defer close(s.doneCh)
for {
@ -97,34 +276,83 @@ func (s *bulkBatchSender) run() {
return
}
batch := []bulkBatchRequest{req}
batchBytes := bulkBatchRequestApproxBytes(req)
timer := (*time.Timer)(nil)
timerCh := (<-chan time.Time)(nil)
if bulkBatchShouldWaitForMore(batch, batchBytes) {
timer = time.NewTimer(bulkBatchMaxFlushDelay)
timerCh = timer.C
}
drain:
for len(batch) < bulkBatchMaxPayloads {
for len(batch) < bulkBatchMaxPayloads && batchBytes < bulkBatchMaxPayloadBytes {
if timerCh == nil {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
batchBytes += bulkBatchRequestApproxBytes(next)
default:
break drain
}
continue
}
select {
case <-s.stopCh:
if timer != nil {
timer.Stop()
}
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
default:
batchBytes += bulkBatchRequestApproxBytes(next)
case <-timerCh:
timerCh = nil
break drain
}
}
active, payloads := activeBulkBatchRequests(batch)
if timer != nil {
if !timer.Stop() && timerCh != nil {
select {
case <-timer.C:
default:
}
}
}
s.flushMu.Lock()
err := s.errSnapshot()
active := make([]bulkBatchRequest, 0, len(batch))
for _, item := range batch {
if !item.tryStart() {
s.finishRequest(item, item.canceledErr())
continue
}
if itemErr := item.contextErr(); itemErr != nil {
s.finishRequest(item, itemErr)
continue
}
active = append(active, item)
}
if len(active) == 0 {
s.flushMu.Unlock()
continue
}
deadline := bulkBatchRequestsEarliestDeadline(active)
err := s.flush(payloads, deadline)
if err == nil {
err = s.flush(active)
}
s.flushMu.Unlock()
if err != nil {
s.setErr(err)
for _, item := range active {
item.done <- err
s.finishRequest(item, err)
}
s.failPending(err)
return
}
for _, item := range active {
item.done <- err
s.finishRequest(item, nil)
}
}
}
@ -139,37 +367,6 @@ func (s *bulkBatchSender) nextRequest() (bulkBatchRequest, bool) {
}
}
func activeBulkBatchRequests(batch []bulkBatchRequest) ([]bulkBatchRequest, [][]byte) {
active := make([]bulkBatchRequest, 0, len(batch))
payloads := make([][]byte, 0, len(batch))
for _, item := range batch {
if !item.tryStart() {
item.done <- item.canceledErr()
continue
}
if err := item.contextErr(); err != nil {
item.done <- err
continue
}
active = append(active, item)
payloads = append(payloads, item.payload)
}
return active, payloads
}
func bulkBatchRequestsEarliestDeadline(batch []bulkBatchRequest) time.Time {
var deadline time.Time
for _, item := range batch {
if item.deadline.IsZero() {
continue
}
if deadline.IsZero() || item.deadline.Before(deadline) {
deadline = item.deadline
}
}
return deadline
}
func (r bulkBatchRequest) contextErr() error {
if r.ctx == nil {
return nil
@ -203,7 +400,7 @@ func (r bulkBatchRequest) canceledErr() error {
return context.Canceled
}
func (s *bulkBatchSender) flush(payloads [][]byte, deadline time.Time) error {
func (s *bulkBatchSender) flush(requests []bulkBatchRequest) error {
if s == nil || s.binding == nil {
return errTransportDetached
}
@ -211,9 +408,196 @@ func (s *bulkBatchSender) flush(payloads [][]byte, deadline time.Time) error {
if queue == nil {
return errTransportFrameQueueUnavailable
}
return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error {
return writeFramedPayloadBatchUnlocked(conn, queue, payloads)
})
payloads, err := s.encodeRequests(requests)
if err != nil {
return err
}
defer func() {
for index := range payloads {
payloads[index].done()
}
}()
writeTimeout := s.transportWriteTimeout()
for _, payload := range payloads {
frame := payload.payload
started := time.Now()
err := s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error {
return writeFramedPayloadUnlocked(conn, queue, frame)
})
s.binding.observeBulkAdaptivePayloadWrite(len(frame), time.Since(started), writeTimeout, err)
if err != nil {
return err
}
}
return nil
}
func (s *bulkBatchSender) encodeRequests(requests []bulkBatchRequest) ([]bulkBatchEncodedPayload, error) {
if len(requests) == 0 {
return nil, nil
}
payloads := make([]bulkBatchEncodedPayload, 0, len(requests))
batch := make([]bulkFastFrame, 0, minInt(len(requests), bulkFastBatchMaxItems))
mixedBatchLimit := s.sharedMixedPayloadLimit()
batchRequestIndex := -1
batchDataID := uint64(0)
batchMixed := false
flushBatch := func() error {
if len(batch) == 0 {
return nil
}
payload, release, err := s.encodeBatch(batch)
if err != nil {
return err
}
payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release})
batch = batch[:0]
batchRequestIndex = -1
batchDataID = 0
batchMixed = false
return nil
}
batchBytes := bulkFastBatchHeaderLen
for reqIndex, req := range requests {
for _, frame := range req.frames {
if !bulkFastPathSupportsSharedBatch(req.fastPathVersion) {
if err := flushBatch(); err != nil {
return nil, err
}
batchBytes = bulkFastBatchHeaderLen
payload, release, err := s.encodeSingle(frame)
if err != nil {
return nil, err
}
payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release})
continue
}
frameLen := bulkFastBatchFrameLen(frame)
if frameLen+bulkFastBatchHeaderLen > bulkFastBatchMaxPlainBytes {
if err := flushBatch(); err != nil {
return nil, err
}
batchBytes = bulkFastBatchHeaderLen
payload, release, err := s.encodeSingle(frame)
if err != nil {
return nil, err
}
payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release})
continue
}
nextMixed := batchMixed
if len(batch) > 0 && (batchRequestIndex != reqIndex || (batchDataID != 0 && batchDataID != frame.DataID)) {
nextMixed = true
}
batchLimit := bulkFastBatchMaxPlainBytes
if nextMixed && mixedBatchLimit > 0 && mixedBatchLimit < batchLimit {
batchLimit = mixedBatchLimit
}
if len(batch) > 0 && (len(batch) >= bulkFastBatchMaxItems || batchBytes+frameLen > batchLimit) {
if err := flushBatch(); err != nil {
return nil, err
}
batchBytes = bulkFastBatchHeaderLen
nextMixed = false
}
if len(batch) == 0 {
batchRequestIndex = reqIndex
batchDataID = frame.DataID
batchMixed = false
} else if batchRequestIndex != reqIndex || (batchDataID != 0 && batchDataID != frame.DataID) {
batchMixed = true
}
batch = append(batch, frame)
batchBytes += frameLen
}
}
if err := flushBatch(); err != nil {
return nil, err
}
return payloads, nil
}
func bulkBatchRequestApproxBytes(req bulkBatchRequest) int {
total := 0
for _, frame := range req.frames {
total += bulkFastBatchFrameLen(frame)
}
return total
}
func bulkBatchRequestSupportsSharedSuperBatch(req bulkBatchRequest) bool {
if len(req.frames) == 0 || !bulkFastPathSupportsSharedBatch(req.fastPathVersion) {
return false
}
for _, frame := range req.frames {
switch frame.Type {
case bulkFastPayloadTypeData:
default:
return false
}
}
return true
}
func bulkBatchShouldWaitForMore(batch []bulkBatchRequest, batchBytes int) bool {
if bulkBatchMaxFlushDelay <= 0 || len(batch) == 0 {
return false
}
if len(batch) >= bulkBatchMaxPayloads || batchBytes >= bulkBatchMaxPayloadBytes {
return false
}
for _, req := range batch {
if !bulkBatchRequestSupportsSharedSuperBatch(req) {
return false
}
}
return true
}
func cloneQueuedBulkBatchRequest(req bulkBatchRequest) bulkBatchRequest {
if len(req.frames) == 0 || req.payloadOwned {
return req
}
clonedFrames := make([]bulkFastFrame, len(req.frames))
totalPayload := 0
for _, frame := range req.frames {
totalPayload += len(frame.Payload)
}
var payloadBuf []byte
if totalPayload > 0 {
payloadBuf = getBulkAsyncWritePayload(totalPayload)
req.release = func() {
putBulkAsyncWritePayload(payloadBuf)
}
}
offset := 0
for index, frame := range req.frames {
clonedFrames[index] = frame
if len(frame.Payload) == 0 {
clonedFrames[index].Payload = nil
continue
}
next := offset + len(frame.Payload)
clonedFrames[index].Payload = payloadBuf[offset:next]
copy(clonedFrames[index].Payload, frame.Payload)
offset = next
}
req.frames = clonedFrames
return req
}
func (s *bulkBatchSender) encodeSingle(frame bulkFastFrame) ([]byte, func(), error) {
if s == nil || s.codec.encodeSingle == nil {
return nil, nil, errTransportDetached
}
return s.codec.encodeSingle(frame)
}
func (s *bulkBatchSender) encodeBatch(frames []bulkFastFrame) ([]byte, func(), error) {
if len(frames) == 1 || s.codec.encodeBatch == nil {
return s.encodeSingle(frames[0])
}
return s.codec.encodeBatch(frames)
}
func (s *bulkBatchSender) stop() {
@ -231,13 +615,23 @@ func (s *bulkBatchSender) failPending(err error) {
for {
select {
case item := <-s.reqCh:
item.done <- err
s.finishRequest(item, err)
default:
return
}
}
}
func (s *bulkBatchSender) finishRequest(req bulkBatchRequest, err error) {
if s != nil {
s.queued.Add(-1)
}
if req.release != nil {
req.release()
}
req.done <- err
}
func (s *bulkBatchSender) setErr(err error) {
if s == nil || err == nil {
return
@ -264,3 +658,31 @@ func (s *bulkBatchSender) stoppedErr() error {
}
return errTransportDetached
}
func (s *bulkBatchSender) transportWriteDeadline() time.Time {
if s == nil || s.writeTimeoutProvider == nil {
return time.Time{}
}
return writeDeadlineFromTimeout(s.writeTimeoutProvider())
}
func (s *bulkBatchSender) transportWriteTimeout() time.Duration {
if s == nil || s.writeTimeoutProvider == nil {
return 0
}
return s.writeTimeoutProvider()
}
func (s *bulkBatchSender) sharedMixedPayloadLimit() int {
if s == nil || s.binding == nil {
return bulkAdaptiveSoftPayloadFallbackBytes
}
return s.binding.bulkAdaptiveSoftPayloadBytesSnapshot()
}
func minInt(a int, b int) int {
if a < b {
return a
}
return b
}

View File

@ -161,7 +161,7 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil {
b.Fatalf("server Listen failed: %v", err)
}
b.Cleanup(func() {
@ -172,7 +172,7 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
b.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
b.Fatalf("client Connect failed: %v", err)
}
b.Cleanup(func() {
@ -258,7 +258,7 @@ func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurr
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil {
b.Fatalf("server Listen failed: %v", err)
}
b.Cleanup(func() {
@ -269,7 +269,7 @@ func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurr
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
b.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
b.Fatalf("client Connect failed: %v", err)
}
b.Cleanup(func() {

119
bulk_buffer_release_test.go Normal file
View File

@ -0,0 +1,119 @@
package notify
import (
"context"
"errors"
"testing"
"time"
)
func TestBulkOwnedChunkReleaseAfterRead(t *testing.T) {
bulk := newBulkHandle(context.Background(), newBulkRuntime("buffer-release-read"), clientFileScope(), BulkOpenRequest{
BulkID: "buffer-release-read",
DataID: 1,
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
released := 0
if err := bulk.pushOwnedChunkWithReleaseNoReset([]byte("hello"), func() {
released++
}); err != nil {
t.Fatalf("pushOwnedChunkWithReleaseNoReset failed: %v", err)
}
buf := make([]byte, 5)
n, err := bulk.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n != 5 || string(buf[:n]) != "hello" {
t.Fatalf("Read = %d %q, want 5 hello", n, string(buf[:n]))
}
if released != 1 {
t.Fatalf("release count = %d, want 1", released)
}
}
func TestBulkOwnedChunkReleaseOnReset(t *testing.T) {
bulk := newBulkHandle(context.Background(), newBulkRuntime("buffer-release-reset"), clientFileScope(), BulkOpenRequest{
BulkID: "buffer-release-reset",
DataID: 1,
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
released := 0
if err := bulk.pushOwnedChunkWithReleaseNoReset([]byte("hello"), func() {
released++
}); err != nil {
t.Fatalf("pushOwnedChunkWithReleaseNoReset failed: %v", err)
}
bulk.markReset(errors.New("boom"))
if released != 1 {
t.Fatalf("release count = %d, want 1", released)
}
}
func TestBulkReadDoesNotBlockOnAsyncWindowRelease(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
releaseStarted := make(chan struct{})
releaseUnblock := make(chan struct{})
bulk := newBulkHandle(ctx, newBulkRuntime("buffer-release-async"), clientFileScope(), BulkOpenRequest{
BulkID: "buffer-release-async",
DataID: 1,
ChunkSize: 4,
WindowBytes: 4,
MaxInFlight: 1,
}, 0, nil, nil, 0, nil, nil, nil, nil, func(_ *bulkHandle, bytes int64, chunks int) error {
if bytes != 4 || chunks != 1 {
t.Fatalf("release = (%d,%d), want (4,1)", bytes, chunks)
}
close(releaseStarted)
<-releaseUnblock
return nil
})
if err := bulk.pushOwnedChunk([]byte("ping")); err != nil {
t.Fatalf("pushOwnedChunk failed: %v", err)
}
buf := make([]byte, 4)
doneCh := make(chan error, 1)
go func() {
n, err := bulk.Read(buf)
if err != nil {
doneCh <- err
return
}
if got, want := n, 4; got != want {
doneCh <- errors.New("unexpected read size")
return
}
doneCh <- nil
}()
select {
case err := <-doneCh:
if err != nil {
t.Fatalf("Read failed: %v", err)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("Read should not block on async release sender")
}
select {
case <-releaseStarted:
case <-time.After(time.Second):
t.Fatal("window release sender did not start")
}
close(releaseUnblock)
cancel()
if bulk.releaseWorkerDone != nil {
select {
case <-bulk.releaseWorkerDone:
case <-time.After(time.Second):
t.Fatal("release worker did not exit")
}
}
}

View File

@ -7,22 +7,25 @@ import (
)
type BulkOpenRequest struct {
BulkID string
DataID uint64
Range BulkRange
Metadata BulkMetadata
ReadTimeout time.Duration
WriteTimeout time.Duration
Dedicated bool
AttachToken string
ChunkSize int
WindowBytes int
MaxInFlight int
BulkID string
DataID uint64
FastPathVersion uint8
Range BulkRange
Metadata BulkMetadata
ReadTimeout time.Duration
WriteTimeout time.Duration
Dedicated bool
DedicatedLaneID uint32
AttachToken string
ChunkSize int
WindowBytes int
MaxInFlight int
}
type BulkOpenResponse struct {
BulkID string
DataID uint64
FastPathVersion uint8
Accepted bool
Dedicated bool
AttachToken string
@ -53,6 +56,18 @@ type BulkResetResponse struct {
Error string
}
type BulkReadyRequest struct {
BulkID string
DataID uint64
Error string
}
type BulkReadyResponse struct {
BulkID string
Accepted bool
Error string
}
type BulkReleaseRequest struct {
BulkID string
DataID uint64
@ -73,6 +88,9 @@ func bindClientBulkControl(c *ClientCommon) {
c.SetLink(BulkResetSignalKey, func(msg *Message) {
c.handleInboundBulkReset(msg)
})
c.SetLink(BulkReadySignalKey, func(msg *Message) {
c.handleInboundBulkReady(msg)
})
c.SetLink(BulkReleaseSignalKey, func(msg *Message) {
c.handleInboundBulkRelease(msg)
})
@ -91,14 +109,188 @@ func bindServerBulkControl(s *ServerCommon) {
s.SetLink(BulkResetSignalKey, func(msg *Message) {
s.handleInboundBulkReset(msg)
})
s.SetLink(BulkReadySignalKey, func(msg *Message) {
s.handleInboundBulkReady(msg)
})
s.SetLink(BulkReleaseSignalKey, func(msg *Message) {
s.handleInboundBulkRelease(msg)
})
}
func bulkAcceptInfoFromClientBulk(bulk *bulkHandle) BulkAcceptInfo {
if bulk == nil {
return BulkAcceptInfo{}
}
return BulkAcceptInfo{
ID: bulk.ID(),
Range: bulk.Range(),
Metadata: bulk.Metadata(),
Dedicated: bulk.Dedicated(),
TransportGeneration: bulk.TransportGeneration(),
Bulk: bulk,
}
}
func bulkAcceptInfoFromServerBulk(bulk *bulkHandle, logical *LogicalConn, transport *TransportConn) BulkAcceptInfo {
if bulk == nil {
return BulkAcceptInfo{}
}
if logical == nil {
logical = bulk.LogicalConn()
}
if transport == nil {
transport = bulk.TransportConn()
}
return BulkAcceptInfo{
ID: bulk.ID(),
Range: bulk.Range(),
Metadata: bulk.Metadata(),
Dedicated: bulk.Dedicated(),
LogicalConn: logical,
TransportConn: transport,
TransportGeneration: bulk.TransportGeneration(),
Bulk: bulk,
}
}
func dispatchBulkAccept(handler func(BulkAcceptInfo) error, bulk *bulkHandle, info BulkAcceptInfo) error {
if bulk == nil {
return errBulkNotFound
}
if !bulk.markAcceptDispatched() {
return nil
}
var dispatchErr error
defer func() {
bulk.finishAcceptDispatch(dispatchErr)
}()
if handler == nil {
dispatchErr = errBulkHandlerNotConfigured
bulk.markReset(dispatchErr)
return dispatchErr
}
if err := handler(info); err != nil {
dispatchErr = err
bulk.markReset(dispatchErr)
return dispatchErr
}
return nil
}
func (c *ClientCommon) clientBulkAcceptReadyNotifier(bulk *bulkHandle) func(error) {
return func(readyErr error) {
if c == nil || bulk == nil {
return
}
req := BulkReadyRequest{
BulkID: bulk.ID(),
DataID: bulk.dataIDSnapshot(),
}
if readyErr != nil {
req.Error = readyErr.Error()
}
ctx, cancel := context.WithTimeout(context.Background(), defaultBulkAcceptReadyTimeout)
defer cancel()
if _, err := sendBulkReadyClient(ctx, c, req); err != nil && bulk.Context().Err() == nil {
bulk.markReset(err)
}
}
}
func sendBulkReadyServer(ctx context.Context, s *ServerCommon, logical *LogicalConn, transport *TransportConn, req BulkReadyRequest) error {
if s == nil {
return errBulkServerNil
}
if transport != nil {
if _, err := sendBulkReadyServerTransport(ctx, s, transport, req); err == nil {
return nil
} else if !errors.Is(err, errTransportDetached) && !errors.Is(err, errBulkTransportNil) {
return err
}
}
if logical == nil {
return errBulkLogicalConnNil
}
_, err := sendBulkReadyServerLogical(ctx, s, logical, req)
return err
}
func (s *ServerCommon) serverBulkAcceptReadyNotifier(bulk *bulkHandle, logical *LogicalConn, transport *TransportConn) func(error) {
return func(readyErr error) {
if s == nil || bulk == nil {
return
}
req := BulkReadyRequest{
BulkID: bulk.ID(),
DataID: bulk.dataIDSnapshot(),
}
if readyErr != nil {
req.Error = readyErr.Error()
}
ctx, cancel := context.WithTimeout(context.Background(), defaultBulkAcceptReadyTimeout)
defer cancel()
if err := sendBulkReadyServer(ctx, s, logical, transport, req); err != nil && bulk.Context().Err() == nil {
bulk.markReset(err)
}
}
}
func (c *ClientCommon) startClientBulkAcceptDispatch(bulk *bulkHandle) {
if c == nil || bulk == nil {
return
}
bulk.setAcceptNotify(c.clientBulkAcceptReadyNotifier(bulk))
go func() {
_ = c.dispatchClientBulkAccept(bulk)
}()
}
func (s *ServerCommon) startServerBulkAcceptDispatch(bulk *bulkHandle, logical *LogicalConn, transport *TransportConn) {
if s == nil || bulk == nil {
return
}
if logical == nil {
logical = bulk.LogicalConn()
}
if transport == nil {
transport = bulk.TransportConn()
}
bulk.setAcceptNotify(s.serverBulkAcceptReadyNotifier(bulk, logical, transport))
go func() {
_ = s.dispatchServerBulkAccept(bulk, logical, transport)
}()
}
func (c *ClientCommon) dispatchClientBulkAccept(bulk *bulkHandle) error {
if c == nil {
return errBulkClientNil
}
runtime := c.getBulkRuntime()
if runtime == nil {
return errBulkRuntimeNil
}
return dispatchBulkAccept(runtime.handlerSnapshot(), bulk, bulkAcceptInfoFromClientBulk(bulk))
}
func (s *ServerCommon) dispatchServerBulkAccept(bulk *bulkHandle, logical *LogicalConn, transport *TransportConn) error {
if s == nil {
return errBulkServerNil
}
runtime := s.getBulkRuntime()
if runtime == nil {
return errBulkRuntimeNil
}
return dispatchBulkAccept(runtime.handlerSnapshot(), bulk, bulkAcceptInfoFromServerBulk(bulk, logical, transport))
}
func (c *ClientCommon) handleInboundBulkOpen(msg *Message) {
req, err := decodeBulkOpenRequest(msg)
resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated}
resp := BulkOpenResponse{
BulkID: req.BulkID,
DataID: req.DataID,
FastPathVersion: negotiateBulkFastPathVersion(req.FastPathVersion),
Dedicated: req.Dedicated,
}
if err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
@ -133,13 +325,6 @@ func (c *ClientCommon) handleInboundBulkOpen(msg *Message) {
replyBulkControlIfNeeded(msg, resp)
return
}
handler := runtime.handlerSnapshot()
if handler == nil {
bulk.markReset(errBulkHandlerNotConfigured)
resp.Error = errBulkHandlerNotConfigured.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if req.Dedicated {
if err := c.attachDedicatedBulkSidecar(context.Background(), bulk); err != nil {
bulk.markReset(err)
@ -147,17 +332,14 @@ func (c *ClientCommon) handleInboundBulkOpen(msg *Message) {
replyBulkControlIfNeeded(msg, resp)
return
}
resp.Accepted = true
resp.DataID = bulk.dataIDSnapshot()
resp.TransportGeneration = bulk.TransportGeneration()
replyBulkControlIfNeeded(msg, resp)
c.startClientBulkAcceptDispatch(bulk)
return
}
info := BulkAcceptInfo{
ID: bulk.ID(),
Range: bulk.Range(),
Metadata: bulk.Metadata(),
Dedicated: bulk.Dedicated(),
TransportGeneration: bulk.TransportGeneration(),
Bulk: bulk,
}
if err := handler(info); err != nil {
bulk.markReset(err)
if err := c.dispatchClientBulkAccept(bulk); err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
@ -170,7 +352,12 @@ func (c *ClientCommon) handleInboundBulkOpen(msg *Message) {
func (s *ServerCommon) handleInboundBulkOpen(msg *Message) {
req, err := decodeBulkOpenRequest(msg)
resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated}
resp := BulkOpenResponse{
BulkID: req.BulkID,
DataID: req.DataID,
FastPathVersion: negotiateBulkFastPathVersion(req.FastPathVersion),
Dedicated: req.Dedicated,
}
if err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
@ -218,25 +405,24 @@ func (s *ServerCommon) handleInboundBulkOpen(msg *Message) {
replyBulkControlIfNeeded(msg, resp)
return
}
handler := runtime.handlerSnapshot()
if handler == nil {
s.attachServerDedicatedSidecarIfExists(logical, bulk)
if runtime.handlerSnapshot() == nil {
bulk.markReset(errBulkHandlerNotConfigured)
resp.Error = errBulkHandlerNotConfigured.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
info := BulkAcceptInfo{
ID: bulk.ID(),
Range: bulk.Range(),
Metadata: bulk.Metadata(),
Dedicated: bulk.Dedicated(),
LogicalConn: logical,
TransportConn: transport,
TransportGeneration: bulk.TransportGeneration(),
Bulk: bulk,
if req.Dedicated {
resp.Accepted = true
resp.DataID = bulk.dataIDSnapshot()
resp.TransportGeneration = bulk.TransportGeneration()
replyBulkControlIfNeeded(msg, resp)
if bulk.dedicatedAttachedSnapshot() {
s.startServerBulkAcceptDispatch(bulk, logical, messageTransportConnSnapshot(msg))
}
return
}
if err := handler(info); err != nil {
bulk.markReset(err)
if err := s.dispatchServerBulkAccept(bulk, logical, transport); err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
@ -357,6 +543,41 @@ func (c *ClientCommon) handleInboundBulkRelease(msg *Message) {
bulk.releaseOutboundWindow(req.Bytes, req.Chunks)
}
func (c *ClientCommon) handleInboundBulkReady(msg *Message) {
req, err := decodeBulkReadyRequest(msg)
resp := BulkReadyResponse{BulkID: req.BulkID}
if err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
runtime := c.getBulkRuntime()
if runtime == nil {
resp.Error = errBulkRuntimeNil.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
bulk, ok := runtime.lookup(clientFileScope(), req.BulkID)
if !ok && req.DataID != 0 {
bulk, ok = runtime.lookupByDataID(clientFileScope(), req.DataID)
}
if !ok {
resp.Error = errBulkNotFound.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if resp.BulkID == "" {
resp.BulkID = bulk.ID()
}
readyErr := bulkReadyRemoteError(req.Error)
bulk.markAcceptReady(readyErr)
if readyErr != nil {
bulk.markReset(readyErr)
}
resp.Accepted = true
replyBulkControlIfNeeded(msg, resp)
}
func (s *ServerCommon) handleInboundBulkReset(msg *Message) {
req, err := decodeBulkResetRequest(msg)
resp := BulkResetResponse{BulkID: req.BulkID}
@ -390,6 +611,43 @@ func (s *ServerCommon) handleInboundBulkReset(msg *Message) {
replyBulkControlIfNeeded(msg, resp)
}
func (s *ServerCommon) handleInboundBulkReady(msg *Message) {
req, err := decodeBulkReadyRequest(msg)
resp := BulkReadyResponse{BulkID: req.BulkID}
if err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
runtime := s.getBulkRuntime()
if runtime == nil {
resp.Error = errBulkRuntimeNil.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
logical := messageLogicalConnSnapshot(msg)
scope := serverFileScope(logical)
bulk, ok := runtime.lookup(scope, req.BulkID)
if !ok && req.DataID != 0 {
bulk, ok = runtime.lookupByDataID(scope, req.DataID)
}
if !ok {
resp.Error = errBulkNotFound.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if resp.BulkID == "" {
resp.BulkID = bulk.ID()
}
readyErr := bulkReadyRemoteError(req.Error)
bulk.markAcceptReady(readyErr)
if readyErr != nil {
bulk.markReset(readyErr)
}
resp.Accepted = true
replyBulkControlIfNeeded(msg, resp)
}
func (s *ServerCommon) handleInboundBulkRelease(msg *Message) {
req, err := decodeBulkReleaseRequest(msg)
if err != nil {
@ -625,11 +883,26 @@ func decodeBulkReleaseRequest(msg *Message) (BulkReleaseRequest, error) {
return req, nil
}
func decodeBulkReadyRequest(msg *Message) (BulkReadyRequest, error) {
var req BulkReadyRequest
if msg == nil {
return BulkReadyRequest{}, errBulkIDEmpty
}
if err := msg.Value.Orm(&req); err != nil {
return BulkReadyRequest{}, err
}
if req.BulkID == "" && req.DataID == 0 {
return BulkReadyRequest{}, errBulkIDEmpty
}
return req, nil
}
func decodeBulkOpenResponse(msg Message) (BulkOpenResponse, error) {
var resp BulkOpenResponse
if err := msg.Value.Orm(&resp); err != nil {
return BulkOpenResponse{}, err
}
resp.FastPathVersion = normalizeBulkFastPathVersion(resp.FastPathVersion)
return resp, bulkControlResultError("open", resp.Accepted, resp.Error, nil)
}
@ -649,6 +922,14 @@ func decodeBulkResetResponse(msg Message) (BulkResetResponse, error) {
return resp, bulkControlResultError("reset", resp.Accepted, resp.Error, nil)
}
func decodeBulkReadyResponse(msg Message) (BulkReadyResponse, error) {
var resp BulkReadyResponse
if err := msg.Value.Orm(&resp); err != nil {
return BulkReadyResponse{}, err
}
return resp, bulkControlResultError("ready", resp.Accepted, resp.Error, nil)
}
func bulkControlResultError(op string, accepted bool, message string, callErr error) error {
if callErr != nil {
return callErr
@ -697,6 +978,52 @@ func bulkRemoteResetError(message string) error {
return errors.New(message)
}
func bulkReadyRemoteError(message string) error {
if message == "" {
return nil
}
return bulkControlMessageError(message)
}
func sendBulkReadyClient(ctx context.Context, c Client, req BulkReadyRequest) (BulkReadyResponse, error) {
if c == nil {
return BulkReadyResponse{}, errBulkClientNil
}
msg, err := c.SendObjCtx(ctx, BulkReadySignalKey, req)
if err != nil {
return BulkReadyResponse{}, err
}
return decodeBulkReadyResponse(msg)
}
func sendBulkReadyServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkReadyRequest) (BulkReadyResponse, error) {
if s == nil {
return BulkReadyResponse{}, errBulkServerNil
}
if logical == nil {
return BulkReadyResponse{}, errBulkLogicalConnNil
}
msg, err := s.SendObjCtxLogical(ctx, logical, BulkReadySignalKey, req)
if err != nil {
return BulkReadyResponse{}, err
}
return decodeBulkReadyResponse(msg)
}
func sendBulkReadyServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkReadyRequest) (BulkReadyResponse, error) {
if s == nil {
return BulkReadyResponse{}, errBulkServerNil
}
if transport == nil {
return BulkReadyResponse{}, errBulkTransportNil
}
msg, err := s.SendObjCtxTransport(ctx, transport, BulkReadySignalKey, req)
if err != nil {
return BulkReadyResponse{}, err
}
return decodeBulkReadyResponse(msg)
}
func bulkTransportGeneration(logical *LogicalConn, transport *TransportConn) uint64 {
return streamTransportGeneration(logical, transport)
}

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/binary"
"errors"
"net"
"testing"
"time"
@ -136,6 +137,7 @@ func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHa
Dedicated: true,
AttachToken: "attach-token",
}, 0, target, nil, 0, nil, nil, nil, nil, nil)
bulk.markAcceptHandled()
if err := server.getBulkRuntime().register(serverFileScope(target), bulk); err != nil {
t.Fatalf("register bulk runtime failed: %v", err)
}
@ -215,3 +217,251 @@ func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHa
t.Fatalf("attach sidecar logical should be removed after handoff, got %+v", got)
}
}
func TestHandleBulkAttachSystemMessageDoesNotExposeSharedSidecarBeforeReplyCompletes(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
sidecarLeft, sidecarRight := net.Pipe()
defer sidecarRight.Close()
current := server.bootstrapAcceptedLogical("dedicated-attach-current-blocked", nil, sidecarLeft)
if current == nil {
t.Fatal("bootstrapAcceptedLogical(current) should return logical")
}
target := server.bootstrapAcceptedLogical("dedicated-attach-target-blocked", nil, nil)
if target == nil {
t.Fatal("bootstrapAcceptedLogical(target) should return logical")
}
currentBulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{
BulkID: "server-dedicated-current",
DataID: 17,
Dedicated: true,
AttachToken: "attach-token",
}, 0, target, nil, 0, nil, nil, nil, nil, nil)
currentBulk.markAcceptHandled()
if err := server.getBulkRuntime().register(serverFileScope(target), currentBulk); err != nil {
t.Fatalf("register current bulk runtime failed: %v", err)
}
pendingBulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{
BulkID: "server-dedicated-pending",
DataID: 18,
Dedicated: true,
AttachToken: "attach-token-2",
}, 0, target, nil, 0, nil, nil, nil, nil, nil)
pendingBulk.markAcceptHandled()
if err := server.getBulkRuntime().register(serverFileScope(target), pendingBulk); err != nil {
t.Fatalf("register pending bulk runtime failed: %v", err)
}
reqPayload, err := server.sequenceEn(bulkAttachRequest{
PeerID: target.ID(),
BulkID: currentBulk.ID(),
AttachToken: "attach-token",
})
if err != nil {
t.Fatalf("encode bulkAttachRequest failed: %v", err)
}
msg := Message{
NetType: NET_SERVER,
LogicalConn: current,
ClientConn: current.compatClientConn(),
TransferMsg: TransferMsg{
ID: 77,
Key: systemBulkAttachKey,
Value: reqPayload,
Type: MSG_SYS_WAIT,
},
inboundConn: sidecarLeft,
Time: time.Now(),
}
done := make(chan struct{})
go func() {
defer close(done)
_ = server.handleBulkAttachSystemMessage(msg)
}()
time.Sleep(50 * time.Millisecond)
if got := pendingBulk.dedicatedConnSnapshot(); got != nil {
t.Fatal("pending dedicated bulk should not observe shared sidecar before attach reply is fully sent")
}
if got := server.serverDedicatedSidecarSnapshot(target); got != nil {
t.Fatal("server dedicated sidecar should not be published before attach reply is fully sent")
}
type attachReplyResult struct {
transfer TransferMsg
resp bulkAttachResponse
err error
}
replyCh := make(chan attachReplyResult, 1)
go func() {
_ = sidecarRight.SetReadDeadline(time.Now().Add(time.Second))
replyPayload, err := readDirectSignalFramePayload(sidecarRight)
if err != nil {
replyCh <- attachReplyResult{err: err}
return
}
transfer, err := decodeDirectSignalPayload(server.sequenceDe, current.msgDeSnapshot(), current.secretKeySnapshot(), replyPayload)
if err != nil {
replyCh <- attachReplyResult{err: err}
return
}
resp, err := decodeBulkAttachResponse(server.sequenceDe, transfer.Value)
replyCh <- attachReplyResult{transfer: transfer, resp: resp, err: err}
}()
select {
case result := <-replyCh:
if result.err != nil {
t.Fatalf("read direct attach reply failed: %v", result.err)
}
if !result.resp.Accepted {
t.Fatalf("bulk attach response = %+v, want accepted", result.resp)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for direct attach reply")
}
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for handleBulkAttachSystemMessage to finish")
}
if got := pendingBulk.dedicatedConnSnapshot(); got != sidecarLeft {
t.Fatalf("pending dedicated bulk conn mismatch after reply: got %v want %v", got, sidecarLeft)
}
if got := server.serverDedicatedSidecarSnapshot(target); got == nil || got.conn != sidecarLeft {
t.Fatal("server dedicated sidecar should be published after attach reply completes")
}
}
func TestBulkAttachResponseErrorCarriesStructuredFields(t *testing.T) {
resp := toBulkAttachResponseError(&bulkAttachError{
Code: bulkAttachErrorCodeTokenMismatch,
Retryable: false,
Message: "bulk attach token mismatch",
FailedSeq: 7,
FailedBulk: "bulk-1",
}, "fallback-bulk")
if resp.Accepted {
t.Fatalf("Accepted = %v, want false", resp.Accepted)
}
if got, want := resp.Code, string(bulkAttachErrorCodeTokenMismatch); got != want {
t.Fatalf("Code = %q, want %q", got, want)
}
if got, want := resp.Retryable, false; got != want {
t.Fatalf("Retryable = %v, want %v", got, want)
}
if got, want := resp.FailedSeq, uint64(7); got != want {
t.Fatalf("FailedSeq = %d, want %d", got, want)
}
if got, want := resp.FailedBulk, "bulk-1"; got != want {
t.Fatalf("FailedBulk = %q, want %q", got, want)
}
}
func TestResolveInboundDedicatedBulkRejectsAlreadyAttachedAfterDataStarted(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
currentLeft, currentRight := net.Pipe()
defer currentRight.Close()
current := server.bootstrapAcceptedLogical("dedicated-attach-current", nil, currentLeft)
if current == nil {
t.Fatal("bootstrapAcceptedLogical(current) should return logical")
}
target := server.bootstrapAcceptedLogical("dedicated-attach-target", nil, nil)
if target == nil {
t.Fatal("bootstrapAcceptedLogical(target) should return logical")
}
bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{
BulkID: "server-dedicated-attach-test",
DataID: 7,
Dedicated: true,
AttachToken: "attach-token",
}, 0, target, nil, 0, nil, nil, nil, nil, nil)
bulk.markAcceptHandled()
if err := server.getBulkRuntime().register(serverFileScope(target), bulk); err != nil {
t.Fatalf("register bulk runtime failed: %v", err)
}
attachedLeft, attachedRight := net.Pipe()
defer attachedRight.Close()
if err := bulk.attachDedicatedConn(attachedLeft); err != nil {
t.Fatalf("attachDedicatedConn failed: %v", err)
}
bulk.markDedicatedDataStarted()
_, _, err := server.resolveInboundDedicatedBulk(current, bulkAttachRequest{
PeerID: target.ID(),
BulkID: bulk.ID(),
AttachToken: "attach-token",
})
if err == nil {
t.Fatal("resolveInboundDedicatedBulk should reject duplicate attach after data started")
}
var attachErr *bulkAttachError
if !errors.As(err, &attachErr) || attachErr == nil {
t.Fatalf("resolveInboundDedicatedBulk error type = %T, want *bulkAttachError", err)
}
if got, want := attachErr.Code, bulkAttachErrorCodeAlreadyAttached; got != want {
t.Fatalf("attach error code = %q, want %q", got, want)
}
if attachErr.Retryable {
t.Fatalf("attach error retryable = true, want false")
}
}
func TestResolveInboundDedicatedBulkAllowsReattachBeforeDataStarts(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
currentLeft, currentRight := net.Pipe()
defer currentRight.Close()
current := server.bootstrapAcceptedLogical("dedicated-attach-current", nil, currentLeft)
if current == nil {
t.Fatal("bootstrapAcceptedLogical(current) should return logical")
}
target := server.bootstrapAcceptedLogical("dedicated-attach-target", nil, nil)
if target == nil {
t.Fatal("bootstrapAcceptedLogical(target) should return logical")
}
bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{
BulkID: "server-dedicated-attach-test",
DataID: 7,
Dedicated: true,
AttachToken: "attach-token",
}, 0, target, nil, 0, nil, nil, nil, nil, nil)
bulk.markAcceptHandled()
if err := server.getBulkRuntime().register(serverFileScope(target), bulk); err != nil {
t.Fatalf("register bulk runtime failed: %v", err)
}
attachedLeft, attachedRight := net.Pipe()
defer attachedRight.Close()
if err := bulk.attachDedicatedConn(attachedLeft); err != nil {
t.Fatalf("attachDedicatedConn failed: %v", err)
}
resolvedLogical, resolvedBulk, err := server.resolveInboundDedicatedBulk(current, bulkAttachRequest{
PeerID: target.ID(),
BulkID: bulk.ID(),
AttachToken: "attach-token",
})
if err != nil {
t.Fatalf("resolveInboundDedicatedBulk failed: %v", err)
}
if resolvedLogical != target {
t.Fatalf("resolved logical mismatch: got %v want %v", resolvedLogical, target)
}
if resolvedBulk != bulk {
t.Fatalf("resolved bulk mismatch: got %v want %v", resolvedBulk, bulk)
}
}

View File

@ -12,14 +12,18 @@ import (
)
const (
bulkDedicatedBatchMagic = "NBD2"
bulkDedicatedBatchVersion = 1
bulkDedicatedBatchHeaderLen = 20
bulkDedicatedBatchItemHeaderLen = 16
bulkDedicatedBatchMaxItems = 32
bulkDedicatedBatchMaxPlainBytes = 8 * 1024 * 1024
bulkDedicatedSendQueueSize = bulkDedicatedBatchMaxItems
bulkDedicatedReleasePayloadLen = 12
bulkDedicatedBatchMagic = "NBD2"
bulkDedicatedBatchVersion = 1
bulkDedicatedBatchHeaderLen = 20
bulkDedicatedBatchItemHeaderLen = 16
bulkDedicatedSuperBatchMagic = "NBD3"
bulkDedicatedSuperBatchVersion = 1
bulkDedicatedSuperBatchHeaderLen = 12
bulkDedicatedSuperBatchGroupHeaderLen = 12
bulkDedicatedBatchMaxItems = 64
bulkDedicatedBatchMaxPlainBytes = 16 * 1024 * 1024
bulkDedicatedSendQueueSize = bulkDedicatedBatchMaxItems
bulkDedicatedReleasePayloadLen = 12
)
const (
@ -46,6 +50,16 @@ type bulkDedicatedSendRequest struct {
Payload []byte
}
type bulkDedicatedOutboundBatch struct {
DataID uint64
Items []bulkDedicatedSendRequest
}
type bulkDedicatedInboundBatch struct {
DataID uint64
Items []bulkDedicatedBatchItem
}
type bulkDedicatedBatchRequest struct {
Ctx context.Context
Items []bulkDedicatedSendRequest
@ -165,7 +179,7 @@ func (s *bulkDedicatedSender) submitWriteBatch(ctx context.Context, items []bulk
if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted {
return err
}
queuedItems := cloneBulkDedicatedSendRequests(items)
queuedItems := copyBulkDedicatedSendRequests(items)
return s.submitBatch(ctx, queuedItems, true)
}
@ -471,6 +485,23 @@ func cloneBulkDedicatedSendRequests(items []bulkDedicatedSendRequest) []bulkDedi
return cloned
}
func copyBulkDedicatedSendRequests(items []bulkDedicatedSendRequest) []bulkDedicatedSendRequest {
if len(items) == 0 {
return nil
}
copied := make([]bulkDedicatedSendRequest, len(items))
copy(copied, items)
return copied
}
func bulkDedicatedSendRequestsLen(items []bulkDedicatedSendRequest) int {
total := 0
for _, item := range items {
total += bulkDedicatedSendRequestLen(item)
}
return total
}
func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int {
return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload))
}
@ -479,6 +510,83 @@ func bulkDedicatedSendRequestLenFromPayloadLen(payloadLen int) int {
return bulkDedicatedBatchItemHeaderLen + payloadLen
}
func bulkDedicatedBatchesPlainLen(batches []bulkDedicatedOutboundBatch) int {
switch len(batches) {
case 0:
return 0
case 1:
return bulkDedicatedBatchPlainLen(batches[0].Items)
default:
total := bulkDedicatedSuperBatchHeaderLen
for _, batch := range batches {
total += bulkDedicatedSuperBatchGroupHeaderLen + bulkDedicatedSendRequestsLen(batch.Items)
}
return total
}
}
func encodeBulkDedicatedBatchesPlain(batches []bulkDedicatedOutboundBatch) ([]byte, error) {
if len(batches) == 0 {
return nil, errBulkFastPayloadInvalid
}
total := bulkDedicatedBatchesPlainLen(batches)
buf := make([]byte, total)
if err := writeBulkDedicatedBatchesPlain(buf, batches); err != nil {
return nil, err
}
return buf, nil
}
func writeBulkDedicatedBatchesPlain(buf []byte, batches []bulkDedicatedOutboundBatch) error {
switch len(batches) {
case 0:
return errBulkFastPayloadInvalid
case 1:
return writeBulkDedicatedBatchPlain(buf, batches[0].DataID, batches[0].Items)
default:
return writeBulkDedicatedSuperBatchPlain(buf, batches)
}
}
func writeBulkDedicatedSuperBatchPlain(buf []byte, batches []bulkDedicatedOutboundBatch) error {
if len(batches) <= 1 {
return errBulkFastPayloadInvalid
}
if len(buf) != bulkDedicatedBatchesPlainLen(batches) || len(buf) > bulkDedicatedBatchMaxPlainBytes {
return errBulkFastPayloadInvalid
}
copy(buf[:4], bulkDedicatedSuperBatchMagic)
buf[4] = bulkDedicatedSuperBatchVersion
binary.BigEndian.PutUint32(buf[8:12], uint32(len(batches)))
offset := bulkDedicatedSuperBatchHeaderLen
totalItems := 0
for _, batch := range batches {
if batch.DataID == 0 || len(batch.Items) == 0 {
return errBulkFastPayloadInvalid
}
totalItems += len(batch.Items)
if totalItems > bulkDedicatedBatchMaxItems {
return errBulkFastPayloadInvalid
}
binary.BigEndian.PutUint64(buf[offset:offset+8], batch.DataID)
binary.BigEndian.PutUint32(buf[offset+8:offset+12], uint32(len(batch.Items)))
offset += bulkDedicatedSuperBatchGroupHeaderLen
for _, item := range batch.Items {
buf[offset] = item.Type
buf[offset+1] = item.Flags
binary.BigEndian.PutUint64(buf[offset+4:offset+12], item.Seq)
binary.BigEndian.PutUint32(buf[offset+12:offset+16], uint32(len(item.Payload)))
offset += bulkDedicatedBatchItemHeaderLen
copy(buf[offset:offset+len(item.Payload)], item.Payload)
offset += len(item.Payload)
}
}
if offset != len(buf) {
return errBulkFastPayloadInvalid
}
return nil
}
func encodeBulkDedicatedReleasePayload(bytes int64, chunks int) ([]byte, error) {
if bytes <= 0 && chunks <= 0 {
return nil, errBulkFastPayloadInvalid
@ -505,24 +613,26 @@ func decodeBulkDedicatedReleasePayload(payload []byte) (int64, int, error) {
}
func encodeBulkDedicatedBatchPlain(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
if dataID == 0 || len(items) == 0 {
return nil, errBulkFastPayloadInvalid
}
total := bulkDedicatedBatchPlainLen(items)
buf := make([]byte, total)
if err := writeBulkDedicatedBatchPlain(buf, dataID, items); err != nil {
return nil, err
}
return buf, nil
return encodeBulkDedicatedBatchesPlain([]bulkDedicatedOutboundBatch{{
DataID: dataID,
Items: items,
}})
}
func encodeBulkDedicatedBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
return encodeBulkDedicatedBatchesPayloadFast(encode, secretKey, []bulkDedicatedOutboundBatch{{
DataID: dataID,
Items: items,
}})
}
func encodeBulkDedicatedBatchesPayloadFast(encode transportFastPlainEncoder, secretKey []byte, batches []bulkDedicatedOutboundBatch) ([]byte, error) {
if encode == nil {
return nil, errTransportPayloadEncryptFailed
}
plainLen := bulkDedicatedBatchPlainLen(items)
plainLen := bulkDedicatedBatchesPlainLen(batches)
return encode(secretKey, plainLen, func(dst []byte) error {
return writeBulkDedicatedBatchPlain(dst, dataID, items)
return writeBulkDedicatedBatchesPlain(dst, batches)
})
}
@ -606,29 +716,268 @@ func decodeBulkDedicatedBatchPlain(payload []byte) (uint64, []bulkDedicatedBatch
return dataID, items, true, nil
}
func decodeDedicatedBulkInboundItems(expectedDataID uint64, plain []byte) ([]bulkDedicatedBatchItem, error) {
func decodeBulkDedicatedSuperBatchPlain(payload []byte) ([]bulkDedicatedInboundBatch, bool, error) {
if len(payload) < 4 || string(payload[:4]) != bulkDedicatedSuperBatchMagic {
return nil, false, nil
}
if len(payload) < bulkDedicatedSuperBatchHeaderLen {
return nil, true, errBulkFastPayloadInvalid
}
if payload[4] != bulkDedicatedSuperBatchVersion {
return nil, true, errBulkFastPayloadInvalid
}
groupCount := int(binary.BigEndian.Uint32(payload[8:12]))
if groupCount <= 0 {
return nil, true, errBulkFastPayloadInvalid
}
batches := make([]bulkDedicatedInboundBatch, 0, groupCount)
offset := bulkDedicatedSuperBatchHeaderLen
totalItems := 0
for i := 0; i < groupCount; i++ {
if len(payload)-offset < bulkDedicatedSuperBatchGroupHeaderLen {
return nil, true, errBulkFastPayloadInvalid
}
dataID := binary.BigEndian.Uint64(payload[offset : offset+8])
count := int(binary.BigEndian.Uint32(payload[offset+8 : offset+12]))
offset += bulkDedicatedSuperBatchGroupHeaderLen
if dataID == 0 || count <= 0 {
return nil, true, errBulkFastPayloadInvalid
}
totalItems += count
if totalItems > bulkDedicatedBatchMaxItems {
return nil, true, errBulkFastPayloadInvalid
}
items := make([]bulkDedicatedBatchItem, 0, count)
for j := 0; j < count; j++ {
if len(payload)-offset < bulkDedicatedBatchItemHeaderLen {
return nil, true, errBulkFastPayloadInvalid
}
itemType := payload[offset]
switch itemType {
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
default:
return nil, true, errBulkFastPayloadInvalid
}
flags := payload[offset+1]
seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16]))
offset += bulkDedicatedBatchItemHeaderLen
if dataLen < 0 || len(payload)-offset < dataLen {
return nil, true, errBulkFastPayloadInvalid
}
items = append(items, bulkDedicatedBatchItem{
Type: itemType,
Flags: flags,
Seq: seq,
Payload: payload[offset : offset+dataLen],
})
offset += dataLen
}
batches = append(batches, bulkDedicatedInboundBatch{
DataID: dataID,
Items: items,
})
}
if offset != len(payload) {
return nil, true, errBulkFastPayloadInvalid
}
return batches, true, nil
}
func decodeDedicatedBulkInboundBatches(plain []byte) ([]bulkDedicatedInboundBatch, error) {
if batches, matched, err := decodeBulkDedicatedSuperBatchPlain(plain); matched {
if err != nil {
return nil, err
}
return batches, nil
}
if dataID, items, matched, err := decodeBulkDedicatedBatchPlain(plain); matched {
if err != nil {
return nil, err
}
if expectedDataID == 0 || dataID != expectedDataID {
return nil, errBulkFastPayloadInvalid
}
return items, nil
return []bulkDedicatedInboundBatch{{
DataID: dataID,
Items: items,
}}, nil
}
frame, matched, err := decodeBulkFastFrame(plain)
if err != nil {
return nil, err
}
if !matched || expectedDataID == 0 || frame.DataID != expectedDataID {
if !matched || frame.DataID == 0 {
return nil, errBulkFastPayloadInvalid
}
return []bulkDedicatedBatchItem{{
return []bulkDedicatedInboundBatch{{
DataID: frame.DataID,
Items: []bulkDedicatedBatchItem{{
Type: frame.Type,
Flags: frame.Flags,
Seq: frame.Seq,
Payload: frame.Payload,
}},
}}, nil
}
func decodeDedicatedBulkInboundPayload(plain []byte) (uint64, []bulkDedicatedBatchItem, error) {
batches, err := decodeDedicatedBulkInboundBatches(plain)
if err != nil {
return 0, nil, err
}
if len(batches) != 1 {
return 0, nil, errBulkFastPayloadInvalid
}
return batches[0].DataID, batches[0].Items, nil
}
func decodeDedicatedBulkInboundItems(expectedDataID uint64, plain []byte) ([]bulkDedicatedBatchItem, error) {
dataID, items, err := decodeDedicatedBulkInboundPayload(plain)
if err != nil {
return nil, err
}
if expectedDataID == 0 || dataID != expectedDataID {
return nil, errBulkFastPayloadInvalid
}
return items, nil
}
func hasBulkDedicatedMagic(payload []byte, magic string) bool {
return len(payload) >= 4 &&
payload[0] == magic[0] &&
payload[1] == magic[1] &&
payload[2] == magic[2] &&
payload[3] == magic[3]
}
func walkDedicatedBulkInboundPayload(plain []byte, visit func(dataID uint64, item bulkDedicatedBatchItem) error) error {
if visit == nil {
return errBulkFastPayloadInvalid
}
if hasBulkDedicatedMagic(plain, bulkDedicatedSuperBatchMagic) {
return walkDedicatedBulkInboundSuperBatchPlain(plain, visit)
}
if hasBulkDedicatedMagic(plain, bulkDedicatedBatchMagic) {
return walkDedicatedBulkInboundBatchPlain(plain, visit)
}
frame, matched, err := decodeBulkFastFrame(plain)
if err != nil {
return err
}
if !matched || frame.DataID == 0 {
return errBulkFastPayloadInvalid
}
return visit(frame.DataID, bulkDedicatedBatchItem{
Type: frame.Type,
Flags: frame.Flags,
Seq: frame.Seq,
Payload: frame.Payload,
}}, nil
})
}
func walkDedicatedBulkInboundBatchPlain(payload []byte, visit func(dataID uint64, item bulkDedicatedBatchItem) error) error {
if len(payload) < bulkDedicatedBatchHeaderLen {
return errBulkFastPayloadInvalid
}
if payload[4] != bulkDedicatedBatchVersion {
return errBulkFastPayloadInvalid
}
dataID := binary.BigEndian.Uint64(payload[8:16])
count := int(binary.BigEndian.Uint32(payload[16:20]))
if dataID == 0 || count <= 0 {
return errBulkFastPayloadInvalid
}
offset := bulkDedicatedBatchHeaderLen
for i := 0; i < count; i++ {
if len(payload)-offset < bulkDedicatedBatchItemHeaderLen {
return errBulkFastPayloadInvalid
}
itemType := payload[offset]
switch itemType {
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
default:
return errBulkFastPayloadInvalid
}
flags := payload[offset+1]
seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16]))
offset += bulkDedicatedBatchItemHeaderLen
if dataLen < 0 || len(payload)-offset < dataLen {
return errBulkFastPayloadInvalid
}
if err := visit(dataID, bulkDedicatedBatchItem{
Type: itemType,
Flags: flags,
Seq: seq,
Payload: payload[offset : offset+dataLen],
}); err != nil {
return err
}
offset += dataLen
}
if offset != len(payload) {
return errBulkFastPayloadInvalid
}
return nil
}
func walkDedicatedBulkInboundSuperBatchPlain(payload []byte, visit func(dataID uint64, item bulkDedicatedBatchItem) error) error {
if len(payload) < bulkDedicatedSuperBatchHeaderLen {
return errBulkFastPayloadInvalid
}
if payload[4] != bulkDedicatedSuperBatchVersion {
return errBulkFastPayloadInvalid
}
groupCount := int(binary.BigEndian.Uint32(payload[8:12]))
if groupCount <= 0 {
return errBulkFastPayloadInvalid
}
offset := bulkDedicatedSuperBatchHeaderLen
totalItems := 0
for i := 0; i < groupCount; i++ {
if len(payload)-offset < bulkDedicatedSuperBatchGroupHeaderLen {
return errBulkFastPayloadInvalid
}
dataID := binary.BigEndian.Uint64(payload[offset : offset+8])
count := int(binary.BigEndian.Uint32(payload[offset+8 : offset+12]))
offset += bulkDedicatedSuperBatchGroupHeaderLen
if dataID == 0 || count <= 0 {
return errBulkFastPayloadInvalid
}
totalItems += count
if totalItems > bulkDedicatedBatchMaxItems {
return errBulkFastPayloadInvalid
}
for j := 0; j < count; j++ {
if len(payload)-offset < bulkDedicatedBatchItemHeaderLen {
return errBulkFastPayloadInvalid
}
itemType := payload[offset]
switch itemType {
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
default:
return errBulkFastPayloadInvalid
}
flags := payload[offset+1]
seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16]))
offset += bulkDedicatedBatchItemHeaderLen
if dataLen < 0 || len(payload)-offset < dataLen {
return errBulkFastPayloadInvalid
}
if err := visit(dataID, bulkDedicatedBatchItem{
Type: itemType,
Flags: flags,
Seq: seq,
Payload: payload[offset : offset+dataLen],
}); err != nil {
return err
}
offset += dataLen
}
}
if offset != len(payload) {
return errBulkFastPayloadInvalid
}
return nil
}
func normalizeDedicatedBulkSendError(err error) error {
@ -643,13 +992,23 @@ func normalizeDedicatedBulkSendError(err error) error {
}
func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchItem) error {
return dispatchDedicatedBulkInboundItemWithRelease(bulk, item, nil)
}
func dispatchDedicatedBulkInboundItemWithRelease(bulk *bulkHandle, item bulkDedicatedBatchItem, release func()) error {
if bulk == nil {
if release != nil {
release()
}
return io.ErrClosedPipe
}
switch item.Type {
case bulkFastPayloadTypeData:
return bulk.pushOwnedChunkNoReset(item.Payload)
return bulk.pushOwnedChunkWithReleaseNoReset(item.Payload, release)
case bulkFastPayloadTypeClose:
if release != nil {
release()
}
if item.Flags&bulkFastPayloadFlagFullClose != 0 {
bulk.markPeerClosed()
return nil
@ -657,6 +1016,9 @@ func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchI
bulk.markRemoteClosed()
return nil
case bulkFastPayloadTypeReset:
if release != nil {
release()
}
resetErr := errBulkReset
if len(item.Payload) > 0 {
resetErr = bulkRemoteResetError(string(item.Payload))
@ -664,6 +1026,9 @@ func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchI
bulk.markReset(bulkResetError(resetErr))
return nil
case bulkFastPayloadTypeRelease:
if release != nil {
release()
}
bytes, chunks, err := decodeBulkDedicatedReleasePayload(item.Payload)
if err != nil {
return err
@ -671,6 +1036,9 @@ func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchI
bulk.releaseOutboundWindow(bytes, chunks)
return nil
default:
if release != nil {
release()
}
return errBulkFastPayloadInvalid
}
}

View File

@ -0,0 +1,639 @@
package notify
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
)
const bulkDedicatedLaneMicroBatchWait = 50 * time.Microsecond
type bulkDedicatedLaneBatchRequest struct {
Ctx context.Context
DataID uint64
Items []bulkDedicatedSendRequest
Deadline time.Time
wait bool
resultCh chan error
state bulkDedicatedRequestState
aborted atomic.Bool
}
var bulkDedicatedLaneBatchRequestPool sync.Pool
type bulkDedicatedLaneSender struct {
conn net.Conn
encode func([]bulkDedicatedOutboundBatch) ([]byte, func(), error)
fail func(error)
reqCh chan *bulkDedicatedLaneBatchRequest
stopCh chan struct{}
doneCh chan struct{}
stopOnce sync.Once
flushMu sync.Mutex
queued atomic.Int64
errMu sync.Mutex
err error
}
func newBulkDedicatedLaneSender(conn net.Conn, encode func([]bulkDedicatedOutboundBatch) ([]byte, func(), error), fail func(error)) *bulkDedicatedLaneSender {
sender := &bulkDedicatedLaneSender{
conn: conn,
encode: encode,
fail: fail,
reqCh: make(chan *bulkDedicatedLaneBatchRequest, bulkDedicatedSendQueueSize),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
go sender.run()
return sender
}
func getBulkDedicatedLaneBatchRequest() *bulkDedicatedLaneBatchRequest {
if pooled, ok := bulkDedicatedLaneBatchRequestPool.Get().(*bulkDedicatedLaneBatchRequest); ok && pooled != nil {
pooled.reset()
return pooled
}
req := &bulkDedicatedLaneBatchRequest{
resultCh: make(chan error, 1),
}
req.reset()
return req
}
func (r *bulkDedicatedLaneBatchRequest) reset() {
if r == nil {
return
}
r.Ctx = nil
r.DataID = 0
r.Deadline = time.Time{}
r.wait = false
r.Items = r.Items[:0]
r.state.value.Store(bulkDedicatedRequestQueued)
r.aborted.Store(false)
if r.resultCh != nil {
select {
case <-r.resultCh:
default:
}
}
}
func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool) {
if r == nil {
return
}
r.reset()
r.Ctx = ctx
r.DataID = dataID
r.wait = wait
if deadline, ok := ctx.Deadline(); ok {
r.Deadline = deadline
}
if cap(r.Items) < len(items) {
r.Items = make([]bulkDedicatedSendRequest, len(items))
} else {
r.Items = r.Items[:len(items)]
}
copy(r.Items, items)
}
func (r *bulkDedicatedLaneBatchRequest) recycle() {
if r == nil {
return
}
r.reset()
bulkDedicatedLaneBatchRequestPool.Put(r)
}
func (s *bulkDedicatedLaneSender) submitData(ctx context.Context, dataID uint64, seq uint64, payload []byte) error {
if s == nil {
return errTransportDetached
}
items := []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: append([]byte(nil), payload...),
}}
return s.submitBatch(ctx, dataID, items, false)
}
func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int) (int, error) {
if s == nil {
return 0, errTransportDetached
}
if len(payload) == 0 {
return 0, nil
}
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
written := 0
seq := startSeq
for written < len(payload) {
var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest
items := itemBuf[:0]
batchBytes := bulkDedicatedBatchHeaderLen
start := written
for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written)
if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes {
break
}
items = append(items, bulkDedicatedSendRequest{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: payload[written:end],
})
batchBytes += itemLen
seq++
written = end
}
if len(items) == 0 {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
items = append(items, bulkDedicatedSendRequest{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: payload[written:end],
})
seq++
written = end
}
if err := s.submitWriteBatch(ctx, dataID, items); err != nil {
return start, err
}
start = written
}
return written, nil
}
func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) error {
if s == nil {
return errTransportDetached
}
if len(items) == 0 {
return nil
}
if submitted, err := s.tryDirectSubmitBatch(ctx, dataID, items); submitted {
return err
}
queuedItems := copyBulkDedicatedSendRequests(items)
return s.submitBatch(ctx, dataID, queuedItems, true)
}
func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint64, frameType uint8, flags uint8, seq uint64, payload []byte) error {
if s == nil {
return errTransportDetached
}
items := []bulkDedicatedSendRequest{{
Type: frameType,
Flags: flags,
Seq: seq,
}}
if len(payload) > 0 {
items[0].Payload = append([]byte(nil), payload...)
}
return s.submitBatch(ctx, dataID, items, true)
}
func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool) error {
if s == nil {
return errTransportDetached
}
if ctx == nil {
ctx = context.Background()
}
if err := s.errSnapshot(); err != nil {
return err
}
req := getBulkDedicatedLaneBatchRequest()
req.prepare(ctx, dataID, items, wait)
s.queued.Add(1)
select {
case <-ctx.Done():
s.queued.Add(-1)
req.recycle()
return normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
s.queued.Add(-1)
req.recycle()
return s.stoppedErr()
case s.reqCh <- req:
if !wait {
return nil
}
return s.waitAck(req)
}
}
func (s *bulkDedicatedLaneSender) tryDirectSubmitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) (bool, error) {
if s == nil {
return true, errTransportDetached
}
if ctx == nil {
ctx = context.Background()
}
if len(items) == 0 {
return true, nil
}
if err := s.errSnapshot(); err != nil {
return true, err
}
select {
case <-ctx.Done():
return true, normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
return true, s.stoppedErr()
default:
}
if s.queued.Load() != 0 {
return false, nil
}
if !s.flushMu.TryLock() {
return false, nil
}
defer s.flushMu.Unlock()
if s.queued.Load() != 0 {
return false, nil
}
if err := s.errSnapshot(); err != nil {
return true, err
}
select {
case <-ctx.Done():
return true, normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
return true, s.stoppedErr()
default:
}
deadline, _ := ctx.Deadline()
if err := s.flush([]bulkDedicatedOutboundBatch{{
DataID: dataID,
Items: items,
}}, deadline); err != nil {
err = normalizeDedicatedBulkSendError(err)
s.setErr(err)
s.failPending(err)
if s.fail != nil {
go s.fail(err)
}
return true, err
}
return true, nil
}
func (s *bulkDedicatedLaneSender) waitAck(req *bulkDedicatedLaneBatchRequest) error {
if s == nil {
return errTransportDetached
}
if req == nil {
return errTransportDetached
}
ctx := req.Ctx
if ctx == nil {
ctx = context.Background()
}
select {
case err := <-req.resultCh:
req.recycle()
return normalizeDedicatedBulkSendError(err)
case <-ctx.Done():
if req.tryCancel() {
req.aborted.Store(true)
return normalizeStreamDeadlineError(ctx.Err())
}
err := <-req.resultCh
req.recycle()
return normalizeDedicatedBulkSendError(err)
}
}
func (s *bulkDedicatedLaneSender) stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
s.setErr(errTransportDetached)
close(s.stopCh)
})
<-s.doneCh
}
func (s *bulkDedicatedLaneSender) run() {
defer close(s.doneCh)
var carry *bulkDedicatedLaneBatchRequest
for {
req, ok := s.nextRequest(carry)
carry = nil
if !ok {
return
}
if !req.tryStart() {
s.finishRequest(req, req.canceledErr())
continue
}
if err := req.contextErr(); err != nil {
s.finishRequest(req, err)
continue
}
batchReqs := []*bulkDedicatedLaneBatchRequest{req}
batches := []bulkDedicatedOutboundBatch{{
DataID: req.DataID,
Items: req.Items,
}}
batchBytes := bulkDedicatedBatchesPlainLen(batches)
deadline := req.Deadline
s.flushMu.Lock()
err := s.errSnapshot()
if err == nil {
carry, err = s.collectBatchRequests(&batchReqs, &batches, &batchBytes, &deadline)
if err == nil {
err = s.flush(batches, deadline)
}
}
s.flushMu.Unlock()
if err != nil {
err = normalizeDedicatedBulkSendError(err)
s.setErr(err)
s.finishBatchRequests(batchReqs, err)
if carry != nil {
s.finishRequest(carry, err)
carry = nil
}
s.failPending(err)
if s.fail != nil {
go s.fail(err)
}
return
}
s.finishBatchRequests(batchReqs, nil)
}
}
func appendBulkDedicatedLaneBatch(batches *[]bulkDedicatedOutboundBatch, req *bulkDedicatedLaneBatchRequest) {
if req == nil {
return
}
if len(*batches) > 0 && (*batches)[len(*batches)-1].DataID == req.DataID {
last := &(*batches)[len(*batches)-1]
last.Items = append(last.Items, req.Items...)
return
}
*batches = append(*batches, bulkDedicatedOutboundBatch{
DataID: req.DataID,
Items: req.Items,
})
}
func bulkDedicatedLaneBatchItemCount(batches []bulkDedicatedOutboundBatch) int {
total := 0
for _, batch := range batches {
total += len(batch.Items)
}
return total
}
func bulkDedicatedLaneNextBatchBytes(batches []bulkDedicatedOutboundBatch, req *bulkDedicatedLaneBatchRequest, currentBytes int) int {
reqBytes := bulkDedicatedSendRequestsLen(req.Items)
if len(batches) == 0 {
return bulkDedicatedBatchHeaderLen + reqBytes
}
last := batches[len(batches)-1]
if last.DataID == req.DataID {
return currentBytes + reqBytes
}
if len(batches) == 1 {
return bulkDedicatedSuperBatchHeaderLen +
bulkDedicatedSuperBatchGroupHeaderLen + bulkDedicatedSendRequestsLen(batches[0].Items) +
bulkDedicatedSuperBatchGroupHeaderLen + reqBytes
}
return currentBytes + bulkDedicatedSuperBatchGroupHeaderLen + reqBytes
}
func (r *bulkDedicatedLaneBatchRequest) contextErr() error {
if r.Ctx == nil {
return nil
}
select {
case <-r.Ctx.Done():
return normalizeStreamDeadlineError(r.Ctx.Err())
default:
return nil
}
}
func (r *bulkDedicatedLaneBatchRequest) tryStart() bool {
if r == nil {
return false
}
return r.state.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestStarted)
}
func (r *bulkDedicatedLaneBatchRequest) tryCancel() bool {
if r == nil {
return false
}
return r.state.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestCanceled)
}
func (r *bulkDedicatedLaneBatchRequest) canceledErr() error {
if r == nil {
return context.Canceled
}
if err := r.contextErr(); err != nil {
return err
}
return context.Canceled
}
func (s *bulkDedicatedLaneSender) nextRequest(carry *bulkDedicatedLaneBatchRequest) (*bulkDedicatedLaneBatchRequest, bool) {
if carry != nil {
select {
case <-s.stopCh:
err := s.stoppedErr()
s.finishRequest(carry, err)
s.failPending(err)
return nil, false
default:
return carry, true
}
}
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return nil, false
case req := <-s.reqCh:
return req, true
}
}
func (s *bulkDedicatedLaneSender) collectBatchRequests(batchReqs *[]*bulkDedicatedLaneBatchRequest, batches *[]bulkDedicatedOutboundBatch, batchBytes *int, deadline *time.Time) (*bulkDedicatedLaneBatchRequest, error) {
if s == nil || bulkDedicatedLaneBatchItemCount(*batches) >= bulkDedicatedBatchMaxItems || *batchBytes >= bulkDedicatedBatchMaxPlainBytes {
return nil, nil
}
wait := bulkDedicatedLaneMicroBatchWait
var (
timer *time.Timer
timerCh <-chan time.Time
waited bool
)
if wait > 0 {
timer = time.NewTimer(wait)
timerCh = timer.C
defer timer.Stop()
}
for {
if bulkDedicatedLaneBatchItemCount(*batches) >= bulkDedicatedBatchMaxItems || *batchBytes >= bulkDedicatedBatchMaxPlainBytes {
return nil, nil
}
var (
req *bulkDedicatedLaneBatchRequest
ok bool
)
select {
case <-s.stopCh:
return nil, s.stoppedErr()
case req = <-s.reqCh:
ok = true
default:
}
if !ok {
if waited || timerCh == nil {
return nil, nil
}
waited = true
select {
case <-s.stopCh:
return nil, s.stoppedErr()
case req = <-s.reqCh:
case <-timerCh:
return nil, nil
}
}
if err := req.contextErr(); err != nil {
if req.tryCancel() {
s.finishRequest(req, err)
continue
}
s.finishRequest(req, req.canceledErr())
continue
}
nextItems := bulkDedicatedLaneBatchItemCount(*batches) + len(req.Items)
if nextItems > bulkDedicatedBatchMaxItems {
return req, nil
}
nextBytes := bulkDedicatedLaneNextBatchBytes(*batches, req, *batchBytes)
if nextBytes > bulkDedicatedBatchMaxPlainBytes {
return req, nil
}
if !req.tryStart() {
s.finishRequest(req, req.canceledErr())
continue
}
if err := req.contextErr(); err != nil {
s.finishRequest(req, err)
continue
}
*batchReqs = append(*batchReqs, req)
appendBulkDedicatedLaneBatch(batches, req)
*batchBytes = nextBytes
*deadline = earlierDeadline(*deadline, req.Deadline)
}
}
func earlierDeadline(current time.Time, next time.Time) time.Time {
if current.IsZero() {
return next
}
if next.IsZero() || current.Before(next) {
return current
}
return next
}
func (s *bulkDedicatedLaneSender) flush(batches []bulkDedicatedOutboundBatch, deadline time.Time) error {
if s == nil || s.conn == nil {
return errTransportDetached
}
payload, release, err := s.encode(batches)
if err != nil {
return err
}
if release != nil {
defer release()
}
return writeBulkDedicatedRecordWithDeadline(s.conn, payload, deadline)
}
func (s *bulkDedicatedLaneSender) finishRequest(req *bulkDedicatedLaneBatchRequest, err error) {
if s != nil {
s.queued.Add(-1)
}
if req == nil {
return
}
if req.wait && req.resultCh != nil {
if req.aborted.Load() {
req.recycle()
return
}
req.resultCh <- err
return
}
req.recycle()
}
func (s *bulkDedicatedLaneSender) finishBatchRequests(reqs []*bulkDedicatedLaneBatchRequest, err error) {
for _, req := range reqs {
s.finishRequest(req, err)
}
}
func (s *bulkDedicatedLaneSender) failPending(err error) {
for {
select {
case item := <-s.reqCh:
s.finishRequest(item, err)
default:
return
}
}
}
func (s *bulkDedicatedLaneSender) setErr(err error) {
if s == nil || err == nil {
return
}
s.errMu.Lock()
if s.err == nil {
s.err = err
}
s.errMu.Unlock()
}
func (s *bulkDedicatedLaneSender) errSnapshot() error {
if s == nil {
return errTransportDetached
}
s.errMu.Lock()
defer s.errMu.Unlock()
return s.err
}
func (s *bulkDedicatedLaneSender) stoppedErr() error {
if err := s.errSnapshot(); err != nil {
return err
}
return errTransportDetached
}

View File

@ -0,0 +1,145 @@
package notify
import (
"bytes"
"context"
"testing"
"time"
)
func TestBulkDedicatedLaneSenderCollectBatchRequestsBatchesAcrossDataIDs(t *testing.T) {
sender := &bulkDedicatedLaneSender{
reqCh: make(chan *bulkDedicatedLaneBatchRequest, 3),
stopCh: make(chan struct{}),
}
first := &bulkDedicatedLaneBatchRequest{
Ctx: context.Background(),
DataID: 7,
Items: []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeData,
Seq: 1,
Payload: []byte("a"),
}},
Deadline: time.Now().Add(2 * time.Second),
}
if !first.tryStart() {
t.Fatal("first request should start")
}
second := &bulkDedicatedLaneBatchRequest{
Ctx: context.Background(),
DataID: 7,
Items: []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeData,
Seq: 2,
Payload: []byte("b"),
}},
Deadline: time.Now().Add(time.Second),
}
third := &bulkDedicatedLaneBatchRequest{
Ctx: context.Background(),
DataID: 8,
Items: []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeData,
Seq: 3,
Payload: []byte("c"),
}},
Deadline: time.Now().Add(3 * time.Second),
}
sender.reqCh <- second
sender.reqCh <- third
batchReqs := []*bulkDedicatedLaneBatchRequest{first}
batches := []bulkDedicatedOutboundBatch{{
DataID: first.DataID,
Items: first.Items,
}}
batchBytes := bulkDedicatedBatchesPlainLen(batches)
deadline := first.Deadline
carry, err := sender.collectBatchRequests(&batchReqs, &batches, &batchBytes, &deadline)
if err != nil {
t.Fatalf("collectBatchRequests error = %v", err)
}
if carry != nil {
t.Fatalf("carry request = %+v, want nil", carry)
}
if len(batchReqs) != 3 {
t.Fatalf("batched request count = %d, want 3", len(batchReqs))
}
if len(batches) != 2 {
t.Fatalf("batched group count = %d, want 2", len(batches))
}
if got, want := batches[0].DataID, uint64(7); got != want {
t.Fatalf("first batch dataID = %d, want %d", got, want)
}
if got, want := len(batches[0].Items), 2; got != want {
t.Fatalf("first batch item count = %d, want %d", got, want)
}
if got, want := batches[1].DataID, uint64(8); got != want {
t.Fatalf("second batch dataID = %d, want %d", got, want)
}
if got, want := len(batches[1].Items), 1; got != want {
t.Fatalf("second batch item count = %d, want %d", got, want)
}
if got, want := batches[0].Items[0].Seq, uint64(1); got != want {
t.Fatalf("first batch seq = %d, want %d", got, want)
}
if got, want := batches[0].Items[1].Seq, uint64(2); got != want {
t.Fatalf("second batch seq = %d, want %d", got, want)
}
if got, want := batches[1].Items[0].Seq, uint64(3); got != want {
t.Fatalf("third batch seq = %d, want %d", got, want)
}
if !deadline.Equal(second.Deadline) {
t.Fatalf("merged deadline = %v, want %v", deadline, second.Deadline)
}
}
func TestDedicatedBulkSuperBatchRoundTrip(t *testing.T) {
batches := []bulkDedicatedOutboundBatch{
{
DataID: 7,
Items: []bulkDedicatedSendRequest{
{Type: bulkFastPayloadTypeData, Seq: 1, Payload: []byte("alpha")},
{Type: bulkFastPayloadTypeData, Seq: 2, Payload: []byte("beta")},
},
},
{
DataID: 8,
Items: []bulkDedicatedSendRequest{
{Type: bulkFastPayloadTypeClose, Flags: bulkFastPayloadFlagFullClose, Seq: 3},
{Type: bulkFastPayloadTypeRelease, Seq: 4, Payload: []byte{1, 2, 3, 4}},
},
},
}
plain, err := encodeBulkDedicatedBatchesPlain(batches)
if err != nil {
t.Fatalf("encodeBulkDedicatedBatchesPlain error = %v", err)
}
got, err := decodeDedicatedBulkInboundBatches(plain)
if err != nil {
t.Fatalf("decodeDedicatedBulkInboundBatches error = %v", err)
}
if len(got) != 2 {
t.Fatalf("decoded batch count = %d, want 2", len(got))
}
if got[0].DataID != 7 || got[1].DataID != 8 {
t.Fatalf("decoded dataIDs = %d,%d, want 7,8", got[0].DataID, got[1].DataID)
}
if len(got[0].Items) != 2 || len(got[1].Items) != 2 {
t.Fatalf("decoded item counts = %d,%d, want 2,2", len(got[0].Items), len(got[1].Items))
}
if !bytes.Equal(got[0].Items[0].Payload, []byte("alpha")) || !bytes.Equal(got[0].Items[1].Payload, []byte("beta")) {
t.Fatalf("decoded first batch payloads mismatch: %+v", got[0].Items)
}
if got[1].Items[0].Type != bulkFastPayloadTypeClose || got[1].Items[0].Flags != bulkFastPayloadFlagFullClose {
t.Fatalf("decoded close item mismatch: %+v", got[1].Items[0])
}
if got[1].Items[1].Type != bulkFastPayloadTypeRelease || !bytes.Equal(got[1].Items[1].Payload, []byte{1, 2, 3, 4}) {
t.Fatalf("decoded release item mismatch: %+v", got[1].Items[1])
}
}

451
bulk_dedicated_sidecar.go Normal file
View File

@ -0,0 +1,451 @@
package notify
import (
"context"
"net"
"sync"
)
type bulkDedicatedSidecar struct {
laneID uint32
conn net.Conn
closeOnce sync.Once
senderMu sync.Mutex
sender *bulkDedicatedLaneSender
}
type bulkDedicatedLane struct {
id uint32
activeBulks int
sidecar *bulkDedicatedSidecar
attachFlight *bulkDedicatedAttachFlight
}
type bulkDedicatedAttachFlight struct {
done chan struct{}
once sync.Once
err error
}
func normalizeBulkDedicatedLaneID(laneID uint32) uint32 {
if laneID == 0 {
return 1
}
return laneID
}
func newBulkDedicatedSidecar(conn net.Conn, laneID uint32) *bulkDedicatedSidecar {
if conn == nil {
return nil
}
return &bulkDedicatedSidecar{
laneID: normalizeBulkDedicatedLaneID(laneID),
conn: conn,
}
}
func newBulkDedicatedAttachFlight() *bulkDedicatedAttachFlight {
return &bulkDedicatedAttachFlight{
done: make(chan struct{}),
}
}
func (f *bulkDedicatedAttachFlight) finish(err error) {
if f == nil {
return
}
f.once.Do(func() {
f.err = err
close(f.done)
})
}
func (f *bulkDedicatedAttachFlight) wait(ctx context.Context) error {
if f == nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return ctx.Err()
case <-f.done:
return f.err
}
}
func (s *bulkDedicatedSidecar) close() {
if s == nil {
return
}
s.closeOnce.Do(func() {
if sender := s.laneSenderSnapshot(); sender != nil {
sender.stop()
}
if s.conn != nil {
_ = s.conn.Close()
}
})
}
func (s *bulkDedicatedSidecar) laneSenderSnapshot() *bulkDedicatedLaneSender {
if s == nil {
return nil
}
s.senderMu.Lock()
defer s.senderMu.Unlock()
return s.sender
}
func (s *bulkDedicatedSidecar) laneSenderWithFactory(factory func(net.Conn) *bulkDedicatedLaneSender) *bulkDedicatedLaneSender {
if s == nil || factory == nil {
return nil
}
s.senderMu.Lock()
defer s.senderMu.Unlock()
if s.sender != nil {
return s.sender
}
if s.conn == nil {
return nil
}
s.sender = factory(s.conn)
return s.sender
}
func (c *ClientCommon) clientDedicatedSidecarSnapshot() *bulkDedicatedSidecar {
if c == nil {
return nil
}
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
return firstClientDedicatedSidecarLocked(c.bulkDedicatedLanes)
}
func firstClientDedicatedSidecarLocked(lanes map[uint32]*bulkDedicatedLane) *bulkDedicatedSidecar {
var (
selected *bulkDedicatedSidecar
bestID uint32
)
for laneID, lane := range lanes {
if lane == nil || lane.sidecar == nil {
continue
}
if selected == nil || laneID < bestID {
selected = lane.sidecar
bestID = laneID
}
}
return selected
}
func (c *ClientCommon) reserveBulkDedicatedLane() uint32 {
if c == nil {
return normalizeBulkDedicatedLaneID(0)
}
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
if c.bulkDedicatedLanes == nil {
c.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane)
}
limit := c.bulkDedicatedLaneLimitSnapshot()
var best *bulkDedicatedLane
for _, lane := range c.bulkDedicatedLanes {
if lane == nil {
continue
}
if best == nil || lane.activeBulks < best.activeBulks || (lane.activeBulks == best.activeBulks && lane.id < best.id) {
best = lane
}
}
if best == nil || ((limit <= 0 || len(c.bulkDedicatedLanes) < limit) && best.activeBulks > 0) {
c.bulkDedicatedNextLaneID++
laneID := normalizeBulkDedicatedLaneID(c.bulkDedicatedNextLaneID)
best = &bulkDedicatedLane{id: laneID}
c.bulkDedicatedLanes[laneID] = best
}
best.activeBulks++
return best.id
}
func (c *ClientCommon) releaseBulkDedicatedLane(laneID uint32) {
if c == nil {
return
}
laneID = normalizeBulkDedicatedLaneID(laneID)
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
lane := c.bulkDedicatedLanes[laneID]
if lane == nil {
return
}
if lane.activeBulks > 0 {
lane.activeBulks--
}
if lane.activeBulks == 0 && lane.sidecar == nil && lane.attachFlight == nil {
delete(c.bulkDedicatedLanes, laneID)
}
}
func (c *ClientCommon) clientDedicatedSidecarSnapshotForLane(laneID uint32) *bulkDedicatedSidecar {
if c == nil {
return nil
}
laneID = normalizeBulkDedicatedLaneID(laneID)
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
if lane := c.bulkDedicatedLanes[laneID]; lane != nil {
return lane.sidecar
}
return nil
}
func (c *ClientCommon) beginClientDedicatedSidecarAttach(laneID uint32) (*bulkDedicatedSidecar, *bulkDedicatedAttachFlight, bool) {
if c == nil {
return nil, nil, false
}
laneID = normalizeBulkDedicatedLaneID(laneID)
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
if c.bulkDedicatedLanes == nil {
c.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane)
}
lane := c.bulkDedicatedLanes[laneID]
if lane == nil {
lane = &bulkDedicatedLane{id: laneID}
c.bulkDedicatedLanes[laneID] = lane
}
if lane.sidecar != nil {
return lane.sidecar, nil, false
}
if lane.attachFlight != nil {
return nil, lane.attachFlight, false
}
flight := newBulkDedicatedAttachFlight()
lane.attachFlight = flight
return nil, flight, true
}
func (c *ClientCommon) finishClientDedicatedSidecarAttach(laneID uint32, flight *bulkDedicatedAttachFlight, err error) {
if c == nil || flight == nil {
return
}
laneID = normalizeBulkDedicatedLaneID(laneID)
c.bulkDedicatedSidecarMu.Lock()
if lane := c.bulkDedicatedLanes[laneID]; lane != nil && lane.attachFlight == flight {
lane.attachFlight = nil
if lane.activeBulks == 0 && lane.sidecar == nil {
delete(c.bulkDedicatedLanes, laneID)
}
}
c.bulkDedicatedSidecarMu.Unlock()
flight.finish(err)
}
func (c *ClientCommon) installClientDedicatedSidecar(laneID uint32, sidecar *bulkDedicatedSidecar) (*bulkDedicatedSidecar, bool) {
if c == nil || sidecar == nil {
return nil, false
}
laneID = normalizeBulkDedicatedLaneID(laneID)
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
if c.bulkDedicatedLanes == nil {
c.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane)
}
lane := c.bulkDedicatedLanes[laneID]
if lane == nil {
lane = &bulkDedicatedLane{id: laneID}
c.bulkDedicatedLanes[laneID] = lane
}
if lane.sidecar != nil {
return lane.sidecar, false
}
lane.sidecar = sidecar
return sidecar, true
}
func (c *ClientCommon) clearClientDedicatedSidecar(laneID uint32, sidecar *bulkDedicatedSidecar) bool {
if c == nil || sidecar == nil {
return false
}
laneID = normalizeBulkDedicatedLaneID(laneID)
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
lane := c.bulkDedicatedLanes[laneID]
if lane == nil || lane.sidecar != sidecar {
return false
}
lane.sidecar = nil
if lane.activeBulks == 0 && lane.attachFlight == nil {
delete(c.bulkDedicatedLanes, laneID)
}
return true
}
func (c *ClientCommon) closeClientDedicatedSidecar() {
if c == nil {
return
}
c.bulkDedicatedSidecarMu.Lock()
lanes := c.bulkDedicatedLanes
c.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane)
c.bulkDedicatedSidecarMu.Unlock()
for _, lane := range lanes {
if lane == nil {
continue
}
if lane.sidecar != nil {
lane.sidecar.close()
}
if lane.attachFlight != nil {
lane.attachFlight.finish(errServiceShutdown)
}
}
}
func (c *ClientCommon) handleClientDedicatedSidecarFailure(sidecar *bulkDedicatedSidecar, err error) {
if c == nil || sidecar == nil {
return
}
if !c.clearClientDedicatedSidecar(sidecar.laneID, sidecar) {
return
}
runtime := c.getBulkRuntime()
if runtime != nil && sidecar.conn != nil {
runtime.handleDedicatedReadErrorByConn(clientFileScope(), sidecar.conn, err)
}
sidecar.close()
}
func (s *ServerCommon) serverDedicatedSidecarSnapshot(logical *LogicalConn) *bulkDedicatedSidecar {
if s == nil || logical == nil {
return nil
}
s.bulkDedicatedSidecarMu.Lock()
defer s.bulkDedicatedSidecarMu.Unlock()
return firstServerDedicatedSidecarLocked(s.bulkDedicatedSidecars[logical])
}
func firstServerDedicatedSidecarLocked(lanes map[uint32]*bulkDedicatedSidecar) *bulkDedicatedSidecar {
var (
selected *bulkDedicatedSidecar
bestID uint32
)
for laneID, sidecar := range lanes {
if sidecar == nil {
continue
}
if selected == nil || laneID < bestID {
selected = sidecar
bestID = laneID
}
}
return selected
}
func (s *ServerCommon) serverDedicatedSidecarSnapshotForLane(logical *LogicalConn, laneID uint32) *bulkDedicatedSidecar {
if s == nil || logical == nil {
return nil
}
laneID = normalizeBulkDedicatedLaneID(laneID)
s.bulkDedicatedSidecarMu.Lock()
defer s.bulkDedicatedSidecarMu.Unlock()
if lanes := s.bulkDedicatedSidecars[logical]; lanes != nil {
return lanes[laneID]
}
return nil
}
func (s *ServerCommon) installServerDedicatedSidecar(logical *LogicalConn, laneID uint32, sidecar *bulkDedicatedSidecar) *bulkDedicatedSidecar {
if s == nil || logical == nil || sidecar == nil {
return nil
}
laneID = normalizeBulkDedicatedLaneID(laneID)
s.bulkDedicatedSidecarMu.Lock()
defer s.bulkDedicatedSidecarMu.Unlock()
lanes := s.bulkDedicatedSidecars[logical]
if lanes == nil {
lanes = make(map[uint32]*bulkDedicatedSidecar)
s.bulkDedicatedSidecars[logical] = lanes
}
prev := lanes[laneID]
lanes[laneID] = sidecar
return prev
}
func (s *ServerCommon) clearServerDedicatedSidecar(logical *LogicalConn, laneID uint32, sidecar *bulkDedicatedSidecar) bool {
if s == nil || logical == nil || sidecar == nil {
return false
}
laneID = normalizeBulkDedicatedLaneID(laneID)
s.bulkDedicatedSidecarMu.Lock()
defer s.bulkDedicatedSidecarMu.Unlock()
lanes := s.bulkDedicatedSidecars[logical]
if lanes == nil || lanes[laneID] != sidecar {
return false
}
delete(lanes, laneID)
if len(lanes) == 0 {
delete(s.bulkDedicatedSidecars, logical)
}
return true
}
func (s *ServerCommon) closeServerDedicatedSidecar(logical *LogicalConn) {
if s == nil || logical == nil {
return
}
s.bulkDedicatedSidecarMu.Lock()
lanes := s.bulkDedicatedSidecars[logical]
delete(s.bulkDedicatedSidecars, logical)
s.bulkDedicatedSidecarMu.Unlock()
for _, sidecar := range lanes {
if sidecar != nil {
sidecar.close()
}
}
}
func (s *ServerCommon) closeAllServerDedicatedSidecars() {
if s == nil {
return
}
s.bulkDedicatedSidecarMu.Lock()
lanesByLogical := s.bulkDedicatedSidecars
s.bulkDedicatedSidecars = make(map[*LogicalConn]map[uint32]*bulkDedicatedSidecar)
s.bulkDedicatedSidecarMu.Unlock()
for _, lanes := range lanesByLogical {
for _, sidecar := range lanes {
if sidecar != nil {
sidecar.close()
}
}
}
}
func (s *ServerCommon) handleServerDedicatedSidecarFailure(logical *LogicalConn, sidecar *bulkDedicatedSidecar, err error) {
if s == nil || logical == nil || sidecar == nil {
return
}
if !s.clearServerDedicatedSidecar(logical, sidecar.laneID, sidecar) {
return
}
runtime := s.getBulkRuntime()
if runtime != nil && sidecar.conn != nil {
runtime.handleDedicatedReadErrorByConn(serverFileScope(logical), sidecar.conn, err)
}
sidecar.close()
}
func (s *ServerCommon) attachServerDedicatedSidecarIfExists(logical *LogicalConn, bulk *bulkHandle) {
if s == nil || logical == nil || bulk == nil || !bulk.Dedicated() {
return
}
sidecar := s.serverDedicatedSidecarSnapshotForLane(logical, bulk.dedicatedLaneIDSnapshot())
if sidecar == nil || sidecar.conn == nil {
return
}
_ = bulk.attachDedicatedConnShared(sidecar.conn)
}

View File

@ -12,6 +12,10 @@ import (
const bulkDispatchRejectTimeout = 300 * time.Millisecond
func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) {
c.dispatchFastBulkFrameWithOwner(frame, nil)
}
func (c *ClientCommon) dispatchFastBulkFrameWithOwner(frame bulkFastFrame, owner *bulkReadPayloadOwner) {
if frame.DataID == 0 {
return
}
@ -38,7 +42,13 @@ func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) {
}
switch frame.Type {
case bulkFastPayloadTypeData:
if err := bulk.pushOwnedChunk(frame.Payload); err != nil {
var err error
if owner != nil {
err = bulk.pushChunkWithOwnershipOptionsAndRelease(frame.Payload, true, true, owner.retainChunk())
} else {
err = bulk.pushOwnedChunk(frame.Payload)
}
if err != nil {
if c.showError || c.debugMode {
fmt.Println("client bulk push chunk error", err)
}
@ -58,14 +68,28 @@ func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) {
resetErr = bulkRemoteResetError(string(frame.Payload))
}
bulk.markReset(bulkResetError(resetErr))
case bulkFastPayloadTypeRelease:
bytes, chunks, err := decodeBulkDedicatedReleasePayload(frame.Payload)
if err != nil {
if c.showError || c.debugMode {
fmt.Println("client bulk release decode error", err)
}
c.bestEffortRejectInboundBulkData(bulk.ID(), frame.DataID, err.Error())
return
}
bulk.releaseOutboundWindow(bytes, chunks)
}
}
func (c *ClientCommon) dispatchFastBulkData(frame bulkFastDataFrame) {
c.dispatchFastBulkFrame(frame)
c.dispatchFastBulkFrameWithOwner(frame, nil)
}
func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastFrame) {
s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, nil)
}
func (s *ServerCommon) dispatchFastBulkFrameWithOwner(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastFrame, owner *bulkReadPayloadOwner) {
if logical == nil || frame.DataID == 0 {
return
}
@ -91,7 +115,13 @@ func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *Tr
}
switch frame.Type {
case bulkFastPayloadTypeData:
if err := bulk.pushOwnedChunk(frame.Payload); err != nil {
var err error
if owner != nil {
err = bulk.pushChunkWithOwnershipOptionsAndRelease(frame.Payload, true, true, owner.retainChunk())
} else {
err = bulk.pushOwnedChunk(frame.Payload)
}
if err != nil {
if s.showError || s.debugMode {
fmt.Println("server bulk push chunk error", err)
}
@ -111,11 +141,82 @@ func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *Tr
resetErr = bulkRemoteResetError(string(frame.Payload))
}
bulk.markReset(bulkResetError(resetErr))
case bulkFastPayloadTypeRelease:
bytes, chunks, err := decodeBulkDedicatedReleasePayload(frame.Payload)
if err != nil {
if s.showError || s.debugMode {
fmt.Println("server bulk release decode error", err)
}
s.bestEffortRejectInboundBulkData(logical, transport, conn, bulk.ID(), frame.DataID, err.Error())
return
}
bulk.releaseOutboundWindow(bytes, chunks)
}
}
func (s *ServerCommon) dispatchFastBulkData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastDataFrame) {
s.dispatchFastBulkFrame(logical, transport, conn, frame)
s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, nil)
}
func (c *ClientCommon) tryDispatchBorrowedBulkTransportPayload(payload []byte) bool {
if c == nil || len(payload) == 0 {
return false
}
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(c.modernPSKRuntime, c.msgDe, c.SecretKey, payload)
if err != nil {
if c.showError || c.debugMode {
fmt.Println("client decode transport payload error", err)
}
return true
}
owner := newBulkReadPayloadOwner(plainRelease)
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
c.dispatchFastBulkFrameWithOwner(frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
return false
}
if walkErr != nil && (c.showError || c.debugMode) {
fmt.Println("client decode bulk fast payload error", walkErr)
}
return true
}
func (s *ServerCommon) tryDispatchBorrowedBulkTransportPayload(source interface{}, payload []byte) bool {
if s == nil || len(payload) == 0 {
return false
}
logical, transport := s.resolveInboundSource(source)
if logical == nil {
return false
}
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload)
if err != nil {
if s.showError || s.debugMode {
fmt.Println("server decode transport payload error", err)
}
return true
}
conn := serverInboundConn(source)
owner := newBulkReadPayloadOwner(plainRelease)
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
return false
}
if walkErr != nil && (s.showError || s.debugMode) {
fmt.Println("server decode bulk fast payload error", walkErr)
}
return true
}
func (c *ClientCommon) bestEffortRejectInboundBulkData(bulkID string, dataID uint64, message string) {

View File

@ -308,7 +308,11 @@ func newBulkBenchmarkClient(tb testing.TB, network string, server bulkBenchmarkS
tb.Fatalf("UseSignalReliabilityClient failed: %v", err)
}
}
if err := client.Connect(network, server.addr); err != nil {
dialAddr := server.addr
if network == "tcp" {
dialAddr = benchmarkTCPDialAddr(tb, dialAddr)
}
if err := client.Connect(network, dialAddr); err != nil {
tb.Fatalf("client Connect failed: %v", err)
}
tb.Cleanup(func() {
@ -333,6 +337,9 @@ func bulkBenchmarkListenAddr(tb testing.TB, network string) string {
case "unix":
return filepath.Join(tb.TempDir(), "notify-bulk.sock")
case "udp", "tcp":
if network == "tcp" {
return benchmarkTCPListenAddr(tb)
}
return "127.0.0.1:0"
default:
tb.Fatalf("unsupported benchmark network %q", network)

View File

@ -120,32 +120,145 @@ func decodeBulkFastDataFrame(payload []byte) (bulkFastDataFrame, bool, error) {
}
func (c *ClientCommon) encodeFastBulkDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
if c != nil && c.fastBulkEncode != nil {
return c.fastBulkEncode(c.SecretKey, dataID, seq, chunk)
}
scratch := getBulkFastFrameScratch(len(chunk))
defer putBulkFastFrameScratch(scratch)
frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)]
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil {
return nil, err
}
copy(frame[bulkFastPayloadHeaderLen:], chunk)
return c.encryptTransportPayload(frame)
return c.encodeBulkFastPayload(bulkFastFrame{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: chunk,
})
}
func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte) error {
func (c *ClientCommon) encodeBulkFastPayload(frame bulkFastFrame) ([]byte, error) {
if c == nil {
return nil, errBulkClientNil
}
if c.fastPlainEncode != nil {
return encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame)
}
plain, err := encodeBulkFastFramePayload(frame)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (c *ClientCommon) encodeBulkFastBatchPayload(frames []bulkFastFrame) ([]byte, error) {
if c == nil {
return nil, errBulkClientNil
}
if c.fastPlainEncode != nil {
return encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames)
}
plain, err := encodeBulkFastBatchPlain(frames)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (c *ClientCommon) encodeBulkFastPayloadPooled(frame bulkFastFrame) ([]byte, func(), error) {
if c == nil {
return nil, nil, errBulkClientNil
}
if runtime := c.modernPSKRuntime; runtime != nil {
return encodeBulkFastFramePayloadPooled(runtime, frame)
}
if c.fastPlainEncode != nil {
payload, err := encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame)
return payload, nil, err
}
plain, err := encodeBulkFastFramePayload(frame)
if err != nil {
return nil, nil, err
}
payload, err := c.encryptTransportPayload(plain)
return payload, nil, err
}
func (c *ClientCommon) encodeBulkFastBatchPayloadPooled(frames []bulkFastFrame) ([]byte, func(), error) {
if c == nil {
return nil, nil, errBulkClientNil
}
if runtime := c.modernPSKRuntime; runtime != nil {
return encodeBulkFastBatchPayloadPooled(runtime, frames)
}
if c.fastPlainEncode != nil {
payload, err := encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames)
return payload, nil, err
}
plain, err := encodeBulkFastBatchPlain(frames)
if err != nil {
return nil, nil, err
}
payload, err := c.encryptTransportPayload(plain)
return payload, nil, err
}
func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte, fastPathVersion uint8) error {
binding := c.clientTransportBindingSnapshot()
if binding == nil {
return net.ErrClosed
}
if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil {
return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk)
}
payload, err := c.encodeFastBulkDataPayload(dataID, seq, chunk)
if err != nil {
return err
}
return c.writePayloadToTransport(payload)
}
func (c *ClientCommon) sendFastBulkWrite(ctx context.Context, dataID uint64, startSeq uint64, chunkSize int, fastPathVersion uint8, payload []byte, payloadOwned bool) (int, error) {
if len(payload) == 0 {
return 0, nil
}
binding := c.clientTransportBindingSnapshot()
if binding == nil {
return 0, net.ErrClosed
}
if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil {
return sender.submitWrite(ctx, dataID, startSeq, fastPathVersion, payload, chunkSize, payloadOwned)
}
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
written := 0
seq := startSeq
for written < len(payload) {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
if err := c.sendFastBulkData(ctx, dataID, seq, payload[written:end], fastPathVersion); err != nil {
return written, err
}
seq++
written = end
}
return written, nil
}
func (c *ClientCommon) sendFastBulkControl(ctx context.Context, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
frame := bulkFastFrame{
Type: frameType,
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload,
}
binding := c.clientTransportBindingSnapshot()
if binding == nil {
return net.ErrClosed
}
if sender := binding.bulkBatchSenderSnapshot(); sender != nil {
return sender.submit(ctx, payload)
if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil {
return sender.submitControl(ctx, frameType, flags, dataID, seq, fastPathVersion, payload)
}
return c.writePayloadToTransport(payload)
encoded, err := c.encodeBulkFastPayload(frame)
if err != nil {
return err
}
return c.writePayloadToTransport(encoded)
}
func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
@ -157,22 +270,81 @@ func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8
}
func (s *ServerCommon) encodeFastBulkDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
if logical != nil {
if fastBulkEncode := logical.fastBulkEncodeSnapshot(); fastBulkEncode != nil {
return fastBulkEncode(logical.secretKeySnapshot(), dataID, seq, chunk)
}
}
scratch := getBulkFastFrameScratch(len(chunk))
defer putBulkFastFrameScratch(scratch)
frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)]
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil {
return nil, err
}
copy(frame[bulkFastPayloadHeaderLen:], chunk)
return s.encryptTransportPayloadLogical(logical, frame)
return s.encodeBulkFastPayloadLogical(logical, bulkFastFrame{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: chunk,
})
}
func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte) error {
func (s *ServerCommon) encodeBulkFastPayloadLogical(logical *LogicalConn, frame bulkFastFrame) ([]byte, error) {
if logical == nil {
return nil, errTransportDetached
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeBulkFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame)
}
plain, err := encodeBulkFastFramePayload(frame)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func (s *ServerCommon) encodeBulkFastBatchPayloadLogical(logical *LogicalConn, frames []bulkFastFrame) ([]byte, error) {
if logical == nil {
return nil, errTransportDetached
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeBulkFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames)
}
plain, err := encodeBulkFastBatchPlain(frames)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func (s *ServerCommon) encodeBulkFastPayloadLogicalPooled(logical *LogicalConn, frame bulkFastFrame) ([]byte, func(), error) {
if logical == nil {
return nil, nil, errTransportDetached
}
if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil {
return encodeBulkFastFramePayloadPooled(runtime, frame)
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
payload, err := encodeBulkFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame)
return payload, nil, err
}
plain, err := encodeBulkFastFramePayload(frame)
if err != nil {
return nil, nil, err
}
payload, err := s.encryptTransportPayloadLogical(logical, plain)
return payload, nil, err
}
func (s *ServerCommon) encodeBulkFastBatchPayloadLogicalPooled(logical *LogicalConn, frames []bulkFastFrame) ([]byte, func(), error) {
if logical == nil {
return nil, nil, errTransportDetached
}
if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil {
return encodeBulkFastBatchPayloadPooled(runtime, frames)
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
payload, err := encodeBulkFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames)
return payload, nil, err
}
plain, err := encodeBulkFastBatchPlain(frames)
if err != nil {
return nil, nil, err
}
payload, err := s.encryptTransportPayloadLogical(logical, plain)
return payload, nil, err
}
func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte, fastPathVersion uint8) error {
if err := s.ensureServerTransportSendReady(transport); err != nil {
return err
}
@ -182,18 +354,87 @@ func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *L
if logical == nil {
return errTransportDetached
}
if binding := logical.transportBindingSnapshot(); binding != nil {
if binding.queueSnapshot() != nil {
if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil {
return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk)
}
}
}
payload, err := s.encodeFastBulkDataPayloadLogical(logical, dataID, seq, chunk)
if err != nil {
return err
}
return s.writeEnvelopePayload(logical, transport, nil, payload)
}
func (s *ServerCommon) sendFastBulkWriteTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, startSeq uint64, chunkSize int, fastPathVersion uint8, payload []byte, payloadOwned bool) (int, error) {
if len(payload) == 0 {
return 0, nil
}
if err := s.ensureServerTransportSendReady(transport); err != nil {
return 0, err
}
if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot()
}
if logical == nil {
return 0, errTransportDetached
}
if binding := logical.transportBindingSnapshot(); binding != nil {
if binding.queueSnapshot() != nil {
if sender := binding.bulkBatchSenderSnapshot(); sender != nil {
return sender.submit(ctx, payload)
if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil {
return sender.submitWrite(ctx, dataID, startSeq, fastPathVersion, payload, chunkSize, payloadOwned)
}
}
}
return s.writeEnvelopePayload(logical, transport, nil, payload)
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
written := 0
seq := startSeq
for written < len(payload) {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
if err := s.sendFastBulkDataTransport(ctx, logical, transport, dataID, seq, payload[written:end], fastPathVersion); err != nil {
return written, err
}
seq++
written = end
}
return written, nil
}
func (s *ServerCommon) sendFastBulkControlTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
if err := s.ensureServerTransportSendReady(transport); err != nil {
return err
}
if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot()
}
if logical == nil {
return errTransportDetached
}
if binding := logical.transportBindingSnapshot(); binding != nil {
if binding.queueSnapshot() != nil {
if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil {
return sender.submitControl(ctx, frameType, flags, dataID, seq, fastPathVersion, payload)
}
}
}
encoded, err := s.encodeBulkFastPayloadLogical(logical, bulkFastFrame{
Type: frameType,
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload,
})
if err != nil {
return err
}
return s.writeEnvelopePayload(logical, transport, nil, encoded)
}
func (s *ServerCommon) encodeBulkFastControlPayloadLogical(logical *LogicalConn, frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
@ -224,18 +465,22 @@ func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.
if err != nil {
return err
}
if frame, matched, err := decodeBulkFastFrame(plain); matched {
if frames, matched, err := decodeBulkFastFrames(plain); matched {
if err != nil {
return err
}
c.dispatchFastBulkFrame(frame)
for _, frame := range frames {
c.dispatchFastBulkFrame(frame)
}
return nil
}
if frame, matched, err := decodeStreamFastDataFrame(plain); matched {
if frames, matched, err := decodeStreamFastDataFrames(plain); matched {
if err != nil {
return err
}
c.dispatchFastStreamData(frame)
for _, frame := range frames {
c.dispatchFastStreamData(frame)
}
return nil
}
env, err := c.decodeEnvelopePlain(plain)
@ -257,18 +502,22 @@ func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, tra
if err != nil {
return err
}
if frame, matched, err := decodeBulkFastFrame(plain); matched {
if frames, matched, err := decodeBulkFastFrames(plain); matched {
if err != nil {
return err
}
s.dispatchFastBulkFrame(logical, transport, conn, frame)
for _, frame := range frames {
s.dispatchFastBulkFrame(logical, transport, conn, frame)
}
return nil
}
if frame, matched, err := decodeStreamFastDataFrame(plain); matched {
if frames, matched, err := decodeStreamFastDataFrames(plain); matched {
if err != nil {
return err
}
s.dispatchFastStreamData(logical, transport, conn, frame)
for _, frame := range frames {
s.dispatchFastStreamData(logical, transport, conn, frame)
}
return nil
}
env, err := s.decodeEnvelopePlain(plain)

View File

@ -2,7 +2,7 @@ package notify
import (
"fmt"
"strconv"
"net"
"strings"
"sync"
"sync/atomic"
@ -16,14 +16,14 @@ type bulkRuntime struct {
mu sync.RWMutex
handler func(BulkAcceptInfo) error
bulks map[string]*bulkHandle
data map[string]*bulkHandle
data map[string]map[uint64]*bulkHandle
}
func newBulkRuntime(rolePrefix string) *bulkRuntime {
return &bulkRuntime{
rolePrefix: rolePrefix,
bulks: make(map[string]*bulkHandle),
data: make(map[string]*bulkHandle),
data: make(map[string]map[uint64]*bulkHandle),
}
}
@ -66,8 +66,8 @@ func (r *bulkRuntime) register(scope string, bulk *bulkHandle) error {
if bulk == nil || bulk.id == "" {
return errBulkIDEmpty
}
scope = normalizeFileScope(scope)
key := bulkRuntimeKey(scope, bulk.id)
dataKey := bulkRuntimeDataKey(scope, bulk.dataID)
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.bulks[key]; ok {
@ -76,11 +76,16 @@ func (r *bulkRuntime) register(scope string, bulk *bulkHandle) error {
if bulk.dataID == 0 {
return errBulkDataIDEmpty
}
if _, ok := r.data[dataKey]; ok {
dataScope := r.data[scope]
if dataScope == nil {
dataScope = make(map[uint64]*bulkHandle)
r.data[scope] = dataScope
}
if _, ok := dataScope[bulk.dataID]; ok {
return errBulkAlreadyExists
}
r.bulks[key] = bulk
r.data[dataKey] = bulk
dataScope[bulk.dataID] = bulk
return nil
}
@ -99,10 +104,14 @@ func (r *bulkRuntime) lookupByDataID(scope string, dataID uint64) (*bulkHandle,
if r == nil || dataID == 0 {
return nil, false
}
key := bulkRuntimeDataKey(scope, dataID)
scope = normalizeFileScope(scope)
r.mu.RLock()
defer r.mu.RUnlock()
bulk, ok := r.data[key]
dataScope := r.data[scope]
if dataScope == nil {
return nil, false
}
bulk, ok := dataScope[dataID]
return bulk, ok
}
@ -110,11 +119,17 @@ func (r *bulkRuntime) remove(scope string, bulkID string) {
if r == nil || bulkID == "" {
return
}
scope = normalizeFileScope(scope)
key := bulkRuntimeKey(scope, bulkID)
r.mu.Lock()
defer r.mu.Unlock()
if bulk := r.bulks[key]; bulk != nil && bulk.dataID != 0 {
delete(r.data, bulkRuntimeDataKey(scope, bulk.dataID))
if dataScope := r.data[scope]; dataScope != nil {
delete(dataScope, bulk.dataID)
if len(dataScope) == 0 {
delete(r.data, scope)
}
}
}
delete(r.bulks, key)
}
@ -149,6 +164,140 @@ func (r *bulkRuntime) closeMatching(match func(string) bool, err error) {
}
}
func (r *bulkRuntime) resetDedicatedByConn(scope string, conn net.Conn, err error) {
if r == nil || conn == nil {
return
}
scope = normalizeFileScope(scope)
resetErr := bulkRuntimeCloseError(err)
prefix := scope + "\x00"
r.mu.RLock()
bulks := make([]*bulkHandle, 0, len(r.bulks))
for key, bulk := range r.bulks {
if bulk == nil {
continue
}
if !strings.HasPrefix(key, prefix) {
continue
}
if !bulk.Dedicated() {
continue
}
if bulk.dedicatedConnSnapshot() != conn {
continue
}
bulks = append(bulks, bulk)
}
r.mu.RUnlock()
for _, bulk := range bulks {
bulk.markReset(resetErr)
}
}
func (r *bulkRuntime) handleDedicatedReadErrorByConn(scope string, conn net.Conn, err error) {
if r == nil || conn == nil {
return
}
scope = normalizeFileScope(scope)
prefix := scope + "\x00"
r.mu.RLock()
bulks := make([]*bulkHandle, 0, len(r.bulks))
for key, bulk := range r.bulks {
if bulk == nil {
continue
}
if !strings.HasPrefix(key, prefix) {
continue
}
if !bulk.Dedicated() {
continue
}
if bulk.dedicatedConnSnapshot() != conn {
continue
}
bulks = append(bulks, bulk)
}
r.mu.RUnlock()
for _, bulk := range bulks {
handleDedicatedBulkReadError(bulk, err)
}
}
func (r *bulkRuntime) attachSharedDedicatedConn(scope string, laneID uint32, conn net.Conn) {
if r == nil || conn == nil {
return
}
scope = normalizeFileScope(scope)
laneID = normalizeBulkDedicatedLaneID(laneID)
prefix := scope + "\x00"
r.mu.RLock()
bulks := make([]*bulkHandle, 0, len(r.bulks))
for key, bulk := range r.bulks {
if bulk == nil {
continue
}
if !strings.HasPrefix(key, prefix) {
continue
}
if !bulk.Dedicated() {
continue
}
if bulk.dedicatedLaneIDSnapshot() != laneID {
continue
}
bulks = append(bulks, bulk)
}
r.mu.RUnlock()
for _, bulk := range bulks {
current := bulk.dedicatedConnSnapshot()
switch {
case current == conn:
_ = bulk.attachDedicatedConnShared(conn)
case current == nil:
_ = bulk.attachDedicatedConnShared(conn)
default:
oldConn, oldSender, err := bulk.replaceDedicatedConnShared(conn)
if err != nil {
bulk.markReset(err)
continue
}
if oldSender != nil {
oldSender.stop()
}
if oldConn != nil {
_ = oldConn.Close()
}
}
}
}
func (r *bulkRuntime) dedicatedBulksForConn(scope string, conn net.Conn) []*bulkHandle {
if r == nil || conn == nil {
return nil
}
scope = normalizeFileScope(scope)
prefix := scope + "\x00"
r.mu.RLock()
bulks := make([]*bulkHandle, 0, len(r.bulks))
for key, bulk := range r.bulks {
if bulk == nil {
continue
}
if !strings.HasPrefix(key, prefix) {
continue
}
if !bulk.Dedicated() {
continue
}
if bulk.dedicatedConnSnapshot() != conn {
continue
}
bulks = append(bulks, bulk)
}
r.mu.RUnlock()
return bulks
}
func (r *bulkRuntime) snapshots() []BulkSnapshot {
if r == nil {
return nil
@ -170,10 +319,6 @@ func bulkRuntimeKey(scope string, bulkID string) string {
return normalizeFileScope(scope) + "\x00" + bulkID
}
func bulkRuntimeDataKey(scope string, dataID uint64) string {
return normalizeFileScope(scope) + "\x01" + strconv.FormatUint(dataID, 10)
}
func bulkRuntimeCloseError(err error) error {
if err != nil {
return err

228
bulk_shared_batch.go Normal file
View File

@ -0,0 +1,228 @@
package notify
import (
"encoding/binary"
)
const (
bulkFastPathVersionV1 = 1
bulkFastPathVersionV2 = 2
bulkFastPathVersionCurrent = bulkFastPathVersionV2
)
const (
bulkFastBatchMagic = "NBF2"
bulkFastBatchVersion = 1
bulkFastBatchHeaderLen = 12
bulkFastBatchItemHeaderLen = 24
bulkFastBatchMaxItems = 64
bulkFastBatchMaxPlainBytes = 8 * 1024 * 1024
)
func normalizeBulkFastPathVersion(version uint8) uint8 {
if version < bulkFastPathVersionV1 {
return bulkFastPathVersionV1
}
if version > bulkFastPathVersionCurrent {
return bulkFastPathVersionCurrent
}
return version
}
func negotiateBulkFastPathVersion(version uint8) uint8 {
return normalizeBulkFastPathVersion(version)
}
func bulkFastPathSupportsSharedBatch(version uint8) bool {
return normalizeBulkFastPathVersion(version) >= bulkFastPathVersionV2
}
func bulkFastBatchFrameLen(frame bulkFastFrame) int {
return bulkFastBatchItemHeaderLen + len(frame.Payload)
}
func bulkFastBatchPlainLen(frames []bulkFastFrame) int {
total := bulkFastBatchHeaderLen
for _, frame := range frames {
total += bulkFastBatchFrameLen(frame)
}
return total
}
func encodeBulkFastFramePayload(frame bulkFastFrame) ([]byte, error) {
return encodeBulkFastControlFrame(frame.Type, frame.Flags, frame.DataID, frame.Seq, frame.Payload)
}
func encodeBulkFastFramePayloadFast(encode transportFastPlainEncoder, secretKey []byte, frame bulkFastFrame) ([]byte, error) {
if encode == nil {
return nil, errTransportPayloadEncryptFailed
}
plainLen := bulkFastPayloadHeaderLen + len(frame.Payload)
return encode(secretKey, plainLen, func(dst []byte) error {
if err := encodeBulkFastFrameHeader(dst, frame.Type, frame.Flags, frame.DataID, frame.Seq, len(frame.Payload)); err != nil {
return err
}
copy(dst[bulkFastPayloadHeaderLen:], frame.Payload)
return nil
})
}
func encodeBulkFastFramePayloadPooled(runtime *modernPSKCodecRuntime, frame bulkFastFrame) ([]byte, func(), error) {
if runtime == nil {
return nil, nil, errTransportPayloadEncryptFailed
}
plainLen := bulkFastPayloadHeaderLen + len(frame.Payload)
return runtime.sealFilledPayloadPooled(plainLen, func(dst []byte) error {
if err := encodeBulkFastFrameHeader(dst, frame.Type, frame.Flags, frame.DataID, frame.Seq, len(frame.Payload)); err != nil {
return err
}
copy(dst[bulkFastPayloadHeaderLen:], frame.Payload)
return nil
})
}
func encodeBulkFastBatchPlain(frames []bulkFastFrame) ([]byte, error) {
if len(frames) == 0 {
return nil, errBulkFastPayloadInvalid
}
buf := make([]byte, bulkFastBatchPlainLen(frames))
if err := writeBulkFastBatchPlain(buf, frames); err != nil {
return nil, err
}
return buf, nil
}
func encodeBulkFastBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, frames []bulkFastFrame) ([]byte, error) {
if encode == nil {
return nil, errTransportPayloadEncryptFailed
}
plainLen := bulkFastBatchPlainLen(frames)
return encode(secretKey, plainLen, func(dst []byte) error {
return writeBulkFastBatchPlain(dst, frames)
})
}
func encodeBulkFastBatchPayloadPooled(runtime *modernPSKCodecRuntime, frames []bulkFastFrame) ([]byte, func(), error) {
if runtime == nil {
return nil, nil, errTransportPayloadEncryptFailed
}
return runtime.sealFilledPayloadPooled(bulkFastBatchPlainLen(frames), func(dst []byte) error {
return writeBulkFastBatchPlain(dst, frames)
})
}
func writeBulkFastBatchPlain(dst []byte, frames []bulkFastFrame) error {
if len(frames) == 0 || len(dst) != bulkFastBatchPlainLen(frames) {
return errBulkFastPayloadInvalid
}
copy(dst[:4], bulkFastBatchMagic)
dst[4] = bulkFastBatchVersion
binary.BigEndian.PutUint32(dst[8:12], uint32(len(frames)))
offset := bulkFastBatchHeaderLen
for _, frame := range frames {
if frame.DataID == 0 {
return errBulkFastPayloadInvalid
}
dst[offset] = frame.Type
dst[offset+1] = frame.Flags
binary.BigEndian.PutUint64(dst[offset+4:offset+12], frame.DataID)
binary.BigEndian.PutUint64(dst[offset+12:offset+20], frame.Seq)
binary.BigEndian.PutUint32(dst[offset+20:offset+24], uint32(len(frame.Payload)))
offset += bulkFastBatchItemHeaderLen
copy(dst[offset:offset+len(frame.Payload)], frame.Payload)
offset += len(frame.Payload)
}
return nil
}
func walkBulkFastBatchPlain(payload []byte, fn func(bulkFastFrame) error) (bool, error) {
if len(payload) < 4 || string(payload[:4]) != bulkFastBatchMagic {
return false, nil
}
if len(payload) < bulkFastBatchHeaderLen {
return true, errBulkFastPayloadInvalid
}
if payload[4] != bulkFastBatchVersion {
return true, errBulkFastPayloadInvalid
}
count := int(binary.BigEndian.Uint32(payload[8:12]))
if count <= 0 {
return true, errBulkFastPayloadInvalid
}
offset := bulkFastBatchHeaderLen
for index := 0; index < count; index++ {
if len(payload)-offset < bulkFastBatchItemHeaderLen {
return true, errBulkFastPayloadInvalid
}
frameType := payload[offset]
switch frameType {
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
default:
return true, errBulkFastPayloadInvalid
}
flags := payload[offset+1]
dataID := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
seq := binary.BigEndian.Uint64(payload[offset+12 : offset+20])
payloadLen := int(binary.BigEndian.Uint32(payload[offset+20 : offset+24]))
offset += bulkFastBatchItemHeaderLen
if dataID == 0 || payloadLen < 0 || len(payload)-offset < payloadLen {
return true, errBulkFastPayloadInvalid
}
if fn != nil {
if err := fn(bulkFastFrame{
Type: frameType,
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload[offset : offset+payloadLen],
}); err != nil {
return true, err
}
}
offset += payloadLen
}
if offset != len(payload) {
return true, errBulkFastPayloadInvalid
}
return true, nil
}
func decodeBulkFastBatchPlain(payload []byte) ([]bulkFastFrame, bool, error) {
frames := make([]bulkFastFrame, 0, 1)
matched, err := walkBulkFastBatchPlain(payload, func(frame bulkFastFrame) error {
frames = append(frames, frame)
return nil
})
if !matched || err != nil {
return nil, matched, err
}
return frames, true, nil
}
func walkBulkFastFrames(payload []byte, fn func(bulkFastFrame) error) (bool, error) {
if matched, err := walkBulkFastBatchPlain(payload, fn); matched {
return true, err
}
frame, matched, err := decodeBulkFastFrame(payload)
if !matched || err != nil {
return matched, err
}
if fn != nil {
if err := fn(frame); err != nil {
return true, err
}
}
return true, nil
}
func decodeBulkFastFrames(payload []byte) ([]bulkFastFrame, bool, error) {
frames := make([]bulkFastFrame, 0, 1)
matched, err := walkBulkFastFrames(payload, func(frame bulkFastFrame) error {
frames = append(frames, frame)
return nil
})
if !matched || err != nil {
return nil, matched, err
}
return frames, true, nil
}

458
bulk_shared_batch_test.go Normal file
View File

@ -0,0 +1,458 @@
package notify
import (
"context"
"testing"
"time"
)
func TestBulkFastBatchPlainRoundTrip(t *testing.T) {
releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2)
if err != nil {
t.Fatalf("encodeBulkDedicatedReleasePayload failed: %v", err)
}
frames := []bulkFastFrame{
{
Type: bulkFastPayloadTypeData,
DataID: 11,
Seq: 7,
Payload: []byte("alpha"),
},
{
Type: bulkFastPayloadTypeRelease,
DataID: 12,
Seq: 0,
Payload: releasePayload,
},
}
wire, err := encodeBulkFastBatchPlain(frames)
if err != nil {
t.Fatalf("encodeBulkFastBatchPlain failed: %v", err)
}
decoded, matched, err := decodeBulkFastBatchPlain(wire)
if err != nil {
t.Fatalf("decodeBulkFastBatchPlain failed: %v", err)
}
if !matched {
t.Fatal("decodeBulkFastBatchPlain should match encoded batch")
}
if got, want := len(decoded), len(frames); got != want {
t.Fatalf("decoded frame count = %d, want %d", got, want)
}
for index := range frames {
if got, want := decoded[index].Type, frames[index].Type; got != want {
t.Fatalf("frame %d type = %d, want %d", index, got, want)
}
if got, want := decoded[index].DataID, frames[index].DataID; got != want {
t.Fatalf("frame %d dataID = %d, want %d", index, got, want)
}
if got, want := decoded[index].Seq, frames[index].Seq; got != want {
t.Fatalf("frame %d seq = %d, want %d", index, got, want)
}
if got, want := string(decoded[index].Payload), string(frames[index].Payload); got != want {
t.Fatalf("frame %d payload = %q, want %q", index, got, want)
}
}
}
func TestBulkBatchSenderEncodeRequestsCoalescesSharedFastV2Frames(t *testing.T) {
var (
singleCalls int
batchCalls [][]bulkFastFrame
)
sender := &bulkBatchSender{
codec: bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
singleCalls++
return []byte("single"), nil, nil
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
cloned := make([]bulkFastFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil, nil
},
},
}
payloads, err := sender.encodeRequests([]bulkBatchRequest{
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 101,
Seq: 1,
Payload: []byte("a"),
}},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeRelease,
DataID: 101,
Seq: 0,
Payload: []byte("rel"),
}},
fastPathVersion: bulkFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 1; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if singleCalls != 0 {
t.Fatalf("single encode calls = %d, want 0", singleCalls)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("batched item count = %d, want %d", got, want)
}
}
func TestBulkBatchSenderEncodeRequestsCoalescesAcrossRequestsAndBulks(t *testing.T) {
var (
singleCalls int
batchCalls [][]bulkFastFrame
)
sender := &bulkBatchSender{
codec: bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
singleCalls++
return []byte("single"), nil, nil
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
cloned := make([]bulkFastFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil, nil
},
},
}
payloads, err := sender.encodeRequests([]bulkBatchRequest{
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 101,
Seq: 1,
Payload: []byte("bulk-a"),
}},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 202,
Seq: 1,
Payload: []byte("bulk-b"),
}},
fastPathVersion: bulkFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 1; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if singleCalls != 0 {
t.Fatalf("single encode calls = %d, want 0", singleCalls)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("batched item count = %d, want %d", got, want)
}
if got, want := batchCalls[0][0].DataID, uint64(101); got != want {
t.Fatalf("first batched dataID = %d, want %d", got, want)
}
if got, want := batchCalls[0][1].DataID, uint64(202); got != want {
t.Fatalf("second batched dataID = %d, want %d", got, want)
}
}
func TestBulkBatchSenderEncodeRequestsSplitsLargeCrossBulkSuperBatch(t *testing.T) {
var batchCalls [][]bulkFastFrame
sender := &bulkBatchSender{
codec: bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
return []byte("single"), nil, nil
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
cloned := make([]bulkFastFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil, nil
},
},
}
payload := make([]byte, 768*1024)
payloads, err := sender.encodeRequests([]bulkBatchRequest{
{
frames: []bulkFastFrame{
{
Type: bulkFastPayloadTypeData,
DataID: 101,
Seq: 1,
Payload: payload,
},
{
Type: bulkFastPayloadTypeData,
DataID: 101,
Seq: 2,
Payload: payload,
},
},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{
{
Type: bulkFastPayloadTypeData,
DataID: 202,
Seq: 1,
Payload: payload,
},
{
Type: bulkFastPayloadTypeData,
DataID: 202,
Seq: 2,
Payload: payload,
},
},
fastPathVersion: bulkFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 2; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := len(batchCalls), 2; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("first payload frame count = %d, want %d", got, want)
}
if got, want := len(batchCalls[1]), 2; got != want {
t.Fatalf("second payload frame count = %d, want %d", got, want)
}
if got, want := batchCalls[0][0].DataID, uint64(101); got != want {
t.Fatalf("first payload dataID = %d, want %d", got, want)
}
if got, want := batchCalls[1][0].DataID, uint64(202); got != want {
t.Fatalf("second payload dataID = %d, want %d", got, want)
}
}
func TestBulkReleaseFastRoundTripTransport(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan BulkAcceptInfo, 1)
server.SetBulkHandler(func(info BulkAcceptInfo) error {
acceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
const chunkSize = 64 * 1024
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
Range: BulkRange{Offset: 0, Length: chunkSize},
ChunkSize: chunkSize,
WindowBytes: chunkSize,
MaxInFlight: 1,
})
if err != nil {
t.Fatalf("client OpenBulk failed: %v", err)
}
accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second)
clientHandle := bulk.(*bulkHandle)
serverHandle := accepted.Bulk.(*bulkHandle)
if got, want := clientHandle.FastPathVersion(), uint8(bulkFastPathVersionV2); got != want {
t.Fatalf("client fast path version = %d, want %d", got, want)
}
if got, want := serverHandle.FastPathVersion(), uint8(bulkFastPathVersionV2); got != want {
t.Fatalf("server fast path version = %d, want %d", got, want)
}
clientHandle.mu.Lock()
clientHandle.outboundAvailBytes = 0
clientHandle.outboundInFlight = 1
clientHandle.mu.Unlock()
releaseFn := serverBulkReleaseSender(server, accepted.LogicalConn, accepted.TransportConn)
if err := releaseFn(serverHandle, chunkSize, 1); err != nil {
t.Fatalf("server bulk release sender failed: %v", err)
}
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
clientHandle.mu.Lock()
avail := clientHandle.outboundAvailBytes
inFlight := clientHandle.outboundInFlight
clientHandle.mu.Unlock()
if avail == chunkSize && inFlight == 0 {
return
}
time.Sleep(10 * time.Millisecond)
}
clientHandle.mu.Lock()
avail := clientHandle.outboundAvailBytes
inFlight := clientHandle.outboundInFlight
clientHandle.mu.Unlock()
t.Fatalf("client outbound window not released by fast path: avail=%d inFlight=%d", avail, inFlight)
}
func TestBulkBatchSenderEncodeRequestsResetsBatchBytesAfterFlushBoundary(t *testing.T) {
var (
singleCalls int
batchCalls [][]bulkFastFrame
)
sender := &bulkBatchSender{
codec: bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
singleCalls++
return []byte("single"), nil, nil
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
cloned := make([]bulkFastFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil, nil
},
},
}
largePayload := make([]byte, bulkFastBatchMaxPlainBytes-bulkFastBatchHeaderLen-bulkFastBatchItemHeaderLen-128)
payloads, err := sender.encodeRequests([]bulkBatchRequest{
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 101,
Seq: 1,
Payload: largePayload,
}},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 101,
Seq: 2,
Payload: []byte("sep"),
}},
fastPathVersion: bulkFastPathVersionV1,
},
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 202,
Seq: 1,
Payload: []byte("a"),
}},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 202,
Seq: 2,
Payload: []byte("b"),
}},
fastPathVersion: bulkFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 3; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := singleCalls, 2; got != want {
t.Fatalf("single encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("post-flush batched frame count = %d, want %d", got, want)
}
}
func TestBulkBatchSenderEncodeRequestsUsesBindingAdaptiveSoftLimit(t *testing.T) {
binding := &transportBinding{}
binding.observeBulkAdaptivePayloadWrite(8*1024*1024, 640*time.Millisecond, 0, nil)
var batchCalls [][]bulkFastFrame
sender := &bulkBatchSender{
binding: binding,
codec: bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
return []byte("single"), nil, nil
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
cloned := make([]bulkFastFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil, nil
},
},
}
payload := make([]byte, 96*1024)
payloads, err := sender.encodeRequests([]bulkBatchRequest{
{
frames: []bulkFastFrame{
{Type: bulkFastPayloadTypeData, DataID: 101, Seq: 1, Payload: payload},
{Type: bulkFastPayloadTypeData, DataID: 101, Seq: 2, Payload: payload},
},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{
{Type: bulkFastPayloadTypeData, DataID: 202, Seq: 1, Payload: payload},
{Type: bulkFastPayloadTypeData, DataID: 202, Seq: 2, Payload: payload},
},
fastPathVersion: bulkFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 2; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := len(batchCalls), 2; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
for index, want := range []uint64{101, 202} {
if got := batchCalls[index][0].DataID; got != want {
t.Fatalf("payload %d first dataID = %d, want %d", index, got, want)
}
if got, wantItems := len(batchCalls[index]), 2; got != wantItems {
t.Fatalf("payload %d item count = %d, want %d", index, got, wantItems)
}
}
}

View File

@ -7,49 +7,56 @@ import (
)
type BulkSnapshot struct {
ID string
DataID uint64
Scope string
Range BulkRange
Metadata BulkMetadata
BindingOwner string
BindingAlive bool
BindingCurrent bool
BindingReason string
BindingError string
Dedicated bool
DedicatedAttached bool
SessionEpoch uint64
LogicalClientID string
TransportGeneration uint64
TransportAttached bool
TransportHasRuntimeConn bool
TransportCurrent bool
TransportDetachReason string
TransportDetachKind string
TransportDetachGeneration uint64
TransportDetachError string
TransportDetachedAt time.Time
ReattachEligible bool
LocalClosed bool
LocalReadClosed bool
RemoteClosed bool
PeerReadClosed bool
BufferedChunks int
BufferedBytes int
ReadTimeout time.Duration
WriteTimeout time.Duration
ChunkSize int
WindowBytes int
MaxInFlight int
BytesRead int64
BytesWritten int64
ReadCalls int64
WriteCalls int64
OpenedAt time.Time
LastReadAt time.Time
LastWriteAt time.Time
ResetError string
ID string
DataID uint64
FastPathVersion uint8
Scope string
Range BulkRange
Metadata BulkMetadata
BindingOwner string
BindingAlive bool
BindingCurrent bool
BindingReason string
BindingError string
BindingBulkAdaptiveSoftPayloadBytes int
Dedicated bool
DedicatedLaneID uint32
DedicatedAttached bool
DedicatedAttachState string
DedicatedAttachAttempts uint32
DedicatedAttachLastCode string
DedicatedDataStarted bool
SessionEpoch uint64
LogicalClientID string
TransportGeneration uint64
TransportAttached bool
TransportHasRuntimeConn bool
TransportCurrent bool
TransportDetachReason string
TransportDetachKind string
TransportDetachGeneration uint64
TransportDetachError string
TransportDetachedAt time.Time
ReattachEligible bool
LocalClosed bool
LocalReadClosed bool
RemoteClosed bool
PeerReadClosed bool
BufferedChunks int
BufferedBytes int
ReadTimeout time.Duration
WriteTimeout time.Duration
ChunkSize int
WindowBytes int
MaxInFlight int
BytesRead int64
BytesWritten int64
ReadCalls int64
WriteCalls int64
OpenedAt time.Time
LastReadAt time.Time
LastWriteAt time.Time
ResetError string
}
type clientBulkSnapshotReader interface {

View File

@ -88,7 +88,7 @@ func BenchmarkDedicatedWireLocalhostThroughput(b *testing.B) {
func benchmarkDedicatedWireLocalhostThroughput(b *testing.B, key []byte, encode transportFastPlainEncoder, payloadSize int) {
b.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
listener, err := net.Listen("tcp", benchmarkTCPListenAddr(b))
if err != nil {
b.Fatalf("net.Listen failed: %v", err)
}
@ -107,7 +107,7 @@ func benchmarkDedicatedWireLocalhostThroughput(b *testing.B, key []byte, encode
acceptCh <- conn
}()
clientConn, err := net.Dial("tcp", listener.Addr().String())
clientConn, err := net.Dial("tcp", benchmarkTCPDialAddr(b, listener.Addr().String()))
if err != nil {
b.Fatalf("net.Dial failed: %v", err)
}

View File

@ -109,6 +109,9 @@ func TestBulkOpenRoundTripTCP(t *testing.T) {
if !clientSnapshots[0].BindingAlive || !clientSnapshots[0].BindingCurrent || !clientSnapshots[0].TransportAttached || !clientSnapshots[0].TransportCurrent {
t.Fatalf("client bulk binding snapshot mismatch: %+v", clientSnapshots[0])
}
if got, want := clientSnapshots[0].BindingBulkAdaptiveSoftPayloadBytes, bulkAdaptiveSoftPayloadStartBytes; got != want {
t.Fatalf("client bulk BindingBulkAdaptiveSoftPayloadBytes = %d, want %d", got, want)
}
serverSnapshots, err := GetServerBulkSnapshots(server)
if err != nil {
t.Fatalf("GetServerBulkSnapshots failed: %v", err)
@ -119,6 +122,9 @@ func TestBulkOpenRoundTripTCP(t *testing.T) {
if got, want := serverSnapshots[0].BindingOwner, "server-transport"; got != want {
t.Fatalf("server bulk BindingOwner = %q, want %q", got, want)
}
if got, want := serverSnapshots[0].BindingBulkAdaptiveSoftPayloadBytes, bulkAdaptiveSoftPayloadStartBytes; got != want {
t.Fatalf("server bulk BindingBulkAdaptiveSoftPayloadBytes = %d, want %d", got, want)
}
if !serverSnapshots[0].BindingAlive || !serverSnapshots[0].BindingCurrent || !serverSnapshots[0].TransportAttached || !serverSnapshots[0].TransportCurrent {
t.Fatalf("server bulk binding snapshot mismatch: %+v", serverSnapshots[0])
}
@ -135,6 +141,314 @@ func TestBulkOpenRoundTripTCP(t *testing.T) {
waitForBulkContextDone(t, bulk.Context(), 2*time.Second)
}
func TestDedicatedBulkOpenUnblocksSynchronousReadHandler(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
readCh := make(chan string, 1)
server.SetBulkHandler(func(info BulkAcceptInfo) error {
defer func() {
_ = info.Bulk.Close()
}()
buf := make([]byte, 5)
if _, err := io.ReadFull(info.Bulk, buf); err != nil {
return err
}
readCh <- string(buf)
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
type openResult struct {
bulk Bulk
err error
}
openCh := make(chan openResult, 1)
go func() {
bulk, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{
ID: "sync-read-handler",
Range: BulkRange{
Offset: 0,
Length: 5,
},
})
openCh <- openResult{bulk: bulk, err: err}
}()
var bulk Bulk
select {
case result := <-openCh:
if result.err != nil {
t.Fatalf("client OpenDedicatedBulk failed: %v", result.err)
}
bulk = result.bulk
case <-time.After(2 * time.Second):
t.Fatal("client OpenDedicatedBulk timed out while remote handler was synchronously reading")
}
defer func() {
if bulk != nil {
_ = bulk.Close()
}
}()
if _, err := bulk.Write([]byte("hello")); err != nil {
t.Fatalf("client dedicated bulk Write failed: %v", err)
}
select {
case got := <-readCh:
if got != "hello" {
t.Fatalf("server handler read %q, want %q", got, "hello")
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for synchronous handler read")
}
}
func TestDedicatedBulkOpenUnblocksOnBlockingFirstWrite(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
payload := strings.Repeat("w", 16)
writeDone := make(chan error, 1)
server.SetBulkHandler(func(info BulkAcceptInfo) error {
_, err := io.WriteString(info.Bulk, payload)
if err == nil {
err = info.Bulk.CloseWrite()
}
writeDone <- err
return err
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
bulk, err := client.OpenDedicatedBulk(ctx, BulkOpenOptions{
ID: "blocking-write-ready",
Range: BulkRange{
Offset: 0,
Length: int64(len(payload)),
},
ChunkSize: 4,
WindowBytes: 4,
MaxInFlight: 1,
})
if err != nil {
t.Fatalf("client OpenDedicatedBulk failed: %v", err)
}
readBulkExactly(t, bulk, payload, 2*time.Second)
waitForBulkReadEOF(t, bulk, 2*time.Second)
select {
case err := <-writeDone:
if err != nil {
t.Fatalf("server handler write failed: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for blocking first write to finish")
}
if err := bulk.Close(); err != nil {
t.Fatalf("client dedicated bulk Close failed: %v", err)
}
waitForBulkContextDone(t, bulk.Context(), 2*time.Second)
}
func TestServerOpenBulkLogicalDedicatedUnblocksOnBlockingFirstRead(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
readCh := make(chan string, 1)
client.SetBulkHandler(func(info BulkAcceptInfo) error {
defer func() {
_ = info.Bulk.Close()
}()
buf := make([]byte, 5)
if _, err := io.ReadFull(info.Bulk, buf); err != nil {
return err
}
readCh <- string(buf)
return nil
})
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
bulk, err := server.OpenBulkLogical(ctx, logical, BulkOpenOptions{
ID: "server-blocking-read-ready",
Range: BulkRange{
Offset: 0,
Length: 5,
},
Dedicated: true,
})
if err != nil {
t.Fatalf("server OpenBulkLogical dedicated failed: %v", err)
}
if _, err := bulk.Write([]byte("hello")); err != nil {
t.Fatalf("server dedicated bulk Write failed: %v", err)
}
select {
case got := <-readCh:
if got != "hello" {
t.Fatalf("client handler read %q, want %q", got, "hello")
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for blocking first read to finish")
}
if err := bulk.Close(); err != nil {
t.Fatalf("server dedicated bulk Close failed: %v", err)
}
waitForBulkContextDone(t, bulk.Context(), 2*time.Second)
}
func TestDedicatedBulkOpenReturnsHandlerFailureAfterAccepted(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
server.SetBulkHandler(func(info BulkAcceptInfo) error {
time.Sleep(80 * time.Millisecond)
return errors.New("dedicated handler failed after accept")
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, err := client.OpenDedicatedBulk(ctx, BulkOpenOptions{
ID: "accepted-then-fail",
Range: BulkRange{
Offset: 0,
Length: 1,
},
})
if err == nil || !strings.Contains(err.Error(), "dedicated handler failed after accept") {
t.Fatalf("client OpenDedicatedBulk error = %v, want dedicated handler failure after accept", err)
}
}
func TestServerOpenBulkLogicalDedicatedReturnsHandlerFailureAfterAccepted(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
client.SetBulkHandler(func(info BulkAcceptInfo) error {
time.Sleep(80 * time.Millisecond)
return errors.New("client dedicated handler failed after accept")
})
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, err := server.OpenBulkLogical(ctx, logical, BulkOpenOptions{
ID: "server-accepted-then-fail",
Range: BulkRange{
Offset: 0,
Length: 1,
},
Dedicated: true,
})
if err == nil || !strings.Contains(err.Error(), "client dedicated handler failed after accept") {
t.Fatalf("server OpenBulkLogical dedicated error = %v, want client dedicated handler failure after accept", err)
}
}
func TestBulkOpenRoundTripServerLogicalTCP(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
@ -532,15 +846,18 @@ func TestDedicatedBulkWritePrefersClosedPipeOverContextCanceled(t *testing.T) {
MaxInFlight: 4,
}, 0, nil, nil, 0, nil, nil, func(context.Context, *bulkHandle, []byte) error {
return nil
}, func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) {
}, func(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) {
bulk.markPeerClosed()
<-ctx.Done()
return 0, ctx.Err()
}, nil)
_, err := bulk.Write([]byte("abcdefgh"))
if !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("bulk Write error = %v, want %v", err, io.ErrClosedPipe)
if err != nil && !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("bulk Write error = %v, want nil or %v", err, io.ErrClosedPipe)
}
if err := bulk.waitPendingAsyncWrites(context.Background()); err != nil && !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("bulk waitPendingAsyncWrites error = %v, want nil or %v", err, io.ErrClosedPipe)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -35,6 +35,7 @@ type ClientCommon struct {
fastStreamEncode transportFastStreamEncoder
fastBulkEncode transportFastBulkEncoder
fastPlainEncode transportFastPlainEncoder
modernPSKRuntime *modernPSKCodecRuntime
handshakeRsaPubKey []byte
SecretKey []byte
noFinSyncMsgMaxKeepSeconds int
@ -55,6 +56,26 @@ type ClientCommon struct {
streamRuntime *streamRuntime
recordRuntime *recordRuntime
bulkRuntime *bulkRuntime
bulkDefaultOpenMode BulkOpenMode
bulkNetworkProfile BulkNetworkProfile
bulkOpenTuning BulkOpenTuning
bulkDedicatedAttachLimit int
bulkDedicatedAttachSem chan struct{}
bulkDedicatedAttachRetry int
bulkDedicatedAttachBackoff time.Duration
bulkDedicatedDialTimeout time.Duration
bulkDedicatedHelloTimeout time.Duration
bulkDedicatedActiveLimit int
bulkDedicatedActive atomic.Int32
bulkDedicatedActiveWait chan struct{}
bulkDedicatedLaneLimit int
bulkDedicatedSidecarMu sync.Mutex
bulkDedicatedLanes map[uint32]*bulkDedicatedLane
bulkDedicatedNextLaneID uint32
bulkAttachAttemptCount atomic.Int64
bulkAttachRetryCount atomic.Int64
bulkAttachSuccessCount atomic.Int64
bulkAttachFallbackCount atomic.Int64
connectionRetryState *connectionRetryState
securityReadyCheck bool
debugMode bool
@ -63,21 +84,32 @@ type ClientCommon struct {
func NewClient() Client {
transport := defaultModernPSKTransportBundle()
var client = ClientCommon{
maxReadTimeout: 0,
maxWriteTimeout: 0,
peerIdentity: newClientPeerIdentity(),
sequenceEn: encode,
sequenceDe: Decode,
keyExchangeFn: aesRsaHello,
SecretKey: nil,
handshakeRsaPubKey: defaultRsaPubKey,
msgEn: transport.msgEn,
msgDe: transport.msgDe,
fastStreamEncode: transport.fastStreamEncode,
fastBulkEncode: transport.fastBulkEncode,
fastPlainEncode: transport.fastPlainEncode,
skipKeyExchange: true,
securityReadyCheck: true,
maxReadTimeout: 0,
maxWriteTimeout: 0,
peerIdentity: newClientPeerIdentity(),
sequenceEn: encode,
sequenceDe: Decode,
keyExchangeFn: aesRsaHello,
SecretKey: nil,
handshakeRsaPubKey: defaultRsaPubKey,
msgEn: transport.msgEn,
msgDe: transport.msgDe,
fastStreamEncode: transport.fastStreamEncode,
fastBulkEncode: transport.fastBulkEncode,
fastPlainEncode: transport.fastPlainEncode,
skipKeyExchange: true,
securityReadyCheck: true,
bulkDefaultOpenMode: BulkOpenModeShared,
bulkNetworkProfile: BulkNetworkProfileDefault,
bulkOpenTuning: defaultBulkOpenTuning(),
bulkDedicatedAttachLimit: defaultBulkDedicatedAttachLimit,
bulkDedicatedAttachRetry: defaultBulkDedicatedAttachRetry,
bulkDedicatedAttachBackoff: defaultBulkDedicatedAttachBackoff,
bulkDedicatedDialTimeout: defaultBulkDedicatedDialTimeout,
bulkDedicatedHelloTimeout: defaultBulkDedicatedHelloTimeout,
bulkDedicatedActiveLimit: defaultBulkDedicatedActiveLimit,
bulkDedicatedActiveWait: make(chan struct{}),
bulkDedicatedLaneLimit: defaultBulkDedicatedLaneLimit,
}
client.alive.Store(false)
client.useHeartBeat = true
@ -93,6 +125,10 @@ func NewClient() Client {
client.streamRuntime = newStreamRuntime("cstrm")
client.recordRuntime = newRecordRuntime()
client.bulkRuntime = newBulkRuntime("cblk")
client.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane)
if client.bulkDedicatedAttachLimit > 0 {
client.bulkDedicatedAttachSem = make(chan struct{}, client.bulkDedicatedAttachLimit)
}
client.connectionRetryState = newConnectionRetryState()
client.onFileEvent = normalizeFileEventCallback(nil)
client.fileEventObserver = normalizeFileEventCallback(nil)
@ -103,3 +139,12 @@ func NewClient() Client {
client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler)
return &client
}
func (c *ClientCommon) maxWriteTimeoutSnapshot() time.Duration {
if c == nil {
return 0
}
c.mu.Lock()
defer c.mu.Unlock()
return c.maxWriteTimeout
}

View File

@ -1,6 +1,9 @@
package notify
import "context"
import (
"context"
"errors"
)
func (c *ClientCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) {
runtime := c.getBulkRuntime()
@ -10,10 +13,67 @@ func (c *ClientCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) {
runtime.setHandler(fn)
}
func (c *ClientCommon) OpenSharedBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) {
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return c.OpenBulk(ctx, opt)
}
func (c *ClientCommon) OpenDedicatedBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) {
opt.Mode = BulkOpenModeDedicated
opt.Dedicated = true
return c.OpenBulk(ctx, opt)
}
func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) {
if normalizeBulkOpenMode(opt.Mode) == BulkOpenModeDefault && !opt.Dedicated {
opt.Mode = c.bulkDefaultOpenModeSnapshot()
}
opt = normalizeBulkOpenOptions(opt)
switch opt.Mode {
case BulkOpenModeDedicated:
opt.Dedicated = true
return c.openBulkWithDedicatedMode(ctx, opt)
case BulkOpenModeAuto:
// Auto mode prefers dedicated path and falls back to shared if dedicated fails.
if err := clientDedicatedBulkSupportError(c); err == nil {
dedicatedOpt := opt
dedicatedOpt.Mode = BulkOpenModeDedicated
dedicatedOpt.Dedicated = true
bulk, dedicatedErr := c.openBulkWithDedicatedMode(ctx, dedicatedOpt)
if dedicatedErr == nil {
return bulk, nil
}
sharedOpt := opt
sharedOpt.Mode = BulkOpenModeShared
sharedOpt.Dedicated = false
sharedBulk, sharedErr := c.openBulkWithDedicatedMode(ctx, sharedOpt)
if sharedErr == nil {
c.bulkAttachFallbackCount.Add(1)
return sharedBulk, nil
}
return nil, errors.Join(dedicatedErr, sharedErr)
}
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
c.bulkAttachFallbackCount.Add(1)
return c.openBulkWithDedicatedMode(ctx, opt)
case BulkOpenModeShared, BulkOpenModeDefault:
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return c.openBulkWithDedicatedMode(ctx, opt)
default:
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return c.openBulkWithDedicatedMode(ctx, opt)
}
}
func (c *ClientCommon) openBulkWithDedicatedMode(ctx context.Context, opt BulkOpenOptions) (Bulk, error) {
if c == nil {
return nil, errBulkClientNil
}
opt = applyBulkOpenTuningDefaults(opt, c.bulkOpenTuningSnapshot())
runtime := c.getBulkRuntime()
if runtime == nil {
return nil, errBulkRuntimeNil
@ -33,6 +93,66 @@ func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk,
if _, exists := runtime.lookup(clientFileScope(), req.BulkID); exists {
return nil, errBulkAlreadyExists
}
if req.Dedicated {
req.DedicatedLaneID = c.reserveBulkDedicatedLane()
if req.DataID == 0 {
req.DataID = runtime.nextDataID()
}
if req.AttachToken == "" {
req.AttachToken = newBulkAttachToken()
}
bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, 0, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c))
bulk.setClientSnapshotOwner(c)
bulk.markAcceptHandled()
if err := runtime.register(clientFileScope(), bulk); err != nil {
c.releaseBulkDedicatedLane(req.DedicatedLaneID)
return nil, err
}
resp, err := sendBulkOpenClient(ctx, c, req)
if err != nil {
bulk.markReset(err)
return nil, err
}
if resp.DataID != 0 && resp.DataID != req.DataID {
err = errBulkAlreadyExists
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: "bulk dedicated data id mismatch",
})
bulk.markReset(err)
return nil, err
}
if resp.TransportGeneration != 0 {
bulk.transportGeneration = resp.TransportGeneration
}
if resp.FastPathVersion != 0 {
bulk.fastPathVersion = normalizeBulkFastPathVersion(resp.FastPathVersion)
}
if resp.AttachToken != "" {
req.AttachToken = resp.AttachToken
bulk.setDedicatedAttachToken(resp.AttachToken)
}
if err := c.attachDedicatedBulkSidecar(ctx, bulk); err != nil {
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: err.Error(),
})
bulk.markReset(err)
return nil, err
}
if err := bulk.waitAcceptReady(ctx); err != nil {
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: err.Error(),
})
bulk.markReset(err)
return nil, err
}
return bulk, nil
}
resp, err := sendBulkOpenClient(ctx, c, req)
if err != nil {
return nil, err
@ -40,6 +160,9 @@ func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk,
if resp.DataID != 0 {
req.DataID = resp.DataID
}
if resp.FastPathVersion != 0 {
req.FastPathVersion = resp.FastPathVersion
}
req.Dedicated = resp.Dedicated
if resp.AttachToken != "" {
req.AttachToken = resp.AttachToken
@ -49,7 +172,9 @@ func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk,
}
bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c))
bulk.setClientSnapshotOwner(c)
bulk.markAcceptHandled()
if err := runtime.register(clientFileScope(), bulk); err != nil {
c.releaseBulkDedicatedLane(req.DedicatedLaneID)
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
@ -78,15 +203,16 @@ func clientBulkRequest(runtime *bulkRuntime, opt BulkOpenOptions) BulkOpenReques
id = runtime.nextID()
}
return normalizeBulkOpenRequest(BulkOpenRequest{
BulkID: id,
Range: opt.Range,
Metadata: cloneBulkMetadata(opt.Metadata),
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
Dedicated: opt.Dedicated,
ChunkSize: opt.ChunkSize,
WindowBytes: opt.WindowBytes,
MaxInFlight: opt.MaxInFlight,
BulkID: id,
FastPathVersion: bulkFastPathVersionCurrent,
Range: opt.Range,
Metadata: cloneBulkMetadata(opt.Metadata),
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
Dedicated: opt.Dedicated,
ChunkSize: opt.ChunkSize,
WindowBytes: opt.WindowBytes,
MaxInFlight: opt.MaxInFlight,
})
}
@ -148,12 +274,12 @@ func clientBulkDataSender(c *ClientCommon, epoch uint64) bulkDataSender {
if dataID == 0 {
return errBulkDataPathNotReady
}
return c.sendFastBulkData(ctx, dataID, bulk.nextOutboundDataSeq(), chunk)
return c.sendFastBulkData(ctx, dataID, bulk.nextOutboundDataSeq(), chunk, bulk.fastPathVersionSnapshot())
}
}
func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender {
return func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) {
return func(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) {
if c == nil {
return 0, errBulkClientNil
}
@ -168,12 +294,19 @@ func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender {
if err := bulk.waitDedicatedReady(ctx); err != nil {
return 0, err
}
return c.sendDedicatedBulkWrite(ctx, bulk, payload)
return c.sendDedicatedBulkWrite(ctx, bulk, startSeq, payload)
}
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
return 0, errTransportDetached
}
return 0, nil
if bulk == nil {
return 0, errBulkRuntimeNil
}
dataID := bulk.dataIDSnapshot()
if dataID == 0 {
return 0, errBulkDataPathNotReady
}
return c.sendFastBulkWrite(ctx, dataID, startSeq, bulk.chunkSize, bulk.fastPathVersionSnapshot(), payload, payloadOwned)
}
}
@ -185,8 +318,20 @@ func clientBulkReleaseSender(c *ClientCommon) bulkReleaseSender {
if bytes <= 0 && chunks <= 0 {
return nil
}
ctx, cancel, err := bulk.newWriteContext(bulk.Context(), bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
if bulk.Dedicated() {
return c.sendDedicatedBulkRelease(context.Background(), bulk, bytes, chunks)
return c.sendDedicatedBulkRelease(ctx, bulk, bytes, chunks)
}
if bulk.fastPathVersionSnapshot() >= bulkFastPathVersionV2 {
payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks)
if err != nil {
return err
}
return c.sendFastBulkControl(ctx, bulkFastPayloadTypeRelease, 0, bulk.dataIDSnapshot(), 0, bulk.fastPathVersionSnapshot(), payload)
}
return sendBulkReleaseClient(c, BulkReleaseRequest{
BulkID: bulk.ID(),

364
client_bulk_config.go Normal file
View File

@ -0,0 +1,364 @@
package notify
import "time"
func defaultBulkOpenTuning() BulkOpenTuning {
return normalizeBulkOpenTuning(BulkOpenTuning{
ChunkSize: defaultBulkChunkSize,
WindowBytes: defaultBulkOpenWindowBytes,
MaxInFlight: defaultBulkOpenMaxInFlight,
})
}
func normalizeBulkOpenTuning(tuning BulkOpenTuning) BulkOpenTuning {
if tuning.ChunkSize <= 0 {
tuning.ChunkSize = defaultBulkChunkSize
}
if tuning.WindowBytes <= 0 {
tuning.WindowBytes = defaultBulkOpenWindowBytes
}
if tuning.WindowBytes < tuning.ChunkSize {
tuning.WindowBytes = tuning.ChunkSize
}
if tuning.MaxInFlight <= 0 {
tuning.MaxInFlight = defaultBulkOpenMaxInFlight
}
return tuning
}
func applyBulkOpenTuningDefaults(opt BulkOpenOptions, tuning BulkOpenTuning) BulkOpenOptions {
tuning = normalizeBulkOpenTuning(tuning)
if opt.ChunkSize <= 0 {
opt.ChunkSize = tuning.ChunkSize
}
if opt.WindowBytes <= 0 {
opt.WindowBytes = tuning.WindowBytes
}
if opt.MaxInFlight <= 0 {
opt.MaxInFlight = tuning.MaxInFlight
}
return opt
}
func normalizeBulkNetworkProfile(profile BulkNetworkProfile) BulkNetworkProfile {
switch profile {
case BulkNetworkProfileDefault, BulkNetworkProfileLAN, BulkNetworkProfileWAN, BulkNetworkProfileConstrained:
return profile
default:
return BulkNetworkProfileDefault
}
}
func bulkNetworkProfileName(profile BulkNetworkProfile) string {
switch normalizeBulkNetworkProfile(profile) {
case BulkNetworkProfileLAN:
return "lan"
case BulkNetworkProfileWAN:
return "wan"
case BulkNetworkProfileConstrained:
return "constrained"
case BulkNetworkProfileDefault:
fallthrough
default:
return "default"
}
}
func bulkNetworkProfilePreset(profile BulkNetworkProfile) (BulkOpenMode, BulkDedicatedAttachConfig, BulkOpenTuning) {
switch normalizeBulkNetworkProfile(profile) {
case BulkNetworkProfileLAN:
return BulkOpenModeAuto, BulkDedicatedAttachConfig{
AttachLimit: defaultBulkDedicatedAttachLimit,
ActiveLimit: defaultBulkDedicatedActiveLimit,
LaneLimit: 4,
Retry: defaultBulkDedicatedAttachRetry,
Backoff: defaultBulkDedicatedAttachBackoff,
DialTimeout: defaultBulkDedicatedDialTimeout,
HelloTimeout: defaultBulkDedicatedHelloTimeout,
}, BulkOpenTuning{
ChunkSize: 1024 * 1024,
WindowBytes: 32 * 1024 * 1024,
MaxInFlight: 64,
}
case BulkNetworkProfileWAN:
return BulkOpenModeAuto, BulkDedicatedAttachConfig{
AttachLimit: 2,
ActiveLimit: defaultBulkDedicatedActiveLimit,
LaneLimit: 2,
Retry: 4,
Backoff: 250 * time.Millisecond,
DialTimeout: 15 * time.Second,
HelloTimeout: 20 * time.Second,
}, BulkOpenTuning{
ChunkSize: 512 * 1024,
WindowBytes: 8 * 1024 * 1024,
MaxInFlight: 16,
}
case BulkNetworkProfileConstrained:
return BulkOpenModeShared, BulkDedicatedAttachConfig{
AttachLimit: 1,
ActiveLimit: 2,
LaneLimit: 1,
Retry: 5,
Backoff: 400 * time.Millisecond,
DialTimeout: 20 * time.Second,
HelloTimeout: 30 * time.Second,
}, BulkOpenTuning{
ChunkSize: 128 * 1024,
WindowBytes: 1 * 1024 * 1024,
MaxInFlight: 8,
}
default:
return BulkOpenModeShared, BulkDedicatedAttachConfig{
AttachLimit: defaultBulkDedicatedAttachLimit,
ActiveLimit: defaultBulkDedicatedActiveLimit,
LaneLimit: defaultBulkDedicatedLaneLimit,
Retry: defaultBulkDedicatedAttachRetry,
Backoff: defaultBulkDedicatedAttachBackoff,
DialTimeout: defaultBulkDedicatedDialTimeout,
HelloTimeout: defaultBulkDedicatedHelloTimeout,
}, defaultBulkOpenTuning()
}
}
func normalizeBulkDedicatedAttachConfig(cfg BulkDedicatedAttachConfig) BulkDedicatedAttachConfig {
if cfg.AttachLimit < 0 {
cfg.AttachLimit = 0
}
if cfg.ActiveLimit < 0 {
cfg.ActiveLimit = 0
}
if cfg.LaneLimit < 0 {
cfg.LaneLimit = 0
}
if cfg.Retry < 0 {
cfg.Retry = 0
}
if cfg.Backoff <= 0 {
cfg.Backoff = defaultBulkDedicatedAttachBackoff
}
if cfg.DialTimeout <= 0 {
cfg.DialTimeout = defaultBulkDedicatedDialTimeout
}
if cfg.HelloTimeout <= 0 {
cfg.HelloTimeout = defaultBulkDedicatedHelloTimeout
}
return cfg
}
func (c *ClientCommon) setBulkDedicatedAttachSemaphoreLocked(limit int) {
if limit <= 0 {
c.bulkDedicatedAttachSem = nil
return
}
c.bulkDedicatedAttachSem = make(chan struct{}, limit)
}
func (c *ClientCommon) applyBulkDedicatedAttachConfigLocked(cfg BulkDedicatedAttachConfig) {
cfg = normalizeBulkDedicatedAttachConfig(cfg)
c.bulkDedicatedAttachLimit = cfg.AttachLimit
c.setBulkDedicatedAttachSemaphoreLocked(cfg.AttachLimit)
c.bulkDedicatedActiveLimit = cfg.ActiveLimit
c.notifyBulkDedicatedActiveWaitersLocked()
c.bulkDedicatedLaneLimit = cfg.LaneLimit
c.bulkDedicatedAttachRetry = cfg.Retry
c.bulkDedicatedAttachBackoff = cfg.Backoff
c.bulkDedicatedDialTimeout = cfg.DialTimeout
c.bulkDedicatedHelloTimeout = cfg.HelloTimeout
}
func (c *ClientCommon) ensureBulkDedicatedActiveWaitLocked() chan struct{} {
if c == nil {
return nil
}
if c.bulkDedicatedActiveWait == nil {
c.bulkDedicatedActiveWait = make(chan struct{})
}
return c.bulkDedicatedActiveWait
}
func (c *ClientCommon) notifyBulkDedicatedActiveWaitersLocked() {
if c == nil {
return
}
waitCh := c.ensureBulkDedicatedActiveWaitLocked()
close(waitCh)
c.bulkDedicatedActiveWait = make(chan struct{})
}
func (c *ClientCommon) applyBulkOpenTuningLocked(tuning BulkOpenTuning) {
c.bulkOpenTuning = normalizeBulkOpenTuning(tuning)
}
func (c *ClientCommon) bulkDefaultOpenModeSnapshot() BulkOpenMode {
if c == nil {
return BulkOpenModeShared
}
c.mu.Lock()
defer c.mu.Unlock()
mode := normalizeBulkOpenMode(c.bulkDefaultOpenMode)
if mode == BulkOpenModeDefault {
mode = BulkOpenModeShared
}
return mode
}
func (c *ClientCommon) bulkDedicatedAttachSemaphoreSnapshot() chan struct{} {
if c == nil {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
return c.bulkDedicatedAttachSem
}
func (c *ClientCommon) bulkDedicatedActiveLimitSnapshot() int {
if c == nil {
return 0
}
c.mu.Lock()
defer c.mu.Unlock()
if c.bulkDedicatedActiveLimit < 0 {
return 0
}
return c.bulkDedicatedActiveLimit
}
func (c *ClientCommon) bulkDedicatedActiveWaitSnapshot() chan struct{} {
if c == nil {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
return c.ensureBulkDedicatedActiveWaitLocked()
}
func (c *ClientCommon) bulkDedicatedLaneLimitSnapshot() int {
if c == nil {
return defaultBulkDedicatedLaneLimit
}
c.mu.Lock()
defer c.mu.Unlock()
limit := c.bulkDedicatedLaneLimit
if limit < 0 {
return 0
}
if limit == 0 {
return 0
}
return limit
}
func (c *ClientCommon) SetBulkDefaultOpenMode(mode BulkOpenMode) {
if c == nil {
return
}
mode = normalizeBulkOpenMode(mode)
if mode == BulkOpenModeDefault {
mode = BulkOpenModeShared
}
c.mu.Lock()
c.bulkDefaultOpenMode = mode
c.mu.Unlock()
}
func (c *ClientCommon) BulkDefaultOpenMode() BulkOpenMode {
return c.bulkDefaultOpenModeSnapshot()
}
func (c *ClientCommon) bulkOpenTuningSnapshot() BulkOpenTuning {
if c == nil {
return defaultBulkOpenTuning()
}
c.mu.Lock()
defer c.mu.Unlock()
return normalizeBulkOpenTuning(c.bulkOpenTuning)
}
func (c *ClientCommon) SetBulkOpenTuning(tuning BulkOpenTuning) {
if c == nil {
return
}
c.mu.Lock()
c.applyBulkOpenTuningLocked(tuning)
c.mu.Unlock()
}
func (c *ClientCommon) BulkOpenTuning() BulkOpenTuning {
return c.bulkOpenTuningSnapshot()
}
func (c *ClientCommon) SetBulkDedicatedAttachConfig(cfg BulkDedicatedAttachConfig) {
if c == nil {
return
}
c.mu.Lock()
c.applyBulkDedicatedAttachConfigLocked(cfg)
c.mu.Unlock()
}
func (c *ClientCommon) BulkDedicatedAttachConfig() BulkDedicatedAttachConfig {
if c == nil {
return normalizeBulkDedicatedAttachConfig(BulkDedicatedAttachConfig{})
}
c.mu.Lock()
defer c.mu.Unlock()
return normalizeBulkDedicatedAttachConfig(BulkDedicatedAttachConfig{
AttachLimit: c.bulkDedicatedAttachLimit,
ActiveLimit: c.bulkDedicatedActiveLimit,
LaneLimit: c.bulkDedicatedLaneLimit,
Retry: c.bulkDedicatedAttachRetry,
Backoff: c.bulkDedicatedAttachBackoff,
DialTimeout: c.bulkDedicatedDialTimeout,
HelloTimeout: c.bulkDedicatedHelloTimeout,
})
}
func (c *ClientCommon) SetBulkNetworkProfile(profile BulkNetworkProfile) {
if c == nil {
return
}
profile = normalizeBulkNetworkProfile(profile)
mode, cfg, tuning := bulkNetworkProfilePreset(profile)
c.mu.Lock()
c.bulkNetworkProfile = profile
c.bulkDefaultOpenMode = mode
c.applyBulkDedicatedAttachConfigLocked(cfg)
c.applyBulkOpenTuningLocked(tuning)
c.mu.Unlock()
}
func (c *ClientCommon) BulkNetworkProfile() BulkNetworkProfile {
if c == nil {
return BulkNetworkProfileDefault
}
c.mu.Lock()
defer c.mu.Unlock()
return normalizeBulkNetworkProfile(c.bulkNetworkProfile)
}
func (s *ServerCommon) applyBulkOpenTuningLocked(tuning BulkOpenTuning) {
s.bulkOpenTuning = normalizeBulkOpenTuning(tuning)
}
func (s *ServerCommon) bulkOpenTuningSnapshot() BulkOpenTuning {
if s == nil {
return defaultBulkOpenTuning()
}
s.mu.RLock()
defer s.mu.RUnlock()
return normalizeBulkOpenTuning(s.bulkOpenTuning)
}
func (s *ServerCommon) SetBulkOpenTuning(tuning BulkOpenTuning) {
if s == nil {
return
}
s.mu.Lock()
s.applyBulkOpenTuningLocked(tuning)
s.mu.Unlock()
}
func (s *ServerCommon) BulkOpenTuning() BulkOpenTuning {
return s.bulkOpenTuningSnapshot()
}

View File

@ -75,6 +75,7 @@ func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
c.fastStreamEncode = nil
c.fastBulkEncode = nil
c.fastPlainEncode = nil
c.modernPSKRuntime = nil
c.securityReadyCheck = false
}
@ -89,6 +90,7 @@ func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
c.fastStreamEncode = nil
c.fastBulkEncode = nil
c.fastPlainEncode = nil
c.modernPSKRuntime = nil
c.securityReadyCheck = false
}
@ -108,6 +110,13 @@ func (c *ClientCommon) GetSecretKey() []byte {
// Prefer UseModernPSKClient or UseLegacySecurityClient.
func (c *ClientCommon) SetSecretKey(key []byte) {
c.SecretKey = key
if len(key) == 0 {
c.modernPSKRuntime = nil
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
c.modernPSKRuntime = runtime
} else {
c.modernPSKRuntime = nil
}
c.securityReadyCheck = len(key) == 0
c.skipKeyExchange = true
}

View File

@ -72,6 +72,23 @@ func (c *ClientConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
conn := rt.tuConn
generation := rt.transportGeneration
defer closeClientConnSessionRuntimeTransportDone(rt)
if conn != nil && !isPacketTransportConn(conn) {
reader := newTransportFrameReader(conn, nil)
for {
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseClientConnTransportOnStop(conn) {
_ = conn.Close()
}
return
default:
}
payload, release, err := c.readTUTransportPayloadPooled(conn, reader)
if !c.handleTUTransportPayloadReadResultWithSessionPooled(stopCtx, conn, generation, payload, release, err) {
return
}
}
}
buf := streamReadBuffer()
for {
select {

View File

@ -13,6 +13,7 @@ type clientConnAttachmentState struct {
fastStreamEncode transportFastStreamEncoder
fastBulkEncode transportFastBulkEncoder
fastPlainEncode transportFastPlainEncoder
modernPSKRuntime *modernPSKCodecRuntime
handshakeRsaKey []byte
secretKey []byte
lastHeartBeat int64
@ -154,6 +155,7 @@ func (c *ClientConn) applyClientConnAttachmentProfile(maxReadTimeout time.Durati
state.maxWriteTimeout = maxWriteTimeout
state.msgEn = msgEn
state.msgDe = msgDe
state.modernPSKRuntime = nil
state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey)
state.secretKey = cloneClientConnAttachmentBytes(secretKey)
})
@ -208,6 +210,7 @@ func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) {
state.fastStreamEncode = nil
state.fastBulkEncode = nil
state.fastPlainEncode = nil
state.modernPSKRuntime = nil
})
}
@ -224,6 +227,7 @@ func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) {
state.fastStreamEncode = nil
state.fastBulkEncode = nil
state.fastPlainEncode = nil
state.modernPSKRuntime = nil
})
}
@ -286,6 +290,64 @@ func (c *ClientConn) setClientConnSecretKey(key []byte) {
})
}
func (c *LogicalConn) attachmentStateRaw() *clientConnAttachmentState {
if c == nil {
return nil
}
if state := c.attachment.Load(); state != nil {
if client := c.compatClientConn(); client != nil {
client.attachment.Store(state)
}
return state
}
if client := c.compatClientConn(); client != nil {
if state := client.attachment.Load(); state != nil {
if c.attachment.CompareAndSwap(nil, state) {
client.attachment.Store(state)
return state
}
return c.attachmentStateRaw()
}
}
return nil
}
func (c *LogicalConn) modernPSKRuntimeSnapshot() *modernPSKCodecRuntime {
if state := c.attachmentStateRaw(); state != nil {
return state.modernPSKRuntime
}
return nil
}
func (c *LogicalConn) setModernPSKRuntime(runtime *modernPSKCodecRuntime) {
c.updateAttachmentState(func(state *clientConnAttachmentState) {
state.modernPSKRuntime = runtime
})
}
func (c *ClientConn) clientConnAttachmentStateRaw() *clientConnAttachmentState {
if c == nil {
return nil
}
if logical := c.logicalView.Load(); logical != nil {
return logical.attachmentStateRaw()
}
return c.attachment.Load()
}
func (c *ClientConn) clientConnModernPSKRuntimeSnapshot() *modernPSKCodecRuntime {
if state := c.clientConnAttachmentStateRaw(); state != nil {
return state.modernPSKRuntime
}
return nil
}
func (c *ClientConn) setClientConnModernPSKRuntime(runtime *modernPSKCodecRuntime) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.modernPSKRuntime = runtime
})
}
func (c *ClientConn) clientConnLastHeartbeatUnixSnapshot() int64 {
if c == nil {
return 0

View File

@ -1,6 +1,7 @@
package notify
import (
"b612.me/stario"
"context"
"net"
"os"
@ -15,6 +16,10 @@ type serverInboundSourcePusher interface {
pushMessageSource([]byte, interface{})
}
type serverInboundSourceFastPusher interface {
pushTransportPayloadSourceFast([]byte, func(), interface{}) bool
}
func (c *LogicalConn) readTUMessage() {
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil {
@ -37,6 +42,23 @@ func (c *LogicalConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
conn := rt.tuConn
generation := rt.transportGeneration
defer closeClientConnSessionRuntimeTransportDone(rt)
if conn != nil && !isPacketTransportConn(conn) {
reader := newTransportFrameReader(conn, nil)
for {
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseTransportOnStop(conn) {
_ = conn.Close()
}
return
default:
}
payload, release, err := c.readTUTransportPayloadPooled(conn, reader)
if !c.handleTUTransportPayloadReadResultWithSessionPooled(stopCtx, conn, generation, payload, release, err) {
return
}
}
}
buf := streamReadBuffer()
for {
select {
@ -54,6 +76,55 @@ func (c *LogicalConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
}
}
func (c *LogicalConn) readTUTransportPayloadPooled(conn net.Conn, reader *stario.FrameReader) ([]byte, func(), error) {
if reader == nil {
return nil, nil, net.ErrClosed
}
if conn == nil {
return nil, nil, net.ErrClosed
}
if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(timeout))
}
return reader.NextPooled()
}
func (c *LogicalConn) handleTUTransportPayloadReadResultWithSessionPooled(stopCtx context.Context, conn net.Conn, generation uint64, payload []byte, release func(), err error) bool {
if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) {
if release != nil {
release()
}
if c.shouldCloseTransportOnStop(conn) {
_ = conn.Close()
}
return false
}
if err == os.ErrDeadlineExceeded {
return true
}
if err != nil {
if release != nil {
release()
}
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseTransportOnStop(conn) {
_ = conn.Close()
}
return false
default:
}
if detacher, ok := c.Server().(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() {
detacher.detachLogicalSessionTransport(c, "read error", err)
return false
}
c.stopServerOwnedSession("read error", err)
return false
}
c.pushServerOwnedTransportPayload(payload, release, conn, generation)
return true
}
func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) {
if len(data) == 0 {
data = streamReadBuffer()
@ -140,6 +211,29 @@ func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn
server.pushMessage(data, c.clientConnIDSnapshot())
}
func (c *LogicalConn) pushServerOwnedTransportPayload(payload []byte, release func(), conn net.Conn, generation uint64) {
if c == nil || len(payload) == 0 {
if release != nil {
release()
}
return
}
server := c.Server()
if server == nil {
if release != nil {
release()
}
return
}
if pusher, ok := server.(serverInboundSourceFastPusher); ok {
pusher.pushTransportPayloadSourceFast(payload, release, newServerInboundSource(c, conn, nil, generation))
return
}
if release != nil {
release()
}
}
func (c *LogicalConn) shouldCloseTransportOnStop(conn net.Conn) bool {
if c == nil || conn == nil {
return false
@ -185,10 +279,65 @@ func (c *ClientConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byt
return num, data, err
}
func (c *ClientConn) readTUTransportPayloadPooled(conn net.Conn, reader *stario.FrameReader) ([]byte, func(), error) {
if logical := c.LogicalConn(); logical != nil {
return logical.readTUTransportPayloadPooled(conn, reader)
}
if reader == nil {
return nil, nil, net.ErrClosed
}
if conn == nil {
return nil, nil, net.ErrClosed
}
if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(timeout))
}
return reader.NextPooled()
}
func (c *ClientConn) handleTUTransportReadResult(num int, data []byte, err error) bool {
return c.handleTUTransportReadResultWithSession(c.clientConnTransportStopContextSnapshot(), c.clientConnTransportSnapshot(), c.clientConnTransportGenerationSnapshot(), num, data, err)
}
func (c *ClientConn) handleTUTransportPayloadReadResultWithSessionPooled(stopCtx context.Context, conn net.Conn, generation uint64, payload []byte, release func(), err error) bool {
if logical := c.LogicalConn(); logical != nil {
return logical.handleTUTransportPayloadReadResultWithSessionPooled(stopCtx, conn, generation, payload, release, err)
}
if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) {
if release != nil {
release()
}
if c.shouldCloseClientConnTransportOnStop(conn) {
_ = conn.Close()
}
return false
}
if err == os.ErrDeadlineExceeded {
return true
}
if err != nil {
if release != nil {
release()
}
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseClientConnTransportOnStop(conn) {
_ = conn.Close()
}
return false
default:
}
if detacher, ok := c.server.(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() {
detacher.detachLogicalSessionTransport(logicalConnFromClient(c), "read error", err)
return false
}
c.stopServerOwnedSession("read error", err)
return false
}
c.pushServerOwnedTransportPayload(payload, release, conn, generation)
return true
}
func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool {
if logical := c.LogicalConn(); logical != nil {
return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err)
@ -255,6 +404,26 @@ func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn,
c.server.pushMessage(data, c.clientConnIDSnapshot())
}
func (c *ClientConn) pushServerOwnedTransportPayload(payload []byte, release func(), conn net.Conn, generation uint64) {
if logical := c.LogicalConn(); logical != nil {
logical.pushServerOwnedTransportPayload(payload, release, conn, generation)
return
}
if c == nil || c.server == nil || len(payload) == 0 {
if release != nil {
release()
}
return
}
if pusher, ok := c.server.(serverInboundSourceFastPusher); ok {
pusher.pushTransportPayloadSourceFast(payload, release, newServerInboundSource(logicalConnFromClient(c), conn, nil, generation))
return
}
if release != nil {
release()
}
}
func (c *ClientConn) shouldCloseClientConnTransportOnStop(conn net.Conn) bool {
if logical := c.LogicalConn(); logical != nil {
return logical.shouldCloseTransportOnStop(conn)

View File

@ -433,6 +433,21 @@ func (c *ClientCommon) readMessageLoop(stopCtx context.Context, conn net.Conn, q
}
binding := newTransportBinding(conn, queue)
dispatcher := c.clientInboundDispatcherSnapshot()
if conn != nil && queue != nil && !isPacketTransportConn(conn) {
reader := newTransportFrameReader(conn, queue)
for {
select {
case <-stopCtx.Done():
c.closeClientTransportBinding(binding)
return
default:
}
payload, release, err := c.readTransportPayloadPooled(conn, reader)
if !c.handleTransportPayloadReadResultWithSession(stopCtx, binding, payload, release, err, epoch, dispatcher) {
return
}
}
}
buf := streamReadBuffer()
for {
select {

View File

@ -320,7 +320,7 @@ func TestSetClientSessionRuntimeStopsOldBindingWorkersOnReattach(t *testing.T) {
defer oldLeft.Close()
defer oldRight.Close()
oldBinding := newTransportBinding(oldLeft, queue)
oldSender := oldBinding.bulkBatchSenderSnapshot()
oldSender := oldBinding.clientBulkBatchSenderSnapshot(client)
client.setClientSessionRuntime(&clientSessionRuntime{
transport: oldBinding,
@ -345,7 +345,7 @@ func TestSetClientSessionRuntimeStopsOldBindingWorkersOnReattach(t *testing.T) {
epoch: 2,
})
err := oldSender.submit(context.Background(), []byte("payload"))
err := oldSender.submitData(context.Background(), 1, 1, bulkFastPathVersionV1, []byte("payload"))
if err != errTransportDetached {
t.Fatalf("old sender submit after reattach = %v, want %v", err, errTransportDetached)
}

View File

@ -35,6 +35,11 @@ func (c *ClientCommon) OpenStream(ctx context.Context, opt StreamOpenOptions) (S
if resp.DataID != 0 {
req.DataID = resp.DataID
}
if resp.FastPathVersion != 0 {
req.FastPathVersion = resp.FastPathVersion
} else {
req.FastPathVersion = streamFastPathVersionV1
}
req.Metadata = mergeStreamMetadata(req.Metadata, resp.Metadata)
stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot())
stream.setClientSnapshotOwner(c)
@ -66,11 +71,12 @@ func clientStreamRequest(runtime *streamRuntime, opt StreamOpenOptions) StreamOp
id = runtime.nextID()
}
return normalizeStreamOpenRequest(StreamOpenRequest{
StreamID: id,
Channel: opt.Channel,
Metadata: cloneStreamMetadata(opt.Metadata),
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
StreamID: id,
FastPathVersion: streamFastPathVersionCurrent,
Channel: opt.Channel,
Metadata: cloneStreamMetadata(opt.Metadata),
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
})
}
@ -110,7 +116,7 @@ func clientStreamDataSender(c *ClientCommon, epoch uint64) streamDataSender {
}
}
if dataID := stream.dataIDSnapshot(); dataID != 0 {
return c.sendFastStreamData(dataID, stream.nextOutboundDataSeq(), chunk)
return c.sendFastStreamData(ctx, stream, chunk)
}
return c.sendEnvelope(newStreamDataEnvelope(stream.ID(), chunk))
}

View File

@ -130,24 +130,85 @@ func (c *ClientCommon) handleTransportReadResultWithSessionDispatcher(stopCtx co
return true
}
func (c *ClientCommon) readTransportPayloadPooled(conn net.Conn, reader *stario.FrameReader) ([]byte, func(), error) {
if reader == nil {
return nil, nil, net.ErrClosed
}
if conn == nil {
return nil, nil, net.ErrClosed
}
if c.maxReadTimeout.Seconds() != 0 {
_ = conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout))
}
return reader.NextPooled()
}
func (c *ClientCommon) handleTransportPayloadReadResultWithSession(stopCtx context.Context, binding *transportBinding, payload []byte, release func(), err error, epoch uint64, dispatcher *inboundDispatcher) bool {
if err == os.ErrDeadlineExceeded {
return true
}
if err != nil {
if release != nil {
release()
}
if c.showError || c.debugMode {
fmt.Println("client read error", err)
}
select {
case <-sessionStopChan(stopCtx):
c.closeClientTransportBinding(binding)
return false
default:
}
c.stopClientSessionIfCurrent(epoch, "client read error", err)
return false
}
c.dispatchTransportPayloadFast(payload, release, dispatcher)
return true
}
func (c *ClientCommon) dispatchTransportPayloadFast(payload []byte, release func(), dispatcher *inboundDispatcher) {
if len(payload) == 0 {
if release != nil {
release()
}
return
}
if c.tryDispatchBorrowedBulkTransportPayload(payload) {
if release != nil {
release()
}
return
}
owned := append([]byte(nil), payload...)
if release != nil {
release()
}
if dispatcher == nil {
now := time.Now()
if err := c.dispatchInboundTransportPayload(owned, now); err != nil && (c.showError || c.debugMode) {
fmt.Println("client decode envelope error", err)
}
return
}
c.wg.Add(1)
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
defer c.wg.Done()
now := time.Now()
if err := c.dispatchInboundTransportPayload(owned, now); err != nil && (c.showError || c.debugMode) {
fmt.Println("client decode envelope error", err)
}
}) {
c.wg.Done()
}
}
func (c *ClientCommon) pushMessageFast(queue *stario.StarQueue, data []byte, dispatcher *inboundDispatcher) bool {
if queue == nil || dispatcher == nil || len(data) == 0 {
return false
}
if err := queue.ParseMessageOwned(data, "b612", func(msg stario.MsgQueue) error {
payload := msg.Msg
c.wg.Add(1)
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
defer c.wg.Done()
now := time.Now()
if err := c.dispatchInboundTransportPayload(payload, now); err != nil {
if c.showError || c.debugMode {
fmt.Println("client decode envelope error", err)
}
}
}) {
c.wg.Done()
}
if err := queue.ParseMessageView(data, "b612", func(frame stario.FrameView) error {
c.dispatchTransportPayloadFast(frame.Payload, nil, dispatcher)
return nil
}); err != nil && (c.showError || c.debugMode) {
fmt.Println("client parse inbound frame error", err)

View File

@ -18,6 +18,14 @@ type Client interface {
SetStreamConfig(StreamConfig)
SetTransferResumeStore(TransferResumeStore)
RecoverTransferSnapshots(context.Context) error
SetBulkNetworkProfile(BulkNetworkProfile)
BulkNetworkProfile() BulkNetworkProfile
SetBulkDefaultOpenMode(BulkOpenMode)
BulkDefaultOpenMode() BulkOpenMode
SetBulkOpenTuning(BulkOpenTuning)
BulkOpenTuning() BulkOpenTuning
SetBulkDedicatedAttachConfig(BulkDedicatedAttachConfig)
BulkDedicatedAttachConfig() BulkDedicatedAttachConfig
SetFileReceiveDir(dir string) error
send(msg TransferMsg) (WaitMsg, error)
sendEnvelope(env Envelope) error
@ -77,6 +85,8 @@ type Client interface {
OpenStream(ctx context.Context, opt StreamOpenOptions) (Stream, error)
OpenRecordStream(ctx context.Context, opt RecordOpenOptions) (RecordStream, error)
OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error)
OpenSharedBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error)
OpenDedicatedBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error)
SendTransfer(ctx context.Context, opt TransferSendOptions) (TransferHandle, error)
SendFile(ctx context.Context, filePath string) error
}

View File

@ -126,6 +126,8 @@ func init() {
RegisterName("b612.me/notify.BulkCloseResponse", BulkCloseResponse{})
RegisterName("b612.me/notify.BulkResetRequest", BulkResetRequest{})
RegisterName("b612.me/notify.BulkResetResponse", BulkResetResponse{})
RegisterName("b612.me/notify.BulkReadyRequest", BulkReadyRequest{})
RegisterName("b612.me/notify.BulkReadyResponse", BulkReadyResponse{})
RegisterName("b612.me/notify.BulkReleaseRequest", BulkReleaseRequest{})
RegisterName("b612.me/notify.bulkAttachRequest", bulkAttachRequest{})
RegisterName("b612.me/notify.bulkAttachResponse", bulkAttachResponse{})

View File

@ -59,6 +59,11 @@ type DiagnosticsSummary struct {
LogicalCount int
CurrentTransportCount int
BulkAttachAttempts int64
BulkAttachRetries int64
BulkAttachSuccess int64
BulkAutoFallbacks int64
StreamCount int
ActiveStreamCount int
StaleStreamCount int
@ -236,7 +241,11 @@ func serverCurrentTransportRuntimeSnapshots(s Server) ([]TransportConnRuntimeSna
func summarizeClientDiagnosticsSnapshot(snapshot ClientDiagnosticsSnapshot) DiagnosticsSummary {
summary := DiagnosticsSummary{
LogicalCount: diagnosticsLogicalCountFromClientRuntime(snapshot.Runtime),
LogicalCount: diagnosticsLogicalCountFromClientRuntime(snapshot.Runtime),
BulkAttachAttempts: snapshot.Runtime.BulkAttachAttempts,
BulkAttachRetries: snapshot.Runtime.BulkAttachRetries,
BulkAttachSuccess: snapshot.Runtime.BulkAttachSuccess,
BulkAutoFallbacks: snapshot.Runtime.BulkAutoFallbacks,
}
if snapshot.Runtime.TransportAttached {
summary.CurrentTransportCount = 1

2
go.mod
View File

@ -4,7 +4,7 @@ go 1.24.0
require (
b612.me/starcrypto v1.0.2
b612.me/stario v0.1.0
b612.me/stario v0.1.1
github.com/Microsoft/go-winio v0.6.2
)

4
go.sum
View File

@ -1,7 +1,7 @@
b612.me/starcrypto v1.0.2 h1:6f8YHNMHZPwxDSRxY2OJeMP4ExKa/cakLIO04f0gLhE=
b612.me/starcrypto v1.0.2/go.mod h1:I7oYTmQgnVPj5S5yKwoTyqkItq1HgF9XdJT/v3qs5QE=
b612.me/stario v0.1.0 h1:V1uA7fLYzgTadOXpnyPaFC3z0MAKFIM/RKXzZUDXvL4=
b612.me/stario v0.1.0/go.mod h1:7kjE69oFqNrca0P72L5+ZbTV09QGJ2N3bBY3qeFXOGc=
b612.me/stario v0.1.1 h1:WIQy5DdK2Tkk+PIRORaVb76f4KY+64UvWChXNI7hSVY=
b612.me/stario v0.1.1/go.mod h1:qMMqjaMhKmdfFn5T0oleL5L4FpFWEHsuIrT878hyo7I=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=

View File

@ -409,6 +409,7 @@ func (c *LogicalConn) applyAttachmentProfile(maxReadTimeout time.Duration, maxWr
state.fastStreamEncode = fastStreamEncode
state.fastBulkEncode = fastBulkEncode
state.fastPlainEncode = fastPlainEncode
state.modernPSKRuntime = nil
state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey)
state.secretKey = cloneClientConnAttachmentBytes(secretKey)
})
@ -553,6 +554,12 @@ func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot {
snapshot.HasRuntimeConn = c.transportSnapshot() != nil
snapshot.HasRuntimeStopCtx = rt.stopCtx != nil
}
if binding := c.transportBindingSnapshot(); binding != nil {
snapshot.TransportBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot()
snapshot.TransportStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot()
snapshot.TransportStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot()
snapshot.TransportStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot()
}
if detach := c.transportDetachSnapshot(); detach != nil {
snapshot.TransportDetachReason = detach.Reason
snapshot.TransportDetachKind = classifyClientConnTransportDetachReason(detach.Reason)
@ -816,6 +823,7 @@ func (c *LogicalConn) applyClientConnAttachmentProfile(maxReadTimeout time.Durat
state.maxWriteTimeout = maxWriteTimeout
state.msgEn = msgEn
state.msgDe = msgDe
state.modernPSKRuntime = nil
state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey)
state.secretKey = cloneClientConnAttachmentBytes(secretKey)
})

View File

@ -41,7 +41,7 @@ func BenchmarkRawTCPLocalhostThroughput(b *testing.B) {
func benchmarkRawTCPLocalhostThroughput(b *testing.B, payloadSize int) {
b.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
listener, err := net.Listen("tcp", benchmarkTCPListenAddr(b))
if err != nil {
b.Fatalf("net.Listen failed: %v", err)
}
@ -60,7 +60,7 @@ func benchmarkRawTCPLocalhostThroughput(b *testing.B, payloadSize int) {
acceptCh <- conn
}()
clientConn, err := net.Dial("tcp", listener.Addr().String())
clientConn, err := net.Dial("tcp", benchmarkTCPDialAddr(b, listener.Addr().String()))
if err != nil {
b.Fatalf("net.Dial failed: %v", err)
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"net"
"strings"
"sync/atomic"
"testing"
"time"
)
@ -14,6 +15,29 @@ type releaseP0TestAddr string
func (a releaseP0TestAddr) Network() string { return "tcp" }
func (a releaseP0TestAddr) String() string { return string(a) }
type closeInspectConn struct {
closeFn func()
closed atomic.Bool
}
func (c *closeInspectConn) Read([]byte) (int, error) { return 0, net.ErrClosed }
func (c *closeInspectConn) Write(p []byte) (int, error) { return len(p), nil }
func (c *closeInspectConn) LocalAddr() net.Addr { return releaseP0TestAddr("local") }
func (c *closeInspectConn) RemoteAddr() net.Addr { return releaseP0TestAddr("remote") }
func (c *closeInspectConn) SetDeadline(time.Time) error { return nil }
func (c *closeInspectConn) SetReadDeadline(time.Time) error { return nil }
func (c *closeInspectConn) SetWriteDeadline(time.Time) error { return nil }
func (c *closeInspectConn) Close() error {
if c == nil {
return nil
}
if c.closed.CompareAndSwap(false, true) && c.closeFn != nil {
c.closeFn()
}
return nil
}
func TestGetLogicalConnRuntimeSnapshotWithoutCompatClient(t *testing.T) {
server := NewServer().(*ServerCommon)
logical := &LogicalConn{server: server}
@ -101,6 +125,116 @@ func TestHandleDedicatedBulkReadErrorPreservesUnderlyingCause(t *testing.T) {
}
}
func TestHandleClientDedicatedSidecarFailureMarksBulkBeforeClosingConn(t *testing.T) {
client := NewClient().(*ClientCommon)
runtime := client.getBulkRuntime()
if runtime == nil {
t.Fatal("client bulk runtime should not be nil")
}
bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{
BulkID: "sidecar-order",
DataID: 7,
Dedicated: true,
Range: BulkRange{
Length: 1,
},
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
var closeObservedErr error
conn := &closeInspectConn{
closeFn: func() {
closeObservedErr = bulk.resetErrSnapshot()
},
}
if err := bulk.attachDedicatedConnShared(conn); err != nil {
t.Fatalf("attachDedicatedConnShared failed: %v", err)
}
if err := runtime.register(clientFileScope(), bulk); err != nil {
t.Fatalf("register bulk failed: %v", err)
}
sidecar := newBulkDedicatedSidecar(conn, 1)
client.installClientDedicatedSidecar(1, sidecar)
client.handleClientDedicatedSidecarFailure(sidecar, errors.New("boom sidecar"))
if !errors.Is(closeObservedErr, errTransportDetached) {
t.Fatalf("closeObservedErr = %v, want transport detached", closeObservedErr)
}
if !strings.Contains(closeObservedErr.Error(), "dedicated bulk read error") || !strings.Contains(closeObservedErr.Error(), "boom sidecar") {
t.Fatalf("closeObservedErr detail = %q, want dedicated read error and cause", closeObservedErr.Error())
}
}
func TestCleanupClientSessionResourcesMarksBulkBeforeClosingSidecar(t *testing.T) {
client := NewClient().(*ClientCommon)
runtime := client.getBulkRuntime()
if runtime == nil {
t.Fatal("client bulk runtime should not be nil")
}
bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{
BulkID: "cleanup-order",
DataID: 9,
Dedicated: true,
Range: BulkRange{
Length: 1,
},
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
var closeObservedErr error
conn := &closeInspectConn{
closeFn: func() {
closeObservedErr = bulk.resetErrSnapshot()
},
}
if err := bulk.attachDedicatedConnShared(conn); err != nil {
t.Fatalf("attachDedicatedConnShared failed: %v", err)
}
if err := runtime.register(clientFileScope(), bulk); err != nil {
t.Fatalf("register bulk failed: %v", err)
}
client.installClientDedicatedSidecar(1, newBulkDedicatedSidecar(conn, 1))
client.cleanupClientSessionResources()
if !errors.Is(closeObservedErr, errServiceShutdown) {
t.Fatalf("closeObservedErr = %v, want %v", closeObservedErr, errServiceShutdown)
}
}
func TestBestEffortRejectInboundDedicatedDataUsesDedicatedResetRecord(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
logical := server.bootstrapAcceptedLogical("dedicated-reject", nil, nil)
if logical == nil {
t.Fatal("bootstrapAcceptedLogical should return logical")
}
conn := newBulkAttachScriptConn(nil)
server.bestEffortRejectInboundDedicatedData(logical, conn, 42, "unknown data id")
recordConn := newBulkAttachScriptConn(conn.writtenBytes())
payload, err := readBulkDedicatedRecord(recordConn)
if err != nil {
t.Fatalf("readBulkDedicatedRecord failed: %v", err)
}
plain, err := server.decryptTransportPayloadLogical(logical, payload)
if err != nil {
t.Fatalf("decryptTransportPayloadLogical failed: %v", err)
}
items, err := decodeDedicatedBulkInboundItems(42, plain)
if err != nil {
t.Fatalf("decodeDedicatedBulkInboundItems failed: %v", err)
}
if len(items) != 1 {
t.Fatalf("decoded items = %d, want 1", len(items))
}
if items[0].Type != bulkFastPayloadTypeReset {
t.Fatalf("reset item type = %d, want %d", items[0].Type, bulkFastPayloadTypeReset)
}
if got, want := string(items[0].Payload), "unknown data id"; got != want {
t.Fatalf("reset payload = %q, want %q", got, want)
}
}
func TestRegisterAcceptedLogicalWithoutCompatClient(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)

View File

@ -40,6 +40,8 @@ type modernPSKTransportBundle struct {
fastPlainEncode transportFastPlainEncoder
}
var modernPSKPayloadPool sync.Pool
// ModernPSKOptions configures the modern PSK transport profile.
//
// The current profile derives a 32-byte transport key with Argon2id and uses
@ -81,6 +83,10 @@ func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) e
return err
}
transport := buildModernPSKTransportBundle(aad)
runtime, err := newModernPSKCodecRuntime(key, aad)
if err != nil {
return err
}
c.SetSecretKey(key)
c.SetMsgEn(transport.msgEn)
c.SetMsgDe(transport.msgDe)
@ -88,6 +94,7 @@ func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) e
client.fastStreamEncode = transport.fastStreamEncode
client.fastBulkEncode = transport.fastBulkEncode
client.fastPlainEncode = transport.fastPlainEncode
client.modernPSKRuntime = runtime
}
c.SetSkipExchangeKey(true)
return nil
@ -104,6 +111,10 @@ func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) e
return err
}
transport := buildModernPSKTransportBundle(aad)
runtime, err := newModernPSKCodecRuntime(key, aad)
if err != nil {
return err
}
s.SetSecretKey(key)
s.SetDefaultCommEncode(transport.msgEn)
s.SetDefaultCommDecode(transport.msgDe)
@ -111,6 +122,7 @@ func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) e
server.defaultFastStreamEncode = transport.fastStreamEncode
server.defaultFastBulkEncode = transport.fastBulkEncode
server.defaultFastPlainEncode = transport.fastPlainEncode
server.defaultModernPSKRuntime = runtime
}
return nil
}
@ -127,6 +139,7 @@ func UseLegacySecurityClient(c Client) {
client.fastStreamEncode = nil
client.fastBulkEncode = nil
client.fastPlainEncode = nil
client.modernPSKRuntime = nil
}
c.SetSkipExchangeKey(false)
c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey))
@ -144,6 +157,7 @@ func UseLegacySecurityServer(s Server) {
server.defaultFastStreamEncode = nil
server.defaultFastBulkEncode = nil
server.defaultFastPlainEncode = nil
server.defaultModernPSKRuntime = nil
}
s.SetRsaPrivKey(bytes.Clone(defaultRsaKey))
}
@ -185,14 +199,14 @@ func buildModernPSKCodecs(aad []byte) (func([]byte, []byte) []byte, func([]byte,
func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle {
aadCopy := bytes.Clone(aad)
cache := &modernPSKCodecCache{}
cache := newModernPSKCodecCache(aadCopy)
msgEn := func(key []byte, plain []byte) []byte {
runtime, err := cache.runtimeForKey(key)
if err != nil {
log.Print(err)
return nil
}
out, err := runtime.sealPlainPayload(aadCopy, plain)
out, err := runtime.sealPlainPayload(plain)
if err != nil {
log.Print(err)
return nil
@ -214,9 +228,7 @@ func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle {
log.Print(err)
return nil
}
nonce := encrypted[len(modernPSKMagic):headerLen]
ciphertext := encrypted[headerLen:]
plain, err := runtime.aead.Open(make([]byte, 0, len(ciphertext)), nonce, ciphertext, aadCopy)
plain, err := runtime.openPayload(encrypted)
if err != nil {
log.Print(err)
return nil
@ -228,21 +240,21 @@ func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle {
if err != nil {
return nil, err
}
return runtime.sealStreamFastPayload(aadCopy, dataID, seq, payload)
return runtime.sealStreamFastPayload(dataID, seq, payload)
}
fastBulkEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
runtime, err := cache.runtimeForKey(key)
if err != nil {
return nil, err
}
return runtime.sealBulkFastPayload(aadCopy, dataID, seq, payload)
return runtime.sealBulkFastPayload(dataID, seq, payload)
}
fastPlainEncode := func(key []byte, plainLen int, fill func([]byte) error) ([]byte, error) {
runtime, err := cache.runtimeForKey(key)
if err != nil {
return nil, err
}
return runtime.sealFilledPayload(aadCopy, plainLen, fill)
return runtime.sealFilledPayload(plainLen, fill)
}
return modernPSKTransportBundle{
msgEn: msgEn,
@ -269,16 +281,23 @@ func (s *ServerCommon) validateSecurityConfiguration() error {
type modernPSKCodecCache struct {
mu sync.Mutex
aad []byte
key []byte
runtime *modernPSKCodecRuntime
}
type modernPSKCodecRuntime struct {
aead cipher.AEAD
key []byte
aad []byte
prefix [modernPSKNonceSize - 8]byte
seq atomic.Uint64
}
func newModernPSKCodecCache(aad []byte) *modernPSKCodecCache {
return &modernPSKCodecCache{aad: bytes.Clone(aad)}
}
func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime, error) {
if c == nil {
return nil, errModernPSKSecretEmpty
@ -288,7 +307,7 @@ func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime,
if c.runtime != nil && bytes.Equal(c.key, key) {
return c.runtime, nil
}
runtime, err := newModernPSKCodecRuntime(key)
runtime, err := newModernPSKCodecRuntime(key, c.aad)
if err != nil {
return nil, err
}
@ -297,7 +316,7 @@ func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime,
return runtime, nil
}
func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) {
func newModernPSKCodecRuntime(key []byte, aad []byte) (*modernPSKCodecRuntime, error) {
if len(key) == 0 {
return nil, errModernPSKSecretEmpty
}
@ -311,6 +330,8 @@ func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) {
}
runtime := &modernPSKCodecRuntime{
aead: aead,
key: bytes.Clone(key),
aad: bytes.Clone(aad),
}
if _, err := cryptorand.Read(runtime.prefix[:]); err != nil {
return nil, err
@ -318,6 +339,13 @@ func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) {
return runtime, nil
}
func (r *modernPSKCodecRuntime) fork() (*modernPSKCodecRuntime, error) {
if r == nil {
return nil, errModernPSKSecretEmpty
}
return newModernPSKCodecRuntime(r.key, r.aad)
}
func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte {
var nonce [modernPSKNonceSize]byte
if r == nil {
@ -328,8 +356,8 @@ func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte {
return nonce
}
func (r *modernPSKCodecRuntime) sealStreamFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
return r.sealFilledPayload(aad, streamFastPayloadHeaderLen+len(payload), func(frame []byte) error {
func (r *modernPSKCodecRuntime) sealStreamFastPayload(dataID uint64, seq uint64, payload []byte) ([]byte, error) {
return r.sealFilledPayload(streamFastPayloadHeaderLen+len(payload), func(frame []byte) error {
if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
return err
}
@ -338,11 +366,11 @@ func (r *modernPSKCodecRuntime) sealStreamFastPayload(aad []byte, dataID uint64,
})
}
func (r *modernPSKCodecRuntime) sealBulkFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
func (r *modernPSKCodecRuntime) sealBulkFastPayload(dataID uint64, seq uint64, payload []byte) ([]byte, error) {
if r == nil {
return nil, errTransportPayloadEncryptFailed
}
return r.sealFilledPayload(aad, bulkFastPayloadHeaderLen+len(payload), func(frame []byte) error {
return r.sealFilledPayload(bulkFastPayloadHeaderLen+len(payload), func(frame []byte) error {
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
return err
}
@ -351,14 +379,14 @@ func (r *modernPSKCodecRuntime) sealBulkFastPayload(aad []byte, dataID uint64, s
})
}
func (r *modernPSKCodecRuntime) sealPlainPayload(aad []byte, plain []byte) ([]byte, error) {
return r.sealFilledPayload(aad, len(plain), func(dst []byte) error {
func (r *modernPSKCodecRuntime) sealPlainPayload(plain []byte) ([]byte, error) {
return r.sealFilledPayload(len(plain), func(dst []byte) error {
copy(dst, plain)
return nil
})
}
func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill func([]byte) error) ([]byte, error) {
func (r *modernPSKCodecRuntime) sealFilledPayload(plainLen int, fill func([]byte) error) ([]byte, error) {
if r == nil {
return nil, errTransportPayloadEncryptFailed
}
@ -368,6 +396,35 @@ func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill
nonce := r.nextNonce()
headerLen := len(modernPSKMagic) + modernPSKNonceSize
out := make([]byte, headerLen+plainLen+r.aead.Overhead())
sealed, err := r.sealInto(out, headerLen, nonce, plainLen, fill)
if err != nil {
return nil, err
}
return out[:headerLen+len(sealed)], nil
}
func (r *modernPSKCodecRuntime) sealFilledPayloadPooled(plainLen int, fill func([]byte) error) ([]byte, func(), error) {
if r == nil {
return nil, nil, errTransportPayloadEncryptFailed
}
if plainLen < 0 {
return nil, nil, errTransportPayloadEncryptFailed
}
nonce := r.nextNonce()
headerLen := len(modernPSKMagic) + modernPSKNonceSize
totalLen := headerLen + plainLen + r.aead.Overhead()
out := getModernPSKPayloadBuffer(totalLen)
sealed, err := r.sealInto(out, headerLen, nonce, plainLen, fill)
if err != nil {
putModernPSKPayloadBuffer(out)
return nil, nil, err
}
return out[:headerLen+len(sealed)], func() {
putModernPSKPayloadBuffer(out)
}, nil
}
func (r *modernPSKCodecRuntime) sealInto(out []byte, headerLen int, nonce [modernPSKNonceSize]byte, plainLen int, fill func([]byte) error) ([]byte, error) {
copy(out[:len(modernPSKMagic)], modernPSKMagic)
copy(out[len(modernPSKMagic):headerLen], nonce[:])
frame := out[headerLen : headerLen+plainLen]
@ -376,6 +433,98 @@ func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill
return nil, err
}
}
sealed := r.aead.Seal(frame[:0], nonce[:], frame, aad)
return out[:headerLen+len(sealed)], nil
return r.aead.Seal(frame[:0], nonce[:], frame, r.aad), nil
}
func (r *modernPSKCodecRuntime) openPayload(encrypted []byte) ([]byte, error) {
if r == nil {
return nil, errTransportPayloadDecryptFailed
}
headerLen := len(modernPSKMagic) + modernPSKNonceSize
if len(encrypted) < headerLen {
return nil, errModernPSKPayload
}
if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) {
return nil, errModernPSKPayload
}
nonce := encrypted[len(modernPSKMagic):headerLen]
ciphertext := encrypted[headerLen:]
return r.aead.Open(make([]byte, 0, len(ciphertext)), nonce, ciphertext, r.aad)
}
func (r *modernPSKCodecRuntime) openPayloadPooled(encrypted []byte, release func()) ([]byte, func(), error) {
if r == nil {
if release != nil {
release()
}
return nil, nil, errTransportPayloadDecryptFailed
}
headerLen := len(modernPSKMagic) + modernPSKNonceSize
if len(encrypted) < headerLen {
if release != nil {
release()
}
return nil, nil, errModernPSKPayload
}
if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) {
if release != nil {
release()
}
return nil, nil, errModernPSKPayload
}
nonce := encrypted[len(modernPSKMagic):headerLen]
ciphertext := encrypted[headerLen:]
plain, err := r.aead.Open(ciphertext[:0], nonce, ciphertext, r.aad)
if err != nil {
if release != nil {
release()
}
return nil, nil, err
}
return plain, release, nil
}
func (r *modernPSKCodecRuntime) openPayloadOwnedPooled(encrypted []byte) ([]byte, func(), error) {
if r == nil {
return nil, nil, errTransportPayloadDecryptFailed
}
headerLen := len(modernPSKMagic) + modernPSKNonceSize
if len(encrypted) < headerLen {
return nil, nil, errModernPSKPayload
}
if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) {
return nil, nil, errModernPSKPayload
}
nonce := encrypted[len(modernPSKMagic):headerLen]
ciphertext := encrypted[headerLen:]
plainLen := len(ciphertext) - r.aead.Overhead()
if plainLen < 0 {
return nil, nil, errModernPSKPayload
}
out := getModernPSKPayloadBuffer(plainLen)
plain, err := r.aead.Open(out[:0], nonce, ciphertext, r.aad)
if err != nil {
putModernPSKPayloadBuffer(out)
return nil, nil, err
}
return plain, func() {
putModernPSKPayloadBuffer(out)
}, nil
}
func getModernPSKPayloadBuffer(size int) []byte {
if size <= 0 {
return nil
}
if pooled, ok := modernPSKPayloadPool.Get().([]byte); ok && cap(pooled) >= size {
return pooled[:size]
}
return make([]byte, size)
}
func putModernPSKPayloadBuffer(buf []byte) {
if cap(buf) == 0 || cap(buf) > 32*1024*1024 {
return
}
modernPSKPayloadPool.Put(buf[:0])
}

View File

@ -83,6 +83,12 @@ func TestDefaultConstructorsUseModernTransportAfterSetSecretKey(t *testing.T) {
sharedKey := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(sharedKey)
server.SetSecretKey(sharedKey)
if client.modernPSKRuntime == nil {
t.Fatal("client modernPSKRuntime should be installed after SetSecretKey")
}
if server.defaultModernPSKRuntime == nil {
t.Fatal("server defaultModernPSKRuntime should be installed after SetSecretKey")
}
plain := []byte("notify default modern transport")
wire := client.msgEn(client.SecretKey, plain)
@ -92,6 +98,29 @@ func TestDefaultConstructorsUseModernTransportAfterSetSecretKey(t *testing.T) {
}
}
func TestCustomCodecOverridesClearModernRuntime(t *testing.T) {
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
sharedKey := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(sharedKey)
server.SetSecretKey(sharedKey)
if client.modernPSKRuntime == nil || server.defaultModernPSKRuntime == nil {
t.Fatal("modern runtimes should be installed before override")
}
client.SetMsgEn(defaultMsgEn)
client.SetMsgDe(defaultMsgDe)
server.SetDefaultCommEncode(defaultMsgEn)
server.SetDefaultCommDecode(defaultMsgDe)
if client.modernPSKRuntime != nil {
t.Fatal("client modernPSKRuntime should be cleared after custom codec override")
}
if server.defaultModernPSKRuntime != nil {
t.Fatal("server defaultModernPSKRuntime should be cleared after custom codec override")
}
}
func TestDefaultConstructorsDecodeSignalEnvelopeWithModernTransport(t *testing.T) {
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)

View File

@ -33,6 +33,7 @@ type ServerCommon struct {
defaultFastStreamEncode transportFastStreamEncoder
defaultFastBulkEncode transportFastBulkEncoder
defaultFastPlainEncode transportFastPlainEncoder
defaultModernPSKRuntime *modernPSKCodecRuntime
linkFns map[string]func(message *Message)
defaultFns func(message *Message)
noFinSyncMsgMaxKeepSeconds int64
@ -47,6 +48,9 @@ type ServerCommon struct {
streamRuntime *streamRuntime
recordRuntime *recordRuntime
bulkRuntime *bulkRuntime
bulkOpenTuning BulkOpenTuning
bulkDedicatedSidecarMu sync.Mutex
bulkDedicatedSidecars map[*LogicalConn]map[uint32]*bulkDedicatedSidecar
connectionRetryState *connectionRetryState
detachedClientKeepSeconds int64
securityReadyCheck bool
@ -81,6 +85,8 @@ func NewServer() Server {
server.streamRuntime = newStreamRuntime("sstrm")
server.recordRuntime = newRecordRuntime()
server.bulkRuntime = newBulkRuntime("sblk")
server.bulkOpenTuning = defaultBulkOpenTuning()
server.bulkDedicatedSidecars = make(map[*LogicalConn]map[uint32]*bulkDedicatedSidecar)
server.connectionRetryState = newConnectionRetryState()
server.onFileEvent = normalizeFileEventCallback(nil)
server.fileEventObserver = normalizeFileEventCallback(nil)

View File

@ -1,6 +1,9 @@
package notify
import "context"
import (
"context"
"errors"
)
func (s *ServerCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) {
runtime := s.getBulkRuntime()
@ -11,9 +14,48 @@ func (s *ServerCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) {
}
func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn, opt BulkOpenOptions) (Bulk, error) {
opt = normalizeBulkOpenOptions(opt)
switch opt.Mode {
case BulkOpenModeDedicated:
opt.Dedicated = true
return s.openBulkLogicalWithMode(ctx, logical, opt)
case BulkOpenModeAuto:
if err := logicalDedicatedBulkSupportError(logical); err == nil {
dedicatedOpt := opt
dedicatedOpt.Mode = BulkOpenModeDedicated
dedicatedOpt.Dedicated = true
bulk, dedicatedErr := s.openBulkLogicalWithMode(ctx, logical, dedicatedOpt)
if dedicatedErr == nil {
return bulk, nil
}
sharedOpt := opt
sharedOpt.Mode = BulkOpenModeShared
sharedOpt.Dedicated = false
sharedBulk, sharedErr := s.openBulkLogicalWithMode(ctx, logical, sharedOpt)
if sharedErr == nil {
return sharedBulk, nil
}
return nil, errors.Join(dedicatedErr, sharedErr)
}
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return s.openBulkLogicalWithMode(ctx, logical, opt)
case BulkOpenModeShared, BulkOpenModeDefault:
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return s.openBulkLogicalWithMode(ctx, logical, opt)
default:
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return s.openBulkLogicalWithMode(ctx, logical, opt)
}
}
func (s *ServerCommon) openBulkLogicalWithMode(ctx context.Context, logical *LogicalConn, opt BulkOpenOptions) (Bulk, error) {
if s == nil {
return nil, errBulkServerNil
}
opt = applyBulkOpenTuningDefaults(opt, s.bulkOpenTuningSnapshot())
if logical == nil {
return nil, errBulkLogicalConnNil
}
@ -42,17 +84,44 @@ func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn
req.AttachToken = newBulkAttachToken()
}
bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, logical.CurrentTransportConn(), logical.transportGenerationSnapshot(), serverBulkCloseSender(s, logical, nil), serverBulkResetSender(s, logical, nil), serverBulkDataSender(s, logical.CurrentTransportConn()), serverBulkWriteSender(s, logical, logical.CurrentTransportConn()), serverBulkReleaseSender(s, logical, logical.CurrentTransportConn()))
bulk.markAcceptHandled()
if err := runtime.register(scope, bulk); err != nil {
return nil, err
}
s.attachServerDedicatedSidecarIfExists(logical, bulk)
resp, err := sendBulkOpenServerLogical(ctx, s, logical, req)
if err != nil {
runtime.remove(scope, bulk.ID())
bulk.markReset(err)
return nil, err
}
if resp.DataID != 0 && resp.DataID != req.DataID {
err = errBulkAlreadyExists
_, _ = sendBulkResetServerLogical(context.Background(), s, logical, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: "bulk dedicated data id mismatch",
})
bulk.markReset(err)
return nil, err
}
if resp.TransportGeneration != 0 {
bulk.transportGeneration = resp.TransportGeneration
}
if resp.FastPathVersion != 0 {
bulk.fastPathVersion = normalizeBulkFastPathVersion(resp.FastPathVersion)
}
if resp.AttachToken != "" {
bulk.setDedicatedAttachToken(resp.AttachToken)
}
if err := bulk.waitAcceptReady(ctx); err != nil {
_, _ = sendBulkResetServerLogical(context.Background(), s, logical, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: err.Error(),
})
bulk.markReset(err)
return nil, err
}
return bulk, nil
}
resp, err := sendBulkOpenServerLogical(ctx, s, logical, req)
@ -62,6 +131,9 @@ func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn
if resp.DataID != 0 {
req.DataID = resp.DataID
}
if resp.FastPathVersion != 0 {
req.FastPathVersion = resp.FastPathVersion
}
req.Dedicated = resp.Dedicated
if resp.AttachToken != "" {
req.AttachToken = resp.AttachToken
@ -71,6 +143,7 @@ func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn
}
transport := logical.CurrentTransportConn()
bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverBulkCloseSender(s, logical, nil), serverBulkResetSender(s, logical, nil), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport))
bulk.markAcceptHandled()
if err := runtime.register(scope, bulk); err != nil {
_, _ = sendBulkResetServerLogical(context.Background(), s, logical, BulkResetRequest{
BulkID: req.BulkID,
@ -79,13 +152,53 @@ func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn
})
return nil, err
}
s.attachServerDedicatedSidecarIfExists(logical, bulk)
return bulk, nil
}
func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *TransportConn, opt BulkOpenOptions) (Bulk, error) {
opt = normalizeBulkOpenOptions(opt)
switch opt.Mode {
case BulkOpenModeDedicated:
opt.Dedicated = true
return s.openBulkTransportWithMode(ctx, transport, opt)
case BulkOpenModeAuto:
if err := transportDedicatedBulkSupportError(transport); err == nil {
dedicatedOpt := opt
dedicatedOpt.Mode = BulkOpenModeDedicated
dedicatedOpt.Dedicated = true
bulk, dedicatedErr := s.openBulkTransportWithMode(ctx, transport, dedicatedOpt)
if dedicatedErr == nil {
return bulk, nil
}
sharedOpt := opt
sharedOpt.Mode = BulkOpenModeShared
sharedOpt.Dedicated = false
sharedBulk, sharedErr := s.openBulkTransportWithMode(ctx, transport, sharedOpt)
if sharedErr == nil {
return sharedBulk, nil
}
return nil, errors.Join(dedicatedErr, sharedErr)
}
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return s.openBulkTransportWithMode(ctx, transport, opt)
case BulkOpenModeShared, BulkOpenModeDefault:
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return s.openBulkTransportWithMode(ctx, transport, opt)
default:
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return s.openBulkTransportWithMode(ctx, transport, opt)
}
}
func (s *ServerCommon) openBulkTransportWithMode(ctx context.Context, transport *TransportConn, opt BulkOpenOptions) (Bulk, error) {
if s == nil {
return nil, errBulkServerNil
}
opt = applyBulkOpenTuningDefaults(opt, s.bulkOpenTuningSnapshot())
if transport == nil {
return nil, errBulkTransportNil
}
@ -118,17 +231,44 @@ func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *Transpo
req.AttachToken = newBulkAttachToken()
}
bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, transport.TransportGeneration(), serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport))
bulk.markAcceptHandled()
if err := runtime.register(scope, bulk); err != nil {
return nil, err
}
s.attachServerDedicatedSidecarIfExists(logical, bulk)
resp, err := sendBulkOpenServerTransport(ctx, s, transport, req)
if err != nil {
runtime.remove(scope, bulk.ID())
bulk.markReset(err)
return nil, err
}
if resp.DataID != 0 && resp.DataID != req.DataID {
err = errBulkAlreadyExists
_, _ = sendBulkResetServerTransport(context.Background(), s, transport, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: "bulk dedicated data id mismatch",
})
bulk.markReset(err)
return nil, err
}
if resp.TransportGeneration != 0 {
bulk.transportGeneration = resp.TransportGeneration
}
if resp.FastPathVersion != 0 {
bulk.fastPathVersion = normalizeBulkFastPathVersion(resp.FastPathVersion)
}
if resp.AttachToken != "" {
bulk.setDedicatedAttachToken(resp.AttachToken)
}
if err := bulk.waitAcceptReady(ctx); err != nil {
_, _ = sendBulkResetServerTransport(context.Background(), s, transport, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: err.Error(),
})
bulk.markReset(err)
return nil, err
}
return bulk, nil
}
resp, err := sendBulkOpenServerTransport(ctx, s, transport, req)
@ -138,6 +278,9 @@ func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *Transpo
if resp.DataID != 0 {
req.DataID = resp.DataID
}
if resp.FastPathVersion != 0 {
req.FastPathVersion = resp.FastPathVersion
}
req.Dedicated = resp.Dedicated
if resp.AttachToken != "" {
req.AttachToken = resp.AttachToken
@ -146,6 +289,7 @@ func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *Transpo
return nil, errBulkDataIDEmpty
}
bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport))
bulk.markAcceptHandled()
if err := runtime.register(scope, bulk); err != nil {
_, _ = sendBulkResetServerTransport(context.Background(), s, transport, BulkResetRequest{
BulkID: req.BulkID,
@ -154,6 +298,7 @@ func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *Transpo
})
return nil, err
}
s.attachServerDedicatedSidecarIfExists(logical, bulk)
return bulk, nil
}
@ -164,15 +309,16 @@ func serverBulkRequest(runtime *bulkRuntime, opt BulkOpenOptions) BulkOpenReques
id = runtime.nextID()
}
return normalizeBulkOpenRequest(BulkOpenRequest{
BulkID: id,
Range: opt.Range,
Metadata: cloneBulkMetadata(opt.Metadata),
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
Dedicated: opt.Dedicated,
ChunkSize: opt.ChunkSize,
WindowBytes: opt.WindowBytes,
MaxInFlight: opt.MaxInFlight,
BulkID: id,
FastPathVersion: bulkFastPathVersionCurrent,
Range: opt.Range,
Metadata: cloneBulkMetadata(opt.Metadata),
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
Dedicated: opt.Dedicated,
ChunkSize: opt.ChunkSize,
WindowBytes: opt.WindowBytes,
MaxInFlight: opt.MaxInFlight,
})
}
@ -247,12 +393,12 @@ func serverBulkDataSender(s *ServerCommon, transport *TransportConn) bulkDataSen
if dataID == 0 {
return errBulkDataPathNotReady
}
return s.sendFastBulkDataTransport(ctx, bulk.LogicalConn(), transport, dataID, bulk.nextOutboundDataSeq(), chunk)
return s.sendFastBulkDataTransport(ctx, bulk.LogicalConn(), transport, dataID, bulk.nextOutboundDataSeq(), chunk, bulk.fastPathVersionSnapshot())
}
}
func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *TransportConn) bulkWriteSender {
return func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) {
return func(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) {
if s == nil {
return 0, errBulkServerNil
}
@ -267,7 +413,7 @@ func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *Tra
if err := bulk.waitDedicatedReady(ctx); err != nil {
return 0, err
}
return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, payload)
return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, startSeq, payload)
}
if transport == nil {
return 0, errBulkTransportNil
@ -275,7 +421,14 @@ func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *Tra
if !transport.IsCurrent() {
return 0, errTransportDetached
}
return 0, nil
if bulk == nil {
return 0, errBulkRuntimeNil
}
dataID := bulk.dataIDSnapshot()
if dataID == 0 {
return 0, errBulkDataPathNotReady
}
return s.sendFastBulkWriteTransport(ctx, bulk.LogicalConn(), transport, dataID, startSeq, bulk.chunkSize, bulk.fastPathVersionSnapshot(), payload, payloadOwned)
}
}
@ -287,8 +440,20 @@ func serverBulkReleaseSender(s *ServerCommon, logical *LogicalConn, transport *T
if bytes <= 0 && chunks <= 0 {
return nil
}
ctx, cancel, err := bulk.newWriteContext(bulk.Context(), bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
if bulk.Dedicated() {
return s.sendDedicatedBulkRelease(context.Background(), logical, bulk, bytes, chunks)
return s.sendDedicatedBulkRelease(ctx, logical, bulk, bytes, chunks)
}
if transport != nil && transport.IsCurrent() && bulk.fastPathVersionSnapshot() >= bulkFastPathVersionV2 {
payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks)
if err != nil {
return err
}
return s.sendFastBulkControlTransport(ctx, logical, transport, bulkFastPayloadTypeRelease, 0, bulk.dataIDSnapshot(), 0, bulk.fastPathVersionSnapshot(), payload)
}
req := BulkReleaseRequest{
BulkID: bulk.ID(),

View File

@ -35,6 +35,7 @@ func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) {
s.defaultFastStreamEncode = nil
s.defaultFastBulkEncode = nil
s.defaultFastPlainEncode = nil
s.defaultModernPSKRuntime = nil
s.securityReadyCheck = false
}
@ -45,6 +46,7 @@ func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) {
s.defaultFastStreamEncode = nil
s.defaultFastBulkEncode = nil
s.defaultFastPlainEncode = nil
s.defaultModernPSKRuntime = nil
s.securityReadyCheck = false
}
@ -97,6 +99,13 @@ func (s *ServerCommon) GetSecretKey() []byte {
// Prefer UseModernPSKServer or UseLegacySecurityServer.
func (s *ServerCommon) SetSecretKey(key []byte) {
s.SecretKey = key
if len(key) == 0 {
s.defaultModernPSKRuntime = nil
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
s.defaultModernPSKRuntime = runtime
} else {
s.defaultModernPSKRuntime = nil
}
s.securityReadyCheck = len(key) == 0
}

View File

@ -59,26 +59,8 @@ func (s *ServerCommon) pushMessageSourceFast(queue *stario.StarQueue, data []byt
if queue == nil || dispatcher == nil || len(data) == 0 {
return false
}
if err := queue.ParseMessageOwned(data, source, func(msg stario.MsgQueue) error {
payload := msg.Msg
source := msg.Conn
s.wg.Add(1)
if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() {
defer s.wg.Done()
logical, transport := s.resolveInboundSource(source)
if logical == nil {
return
}
now := time.Now()
inboundConn := serverInboundConn(source)
if err := s.dispatchInboundTransportPayload(logical, transport, inboundConn, payload, now); err != nil {
if s.showError || s.debugMode {
fmt.Println("server decode envelope error", err)
}
}
}) {
s.wg.Done()
}
if err := queue.ParseMessageView(data, source, func(frame stario.FrameView) error {
s.pushTransportPayloadSourceFast(frame.Payload, nil, frame.Conn)
return nil
}); err != nil && (s.showError || s.debugMode) {
fmt.Println("server parse inbound frame error", err)
@ -86,6 +68,59 @@ func (s *ServerCommon) pushMessageSourceFast(queue *stario.StarQueue, data []byt
return true
}
func (s *ServerCommon) pushTransportPayloadSourceFast(payload []byte, release func(), source interface{}) bool {
dispatcher := s.serverInboundDispatcherSnapshot()
if len(payload) == 0 {
if release != nil {
release()
}
return false
}
if dispatcher == nil {
queue := s.serverQueueSnapshot()
if queue == nil {
if release != nil {
release()
}
return false
}
frame := queue.BuildMessage(payload)
if release != nil {
release()
}
if err := queue.ParseMessage(frame, source); err != nil && (s.showError || s.debugMode) {
fmt.Println("server enqueue inbound frame error", err)
}
return true
}
if s.tryDispatchBorrowedBulkTransportPayload(source, payload) {
if release != nil {
release()
}
return true
}
owned := append([]byte(nil), payload...)
if release != nil {
release()
}
s.wg.Add(1)
if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() {
defer s.wg.Done()
logical, transport := s.resolveInboundSource(source)
if logical == nil {
return
}
now := time.Now()
inboundConn := serverInboundConn(source)
if err := s.dispatchInboundTransportPayload(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) {
fmt.Println("server decode envelope error", err)
}
}) {
s.wg.Done()
}
return true
}
func serverInboundConn(source interface{}) net.Conn {
switch data := source.(type) {
case net.Conn:

View File

@ -130,6 +130,7 @@ func (s *ServerCommon) removeLogical(logical *LogicalConn) {
s.getFileAckPool().closeScopeFamily(scope)
s.getSignalAckPool().closeScopeFamily(scope)
s.getReceivedSignalCache().closeScope(scope)
s.closeServerDedicatedSidecar(logical)
s.getPeerRegistry().removeLogical(logical)
}

View File

@ -57,6 +57,7 @@ func (s *ServerCommon) detachClientSessionTransport(client *ClientConn, reason s
if runtime := s.getBulkRuntime(); runtime != nil {
runtime.closeScope(serverFileScope(client), errTransportDetached)
}
s.closeServerDedicatedSidecar(logicalConnFromClient(client))
client.detachServerOwnedTransport()
}
@ -80,6 +81,7 @@ func (s *ServerCommon) detachLogicalSessionTransport(logical *LogicalConn, reaso
if runtime := s.getBulkRuntime(); runtime != nil {
runtime.closeScope(serverFileScope(logical), errTransportDetached)
}
s.closeServerDedicatedSidecar(logical)
logical.detachServerOwnedTransport()
}
@ -108,6 +110,7 @@ func (s *ServerCommon) registerAcceptedLogical(logical *LogicalConn) *LogicalCon
}
logical.setServer(s)
logical.applyAttachmentProfile(s.maxReadTimeout, s.maxWriteTimeout, s.defaultMsgEn, s.defaultMsgDe, s.defaultFastStreamEncode, s.defaultFastBulkEncode, s.defaultFastPlainEncode, s.handshakeRsaKey, s.SecretKey)
logical.setModernPSKRuntime(s.defaultModernPSKRuntime)
logical.markHeartbeatNow()
return s.getPeerRegistry().registerLogical(logical)
}

View File

@ -33,6 +33,11 @@ func (s *ServerCommon) OpenStreamLogical(ctx context.Context, logical *LogicalCo
if resp.DataID != 0 {
req.DataID = resp.DataID
}
if resp.FastPathVersion != 0 {
req.FastPathVersion = resp.FastPathVersion
} else {
req.FastPathVersion = streamFastPathVersionV1
}
req.Metadata = mergeStreamMetadata(req.Metadata, resp.Metadata)
transport := logical.CurrentTransportConn()
stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverStreamCloseSender(s, logical, nil), serverStreamResetSender(s, logical, nil), serverStreamDataSender(s, transport), runtime.configSnapshot())
@ -73,6 +78,11 @@ func (s *ServerCommon) OpenStreamTransport(ctx context.Context, transport *Trans
if resp.DataID != 0 {
req.DataID = resp.DataID
}
if resp.FastPathVersion != 0 {
req.FastPathVersion = resp.FastPathVersion
} else {
req.FastPathVersion = streamFastPathVersionV1
}
req.Metadata = mergeStreamMetadata(req.Metadata, resp.Metadata)
stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverStreamCloseSender(s, logical, transport), serverStreamResetSender(s, logical, transport), serverStreamDataSender(s, transport), runtime.configSnapshot())
if err := runtime.register(scope, stream); err != nil {
@ -91,11 +101,12 @@ func serverStreamRequest(runtime *streamRuntime, opt StreamOpenOptions) StreamOp
id = runtime.nextID()
}
return normalizeStreamOpenRequest(StreamOpenRequest{
StreamID: id,
Channel: opt.Channel,
Metadata: cloneStreamMetadata(opt.Metadata),
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
StreamID: id,
FastPathVersion: streamFastPathVersionCurrent,
Channel: opt.Channel,
Metadata: cloneStreamMetadata(opt.Metadata),
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
})
}
@ -148,7 +159,7 @@ func serverStreamDataSender(s *ServerCommon, transport *TransportConn) streamDat
}
}
if dataID := stream.dataIDSnapshot(); dataID != 0 {
return s.sendFastStreamDataTransport(stream.LogicalConn(), transport, dataID, stream.nextOutboundDataSeq(), chunk)
return s.sendFastStreamDataTransport(ctx, stream.LogicalConn(), transport, stream, chunk)
}
return s.sendEnvelopeTransport(transport, newStreamDataEnvelope(stream.ID(), chunk))
}

View File

@ -24,6 +24,8 @@ type Server interface {
SetStreamConfig(StreamConfig)
SetTransferResumeStore(TransferResumeStore)
RecoverTransferSnapshots(context.Context) error
SetBulkOpenTuning(BulkOpenTuning)
BulkOpenTuning() BulkOpenTuning
SetFileReceiveDir(dir string) error
send(c *ClientConn, msg TransferMsg) (WaitMsg, error)
sendEnvelope(c *ClientConn, env Envelope) error

View File

@ -6,18 +6,34 @@ import (
)
type ClientRuntimeSnapshot struct {
OwnerState string
Alive bool
SessionEpoch uint64
TransportAttached bool
HasRuntimeConn bool
HasRuntimeQueue bool
HasRuntimeStopCtx bool
ConnectSource string
ConnectNetwork string
ConnectAddress string
CanReconnect bool
Retry ConnectionRetrySnapshot
OwnerState string
Alive bool
SessionEpoch uint64
TransportAttached bool
HasRuntimeConn bool
HasRuntimeQueue bool
HasRuntimeStopCtx bool
ConnectSource string
ConnectNetwork string
ConnectAddress string
CanReconnect bool
BulkNetworkProfile string
BulkDefaultMode string
BulkChunkSize int
BulkWindowBytes int
BulkMaxInFlight int
BulkAttachLimit int
BulkActiveLimit int
BulkLaneLimit int
BulkAttachAttempts int64
BulkAttachRetries int64
BulkAttachSuccess int64
BulkAutoFallbacks int64
TransportBulkAdaptiveSoftPayloadBytes int
TransportStreamAdaptiveSoftPayloadBytes int
TransportStreamAdaptiveWaitThresholdBytes int
TransportStreamAdaptiveFlushDelay time.Duration
Retry ConnectionRetrySnapshot
}
type ServerRuntimeSnapshot struct {
@ -33,36 +49,43 @@ type ServerRuntimeSnapshot struct {
HasRuntimeUDPListener bool
HasRuntimeQueue bool
HasRuntimeStopCtx bool
BulkChunkSize int
BulkWindowBytes int
BulkMaxInFlight int
Retry ConnectionRetrySnapshot
}
type ClientConnRuntimeSnapshot struct {
ClientID string
RemoteAddress string
Alive bool
Reason string
Error string
IdentityBound bool
UsesStreamTransport bool
TransportGeneration uint64
TransportAttached bool
HasRuntimeConn bool
HasRuntimeStopCtx bool
TransportAttachCount uint64
TransportDetachCount uint64
LastTransportAttachAt time.Time
DetachedClientKeepSec int64
LastHeartbeatAt time.Time
TransportDetachReason string
TransportDetachKind string
TransportDetachGeneration uint64
TransportDetachError string
TransportDetachedAt time.Time
TransportDetachHasExpiry bool
TransportDetachExpiry time.Time
TransportDetachRemaining time.Duration
TransportDetachExpired bool
ReattachEligible bool
ClientID string
RemoteAddress string
Alive bool
Reason string
Error string
IdentityBound bool
UsesStreamTransport bool
TransportGeneration uint64
TransportAttached bool
HasRuntimeConn bool
HasRuntimeStopCtx bool
TransportAttachCount uint64
TransportDetachCount uint64
LastTransportAttachAt time.Time
DetachedClientKeepSec int64
LastHeartbeatAt time.Time
TransportDetachReason string
TransportDetachKind string
TransportDetachGeneration uint64
TransportDetachError string
TransportDetachedAt time.Time
TransportDetachHasExpiry bool
TransportDetachExpiry time.Time
TransportDetachRemaining time.Duration
TransportDetachExpired bool
ReattachEligible bool
TransportBulkAdaptiveSoftPayloadBytes int
TransportStreamAdaptiveSoftPayloadBytes int
TransportStreamAdaptiveWaitThresholdBytes int
TransportStreamAdaptiveFlushDelay time.Duration
}
func (c *ClientCommon) clientRuntimeSnapshot() ClientRuntimeSnapshot {
@ -85,6 +108,26 @@ func (c *ClientCommon) clientRuntimeSnapshot() ClientRuntimeSnapshot {
snapshot.ConnectAddress = source.addr
snapshot.CanReconnect = source.canReconnect()
}
snapshot.BulkNetworkProfile = bulkNetworkProfileName(c.BulkNetworkProfile())
snapshot.BulkDefaultMode = bulkOpenModeName(c.BulkDefaultOpenMode())
tuning := c.BulkOpenTuning()
cfg := c.BulkDedicatedAttachConfig()
snapshot.BulkChunkSize = tuning.ChunkSize
snapshot.BulkWindowBytes = tuning.WindowBytes
snapshot.BulkMaxInFlight = tuning.MaxInFlight
snapshot.BulkAttachLimit = cfg.AttachLimit
snapshot.BulkActiveLimit = cfg.ActiveLimit
snapshot.BulkLaneLimit = cfg.LaneLimit
snapshot.BulkAttachAttempts = c.bulkAttachAttemptCount.Load()
snapshot.BulkAttachRetries = c.bulkAttachRetryCount.Load()
snapshot.BulkAttachSuccess = c.bulkAttachSuccessCount.Load()
snapshot.BulkAutoFallbacks = c.bulkAttachFallbackCount.Load()
if binding := c.clientTransportBindingSnapshot(); binding != nil {
snapshot.TransportBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot()
snapshot.TransportStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot()
snapshot.TransportStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot()
snapshot.TransportStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot()
}
snapshot.Retry = c.connectionRetrySnapshot()
return snapshot
}
@ -118,6 +161,10 @@ func (s *ServerCommon) serverRuntimeSnapshot() ServerRuntimeSnapshot {
snapshot.HasRuntimeQueue = rt.queue != nil
snapshot.HasRuntimeStopCtx = rt.stopCtx != nil
}
tuning := s.BulkOpenTuning()
snapshot.BulkChunkSize = tuning.ChunkSize
snapshot.BulkWindowBytes = tuning.WindowBytes
snapshot.BulkMaxInFlight = tuning.MaxInFlight
snapshot.Retry = s.connectionRetrySnapshot()
return snapshot
}
@ -153,6 +200,12 @@ func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot {
snapshot.HasRuntimeConn = c.clientConnTransportSnapshot() != nil
snapshot.HasRuntimeStopCtx = rt.stopCtx != nil
}
if binding := c.clientConnTransportBindingSnapshot(); binding != nil {
snapshot.TransportBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot()
snapshot.TransportStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot()
snapshot.TransportStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot()
snapshot.TransportStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot()
}
if detach := c.clientConnTransportDetachSnapshot(); detach != nil {
snapshot.TransportDetachReason = detach.Reason
snapshot.TransportDetachKind = c.clientConnTransportDetachKindSnapshot()

View File

@ -37,6 +37,45 @@ func TestGetClientRuntimeSnapshotDefaults(t *testing.T) {
if snapshot.ConnectSource != "" || snapshot.ConnectNetwork != "" || snapshot.ConnectAddress != "" || snapshot.CanReconnect {
t.Fatalf("unexpected default connect source snapshot: %+v", snapshot)
}
if got, want := snapshot.BulkNetworkProfile, "default"; got != want {
t.Fatalf("BulkNetworkProfile mismatch: got %q want %q", got, want)
}
if got, want := snapshot.BulkDefaultMode, "shared"; got != want {
t.Fatalf("BulkDefaultMode mismatch: got %q want %q", got, want)
}
if got, want := snapshot.BulkChunkSize, defaultBulkChunkSize; got != want {
t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkWindowBytes, defaultBulkOpenWindowBytes; got != want {
t.Fatalf("BulkWindowBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkMaxInFlight, defaultBulkOpenMaxInFlight; got != want {
t.Fatalf("BulkMaxInFlight mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkAttachLimit, defaultBulkDedicatedAttachLimit; got != want {
t.Fatalf("BulkAttachLimit mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkActiveLimit, defaultBulkDedicatedActiveLimit; got != want {
t.Fatalf("BulkActiveLimit mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkLaneLimit, defaultBulkDedicatedLaneLimit; got != want {
t.Fatalf("BulkLaneLimit mismatch: got %d want %d", got, want)
}
if snapshot.BulkAttachAttempts != 0 || snapshot.BulkAttachRetries != 0 || snapshot.BulkAttachSuccess != 0 || snapshot.BulkAutoFallbacks != 0 {
t.Fatalf("unexpected default bulk attach counters: %+v", snapshot)
}
if snapshot.TransportBulkAdaptiveSoftPayloadBytes != 0 {
t.Fatalf("TransportBulkAdaptiveSoftPayloadBytes mismatch: got %d want 0", snapshot.TransportBulkAdaptiveSoftPayloadBytes)
}
if snapshot.TransportStreamAdaptiveSoftPayloadBytes != 0 {
t.Fatalf("TransportStreamAdaptiveSoftPayloadBytes mismatch: got %d want 0", snapshot.TransportStreamAdaptiveSoftPayloadBytes)
}
if snapshot.TransportStreamAdaptiveWaitThresholdBytes != 0 {
t.Fatalf("TransportStreamAdaptiveWaitThresholdBytes mismatch: got %d want 0", snapshot.TransportStreamAdaptiveWaitThresholdBytes)
}
if snapshot.TransportStreamAdaptiveFlushDelay != 0 {
t.Fatalf("TransportStreamAdaptiveFlushDelay mismatch: got %s want 0", snapshot.TransportStreamAdaptiveFlushDelay)
}
if snapshot.Retry != (ConnectionRetrySnapshot{}) {
t.Fatalf("Retry snapshot mismatch: %+v", snapshot.Retry)
}
@ -78,11 +117,144 @@ func TestGetServerRuntimeSnapshotDefaults(t *testing.T) {
if !snapshot.HasRuntimeStopCtx {
t.Fatalf("HasRuntimeStopCtx mismatch: got %v want true", snapshot.HasRuntimeStopCtx)
}
if got, want := snapshot.BulkChunkSize, defaultBulkChunkSize; got != want {
t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkWindowBytes, defaultBulkOpenWindowBytes; got != want {
t.Fatalf("BulkWindowBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkMaxInFlight, defaultBulkOpenMaxInFlight; got != want {
t.Fatalf("BulkMaxInFlight mismatch: got %d want %d", got, want)
}
if snapshot.Retry != (ConnectionRetrySnapshot{}) {
t.Fatalf("Retry snapshot mismatch: %+v", snapshot.Retry)
}
}
func TestGetClientRuntimeSnapshotIncludesBulkAttachStats(t *testing.T) {
client := NewClient().(*ClientCommon)
client.SetBulkNetworkProfile(BulkNetworkProfileWAN)
client.SetBulkDefaultOpenMode(BulkOpenModeAuto)
client.bulkAttachAttemptCount.Store(11)
client.bulkAttachRetryCount.Store(5)
client.bulkAttachSuccessCount.Store(6)
client.bulkAttachFallbackCount.Store(3)
snapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if got, want := snapshot.BulkNetworkProfile, "wan"; got != want {
t.Fatalf("BulkNetworkProfile mismatch: got %q want %q", got, want)
}
if got, want := snapshot.BulkDefaultMode, "auto"; got != want {
t.Fatalf("BulkDefaultMode mismatch: got %q want %q", got, want)
}
if got, want := snapshot.BulkChunkSize, 512*1024; got != want {
t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkWindowBytes, 8*1024*1024; got != want {
t.Fatalf("BulkWindowBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkMaxInFlight, 16; got != want {
t.Fatalf("BulkMaxInFlight mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkAttachLimit, 2; got != want {
t.Fatalf("BulkAttachLimit mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkActiveLimit, 4096; got != want {
t.Fatalf("BulkActiveLimit mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkLaneLimit, 2; got != want {
t.Fatalf("BulkLaneLimit mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkAttachAttempts, int64(11); got != want {
t.Fatalf("BulkAttachAttempts mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkAttachRetries, int64(5); got != want {
t.Fatalf("BulkAttachRetries mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkAttachSuccess, int64(6); got != want {
t.Fatalf("BulkAttachSuccess mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkAutoFallbacks, int64(3); got != want {
t.Fatalf("BulkAutoFallbacks mismatch: got %d want %d", got, want)
}
}
func TestGetClientRuntimeSnapshotIncludesAdaptiveBulkSoftPayload(t *testing.T) {
client := NewClient().(*ClientCommon)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
binding := newTransportBinding(left, queue)
binding.observeBulkAdaptivePayloadWrite(8*1024*1024, 640*time.Millisecond, 0, nil)
client.setClientSessionRuntime(&clientSessionRuntime{
transport: binding,
stopCtx: stopCtx,
stopFn: stopFn,
queue: queue,
epoch: 1,
})
client.markSessionStarted()
snapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if got, want := snapshot.TransportBulkAdaptiveSoftPayloadBytes, bulkAdaptiveSoftPayloadMinBytes; got != want {
t.Fatalf("TransportBulkAdaptiveSoftPayloadBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.TransportStreamAdaptiveSoftPayloadBytes, streamAdaptiveSoftPayloadStartBytes; got != want {
t.Fatalf("TransportStreamAdaptiveSoftPayloadBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.TransportStreamAdaptiveWaitThresholdBytes, streamBatchWaitThreshold; got != want {
t.Fatalf("TransportStreamAdaptiveWaitThresholdBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.TransportStreamAdaptiveFlushDelay, streamBatchMaxFlushDelay; got != want {
t.Fatalf("TransportStreamAdaptiveFlushDelay mismatch: got %s want %s", got, want)
}
}
func TestGetClientRuntimeSnapshotIncludesAdaptiveStreamTuning(t *testing.T) {
client := NewClient().(*ClientCommon)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
binding := newTransportBinding(left, queue)
binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil)
client.setClientSessionRuntime(&clientSessionRuntime{
transport: binding,
stopCtx: stopCtx,
stopFn: stopFn,
queue: queue,
epoch: 1,
})
client.markSessionStarted()
snapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if got, want := snapshot.TransportStreamAdaptiveSoftPayloadBytes, streamAdaptiveSoftPayloadMinBytes; got != want {
t.Fatalf("TransportStreamAdaptiveSoftPayloadBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.TransportStreamAdaptiveWaitThresholdBytes, streamAdaptiveWaitThresholdMinBytes; got != want {
t.Fatalf("TransportStreamAdaptiveWaitThresholdBytes mismatch: got %d want %d", got, want)
}
if got := snapshot.TransportStreamAdaptiveFlushDelay; got != 0 {
t.Fatalf("TransportStreamAdaptiveFlushDelay mismatch: got %s want 0", got)
}
}
func TestGetRuntimeSnapshotRejectsNil(t *testing.T) {
if _, err := GetClientRuntimeSnapshot(nil); !errors.Is(err, errClientRuntimeSnapshotNil) {
t.Fatalf("GetClientRuntimeSnapshot nil error = %v, want %v", err, errClientRuntimeSnapshotNil)

View File

@ -140,6 +140,7 @@ func (c *ClientCommon) cleanupClientSessionResources() {
if runtime := c.getBulkRuntime(); runtime != nil {
runtime.closeAll(errServiceShutdown)
}
c.closeClientDedicatedSidecar()
}
func (s *ServerCommon) cleanupServerSessionResources() {
@ -158,4 +159,5 @@ func (s *ServerCommon) cleanupServerSessionResources() {
if runtime := s.getBulkRuntime(); runtime != nil {
runtime.closeAll(errServiceShutdown)
}
s.closeAllServerDedicatedSidecars()
}

View File

@ -76,7 +76,7 @@ func startSignalRoundTripServerForBenchmark(b *testing.B) (*ServerCommon, string
server.SetLink("signal-roundtrip", func(msg *Message) {
_ = msg.Reply([]byte("ack:" + string(msg.Value)))
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil {
if benchmarkListenPermissionDenied(err) {
b.Skipf("tcp benchmark requires local listen permission: %v", err)
}
@ -102,7 +102,7 @@ func newSignalRoundTripBenchmarkClient(b *testing.B, addr string) *ClientCommon
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
b.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", addr); err != nil {
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, addr)); err != nil {
b.Fatalf("client Connect failed: %v", err)
}
return client

View File

@ -3,20 +3,24 @@ package notify
import "time"
type snapshotBindingDiagnostics struct {
BindingOwner string
BindingAlive bool
BindingCurrent bool
BindingReason string
BindingError string
TransportAttached bool
TransportHasRuntimeConn bool
TransportCurrent bool
TransportDetachReason string
TransportDetachKind string
TransportDetachError string
TransportDetachGeneration uint64
TransportDetachedAt time.Time
ReattachEligible bool
BindingOwner string
BindingAlive bool
BindingCurrent bool
BindingReason string
BindingError string
BindingBulkAdaptiveSoftPayloadBytes int
BindingStreamAdaptiveSoftPayloadBytes int
BindingStreamAdaptiveWaitThresholdBytes int
BindingStreamAdaptiveFlushDelay time.Duration
TransportAttached bool
TransportHasRuntimeConn bool
TransportCurrent bool
TransportDetachReason string
TransportDetachKind string
TransportDetachError string
TransportDetachGeneration uint64
TransportDetachedAt time.Time
ReattachEligible bool
}
func snapshotBindingDiagnosticsFromClient(c *ClientCommon, sessionEpoch uint64) snapshotBindingDiagnostics {
@ -36,6 +40,12 @@ func snapshotBindingDiagnosticsFromClient(c *ClientCommon, sessionEpoch uint64)
diag.TransportAttached = c.clientTransportAttachedSnapshot()
diag.TransportHasRuntimeConn = c.clientTransportConnSnapshot() != nil
diag.TransportCurrent = diag.BindingCurrent && diag.TransportAttached
if binding := c.clientTransportBindingSnapshot(); binding != nil {
diag.BindingBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot()
diag.BindingStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot()
diag.BindingStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot()
diag.BindingStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot()
}
return diag
}
@ -72,5 +82,11 @@ func snapshotBindingDiagnosticsFromLogical(logical *LogicalConn, transport *Tran
diag.TransportCurrent = runtime.TransportAttached
}
diag.BindingCurrent = diag.BindingAlive && diag.TransportCurrent
if binding := logical.transportBindingSnapshot(); binding != nil {
diag.BindingBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot()
diag.BindingStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot()
diag.BindingStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot()
diag.BindingStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot()
}
return diag
}

100
stream.go
View File

@ -7,6 +7,7 @@ import (
"net"
"os"
"sync"
"sync/atomic"
"time"
)
@ -93,7 +94,8 @@ type streamHandle struct {
runtimeScope string
id string
dataID uint64
outboundSeq uint64
fastPathVersion uint8
outboundSeq atomic.Uint64
channel StreamChannel
metadata StreamMetadata
sessionEpoch uint64
@ -132,6 +134,9 @@ type streamHandle struct {
writeDeadlineOverride bool
readDeadlineNotify chan struct{}
writeDeadlineNotify chan struct{}
writeWaitSeq uint64
writeWaitCancel context.CancelFunc
writeWaitChanged chan struct{}
bytesRead int64
bytesWritten int64
readCalls int64
@ -157,6 +162,7 @@ func newStreamHandle(parent context.Context, runtime *streamRuntime, runtimeScop
runtimeScope: runtimeScope,
id: req.StreamID,
dataID: req.DataID,
fastPathVersion: normalizeStreamFastPathVersion(req.FastPathVersion),
channel: normalizeStreamChannel(req.Channel),
metadata: cloneStreamMetadata(req.Metadata),
sessionEpoch: sessionEpoch,
@ -224,13 +230,25 @@ func (s *streamHandle) dataIDSnapshot() uint64 {
}
func (s *streamHandle) nextOutboundDataSeq() uint64 {
return s.reserveOutboundDataSeqs(1)
}
func (s *streamHandle) reserveOutboundDataSeqs(count int) uint64 {
if s == nil {
return 0
}
s.mu.Lock()
defer s.mu.Unlock()
s.outboundSeq++
return s.outboundSeq
if count <= 0 {
count = 1
}
end := s.outboundSeq.Add(uint64(count))
return end - uint64(count) + 1
}
func (s *streamHandle) fastPathVersionSnapshot() uint8 {
if s == nil {
return streamFastPathVersionV1
}
return normalizeStreamFastPathVersion(s.fastPathVersion)
}
func (s *streamHandle) Channel() StreamChannel {
@ -377,6 +395,7 @@ func (s *streamHandle) Write(p []byte) (int, error) {
sendDataFn := s.sendDataFn
chunkSize := s.chunkSize
writeTimeout := s.writeTimeout
writeDeadlineOverride := s.writeDeadlineOverride
streamCtx := s.ctx
runtime := s.runtime
s.mu.Unlock()
@ -399,6 +418,20 @@ func (s *streamHandle) Write(p []byte) (int, error) {
end = len(p)
}
chunk := p[written:end]
if !writeDeadlineOverride && writeTimeout <= 0 {
if tryAcquireStreamOutboundBudget(runtime, len(chunk)) {
err := sendDataFn(streamCtx, s, chunk)
releaseStreamOutboundBudget(runtime, len(chunk))
if err != nil {
if written > 0 {
s.recordWrite(written, time.Now())
}
return written, s.normalizeWriteError(err)
}
written = end
continue
}
}
sendCtx, cancel, deadlineChanged, err := s.newWriteContext(streamCtx, writeTimeout)
if err != nil {
if written > 0 {
@ -464,7 +497,15 @@ func (s *streamHandle) SetWriteDeadline(deadline time.Time) error {
s.writeDeadline = deadline
s.writeDeadlineOverride = true
signalStreamDeadlineChangeLocked(&s.writeDeadlineNotify)
waitCancel := s.writeWaitCancel
if s.writeWaitChanged != nil {
close(s.writeWaitChanged)
s.writeWaitChanged = nil
}
s.mu.Unlock()
if waitCancel != nil {
waitCancel()
}
return nil
}
@ -535,7 +576,6 @@ func (s *streamHandle) newWriteContext(parent context.Context, writeTimeout time
}
s.mu.Lock()
deadline := s.effectiveWriteDeadlineLocked(time.Now(), writeTimeout)
deadlineNotify := s.writeDeadlineNotify
s.mu.Unlock()
if !deadline.IsZero() && !deadline.After(time.Now()) {
return nil, func() {}, nil, os.ErrDeadlineExceeded
@ -548,19 +588,20 @@ func (s *streamHandle) newWriteContext(parent context.Context, writeTimeout time
baseCtx, baseCancel = context.WithCancel(parent)
}
changed := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
select {
case <-baseCtx.Done():
case <-deadlineNotify:
close(changed)
baseCancel()
}
}()
s.mu.Lock()
s.writeWaitSeq++
waitSeq := s.writeWaitSeq
s.writeWaitCancel = baseCancel
s.writeWaitChanged = changed
s.mu.Unlock()
cancel := func() {
baseCancel()
<-done
s.mu.Lock()
if s.writeWaitSeq == waitSeq {
s.writeWaitCancel = nil
s.writeWaitChanged = nil
}
s.mu.Unlock()
}
return baseCtx, cancel, changed, nil
}
@ -814,7 +855,11 @@ func (s *streamHandle) pushChunkWithOwnership(chunk []byte, owned bool) error {
s.finalize()
return err
}
s.readQueue = append(s.readQueue, stored)
if len(s.readBuf) == 0 && len(s.readQueue) == 0 {
s.readBuf = stored
} else {
s.readQueue = append(s.readQueue, stored)
}
s.bufferedBytes += len(stored)
s.notifyReadableLocked()
s.mu.Unlock()
@ -917,6 +962,10 @@ func (s *streamHandle) snapshot() StreamSnapshot {
snapshot.BindingCurrent = diag.BindingCurrent
snapshot.BindingReason = diag.BindingReason
snapshot.BindingError = diag.BindingError
snapshot.BindingBulkAdaptiveSoftPayloadBytes = diag.BindingBulkAdaptiveSoftPayloadBytes
snapshot.BindingStreamAdaptiveSoftPayloadBytes = diag.BindingStreamAdaptiveSoftPayloadBytes
snapshot.BindingStreamAdaptiveWaitThresholdBytes = diag.BindingStreamAdaptiveWaitThresholdBytes
snapshot.BindingStreamAdaptiveFlushDelay = diag.BindingStreamAdaptiveFlushDelay
snapshot.TransportAttached = diag.TransportAttached
snapshot.TransportHasRuntimeConn = diag.TransportHasRuntimeConn
snapshot.TransportCurrent = diag.TransportCurrent
@ -1057,8 +1106,23 @@ func acquireStreamOutboundBudget(runtime *streamRuntime, ctx context.Context, si
return runtime.acquireOutbound(ctx, size)
}
func tryAcquireStreamOutboundBudget(runtime *streamRuntime, size int) bool {
if runtime == nil {
return true
}
return runtime.tryAcquireOutbound(size)
}
func releaseStreamOutboundBudget(runtime *streamRuntime, size int) {
if runtime == nil {
return
}
runtime.releaseOutbound(size)
}
func normalizeStreamOpenRequest(req StreamOpenRequest) StreamOpenRequest {
req.Channel = normalizeStreamChannel(req.Channel)
req.FastPathVersion = normalizeStreamFastPathVersion(req.FastPathVersion)
req.Metadata = cloneStreamMetadata(req.Metadata)
return req
}

6
stream_batch_codec.go Normal file
View File

@ -0,0 +1,6 @@
package notify
type streamBatchCodec struct {
encodeSingle func(streamFastDataFrame) ([]byte, error)
encodeBatch func([]streamFastDataFrame) ([]byte, error)
}

582
stream_batch_sender.go Normal file
View File

@ -0,0 +1,582 @@
package notify
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
)
const (
streamBatchMaxPayloads = 64
streamBatchMaxPayloadBytes = 2 * 1024 * 1024
streamBatchMaxFlushDelay = 50 * time.Microsecond
streamBatchWaitThreshold = 128 * 1024
)
const (
streamBatchRequestQueued int32 = iota
streamBatchRequestStarted
streamBatchRequestCanceled
)
type streamBatchRequestState struct {
value atomic.Int32
}
type streamBatchRequest struct {
ctx context.Context
frame streamFastDataFrame
hasFrame bool
encodedPayload []byte
hasEncoded bool
frames []streamFastDataFrame
fastPathVersion uint8
deadline time.Time
done chan error
state *streamBatchRequestState
}
type streamBatchSender struct {
binding *transportBinding
codec streamBatchCodec
writeTimeoutProvider func() time.Duration
reqCh chan streamBatchRequest
stopCh chan struct{}
doneCh chan struct{}
stopOnce sync.Once
flushMu sync.Mutex
queued atomic.Int64
errMu sync.Mutex
err error
}
func newStreamBatchSender(binding *transportBinding, codec streamBatchCodec, writeTimeoutProvider func() time.Duration) *streamBatchSender {
sender := &streamBatchSender{
binding: binding,
codec: codec,
writeTimeoutProvider: writeTimeoutProvider,
reqCh: make(chan streamBatchRequest, streamBatchMaxPayloads*4),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
go sender.run()
return sender
}
func (s *streamBatchSender) submitData(ctx context.Context, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
if s == nil {
return errTransportDetached
}
if len(payload) == 0 {
return nil
}
return s.submitRequest(streamBatchRequest{
ctx: ctx,
frame: streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: payload,
},
hasFrame: true,
fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion),
})
}
func (s *streamBatchSender) submitEncoded(ctx context.Context, fastPathVersion uint8, payload []byte) error {
if s == nil {
return errTransportDetached
}
if len(payload) == 0 {
return nil
}
return s.submitRequest(streamBatchRequest{
ctx: ctx,
encodedPayload: payload,
hasEncoded: true,
fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion),
})
}
func (s *streamBatchSender) submitFrames(ctx context.Context, fastPathVersion uint8, frames []streamFastDataFrame) error {
if s == nil {
return errTransportDetached
}
if len(frames) == 0 {
return nil
}
queuedFrames := append([]streamFastDataFrame(nil), frames...)
req := streamBatchRequest{
ctx: ctx,
frames: queuedFrames,
fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion),
}
if len(queuedFrames) == 1 {
req.frame = queuedFrames[0]
req.frames = nil
req.hasFrame = true
}
return s.submitRequest(req)
}
func (s *streamBatchSender) submitRequest(req streamBatchRequest) error {
if s == nil {
return errTransportDetached
}
if req.ctx == nil {
req.ctx = context.Background()
}
if !req.hasFrame && !req.hasEncoded && len(req.frames) == 0 {
return nil
}
req.fastPathVersion = normalizeStreamFastPathVersion(req.fastPathVersion)
req.done = make(chan error, 1)
req.state = &streamBatchRequestState{}
if deadline, ok := req.ctx.Deadline(); ok {
req.deadline = deadline
}
if err := s.errSnapshot(); err != nil {
return err
}
if s.shouldDirectSubmit(req) {
if submitted, err := s.tryDirectSubmit(req); submitted {
return err
}
}
s.queued.Add(1)
select {
case <-req.ctx.Done():
s.queued.Add(-1)
return normalizeStreamDeadlineError(req.ctx.Err())
case <-s.stopCh:
s.queued.Add(-1)
return s.stoppedErr()
case s.reqCh <- req:
}
select {
case err := <-req.done:
return err
case <-req.ctx.Done():
if req.tryCancel() {
return normalizeStreamDeadlineError(req.ctx.Err())
}
return <-req.done
}
}
func (s *streamBatchSender) shouldDirectSubmit(req streamBatchRequest) bool {
if req.hasEncoded {
return false
}
if !req.hasFrame && len(req.frames) == 0 {
return false
}
return !streamFastPathSupportsBatch(req.fastPathVersion)
}
func (s *streamBatchSender) tryDirectSubmit(req streamBatchRequest) (bool, error) {
if s == nil {
return true, errTransportDetached
}
if err := s.errSnapshot(); err != nil {
return true, err
}
select {
case <-req.ctx.Done():
return true, normalizeStreamDeadlineError(req.ctx.Err())
case <-s.stopCh:
return true, s.stoppedErr()
default:
}
if s.queued.Load() != 0 {
return false, nil
}
if !s.flushMu.TryLock() {
return false, nil
}
defer s.flushMu.Unlock()
if s.queued.Load() != 0 {
return false, nil
}
if err := s.errSnapshot(); err != nil {
return true, err
}
if !req.tryStart() {
return true, req.canceledErr()
}
if err := req.contextErr(); err != nil {
return true, err
}
if err := s.flush([]streamBatchRequest{req}); err != nil {
s.setErr(err)
s.failPending(err)
return true, err
}
return true, nil
}
func (s *streamBatchSender) run() {
defer close(s.doneCh)
for {
req, ok := s.nextRequest()
if !ok {
return
}
batch := []streamBatchRequest{req}
batchBytes := streamBatchRequestApproxBytes(req)
softPayloadLimit := s.batchSoftPayloadLimit()
waitThreshold := s.batchWaitThreshold()
flushDelay := s.batchFlushDelay()
timer := (*time.Timer)(nil)
timerCh := (<-chan time.Time)(nil)
if flushDelay > 0 && batchBytes < waitThreshold && batchBytes < softPayloadLimit && len(batch) < streamBatchMaxPayloads {
timer = time.NewTimer(flushDelay)
timerCh = timer.C
}
drain:
for len(batch) < streamBatchMaxPayloads && batchBytes < softPayloadLimit {
if timerCh == nil {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
batchBytes += streamBatchRequestApproxBytes(next)
default:
break drain
}
continue
}
select {
case <-s.stopCh:
if timer != nil {
timer.Stop()
}
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
batchBytes += streamBatchRequestApproxBytes(next)
case <-timerCh:
timerCh = nil
break drain
}
}
if timer != nil {
if !timer.Stop() && timerCh != nil {
select {
case <-timer.C:
default:
}
}
}
s.flushMu.Lock()
err := s.errSnapshot()
active := make([]streamBatchRequest, 0, len(batch))
for _, item := range batch {
if !item.tryStart() {
s.finishRequest(item, item.canceledErr())
continue
}
if itemErr := item.contextErr(); itemErr != nil {
s.finishRequest(item, itemErr)
continue
}
active = append(active, item)
}
if len(active) == 0 {
s.flushMu.Unlock()
continue
}
if err == nil {
err = s.flush(active)
}
s.flushMu.Unlock()
if err != nil {
s.setErr(err)
for _, item := range active {
s.finishRequest(item, err)
}
s.failPending(err)
return
}
for _, item := range active {
s.finishRequest(item, nil)
}
}
}
func (s *streamBatchSender) nextRequest() (streamBatchRequest, bool) {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return streamBatchRequest{}, false
case req := <-s.reqCh:
return req, true
}
}
func (s *streamBatchSender) flush(requests []streamBatchRequest) error {
if s == nil || s.binding == nil {
return errTransportDetached
}
queue := s.binding.queueSnapshot()
if queue == nil {
return errTransportFrameQueueUnavailable
}
payloads, err := s.encodeRequests(requests)
if err != nil {
return err
}
writeTimeout := s.transportWriteTimeout()
payloadBytes := 0
for _, payload := range payloads {
payloadBytes += len(payload)
}
started := time.Now()
err = s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error {
return writeFramedPayloadBatchUnlocked(conn, queue, payloads)
})
s.binding.observeStreamAdaptivePayloadWrite(payloadBytes, time.Since(started), writeTimeout, err)
return err
}
func (s *streamBatchSender) transportWriteTimeout() time.Duration {
if s == nil || s.writeTimeoutProvider == nil {
return 0
}
return s.writeTimeoutProvider()
}
func (s *streamBatchSender) batchSoftPayloadLimit() int {
if s == nil || s.binding == nil {
return streamAdaptiveSoftPayloadFallbackBytes
}
return s.binding.streamAdaptiveSoftPayloadBytesSnapshot()
}
func (s *streamBatchSender) batchWaitThreshold() int {
if s == nil || s.binding == nil {
return streamBatchWaitThreshold
}
return s.binding.streamAdaptiveWaitThresholdBytesSnapshot()
}
func (s *streamBatchSender) batchFlushDelay() time.Duration {
if s == nil || s.binding == nil {
return streamBatchMaxFlushDelay
}
return s.binding.streamAdaptiveFlushDelaySnapshot()
}
func (s *streamBatchSender) batchPlainPayloadLimit() int {
limit := s.batchSoftPayloadLimit()
if limit <= streamFastBatchHeaderLen {
return streamFastBatchHeaderLen + 1
}
return minInt(limit, streamFastBatchMaxPlainBytes)
}
func (s *streamBatchSender) encodeRequests(requests []streamBatchRequest) ([][]byte, error) {
if len(requests) == 0 {
return nil, nil
}
payloads := make([][]byte, 0, len(requests))
batchPlainLimit := s.batchPlainPayloadLimit()
var batch []streamFastDataFrame
flushBatch := func() error {
if len(batch) == 0 {
return nil
}
payload, err := s.encodeBatch(batch)
if err != nil {
return err
}
payloads = append(payloads, payload)
batch = batch[:0]
return nil
}
batchBytes := streamFastBatchHeaderLen
appendFrame := func(frame streamFastDataFrame, fastPathVersion uint8) error {
if !streamFastPathSupportsBatch(fastPathVersion) {
if err := flushBatch(); err != nil {
return err
}
batchBytes = streamFastBatchHeaderLen
payload, err := s.encodeSingle(frame)
if err != nil {
return err
}
payloads = append(payloads, payload)
return nil
}
frameLen := streamFastBatchFrameLen(frame)
if frameLen+streamFastBatchHeaderLen > batchPlainLimit {
if err := flushBatch(); err != nil {
return err
}
batchBytes = streamFastBatchHeaderLen
payload, err := s.encodeSingle(frame)
if err != nil {
return err
}
payloads = append(payloads, payload)
return nil
}
if len(batch) > 0 && (len(batch) >= streamFastBatchMaxItems || batchBytes+frameLen > batchPlainLimit) {
if err := flushBatch(); err != nil {
return err
}
batchBytes = streamFastBatchHeaderLen
}
if batch == nil {
batch = make([]streamFastDataFrame, 0, minInt(len(requests), streamFastBatchMaxItems))
}
batch = append(batch, frame)
batchBytes += frameLen
return nil
}
for _, req := range requests {
if req.hasFrame {
if err := appendFrame(req.frame, req.fastPathVersion); err != nil {
return nil, err
}
}
for _, frame := range req.frames {
if err := appendFrame(frame, req.fastPathVersion); err != nil {
return nil, err
}
}
if req.hasEncoded {
if err := flushBatch(); err != nil {
return nil, err
}
batchBytes = streamFastBatchHeaderLen
payloads = append(payloads, req.encodedPayload)
}
}
if err := flushBatch(); err != nil {
return nil, err
}
return payloads, nil
}
func streamBatchRequestApproxBytes(req streamBatchRequest) int {
total := 0
if req.hasFrame {
total += streamFastBatchFrameLen(req.frame)
}
for _, frame := range req.frames {
total += streamFastBatchFrameLen(frame)
}
if req.hasEncoded {
total += len(req.encodedPayload)
}
return total
}
func (s *streamBatchSender) encodeSingle(frame streamFastDataFrame) ([]byte, error) {
if s == nil || s.codec.encodeSingle == nil {
return nil, errTransportDetached
}
return s.codec.encodeSingle(frame)
}
func (s *streamBatchSender) encodeBatch(frames []streamFastDataFrame) ([]byte, error) {
if len(frames) == 1 || s.codec.encodeBatch == nil {
return s.encodeSingle(frames[0])
}
return s.codec.encodeBatch(frames)
}
func (s *streamBatchSender) finishRequest(req streamBatchRequest, err error) {
if s != nil {
s.queued.Add(-1)
}
req.done <- err
}
func (s *streamBatchSender) stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
s.setErr(errTransportDetached)
close(s.stopCh)
})
<-s.doneCh
}
func (s *streamBatchSender) failPending(err error) {
for {
select {
case item := <-s.reqCh:
s.finishRequest(item, err)
default:
return
}
}
}
func (s *streamBatchSender) setErr(err error) {
if s == nil || err == nil {
return
}
s.errMu.Lock()
if s.err == nil {
s.err = err
}
s.errMu.Unlock()
}
func (s *streamBatchSender) errSnapshot() error {
if s == nil {
return errTransportDetached
}
s.errMu.Lock()
defer s.errMu.Unlock()
return s.err
}
func (s *streamBatchSender) stoppedErr() error {
if err := s.errSnapshot(); err != nil {
return err
}
return errTransportDetached
}
func (r streamBatchRequest) contextErr() error {
if r.ctx == nil {
return nil
}
select {
case <-r.ctx.Done():
return normalizeStreamDeadlineError(r.ctx.Err())
default:
return nil
}
}
func (r streamBatchRequest) tryStart() bool {
if r.state == nil {
return true
}
return r.state.value.CompareAndSwap(streamBatchRequestQueued, streamBatchRequestStarted)
}
func (r streamBatchRequest) tryCancel() bool {
if r.state == nil {
return false
}
return r.state.value.CompareAndSwap(streamBatchRequestQueued, streamBatchRequestCanceled)
}
func (r streamBatchRequest) canceledErr() error {
if err := r.contextErr(); err != nil {
return err
}
return context.Canceled
}

View File

@ -128,7 +128,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil {
b.Fatalf("server Listen failed: %v", err)
}
b.Cleanup(func() {
@ -140,7 +140,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
b.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
b.Fatalf("client Connect failed: %v", err)
}
b.Cleanup(func() {
@ -216,7 +216,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil {
b.Fatalf("server Listen failed: %v", err)
}
b.Cleanup(func() {
@ -228,7 +228,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
b.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
b.Fatalf("client Connect failed: %v", err)
}
b.Cleanup(func() {

View File

@ -7,17 +7,19 @@ import (
)
type StreamOpenRequest struct {
StreamID string
DataID uint64
Channel StreamChannel
Metadata StreamMetadata
ReadTimeout time.Duration
WriteTimeout time.Duration
StreamID string
DataID uint64
FastPathVersion uint8
Channel StreamChannel
Metadata StreamMetadata
ReadTimeout time.Duration
WriteTimeout time.Duration
}
type StreamOpenResponse struct {
StreamID string
DataID uint64
FastPathVersion uint8
Accepted bool
TransportGeneration uint64
Metadata StreamMetadata
@ -92,6 +94,8 @@ func (c *ClientCommon) handleInboundStreamOpen(msg *Message) {
return
}
scope := clientFileScope()
req.FastPathVersion = negotiateStreamFastPathVersion(req.FastPathVersion)
resp.FastPathVersion = req.FastPathVersion
if req.DataID == 0 {
req.DataID = runtime.nextDataID()
resp.DataID = req.DataID
@ -178,6 +182,8 @@ func (s *ServerCommon) handleInboundStreamOpen(msg *Message) {
}
transport := messageTransportConnSnapshot(msg)
scope := serverFileScope(logical)
req.FastPathVersion = negotiateStreamFastPathVersion(req.FastPathVersion)
resp.FastPathVersion = req.FastPathVersion
if req.DataID == 0 {
req.DataID = runtime.nextDataID()
resp.DataID = req.DataID

View File

@ -1,8 +1,10 @@
package notify
import (
"context"
"encoding/binary"
"errors"
"io"
)
var (
@ -15,6 +17,7 @@ const (
streamFastPayloadVersion = 1
streamFastPayloadTypeData = 1
streamFastPayloadHeaderLen = 28
streamFastBatchDirectLimit = 512 * 1024
)
type streamFastDataFrame struct {
@ -24,6 +27,56 @@ type streamFastDataFrame struct {
Payload []byte
}
func streamAdaptiveFramePayloadLimit(binding *transportBinding) int {
if binding == nil {
return 0
}
limit := binding.streamAdaptiveSoftPayloadBytesSnapshot() - streamFastPayloadHeaderLen
if limit <= 0 {
return 1
}
maxPayload := streamFastBatchMaxPlainBytes - streamFastPayloadHeaderLen
if limit > maxPayload {
return maxPayload
}
return limit
}
func streamFastSplitFrameCount(size int, maxPayload int) int {
if size <= 0 || maxPayload <= 0 {
return 1
}
return (size + maxPayload - 1) / maxPayload
}
func buildStreamFastSplitFrames(dataID uint64, startSeq uint64, chunk []byte, maxPayload int) []streamFastDataFrame {
if len(chunk) == 0 {
return nil
}
if maxPayload <= 0 || len(chunk) <= maxPayload {
return []streamFastDataFrame{{
DataID: dataID,
Seq: startSeq,
Payload: chunk,
}}
}
frames := make([]streamFastDataFrame, 0, streamFastSplitFrameCount(len(chunk), maxPayload))
seq := startSeq
for offset := 0; offset < len(chunk); offset += maxPayload {
end := offset + maxPayload
if end > len(chunk) {
end = len(chunk)
}
frames = append(frames, streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: chunk[offset:end],
})
seq++
}
return frames
}
func encodeStreamFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error {
if dataID == 0 {
return errStreamFastDataIDEmpty
@ -51,6 +104,31 @@ func encodeStreamFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byt
return frame, nil
}
func encodeStreamFastFramePayload(frame streamFastDataFrame) ([]byte, error) {
framePayload := make([]byte, streamFastPayloadHeaderLen+len(frame.Payload))
if err := encodeStreamFastDataFrameHeader(framePayload, frame.DataID, frame.Seq, len(frame.Payload)); err != nil {
return nil, err
}
framePayload[6] = frame.Flags
copy(framePayload[streamFastPayloadHeaderLen:], frame.Payload)
return framePayload, nil
}
func encodeStreamFastFramePayloadFast(encode transportFastPlainEncoder, secretKey []byte, frame streamFastDataFrame) ([]byte, error) {
if encode == nil {
return nil, errTransportPayloadEncryptFailed
}
plainLen := streamFastPayloadHeaderLen + len(frame.Payload)
return encode(secretKey, plainLen, func(dst []byte) error {
if err := encodeStreamFastDataFrameHeader(dst, frame.DataID, frame.Seq, len(frame.Payload)); err != nil {
return err
}
dst[6] = frame.Flags
copy(dst[streamFastPayloadHeaderLen:], frame.Payload)
return nil
})
}
func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error) {
if len(payload) < 4 || string(payload[:4]) != streamFastPayloadMagic {
return streamFastDataFrame{}, false, nil
@ -77,18 +155,66 @@ func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error
}, true, nil
}
func (c *ClientCommon) encodeFastStreamDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
if c != nil && c.fastStreamEncode != nil {
return c.fastStreamEncode(c.SecretKey, dataID, seq, chunk)
func (c *ClientCommon) encodeFastStreamPayload(frame streamFastDataFrame) ([]byte, error) {
if c != nil && c.fastStreamEncode != nil && frame.Flags == 0 {
return c.fastStreamEncode(c.SecretKey, frame.DataID, frame.Seq, frame.Payload)
}
plain, err := encodeStreamFastDataFrame(dataID, seq, chunk)
if c != nil && c.fastPlainEncode != nil {
return encodeStreamFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame)
}
plain, err := encodeStreamFastFramePayload(frame)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (c *ClientCommon) sendFastStreamData(dataID uint64, seq uint64, chunk []byte) error {
func (c *ClientCommon) encodeFastStreamDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
return c.encodeFastStreamPayload(streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: chunk,
})
}
func (c *ClientCommon) encodeFastStreamBatchPayload(frames []streamFastDataFrame) ([]byte, error) {
if c == nil {
return nil, errStreamClientNil
}
if c.fastPlainEncode != nil {
return encodeStreamFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames)
}
plain, err := encodeStreamFastBatchPlain(frames)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (c *ClientCommon) sendFastStreamData(ctx context.Context, stream *streamHandle, chunk []byte) error {
if stream == nil {
return io.ErrClosedPipe
}
dataID := stream.dataIDSnapshot()
fastPathVersion := stream.fastPathVersionSnapshot()
if binding := c.clientTransportBindingSnapshot(); binding != nil && streamFastPathSupportsBatch(fastPathVersion) {
if sender := binding.clientStreamBatchSenderSnapshot(c); sender != nil {
if maxPayload := streamAdaptiveFramePayloadLimit(binding); maxPayload > 0 && len(chunk) > maxPayload {
startSeq := stream.reserveOutboundDataSeqs(streamFastSplitFrameCount(len(chunk), maxPayload))
return sender.submitFrames(ctx, fastPathVersion, buildStreamFastSplitFrames(dataID, startSeq, chunk, maxPayload))
}
seq := stream.reserveOutboundDataSeqs(1)
if len(chunk) < streamFastBatchDirectLimit {
return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk)
}
payload, err := c.encodeFastStreamDataPayload(dataID, seq, chunk)
if err != nil {
return err
}
return sender.submitEncoded(ctx, fastPathVersion, payload)
}
}
seq := stream.reserveOutboundDataSeqs(1)
payload, err := c.encodeFastStreamDataPayload(dataID, seq, chunk)
if err != nil {
return err
@ -96,29 +222,78 @@ func (c *ClientCommon) sendFastStreamData(dataID uint64, seq uint64, chunk []byt
return c.writePayloadToTransport(payload)
}
func (s *ServerCommon) encodeFastStreamDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
if logical != nil {
if fastStreamEncode := logical.fastStreamEncodeSnapshot(); fastStreamEncode != nil {
return fastStreamEncode(logical.secretKeySnapshot(), dataID, seq, chunk)
}
func (s *ServerCommon) encodeFastStreamPayloadLogical(logical *LogicalConn, frame streamFastDataFrame) ([]byte, error) {
if logical == nil {
return nil, errTransportDetached
}
plain, err := encodeStreamFastDataFrame(dataID, seq, chunk)
if fastStreamEncode := logical.fastStreamEncodeSnapshot(); fastStreamEncode != nil && frame.Flags == 0 {
return fastStreamEncode(logical.secretKeySnapshot(), frame.DataID, frame.Seq, frame.Payload)
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeStreamFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame)
}
plain, err := encodeStreamFastFramePayload(frame)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func (s *ServerCommon) sendFastStreamDataTransport(logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte) error {
func (s *ServerCommon) encodeFastStreamDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
return s.encodeFastStreamPayloadLogical(logical, streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: chunk,
})
}
func (s *ServerCommon) encodeFastStreamBatchPayloadLogical(logical *LogicalConn, frames []streamFastDataFrame) ([]byte, error) {
if logical == nil {
return nil, errTransportDetached
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeStreamFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames)
}
plain, err := encodeStreamFastBatchPlain(frames)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func (s *ServerCommon) sendFastStreamDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, stream *streamHandle, chunk []byte) error {
if err := s.ensureServerTransportSendReady(transport); err != nil {
return err
}
if stream == nil {
return io.ErrClosedPipe
}
if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot()
}
if logical == nil {
return errTransportDetached
}
dataID := stream.dataIDSnapshot()
fastPathVersion := stream.fastPathVersionSnapshot()
if binding := logical.transportBindingSnapshot(); binding != nil && binding.queueSnapshot() != nil && streamFastPathSupportsBatch(fastPathVersion) {
if sender := binding.serverStreamBatchSenderSnapshot(logical); sender != nil {
if maxPayload := streamAdaptiveFramePayloadLimit(binding); maxPayload > 0 && len(chunk) > maxPayload {
startSeq := stream.reserveOutboundDataSeqs(streamFastSplitFrameCount(len(chunk), maxPayload))
return sender.submitFrames(ctx, fastPathVersion, buildStreamFastSplitFrames(dataID, startSeq, chunk, maxPayload))
}
seq := stream.reserveOutboundDataSeqs(1)
if len(chunk) < streamFastBatchDirectLimit {
return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk)
}
payload, err := s.encodeFastStreamDataPayloadLogical(logical, dataID, seq, chunk)
if err != nil {
return err
}
return sender.submitEncoded(ctx, fastPathVersion, payload)
}
}
seq := stream.reserveOutboundDataSeqs(1)
payload, err := s.encodeFastStreamDataPayloadLogical(logical, dataID, seq, chunk)
if err != nil {
return err

View File

@ -4,6 +4,8 @@ import (
"b612.me/stario"
"context"
"math"
"sync"
"sync/atomic"
"testing"
"time"
)
@ -31,6 +33,46 @@ func TestStreamFastDataFrameRoundTrip(t *testing.T) {
}
}
func TestStreamFastBatchPlainRoundTrip(t *testing.T) {
frames := []streamFastDataFrame{
{
DataID: 11,
Seq: 7,
Payload: []byte("alpha"),
},
{
DataID: 12,
Seq: 8,
Payload: []byte("beta"),
},
}
wire, err := encodeStreamFastBatchPlain(frames)
if err != nil {
t.Fatalf("encodeStreamFastBatchPlain failed: %v", err)
}
decoded, matched, err := decodeStreamFastBatchPlain(wire)
if err != nil {
t.Fatalf("decodeStreamFastBatchPlain failed: %v", err)
}
if !matched {
t.Fatal("decodeStreamFastBatchPlain should match encoded batch")
}
if got, want := len(decoded), len(frames); got != want {
t.Fatalf("decoded frame count = %d, want %d", got, want)
}
for index := range frames {
if got, want := decoded[index].DataID, frames[index].DataID; got != want {
t.Fatalf("frame %d data id = %d, want %d", index, got, want)
}
if got, want := decoded[index].Seq, frames[index].Seq; got != want {
t.Fatalf("frame %d seq = %d, want %d", index, got, want)
}
if got, want := string(decoded[index].Payload), string(frames[index].Payload); got != want {
t.Fatalf("frame %d payload = %q, want %q", index, got, want)
}
}
}
func TestClientDispatchInboundTransportPayloadFastStream(t *testing.T) {
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
@ -60,6 +102,165 @@ func TestClientDispatchInboundTransportPayloadFastStream(t *testing.T) {
readStreamExactly(t, stream, "fast-payload", 2*time.Second)
}
func TestClientDispatchInboundTransportPayloadFastStreamBatch(t *testing.T) {
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
runtime := client.getStreamRuntime()
if runtime == nil {
t.Fatal("client stream runtime should not be nil")
}
streamA := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "fast-client-a",
DataID: 23,
FastPathVersion: streamFastPathVersionCurrent,
Channel: StreamDataChannel,
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
streamB := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "fast-client-b",
DataID: 24,
FastPathVersion: streamFastPathVersionCurrent,
Channel: StreamDataChannel,
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
if err := runtime.register(clientFileScope(), streamA); err != nil {
t.Fatalf("register streamA failed: %v", err)
}
if err := runtime.register(clientFileScope(), streamB); err != nil {
t.Fatalf("register streamB failed: %v", err)
}
payload, err := client.encodeFastStreamBatchPayload([]streamFastDataFrame{
{DataID: 23, Seq: 1, Payload: []byte("fast-a")},
{DataID: 24, Seq: 2, Payload: []byte("fast-b")},
})
if err != nil {
t.Fatalf("encodeFastStreamBatchPayload failed: %v", err)
}
if err := client.dispatchInboundTransportPayload(payload, time.Now()); err != nil {
t.Fatalf("dispatchInboundTransportPayload failed: %v", err)
}
readStreamExactly(t, streamA, "fast-a", 2*time.Second)
readStreamExactly(t, streamB, "fast-b", 2*time.Second)
}
func TestStreamBatchSenderEncodeRequestsCoalescesFastV2Frames(t *testing.T) {
var (
singleCalls int
batchCalls [][]streamFastDataFrame
)
sender := &streamBatchSender{
codec: streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
singleCalls++
return []byte("single"), nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
cloned := make([]streamFastDataFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil
},
},
}
payloads, err := sender.encodeRequests([]streamBatchRequest{
{
frames: []streamFastDataFrame{{
DataID: 101,
Seq: 1,
Payload: []byte("a"),
}},
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{{
DataID: 102,
Seq: 2,
Payload: []byte("b"),
}},
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 1; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if singleCalls != 0 {
t.Fatalf("single encode calls = %d, want 0", singleCalls)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("batched frame count = %d, want %d", got, want)
}
}
func TestStreamBatchSenderEncodeRequestsFlushesBeforePreEncodedPayload(t *testing.T) {
var (
singleCalls int
batchCalls [][]streamFastDataFrame
)
sender := &streamBatchSender{
codec: streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
singleCalls++
return append([]byte("single-"), frame.Payload...), nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
cloned := make([]streamFastDataFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch-2"), nil
},
},
}
payloads, err := sender.encodeRequests([]streamBatchRequest{
{
frames: []streamFastDataFrame{
{DataID: 101, Seq: 1, Payload: []byte("a")},
{DataID: 102, Seq: 2, Payload: []byte("b")},
},
fastPathVersion: streamFastPathVersionV2,
},
{
encodedPayload: []byte("raw"),
hasEncoded: true,
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{
{DataID: 103, Seq: 3, Payload: []byte("c")},
},
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 3; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := string(payloads[0]), "batch-2"; got != want {
t.Fatalf("first payload = %q, want %q", got, want)
}
if got, want := string(payloads[1]), "raw"; got != want {
t.Fatalf("second payload = %q, want %q", got, want)
}
if got, want := string(payloads[2]), "single-c"; got != want {
t.Fatalf("third payload = %q, want %q", got, want)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if singleCalls != 1 {
t.Fatalf("single encode calls = %d, want %d", singleCalls, 1)
}
}
func TestClientPushMessageFastDispatchesDirectWithRuntimeDispatcher(t *testing.T) {
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
@ -110,3 +311,216 @@ func TestClientPushMessageFastDispatchesDirectWithRuntimeDispatcher(t *testing.T
default:
}
}
func TestStreamBatchSenderEncodeRequestsResetsBatchBytesAfterFlushBoundary(t *testing.T) {
var (
singleCalls int
batchCalls [][]streamFastDataFrame
)
sender := &streamBatchSender{
codec: streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
singleCalls++
return []byte("single"), nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
cloned := make([]streamFastDataFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil
},
},
}
largePayload := make([]byte, streamFastBatchMaxPlainBytes-streamFastBatchHeaderLen-streamFastBatchItemHeaderLen-128)
payloads, err := sender.encodeRequests([]streamBatchRequest{
{
frames: []streamFastDataFrame{{
DataID: 101,
Seq: 1,
Payload: largePayload,
}},
fastPathVersion: streamFastPathVersionV2,
},
{
encodedPayload: []byte("raw"),
hasEncoded: true,
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{{
DataID: 202,
Seq: 1,
Payload: []byte("a"),
}},
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{{
DataID: 202,
Seq: 2,
Payload: []byte("b"),
}},
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 3; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := singleCalls, 1; got != want {
t.Fatalf("single encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("post-flush batched frame count = %d, want %d", got, want)
}
}
func TestStreamBatchSenderEncodeRequestsUsesAdaptiveSoftLimit(t *testing.T) {
binding := &transportBinding{}
binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil)
var (
singleCalls int
batchCalls int
)
sender := &streamBatchSender{
binding: binding,
codec: streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
singleCalls++
return []byte("single"), nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
batchCalls++
return []byte("batch"), nil
},
},
}
payload := make([]byte, 160*1024)
payloads, err := sender.encodeRequests([]streamBatchRequest{
{
frames: []streamFastDataFrame{{
DataID: 101,
Seq: 1,
Payload: payload,
}},
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{{
DataID: 102,
Seq: 2,
Payload: payload,
}},
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 2; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := singleCalls, 2; got != want {
t.Fatalf("single encode calls = %d, want %d", got, want)
}
if batchCalls != 0 {
t.Fatalf("batch encode calls = %d, want 0", batchCalls)
}
}
func TestClientSendFastStreamDataSplitsLargeChunkWhenAdaptiveSoftLimitShrinks(t *testing.T) {
binding := newTransportBinding(&delayedWriteConn{}, stario.NewQueue())
binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil)
var (
mu sync.Mutex
singleFrames []streamFastDataFrame
batchFrames [][]streamFastDataFrame
)
binding.streamSender = newStreamBatchSender(binding, streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
mu.Lock()
singleFrames = append(singleFrames, streamFastDataFrame{
DataID: frame.DataID,
Seq: frame.Seq,
Payload: append([]byte(nil), frame.Payload...),
})
mu.Unlock()
return []byte{1}, nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
cloned := make([]streamFastDataFrame, len(frames))
for index := range frames {
cloned[index] = streamFastDataFrame{
DataID: frames[index].DataID,
Seq: frames[index].Seq,
Payload: append([]byte(nil), frames[index].Payload...),
}
}
mu.Lock()
batchFrames = append(batchFrames, cloned)
mu.Unlock()
return []byte{2}, nil
},
}, nil)
defer binding.stopBackgroundWorkers()
client := NewClient().(*ClientCommon)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
client.setClientSessionRuntime(&clientSessionRuntime{
transport: binding,
transportAttached: true,
stopCtx: stopCtx,
stopFn: stopFn,
queue: binding.queueSnapshot(),
inboundDispatcher: newInboundDispatcher(),
suppressGoodByeOnStop: &atomic.Bool{},
})
stream := &streamHandle{
dataID: 41,
fastPathVersion: streamFastPathVersionV2,
}
chunk := make([]byte, streamFastBatchDirectLimit+128*1024)
for index := range chunk {
chunk[index] = byte(index)
}
if err := client.sendFastStreamData(context.Background(), stream, chunk); err != nil {
t.Fatalf("sendFastStreamData failed: %v", err)
}
expectedFrames := streamFastSplitFrameCount(len(chunk), streamAdaptiveFramePayloadLimit(binding))
if got, want := int(stream.outboundSeq.Load()), expectedFrames; got != want {
t.Fatalf("reserved seq count = %d, want %d", got, want)
}
mu.Lock()
defer mu.Unlock()
if len(batchFrames) != 0 {
t.Fatalf("batch encode calls = %d, want 0", len(batchFrames))
}
if got, want := len(singleFrames), expectedFrames; got != want {
t.Fatalf("single frame count = %d, want %d", got, want)
}
rebuilt := make([]byte, 0, len(chunk))
for index, frame := range singleFrames {
if got, want := frame.DataID, uint64(41); got != want {
t.Fatalf("frame %d data id = %d, want %d", index, got, want)
}
if got, want := frame.Seq, uint64(index+1); got != want {
t.Fatalf("frame %d seq = %d, want %d", index, got, want)
}
rebuilt = append(rebuilt, frame.Payload...)
}
if string(rebuilt) != string(chunk) {
t.Fatal("rebuilt payload does not match original chunk")
}
}

View File

@ -3,15 +3,19 @@ package notify
import (
"context"
"sync"
"sync/atomic"
)
type streamFlowController struct {
mu sync.Mutex
queue []*streamFlowRequest
inFlightBytes int
inFlightChunks int
windowBytes int
maxChunks int
mu sync.Mutex
queue []*streamFlowRequest
inFlightBytes atomic.Int64
inFlightChunks atomic.Int64
windowBytes atomic.Int64
maxChunks atomic.Int64
waiters atomic.Int32
}
type streamFlowRequest struct {
@ -22,10 +26,10 @@ type streamFlowRequest struct {
func newStreamFlowController(cfg streamConfig) *streamFlowController {
cfg = normalizeStreamConfig(cfg)
return &streamFlowController{
windowBytes: cfg.OutboundWindowBytes,
maxChunks: cfg.OutboundMaxInFlightChunks,
}
controller := &streamFlowController{}
controller.windowBytes.Store(int64(cfg.OutboundWindowBytes))
controller.maxChunks.Store(int64(cfg.OutboundMaxInFlightChunks))
return controller
}
func (c *streamFlowController) applyConfig(cfg streamConfig) {
@ -33,9 +37,12 @@ func (c *streamFlowController) applyConfig(cfg streamConfig) {
return
}
cfg = normalizeStreamConfig(cfg)
c.windowBytes.Store(int64(cfg.OutboundWindowBytes))
c.maxChunks.Store(int64(cfg.OutboundMaxInFlightChunks))
if c.waiters.Load() == 0 {
return
}
c.mu.Lock()
c.windowBytes = cfg.OutboundWindowBytes
c.maxChunks = cfg.OutboundMaxInFlightChunks
c.drainLocked()
c.mu.Unlock()
}
@ -47,58 +54,32 @@ func (c *streamFlowController) acquire(ctx context.Context, size int) (func(), e
if ctx == nil {
ctx = context.Background()
}
if c.tryAcquire(size) {
return c.releaseFunc(size), nil
}
req := &streamFlowRequest{
size: size,
ready: make(chan struct{}),
}
c.mu.Lock()
if c.tryAcquireLocked(size) {
c.mu.Unlock()
return c.releaseFunc(size), nil
}
c.queue = append(c.queue, req)
c.waiters.Add(1)
c.drainLocked()
c.mu.Unlock()
select {
case <-req.ready:
released := false
return func() {
c.mu.Lock()
if released {
c.mu.Unlock()
return
}
released = true
c.inFlightBytes -= size
if c.inFlightBytes < 0 {
c.inFlightBytes = 0
}
if c.inFlightChunks > 0 {
c.inFlightChunks--
}
c.drainLocked()
c.mu.Unlock()
}, nil
return c.releaseFunc(size), nil
case <-ctx.Done():
c.mu.Lock()
if req.admitted {
c.mu.Unlock()
released := false
return func() {
c.mu.Lock()
if released {
c.mu.Unlock()
return
}
released = true
c.inFlightBytes -= size
if c.inFlightBytes < 0 {
c.inFlightBytes = 0
}
if c.inFlightChunks > 0 {
c.inFlightChunks--
}
c.drainLocked()
c.mu.Unlock()
}, nil
return c.releaseFunc(size), nil
}
c.removeLocked(req)
c.drainLocked()
@ -107,6 +88,128 @@ func (c *streamFlowController) acquire(ctx context.Context, size int) (func(), e
}
}
func (c *streamFlowController) tryAcquire(size int) bool {
if c == nil || size <= 0 {
return true
}
if c.waiters.Load() != 0 {
return false
}
return c.tryAcquireCAS(size)
}
func (c *streamFlowController) tryAcquireLocked(size int) bool {
if c == nil || size <= 0 {
return true
}
if len(c.queue) != 0 {
return false
}
return c.tryAcquireCAS(size)
}
func (c *streamFlowController) tryAcquireCAS(size int) bool {
if c == nil || size <= 0 {
return true
}
size64 := int64(size)
for {
window := c.windowBytes.Load()
maxChunks := c.maxChunks.Load()
inFlightBytes := c.inFlightBytes.Load()
inFlightChunks := c.inFlightChunks.Load()
if maxChunks > 0 && inFlightChunks >= maxChunks {
return false
}
if window > 0 && inFlightBytes+size64 > window {
if !(inFlightBytes == 0 && inFlightChunks == 0) {
return false
}
}
if !c.inFlightBytes.CompareAndSwap(inFlightBytes, inFlightBytes+size64) {
continue
}
if c.addChunksCAS(1, maxChunks) {
return true
}
c.subBytesCAS(size64)
return false
}
}
func (c *streamFlowController) addChunksCAS(delta int64, maxChunks int64) bool {
if c == nil || delta <= 0 {
return true
}
for {
current := c.inFlightChunks.Load()
if maxChunks > 0 && current+delta > maxChunks {
return false
}
if c.inFlightChunks.CompareAndSwap(current, current+delta) {
return true
}
}
}
func (c *streamFlowController) subBytesCAS(delta int64) {
if c == nil || delta <= 0 {
return
}
for {
current := c.inFlightBytes.Load()
next := current - delta
if next < 0 {
next = 0
}
if c.inFlightBytes.CompareAndSwap(current, next) {
return
}
}
}
func (c *streamFlowController) subChunksCAS(delta int64) {
if c == nil || delta <= 0 {
return
}
for {
current := c.inFlightChunks.Load()
next := current - delta
if next < 0 {
next = 0
}
if c.inFlightChunks.CompareAndSwap(current, next) {
return
}
}
}
func (c *streamFlowController) releaseFunc(size int) func() {
released := false
return func() {
if released {
return
}
released = true
c.release(size)
}
}
func (c *streamFlowController) release(size int) {
if c == nil || size <= 0 {
return
}
c.subBytesCAS(int64(size))
c.subChunksCAS(1)
if c.waiters.Load() == 0 {
return
}
c.mu.Lock()
c.drainLocked()
c.mu.Unlock()
}
func (c *streamFlowController) removeLocked(req *streamFlowRequest) {
if c == nil || req == nil {
return
@ -118,6 +221,7 @@ func (c *streamFlowController) removeLocked(req *streamFlowRequest) {
copy(c.queue[i:], c.queue[i+1:])
c.queue[len(c.queue)-1] = nil
c.queue = c.queue[:len(c.queue)-1]
c.waiters.Add(-1)
return
}
}
@ -132,18 +236,17 @@ func (c *streamFlowController) drainLocked() {
c.queue = c.queue[1:]
continue
}
if c.maxChunks > 0 && c.inFlightChunks >= c.maxChunks {
if !c.canAdmitLocked(req.size) {
return
}
if !c.canAdmitLocked(req.size) {
if !c.tryAcquireCAS(req.size) {
return
}
copy(c.queue[0:], c.queue[1:])
c.queue[len(c.queue)-1] = nil
c.queue = c.queue[:len(c.queue)-1]
c.waiters.Add(-1)
req.admitted = true
c.inFlightBytes += req.size
c.inFlightChunks++
close(req.ready)
}
}
@ -155,11 +258,18 @@ func (c *streamFlowController) canAdmitLocked(size int) bool {
if size <= 0 {
return true
}
if c.windowBytes <= 0 {
window := c.windowBytes.Load()
chunks := c.inFlightChunks.Load()
bytes := c.inFlightBytes.Load()
maxChunks := c.maxChunks.Load()
if maxChunks > 0 && chunks >= maxChunks {
return false
}
if window <= 0 {
return true
}
if c.inFlightBytes+size <= c.windowBytes {
if bytes+int64(size) <= window {
return true
}
return c.inFlightBytes == 0 && c.inFlightChunks == 0
return bytes == 0 && chunks == 0
}

View File

@ -105,3 +105,51 @@ func TestStreamFlowControllerAdmitsRequestsFIFO(t *testing.T) {
t.Fatalf("second admitted request = %d, want 3", second)
}
}
func TestStreamFlowControllerTryAcquireDoesNotBypassQueuedWaiter(t *testing.T) {
controller := newStreamFlowController(streamConfig{
ChunkSize: 4,
InboundQueueLimit: 1,
InboundBufferedBytesLimit: 4,
OutboundWindowBytes: 4,
OutboundMaxInFlightChunks: 1,
})
releaseFirst, err := controller.acquire(context.Background(), 4)
if err != nil {
t.Fatalf("first acquire failed: %v", err)
}
defer releaseFirst()
waiterReady := make(chan struct{})
waiterAcquired := make(chan func(), 1)
go func() {
close(waiterReady)
releaseSecond, err := controller.acquire(context.Background(), 4)
if err != nil {
t.Errorf("waiter acquire failed: %v", err)
return
}
waiterAcquired <- releaseSecond
}()
<-waiterReady
time.Sleep(20 * time.Millisecond)
if controller.tryAcquire(4) {
t.Fatal("tryAcquire should not bypass queued waiter")
}
releaseFirst()
var releaseSecond func()
select {
case releaseSecond = <-waiterAcquired:
case <-time.After(time.Second):
t.Fatal("timed out waiting for queued waiter to acquire")
}
if releaseSecond == nil {
t.Fatal("queued waiter returned nil release func")
}
releaseSecond()
}

View File

@ -3,7 +3,6 @@ package notify
import (
"context"
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
@ -17,7 +16,7 @@ type streamRuntime struct {
mu sync.RWMutex
handler func(StreamAcceptInfo) error
streams map[string]*streamHandle
data map[string]*streamHandle
data map[string]map[uint64]*streamHandle
cfg streamConfig
flow *streamFlowController
}
@ -27,7 +26,7 @@ func newStreamRuntime(rolePrefix string) *streamRuntime {
return &streamRuntime{
rolePrefix: rolePrefix,
streams: make(map[string]*streamHandle),
data: make(map[string]*streamHandle),
data: make(map[string]map[uint64]*streamHandle),
cfg: cfg,
flow: newStreamFlowController(cfg),
}
@ -72,18 +71,23 @@ func (r *streamRuntime) register(scope string, stream *streamHandle) error {
if stream == nil || stream.id == "" {
return errStreamIDEmpty
}
scope = normalizeFileScope(scope)
key := streamRuntimeKey(scope, stream.id)
dataKey := streamRuntimeDataKey(scope, stream.dataID)
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.streams[key]; ok {
return errStreamAlreadyExists
}
if stream.dataID != 0 {
if _, ok := r.data[dataKey]; ok {
dataScope := r.data[scope]
if dataScope == nil {
dataScope = make(map[uint64]*streamHandle)
r.data[scope] = dataScope
}
if _, ok := dataScope[stream.dataID]; ok {
return errStreamAlreadyExists
}
r.data[dataKey] = stream
dataScope[stream.dataID] = stream
}
r.streams[key] = stream
return nil
@ -104,10 +108,14 @@ func (r *streamRuntime) lookupByDataID(scope string, dataID uint64) (*streamHand
if r == nil || dataID == 0 {
return nil, false
}
key := streamRuntimeDataKey(scope, dataID)
scope = normalizeFileScope(scope)
r.mu.RLock()
defer r.mu.RUnlock()
stream, ok := r.data[key]
dataScope := r.data[scope]
if dataScope == nil {
return nil, false
}
stream, ok := dataScope[dataID]
return stream, ok
}
@ -115,11 +123,17 @@ func (r *streamRuntime) remove(scope string, streamID string) {
if r == nil || streamID == "" {
return
}
scope = normalizeFileScope(scope)
key := streamRuntimeKey(scope, streamID)
r.mu.Lock()
defer r.mu.Unlock()
if stream := r.streams[key]; stream != nil && stream.dataID != 0 {
delete(r.data, streamRuntimeDataKey(scope, stream.dataID))
if dataScope := r.data[scope]; dataScope != nil {
delete(dataScope, stream.dataID)
if len(dataScope) == 0 {
delete(r.data, scope)
}
}
}
delete(r.streams, key)
}
@ -131,6 +145,20 @@ func (r *streamRuntime) acquireOutbound(ctx context.Context, size int) (func(),
return r.flow.acquire(ctx, size)
}
func (r *streamRuntime) tryAcquireOutbound(size int) bool {
if r == nil || r.flow == nil {
return true
}
return r.flow.tryAcquire(size)
}
func (r *streamRuntime) releaseOutbound(size int) {
if r == nil || r.flow == nil {
return
}
r.flow.release(size)
}
func (r *streamRuntime) snapshots() []StreamSnapshot {
if r == nil {
return nil
@ -182,10 +210,6 @@ func streamRuntimeKey(scope string, streamID string) string {
return normalizeFileScope(scope) + "\x00" + streamID
}
func streamRuntimeDataKey(scope string, dataID uint64) string {
return normalizeFileScope(scope) + "\x01" + strconv.FormatUint(dataID, 10)
}
func (c *ClientCommon) getStreamRuntime() *streamRuntime {
if c == nil {
return nil

145
stream_shared_batch.go Normal file
View File

@ -0,0 +1,145 @@
package notify
import "encoding/binary"
const (
streamFastPathVersionV1 = 1
streamFastPathVersionV2 = 2
streamFastPathVersionCurrent = streamFastPathVersionV2
)
const (
streamFastBatchMagic = "NSB1"
streamFastBatchVersion = 1
streamFastBatchHeaderLen = 12
streamFastBatchItemHeaderLen = 24
streamFastBatchMaxItems = 64
streamFastBatchMaxPlainBytes = 8 * 1024 * 1024
)
func normalizeStreamFastPathVersion(version uint8) uint8 {
if version < streamFastPathVersionV1 {
return streamFastPathVersionV1
}
if version > streamFastPathVersionCurrent {
return streamFastPathVersionCurrent
}
return version
}
func negotiateStreamFastPathVersion(version uint8) uint8 {
return normalizeStreamFastPathVersion(version)
}
func streamFastPathSupportsBatch(version uint8) bool {
return normalizeStreamFastPathVersion(version) >= streamFastPathVersionV2
}
func streamFastBatchFrameLen(frame streamFastDataFrame) int {
return streamFastBatchItemHeaderLen + len(frame.Payload)
}
func streamFastBatchPlainLen(frames []streamFastDataFrame) int {
total := streamFastBatchHeaderLen
for _, frame := range frames {
total += streamFastBatchFrameLen(frame)
}
return total
}
func encodeStreamFastBatchPlain(frames []streamFastDataFrame) ([]byte, error) {
if len(frames) == 0 {
return nil, errStreamFastPayloadInvalid
}
buf := make([]byte, streamFastBatchPlainLen(frames))
if err := writeStreamFastBatchPlain(buf, frames); err != nil {
return nil, err
}
return buf, nil
}
func encodeStreamFastBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, frames []streamFastDataFrame) ([]byte, error) {
if encode == nil {
return nil, errTransportPayloadEncryptFailed
}
plainLen := streamFastBatchPlainLen(frames)
return encode(secretKey, plainLen, func(dst []byte) error {
return writeStreamFastBatchPlain(dst, frames)
})
}
func writeStreamFastBatchPlain(dst []byte, frames []streamFastDataFrame) error {
if len(frames) == 0 || len(dst) != streamFastBatchPlainLen(frames) {
return errStreamFastPayloadInvalid
}
copy(dst[:4], streamFastBatchMagic)
dst[4] = streamFastBatchVersion
binary.BigEndian.PutUint32(dst[8:12], uint32(len(frames)))
offset := streamFastBatchHeaderLen
for _, frame := range frames {
if frame.DataID == 0 {
return errStreamFastPayloadInvalid
}
dst[offset] = frame.Flags
binary.BigEndian.PutUint64(dst[offset+4:offset+12], frame.DataID)
binary.BigEndian.PutUint64(dst[offset+12:offset+20], frame.Seq)
binary.BigEndian.PutUint32(dst[offset+20:offset+24], uint32(len(frame.Payload)))
offset += streamFastBatchItemHeaderLen
copy(dst[offset:offset+len(frame.Payload)], frame.Payload)
offset += len(frame.Payload)
}
return nil
}
func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, error) {
if len(payload) < 4 || string(payload[:4]) != streamFastBatchMagic {
return nil, false, nil
}
if len(payload) < streamFastBatchHeaderLen {
return nil, true, errStreamFastPayloadInvalid
}
if payload[4] != streamFastBatchVersion {
return nil, true, errStreamFastPayloadInvalid
}
count := int(binary.BigEndian.Uint32(payload[8:12]))
if count <= 0 {
return nil, true, errStreamFastPayloadInvalid
}
frames := make([]streamFastDataFrame, 0, count)
offset := streamFastBatchHeaderLen
for index := 0; index < count; index++ {
if len(payload)-offset < streamFastBatchItemHeaderLen {
return nil, true, errStreamFastPayloadInvalid
}
flags := payload[offset]
dataID := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
seq := binary.BigEndian.Uint64(payload[offset+12 : offset+20])
payloadLen := int(binary.BigEndian.Uint32(payload[offset+20 : offset+24]))
offset += streamFastBatchItemHeaderLen
if dataID == 0 || payloadLen < 0 || len(payload)-offset < payloadLen {
return nil, true, errStreamFastPayloadInvalid
}
frames = append(frames, streamFastDataFrame{
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload[offset : offset+payloadLen],
})
offset += payloadLen
}
if offset != len(payload) {
return nil, true, errStreamFastPayloadInvalid
}
return frames, true, nil
}
func decodeStreamFastDataFrames(payload []byte) ([]streamFastDataFrame, bool, error) {
if frames, matched, err := decodeStreamFastBatchPlain(payload); matched {
return frames, true, err
}
frame, matched, err := decodeStreamFastDataFrame(payload)
if !matched || err != nil {
return nil, matched, err
}
return []streamFastDataFrame{frame}, true, nil
}

View File

@ -7,48 +7,52 @@ import (
)
type StreamSnapshot struct {
ID string
DataID uint64
Scope string
Channel StreamChannel
Metadata StreamMetadata
BindingOwner string
BindingAlive bool
BindingCurrent bool
BindingReason string
BindingError string
SessionEpoch uint64
LogicalClientID string
LocalAddress string
RemoteAddress string
TransportGeneration uint64
TransportAttached bool
TransportHasRuntimeConn bool
TransportCurrent bool
TransportDetachReason string
TransportDetachKind string
TransportDetachGeneration uint64
TransportDetachError string
TransportDetachedAt time.Time
ReattachEligible bool
LocalClosed bool
LocalReadClosed bool
RemoteClosed bool
PeerReadClosed bool
BufferedChunks int
BufferedBytes int
ReadTimeout time.Duration
WriteTimeout time.Duration
BytesRead int64
BytesWritten int64
ReadCalls int64
WriteCalls int64
OpenedAt time.Time
LastReadAt time.Time
LastWriteAt time.Time
ReadDeadline time.Time
WriteDeadline time.Time
ResetError string
ID string
DataID uint64
Scope string
Channel StreamChannel
Metadata StreamMetadata
BindingOwner string
BindingAlive bool
BindingCurrent bool
BindingReason string
BindingError string
BindingBulkAdaptiveSoftPayloadBytes int
BindingStreamAdaptiveSoftPayloadBytes int
BindingStreamAdaptiveWaitThresholdBytes int
BindingStreamAdaptiveFlushDelay time.Duration
SessionEpoch uint64
LogicalClientID string
LocalAddress string
RemoteAddress string
TransportGeneration uint64
TransportAttached bool
TransportHasRuntimeConn bool
TransportCurrent bool
TransportDetachReason string
TransportDetachKind string
TransportDetachGeneration uint64
TransportDetachError string
TransportDetachedAt time.Time
ReattachEligible bool
LocalClosed bool
LocalReadClosed bool
RemoteClosed bool
PeerReadClosed bool
BufferedChunks int
BufferedBytes int
ReadTimeout time.Duration
WriteTimeout time.Duration
BytesRead int64
BytesWritten int64
ReadCalls int64
WriteCalls int64
OpenedAt time.Time
LastReadAt time.Time
LastWriteAt time.Time
ReadDeadline time.Time
WriteDeadline time.Time
ResetError string
}
type clientStreamSnapshotReader interface {

View File

@ -15,9 +15,14 @@ type transportBinding struct {
queue *stario.StarQueue
writeMu sync.Mutex
adaptiveTx adaptiveTxState
controlMu sync.Mutex
controlSender *controlBatchSender
streamMu sync.Mutex
streamSender *streamBatchSender
bulkMu sync.Mutex
bulkSender *bulkBatchSender
}
@ -71,7 +76,7 @@ func (b *transportBinding) withConnWriteLockDeadline(deadline time.Time, fn func
return fn(conn)
}
func (b *transportBinding) bulkBatchSenderSnapshot() *bulkBatchSender {
func (b *transportBinding) bulkBatchSenderSnapshotWithCodec(codec bulkBatchCodec, writeTimeout func() time.Duration) *bulkBatchSender {
if b == nil {
return nil
}
@ -80,10 +85,39 @@ func (b *transportBinding) bulkBatchSenderSnapshot() *bulkBatchSender {
if b.bulkSender != nil {
return b.bulkSender
}
b.bulkSender = newBulkBatchSender(b)
b.bulkSender = newBulkBatchSender(b, codec, writeTimeout)
return b.bulkSender
}
func (b *transportBinding) clientBulkBatchSenderSnapshot(c *ClientCommon) *bulkBatchSender {
if b == nil || c == nil {
return nil
}
return b.bulkBatchSenderSnapshotWithCodec(bulkBatchCodec{
encodeSingle: c.encodeBulkFastPayloadPooled,
encodeBatch: c.encodeBulkFastBatchPayloadPooled,
}, c.maxWriteTimeoutSnapshot)
}
func (b *transportBinding) serverBulkBatchSenderSnapshot(logical *LogicalConn) *bulkBatchSender {
if b == nil || logical == nil {
return nil
}
server := logical.Server()
common, ok := server.(*ServerCommon)
if !ok || common == nil {
return nil
}
return b.bulkBatchSenderSnapshotWithCodec(bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
return common.encodeBulkFastPayloadLogicalPooled(logical, frame)
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
return common.encodeBulkFastBatchPayloadLogicalPooled(logical, frames)
},
}, logical.maxWriteTimeoutSnapshot)
}
func (b *transportBinding) controlBatchSenderSnapshot() *controlBatchSender {
if b == nil {
return nil
@ -97,6 +131,48 @@ func (b *transportBinding) controlBatchSenderSnapshot() *controlBatchSender {
return b.controlSender
}
func (b *transportBinding) streamBatchSenderSnapshotWithCodec(codec streamBatchCodec, writeTimeout func() time.Duration) *streamBatchSender {
if b == nil {
return nil
}
b.streamMu.Lock()
defer b.streamMu.Unlock()
if b.streamSender != nil {
return b.streamSender
}
b.streamSender = newStreamBatchSender(b, codec, writeTimeout)
return b.streamSender
}
func (b *transportBinding) clientStreamBatchSenderSnapshot(c *ClientCommon) *streamBatchSender {
if b == nil || c == nil {
return nil
}
return b.streamBatchSenderSnapshotWithCodec(streamBatchCodec{
encodeSingle: c.encodeFastStreamPayload,
encodeBatch: c.encodeFastStreamBatchPayload,
}, c.maxWriteTimeoutSnapshot)
}
func (b *transportBinding) serverStreamBatchSenderSnapshot(logical *LogicalConn) *streamBatchSender {
if b == nil || logical == nil {
return nil
}
server := logical.Server()
common, ok := server.(*ServerCommon)
if !ok || common == nil {
return nil
}
return b.streamBatchSenderSnapshotWithCodec(streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
return common.encodeFastStreamPayloadLogical(logical, frame)
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
return common.encodeFastStreamBatchPayloadLogical(logical, frames)
},
}, logical.maxWriteTimeoutSnapshot)
}
func (b *transportBinding) stopBackgroundWorkers() {
if b == nil {
return
@ -104,12 +180,18 @@ func (b *transportBinding) stopBackgroundWorkers() {
b.controlMu.Lock()
controlSender := b.controlSender
b.controlMu.Unlock()
b.streamMu.Lock()
streamSender := b.streamSender
b.streamMu.Unlock()
b.bulkMu.Lock()
bulkSender := b.bulkSender
b.bulkMu.Unlock()
if controlSender != nil {
controlSender.stop()
}
if streamSender != nil {
streamSender.stop()
}
if bulkSender != nil {
bulkSender.stop()
}

View File

@ -0,0 +1,383 @@
package notify
import (
"sync"
"time"
)
const (
bulkAdaptiveSoftPayloadMinBytes = 256 * 1024
bulkAdaptiveSoftPayloadFallbackBytes = 2 * 1024 * 1024
bulkAdaptiveSoftPayloadStartBytes = bulkFastBatchMaxPlainBytes
bulkAdaptiveSoftPayloadTargetFlush = 4 * time.Millisecond
bulkAdaptiveSoftPayloadSlowFlush = 16 * time.Millisecond
bulkAdaptiveSoftPayloadMinSampleBytes = 256 * 1024
bulkAdaptiveSoftPayloadGrowSuccesses = 8
streamAdaptiveSoftPayloadMinBytes = 256 * 1024
streamAdaptiveSoftPayloadFallbackBytes = streamBatchMaxPayloadBytes
streamAdaptiveSoftPayloadStartBytes = streamBatchMaxPayloadBytes
streamAdaptiveSoftPayloadTargetFlush = 4 * time.Millisecond
streamAdaptiveSoftPayloadSlowFlush = 16 * time.Millisecond
streamAdaptiveSoftPayloadMinSampleBytes = 64 * 1024
streamAdaptiveSoftPayloadGrowSuccesses = 8
streamAdaptiveWaitThresholdMinBytes = 32 * 1024
streamAdaptiveFlushDelayMid = 25 * time.Microsecond
)
var bulkAdaptiveSoftPayloadSteps = [...]int{
256 * 1024,
512 * 1024,
1024 * 1024,
2 * 1024 * 1024,
4 * 1024 * 1024,
bulkFastBatchMaxPlainBytes,
}
var streamAdaptiveSoftPayloadSteps = [...]int{
256 * 1024,
512 * 1024,
1024 * 1024,
streamBatchMaxPayloadBytes,
}
type adaptiveTxState struct {
mu sync.Mutex
bulkSoftPayloadBytes int
bulkGoodputBytesPerS float64
bulkGrowStreak int
streamSoftPayloadBytes int
streamGoodputBytesPerS float64
streamGrowStreak int
}
func (b *transportBinding) bulkAdaptiveSoftPayloadBytesSnapshot() int {
if b == nil {
return bulkAdaptiveSoftPayloadFallbackBytes
}
return b.adaptiveTx.bulkSoftPayloadBytesSnapshot()
}
func (b *transportBinding) observeBulkAdaptivePayloadWrite(payloadBytes int, elapsed time.Duration, timeout time.Duration, err error) {
if b == nil {
return
}
b.adaptiveTx.observeBulkPayloadWrite(payloadBytes, elapsed, timeout, err)
}
func (b *transportBinding) streamAdaptiveSoftPayloadBytesSnapshot() int {
if b == nil {
return streamAdaptiveSoftPayloadFallbackBytes
}
return b.adaptiveTx.streamSoftPayloadBytesSnapshot()
}
func (b *transportBinding) streamAdaptiveWaitThresholdBytesSnapshot() int {
if b == nil {
return streamBatchWaitThreshold
}
return b.adaptiveTx.streamWaitThresholdBytesSnapshot()
}
func (b *transportBinding) streamAdaptiveFlushDelaySnapshot() time.Duration {
if b == nil {
return streamBatchMaxFlushDelay
}
return b.adaptiveTx.streamFlushDelaySnapshot()
}
func (b *transportBinding) observeStreamAdaptivePayloadWrite(payloadBytes int, elapsed time.Duration, timeout time.Duration, err error) {
if b == nil {
return
}
b.adaptiveTx.observeStreamPayloadWrite(payloadBytes, elapsed, timeout, err)
}
func (s *adaptiveTxState) bulkSoftPayloadBytesSnapshot() int {
if s == nil {
return bulkAdaptiveSoftPayloadStartBytes
}
s.mu.Lock()
defer s.mu.Unlock()
return s.bulkSoftPayloadBytesLocked()
}
func (s *adaptiveTxState) observeBulkPayloadWrite(payloadBytes int, elapsed time.Duration, timeout time.Duration, err error) {
if s == nil || payloadBytes <= 0 {
return
}
s.mu.Lock()
defer s.mu.Unlock()
current := s.bulkSoftPayloadBytesLocked()
target, hasSample := s.observeBulkGoodputLocked(payloadBytes, elapsed)
nearTimeout := timeout > 0 && elapsed >= (timeout*3)/4
if isTimeoutLikeError(err) || nearTimeout {
s.bulkGrowStreak = 0
if hasSample && target < current {
s.bulkSoftPayloadBytes = target
return
}
s.bulkSoftPayloadBytes = previousBulkAdaptiveSoftPayloadStep(current)
return
}
if err != nil {
s.bulkGrowStreak = 0
return
}
if !hasSample {
return
}
if elapsed >= bulkAdaptiveSoftPayloadSlowFlush {
s.bulkGrowStreak = 0
if target < current {
s.bulkSoftPayloadBytes = target
return
}
s.bulkSoftPayloadBytes = previousBulkAdaptiveSoftPayloadStep(current)
return
}
if target > current {
s.bulkGrowStreak++
if s.bulkGrowStreak >= bulkAdaptiveSoftPayloadGrowSuccesses {
s.bulkSoftPayloadBytes = nextBulkAdaptiveSoftPayloadStep(current, target)
s.bulkGrowStreak = 0
}
return
}
if target < current && elapsed >= bulkAdaptiveSoftPayloadTargetFlush*2 {
s.bulkSoftPayloadBytes = target
s.bulkGrowStreak = 0
return
}
s.bulkGrowStreak = 0
}
func (s *adaptiveTxState) bulkSoftPayloadBytesLocked() int {
if s.bulkSoftPayloadBytes == 0 {
s.bulkSoftPayloadBytes = bulkAdaptiveSoftPayloadStartBytes
}
return normalizeBulkAdaptiveSoftPayloadBytes(s.bulkSoftPayloadBytes)
}
func (s *adaptiveTxState) streamSoftPayloadBytesSnapshot() int {
if s == nil {
return streamAdaptiveSoftPayloadStartBytes
}
s.mu.Lock()
defer s.mu.Unlock()
return s.streamSoftPayloadBytesLocked()
}
func (s *adaptiveTxState) streamWaitThresholdBytesSnapshot() int {
if s == nil {
return streamBatchWaitThreshold
}
s.mu.Lock()
defer s.mu.Unlock()
return streamAdaptiveWaitThresholdBytesForSoftPayload(s.streamSoftPayloadBytesLocked())
}
func (s *adaptiveTxState) streamFlushDelaySnapshot() time.Duration {
if s == nil {
return streamBatchMaxFlushDelay
}
s.mu.Lock()
defer s.mu.Unlock()
return streamAdaptiveFlushDelayForSoftPayload(s.streamSoftPayloadBytesLocked())
}
func (s *adaptiveTxState) observeStreamPayloadWrite(payloadBytes int, elapsed time.Duration, timeout time.Duration, err error) {
if s == nil || payloadBytes <= 0 {
return
}
s.mu.Lock()
defer s.mu.Unlock()
current := s.streamSoftPayloadBytesLocked()
target, hasSample := s.observeStreamGoodputLocked(payloadBytes, elapsed)
nearTimeout := timeout > 0 && elapsed >= (timeout*3)/4
if isTimeoutLikeError(err) || nearTimeout {
s.streamGrowStreak = 0
if hasSample && target < current {
s.streamSoftPayloadBytes = target
return
}
s.streamSoftPayloadBytes = previousStreamAdaptiveSoftPayloadStep(current)
return
}
if err != nil {
s.streamGrowStreak = 0
return
}
if !hasSample {
return
}
if elapsed >= streamAdaptiveSoftPayloadSlowFlush {
s.streamGrowStreak = 0
if target < current {
s.streamSoftPayloadBytes = target
return
}
s.streamSoftPayloadBytes = previousStreamAdaptiveSoftPayloadStep(current)
return
}
if target > current {
s.streamGrowStreak++
if s.streamGrowStreak >= streamAdaptiveSoftPayloadGrowSuccesses {
s.streamSoftPayloadBytes = nextStreamAdaptiveSoftPayloadStep(current, target)
s.streamGrowStreak = 0
}
return
}
if target < current && elapsed >= streamAdaptiveSoftPayloadTargetFlush*2 {
s.streamSoftPayloadBytes = target
s.streamGrowStreak = 0
return
}
s.streamGrowStreak = 0
}
func (s *adaptiveTxState) streamSoftPayloadBytesLocked() int {
if s.streamSoftPayloadBytes == 0 {
s.streamSoftPayloadBytes = streamAdaptiveSoftPayloadStartBytes
}
return normalizeStreamAdaptiveSoftPayloadBytes(s.streamSoftPayloadBytes)
}
func (s *adaptiveTxState) observeBulkGoodputLocked(payloadBytes int, elapsed time.Duration) (int, bool) {
if payloadBytes < bulkAdaptiveSoftPayloadMinSampleBytes || elapsed <= 0 {
return 0, false
}
sample := float64(payloadBytes) / elapsed.Seconds()
if sample <= 0 {
return 0, false
}
if s.bulkGoodputBytesPerS <= 0 {
s.bulkGoodputBytesPerS = sample
} else {
const alpha = 0.25
s.bulkGoodputBytesPerS = s.bulkGoodputBytesPerS*(1-alpha) + sample*alpha
}
target := int(s.bulkGoodputBytesPerS * bulkAdaptiveSoftPayloadTargetFlush.Seconds())
return normalizeBulkAdaptiveSoftPayloadBytes(target), true
}
func (s *adaptiveTxState) observeStreamGoodputLocked(payloadBytes int, elapsed time.Duration) (int, bool) {
if payloadBytes < streamAdaptiveSoftPayloadMinSampleBytes || elapsed <= 0 {
return 0, false
}
sample := float64(payloadBytes) / elapsed.Seconds()
if sample <= 0 {
return 0, false
}
if s.streamGoodputBytesPerS <= 0 {
s.streamGoodputBytesPerS = sample
} else {
const alpha = 0.25
s.streamGoodputBytesPerS = s.streamGoodputBytesPerS*(1-alpha) + sample*alpha
}
target := int(s.streamGoodputBytesPerS * streamAdaptiveSoftPayloadTargetFlush.Seconds())
return normalizeStreamAdaptiveSoftPayloadBytes(target), true
}
func normalizeBulkAdaptiveSoftPayloadBytes(size int) int {
if size <= bulkAdaptiveSoftPayloadMinBytes {
return bulkAdaptiveSoftPayloadMinBytes
}
for _, step := range bulkAdaptiveSoftPayloadSteps {
if size <= step {
return step
}
}
return bulkAdaptiveSoftPayloadStartBytes
}
func previousBulkAdaptiveSoftPayloadStep(current int) int {
current = normalizeBulkAdaptiveSoftPayloadBytes(current)
for index := len(bulkAdaptiveSoftPayloadSteps) - 1; index >= 0; index-- {
step := bulkAdaptiveSoftPayloadSteps[index]
if current > step {
return step
}
}
return bulkAdaptiveSoftPayloadMinBytes
}
func nextBulkAdaptiveSoftPayloadStep(current int, target int) int {
current = normalizeBulkAdaptiveSoftPayloadBytes(current)
target = normalizeBulkAdaptiveSoftPayloadBytes(target)
for _, step := range bulkAdaptiveSoftPayloadSteps {
if step > current {
if step > target {
return target
}
return step
}
}
return bulkAdaptiveSoftPayloadStartBytes
}
func normalizeStreamAdaptiveSoftPayloadBytes(size int) int {
if size <= streamAdaptiveSoftPayloadMinBytes {
return streamAdaptiveSoftPayloadMinBytes
}
for _, step := range streamAdaptiveSoftPayloadSteps {
if size <= step {
return step
}
}
return streamAdaptiveSoftPayloadStartBytes
}
func previousStreamAdaptiveSoftPayloadStep(current int) int {
current = normalizeStreamAdaptiveSoftPayloadBytes(current)
for index := len(streamAdaptiveSoftPayloadSteps) - 1; index >= 0; index-- {
step := streamAdaptiveSoftPayloadSteps[index]
if current > step {
return step
}
}
return streamAdaptiveSoftPayloadMinBytes
}
func nextStreamAdaptiveSoftPayloadStep(current int, target int) int {
current = normalizeStreamAdaptiveSoftPayloadBytes(current)
target = normalizeStreamAdaptiveSoftPayloadBytes(target)
for _, step := range streamAdaptiveSoftPayloadSteps {
if step > current {
if step > target {
return target
}
return step
}
}
return streamAdaptiveSoftPayloadStartBytes
}
func streamAdaptiveWaitThresholdBytesForSoftPayload(size int) int {
size = normalizeStreamAdaptiveSoftPayloadBytes(size)
threshold := size / 16
if threshold < streamAdaptiveWaitThresholdMinBytes {
return streamAdaptiveWaitThresholdMinBytes
}
if threshold > streamBatchWaitThreshold {
return streamBatchWaitThreshold
}
return threshold
}
func streamAdaptiveFlushDelayForSoftPayload(size int) time.Duration {
size = normalizeStreamAdaptiveSoftPayloadBytes(size)
switch {
case size >= streamAdaptiveSoftPayloadStartBytes:
return streamBatchMaxFlushDelay
case size >= 1024*1024:
return streamAdaptiveFlushDelayMid
default:
return 0
}
}

View File

@ -0,0 +1,160 @@
package notify
import (
"b612.me/stario"
"net"
"testing"
"time"
)
func TestTransportBindingAdaptiveBulkSoftPayloadStartsAggressive(t *testing.T) {
binding := &transportBinding{}
if got, want := binding.bulkAdaptiveSoftPayloadBytesSnapshot(), bulkAdaptiveSoftPayloadStartBytes; got != want {
t.Fatalf("adaptive bulk soft payload = %d, want %d", got, want)
}
}
func TestTransportBindingAdaptiveBulkSoftPayloadShrinksAfterSlowWrite(t *testing.T) {
binding := &transportBinding{}
binding.observeBulkAdaptivePayloadWrite(8*1024*1024, 640*time.Millisecond, 0, nil)
if got, want := binding.bulkAdaptiveSoftPayloadBytesSnapshot(), bulkAdaptiveSoftPayloadMinBytes; got != want {
t.Fatalf("adaptive bulk soft payload = %d, want %d", got, want)
}
}
func TestTransportBindingAdaptiveBulkSoftPayloadRecoversAfterGoodWrites(t *testing.T) {
binding := &transportBinding{}
binding.observeBulkAdaptivePayloadWrite(8*1024*1024, 640*time.Millisecond, 0, nil)
samples := bulkAdaptiveSoftPayloadGrowSuccesses * (len(bulkAdaptiveSoftPayloadSteps) - 1)
for i := 0; i < samples; i++ {
binding.observeBulkAdaptivePayloadWrite(8*1024*1024, 6*time.Millisecond, 0, nil)
}
if got, want := binding.bulkAdaptiveSoftPayloadBytesSnapshot(), bulkAdaptiveSoftPayloadStartBytes; got != want {
t.Fatalf("adaptive bulk soft payload = %d, want %d", got, want)
}
}
type delayedWriteConn struct {
delay time.Duration
}
func (c *delayedWriteConn) Read([]byte) (int, error) { return 0, net.ErrClosed }
func (c *delayedWriteConn) Write(p []byte) (int, error) { time.Sleep(c.delay); return len(p), nil }
func (c *delayedWriteConn) Close() error { return nil }
func (c *delayedWriteConn) LocalAddr() net.Addr { return nil }
func (c *delayedWriteConn) RemoteAddr() net.Addr { return nil }
func (c *delayedWriteConn) SetDeadline(time.Time) error { return nil }
func (c *delayedWriteConn) SetReadDeadline(time.Time) error { return nil }
func (c *delayedWriteConn) SetWriteDeadline(time.Time) error { return nil }
func TestBulkBatchSenderFlushAdaptsToSlowWrites(t *testing.T) {
binding := newTransportBinding(&delayedWriteConn{delay: 20 * time.Millisecond}, stario.NewQueue())
sender := newTestBulkBatchSender(binding)
defer sender.stop()
payload := make([]byte, 512*1024)
err := sender.flush([]bulkBatchRequest{
{
frames: []bulkFastFrame{
{Type: bulkFastPayloadTypeData, DataID: 101, Seq: 1, Payload: payload},
{Type: bulkFastPayloadTypeData, DataID: 101, Seq: 2, Payload: payload},
},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{
{Type: bulkFastPayloadTypeData, DataID: 202, Seq: 1, Payload: payload},
{Type: bulkFastPayloadTypeData, DataID: 202, Seq: 2, Payload: payload},
},
fastPathVersion: bulkFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("flush failed: %v", err)
}
if got := binding.bulkAdaptiveSoftPayloadBytesSnapshot(); got >= bulkAdaptiveSoftPayloadStartBytes {
t.Fatalf("adaptive bulk soft payload = %d, want smaller than %d after slow write", got, bulkAdaptiveSoftPayloadStartBytes)
}
}
func TestTransportBindingAdaptiveStreamStartsAggressive(t *testing.T) {
binding := &transportBinding{}
if got, want := binding.streamAdaptiveSoftPayloadBytesSnapshot(), streamAdaptiveSoftPayloadStartBytes; got != want {
t.Fatalf("adaptive stream soft payload = %d, want %d", got, want)
}
if got, want := binding.streamAdaptiveWaitThresholdBytesSnapshot(), streamBatchWaitThreshold; got != want {
t.Fatalf("adaptive stream wait threshold = %d, want %d", got, want)
}
if got, want := binding.streamAdaptiveFlushDelaySnapshot(), streamBatchMaxFlushDelay; got != want {
t.Fatalf("adaptive stream flush delay = %s, want %s", got, want)
}
}
func TestTransportBindingAdaptiveStreamShrinksAfterSlowWrite(t *testing.T) {
binding := &transportBinding{}
binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil)
if got, want := binding.streamAdaptiveSoftPayloadBytesSnapshot(), streamAdaptiveSoftPayloadMinBytes; got != want {
t.Fatalf("adaptive stream soft payload = %d, want %d", got, want)
}
if got, want := binding.streamAdaptiveWaitThresholdBytesSnapshot(), streamAdaptiveWaitThresholdMinBytes; got != want {
t.Fatalf("adaptive stream wait threshold = %d, want %d", got, want)
}
if got := binding.streamAdaptiveFlushDelaySnapshot(); got != 0 {
t.Fatalf("adaptive stream flush delay = %s, want 0", got)
}
}
func TestTransportBindingAdaptiveStreamRecoversAfterGoodWrites(t *testing.T) {
binding := &transportBinding{}
binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil)
samples := streamAdaptiveSoftPayloadGrowSuccesses * (len(streamAdaptiveSoftPayloadSteps) - 1)
for i := 0; i < samples; i++ {
binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 6*time.Millisecond, 0, nil)
}
if got, want := binding.streamAdaptiveSoftPayloadBytesSnapshot(), streamAdaptiveSoftPayloadStartBytes; got != want {
t.Fatalf("adaptive stream soft payload = %d, want %d", got, want)
}
if got, want := binding.streamAdaptiveWaitThresholdBytesSnapshot(), streamBatchWaitThreshold; got != want {
t.Fatalf("adaptive stream wait threshold = %d, want %d", got, want)
}
if got, want := binding.streamAdaptiveFlushDelaySnapshot(), streamBatchMaxFlushDelay; got != want {
t.Fatalf("adaptive stream flush delay = %s, want %s", got, want)
}
}
func TestStreamBatchSenderFlushAdaptsToSlowWrites(t *testing.T) {
binding := newTransportBinding(&delayedWriteConn{delay: 20 * time.Millisecond}, stario.NewQueue())
sender := newTestStreamBatchSender(binding, nil)
defer sender.stop()
payload := make([]byte, 512*1024)
err := sender.flush([]streamBatchRequest{
{
frame: streamFastDataFrame{
DataID: 101,
Seq: 1,
Payload: payload,
},
hasFrame: true,
fastPathVersion: streamFastPathVersionV2,
},
{
frame: streamFastDataFrame{
DataID: 202,
Seq: 1,
Payload: payload,
},
hasFrame: true,
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("flush failed: %v", err)
}
if got := binding.streamAdaptiveSoftPayloadBytesSnapshot(); got >= streamAdaptiveSoftPayloadStartBytes {
t.Fatalf("adaptive stream soft payload = %d, want smaller than %d after slow write", got, streamAdaptiveSoftPayloadStartBytes)
}
if got := binding.streamAdaptiveWaitThresholdBytesSnapshot(); got >= streamBatchWaitThreshold {
t.Fatalf("adaptive stream wait threshold = %d, want smaller than %d after slow write", got, streamBatchWaitThreshold)
}
}

View File

@ -9,12 +9,93 @@ var (
errTransportPayloadDecryptFailed = errors.New("transport payload decrypt failed")
)
func encryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgEn func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) {
if runtime != nil {
encoded, err := runtime.sealPlainPayload(data)
if err != nil {
return nil, errTransportPayloadEncryptFailed
}
return encoded, nil
}
if msgEn == nil {
return nil, errTransportPayloadEncryptFailed
}
encoded := msgEn(secretKey, data)
if encoded == nil && len(data) != 0 {
return nil, errTransportPayloadEncryptFailed
}
return encoded, nil
}
func decryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) {
if runtime != nil {
plain, err := runtime.openPayload(data)
if err != nil {
return nil, errTransportPayloadDecryptFailed
}
return plain, nil
}
if msgDe == nil {
return nil, errTransportPayloadDecryptFailed
}
plain := msgDe(secretKey, data)
if plain == nil && len(data) != 0 {
return nil, errTransportPayloadDecryptFailed
}
return plain, nil
}
func decryptTransportPayloadCodecPooled(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte, release func()) ([]byte, func(), error) {
if runtime != nil {
plain, plainRelease, err := runtime.openPayloadPooled(data, release)
if err != nil {
return nil, nil, errTransportPayloadDecryptFailed
}
return plain, plainRelease, nil
}
if msgDe == nil {
if release != nil {
release()
}
return nil, nil, errTransportPayloadDecryptFailed
}
plain := msgDe(secretKey, data)
if release != nil {
release()
}
if plain == nil && len(data) != 0 {
return nil, nil, errTransportPayloadDecryptFailed
}
return plain, nil, nil
}
func decryptTransportPayloadCodecOwnedPooled(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, func(), error) {
if runtime != nil {
plain, plainRelease, err := runtime.openPayloadOwnedPooled(data)
if err != nil {
return nil, nil, errTransportPayloadDecryptFailed
}
return plain, plainRelease, nil
}
if msgDe == nil {
return nil, nil, errTransportPayloadDecryptFailed
}
plain := msgDe(secretKey, data)
if plain == nil && len(data) != 0 {
return nil, nil, errTransportPayloadDecryptFailed
}
return plain, nil, nil
}
func (c *ClientCommon) encodeTransferMsg(msg TransferMsg) ([]byte, error) {
data, err := c.sequenceEn(msg)
if err != nil {
return nil, err
}
data = c.msgEn(c.SecretKey, data)
data, err = c.encryptTransportPayload(data)
if err != nil {
return nil, err
}
queue := c.clientQueueSnapshot()
if queue == nil {
return nil, errClientSessionQueueUnavailable
@ -23,7 +104,11 @@ func (c *ClientCommon) encodeTransferMsg(msg TransferMsg) ([]byte, error) {
}
func (c *ClientCommon) decodeTransferMsg(data []byte) (TransferMsg, error) {
msg, err := c.sequenceDe(c.msgDe(c.SecretKey, data))
plain, err := c.decryptTransportPayload(data)
if err != nil {
return TransferMsg{}, err
}
msg, err := c.sequenceDe(plain)
if err != nil {
return TransferMsg{}, err
}
@ -41,7 +126,10 @@ func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte
}
msgEn := c.clientConnMsgEnSnapshot()
secretKey := c.clientConnSecretKeySnapshot()
data = msgEn(secretKey, data)
data, err = encryptTransportPayloadCodec(c.clientConnModernPSKRuntimeSnapshot(), msgEn, secretKey, data)
if err != nil {
return nil, err
}
queue := s.serverQueueSnapshot()
if queue == nil {
return nil, errServerSessionQueueUnavailable
@ -52,7 +140,11 @@ func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte
func (s *ServerCommon) decodeTransferMsg(c *ClientConn, data []byte) (TransferMsg, error) {
msgDe := c.clientConnMsgDeSnapshot()
secretKey := c.clientConnSecretKeySnapshot()
msg, err := s.sequenceDe(msgDe(secretKey, data))
plain, err := decryptTransportPayloadCodec(c.clientConnModernPSKRuntimeSnapshot(), msgDe, secretKey, data)
if err != nil {
return TransferMsg{}, err
}
msg, err := s.sequenceDe(plain)
if err != nil {
return TransferMsg{}, err
}
@ -80,11 +172,7 @@ func (c *ClientCommon) encodeEnvelopePlain(env Envelope) ([]byte, error) {
}
func (c *ClientCommon) encryptTransportPayload(data []byte) ([]byte, error) {
encoded := c.msgEn(c.SecretKey, data)
if encoded == nil && len(data) != 0 {
return nil, errTransportPayloadEncryptFailed
}
return encoded, nil
return encryptTransportPayloadCodec(c.modernPSKRuntime, c.msgEn, c.SecretKey, data)
}
func (c *ClientCommon) encodeEnvelope(env Envelope) ([]byte, error) {
@ -108,11 +196,7 @@ func (c *ClientCommon) decodeEnvelope(data []byte) (Envelope, error) {
}
func (c *ClientCommon) decryptTransportPayload(data []byte) ([]byte, error) {
plain := c.msgDe(c.SecretKey, data)
if plain == nil && len(data) != 0 {
return nil, errTransportPayloadDecryptFailed
}
return plain, nil
return decryptTransportPayloadCodec(c.modernPSKRuntime, c.msgDe, c.SecretKey, data)
}
func (c *ClientCommon) decodeEnvelopePlain(data []byte) (Envelope, error) {
@ -167,11 +251,7 @@ func (s *ServerCommon) encryptTransportPayloadLogical(logical *LogicalConn, data
if msgEn == nil {
return nil, errTransportDetached
}
encoded := msgEn(secretKey, data)
if encoded == nil && len(data) != 0 {
return nil, errTransportPayloadEncryptFailed
}
return encoded, nil
return encryptTransportPayloadCodec(logical.modernPSKRuntimeSnapshot(), msgEn, secretKey, data)
}
func (s *ServerCommon) encodeEnvelopeLogical(logical *LogicalConn, env Envelope) ([]byte, error) {
@ -210,11 +290,7 @@ func (s *ServerCommon) decryptTransportPayloadLogical(logical *LogicalConn, data
if msgDe == nil {
return nil, errTransportDetached
}
plain := msgDe(secretKey, data)
if plain == nil && len(data) != 0 {
return nil, errTransportPayloadDecryptFailed
}
return plain, nil
return decryptTransportPayloadCodec(logical.modernPSKRuntimeSnapshot(), msgDe, secretKey, data)
}
func (s *ServerCommon) decodeEnvelopePlain(data []byte) (Envelope, error) {

View File

@ -1,6 +1,7 @@
package notify
import (
"b612.me/stario"
"context"
"errors"
"net"
@ -16,7 +17,7 @@ type TransportConn struct {
}
const (
transportStreamReadBufferSize = 256 * 1024
transportStreamReadBufferSize = 1024 * 1024
transportPacketReadBufferSize = 64 * 1024
)
@ -28,6 +29,17 @@ func packetReadBuffer() []byte {
return make([]byte, transportPacketReadBufferSize)
}
func newTransportFrameReader(conn net.Conn, queue *stario.StarQueue) *stario.FrameReader {
reader := stario.NewFrameReader(conn, queue)
if reader == nil {
return nil
}
if transportStreamReadBufferSize > stario.DefaultFrameReaderBufferSize {
reader.SetReadBufferSize(transportStreamReadBufferSize)
}
return reader
}
type TransportConnRuntimeSnapshot struct {
ClientID string
RemoteAddress string

View File

@ -62,18 +62,72 @@ func TestWriteFullToConnSerializesConcurrentWriters(t *testing.T) {
}
}
func newTestBulkBatchSender(binding *transportBinding) *bulkBatchSender {
return newTestBulkBatchSenderWithWriteTimeout(binding, nil)
}
func newTestBulkBatchSenderWithWriteTimeout(binding *transportBinding, writeTimeout func() time.Duration) *bulkBatchSender {
return newBulkBatchSender(binding, bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
return append([]byte(nil), frame.Payload...), nil, nil
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
payload, err := encodeBulkFastBatchPlain(frames)
return payload, nil, err
},
}, writeTimeout)
}
func newTestStreamBatchSender(binding *transportBinding, writeTimeout func() time.Duration) *streamBatchSender {
return newStreamBatchSender(binding, streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
return encodeStreamFastFramePayload(frame)
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
return encodeStreamFastBatchPlain(frames)
},
}, writeTimeout)
}
func TestBulkBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) {
left, right := net.Pipe()
defer left.Close()
defer right.Close()
sender := newBulkBatchSender(newTransportBinding(left, stario.NewQueue()))
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
sender := newTestBulkBatchSenderWithWriteTimeout(newTransportBinding(left, stario.NewQueue()), func() time.Duration {
return 50 * time.Millisecond
})
errCh := make(chan error, 1)
go func() {
errCh <- sender.submit(ctx, []byte("payload"))
errCh <- sender.submitData(context.Background(), 1, 1, bulkFastPathVersionV1, []byte("payload"))
}()
select {
case err := <-errCh:
if err == nil {
t.Fatal("sender.submit should fail when receiver stalls")
}
if !isTimeoutLikeError(err) {
t.Fatalf("sender.submit error = %v, want timeout-like error", err)
}
case <-time.After(time.Second):
t.Fatal("sender.submit should not hang when receiver stalls")
}
}
func TestStreamBatchSenderRespectsBindingWriteDeadlineWhenReceiverStalls(t *testing.T) {
left, right := net.Pipe()
defer left.Close()
defer right.Close()
sender := newTestStreamBatchSender(newTransportBinding(left, stario.NewQueue()), func() time.Duration {
return 50 * time.Millisecond
})
errCh := make(chan error, 1)
go func() {
errCh <- sender.submitData(context.Background(), 1, 1, streamFastPathVersionV2, []byte("payload"))
}()
select {
@ -258,12 +312,12 @@ func TestWriteNetBuffersFullUnlockedUsesUnwrappedVectoredConn(t *testing.T) {
func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) {
conn := newBlockingPacketWriteConn()
binding := newTransportBinding(conn, stario.NewQueue())
sender := newBulkBatchSender(binding)
sender := newTestBulkBatchSender(binding)
defer sender.stop()
firstErrCh := make(chan error, 1)
go func() {
firstErrCh <- sender.submit(context.Background(), []byte("first"))
firstErrCh <- sender.submitData(context.Background(), 1, 1, bulkFastPathVersionV1, []byte("first"))
}()
select {
@ -275,7 +329,7 @@ func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
secondErrCh := make(chan error, 1)
go func() {
secondErrCh <- sender.submit(ctx, []byte("second"))
secondErrCh <- sender.submitData(ctx, 1, 2, bulkFastPathVersionV1, []byte("second"))
}()
time.Sleep(20 * time.Millisecond)
cancel()
@ -306,16 +360,48 @@ func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) {
}
}
func TestBulkBatchSenderDoesNotDirectSubmitShareableV2Data(t *testing.T) {
sender := &bulkBatchSender{}
req := bulkBatchRequest{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 1,
Seq: 1,
Payload: make([]byte, 256*1024),
}},
fastPathVersion: bulkFastPathVersionV2,
}
if sender.shouldDirectSubmit(req) {
t.Fatal("shareable v2 shared bulk data should queue for super-batch instead of direct submit")
}
}
func TestBulkBatchSenderDirectSubmitsUnbatchableRequest(t *testing.T) {
sender := &bulkBatchSender{}
req := bulkBatchRequest{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeRelease,
DataID: 1,
Seq: 0,
Payload: []byte("rel"),
}},
fastPathVersion: bulkFastPathVersionV2,
}
if !sender.shouldDirectSubmit(req) {
t.Fatal("unbatchable shared bulk control request should still direct submit")
}
}
func TestBulkBatchSenderReturnsFlushResultAfterStartedContextCancel(t *testing.T) {
conn := newBlockingPacketWriteConn()
binding := newTransportBinding(conn, stario.NewQueue())
sender := newBulkBatchSender(binding)
sender := newTestBulkBatchSender(binding)
defer sender.stop()
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 1)
go func() {
errCh <- sender.submit(ctx, []byte("payload"))
errCh <- sender.submitData(ctx, 1, 1, bulkFastPathVersionV1, []byte("payload"))
}()
select {
@ -346,10 +432,18 @@ func TestBulkBatchSenderReturnsFlushResultAfterStartedContextCancel(t *testing.T
func TestTransportBindingStopBackgroundWorkersStopsSharedSender(t *testing.T) {
binding := newTransportBinding(newBlockingPacketWriteConn(), stario.NewQueue())
sender := binding.bulkBatchSenderSnapshot()
sender := binding.bulkBatchSenderSnapshotWithCodec(bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
return append([]byte(nil), frame.Payload...), nil, nil
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
payload, err := encodeBulkFastBatchPlain(frames)
return payload, nil, err
},
}, nil)
binding.stopBackgroundWorkers()
err := sender.submit(context.Background(), []byte("payload"))
err := sender.submitData(context.Background(), 1, 1, bulkFastPathVersionV1, []byte("payload"))
if !errors.Is(err, errTransportDetached) {
t.Fatalf("sender.submit after stop = %v, want %v", err, errTransportDetached)
}