4 Commits

Author SHA1 Message Date
b612 98ef9e7fcc feat(transport): 完成安全架构拆分并收口 stream/bulk 传输优化
- 新增 managed/external/nested 三种传输保护模式
  - 新增 peer attach 显式认证、抗重放、channel binding 和可选前向保密协商
  - 明确单连接注入与可重拨连接源的语义边界
  - 禁止 ConnectByConn 场景下 dedicated bulk 走 sidecar,auto 模式自动回退 shared
  - 修正 dedicated attach 在 bootstrap/steady profile 切换下的处理逻辑
  - 优化 shared bulk super-batch 与批量 framed write 路径
  - 降低 stream/bulk fast path 的复制和分发损耗
  - 补齐 benchmark、回归测试、运行时快照和 README 文档
2026-04-20 16:35:44 +08:00
b612 f038a89771 fix: close stream adaptive gaps and switch notify to stario v0.1.1
- make stream fast path honor adaptive soft payload limits end-to-end
  - split oversized fast-stream payloads into sequential frames before batching
  - use adaptive soft cap when encoding stream batch payloads
  - move timeout-like error detection into production code for adaptive tx
  - tune notify FrameReader read size explicitly to avoid throughput regression
  - drop local stario replace and depend on released b612.me/stario v0.1.1
2026-04-18 16:05:57 +08:00
b612 4f760f2807 fix: 修复 dedicated bulk attach 竞态并优化 short write 补写路径
- 客户端 dedicated attach 回复改为精确读取单帧,避免 attach reply 与后续 NBR1 数据粘连后被误解析
  - 服务端 accepted attach 改为先 detach transport,再直接回 attach reply,随后立即切入 dedicated bulk read loop
  - transport 读循环在 stop 或 transport ownership 失效后不再继续上推已读数据,避免 handoff 后首包被旧 reader 吃掉
  - dedicated bulk record 写路径改为 full-write,消除 short write 导致的 invalid bulk fast payload
  - 优化 vectored write 补写策略:先尝试一次 writev,未写完时直接顺序补完剩余 buffers,减少重复 WriteTo 开销
  - 放宽 vectored write 能力识别,支持通过 UnwrapConn/WriteBuffers 命中 fast path
  - 修复 dedicated batch 排队路径 payload 复用问题,改为深拷贝 queued items
  - 补齐 dedicated attach、short write、payload clone、transport stop/handoff 等回归测试
2026-04-16 17:27:48 +08:00
b612 7ed3dd5b37 feat: 完善 RecordStream 的协议协商、运行观测与文档说明
- 将 RecordStream 出站路径收敛为单 writer loop
  - 支持在 batch header 中 piggyback AckSeq,保留独立 ack 作为兼容回退
  - 增加 record stream 打开阶段能力协商,支持 mixed-version peer 自动降级
  - 补充 RecordSnapshot 与 diagnostics summary 的 record-plane 观测项
  - 增加 batch/ack/error frame、piggyback ack、barrier 等待拆分与 apply backlog 指标
  - 收紧 TransportConn detach 后的 runtime snapshot 语义
  - 补充 README 中的 RecordStream 语义、兼容行为与诊断快照说明
  - 补充相关单测与 race 回归验证
2026-04-15 19:52:45 +08:00
108 changed files with 18544 additions and 1331 deletions
+100
View File
@@ -9,6 +9,7 @@
- 记录流数据面:`OpenRecordStream` - 记录流数据面:`OpenRecordStream`
- 批量数据面:`OpenBulk``shared` / `dedicated` - 批量数据面:`OpenBulk``shared` / `dedicated`
- 文件传输内核:transfer control / progress / resume - 文件传输内核:transfer control / progress / resume
- 观测面:runtime snapshot / diagnostics summary
- 会话模型:`LogicalConn`(逻辑会话)与 `TransportConn`(物理承载)分离 - 会话模型:`LogicalConn`(逻辑会话)与 `TransportConn`(物理承载)分离
## 版本要求 ## 版本要求
@@ -24,6 +25,21 @@
未配置时会返回 `errModernPSKRequired` 未配置时会返回 `errModernPSKRequired`
## 安全模式选择
- `UseModernPSKClient` / `UseModernPSKServer`
- bootstrap 和稳态传输都由 `notify` 自己保护
- 适合默认场景
- 支持 peer attach 显式认证、抗重放,以及在需要时协商前向保密
- `UsePSKOverExternalTransportClient` / `UsePSKOverExternalTransportServer`
- bootstrap 仍用 PSK 做认证
- 稳态阶段信任外部物理通道,不再做 `notify` 内层加密
- 适合 `tls.Conn` 或调用方自认可信的外部通道
- 不支持 `RequireForwardSecrecy`
- `UseNestedSecurityClient` / `UseNestedSecurityServer`
- 外层已有可信通道,但仍保留 `notify` 内层保护
- 适合需要“外层可信 + 内层独立保护”的场景
## 快速开始 ## 快速开始
服务端: 服务端:
@@ -82,6 +98,87 @@ func main() {
} }
``` ```
## 连接入口与物理连接语义
- `Connect` / `ConnectTimeout`
-`notify` 自己拨号
- 支持重连,也支持 dedicated bulk 额外 sidecar 连接
- `ConnectByFactory`
- 调用方提供 `dialFn`
- `notify` 会在需要时再次调用 `dialFn`,因此仍支持重连和 dedicated bulk
- `ConnectByConn`
- 调用方注入一个已经建立好的 `net.Conn`
- 该模式被视为“单物理连接模式”
- `OpenDedicatedBulk` 会直接返回错误
- `OpenBulk` 使用 `auto` 模式时会自动回退到 `shared`
- `ListenByListener`
- 服务端复用调用方提供的 `net.Listener`
- 适合需要和现有 listener 栈整合的场景
`dedicated bulk` 依赖额外物理连接,因此只适用于可再次拨号的 transport source。
## Peer Attach 安全策略
可通过 `SetPeerAttachSecurityConfig` 配置逻辑会话 attach 阶段的额外保护。
- `RequireExplicitAuth`
- 要求 peer attach 使用显式认证
- `RequireChannelBinding`
- 要求 attach 绑定到底层可信通道
- 启用后会隐式要求显式认证
- `ChannelBinding`
- 由调用方提供 channel binding 提取函数
- 适合外层 TLS 或其他可信通道整合
- `ReplayWindow` / `ReplayCapacity`
- 控制 attach 抗重放窗口和缓存容量
如果你选择 `UsePSKOverExternalTransport*`,并且希望 attach 阶段显式绑定到外层可信信道,建议同时配置 channel binding。
## RecordStream 说明
`RecordStream` 构建在 `Stream` 之上,适合“有边界的顺序记录”场景。
- 写入入口:`OpenRecordStream``WriteRecord`
- 接收入口:`ReadRecord`
- 确认入口:`AckRecord`
- 检查点:`Barrier``BarrierTo`
- 错误回包:`RecordFailure`
确认语义:
- `AckRecord` 表示“该序号及其之前的连续记录已完成 apply”,不是“已收到”
- `Barrier` / `BarrierTo` 等待的是对端 `apply-complete` 的最大连续序号
- `RecordFailure` 会返回 `FailedSeq``Code``Retryable``Message`
兼容与传输:
- record stream 在打开阶段协商 batch ack 能力
- 双端都支持时,累计 `AckSeq` 会随 batch header piggyback 发送
- 对端不支持时,自动回退到独立 ack frame
- mixed-version peer 可以互通,不要求双方同时升级
## 诊断快照
顶层诊断入口:
- `GetClientDiagnosticsSnapshot`
- `GetServerDiagnosticsSnapshot`
快照内容:
- 会话运行态:client / server runtime
- 数据面快照:`StreamSnapshot``BulkSnapshot``RecordSnapshot`
- 文件传输快照:`TransferSnapshot`
- 汇总视图:`DiagnosticsSummary`
`RecordSnapshot` / `DiagnosticsSummary.RecordTelemetry` 当前覆盖:
- batch / ack / error frame 收发计数
- piggyback ack 命中计数
- barrier 等待时间拆分:`flush` / `apply`
- `outstanding records/bytes`
- `pending apply / pending ack / peak pending apply`
## 传输与 IPC ## 传输与 IPC
- `tcp` - `tcp`
@@ -100,6 +197,9 @@ func main() {
- 共享密钥派生(Argon2id - 共享密钥派生(Argon2id
- 消息层加密(AES-GCM - 消息层加密(AES-GCM
- `stream` / `bulk` fast path 复用现代编码栈 - `stream` / `bulk` fast path 复用现代编码栈
- peer attach 显式认证 / 抗重放
- 可选 channel binding
- 可选前向保密(`UseModernPSK*` / `UseNestedSecurity*`
兼容入口仍保留,但属于历史路径: 兼容入口仍保留,但属于历史路径:
+16
View File
@@ -0,0 +1,16 @@
package notify
import (
"os"
"testing"
)
const benchmarkTCPListenAddrEnv = "NOTIFY_BENCH_TCP_LISTEN_ADDR"
func benchmarkTCPListenAddr(tb testing.TB) string {
tb.Helper()
if addr := os.Getenv(benchmarkTCPListenAddrEnv); addr != "" {
return addr
}
return "127.0.0.1:0"
}
+378
View File
@@ -0,0 +1,378 @@
package notify
import (
"context"
"errors"
"io"
"math/rand"
"net"
"os"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
)
const (
benchmarkTCPProxyEnableEnv = "NOTIFY_BENCH_TCP_PROXY_ENABLE"
benchmarkTCPProxyListenAddrEnv = "NOTIFY_BENCH_TCP_PROXY_LISTEN_ADDR"
benchmarkTCPProxyRateMbitEnv = "NOTIFY_BENCH_TCP_PROXY_RATE_MBIT"
benchmarkTCPProxyDelayMsEnv = "NOTIFY_BENCH_TCP_PROXY_DELAY_MS"
benchmarkTCPProxyLossPctEnv = "NOTIFY_BENCH_TCP_PROXY_LOSS_PCT"
benchmarkTCPProxyLossPenaltyMsEnv = "NOTIFY_BENCH_TCP_PROXY_LOSS_PENALTY_MS"
benchmarkTCPProxySeedEnv = "NOTIFY_BENCH_TCP_PROXY_SEED"
benchmarkTCPProxyDefaultListenAddr = "127.0.0.1:0"
)
const benchmarkTCPProxyDefaultLossPenalty = 120 * time.Millisecond
type benchmarkTCPProxyConfig struct {
Enabled bool
ListenAddr string
RateBytesPS float64
Delay time.Duration
LossPct float64
LossPenalty time.Duration
Seed int64
}
type benchmarkTCPProxy struct {
listener net.Listener
target string
cfg benchmarkTCPProxyConfig
closed atomic.Bool
wg sync.WaitGroup
}
func benchmarkTCPDialAddr(tb testing.TB, targetAddr string) string {
tb.Helper()
if targetAddr == "" {
return targetAddr
}
cfg := benchmarkTCPProxyConfigFromEnv(tb)
if !cfg.Enabled {
return targetAddr
}
proxy, err := startBenchmarkTCPProxy(targetAddr, cfg)
if err != nil {
tb.Fatalf("start benchmark tcp proxy failed: %v", err)
}
tb.Cleanup(proxy.stop)
if proxy.listener == nil || proxy.listener.Addr() == nil {
tb.Fatalf("benchmark tcp proxy listener is nil")
}
return proxy.listener.Addr().String()
}
func benchmarkTCPProxyConfigFromEnv(tb testing.TB) benchmarkTCPProxyConfig {
tb.Helper()
enabled, ok, err := parseBoolEnv(benchmarkTCPProxyEnableEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyEnableEnv, err)
}
if !ok || !enabled {
return benchmarkTCPProxyConfig{}
}
cfg := benchmarkTCPProxyConfig{
Enabled: true,
ListenAddr: benchmarkTCPProxyDefaultListenAddr,
Seed: 1,
}
if listen := os.Getenv(benchmarkTCPProxyListenAddrEnv); listen != "" {
cfg.ListenAddr = listen
}
rateMbit, ok, err := parseFloatEnv(benchmarkTCPProxyRateMbitEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyRateMbitEnv, err)
}
if ok && rateMbit > 0 {
cfg.RateBytesPS = rateMbit * 1000 * 1000 / 8
}
delayMs, ok, err := parseIntEnv(benchmarkTCPProxyDelayMsEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyDelayMsEnv, err)
}
if ok && delayMs > 0 {
cfg.Delay = time.Duration(delayMs) * time.Millisecond
}
lossPct, ok, err := parseFloatEnv(benchmarkTCPProxyLossPctEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyLossPctEnv, err)
}
if ok {
if lossPct < 0 || lossPct > 100 {
tb.Fatalf("%s must be in [0, 100], got %.4f", benchmarkTCPProxyLossPctEnv, lossPct)
}
cfg.LossPct = lossPct
}
lossPenaltyMs, ok, err := parseIntEnv(benchmarkTCPProxyLossPenaltyMsEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyLossPenaltyMsEnv, err)
}
if ok && lossPenaltyMs > 0 {
cfg.LossPenalty = time.Duration(lossPenaltyMs) * time.Millisecond
}
if cfg.LossPct > 0 && cfg.LossPenalty <= 0 {
cfg.LossPenalty = benchmarkTCPProxyDefaultLossPenalty
}
seed, ok, err := parseInt64Env(benchmarkTCPProxySeedEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxySeedEnv, err)
}
if ok {
cfg.Seed = seed
}
return cfg
}
func parseBoolEnv(name string) (value bool, ok bool, err error) {
raw := os.Getenv(name)
if raw == "" {
return false, false, nil
}
value, err = strconv.ParseBool(raw)
return value, true, err
}
func parseFloatEnv(name string) (value float64, ok bool, err error) {
raw := os.Getenv(name)
if raw == "" {
return 0, false, nil
}
value, err = strconv.ParseFloat(raw, 64)
return value, true, err
}
func parseIntEnv(name string) (value int, ok bool, err error) {
raw := os.Getenv(name)
if raw == "" {
return 0, false, nil
}
value, err = strconv.Atoi(raw)
return value, true, err
}
func parseInt64Env(name string) (value int64, ok bool, err error) {
raw := os.Getenv(name)
if raw == "" {
return 0, false, nil
}
value, err = strconv.ParseInt(raw, 10, 64)
return value, true, err
}
func startBenchmarkTCPProxy(targetAddr string, cfg benchmarkTCPProxyConfig) (*benchmarkTCPProxy, error) {
if targetAddr == "" {
return nil, errors.New("benchmark tcp proxy target addr is empty")
}
listenAddr := cfg.ListenAddr
if listenAddr == "" {
listenAddr = benchmarkTCPProxyDefaultListenAddr
}
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return nil, err
}
proxy := &benchmarkTCPProxy{
listener: listener,
target: targetAddr,
cfg: cfg,
}
proxy.wg.Add(1)
go proxy.runAcceptLoop()
return proxy, nil
}
func (p *benchmarkTCPProxy) runAcceptLoop() {
defer p.wg.Done()
for {
conn, err := p.listener.Accept()
if err != nil {
if p.closed.Load() || errors.Is(err, net.ErrClosed) {
return
}
continue
}
p.wg.Add(1)
go func(clientConn net.Conn) {
defer p.wg.Done()
p.handleConn(clientConn)
}(conn)
}
}
func (p *benchmarkTCPProxy) handleConn(clientConn net.Conn) {
if p == nil || clientConn == nil {
return
}
targetConn, err := net.Dial("tcp", p.target)
if err != nil {
_ = clientConn.Close()
return
}
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
}
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 2)
leftRng := rand.New(rand.NewSource(p.cfg.Seed + 101))
rightRng := rand.New(rand.NewSource(p.cfg.Seed + 202))
go func() {
errCh <- benchmarkRelayStream(ctx, targetConn, clientConn, p.cfg, leftRng)
}()
go func() {
errCh <- benchmarkRelayStream(ctx, clientConn, targetConn, p.cfg, rightRng)
}()
_ = <-errCh
cancel()
_ = clientConn.Close()
_ = targetConn.Close()
<-errCh
}
func (p *benchmarkTCPProxy) stop() {
if p == nil {
return
}
if p.closed.CompareAndSwap(false, true) {
if p.listener != nil {
_ = p.listener.Close()
}
}
p.wg.Wait()
}
func benchmarkRelayStream(ctx context.Context, dst net.Conn, src net.Conn, cfg benchmarkTCPProxyConfig, rng *rand.Rand) error {
if dst == nil || src == nil {
return net.ErrClosed
}
chunkCh := make(chan benchmarkRelayChunk, 128)
readerErrCh := make(chan error, 1)
go benchmarkRelayReader(ctx, src, chunkCh, readerErrCh)
scheduler := benchmarkRelayScheduler{rateBytesPS: cfg.RateBytesPS}
for chunk := range chunkCh {
sendAt := scheduler.schedule(chunk.readAt, len(chunk.payload), cfg.Delay)
if cfg.LossPct > 0 && cfg.LossPenalty > 0 && rng != nil && rng.Float64()*100 < cfg.LossPct {
sendAt = sendAt.Add(cfg.LossPenalty)
}
if waitErr := benchmarkSleepUntilContext(ctx, sendAt); waitErr != nil {
return waitErr
}
if writeErr := writeAllToConn(dst, chunk.payload); writeErr != nil {
return writeErr
}
}
readerErr := <-readerErrCh
if errors.Is(readerErr, io.EOF) {
return nil
}
return readerErr
}
type benchmarkRelayChunk struct {
readAt time.Time
payload []byte
}
func benchmarkRelayReader(ctx context.Context, src net.Conn, chunkCh chan<- benchmarkRelayChunk, errCh chan<- error) {
defer close(chunkCh)
if src == nil {
errCh <- net.ErrClosed
return
}
buf := make([]byte, 64*1024)
for {
n, err := src.Read(buf)
if n > 0 {
chunk := benchmarkRelayChunk{
readAt: time.Now(),
payload: append([]byte(nil), buf[:n]...),
}
select {
case <-ctx.Done():
errCh <- ctx.Err()
return
case chunkCh <- chunk:
}
}
if err != nil {
errCh <- err
return
}
}
}
type benchmarkRelayScheduler struct {
rateBytesPS float64
nextAt time.Time
}
func (s *benchmarkRelayScheduler) schedule(readAt time.Time, bytes int, delay time.Duration) time.Time {
if s == nil {
return readAt
}
sendAt := readAt
if delay > 0 {
sendAt = sendAt.Add(delay)
}
if !s.nextAt.IsZero() && s.nextAt.After(sendAt) {
sendAt = s.nextAt
}
if s.rateBytesPS > 0 && bytes > 0 {
duration := time.Duration(float64(bytes) / s.rateBytesPS * float64(time.Second))
if duration <= 0 {
duration = time.Nanosecond
}
s.nextAt = sendAt.Add(duration)
} else {
s.nextAt = sendAt
}
return sendAt
}
func benchmarkSleepUntilContext(ctx context.Context, at time.Time) error {
wait := time.Until(at)
if wait <= 0 {
return nil
}
return benchmarkSleepContext(ctx, wait)
}
func benchmarkSleepContext(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
func writeAllToConn(conn net.Conn, payload []byte) error {
for len(payload) > 0 {
n, err := conn.Write(payload)
if n > 0 {
payload = payload[n:]
}
if err != nil {
return err
}
if n == 0 {
return io.ErrNoProgress
}
}
return nil
}
+42
View File
@@ -0,0 +1,42 @@
package notify
import "testing"
type benchmarkTransportSecurityMode string
const (
benchmarkTransportSecurityModernPSK benchmarkTransportSecurityMode = "modern_psk"
benchmarkTransportSecurityTrustedRaw benchmarkTransportSecurityMode = "trusted_raw"
)
func benchmarkApplyServerTransportSecurity(tb testing.TB, server *ServerCommon, mode benchmarkTransportSecurityMode) {
tb.Helper()
switch mode {
case benchmarkTransportSecurityModernPSK:
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
tb.Fatalf("UseModernPSKServer failed: %v", err)
}
case benchmarkTransportSecurityTrustedRaw:
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
tb.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
}
default:
tb.Fatalf("unsupported benchmark transport security mode %q", mode)
}
}
func benchmarkApplyClientTransportSecurity(tb testing.TB, client *ClientCommon, mode benchmarkTransportSecurityMode) {
tb.Helper()
switch mode {
case benchmarkTransportSecurityModernPSK:
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
tb.Fatalf("UseModernPSKClient failed: %v", err)
}
case benchmarkTransportSecurityTrustedRaw:
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
tb.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
}
default:
tb.Fatalf("unsupported benchmark transport security mode %q", mode)
}
}
+1111 -126
View File
File diff suppressed because it is too large Load Diff
+484 -62
View File
@@ -9,7 +9,9 @@ import (
) )
const ( const (
bulkBatchMaxPayloads = 16 bulkBatchMaxPayloads = 64
bulkBatchMaxPayloadBytes = bulkFastBatchMaxPlainBytes
bulkBatchMaxFlushDelay = 50 * time.Microsecond
) )
const ( const (
@@ -22,48 +24,161 @@ type bulkBatchRequestState struct {
value atomic.Int32 value atomic.Int32
} }
type bulkBatchCodec struct {
encodeSingle func(bulkFastFrame) ([]byte, func(), error)
encodeBatch func([]bulkFastFrame) ([]byte, func(), error)
}
type bulkBatchRequest struct { type bulkBatchRequest struct {
ctx context.Context ctx context.Context
payload []byte frames []bulkFastFrame
deadline time.Time fastPathVersion uint8
done chan error payloadOwned bool
state *bulkBatchRequestState deadline time.Time
done chan error
state *bulkBatchRequestState
release func()
}
type bulkBatchEncodedPayload struct {
payload []byte
release func()
}
func (p *bulkBatchEncodedPayload) done() {
if p == nil || p.release == nil {
return
}
p.release()
p.release = nil
} }
type bulkBatchSender struct { type bulkBatchSender struct {
binding *transportBinding binding *transportBinding
reqCh chan bulkBatchRequest codec bulkBatchCodec
stopCh chan struct{} writeTimeoutProvider func() time.Duration
doneCh chan struct{} reqCh chan bulkBatchRequest
stopCh chan struct{}
doneCh chan struct{}
stopOnce sync.Once stopOnce sync.Once
flushMu sync.Mutex
queued atomic.Int64
errMu sync.Mutex errMu sync.Mutex
err error err error
} }
func newBulkBatchSender(binding *transportBinding) *bulkBatchSender { func newBulkBatchSender(binding *transportBinding, codec bulkBatchCodec, writeTimeoutProvider func() time.Duration) *bulkBatchSender {
sender := &bulkBatchSender{ sender := &bulkBatchSender{
binding: binding, binding: binding,
reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4), codec: codec,
stopCh: make(chan struct{}), writeTimeoutProvider: writeTimeoutProvider,
doneCh: make(chan struct{}), reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
} }
go sender.run() go sender.run()
return sender return sender
} }
func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error { func (s *bulkBatchSender) submitData(ctx context.Context, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
return s.submitFramesOwned(ctx, []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: payload,
}}, fastPathVersion, false)
}
func (s *bulkBatchSender) submitControl(ctx context.Context, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
return s.submitFramesOwned(ctx, []bulkFastFrame{{
Type: frameType,
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload,
}}, fastPathVersion, false)
}
func (s *bulkBatchSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, fastPathVersion uint8, payload []byte, chunkSize int, payloadOwned bool) (int, error) {
if s == nil {
return 0, errTransportDetached
}
if len(payload) == 0 {
return 0, nil
}
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
written := 0
seq := startSeq
for written < len(payload) {
var batch [bulkFastBatchMaxItems]bulkFastFrame
frames := batch[:0]
batchBytes := bulkFastBatchHeaderLen
start := written
for written < len(payload) && len(frames) < bulkFastBatchMaxItems {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
frame := bulkFastFrame{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: payload[written:end],
}
frameLen := bulkFastBatchFrameLen(frame)
if len(frames) > 0 && batchBytes+frameLen > bulkFastBatchMaxPlainBytes {
break
}
frames = append(frames, frame)
batchBytes += frameLen
seq++
written = end
}
if len(frames) == 0 {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
frames = append(frames, bulkFastFrame{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: payload[written:end],
})
seq++
written = end
}
if err := s.submitFramesOwned(ctx, frames, fastPathVersion, payloadOwned); err != nil {
return start, err
}
}
return written, nil
}
func (s *bulkBatchSender) submitFrames(ctx context.Context, frames []bulkFastFrame, fastPathVersion uint8) error {
return s.submitFramesOwned(ctx, frames, fastPathVersion, false)
}
func (s *bulkBatchSender) submitFramesOwned(ctx context.Context, frames []bulkFastFrame, fastPathVersion uint8, payloadOwned bool) error {
if s == nil { if s == nil {
return errTransportDetached return errTransportDetached
} }
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
if len(frames) == 0 {
return nil
}
req := bulkBatchRequest{ req := bulkBatchRequest{
ctx: ctx, ctx: ctx,
payload: payload, frames: frames,
done: make(chan error, 1), fastPathVersion: normalizeBulkFastPathVersion(fastPathVersion),
state: &bulkBatchRequestState{}, payloadOwned: payloadOwned,
done: make(chan error, 1),
state: &bulkBatchRequestState{},
} }
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
req.deadline = deadline req.deadline = deadline
@@ -71,10 +186,25 @@ func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error {
if err := s.errSnapshot(); err != nil { if err := s.errSnapshot(); err != nil {
return err return err
} }
if s.shouldDirectSubmit(req) {
if submitted, err := s.tryDirectSubmit(req); submitted {
return err
}
}
req = cloneQueuedBulkBatchRequest(req)
s.queued.Add(1)
select { select {
case <-ctx.Done(): case <-ctx.Done():
s.queued.Add(-1)
if req.release != nil {
req.release()
}
return normalizeStreamDeadlineError(ctx.Err()) return normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh: case <-s.stopCh:
s.queued.Add(-1)
if req.release != nil {
req.release()
}
return s.stoppedErr() return s.stoppedErr()
case s.reqCh <- req: case s.reqCh <- req:
} }
@@ -89,6 +219,55 @@ func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error {
} }
} }
func (s *bulkBatchSender) shouldDirectSubmit(req bulkBatchRequest) bool {
if len(req.frames) == 0 {
return false
}
return !bulkBatchRequestSupportsSharedSuperBatch(req)
}
func (s *bulkBatchSender) tryDirectSubmit(req bulkBatchRequest) (bool, error) {
if s == nil {
return true, errTransportDetached
}
if err := s.errSnapshot(); err != nil {
return true, err
}
select {
case <-req.ctx.Done():
return true, normalizeStreamDeadlineError(req.ctx.Err())
case <-s.stopCh:
return true, s.stoppedErr()
default:
}
if s.queued.Load() != 0 {
return false, nil
}
if !s.flushMu.TryLock() {
return false, nil
}
defer s.flushMu.Unlock()
if s.queued.Load() != 0 {
return false, nil
}
if err := s.errSnapshot(); err != nil {
return true, err
}
if !req.tryStart() {
return true, req.canceledErr()
}
if err := req.contextErr(); err != nil {
return true, err
}
err := s.flush([]bulkBatchRequest{req})
if err != nil {
s.setErr(err)
s.failPending(err)
return true, err
}
return true, nil
}
func (s *bulkBatchSender) run() { func (s *bulkBatchSender) run() {
defer close(s.doneCh) defer close(s.doneCh)
for { for {
@@ -97,34 +276,83 @@ func (s *bulkBatchSender) run() {
return return
} }
batch := []bulkBatchRequest{req} batch := []bulkBatchRequest{req}
batchBytes := bulkBatchRequestApproxBytes(req)
timer := (*time.Timer)(nil)
timerCh := (<-chan time.Time)(nil)
if bulkBatchShouldWaitForMore(batch, batchBytes) {
timer = time.NewTimer(bulkBatchMaxFlushDelay)
timerCh = timer.C
}
drain: drain:
for len(batch) < bulkBatchMaxPayloads { for len(batch) < bulkBatchMaxPayloads && batchBytes < bulkBatchMaxPayloadBytes {
if timerCh == nil {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
batchBytes += bulkBatchRequestApproxBytes(next)
default:
break drain
}
continue
}
select { select {
case <-s.stopCh: case <-s.stopCh:
if timer != nil {
timer.Stop()
}
s.failPending(s.stoppedErr()) s.failPending(s.stoppedErr())
return return
case next := <-s.reqCh: case next := <-s.reqCh:
batch = append(batch, next) batch = append(batch, next)
default: batchBytes += bulkBatchRequestApproxBytes(next)
case <-timerCh:
timerCh = nil
break drain break drain
} }
} }
active, payloads := activeBulkBatchRequests(batch) if timer != nil {
if !timer.Stop() && timerCh != nil {
select {
case <-timer.C:
default:
}
}
}
s.flushMu.Lock()
err := s.errSnapshot()
active := make([]bulkBatchRequest, 0, len(batch))
for _, item := range batch {
if !item.tryStart() {
s.finishRequest(item, item.canceledErr())
continue
}
if itemErr := item.contextErr(); itemErr != nil {
s.finishRequest(item, itemErr)
continue
}
active = append(active, item)
}
if len(active) == 0 { if len(active) == 0 {
s.flushMu.Unlock()
continue continue
} }
deadline := bulkBatchRequestsEarliestDeadline(active) if err == nil {
err := s.flush(payloads, deadline) err = s.flush(active)
}
s.flushMu.Unlock()
if err != nil { if err != nil {
s.setErr(err) s.setErr(err)
for _, item := range active { for _, item := range active {
item.done <- err s.finishRequest(item, err)
} }
s.failPending(err) s.failPending(err)
return return
} }
for _, item := range active { for _, item := range active {
item.done <- err s.finishRequest(item, nil)
} }
} }
} }
@@ -139,37 +367,6 @@ func (s *bulkBatchSender) nextRequest() (bulkBatchRequest, bool) {
} }
} }
func activeBulkBatchRequests(batch []bulkBatchRequest) ([]bulkBatchRequest, [][]byte) {
active := make([]bulkBatchRequest, 0, len(batch))
payloads := make([][]byte, 0, len(batch))
for _, item := range batch {
if !item.tryStart() {
item.done <- item.canceledErr()
continue
}
if err := item.contextErr(); err != nil {
item.done <- err
continue
}
active = append(active, item)
payloads = append(payloads, item.payload)
}
return active, payloads
}
func bulkBatchRequestsEarliestDeadline(batch []bulkBatchRequest) time.Time {
var deadline time.Time
for _, item := range batch {
if item.deadline.IsZero() {
continue
}
if deadline.IsZero() || item.deadline.Before(deadline) {
deadline = item.deadline
}
}
return deadline
}
func (r bulkBatchRequest) contextErr() error { func (r bulkBatchRequest) contextErr() error {
if r.ctx == nil { if r.ctx == nil {
return nil return nil
@@ -203,7 +400,7 @@ func (r bulkBatchRequest) canceledErr() error {
return context.Canceled return context.Canceled
} }
func (s *bulkBatchSender) flush(payloads [][]byte, deadline time.Time) error { func (s *bulkBatchSender) flush(requests []bulkBatchRequest) error {
if s == nil || s.binding == nil { if s == nil || s.binding == nil {
return errTransportDetached return errTransportDetached
} }
@@ -211,9 +408,196 @@ func (s *bulkBatchSender) flush(payloads [][]byte, deadline time.Time) error {
if queue == nil { if queue == nil {
return errTransportFrameQueueUnavailable return errTransportFrameQueueUnavailable
} }
return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error { payloads, err := s.encodeRequests(requests)
return writeFramedPayloadBatchUnlocked(conn, queue, payloads) if err != nil {
return err
}
defer func() {
for index := range payloads {
payloads[index].done()
}
}()
writeTimeout := s.transportWriteTimeout()
frames := make([][]byte, 0, len(payloads))
payloadBytes := 0
for _, payload := range payloads {
frames = append(frames, payload.payload)
payloadBytes += len(payload.payload)
}
started := time.Now()
err = s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error {
return writeFramedPayloadBatchUnlocked(conn, queue, frames)
}) })
s.binding.observeBulkAdaptivePayloadWrite(payloadBytes, time.Since(started), writeTimeout, err)
return err
}
func (s *bulkBatchSender) encodeRequests(requests []bulkBatchRequest) ([]bulkBatchEncodedPayload, error) {
if len(requests) == 0 {
return nil, nil
}
payloads := make([]bulkBatchEncodedPayload, 0, len(requests))
batch := make([]bulkFastFrame, 0, minInt(len(requests), bulkFastBatchMaxItems))
mixedBatchLimit := s.sharedMixedPayloadLimit()
batchRequestIndex := -1
batchDataID := uint64(0)
batchMixed := false
flushBatch := func() error {
if len(batch) == 0 {
return nil
}
payload, release, err := s.encodeBatch(batch)
if err != nil {
return err
}
payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release})
batch = batch[:0]
batchRequestIndex = -1
batchDataID = 0
batchMixed = false
return nil
}
batchBytes := bulkFastBatchHeaderLen
for reqIndex, req := range requests {
for _, frame := range req.frames {
if !bulkFastPathSupportsSharedBatch(req.fastPathVersion) {
if err := flushBatch(); err != nil {
return nil, err
}
batchBytes = bulkFastBatchHeaderLen
payload, release, err := s.encodeSingle(frame)
if err != nil {
return nil, err
}
payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release})
continue
}
frameLen := bulkFastBatchFrameLen(frame)
if frameLen+bulkFastBatchHeaderLen > bulkFastBatchMaxPlainBytes {
if err := flushBatch(); err != nil {
return nil, err
}
batchBytes = bulkFastBatchHeaderLen
payload, release, err := s.encodeSingle(frame)
if err != nil {
return nil, err
}
payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release})
continue
}
nextMixed := batchMixed
if len(batch) > 0 && (batchRequestIndex != reqIndex || (batchDataID != 0 && batchDataID != frame.DataID)) {
nextMixed = true
}
batchLimit := bulkFastBatchMaxPlainBytes
if nextMixed && mixedBatchLimit > 0 && mixedBatchLimit < batchLimit {
batchLimit = mixedBatchLimit
}
if len(batch) > 0 && (len(batch) >= bulkFastBatchMaxItems || batchBytes+frameLen > batchLimit) {
if err := flushBatch(); err != nil {
return nil, err
}
batchBytes = bulkFastBatchHeaderLen
nextMixed = false
}
if len(batch) == 0 {
batchRequestIndex = reqIndex
batchDataID = frame.DataID
batchMixed = false
} else if batchRequestIndex != reqIndex || (batchDataID != 0 && batchDataID != frame.DataID) {
batchMixed = true
}
batch = append(batch, frame)
batchBytes += frameLen
}
}
if err := flushBatch(); err != nil {
return nil, err
}
return payloads, nil
}
func bulkBatchRequestApproxBytes(req bulkBatchRequest) int {
total := 0
for _, frame := range req.frames {
total += bulkFastBatchFrameLen(frame)
}
return total
}
func bulkBatchRequestSupportsSharedSuperBatch(req bulkBatchRequest) bool {
if len(req.frames) == 0 || !bulkFastPathSupportsSharedBatch(req.fastPathVersion) {
return false
}
for _, frame := range req.frames {
switch frame.Type {
case bulkFastPayloadTypeData:
default:
return false
}
}
return true
}
func bulkBatchShouldWaitForMore(batch []bulkBatchRequest, batchBytes int) bool {
if bulkBatchMaxFlushDelay <= 0 || len(batch) == 0 {
return false
}
if len(batch) >= bulkBatchMaxPayloads || batchBytes >= bulkBatchMaxPayloadBytes {
return false
}
for _, req := range batch {
if !bulkBatchRequestSupportsSharedSuperBatch(req) {
return false
}
}
return true
}
func cloneQueuedBulkBatchRequest(req bulkBatchRequest) bulkBatchRequest {
if len(req.frames) == 0 || req.payloadOwned {
return req
}
clonedFrames := make([]bulkFastFrame, len(req.frames))
totalPayload := 0
for _, frame := range req.frames {
totalPayload += len(frame.Payload)
}
var payloadBuf []byte
if totalPayload > 0 {
payloadBuf = getBulkAsyncWritePayload(totalPayload)
req.release = func() {
putBulkAsyncWritePayload(payloadBuf)
}
}
offset := 0
for index, frame := range req.frames {
clonedFrames[index] = frame
if len(frame.Payload) == 0 {
clonedFrames[index].Payload = nil
continue
}
next := offset + len(frame.Payload)
clonedFrames[index].Payload = payloadBuf[offset:next]
copy(clonedFrames[index].Payload, frame.Payload)
offset = next
}
req.frames = clonedFrames
return req
}
func (s *bulkBatchSender) encodeSingle(frame bulkFastFrame) ([]byte, func(), error) {
if s == nil || s.codec.encodeSingle == nil {
return nil, nil, errTransportDetached
}
return s.codec.encodeSingle(frame)
}
func (s *bulkBatchSender) encodeBatch(frames []bulkFastFrame) ([]byte, func(), error) {
if len(frames) == 1 || s.codec.encodeBatch == nil {
return s.encodeSingle(frames[0])
}
return s.codec.encodeBatch(frames)
} }
func (s *bulkBatchSender) stop() { func (s *bulkBatchSender) stop() {
@@ -231,13 +615,23 @@ func (s *bulkBatchSender) failPending(err error) {
for { for {
select { select {
case item := <-s.reqCh: case item := <-s.reqCh:
item.done <- err s.finishRequest(item, err)
default: default:
return return
} }
} }
} }
func (s *bulkBatchSender) finishRequest(req bulkBatchRequest, err error) {
if s != nil {
s.queued.Add(-1)
}
if req.release != nil {
req.release()
}
req.done <- err
}
func (s *bulkBatchSender) setErr(err error) { func (s *bulkBatchSender) setErr(err error) {
if s == nil || err == nil { if s == nil || err == nil {
return return
@@ -264,3 +658,31 @@ func (s *bulkBatchSender) stoppedErr() error {
} }
return errTransportDetached return errTransportDetached
} }
func (s *bulkBatchSender) transportWriteDeadline() time.Time {
if s == nil || s.writeTimeoutProvider == nil {
return time.Time{}
}
return writeDeadlineFromTimeout(s.writeTimeoutProvider())
}
func (s *bulkBatchSender) transportWriteTimeout() time.Duration {
if s == nil || s.writeTimeoutProvider == nil {
return 0
}
return s.writeTimeoutProvider()
}
func (s *bulkBatchSender) sharedMixedPayloadLimit() int {
if s == nil || s.binding == nil {
return bulkAdaptiveSoftPayloadFallbackBytes
}
return s.binding.bulkAdaptiveSoftPayloadBytesSnapshot()
}
func minInt(a int, b int) int {
if a < b {
return a
}
return b
}
+86 -22
View File
@@ -38,7 +38,25 @@ func BenchmarkBulkTCPThroughput(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughput(b, tc.payloadSize, false) benchmarkBulkTCPThroughput(b, tc.payloadSize, false, benchmarkTransportSecurityModernPSK)
})
}
}
func BenchmarkBulkTCPThroughputTrustedRaw(b *testing.B) {
cases := []struct {
name string
payloadSize int
}{
{name: "chunk_256KiB", payloadSize: 256 * 1024},
{name: "chunk_512KiB", payloadSize: 512 * 1024},
{name: "chunk_768KiB", payloadSize: 768 * 1024},
{name: "chunk_1MiB", payloadSize: 1024 * 1024},
{name: "chunk_2MiB", payloadSize: 2 * 1024 * 1024},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughput(b, tc.payloadSize, false, benchmarkTransportSecurityTrustedRaw)
}) })
} }
} }
@@ -72,7 +90,25 @@ func BenchmarkBulkTCPThroughputDedicated(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughput(b, tc.payloadSize, true) benchmarkBulkTCPThroughput(b, tc.payloadSize, true, benchmarkTransportSecurityModernPSK)
})
}
}
func BenchmarkBulkTCPThroughputDedicatedTrustedRaw(b *testing.B) {
cases := []struct {
name string
payloadSize int
}{
{name: "chunk_256KiB", payloadSize: 256 * 1024},
{name: "chunk_512KiB", payloadSize: 512 * 1024},
{name: "chunk_768KiB", payloadSize: 768 * 1024},
{name: "chunk_1MiB", payloadSize: 1024 * 1024},
{name: "chunk_2MiB", payloadSize: 2 * 1024 * 1024},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughput(b, tc.payloadSize, true, benchmarkTransportSecurityTrustedRaw)
}) })
} }
} }
@@ -107,7 +143,25 @@ func BenchmarkBulkTCPThroughputConcurrent(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false) benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false, benchmarkTransportSecurityModernPSK)
})
}
}
func BenchmarkBulkTCPThroughputConcurrentTrustedRaw(b *testing.B) {
cases := []struct {
name string
payloadSize int
concurrency int
}{
{name: "bulks_2_512KiB", payloadSize: 512 * 1024, concurrency: 2},
{name: "bulks_4_512KiB", payloadSize: 512 * 1024, concurrency: 4},
{name: "bulks_2_1MiB", payloadSize: 1024 * 1024, concurrency: 2},
{name: "bulks_4_1MiB", payloadSize: 1024 * 1024, concurrency: 4},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false, benchmarkTransportSecurityTrustedRaw)
}) })
} }
} }
@@ -142,18 +196,34 @@ func BenchmarkBulkTCPThroughputConcurrentDedicated(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true) benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true, benchmarkTransportSecurityModernPSK)
}) })
} }
} }
func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) { func BenchmarkBulkTCPThroughputConcurrentDedicatedTrustedRaw(b *testing.B) {
cases := []struct {
name string
payloadSize int
concurrency int
}{
{name: "bulks_2_512KiB", payloadSize: 512 * 1024, concurrency: 2},
{name: "bulks_4_512KiB", payloadSize: 512 * 1024, concurrency: 4},
{name: "bulks_2_1MiB", payloadSize: 1024 * 1024, concurrency: 2},
{name: "bulks_4_1MiB", payloadSize: 1024 * 1024, concurrency: 4},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true, benchmarkTransportSecurityTrustedRaw)
})
}
}
func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool, securityMode benchmarkTransportSecurityMode) {
b.Helper() b.Helper()
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyServerTransportSecurity(b, server, securityMode)
b.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan BulkAcceptInfo, 1) acceptCh := make(chan BulkAcceptInfo, 1)
server.SetBulkHandler(func(info BulkAcceptInfo) error { server.SetBulkHandler(func(info BulkAcceptInfo) error {
@@ -161,7 +231,7 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
return nil return nil
}) })
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil {
b.Fatalf("server Listen failed: %v", err) b.Fatalf("server Listen failed: %v", err)
} }
b.Cleanup(func() { b.Cleanup(func() {
@@ -169,10 +239,8 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
}) })
client := NewClient().(*ClientCommon) client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyClientTransportSecurity(b, client, securityMode)
b.Fatalf("UseModernPSKClient failed: %v", err) if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
b.Fatalf("client Connect failed: %v", err) b.Fatalf("client Connect failed: %v", err)
} }
b.Cleanup(func() { b.Cleanup(func() {
@@ -241,16 +309,14 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
_ = bulk.Close() _ = bulk.Close()
} }
func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool) { func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool, securityMode benchmarkTransportSecurityMode) {
b.Helper() b.Helper()
if concurrency <= 0 { if concurrency <= 0 {
b.Fatal("concurrency must be > 0") b.Fatal("concurrency must be > 0")
} }
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyServerTransportSecurity(b, server, securityMode)
b.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan BulkAcceptInfo, concurrency*2) acceptCh := make(chan BulkAcceptInfo, concurrency*2)
server.SetBulkHandler(func(info BulkAcceptInfo) error { server.SetBulkHandler(func(info BulkAcceptInfo) error {
@@ -258,7 +324,7 @@ func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurr
return nil return nil
}) })
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil {
b.Fatalf("server Listen failed: %v", err) b.Fatalf("server Listen failed: %v", err)
} }
b.Cleanup(func() { b.Cleanup(func() {
@@ -266,10 +332,8 @@ func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurr
}) })
client := NewClient().(*ClientCommon) client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyClientTransportSecurity(b, client, securityMode)
b.Fatalf("UseModernPSKClient failed: %v", err) if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
b.Fatalf("client Connect failed: %v", err) b.Fatalf("client Connect failed: %v", err)
} }
b.Cleanup(func() { b.Cleanup(func() {
+119
View File
@@ -0,0 +1,119 @@
package notify
import (
"context"
"errors"
"testing"
"time"
)
func TestBulkOwnedChunkReleaseAfterRead(t *testing.T) {
bulk := newBulkHandle(context.Background(), newBulkRuntime("buffer-release-read"), clientFileScope(), BulkOpenRequest{
BulkID: "buffer-release-read",
DataID: 1,
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
released := 0
if err := bulk.pushOwnedChunkWithReleaseNoReset([]byte("hello"), func() {
released++
}); err != nil {
t.Fatalf("pushOwnedChunkWithReleaseNoReset failed: %v", err)
}
buf := make([]byte, 5)
n, err := bulk.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n != 5 || string(buf[:n]) != "hello" {
t.Fatalf("Read = %d %q, want 5 hello", n, string(buf[:n]))
}
if released != 1 {
t.Fatalf("release count = %d, want 1", released)
}
}
func TestBulkOwnedChunkReleaseOnReset(t *testing.T) {
bulk := newBulkHandle(context.Background(), newBulkRuntime("buffer-release-reset"), clientFileScope(), BulkOpenRequest{
BulkID: "buffer-release-reset",
DataID: 1,
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
released := 0
if err := bulk.pushOwnedChunkWithReleaseNoReset([]byte("hello"), func() {
released++
}); err != nil {
t.Fatalf("pushOwnedChunkWithReleaseNoReset failed: %v", err)
}
bulk.markReset(errors.New("boom"))
if released != 1 {
t.Fatalf("release count = %d, want 1", released)
}
}
func TestBulkReadDoesNotBlockOnAsyncWindowRelease(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
releaseStarted := make(chan struct{})
releaseUnblock := make(chan struct{})
bulk := newBulkHandle(ctx, newBulkRuntime("buffer-release-async"), clientFileScope(), BulkOpenRequest{
BulkID: "buffer-release-async",
DataID: 1,
ChunkSize: 4,
WindowBytes: 4,
MaxInFlight: 1,
}, 0, nil, nil, 0, nil, nil, nil, nil, func(_ *bulkHandle, bytes int64, chunks int) error {
if bytes != 4 || chunks != 1 {
t.Fatalf("release = (%d,%d), want (4,1)", bytes, chunks)
}
close(releaseStarted)
<-releaseUnblock
return nil
})
if err := bulk.pushOwnedChunk([]byte("ping")); err != nil {
t.Fatalf("pushOwnedChunk failed: %v", err)
}
buf := make([]byte, 4)
doneCh := make(chan error, 1)
go func() {
n, err := bulk.Read(buf)
if err != nil {
doneCh <- err
return
}
if got, want := n, 4; got != want {
doneCh <- errors.New("unexpected read size")
return
}
doneCh <- nil
}()
select {
case err := <-doneCh:
if err != nil {
t.Fatalf("Read failed: %v", err)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("Read should not block on async release sender")
}
select {
case <-releaseStarted:
case <-time.After(time.Second):
t.Fatal("window release sender did not start")
}
close(releaseUnblock)
cancel()
if bulk.releaseWorkerDone != nil {
select {
case <-bulk.releaseWorkerDone:
case <-time.After(time.Second):
t.Fatal("release worker did not exit")
}
}
}
+370 -43
View File
@@ -7,22 +7,25 @@ import (
) )
type BulkOpenRequest struct { type BulkOpenRequest struct {
BulkID string BulkID string
DataID uint64 DataID uint64
Range BulkRange FastPathVersion uint8
Metadata BulkMetadata Range BulkRange
ReadTimeout time.Duration Metadata BulkMetadata
WriteTimeout time.Duration ReadTimeout time.Duration
Dedicated bool WriteTimeout time.Duration
AttachToken string Dedicated bool
ChunkSize int DedicatedLaneID uint32
WindowBytes int AttachToken string
MaxInFlight int ChunkSize int
WindowBytes int
MaxInFlight int
} }
type BulkOpenResponse struct { type BulkOpenResponse struct {
BulkID string BulkID string
DataID uint64 DataID uint64
FastPathVersion uint8
Accepted bool Accepted bool
Dedicated bool Dedicated bool
AttachToken string AttachToken string
@@ -53,6 +56,18 @@ type BulkResetResponse struct {
Error string Error string
} }
type BulkReadyRequest struct {
BulkID string
DataID uint64
Error string
}
type BulkReadyResponse struct {
BulkID string
Accepted bool
Error string
}
type BulkReleaseRequest struct { type BulkReleaseRequest struct {
BulkID string BulkID string
DataID uint64 DataID uint64
@@ -73,6 +88,9 @@ func bindClientBulkControl(c *ClientCommon) {
c.SetLink(BulkResetSignalKey, func(msg *Message) { c.SetLink(BulkResetSignalKey, func(msg *Message) {
c.handleInboundBulkReset(msg) c.handleInboundBulkReset(msg)
}) })
c.SetLink(BulkReadySignalKey, func(msg *Message) {
c.handleInboundBulkReady(msg)
})
c.SetLink(BulkReleaseSignalKey, func(msg *Message) { c.SetLink(BulkReleaseSignalKey, func(msg *Message) {
c.handleInboundBulkRelease(msg) c.handleInboundBulkRelease(msg)
}) })
@@ -91,14 +109,188 @@ func bindServerBulkControl(s *ServerCommon) {
s.SetLink(BulkResetSignalKey, func(msg *Message) { s.SetLink(BulkResetSignalKey, func(msg *Message) {
s.handleInboundBulkReset(msg) s.handleInboundBulkReset(msg)
}) })
s.SetLink(BulkReadySignalKey, func(msg *Message) {
s.handleInboundBulkReady(msg)
})
s.SetLink(BulkReleaseSignalKey, func(msg *Message) { s.SetLink(BulkReleaseSignalKey, func(msg *Message) {
s.handleInboundBulkRelease(msg) s.handleInboundBulkRelease(msg)
}) })
} }
func bulkAcceptInfoFromClientBulk(bulk *bulkHandle) BulkAcceptInfo {
if bulk == nil {
return BulkAcceptInfo{}
}
return BulkAcceptInfo{
ID: bulk.ID(),
Range: bulk.Range(),
Metadata: bulk.Metadata(),
Dedicated: bulk.Dedicated(),
TransportGeneration: bulk.TransportGeneration(),
Bulk: bulk,
}
}
func bulkAcceptInfoFromServerBulk(bulk *bulkHandle, logical *LogicalConn, transport *TransportConn) BulkAcceptInfo {
if bulk == nil {
return BulkAcceptInfo{}
}
if logical == nil {
logical = bulk.LogicalConn()
}
if transport == nil {
transport = bulk.TransportConn()
}
return BulkAcceptInfo{
ID: bulk.ID(),
Range: bulk.Range(),
Metadata: bulk.Metadata(),
Dedicated: bulk.Dedicated(),
LogicalConn: logical,
TransportConn: transport,
TransportGeneration: bulk.TransportGeneration(),
Bulk: bulk,
}
}
func dispatchBulkAccept(handler func(BulkAcceptInfo) error, bulk *bulkHandle, info BulkAcceptInfo) error {
if bulk == nil {
return errBulkNotFound
}
if !bulk.markAcceptDispatched() {
return nil
}
var dispatchErr error
defer func() {
bulk.finishAcceptDispatch(dispatchErr)
}()
if handler == nil {
dispatchErr = errBulkHandlerNotConfigured
bulk.markReset(dispatchErr)
return dispatchErr
}
if err := handler(info); err != nil {
dispatchErr = err
bulk.markReset(dispatchErr)
return dispatchErr
}
return nil
}
func (c *ClientCommon) clientBulkAcceptReadyNotifier(bulk *bulkHandle) func(error) {
return func(readyErr error) {
if c == nil || bulk == nil {
return
}
req := BulkReadyRequest{
BulkID: bulk.ID(),
DataID: bulk.dataIDSnapshot(),
}
if readyErr != nil {
req.Error = readyErr.Error()
}
ctx, cancel := context.WithTimeout(context.Background(), defaultBulkAcceptReadyTimeout)
defer cancel()
if _, err := sendBulkReadyClient(ctx, c, req); err != nil && bulk.Context().Err() == nil {
bulk.markReset(err)
}
}
}
func sendBulkReadyServer(ctx context.Context, s *ServerCommon, logical *LogicalConn, transport *TransportConn, req BulkReadyRequest) error {
if s == nil {
return errBulkServerNil
}
if transport != nil {
if _, err := sendBulkReadyServerTransport(ctx, s, transport, req); err == nil {
return nil
} else if !errors.Is(err, errTransportDetached) && !errors.Is(err, errBulkTransportNil) {
return err
}
}
if logical == nil {
return errBulkLogicalConnNil
}
_, err := sendBulkReadyServerLogical(ctx, s, logical, req)
return err
}
func (s *ServerCommon) serverBulkAcceptReadyNotifier(bulk *bulkHandle, logical *LogicalConn, transport *TransportConn) func(error) {
return func(readyErr error) {
if s == nil || bulk == nil {
return
}
req := BulkReadyRequest{
BulkID: bulk.ID(),
DataID: bulk.dataIDSnapshot(),
}
if readyErr != nil {
req.Error = readyErr.Error()
}
ctx, cancel := context.WithTimeout(context.Background(), defaultBulkAcceptReadyTimeout)
defer cancel()
if err := sendBulkReadyServer(ctx, s, logical, transport, req); err != nil && bulk.Context().Err() == nil {
bulk.markReset(err)
}
}
}
func (c *ClientCommon) startClientBulkAcceptDispatch(bulk *bulkHandle) {
if c == nil || bulk == nil {
return
}
bulk.setAcceptNotify(c.clientBulkAcceptReadyNotifier(bulk))
go func() {
_ = c.dispatchClientBulkAccept(bulk)
}()
}
func (s *ServerCommon) startServerBulkAcceptDispatch(bulk *bulkHandle, logical *LogicalConn, transport *TransportConn) {
if s == nil || bulk == nil {
return
}
if logical == nil {
logical = bulk.LogicalConn()
}
if transport == nil {
transport = bulk.TransportConn()
}
bulk.setAcceptNotify(s.serverBulkAcceptReadyNotifier(bulk, logical, transport))
go func() {
_ = s.dispatchServerBulkAccept(bulk, logical, transport)
}()
}
func (c *ClientCommon) dispatchClientBulkAccept(bulk *bulkHandle) error {
if c == nil {
return errBulkClientNil
}
runtime := c.getBulkRuntime()
if runtime == nil {
return errBulkRuntimeNil
}
return dispatchBulkAccept(runtime.handlerSnapshot(), bulk, bulkAcceptInfoFromClientBulk(bulk))
}
func (s *ServerCommon) dispatchServerBulkAccept(bulk *bulkHandle, logical *LogicalConn, transport *TransportConn) error {
if s == nil {
return errBulkServerNil
}
runtime := s.getBulkRuntime()
if runtime == nil {
return errBulkRuntimeNil
}
return dispatchBulkAccept(runtime.handlerSnapshot(), bulk, bulkAcceptInfoFromServerBulk(bulk, logical, transport))
}
func (c *ClientCommon) handleInboundBulkOpen(msg *Message) { func (c *ClientCommon) handleInboundBulkOpen(msg *Message) {
req, err := decodeBulkOpenRequest(msg) req, err := decodeBulkOpenRequest(msg)
resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated} resp := BulkOpenResponse{
BulkID: req.BulkID,
DataID: req.DataID,
FastPathVersion: negotiateBulkFastPathVersion(req.FastPathVersion),
Dedicated: req.Dedicated,
}
if err != nil { if err != nil {
resp.Error = err.Error() resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp) replyBulkControlIfNeeded(msg, resp)
@@ -133,13 +325,6 @@ func (c *ClientCommon) handleInboundBulkOpen(msg *Message) {
replyBulkControlIfNeeded(msg, resp) replyBulkControlIfNeeded(msg, resp)
return return
} }
handler := runtime.handlerSnapshot()
if handler == nil {
bulk.markReset(errBulkHandlerNotConfigured)
resp.Error = errBulkHandlerNotConfigured.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if req.Dedicated { if req.Dedicated {
if err := c.attachDedicatedBulkSidecar(context.Background(), bulk); err != nil { if err := c.attachDedicatedBulkSidecar(context.Background(), bulk); err != nil {
bulk.markReset(err) bulk.markReset(err)
@@ -147,17 +332,14 @@ func (c *ClientCommon) handleInboundBulkOpen(msg *Message) {
replyBulkControlIfNeeded(msg, resp) replyBulkControlIfNeeded(msg, resp)
return return
} }
resp.Accepted = true
resp.DataID = bulk.dataIDSnapshot()
resp.TransportGeneration = bulk.TransportGeneration()
replyBulkControlIfNeeded(msg, resp)
c.startClientBulkAcceptDispatch(bulk)
return
} }
info := BulkAcceptInfo{ if err := c.dispatchClientBulkAccept(bulk); err != nil {
ID: bulk.ID(),
Range: bulk.Range(),
Metadata: bulk.Metadata(),
Dedicated: bulk.Dedicated(),
TransportGeneration: bulk.TransportGeneration(),
Bulk: bulk,
}
if err := handler(info); err != nil {
bulk.markReset(err)
resp.Error = err.Error() resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp) replyBulkControlIfNeeded(msg, resp)
return return
@@ -170,7 +352,12 @@ func (c *ClientCommon) handleInboundBulkOpen(msg *Message) {
func (s *ServerCommon) handleInboundBulkOpen(msg *Message) { func (s *ServerCommon) handleInboundBulkOpen(msg *Message) {
req, err := decodeBulkOpenRequest(msg) req, err := decodeBulkOpenRequest(msg)
resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated} resp := BulkOpenResponse{
BulkID: req.BulkID,
DataID: req.DataID,
FastPathVersion: negotiateBulkFastPathVersion(req.FastPathVersion),
Dedicated: req.Dedicated,
}
if err != nil { if err != nil {
resp.Error = err.Error() resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp) replyBulkControlIfNeeded(msg, resp)
@@ -218,25 +405,24 @@ func (s *ServerCommon) handleInboundBulkOpen(msg *Message) {
replyBulkControlIfNeeded(msg, resp) replyBulkControlIfNeeded(msg, resp)
return return
} }
handler := runtime.handlerSnapshot() s.attachServerDedicatedSidecarIfExists(logical, bulk)
if handler == nil { if runtime.handlerSnapshot() == nil {
bulk.markReset(errBulkHandlerNotConfigured) bulk.markReset(errBulkHandlerNotConfigured)
resp.Error = errBulkHandlerNotConfigured.Error() resp.Error = errBulkHandlerNotConfigured.Error()
replyBulkControlIfNeeded(msg, resp) replyBulkControlIfNeeded(msg, resp)
return return
} }
info := BulkAcceptInfo{ if req.Dedicated {
ID: bulk.ID(), resp.Accepted = true
Range: bulk.Range(), resp.DataID = bulk.dataIDSnapshot()
Metadata: bulk.Metadata(), resp.TransportGeneration = bulk.TransportGeneration()
Dedicated: bulk.Dedicated(), replyBulkControlIfNeeded(msg, resp)
LogicalConn: logical, if bulk.dedicatedAttachedSnapshot() {
TransportConn: transport, s.startServerBulkAcceptDispatch(bulk, logical, messageTransportConnSnapshot(msg))
TransportGeneration: bulk.TransportGeneration(), }
Bulk: bulk, return
} }
if err := handler(info); err != nil { if err := s.dispatchServerBulkAccept(bulk, logical, transport); err != nil {
bulk.markReset(err)
resp.Error = err.Error() resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp) replyBulkControlIfNeeded(msg, resp)
return return
@@ -357,6 +543,41 @@ func (c *ClientCommon) handleInboundBulkRelease(msg *Message) {
bulk.releaseOutboundWindow(req.Bytes, req.Chunks) bulk.releaseOutboundWindow(req.Bytes, req.Chunks)
} }
func (c *ClientCommon) handleInboundBulkReady(msg *Message) {
req, err := decodeBulkReadyRequest(msg)
resp := BulkReadyResponse{BulkID: req.BulkID}
if err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
runtime := c.getBulkRuntime()
if runtime == nil {
resp.Error = errBulkRuntimeNil.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
bulk, ok := runtime.lookup(clientFileScope(), req.BulkID)
if !ok && req.DataID != 0 {
bulk, ok = runtime.lookupByDataID(clientFileScope(), req.DataID)
}
if !ok {
resp.Error = errBulkNotFound.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if resp.BulkID == "" {
resp.BulkID = bulk.ID()
}
readyErr := bulkReadyRemoteError(req.Error)
bulk.markAcceptReady(readyErr)
if readyErr != nil {
bulk.markReset(readyErr)
}
resp.Accepted = true
replyBulkControlIfNeeded(msg, resp)
}
func (s *ServerCommon) handleInboundBulkReset(msg *Message) { func (s *ServerCommon) handleInboundBulkReset(msg *Message) {
req, err := decodeBulkResetRequest(msg) req, err := decodeBulkResetRequest(msg)
resp := BulkResetResponse{BulkID: req.BulkID} resp := BulkResetResponse{BulkID: req.BulkID}
@@ -390,6 +611,43 @@ func (s *ServerCommon) handleInboundBulkReset(msg *Message) {
replyBulkControlIfNeeded(msg, resp) replyBulkControlIfNeeded(msg, resp)
} }
func (s *ServerCommon) handleInboundBulkReady(msg *Message) {
req, err := decodeBulkReadyRequest(msg)
resp := BulkReadyResponse{BulkID: req.BulkID}
if err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
runtime := s.getBulkRuntime()
if runtime == nil {
resp.Error = errBulkRuntimeNil.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
logical := messageLogicalConnSnapshot(msg)
scope := serverFileScope(logical)
bulk, ok := runtime.lookup(scope, req.BulkID)
if !ok && req.DataID != 0 {
bulk, ok = runtime.lookupByDataID(scope, req.DataID)
}
if !ok {
resp.Error = errBulkNotFound.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if resp.BulkID == "" {
resp.BulkID = bulk.ID()
}
readyErr := bulkReadyRemoteError(req.Error)
bulk.markAcceptReady(readyErr)
if readyErr != nil {
bulk.markReset(readyErr)
}
resp.Accepted = true
replyBulkControlIfNeeded(msg, resp)
}
func (s *ServerCommon) handleInboundBulkRelease(msg *Message) { func (s *ServerCommon) handleInboundBulkRelease(msg *Message) {
req, err := decodeBulkReleaseRequest(msg) req, err := decodeBulkReleaseRequest(msg)
if err != nil { if err != nil {
@@ -625,11 +883,26 @@ func decodeBulkReleaseRequest(msg *Message) (BulkReleaseRequest, error) {
return req, nil return req, nil
} }
func decodeBulkReadyRequest(msg *Message) (BulkReadyRequest, error) {
var req BulkReadyRequest
if msg == nil {
return BulkReadyRequest{}, errBulkIDEmpty
}
if err := msg.Value.Orm(&req); err != nil {
return BulkReadyRequest{}, err
}
if req.BulkID == "" && req.DataID == 0 {
return BulkReadyRequest{}, errBulkIDEmpty
}
return req, nil
}
func decodeBulkOpenResponse(msg Message) (BulkOpenResponse, error) { func decodeBulkOpenResponse(msg Message) (BulkOpenResponse, error) {
var resp BulkOpenResponse var resp BulkOpenResponse
if err := msg.Value.Orm(&resp); err != nil { if err := msg.Value.Orm(&resp); err != nil {
return BulkOpenResponse{}, err return BulkOpenResponse{}, err
} }
resp.FastPathVersion = normalizeBulkFastPathVersion(resp.FastPathVersion)
return resp, bulkControlResultError("open", resp.Accepted, resp.Error, nil) return resp, bulkControlResultError("open", resp.Accepted, resp.Error, nil)
} }
@@ -649,6 +922,14 @@ func decodeBulkResetResponse(msg Message) (BulkResetResponse, error) {
return resp, bulkControlResultError("reset", resp.Accepted, resp.Error, nil) return resp, bulkControlResultError("reset", resp.Accepted, resp.Error, nil)
} }
func decodeBulkReadyResponse(msg Message) (BulkReadyResponse, error) {
var resp BulkReadyResponse
if err := msg.Value.Orm(&resp); err != nil {
return BulkReadyResponse{}, err
}
return resp, bulkControlResultError("ready", resp.Accepted, resp.Error, nil)
}
func bulkControlResultError(op string, accepted bool, message string, callErr error) error { func bulkControlResultError(op string, accepted bool, message string, callErr error) error {
if callErr != nil { if callErr != nil {
return callErr return callErr
@@ -697,6 +978,52 @@ func bulkRemoteResetError(message string) error {
return errors.New(message) return errors.New(message)
} }
func bulkReadyRemoteError(message string) error {
if message == "" {
return nil
}
return bulkControlMessageError(message)
}
func sendBulkReadyClient(ctx context.Context, c Client, req BulkReadyRequest) (BulkReadyResponse, error) {
if c == nil {
return BulkReadyResponse{}, errBulkClientNil
}
msg, err := c.SendObjCtx(ctx, BulkReadySignalKey, req)
if err != nil {
return BulkReadyResponse{}, err
}
return decodeBulkReadyResponse(msg)
}
func sendBulkReadyServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkReadyRequest) (BulkReadyResponse, error) {
if s == nil {
return BulkReadyResponse{}, errBulkServerNil
}
if logical == nil {
return BulkReadyResponse{}, errBulkLogicalConnNil
}
msg, err := s.SendObjCtxLogical(ctx, logical, BulkReadySignalKey, req)
if err != nil {
return BulkReadyResponse{}, err
}
return decodeBulkReadyResponse(msg)
}
func sendBulkReadyServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkReadyRequest) (BulkReadyResponse, error) {
if s == nil {
return BulkReadyResponse{}, errBulkServerNil
}
if transport == nil {
return BulkReadyResponse{}, errBulkTransportNil
}
msg, err := s.SendObjCtxTransport(ctx, transport, BulkReadySignalKey, req)
if err != nil {
return BulkReadyResponse{}, err
}
return decodeBulkReadyResponse(msg)
}
func bulkTransportGeneration(logical *LogicalConn, transport *TransportConn) uint64 { func bulkTransportGeneration(logical *LogicalConn, transport *TransportConn) uint64 {
return streamTransportGeneration(logical, transport) return streamTransportGeneration(logical, transport)
} }
+1019 -186
View File
File diff suppressed because it is too large Load Diff
+537
View File
@@ -0,0 +1,537 @@
package notify
import (
"bytes"
"context"
"encoding/binary"
"errors"
"net"
"testing"
"time"
"b612.me/stario"
)
type bulkAttachScriptConn struct {
readBuf *bytes.Reader
writeBuf bytes.Buffer
}
func newBulkAttachScriptConn(inbound []byte) *bulkAttachScriptConn {
return &bulkAttachScriptConn{
readBuf: bytes.NewReader(append([]byte(nil), inbound...)),
}
}
func (c *bulkAttachScriptConn) Read(p []byte) (int, error) { return c.readBuf.Read(p) }
func (c *bulkAttachScriptConn) Write(p []byte) (int, error) { return c.writeBuf.Write(p) }
func (c *bulkAttachScriptConn) Close() error { return nil }
func (c *bulkAttachScriptConn) LocalAddr() net.Addr { return bulkAttachTestAddr("local") }
func (c *bulkAttachScriptConn) RemoteAddr() net.Addr { return bulkAttachTestAddr("remote") }
func (c *bulkAttachScriptConn) SetDeadline(time.Time) error { return nil }
func (c *bulkAttachScriptConn) SetReadDeadline(time.Time) error {
return nil
}
func (c *bulkAttachScriptConn) SetWriteDeadline(time.Time) error {
return nil
}
func (c *bulkAttachScriptConn) writtenBytes() []byte {
return append([]byte(nil), c.writeBuf.Bytes()...)
}
type bulkAttachTestAddr string
func (a bulkAttachTestAddr) Network() string { return "tcp" }
func (a bulkAttachTestAddr) String() string { return string(a) }
func encodeDedicatedRecordForAttachTest(payload []byte) []byte {
out := make([]byte, bulkDedicatedRecordHeaderLen+len(payload))
copy(out[:4], bulkDedicatedRecordMagic)
binary.BigEndian.PutUint32(out[4:8], uint32(len(payload)))
copy(out[bulkDedicatedRecordHeaderLen:], payload)
return out
}
func TestSendDedicatedBulkAttachRequestKeepsCoalescedDedicatedPayloadUnread(t *testing.T) {
client := NewClient().(*ClientCommon)
UseLegacySecurityClient(client)
client.msgID = 100
bulk := newBulkHandle(context.Background(), newBulkRuntime("dedicated-attach-test"), clientFileScope(), BulkOpenRequest{
BulkID: "bulk-attach-test",
DataID: 1,
Dedicated: true,
AttachToken: "attach-token",
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
encodedResp, err := client.sequenceEn(bulkAttachResponse{Accepted: true})
if err != nil {
t.Fatalf("encode bulkAttachResponse failed: %v", err)
}
replyFrame, err := encodeDirectSignalFrame(stario.NewQueue(), client.sequenceEn, client.msgEn, client.SecretKey, TransferMsg{
ID: 101,
Key: systemBulkAttachKey,
Value: encodedResp,
Type: MSG_SYS_REPLY,
})
if err != nil {
t.Fatalf("encode attach reply frame failed: %v", err)
}
dedicatedPayload := []byte("dedicated-tail-bytes")
conn := newBulkAttachScriptConn(append(replyFrame, encodeDedicatedRecordForAttachTest(dedicatedPayload)...))
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
resp, err := client.sendDedicatedBulkAttachRequest(ctx, conn, bulk)
if err != nil {
t.Fatalf("sendDedicatedBulkAttachRequest failed: %v", err)
}
if !resp.Accepted {
t.Fatalf("bulk attach response = %+v, want accepted", resp)
}
parsedReq := stario.NewQueue()
var reqMsg TransferMsg
if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request", func(msgq stario.MsgQueue) error {
transfer, err := decodeDirectSignalPayload(client.sequenceDe, client.msgDe, client.SecretKey, msgq.Msg)
if err != nil {
return err
}
reqMsg = transfer
return nil
}); err != nil {
t.Fatalf("parse written attach request failed: %v", err)
}
if reqMsg.Key != systemBulkAttachKey || reqMsg.Type != MSG_SYS_WAIT {
t.Fatalf("attach request message mismatch: %+v", reqMsg)
}
readPayload, err := readBulkDedicatedRecord(conn)
if err != nil {
t.Fatalf("readBulkDedicatedRecord after attach failed: %v", err)
}
if !bytes.Equal(readPayload, dedicatedPayload) {
t.Fatalf("dedicated payload mismatch: got %q want %q", string(readPayload), string(dedicatedPayload))
}
}
func TestSendDedicatedBulkAttachRequestUsesBootstrapProtectionEvenAfterSteadySwitch(t *testing.T) {
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
client.msgID = 100
alternate, err := deriveModernPSKProtectionProfile([]byte("notify-dedicated-attach-other-secret"), integrationModernPSKOptions(), ProtectionManaged)
if err != nil {
t.Fatalf("deriveModernPSKProtectionProfile(alternate) failed: %v", err)
}
client.setClientTransportProtectionProfile(alternate)
bulk := newBulkHandle(context.Background(), newBulkRuntime("dedicated-attach-bootstrap-test"), clientFileScope(), BulkOpenRequest{
BulkID: "bulk-attach-bootstrap-test",
DataID: 1,
Dedicated: true,
AttachToken: "attach-token",
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
bootstrap := client.clientDedicatedBulkAttachTransportProtectionProfile()
encodedResp, err := client.sequenceEn(bulkAttachResponse{Accepted: true})
if err != nil {
t.Fatalf("encode bulkAttachResponse failed: %v", err)
}
replyFrame, err := encodeDirectSignalFrame(stario.NewQueue(), client.sequenceEn, bootstrap.msgEn, bootstrap.secretKey, TransferMsg{
ID: 101,
Key: systemBulkAttachKey,
Value: encodedResp,
Type: MSG_SYS_REPLY,
})
if err != nil {
t.Fatalf("encode attach reply frame failed: %v", err)
}
conn := newBulkAttachScriptConn(replyFrame)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
resp, err := client.sendDedicatedBulkAttachRequest(ctx, conn, bulk)
if err != nil {
t.Fatalf("sendDedicatedBulkAttachRequest failed: %v", err)
}
if !resp.Accepted {
t.Fatalf("bulk attach response = %+v, want accepted", resp)
}
parsedReq := stario.NewQueue()
var reqMsg TransferMsg
if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request", func(msgq stario.MsgQueue) error {
transfer, err := decodeDirectSignalPayload(client.sequenceDe, bootstrap.msgDe, bootstrap.secretKey, msgq.Msg)
if err != nil {
return err
}
reqMsg = transfer
return nil
}); err != nil {
t.Fatalf("parse written attach request with bootstrap profile failed: %v", err)
}
if reqMsg.Key != systemBulkAttachKey || reqMsg.Type != MSG_SYS_WAIT {
t.Fatalf("attach request message mismatch: %+v", reqMsg)
}
if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request-current", func(msgq stario.MsgQueue) error {
_, err := decodeDirectSignalPayload(client.sequenceDe, alternate.msgDe, alternate.secretKey, msgq.Msg)
return err
}); !errors.Is(err, errTransportPayloadDecryptFailed) {
t.Fatalf("decode written attach request with current steady profile error = %v, want %v", err, errTransportPayloadDecryptFailed)
}
}
func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHandoff(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
sidecarLeft, sidecarRight := net.Pipe()
defer sidecarRight.Close()
current := server.bootstrapAcceptedLogical("dedicated-attach-current", nil, sidecarLeft)
if current == nil {
t.Fatal("bootstrapAcceptedLogical(current) should return logical")
}
target := server.bootstrapAcceptedLogical("dedicated-attach-target", nil, nil)
if target == nil {
t.Fatal("bootstrapAcceptedLogical(target) should return logical")
}
bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{
BulkID: "server-dedicated-attach-test",
DataID: 7,
Dedicated: true,
AttachToken: "attach-token",
}, 0, target, nil, 0, nil, nil, nil, nil, nil)
bulk.markAcceptHandled()
if err := server.getBulkRuntime().register(serverFileScope(target), bulk); err != nil {
t.Fatalf("register bulk runtime failed: %v", err)
}
reqPayload, err := server.sequenceEn(bulkAttachRequest{
PeerID: target.ID(),
BulkID: bulk.ID(),
AttachToken: "attach-token",
})
if err != nil {
t.Fatalf("encode bulkAttachRequest failed: %v", err)
}
msg := Message{
NetType: NET_SERVER,
LogicalConn: current,
ClientConn: current.compatClientConn(),
TransferMsg: TransferMsg{
ID: 42,
Key: systemBulkAttachKey,
Value: reqPayload,
Type: MSG_SYS_WAIT,
},
inboundConn: sidecarLeft,
Time: time.Now(),
}
type attachReplyResult struct {
transfer TransferMsg
resp bulkAttachResponse
err error
}
replyCh := make(chan attachReplyResult, 1)
go func() {
_ = sidecarRight.SetReadDeadline(time.Now().Add(time.Second))
replyPayload, err := readDirectSignalFramePayload(sidecarRight)
if err != nil {
replyCh <- attachReplyResult{err: err}
return
}
transfer, err := decodeDirectSignalPayload(server.sequenceDe, current.msgDeSnapshot(), current.secretKeySnapshot(), replyPayload)
if err != nil {
replyCh <- attachReplyResult{err: err}
return
}
resp, err := decodeBulkAttachResponse(server.sequenceDe, transfer.Value)
replyCh <- attachReplyResult{transfer: transfer, resp: resp, err: err}
}()
if !server.handleBulkAttachSystemMessage(msg) {
t.Fatal("handleBulkAttachSystemMessage should accept dedicated attach message")
}
var result attachReplyResult
select {
case result = <-replyCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for direct attach reply")
}
if result.err != nil {
t.Fatalf("read direct attach reply failed: %v", result.err)
}
transfer := result.transfer
if transfer.ID != msg.ID || transfer.Key != systemBulkAttachKey || transfer.Type != MSG_SYS_REPLY {
t.Fatalf("attach reply mismatch: %+v", transfer)
}
resp := result.resp
if !resp.Accepted || resp.Error != "" {
t.Fatalf("bulk attach response = %+v, want accepted", resp)
}
if got := bulk.dedicatedConnSnapshot(); got != sidecarLeft {
t.Fatalf("dedicated conn mismatch: got %v want %v", got, sidecarLeft)
}
if current.transportAttachedSnapshot() {
t.Fatal("attach sidecar logical transport should be detached after handoff")
}
if got := server.GetLogicalConn(current.ID()); got != nil {
t.Fatalf("attach sidecar logical should be removed after handoff, got %+v", got)
}
}
func TestHandleBulkAttachSystemMessageDoesNotExposeSharedSidecarBeforeReplyCompletes(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
sidecarLeft, sidecarRight := net.Pipe()
defer sidecarRight.Close()
current := server.bootstrapAcceptedLogical("dedicated-attach-current-blocked", nil, sidecarLeft)
if current == nil {
t.Fatal("bootstrapAcceptedLogical(current) should return logical")
}
target := server.bootstrapAcceptedLogical("dedicated-attach-target-blocked", nil, nil)
if target == nil {
t.Fatal("bootstrapAcceptedLogical(target) should return logical")
}
currentBulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{
BulkID: "server-dedicated-current",
DataID: 17,
Dedicated: true,
AttachToken: "attach-token",
}, 0, target, nil, 0, nil, nil, nil, nil, nil)
currentBulk.markAcceptHandled()
if err := server.getBulkRuntime().register(serverFileScope(target), currentBulk); err != nil {
t.Fatalf("register current bulk runtime failed: %v", err)
}
pendingBulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{
BulkID: "server-dedicated-pending",
DataID: 18,
Dedicated: true,
AttachToken: "attach-token-2",
}, 0, target, nil, 0, nil, nil, nil, nil, nil)
pendingBulk.markAcceptHandled()
if err := server.getBulkRuntime().register(serverFileScope(target), pendingBulk); err != nil {
t.Fatalf("register pending bulk runtime failed: %v", err)
}
reqPayload, err := server.sequenceEn(bulkAttachRequest{
PeerID: target.ID(),
BulkID: currentBulk.ID(),
AttachToken: "attach-token",
})
if err != nil {
t.Fatalf("encode bulkAttachRequest failed: %v", err)
}
msg := Message{
NetType: NET_SERVER,
LogicalConn: current,
ClientConn: current.compatClientConn(),
TransferMsg: TransferMsg{
ID: 77,
Key: systemBulkAttachKey,
Value: reqPayload,
Type: MSG_SYS_WAIT,
},
inboundConn: sidecarLeft,
Time: time.Now(),
}
done := make(chan struct{})
go func() {
defer close(done)
_ = server.handleBulkAttachSystemMessage(msg)
}()
time.Sleep(50 * time.Millisecond)
if got := pendingBulk.dedicatedConnSnapshot(); got != nil {
t.Fatal("pending dedicated bulk should not observe shared sidecar before attach reply is fully sent")
}
if got := server.serverDedicatedSidecarSnapshot(target); got != nil {
t.Fatal("server dedicated sidecar should not be published before attach reply is fully sent")
}
type attachReplyResult struct {
transfer TransferMsg
resp bulkAttachResponse
err error
}
replyCh := make(chan attachReplyResult, 1)
go func() {
_ = sidecarRight.SetReadDeadline(time.Now().Add(time.Second))
replyPayload, err := readDirectSignalFramePayload(sidecarRight)
if err != nil {
replyCh <- attachReplyResult{err: err}
return
}
transfer, err := decodeDirectSignalPayload(server.sequenceDe, current.msgDeSnapshot(), current.secretKeySnapshot(), replyPayload)
if err != nil {
replyCh <- attachReplyResult{err: err}
return
}
resp, err := decodeBulkAttachResponse(server.sequenceDe, transfer.Value)
replyCh <- attachReplyResult{transfer: transfer, resp: resp, err: err}
}()
select {
case result := <-replyCh:
if result.err != nil {
t.Fatalf("read direct attach reply failed: %v", result.err)
}
if !result.resp.Accepted {
t.Fatalf("bulk attach response = %+v, want accepted", result.resp)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for direct attach reply")
}
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for handleBulkAttachSystemMessage to finish")
}
if got := pendingBulk.dedicatedConnSnapshot(); got != sidecarLeft {
t.Fatalf("pending dedicated bulk conn mismatch after reply: got %v want %v", got, sidecarLeft)
}
if got := server.serverDedicatedSidecarSnapshot(target); got == nil || got.conn != sidecarLeft {
t.Fatal("server dedicated sidecar should be published after attach reply completes")
}
}
func TestBulkAttachResponseErrorCarriesStructuredFields(t *testing.T) {
resp := toBulkAttachResponseError(&bulkAttachError{
Code: bulkAttachErrorCodeTokenMismatch,
Retryable: false,
Message: "bulk attach token mismatch",
FailedSeq: 7,
FailedBulk: "bulk-1",
}, "fallback-bulk")
if resp.Accepted {
t.Fatalf("Accepted = %v, want false", resp.Accepted)
}
if got, want := resp.Code, string(bulkAttachErrorCodeTokenMismatch); got != want {
t.Fatalf("Code = %q, want %q", got, want)
}
if got, want := resp.Retryable, false; got != want {
t.Fatalf("Retryable = %v, want %v", got, want)
}
if got, want := resp.FailedSeq, uint64(7); got != want {
t.Fatalf("FailedSeq = %d, want %d", got, want)
}
if got, want := resp.FailedBulk, "bulk-1"; got != want {
t.Fatalf("FailedBulk = %q, want %q", got, want)
}
}
func TestResolveInboundDedicatedBulkRejectsAlreadyAttachedAfterDataStarted(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
currentLeft, currentRight := net.Pipe()
defer currentRight.Close()
current := server.bootstrapAcceptedLogical("dedicated-attach-current", nil, currentLeft)
if current == nil {
t.Fatal("bootstrapAcceptedLogical(current) should return logical")
}
target := server.bootstrapAcceptedLogical("dedicated-attach-target", nil, nil)
if target == nil {
t.Fatal("bootstrapAcceptedLogical(target) should return logical")
}
bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{
BulkID: "server-dedicated-attach-test",
DataID: 7,
Dedicated: true,
AttachToken: "attach-token",
}, 0, target, nil, 0, nil, nil, nil, nil, nil)
bulk.markAcceptHandled()
if err := server.getBulkRuntime().register(serverFileScope(target), bulk); err != nil {
t.Fatalf("register bulk runtime failed: %v", err)
}
attachedLeft, attachedRight := net.Pipe()
defer attachedRight.Close()
if err := bulk.attachDedicatedConn(attachedLeft); err != nil {
t.Fatalf("attachDedicatedConn failed: %v", err)
}
bulk.markDedicatedDataStarted()
_, _, err := server.resolveInboundDedicatedBulk(current, bulkAttachRequest{
PeerID: target.ID(),
BulkID: bulk.ID(),
AttachToken: "attach-token",
})
if err == nil {
t.Fatal("resolveInboundDedicatedBulk should reject duplicate attach after data started")
}
var attachErr *bulkAttachError
if !errors.As(err, &attachErr) || attachErr == nil {
t.Fatalf("resolveInboundDedicatedBulk error type = %T, want *bulkAttachError", err)
}
if got, want := attachErr.Code, bulkAttachErrorCodeAlreadyAttached; got != want {
t.Fatalf("attach error code = %q, want %q", got, want)
}
if attachErr.Retryable {
t.Fatalf("attach error retryable = true, want false")
}
}
func TestResolveInboundDedicatedBulkAllowsReattachBeforeDataStarts(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
currentLeft, currentRight := net.Pipe()
defer currentRight.Close()
current := server.bootstrapAcceptedLogical("dedicated-attach-current", nil, currentLeft)
if current == nil {
t.Fatal("bootstrapAcceptedLogical(current) should return logical")
}
target := server.bootstrapAcceptedLogical("dedicated-attach-target", nil, nil)
if target == nil {
t.Fatal("bootstrapAcceptedLogical(target) should return logical")
}
bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{
BulkID: "server-dedicated-attach-test",
DataID: 7,
Dedicated: true,
AttachToken: "attach-token",
}, 0, target, nil, 0, nil, nil, nil, nil, nil)
bulk.markAcceptHandled()
if err := server.getBulkRuntime().register(serverFileScope(target), bulk); err != nil {
t.Fatalf("register bulk runtime failed: %v", err)
}
attachedLeft, attachedRight := net.Pipe()
defer attachedRight.Close()
if err := bulk.attachDedicatedConn(attachedLeft); err != nil {
t.Fatalf("attachDedicatedConn failed: %v", err)
}
resolvedLogical, resolvedBulk, err := server.resolveInboundDedicatedBulk(current, bulkAttachRequest{
PeerID: target.ID(),
BulkID: bulk.ID(),
AttachToken: "attach-token",
})
if err != nil {
t.Fatalf("resolveInboundDedicatedBulk failed: %v", err)
}
if resolvedLogical != target {
t.Fatalf("resolved logical mismatch: got %v want %v", resolvedLogical, target)
}
if resolvedBulk != bulk {
t.Fatalf("resolved bulk mismatch: got %v want %v", resolvedBulk, bulk)
}
}
+411 -30
View File
@@ -12,14 +12,18 @@ import (
) )
const ( const (
bulkDedicatedBatchMagic = "NBD2" bulkDedicatedBatchMagic = "NBD2"
bulkDedicatedBatchVersion = 1 bulkDedicatedBatchVersion = 1
bulkDedicatedBatchHeaderLen = 20 bulkDedicatedBatchHeaderLen = 20
bulkDedicatedBatchItemHeaderLen = 16 bulkDedicatedBatchItemHeaderLen = 16
bulkDedicatedBatchMaxItems = 32 bulkDedicatedSuperBatchMagic = "NBD3"
bulkDedicatedBatchMaxPlainBytes = 8 * 1024 * 1024 bulkDedicatedSuperBatchVersion = 1
bulkDedicatedSendQueueSize = bulkDedicatedBatchMaxItems bulkDedicatedSuperBatchHeaderLen = 12
bulkDedicatedReleasePayloadLen = 12 bulkDedicatedSuperBatchGroupHeaderLen = 12
bulkDedicatedBatchMaxItems = 64
bulkDedicatedBatchMaxPlainBytes = 16 * 1024 * 1024
bulkDedicatedSendQueueSize = bulkDedicatedBatchMaxItems
bulkDedicatedReleasePayloadLen = 12
) )
const ( const (
@@ -46,6 +50,16 @@ type bulkDedicatedSendRequest struct {
Payload []byte Payload []byte
} }
type bulkDedicatedOutboundBatch struct {
DataID uint64
Items []bulkDedicatedSendRequest
}
type bulkDedicatedInboundBatch struct {
DataID uint64
Items []bulkDedicatedBatchItem
}
type bulkDedicatedBatchRequest struct { type bulkDedicatedBatchRequest struct {
Ctx context.Context Ctx context.Context
Items []bulkDedicatedSendRequest Items []bulkDedicatedSendRequest
@@ -165,8 +179,7 @@ func (s *bulkDedicatedSender) submitWriteBatch(ctx context.Context, items []bulk
if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted { if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted {
return err return err
} }
queuedItems := make([]bulkDedicatedSendRequest, len(items)) queuedItems := copyBulkDedicatedSendRequests(items)
copy(queuedItems, items)
return s.submitBatch(ctx, queuedItems, true) return s.submitBatch(ctx, queuedItems, true)
} }
@@ -458,6 +471,37 @@ func (s *bulkDedicatedSender) stoppedErr() error {
return errTransportDetached return errTransportDetached
} }
func cloneBulkDedicatedSendRequests(items []bulkDedicatedSendRequest) []bulkDedicatedSendRequest {
if len(items) == 0 {
return nil
}
cloned := make([]bulkDedicatedSendRequest, len(items))
for i, item := range items {
cloned[i] = item
if len(item.Payload) > 0 {
cloned[i].Payload = append([]byte(nil), item.Payload...)
}
}
return cloned
}
func copyBulkDedicatedSendRequests(items []bulkDedicatedSendRequest) []bulkDedicatedSendRequest {
if len(items) == 0 {
return nil
}
copied := make([]bulkDedicatedSendRequest, len(items))
copy(copied, items)
return copied
}
func bulkDedicatedSendRequestsLen(items []bulkDedicatedSendRequest) int {
total := 0
for _, item := range items {
total += bulkDedicatedSendRequestLen(item)
}
return total
}
func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int { func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int {
return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload)) return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload))
} }
@@ -466,6 +510,83 @@ func bulkDedicatedSendRequestLenFromPayloadLen(payloadLen int) int {
return bulkDedicatedBatchItemHeaderLen + payloadLen return bulkDedicatedBatchItemHeaderLen + payloadLen
} }
func bulkDedicatedBatchesPlainLen(batches []bulkDedicatedOutboundBatch) int {
switch len(batches) {
case 0:
return 0
case 1:
return bulkDedicatedBatchPlainLen(batches[0].Items)
default:
total := bulkDedicatedSuperBatchHeaderLen
for _, batch := range batches {
total += bulkDedicatedSuperBatchGroupHeaderLen + bulkDedicatedSendRequestsLen(batch.Items)
}
return total
}
}
func encodeBulkDedicatedBatchesPlain(batches []bulkDedicatedOutboundBatch) ([]byte, error) {
if len(batches) == 0 {
return nil, errBulkFastPayloadInvalid
}
total := bulkDedicatedBatchesPlainLen(batches)
buf := make([]byte, total)
if err := writeBulkDedicatedBatchesPlain(buf, batches); err != nil {
return nil, err
}
return buf, nil
}
func writeBulkDedicatedBatchesPlain(buf []byte, batches []bulkDedicatedOutboundBatch) error {
switch len(batches) {
case 0:
return errBulkFastPayloadInvalid
case 1:
return writeBulkDedicatedBatchPlain(buf, batches[0].DataID, batches[0].Items)
default:
return writeBulkDedicatedSuperBatchPlain(buf, batches)
}
}
func writeBulkDedicatedSuperBatchPlain(buf []byte, batches []bulkDedicatedOutboundBatch) error {
if len(batches) <= 1 {
return errBulkFastPayloadInvalid
}
if len(buf) != bulkDedicatedBatchesPlainLen(batches) || len(buf) > bulkDedicatedBatchMaxPlainBytes {
return errBulkFastPayloadInvalid
}
copy(buf[:4], bulkDedicatedSuperBatchMagic)
buf[4] = bulkDedicatedSuperBatchVersion
binary.BigEndian.PutUint32(buf[8:12], uint32(len(batches)))
offset := bulkDedicatedSuperBatchHeaderLen
totalItems := 0
for _, batch := range batches {
if batch.DataID == 0 || len(batch.Items) == 0 {
return errBulkFastPayloadInvalid
}
totalItems += len(batch.Items)
if totalItems > bulkDedicatedBatchMaxItems {
return errBulkFastPayloadInvalid
}
binary.BigEndian.PutUint64(buf[offset:offset+8], batch.DataID)
binary.BigEndian.PutUint32(buf[offset+8:offset+12], uint32(len(batch.Items)))
offset += bulkDedicatedSuperBatchGroupHeaderLen
for _, item := range batch.Items {
buf[offset] = item.Type
buf[offset+1] = item.Flags
binary.BigEndian.PutUint64(buf[offset+4:offset+12], item.Seq)
binary.BigEndian.PutUint32(buf[offset+12:offset+16], uint32(len(item.Payload)))
offset += bulkDedicatedBatchItemHeaderLen
copy(buf[offset:offset+len(item.Payload)], item.Payload)
offset += len(item.Payload)
}
}
if offset != len(buf) {
return errBulkFastPayloadInvalid
}
return nil
}
func encodeBulkDedicatedReleasePayload(bytes int64, chunks int) ([]byte, error) { func encodeBulkDedicatedReleasePayload(bytes int64, chunks int) ([]byte, error) {
if bytes <= 0 && chunks <= 0 { if bytes <= 0 && chunks <= 0 {
return nil, errBulkFastPayloadInvalid return nil, errBulkFastPayloadInvalid
@@ -492,24 +613,26 @@ func decodeBulkDedicatedReleasePayload(payload []byte) (int64, int, error) {
} }
func encodeBulkDedicatedBatchPlain(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { func encodeBulkDedicatedBatchPlain(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
if dataID == 0 || len(items) == 0 { return encodeBulkDedicatedBatchesPlain([]bulkDedicatedOutboundBatch{{
return nil, errBulkFastPayloadInvalid DataID: dataID,
} Items: items,
total := bulkDedicatedBatchPlainLen(items) }})
buf := make([]byte, total)
if err := writeBulkDedicatedBatchPlain(buf, dataID, items); err != nil {
return nil, err
}
return buf, nil
} }
func encodeBulkDedicatedBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { func encodeBulkDedicatedBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
return encodeBulkDedicatedBatchesPayloadFast(encode, secretKey, []bulkDedicatedOutboundBatch{{
DataID: dataID,
Items: items,
}})
}
func encodeBulkDedicatedBatchesPayloadFast(encode transportFastPlainEncoder, secretKey []byte, batches []bulkDedicatedOutboundBatch) ([]byte, error) {
if encode == nil { if encode == nil {
return nil, errTransportPayloadEncryptFailed return nil, errTransportPayloadEncryptFailed
} }
plainLen := bulkDedicatedBatchPlainLen(items) plainLen := bulkDedicatedBatchesPlainLen(batches)
return encode(secretKey, plainLen, func(dst []byte) error { return encode(secretKey, plainLen, func(dst []byte) error {
return writeBulkDedicatedBatchPlain(dst, dataID, items) return writeBulkDedicatedBatchesPlain(dst, batches)
}) })
} }
@@ -593,29 +716,268 @@ func decodeBulkDedicatedBatchPlain(payload []byte) (uint64, []bulkDedicatedBatch
return dataID, items, true, nil return dataID, items, true, nil
} }
func decodeDedicatedBulkInboundItems(expectedDataID uint64, plain []byte) ([]bulkDedicatedBatchItem, error) { func decodeBulkDedicatedSuperBatchPlain(payload []byte) ([]bulkDedicatedInboundBatch, bool, error) {
if len(payload) < 4 || string(payload[:4]) != bulkDedicatedSuperBatchMagic {
return nil, false, nil
}
if len(payload) < bulkDedicatedSuperBatchHeaderLen {
return nil, true, errBulkFastPayloadInvalid
}
if payload[4] != bulkDedicatedSuperBatchVersion {
return nil, true, errBulkFastPayloadInvalid
}
groupCount := int(binary.BigEndian.Uint32(payload[8:12]))
if groupCount <= 0 {
return nil, true, errBulkFastPayloadInvalid
}
batches := make([]bulkDedicatedInboundBatch, 0, groupCount)
offset := bulkDedicatedSuperBatchHeaderLen
totalItems := 0
for i := 0; i < groupCount; i++ {
if len(payload)-offset < bulkDedicatedSuperBatchGroupHeaderLen {
return nil, true, errBulkFastPayloadInvalid
}
dataID := binary.BigEndian.Uint64(payload[offset : offset+8])
count := int(binary.BigEndian.Uint32(payload[offset+8 : offset+12]))
offset += bulkDedicatedSuperBatchGroupHeaderLen
if dataID == 0 || count <= 0 {
return nil, true, errBulkFastPayloadInvalid
}
totalItems += count
if totalItems > bulkDedicatedBatchMaxItems {
return nil, true, errBulkFastPayloadInvalid
}
items := make([]bulkDedicatedBatchItem, 0, count)
for j := 0; j < count; j++ {
if len(payload)-offset < bulkDedicatedBatchItemHeaderLen {
return nil, true, errBulkFastPayloadInvalid
}
itemType := payload[offset]
switch itemType {
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
default:
return nil, true, errBulkFastPayloadInvalid
}
flags := payload[offset+1]
seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16]))
offset += bulkDedicatedBatchItemHeaderLen
if dataLen < 0 || len(payload)-offset < dataLen {
return nil, true, errBulkFastPayloadInvalid
}
items = append(items, bulkDedicatedBatchItem{
Type: itemType,
Flags: flags,
Seq: seq,
Payload: payload[offset : offset+dataLen],
})
offset += dataLen
}
batches = append(batches, bulkDedicatedInboundBatch{
DataID: dataID,
Items: items,
})
}
if offset != len(payload) {
return nil, true, errBulkFastPayloadInvalid
}
return batches, true, nil
}
func decodeDedicatedBulkInboundBatches(plain []byte) ([]bulkDedicatedInboundBatch, error) {
if batches, matched, err := decodeBulkDedicatedSuperBatchPlain(plain); matched {
if err != nil {
return nil, err
}
return batches, nil
}
if dataID, items, matched, err := decodeBulkDedicatedBatchPlain(plain); matched { if dataID, items, matched, err := decodeBulkDedicatedBatchPlain(plain); matched {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if expectedDataID == 0 || dataID != expectedDataID { return []bulkDedicatedInboundBatch{{
return nil, errBulkFastPayloadInvalid DataID: dataID,
} Items: items,
return items, nil }}, nil
} }
frame, matched, err := decodeBulkFastFrame(plain) frame, matched, err := decodeBulkFastFrame(plain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !matched || expectedDataID == 0 || frame.DataID != expectedDataID { if !matched || frame.DataID == 0 {
return nil, errBulkFastPayloadInvalid return nil, errBulkFastPayloadInvalid
} }
return []bulkDedicatedBatchItem{{ return []bulkDedicatedInboundBatch{{
DataID: frame.DataID,
Items: []bulkDedicatedBatchItem{{
Type: frame.Type,
Flags: frame.Flags,
Seq: frame.Seq,
Payload: frame.Payload,
}},
}}, nil
}
func decodeDedicatedBulkInboundPayload(plain []byte) (uint64, []bulkDedicatedBatchItem, error) {
batches, err := decodeDedicatedBulkInboundBatches(plain)
if err != nil {
return 0, nil, err
}
if len(batches) != 1 {
return 0, nil, errBulkFastPayloadInvalid
}
return batches[0].DataID, batches[0].Items, nil
}
func decodeDedicatedBulkInboundItems(expectedDataID uint64, plain []byte) ([]bulkDedicatedBatchItem, error) {
dataID, items, err := decodeDedicatedBulkInboundPayload(plain)
if err != nil {
return nil, err
}
if expectedDataID == 0 || dataID != expectedDataID {
return nil, errBulkFastPayloadInvalid
}
return items, nil
}
func hasBulkDedicatedMagic(payload []byte, magic string) bool {
return len(payload) >= 4 &&
payload[0] == magic[0] &&
payload[1] == magic[1] &&
payload[2] == magic[2] &&
payload[3] == magic[3]
}
func walkDedicatedBulkInboundPayload(plain []byte, visit func(dataID uint64, item bulkDedicatedBatchItem) error) error {
if visit == nil {
return errBulkFastPayloadInvalid
}
if hasBulkDedicatedMagic(plain, bulkDedicatedSuperBatchMagic) {
return walkDedicatedBulkInboundSuperBatchPlain(plain, visit)
}
if hasBulkDedicatedMagic(plain, bulkDedicatedBatchMagic) {
return walkDedicatedBulkInboundBatchPlain(plain, visit)
}
frame, matched, err := decodeBulkFastFrame(plain)
if err != nil {
return err
}
if !matched || frame.DataID == 0 {
return errBulkFastPayloadInvalid
}
return visit(frame.DataID, bulkDedicatedBatchItem{
Type: frame.Type, Type: frame.Type,
Flags: frame.Flags, Flags: frame.Flags,
Seq: frame.Seq, Seq: frame.Seq,
Payload: frame.Payload, Payload: frame.Payload,
}}, nil })
}
func walkDedicatedBulkInboundBatchPlain(payload []byte, visit func(dataID uint64, item bulkDedicatedBatchItem) error) error {
if len(payload) < bulkDedicatedBatchHeaderLen {
return errBulkFastPayloadInvalid
}
if payload[4] != bulkDedicatedBatchVersion {
return errBulkFastPayloadInvalid
}
dataID := binary.BigEndian.Uint64(payload[8:16])
count := int(binary.BigEndian.Uint32(payload[16:20]))
if dataID == 0 || count <= 0 {
return errBulkFastPayloadInvalid
}
offset := bulkDedicatedBatchHeaderLen
for i := 0; i < count; i++ {
if len(payload)-offset < bulkDedicatedBatchItemHeaderLen {
return errBulkFastPayloadInvalid
}
itemType := payload[offset]
switch itemType {
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
default:
return errBulkFastPayloadInvalid
}
flags := payload[offset+1]
seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16]))
offset += bulkDedicatedBatchItemHeaderLen
if dataLen < 0 || len(payload)-offset < dataLen {
return errBulkFastPayloadInvalid
}
if err := visit(dataID, bulkDedicatedBatchItem{
Type: itemType,
Flags: flags,
Seq: seq,
Payload: payload[offset : offset+dataLen],
}); err != nil {
return err
}
offset += dataLen
}
if offset != len(payload) {
return errBulkFastPayloadInvalid
}
return nil
}
func walkDedicatedBulkInboundSuperBatchPlain(payload []byte, visit func(dataID uint64, item bulkDedicatedBatchItem) error) error {
if len(payload) < bulkDedicatedSuperBatchHeaderLen {
return errBulkFastPayloadInvalid
}
if payload[4] != bulkDedicatedSuperBatchVersion {
return errBulkFastPayloadInvalid
}
groupCount := int(binary.BigEndian.Uint32(payload[8:12]))
if groupCount <= 0 {
return errBulkFastPayloadInvalid
}
offset := bulkDedicatedSuperBatchHeaderLen
totalItems := 0
for i := 0; i < groupCount; i++ {
if len(payload)-offset < bulkDedicatedSuperBatchGroupHeaderLen {
return errBulkFastPayloadInvalid
}
dataID := binary.BigEndian.Uint64(payload[offset : offset+8])
count := int(binary.BigEndian.Uint32(payload[offset+8 : offset+12]))
offset += bulkDedicatedSuperBatchGroupHeaderLen
if dataID == 0 || count <= 0 {
return errBulkFastPayloadInvalid
}
totalItems += count
if totalItems > bulkDedicatedBatchMaxItems {
return errBulkFastPayloadInvalid
}
for j := 0; j < count; j++ {
if len(payload)-offset < bulkDedicatedBatchItemHeaderLen {
return errBulkFastPayloadInvalid
}
itemType := payload[offset]
switch itemType {
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
default:
return errBulkFastPayloadInvalid
}
flags := payload[offset+1]
seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16]))
offset += bulkDedicatedBatchItemHeaderLen
if dataLen < 0 || len(payload)-offset < dataLen {
return errBulkFastPayloadInvalid
}
if err := visit(dataID, bulkDedicatedBatchItem{
Type: itemType,
Flags: flags,
Seq: seq,
Payload: payload[offset : offset+dataLen],
}); err != nil {
return err
}
offset += dataLen
}
}
if offset != len(payload) {
return errBulkFastPayloadInvalid
}
return nil
} }
func normalizeDedicatedBulkSendError(err error) error { func normalizeDedicatedBulkSendError(err error) error {
@@ -630,13 +992,23 @@ func normalizeDedicatedBulkSendError(err error) error {
} }
func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchItem) error { func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchItem) error {
return dispatchDedicatedBulkInboundItemWithRelease(bulk, item, nil)
}
func dispatchDedicatedBulkInboundItemWithRelease(bulk *bulkHandle, item bulkDedicatedBatchItem, release func()) error {
if bulk == nil { if bulk == nil {
if release != nil {
release()
}
return io.ErrClosedPipe return io.ErrClosedPipe
} }
switch item.Type { switch item.Type {
case bulkFastPayloadTypeData: case bulkFastPayloadTypeData:
return bulk.pushOwnedChunkNoReset(item.Payload) return bulk.pushOwnedChunkWithReleaseNoReset(item.Payload, release)
case bulkFastPayloadTypeClose: case bulkFastPayloadTypeClose:
if release != nil {
release()
}
if item.Flags&bulkFastPayloadFlagFullClose != 0 { if item.Flags&bulkFastPayloadFlagFullClose != 0 {
bulk.markPeerClosed() bulk.markPeerClosed()
return nil return nil
@@ -644,6 +1016,9 @@ func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchI
bulk.markRemoteClosed() bulk.markRemoteClosed()
return nil return nil
case bulkFastPayloadTypeReset: case bulkFastPayloadTypeReset:
if release != nil {
release()
}
resetErr := errBulkReset resetErr := errBulkReset
if len(item.Payload) > 0 { if len(item.Payload) > 0 {
resetErr = bulkRemoteResetError(string(item.Payload)) resetErr = bulkRemoteResetError(string(item.Payload))
@@ -651,6 +1026,9 @@ func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchI
bulk.markReset(bulkResetError(resetErr)) bulk.markReset(bulkResetError(resetErr))
return nil return nil
case bulkFastPayloadTypeRelease: case bulkFastPayloadTypeRelease:
if release != nil {
release()
}
bytes, chunks, err := decodeBulkDedicatedReleasePayload(item.Payload) bytes, chunks, err := decodeBulkDedicatedReleasePayload(item.Payload)
if err != nil { if err != nil {
return err return err
@@ -658,6 +1036,9 @@ func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchI
bulk.releaseOutboundWindow(bytes, chunks) bulk.releaseOutboundWindow(bytes, chunks)
return nil return nil
default: default:
if release != nil {
release()
}
return errBulkFastPayloadInvalid return errBulkFastPayloadInvalid
} }
} }
+33
View File
@@ -1,6 +1,7 @@
package notify package notify
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"net" "net"
@@ -8,6 +9,38 @@ import (
"time" "time"
) )
func TestCloneBulkDedicatedSendRequestsDeepCopiesPayload(t *testing.T) {
src := []bulkDedicatedSendRequest{
{
Type: bulkFastPayloadTypeData,
Seq: 1,
Payload: []byte("payload-a"),
},
{
Type: bulkFastPayloadTypeReset,
Seq: 2,
Payload: []byte("payload-b"),
},
}
cloned := cloneBulkDedicatedSendRequests(src)
if len(cloned) != len(src) {
t.Fatalf("clone length = %d, want %d", len(cloned), len(src))
}
if &cloned[0] == &src[0] {
t.Fatal("request clone should not alias source slice elements")
}
if len(cloned[0].Payload) == 0 || len(src[0].Payload) == 0 {
t.Fatal("payload should not be empty")
}
if &cloned[0].Payload[0] == &src[0].Payload[0] {
t.Fatal("payload clone should not alias source bytes")
}
src[0].Payload[0] = 'X'
if bytes.Equal(cloned[0].Payload, src[0].Payload) {
t.Fatal("mutating source payload should not affect cloned payload")
}
}
func TestBulkDedicatedBatchPlainRoundTrip(t *testing.T) { func TestBulkDedicatedBatchPlainRoundTrip(t *testing.T) {
releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2) releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2)
if err != nil { if err != nil {
+743
View File
@@ -0,0 +1,743 @@
package notify
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
)
const bulkDedicatedLaneMicroBatchWait = 50 * time.Microsecond
type bulkDedicatedLaneBatchRequest struct {
Ctx context.Context
DataID uint64
Items []bulkDedicatedSendRequest
Deadline time.Time
wait bool
resultCh chan error
state bulkDedicatedRequestState
aborted atomic.Bool
}
var bulkDedicatedLaneBatchRequestPool sync.Pool
type bulkDedicatedLaneSender struct {
conn net.Conn
encode func([]bulkDedicatedOutboundBatch) ([]byte, func(), error)
fail func(error)
reqCh chan *bulkDedicatedLaneBatchRequest
stopCh chan struct{}
doneCh chan struct{}
stopOnce sync.Once
flushMu sync.Mutex
queued atomic.Int64
errMu sync.Mutex
err error
}
func newBulkDedicatedLaneSender(conn net.Conn, encode func([]bulkDedicatedOutboundBatch) ([]byte, func(), error), fail func(error)) *bulkDedicatedLaneSender {
sender := &bulkDedicatedLaneSender{
conn: conn,
encode: encode,
fail: fail,
reqCh: make(chan *bulkDedicatedLaneBatchRequest, bulkDedicatedSendQueueSize),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
go sender.run()
return sender
}
func getBulkDedicatedLaneBatchRequest() *bulkDedicatedLaneBatchRequest {
if pooled, ok := bulkDedicatedLaneBatchRequestPool.Get().(*bulkDedicatedLaneBatchRequest); ok && pooled != nil {
pooled.reset()
return pooled
}
req := &bulkDedicatedLaneBatchRequest{
resultCh: make(chan error, 1),
}
req.reset()
return req
}
func (r *bulkDedicatedLaneBatchRequest) reset() {
if r == nil {
return
}
r.Ctx = nil
r.DataID = 0
r.Deadline = time.Time{}
r.wait = false
r.Items = r.Items[:0]
r.state.value.Store(bulkDedicatedRequestQueued)
r.aborted.Store(false)
if r.resultCh != nil {
select {
case <-r.resultCh:
default:
}
}
}
func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool, borrowItems bool) {
if r == nil {
return
}
r.reset()
r.Ctx = ctx
r.DataID = dataID
r.wait = wait
if deadline, ok := ctx.Deadline(); ok {
r.Deadline = deadline
}
if borrowItems {
r.Items = items
return
}
if cap(r.Items) < len(items) {
r.Items = make([]bulkDedicatedSendRequest, len(items))
} else {
r.Items = r.Items[:len(items)]
}
copy(r.Items, items)
}
func (r *bulkDedicatedLaneBatchRequest) recycle() {
if r == nil {
return
}
r.reset()
bulkDedicatedLaneBatchRequestPool.Put(r)
}
func (s *bulkDedicatedLaneSender) submitData(ctx context.Context, dataID uint64, seq uint64, payload []byte) error {
if s == nil {
return errTransportDetached
}
items := []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: append([]byte(nil), payload...),
}}
return s.submitBatch(ctx, dataID, items, false, false)
}
func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int, payloadOwned bool) (int, error) {
if s == nil {
return 0, errTransportDetached
}
if len(payload) == 0 {
return 0, nil
}
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
if submitted, written, err := s.tryDirectSubmitWrite(ctx, dataID, startSeq, payload, chunkSize); submitted {
return written, err
}
written := 0
seq := startSeq
for written < len(payload) {
var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest
items := itemBuf[:0]
batchBytes := bulkDedicatedBatchHeaderLen
start := written
for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written)
if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes {
break
}
items = append(items, bulkDedicatedSendRequest{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: payload[written:end],
})
batchBytes += itemLen
seq++
written = end
}
if len(items) == 0 {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
items = append(items, bulkDedicatedSendRequest{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: payload[written:end],
})
seq++
written = end
}
if err := s.submitWriteBatch(ctx, dataID, items, payloadOwned); err != nil {
return start, err
}
start = written
}
return written, nil
}
func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, _ bool) error {
if s == nil {
return errTransportDetached
}
if len(items) == 0 {
return nil
}
if submitted, err := s.tryDirectSubmitBatch(ctx, dataID, items); submitted {
return err
}
return s.submitBatch(ctx, dataID, items, true, true)
}
func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint64, frameType uint8, flags uint8, seq uint64, payload []byte) error {
if s == nil {
return errTransportDetached
}
items := []bulkDedicatedSendRequest{{
Type: frameType,
Flags: flags,
Seq: seq,
}}
if len(payload) > 0 {
items[0].Payload = append([]byte(nil), payload...)
}
return s.submitBatch(ctx, dataID, items, true, false)
}
func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool, borrowItems bool) error {
if s == nil {
return errTransportDetached
}
if ctx == nil {
ctx = context.Background()
}
if err := s.errSnapshot(); err != nil {
return err
}
req := getBulkDedicatedLaneBatchRequest()
req.prepare(ctx, dataID, items, wait, borrowItems)
s.queued.Add(1)
select {
case <-ctx.Done():
s.queued.Add(-1)
req.recycle()
return normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
s.queued.Add(-1)
req.recycle()
return s.stoppedErr()
case s.reqCh <- req:
if !wait {
return nil
}
return s.waitAck(req)
}
}
func (s *bulkDedicatedLaneSender) tryDirectSubmitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int) (bool, int, error) {
if s == nil {
return true, 0, errTransportDetached
}
if ctx == nil {
ctx = context.Background()
}
if len(payload) == 0 {
return true, 0, nil
}
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
if err := s.errSnapshot(); err != nil {
return true, 0, err
}
select {
case <-ctx.Done():
return true, 0, normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
return true, 0, s.stoppedErr()
default:
}
if s.queued.Load() != 0 {
return false, 0, nil
}
if !s.flushMu.TryLock() {
return false, 0, nil
}
defer s.flushMu.Unlock()
if s.queued.Load() != 0 {
return false, 0, nil
}
if err := s.errSnapshot(); err != nil {
return true, 0, err
}
written := 0
seq := startSeq
deadline, _ := ctx.Deadline()
for written < len(payload) {
select {
case <-ctx.Done():
return true, written, normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
return true, written, s.stoppedErr()
default:
}
var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest
items := itemBuf[:0]
batchBytes := bulkDedicatedBatchHeaderLen
start := written
for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written)
if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes {
break
}
items = append(items, bulkDedicatedSendRequest{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: payload[written:end],
})
batchBytes += itemLen
seq++
written = end
}
if len(items) == 0 {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
items = append(items, bulkDedicatedSendRequest{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: payload[written:end],
})
seq++
written = end
}
if err := s.flush([]bulkDedicatedOutboundBatch{{
DataID: dataID,
Items: items,
}}, deadline); err != nil {
err = normalizeDedicatedBulkSendError(err)
s.setErr(err)
s.failPending(err)
if s.fail != nil {
go s.fail(err)
}
return true, start, err
}
}
return true, written, nil
}
func (s *bulkDedicatedLaneSender) tryDirectSubmitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) (bool, error) {
if s == nil {
return true, errTransportDetached
}
if ctx == nil {
ctx = context.Background()
}
if len(items) == 0 {
return true, nil
}
if err := s.errSnapshot(); err != nil {
return true, err
}
select {
case <-ctx.Done():
return true, normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
return true, s.stoppedErr()
default:
}
if s.queued.Load() != 0 {
return false, nil
}
if !s.flushMu.TryLock() {
return false, nil
}
defer s.flushMu.Unlock()
if s.queued.Load() != 0 {
return false, nil
}
if err := s.errSnapshot(); err != nil {
return true, err
}
select {
case <-ctx.Done():
return true, normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
return true, s.stoppedErr()
default:
}
deadline, _ := ctx.Deadline()
if err := s.flush([]bulkDedicatedOutboundBatch{{
DataID: dataID,
Items: items,
}}, deadline); err != nil {
err = normalizeDedicatedBulkSendError(err)
s.setErr(err)
s.failPending(err)
if s.fail != nil {
go s.fail(err)
}
return true, err
}
return true, nil
}
func (s *bulkDedicatedLaneSender) waitAck(req *bulkDedicatedLaneBatchRequest) error {
if s == nil {
return errTransportDetached
}
if req == nil {
return errTransportDetached
}
ctx := req.Ctx
if ctx == nil {
ctx = context.Background()
}
select {
case err := <-req.resultCh:
req.recycle()
return normalizeDedicatedBulkSendError(err)
case <-ctx.Done():
if req.tryCancel() {
req.aborted.Store(true)
return normalizeStreamDeadlineError(ctx.Err())
}
err := <-req.resultCh
req.recycle()
return normalizeDedicatedBulkSendError(err)
}
}
func (s *bulkDedicatedLaneSender) stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
s.setErr(errTransportDetached)
close(s.stopCh)
})
<-s.doneCh
}
func (s *bulkDedicatedLaneSender) run() {
defer close(s.doneCh)
var carry *bulkDedicatedLaneBatchRequest
for {
req, ok := s.nextRequest(carry)
carry = nil
if !ok {
return
}
if !req.tryStart() {
s.finishRequest(req, req.canceledErr())
continue
}
if err := req.contextErr(); err != nil {
s.finishRequest(req, err)
continue
}
batchReqs := []*bulkDedicatedLaneBatchRequest{req}
batches := []bulkDedicatedOutboundBatch{{
DataID: req.DataID,
Items: req.Items,
}}
batchBytes := bulkDedicatedBatchesPlainLen(batches)
deadline := req.Deadline
s.flushMu.Lock()
err := s.errSnapshot()
if err == nil {
carry, err = s.collectBatchRequests(&batchReqs, &batches, &batchBytes, &deadline)
if err == nil {
err = s.flush(batches, deadline)
}
}
s.flushMu.Unlock()
if err != nil {
err = normalizeDedicatedBulkSendError(err)
s.setErr(err)
s.finishBatchRequests(batchReqs, err)
if carry != nil {
s.finishRequest(carry, err)
carry = nil
}
s.failPending(err)
if s.fail != nil {
go s.fail(err)
}
return
}
s.finishBatchRequests(batchReqs, nil)
}
}
func appendBulkDedicatedLaneBatch(batches *[]bulkDedicatedOutboundBatch, req *bulkDedicatedLaneBatchRequest) {
if req == nil {
return
}
if len(*batches) > 0 && (*batches)[len(*batches)-1].DataID == req.DataID {
last := &(*batches)[len(*batches)-1]
last.Items = append(last.Items, req.Items...)
return
}
*batches = append(*batches, bulkDedicatedOutboundBatch{
DataID: req.DataID,
Items: req.Items,
})
}
func bulkDedicatedLaneBatchItemCount(batches []bulkDedicatedOutboundBatch) int {
total := 0
for _, batch := range batches {
total += len(batch.Items)
}
return total
}
func bulkDedicatedLaneNextBatchBytes(batches []bulkDedicatedOutboundBatch, req *bulkDedicatedLaneBatchRequest, currentBytes int) int {
reqBytes := bulkDedicatedSendRequestsLen(req.Items)
if len(batches) == 0 {
return bulkDedicatedBatchHeaderLen + reqBytes
}
last := batches[len(batches)-1]
if last.DataID == req.DataID {
return currentBytes + reqBytes
}
if len(batches) == 1 {
return bulkDedicatedSuperBatchHeaderLen +
bulkDedicatedSuperBatchGroupHeaderLen + bulkDedicatedSendRequestsLen(batches[0].Items) +
bulkDedicatedSuperBatchGroupHeaderLen + reqBytes
}
return currentBytes + bulkDedicatedSuperBatchGroupHeaderLen + reqBytes
}
func (r *bulkDedicatedLaneBatchRequest) contextErr() error {
if r.Ctx == nil {
return nil
}
select {
case <-r.Ctx.Done():
return normalizeStreamDeadlineError(r.Ctx.Err())
default:
return nil
}
}
func (r *bulkDedicatedLaneBatchRequest) tryStart() bool {
if r == nil {
return false
}
return r.state.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestStarted)
}
func (r *bulkDedicatedLaneBatchRequest) tryCancel() bool {
if r == nil {
return false
}
return r.state.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestCanceled)
}
func (r *bulkDedicatedLaneBatchRequest) canceledErr() error {
if r == nil {
return context.Canceled
}
if err := r.contextErr(); err != nil {
return err
}
return context.Canceled
}
func (s *bulkDedicatedLaneSender) nextRequest(carry *bulkDedicatedLaneBatchRequest) (*bulkDedicatedLaneBatchRequest, bool) {
if carry != nil {
select {
case <-s.stopCh:
err := s.stoppedErr()
s.finishRequest(carry, err)
s.failPending(err)
return nil, false
default:
return carry, true
}
}
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return nil, false
case req := <-s.reqCh:
return req, true
}
}
func (s *bulkDedicatedLaneSender) collectBatchRequests(batchReqs *[]*bulkDedicatedLaneBatchRequest, batches *[]bulkDedicatedOutboundBatch, batchBytes *int, deadline *time.Time) (*bulkDedicatedLaneBatchRequest, error) {
if s == nil || bulkDedicatedLaneBatchItemCount(*batches) >= bulkDedicatedBatchMaxItems || *batchBytes >= bulkDedicatedBatchMaxPlainBytes {
return nil, nil
}
wait := bulkDedicatedLaneMicroBatchWait
var (
timer *time.Timer
timerCh <-chan time.Time
waited bool
)
if wait > 0 {
timer = time.NewTimer(wait)
timerCh = timer.C
defer timer.Stop()
}
for {
if bulkDedicatedLaneBatchItemCount(*batches) >= bulkDedicatedBatchMaxItems || *batchBytes >= bulkDedicatedBatchMaxPlainBytes {
return nil, nil
}
var (
req *bulkDedicatedLaneBatchRequest
ok bool
)
select {
case <-s.stopCh:
return nil, s.stoppedErr()
case req = <-s.reqCh:
ok = true
default:
}
if !ok {
if waited || timerCh == nil {
return nil, nil
}
waited = true
select {
case <-s.stopCh:
return nil, s.stoppedErr()
case req = <-s.reqCh:
case <-timerCh:
return nil, nil
}
}
if err := req.contextErr(); err != nil {
if req.tryCancel() {
s.finishRequest(req, err)
continue
}
s.finishRequest(req, req.canceledErr())
continue
}
nextItems := bulkDedicatedLaneBatchItemCount(*batches) + len(req.Items)
if nextItems > bulkDedicatedBatchMaxItems {
return req, nil
}
nextBytes := bulkDedicatedLaneNextBatchBytes(*batches, req, *batchBytes)
if nextBytes > bulkDedicatedBatchMaxPlainBytes {
return req, nil
}
if !req.tryStart() {
s.finishRequest(req, req.canceledErr())
continue
}
if err := req.contextErr(); err != nil {
s.finishRequest(req, err)
continue
}
*batchReqs = append(*batchReqs, req)
appendBulkDedicatedLaneBatch(batches, req)
*batchBytes = nextBytes
*deadline = earlierDeadline(*deadline, req.Deadline)
}
}
func earlierDeadline(current time.Time, next time.Time) time.Time {
if current.IsZero() {
return next
}
if next.IsZero() || current.Before(next) {
return current
}
return next
}
func (s *bulkDedicatedLaneSender) flush(batches []bulkDedicatedOutboundBatch, deadline time.Time) error {
if s == nil || s.conn == nil {
return errTransportDetached
}
payload, release, err := s.encode(batches)
if err != nil {
return err
}
if release != nil {
defer release()
}
return writeBulkDedicatedRecordWithDeadline(s.conn, payload, deadline)
}
func (s *bulkDedicatedLaneSender) finishRequest(req *bulkDedicatedLaneBatchRequest, err error) {
if s != nil {
s.queued.Add(-1)
}
if req == nil {
return
}
if req.wait && req.resultCh != nil {
if req.aborted.Load() {
req.recycle()
return
}
req.resultCh <- err
return
}
req.recycle()
}
func (s *bulkDedicatedLaneSender) finishBatchRequests(reqs []*bulkDedicatedLaneBatchRequest, err error) {
for _, req := range reqs {
s.finishRequest(req, err)
}
}
func (s *bulkDedicatedLaneSender) failPending(err error) {
for {
select {
case item := <-s.reqCh:
s.finishRequest(item, err)
default:
return
}
}
}
func (s *bulkDedicatedLaneSender) setErr(err error) {
if s == nil || err == nil {
return
}
s.errMu.Lock()
if s.err == nil {
s.err = err
}
s.errMu.Unlock()
}
func (s *bulkDedicatedLaneSender) errSnapshot() error {
if s == nil {
return errTransportDetached
}
s.errMu.Lock()
defer s.errMu.Unlock()
return s.err
}
func (s *bulkDedicatedLaneSender) stoppedErr() error {
if err := s.errSnapshot(); err != nil {
return err
}
return errTransportDetached
}
+196
View File
@@ -0,0 +1,196 @@
package notify
import (
"bytes"
"context"
"testing"
"time"
)
func TestBulkDedicatedLaneBatchRequestPrepareBorrowedSharesItems(t *testing.T) {
req := getBulkDedicatedLaneBatchRequest()
defer req.recycle()
items := []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeData,
Seq: 7,
Payload: []byte("hello"),
}}
req.prepare(context.Background(), 11, items, true, true)
if got, want := len(req.Items), 1; got != want {
t.Fatalf("prepared item count = %d, want %d", got, want)
}
if &req.Items[0] != &items[0] {
t.Fatal("prepare with borrowed items should share request items")
}
}
func TestBulkDedicatedLaneSenderTryDirectSubmitWriteFlushesWholePayload(t *testing.T) {
conn := &shortWriteBulkRecordConn{maxPerWrite: 1024}
encodeCalls := 0
sender := &bulkDedicatedLaneSender{
conn: conn,
stopCh: make(chan struct{}),
encode: func(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) {
encodeCalls++
payload, err := encodeBulkDedicatedBatchesPlain(batches)
return payload, nil, err
},
}
payload := bytes.Repeat([]byte("a"), 3*defaultBulkChunkSize)
submitted, written, err := sender.tryDirectSubmitWrite(context.Background(), 9, 1, payload, defaultBulkChunkSize)
if err != nil {
t.Fatalf("tryDirectSubmitWrite error = %v", err)
}
if !submitted {
t.Fatal("tryDirectSubmitWrite should submit directly")
}
if got, want := written, len(payload); got != want {
t.Fatalf("written = %d, want %d", got, want)
}
if encodeCalls == 0 {
t.Fatal("encode should be called at least once")
}
if got := sender.queued.Load(); got != 0 {
t.Fatalf("queued requests = %d, want 0", got)
}
}
func TestBulkDedicatedLaneSenderCollectBatchRequestsBatchesAcrossDataIDs(t *testing.T) {
sender := &bulkDedicatedLaneSender{
reqCh: make(chan *bulkDedicatedLaneBatchRequest, 3),
stopCh: make(chan struct{}),
}
first := &bulkDedicatedLaneBatchRequest{
Ctx: context.Background(),
DataID: 7,
Items: []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeData,
Seq: 1,
Payload: []byte("a"),
}},
Deadline: time.Now().Add(2 * time.Second),
}
if !first.tryStart() {
t.Fatal("first request should start")
}
second := &bulkDedicatedLaneBatchRequest{
Ctx: context.Background(),
DataID: 7,
Items: []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeData,
Seq: 2,
Payload: []byte("b"),
}},
Deadline: time.Now().Add(time.Second),
}
third := &bulkDedicatedLaneBatchRequest{
Ctx: context.Background(),
DataID: 8,
Items: []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeData,
Seq: 3,
Payload: []byte("c"),
}},
Deadline: time.Now().Add(3 * time.Second),
}
sender.reqCh <- second
sender.reqCh <- third
batchReqs := []*bulkDedicatedLaneBatchRequest{first}
batches := []bulkDedicatedOutboundBatch{{
DataID: first.DataID,
Items: first.Items,
}}
batchBytes := bulkDedicatedBatchesPlainLen(batches)
deadline := first.Deadline
carry, err := sender.collectBatchRequests(&batchReqs, &batches, &batchBytes, &deadline)
if err != nil {
t.Fatalf("collectBatchRequests error = %v", err)
}
if carry != nil {
t.Fatalf("carry request = %+v, want nil", carry)
}
if len(batchReqs) != 3 {
t.Fatalf("batched request count = %d, want 3", len(batchReqs))
}
if len(batches) != 2 {
t.Fatalf("batched group count = %d, want 2", len(batches))
}
if got, want := batches[0].DataID, uint64(7); got != want {
t.Fatalf("first batch dataID = %d, want %d", got, want)
}
if got, want := len(batches[0].Items), 2; got != want {
t.Fatalf("first batch item count = %d, want %d", got, want)
}
if got, want := batches[1].DataID, uint64(8); got != want {
t.Fatalf("second batch dataID = %d, want %d", got, want)
}
if got, want := len(batches[1].Items), 1; got != want {
t.Fatalf("second batch item count = %d, want %d", got, want)
}
if got, want := batches[0].Items[0].Seq, uint64(1); got != want {
t.Fatalf("first batch seq = %d, want %d", got, want)
}
if got, want := batches[0].Items[1].Seq, uint64(2); got != want {
t.Fatalf("second batch seq = %d, want %d", got, want)
}
if got, want := batches[1].Items[0].Seq, uint64(3); got != want {
t.Fatalf("third batch seq = %d, want %d", got, want)
}
if !deadline.Equal(second.Deadline) {
t.Fatalf("merged deadline = %v, want %v", deadline, second.Deadline)
}
}
func TestDedicatedBulkSuperBatchRoundTrip(t *testing.T) {
batches := []bulkDedicatedOutboundBatch{
{
DataID: 7,
Items: []bulkDedicatedSendRequest{
{Type: bulkFastPayloadTypeData, Seq: 1, Payload: []byte("alpha")},
{Type: bulkFastPayloadTypeData, Seq: 2, Payload: []byte("beta")},
},
},
{
DataID: 8,
Items: []bulkDedicatedSendRequest{
{Type: bulkFastPayloadTypeClose, Flags: bulkFastPayloadFlagFullClose, Seq: 3},
{Type: bulkFastPayloadTypeRelease, Seq: 4, Payload: []byte{1, 2, 3, 4}},
},
},
}
plain, err := encodeBulkDedicatedBatchesPlain(batches)
if err != nil {
t.Fatalf("encodeBulkDedicatedBatchesPlain error = %v", err)
}
got, err := decodeDedicatedBulkInboundBatches(plain)
if err != nil {
t.Fatalf("decodeDedicatedBulkInboundBatches error = %v", err)
}
if len(got) != 2 {
t.Fatalf("decoded batch count = %d, want 2", len(got))
}
if got[0].DataID != 7 || got[1].DataID != 8 {
t.Fatalf("decoded dataIDs = %d,%d, want 7,8", got[0].DataID, got[1].DataID)
}
if len(got[0].Items) != 2 || len(got[1].Items) != 2 {
t.Fatalf("decoded item counts = %d,%d, want 2,2", len(got[0].Items), len(got[1].Items))
}
if !bytes.Equal(got[0].Items[0].Payload, []byte("alpha")) || !bytes.Equal(got[0].Items[1].Payload, []byte("beta")) {
t.Fatalf("decoded first batch payloads mismatch: %+v", got[0].Items)
}
if got[1].Items[0].Type != bulkFastPayloadTypeClose || got[1].Items[0].Flags != bulkFastPayloadFlagFullClose {
t.Fatalf("decoded close item mismatch: %+v", got[1].Items[0])
}
if got[1].Items[1].Type != bulkFastPayloadTypeRelease || !bytes.Equal(got[1].Items[1].Payload, []byte{1, 2, 3, 4}) {
t.Fatalf("decoded release item mismatch: %+v", got[1].Items[1])
}
}
+66
View File
@@ -0,0 +1,66 @@
package notify
import (
"bytes"
"encoding/binary"
"io"
"net"
"testing"
"time"
)
type shortWriteBulkRecordConn struct {
maxPerWrite int
buf bytes.Buffer
}
func (c *shortWriteBulkRecordConn) Read([]byte) (int, error) { return 0, io.EOF }
func (c *shortWriteBulkRecordConn) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
n := c.maxPerWrite
if n <= 0 || n > len(p) {
n = len(p)
}
_, _ = c.buf.Write(p[:n])
return n, nil
}
func (c *shortWriteBulkRecordConn) Close() error { return nil }
func (c *shortWriteBulkRecordConn) LocalAddr() net.Addr { return shortWriteBulkRecordAddr("local") }
func (c *shortWriteBulkRecordConn) RemoteAddr() net.Addr { return shortWriteBulkRecordAddr("remote") }
func (c *shortWriteBulkRecordConn) SetDeadline(time.Time) error { return nil }
func (c *shortWriteBulkRecordConn) SetReadDeadline(time.Time) error {
return nil
}
func (c *shortWriteBulkRecordConn) SetWriteDeadline(time.Time) error {
return nil
}
type shortWriteBulkRecordAddr string
func (a shortWriteBulkRecordAddr) Network() string { return "tcp" }
func (a shortWriteBulkRecordAddr) String() string { return string(a) }
func TestWriteBulkDedicatedRecordWithDeadlineHandlesShortWrite(t *testing.T) {
conn := &shortWriteBulkRecordConn{maxPerWrite: 3}
payload := []byte("abcdefghijklmnopqrstuvwxyz")
if err := writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{}); err != nil {
t.Fatalf("writeBulkDedicatedRecordWithDeadline failed: %v", err)
}
raw := conn.buf.Bytes()
if got, want := len(raw), bulkDedicatedRecordHeaderLen+len(payload); got != want {
t.Fatalf("record length = %d, want %d", got, want)
}
if got := string(raw[:4]); got != bulkDedicatedRecordMagic {
t.Fatalf("record magic = %q, want %q", got, bulkDedicatedRecordMagic)
}
if got, want := int(binary.BigEndian.Uint32(raw[4:8])), len(payload); got != want {
t.Fatalf("record payload length = %d, want %d", got, want)
}
if got := raw[bulkDedicatedRecordHeaderLen:]; !bytes.Equal(got, payload) {
t.Fatalf("record payload mismatch")
}
}
+451
View File
@@ -0,0 +1,451 @@
package notify
import (
"context"
"net"
"sync"
)
type bulkDedicatedSidecar struct {
laneID uint32
conn net.Conn
closeOnce sync.Once
senderMu sync.Mutex
sender *bulkDedicatedLaneSender
}
type bulkDedicatedLane struct {
id uint32
activeBulks int
sidecar *bulkDedicatedSidecar
attachFlight *bulkDedicatedAttachFlight
}
type bulkDedicatedAttachFlight struct {
done chan struct{}
once sync.Once
err error
}
func normalizeBulkDedicatedLaneID(laneID uint32) uint32 {
if laneID == 0 {
return 1
}
return laneID
}
func newBulkDedicatedSidecar(conn net.Conn, laneID uint32) *bulkDedicatedSidecar {
if conn == nil {
return nil
}
return &bulkDedicatedSidecar{
laneID: normalizeBulkDedicatedLaneID(laneID),
conn: conn,
}
}
func newBulkDedicatedAttachFlight() *bulkDedicatedAttachFlight {
return &bulkDedicatedAttachFlight{
done: make(chan struct{}),
}
}
func (f *bulkDedicatedAttachFlight) finish(err error) {
if f == nil {
return
}
f.once.Do(func() {
f.err = err
close(f.done)
})
}
func (f *bulkDedicatedAttachFlight) wait(ctx context.Context) error {
if f == nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
return ctx.Err()
case <-f.done:
return f.err
}
}
func (s *bulkDedicatedSidecar) close() {
if s == nil {
return
}
s.closeOnce.Do(func() {
if sender := s.laneSenderSnapshot(); sender != nil {
sender.stop()
}
if s.conn != nil {
_ = s.conn.Close()
}
})
}
func (s *bulkDedicatedSidecar) laneSenderSnapshot() *bulkDedicatedLaneSender {
if s == nil {
return nil
}
s.senderMu.Lock()
defer s.senderMu.Unlock()
return s.sender
}
func (s *bulkDedicatedSidecar) laneSenderWithFactory(factory func(net.Conn) *bulkDedicatedLaneSender) *bulkDedicatedLaneSender {
if s == nil || factory == nil {
return nil
}
s.senderMu.Lock()
defer s.senderMu.Unlock()
if s.sender != nil {
return s.sender
}
if s.conn == nil {
return nil
}
s.sender = factory(s.conn)
return s.sender
}
func (c *ClientCommon) clientDedicatedSidecarSnapshot() *bulkDedicatedSidecar {
if c == nil {
return nil
}
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
return firstClientDedicatedSidecarLocked(c.bulkDedicatedLanes)
}
func firstClientDedicatedSidecarLocked(lanes map[uint32]*bulkDedicatedLane) *bulkDedicatedSidecar {
var (
selected *bulkDedicatedSidecar
bestID uint32
)
for laneID, lane := range lanes {
if lane == nil || lane.sidecar == nil {
continue
}
if selected == nil || laneID < bestID {
selected = lane.sidecar
bestID = laneID
}
}
return selected
}
func (c *ClientCommon) reserveBulkDedicatedLane() uint32 {
if c == nil {
return normalizeBulkDedicatedLaneID(0)
}
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
if c.bulkDedicatedLanes == nil {
c.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane)
}
limit := c.bulkDedicatedLaneLimitSnapshot()
var best *bulkDedicatedLane
for _, lane := range c.bulkDedicatedLanes {
if lane == nil {
continue
}
if best == nil || lane.activeBulks < best.activeBulks || (lane.activeBulks == best.activeBulks && lane.id < best.id) {
best = lane
}
}
if best == nil || ((limit <= 0 || len(c.bulkDedicatedLanes) < limit) && best.activeBulks > 0) {
c.bulkDedicatedNextLaneID++
laneID := normalizeBulkDedicatedLaneID(c.bulkDedicatedNextLaneID)
best = &bulkDedicatedLane{id: laneID}
c.bulkDedicatedLanes[laneID] = best
}
best.activeBulks++
return best.id
}
func (c *ClientCommon) releaseBulkDedicatedLane(laneID uint32) {
if c == nil {
return
}
laneID = normalizeBulkDedicatedLaneID(laneID)
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
lane := c.bulkDedicatedLanes[laneID]
if lane == nil {
return
}
if lane.activeBulks > 0 {
lane.activeBulks--
}
if lane.activeBulks == 0 && lane.sidecar == nil && lane.attachFlight == nil {
delete(c.bulkDedicatedLanes, laneID)
}
}
func (c *ClientCommon) clientDedicatedSidecarSnapshotForLane(laneID uint32) *bulkDedicatedSidecar {
if c == nil {
return nil
}
laneID = normalizeBulkDedicatedLaneID(laneID)
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
if lane := c.bulkDedicatedLanes[laneID]; lane != nil {
return lane.sidecar
}
return nil
}
func (c *ClientCommon) beginClientDedicatedSidecarAttach(laneID uint32) (*bulkDedicatedSidecar, *bulkDedicatedAttachFlight, bool) {
if c == nil {
return nil, nil, false
}
laneID = normalizeBulkDedicatedLaneID(laneID)
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
if c.bulkDedicatedLanes == nil {
c.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane)
}
lane := c.bulkDedicatedLanes[laneID]
if lane == nil {
lane = &bulkDedicatedLane{id: laneID}
c.bulkDedicatedLanes[laneID] = lane
}
if lane.sidecar != nil {
return lane.sidecar, nil, false
}
if lane.attachFlight != nil {
return nil, lane.attachFlight, false
}
flight := newBulkDedicatedAttachFlight()
lane.attachFlight = flight
return nil, flight, true
}
func (c *ClientCommon) finishClientDedicatedSidecarAttach(laneID uint32, flight *bulkDedicatedAttachFlight, err error) {
if c == nil || flight == nil {
return
}
laneID = normalizeBulkDedicatedLaneID(laneID)
c.bulkDedicatedSidecarMu.Lock()
if lane := c.bulkDedicatedLanes[laneID]; lane != nil && lane.attachFlight == flight {
lane.attachFlight = nil
if lane.activeBulks == 0 && lane.sidecar == nil {
delete(c.bulkDedicatedLanes, laneID)
}
}
c.bulkDedicatedSidecarMu.Unlock()
flight.finish(err)
}
func (c *ClientCommon) installClientDedicatedSidecar(laneID uint32, sidecar *bulkDedicatedSidecar) (*bulkDedicatedSidecar, bool) {
if c == nil || sidecar == nil {
return nil, false
}
laneID = normalizeBulkDedicatedLaneID(laneID)
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
if c.bulkDedicatedLanes == nil {
c.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane)
}
lane := c.bulkDedicatedLanes[laneID]
if lane == nil {
lane = &bulkDedicatedLane{id: laneID}
c.bulkDedicatedLanes[laneID] = lane
}
if lane.sidecar != nil {
return lane.sidecar, false
}
lane.sidecar = sidecar
return sidecar, true
}
func (c *ClientCommon) clearClientDedicatedSidecar(laneID uint32, sidecar *bulkDedicatedSidecar) bool {
if c == nil || sidecar == nil {
return false
}
laneID = normalizeBulkDedicatedLaneID(laneID)
c.bulkDedicatedSidecarMu.Lock()
defer c.bulkDedicatedSidecarMu.Unlock()
lane := c.bulkDedicatedLanes[laneID]
if lane == nil || lane.sidecar != sidecar {
return false
}
lane.sidecar = nil
if lane.activeBulks == 0 && lane.attachFlight == nil {
delete(c.bulkDedicatedLanes, laneID)
}
return true
}
func (c *ClientCommon) closeClientDedicatedSidecar() {
if c == nil {
return
}
c.bulkDedicatedSidecarMu.Lock()
lanes := c.bulkDedicatedLanes
c.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane)
c.bulkDedicatedSidecarMu.Unlock()
for _, lane := range lanes {
if lane == nil {
continue
}
if lane.sidecar != nil {
lane.sidecar.close()
}
if lane.attachFlight != nil {
lane.attachFlight.finish(errServiceShutdown)
}
}
}
func (c *ClientCommon) handleClientDedicatedSidecarFailure(sidecar *bulkDedicatedSidecar, err error) {
if c == nil || sidecar == nil {
return
}
if !c.clearClientDedicatedSidecar(sidecar.laneID, sidecar) {
return
}
runtime := c.getBulkRuntime()
if runtime != nil && sidecar.conn != nil {
runtime.handleDedicatedReadErrorByConn(clientFileScope(), sidecar.conn, err)
}
sidecar.close()
}
func (s *ServerCommon) serverDedicatedSidecarSnapshot(logical *LogicalConn) *bulkDedicatedSidecar {
if s == nil || logical == nil {
return nil
}
s.bulkDedicatedSidecarMu.Lock()
defer s.bulkDedicatedSidecarMu.Unlock()
return firstServerDedicatedSidecarLocked(s.bulkDedicatedSidecars[logical])
}
func firstServerDedicatedSidecarLocked(lanes map[uint32]*bulkDedicatedSidecar) *bulkDedicatedSidecar {
var (
selected *bulkDedicatedSidecar
bestID uint32
)
for laneID, sidecar := range lanes {
if sidecar == nil {
continue
}
if selected == nil || laneID < bestID {
selected = sidecar
bestID = laneID
}
}
return selected
}
func (s *ServerCommon) serverDedicatedSidecarSnapshotForLane(logical *LogicalConn, laneID uint32) *bulkDedicatedSidecar {
if s == nil || logical == nil {
return nil
}
laneID = normalizeBulkDedicatedLaneID(laneID)
s.bulkDedicatedSidecarMu.Lock()
defer s.bulkDedicatedSidecarMu.Unlock()
if lanes := s.bulkDedicatedSidecars[logical]; lanes != nil {
return lanes[laneID]
}
return nil
}
func (s *ServerCommon) installServerDedicatedSidecar(logical *LogicalConn, laneID uint32, sidecar *bulkDedicatedSidecar) *bulkDedicatedSidecar {
if s == nil || logical == nil || sidecar == nil {
return nil
}
laneID = normalizeBulkDedicatedLaneID(laneID)
s.bulkDedicatedSidecarMu.Lock()
defer s.bulkDedicatedSidecarMu.Unlock()
lanes := s.bulkDedicatedSidecars[logical]
if lanes == nil {
lanes = make(map[uint32]*bulkDedicatedSidecar)
s.bulkDedicatedSidecars[logical] = lanes
}
prev := lanes[laneID]
lanes[laneID] = sidecar
return prev
}
func (s *ServerCommon) clearServerDedicatedSidecar(logical *LogicalConn, laneID uint32, sidecar *bulkDedicatedSidecar) bool {
if s == nil || logical == nil || sidecar == nil {
return false
}
laneID = normalizeBulkDedicatedLaneID(laneID)
s.bulkDedicatedSidecarMu.Lock()
defer s.bulkDedicatedSidecarMu.Unlock()
lanes := s.bulkDedicatedSidecars[logical]
if lanes == nil || lanes[laneID] != sidecar {
return false
}
delete(lanes, laneID)
if len(lanes) == 0 {
delete(s.bulkDedicatedSidecars, logical)
}
return true
}
func (s *ServerCommon) closeServerDedicatedSidecar(logical *LogicalConn) {
if s == nil || logical == nil {
return
}
s.bulkDedicatedSidecarMu.Lock()
lanes := s.bulkDedicatedSidecars[logical]
delete(s.bulkDedicatedSidecars, logical)
s.bulkDedicatedSidecarMu.Unlock()
for _, sidecar := range lanes {
if sidecar != nil {
sidecar.close()
}
}
}
func (s *ServerCommon) closeAllServerDedicatedSidecars() {
if s == nil {
return
}
s.bulkDedicatedSidecarMu.Lock()
lanesByLogical := s.bulkDedicatedSidecars
s.bulkDedicatedSidecars = make(map[*LogicalConn]map[uint32]*bulkDedicatedSidecar)
s.bulkDedicatedSidecarMu.Unlock()
for _, lanes := range lanesByLogical {
for _, sidecar := range lanes {
if sidecar != nil {
sidecar.close()
}
}
}
}
func (s *ServerCommon) handleServerDedicatedSidecarFailure(logical *LogicalConn, sidecar *bulkDedicatedSidecar, err error) {
if s == nil || logical == nil || sidecar == nil {
return
}
if !s.clearServerDedicatedSidecar(logical, sidecar.laneID, sidecar) {
return
}
runtime := s.getBulkRuntime()
if runtime != nil && sidecar.conn != nil {
runtime.handleDedicatedReadErrorByConn(serverFileScope(logical), sidecar.conn, err)
}
sidecar.close()
}
func (s *ServerCommon) attachServerDedicatedSidecarIfExists(logical *LogicalConn, bulk *bulkHandle) {
if s == nil || logical == nil || bulk == nil || !bulk.Dedicated() {
return
}
sidecar := s.serverDedicatedSidecarSnapshotForLane(logical, bulk.dedicatedLaneIDSnapshot())
if sidecar == nil || sidecar.conn == nil {
return
}
_ = bulk.attachDedicatedConnShared(sidecar.conn)
}
+106 -4
View File
@@ -12,6 +12,10 @@ import (
const bulkDispatchRejectTimeout = 300 * time.Millisecond const bulkDispatchRejectTimeout = 300 * time.Millisecond
func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) { func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) {
c.dispatchFastBulkFrameWithOwner(frame, nil)
}
func (c *ClientCommon) dispatchFastBulkFrameWithOwner(frame bulkFastFrame, owner *bulkReadPayloadOwner) {
if frame.DataID == 0 { if frame.DataID == 0 {
return return
} }
@@ -38,7 +42,13 @@ func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) {
} }
switch frame.Type { switch frame.Type {
case bulkFastPayloadTypeData: case bulkFastPayloadTypeData:
if err := bulk.pushOwnedChunk(frame.Payload); err != nil { var err error
if owner != nil {
err = bulk.pushChunkWithOwnershipOptionsAndRelease(frame.Payload, true, true, owner.retainChunk())
} else {
err = bulk.pushOwnedChunk(frame.Payload)
}
if err != nil {
if c.showError || c.debugMode { if c.showError || c.debugMode {
fmt.Println("client bulk push chunk error", err) fmt.Println("client bulk push chunk error", err)
} }
@@ -58,14 +68,28 @@ func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) {
resetErr = bulkRemoteResetError(string(frame.Payload)) resetErr = bulkRemoteResetError(string(frame.Payload))
} }
bulk.markReset(bulkResetError(resetErr)) bulk.markReset(bulkResetError(resetErr))
case bulkFastPayloadTypeRelease:
bytes, chunks, err := decodeBulkDedicatedReleasePayload(frame.Payload)
if err != nil {
if c.showError || c.debugMode {
fmt.Println("client bulk release decode error", err)
}
c.bestEffortRejectInboundBulkData(bulk.ID(), frame.DataID, err.Error())
return
}
bulk.releaseOutboundWindow(bytes, chunks)
} }
} }
func (c *ClientCommon) dispatchFastBulkData(frame bulkFastDataFrame) { func (c *ClientCommon) dispatchFastBulkData(frame bulkFastDataFrame) {
c.dispatchFastBulkFrame(frame) c.dispatchFastBulkFrameWithOwner(frame, nil)
} }
func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastFrame) { func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastFrame) {
s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, nil)
}
func (s *ServerCommon) dispatchFastBulkFrameWithOwner(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastFrame, owner *bulkReadPayloadOwner) {
if logical == nil || frame.DataID == 0 { if logical == nil || frame.DataID == 0 {
return return
} }
@@ -91,7 +115,13 @@ func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *Tr
} }
switch frame.Type { switch frame.Type {
case bulkFastPayloadTypeData: case bulkFastPayloadTypeData:
if err := bulk.pushOwnedChunk(frame.Payload); err != nil { var err error
if owner != nil {
err = bulk.pushChunkWithOwnershipOptionsAndRelease(frame.Payload, true, true, owner.retainChunk())
} else {
err = bulk.pushOwnedChunk(frame.Payload)
}
if err != nil {
if s.showError || s.debugMode { if s.showError || s.debugMode {
fmt.Println("server bulk push chunk error", err) fmt.Println("server bulk push chunk error", err)
} }
@@ -111,11 +141,83 @@ func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *Tr
resetErr = bulkRemoteResetError(string(frame.Payload)) resetErr = bulkRemoteResetError(string(frame.Payload))
} }
bulk.markReset(bulkResetError(resetErr)) bulk.markReset(bulkResetError(resetErr))
case bulkFastPayloadTypeRelease:
bytes, chunks, err := decodeBulkDedicatedReleasePayload(frame.Payload)
if err != nil {
if s.showError || s.debugMode {
fmt.Println("server bulk release decode error", err)
}
s.bestEffortRejectInboundBulkData(logical, transport, conn, bulk.ID(), frame.DataID, err.Error())
return
}
bulk.releaseOutboundWindow(bytes, chunks)
} }
} }
func (s *ServerCommon) dispatchFastBulkData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastDataFrame) { func (s *ServerCommon) dispatchFastBulkData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastDataFrame) {
s.dispatchFastBulkFrame(logical, transport, conn, frame) s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, nil)
}
func (c *ClientCommon) tryDispatchBorrowedBulkTransportPayload(payload []byte) bool {
if c == nil || len(payload) == 0 {
return false
}
profile := c.clientTransportProtectionSnapshot()
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload)
if err != nil {
if c.showError || c.debugMode {
fmt.Println("client decode transport payload error", err)
}
return true
}
owner := newBulkReadPayloadOwner(plainRelease)
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
c.dispatchFastBulkFrameWithOwner(frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
return false
}
if walkErr != nil && (c.showError || c.debugMode) {
fmt.Println("client decode bulk fast payload error", walkErr)
}
return true
}
func (s *ServerCommon) tryDispatchBorrowedBulkTransportPayload(source interface{}, payload []byte) bool {
if s == nil || len(payload) == 0 {
return false
}
logical, transport := s.resolveInboundSource(source)
if logical == nil {
return false
}
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload)
if err != nil {
if s.showError || s.debugMode {
fmt.Println("server decode transport payload error", err)
}
return true
}
conn := serverInboundConn(source)
owner := newBulkReadPayloadOwner(plainRelease)
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
return false
}
if walkErr != nil && (s.showError || s.debugMode) {
fmt.Println("server decode bulk fast payload error", walkErr)
}
return true
} }
func (c *ClientCommon) bestEffortRejectInboundBulkData(bulkID string, dataID uint64, message string) { func (c *ClientCommon) bestEffortRejectInboundBulkData(bulkID string, dataID uint64, message string) {
+8 -1
View File
@@ -308,7 +308,11 @@ func newBulkBenchmarkClient(tb testing.TB, network string, server bulkBenchmarkS
tb.Fatalf("UseSignalReliabilityClient failed: %v", err) tb.Fatalf("UseSignalReliabilityClient failed: %v", err)
} }
} }
if err := client.Connect(network, server.addr); err != nil { dialAddr := server.addr
if network == "tcp" {
dialAddr = benchmarkTCPDialAddr(tb, dialAddr)
}
if err := client.Connect(network, dialAddr); err != nil {
tb.Fatalf("client Connect failed: %v", err) tb.Fatalf("client Connect failed: %v", err)
} }
tb.Cleanup(func() { tb.Cleanup(func() {
@@ -333,6 +337,9 @@ func bulkBenchmarkListenAddr(tb testing.TB, network string) string {
case "unix": case "unix":
return filepath.Join(tb.TempDir(), "notify-bulk.sock") return filepath.Join(tb.TempDir(), "notify-bulk.sock")
case "udp", "tcp": case "udp", "tcp":
if network == "tcp" {
return benchmarkTCPListenAddr(tb)
}
return "127.0.0.1:0" return "127.0.0.1:0"
default: default:
tb.Fatalf("unsupported benchmark network %q", network) tb.Fatalf("unsupported benchmark network %q", network)
+398 -48
View File
@@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
"time" "time"
@@ -120,32 +121,149 @@ func decodeBulkFastDataFrame(payload []byte) (bulkFastDataFrame, bool, error) {
} }
func (c *ClientCommon) encodeFastBulkDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) { func (c *ClientCommon) encodeFastBulkDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
if c != nil && c.fastBulkEncode != nil { return c.encodeBulkFastPayload(bulkFastFrame{
return c.fastBulkEncode(c.SecretKey, dataID, seq, chunk) Type: bulkFastPayloadTypeData,
} DataID: dataID,
scratch := getBulkFastFrameScratch(len(chunk)) Seq: seq,
defer putBulkFastFrameScratch(scratch) Payload: chunk,
frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)] })
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil {
return nil, err
}
copy(frame[bulkFastPayloadHeaderLen:], chunk)
return c.encryptTransportPayload(frame)
} }
func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte) error { func (c *ClientCommon) encodeBulkFastPayload(frame bulkFastFrame) ([]byte, error) {
if c == nil {
return nil, errBulkClientNil
}
profile := c.clientTransportProtectionSnapshot()
if profile.fastPlainEncode != nil {
return encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
}
plain, err := encodeBulkFastFramePayload(frame)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (c *ClientCommon) encodeBulkFastBatchPayload(frames []bulkFastFrame) ([]byte, error) {
if c == nil {
return nil, errBulkClientNil
}
profile := c.clientTransportProtectionSnapshot()
if profile.fastPlainEncode != nil {
return encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
}
plain, err := encodeBulkFastBatchPlain(frames)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (c *ClientCommon) encodeBulkFastPayloadPooled(frame bulkFastFrame) ([]byte, func(), error) {
if c == nil {
return nil, nil, errBulkClientNil
}
profile := c.clientTransportProtectionSnapshot()
if runtime := profile.runtime; runtime != nil {
return encodeBulkFastFramePayloadPooled(runtime, frame)
}
if profile.fastPlainEncode != nil {
payload, err := encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
return payload, nil, err
}
plain, err := encodeBulkFastFramePayload(frame)
if err != nil {
return nil, nil, err
}
payload, err := c.encryptTransportPayload(plain)
return payload, nil, err
}
func (c *ClientCommon) encodeBulkFastBatchPayloadPooled(frames []bulkFastFrame) ([]byte, func(), error) {
if c == nil {
return nil, nil, errBulkClientNil
}
profile := c.clientTransportProtectionSnapshot()
if runtime := profile.runtime; runtime != nil {
return encodeBulkFastBatchPayloadPooled(runtime, frames)
}
if profile.fastPlainEncode != nil {
payload, err := encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
return payload, nil, err
}
plain, err := encodeBulkFastBatchPlain(frames)
if err != nil {
return nil, nil, err
}
payload, err := c.encryptTransportPayload(plain)
return payload, nil, err
}
func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte, fastPathVersion uint8) error {
binding := c.clientTransportBindingSnapshot()
if binding == nil {
return net.ErrClosed
}
if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil {
return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk)
}
payload, err := c.encodeFastBulkDataPayload(dataID, seq, chunk) payload, err := c.encodeFastBulkDataPayload(dataID, seq, chunk)
if err != nil { if err != nil {
return err return err
} }
return c.writePayloadToTransport(payload)
}
func (c *ClientCommon) sendFastBulkWrite(ctx context.Context, dataID uint64, startSeq uint64, chunkSize int, fastPathVersion uint8, payload []byte, payloadOwned bool) (int, error) {
if len(payload) == 0 {
return 0, nil
}
binding := c.clientTransportBindingSnapshot()
if binding == nil {
return 0, net.ErrClosed
}
if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil {
return sender.submitWrite(ctx, dataID, startSeq, fastPathVersion, payload, chunkSize, payloadOwned)
}
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
written := 0
seq := startSeq
for written < len(payload) {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
if err := c.sendFastBulkData(ctx, dataID, seq, payload[written:end], fastPathVersion); err != nil {
return written, err
}
seq++
written = end
}
return written, nil
}
func (c *ClientCommon) sendFastBulkControl(ctx context.Context, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
frame := bulkFastFrame{
Type: frameType,
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload,
}
binding := c.clientTransportBindingSnapshot() binding := c.clientTransportBindingSnapshot()
if binding == nil { if binding == nil {
return net.ErrClosed return net.ErrClosed
} }
if sender := binding.bulkBatchSenderSnapshot(); sender != nil { if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil {
return sender.submit(ctx, payload) return sender.submitControl(ctx, frameType, flags, dataID, seq, fastPathVersion, payload)
} }
return c.writePayloadToTransport(payload) encoded, err := c.encodeBulkFastPayload(frame)
if err != nil {
return err
}
return c.writePayloadToTransport(encoded)
} }
func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
@@ -157,22 +275,81 @@ func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8
} }
func (s *ServerCommon) encodeFastBulkDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) { func (s *ServerCommon) encodeFastBulkDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
if logical != nil { return s.encodeBulkFastPayloadLogical(logical, bulkFastFrame{
if fastBulkEncode := logical.fastBulkEncodeSnapshot(); fastBulkEncode != nil { Type: bulkFastPayloadTypeData,
return fastBulkEncode(logical.secretKeySnapshot(), dataID, seq, chunk) DataID: dataID,
} Seq: seq,
} Payload: chunk,
scratch := getBulkFastFrameScratch(len(chunk)) })
defer putBulkFastFrameScratch(scratch)
frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)]
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil {
return nil, err
}
copy(frame[bulkFastPayloadHeaderLen:], chunk)
return s.encryptTransportPayloadLogical(logical, frame)
} }
func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte) error { func (s *ServerCommon) encodeBulkFastPayloadLogical(logical *LogicalConn, frame bulkFastFrame) ([]byte, error) {
if logical == nil {
return nil, errTransportDetached
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeBulkFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame)
}
plain, err := encodeBulkFastFramePayload(frame)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func (s *ServerCommon) encodeBulkFastBatchPayloadLogical(logical *LogicalConn, frames []bulkFastFrame) ([]byte, error) {
if logical == nil {
return nil, errTransportDetached
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeBulkFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames)
}
plain, err := encodeBulkFastBatchPlain(frames)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func (s *ServerCommon) encodeBulkFastPayloadLogicalPooled(logical *LogicalConn, frame bulkFastFrame) ([]byte, func(), error) {
if logical == nil {
return nil, nil, errTransportDetached
}
if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil {
return encodeBulkFastFramePayloadPooled(runtime, frame)
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
payload, err := encodeBulkFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame)
return payload, nil, err
}
plain, err := encodeBulkFastFramePayload(frame)
if err != nil {
return nil, nil, err
}
payload, err := s.encryptTransportPayloadLogical(logical, plain)
return payload, nil, err
}
func (s *ServerCommon) encodeBulkFastBatchPayloadLogicalPooled(logical *LogicalConn, frames []bulkFastFrame) ([]byte, func(), error) {
if logical == nil {
return nil, nil, errTransportDetached
}
if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil {
return encodeBulkFastBatchPayloadPooled(runtime, frames)
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
payload, err := encodeBulkFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames)
return payload, nil, err
}
plain, err := encodeBulkFastBatchPlain(frames)
if err != nil {
return nil, nil, err
}
payload, err := s.encryptTransportPayloadLogical(logical, plain)
return payload, nil, err
}
func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte, fastPathVersion uint8) error {
if err := s.ensureServerTransportSendReady(transport); err != nil { if err := s.ensureServerTransportSendReady(transport); err != nil {
return err return err
} }
@@ -182,18 +359,87 @@ func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *L
if logical == nil { if logical == nil {
return errTransportDetached return errTransportDetached
} }
if binding := logical.transportBindingSnapshot(); binding != nil {
if binding.queueSnapshot() != nil {
if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil {
return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk)
}
}
}
payload, err := s.encodeFastBulkDataPayloadLogical(logical, dataID, seq, chunk) payload, err := s.encodeFastBulkDataPayloadLogical(logical, dataID, seq, chunk)
if err != nil { if err != nil {
return err return err
} }
return s.writeEnvelopePayload(logical, transport, nil, payload)
}
func (s *ServerCommon) sendFastBulkWriteTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, startSeq uint64, chunkSize int, fastPathVersion uint8, payload []byte, payloadOwned bool) (int, error) {
if len(payload) == 0 {
return 0, nil
}
if err := s.ensureServerTransportSendReady(transport); err != nil {
return 0, err
}
if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot()
}
if logical == nil {
return 0, errTransportDetached
}
if binding := logical.transportBindingSnapshot(); binding != nil { if binding := logical.transportBindingSnapshot(); binding != nil {
if binding.queueSnapshot() != nil { if binding.queueSnapshot() != nil {
if sender := binding.bulkBatchSenderSnapshot(); sender != nil { if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil {
return sender.submit(ctx, payload) return sender.submitWrite(ctx, dataID, startSeq, fastPathVersion, payload, chunkSize, payloadOwned)
} }
} }
} }
return s.writeEnvelopePayload(logical, transport, nil, payload) if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
written := 0
seq := startSeq
for written < len(payload) {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
if err := s.sendFastBulkDataTransport(ctx, logical, transport, dataID, seq, payload[written:end], fastPathVersion); err != nil {
return written, err
}
seq++
written = end
}
return written, nil
}
func (s *ServerCommon) sendFastBulkControlTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
if err := s.ensureServerTransportSendReady(transport); err != nil {
return err
}
if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot()
}
if logical == nil {
return errTransportDetached
}
if binding := logical.transportBindingSnapshot(); binding != nil {
if binding.queueSnapshot() != nil {
if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil {
return sender.submitControl(ctx, frameType, flags, dataID, seq, fastPathVersion, payload)
}
}
}
encoded, err := s.encodeBulkFastPayloadLogical(logical, bulkFastFrame{
Type: frameType,
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload,
})
if err != nil {
return err
}
return s.writeEnvelopePayload(logical, transport, nil, encoded)
} }
func (s *ServerCommon) encodeBulkFastControlPayloadLogical(logical *LogicalConn, frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { func (s *ServerCommon) encodeBulkFastControlPayloadLogical(logical *LogicalConn, frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
@@ -219,24 +465,126 @@ func putBulkFastFrameScratch(buf []byte) {
bulkFastFrameScratchPool.Put(buf[:0]) bulkFastFrameScratchPool.Put(buf[:0])
} }
func transportFastPayloadMagic(payload []byte) string {
if len(payload) < 4 {
return ""
}
return string(payload[:4])
}
func (c *ClientCommon) decryptTransportPayloadPooled(payload []byte, release func()) ([]byte, func(), error) {
profile := c.clientTransportProtectionSnapshot()
return decryptTransportPayloadCodecPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload, release)
}
func (s *ServerCommon) decryptTransportPayloadLogicalPooled(logical *LogicalConn, payload []byte, release func()) ([]byte, func(), error) {
if logical == nil {
if release != nil {
release()
}
return nil, nil, errTransportDetached
}
return decryptTransportPayloadCodecPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, release)
}
func (c *ClientCommon) tryDispatchBorrowedTransportPlain(plain []byte, release func()) bool {
switch transportFastPayloadMagic(plain) {
case bulkFastPayloadMagic, bulkFastBatchMagic:
owner := newBulkReadPayloadOwner(release)
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
c.dispatchFastBulkFrameWithOwner(frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
walkErr = errBulkFastPayloadInvalid
}
if walkErr != nil && (c.showError || c.debugMode) {
fmt.Println("client decode bulk fast payload error", walkErr)
}
return true
case streamFastPayloadMagic, streamFastBatchMagic:
owner := newStreamReadPayloadOwner(release)
matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
c.dispatchFastStreamDataWithOwner(frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
walkErr = errStreamFastPayloadInvalid
}
if walkErr != nil && (c.showError || c.debugMode) {
fmt.Println("client decode stream fast payload error", walkErr)
}
return true
default:
return false
}
}
func (s *ServerCommon) tryDispatchBorrowedTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, release func()) bool {
switch transportFastPayloadMagic(plain) {
case bulkFastPayloadMagic, bulkFastBatchMagic:
owner := newBulkReadPayloadOwner(release)
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
walkErr = errBulkFastPayloadInvalid
}
if walkErr != nil && (s.showError || s.debugMode) {
fmt.Println("server decode bulk fast payload error", walkErr)
}
return true
case streamFastPayloadMagic, streamFastBatchMagic:
owner := newStreamReadPayloadOwner(release)
matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
s.dispatchFastStreamDataWithOwner(logical, transport, conn, frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
walkErr = errStreamFastPayloadInvalid
}
if walkErr != nil && (s.showError || s.debugMode) {
fmt.Println("server decode stream fast payload error", walkErr)
}
return true
default:
return false
}
}
func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error { func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error {
plain, err := c.decryptTransportPayload(payload) plain, err := c.decryptTransportPayload(payload)
if err != nil { if err != nil {
return err return err
} }
if frame, matched, err := decodeBulkFastFrame(plain); matched { return c.dispatchInboundTransportPlain(plain, now)
if err != nil { }
return err
} func (c *ClientCommon) dispatchInboundTransportPlain(plain []byte, now time.Time) error {
if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
c.dispatchFastBulkFrame(frame) c.dispatchFastBulkFrame(frame)
return nil return nil
}); matched {
return err
} }
if frame, matched, err := decodeStreamFastDataFrame(plain); matched { if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
if err != nil {
return err
}
c.dispatchFastStreamData(frame) c.dispatchFastStreamData(frame)
return nil return nil
}); matched {
return err
} }
env, err := c.decodeEnvelopePlain(plain) env, err := c.decodeEnvelopePlain(plain)
if err != nil { if err != nil {
@@ -257,19 +605,21 @@ func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, tra
if err != nil { if err != nil {
return err return err
} }
if frame, matched, err := decodeBulkFastFrame(plain); matched { return s.dispatchInboundTransportPlain(logical, transport, conn, plain, now)
if err != nil { }
return err
} func (s *ServerCommon) dispatchInboundTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, now time.Time) error {
if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
s.dispatchFastBulkFrame(logical, transport, conn, frame) s.dispatchFastBulkFrame(logical, transport, conn, frame)
return nil return nil
}); matched {
return err
} }
if frame, matched, err := decodeStreamFastDataFrame(plain); matched { if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
if err != nil {
return err
}
s.dispatchFastStreamData(logical, transport, conn, frame) s.dispatchFastStreamData(logical, transport, conn, frame)
return nil return nil
}); matched {
return err
} }
env, err := s.decodeEnvelopePlain(plain) env, err := s.decodeEnvelopePlain(plain)
if err != nil { if err != nil {
+158 -13
View File
@@ -2,7 +2,7 @@ package notify
import ( import (
"fmt" "fmt"
"strconv" "net"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -16,14 +16,14 @@ type bulkRuntime struct {
mu sync.RWMutex mu sync.RWMutex
handler func(BulkAcceptInfo) error handler func(BulkAcceptInfo) error
bulks map[string]*bulkHandle bulks map[string]*bulkHandle
data map[string]*bulkHandle data map[string]map[uint64]*bulkHandle
} }
func newBulkRuntime(rolePrefix string) *bulkRuntime { func newBulkRuntime(rolePrefix string) *bulkRuntime {
return &bulkRuntime{ return &bulkRuntime{
rolePrefix: rolePrefix, rolePrefix: rolePrefix,
bulks: make(map[string]*bulkHandle), bulks: make(map[string]*bulkHandle),
data: make(map[string]*bulkHandle), data: make(map[string]map[uint64]*bulkHandle),
} }
} }
@@ -66,8 +66,8 @@ func (r *bulkRuntime) register(scope string, bulk *bulkHandle) error {
if bulk == nil || bulk.id == "" { if bulk == nil || bulk.id == "" {
return errBulkIDEmpty return errBulkIDEmpty
} }
scope = normalizeFileScope(scope)
key := bulkRuntimeKey(scope, bulk.id) key := bulkRuntimeKey(scope, bulk.id)
dataKey := bulkRuntimeDataKey(scope, bulk.dataID)
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if _, ok := r.bulks[key]; ok { if _, ok := r.bulks[key]; ok {
@@ -76,11 +76,16 @@ func (r *bulkRuntime) register(scope string, bulk *bulkHandle) error {
if bulk.dataID == 0 { if bulk.dataID == 0 {
return errBulkDataIDEmpty return errBulkDataIDEmpty
} }
if _, ok := r.data[dataKey]; ok { dataScope := r.data[scope]
if dataScope == nil {
dataScope = make(map[uint64]*bulkHandle)
r.data[scope] = dataScope
}
if _, ok := dataScope[bulk.dataID]; ok {
return errBulkAlreadyExists return errBulkAlreadyExists
} }
r.bulks[key] = bulk r.bulks[key] = bulk
r.data[dataKey] = bulk dataScope[bulk.dataID] = bulk
return nil return nil
} }
@@ -99,10 +104,14 @@ func (r *bulkRuntime) lookupByDataID(scope string, dataID uint64) (*bulkHandle,
if r == nil || dataID == 0 { if r == nil || dataID == 0 {
return nil, false return nil, false
} }
key := bulkRuntimeDataKey(scope, dataID) scope = normalizeFileScope(scope)
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
bulk, ok := r.data[key] dataScope := r.data[scope]
if dataScope == nil {
return nil, false
}
bulk, ok := dataScope[dataID]
return bulk, ok return bulk, ok
} }
@@ -110,11 +119,17 @@ func (r *bulkRuntime) remove(scope string, bulkID string) {
if r == nil || bulkID == "" { if r == nil || bulkID == "" {
return return
} }
scope = normalizeFileScope(scope)
key := bulkRuntimeKey(scope, bulkID) key := bulkRuntimeKey(scope, bulkID)
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if bulk := r.bulks[key]; bulk != nil && bulk.dataID != 0 { if bulk := r.bulks[key]; bulk != nil && bulk.dataID != 0 {
delete(r.data, bulkRuntimeDataKey(scope, bulk.dataID)) if dataScope := r.data[scope]; dataScope != nil {
delete(dataScope, bulk.dataID)
if len(dataScope) == 0 {
delete(r.data, scope)
}
}
} }
delete(r.bulks, key) delete(r.bulks, key)
} }
@@ -149,6 +164,140 @@ func (r *bulkRuntime) closeMatching(match func(string) bool, err error) {
} }
} }
func (r *bulkRuntime) resetDedicatedByConn(scope string, conn net.Conn, err error) {
if r == nil || conn == nil {
return
}
scope = normalizeFileScope(scope)
resetErr := bulkRuntimeCloseError(err)
prefix := scope + "\x00"
r.mu.RLock()
bulks := make([]*bulkHandle, 0, len(r.bulks))
for key, bulk := range r.bulks {
if bulk == nil {
continue
}
if !strings.HasPrefix(key, prefix) {
continue
}
if !bulk.Dedicated() {
continue
}
if bulk.dedicatedConnSnapshot() != conn {
continue
}
bulks = append(bulks, bulk)
}
r.mu.RUnlock()
for _, bulk := range bulks {
bulk.markReset(resetErr)
}
}
func (r *bulkRuntime) handleDedicatedReadErrorByConn(scope string, conn net.Conn, err error) {
if r == nil || conn == nil {
return
}
scope = normalizeFileScope(scope)
prefix := scope + "\x00"
r.mu.RLock()
bulks := make([]*bulkHandle, 0, len(r.bulks))
for key, bulk := range r.bulks {
if bulk == nil {
continue
}
if !strings.HasPrefix(key, prefix) {
continue
}
if !bulk.Dedicated() {
continue
}
if bulk.dedicatedConnSnapshot() != conn {
continue
}
bulks = append(bulks, bulk)
}
r.mu.RUnlock()
for _, bulk := range bulks {
handleDedicatedBulkReadError(bulk, err)
}
}
func (r *bulkRuntime) attachSharedDedicatedConn(scope string, laneID uint32, conn net.Conn) {
if r == nil || conn == nil {
return
}
scope = normalizeFileScope(scope)
laneID = normalizeBulkDedicatedLaneID(laneID)
prefix := scope + "\x00"
r.mu.RLock()
bulks := make([]*bulkHandle, 0, len(r.bulks))
for key, bulk := range r.bulks {
if bulk == nil {
continue
}
if !strings.HasPrefix(key, prefix) {
continue
}
if !bulk.Dedicated() {
continue
}
if bulk.dedicatedLaneIDSnapshot() != laneID {
continue
}
bulks = append(bulks, bulk)
}
r.mu.RUnlock()
for _, bulk := range bulks {
current := bulk.dedicatedConnSnapshot()
switch {
case current == conn:
_ = bulk.attachDedicatedConnShared(conn)
case current == nil:
_ = bulk.attachDedicatedConnShared(conn)
default:
oldConn, oldSender, err := bulk.replaceDedicatedConnShared(conn)
if err != nil {
bulk.markReset(err)
continue
}
if oldSender != nil {
oldSender.stop()
}
if oldConn != nil {
_ = oldConn.Close()
}
}
}
}
func (r *bulkRuntime) dedicatedBulksForConn(scope string, conn net.Conn) []*bulkHandle {
if r == nil || conn == nil {
return nil
}
scope = normalizeFileScope(scope)
prefix := scope + "\x00"
r.mu.RLock()
bulks := make([]*bulkHandle, 0, len(r.bulks))
for key, bulk := range r.bulks {
if bulk == nil {
continue
}
if !strings.HasPrefix(key, prefix) {
continue
}
if !bulk.Dedicated() {
continue
}
if bulk.dedicatedConnSnapshot() != conn {
continue
}
bulks = append(bulks, bulk)
}
r.mu.RUnlock()
return bulks
}
func (r *bulkRuntime) snapshots() []BulkSnapshot { func (r *bulkRuntime) snapshots() []BulkSnapshot {
if r == nil { if r == nil {
return nil return nil
@@ -170,10 +319,6 @@ func bulkRuntimeKey(scope string, bulkID string) string {
return normalizeFileScope(scope) + "\x00" + bulkID return normalizeFileScope(scope) + "\x00" + bulkID
} }
func bulkRuntimeDataKey(scope string, dataID uint64) string {
return normalizeFileScope(scope) + "\x01" + strconv.FormatUint(dataID, 10)
}
func bulkRuntimeCloseError(err error) error { func bulkRuntimeCloseError(err error) error {
if err != nil { if err != nil {
return err return err
+228
View File
@@ -0,0 +1,228 @@
package notify
import (
"encoding/binary"
)
const (
bulkFastPathVersionV1 = 1
bulkFastPathVersionV2 = 2
bulkFastPathVersionCurrent = bulkFastPathVersionV2
)
const (
bulkFastBatchMagic = "NBF2"
bulkFastBatchVersion = 1
bulkFastBatchHeaderLen = 12
bulkFastBatchItemHeaderLen = 24
bulkFastBatchMaxItems = 64
bulkFastBatchMaxPlainBytes = 8 * 1024 * 1024
)
func normalizeBulkFastPathVersion(version uint8) uint8 {
if version < bulkFastPathVersionV1 {
return bulkFastPathVersionV1
}
if version > bulkFastPathVersionCurrent {
return bulkFastPathVersionCurrent
}
return version
}
func negotiateBulkFastPathVersion(version uint8) uint8 {
return normalizeBulkFastPathVersion(version)
}
func bulkFastPathSupportsSharedBatch(version uint8) bool {
return normalizeBulkFastPathVersion(version) >= bulkFastPathVersionV2
}
func bulkFastBatchFrameLen(frame bulkFastFrame) int {
return bulkFastBatchItemHeaderLen + len(frame.Payload)
}
func bulkFastBatchPlainLen(frames []bulkFastFrame) int {
total := bulkFastBatchHeaderLen
for _, frame := range frames {
total += bulkFastBatchFrameLen(frame)
}
return total
}
func encodeBulkFastFramePayload(frame bulkFastFrame) ([]byte, error) {
return encodeBulkFastControlFrame(frame.Type, frame.Flags, frame.DataID, frame.Seq, frame.Payload)
}
func encodeBulkFastFramePayloadFast(encode transportFastPlainEncoder, secretKey []byte, frame bulkFastFrame) ([]byte, error) {
if encode == nil {
return nil, errTransportPayloadEncryptFailed
}
plainLen := bulkFastPayloadHeaderLen + len(frame.Payload)
return encode(secretKey, plainLen, func(dst []byte) error {
if err := encodeBulkFastFrameHeader(dst, frame.Type, frame.Flags, frame.DataID, frame.Seq, len(frame.Payload)); err != nil {
return err
}
copy(dst[bulkFastPayloadHeaderLen:], frame.Payload)
return nil
})
}
func encodeBulkFastFramePayloadPooled(runtime *modernPSKCodecRuntime, frame bulkFastFrame) ([]byte, func(), error) {
if runtime == nil {
return nil, nil, errTransportPayloadEncryptFailed
}
plainLen := bulkFastPayloadHeaderLen + len(frame.Payload)
return runtime.sealFilledPayloadPooled(plainLen, func(dst []byte) error {
if err := encodeBulkFastFrameHeader(dst, frame.Type, frame.Flags, frame.DataID, frame.Seq, len(frame.Payload)); err != nil {
return err
}
copy(dst[bulkFastPayloadHeaderLen:], frame.Payload)
return nil
})
}
func encodeBulkFastBatchPlain(frames []bulkFastFrame) ([]byte, error) {
if len(frames) == 0 {
return nil, errBulkFastPayloadInvalid
}
buf := make([]byte, bulkFastBatchPlainLen(frames))
if err := writeBulkFastBatchPlain(buf, frames); err != nil {
return nil, err
}
return buf, nil
}
func encodeBulkFastBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, frames []bulkFastFrame) ([]byte, error) {
if encode == nil {
return nil, errTransportPayloadEncryptFailed
}
plainLen := bulkFastBatchPlainLen(frames)
return encode(secretKey, plainLen, func(dst []byte) error {
return writeBulkFastBatchPlain(dst, frames)
})
}
func encodeBulkFastBatchPayloadPooled(runtime *modernPSKCodecRuntime, frames []bulkFastFrame) ([]byte, func(), error) {
if runtime == nil {
return nil, nil, errTransportPayloadEncryptFailed
}
return runtime.sealFilledPayloadPooled(bulkFastBatchPlainLen(frames), func(dst []byte) error {
return writeBulkFastBatchPlain(dst, frames)
})
}
func writeBulkFastBatchPlain(dst []byte, frames []bulkFastFrame) error {
if len(frames) == 0 || len(dst) != bulkFastBatchPlainLen(frames) {
return errBulkFastPayloadInvalid
}
copy(dst[:4], bulkFastBatchMagic)
dst[4] = bulkFastBatchVersion
binary.BigEndian.PutUint32(dst[8:12], uint32(len(frames)))
offset := bulkFastBatchHeaderLen
for _, frame := range frames {
if frame.DataID == 0 {
return errBulkFastPayloadInvalid
}
dst[offset] = frame.Type
dst[offset+1] = frame.Flags
binary.BigEndian.PutUint64(dst[offset+4:offset+12], frame.DataID)
binary.BigEndian.PutUint64(dst[offset+12:offset+20], frame.Seq)
binary.BigEndian.PutUint32(dst[offset+20:offset+24], uint32(len(frame.Payload)))
offset += bulkFastBatchItemHeaderLen
copy(dst[offset:offset+len(frame.Payload)], frame.Payload)
offset += len(frame.Payload)
}
return nil
}
func walkBulkFastBatchPlain(payload []byte, fn func(bulkFastFrame) error) (bool, error) {
if len(payload) < 4 || string(payload[:4]) != bulkFastBatchMagic {
return false, nil
}
if len(payload) < bulkFastBatchHeaderLen {
return true, errBulkFastPayloadInvalid
}
if payload[4] != bulkFastBatchVersion {
return true, errBulkFastPayloadInvalid
}
count := int(binary.BigEndian.Uint32(payload[8:12]))
if count <= 0 {
return true, errBulkFastPayloadInvalid
}
offset := bulkFastBatchHeaderLen
for index := 0; index < count; index++ {
if len(payload)-offset < bulkFastBatchItemHeaderLen {
return true, errBulkFastPayloadInvalid
}
frameType := payload[offset]
switch frameType {
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
default:
return true, errBulkFastPayloadInvalid
}
flags := payload[offset+1]
dataID := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
seq := binary.BigEndian.Uint64(payload[offset+12 : offset+20])
payloadLen := int(binary.BigEndian.Uint32(payload[offset+20 : offset+24]))
offset += bulkFastBatchItemHeaderLen
if dataID == 0 || payloadLen < 0 || len(payload)-offset < payloadLen {
return true, errBulkFastPayloadInvalid
}
if fn != nil {
if err := fn(bulkFastFrame{
Type: frameType,
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload[offset : offset+payloadLen],
}); err != nil {
return true, err
}
}
offset += payloadLen
}
if offset != len(payload) {
return true, errBulkFastPayloadInvalid
}
return true, nil
}
func decodeBulkFastBatchPlain(payload []byte) ([]bulkFastFrame, bool, error) {
frames := make([]bulkFastFrame, 0, 1)
matched, err := walkBulkFastBatchPlain(payload, func(frame bulkFastFrame) error {
frames = append(frames, frame)
return nil
})
if !matched || err != nil {
return nil, matched, err
}
return frames, true, nil
}
func walkBulkFastFrames(payload []byte, fn func(bulkFastFrame) error) (bool, error) {
if matched, err := walkBulkFastBatchPlain(payload, fn); matched {
return true, err
}
frame, matched, err := decodeBulkFastFrame(payload)
if !matched || err != nil {
return matched, err
}
if fn != nil {
if err := fn(frame); err != nil {
return true, err
}
}
return true, nil
}
func decodeBulkFastFrames(payload []byte) ([]bulkFastFrame, bool, error) {
frames := make([]bulkFastFrame, 0, 1)
matched, err := walkBulkFastFrames(payload, func(frame bulkFastFrame) error {
frames = append(frames, frame)
return nil
})
if !matched || err != nil {
return nil, matched, err
}
return frames, true, nil
}
+458
View File
@@ -0,0 +1,458 @@
package notify
import (
"context"
"testing"
"time"
)
func TestBulkFastBatchPlainRoundTrip(t *testing.T) {
releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2)
if err != nil {
t.Fatalf("encodeBulkDedicatedReleasePayload failed: %v", err)
}
frames := []bulkFastFrame{
{
Type: bulkFastPayloadTypeData,
DataID: 11,
Seq: 7,
Payload: []byte("alpha"),
},
{
Type: bulkFastPayloadTypeRelease,
DataID: 12,
Seq: 0,
Payload: releasePayload,
},
}
wire, err := encodeBulkFastBatchPlain(frames)
if err != nil {
t.Fatalf("encodeBulkFastBatchPlain failed: %v", err)
}
decoded, matched, err := decodeBulkFastBatchPlain(wire)
if err != nil {
t.Fatalf("decodeBulkFastBatchPlain failed: %v", err)
}
if !matched {
t.Fatal("decodeBulkFastBatchPlain should match encoded batch")
}
if got, want := len(decoded), len(frames); got != want {
t.Fatalf("decoded frame count = %d, want %d", got, want)
}
for index := range frames {
if got, want := decoded[index].Type, frames[index].Type; got != want {
t.Fatalf("frame %d type = %d, want %d", index, got, want)
}
if got, want := decoded[index].DataID, frames[index].DataID; got != want {
t.Fatalf("frame %d dataID = %d, want %d", index, got, want)
}
if got, want := decoded[index].Seq, frames[index].Seq; got != want {
t.Fatalf("frame %d seq = %d, want %d", index, got, want)
}
if got, want := string(decoded[index].Payload), string(frames[index].Payload); got != want {
t.Fatalf("frame %d payload = %q, want %q", index, got, want)
}
}
}
func TestBulkBatchSenderEncodeRequestsCoalescesSharedFastV2Frames(t *testing.T) {
var (
singleCalls int
batchCalls [][]bulkFastFrame
)
sender := &bulkBatchSender{
codec: bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
singleCalls++
return []byte("single"), nil, nil
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
cloned := make([]bulkFastFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil, nil
},
},
}
payloads, err := sender.encodeRequests([]bulkBatchRequest{
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 101,
Seq: 1,
Payload: []byte("a"),
}},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeRelease,
DataID: 101,
Seq: 0,
Payload: []byte("rel"),
}},
fastPathVersion: bulkFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 1; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if singleCalls != 0 {
t.Fatalf("single encode calls = %d, want 0", singleCalls)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("batched item count = %d, want %d", got, want)
}
}
func TestBulkBatchSenderEncodeRequestsCoalescesAcrossRequestsAndBulks(t *testing.T) {
var (
singleCalls int
batchCalls [][]bulkFastFrame
)
sender := &bulkBatchSender{
codec: bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
singleCalls++
return []byte("single"), nil, nil
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
cloned := make([]bulkFastFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil, nil
},
},
}
payloads, err := sender.encodeRequests([]bulkBatchRequest{
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 101,
Seq: 1,
Payload: []byte("bulk-a"),
}},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 202,
Seq: 1,
Payload: []byte("bulk-b"),
}},
fastPathVersion: bulkFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 1; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if singleCalls != 0 {
t.Fatalf("single encode calls = %d, want 0", singleCalls)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("batched item count = %d, want %d", got, want)
}
if got, want := batchCalls[0][0].DataID, uint64(101); got != want {
t.Fatalf("first batched dataID = %d, want %d", got, want)
}
if got, want := batchCalls[0][1].DataID, uint64(202); got != want {
t.Fatalf("second batched dataID = %d, want %d", got, want)
}
}
func TestBulkBatchSenderEncodeRequestsSplitsLargeCrossBulkSuperBatch(t *testing.T) {
var batchCalls [][]bulkFastFrame
sender := &bulkBatchSender{
codec: bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
return []byte("single"), nil, nil
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
cloned := make([]bulkFastFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil, nil
},
},
}
payload := make([]byte, 768*1024)
payloads, err := sender.encodeRequests([]bulkBatchRequest{
{
frames: []bulkFastFrame{
{
Type: bulkFastPayloadTypeData,
DataID: 101,
Seq: 1,
Payload: payload,
},
{
Type: bulkFastPayloadTypeData,
DataID: 101,
Seq: 2,
Payload: payload,
},
},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{
{
Type: bulkFastPayloadTypeData,
DataID: 202,
Seq: 1,
Payload: payload,
},
{
Type: bulkFastPayloadTypeData,
DataID: 202,
Seq: 2,
Payload: payload,
},
},
fastPathVersion: bulkFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 2; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := len(batchCalls), 2; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("first payload frame count = %d, want %d", got, want)
}
if got, want := len(batchCalls[1]), 2; got != want {
t.Fatalf("second payload frame count = %d, want %d", got, want)
}
if got, want := batchCalls[0][0].DataID, uint64(101); got != want {
t.Fatalf("first payload dataID = %d, want %d", got, want)
}
if got, want := batchCalls[1][0].DataID, uint64(202); got != want {
t.Fatalf("second payload dataID = %d, want %d", got, want)
}
}
func TestBulkReleaseFastRoundTripTransport(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan BulkAcceptInfo, 1)
server.SetBulkHandler(func(info BulkAcceptInfo) error {
acceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
const chunkSize = 64 * 1024
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
Range: BulkRange{Offset: 0, Length: chunkSize},
ChunkSize: chunkSize,
WindowBytes: chunkSize,
MaxInFlight: 1,
})
if err != nil {
t.Fatalf("client OpenBulk failed: %v", err)
}
accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second)
clientHandle := bulk.(*bulkHandle)
serverHandle := accepted.Bulk.(*bulkHandle)
if got, want := clientHandle.FastPathVersion(), uint8(bulkFastPathVersionV2); got != want {
t.Fatalf("client fast path version = %d, want %d", got, want)
}
if got, want := serverHandle.FastPathVersion(), uint8(bulkFastPathVersionV2); got != want {
t.Fatalf("server fast path version = %d, want %d", got, want)
}
clientHandle.mu.Lock()
clientHandle.outboundAvailBytes = 0
clientHandle.outboundInFlight = 1
clientHandle.mu.Unlock()
releaseFn := serverBulkReleaseSender(server, accepted.LogicalConn, accepted.TransportConn)
if err := releaseFn(serverHandle, chunkSize, 1); err != nil {
t.Fatalf("server bulk release sender failed: %v", err)
}
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
clientHandle.mu.Lock()
avail := clientHandle.outboundAvailBytes
inFlight := clientHandle.outboundInFlight
clientHandle.mu.Unlock()
if avail == chunkSize && inFlight == 0 {
return
}
time.Sleep(10 * time.Millisecond)
}
clientHandle.mu.Lock()
avail := clientHandle.outboundAvailBytes
inFlight := clientHandle.outboundInFlight
clientHandle.mu.Unlock()
t.Fatalf("client outbound window not released by fast path: avail=%d inFlight=%d", avail, inFlight)
}
func TestBulkBatchSenderEncodeRequestsResetsBatchBytesAfterFlushBoundary(t *testing.T) {
var (
singleCalls int
batchCalls [][]bulkFastFrame
)
sender := &bulkBatchSender{
codec: bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
singleCalls++
return []byte("single"), nil, nil
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
cloned := make([]bulkFastFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil, nil
},
},
}
largePayload := make([]byte, bulkFastBatchMaxPlainBytes-bulkFastBatchHeaderLen-bulkFastBatchItemHeaderLen-128)
payloads, err := sender.encodeRequests([]bulkBatchRequest{
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 101,
Seq: 1,
Payload: largePayload,
}},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 101,
Seq: 2,
Payload: []byte("sep"),
}},
fastPathVersion: bulkFastPathVersionV1,
},
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 202,
Seq: 1,
Payload: []byte("a"),
}},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 202,
Seq: 2,
Payload: []byte("b"),
}},
fastPathVersion: bulkFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 3; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := singleCalls, 2; got != want {
t.Fatalf("single encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("post-flush batched frame count = %d, want %d", got, want)
}
}
func TestBulkBatchSenderEncodeRequestsUsesBindingAdaptiveSoftLimit(t *testing.T) {
binding := &transportBinding{}
binding.observeBulkAdaptivePayloadWrite(8*1024*1024, 640*time.Millisecond, 0, nil)
var batchCalls [][]bulkFastFrame
sender := &bulkBatchSender{
binding: binding,
codec: bulkBatchCodec{
encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) {
return []byte("single"), nil, nil
},
encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) {
cloned := make([]bulkFastFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil, nil
},
},
}
payload := make([]byte, 96*1024)
payloads, err := sender.encodeRequests([]bulkBatchRequest{
{
frames: []bulkFastFrame{
{Type: bulkFastPayloadTypeData, DataID: 101, Seq: 1, Payload: payload},
{Type: bulkFastPayloadTypeData, DataID: 101, Seq: 2, Payload: payload},
},
fastPathVersion: bulkFastPathVersionV2,
},
{
frames: []bulkFastFrame{
{Type: bulkFastPayloadTypeData, DataID: 202, Seq: 1, Payload: payload},
{Type: bulkFastPayloadTypeData, DataID: 202, Seq: 2, Payload: payload},
},
fastPathVersion: bulkFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 2; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := len(batchCalls), 2; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
for index, want := range []uint64{101, 202} {
if got := batchCalls[index][0].DataID; got != want {
t.Fatalf("payload %d first dataID = %d, want %d", index, got, want)
}
if got, wantItems := len(batchCalls[index]), 2; got != wantItems {
t.Fatalf("payload %d item count = %d, want %d", index, got, wantItems)
}
}
}
+50 -43
View File
@@ -7,49 +7,56 @@ import (
) )
type BulkSnapshot struct { type BulkSnapshot struct {
ID string ID string
DataID uint64 DataID uint64
Scope string FastPathVersion uint8
Range BulkRange Scope string
Metadata BulkMetadata Range BulkRange
BindingOwner string Metadata BulkMetadata
BindingAlive bool BindingOwner string
BindingCurrent bool BindingAlive bool
BindingReason string BindingCurrent bool
BindingError string BindingReason string
Dedicated bool BindingError string
DedicatedAttached bool BindingBulkAdaptiveSoftPayloadBytes int
SessionEpoch uint64 Dedicated bool
LogicalClientID string DedicatedLaneID uint32
TransportGeneration uint64 DedicatedAttached bool
TransportAttached bool DedicatedAttachState string
TransportHasRuntimeConn bool DedicatedAttachAttempts uint32
TransportCurrent bool DedicatedAttachLastCode string
TransportDetachReason string DedicatedDataStarted bool
TransportDetachKind string SessionEpoch uint64
TransportDetachGeneration uint64 LogicalClientID string
TransportDetachError string TransportGeneration uint64
TransportDetachedAt time.Time TransportAttached bool
ReattachEligible bool TransportHasRuntimeConn bool
LocalClosed bool TransportCurrent bool
LocalReadClosed bool TransportDetachReason string
RemoteClosed bool TransportDetachKind string
PeerReadClosed bool TransportDetachGeneration uint64
BufferedChunks int TransportDetachError string
BufferedBytes int TransportDetachedAt time.Time
ReadTimeout time.Duration ReattachEligible bool
WriteTimeout time.Duration LocalClosed bool
ChunkSize int LocalReadClosed bool
WindowBytes int RemoteClosed bool
MaxInFlight int PeerReadClosed bool
BytesRead int64 BufferedChunks int
BytesWritten int64 BufferedBytes int
ReadCalls int64 ReadTimeout time.Duration
WriteCalls int64 WriteTimeout time.Duration
OpenedAt time.Time ChunkSize int
LastReadAt time.Time WindowBytes int
LastWriteAt time.Time MaxInFlight int
ResetError string BytesRead int64
BytesWritten int64
ReadCalls int64
WriteCalls int64
OpenedAt time.Time
LastReadAt time.Time
LastWriteAt time.Time
ResetError string
} }
type clientBulkSnapshotReader interface { type clientBulkSnapshotReader interface {
+2 -2
View File
@@ -88,7 +88,7 @@ func BenchmarkDedicatedWireLocalhostThroughput(b *testing.B) {
func benchmarkDedicatedWireLocalhostThroughput(b *testing.B, key []byte, encode transportFastPlainEncoder, payloadSize int) { func benchmarkDedicatedWireLocalhostThroughput(b *testing.B, key []byte, encode transportFastPlainEncoder, payloadSize int) {
b.Helper() b.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", benchmarkTCPListenAddr(b))
if err != nil { if err != nil {
b.Fatalf("net.Listen failed: %v", err) b.Fatalf("net.Listen failed: %v", err)
} }
@@ -107,7 +107,7 @@ func benchmarkDedicatedWireLocalhostThroughput(b *testing.B, key []byte, encode
acceptCh <- conn acceptCh <- conn
}() }()
clientConn, err := net.Dial("tcp", listener.Addr().String()) clientConn, err := net.Dial("tcp", benchmarkTCPDialAddr(b, listener.Addr().String()))
if err != nil { if err != nil {
b.Fatalf("net.Dial failed: %v", err) b.Fatalf("net.Dial failed: %v", err)
} }
+320 -3
View File
@@ -109,6 +109,9 @@ func TestBulkOpenRoundTripTCP(t *testing.T) {
if !clientSnapshots[0].BindingAlive || !clientSnapshots[0].BindingCurrent || !clientSnapshots[0].TransportAttached || !clientSnapshots[0].TransportCurrent { if !clientSnapshots[0].BindingAlive || !clientSnapshots[0].BindingCurrent || !clientSnapshots[0].TransportAttached || !clientSnapshots[0].TransportCurrent {
t.Fatalf("client bulk binding snapshot mismatch: %+v", clientSnapshots[0]) t.Fatalf("client bulk binding snapshot mismatch: %+v", clientSnapshots[0])
} }
if got, want := clientSnapshots[0].BindingBulkAdaptiveSoftPayloadBytes, bulkAdaptiveSoftPayloadStartBytes; got != want {
t.Fatalf("client bulk BindingBulkAdaptiveSoftPayloadBytes = %d, want %d", got, want)
}
serverSnapshots, err := GetServerBulkSnapshots(server) serverSnapshots, err := GetServerBulkSnapshots(server)
if err != nil { if err != nil {
t.Fatalf("GetServerBulkSnapshots failed: %v", err) t.Fatalf("GetServerBulkSnapshots failed: %v", err)
@@ -119,6 +122,9 @@ func TestBulkOpenRoundTripTCP(t *testing.T) {
if got, want := serverSnapshots[0].BindingOwner, "server-transport"; got != want { if got, want := serverSnapshots[0].BindingOwner, "server-transport"; got != want {
t.Fatalf("server bulk BindingOwner = %q, want %q", got, want) t.Fatalf("server bulk BindingOwner = %q, want %q", got, want)
} }
if got, want := serverSnapshots[0].BindingBulkAdaptiveSoftPayloadBytes, bulkAdaptiveSoftPayloadStartBytes; got != want {
t.Fatalf("server bulk BindingBulkAdaptiveSoftPayloadBytes = %d, want %d", got, want)
}
if !serverSnapshots[0].BindingAlive || !serverSnapshots[0].BindingCurrent || !serverSnapshots[0].TransportAttached || !serverSnapshots[0].TransportCurrent { if !serverSnapshots[0].BindingAlive || !serverSnapshots[0].BindingCurrent || !serverSnapshots[0].TransportAttached || !serverSnapshots[0].TransportCurrent {
t.Fatalf("server bulk binding snapshot mismatch: %+v", serverSnapshots[0]) t.Fatalf("server bulk binding snapshot mismatch: %+v", serverSnapshots[0])
} }
@@ -135,6 +141,314 @@ func TestBulkOpenRoundTripTCP(t *testing.T) {
waitForBulkContextDone(t, bulk.Context(), 2*time.Second) waitForBulkContextDone(t, bulk.Context(), 2*time.Second)
} }
func TestDedicatedBulkOpenUnblocksSynchronousReadHandler(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
readCh := make(chan string, 1)
server.SetBulkHandler(func(info BulkAcceptInfo) error {
defer func() {
_ = info.Bulk.Close()
}()
buf := make([]byte, 5)
if _, err := io.ReadFull(info.Bulk, buf); err != nil {
return err
}
readCh <- string(buf)
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
type openResult struct {
bulk Bulk
err error
}
openCh := make(chan openResult, 1)
go func() {
bulk, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{
ID: "sync-read-handler",
Range: BulkRange{
Offset: 0,
Length: 5,
},
})
openCh <- openResult{bulk: bulk, err: err}
}()
var bulk Bulk
select {
case result := <-openCh:
if result.err != nil {
t.Fatalf("client OpenDedicatedBulk failed: %v", result.err)
}
bulk = result.bulk
case <-time.After(2 * time.Second):
t.Fatal("client OpenDedicatedBulk timed out while remote handler was synchronously reading")
}
defer func() {
if bulk != nil {
_ = bulk.Close()
}
}()
if _, err := bulk.Write([]byte("hello")); err != nil {
t.Fatalf("client dedicated bulk Write failed: %v", err)
}
select {
case got := <-readCh:
if got != "hello" {
t.Fatalf("server handler read %q, want %q", got, "hello")
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for synchronous handler read")
}
}
func TestDedicatedBulkOpenUnblocksOnBlockingFirstWrite(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
payload := strings.Repeat("w", 16)
writeDone := make(chan error, 1)
server.SetBulkHandler(func(info BulkAcceptInfo) error {
_, err := io.WriteString(info.Bulk, payload)
if err == nil {
err = info.Bulk.CloseWrite()
}
writeDone <- err
return err
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
bulk, err := client.OpenDedicatedBulk(ctx, BulkOpenOptions{
ID: "blocking-write-ready",
Range: BulkRange{
Offset: 0,
Length: int64(len(payload)),
},
ChunkSize: 4,
WindowBytes: 4,
MaxInFlight: 1,
})
if err != nil {
t.Fatalf("client OpenDedicatedBulk failed: %v", err)
}
readBulkExactly(t, bulk, payload, 2*time.Second)
waitForBulkReadEOF(t, bulk, 2*time.Second)
select {
case err := <-writeDone:
if err != nil {
t.Fatalf("server handler write failed: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for blocking first write to finish")
}
if err := bulk.Close(); err != nil {
t.Fatalf("client dedicated bulk Close failed: %v", err)
}
waitForBulkContextDone(t, bulk.Context(), 2*time.Second)
}
func TestServerOpenBulkLogicalDedicatedUnblocksOnBlockingFirstRead(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
readCh := make(chan string, 1)
client.SetBulkHandler(func(info BulkAcceptInfo) error {
defer func() {
_ = info.Bulk.Close()
}()
buf := make([]byte, 5)
if _, err := io.ReadFull(info.Bulk, buf); err != nil {
return err
}
readCh <- string(buf)
return nil
})
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
bulk, err := server.OpenBulkLogical(ctx, logical, BulkOpenOptions{
ID: "server-blocking-read-ready",
Range: BulkRange{
Offset: 0,
Length: 5,
},
Dedicated: true,
})
if err != nil {
t.Fatalf("server OpenBulkLogical dedicated failed: %v", err)
}
if _, err := bulk.Write([]byte("hello")); err != nil {
t.Fatalf("server dedicated bulk Write failed: %v", err)
}
select {
case got := <-readCh:
if got != "hello" {
t.Fatalf("client handler read %q, want %q", got, "hello")
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for blocking first read to finish")
}
if err := bulk.Close(); err != nil {
t.Fatalf("server dedicated bulk Close failed: %v", err)
}
waitForBulkContextDone(t, bulk.Context(), 2*time.Second)
}
func TestDedicatedBulkOpenReturnsHandlerFailureAfterAccepted(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
server.SetBulkHandler(func(info BulkAcceptInfo) error {
time.Sleep(80 * time.Millisecond)
return errors.New("dedicated handler failed after accept")
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, err := client.OpenDedicatedBulk(ctx, BulkOpenOptions{
ID: "accepted-then-fail",
Range: BulkRange{
Offset: 0,
Length: 1,
},
})
if err == nil || !strings.Contains(err.Error(), "dedicated handler failed after accept") {
t.Fatalf("client OpenDedicatedBulk error = %v, want dedicated handler failure after accept", err)
}
}
func TestServerOpenBulkLogicalDedicatedReturnsHandlerFailureAfterAccepted(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
client.SetBulkHandler(func(info BulkAcceptInfo) error {
time.Sleep(80 * time.Millisecond)
return errors.New("client dedicated handler failed after accept")
})
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, err := server.OpenBulkLogical(ctx, logical, BulkOpenOptions{
ID: "server-accepted-then-fail",
Range: BulkRange{
Offset: 0,
Length: 1,
},
Dedicated: true,
})
if err == nil || !strings.Contains(err.Error(), "client dedicated handler failed after accept") {
t.Fatalf("server OpenBulkLogical dedicated error = %v, want client dedicated handler failure after accept", err)
}
}
func TestBulkOpenRoundTripServerLogicalTCP(t *testing.T) { func TestBulkOpenRoundTripServerLogicalTCP(t *testing.T) {
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
@@ -532,15 +846,18 @@ func TestDedicatedBulkWritePrefersClosedPipeOverContextCanceled(t *testing.T) {
MaxInFlight: 4, MaxInFlight: 4,
}, 0, nil, nil, 0, nil, nil, func(context.Context, *bulkHandle, []byte) error { }, 0, nil, nil, 0, nil, nil, func(context.Context, *bulkHandle, []byte) error {
return nil return nil
}, func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) { }, func(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) {
bulk.markPeerClosed() bulk.markPeerClosed()
<-ctx.Done() <-ctx.Done()
return 0, ctx.Err() return 0, ctx.Err()
}, nil) }, nil)
_, err := bulk.Write([]byte("abcdefgh")) _, err := bulk.Write([]byte("abcdefgh"))
if !errors.Is(err, io.ErrClosedPipe) { if err != nil && !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("bulk Write error = %v, want %v", err, io.ErrClosedPipe) t.Fatalf("bulk Write error = %v, want nil or %v", err, io.ErrClosedPipe)
}
if err := bulk.waitPendingAsyncWrites(context.Background()); err != nil && !errors.Is(err, io.ErrClosedPipe) {
t.Fatalf("bulk waitPendingAsyncWrites error = %v, want nil or %v", err, io.ErrClosedPipe)
} }
} }
File diff suppressed because it is too large Load Diff
+122 -63
View File
@@ -10,74 +10,118 @@ import (
) )
type ClientCommon struct { type ClientCommon struct {
alive atomic.Value alive atomic.Value
status Status status Status
byeFromServer bool byeFromServer bool
conn net.Conn conn net.Conn
mu sync.Mutex mu sync.Mutex
msgID uint64 msgID uint64
peerIdentity string peerIdentity string
sessionEpoch uint64 sessionEpoch uint64
sessionOwnerState atomic.Int32 sessionOwnerState atomic.Int32
sessionRuntime atomic.Pointer[clientSessionRuntime] sessionRuntime atomic.Pointer[clientSessionRuntime]
connectSource atomic.Pointer[clientConnectSource] connectSource atomic.Pointer[clientConnectSource]
queue *stario.StarQueue queue *stario.StarQueue
stopFn context.CancelFunc stopFn context.CancelFunc
stopCtx context.Context stopCtx context.Context
parallelNum int parallelNum int
maxReadTimeout time.Duration maxReadTimeout time.Duration
maxWriteTimeout time.Duration maxWriteTimeout time.Duration
keyExchangeFn func(c Client) error keyExchangeFn func(c Client) error
linkFns map[string]func(message *Message) linkFns map[string]func(message *Message)
defaultFns func(message *Message) defaultFns func(message *Message)
msgEn func([]byte, []byte) []byte msgEn func([]byte, []byte) []byte
msgDe func([]byte, []byte) []byte msgDe func([]byte, []byte) []byte
fastStreamEncode transportFastStreamEncoder fastStreamEncode transportFastStreamEncoder
fastBulkEncode transportFastBulkEncoder fastBulkEncode transportFastBulkEncoder
fastPlainEncode transportFastPlainEncoder fastPlainEncode transportFastPlainEncoder
handshakeRsaPubKey []byte modernPSKRuntime *modernPSKCodecRuntime
SecretKey []byte handshakeRsaPubKey []byte
noFinSyncMsgMaxKeepSeconds int SecretKey []byte
lastHeartbeat int64 transportProtection atomic.Pointer[transportProtectionProfile]
heartbeatPeriod time.Duration peerAttachSecurity atomic.Pointer[peerAttachSecurityState]
wg stario.WaitGroup securityBootstrap transportProtectionProfile
netType NetType securitySteady transportProtectionProfile
showError bool securitySteadyNegotiated transportProtectionProfile
skipKeyExchange bool securityAuthMode AuthMode
useHeartBeat bool securityProtectionMode ProtectionMode
sequenceDe func([]byte) (interface{}, error) securityRequireForwardSecrecy bool
sequenceEn func(interface{}) ([]byte, error) securityConfigured bool
logicalSession *logicalSessionState peerAttachAuthenticated bool
onFileEvent func(FileEvent) peerAttachAuthFallback bool
fileEventObserver func(FileEvent) peerAttachAt int64
fileTransferCfg fileTransferConfig noFinSyncMsgMaxKeepSeconds int
signalReliableCfg signalReliabilityConfig lastHeartbeat int64
streamRuntime *streamRuntime heartbeatPeriod time.Duration
recordRuntime *recordRuntime wg stario.WaitGroup
bulkRuntime *bulkRuntime netType NetType
connectionRetryState *connectionRetryState showError bool
securityReadyCheck bool skipKeyExchange bool
debugMode bool useHeartBeat bool
sequenceDe func([]byte) (interface{}, error)
sequenceEn func(interface{}) ([]byte, error)
logicalSession *logicalSessionState
onFileEvent func(FileEvent)
fileEventObserver func(FileEvent)
fileTransferCfg fileTransferConfig
signalReliableCfg signalReliabilityConfig
streamRuntime *streamRuntime
recordRuntime *recordRuntime
bulkRuntime *bulkRuntime
bulkDefaultOpenMode BulkOpenMode
bulkNetworkProfile BulkNetworkProfile
bulkOpenTuning BulkOpenTuning
bulkDedicatedAttachLimit int
bulkDedicatedAttachSem chan struct{}
bulkDedicatedAttachRetry int
bulkDedicatedAttachBackoff time.Duration
bulkDedicatedDialTimeout time.Duration
bulkDedicatedHelloTimeout time.Duration
bulkDedicatedActiveLimit int
bulkDedicatedActive atomic.Int32
bulkDedicatedActiveWait chan struct{}
bulkDedicatedLaneLimit int
bulkDedicatedSidecarMu sync.Mutex
bulkDedicatedLanes map[uint32]*bulkDedicatedLane
bulkDedicatedNextLaneID uint32
bulkAttachAttemptCount atomic.Int64
bulkAttachRetryCount atomic.Int64
bulkAttachSuccessCount atomic.Int64
bulkAttachFallbackCount atomic.Int64
connectionRetryState *connectionRetryState
securityReadyCheck bool
debugMode bool
} }
func NewClient() Client { func NewClient() Client {
transport := defaultModernPSKTransportBundle() transport := defaultModernPSKTransportBundle()
var client = ClientCommon{ var client = ClientCommon{
maxReadTimeout: 0, maxReadTimeout: 0,
maxWriteTimeout: 0, maxWriteTimeout: 0,
peerIdentity: newClientPeerIdentity(), peerIdentity: newClientPeerIdentity(),
sequenceEn: encode, sequenceEn: encode,
sequenceDe: Decode, sequenceDe: Decode,
keyExchangeFn: aesRsaHello, keyExchangeFn: aesRsaHello,
SecretKey: nil, SecretKey: nil,
handshakeRsaPubKey: defaultRsaPubKey, handshakeRsaPubKey: defaultRsaPubKey,
msgEn: transport.msgEn, msgEn: transport.msgEn,
msgDe: transport.msgDe, msgDe: transport.msgDe,
fastStreamEncode: transport.fastStreamEncode, fastStreamEncode: transport.fastStreamEncode,
fastBulkEncode: transport.fastBulkEncode, fastBulkEncode: transport.fastBulkEncode,
fastPlainEncode: transport.fastPlainEncode, fastPlainEncode: transport.fastPlainEncode,
skipKeyExchange: true, skipKeyExchange: true,
securityReadyCheck: true, securityReadyCheck: true,
bulkDefaultOpenMode: BulkOpenModeShared,
bulkNetworkProfile: BulkNetworkProfileDefault,
bulkOpenTuning: defaultBulkOpenTuning(),
bulkDedicatedAttachLimit: defaultBulkDedicatedAttachLimit,
bulkDedicatedAttachRetry: defaultBulkDedicatedAttachRetry,
bulkDedicatedAttachBackoff: defaultBulkDedicatedAttachBackoff,
bulkDedicatedDialTimeout: defaultBulkDedicatedDialTimeout,
bulkDedicatedHelloTimeout: defaultBulkDedicatedHelloTimeout,
bulkDedicatedActiveLimit: defaultBulkDedicatedActiveLimit,
bulkDedicatedActiveWait: make(chan struct{}),
bulkDedicatedLaneLimit: defaultBulkDedicatedLaneLimit,
} }
client.alive.Store(false) client.alive.Store(false)
client.useHeartBeat = true client.useHeartBeat = true
@@ -93,13 +137,28 @@ func NewClient() Client {
client.streamRuntime = newStreamRuntime("cstrm") client.streamRuntime = newStreamRuntime("cstrm")
client.recordRuntime = newRecordRuntime() client.recordRuntime = newRecordRuntime()
client.bulkRuntime = newBulkRuntime("cblk") client.bulkRuntime = newBulkRuntime("cblk")
client.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane)
if client.bulkDedicatedAttachLimit > 0 {
client.bulkDedicatedAttachSem = make(chan struct{}, client.bulkDedicatedAttachLimit)
}
client.connectionRetryState = newConnectionRetryState() client.connectionRetryState = newConnectionRetryState()
client.onFileEvent = normalizeFileEventCallback(nil) client.onFileEvent = normalizeFileEventCallback(nil)
client.fileEventObserver = normalizeFileEventCallback(nil) client.fileEventObserver = normalizeFileEventCallback(nil)
client.stopCtx, client.stopFn = context.WithCancel(context.Background()) client.stopCtx, client.stopFn = context.WithCancel(context.Background())
client.sessionRuntime.Store(newClientSessionRuntimeBase(client.stopCtx, client.stopFn)) client.sessionRuntime.Store(newClientSessionRuntimeBase(client.stopCtx, client.stopFn))
client.setClientTransportProtectionProfile(defaultTransportProtectionProfile())
client.peerAttachSecurity.Store(defaultPeerAttachSecurityState())
bindClientStreamControl(&client) bindClientStreamControl(&client)
bindClientBulkControl(&client) bindClientBulkControl(&client)
client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler) client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler)
return &client return &client
} }
func (c *ClientCommon) maxWriteTimeoutSnapshot() time.Duration {
if c == nil {
return 0
}
c.mu.Lock()
defer c.mu.Unlock()
return c.maxWriteTimeout
}
+160 -15
View File
@@ -1,6 +1,9 @@
package notify package notify
import "context" import (
"context"
"errors"
)
func (c *ClientCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) { func (c *ClientCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) {
runtime := c.getBulkRuntime() runtime := c.getBulkRuntime()
@@ -10,10 +13,67 @@ func (c *ClientCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) {
runtime.setHandler(fn) runtime.setHandler(fn)
} }
func (c *ClientCommon) OpenSharedBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) {
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return c.OpenBulk(ctx, opt)
}
func (c *ClientCommon) OpenDedicatedBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) {
opt.Mode = BulkOpenModeDedicated
opt.Dedicated = true
return c.OpenBulk(ctx, opt)
}
func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) { func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) {
if normalizeBulkOpenMode(opt.Mode) == BulkOpenModeDefault && !opt.Dedicated {
opt.Mode = c.bulkDefaultOpenModeSnapshot()
}
opt = normalizeBulkOpenOptions(opt)
switch opt.Mode {
case BulkOpenModeDedicated:
opt.Dedicated = true
return c.openBulkWithDedicatedMode(ctx, opt)
case BulkOpenModeAuto:
// Auto mode prefers dedicated path and falls back to shared if dedicated fails.
if err := clientDedicatedBulkSupportError(c); err == nil {
dedicatedOpt := opt
dedicatedOpt.Mode = BulkOpenModeDedicated
dedicatedOpt.Dedicated = true
bulk, dedicatedErr := c.openBulkWithDedicatedMode(ctx, dedicatedOpt)
if dedicatedErr == nil {
return bulk, nil
}
sharedOpt := opt
sharedOpt.Mode = BulkOpenModeShared
sharedOpt.Dedicated = false
sharedBulk, sharedErr := c.openBulkWithDedicatedMode(ctx, sharedOpt)
if sharedErr == nil {
c.bulkAttachFallbackCount.Add(1)
return sharedBulk, nil
}
return nil, errors.Join(dedicatedErr, sharedErr)
}
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
c.bulkAttachFallbackCount.Add(1)
return c.openBulkWithDedicatedMode(ctx, opt)
case BulkOpenModeShared, BulkOpenModeDefault:
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return c.openBulkWithDedicatedMode(ctx, opt)
default:
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return c.openBulkWithDedicatedMode(ctx, opt)
}
}
func (c *ClientCommon) openBulkWithDedicatedMode(ctx context.Context, opt BulkOpenOptions) (Bulk, error) {
if c == nil { if c == nil {
return nil, errBulkClientNil return nil, errBulkClientNil
} }
opt = applyBulkOpenTuningDefaults(opt, c.bulkOpenTuningSnapshot())
runtime := c.getBulkRuntime() runtime := c.getBulkRuntime()
if runtime == nil { if runtime == nil {
return nil, errBulkRuntimeNil return nil, errBulkRuntimeNil
@@ -33,6 +93,66 @@ func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk,
if _, exists := runtime.lookup(clientFileScope(), req.BulkID); exists { if _, exists := runtime.lookup(clientFileScope(), req.BulkID); exists {
return nil, errBulkAlreadyExists return nil, errBulkAlreadyExists
} }
if req.Dedicated {
req.DedicatedLaneID = c.reserveBulkDedicatedLane()
if req.DataID == 0 {
req.DataID = runtime.nextDataID()
}
if req.AttachToken == "" {
req.AttachToken = newBulkAttachToken()
}
bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, 0, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c))
bulk.setClientSnapshotOwner(c)
bulk.markAcceptHandled()
if err := runtime.register(clientFileScope(), bulk); err != nil {
c.releaseBulkDedicatedLane(req.DedicatedLaneID)
return nil, err
}
resp, err := sendBulkOpenClient(ctx, c, req)
if err != nil {
bulk.markReset(err)
return nil, err
}
if resp.DataID != 0 && resp.DataID != req.DataID {
err = errBulkAlreadyExists
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: "bulk dedicated data id mismatch",
})
bulk.markReset(err)
return nil, err
}
if resp.TransportGeneration != 0 {
bulk.transportGeneration = resp.TransportGeneration
}
if resp.FastPathVersion != 0 {
bulk.fastPathVersion = normalizeBulkFastPathVersion(resp.FastPathVersion)
}
if resp.AttachToken != "" {
req.AttachToken = resp.AttachToken
bulk.setDedicatedAttachToken(resp.AttachToken)
}
if err := c.attachDedicatedBulkSidecar(ctx, bulk); err != nil {
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: err.Error(),
})
bulk.markReset(err)
return nil, err
}
if err := bulk.waitAcceptReady(ctx); err != nil {
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: err.Error(),
})
bulk.markReset(err)
return nil, err
}
return bulk, nil
}
resp, err := sendBulkOpenClient(ctx, c, req) resp, err := sendBulkOpenClient(ctx, c, req)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -40,6 +160,9 @@ func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk,
if resp.DataID != 0 { if resp.DataID != 0 {
req.DataID = resp.DataID req.DataID = resp.DataID
} }
if resp.FastPathVersion != 0 {
req.FastPathVersion = resp.FastPathVersion
}
req.Dedicated = resp.Dedicated req.Dedicated = resp.Dedicated
if resp.AttachToken != "" { if resp.AttachToken != "" {
req.AttachToken = resp.AttachToken req.AttachToken = resp.AttachToken
@@ -49,7 +172,9 @@ func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk,
} }
bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c)) bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c))
bulk.setClientSnapshotOwner(c) bulk.setClientSnapshotOwner(c)
bulk.markAcceptHandled()
if err := runtime.register(clientFileScope(), bulk); err != nil { if err := runtime.register(clientFileScope(), bulk); err != nil {
c.releaseBulkDedicatedLane(req.DedicatedLaneID)
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{ _, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
BulkID: req.BulkID, BulkID: req.BulkID,
DataID: req.DataID, DataID: req.DataID,
@@ -78,15 +203,16 @@ func clientBulkRequest(runtime *bulkRuntime, opt BulkOpenOptions) BulkOpenReques
id = runtime.nextID() id = runtime.nextID()
} }
return normalizeBulkOpenRequest(BulkOpenRequest{ return normalizeBulkOpenRequest(BulkOpenRequest{
BulkID: id, BulkID: id,
Range: opt.Range, FastPathVersion: bulkFastPathVersionCurrent,
Metadata: cloneBulkMetadata(opt.Metadata), Range: opt.Range,
ReadTimeout: opt.ReadTimeout, Metadata: cloneBulkMetadata(opt.Metadata),
WriteTimeout: opt.WriteTimeout, ReadTimeout: opt.ReadTimeout,
Dedicated: opt.Dedicated, WriteTimeout: opt.WriteTimeout,
ChunkSize: opt.ChunkSize, Dedicated: opt.Dedicated,
WindowBytes: opt.WindowBytes, ChunkSize: opt.ChunkSize,
MaxInFlight: opt.MaxInFlight, WindowBytes: opt.WindowBytes,
MaxInFlight: opt.MaxInFlight,
}) })
} }
@@ -148,12 +274,12 @@ func clientBulkDataSender(c *ClientCommon, epoch uint64) bulkDataSender {
if dataID == 0 { if dataID == 0 {
return errBulkDataPathNotReady return errBulkDataPathNotReady
} }
return c.sendFastBulkData(ctx, dataID, bulk.nextOutboundDataSeq(), chunk) return c.sendFastBulkData(ctx, dataID, bulk.nextOutboundDataSeq(), chunk, bulk.fastPathVersionSnapshot())
} }
} }
func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender { func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender {
return func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) { return func(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) {
if c == nil { if c == nil {
return 0, errBulkClientNil return 0, errBulkClientNil
} }
@@ -168,12 +294,19 @@ func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender {
if err := bulk.waitDedicatedReady(ctx); err != nil { if err := bulk.waitDedicatedReady(ctx); err != nil {
return 0, err return 0, err
} }
return c.sendDedicatedBulkWrite(ctx, bulk, payload) return c.sendDedicatedBulkWrite(ctx, bulk, startSeq, payload, payloadOwned)
} }
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) { if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
return 0, errTransportDetached return 0, errTransportDetached
} }
return 0, nil if bulk == nil {
return 0, errBulkRuntimeNil
}
dataID := bulk.dataIDSnapshot()
if dataID == 0 {
return 0, errBulkDataPathNotReady
}
return c.sendFastBulkWrite(ctx, dataID, startSeq, bulk.chunkSize, bulk.fastPathVersionSnapshot(), payload, payloadOwned)
} }
} }
@@ -185,8 +318,20 @@ func clientBulkReleaseSender(c *ClientCommon) bulkReleaseSender {
if bytes <= 0 && chunks <= 0 { if bytes <= 0 && chunks <= 0 {
return nil return nil
} }
ctx, cancel, err := bulk.newWriteContext(bulk.Context(), bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
if bulk.Dedicated() { if bulk.Dedicated() {
return c.sendDedicatedBulkRelease(context.Background(), bulk, bytes, chunks) return c.sendDedicatedBulkRelease(ctx, bulk, bytes, chunks)
}
if bulk.fastPathVersionSnapshot() >= bulkFastPathVersionV2 {
payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks)
if err != nil {
return err
}
return c.sendFastBulkControl(ctx, bulkFastPayloadTypeRelease, 0, bulk.dataIDSnapshot(), 0, bulk.fastPathVersionSnapshot(), payload)
} }
return sendBulkReleaseClient(c, BulkReleaseRequest{ return sendBulkReleaseClient(c, BulkReleaseRequest{
BulkID: bulk.ID(), BulkID: bulk.ID(),
+364
View File
@@ -0,0 +1,364 @@
package notify
import "time"
func defaultBulkOpenTuning() BulkOpenTuning {
return normalizeBulkOpenTuning(BulkOpenTuning{
ChunkSize: defaultBulkChunkSize,
WindowBytes: defaultBulkOpenWindowBytes,
MaxInFlight: defaultBulkOpenMaxInFlight,
})
}
func normalizeBulkOpenTuning(tuning BulkOpenTuning) BulkOpenTuning {
if tuning.ChunkSize <= 0 {
tuning.ChunkSize = defaultBulkChunkSize
}
if tuning.WindowBytes <= 0 {
tuning.WindowBytes = defaultBulkOpenWindowBytes
}
if tuning.WindowBytes < tuning.ChunkSize {
tuning.WindowBytes = tuning.ChunkSize
}
if tuning.MaxInFlight <= 0 {
tuning.MaxInFlight = defaultBulkOpenMaxInFlight
}
return tuning
}
func applyBulkOpenTuningDefaults(opt BulkOpenOptions, tuning BulkOpenTuning) BulkOpenOptions {
tuning = normalizeBulkOpenTuning(tuning)
if opt.ChunkSize <= 0 {
opt.ChunkSize = tuning.ChunkSize
}
if opt.WindowBytes <= 0 {
opt.WindowBytes = tuning.WindowBytes
}
if opt.MaxInFlight <= 0 {
opt.MaxInFlight = tuning.MaxInFlight
}
return opt
}
func normalizeBulkNetworkProfile(profile BulkNetworkProfile) BulkNetworkProfile {
switch profile {
case BulkNetworkProfileDefault, BulkNetworkProfileLAN, BulkNetworkProfileWAN, BulkNetworkProfileConstrained:
return profile
default:
return BulkNetworkProfileDefault
}
}
func bulkNetworkProfileName(profile BulkNetworkProfile) string {
switch normalizeBulkNetworkProfile(profile) {
case BulkNetworkProfileLAN:
return "lan"
case BulkNetworkProfileWAN:
return "wan"
case BulkNetworkProfileConstrained:
return "constrained"
case BulkNetworkProfileDefault:
fallthrough
default:
return "default"
}
}
func bulkNetworkProfilePreset(profile BulkNetworkProfile) (BulkOpenMode, BulkDedicatedAttachConfig, BulkOpenTuning) {
switch normalizeBulkNetworkProfile(profile) {
case BulkNetworkProfileLAN:
return BulkOpenModeAuto, BulkDedicatedAttachConfig{
AttachLimit: defaultBulkDedicatedAttachLimit,
ActiveLimit: defaultBulkDedicatedActiveLimit,
LaneLimit: 4,
Retry: defaultBulkDedicatedAttachRetry,
Backoff: defaultBulkDedicatedAttachBackoff,
DialTimeout: defaultBulkDedicatedDialTimeout,
HelloTimeout: defaultBulkDedicatedHelloTimeout,
}, BulkOpenTuning{
ChunkSize: 1024 * 1024,
WindowBytes: 32 * 1024 * 1024,
MaxInFlight: 64,
}
case BulkNetworkProfileWAN:
return BulkOpenModeAuto, BulkDedicatedAttachConfig{
AttachLimit: 2,
ActiveLimit: defaultBulkDedicatedActiveLimit,
LaneLimit: 2,
Retry: 4,
Backoff: 250 * time.Millisecond,
DialTimeout: 15 * time.Second,
HelloTimeout: 20 * time.Second,
}, BulkOpenTuning{
ChunkSize: 512 * 1024,
WindowBytes: 8 * 1024 * 1024,
MaxInFlight: 16,
}
case BulkNetworkProfileConstrained:
return BulkOpenModeShared, BulkDedicatedAttachConfig{
AttachLimit: 1,
ActiveLimit: 2,
LaneLimit: 1,
Retry: 5,
Backoff: 400 * time.Millisecond,
DialTimeout: 20 * time.Second,
HelloTimeout: 30 * time.Second,
}, BulkOpenTuning{
ChunkSize: 128 * 1024,
WindowBytes: 1 * 1024 * 1024,
MaxInFlight: 8,
}
default:
return BulkOpenModeShared, BulkDedicatedAttachConfig{
AttachLimit: defaultBulkDedicatedAttachLimit,
ActiveLimit: defaultBulkDedicatedActiveLimit,
LaneLimit: defaultBulkDedicatedLaneLimit,
Retry: defaultBulkDedicatedAttachRetry,
Backoff: defaultBulkDedicatedAttachBackoff,
DialTimeout: defaultBulkDedicatedDialTimeout,
HelloTimeout: defaultBulkDedicatedHelloTimeout,
}, defaultBulkOpenTuning()
}
}
func normalizeBulkDedicatedAttachConfig(cfg BulkDedicatedAttachConfig) BulkDedicatedAttachConfig {
if cfg.AttachLimit < 0 {
cfg.AttachLimit = 0
}
if cfg.ActiveLimit < 0 {
cfg.ActiveLimit = 0
}
if cfg.LaneLimit < 0 {
cfg.LaneLimit = 0
}
if cfg.Retry < 0 {
cfg.Retry = 0
}
if cfg.Backoff <= 0 {
cfg.Backoff = defaultBulkDedicatedAttachBackoff
}
if cfg.DialTimeout <= 0 {
cfg.DialTimeout = defaultBulkDedicatedDialTimeout
}
if cfg.HelloTimeout <= 0 {
cfg.HelloTimeout = defaultBulkDedicatedHelloTimeout
}
return cfg
}
func (c *ClientCommon) setBulkDedicatedAttachSemaphoreLocked(limit int) {
if limit <= 0 {
c.bulkDedicatedAttachSem = nil
return
}
c.bulkDedicatedAttachSem = make(chan struct{}, limit)
}
func (c *ClientCommon) applyBulkDedicatedAttachConfigLocked(cfg BulkDedicatedAttachConfig) {
cfg = normalizeBulkDedicatedAttachConfig(cfg)
c.bulkDedicatedAttachLimit = cfg.AttachLimit
c.setBulkDedicatedAttachSemaphoreLocked(cfg.AttachLimit)
c.bulkDedicatedActiveLimit = cfg.ActiveLimit
c.notifyBulkDedicatedActiveWaitersLocked()
c.bulkDedicatedLaneLimit = cfg.LaneLimit
c.bulkDedicatedAttachRetry = cfg.Retry
c.bulkDedicatedAttachBackoff = cfg.Backoff
c.bulkDedicatedDialTimeout = cfg.DialTimeout
c.bulkDedicatedHelloTimeout = cfg.HelloTimeout
}
func (c *ClientCommon) ensureBulkDedicatedActiveWaitLocked() chan struct{} {
if c == nil {
return nil
}
if c.bulkDedicatedActiveWait == nil {
c.bulkDedicatedActiveWait = make(chan struct{})
}
return c.bulkDedicatedActiveWait
}
func (c *ClientCommon) notifyBulkDedicatedActiveWaitersLocked() {
if c == nil {
return
}
waitCh := c.ensureBulkDedicatedActiveWaitLocked()
close(waitCh)
c.bulkDedicatedActiveWait = make(chan struct{})
}
func (c *ClientCommon) applyBulkOpenTuningLocked(tuning BulkOpenTuning) {
c.bulkOpenTuning = normalizeBulkOpenTuning(tuning)
}
func (c *ClientCommon) bulkDefaultOpenModeSnapshot() BulkOpenMode {
if c == nil {
return BulkOpenModeShared
}
c.mu.Lock()
defer c.mu.Unlock()
mode := normalizeBulkOpenMode(c.bulkDefaultOpenMode)
if mode == BulkOpenModeDefault {
mode = BulkOpenModeShared
}
return mode
}
func (c *ClientCommon) bulkDedicatedAttachSemaphoreSnapshot() chan struct{} {
if c == nil {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
return c.bulkDedicatedAttachSem
}
func (c *ClientCommon) bulkDedicatedActiveLimitSnapshot() int {
if c == nil {
return 0
}
c.mu.Lock()
defer c.mu.Unlock()
if c.bulkDedicatedActiveLimit < 0 {
return 0
}
return c.bulkDedicatedActiveLimit
}
func (c *ClientCommon) bulkDedicatedActiveWaitSnapshot() chan struct{} {
if c == nil {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
return c.ensureBulkDedicatedActiveWaitLocked()
}
func (c *ClientCommon) bulkDedicatedLaneLimitSnapshot() int {
if c == nil {
return defaultBulkDedicatedLaneLimit
}
c.mu.Lock()
defer c.mu.Unlock()
limit := c.bulkDedicatedLaneLimit
if limit < 0 {
return 0
}
if limit == 0 {
return 0
}
return limit
}
func (c *ClientCommon) SetBulkDefaultOpenMode(mode BulkOpenMode) {
if c == nil {
return
}
mode = normalizeBulkOpenMode(mode)
if mode == BulkOpenModeDefault {
mode = BulkOpenModeShared
}
c.mu.Lock()
c.bulkDefaultOpenMode = mode
c.mu.Unlock()
}
func (c *ClientCommon) BulkDefaultOpenMode() BulkOpenMode {
return c.bulkDefaultOpenModeSnapshot()
}
func (c *ClientCommon) bulkOpenTuningSnapshot() BulkOpenTuning {
if c == nil {
return defaultBulkOpenTuning()
}
c.mu.Lock()
defer c.mu.Unlock()
return normalizeBulkOpenTuning(c.bulkOpenTuning)
}
func (c *ClientCommon) SetBulkOpenTuning(tuning BulkOpenTuning) {
if c == nil {
return
}
c.mu.Lock()
c.applyBulkOpenTuningLocked(tuning)
c.mu.Unlock()
}
func (c *ClientCommon) BulkOpenTuning() BulkOpenTuning {
return c.bulkOpenTuningSnapshot()
}
func (c *ClientCommon) SetBulkDedicatedAttachConfig(cfg BulkDedicatedAttachConfig) {
if c == nil {
return
}
c.mu.Lock()
c.applyBulkDedicatedAttachConfigLocked(cfg)
c.mu.Unlock()
}
func (c *ClientCommon) BulkDedicatedAttachConfig() BulkDedicatedAttachConfig {
if c == nil {
return normalizeBulkDedicatedAttachConfig(BulkDedicatedAttachConfig{})
}
c.mu.Lock()
defer c.mu.Unlock()
return normalizeBulkDedicatedAttachConfig(BulkDedicatedAttachConfig{
AttachLimit: c.bulkDedicatedAttachLimit,
ActiveLimit: c.bulkDedicatedActiveLimit,
LaneLimit: c.bulkDedicatedLaneLimit,
Retry: c.bulkDedicatedAttachRetry,
Backoff: c.bulkDedicatedAttachBackoff,
DialTimeout: c.bulkDedicatedDialTimeout,
HelloTimeout: c.bulkDedicatedHelloTimeout,
})
}
func (c *ClientCommon) SetBulkNetworkProfile(profile BulkNetworkProfile) {
if c == nil {
return
}
profile = normalizeBulkNetworkProfile(profile)
mode, cfg, tuning := bulkNetworkProfilePreset(profile)
c.mu.Lock()
c.bulkNetworkProfile = profile
c.bulkDefaultOpenMode = mode
c.applyBulkDedicatedAttachConfigLocked(cfg)
c.applyBulkOpenTuningLocked(tuning)
c.mu.Unlock()
}
func (c *ClientCommon) BulkNetworkProfile() BulkNetworkProfile {
if c == nil {
return BulkNetworkProfileDefault
}
c.mu.Lock()
defer c.mu.Unlock()
return normalizeBulkNetworkProfile(c.bulkNetworkProfile)
}
func (s *ServerCommon) applyBulkOpenTuningLocked(tuning BulkOpenTuning) {
s.bulkOpenTuning = normalizeBulkOpenTuning(tuning)
}
func (s *ServerCommon) bulkOpenTuningSnapshot() BulkOpenTuning {
if s == nil {
return defaultBulkOpenTuning()
}
s.mu.RLock()
defer s.mu.RUnlock()
return normalizeBulkOpenTuning(s.bulkOpenTuning)
}
func (s *ServerCommon) SetBulkOpenTuning(tuning BulkOpenTuning) {
if s == nil {
return
}
s.mu.Lock()
s.applyBulkOpenTuningLocked(tuning)
s.mu.Unlock()
}
func (s *ServerCommon) BulkOpenTuning() BulkOpenTuning {
return s.bulkOpenTuningSnapshot()
}
+33 -12
View File
@@ -65,30 +65,40 @@ func (c *ClientCommon) RecoverTransferSnapshots(ctx context.Context) error {
} }
func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte { func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte {
return c.msgEn return c.clientTransportProtectionSnapshot().msgEn
} }
// Deprecated: SetMsgEn overrides the transport codec directly. // Deprecated: SetMsgEn overrides the transport codec directly.
// Prefer UseModernPSKClient or UseLegacySecurityClient. // Prefer UseModernPSKClient or UseLegacySecurityClient.
func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) { func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
c.msgEn = fn profile := c.clientTransportProtectionSnapshot()
c.fastStreamEncode = nil profile.mode = ProtectionManaged
c.fastBulkEncode = nil profile.msgEn = fn
c.fastPlainEncode = nil profile.fastStreamEncode = nil
profile.fastBulkEncode = nil
profile.fastPlainEncode = nil
profile.runtime = nil
c.setClientTransportProtectionProfile(profile)
c.clearClientSecurityProfiles()
c.securityReadyCheck = false c.securityReadyCheck = false
} }
func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte { func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte {
return c.msgDe return c.clientTransportProtectionSnapshot().msgDe
} }
// Deprecated: SetMsgDe overrides the transport codec directly. // Deprecated: SetMsgDe overrides the transport codec directly.
// Prefer UseModernPSKClient or UseLegacySecurityClient. // Prefer UseModernPSKClient or UseLegacySecurityClient.
func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) { func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
c.msgDe = fn profile := c.clientTransportProtectionSnapshot()
c.fastStreamEncode = nil profile.mode = ProtectionManaged
c.fastBulkEncode = nil profile.msgDe = fn
c.fastPlainEncode = nil profile.fastStreamEncode = nil
profile.fastBulkEncode = nil
profile.fastPlainEncode = nil
profile.runtime = nil
c.setClientTransportProtectionProfile(profile)
c.clearClientSecurityProfiles()
c.securityReadyCheck = false c.securityReadyCheck = false
} }
@@ -101,13 +111,24 @@ func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) {
} }
func (c *ClientCommon) GetSecretKey() []byte { func (c *ClientCommon) GetSecretKey() []byte {
return c.SecretKey return c.clientTransportProtectionSnapshot().secretKey
} }
// Deprecated: SetSecretKey injects a raw transport key directly. // Deprecated: SetSecretKey injects a raw transport key directly.
// Prefer UseModernPSKClient or UseLegacySecurityClient. // Prefer UseModernPSKClient or UseLegacySecurityClient.
func (c *ClientCommon) SetSecretKey(key []byte) { func (c *ClientCommon) SetSecretKey(key []byte) {
c.SecretKey = key profile := c.clientTransportProtectionSnapshot()
profile.mode = ProtectionManaged
profile.secretKey = cloneTransportProtectionKey(key)
if len(key) == 0 {
profile.runtime = nil
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
profile.runtime = runtime
} else {
profile.runtime = nil
}
c.setClientTransportProtectionProfile(profile)
c.clearClientSecurityProfiles()
c.securityReadyCheck = len(key) == 0 c.securityReadyCheck = len(key) == 0
c.skipKeyExchange = true c.skipKeyExchange = true
} }
+17
View File
@@ -72,6 +72,23 @@ func (c *ClientConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
conn := rt.tuConn conn := rt.tuConn
generation := rt.transportGeneration generation := rt.transportGeneration
defer closeClientConnSessionRuntimeTransportDone(rt) defer closeClientConnSessionRuntimeTransportDone(rt)
if conn != nil && !isPacketTransportConn(conn) {
reader := newTransportFrameReader(conn, nil)
for {
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseClientConnTransportOnStop(conn) {
_ = conn.Close()
}
return
default:
}
payload, release, err := c.readTUTransportPayloadPooled(conn, reader)
if !c.handleTUTransportPayloadReadResultWithSessionPooled(stopCtx, conn, generation, payload, release, err) {
return
}
}
}
buf := streamReadBuffer() buf := streamReadBuffer()
for { for {
select { select {
+118 -10
View File
@@ -6,16 +6,26 @@ import (
) )
type clientConnAttachmentState struct { type clientConnAttachmentState struct {
maxReadTimeout time.Duration maxReadTimeout time.Duration
maxWriteTimeout time.Duration maxWriteTimeout time.Duration
msgEn func([]byte, []byte) []byte authMode AuthMode
msgDe func([]byte, []byte) []byte protectionMode ProtectionMode
fastStreamEncode transportFastStreamEncoder msgEn func([]byte, []byte) []byte
fastBulkEncode transportFastBulkEncoder msgDe func([]byte, []byte) []byte
fastPlainEncode transportFastPlainEncoder fastStreamEncode transportFastStreamEncoder
handshakeRsaKey []byte fastBulkEncode transportFastBulkEncoder
secretKey []byte fastPlainEncode transportFastPlainEncoder
lastHeartBeat int64 modernPSKRuntime *modernPSKCodecRuntime
handshakeRsaKey []byte
secretKey []byte
keyMode string
sessionID []byte
forwardSecrecy bool
forwardSecrecyFallback bool
peerAttached bool
peerAttachFallback bool
peerAttachAt int64
lastHeartBeat int64
} }
func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnAttachmentState { func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnAttachmentState {
@@ -25,6 +35,7 @@ func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnA
cloned := *src cloned := *src
cloned.handshakeRsaKey = cloneClientConnAttachmentBytes(src.handshakeRsaKey) cloned.handshakeRsaKey = cloneClientConnAttachmentBytes(src.handshakeRsaKey)
cloned.secretKey = cloneClientConnAttachmentBytes(src.secretKey) cloned.secretKey = cloneClientConnAttachmentBytes(src.secretKey)
cloned.sessionID = cloneClientConnAttachmentBytes(src.sessionID)
return &cloned return &cloned
} }
@@ -152,8 +163,10 @@ func (c *ClientConn) applyClientConnAttachmentProfile(maxReadTimeout time.Durati
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.maxReadTimeout = maxReadTimeout state.maxReadTimeout = maxReadTimeout
state.maxWriteTimeout = maxWriteTimeout state.maxWriteTimeout = maxWriteTimeout
state.protectionMode = ProtectionManaged
state.msgEn = msgEn state.msgEn = msgEn
state.msgDe = msgDe state.msgDe = msgDe
state.modernPSKRuntime = nil
state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey) state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey)
state.secretKey = cloneClientConnAttachmentBytes(secretKey) state.secretKey = cloneClientConnAttachmentBytes(secretKey)
}) })
@@ -204,10 +217,12 @@ func (c *ClientConn) clientConnMsgEnSnapshot() func([]byte, []byte) []byte {
func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) { func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.protectionMode = ProtectionManaged
state.msgEn = fn state.msgEn = fn
state.fastStreamEncode = nil state.fastStreamEncode = nil
state.fastBulkEncode = nil state.fastBulkEncode = nil
state.fastPlainEncode = nil state.fastPlainEncode = nil
state.modernPSKRuntime = nil
}) })
} }
@@ -220,10 +235,12 @@ func (c *ClientConn) clientConnMsgDeSnapshot() func([]byte, []byte) []byte {
func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) { func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.protectionMode = ProtectionManaged
state.msgDe = fn state.msgDe = fn
state.fastStreamEncode = nil state.fastStreamEncode = nil
state.fastBulkEncode = nil state.fastBulkEncode = nil
state.fastPlainEncode = nil state.fastPlainEncode = nil
state.modernPSKRuntime = nil
}) })
} }
@@ -286,6 +303,97 @@ func (c *ClientConn) setClientConnSecretKey(key []byte) {
}) })
} }
func (c *LogicalConn) attachmentStateRaw() *clientConnAttachmentState {
if c == nil {
return nil
}
if state := c.attachment.Load(); state != nil {
if client := c.compatClientConn(); client != nil {
client.attachment.Store(state)
}
return state
}
if client := c.compatClientConn(); client != nil {
if state := client.attachment.Load(); state != nil {
if c.attachment.CompareAndSwap(nil, state) {
client.attachment.Store(state)
return state
}
return c.attachmentStateRaw()
}
}
return nil
}
func (c *LogicalConn) modernPSKRuntimeSnapshot() *modernPSKCodecRuntime {
if state := c.attachmentStateRaw(); state != nil {
return state.modernPSKRuntime
}
return nil
}
func (c *LogicalConn) protectionModeSnapshot() ProtectionMode {
if state := c.attachmentStateRaw(); state != nil {
return state.protectionMode
}
return ProtectionManaged
}
func (c *LogicalConn) authModeSnapshot() AuthMode {
if state := c.attachmentStateRaw(); state != nil {
return state.authMode
}
return AuthNone
}
func (c *LogicalConn) peerAttachAuthenticatedSnapshot() (bool, bool, time.Time) {
if state := c.attachmentStateRaw(); state != nil {
if state.peerAttachAt == 0 {
return state.peerAttached, state.peerAttachFallback, time.Time{}
}
return state.peerAttached, state.peerAttachFallback, time.Unix(0, state.peerAttachAt)
}
return false, false, time.Time{}
}
func (c *LogicalConn) markPeerAttachAuthenticated(authMode AuthMode, fallback bool, at time.Time) {
c.updateAttachmentState(func(state *clientConnAttachmentState) {
state.authMode = authMode
state.peerAttached = true
state.peerAttachFallback = fallback
state.peerAttachAt = at.UnixNano()
})
}
func (c *LogicalConn) setModernPSKRuntime(runtime *modernPSKCodecRuntime) {
c.updateAttachmentState(func(state *clientConnAttachmentState) {
state.modernPSKRuntime = runtime
})
}
func (c *ClientConn) clientConnAttachmentStateRaw() *clientConnAttachmentState {
if c == nil {
return nil
}
if logical := c.logicalView.Load(); logical != nil {
return logical.attachmentStateRaw()
}
return c.attachment.Load()
}
func (c *ClientConn) clientConnModernPSKRuntimeSnapshot() *modernPSKCodecRuntime {
if state := c.clientConnAttachmentStateRaw(); state != nil {
return state.modernPSKRuntime
}
return nil
}
func (c *ClientConn) setClientConnModernPSKRuntime(runtime *modernPSKCodecRuntime) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.modernPSKRuntime = runtime
})
}
func (c *ClientConn) clientConnLastHeartbeatUnixSnapshot() int64 { func (c *ClientConn) clientConnLastHeartbeatUnixSnapshot() int64 {
if c == nil { if c == nil {
return 0 return 0
+18
View File
@@ -332,6 +332,24 @@ func TestLogicalDetachTransportForTransferKeepsHandoffConnAlive(t *testing.T) {
} }
} }
func TestLogicalHandleTUTransportReadResultWithSessionDropsDataAfterTransportStop(t *testing.T) {
server := NewServer().(*ServerCommon)
left, right := net.Pipe()
defer right.Close()
stopCtx, stopFn := context.WithCancel(context.Background())
logical, _, _ := newRegisteredServerLogicalForTest(t, server, "logical-stop-read-drop", left, stopCtx, stopFn)
if logical == nil {
t.Fatal("logical should not be nil")
}
generation := logical.transportGenerationSnapshot()
stopFn()
if logical.handleTUTransportReadResultWithSession(stopCtx, left, generation, len([]byte("late-data")), []byte("late-data"), nil) {
t.Fatal("handleTUTransportReadResultWithSession should stop after transport stop")
}
}
func TestClientConnTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) { func TestClientConnTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) {
client := &ClientConn{} client := &ClientConn{}
left, right := net.Pipe() left, right := net.Pipe()
+220
View File
@@ -1,6 +1,7 @@
package notify package notify
import ( import (
"b612.me/stario"
"context" "context"
"net" "net"
"os" "os"
@@ -15,6 +16,10 @@ type serverInboundSourcePusher interface {
pushMessageSource([]byte, interface{}) pushMessageSource([]byte, interface{})
} }
type serverInboundSourceFastPusher interface {
pushTransportPayloadSourceFast([]byte, func(), interface{}) bool
}
func (c *LogicalConn) readTUMessage() { func (c *LogicalConn) readTUMessage() {
rt := c.clientConnSessionRuntimeSnapshot() rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil { if rt == nil {
@@ -37,6 +42,23 @@ func (c *LogicalConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
conn := rt.tuConn conn := rt.tuConn
generation := rt.transportGeneration generation := rt.transportGeneration
defer closeClientConnSessionRuntimeTransportDone(rt) defer closeClientConnSessionRuntimeTransportDone(rt)
if conn != nil && !isPacketTransportConn(conn) {
reader := newTransportFrameReader(conn, nil)
for {
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseTransportOnStop(conn) {
_ = conn.Close()
}
return
default:
}
payload, release, err := c.readTUTransportPayloadPooled(conn, reader)
if !c.handleTUTransportPayloadReadResultWithSessionPooled(stopCtx, conn, generation, payload, release, err) {
return
}
}
}
buf := streamReadBuffer() buf := streamReadBuffer()
for { for {
select { select {
@@ -54,6 +76,55 @@ func (c *LogicalConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
} }
} }
func (c *LogicalConn) readTUTransportPayloadPooled(conn net.Conn, reader *stario.FrameReader) ([]byte, func(), error) {
if reader == nil {
return nil, nil, net.ErrClosed
}
if conn == nil {
return nil, nil, net.ErrClosed
}
if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(timeout))
}
return reader.NextPooled()
}
func (c *LogicalConn) handleTUTransportPayloadReadResultWithSessionPooled(stopCtx context.Context, conn net.Conn, generation uint64, payload []byte, release func(), err error) bool {
if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) {
if release != nil {
release()
}
if c.shouldCloseTransportOnStop(conn) {
_ = conn.Close()
}
return false
}
if err == os.ErrDeadlineExceeded {
return true
}
if err != nil {
if release != nil {
release()
}
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseTransportOnStop(conn) {
_ = conn.Close()
}
return false
default:
}
if detacher, ok := c.Server().(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() {
detacher.detachLogicalSessionTransport(c, "read error", err)
return false
}
c.stopServerOwnedSession("read error", err)
return false
}
c.pushServerOwnedTransportPayload(payload, release, conn, generation)
return true
}
func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) { func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) {
if len(data) == 0 { if len(data) == 0 {
data = streamReadBuffer() data = streamReadBuffer()
@@ -69,6 +140,12 @@ func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []by
} }
func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool { func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool {
if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) {
if c.shouldCloseTransportOnStop(conn) {
_ = conn.Close()
}
return false
}
if err == os.ErrDeadlineExceeded { if err == os.ErrDeadlineExceeded {
if num != 0 { if num != 0 {
c.pushServerOwnedTransportMessage(data[:num], conn, generation) c.pushServerOwnedTransportMessage(data[:num], conn, generation)
@@ -95,6 +172,30 @@ func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Con
return true return true
} }
func transportReadShouldStop(stopCtx context.Context) bool {
select {
case <-sessionStopChan(stopCtx):
return true
default:
return false
}
}
func (c *LogicalConn) ownsTransportRead(conn net.Conn, generation uint64) bool {
if c == nil {
return false
}
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil || !rt.transportAttached || rt.transportGeneration != generation {
return false
}
current := rt.tuConn
if rt.transport != nil && rt.transport.connSnapshot() != nil {
current = rt.transport.connSnapshot()
}
return current == conn
}
func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) { func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) {
if c == nil || len(data) == 0 { if c == nil || len(data) == 0 {
return return
@@ -110,6 +211,29 @@ func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn
server.pushMessage(data, c.clientConnIDSnapshot()) server.pushMessage(data, c.clientConnIDSnapshot())
} }
func (c *LogicalConn) pushServerOwnedTransportPayload(payload []byte, release func(), conn net.Conn, generation uint64) {
if c == nil || len(payload) == 0 {
if release != nil {
release()
}
return
}
server := c.Server()
if server == nil {
if release != nil {
release()
}
return
}
if pusher, ok := server.(serverInboundSourceFastPusher); ok {
pusher.pushTransportPayloadSourceFast(payload, release, newServerInboundSource(c, conn, nil, generation))
return
}
if release != nil {
release()
}
}
func (c *LogicalConn) shouldCloseTransportOnStop(conn net.Conn) bool { func (c *LogicalConn) shouldCloseTransportOnStop(conn net.Conn) bool {
if c == nil || conn == nil { if c == nil || conn == nil {
return false return false
@@ -155,14 +279,75 @@ func (c *ClientConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byt
return num, data, err return num, data, err
} }
func (c *ClientConn) readTUTransportPayloadPooled(conn net.Conn, reader *stario.FrameReader) ([]byte, func(), error) {
if logical := c.LogicalConn(); logical != nil {
return logical.readTUTransportPayloadPooled(conn, reader)
}
if reader == nil {
return nil, nil, net.ErrClosed
}
if conn == nil {
return nil, nil, net.ErrClosed
}
if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(timeout))
}
return reader.NextPooled()
}
func (c *ClientConn) handleTUTransportReadResult(num int, data []byte, err error) bool { func (c *ClientConn) handleTUTransportReadResult(num int, data []byte, err error) bool {
return c.handleTUTransportReadResultWithSession(c.clientConnTransportStopContextSnapshot(), c.clientConnTransportSnapshot(), c.clientConnTransportGenerationSnapshot(), num, data, err) return c.handleTUTransportReadResultWithSession(c.clientConnTransportStopContextSnapshot(), c.clientConnTransportSnapshot(), c.clientConnTransportGenerationSnapshot(), num, data, err)
} }
func (c *ClientConn) handleTUTransportPayloadReadResultWithSessionPooled(stopCtx context.Context, conn net.Conn, generation uint64, payload []byte, release func(), err error) bool {
if logical := c.LogicalConn(); logical != nil {
return logical.handleTUTransportPayloadReadResultWithSessionPooled(stopCtx, conn, generation, payload, release, err)
}
if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) {
if release != nil {
release()
}
if c.shouldCloseClientConnTransportOnStop(conn) {
_ = conn.Close()
}
return false
}
if err == os.ErrDeadlineExceeded {
return true
}
if err != nil {
if release != nil {
release()
}
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseClientConnTransportOnStop(conn) {
_ = conn.Close()
}
return false
default:
}
if detacher, ok := c.server.(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() {
detacher.detachLogicalSessionTransport(logicalConnFromClient(c), "read error", err)
return false
}
c.stopServerOwnedSession("read error", err)
return false
}
c.pushServerOwnedTransportPayload(payload, release, conn, generation)
return true
}
func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool { func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool {
if logical := c.LogicalConn(); logical != nil { if logical := c.LogicalConn(); logical != nil {
return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err)
} }
if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) {
if c.shouldCloseClientConnTransportOnStop(conn) {
_ = conn.Close()
}
return false
}
if err == os.ErrDeadlineExceeded { if err == os.ErrDeadlineExceeded {
if num != 0 { if num != 0 {
c.pushServerOwnedTransportMessage(data[:num], conn, generation) c.pushServerOwnedTransportMessage(data[:num], conn, generation)
@@ -189,6 +374,21 @@ func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Cont
return true return true
} }
func (c *ClientConn) ownsTransportRead(conn net.Conn, generation uint64) bool {
if c == nil {
return false
}
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil || !rt.transportAttached || rt.transportGeneration != generation {
return false
}
current := rt.tuConn
if rt.transport != nil && rt.transport.connSnapshot() != nil {
current = rt.transport.connSnapshot()
}
return current == conn
}
func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) { func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) {
if logical := c.LogicalConn(); logical != nil { if logical := c.LogicalConn(); logical != nil {
logical.pushServerOwnedTransportMessage(data, conn, generation) logical.pushServerOwnedTransportMessage(data, conn, generation)
@@ -204,6 +404,26 @@ func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn,
c.server.pushMessage(data, c.clientConnIDSnapshot()) c.server.pushMessage(data, c.clientConnIDSnapshot())
} }
func (c *ClientConn) pushServerOwnedTransportPayload(payload []byte, release func(), conn net.Conn, generation uint64) {
if logical := c.LogicalConn(); logical != nil {
logical.pushServerOwnedTransportPayload(payload, release, conn, generation)
return
}
if c == nil || c.server == nil || len(payload) == 0 {
if release != nil {
release()
}
return
}
if pusher, ok := c.server.(serverInboundSourceFastPusher); ok {
pusher.pushTransportPayloadSourceFast(payload, release, newServerInboundSource(logicalConnFromClient(c), conn, nil, generation))
return
}
if release != nil {
release()
}
}
func (c *ClientConn) shouldCloseClientConnTransportOnStop(conn net.Conn) bool { func (c *ClientConn) shouldCloseClientConnTransportOnStop(conn net.Conn) bool {
if logical := c.LogicalConn(); logical != nil { if logical := c.LogicalConn(); logical != nil {
return logical.shouldCloseTransportOnStop(conn) return logical.shouldCloseTransportOnStop(conn)
+24 -13
View File
@@ -18,14 +18,18 @@ const (
var errClientReconnectSourceUnavailable = errors.New("client reconnect source is unavailable") var errClientReconnectSourceUnavailable = errors.New("client reconnect source is unavailable")
type clientConnectSource struct { type clientConnectSource struct {
kind string kind string
network string network string
addr string addr string
dialFn func(context.Context) (net.Conn, error) dialFn func(context.Context) (net.Conn, error)
supportsAdditional bool
} }
func newClientConnConnectSource(conn net.Conn) *clientConnectSource { func newClientConnConnectSource(conn net.Conn) *clientConnectSource {
source := &clientConnectSource{kind: clientConnectSourceConn} source := &clientConnectSource{
kind: clientConnectSourceConn,
supportsAdditional: false,
}
if conn == nil { if conn == nil {
return source return source
} }
@@ -43,9 +47,10 @@ func newClientConnConnectSource(conn net.Conn) *clientConnectSource {
func newClientNetworkConnectSource(network string, addr string) *clientConnectSource { func newClientNetworkConnectSource(network string, addr string) *clientConnectSource {
return &clientConnectSource{ return &clientConnectSource{
kind: clientConnectSourceNetwork, kind: clientConnectSourceNetwork,
network: network, network: network,
addr: addr, addr: addr,
supportsAdditional: true,
dialFn: func(context.Context) (net.Conn, error) { dialFn: func(context.Context) (net.Conn, error) {
return transport.Dial(network, addr) return transport.Dial(network, addr)
}, },
@@ -54,9 +59,10 @@ func newClientNetworkConnectSource(network string, addr string) *clientConnectSo
func newClientTimeoutConnectSource(network string, addr string, timeout time.Duration) *clientConnectSource { func newClientTimeoutConnectSource(network string, addr string, timeout time.Duration) *clientConnectSource {
return &clientConnectSource{ return &clientConnectSource{
kind: clientConnectSourceTimeout, kind: clientConnectSourceTimeout,
network: network, network: network,
addr: addr, addr: addr,
supportsAdditional: true,
dialFn: func(context.Context) (net.Conn, error) { dialFn: func(context.Context) (net.Conn, error) {
return transport.DialTimeout(network, addr, timeout) return transport.DialTimeout(network, addr, timeout)
}, },
@@ -65,8 +71,9 @@ func newClientTimeoutConnectSource(network string, addr string, timeout time.Dur
func newClientFactoryConnectSource(dialFn func(context.Context) (net.Conn, error)) *clientConnectSource { func newClientFactoryConnectSource(dialFn func(context.Context) (net.Conn, error)) *clientConnectSource {
return &clientConnectSource{ return &clientConnectSource{
kind: clientConnectSourceFactory, kind: clientConnectSourceFactory,
dialFn: dialFn, dialFn: dialFn,
supportsAdditional: true,
} }
} }
@@ -82,6 +89,10 @@ func (s *clientConnectSource) canReconnect() bool {
return s != nil && s.dialFn != nil return s != nil && s.dialFn != nil
} }
func (s *clientConnectSource) supportsAdditionalConn() bool {
return s != nil && s.supportsAdditional
}
func (s *clientConnectSource) isUDP() bool { func (s *clientConnectSource) isUDP() bool {
if s == nil { if s == nil {
return false return false
+5 -1
View File
@@ -31,7 +31,11 @@ func (c *ClientCommon) ExchangeKey(newKey []byte) error {
if string(data.Value) != "success" { if string(data.Value) != "success" {
return errors.New("cannot exchange new aes-key") return errors.New("cannot exchange new aes-key")
} }
c.SecretKey = newKey profile := c.clientTransportProtectionSnapshot()
profile.mode = ProtectionManaged
profile.secretKey = cloneTransportProtectionKey(newKey)
profile.runtime = nil
c.setClientTransportProtectionProfile(profile)
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
return nil return nil
} }
+2
View File
@@ -24,6 +24,7 @@ func (c *ClientCommon) OpenRecordStream(ctx context.Context, opt RecordOpenOptio
_ = stream.Reset(err) _ = stream.Reset(err)
return nil, err return nil, err
} }
bindRecordRuntime(record, c.getRecordRuntime())
return record, nil return record, nil
} }
@@ -51,6 +52,7 @@ func (c *ClientCommon) claimInboundRecordStream(stream *streamHandle) (bool, err
if err != nil { if err != nil {
return true, err return true, err
} }
bindRecordRuntime(record, runtime)
info := RecordAcceptInfo{ info := RecordAcceptInfo{
ID: stream.ID(), ID: stream.ID(),
Metadata: stream.Metadata(), Metadata: stream.Metadata(),
+18
View File
@@ -350,6 +350,8 @@ func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime,
if rt == nil { if rt == nil {
return nil return nil
} }
c.resetClientPeerAttachAuth()
c.activateClientBootstrapTransportProtection()
if runKeyExchange && !c.skipKeyExchange { if runKeyExchange && !c.skipKeyExchange {
if err := c.keyExchangeFn(c); err != nil { if err := c.keyExchangeFn(c); err != nil {
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err) return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err)
@@ -358,6 +360,7 @@ func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime,
if err := c.announceClientPeerIdentity(); err != nil { if err := c.announceClientPeerIdentity(); err != nil {
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err) return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err)
} }
c.activateClientSteadyTransportProtection()
return nil return nil
} }
@@ -433,6 +436,21 @@ func (c *ClientCommon) readMessageLoop(stopCtx context.Context, conn net.Conn, q
} }
binding := newTransportBinding(conn, queue) binding := newTransportBinding(conn, queue)
dispatcher := c.clientInboundDispatcherSnapshot() dispatcher := c.clientInboundDispatcherSnapshot()
if conn != nil && queue != nil && !isPacketTransportConn(conn) {
reader := newTransportFrameReader(conn, queue)
for {
select {
case <-stopCtx.Done():
c.closeClientTransportBinding(binding)
return
default:
}
payload, release, err := c.readTransportPayloadPooled(conn, reader)
if !c.handleTransportPayloadReadResultWithSession(stopCtx, binding, payload, release, err, epoch, dispatcher) {
return
}
}
}
buf := streamReadBuffer() buf := streamReadBuffer()
for { for {
select { select {
+2 -2
View File
@@ -320,7 +320,7 @@ func TestSetClientSessionRuntimeStopsOldBindingWorkersOnReattach(t *testing.T) {
defer oldLeft.Close() defer oldLeft.Close()
defer oldRight.Close() defer oldRight.Close()
oldBinding := newTransportBinding(oldLeft, queue) oldBinding := newTransportBinding(oldLeft, queue)
oldSender := oldBinding.bulkBatchSenderSnapshot() oldSender := oldBinding.clientBulkBatchSenderSnapshot(client)
client.setClientSessionRuntime(&clientSessionRuntime{ client.setClientSessionRuntime(&clientSessionRuntime{
transport: oldBinding, transport: oldBinding,
@@ -345,7 +345,7 @@ func TestSetClientSessionRuntimeStopsOldBindingWorkersOnReattach(t *testing.T) {
epoch: 2, epoch: 2,
}) })
err := oldSender.submit(context.Background(), []byte("payload")) err := oldSender.submitData(context.Background(), 1, 1, bulkFastPathVersionV1, []byte("payload"))
if err != errTransportDetached { if err != errTransportDetached {
t.Fatalf("old sender submit after reattach = %v, want %v", err, errTransportDetached) t.Fatalf("old sender submit after reattach = %v, want %v", err, errTransportDetached)
} }
+13 -6
View File
@@ -35,6 +35,12 @@ func (c *ClientCommon) OpenStream(ctx context.Context, opt StreamOpenOptions) (S
if resp.DataID != 0 { if resp.DataID != 0 {
req.DataID = resp.DataID req.DataID = resp.DataID
} }
if resp.FastPathVersion != 0 {
req.FastPathVersion = resp.FastPathVersion
} else {
req.FastPathVersion = streamFastPathVersionV1
}
req.Metadata = mergeStreamMetadata(req.Metadata, resp.Metadata)
stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot()) stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot())
stream.setClientSnapshotOwner(c) stream.setClientSnapshotOwner(c)
stream.setAddrSnapshot(c.clientStreamAddrSnapshot()) stream.setAddrSnapshot(c.clientStreamAddrSnapshot())
@@ -65,11 +71,12 @@ func clientStreamRequest(runtime *streamRuntime, opt StreamOpenOptions) StreamOp
id = runtime.nextID() id = runtime.nextID()
} }
return normalizeStreamOpenRequest(StreamOpenRequest{ return normalizeStreamOpenRequest(StreamOpenRequest{
StreamID: id, StreamID: id,
Channel: opt.Channel, FastPathVersion: streamFastPathVersionCurrent,
Metadata: cloneStreamMetadata(opt.Metadata), Channel: opt.Channel,
ReadTimeout: opt.ReadTimeout, Metadata: cloneStreamMetadata(opt.Metadata),
WriteTimeout: opt.WriteTimeout, ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
}) })
} }
@@ -109,7 +116,7 @@ func clientStreamDataSender(c *ClientCommon, epoch uint64) streamDataSender {
} }
} }
if dataID := stream.dataIDSnapshot(); dataID != 0 { if dataID := stream.dataIDSnapshot(); dataID != 0 {
return c.sendFastStreamData(dataID, stream.nextOutboundDataSeq(), chunk) return c.sendFastStreamData(ctx, stream, chunk)
} }
return c.sendEnvelope(newStreamDataEnvelope(stream.ID(), chunk)) return c.sendEnvelope(newStreamDataEnvelope(stream.ID(), chunk))
} }
+84 -14
View File
@@ -130,24 +130,94 @@ func (c *ClientCommon) handleTransportReadResultWithSessionDispatcher(stopCtx co
return true return true
} }
func (c *ClientCommon) readTransportPayloadPooled(conn net.Conn, reader *stario.FrameReader) ([]byte, func(), error) {
if reader == nil {
return nil, nil, net.ErrClosed
}
if conn == nil {
return nil, nil, net.ErrClosed
}
if c.maxReadTimeout.Seconds() != 0 {
_ = conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout))
}
return reader.NextPooled()
}
func (c *ClientCommon) handleTransportPayloadReadResultWithSession(stopCtx context.Context, binding *transportBinding, payload []byte, release func(), err error, epoch uint64, dispatcher *inboundDispatcher) bool {
if err == os.ErrDeadlineExceeded {
return true
}
if err != nil {
if release != nil {
release()
}
if c.showError || c.debugMode {
fmt.Println("client read error", err)
}
select {
case <-sessionStopChan(stopCtx):
c.closeClientTransportBinding(binding)
return false
default:
}
c.stopClientSessionIfCurrent(epoch, "client read error", err)
return false
}
c.dispatchTransportPayloadFast(payload, release, dispatcher)
return true
}
func (c *ClientCommon) dispatchTransportPayloadFast(payload []byte, release func(), dispatcher *inboundDispatcher) {
if len(payload) == 0 {
if release != nil {
release()
}
return
}
plain, plainRelease, err := c.decryptTransportPayloadPooled(payload, release)
if err != nil {
if c.showError || c.debugMode {
fmt.Println("client decode transport payload error", err)
}
return
}
if c.tryDispatchBorrowedTransportPlain(plain, plainRelease) {
return
}
if dispatcher == nil {
now := time.Now()
err := c.dispatchInboundTransportPlain(plain, now)
if plainRelease != nil {
plainRelease()
}
if err != nil && (c.showError || c.debugMode) {
fmt.Println("client decode envelope error", err)
}
return
}
owned := plain
if plainRelease != nil {
owned = append([]byte(nil), plain...)
plainRelease()
}
c.wg.Add(1)
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
defer c.wg.Done()
now := time.Now()
if err := c.dispatchInboundTransportPlain(owned, now); err != nil && (c.showError || c.debugMode) {
fmt.Println("client decode envelope error", err)
}
}) {
c.wg.Done()
}
}
func (c *ClientCommon) pushMessageFast(queue *stario.StarQueue, data []byte, dispatcher *inboundDispatcher) bool { func (c *ClientCommon) pushMessageFast(queue *stario.StarQueue, data []byte, dispatcher *inboundDispatcher) bool {
if queue == nil || dispatcher == nil || len(data) == 0 { if queue == nil || dispatcher == nil || len(data) == 0 {
return false return false
} }
if err := queue.ParseMessageOwned(data, "b612", func(msg stario.MsgQueue) error { if err := queue.ParseMessageView(data, "b612", func(frame stario.FrameView) error {
payload := msg.Msg c.dispatchTransportPayloadFast(frame.Payload, nil, dispatcher)
c.wg.Add(1)
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
defer c.wg.Done()
now := time.Now()
if err := c.dispatchInboundTransportPayload(payload, now); err != nil {
if c.showError || c.debugMode {
fmt.Println("client decode envelope error", err)
}
}
}) {
c.wg.Done()
}
return nil return nil
}); err != nil && (c.showError || c.debugMode) { }); err != nil && (c.showError || c.debugMode) {
fmt.Println("client parse inbound frame error", err) fmt.Println("client parse inbound frame error", err)
+12
View File
@@ -18,6 +18,16 @@ type Client interface {
SetStreamConfig(StreamConfig) SetStreamConfig(StreamConfig)
SetTransferResumeStore(TransferResumeStore) SetTransferResumeStore(TransferResumeStore)
RecoverTransferSnapshots(context.Context) error RecoverTransferSnapshots(context.Context) error
SetBulkNetworkProfile(BulkNetworkProfile)
BulkNetworkProfile() BulkNetworkProfile
SetBulkDefaultOpenMode(BulkOpenMode)
BulkDefaultOpenMode() BulkOpenMode
SetBulkOpenTuning(BulkOpenTuning)
BulkOpenTuning() BulkOpenTuning
SetBulkDedicatedAttachConfig(BulkDedicatedAttachConfig)
BulkDedicatedAttachConfig() BulkDedicatedAttachConfig
SetPeerAttachSecurityConfig(PeerAttachSecurityConfig) error
PeerAttachSecurityConfig() PeerAttachSecurityConfig
SetFileReceiveDir(dir string) error SetFileReceiveDir(dir string) error
send(msg TransferMsg) (WaitMsg, error) send(msg TransferMsg) (WaitMsg, error)
sendEnvelope(env Envelope) error sendEnvelope(env Envelope) error
@@ -77,6 +87,8 @@ type Client interface {
OpenStream(ctx context.Context, opt StreamOpenOptions) (Stream, error) OpenStream(ctx context.Context, opt StreamOpenOptions) (Stream, error)
OpenRecordStream(ctx context.Context, opt RecordOpenOptions) (RecordStream, error) OpenRecordStream(ctx context.Context, opt RecordOpenOptions) (RecordStream, error)
OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error)
OpenSharedBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error)
OpenDedicatedBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error)
SendTransfer(ctx context.Context, opt TransferSendOptions) (TransferHandle, error) SendTransfer(ctx context.Context, opt TransferSendOptions) (TransferHandle, error)
SendFile(ctx context.Context, filePath string) error SendFile(ctx context.Context, filePath string) error
} }
+2
View File
@@ -126,6 +126,8 @@ func init() {
RegisterName("b612.me/notify.BulkCloseResponse", BulkCloseResponse{}) RegisterName("b612.me/notify.BulkCloseResponse", BulkCloseResponse{})
RegisterName("b612.me/notify.BulkResetRequest", BulkResetRequest{}) RegisterName("b612.me/notify.BulkResetRequest", BulkResetRequest{})
RegisterName("b612.me/notify.BulkResetResponse", BulkResetResponse{}) RegisterName("b612.me/notify.BulkResetResponse", BulkResetResponse{})
RegisterName("b612.me/notify.BulkReadyRequest", BulkReadyRequest{})
RegisterName("b612.me/notify.BulkReadyResponse", BulkReadyResponse{})
RegisterName("b612.me/notify.BulkReleaseRequest", BulkReleaseRequest{}) RegisterName("b612.me/notify.BulkReleaseRequest", BulkReleaseRequest{})
RegisterName("b612.me/notify.bulkAttachRequest", bulkAttachRequest{}) RegisterName("b612.me/notify.bulkAttachRequest", bulkAttachRequest{})
RegisterName("b612.me/notify.bulkAttachResponse", bulkAttachResponse{}) RegisterName("b612.me/notify.bulkAttachResponse", bulkAttachResponse{})
+113 -1
View File
@@ -34,10 +34,36 @@ type DiagnosticsTransferTelemetrySummary struct {
CommitWaitRatio float64 CommitWaitRatio float64
} }
type DiagnosticsRecordTelemetrySummary struct {
BatchFramesSent int64
AckFramesSent int64
ErrorFramesSent int64
BatchFramesReceived int64
AckFramesReceived int64
ErrorFramesReceived int64
FrameSendCount int64
FrameReceiveCount int64
PiggybackAckSent int64
PiggybackAckReceived int64
BarrierCount int64
BarrierFlushWaitDuration time.Duration
BarrierApplyWaitDuration time.Duration
OutstandingRecords int
OutstandingBytes int
PendingApplyRecords int
PendingAckRecords int
PeakPendingApplyRecords int
}
type DiagnosticsSummary struct { type DiagnosticsSummary struct {
LogicalCount int LogicalCount int
CurrentTransportCount int CurrentTransportCount int
BulkAttachAttempts int64
BulkAttachRetries int64
BulkAttachSuccess int64
BulkAutoFallbacks int64
StreamCount int StreamCount int
ActiveStreamCount int ActiveStreamCount int
StaleStreamCount int StaleStreamCount int
@@ -49,6 +75,11 @@ type DiagnosticsSummary struct {
StaleBulkCount int StaleBulkCount int
ResetBulkCount int ResetBulkCount int
RecordCount int
ActiveRecordCount int
StaleRecordCount int
ResetRecordCount int
TransferCount int TransferCount int
ActiveTransferCount int ActiveTransferCount int
PausedTransferCount int PausedTransferCount int
@@ -58,6 +89,8 @@ type DiagnosticsSummary struct {
StreamResetCauses DiagnosticsResetCauseSummary StreamResetCauses DiagnosticsResetCauseSummary
BulkResetCauses DiagnosticsResetCauseSummary BulkResetCauses DiagnosticsResetCauseSummary
RecordResetCauses DiagnosticsResetCauseSummary
RecordTelemetry DiagnosticsRecordTelemetrySummary
TransferTelemetry DiagnosticsTransferTelemetrySummary TransferTelemetry DiagnosticsTransferTelemetrySummary
} }
@@ -65,6 +98,7 @@ type ClientDiagnosticsSnapshot struct {
Runtime ClientRuntimeSnapshot Runtime ClientRuntimeSnapshot
Streams []StreamSnapshot Streams []StreamSnapshot
Bulks []BulkSnapshot Bulks []BulkSnapshot
Records []RecordSnapshot
Transfers []TransferSnapshot Transfers []TransferSnapshot
Summary DiagnosticsSummary Summary DiagnosticsSummary
} }
@@ -75,6 +109,7 @@ type ServerDiagnosticsSnapshot struct {
CurrentTransports []TransportConnRuntimeSnapshot CurrentTransports []TransportConnRuntimeSnapshot
Streams []StreamSnapshot Streams []StreamSnapshot
Bulks []BulkSnapshot Bulks []BulkSnapshot
Records []RecordSnapshot
Transfers []TransferSnapshot Transfers []TransferSnapshot
Summary DiagnosticsSummary Summary DiagnosticsSummary
} }
@@ -100,6 +135,10 @@ func GetClientDiagnosticsSnapshot(c Client) (ClientDiagnosticsSnapshot, error) {
if err != nil { if err != nil {
return ClientDiagnosticsSnapshot{}, err return ClientDiagnosticsSnapshot{}, err
} }
records, err := GetClientRecordSnapshots(c)
if err != nil {
return ClientDiagnosticsSnapshot{}, err
}
transfers, err := GetClientTransferSnapshots(c) transfers, err := GetClientTransferSnapshots(c)
if err != nil { if err != nil {
return ClientDiagnosticsSnapshot{}, err return ClientDiagnosticsSnapshot{}, err
@@ -108,6 +147,7 @@ func GetClientDiagnosticsSnapshot(c Client) (ClientDiagnosticsSnapshot, error) {
Runtime: runtime, Runtime: runtime,
Streams: streams, Streams: streams,
Bulks: bulks, Bulks: bulks,
Records: records,
Transfers: transfers, Transfers: transfers,
} }
snapshot.Summary = summarizeClientDiagnosticsSnapshot(snapshot) snapshot.Summary = summarizeClientDiagnosticsSnapshot(snapshot)
@@ -138,6 +178,10 @@ func GetServerDiagnosticsSnapshot(s Server) (ServerDiagnosticsSnapshot, error) {
if err != nil { if err != nil {
return ServerDiagnosticsSnapshot{}, err return ServerDiagnosticsSnapshot{}, err
} }
records, err := GetServerRecordSnapshots(s)
if err != nil {
return ServerDiagnosticsSnapshot{}, err
}
transfers, err := GetServerTransferSnapshots(s) transfers, err := GetServerTransferSnapshots(s)
if err != nil { if err != nil {
return ServerDiagnosticsSnapshot{}, err return ServerDiagnosticsSnapshot{}, err
@@ -148,6 +192,7 @@ func GetServerDiagnosticsSnapshot(s Server) (ServerDiagnosticsSnapshot, error) {
CurrentTransports: transports, CurrentTransports: transports,
Streams: streams, Streams: streams,
Bulks: bulks, Bulks: bulks,
Records: records,
Transfers: transfers, Transfers: transfers,
} }
snapshot.Summary = summarizeServerDiagnosticsSnapshot(snapshot) snapshot.Summary = summarizeServerDiagnosticsSnapshot(snapshot)
@@ -196,13 +241,18 @@ func serverCurrentTransportRuntimeSnapshots(s Server) ([]TransportConnRuntimeSna
func summarizeClientDiagnosticsSnapshot(snapshot ClientDiagnosticsSnapshot) DiagnosticsSummary { func summarizeClientDiagnosticsSnapshot(snapshot ClientDiagnosticsSnapshot) DiagnosticsSummary {
summary := DiagnosticsSummary{ summary := DiagnosticsSummary{
LogicalCount: diagnosticsLogicalCountFromClientRuntime(snapshot.Runtime), LogicalCount: diagnosticsLogicalCountFromClientRuntime(snapshot.Runtime),
BulkAttachAttempts: snapshot.Runtime.BulkAttachAttempts,
BulkAttachRetries: snapshot.Runtime.BulkAttachRetries,
BulkAttachSuccess: snapshot.Runtime.BulkAttachSuccess,
BulkAutoFallbacks: snapshot.Runtime.BulkAutoFallbacks,
} }
if snapshot.Runtime.TransportAttached { if snapshot.Runtime.TransportAttached {
summary.CurrentTransportCount = 1 summary.CurrentTransportCount = 1
} }
summarizeStreamSnapshots(&summary, snapshot.Streams) summarizeStreamSnapshots(&summary, snapshot.Streams)
summarizeBulkSnapshots(&summary, snapshot.Bulks) summarizeBulkSnapshots(&summary, snapshot.Bulks)
summarizeRecordSnapshots(&summary, snapshot.Records)
summarizeTransferSnapshots(&summary, snapshot.Transfers) summarizeTransferSnapshots(&summary, snapshot.Transfers)
return summary return summary
} }
@@ -214,6 +264,7 @@ func summarizeServerDiagnosticsSnapshot(snapshot ServerDiagnosticsSnapshot) Diag
} }
summarizeStreamSnapshots(&summary, snapshot.Streams) summarizeStreamSnapshots(&summary, snapshot.Streams)
summarizeBulkSnapshots(&summary, snapshot.Bulks) summarizeBulkSnapshots(&summary, snapshot.Bulks)
summarizeRecordSnapshots(&summary, snapshot.Records)
summarizeTransferSnapshots(&summary, snapshot.Transfers) summarizeTransferSnapshots(&summary, snapshot.Transfers)
return summary return summary
} }
@@ -266,6 +317,27 @@ func summarizeBulkSnapshots(summary *DiagnosticsSummary, snapshots []BulkSnapsho
} }
} }
func summarizeRecordSnapshots(summary *DiagnosticsSummary, snapshots []RecordSnapshot) {
if summary == nil {
return
}
summary.RecordCount = len(snapshots)
for _, snapshot := range snapshots {
switch {
case snapshot.ResetError != "":
summary.ResetRecordCount++
accumulateDiagnosticsResetCause(&summary.RecordResetCauses, snapshot.ResetError, "")
case recordSnapshotFinished(snapshot):
case recordSnapshotBoundActive(snapshot):
summary.ActiveRecordCount++
default:
summary.StaleRecordCount++
}
accumulateDiagnosticsRecordTelemetry(&summary.RecordTelemetry, snapshot)
}
finalizeDiagnosticsRecordTelemetry(&summary.RecordTelemetry)
}
func summarizeTransferSnapshots(summary *DiagnosticsSummary, snapshots []TransferSnapshot) { func summarizeTransferSnapshots(summary *DiagnosticsSummary, snapshots []TransferSnapshot) {
if summary == nil { if summary == nil {
return return
@@ -297,6 +369,10 @@ func bulkSnapshotFinished(snapshot BulkSnapshot) bool {
return snapshot.ResetError == "" && snapshot.LocalClosed && snapshot.RemoteClosed return snapshot.ResetError == "" && snapshot.LocalClosed && snapshot.RemoteClosed
} }
func recordSnapshotFinished(snapshot RecordSnapshot) bool {
return snapshot.ResetError == "" && snapshot.LocalClosed && snapshot.RemoteClosed
}
func streamSnapshotBoundActive(snapshot StreamSnapshot) bool { func streamSnapshotBoundActive(snapshot StreamSnapshot) bool {
return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent
} }
@@ -305,6 +381,10 @@ func bulkSnapshotBoundActive(snapshot BulkSnapshot) bool {
return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent
} }
func recordSnapshotBoundActive(snapshot RecordSnapshot) bool {
return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent
}
func accumulateDiagnosticsResetCause(summary *DiagnosticsResetCauseSummary, resetError string, backpressureError string) { func accumulateDiagnosticsResetCause(summary *DiagnosticsResetCauseSummary, resetError string, backpressureError string) {
if summary == nil || resetError == "" { if summary == nil || resetError == "" {
return return
@@ -362,6 +442,38 @@ func finalizeDiagnosticsTransferTelemetry(summary *DiagnosticsTransferTelemetryS
summary.CommitWaitRatio = durationRatio(summary.CommitWaitDuration, summary.ObservedDuration) summary.CommitWaitRatio = durationRatio(summary.CommitWaitDuration, summary.ObservedDuration)
} }
func accumulateDiagnosticsRecordTelemetry(summary *DiagnosticsRecordTelemetrySummary, snapshot RecordSnapshot) {
if summary == nil {
return
}
summary.BatchFramesSent += snapshot.BatchFramesSent
summary.AckFramesSent += snapshot.AckFramesSent
summary.ErrorFramesSent += snapshot.ErrorFramesSent
summary.BatchFramesReceived += snapshot.BatchFramesReceived
summary.AckFramesReceived += snapshot.AckFramesReceived
summary.ErrorFramesReceived += snapshot.ErrorFramesReceived
summary.PiggybackAckSent += snapshot.PiggybackAckSent
summary.PiggybackAckReceived += snapshot.PiggybackAckReceived
summary.BarrierCount += snapshot.BarrierCount
summary.BarrierFlushWaitDuration += snapshot.BarrierFlushWaitDuration
summary.BarrierApplyWaitDuration += snapshot.BarrierApplyWaitDuration
summary.OutstandingRecords += snapshot.OutstandingRecords
summary.OutstandingBytes += snapshot.OutstandingBytes
summary.PendingApplyRecords += snapshot.PendingApplyRecords
summary.PendingAckRecords += snapshot.PendingAckRecords
if snapshot.PeakPendingApplyRecords > summary.PeakPendingApplyRecords {
summary.PeakPendingApplyRecords = snapshot.PeakPendingApplyRecords
}
}
func finalizeDiagnosticsRecordTelemetry(summary *DiagnosticsRecordTelemetrySummary) {
if summary == nil {
return
}
summary.FrameSendCount = summary.BatchFramesSent + summary.AckFramesSent + summary.ErrorFramesSent
summary.FrameReceiveCount = summary.BatchFramesReceived + summary.AckFramesReceived + summary.ErrorFramesReceived
}
func sortClientConnRuntimeSnapshots(src []ClientConnRuntimeSnapshot) { func sortClientConnRuntimeSnapshots(src []ClientConnRuntimeSnapshot) {
sort.Slice(src, func(i, j int) bool { sort.Slice(src, func(i, j int) bool {
if src[i].ClientID != src[j].ClientID { if src[i].ClientID != src[j].ClientID {
+253 -1
View File
@@ -20,7 +20,7 @@ func TestGetClientDiagnosticsSnapshotDefaults(t *testing.T) {
if got, want := snapshot.Runtime.OwnerState, "idle"; got != want { if got, want := snapshot.Runtime.OwnerState, "idle"; got != want {
t.Fatalf("Runtime.OwnerState = %q, want %q", got, want) t.Fatalf("Runtime.OwnerState = %q, want %q", got, want)
} }
if len(snapshot.Streams) != 0 || len(snapshot.Bulks) != 0 || len(snapshot.Transfers) != 0 { if len(snapshot.Streams) != 0 || len(snapshot.Bulks) != 0 || len(snapshot.Records) != 0 || len(snapshot.Transfers) != 0 {
t.Fatalf("default diagnostics should be empty: %+v", snapshot) t.Fatalf("default diagnostics should be empty: %+v", snapshot)
} }
if snapshot.Summary != (DiagnosticsSummary{}) { if snapshot.Summary != (DiagnosticsSummary{}) {
@@ -130,6 +130,137 @@ func TestGetClientDiagnosticsSnapshotAggregatesActiveState(t *testing.T) {
_ = bulk.Close() _ = bulk.Close()
} }
func TestGetDiagnosticsSnapshotAggregatesActiveRecordState(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
recordAcceptCh := make(chan RecordAcceptInfo, 1)
recordReleaseCh := make(chan struct{})
recordHandlerDone := make(chan error, 1)
server.SetRecordStreamHandler(func(info RecordAcceptInfo) error {
recordAcceptCh <- info
msg, err := info.RecordStream.ReadRecord(context.Background())
if err != nil {
recordHandlerDone <- err
return err
}
if string(msg.Payload) != "diag-record" {
err = errors.New("unexpected record payload")
recordHandlerDone <- err
return err
}
if err := info.RecordStream.AckRecord(msg.Seq); err != nil {
recordHandlerDone <- err
return err
}
<-recordReleaseCh
err = info.RecordStream.Close()
recordHandlerDone <- err
return err
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
record, err := client.OpenRecordStream(context.Background(), RecordOpenOptions{
Stream: StreamOpenOptions{ID: "diag-client-record"},
})
if err != nil {
t.Fatalf("client OpenRecordStream failed: %v", err)
}
select {
case <-recordAcceptCh:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting accepted record stream")
}
if _, err := record.WriteRecord(context.Background(), []byte("diag-record")); err != nil {
t.Fatalf("WriteRecord failed: %v", err)
}
if _, err := record.Barrier(context.Background()); err != nil {
t.Fatalf("Barrier failed: %v", err)
}
clientSnapshot, err := GetClientDiagnosticsSnapshot(client)
if err != nil {
t.Fatalf("GetClientDiagnosticsSnapshot failed: %v", err)
}
if got, want := len(clientSnapshot.Records), 1; got != want {
t.Fatalf("client record snapshot count = %d, want %d", got, want)
}
if got, want := clientSnapshot.Summary.RecordCount, 1; got != want {
t.Fatalf("client RecordCount = %d, want %d", got, want)
}
if got, want := clientSnapshot.Summary.ActiveRecordCount, 1; got != want {
t.Fatalf("client ActiveRecordCount = %d, want %d", got, want)
}
clientRecord := clientSnapshot.Records[0]
if got := clientRecord.BatchFramesSent; got < 1 {
t.Fatalf("client BatchFramesSent = %d, want >= 1", got)
}
if got := clientRecord.AckFramesReceived; got < 1 {
t.Fatalf("client AckFramesReceived = %d, want >= 1", got)
}
if got := clientRecord.BarrierCount; got < 1 {
t.Fatalf("client BarrierCount = %d, want >= 1", got)
}
if got := clientSnapshot.Summary.RecordTelemetry.FrameSendCount; got < 1 {
t.Fatalf("client RecordTelemetry.FrameSendCount = %d, want >= 1", got)
}
serverSnapshot, err := GetServerDiagnosticsSnapshot(server)
if err != nil {
t.Fatalf("GetServerDiagnosticsSnapshot failed: %v", err)
}
if got, want := len(serverSnapshot.Records), 1; got != want {
t.Fatalf("server record snapshot count = %d, want %d", got, want)
}
if got, want := serverSnapshot.Summary.RecordCount, 1; got != want {
t.Fatalf("server RecordCount = %d, want %d", got, want)
}
if got, want := serverSnapshot.Summary.ActiveRecordCount, 1; got != want {
t.Fatalf("server ActiveRecordCount = %d, want %d", got, want)
}
serverRecord := serverSnapshot.Records[0]
if got := serverRecord.BatchFramesReceived; got < 1 {
t.Fatalf("server BatchFramesReceived = %d, want >= 1", got)
}
if got := serverRecord.AckFramesSent; got < 1 {
t.Fatalf("server AckFramesSent = %d, want >= 1", got)
}
if got := serverSnapshot.Summary.RecordTelemetry.FrameReceiveCount; got < 1 {
t.Fatalf("server RecordTelemetry.FrameReceiveCount = %d, want >= 1", got)
}
close(recordReleaseCh)
select {
case err := <-recordHandlerDone:
if err != nil {
t.Fatalf("record handler failed: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting record handler completion")
}
_ = record.Close()
}
func TestGetServerDiagnosticsSnapshotAggregatesStaleAndResetState(t *testing.T) { func TestGetServerDiagnosticsSnapshotAggregatesStaleAndResetState(t *testing.T) {
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
@@ -323,6 +454,127 @@ func TestDiagnosticsSummaryClassifiesResetCauses(t *testing.T) {
} }
} }
func TestDiagnosticsSummaryAggregatesRecordTelemetry(t *testing.T) {
summary := summarizeClientDiagnosticsSnapshot(ClientDiagnosticsSnapshot{
Records: []RecordSnapshot{
{
ID: "record-active",
BindingCurrent: true,
TransportAttached: true,
TransportCurrent: true,
OutstandingRecords: 3,
OutstandingBytes: 4096,
PendingApplyRecords: 2,
PendingAckRecords: 1,
PeakPendingApplyRecords: 5,
BatchFramesSent: 10,
AckFramesSent: 4,
ErrorFramesSent: 1,
BatchFramesReceived: 8,
AckFramesReceived: 3,
ErrorFramesReceived: 0,
PiggybackAckSent: 6,
PiggybackAckReceived: 2,
BarrierCount: 4,
BarrierFlushWaitDuration: 10 * time.Millisecond,
BarrierApplyWaitDuration: 30 * time.Millisecond,
},
{
ID: "record-reset",
ResetError: errTransportDetached.Error(),
OutstandingRecords: 1,
OutstandingBytes: 512,
PendingApplyRecords: 3,
PendingAckRecords: 2,
PeakPendingApplyRecords: 7,
BatchFramesSent: 2,
AckFramesSent: 1,
ErrorFramesSent: 1,
BatchFramesReceived: 1,
AckFramesReceived: 1,
ErrorFramesReceived: 1,
PiggybackAckSent: 1,
PiggybackAckReceived: 1,
BarrierCount: 1,
BarrierFlushWaitDuration: 5 * time.Millisecond,
BarrierApplyWaitDuration: 15 * time.Millisecond,
},
},
})
if got, want := summary.RecordCount, 2; got != want {
t.Fatalf("RecordCount = %d, want %d", got, want)
}
if got, want := summary.ActiveRecordCount, 1; got != want {
t.Fatalf("ActiveRecordCount = %d, want %d", got, want)
}
if got, want := summary.ResetRecordCount, 1; got != want {
t.Fatalf("ResetRecordCount = %d, want %d", got, want)
}
if got, want := summary.RecordResetCauses.Total, 1; got != want {
t.Fatalf("RecordResetCauses.Total = %d, want %d", got, want)
}
if got, want := summary.RecordResetCauses.TransportDetached, 1; got != want {
t.Fatalf("RecordResetCauses.TransportDetached = %d, want %d", got, want)
}
telemetry := summary.RecordTelemetry
if got, want := telemetry.BatchFramesSent, int64(12); got != want {
t.Fatalf("BatchFramesSent = %d, want %d", got, want)
}
if got, want := telemetry.AckFramesSent, int64(5); got != want {
t.Fatalf("AckFramesSent = %d, want %d", got, want)
}
if got, want := telemetry.ErrorFramesSent, int64(2); got != want {
t.Fatalf("ErrorFramesSent = %d, want %d", got, want)
}
if got, want := telemetry.BatchFramesReceived, int64(9); got != want {
t.Fatalf("BatchFramesReceived = %d, want %d", got, want)
}
if got, want := telemetry.AckFramesReceived, int64(4); got != want {
t.Fatalf("AckFramesReceived = %d, want %d", got, want)
}
if got, want := telemetry.ErrorFramesReceived, int64(1); got != want {
t.Fatalf("ErrorFramesReceived = %d, want %d", got, want)
}
if got, want := telemetry.FrameSendCount, int64(19); got != want {
t.Fatalf("FrameSendCount = %d, want %d", got, want)
}
if got, want := telemetry.FrameReceiveCount, int64(14); got != want {
t.Fatalf("FrameReceiveCount = %d, want %d", got, want)
}
if got, want := telemetry.PiggybackAckSent, int64(7); got != want {
t.Fatalf("PiggybackAckSent = %d, want %d", got, want)
}
if got, want := telemetry.PiggybackAckReceived, int64(3); got != want {
t.Fatalf("PiggybackAckReceived = %d, want %d", got, want)
}
if got, want := telemetry.BarrierCount, int64(5); got != want {
t.Fatalf("BarrierCount = %d, want %d", got, want)
}
if got, want := telemetry.BarrierFlushWaitDuration, 15*time.Millisecond; got != want {
t.Fatalf("BarrierFlushWaitDuration = %v, want %v", got, want)
}
if got, want := telemetry.BarrierApplyWaitDuration, 45*time.Millisecond; got != want {
t.Fatalf("BarrierApplyWaitDuration = %v, want %v", got, want)
}
if got, want := telemetry.OutstandingRecords, 4; got != want {
t.Fatalf("OutstandingRecords = %d, want %d", got, want)
}
if got, want := telemetry.OutstandingBytes, 4608; got != want {
t.Fatalf("OutstandingBytes = %d, want %d", got, want)
}
if got, want := telemetry.PendingApplyRecords, 5; got != want {
t.Fatalf("PendingApplyRecords = %d, want %d", got, want)
}
if got, want := telemetry.PendingAckRecords, 3; got != want {
t.Fatalf("PendingAckRecords = %d, want %d", got, want)
}
if got, want := telemetry.PeakPendingApplyRecords, 7; got != want {
t.Fatalf("PeakPendingApplyRecords = %d, want %d", got, want)
}
}
func TestDiagnosticsSummaryAggregatesTransferTelemetry(t *testing.T) { func TestDiagnosticsSummaryAggregatesTransferTelemetry(t *testing.T) {
summary := summarizeClientDiagnosticsSnapshot(ClientDiagnosticsSnapshot{ summary := summarizeClientDiagnosticsSnapshot(ClientDiagnosticsSnapshot{
Transfers: []TransferSnapshot{ Transfers: []TransferSnapshot{
+1 -1
View File
@@ -4,7 +4,7 @@ go 1.24.0
require ( require (
b612.me/starcrypto v1.0.2 b612.me/starcrypto v1.0.2
b612.me/stario v0.1.0 b612.me/stario v0.1.1
github.com/Microsoft/go-winio v0.6.2 github.com/Microsoft/go-winio v0.6.2
) )
+2 -2
View File
@@ -1,7 +1,7 @@
b612.me/starcrypto v1.0.2 h1:6f8YHNMHZPwxDSRxY2OJeMP4ExKa/cakLIO04f0gLhE= b612.me/starcrypto v1.0.2 h1:6f8YHNMHZPwxDSRxY2OJeMP4ExKa/cakLIO04f0gLhE=
b612.me/starcrypto v1.0.2/go.mod h1:I7oYTmQgnVPj5S5yKwoTyqkItq1HgF9XdJT/v3qs5QE= b612.me/starcrypto v1.0.2/go.mod h1:I7oYTmQgnVPj5S5yKwoTyqkItq1HgF9XdJT/v3qs5QE=
b612.me/stario v0.1.0 h1:V1uA7fLYzgTadOXpnyPaFC3z0MAKFIM/RKXzZUDXvL4= b612.me/stario v0.1.1 h1:WIQy5DdK2Tkk+PIRORaVb76f4KY+64UvWChXNI7hSVY=
b612.me/stario v0.1.0/go.mod h1:7kjE69oFqNrca0P72L5+ZbTV09QGJ2N3bBY3qeFXOGc= b612.me/stario v0.1.1/go.mod h1:qMMqjaMhKmdfFn5T0oleL5L4FpFWEHsuIrT878hyo7I=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+24 -9
View File
@@ -404,11 +404,13 @@ func (c *LogicalConn) applyAttachmentProfile(maxReadTimeout time.Duration, maxWr
c.updateAttachmentState(func(state *clientConnAttachmentState) { c.updateAttachmentState(func(state *clientConnAttachmentState) {
state.maxReadTimeout = maxReadTimeout state.maxReadTimeout = maxReadTimeout
state.maxWriteTimeout = maxWriteTimeout state.maxWriteTimeout = maxWriteTimeout
state.protectionMode = ProtectionManaged
state.msgEn = msgEn state.msgEn = msgEn
state.msgDe = msgDe state.msgDe = msgDe
state.fastStreamEncode = fastStreamEncode state.fastStreamEncode = fastStreamEncode
state.fastBulkEncode = fastBulkEncode state.fastBulkEncode = fastBulkEncode
state.fastPlainEncode = fastPlainEncode state.fastPlainEncode = fastPlainEncode
state.modernPSKRuntime = nil
state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey) state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey)
state.secretKey = cloneClientConnAttachmentBytes(secretKey) state.secretKey = cloneClientConnAttachmentBytes(secretKey)
}) })
@@ -524,17 +526,23 @@ func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot {
return ClientConnRuntimeSnapshot{} return ClientConnRuntimeSnapshot{}
} }
status := c.Status() status := c.Status()
authenticated, fallback, attachAt := c.peerAttachAuthenticatedSnapshot()
now := time.Now() now := time.Now()
snapshot := ClientConnRuntimeSnapshot{ snapshot := ClientConnRuntimeSnapshot{
ClientID: c.clientIDSnapshot(), ClientID: c.clientIDSnapshot(),
Alive: status.Alive, Alive: status.Alive,
Reason: status.Reason, Reason: status.Reason,
IdentityBound: c.clientConnIdentityBoundSnapshot(), IdentityBound: c.clientConnIdentityBoundSnapshot(),
UsesStreamTransport: c.usesStreamTransportSnapshot(), UsesStreamTransport: c.usesStreamTransportSnapshot(),
TransportGeneration: c.transportGenerationSnapshot(), TransportGeneration: c.transportGenerationSnapshot(),
TransportAttachCount: c.clientConnTransportAttachCountSnapshot(), TransportAttachCount: c.clientConnTransportAttachCountSnapshot(),
TransportDetachCount: c.clientConnTransportDetachCountSnapshot(), TransportDetachCount: c.clientConnTransportDetachCountSnapshot(),
LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(), LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(),
AuthMode: authModeName(c.authModeSnapshot()),
ProtectionMode: protectionModeName(c.protectionModeSnapshot()),
PeerAttachAuthenticated: authenticated,
PeerAttachAuthFallback: fallback,
LastPeerAttachAt: attachAt,
} }
if status.Err != nil { if status.Err != nil {
snapshot.Error = status.Err.Error() snapshot.Error = status.Err.Error()
@@ -553,6 +561,12 @@ func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot {
snapshot.HasRuntimeConn = c.transportSnapshot() != nil snapshot.HasRuntimeConn = c.transportSnapshot() != nil
snapshot.HasRuntimeStopCtx = rt.stopCtx != nil snapshot.HasRuntimeStopCtx = rt.stopCtx != nil
} }
if binding := c.transportBindingSnapshot(); binding != nil {
snapshot.TransportBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot()
snapshot.TransportStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot()
snapshot.TransportStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot()
snapshot.TransportStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot()
}
if detach := c.transportDetachSnapshot(); detach != nil { if detach := c.transportDetachSnapshot(); detach != nil {
snapshot.TransportDetachReason = detach.Reason snapshot.TransportDetachReason = detach.Reason
snapshot.TransportDetachKind = classifyClientConnTransportDetachReason(detach.Reason) snapshot.TransportDetachKind = classifyClientConnTransportDetachReason(detach.Reason)
@@ -816,6 +830,7 @@ func (c *LogicalConn) applyClientConnAttachmentProfile(maxReadTimeout time.Durat
state.maxWriteTimeout = maxWriteTimeout state.maxWriteTimeout = maxWriteTimeout
state.msgEn = msgEn state.msgEn = msgEn
state.msgDe = msgDe state.msgDe = msgDe
state.modernPSKRuntime = nil
state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey) state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey)
state.secretKey = cloneClientConnAttachmentBytes(secretKey) state.secretKey = cloneClientConnAttachmentBytes(secretKey)
}) })
+14
View File
@@ -43,6 +43,20 @@ func TestHydrateServerMessagePeerFieldsFromLogicalConn(t *testing.T) {
} }
} }
func TestHydrateServerMessagePeerFieldsZeroValueDoesNotPanic(t *testing.T) {
message := hydrateServerMessagePeerFields(Message{})
if message.LogicalConn != nil {
t.Fatal("zero-value message should not hydrate logical conn")
}
if message.ClientConn != nil {
t.Fatal("zero-value message should not hydrate client conn")
}
if message.TransportConn != nil {
t.Fatal("zero-value message should not hydrate transport conn")
}
}
func TestServerPublishSendFileEventHydratesLogicalConnAlias(t *testing.T) { func TestServerPublishSendFileEventHydratesLogicalConnAlias(t *testing.T) {
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
left, right := net.Pipe() left, right := net.Pipe()
+33 -6
View File
@@ -37,9 +37,10 @@ type Message struct {
NetType NetType
LogicalConn *LogicalConn LogicalConn *LogicalConn
// Deprecated: ClientConn aliases LogicalConn for compatibility. // Deprecated: ClientConn aliases LogicalConn for compatibility.
ClientConn *ClientConn ClientConn *ClientConn
TransportConn *TransportConn TransportConn *TransportConn
ServerConn Client ServerConn Client
inboundTransportProfile *transportProtectionProfile
TransferMsg TransferMsg
Time time.Time Time time.Time
inboundConn net.Conn inboundConn net.Conn
@@ -58,7 +59,7 @@ type messageLogicalTransferSender interface {
} }
type messageInboundTransferSender interface { type messageInboundTransferSender interface {
sendTransferInbound(*LogicalConn, *TransportConn, net.Conn, TransferMsg) error sendTransferInbound(*LogicalConn, *TransportConn, net.Conn, *transportProtectionProfile, TransferMsg) error
} }
func (m *Message) Reply(value MsgVal) (err error) { func (m *Message) Reply(value MsgVal) (err error) {
@@ -86,7 +87,7 @@ func (m *Message) Reply(value MsgVal) (err error) {
if sender == nil { if sender == nil {
return transportDetachedErrorForPeer(logical, transport) return transportDetachedErrorForPeer(logical, transport)
} }
return sender.sendTransferInbound(logical, transport, m.inboundConn, reply) return sender.sendTransferInbound(logical, transport, m.inboundConn, messageInboundTransportProtectionSnapshot(m), reply)
} }
if transport != nil { if transport != nil {
_, err = transport.sendTransfer(reply) _, err = transport.sendTransfer(reply)
@@ -123,12 +124,19 @@ func hydrateServerMessagePeerFields(message Message) Message {
if message.LogicalConn == nil { if message.LogicalConn == nil {
message.LogicalConn = logicalConnFromClient(message.ClientConn) message.LogicalConn = logicalConnFromClient(message.ClientConn)
} }
if message.ClientConn == nil { if message.LogicalConn == nil && message.TransportConn != nil {
message.LogicalConn = message.TransportConn.logicalConnSnapshot()
}
if message.ClientConn == nil && message.LogicalConn != nil {
message.ClientConn = message.LogicalConn.compatClientConn() message.ClientConn = message.LogicalConn.compatClientConn()
} }
if message.TransportConn == nil && message.LogicalConn != nil { if message.TransportConn == nil && message.LogicalConn != nil {
message.TransportConn = message.LogicalConn.CurrentTransportConn() message.TransportConn = message.LogicalConn.CurrentTransportConn()
} }
if message.inboundConn != nil && message.inboundTransportProfile == nil && message.LogicalConn != nil {
profile := message.LogicalConn.transportProtectionProfileSnapshot()
message.inboundTransportProfile = &profile
}
return message return message
} }
@@ -155,3 +163,22 @@ func messageTransportConnSnapshot(message *Message) *TransportConn {
} }
return logical.CurrentTransportConn() return logical.CurrentTransportConn()
} }
func messageInboundTransportProtectionSnapshot(message *Message) *transportProtectionProfile {
if message == nil {
return nil
}
if message.inboundTransportProfile != nil {
return message.inboundTransportProfile
}
if message.inboundConn == nil {
return nil
}
logical := messageLogicalConnSnapshot(message)
if logical == nil {
return nil
}
profile := logical.transportProtectionProfileSnapshot()
message.inboundTransportProfile = &profile
return message.inboundTransportProfile
}
+517
View File
@@ -0,0 +1,517 @@
package notify
import (
"bytes"
"crypto/hmac"
cryptorand "crypto/rand"
"crypto/sha256"
"encoding/binary"
"errors"
"net"
"sync"
"sync/atomic"
"time"
)
const (
peerAttachFeatureExplicitAuth uint64 = 1 << iota
peerAttachFeatureChannelBinding
peerAttachFeatureForwardSecrecy
)
const (
peerAttachNonceSize = 16
peerAttachReplayTTL = 30 * time.Second
)
var (
errPeerAttachAuthInvalid = errors.New("peer attach auth invalid")
errPeerAttachReplayRejected = errors.New("peer attach replay rejected")
errPeerAttachReplayWindowFull = errors.New("peer attach replay window full")
errPeerAttachExplicitAuthRequired = errors.New("peer attach explicit auth required")
errPeerAttachChannelBindingRequired = errors.New("peer attach channel binding required")
errPeerAttachChannelBindingUnavailable = errors.New("peer attach channel binding unavailable")
errPeerAttachForwardSecrecyRequired = errors.New("peer attach forward secrecy required")
)
type peerAttachAuthResult struct {
explicit bool
fallback bool
clientNonce []byte
serverNonce []byte
channelBinding []byte
clientECDHEPublicKey []byte
}
type peerAttachReplayCache struct {
mu sync.Mutex
entries map[string]time.Time
rejects atomic.Int64
overflowRejects atomic.Int64
}
func newPeerAttachNonce() ([]byte, error) {
buf := make([]byte, peerAttachNonceSize)
if _, err := cryptorand.Read(buf); err != nil {
return nil, err
}
return buf, nil
}
func appendPeerAttachAuthBytes(dst []byte, data []byte) []byte {
dst = binary.BigEndian.AppendUint32(dst, uint32(len(data)))
return append(dst, data...)
}
func appendPeerAttachAuthString(dst []byte, value string) []byte {
return appendPeerAttachAuthBytes(dst, []byte(value))
}
func appendPeerAttachAuthBool(dst []byte, value bool) []byte {
if value {
return append(dst, 1)
}
return append(dst, 0)
}
func peerAttachRequestAuthPayload(req peerAttachRequest, channelBinding []byte) []byte {
buf := make([]byte, 0, 96+len(req.PeerID)+len(channelBinding))
buf = appendPeerAttachAuthString(buf, "notify/peer-attach/request-auth/v1")
buf = binary.BigEndian.AppendUint64(buf, req.Features)
buf = appendPeerAttachAuthString(buf, req.PeerID)
buf = appendPeerAttachAuthBytes(buf, req.ClientNonce)
if supportsPeerAttachChannelBinding(req.Features) {
buf = appendPeerAttachAuthBytes(buf, channelBinding)
}
return buf
}
func peerAttachResponseAuthPayload(req peerAttachRequest, resp peerAttachResponse, channelBinding []byte) []byte {
buf := make([]byte, 0, 160+len(req.PeerID)+len(resp.PeerID)+len(resp.Error)+len(channelBinding))
buf = appendPeerAttachAuthString(buf, "notify/peer-attach/response-auth/v1")
buf = binary.BigEndian.AppendUint64(buf, req.Features)
buf = appendPeerAttachAuthString(buf, req.PeerID)
buf = appendPeerAttachAuthBytes(buf, req.ClientNonce)
buf = binary.BigEndian.AppendUint64(buf, resp.Features)
buf = appendPeerAttachAuthString(buf, resp.PeerID)
buf = appendPeerAttachAuthBool(buf, resp.Accepted)
buf = appendPeerAttachAuthBool(buf, resp.Reused)
buf = appendPeerAttachAuthString(buf, resp.Error)
buf = appendPeerAttachAuthBytes(buf, resp.ServerNonce)
if supportsPeerAttachChannelBinding(resp.Features) {
buf = appendPeerAttachAuthBytes(buf, channelBinding)
}
return buf
}
func signPeerAttachPayload(secretKey []byte, payload []byte) []byte {
if len(secretKey) == 0 {
return nil
}
mac := hmac.New(sha256.New, secretKey)
_, _ = mac.Write(payload)
return mac.Sum(nil)
}
func computePeerAttachRequestAuthTag(secretKey []byte, req peerAttachRequest, channelBinding []byte) []byte {
return signPeerAttachPayload(secretKey, peerAttachRequestAuthPayload(req, channelBinding))
}
func computePeerAttachResponseAuthTag(secretKey []byte, req peerAttachRequest, resp peerAttachResponse, channelBinding []byte) []byte {
return signPeerAttachPayload(secretKey, peerAttachResponseAuthPayload(req, resp, channelBinding))
}
func supportsExplicitPeerAttachAuth(features uint64) bool {
return features&peerAttachFeatureExplicitAuth != 0
}
func supportsPeerAttachChannelBinding(features uint64) bool {
return features&peerAttachFeatureChannelBinding != 0
}
func supportsPeerAttachForwardSecrecy(features uint64) bool {
return features&peerAttachFeatureForwardSecrecy != 0
}
func classifyPeerAttachRejectCounter(s *ServerCommon, err error) {
if s == nil || err == nil {
return
}
switch {
case errors.Is(err, errPeerAttachReplayRejected):
s.peerAttachReplay.rejects.Add(1)
case errors.Is(err, errPeerAttachReplayWindowFull):
s.peerAttachReplay.overflowRejects.Add(1)
case errors.Is(err, errPeerAttachExplicitAuthRequired), errors.Is(err, errPeerAttachChannelBindingRequired), errors.Is(err, errPeerAttachForwardSecrecyRequired):
s.peerAttachDowngradeRejectCount.Add(1)
case errors.Is(err, errPeerAttachChannelBindingUnavailable):
s.peerAttachBindingRejectCount.Add(1)
default:
s.peerAttachAuthRejectCount.Add(1)
}
}
func (c *ClientCommon) shouldUseExplicitPeerAttachAuth() bool {
if c == nil || !c.securityConfigured {
return false
}
return c.securityAuthMode == AuthPSK && len(c.securityBootstrap.secretKey) != 0
}
func resolvePeerAttachChannelBinding(provider PeerAttachChannelBindingProvider, role PeerAttachChannelBindingRole, peerID string, conn net.Conn) ([]byte, error) {
if provider == nil {
return nil, nil
}
if conn == nil {
return nil, errPeerAttachChannelBindingUnavailable
}
binding, err := provider(PeerAttachChannelBindingContext{
Role: role,
PeerID: peerID,
Conn: conn,
})
if err != nil || len(binding) == 0 {
return nil, errPeerAttachChannelBindingUnavailable
}
return bytes.Clone(binding), nil
}
func (c *ClientCommon) buildPeerAttachRequest(peerID string) (peerAttachRequest, peerAttachRequestState, error) {
cfg := c.peerAttachSecuritySnapshot()
req := peerAttachRequest{
PeerID: stringsTrimSpaceNoAlloc(peerID),
}
if !c.shouldUseExplicitPeerAttachAuth() {
if c.clientRequiresForwardSecrecy() {
return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachForwardSecrecyRequired
}
if cfg.requireExplicitAuth {
return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachExplicitAuthRequired
}
return req, peerAttachRequestState{}, nil
}
nonce, err := newPeerAttachNonce()
if err != nil {
return peerAttachRequest{}, peerAttachRequestState{}, err
}
req.Features = peerAttachFeatureExplicitAuth
requestState := peerAttachRequestState{}
var channelBinding []byte
if cfg.channelBinding != nil {
channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleClient, req.PeerID, c.clientTransportConnSnapshot())
if err != nil {
return peerAttachRequest{}, peerAttachRequestState{}, err
}
}
if len(channelBinding) != 0 {
req.Features |= peerAttachFeatureChannelBinding
}
if cfg.requireChannelBinding && !supportsPeerAttachChannelBinding(req.Features) {
return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachChannelBindingRequired
}
if c.clientSupportsForwardSecrecy() {
requestState.forwardSecrecy, err = newPeerAttachForwardSecrecyClientState()
if err != nil {
return peerAttachRequest{}, peerAttachRequestState{}, err
}
req.Features |= peerAttachFeatureForwardSecrecy
req.ClientECDHEPublicKey = bytes.Clone(requestState.forwardSecrecy.publicKey)
}
req.ClientNonce = nonce
req.AuthTag = computePeerAttachRequestAuthTag(c.securityBootstrap.secretKey, req, channelBinding)
return req, requestState, nil
}
func (c *ClientCommon) verifyPeerAttachResponse(req peerAttachRequest, resp peerAttachResponse, requestState peerAttachRequestState) (peerAttachResponseVerifyResult, error) {
cfg := c.peerAttachSecuritySnapshot()
baseSteady := transportProtectionProfile{}
if c != nil {
baseSteady = c.securitySteady.clone().withForwardSecrecyFallback(false)
}
result := peerAttachResponseVerifyResult{steadyProfile: baseSteady}
if c == nil || !c.shouldUseExplicitPeerAttachAuth() {
if c.clientRequiresForwardSecrecy() {
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired
}
if cfg.requireExplicitAuth {
return peerAttachResponseVerifyResult{}, errPeerAttachExplicitAuthRequired
}
return result, nil
}
if !supportsExplicitPeerAttachAuth(resp.Features) {
if c.clientRequiresForwardSecrecy() {
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired
}
if cfg.requireExplicitAuth {
return peerAttachResponseVerifyResult{}, errPeerAttachExplicitAuthRequired
}
if c.clientSupportsForwardSecrecy() {
result.steadyProfile = result.steadyProfile.withForwardSecrecyFallback(true)
}
result.authFallback = true
return result, nil
}
var channelBinding []byte
if supportsPeerAttachChannelBinding(req.Features) {
if cfg.channelBinding == nil {
return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingUnavailable
}
if !supportsPeerAttachChannelBinding(resp.Features) {
return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingRequired
}
var err error
channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleClient, req.PeerID, c.clientTransportConnSnapshot())
if err != nil {
return peerAttachResponseVerifyResult{}, err
}
} else if cfg.requireChannelBinding {
return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingRequired
}
if len(resp.ServerNonce) != peerAttachNonceSize {
return peerAttachResponseVerifyResult{}, errPeerAttachAuthInvalid
}
expected := computePeerAttachResponseAuthTag(c.securityBootstrap.secretKey, req, peerAttachResponse{
PeerID: resp.PeerID,
Accepted: resp.Accepted,
Reused: resp.Reused,
Error: resp.Error,
Features: resp.Features,
ServerNonce: resp.ServerNonce,
}, channelBinding)
if !hmac.Equal(resp.AuthTag, expected) {
return peerAttachResponseVerifyResult{}, errPeerAttachAuthInvalid
}
if requestState.forwardSecrecy == nil || !supportsPeerAttachForwardSecrecy(req.Features) {
return result, nil
}
if !supportsPeerAttachForwardSecrecy(resp.Features) {
if c.clientRequiresForwardSecrecy() {
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired
}
result.steadyProfile = result.steadyProfile.withForwardSecrecyFallback(true)
return result, nil
}
if resp.KeyMode != "" && resp.KeyMode != peerAttachKeyModeECDHE {
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyInvalid
}
profile, err := derivePeerAttachForwardSecrecyTransportProfile(c.securitySteady, c.securityBootstrap.secretKey, requestState.forwardSecrecy.privateKey, resp.ServerECDHEPublicKey, req, resp)
if err != nil {
return peerAttachResponseVerifyResult{}, err
}
result.steadyProfile = profile
return result, nil
}
func (s *ServerCommon) validatePeerAttachRequestAuth(logical *LogicalConn, transport net.Conn, req peerAttachRequest) (peerAttachAuthResult, error) {
cfg := s.peerAttachSecuritySnapshot()
if !supportsExplicitPeerAttachAuth(req.Features) {
if s.serverRequiresForwardSecrecy() {
return peerAttachAuthResult{}, errPeerAttachForwardSecrecyRequired
}
if cfg.requireExplicitAuth {
return peerAttachAuthResult{}, errPeerAttachExplicitAuthRequired
}
return peerAttachAuthResult{fallback: true}, nil
}
if logical == nil {
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
}
var channelBinding []byte
if supportsPeerAttachChannelBinding(req.Features) {
if cfg.channelBinding == nil {
return peerAttachAuthResult{}, errPeerAttachChannelBindingUnavailable
}
var err error
channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleServer, req.PeerID, transport)
if err != nil {
return peerAttachAuthResult{}, err
}
} else if cfg.requireChannelBinding {
return peerAttachAuthResult{}, errPeerAttachChannelBindingRequired
}
secretKey := logical.secretKeySnapshot()
if len(secretKey) == 0 {
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
}
if len(req.ClientNonce) != peerAttachNonceSize {
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
}
expected := computePeerAttachRequestAuthTag(secretKey, peerAttachRequest{
PeerID: req.PeerID,
Features: req.Features,
ClientNonce: req.ClientNonce,
}, channelBinding)
if !hmac.Equal(req.AuthTag, expected) {
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
}
if supportsPeerAttachForwardSecrecy(req.Features) {
if len(req.ClientECDHEPublicKey) != peerAttachECDHEPublicKeySize {
return peerAttachAuthResult{}, errPeerAttachForwardSecrecyInvalid
}
}
if err := s.acceptPeerAttachReplay(req.PeerID, req.ClientNonce, time.Now(), cfg.replayWindow, cfg.replayCapacity); err != nil {
return peerAttachAuthResult{}, err
}
serverNonce, err := newPeerAttachNonce()
if err != nil {
return peerAttachAuthResult{}, err
}
return peerAttachAuthResult{
explicit: true,
clientNonce: bytes.Clone(req.ClientNonce),
serverNonce: serverNonce,
channelBinding: channelBinding,
clientECDHEPublicKey: bytes.Clone(req.ClientECDHEPublicKey),
}, nil
}
func (s *ServerCommon) signPeerAttachResponse(logical *LogicalConn, req peerAttachRequest, resp *peerAttachResponse, auth peerAttachAuthResult) {
if s == nil || logical == nil || resp == nil || !auth.explicit {
return
}
secretKey := logical.secretKeySnapshot()
if len(secretKey) == 0 {
return
}
resp.Features |= peerAttachFeatureExplicitAuth
if len(auth.channelBinding) != 0 {
resp.Features |= peerAttachFeatureChannelBinding
}
resp.ServerNonce = bytes.Clone(auth.serverNonce)
resp.AuthTag = computePeerAttachResponseAuthTag(secretKey, req, *resp, auth.channelBinding)
}
func (s *ServerCommon) preparePeerAttachSteadyTransportProfile(logical *LogicalConn, req peerAttachRequest, resp *peerAttachResponse, auth peerAttachAuthResult) (transportProtectionProfile, error) {
if s == nil {
return transportProtectionProfile{}, nil
}
profile := s.securitySteady.clone().withForwardSecrecyFallback(false)
if resp != nil && auth.explicit {
resp.Features |= peerAttachFeatureExplicitAuth
if len(auth.channelBinding) != 0 {
resp.Features |= peerAttachFeatureChannelBinding
}
resp.ServerNonce = bytes.Clone(auth.serverNonce)
}
if resp != nil && resp.KeyMode == "" {
resp.KeyMode = profile.keyMode
}
if !s.serverSupportsForwardSecrecy() {
return profile, nil
}
if !auth.explicit || !supportsPeerAttachForwardSecrecy(req.Features) {
if s.serverRequiresForwardSecrecy() {
return transportProtectionProfile{}, errPeerAttachForwardSecrecyRequired
}
return profile.withForwardSecrecyFallback(true), nil
}
fsState, err := newPeerAttachForwardSecrecyClientState()
if err != nil {
return transportProtectionProfile{}, err
}
if resp != nil {
resp.Features |= peerAttachFeatureForwardSecrecy
resp.KeyMode = peerAttachKeyModeECDHE
resp.ServerECDHEPublicKey = bytes.Clone(fsState.publicKey)
}
return derivePeerAttachForwardSecrecyTransportProfile(s.securitySteady, logical.secretKeySnapshot(), fsState.privateKey, auth.clientECDHEPublicKey, req, *resp)
}
func (s *ServerCommon) acceptPeerAttachReplay(peerID string, nonce []byte, now time.Time, window time.Duration, capacity int) error {
if s == nil || len(nonce) == 0 {
return nil
}
cache := &s.peerAttachReplay
key := peerID + "\x00" + string(nonce)
expireBefore := now.Add(-window)
cache.mu.Lock()
defer cache.mu.Unlock()
if cache.entries == nil {
cache.entries = make(map[string]time.Time)
}
for replayKey, seenAt := range cache.entries {
if seenAt.Before(expireBefore) {
delete(cache.entries, replayKey)
}
}
if seenAt, ok := cache.entries[key]; ok && !seenAt.Before(expireBefore) {
return errPeerAttachReplayRejected
}
if capacity > 0 && len(cache.entries) >= capacity {
return errPeerAttachReplayWindowFull
}
cache.entries[key] = now
return nil
}
func (s *ServerCommon) peerAttachReplayRejectCountSnapshot() int64 {
if s == nil {
return 0
}
return s.peerAttachReplay.rejects.Load()
}
func (s *ServerCommon) peerAttachReplayOverflowRejectCountSnapshot() int64 {
if s == nil {
return 0
}
return s.peerAttachReplay.overflowRejects.Load()
}
func (c *ClientCommon) markClientPeerAttachAuthenticated(fallback bool, at time.Time) {
if c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
c.peerAttachAuthenticated = true
c.peerAttachAuthFallback = fallback
c.peerAttachAt = at.UnixNano()
}
func (c *ClientCommon) resetClientPeerAttachAuth() {
if c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
c.peerAttachAuthenticated = false
c.peerAttachAuthFallback = false
c.peerAttachAt = 0
}
func (c *ClientCommon) clientPeerAttachAuthSnapshot() (bool, bool, time.Time) {
if c == nil {
return false, false, time.Time{}
}
c.mu.Lock()
defer c.mu.Unlock()
if c.peerAttachAt == 0 {
return c.peerAttachAuthenticated, c.peerAttachAuthFallback, time.Time{}
}
return c.peerAttachAuthenticated, c.peerAttachAuthFallback, time.Unix(0, c.peerAttachAt)
}
func stringsTrimSpaceNoAlloc(value string) string {
start := 0
for start < len(value) {
switch value[start] {
case ' ', '\t', '\n', '\r':
start++
default:
goto endStart
}
}
return ""
endStart:
end := len(value)
for end > start {
switch value[end-1] {
case ' ', '\t', '\n', '\r':
end--
default:
return value[start:end]
}
}
return value[start:end]
}
+237
View File
@@ -0,0 +1,237 @@
package notify
import (
"bytes"
"errors"
"testing"
)
func newPeerAttachAuthLogicalForTest(t *testing.T, server *ServerCommon) *LogicalConn {
t.Helper()
logical := newServerLogicalConn(server, "accepted-auth", nil)
logical = server.registerAcceptedLogical(logical)
if logical == nil {
t.Fatal("registerAcceptedLogical returned nil")
}
return logical
}
func TestPeerAttachExplicitAuthHelpersRoundTrip(t *testing.T) {
secret := []byte("correct horse battery staple")
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
logical := newPeerAttachAuthLogicalForTest(t, server)
req, reqState, err := client.buildPeerAttachRequest("peer-explicit")
if err != nil {
t.Fatalf("buildPeerAttachRequest failed: %v", err)
}
if !supportsExplicitPeerAttachAuth(req.Features) {
t.Fatalf("request features = %d, want explicit auth bit", req.Features)
}
if !supportsPeerAttachForwardSecrecy(req.Features) {
t.Fatalf("request features = %d, want forward secrecy bit", req.Features)
}
if len(req.ClientNonce) != peerAttachNonceSize {
t.Fatalf("client nonce length = %d, want %d", len(req.ClientNonce), peerAttachNonceSize)
}
auth, err := server.validatePeerAttachRequestAuth(logical, nil, req)
if err != nil {
t.Fatalf("validatePeerAttachRequestAuth failed: %v", err)
}
if !auth.explicit || auth.fallback {
t.Fatalf("auth result mismatch: %+v", auth)
}
resp := peerAttachResponse{
PeerID: req.PeerID,
Accepted: true,
}
server.signPeerAttachResponse(logical, req, &resp, auth)
if !supportsExplicitPeerAttachAuth(resp.Features) {
t.Fatalf("response features = %d, want explicit auth bit", resp.Features)
}
if len(resp.ServerNonce) != peerAttachNonceSize {
t.Fatalf("server nonce length = %d, want %d", len(resp.ServerNonce), peerAttachNonceSize)
}
verifyResult, err := client.verifyPeerAttachResponse(req, resp, reqState)
if err != nil {
t.Fatalf("verifyPeerAttachResponse failed: %v", err)
}
if verifyResult.authFallback {
t.Fatal("explicit response should not be marked as fallback")
}
if !verifyResult.steadyProfile.forwardSecrecyFallback {
t.Fatal("response without fs extension should mark forward secrecy fallback")
}
resp.AuthTag[0] ^= 0xff
if verifyResult, err = client.verifyPeerAttachResponse(req, resp, reqState); !errors.Is(err, errPeerAttachAuthInvalid) {
t.Fatalf("tampered response error = %v, want %v", err, errPeerAttachAuthInvalid)
} else if verifyResult.authFallback {
t.Fatal("tampered explicit response should not be treated as fallback")
}
}
func TestPeerAttachRequestAuthRejectsReplay(t *testing.T) {
secret := []byte("correct horse battery staple")
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
logical := newPeerAttachAuthLogicalForTest(t, server)
req, _, err := client.buildPeerAttachRequest("peer-replay")
if err != nil {
t.Fatalf("buildPeerAttachRequest failed: %v", err)
}
if _, err := server.validatePeerAttachRequestAuth(logical, nil, req); err != nil {
t.Fatalf("first validatePeerAttachRequestAuth failed: %v", err)
}
if _, err := server.validatePeerAttachRequestAuth(logical, nil, req); !errors.Is(err, errPeerAttachReplayRejected) {
t.Fatalf("second validatePeerAttachRequestAuth error = %v, want %v", err, errPeerAttachReplayRejected)
}
classifyPeerAttachRejectCounter(server, errPeerAttachReplayRejected)
if got, want := server.peerAttachReplayRejectCountSnapshot(), int64(1); got != want {
t.Fatalf("replay reject count = %d, want %d", got, want)
}
}
func TestPeerAttachAuthFallbackCompatibility(t *testing.T) {
secret := []byte("correct horse battery staple")
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
logical := newPeerAttachAuthLogicalForTest(t, server)
auth, err := server.validatePeerAttachRequestAuth(logical, nil, peerAttachRequest{PeerID: "peer-fallback"})
if err != nil {
t.Fatalf("validatePeerAttachRequestAuth fallback failed: %v", err)
}
if auth.explicit || !auth.fallback {
t.Fatalf("fallback auth result mismatch: %+v", auth)
}
req, reqState, err := client.buildPeerAttachRequest("peer-fallback")
if err != nil {
t.Fatalf("buildPeerAttachRequest failed: %v", err)
}
verifyResult, err := client.verifyPeerAttachResponse(req, peerAttachResponse{
PeerID: req.PeerID,
Accepted: true,
}, reqState)
if err != nil {
t.Fatalf("verifyPeerAttachResponse fallback failed: %v", err)
}
if !verifyResult.authFallback {
t.Fatal("unsigned legacy response should be marked as fallback")
}
}
func TestPeerAttachForwardSecrecyNegotiatesDerivedSteadyProfile(t *testing.T) {
secret := []byte("correct horse battery staple")
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
logical := newPeerAttachAuthLogicalForTest(t, server)
req, reqState, err := client.buildPeerAttachRequest("peer-fs")
if err != nil {
t.Fatalf("buildPeerAttachRequest failed: %v", err)
}
auth, err := server.validatePeerAttachRequestAuth(logical, nil, req)
if err != nil {
t.Fatalf("validatePeerAttachRequestAuth failed: %v", err)
}
resp := peerAttachResponse{
PeerID: req.PeerID,
Accepted: true,
}
serverProfile, err := server.preparePeerAttachSteadyTransportProfile(logical, req, &resp, auth)
if err != nil {
t.Fatalf("preparePeerAttachSteadyTransportProfile failed: %v", err)
}
server.signPeerAttachResponse(logical, req, &resp, auth)
verifyResult, err := client.verifyPeerAttachResponse(req, resp, reqState)
if err != nil {
t.Fatalf("verifyPeerAttachResponse failed: %v", err)
}
if !supportsPeerAttachForwardSecrecy(resp.Features) {
t.Fatalf("response features = %d, want forward secrecy bit", resp.Features)
}
if resp.KeyMode != peerAttachKeyModeECDHE {
t.Fatalf("response key mode = %q, want %q", resp.KeyMode, peerAttachKeyModeECDHE)
}
if !verifyResult.steadyProfile.forwardSecrecy {
t.Fatal("client steady profile should enable forward secrecy")
}
if !serverProfile.forwardSecrecy {
t.Fatal("server steady profile should enable forward secrecy")
}
if len(verifyResult.steadyProfile.sessionID) == 0 {
t.Fatal("client session id should be populated")
}
if !bytes.Equal(verifyResult.steadyProfile.secretKey, serverProfile.secretKey) {
t.Fatal("client/server derived steady keys should match")
}
if !bytes.Equal(verifyResult.steadyProfile.sessionID, serverProfile.sessionID) {
t.Fatal("client/server session ids should match")
}
if bytes.Equal(verifyResult.steadyProfile.secretKey, client.securityBootstrap.secretKey) {
t.Fatal("derived steady key should differ from bootstrap key")
}
}
func TestPeerAttachForwardSecrecyStrictRejectsFallback(t *testing.T) {
secret := []byte("correct horse battery staple")
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
opts := testModernPSKOptions()
opts.RequireForwardSecrecy = true
if err := UseModernPSKClient(client, secret, opts); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, opts); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
req, reqState, err := client.buildPeerAttachRequest("peer-fs-strict")
if err != nil {
t.Fatalf("buildPeerAttachRequest failed: %v", err)
}
_, err = client.verifyPeerAttachResponse(req, peerAttachResponse{
PeerID: req.PeerID,
Accepted: true,
Features: peerAttachFeatureExplicitAuth,
ServerNonce: make([]byte, peerAttachNonceSize),
AuthTag: computePeerAttachResponseAuthTag(client.securityBootstrap.secretKey, req, peerAttachResponse{PeerID: req.PeerID, Accepted: true, Features: peerAttachFeatureExplicitAuth, ServerNonce: make([]byte, peerAttachNonceSize)}, nil),
}, reqState)
if !errors.Is(err, errPeerAttachForwardSecrecyRequired) {
t.Fatalf("verifyPeerAttachResponse error = %v, want %v", err, errPeerAttachForwardSecrecyRequired)
}
}
+139
View File
@@ -0,0 +1,139 @@
package notify
import (
"errors"
"net"
"time"
)
const defaultPeerAttachReplayCapacity = 4096
type PeerAttachChannelBindingRole string
const (
PeerAttachChannelBindingRoleClient PeerAttachChannelBindingRole = "client"
PeerAttachChannelBindingRoleServer PeerAttachChannelBindingRole = "server"
)
type PeerAttachChannelBindingContext struct {
Role PeerAttachChannelBindingRole
PeerID string
Conn net.Conn
}
type PeerAttachChannelBindingProvider func(PeerAttachChannelBindingContext) ([]byte, error)
type PeerAttachSecurityConfig struct {
RequireExplicitAuth bool
RequireChannelBinding bool
ReplayWindow time.Duration
ReplayCapacity int
ChannelBinding PeerAttachChannelBindingProvider
}
type peerAttachSecurityState struct {
requireExplicitAuth bool
requireChannelBinding bool
replayWindow time.Duration
replayCapacity int
channelBinding PeerAttachChannelBindingProvider
}
var errPeerAttachChannelBindingProviderNil = errors.New("peer attach channel binding provider is nil")
func DefaultPeerAttachSecurityConfig() PeerAttachSecurityConfig {
return PeerAttachSecurityConfig{
ReplayWindow: peerAttachReplayTTL,
ReplayCapacity: defaultPeerAttachReplayCapacity,
}
}
func normalizePeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) (peerAttachSecurityState, error) {
if cfg.ReplayWindow <= 0 {
cfg.ReplayWindow = peerAttachReplayTTL
}
if cfg.ReplayCapacity <= 0 {
cfg.ReplayCapacity = defaultPeerAttachReplayCapacity
}
if cfg.RequireChannelBinding {
cfg.RequireExplicitAuth = true
if cfg.ChannelBinding == nil {
return peerAttachSecurityState{}, errPeerAttachChannelBindingProviderNil
}
}
return peerAttachSecurityState{
requireExplicitAuth: cfg.RequireExplicitAuth,
requireChannelBinding: cfg.RequireChannelBinding,
replayWindow: cfg.ReplayWindow,
replayCapacity: cfg.ReplayCapacity,
channelBinding: cfg.ChannelBinding,
}, nil
}
func peerAttachSecurityConfigFromState(state *peerAttachSecurityState) PeerAttachSecurityConfig {
if state == nil {
return DefaultPeerAttachSecurityConfig()
}
return PeerAttachSecurityConfig{
RequireExplicitAuth: state.requireExplicitAuth,
RequireChannelBinding: state.requireChannelBinding,
ReplayWindow: state.replayWindow,
ReplayCapacity: state.replayCapacity,
ChannelBinding: state.channelBinding,
}
}
func defaultPeerAttachSecurityState() *peerAttachSecurityState {
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
return &cfg
}
func (c *ClientCommon) peerAttachSecuritySnapshot() peerAttachSecurityState {
if c == nil {
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
return cfg
}
if state := c.peerAttachSecurity.Load(); state != nil {
return *state
}
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
return cfg
}
func (s *ServerCommon) peerAttachSecuritySnapshot() peerAttachSecurityState {
if s == nil {
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
return cfg
}
if state := s.peerAttachSecurity.Load(); state != nil {
return *state
}
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
return cfg
}
func (c *ClientCommon) SetPeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) error {
state, err := normalizePeerAttachSecurityConfig(cfg)
if err != nil {
return err
}
c.peerAttachSecurity.Store(&state)
return nil
}
func (c *ClientCommon) PeerAttachSecurityConfig() PeerAttachSecurityConfig {
return peerAttachSecurityConfigFromState(c.peerAttachSecurity.Load())
}
func (s *ServerCommon) SetPeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) error {
state, err := normalizePeerAttachSecurityConfig(cfg)
if err != nil {
return err
}
s.peerAttachSecurity.Store(&state)
return nil
}
func (s *ServerCommon) PeerAttachSecurityConfig() PeerAttachSecurityConfig {
return peerAttachSecurityConfigFromState(s.peerAttachSecurity.Load())
}
+221
View File
@@ -0,0 +1,221 @@
package notify
import (
"bytes"
"errors"
"net"
"testing"
"time"
)
func staticPeerAttachChannelBindingProvider(material []byte) PeerAttachChannelBindingProvider {
cloned := bytes.Clone(material)
return func(PeerAttachChannelBindingContext) ([]byte, error) {
return bytes.Clone(cloned), nil
}
}
func failingPeerAttachChannelBindingProvider(PeerAttachChannelBindingContext) ([]byte, error) {
return nil, errors.New("binding unavailable")
}
func TestSetPeerAttachSecurityConfigRejectsMissingChannelBindingProvider(t *testing.T) {
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
cfg := PeerAttachSecurityConfig{RequireChannelBinding: true}
if err := client.SetPeerAttachSecurityConfig(cfg); !errors.Is(err, errPeerAttachChannelBindingProviderNil) {
t.Fatalf("client SetPeerAttachSecurityConfig error = %v, want %v", err, errPeerAttachChannelBindingProviderNil)
}
if err := server.SetPeerAttachSecurityConfig(cfg); !errors.Is(err, errPeerAttachChannelBindingProviderNil) {
t.Fatalf("server SetPeerAttachSecurityConfig error = %v, want %v", err, errPeerAttachChannelBindingProviderNil)
}
}
func TestPeerAttachRequireExplicitAuthRejectsFallbackClient(t *testing.T) {
secret := []byte("correct horse battery staple")
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
RequireExplicitAuth: true,
}); err != nil {
t.Fatalf("SetPeerAttachSecurityConfig failed: %v", err)
}
logical := newPeerAttachAuthLogicalForTest(t, server)
if _, err := server.validatePeerAttachRequestAuth(logical, nil, peerAttachRequest{PeerID: "peer-fallback"}); !errors.Is(err, errPeerAttachExplicitAuthRequired) {
t.Fatalf("validatePeerAttachRequestAuth error = %v, want %v", err, errPeerAttachExplicitAuthRequired)
}
classifyPeerAttachRejectCounter(server, errPeerAttachExplicitAuthRequired)
snapshot, snapErr := GetServerRuntimeSnapshot(server)
if snapErr != nil {
t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr)
}
if got, want := snapshot.PeerAttachDowngradeRejects, int64(1); got != want {
t.Fatalf("PeerAttachDowngradeRejects = %d, want %d", got, want)
}
if got := snapshot.PeerAttachAuthFallbacks; got != 0 {
t.Fatalf("PeerAttachAuthFallbacks = %d, want 0", got)
}
}
func TestPeerAttachChannelBindingRoundTrip(t *testing.T) {
secret := []byte("correct horse battery staple")
bindingProvider := staticPeerAttachChannelBindingProvider([]byte("tls-exporter:test"))
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UsePSKOverExternalTransportServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
}
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
RequireChannelBinding: true,
ChannelBinding: bindingProvider,
}); err != nil {
t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err)
}
server.SetLink("echo", func(msg *Message) {
_ = msg.Reply([]byte("ack"))
})
})
client := NewClient().(*ClientCommon)
if err := UsePSKOverExternalTransportClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
}
if err := client.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
RequireChannelBinding: true,
ChannelBinding: bindingProvider,
}); err != nil {
t.Fatalf("client SetPeerAttachSecurityConfig failed: %v", err)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachLogicalForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
reply, err := client.SendWait("echo", []byte("ping"), time.Second)
if err != nil {
t.Fatalf("SendWait failed: %v", err)
}
if got, want := string(reply.Value), "ack"; got != want {
t.Fatalf("reply = %q, want %q", got, want)
}
serverSnapshot, err := GetServerRuntimeSnapshot(server)
if err != nil {
t.Fatalf("GetServerRuntimeSnapshot failed: %v", err)
}
if !serverSnapshot.PeerAttachRequireExplicitAuth || !serverSnapshot.PeerAttachRequireChannelBinding || !serverSnapshot.PeerAttachChannelBindingConfigured {
t.Fatalf("unexpected server peer attach policy snapshot: %+v", serverSnapshot)
}
if got, want := serverSnapshot.PeerAttachExplicitAuth, int64(1); got != want {
t.Fatalf("PeerAttachExplicitAuth = %d, want %d", got, want)
}
if serverSnapshot.PeerAttachAuthRejects != 0 || serverSnapshot.PeerAttachDowngradeRejects != 0 || serverSnapshot.PeerAttachBindingRejects != 0 {
t.Fatalf("unexpected server reject counters: %+v", serverSnapshot)
}
clientSnapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if !clientSnapshot.PeerAttachRequireExplicitAuth || !clientSnapshot.PeerAttachRequireChannelBinding || !clientSnapshot.PeerAttachChannelBindingConfigured {
t.Fatalf("unexpected client peer attach policy snapshot: %+v", clientSnapshot)
}
}
func TestPeerAttachChannelBindingProviderFailureRejectsAttach(t *testing.T) {
secret := []byte("correct horse battery staple")
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
RequireChannelBinding: true,
ChannelBinding: failingPeerAttachChannelBindingProvider,
}); err != nil {
t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err)
}
})
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
RequireChannelBinding: true,
ChannelBinding: staticPeerAttachChannelBindingProvider([]byte("binding")),
}); err != nil {
t.Fatalf("client SetPeerAttachSecurityConfig failed: %v", err)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachLogicalForTest(t, server, right)
err := client.ConnectByConn(left)
if !errors.Is(err, errPeerAttachChannelBindingUnavailable) && (err == nil || err.Error() != errPeerAttachChannelBindingUnavailable.Error()) {
t.Fatalf("ConnectByConn error = %v, want %v", err, errPeerAttachChannelBindingUnavailable)
}
snapshot, snapErr := GetServerRuntimeSnapshot(server)
if snapErr != nil {
t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr)
}
if got, want := snapshot.PeerAttachBindingRejects, int64(1); got != want {
t.Fatalf("PeerAttachBindingRejects = %d, want %d", got, want)
}
}
func TestPeerAttachReplayCapacityRejectsOverflow(t *testing.T) {
secret := []byte("correct horse battery staple")
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
ReplayCapacity: 1,
}); err != nil {
t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err)
}
logical := newPeerAttachAuthLogicalForTest(t, server)
first, _, err := client.buildPeerAttachRequest("peer-one")
if err != nil {
t.Fatalf("buildPeerAttachRequest(first) failed: %v", err)
}
second, _, err := client.buildPeerAttachRequest("peer-two")
if err != nil {
t.Fatalf("buildPeerAttachRequest(second) failed: %v", err)
}
if _, err := server.validatePeerAttachRequestAuth(logical, nil, first); err != nil {
t.Fatalf("validatePeerAttachRequestAuth(first) failed: %v", err)
}
if _, err := server.validatePeerAttachRequestAuth(logical, nil, second); !errors.Is(err, errPeerAttachReplayWindowFull) {
t.Fatalf("validatePeerAttachRequestAuth(second) error = %v, want %v", err, errPeerAttachReplayWindowFull)
}
classifyPeerAttachRejectCounter(server, errPeerAttachReplayWindowFull)
snapshot, snapErr := GetServerRuntimeSnapshot(server)
if snapErr != nil {
t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr)
}
if got, want := snapshot.PeerAttachReplayCapacity, 1; got != want {
t.Fatalf("PeerAttachReplayCapacity = %d, want %d", got, want)
}
if got, want := snapshot.PeerAttachReplayOverflowRejects, int64(1); got != want {
t.Fatalf("PeerAttachReplayOverflowRejects = %d, want %d", got, want)
}
}
+69 -9
View File
@@ -15,14 +15,23 @@ const (
) )
type peerAttachRequest struct { type peerAttachRequest struct {
PeerID string PeerID string
Features uint64
ClientNonce []byte
ClientECDHEPublicKey []byte
AuthTag []byte
} }
type peerAttachResponse struct { type peerAttachResponse struct {
PeerID string PeerID string
Accepted bool Accepted bool
Reused bool Reused bool
Error string Error string
Features uint64
KeyMode string
ServerNonce []byte
ServerECDHEPublicKey []byte
AuthTag []byte
} }
func newClientPeerIdentity() string { func newClientPeerIdentity() string {
@@ -108,7 +117,11 @@ func (c *ClientCommon) announceClientPeerIdentity() error {
if peerID == "" { if peerID == "" {
return errors.New("peer identity is empty") return errors.New("peer identity is empty")
} }
encoded, err := c.sequenceEn(peerAttachRequest{PeerID: peerID}) req, requestState, err := c.buildPeerAttachRequest(peerID)
if err != nil {
return err
}
encoded, err := c.sequenceEn(req)
if err != nil { if err != nil {
return err return err
} }
@@ -133,6 +146,12 @@ func (c *ClientCommon) announceClientPeerIdentity() error {
} }
return errors.New("peer attach rejected") return errors.New("peer attach rejected")
} }
verifyResult, err := c.verifyPeerAttachResponse(req, resp, requestState)
if err != nil {
return err
}
c.setClientNegotiatedSteadyTransportProtection(verifyResult.steadyProfile)
c.markClientPeerAttachAuthenticated(verifyResult.authFallback, time.Now())
return nil return nil
} }
@@ -188,7 +207,7 @@ func (s *ServerCommon) replyPeerAttach(client *LogicalConn, message Message, res
Type: MSG_SYS_REPLY, Type: MSG_SYS_REPLY,
} }
if message.inboundConn != nil { if message.inboundConn != nil {
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply) return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, messageInboundTransportProtectionSnapshot(&message), reply)
} }
_, err = s.sendLogical(client, reply) _, err = s.sendLogical(client, reply)
return err return err
@@ -200,6 +219,10 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool {
} }
message = hydrateServerMessagePeerFields(message) message = hydrateServerMessagePeerFields(message)
current := messageLogicalConnSnapshot(&message) current := messageLogicalConnSnapshot(&message)
transport := message.inboundConn
if transport == nil && current != nil {
transport = current.transportSnapshot()
}
req, err := decodePeerAttachRequest(s.sequenceDe, message.Value) req, err := decodePeerAttachRequest(s.sequenceDe, message.Value)
if err != nil { if err != nil {
if current != nil { if current != nil {
@@ -210,6 +233,18 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool {
} }
return true return true
} }
auth, err := s.validatePeerAttachRequestAuth(current, transport, req)
if err != nil {
classifyPeerAttachRejectCounter(s, err)
if current != nil {
_ = s.replyPeerAttach(current, message, peerAttachResponse{
PeerID: req.PeerID,
Accepted: false,
Error: err.Error(),
})
}
return true
}
bound, reused, err := s.bindAcceptedClientIdentity(current, req.PeerID) bound, reused, err := s.bindAcceptedClientIdentity(current, req.PeerID)
if err != nil { if err != nil {
if current != nil { if current != nil {
@@ -221,12 +256,37 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool {
} }
return true return true
} }
if err := s.replyPeerAttach(bound, message, peerAttachResponse{ resp := peerAttachResponse{
PeerID: bound.ID(), PeerID: bound.ID(),
Accepted: true, Accepted: true,
Reused: reused, Reused: reused,
}); err != nil && bound != nil { }
steadyProfile, err := s.preparePeerAttachSteadyTransportProfile(bound, req, &resp, auth)
if err != nil {
if bound != nil {
_ = s.replyPeerAttach(bound, message, peerAttachResponse{
PeerID: req.PeerID,
Accepted: false,
Error: err.Error(),
})
}
return true
}
s.signPeerAttachResponse(bound, req, &resp, auth)
if bound != nil {
bound.markPeerAttachAuthenticated(s.securityAuthMode, auth.fallback, time.Now())
if auth.explicit {
s.peerAttachExplicitCount.Add(1)
} else if auth.fallback {
s.peerAttachAuthFallbackCount.Add(1)
}
}
if err := s.replyPeerAttach(bound, message, resp); err != nil && bound != nil {
s.stopLogicalSession(bound, "peer attach reply failed", err) s.stopLogicalSession(bound, "peer attach reply failed", err)
return true
}
if bound != nil && s.securityConfigured {
bound.applyTransportProtectionProfile(steadyProfile)
} }
return true return true
} }
+9 -1
View File
@@ -119,6 +119,7 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) {
defer serverConn.Close() defer serverConn.Close()
logical := bootstrapPeerAttachLogicalForTest(t, server, serverConn) logical := bootstrapPeerAttachLogicalForTest(t, server, serverConn)
originalProfile := logical.transportProtectionProfileSnapshot()
message := Message{ message := Message{
NetType: NET_SERVER, NetType: NET_SERVER,
LogicalConn: logical, LogicalConn: logical,
@@ -131,6 +132,13 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) {
Time: time.Now(), Time: time.Now(),
inboundConn: serverConn, inboundConn: serverConn,
} }
message = hydrateServerMessagePeerFields(message)
alternate, err := deriveModernPSKProtectionProfile([]byte("notify-peer-attach-reply-alternate"), testModernPSKOptions(), ProtectionManaged)
if err != nil {
t.Fatalf("deriveModernPSKProtectionProfile failed: %v", err)
}
logical.applyTransportProtectionProfile(alternate)
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
@@ -140,7 +148,7 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) {
}) })
}() }()
env := readServerEnvelopeFromConn(t, server, logical, clientConn, time.Second) env := readServerEnvelopeFromConnWithProfile(t, server, originalProfile, clientConn, time.Second)
if env.Kind != EnvelopeSignal { if env.Kind != EnvelopeSignal {
t.Fatalf("reply envelope kind = %v, want %v", env.Kind, EnvelopeSignal) t.Fatalf("reply envelope kind = %v, want %v", env.Kind, EnvelopeSignal)
} }
+2 -2
View File
@@ -41,7 +41,7 @@ func BenchmarkRawTCPLocalhostThroughput(b *testing.B) {
func benchmarkRawTCPLocalhostThroughput(b *testing.B, payloadSize int) { func benchmarkRawTCPLocalhostThroughput(b *testing.B, payloadSize int) {
b.Helper() b.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", benchmarkTCPListenAddr(b))
if err != nil { if err != nil {
b.Fatalf("net.Listen failed: %v", err) b.Fatalf("net.Listen failed: %v", err)
} }
@@ -60,7 +60,7 @@ func benchmarkRawTCPLocalhostThroughput(b *testing.B, payloadSize int) {
acceptCh <- conn acceptCh <- conn
}() }()
clientConn, err := net.Dial("tcp", listener.Addr().String()) clientConn, err := net.Dial("tcp", benchmarkTCPDialAddr(b, listener.Addr().String()))
if err != nil { if err != nil {
b.Fatalf("net.Dial failed: %v", err) b.Fatalf("net.Dial failed: %v", err)
} }
+117 -33
View File
@@ -6,14 +6,16 @@ import (
) )
const ( const (
recordFrameMagic = "NRS1" recordFrameMagic = "NRS1"
recordFrameVersion = 1 recordFrameVersionV1 = 1
recordFrameTypeBatch uint8 = 1 recordFrameVersionV2 = 2
recordFrameTypeAck uint8 = 2 recordFrameTypeBatch uint8 = 1
recordFrameTypeError uint8 = 3 recordFrameTypeAck uint8 = 2
recordFrameHeaderSize = 8 recordFrameTypeError uint8 = 3
recordBatchHeaderSize = 10 recordFrameHeaderSize = 8
recordErrorHeaderSize = 16 recordBatchHeaderV1Size = 10
recordBatchHeaderV2Size = 18
recordErrorHeaderSize = 16
) )
var ( var (
@@ -27,6 +29,7 @@ type recordOutboundMessage struct {
} }
type recordFrame struct { type recordFrame struct {
Version uint8
Type uint8 Type uint8
Batch []recordOutboundMessage Batch []recordOutboundMessage
AckSeq uint64 AckSeq uint64
@@ -34,7 +37,7 @@ type recordFrame struct {
Retryable bool Retryable bool
} }
func encodeRecordBatchFrame(batch []recordOutboundMessage) ([]byte, error) { func encodeRecordBatchFrame(batch []recordOutboundMessage, ackSeq uint64, useV2 bool) ([]byte, error) {
if len(batch) == 0 { if len(batch) == 0 {
return nil, nil return nil, nil
} }
@@ -42,7 +45,13 @@ func encodeRecordBatchFrame(batch []recordOutboundMessage) ([]byte, error) {
if firstSeq == 0 { if firstSeq == 0 {
return nil, errRecordSeqInvalid return nil, errRecordSeqInvalid
} }
size := recordFrameHeaderSize + recordBatchHeaderSize version := uint8(recordFrameVersionV1)
batchHeaderSize := recordBatchHeaderV1Size
if useV2 {
version = recordFrameVersionV2
batchHeaderSize = recordBatchHeaderV2Size
}
size := recordFrameHeaderSize + batchHeaderSize
for index, item := range batch { for index, item := range batch {
wantSeq := firstSeq + uint64(index) wantSeq := firstSeq + uint64(index)
if item.Seq != wantSeq { if item.Seq != wantSeq {
@@ -52,11 +61,14 @@ func encodeRecordBatchFrame(batch []recordOutboundMessage) ([]byte, error) {
} }
frame := make([]byte, size) frame := make([]byte, size)
copy(frame[:4], recordFrameMagic) copy(frame[:4], recordFrameMagic)
frame[4] = recordFrameVersion frame[4] = version
frame[5] = recordFrameTypeBatch frame[5] = recordFrameTypeBatch
binary.BigEndian.PutUint16(frame[8:10], uint16(len(batch))) binary.BigEndian.PutUint16(frame[8:10], uint16(len(batch)))
binary.BigEndian.PutUint64(frame[10:18], firstSeq) binary.BigEndian.PutUint64(frame[10:18], firstSeq)
offset := recordFrameHeaderSize + recordBatchHeaderSize offset := recordFrameHeaderSize + batchHeaderSize
if useV2 {
binary.BigEndian.PutUint64(frame[18:26], ackSeq)
}
for _, item := range batch { for _, item := range batch {
binary.BigEndian.PutUint32(frame[offset:offset+4], uint32(len(item.Payload))) binary.BigEndian.PutUint32(frame[offset:offset+4], uint32(len(item.Payload)))
offset += 4 offset += 4
@@ -69,7 +81,7 @@ func encodeRecordBatchFrame(batch []recordOutboundMessage) ([]byte, error) {
func encodeRecordAckFrame(ackSeq uint64) ([]byte, error) { func encodeRecordAckFrame(ackSeq uint64) ([]byte, error) {
frame := make([]byte, recordFrameHeaderSize+8) frame := make([]byte, recordFrameHeaderSize+8)
copy(frame[:4], recordFrameMagic) copy(frame[:4], recordFrameMagic)
frame[4] = recordFrameVersion frame[4] = recordFrameVersionV1
frame[5] = recordFrameTypeAck frame[5] = recordFrameTypeAck
binary.BigEndian.PutUint64(frame[8:16], ackSeq) binary.BigEndian.PutUint64(frame[8:16], ackSeq)
return frame, nil return frame, nil
@@ -83,7 +95,7 @@ func encodeRecordErrorFrame(failure RecordFailure) ([]byte, error) {
msgBytes := []byte(failure.Message) msgBytes := []byte(failure.Message)
frame := make([]byte, recordFrameHeaderSize+recordErrorHeaderSize+len(codeBytes)+len(msgBytes)) frame := make([]byte, recordFrameHeaderSize+recordErrorHeaderSize+len(codeBytes)+len(msgBytes))
copy(frame[:4], recordFrameMagic) copy(frame[:4], recordFrameMagic)
frame[4] = recordFrameVersion frame[4] = recordFrameVersionV1
frame[5] = recordFrameTypeError frame[5] = recordFrameTypeError
if failure.Retryable { if failure.Retryable {
frame[6] = 1 frame[6] = 1
@@ -102,30 +114,62 @@ func decodeRecordFrame(payload []byte) (recordFrame, error) {
if len(payload) < recordFrameHeaderSize || string(payload[:4]) != recordFrameMagic { if len(payload) < recordFrameHeaderSize || string(payload[:4]) != recordFrameMagic {
return recordFrame{}, errRecordFrameInvalid return recordFrame{}, errRecordFrameInvalid
} }
if payload[4] != recordFrameVersion { version := payload[4]
return recordFrame{}, errRecordFrameInvalid
}
frameType := payload[5] frameType := payload[5]
switch frameType { switch version {
case recordFrameTypeBatch: case recordFrameVersionV1:
return decodeRecordBatchFrame(payload) switch frameType {
case recordFrameTypeAck: case recordFrameTypeBatch:
if len(payload) != recordFrameHeaderSize+8 { return decodeRecordBatchFrameV1(payload)
case recordFrameTypeAck:
if len(payload) != recordFrameHeaderSize+8 {
return recordFrame{}, errRecordFrameInvalid
}
return recordFrame{
Version: recordFrameVersionV1,
Type: recordFrameTypeAck,
AckSeq: binary.BigEndian.Uint64(payload[8:16]),
}, nil
case recordFrameTypeError:
frame, err := decodeRecordErrorFrame(payload)
if err != nil {
return recordFrame{}, err
}
frame.Version = recordFrameVersionV1
return frame, nil
default:
return recordFrame{}, errRecordFrameInvalid
}
case recordFrameVersionV2:
switch frameType {
case recordFrameTypeBatch:
return decodeRecordBatchFrameV2(payload)
case recordFrameTypeAck:
if len(payload) != recordFrameHeaderSize+8 {
return recordFrame{}, errRecordFrameInvalid
}
return recordFrame{
Version: recordFrameVersionV2,
Type: recordFrameTypeAck,
AckSeq: binary.BigEndian.Uint64(payload[8:16]),
}, nil
case recordFrameTypeError:
frame, err := decodeRecordErrorFrame(payload)
if err != nil {
return recordFrame{}, err
}
frame.Version = recordFrameVersionV2
return frame, nil
default:
return recordFrame{}, errRecordFrameInvalid return recordFrame{}, errRecordFrameInvalid
} }
return recordFrame{
Type: recordFrameTypeAck,
AckSeq: binary.BigEndian.Uint64(payload[8:16]),
}, nil
case recordFrameTypeError:
return decodeRecordErrorFrame(payload)
default: default:
return recordFrame{}, errRecordFrameInvalid return recordFrame{}, errRecordFrameInvalid
} }
} }
func decodeRecordBatchFrame(payload []byte) (recordFrame, error) { func decodeRecordBatchFrameV1(payload []byte) (recordFrame, error) {
if len(payload) < recordFrameHeaderSize+recordBatchHeaderSize { if len(payload) < recordFrameHeaderSize+recordBatchHeaderV1Size {
return recordFrame{}, errRecordFrameInvalid return recordFrame{}, errRecordFrameInvalid
} }
count := int(binary.BigEndian.Uint16(payload[8:10])) count := int(binary.BigEndian.Uint16(payload[8:10]))
@@ -133,7 +177,7 @@ func decodeRecordBatchFrame(payload []byte) (recordFrame, error) {
if count <= 0 || firstSeq == 0 { if count <= 0 || firstSeq == 0 {
return recordFrame{}, errRecordFrameInvalid return recordFrame{}, errRecordFrameInvalid
} }
offset := recordFrameHeaderSize + recordBatchHeaderSize offset := recordFrameHeaderSize + recordBatchHeaderV1Size
batch := make([]recordOutboundMessage, 0, count) batch := make([]recordOutboundMessage, 0, count)
for index := 0; index < count; index++ { for index := 0; index < count; index++ {
if offset+4 > len(payload) { if offset+4 > len(payload) {
@@ -155,8 +199,48 @@ func decodeRecordBatchFrame(payload []byte) (recordFrame, error) {
return recordFrame{}, errRecordFrameInvalid return recordFrame{}, errRecordFrameInvalid
} }
return recordFrame{ return recordFrame{
Type: recordFrameTypeBatch, Version: recordFrameVersionV1,
Batch: batch, Type: recordFrameTypeBatch,
Batch: batch,
}, nil
}
func decodeRecordBatchFrameV2(payload []byte) (recordFrame, error) {
if len(payload) < recordFrameHeaderSize+recordBatchHeaderV2Size {
return recordFrame{}, errRecordFrameInvalid
}
count := int(binary.BigEndian.Uint16(payload[8:10]))
firstSeq := binary.BigEndian.Uint64(payload[10:18])
ackSeq := binary.BigEndian.Uint64(payload[18:26])
if count <= 0 || firstSeq == 0 {
return recordFrame{}, errRecordFrameInvalid
}
offset := recordFrameHeaderSize + recordBatchHeaderV2Size
batch := make([]recordOutboundMessage, 0, count)
for index := 0; index < count; index++ {
if offset+4 > len(payload) {
return recordFrame{}, errRecordFrameInvalid
}
itemLen := int(binary.BigEndian.Uint32(payload[offset : offset+4]))
offset += 4
if itemLen < 0 || offset+itemLen > len(payload) {
return recordFrame{}, errRecordFrameInvalid
}
item := recordOutboundMessage{
Seq: firstSeq + uint64(index),
Payload: append([]byte(nil), payload[offset:offset+itemLen]...),
}
offset += itemLen
batch = append(batch, item)
}
if offset != len(payload) {
return recordFrame{}, errRecordFrameInvalid
}
return recordFrame{
Version: recordFrameVersionV2,
Type: recordFrameTypeBatch,
Batch: batch,
AckSeq: ackSeq,
}, nil }, nil
} }
+73
View File
@@ -0,0 +1,73 @@
package notify
import "testing"
func TestEncodeDecodeRecordBatchFrameV1(t *testing.T) {
batch := []recordOutboundMessage{
{Seq: 7, Payload: []byte("alpha")},
{Seq: 8, Payload: []byte("beta")},
}
payload, err := encodeRecordBatchFrame(batch, 0, false)
if err != nil {
t.Fatalf("encodeRecordBatchFrame v1 failed: %v", err)
}
frame, err := decodeRecordFrame(payload)
if err != nil {
t.Fatalf("decodeRecordFrame v1 failed: %v", err)
}
if got, want := frame.Version, uint8(recordFrameVersionV1); got != want {
t.Fatalf("frame version = %d, want %d", got, want)
}
if got, want := frame.Type, recordFrameTypeBatch; got != want {
t.Fatalf("frame type = %d, want %d", got, want)
}
if frame.AckSeq != 0 {
t.Fatalf("frame ack seq = %d, want 0", frame.AckSeq)
}
if got, want := len(frame.Batch), len(batch); got != want {
t.Fatalf("batch len = %d, want %d", got, want)
}
for i := range batch {
if got, want := frame.Batch[i].Seq, batch[i].Seq; got != want {
t.Fatalf("batch[%d].seq = %d, want %d", i, got, want)
}
if got, want := string(frame.Batch[i].Payload), string(batch[i].Payload); got != want {
t.Fatalf("batch[%d].payload = %q, want %q", i, got, want)
}
}
}
func TestEncodeDecodeRecordBatchFrameV2CarriesAckSeq(t *testing.T) {
batch := []recordOutboundMessage{
{Seq: 11, Payload: []byte("alpha")},
{Seq: 12, Payload: []byte("beta")},
}
payload, err := encodeRecordBatchFrame(batch, 9, true)
if err != nil {
t.Fatalf("encodeRecordBatchFrame v2 failed: %v", err)
}
frame, err := decodeRecordFrame(payload)
if err != nil {
t.Fatalf("decodeRecordFrame v2 failed: %v", err)
}
if got, want := frame.Version, uint8(recordFrameVersionV2); got != want {
t.Fatalf("frame version = %d, want %d", got, want)
}
if got, want := frame.Type, recordFrameTypeBatch; got != want {
t.Fatalf("frame type = %d, want %d", got, want)
}
if got, want := frame.AckSeq, uint64(9); got != want {
t.Fatalf("frame ack seq = %d, want %d", got, want)
}
if got, want := len(frame.Batch), len(batch); got != want {
t.Fatalf("batch len = %d, want %d", got, want)
}
for i := range batch {
if got, want := frame.Batch[i].Seq, batch[i].Seq; got != want {
t.Fatalf("batch[%d].seq = %d, want %d", i, got, want)
}
if got, want := string(frame.Batch[i].Payload), string(batch[i].Payload); got != want {
t.Fatalf("batch[%d].payload = %q, want %q", i, got, want)
}
}
}
+48
View File
@@ -0,0 +1,48 @@
package notify
const (
recordStreamMetadataCapBatchAckKey = "_notify.record_cap_batch_ack"
recordStreamMetadataUseBatchAckKey = "_notify.record_use_batch_ack"
recordStreamMetadataEnabledValue = "1"
)
func advertiseRecordStreamOpenMetadata(metadata StreamMetadata) StreamMetadata {
metadata = cloneStreamMetadata(metadata)
if metadata == nil {
metadata = make(StreamMetadata, 1)
}
metadata[recordStreamMetadataCapBatchAckKey] = recordStreamMetadataEnabledValue
return metadata
}
func negotiateRecordStreamOpenMetadata(channel StreamChannel, metadata StreamMetadata) (StreamMetadata, StreamMetadata) {
metadata = cloneStreamMetadata(metadata)
if normalizeStreamChannel(channel) != StreamRecordChannel {
return metadata, nil
}
if metadata[recordStreamMetadataCapBatchAckKey] != recordStreamMetadataEnabledValue {
return metadata, nil
}
metadata[recordStreamMetadataUseBatchAckKey] = recordStreamMetadataEnabledValue
return metadata, StreamMetadata{
recordStreamMetadataUseBatchAckKey: recordStreamMetadataEnabledValue,
}
}
func mergeStreamMetadata(base StreamMetadata, overlay StreamMetadata) StreamMetadata {
if len(base) == 0 && len(overlay) == 0 {
return nil
}
merged := cloneStreamMetadata(base)
if merged == nil {
merged = make(StreamMetadata, len(overlay))
}
for key, value := range overlay {
merged[key] = value
}
return merged
}
func recordStreamUseBatchAck(metadata StreamMetadata) bool {
return metadata[recordStreamMetadataUseBatchAckKey] == recordStreamMetadataEnabledValue
}
+78
View File
@@ -0,0 +1,78 @@
package notify
import (
"context"
"net"
"testing"
"time"
)
func TestNegotiateRecordStreamOpenMetadataEnablesBatchAck(t *testing.T) {
reqMetadata, respMetadata := negotiateRecordStreamOpenMetadata(StreamRecordChannel, StreamMetadata{
recordStreamMetadataCapBatchAckKey: recordStreamMetadataEnabledValue,
})
if !recordStreamUseBatchAck(reqMetadata) {
t.Fatal("request metadata should enable batch ack")
}
if !recordStreamUseBatchAck(respMetadata) {
t.Fatal("response metadata should enable batch ack")
}
}
func TestNegotiateRecordStreamOpenMetadataKeepsFallbackWithoutCapability(t *testing.T) {
reqMetadata, respMetadata := negotiateRecordStreamOpenMetadata(StreamRecordChannel, nil)
if recordStreamUseBatchAck(reqMetadata) {
t.Fatalf("request metadata should keep fallback mode: %+v", reqMetadata)
}
if recordStreamUseBatchAck(respMetadata) {
t.Fatalf("response metadata should keep fallback mode: %+v", respMetadata)
}
}
func TestOpenRecordStreamNegotiatesBatchAck(t *testing.T) {
server := NewServer().(*ServerCommon)
secret := []byte("0123456789abcdef0123456789abcdef")
server = newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
server.SetSecretKey(secret)
})
acceptedCh := make(chan RecordAcceptInfo, 1)
server.SetRecordStreamHandler(func(info RecordAcceptInfo) error {
acceptedCh <- info
return nil
})
client := NewClient().(*ClientCommon)
client.SetSecretKey(secret)
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachConnForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("client ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
record, err := client.OpenRecordStream(context.Background(), RecordOpenOptions{})
if err != nil {
t.Fatalf("OpenRecordStream failed: %v", err)
}
defer func() {
_ = record.Close()
}()
if !recordStreamUseBatchAck(record.Metadata()) {
t.Fatalf("client record stream metadata should negotiate batch ack: %+v", record.Metadata())
}
select {
case accepted := <-acceptedCh:
if !recordStreamUseBatchAck(accepted.Metadata) {
t.Fatalf("accepted record metadata should negotiate batch ack: %+v", accepted.Metadata)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting accepted record stream")
}
}
+125 -2
View File
@@ -1,14 +1,20 @@
package notify package notify
import "sync" import (
"strconv"
"sync"
)
type recordRuntime struct { type recordRuntime struct {
mu sync.RWMutex mu sync.RWMutex
handler func(RecordAcceptInfo) error handler func(RecordAcceptInfo) error
records map[string]*recordStream
} }
func newRecordRuntime() *recordRuntime { func newRecordRuntime() *recordRuntime {
return &recordRuntime{} return &recordRuntime{
records: make(map[string]*recordStream),
}
} }
func (r *recordRuntime) setHandler(fn func(RecordAcceptInfo) error) { func (r *recordRuntime) setHandler(fn func(RecordAcceptInfo) error) {
@@ -42,3 +48,120 @@ func (s *ServerCommon) getRecordRuntime() *recordRuntime {
} }
return s.recordRuntime return s.recordRuntime
} }
func (r *recordRuntime) register(record *recordStream) {
if r == nil || record == nil {
return
}
key := record.runtimeRegistryKey()
if key == "" {
return
}
r.mu.Lock()
r.records[key] = record
r.mu.Unlock()
}
func (r *recordRuntime) remove(key string) {
if r == nil || key == "" {
return
}
r.mu.Lock()
delete(r.records, key)
r.mu.Unlock()
}
func (r *recordRuntime) snapshots() []RecordSnapshot {
if r == nil {
return nil
}
r.mu.RLock()
records := make([]*recordStream, 0, len(r.records))
for _, record := range r.records {
if record == nil {
continue
}
records = append(records, record)
}
r.mu.RUnlock()
snapshots := make([]RecordSnapshot, 0, len(records))
for _, record := range records {
snapshots = append(snapshots, record.snapshot())
}
sortRecordSnapshots(snapshots)
return snapshots
}
func bindRecordRuntime(record RecordStream, runtime *recordRuntime) {
if runtime == nil || record == nil {
return
}
rs, ok := record.(*recordStream)
if !ok {
return
}
rs.bindRuntime(runtime)
}
func (r *recordStream) bindRuntime(runtime *recordRuntime) {
if r == nil || runtime == nil {
return
}
key := r.runtimeRegistryKey()
if key == "" {
return
}
r.mu.Lock()
r.runtime = runtime
r.runtimeKey = key
r.mu.Unlock()
runtime.register(r)
r.runtimeWatchOnce.Do(func() {
go func() {
streamCtx := r.stream.Context()
if streamCtx == nil {
<-r.ctx.Done()
} else {
select {
case <-r.ctx.Done():
case <-streamCtx.Done():
}
}
r.detachRuntime()
}()
})
}
func (r *recordStream) detachRuntime() {
if r == nil {
return
}
r.runtimeDetachOnce.Do(func() {
r.mu.Lock()
runtime := r.runtime
key := r.runtimeKey
r.runtime = nil
r.runtimeKey = ""
r.mu.Unlock()
if runtime != nil {
runtime.remove(key)
}
})
}
func (r *recordStream) runtimeRegistryKey() string {
if r == nil || r.stream == nil {
return ""
}
scope := ""
dataID := uint64(0)
if stream, ok := r.stream.(*streamHandle); ok {
scope = normalizeFileScope(stream.runtimeScope)
dataID = stream.dataID
}
key := scope + "\x00" + r.stream.ID()
if dataID != 0 {
key += "\x01" + strconv.FormatUint(dataID, 10)
}
return key
}
+252
View File
@@ -0,0 +1,252 @@
package notify
import (
"errors"
"io"
"sort"
"time"
)
type RecordSnapshot struct {
ID string
DataID uint64
Scope string
Metadata StreamMetadata
UseBatchAck bool
BindingOwner string
BindingAlive bool
BindingCurrent bool
BindingReason string
BindingError string
SessionEpoch uint64
LogicalClientID string
LocalAddress string
RemoteAddress string
TransportGeneration uint64
TransportAttached bool
TransportHasRuntimeConn bool
TransportCurrent bool
TransportDetachReason string
TransportDetachKind string
TransportDetachGeneration uint64
TransportDetachError string
TransportDetachedAt time.Time
ReattachEligible bool
LocalClosed bool
LocalReadClosed bool
RemoteClosed bool
PeerReadClosed bool
OutboundClosed bool
NextOutboundSeq uint64
EnqueuedOutboundSeq uint64
FlushedOutboundSeq uint64
AckedOutboundSeq uint64
OutstandingRecords int
OutstandingBytes int
InboundReceivedSeq uint64
InboundAppliedSeq uint64
InboundAckSentSeq uint64
PendingApplyRecords int
PendingAckRecords int
PeakPendingApplyRecords int
BatchFramesSent int64
AckFramesSent int64
ErrorFramesSent int64
BatchFramesReceived int64
AckFramesReceived int64
ErrorFramesReceived int64
PiggybackAckSent int64
PiggybackAckReceived int64
BarrierCount int64
BarrierFlushWaitDuration time.Duration
BarrierApplyWaitDuration time.Duration
OpenedAt time.Time
LastReadAt time.Time
LastWriteAt time.Time
StreamResetError string
ReadError string
TerminalError string
ResetError string
}
type clientRecordSnapshotReader interface {
clientRecordSnapshots() []RecordSnapshot
}
type serverRecordSnapshotReader interface {
serverRecordSnapshots() []RecordSnapshot
}
var (
errClientRecordSnapshotNil = errors.New("client record snapshot target is nil")
errServerRecordSnapshotNil = errors.New("server record snapshot target is nil")
errClientRecordSnapshotUnsupported = errors.New("client record snapshot target type is unsupported")
errServerRecordSnapshotUnsupported = errors.New("server record snapshot target type is unsupported")
)
func GetClientRecordSnapshots(c Client) ([]RecordSnapshot, error) {
if c == nil {
return nil, errClientRecordSnapshotNil
}
reader, ok := any(c).(clientRecordSnapshotReader)
if !ok {
return nil, errClientRecordSnapshotUnsupported
}
return reader.clientRecordSnapshots(), nil
}
func GetServerRecordSnapshots(s Server) ([]RecordSnapshot, error) {
if s == nil {
return nil, errServerRecordSnapshotNil
}
reader, ok := any(s).(serverRecordSnapshotReader)
if !ok {
return nil, errServerRecordSnapshotUnsupported
}
return reader.serverRecordSnapshots(), nil
}
func (c *ClientCommon) clientRecordSnapshots() []RecordSnapshot {
return recordSnapshotsFromRuntime(c.getRecordRuntime())
}
func (s *ServerCommon) serverRecordSnapshots() []RecordSnapshot {
return recordSnapshotsFromRuntime(s.getRecordRuntime())
}
func recordSnapshotsFromRuntime(runtime *recordRuntime) []RecordSnapshot {
if runtime == nil {
return nil
}
return runtime.snapshots()
}
func sortRecordSnapshots(src []RecordSnapshot) {
sort.Slice(src, func(i, j int) bool {
if src[i].Scope != src[j].Scope {
return src[i].Scope < src[j].Scope
}
if src[i].ID != src[j].ID {
return src[i].ID < src[j].ID
}
if src[i].DataID != src[j].DataID {
return src[i].DataID < src[j].DataID
}
return src[i].TransportGeneration < src[j].TransportGeneration
})
}
func (r *recordStream) snapshot() RecordSnapshot {
if r == nil {
return RecordSnapshot{}
}
snapshot := RecordSnapshot{}
if stream, ok := r.stream.(*streamHandle); ok {
snapshot = recordSnapshotFromStreamSnapshot(stream.snapshot())
} else if r.stream != nil {
snapshot.ID = r.stream.ID()
snapshot.Metadata = cloneStreamMetadata(r.stream.Metadata())
snapshot.TransportGeneration = r.stream.TransportGeneration()
if addr := r.stream.LocalAddr(); addr != nil {
snapshot.LocalAddress = addr.String()
}
if addr := r.stream.RemoteAddr(); addr != nil {
snapshot.RemoteAddress = addr.String()
}
if logical := r.stream.LogicalConn(); logical != nil {
snapshot.LogicalClientID = logical.ID()
}
}
snapshot.UseBatchAck = r.useBatchAck
snapshot.BatchFramesSent = r.obs.batchFramesSent.Load()
snapshot.AckFramesSent = r.obs.ackFramesSent.Load()
snapshot.ErrorFramesSent = r.obs.errorFramesSent.Load()
snapshot.BatchFramesReceived = r.obs.batchFramesReceived.Load()
snapshot.AckFramesReceived = r.obs.ackFramesReceived.Load()
snapshot.ErrorFramesReceived = r.obs.errorFramesReceived.Load()
snapshot.PiggybackAckSent = r.obs.piggybackAckSent.Load()
snapshot.PiggybackAckReceived = r.obs.piggybackAckReceived.Load()
snapshot.BarrierCount = r.obs.barrierCount.Load()
snapshot.BarrierFlushWaitDuration = time.Duration(r.obs.barrierFlushWaitNanos.Load())
snapshot.BarrierApplyWaitDuration = time.Duration(r.obs.barrierApplyWaitNanos.Load())
r.mu.Lock()
snapshot.OutboundClosed = r.outboundClosed
snapshot.NextOutboundSeq = r.nextOutboundSeq
snapshot.EnqueuedOutboundSeq = r.enqueuedOutboundSeq
snapshot.FlushedOutboundSeq = r.flushedOutboundSeq
snapshot.AckedOutboundSeq = r.ackedOutboundSeq
snapshot.OutstandingRecords = r.outstandingRecords
snapshot.OutstandingBytes = r.outstandingBytes
snapshot.InboundReceivedSeq = r.inboundReceivedSeq
snapshot.InboundAppliedSeq = r.inboundAppliedSeq
snapshot.InboundAckSentSeq = r.inboundAckSentSeq
snapshot.PendingApplyRecords = recordPendingCount(r.inboundReceivedSeq, r.inboundAppliedSeq)
snapshot.PendingAckRecords = recordPendingCount(r.inboundAppliedSeq, r.inboundAckSentSeq)
snapshot.PeakPendingApplyRecords = r.maxPendingApply
if r.readErr != nil && !errors.Is(r.readErr, io.EOF) {
snapshot.ReadError = r.readErr.Error()
}
if r.terminalErr != nil {
snapshot.TerminalError = r.terminalErr.Error()
}
r.mu.Unlock()
switch {
case snapshot.TerminalError != "":
snapshot.ResetError = snapshot.TerminalError
case snapshot.StreamResetError != "":
snapshot.ResetError = snapshot.StreamResetError
case snapshot.ReadError != "":
snapshot.ResetError = snapshot.ReadError
}
return snapshot
}
func recordSnapshotFromStreamSnapshot(stream StreamSnapshot) RecordSnapshot {
return RecordSnapshot{
ID: stream.ID,
DataID: stream.DataID,
Scope: stream.Scope,
Metadata: cloneStreamMetadata(stream.Metadata),
BindingOwner: stream.BindingOwner,
BindingAlive: stream.BindingAlive,
BindingCurrent: stream.BindingCurrent,
BindingReason: stream.BindingReason,
BindingError: stream.BindingError,
SessionEpoch: stream.SessionEpoch,
LogicalClientID: stream.LogicalClientID,
LocalAddress: stream.LocalAddress,
RemoteAddress: stream.RemoteAddress,
TransportGeneration: stream.TransportGeneration,
TransportAttached: stream.TransportAttached,
TransportHasRuntimeConn: stream.TransportHasRuntimeConn,
TransportCurrent: stream.TransportCurrent,
TransportDetachReason: stream.TransportDetachReason,
TransportDetachKind: stream.TransportDetachKind,
TransportDetachGeneration: stream.TransportDetachGeneration,
TransportDetachError: stream.TransportDetachError,
TransportDetachedAt: stream.TransportDetachedAt,
ReattachEligible: stream.ReattachEligible,
LocalClosed: stream.LocalClosed,
LocalReadClosed: stream.LocalReadClosed,
RemoteClosed: stream.RemoteClosed,
PeerReadClosed: stream.PeerReadClosed,
OpenedAt: stream.OpenedAt,
LastReadAt: stream.LastReadAt,
LastWriteAt: stream.LastWriteAt,
StreamResetError: stream.ResetError,
}
}
func recordPendingCount(high uint64, low uint64) int {
if high <= low {
return 0
}
diff := high - low
maxInt := uint64(^uint(0) >> 1)
if diff > maxInt {
return int(maxInt)
}
return int(diff)
}
+231 -109
View File
@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
@@ -103,25 +104,47 @@ type recordConfig struct {
type recordFlushRequest struct { type recordFlushRequest struct {
targetSeq uint64 targetSeq uint64
forceAck bool
done chan error done chan error
} }
type recordObservability struct {
batchFramesSent atomic.Int64
ackFramesSent atomic.Int64
errorFramesSent atomic.Int64
batchFramesReceived atomic.Int64
ackFramesReceived atomic.Int64
errorFramesReceived atomic.Int64
piggybackAckSent atomic.Int64
piggybackAckReceived atomic.Int64
barrierCount atomic.Int64
barrierFlushWaitNanos atomic.Int64
barrierApplyWaitNanos atomic.Int64
}
type recordStream struct { type recordStream struct {
stream Stream stream Stream
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
cfg recordConfig cfg recordConfig
writeMu sync.Mutex writeMu sync.Mutex
sendCh chan recordOutboundMessage sendCh chan recordOutboundMessage
flushCh chan recordFlushRequest flushCh chan recordFlushRequest
recvCh chan RecordMessage recvCh chan RecordMessage
ackCh chan struct{} ackCh chan struct{}
readerCh chan struct{} readerCh chan struct{}
useBatchAck bool
obs recordObservability
mu sync.Mutex mu sync.Mutex
stateNotify chan struct{} stateNotify chan struct{}
runtime *recordRuntime
runtimeKey string
runtimeWatchOnce sync.Once
runtimeDetachOnce sync.Once
nextOutboundSeq uint64 nextOutboundSeq uint64
enqueuedOutboundSeq uint64 enqueuedOutboundSeq uint64
flushedOutboundSeq uint64 flushedOutboundSeq uint64
@@ -135,6 +158,7 @@ type recordStream struct {
inboundAppliedSeq uint64 inboundAppliedSeq uint64
inboundApplied map[uint64]struct{} inboundApplied map[uint64]struct{}
inboundAckSentSeq uint64 inboundAckSentSeq uint64
maxPendingApply int
remoteClosed bool remoteClosed bool
readErr error readErr error
@@ -193,6 +217,7 @@ func recordConfigFromOptions(opt RecordOpenOptions) recordConfig {
func normalizeRecordStreamOpenOptions(opt StreamOpenOptions) StreamOpenOptions { func normalizeRecordStreamOpenOptions(opt StreamOpenOptions) StreamOpenOptions {
opt.Channel = StreamRecordChannel opt.Channel = StreamRecordChannel
opt.Metadata = advertiseRecordStreamOpenMetadata(opt.Metadata)
return opt return opt
} }
@@ -207,22 +232,22 @@ func WrapStreamAsRecord(stream Stream, opt RecordOpenOptions) (RecordStream, err
} }
ctx, cancel := context.WithCancel(parent) ctx, cancel := context.WithCancel(parent)
record := &recordStream{ record := &recordStream{
stream: stream, stream: stream,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
cfg: recordConfigFromOptions(opt), cfg: recordConfigFromOptions(opt),
sendCh: make(chan recordOutboundMessage, opt.MaxBatchRecords*2), sendCh: make(chan recordOutboundMessage, opt.MaxBatchRecords*2),
flushCh: make(chan recordFlushRequest), flushCh: make(chan recordFlushRequest),
recvCh: make(chan RecordMessage, opt.InboundQueueLimit), recvCh: make(chan RecordMessage, opt.InboundQueueLimit),
ackCh: make(chan struct{}, 1), ackCh: make(chan struct{}, 1),
readerCh: make(chan struct{}), readerCh: make(chan struct{}),
useBatchAck: recordStreamUseBatchAck(stream.Metadata()),
stateNotify: make(chan struct{}), stateNotify: make(chan struct{}),
outstandingSizes: make(map[uint64]int), outstandingSizes: make(map[uint64]int),
inboundApplied: make(map[uint64]struct{}), inboundApplied: make(map[uint64]struct{}),
} }
go record.sendLoop() go record.writerLoop()
go record.ackLoop()
go record.readLoop() go record.readLoop()
return record, nil return record, nil
} }
@@ -360,13 +385,20 @@ func (r *recordStream) BarrierTo(ctx context.Context, target uint64) (uint64, er
if target > current { if target > current {
return 0, errRecordSeqInvalid return 0, errRecordSeqInvalid
} }
if err := r.Flush(ctx); err != nil { r.obs.barrierCount.Add(1)
flushStart := time.Now()
err := r.Flush(ctx)
r.obs.barrierFlushWaitNanos.Add(time.Since(flushStart).Nanoseconds())
if err != nil {
return 0, err return 0, err
} }
if target == 0 { if target == 0 {
return 0, nil return 0, nil
} }
if err := r.waitAckedAtLeast(ctx, target); err != nil { applyStart := time.Now()
err = r.waitAckedAtLeast(ctx, target)
r.obs.barrierApplyWaitNanos.Add(time.Since(applyStart).Nanoseconds())
if err != nil {
return 0, err return 0, err
} }
return target, nil return target, nil
@@ -520,54 +552,118 @@ func (r *recordStream) waitAckedAtLeast(ctx context.Context, target uint64) erro
} }
} }
func (r *recordStream) sendLoop() { func (r *recordStream) writerLoop() {
var ( var (
batch []recordOutboundMessage batch []recordOutboundMessage
batches int batches int
bytes int bytes int
timer *time.Timer batchTimer *time.Timer
timerCh <-chan time.Time batchTimerCh <-chan time.Time
ackTimer *time.Timer
ackTimerCh <-chan time.Time
) )
stopTimer := func() { stopBatchTimer := func() {
if timer == nil { if batchTimer == nil {
return return
} }
if !timer.Stop() { if !batchTimer.Stop() {
select { select {
case <-timer.C: case <-batchTimer.C:
default: default:
} }
} }
timerCh = nil batchTimerCh = nil
} }
flush := func() error { stopAckTimer := func() {
if len(batch) == 0 { if ackTimer == nil {
return
}
if !ackTimer.Stop() {
select {
case <-ackTimer.C:
default:
}
}
ackTimerCh = nil
}
scheduleAck := func(hasPendingBatch bool, force bool) (uint64, bool) {
ackSeq := r.pendingAckSeq()
if ackSeq == 0 {
stopAckTimer()
return 0, false
}
if force {
stopAckTimer()
return ackSeq, true
}
if hasPendingBatch && r.useBatchAck {
stopAckTimer()
return 0, false
}
if r.shouldSendAckNow() || r.cfg.AckDelay <= 0 {
stopAckTimer()
return ackSeq, true
}
if ackTimer == nil {
ackTimer = time.NewTimer(r.cfg.AckDelay)
} else {
ackTimer.Reset(r.cfg.AckDelay)
}
ackTimerCh = ackTimer.C
return 0, false
}
sendStandaloneAck := func(ackSeq uint64) error {
if ackSeq == 0 {
return nil return nil
} }
payload, err := encodeRecordBatchFrame(batch) payload, err := encodeRecordAckFrame(ackSeq)
if err != nil { if err != nil {
return err return err
} }
if err := r.writePayloadFrame(payload); err != nil { if err := r.writePayloadFrame(payload); err != nil {
return err return err
} }
r.obs.ackFramesSent.Add(1)
r.markAckSent(ackSeq)
return nil
}
flushBatch := func() error {
if len(batch) == 0 {
return nil
}
ackSeq := r.pendingAckSeq()
payload, err := encodeRecordBatchFrame(batch, ackSeq, r.useBatchAck)
if err != nil {
return err
}
if err := r.writePayloadFrame(payload); err != nil {
return err
}
r.obs.batchFramesSent.Add(1)
if r.useBatchAck && ackSeq != 0 {
r.obs.piggybackAckSent.Add(1)
r.markAckSent(ackSeq)
}
r.markFlushed(batch[len(batch)-1].Seq) r.markFlushed(batch[len(batch)-1].Seq)
batch = nil batch = nil
batches = 0 batches = 0
bytes = 0 bytes = 0
stopTimer() stopBatchTimer()
if ackSeq, sendNow := scheduleAck(false, false); sendNow {
return sendStandaloneAck(ackSeq)
}
return nil return nil
} }
flushUntil := func(target uint64) error { flushUntil := func(target uint64) error {
for { for {
if target == 0 { if target == 0 {
return flush() return flushBatch()
} }
if r.flushedAtLeast(target) { if r.flushedAtLeast(target) {
return nil return nil
} }
if len(batch) > 0 && batch[len(batch)-1].Seq >= target { if len(batch) > 0 && batch[len(batch)-1].Seq >= target {
if err := flush(); err != nil { if err := flushBatch(); err != nil {
return err return err
} }
if r.flushedAtLeast(target) { if r.flushedAtLeast(target) {
@@ -583,7 +679,7 @@ func (r *recordStream) sendLoop() {
batches++ batches++
bytes += len(req.Payload) bytes += len(req.Payload)
if batches >= r.cfg.MaxBatchRecords || bytes >= r.cfg.MaxBatchBytes { if batches >= r.cfg.MaxBatchRecords || bytes >= r.cfg.MaxBatchBytes {
if err := flush(); err != nil { if err := flushBatch(); err != nil {
return err return err
} }
} }
@@ -598,72 +694,54 @@ func (r *recordStream) sendLoop() {
batches++ batches++
bytes += len(req.Payload) bytes += len(req.Payload)
if len(batch) == 1 && r.cfg.MaxBatchDelay > 0 { if len(batch) == 1 && r.cfg.MaxBatchDelay > 0 {
if timer == nil { if batchTimer == nil {
timer = time.NewTimer(r.cfg.MaxBatchDelay) batchTimer = time.NewTimer(r.cfg.MaxBatchDelay)
} else { } else {
timer.Reset(r.cfg.MaxBatchDelay) batchTimer.Reset(r.cfg.MaxBatchDelay)
} }
timerCh = timer.C batchTimerCh = batchTimer.C
} }
if batches >= r.cfg.MaxBatchRecords || bytes >= r.cfg.MaxBatchBytes { if batches >= r.cfg.MaxBatchRecords || bytes >= r.cfg.MaxBatchBytes {
if err := flush(); err != nil { if err := flushBatch(); err != nil {
r.setTerminalError(err)
return
}
}
case req := <-r.flushCh:
req.done <- flushUntil(req.targetSeq)
case <-timerCh:
if err := flush(); err != nil {
r.setTerminalError(err)
return
}
}
}
}
func (r *recordStream) ackLoop() {
var (
timer *time.Timer
timerCh <-chan time.Time
)
stopTimer := func() {
if timer == nil {
return
}
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timerCh = nil
}
for {
select {
case <-r.ctx.Done():
return
case <-r.ackCh:
if r.shouldSendAckNow() {
stopTimer()
if err := r.flushAckNow(); err != nil {
r.setTerminalError(err) r.setTerminalError(err)
return return
} }
continue continue
} }
if timer == nil { if ackSeq, sendNow := scheduleAck(len(batch) > 0, false); sendNow {
timer = time.NewTimer(r.cfg.AckDelay) if err := sendStandaloneAck(ackSeq); err != nil {
} else { r.setTerminalError(err)
timer.Reset(r.cfg.AckDelay) return
}
} }
timerCh = timer.C case req := <-r.flushCh:
case <-timerCh: err := flushUntil(req.targetSeq)
stopTimer() if err == nil && req.forceAck {
if err := r.flushAckNow(); err != nil { if ackSeq, sendNow := scheduleAck(len(batch) > 0, true); sendNow {
err = sendStandaloneAck(ackSeq)
}
}
req.done <- err
case <-batchTimerCh:
if err := flushBatch(); err != nil {
r.setTerminalError(err) r.setTerminalError(err)
return return
} }
case <-r.ackCh:
if ackSeq, sendNow := scheduleAck(len(batch) > 0, false); sendNow {
if err := sendStandaloneAck(ackSeq); err != nil {
r.setTerminalError(err)
return
}
}
case <-ackTimerCh:
stopAckTimer()
if ackSeq, sendNow := scheduleAck(len(batch) > 0, true); sendNow {
if err := sendStandaloneAck(ackSeq); err != nil {
r.setTerminalError(err)
return
}
}
} }
} }
} }
@@ -694,6 +772,15 @@ func (r *recordStream) readLoop() {
} }
switch frame.Type { switch frame.Type {
case recordFrameTypeBatch: case recordFrameTypeBatch:
r.obs.batchFramesReceived.Add(1)
if frame.AckSeq != 0 {
r.obs.piggybackAckReceived.Add(1)
if err := r.handleAckFrame(frame.AckSeq); err != nil {
r.setReadError(err)
_ = r.stream.Reset(err)
return
}
}
if err := r.handleBatchFrame(frame.Batch); err != nil { if err := r.handleBatchFrame(frame.Batch); err != nil {
_ = r.sendFailureFrame(RecordFailure{ _ = r.sendFailureFrame(RecordFailure{
FailedSeq: r.nextInboundFailureSeq(), FailedSeq: r.nextInboundFailureSeq(),
@@ -705,12 +792,14 @@ func (r *recordStream) readLoop() {
return return
} }
case recordFrameTypeAck: case recordFrameTypeAck:
r.obs.ackFramesReceived.Add(1)
if err := r.handleAckFrame(frame.AckSeq); err != nil { if err := r.handleAckFrame(frame.AckSeq); err != nil {
r.setReadError(err) r.setReadError(err)
_ = r.stream.Reset(err) _ = r.stream.Reset(err)
return return
} }
case recordFrameTypeError: case recordFrameTypeError:
r.obs.errorFramesReceived.Add(1)
r.setReadError(frame.Failure) r.setReadError(frame.Failure)
return return
default: default:
@@ -732,6 +821,7 @@ func (r *recordStream) handleBatchFrame(batch []recordOutboundMessage) error {
} }
lastSeq := batch[len(batch)-1].Seq lastSeq := batch[len(batch)-1].Seq
r.inboundReceivedSeq = lastSeq r.inboundReceivedSeq = lastSeq
r.updatePendingApplyLocked()
r.signalStateLocked() r.signalStateLocked()
r.mu.Unlock() r.mu.Unlock()
for _, item := range batch { for _, item := range batch {
@@ -889,23 +979,21 @@ func (r *recordStream) shouldSendAckNow() bool {
return r.inboundAppliedSeq > r.inboundAckSentSeq && int(r.inboundAppliedSeq-r.inboundAckSentSeq) >= r.cfg.AckEveryRecords return r.inboundAppliedSeq > r.inboundAckSentSeq && int(r.inboundAppliedSeq-r.inboundAckSentSeq) >= r.cfg.AckEveryRecords
} }
func (r *recordStream) flushAckNow() error { func (r *recordStream) pendingAckSeq() uint64 {
if r == nil { if r == nil {
return errRecordStreamNil return 0
} }
r.mu.Lock() r.mu.Lock()
ackSeq := r.inboundAppliedSeq defer r.mu.Unlock()
if ackSeq <= r.inboundAckSentSeq { if r.inboundAppliedSeq <= r.inboundAckSentSeq {
r.mu.Unlock() return 0
return nil
} }
r.mu.Unlock() return r.inboundAppliedSeq
payload, err := encodeRecordAckFrame(ackSeq) }
if err != nil {
return err func (r *recordStream) markAckSent(ackSeq uint64) {
} if r == nil || ackSeq == 0 {
if err := r.writePayloadFrame(payload); err != nil { return
return err
} }
r.mu.Lock() r.mu.Lock()
if ackSeq > r.inboundAckSentSeq { if ackSeq > r.inboundAckSentSeq {
@@ -913,7 +1001,27 @@ func (r *recordStream) flushAckNow() error {
r.signalStateLocked() r.signalStateLocked()
} }
r.mu.Unlock() r.mu.Unlock()
return nil }
func (r *recordStream) flushAckNow() error {
if r == nil {
return errRecordStreamNil
}
req := recordFlushRequest{
forceAck: true,
done: make(chan error, 1),
}
select {
case <-r.ctx.Done():
return r.streamError()
case r.flushCh <- req:
}
select {
case <-r.ctx.Done():
return r.streamError()
case err := <-req.done:
return err
}
} }
func (r *recordStream) sendFailureFrame(failure RecordFailure) error { func (r *recordStream) sendFailureFrame(failure RecordFailure) error {
@@ -921,7 +1029,11 @@ func (r *recordStream) sendFailureFrame(failure RecordFailure) error {
if err != nil { if err != nil {
return err return err
} }
return r.writePayloadFrame(payload) if err := r.writePayloadFrame(payload); err != nil {
return err
}
r.obs.errorFramesSent.Add(1)
return nil
} }
func (r *recordStream) writePayloadFrame(payload []byte) error { func (r *recordStream) writePayloadFrame(payload []byte) error {
@@ -972,3 +1084,13 @@ func (r *recordStream) signalStateLocked() {
close(r.stateNotify) close(r.stateNotify)
r.stateNotify = make(chan struct{}) r.stateNotify = make(chan struct{})
} }
func (r *recordStream) updatePendingApplyLocked() {
if r == nil {
return
}
pending := recordPendingCount(r.inboundReceivedSeq, r.inboundAppliedSeq)
if pending > r.maxPendingApply {
r.maxPendingApply = pending
}
}
+134
View File
@@ -5,6 +5,7 @@ import (
"errors" "errors"
"net" "net"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
) )
@@ -14,6 +15,29 @@ type releaseP0TestAddr string
func (a releaseP0TestAddr) Network() string { return "tcp" } func (a releaseP0TestAddr) Network() string { return "tcp" }
func (a releaseP0TestAddr) String() string { return string(a) } func (a releaseP0TestAddr) String() string { return string(a) }
type closeInspectConn struct {
closeFn func()
closed atomic.Bool
}
func (c *closeInspectConn) Read([]byte) (int, error) { return 0, net.ErrClosed }
func (c *closeInspectConn) Write(p []byte) (int, error) { return len(p), nil }
func (c *closeInspectConn) LocalAddr() net.Addr { return releaseP0TestAddr("local") }
func (c *closeInspectConn) RemoteAddr() net.Addr { return releaseP0TestAddr("remote") }
func (c *closeInspectConn) SetDeadline(time.Time) error { return nil }
func (c *closeInspectConn) SetReadDeadline(time.Time) error { return nil }
func (c *closeInspectConn) SetWriteDeadline(time.Time) error { return nil }
func (c *closeInspectConn) Close() error {
if c == nil {
return nil
}
if c.closed.CompareAndSwap(false, true) && c.closeFn != nil {
c.closeFn()
}
return nil
}
func TestGetLogicalConnRuntimeSnapshotWithoutCompatClient(t *testing.T) { func TestGetLogicalConnRuntimeSnapshotWithoutCompatClient(t *testing.T) {
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
logical := &LogicalConn{server: server} logical := &LogicalConn{server: server}
@@ -101,6 +125,116 @@ func TestHandleDedicatedBulkReadErrorPreservesUnderlyingCause(t *testing.T) {
} }
} }
func TestHandleClientDedicatedSidecarFailureMarksBulkBeforeClosingConn(t *testing.T) {
client := NewClient().(*ClientCommon)
runtime := client.getBulkRuntime()
if runtime == nil {
t.Fatal("client bulk runtime should not be nil")
}
bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{
BulkID: "sidecar-order",
DataID: 7,
Dedicated: true,
Range: BulkRange{
Length: 1,
},
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
var closeObservedErr error
conn := &closeInspectConn{
closeFn: func() {
closeObservedErr = bulk.resetErrSnapshot()
},
}
if err := bulk.attachDedicatedConnShared(conn); err != nil {
t.Fatalf("attachDedicatedConnShared failed: %v", err)
}
if err := runtime.register(clientFileScope(), bulk); err != nil {
t.Fatalf("register bulk failed: %v", err)
}
sidecar := newBulkDedicatedSidecar(conn, 1)
client.installClientDedicatedSidecar(1, sidecar)
client.handleClientDedicatedSidecarFailure(sidecar, errors.New("boom sidecar"))
if !errors.Is(closeObservedErr, errTransportDetached) {
t.Fatalf("closeObservedErr = %v, want transport detached", closeObservedErr)
}
if !strings.Contains(closeObservedErr.Error(), "dedicated bulk read error") || !strings.Contains(closeObservedErr.Error(), "boom sidecar") {
t.Fatalf("closeObservedErr detail = %q, want dedicated read error and cause", closeObservedErr.Error())
}
}
func TestCleanupClientSessionResourcesMarksBulkBeforeClosingSidecar(t *testing.T) {
client := NewClient().(*ClientCommon)
runtime := client.getBulkRuntime()
if runtime == nil {
t.Fatal("client bulk runtime should not be nil")
}
bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{
BulkID: "cleanup-order",
DataID: 9,
Dedicated: true,
Range: BulkRange{
Length: 1,
},
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
var closeObservedErr error
conn := &closeInspectConn{
closeFn: func() {
closeObservedErr = bulk.resetErrSnapshot()
},
}
if err := bulk.attachDedicatedConnShared(conn); err != nil {
t.Fatalf("attachDedicatedConnShared failed: %v", err)
}
if err := runtime.register(clientFileScope(), bulk); err != nil {
t.Fatalf("register bulk failed: %v", err)
}
client.installClientDedicatedSidecar(1, newBulkDedicatedSidecar(conn, 1))
client.cleanupClientSessionResources()
if !errors.Is(closeObservedErr, errServiceShutdown) {
t.Fatalf("closeObservedErr = %v, want %v", closeObservedErr, errServiceShutdown)
}
}
func TestBestEffortRejectInboundDedicatedDataUsesDedicatedResetRecord(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
logical := server.bootstrapAcceptedLogical("dedicated-reject", nil, nil)
if logical == nil {
t.Fatal("bootstrapAcceptedLogical should return logical")
}
conn := newBulkAttachScriptConn(nil)
server.bestEffortRejectInboundDedicatedData(logical, conn, 42, "unknown data id")
recordConn := newBulkAttachScriptConn(conn.writtenBytes())
payload, err := readBulkDedicatedRecord(recordConn)
if err != nil {
t.Fatalf("readBulkDedicatedRecord failed: %v", err)
}
plain, err := server.decryptTransportPayloadLogical(logical, payload)
if err != nil {
t.Fatalf("decryptTransportPayloadLogical failed: %v", err)
}
items, err := decodeDedicatedBulkInboundItems(42, plain)
if err != nil {
t.Fatalf("decodeDedicatedBulkInboundItems failed: %v", err)
}
if len(items) != 1 {
t.Fatalf("decoded items = %d, want 1", len(items))
}
if items[0].Type != bulkFastPayloadTypeReset {
t.Fatalf("reset item type = %d, want %d", items[0].Type, bulkFastPayloadTypeReset)
}
if got, want := string(items[0].Payload), "unknown data id"; got != want {
t.Fatalf("reset payload = %q, want %q", got, want)
}
}
func TestRegisterAcceptedLogicalWithoutCompatClient(t *testing.T) { func TestRegisterAcceptedLogicalWithoutCompatClient(t *testing.T) {
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server) UseLegacySecurityServer(server)
+139
View File
@@ -0,0 +1,139 @@
package notify
import (
"bytes"
"crypto/ecdh"
"crypto/hmac"
cryptorand "crypto/rand"
"crypto/sha256"
"encoding/binary"
"errors"
)
const (
peerAttachECDHEPublicKeySize = 32
peerAttachSessionIDSize = 16
peerAttachKeyModeStatic = "psk-static"
peerAttachKeyModeECDHE = "psk-ecdhe"
transportKeyModeExternal = "external"
)
var errPeerAttachForwardSecrecyInvalid = errors.New("peer attach forward secrecy is invalid")
type peerAttachRequestState struct {
forwardSecrecy *peerAttachForwardSecrecyClientState
}
type peerAttachForwardSecrecyClientState struct {
privateKey *ecdh.PrivateKey
publicKey []byte
}
type peerAttachResponseVerifyResult struct {
authFallback bool
steadyProfile transportProtectionProfile
}
func newPeerAttachForwardSecrecyClientState() (*peerAttachForwardSecrecyClientState, error) {
curve := ecdh.X25519()
privateKey, err := curve.GenerateKey(cryptorand.Reader)
if err != nil {
return nil, err
}
publicKey := privateKey.PublicKey().Bytes()
if len(publicKey) != peerAttachECDHEPublicKeySize {
return nil, errPeerAttachForwardSecrecyInvalid
}
return &peerAttachForwardSecrecyClientState{
privateKey: privateKey,
publicKey: bytes.Clone(publicKey),
}, nil
}
func derivePeerAttachForwardSecrecyTransportProfile(base transportProtectionProfile, bootstrapKey []byte, localPrivateKey *ecdh.PrivateKey, peerPublicKey []byte, req peerAttachRequest, resp peerAttachResponse) (transportProtectionProfile, error) {
if len(bootstrapKey) == 0 || localPrivateKey == nil || len(peerPublicKey) != peerAttachECDHEPublicKeySize {
return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid
}
curve := ecdh.X25519()
publicKey, err := curve.NewPublicKey(peerPublicKey)
if err != nil {
return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid
}
sharedSecret, err := localPrivateKey.ECDH(publicKey)
if err != nil {
return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid
}
transcriptHash := peerAttachForwardSecrecyTranscriptHash(req, resp)
ikm := make([]byte, 0, len(sharedSecret)+len(transcriptHash))
ikm = append(ikm, sharedSecret...)
ikm = append(ikm, transcriptHash...)
prk := hkdfExtractSHA256(bootstrapKey, ikm)
sessionKey := hkdfExpandSHA256(prk, []byte("notify/transport/session/v1"), 32)
sessionID := hkdfExpandSHA256(prk, []byte("notify/session-id/v1"), peerAttachSessionIDSize)
return deriveModernPSKSessionProtectionProfile(base, sessionKey, sessionID)
}
func peerAttachForwardSecrecyTranscriptHash(req peerAttachRequest, resp peerAttachResponse) []byte {
buf := make([]byte, 0, 256+len(req.PeerID)+len(resp.PeerID)+len(resp.Error)+len(req.ClientECDHEPublicKey)+len(resp.ServerECDHEPublicKey))
buf = appendPeerAttachTranscriptString(buf, "notify/peer-attach/forward-secrecy/v1")
buf = binary.BigEndian.AppendUint64(buf, req.Features)
buf = appendPeerAttachTranscriptString(buf, req.PeerID)
buf = appendPeerAttachTranscriptBytes(buf, req.ClientNonce)
buf = appendPeerAttachTranscriptBytes(buf, req.ClientECDHEPublicKey)
buf = binary.BigEndian.AppendUint64(buf, resp.Features)
buf = appendPeerAttachTranscriptString(buf, resp.PeerID)
buf = appendPeerAttachTranscriptBool(buf, resp.Accepted)
buf = appendPeerAttachTranscriptBool(buf, resp.Reused)
buf = appendPeerAttachTranscriptString(buf, resp.Error)
buf = appendPeerAttachTranscriptBytes(buf, resp.ServerNonce)
buf = appendPeerAttachTranscriptString(buf, resp.KeyMode)
buf = appendPeerAttachTranscriptBytes(buf, resp.ServerECDHEPublicKey)
sum := sha256.Sum256(buf)
return sum[:]
}
func appendPeerAttachTranscriptBytes(dst []byte, data []byte) []byte {
dst = binary.BigEndian.AppendUint32(dst, uint32(len(data)))
return append(dst, data...)
}
func appendPeerAttachTranscriptString(dst []byte, value string) []byte {
return appendPeerAttachTranscriptBytes(dst, []byte(value))
}
func appendPeerAttachTranscriptBool(dst []byte, value bool) []byte {
if value {
return append(dst, 1)
}
return append(dst, 0)
}
func hkdfExtractSHA256(salt []byte, ikm []byte) []byte {
mac := hmac.New(sha256.New, salt)
_, _ = mac.Write(ikm)
return mac.Sum(nil)
}
func hkdfExpandSHA256(prk []byte, info []byte, size int) []byte {
if size <= 0 {
return nil
}
out := make([]byte, 0, size)
var block []byte
for counter := byte(1); len(out) < size; counter++ {
mac := hmac.New(sha256.New, prk)
if len(block) != 0 {
_, _ = mac.Write(block)
}
_, _ = mac.Write(info)
_, _ = mac.Write([]byte{counter})
block = mac.Sum(nil)
remaining := size - len(out)
if remaining > len(block) {
remaining = len(block)
}
out = append(out, block[:remaining]...)
}
return out
}
+408
View File
@@ -0,0 +1,408 @@
package notify
import "bytes"
// AuthMode describes how notify authenticates the peer during bootstrap.
type AuthMode int
const (
AuthNone AuthMode = iota
AuthPSK
AuthExternalPeer
)
// ProtectionMode describes how notify protects steady-state transport payloads.
type ProtectionMode int
const (
ProtectionManaged ProtectionMode = iota
ProtectionExternal
ProtectionNested
)
// SecurityOptions describes the high-level auth/protection policy.
//
// The current implementation still exposes dedicated helper constructors such
// as UseModernPSKClient/Server and UsePSKOverExternalTransportClient/Server.
type SecurityOptions struct {
AuthMode AuthMode
ProtectionMode ProtectionMode
SharedSecret []byte
RequireForwardSecrecy bool
}
func authModeName(mode AuthMode) string {
switch mode {
case AuthNone:
return "none"
case AuthPSK:
return "psk"
case AuthExternalPeer:
return "external-peer"
default:
return "unknown"
}
}
func protectionModeName(mode ProtectionMode) string {
switch mode {
case ProtectionManaged:
return "managed"
case ProtectionExternal:
return "external"
case ProtectionNested:
return "nested"
default:
return "unknown"
}
}
type transportProtectionProfile struct {
mode ProtectionMode
msgEn func([]byte, []byte) []byte
msgDe func([]byte, []byte) []byte
fastStreamEncode transportFastStreamEncoder
fastBulkEncode transportFastBulkEncoder
fastPlainEncode transportFastPlainEncoder
runtime *modernPSKCodecRuntime
secretKey []byte
keyMode string
sessionID []byte
forwardSecrecy bool
forwardSecrecyFallback bool
}
func cloneTransportProtectionKey(src []byte) []byte {
if len(src) == 0 {
return nil
}
return bytes.Clone(src)
}
func newTransportProtectionProfile(mode ProtectionMode, bundle modernPSKTransportBundle, runtime *modernPSKCodecRuntime, secretKey []byte) transportProtectionProfile {
return transportProtectionProfile{
mode: mode,
msgEn: bundle.msgEn,
msgDe: bundle.msgDe,
fastStreamEncode: bundle.fastStreamEncode,
fastBulkEncode: bundle.fastBulkEncode,
fastPlainEncode: bundle.fastPlainEncode,
runtime: runtime,
secretKey: cloneTransportProtectionKey(secretKey),
keyMode: defaultTransportKeyMode(mode, secretKey),
}
}
func defaultTransportKeyMode(mode ProtectionMode, secretKey []byte) string {
if len(secretKey) == 0 {
return ""
}
switch mode {
case ProtectionManaged, ProtectionNested:
return peerAttachKeyModeStatic
case ProtectionExternal:
return transportKeyModeExternal
default:
return ""
}
}
func cloneTransportSessionID(src []byte) []byte {
if len(src) == 0 {
return nil
}
return bytes.Clone(src)
}
func (p transportProtectionProfile) clone() transportProtectionProfile {
p.secretKey = cloneTransportProtectionKey(p.secretKey)
p.sessionID = cloneTransportSessionID(p.sessionID)
return p
}
func (p transportProtectionProfile) withForwardSecrecyFallback(fallback bool) transportProtectionProfile {
p = p.clone()
p.forwardSecrecy = false
p.forwardSecrecyFallback = fallback
p.sessionID = nil
return p
}
func passthroughTransportCodec(_ []byte, payload []byte) []byte {
return payload
}
func passthroughFastPlainEncode(_ []byte, plainLen int, fill func([]byte) error) ([]byte, error) {
if plainLen < 0 {
return nil, errTransportPayloadEncryptFailed
}
buf := make([]byte, plainLen)
if fill != nil {
if err := fill(buf); err != nil {
return nil, err
}
}
return buf, nil
}
func buildExternalTransportBundle() modernPSKTransportBundle {
fastPlainEncode := passthroughFastPlainEncode
fastStreamEncode := func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
return encodeStreamFastFramePayloadFast(fastPlainEncode, secretKey, streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: payload,
})
}
fastBulkEncode := func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
return encodeBulkFastFramePayloadFast(fastPlainEncode, secretKey, bulkFastFrame{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: payload,
})
}
return modernPSKTransportBundle{
msgEn: passthroughTransportCodec,
msgDe: passthroughTransportCodec,
fastStreamEncode: fastStreamEncode,
fastBulkEncode: fastBulkEncode,
fastPlainEncode: fastPlainEncode,
}
}
func defaultTransportProtectionProfile() transportProtectionProfile {
return newTransportProtectionProfile(ProtectionManaged, defaultModernPSKTransportBundle(), nil, nil)
}
func (c *ClientCommon) clientTransportProtectionSnapshot() transportProtectionProfile {
if c == nil {
return transportProtectionProfile{}
}
if state := c.transportProtection.Load(); state != nil {
return *state
}
return transportProtectionProfile{
mode: ProtectionManaged,
msgEn: c.msgEn,
msgDe: c.msgDe,
fastStreamEncode: c.fastStreamEncode,
fastBulkEncode: c.fastBulkEncode,
fastPlainEncode: c.fastPlainEncode,
runtime: c.modernPSKRuntime,
secretKey: c.SecretKey,
keyMode: defaultTransportKeyMode(ProtectionManaged, c.SecretKey),
}
}
func (c *ClientCommon) setClientTransportProtectionProfile(profile transportProtectionProfile) {
if c == nil {
return
}
profile.secretKey = cloneTransportProtectionKey(profile.secretKey)
profile.sessionID = cloneTransportSessionID(profile.sessionID)
c.msgEn = profile.msgEn
c.msgDe = profile.msgDe
c.fastStreamEncode = profile.fastStreamEncode
c.fastBulkEncode = profile.fastBulkEncode
c.fastPlainEncode = profile.fastPlainEncode
c.modernPSKRuntime = profile.runtime
c.SecretKey = profile.secretKey
c.transportProtection.Store(&profile)
}
func (c *ClientCommon) clearClientSecurityProfiles() {
if c == nil {
return
}
c.securityConfigured = false
c.securityAuthMode = AuthNone
c.securityProtectionMode = ProtectionManaged
c.securityBootstrap = transportProtectionProfile{}
c.securitySteady = transportProtectionProfile{}
c.securitySteadyNegotiated = transportProtectionProfile{}
c.securityRequireForwardSecrecy = false
c.peerAttachAuthenticated = false
c.peerAttachAuthFallback = false
c.peerAttachAt = 0
}
func (c *ClientCommon) configureClientSecurityProfiles(authMode AuthMode, protectionMode ProtectionMode, bootstrap transportProtectionProfile, steady transportProtectionProfile, requireForwardSecrecy bool) {
if c == nil {
return
}
c.securityConfigured = true
c.securityAuthMode = authMode
c.securityProtectionMode = protectionMode
c.securityBootstrap = bootstrap.clone()
c.securitySteady = steady.clone()
c.securitySteadyNegotiated = steady.clone()
c.securityRequireForwardSecrecy = requireForwardSecrecy
c.setClientTransportProtectionProfile(bootstrap)
c.securityReadyCheck = len(bootstrap.secretKey) != 0
c.skipKeyExchange = true
}
func (c *ClientCommon) activateClientBootstrapTransportProtection() {
if c == nil || !c.securityConfigured {
return
}
c.resetClientNegotiatedSteadyTransportProtection()
c.setClientTransportProtectionProfile(c.securityBootstrap)
}
func (c *ClientCommon) activateClientSteadyTransportProtection() {
if c == nil || !c.securityConfigured {
return
}
c.setClientTransportProtectionProfile(c.securitySteadyNegotiated)
}
func (c *ClientCommon) resetClientNegotiatedSteadyTransportProtection() {
if c == nil {
return
}
c.securitySteadyNegotiated = c.securitySteady.clone().withForwardSecrecyFallback(false)
}
func (c *ClientCommon) setClientNegotiatedSteadyTransportProtection(profile transportProtectionProfile) {
if c == nil {
return
}
c.securitySteadyNegotiated = profile.clone()
}
func (c *ClientCommon) clientSupportsForwardSecrecy() bool {
if c == nil || !c.securityConfigured {
return false
}
if c.securityAuthMode != AuthPSK {
return false
}
if c.securityProtectionMode == ProtectionExternal {
return false
}
return len(c.securityBootstrap.secretKey) != 0
}
func (c *ClientCommon) clientRequiresForwardSecrecy() bool {
if c == nil {
return false
}
return c.securityRequireForwardSecrecy
}
func (s *ServerCommon) serverSupportsForwardSecrecy() bool {
if s == nil || !s.securityConfigured {
return false
}
if s.securityAuthMode != AuthPSK {
return false
}
if s.securityProtectionMode == ProtectionExternal {
return false
}
return len(s.securityBootstrap.secretKey) != 0
}
func (s *ServerCommon) serverRequiresForwardSecrecy() bool {
if s == nil {
return false
}
return s.securityRequireForwardSecrecy
}
func (s *ServerCommon) setServerDefaultTransportProtectionProfile(profile transportProtectionProfile) {
if s == nil {
return
}
profile.secretKey = cloneTransportProtectionKey(profile.secretKey)
profile.sessionID = cloneTransportSessionID(profile.sessionID)
s.defaultMsgEn = profile.msgEn
s.defaultMsgDe = profile.msgDe
s.defaultFastStreamEncode = profile.fastStreamEncode
s.defaultFastBulkEncode = profile.fastBulkEncode
s.defaultFastPlainEncode = profile.fastPlainEncode
s.defaultModernPSKRuntime = profile.runtime
s.SecretKey = profile.secretKey
}
func (s *ServerCommon) clearServerSecurityProfiles() {
if s == nil {
return
}
s.securityConfigured = false
s.securityAuthMode = AuthNone
s.securityProtectionMode = ProtectionManaged
s.securityBootstrap = transportProtectionProfile{}
s.securitySteady = transportProtectionProfile{}
s.securityRequireForwardSecrecy = false
}
func (s *ServerCommon) configureServerSecurityProfiles(authMode AuthMode, protectionMode ProtectionMode, bootstrap transportProtectionProfile, steady transportProtectionProfile, requireForwardSecrecy bool) {
if s == nil {
return
}
s.securityConfigured = true
s.securityAuthMode = authMode
s.securityProtectionMode = protectionMode
s.securityBootstrap = bootstrap.clone()
s.securitySteady = steady.clone()
s.securityRequireForwardSecrecy = requireForwardSecrecy
s.setServerDefaultTransportProtectionProfile(bootstrap)
s.securityReadyCheck = len(bootstrap.secretKey) != 0
}
func (s *ServerCommon) applyLogicalSteadyTransportProtection(logical *LogicalConn) {
if s == nil || logical == nil || !s.securityConfigured {
return
}
logical.applyTransportProtectionProfile(s.securitySteady)
}
func transportProtectionProfileFromAttachmentState(state *clientConnAttachmentState) transportProtectionProfile {
if state == nil {
return transportProtectionProfile{}
}
return transportProtectionProfile{
mode: state.protectionMode,
msgEn: state.msgEn,
msgDe: state.msgDe,
fastStreamEncode: state.fastStreamEncode,
fastBulkEncode: state.fastBulkEncode,
fastPlainEncode: state.fastPlainEncode,
runtime: state.modernPSKRuntime,
secretKey: cloneTransportProtectionKey(state.secretKey),
keyMode: state.keyMode,
sessionID: cloneTransportSessionID(state.sessionID),
forwardSecrecy: state.forwardSecrecy,
forwardSecrecyFallback: state.forwardSecrecyFallback,
}
}
func (c *LogicalConn) transportProtectionProfileSnapshot() transportProtectionProfile {
if c == nil {
return transportProtectionProfile{}
}
return transportProtectionProfileFromAttachmentState(c.attachmentStateSnapshot())
}
func (c *LogicalConn) applyTransportProtectionProfile(profile transportProtectionProfile) {
if c == nil {
return
}
c.updateAttachmentState(func(state *clientConnAttachmentState) {
state.protectionMode = profile.mode
state.msgEn = profile.msgEn
state.msgDe = profile.msgDe
state.fastStreamEncode = profile.fastStreamEncode
state.fastBulkEncode = profile.fastBulkEncode
state.fastPlainEncode = profile.fastPlainEncode
state.modernPSKRuntime = profile.runtime
state.secretKey = cloneTransportProtectionKey(profile.secretKey)
state.keyMode = profile.keyMode
state.sessionID = cloneTransportSessionID(profile.sessionID)
state.forwardSecrecy = profile.forwardSecrecy
state.forwardSecrecyFallback = profile.forwardSecrecyFallback
})
}
+313 -51
View File
@@ -15,9 +15,10 @@ import (
) )
var ( var (
errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty") errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty")
errModernPSKPayload = errors.New("invalid modern psk payload") errModernPSKPayload = errors.New("invalid modern psk payload")
errModernPSKRequired = errors.New("modern psk is required: call UseModernPSKClient/UseModernPSKServer or set a transport key before Connect/Listen") errModernPSKRequired = errors.New("transport security is required: call UseModernPSKClient/UseModernPSKServer, UsePSKOverExternalTransportClient/Server, or set a transport key before Connect/Listen")
errModernPSKForwardSecrecyUnsupported = errors.New("forward secrecy is unsupported for external transport protection")
) )
var ( var (
@@ -40,14 +41,17 @@ type modernPSKTransportBundle struct {
fastPlainEncode transportFastPlainEncoder fastPlainEncode transportFastPlainEncoder
} }
var modernPSKPayloadPool sync.Pool
// ModernPSKOptions configures the modern PSK transport profile. // ModernPSKOptions configures the modern PSK transport profile.
// //
// The current profile derives a 32-byte transport key with Argon2id and uses // The current profile derives a 32-byte transport key with Argon2id and uses
// AES-GCM with a per-codec nonce prefix plus a per-message counter. // AES-GCM with a per-codec nonce prefix plus a per-message counter.
type ModernPSKOptions struct { type ModernPSKOptions struct {
Salt []byte Salt []byte
AAD []byte AAD []byte
Argon2Params starcrypto.Argon2Params Argon2Params starcrypto.Argon2Params
RequireForwardSecrecy bool
} }
// DefaultModernPSKOptions returns the recommended settings for the current // DefaultModernPSKOptions returns the recommended settings for the current
@@ -76,19 +80,17 @@ func defaultModernPSKTransportBundle() modernPSKTransportBundle {
// Argon2id, and switches message protection to AES-GCM. Configure it before // Argon2id, and switches message protection to AES-GCM. Configure it before
// calling Connect/ConnectTimeout. // calling Connect/ConnectTimeout.
func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error { func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
key, aad, err := deriveModernPSKKey(sharedSecret, opts) managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
if err != nil { if err != nil {
return err return err
} }
transport := buildModernPSKTransportBundle(aad)
c.SetSecretKey(key)
c.SetMsgEn(transport.msgEn)
c.SetMsgDe(transport.msgDe)
if client, ok := c.(*ClientCommon); ok { if client, ok := c.(*ClientCommon); ok {
client.fastStreamEncode = transport.fastStreamEncode client.configureClientSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy)
client.fastBulkEncode = transport.fastBulkEncode return nil
client.fastPlainEncode = transport.fastPlainEncode
} }
c.SetSecretKey(managed.secretKey)
c.SetMsgEn(managed.msgEn)
c.SetMsgDe(managed.msgDe)
c.SetSkipExchangeKey(true) c.SetSkipExchangeKey(true)
return nil return nil
} }
@@ -99,19 +101,95 @@ func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) e
// It derives a transport key with Argon2id and switches message protection to // It derives a transport key with Argon2id and switches message protection to
// AES-GCM. Configure it before calling Listen. // AES-GCM. Configure it before calling Listen.
func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error { func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
key, aad, err := deriveModernPSKKey(sharedSecret, opts) managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
if err != nil { if err != nil {
return err return err
} }
transport := buildModernPSKTransportBundle(aad)
s.SetSecretKey(key)
s.SetDefaultCommEncode(transport.msgEn)
s.SetDefaultCommDecode(transport.msgDe)
if server, ok := s.(*ServerCommon); ok { if server, ok := s.(*ServerCommon); ok {
server.defaultFastStreamEncode = transport.fastStreamEncode server.configureServerSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy)
server.defaultFastBulkEncode = transport.fastBulkEncode return nil
server.defaultFastPlainEncode = transport.fastPlainEncode
} }
s.SetSecretKey(managed.secretKey)
s.SetDefaultCommEncode(managed.msgEn)
s.SetDefaultCommDecode(managed.msgDe)
return nil
}
// UsePSKOverExternalTransportClient authenticates bootstrap with PSK and then
// trusts the external channel for steady-state payload protection.
func UsePSKOverExternalTransportClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
if opts != nil && opts.RequireForwardSecrecy {
return errModernPSKForwardSecrecyUnsupported
}
bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
if err != nil {
return err
}
steady := buildExternalProtectionProfile(bootstrap.secretKey)
if client, ok := c.(*ClientCommon); ok {
client.configureClientSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false)
return nil
}
c.SetSecretKey(bootstrap.secretKey)
c.SetMsgEn(bootstrap.msgEn)
c.SetMsgDe(bootstrap.msgDe)
c.SetSkipExchangeKey(true)
return nil
}
// UsePSKOverExternalTransportServer authenticates bootstrap with PSK and then
// trusts the external channel for steady-state payload protection.
func UsePSKOverExternalTransportServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
if opts != nil && opts.RequireForwardSecrecy {
return errModernPSKForwardSecrecyUnsupported
}
bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
if err != nil {
return err
}
steady := buildExternalProtectionProfile(bootstrap.secretKey)
if server, ok := s.(*ServerCommon); ok {
server.configureServerSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false)
return nil
}
s.SetSecretKey(bootstrap.secretKey)
s.SetDefaultCommEncode(bootstrap.msgEn)
s.SetDefaultCommDecode(bootstrap.msgDe)
return nil
}
// UseNestedSecurityClient keeps notify transport protection enabled even when
// the physical connection is already protected by an outer trusted channel.
func UseNestedSecurityClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested)
if err != nil {
return err
}
if client, ok := c.(*ClientCommon); ok {
client.configureClientSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy)
return nil
}
c.SetSecretKey(managed.secretKey)
c.SetMsgEn(managed.msgEn)
c.SetMsgDe(managed.msgDe)
c.SetSkipExchangeKey(true)
return nil
}
// UseNestedSecurityServer keeps notify transport protection enabled even when
// the physical connection is already protected by an outer trusted channel.
func UseNestedSecurityServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested)
if err != nil {
return err
}
if server, ok := s.(*ServerCommon); ok {
server.configureServerSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy)
return nil
}
s.SetSecretKey(managed.secretKey)
s.SetDefaultCommEncode(managed.msgEn)
s.SetDefaultCommDecode(managed.msgDe)
return nil return nil
} }
@@ -120,14 +198,22 @@ func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) e
// //
// It is kept only as an explicit fallback path for existing deployments. // It is kept only as an explicit fallback path for existing deployments.
func UseLegacySecurityClient(c Client) { func UseLegacySecurityClient(c Client) {
if client, ok := c.(*ClientCommon); ok {
client.clearClientSecurityProfiles()
client.setClientTransportProtectionProfile(transportProtectionProfile{
mode: ProtectionManaged,
msgEn: defaultMsgEn,
msgDe: defaultMsgDe,
secretKey: bytes.Clone(defaultAesKey),
})
client.securityReadyCheck = false
client.skipKeyExchange = false
client.handshakeRsaPubKey = bytes.Clone(defaultRsaPubKey)
return
}
c.SetSecretKey(bytes.Clone(defaultAesKey)) c.SetSecretKey(bytes.Clone(defaultAesKey))
c.SetMsgEn(defaultMsgEn) c.SetMsgEn(defaultMsgEn)
c.SetMsgDe(defaultMsgDe) c.SetMsgDe(defaultMsgDe)
if client, ok := c.(*ClientCommon); ok {
client.fastStreamEncode = nil
client.fastBulkEncode = nil
client.fastPlainEncode = nil
}
c.SetSkipExchangeKey(false) c.SetSkipExchangeKey(false)
c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey)) c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey))
} }
@@ -137,14 +223,21 @@ func UseLegacySecurityClient(c Client) {
// //
// It is kept only as an explicit fallback path for existing deployments. // It is kept only as an explicit fallback path for existing deployments.
func UseLegacySecurityServer(s Server) { func UseLegacySecurityServer(s Server) {
if server, ok := s.(*ServerCommon); ok {
server.clearServerSecurityProfiles()
server.setServerDefaultTransportProtectionProfile(transportProtectionProfile{
mode: ProtectionManaged,
msgEn: defaultMsgEn,
msgDe: defaultMsgDe,
secretKey: bytes.Clone(defaultAesKey),
})
server.securityReadyCheck = false
server.handshakeRsaKey = bytes.Clone(defaultRsaKey)
return
}
s.SetSecretKey(bytes.Clone(defaultAesKey)) s.SetSecretKey(bytes.Clone(defaultAesKey))
s.SetDefaultCommEncode(defaultMsgEn) s.SetDefaultCommEncode(defaultMsgEn)
s.SetDefaultCommDecode(defaultMsgDe) s.SetDefaultCommDecode(defaultMsgDe)
if server, ok := s.(*ServerCommon); ok {
server.defaultFastStreamEncode = nil
server.defaultFastBulkEncode = nil
server.defaultFastPlainEncode = nil
}
s.SetRsaPrivKey(bytes.Clone(defaultRsaKey)) s.SetRsaPrivKey(bytes.Clone(defaultRsaKey))
} }
@@ -160,6 +253,40 @@ func deriveModernPSKKey(sharedSecret []byte, opts *ModernPSKOptions) ([]byte, []
return key, cfg.AAD, nil return key, cfg.AAD, nil
} }
func deriveModernPSKProtectionProfile(sharedSecret []byte, opts *ModernPSKOptions, mode ProtectionMode) (transportProtectionProfile, error) {
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
if err != nil {
return transportProtectionProfile{}, err
}
transport := buildModernPSKTransportBundle(aad)
runtime, err := newModernPSKCodecRuntime(key, aad)
if err != nil {
return transportProtectionProfile{}, err
}
return newTransportProtectionProfile(mode, transport, runtime, key), nil
}
func buildExternalProtectionProfile(secretKey []byte) transportProtectionProfile {
return newTransportProtectionProfile(ProtectionExternal, buildExternalTransportBundle(), nil, secretKey)
}
func deriveModernPSKSessionProtectionProfile(base transportProtectionProfile, sessionKey []byte, sessionID []byte) (transportProtectionProfile, error) {
aad := bytes.Clone(defaultModernPSKAAD)
if base.runtime != nil && len(base.runtime.aad) != 0 {
aad = bytes.Clone(base.runtime.aad)
}
runtime, err := newModernPSKCodecRuntime(sessionKey, aad)
if err != nil {
return transportProtectionProfile{}, err
}
profile := newTransportProtectionProfile(base.mode, buildModernPSKTransportBundle(aad), runtime, sessionKey)
profile.keyMode = peerAttachKeyModeECDHE
profile.sessionID = cloneTransportSessionID(sessionID)
profile.forwardSecrecy = true
profile.forwardSecrecyFallback = false
return profile, nil
}
func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions { func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions {
cfg := DefaultModernPSKOptions() cfg := DefaultModernPSKOptions()
if opts == nil { if opts == nil {
@@ -185,14 +312,14 @@ func buildModernPSKCodecs(aad []byte) (func([]byte, []byte) []byte, func([]byte,
func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle { func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle {
aadCopy := bytes.Clone(aad) aadCopy := bytes.Clone(aad)
cache := &modernPSKCodecCache{} cache := newModernPSKCodecCache(aadCopy)
msgEn := func(key []byte, plain []byte) []byte { msgEn := func(key []byte, plain []byte) []byte {
runtime, err := cache.runtimeForKey(key) runtime, err := cache.runtimeForKey(key)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return nil return nil
} }
out, err := runtime.sealPlainPayload(aadCopy, plain) out, err := runtime.sealPlainPayload(plain)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return nil return nil
@@ -214,9 +341,7 @@ func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle {
log.Print(err) log.Print(err)
return nil return nil
} }
nonce := encrypted[len(modernPSKMagic):headerLen] plain, err := runtime.openPayload(encrypted)
ciphertext := encrypted[headerLen:]
plain, err := runtime.aead.Open(make([]byte, 0, len(ciphertext)), nonce, ciphertext, aadCopy)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return nil return nil
@@ -228,21 +353,21 @@ func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return runtime.sealStreamFastPayload(aadCopy, dataID, seq, payload) return runtime.sealStreamFastPayload(dataID, seq, payload)
} }
fastBulkEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { fastBulkEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
runtime, err := cache.runtimeForKey(key) runtime, err := cache.runtimeForKey(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return runtime.sealBulkFastPayload(aadCopy, dataID, seq, payload) return runtime.sealBulkFastPayload(dataID, seq, payload)
} }
fastPlainEncode := func(key []byte, plainLen int, fill func([]byte) error) ([]byte, error) { fastPlainEncode := func(key []byte, plainLen int, fill func([]byte) error) ([]byte, error) {
runtime, err := cache.runtimeForKey(key) runtime, err := cache.runtimeForKey(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return runtime.sealFilledPayload(aadCopy, plainLen, fill) return runtime.sealFilledPayload(plainLen, fill)
} }
return modernPSKTransportBundle{ return modernPSKTransportBundle{
msgEn: msgEn, msgEn: msgEn,
@@ -269,16 +394,23 @@ func (s *ServerCommon) validateSecurityConfiguration() error {
type modernPSKCodecCache struct { type modernPSKCodecCache struct {
mu sync.Mutex mu sync.Mutex
aad []byte
key []byte key []byte
runtime *modernPSKCodecRuntime runtime *modernPSKCodecRuntime
} }
type modernPSKCodecRuntime struct { type modernPSKCodecRuntime struct {
aead cipher.AEAD aead cipher.AEAD
key []byte
aad []byte
prefix [modernPSKNonceSize - 8]byte prefix [modernPSKNonceSize - 8]byte
seq atomic.Uint64 seq atomic.Uint64
} }
func newModernPSKCodecCache(aad []byte) *modernPSKCodecCache {
return &modernPSKCodecCache{aad: bytes.Clone(aad)}
}
func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime, error) { func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime, error) {
if c == nil { if c == nil {
return nil, errModernPSKSecretEmpty return nil, errModernPSKSecretEmpty
@@ -288,7 +420,7 @@ func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime,
if c.runtime != nil && bytes.Equal(c.key, key) { if c.runtime != nil && bytes.Equal(c.key, key) {
return c.runtime, nil return c.runtime, nil
} }
runtime, err := newModernPSKCodecRuntime(key) runtime, err := newModernPSKCodecRuntime(key, c.aad)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -297,7 +429,7 @@ func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime,
return runtime, nil return runtime, nil
} }
func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) { func newModernPSKCodecRuntime(key []byte, aad []byte) (*modernPSKCodecRuntime, error) {
if len(key) == 0 { if len(key) == 0 {
return nil, errModernPSKSecretEmpty return nil, errModernPSKSecretEmpty
} }
@@ -311,6 +443,8 @@ func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) {
} }
runtime := &modernPSKCodecRuntime{ runtime := &modernPSKCodecRuntime{
aead: aead, aead: aead,
key: bytes.Clone(key),
aad: bytes.Clone(aad),
} }
if _, err := cryptorand.Read(runtime.prefix[:]); err != nil { if _, err := cryptorand.Read(runtime.prefix[:]); err != nil {
return nil, err return nil, err
@@ -318,6 +452,13 @@ func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) {
return runtime, nil return runtime, nil
} }
func (r *modernPSKCodecRuntime) fork() (*modernPSKCodecRuntime, error) {
if r == nil {
return nil, errModernPSKSecretEmpty
}
return newModernPSKCodecRuntime(r.key, r.aad)
}
func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte { func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte {
var nonce [modernPSKNonceSize]byte var nonce [modernPSKNonceSize]byte
if r == nil { if r == nil {
@@ -328,8 +469,8 @@ func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte {
return nonce return nonce
} }
func (r *modernPSKCodecRuntime) sealStreamFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { func (r *modernPSKCodecRuntime) sealStreamFastPayload(dataID uint64, seq uint64, payload []byte) ([]byte, error) {
return r.sealFilledPayload(aad, streamFastPayloadHeaderLen+len(payload), func(frame []byte) error { return r.sealFilledPayload(streamFastPayloadHeaderLen+len(payload), func(frame []byte) error {
if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
return err return err
} }
@@ -338,11 +479,11 @@ func (r *modernPSKCodecRuntime) sealStreamFastPayload(aad []byte, dataID uint64,
}) })
} }
func (r *modernPSKCodecRuntime) sealBulkFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { func (r *modernPSKCodecRuntime) sealBulkFastPayload(dataID uint64, seq uint64, payload []byte) ([]byte, error) {
if r == nil { if r == nil {
return nil, errTransportPayloadEncryptFailed return nil, errTransportPayloadEncryptFailed
} }
return r.sealFilledPayload(aad, bulkFastPayloadHeaderLen+len(payload), func(frame []byte) error { return r.sealFilledPayload(bulkFastPayloadHeaderLen+len(payload), func(frame []byte) error {
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
return err return err
} }
@@ -351,14 +492,14 @@ func (r *modernPSKCodecRuntime) sealBulkFastPayload(aad []byte, dataID uint64, s
}) })
} }
func (r *modernPSKCodecRuntime) sealPlainPayload(aad []byte, plain []byte) ([]byte, error) { func (r *modernPSKCodecRuntime) sealPlainPayload(plain []byte) ([]byte, error) {
return r.sealFilledPayload(aad, len(plain), func(dst []byte) error { return r.sealFilledPayload(len(plain), func(dst []byte) error {
copy(dst, plain) copy(dst, plain)
return nil return nil
}) })
} }
func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill func([]byte) error) ([]byte, error) { func (r *modernPSKCodecRuntime) sealFilledPayload(plainLen int, fill func([]byte) error) ([]byte, error) {
if r == nil { if r == nil {
return nil, errTransportPayloadEncryptFailed return nil, errTransportPayloadEncryptFailed
} }
@@ -368,6 +509,35 @@ func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill
nonce := r.nextNonce() nonce := r.nextNonce()
headerLen := len(modernPSKMagic) + modernPSKNonceSize headerLen := len(modernPSKMagic) + modernPSKNonceSize
out := make([]byte, headerLen+plainLen+r.aead.Overhead()) out := make([]byte, headerLen+plainLen+r.aead.Overhead())
sealed, err := r.sealInto(out, headerLen, nonce, plainLen, fill)
if err != nil {
return nil, err
}
return out[:headerLen+len(sealed)], nil
}
func (r *modernPSKCodecRuntime) sealFilledPayloadPooled(plainLen int, fill func([]byte) error) ([]byte, func(), error) {
if r == nil {
return nil, nil, errTransportPayloadEncryptFailed
}
if plainLen < 0 {
return nil, nil, errTransportPayloadEncryptFailed
}
nonce := r.nextNonce()
headerLen := len(modernPSKMagic) + modernPSKNonceSize
totalLen := headerLen + plainLen + r.aead.Overhead()
out := getModernPSKPayloadBuffer(totalLen)
sealed, err := r.sealInto(out, headerLen, nonce, plainLen, fill)
if err != nil {
putModernPSKPayloadBuffer(out)
return nil, nil, err
}
return out[:headerLen+len(sealed)], func() {
putModernPSKPayloadBuffer(out)
}, nil
}
func (r *modernPSKCodecRuntime) sealInto(out []byte, headerLen int, nonce [modernPSKNonceSize]byte, plainLen int, fill func([]byte) error) ([]byte, error) {
copy(out[:len(modernPSKMagic)], modernPSKMagic) copy(out[:len(modernPSKMagic)], modernPSKMagic)
copy(out[len(modernPSKMagic):headerLen], nonce[:]) copy(out[len(modernPSKMagic):headerLen], nonce[:])
frame := out[headerLen : headerLen+plainLen] frame := out[headerLen : headerLen+plainLen]
@@ -376,6 +546,98 @@ func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill
return nil, err return nil, err
} }
} }
sealed := r.aead.Seal(frame[:0], nonce[:], frame, aad) return r.aead.Seal(frame[:0], nonce[:], frame, r.aad), nil
return out[:headerLen+len(sealed)], nil }
func (r *modernPSKCodecRuntime) openPayload(encrypted []byte) ([]byte, error) {
if r == nil {
return nil, errTransportPayloadDecryptFailed
}
headerLen := len(modernPSKMagic) + modernPSKNonceSize
if len(encrypted) < headerLen {
return nil, errModernPSKPayload
}
if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) {
return nil, errModernPSKPayload
}
nonce := encrypted[len(modernPSKMagic):headerLen]
ciphertext := encrypted[headerLen:]
return r.aead.Open(make([]byte, 0, len(ciphertext)), nonce, ciphertext, r.aad)
}
func (r *modernPSKCodecRuntime) openPayloadPooled(encrypted []byte, release func()) ([]byte, func(), error) {
if r == nil {
if release != nil {
release()
}
return nil, nil, errTransportPayloadDecryptFailed
}
headerLen := len(modernPSKMagic) + modernPSKNonceSize
if len(encrypted) < headerLen {
if release != nil {
release()
}
return nil, nil, errModernPSKPayload
}
if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) {
if release != nil {
release()
}
return nil, nil, errModernPSKPayload
}
nonce := encrypted[len(modernPSKMagic):headerLen]
ciphertext := encrypted[headerLen:]
plain, err := r.aead.Open(ciphertext[:0], nonce, ciphertext, r.aad)
if err != nil {
if release != nil {
release()
}
return nil, nil, err
}
return plain, release, nil
}
func (r *modernPSKCodecRuntime) openPayloadOwnedPooled(encrypted []byte) ([]byte, func(), error) {
if r == nil {
return nil, nil, errTransportPayloadDecryptFailed
}
headerLen := len(modernPSKMagic) + modernPSKNonceSize
if len(encrypted) < headerLen {
return nil, nil, errModernPSKPayload
}
if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) {
return nil, nil, errModernPSKPayload
}
nonce := encrypted[len(modernPSKMagic):headerLen]
ciphertext := encrypted[headerLen:]
plainLen := len(ciphertext) - r.aead.Overhead()
if plainLen < 0 {
return nil, nil, errModernPSKPayload
}
out := getModernPSKPayloadBuffer(plainLen)
plain, err := r.aead.Open(out[:0], nonce, ciphertext, r.aad)
if err != nil {
putModernPSKPayloadBuffer(out)
return nil, nil, err
}
return plain, func() {
putModernPSKPayloadBuffer(out)
}, nil
}
func getModernPSKPayloadBuffer(size int) []byte {
if size <= 0 {
return nil
}
if pooled, ok := modernPSKPayloadPool.Get().([]byte); ok && cap(pooled) >= size {
return pooled[:size]
}
return make([]byte, size)
}
func putModernPSKPayloadBuffer(buf []byte) {
if cap(buf) == 0 || cap(buf) > 32*1024*1024 {
return
}
modernPSKPayloadPool.Put(buf[:0])
} }
+202
View File
@@ -3,8 +3,10 @@ package notify
import ( import (
"bytes" "bytes"
"errors" "errors"
"net"
"reflect" "reflect"
"testing" "testing"
"time"
"b612.me/starcrypto" "b612.me/starcrypto"
) )
@@ -83,6 +85,12 @@ func TestDefaultConstructorsUseModernTransportAfterSetSecretKey(t *testing.T) {
sharedKey := []byte("0123456789abcdef0123456789abcdef") sharedKey := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(sharedKey) client.SetSecretKey(sharedKey)
server.SetSecretKey(sharedKey) server.SetSecretKey(sharedKey)
if client.modernPSKRuntime == nil {
t.Fatal("client modernPSKRuntime should be installed after SetSecretKey")
}
if server.defaultModernPSKRuntime == nil {
t.Fatal("server defaultModernPSKRuntime should be installed after SetSecretKey")
}
plain := []byte("notify default modern transport") plain := []byte("notify default modern transport")
wire := client.msgEn(client.SecretKey, plain) wire := client.msgEn(client.SecretKey, plain)
@@ -92,6 +100,29 @@ func TestDefaultConstructorsUseModernTransportAfterSetSecretKey(t *testing.T) {
} }
} }
func TestCustomCodecOverridesClearModernRuntime(t *testing.T) {
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
sharedKey := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(sharedKey)
server.SetSecretKey(sharedKey)
if client.modernPSKRuntime == nil || server.defaultModernPSKRuntime == nil {
t.Fatal("modern runtimes should be installed before override")
}
client.SetMsgEn(defaultMsgEn)
client.SetMsgDe(defaultMsgDe)
server.SetDefaultCommEncode(defaultMsgEn)
server.SetDefaultCommDecode(defaultMsgDe)
if client.modernPSKRuntime != nil {
t.Fatal("client modernPSKRuntime should be cleared after custom codec override")
}
if server.defaultModernPSKRuntime != nil {
t.Fatal("server defaultModernPSKRuntime should be cleared after custom codec override")
}
}
func TestDefaultConstructorsDecodeSignalEnvelopeWithModernTransport(t *testing.T) { func TestDefaultConstructorsDecodeSignalEnvelopeWithModernTransport(t *testing.T) {
client := NewClient().(*ClientCommon) client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
@@ -178,6 +209,17 @@ func TestUseModernPSKRejectsEmptySecret(t *testing.T) {
} }
} }
func TestUsePSKOverExternalTransportRejectsForwardSecrecyRequirement(t *testing.T) {
opts := testModernPSKOptions()
opts.RequireForwardSecrecy = true
if err := UsePSKOverExternalTransportClient(NewClient(), []byte("secret"), opts); !errors.Is(err, errModernPSKForwardSecrecyUnsupported) {
t.Fatalf("UsePSKOverExternalTransportClient error = %v, want %v", err, errModernPSKForwardSecrecyUnsupported)
}
if err := UsePSKOverExternalTransportServer(NewServer(), []byte("secret"), opts); !errors.Is(err, errModernPSKForwardSecrecyUnsupported) {
t.Fatalf("UsePSKOverExternalTransportServer error = %v, want %v", err, errModernPSKForwardSecrecyUnsupported)
}
}
func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) { func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) {
key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions()) key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions())
if err != nil { if err != nil {
@@ -286,6 +328,166 @@ func TestModernPSKFastBulkEncodeRoundTrip(t *testing.T) {
} }
} }
func TestExternalTransportFastStreamEncodeRoundTrip(t *testing.T) {
transport := buildExternalTransportBundle()
wire, err := transport.fastStreamEncode(nil, 23, 7, []byte("payload"))
if err != nil {
t.Fatalf("fastStreamEncode failed: %v", err)
}
plain := transport.msgDe(nil, wire)
frame, matched, err := decodeStreamFastDataFrame(plain)
if err != nil {
t.Fatalf("decodeStreamFastDataFrame failed: %v", err)
}
if !matched {
t.Fatal("decodeStreamFastDataFrame should match fast payload")
}
if frame.DataID != 23 || frame.Seq != 7 || !bytes.Equal(frame.Payload, []byte("payload")) {
t.Fatalf("frame mismatch: %+v", frame)
}
}
func TestExternalTransportFastBulkEncodeRoundTrip(t *testing.T) {
transport := buildExternalTransportBundle()
wire, err := transport.fastBulkEncode(nil, 41, 9, []byte("payload"))
if err != nil {
t.Fatalf("fastBulkEncode failed: %v", err)
}
plain := transport.msgDe(nil, wire)
frame, matched, err := decodeBulkFastDataFrame(plain)
if err != nil {
t.Fatalf("decodeBulkFastDataFrame failed: %v", err)
}
if !matched {
t.Fatal("decodeBulkFastDataFrame should match fast payload")
}
if frame.DataID != 41 || frame.Seq != 9 || !bytes.Equal(frame.Payload, []byte("payload")) {
t.Fatalf("frame mismatch: %+v", frame)
}
}
func TestDecryptTransportPayloadCodecPooledExternalDefersRelease(t *testing.T) {
payload := []byte("payload")
released := false
plain, release, err := decryptTransportPayloadCodecPooled(ProtectionExternal, nil, passthroughTransportCodec, nil, payload, func() {
released = true
})
if err != nil {
t.Fatalf("decryptTransportPayloadCodecPooled failed: %v", err)
}
if released {
t.Fatal("release should not run before caller is done")
}
if !bytes.Equal(plain, payload) {
t.Fatalf("plain mismatch: got %q want %q", plain, payload)
}
if release == nil {
t.Fatal("release callback should be preserved for external mode")
}
release()
if !released {
t.Fatal("release callback should run when caller finishes")
}
}
func TestUsePSKOverExternalTransportConnectByConnSwitchesToExternal(t *testing.T) {
client := NewClient().(*ClientCommon)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UsePSKOverExternalTransportServer(server, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
}
server.SetLink("external-roundtrip", func(msg *Message) {
_ = msg.Reply([]byte("ack:" + string(msg.Value)))
})
})
if err := UsePSKOverExternalTransportClient(client, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
}
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionManaged {
t.Fatalf("client bootstrap mode = %v, want %v", got, ProtectionManaged)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachConnForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionExternal {
t.Fatalf("client steady mode = %v, want %v", got, ProtectionExternal)
}
reply, err := client.SendWait("external-roundtrip", []byte("ping"), time.Second)
if err != nil {
t.Fatalf("SendWait failed: %v", err)
}
if got, want := string(reply.Value), "ack:ping"; got != want {
t.Fatalf("reply mismatch: got %q want %q", got, want)
}
list := server.GetLogicalConnList()
if len(list) != 1 {
t.Fatalf("logical conn count = %d, want 1", len(list))
}
if got := list[0].protectionModeSnapshot(); got != ProtectionExternal {
t.Fatalf("server steady mode = %v, want %v", got, ProtectionExternal)
}
}
func TestUseNestedSecurityConnectByConnKeepsNestedMode(t *testing.T) {
client := NewClient().(*ClientCommon)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UseNestedSecurityServer(server, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
t.Fatalf("UseNestedSecurityServer failed: %v", err)
}
server.SetLink("nested-roundtrip", func(msg *Message) {
_ = msg.Reply([]byte("ack:" + string(msg.Value)))
})
})
if err := UseNestedSecurityClient(client, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
t.Fatalf("UseNestedSecurityClient failed: %v", err)
}
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionNested {
t.Fatalf("client bootstrap mode = %v, want %v", got, ProtectionNested)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachConnForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionNested {
t.Fatalf("client steady mode = %v, want %v", got, ProtectionNested)
}
reply, err := client.SendWait("nested-roundtrip", []byte("ping"), time.Second)
if err != nil {
t.Fatalf("SendWait failed: %v", err)
}
if got, want := string(reply.Value), "ack:ping"; got != want {
t.Fatalf("reply mismatch: got %q want %q", got, want)
}
list := server.GetLogicalConnList()
if len(list) != 1 {
t.Fatalf("logical conn count = %d, want 1", len(list))
}
if got := list[0].protectionModeSnapshot(); got != ProtectionNested {
t.Fatalf("server steady mode = %v, want %v", got, ProtectionNested)
}
}
func TestUseLegacySecurityRoundTrip(t *testing.T) { func TestUseLegacySecurityRoundTrip(t *testing.T) {
client := NewClient() client := NewClient()
server := NewServer() server := NewServer()
+63 -42
View File
@@ -10,48 +10,65 @@ import (
) )
type ServerCommon struct { type ServerCommon struct {
msgID uint64 msgID uint64
alive atomic.Value alive atomic.Value
status Status status Status
sessionOwnerState atomic.Int32 sessionOwnerState atomic.Int32
sessionRuntime atomic.Pointer[serverSessionRuntime] sessionRuntime atomic.Pointer[serverSessionRuntime]
listener net.Listener listener net.Listener
udpListener *net.UDPConn udpListener *net.UDPConn
queue *stario.StarQueue queue *stario.StarQueue
stopFn context.CancelFunc stopFn context.CancelFunc
stopCtx context.Context stopCtx context.Context
maxReadTimeout time.Duration maxReadTimeout time.Duration
maxWriteTimeout time.Duration maxWriteTimeout time.Duration
parallelNum int parallelNum int
wg stario.WaitGroup wg stario.WaitGroup
peerRegistry *serverPeerRegistry peerRegistry *serverPeerRegistry
mu sync.RWMutex mu sync.RWMutex
handshakeRsaKey []byte handshakeRsaKey []byte
SecretKey []byte SecretKey []byte
defaultMsgEn func([]byte, []byte) []byte defaultMsgEn func([]byte, []byte) []byte
defaultMsgDe func([]byte, []byte) []byte defaultMsgDe func([]byte, []byte) []byte
defaultFastStreamEncode transportFastStreamEncoder defaultFastStreamEncode transportFastStreamEncoder
defaultFastBulkEncode transportFastBulkEncoder defaultFastBulkEncode transportFastBulkEncoder
defaultFastPlainEncode transportFastPlainEncoder defaultFastPlainEncode transportFastPlainEncoder
linkFns map[string]func(message *Message) defaultModernPSKRuntime *modernPSKCodecRuntime
defaultFns func(message *Message) peerAttachSecurity atomic.Pointer[peerAttachSecurityState]
noFinSyncMsgMaxKeepSeconds int64 securityBootstrap transportProtectionProfile
maxHeartbeatLostSeconds int64 securitySteady transportProtectionProfile
sequenceDe func([]byte) (interface{}, error) securityAuthMode AuthMode
sequenceEn func(interface{}) ([]byte, error) securityProtectionMode ProtectionMode
logicalSession *logicalSessionState securityRequireForwardSecrecy bool
onFileEvent func(FileEvent) securityConfigured bool
fileEventObserver func(FileEvent) peerAttachReplay peerAttachReplayCache
fileTransferCfg fileTransferConfig peerAttachExplicitCount atomic.Int64
signalReliableCfg signalReliabilityConfig peerAttachAuthFallbackCount atomic.Int64
streamRuntime *streamRuntime peerAttachAuthRejectCount atomic.Int64
recordRuntime *recordRuntime peerAttachDowngradeRejectCount atomic.Int64
bulkRuntime *bulkRuntime peerAttachBindingRejectCount atomic.Int64
connectionRetryState *connectionRetryState linkFns map[string]func(message *Message)
detachedClientKeepSeconds int64 defaultFns func(message *Message)
securityReadyCheck bool noFinSyncMsgMaxKeepSeconds int64
showError bool maxHeartbeatLostSeconds int64
debugMode bool sequenceDe func([]byte) (interface{}, error)
sequenceEn func(interface{}) ([]byte, error)
logicalSession *logicalSessionState
onFileEvent func(FileEvent)
fileEventObserver func(FileEvent)
fileTransferCfg fileTransferConfig
signalReliableCfg signalReliabilityConfig
streamRuntime *streamRuntime
recordRuntime *recordRuntime
bulkRuntime *bulkRuntime
bulkOpenTuning BulkOpenTuning
bulkDedicatedSidecarMu sync.Mutex
bulkDedicatedSidecars map[*LogicalConn]map[uint32]*bulkDedicatedSidecar
connectionRetryState *connectionRetryState
detachedClientKeepSeconds int64
securityReadyCheck bool
showError bool
debugMode bool
} }
func NewServer() Server { func NewServer() Server {
@@ -81,12 +98,16 @@ func NewServer() Server {
server.streamRuntime = newStreamRuntime("sstrm") server.streamRuntime = newStreamRuntime("sstrm")
server.recordRuntime = newRecordRuntime() server.recordRuntime = newRecordRuntime()
server.bulkRuntime = newBulkRuntime("sblk") server.bulkRuntime = newBulkRuntime("sblk")
server.bulkOpenTuning = defaultBulkOpenTuning()
server.bulkDedicatedSidecars = make(map[*LogicalConn]map[uint32]*bulkDedicatedSidecar)
server.connectionRetryState = newConnectionRetryState() server.connectionRetryState = newConnectionRetryState()
server.onFileEvent = normalizeFileEventCallback(nil) server.onFileEvent = normalizeFileEventCallback(nil)
server.fileEventObserver = normalizeFileEventCallback(nil) server.fileEventObserver = normalizeFileEventCallback(nil)
server.defaultFns = func(message *Message) { server.defaultFns = func(message *Message) {
return return
} }
server.setServerDefaultTransportProtectionProfile(defaultTransportProtectionProfile())
server.peerAttachSecurity.Store(defaultPeerAttachSecurityState())
server.sessionRuntime.Store(newServerSessionRuntimeBase(server.stopCtx, server.stopFn)) server.sessionRuntime.Store(newServerSessionRuntimeBase(server.stopCtx, server.stopFn))
bindServerStreamControl(&server) bindServerStreamControl(&server)
bindServerBulkControl(&server) bindServerBulkControl(&server)
+182 -17
View File
@@ -1,6 +1,9 @@
package notify package notify
import "context" import (
"context"
"errors"
)
func (s *ServerCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) { func (s *ServerCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) {
runtime := s.getBulkRuntime() runtime := s.getBulkRuntime()
@@ -11,9 +14,48 @@ func (s *ServerCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) {
} }
func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn, opt BulkOpenOptions) (Bulk, error) { func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn, opt BulkOpenOptions) (Bulk, error) {
opt = normalizeBulkOpenOptions(opt)
switch opt.Mode {
case BulkOpenModeDedicated:
opt.Dedicated = true
return s.openBulkLogicalWithMode(ctx, logical, opt)
case BulkOpenModeAuto:
if err := logicalDedicatedBulkSupportError(logical); err == nil {
dedicatedOpt := opt
dedicatedOpt.Mode = BulkOpenModeDedicated
dedicatedOpt.Dedicated = true
bulk, dedicatedErr := s.openBulkLogicalWithMode(ctx, logical, dedicatedOpt)
if dedicatedErr == nil {
return bulk, nil
}
sharedOpt := opt
sharedOpt.Mode = BulkOpenModeShared
sharedOpt.Dedicated = false
sharedBulk, sharedErr := s.openBulkLogicalWithMode(ctx, logical, sharedOpt)
if sharedErr == nil {
return sharedBulk, nil
}
return nil, errors.Join(dedicatedErr, sharedErr)
}
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return s.openBulkLogicalWithMode(ctx, logical, opt)
case BulkOpenModeShared, BulkOpenModeDefault:
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return s.openBulkLogicalWithMode(ctx, logical, opt)
default:
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return s.openBulkLogicalWithMode(ctx, logical, opt)
}
}
func (s *ServerCommon) openBulkLogicalWithMode(ctx context.Context, logical *LogicalConn, opt BulkOpenOptions) (Bulk, error) {
if s == nil { if s == nil {
return nil, errBulkServerNil return nil, errBulkServerNil
} }
opt = applyBulkOpenTuningDefaults(opt, s.bulkOpenTuningSnapshot())
if logical == nil { if logical == nil {
return nil, errBulkLogicalConnNil return nil, errBulkLogicalConnNil
} }
@@ -42,17 +84,44 @@ func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn
req.AttachToken = newBulkAttachToken() req.AttachToken = newBulkAttachToken()
} }
bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, logical.CurrentTransportConn(), logical.transportGenerationSnapshot(), serverBulkCloseSender(s, logical, nil), serverBulkResetSender(s, logical, nil), serverBulkDataSender(s, logical.CurrentTransportConn()), serverBulkWriteSender(s, logical, logical.CurrentTransportConn()), serverBulkReleaseSender(s, logical, logical.CurrentTransportConn())) bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, logical.CurrentTransportConn(), logical.transportGenerationSnapshot(), serverBulkCloseSender(s, logical, nil), serverBulkResetSender(s, logical, nil), serverBulkDataSender(s, logical.CurrentTransportConn()), serverBulkWriteSender(s, logical, logical.CurrentTransportConn()), serverBulkReleaseSender(s, logical, logical.CurrentTransportConn()))
bulk.markAcceptHandled()
if err := runtime.register(scope, bulk); err != nil { if err := runtime.register(scope, bulk); err != nil {
return nil, err return nil, err
} }
s.attachServerDedicatedSidecarIfExists(logical, bulk)
resp, err := sendBulkOpenServerLogical(ctx, s, logical, req) resp, err := sendBulkOpenServerLogical(ctx, s, logical, req)
if err != nil { if err != nil {
runtime.remove(scope, bulk.ID()) bulk.markReset(err)
return nil, err
}
if resp.DataID != 0 && resp.DataID != req.DataID {
err = errBulkAlreadyExists
_, _ = sendBulkResetServerLogical(context.Background(), s, logical, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: "bulk dedicated data id mismatch",
})
bulk.markReset(err)
return nil, err return nil, err
} }
if resp.TransportGeneration != 0 { if resp.TransportGeneration != 0 {
bulk.transportGeneration = resp.TransportGeneration bulk.transportGeneration = resp.TransportGeneration
} }
if resp.FastPathVersion != 0 {
bulk.fastPathVersion = normalizeBulkFastPathVersion(resp.FastPathVersion)
}
if resp.AttachToken != "" {
bulk.setDedicatedAttachToken(resp.AttachToken)
}
if err := bulk.waitAcceptReady(ctx); err != nil {
_, _ = sendBulkResetServerLogical(context.Background(), s, logical, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: err.Error(),
})
bulk.markReset(err)
return nil, err
}
return bulk, nil return bulk, nil
} }
resp, err := sendBulkOpenServerLogical(ctx, s, logical, req) resp, err := sendBulkOpenServerLogical(ctx, s, logical, req)
@@ -62,6 +131,9 @@ func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn
if resp.DataID != 0 { if resp.DataID != 0 {
req.DataID = resp.DataID req.DataID = resp.DataID
} }
if resp.FastPathVersion != 0 {
req.FastPathVersion = resp.FastPathVersion
}
req.Dedicated = resp.Dedicated req.Dedicated = resp.Dedicated
if resp.AttachToken != "" { if resp.AttachToken != "" {
req.AttachToken = resp.AttachToken req.AttachToken = resp.AttachToken
@@ -71,6 +143,7 @@ func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn
} }
transport := logical.CurrentTransportConn() transport := logical.CurrentTransportConn()
bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverBulkCloseSender(s, logical, nil), serverBulkResetSender(s, logical, nil), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport)) bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverBulkCloseSender(s, logical, nil), serverBulkResetSender(s, logical, nil), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport))
bulk.markAcceptHandled()
if err := runtime.register(scope, bulk); err != nil { if err := runtime.register(scope, bulk); err != nil {
_, _ = sendBulkResetServerLogical(context.Background(), s, logical, BulkResetRequest{ _, _ = sendBulkResetServerLogical(context.Background(), s, logical, BulkResetRequest{
BulkID: req.BulkID, BulkID: req.BulkID,
@@ -79,13 +152,53 @@ func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn
}) })
return nil, err return nil, err
} }
s.attachServerDedicatedSidecarIfExists(logical, bulk)
return bulk, nil return bulk, nil
} }
func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *TransportConn, opt BulkOpenOptions) (Bulk, error) { func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *TransportConn, opt BulkOpenOptions) (Bulk, error) {
opt = normalizeBulkOpenOptions(opt)
switch opt.Mode {
case BulkOpenModeDedicated:
opt.Dedicated = true
return s.openBulkTransportWithMode(ctx, transport, opt)
case BulkOpenModeAuto:
if err := transportDedicatedBulkSupportError(transport); err == nil {
dedicatedOpt := opt
dedicatedOpt.Mode = BulkOpenModeDedicated
dedicatedOpt.Dedicated = true
bulk, dedicatedErr := s.openBulkTransportWithMode(ctx, transport, dedicatedOpt)
if dedicatedErr == nil {
return bulk, nil
}
sharedOpt := opt
sharedOpt.Mode = BulkOpenModeShared
sharedOpt.Dedicated = false
sharedBulk, sharedErr := s.openBulkTransportWithMode(ctx, transport, sharedOpt)
if sharedErr == nil {
return sharedBulk, nil
}
return nil, errors.Join(dedicatedErr, sharedErr)
}
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return s.openBulkTransportWithMode(ctx, transport, opt)
case BulkOpenModeShared, BulkOpenModeDefault:
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return s.openBulkTransportWithMode(ctx, transport, opt)
default:
opt.Mode = BulkOpenModeShared
opt.Dedicated = false
return s.openBulkTransportWithMode(ctx, transport, opt)
}
}
func (s *ServerCommon) openBulkTransportWithMode(ctx context.Context, transport *TransportConn, opt BulkOpenOptions) (Bulk, error) {
if s == nil { if s == nil {
return nil, errBulkServerNil return nil, errBulkServerNil
} }
opt = applyBulkOpenTuningDefaults(opt, s.bulkOpenTuningSnapshot())
if transport == nil { if transport == nil {
return nil, errBulkTransportNil return nil, errBulkTransportNil
} }
@@ -118,17 +231,44 @@ func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *Transpo
req.AttachToken = newBulkAttachToken() req.AttachToken = newBulkAttachToken()
} }
bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, transport.TransportGeneration(), serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport)) bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, transport.TransportGeneration(), serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport))
bulk.markAcceptHandled()
if err := runtime.register(scope, bulk); err != nil { if err := runtime.register(scope, bulk); err != nil {
return nil, err return nil, err
} }
s.attachServerDedicatedSidecarIfExists(logical, bulk)
resp, err := sendBulkOpenServerTransport(ctx, s, transport, req) resp, err := sendBulkOpenServerTransport(ctx, s, transport, req)
if err != nil { if err != nil {
runtime.remove(scope, bulk.ID()) bulk.markReset(err)
return nil, err
}
if resp.DataID != 0 && resp.DataID != req.DataID {
err = errBulkAlreadyExists
_, _ = sendBulkResetServerTransport(context.Background(), s, transport, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: "bulk dedicated data id mismatch",
})
bulk.markReset(err)
return nil, err return nil, err
} }
if resp.TransportGeneration != 0 { if resp.TransportGeneration != 0 {
bulk.transportGeneration = resp.TransportGeneration bulk.transportGeneration = resp.TransportGeneration
} }
if resp.FastPathVersion != 0 {
bulk.fastPathVersion = normalizeBulkFastPathVersion(resp.FastPathVersion)
}
if resp.AttachToken != "" {
bulk.setDedicatedAttachToken(resp.AttachToken)
}
if err := bulk.waitAcceptReady(ctx); err != nil {
_, _ = sendBulkResetServerTransport(context.Background(), s, transport, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: err.Error(),
})
bulk.markReset(err)
return nil, err
}
return bulk, nil return bulk, nil
} }
resp, err := sendBulkOpenServerTransport(ctx, s, transport, req) resp, err := sendBulkOpenServerTransport(ctx, s, transport, req)
@@ -138,6 +278,9 @@ func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *Transpo
if resp.DataID != 0 { if resp.DataID != 0 {
req.DataID = resp.DataID req.DataID = resp.DataID
} }
if resp.FastPathVersion != 0 {
req.FastPathVersion = resp.FastPathVersion
}
req.Dedicated = resp.Dedicated req.Dedicated = resp.Dedicated
if resp.AttachToken != "" { if resp.AttachToken != "" {
req.AttachToken = resp.AttachToken req.AttachToken = resp.AttachToken
@@ -146,6 +289,7 @@ func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *Transpo
return nil, errBulkDataIDEmpty return nil, errBulkDataIDEmpty
} }
bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport)) bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport))
bulk.markAcceptHandled()
if err := runtime.register(scope, bulk); err != nil { if err := runtime.register(scope, bulk); err != nil {
_, _ = sendBulkResetServerTransport(context.Background(), s, transport, BulkResetRequest{ _, _ = sendBulkResetServerTransport(context.Background(), s, transport, BulkResetRequest{
BulkID: req.BulkID, BulkID: req.BulkID,
@@ -154,6 +298,7 @@ func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *Transpo
}) })
return nil, err return nil, err
} }
s.attachServerDedicatedSidecarIfExists(logical, bulk)
return bulk, nil return bulk, nil
} }
@@ -164,15 +309,16 @@ func serverBulkRequest(runtime *bulkRuntime, opt BulkOpenOptions) BulkOpenReques
id = runtime.nextID() id = runtime.nextID()
} }
return normalizeBulkOpenRequest(BulkOpenRequest{ return normalizeBulkOpenRequest(BulkOpenRequest{
BulkID: id, BulkID: id,
Range: opt.Range, FastPathVersion: bulkFastPathVersionCurrent,
Metadata: cloneBulkMetadata(opt.Metadata), Range: opt.Range,
ReadTimeout: opt.ReadTimeout, Metadata: cloneBulkMetadata(opt.Metadata),
WriteTimeout: opt.WriteTimeout, ReadTimeout: opt.ReadTimeout,
Dedicated: opt.Dedicated, WriteTimeout: opt.WriteTimeout,
ChunkSize: opt.ChunkSize, Dedicated: opt.Dedicated,
WindowBytes: opt.WindowBytes, ChunkSize: opt.ChunkSize,
MaxInFlight: opt.MaxInFlight, WindowBytes: opt.WindowBytes,
MaxInFlight: opt.MaxInFlight,
}) })
} }
@@ -247,12 +393,12 @@ func serverBulkDataSender(s *ServerCommon, transport *TransportConn) bulkDataSen
if dataID == 0 { if dataID == 0 {
return errBulkDataPathNotReady return errBulkDataPathNotReady
} }
return s.sendFastBulkDataTransport(ctx, bulk.LogicalConn(), transport, dataID, bulk.nextOutboundDataSeq(), chunk) return s.sendFastBulkDataTransport(ctx, bulk.LogicalConn(), transport, dataID, bulk.nextOutboundDataSeq(), chunk, bulk.fastPathVersionSnapshot())
} }
} }
func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *TransportConn) bulkWriteSender { func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *TransportConn) bulkWriteSender {
return func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) { return func(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) {
if s == nil { if s == nil {
return 0, errBulkServerNil return 0, errBulkServerNil
} }
@@ -267,7 +413,7 @@ func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *Tra
if err := bulk.waitDedicatedReady(ctx); err != nil { if err := bulk.waitDedicatedReady(ctx); err != nil {
return 0, err return 0, err
} }
return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, payload) return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, startSeq, payload, payloadOwned)
} }
if transport == nil { if transport == nil {
return 0, errBulkTransportNil return 0, errBulkTransportNil
@@ -275,7 +421,14 @@ func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *Tra
if !transport.IsCurrent() { if !transport.IsCurrent() {
return 0, errTransportDetached return 0, errTransportDetached
} }
return 0, nil if bulk == nil {
return 0, errBulkRuntimeNil
}
dataID := bulk.dataIDSnapshot()
if dataID == 0 {
return 0, errBulkDataPathNotReady
}
return s.sendFastBulkWriteTransport(ctx, bulk.LogicalConn(), transport, dataID, startSeq, bulk.chunkSize, bulk.fastPathVersionSnapshot(), payload, payloadOwned)
} }
} }
@@ -287,8 +440,20 @@ func serverBulkReleaseSender(s *ServerCommon, logical *LogicalConn, transport *T
if bytes <= 0 && chunks <= 0 { if bytes <= 0 && chunks <= 0 {
return nil return nil
} }
ctx, cancel, err := bulk.newWriteContext(bulk.Context(), bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
if bulk.Dedicated() { if bulk.Dedicated() {
return s.sendDedicatedBulkRelease(context.Background(), logical, bulk, bytes, chunks) return s.sendDedicatedBulkRelease(ctx, logical, bulk, bytes, chunks)
}
if transport != nil && transport.IsCurrent() && bulk.fastPathVersionSnapshot() >= bulkFastPathVersionV2 {
payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks)
if err != nil {
return err
}
return s.sendFastBulkControlTransport(ctx, logical, transport, bulkFastPayloadTypeRelease, 0, bulk.dataIDSnapshot(), 0, bulk.fastPathVersionSnapshot(), payload)
} }
req := BulkReleaseRequest{ req := BulkReleaseRequest{
BulkID: bulk.ID(), BulkID: bulk.ID(),
+31 -9
View File
@@ -31,20 +31,28 @@ func (s *ServerCommon) Stop() error {
// Deprecated: SetDefaultCommEncode overrides the transport codec directly. // Deprecated: SetDefaultCommEncode overrides the transport codec directly.
// Prefer UseModernPSKServer or UseLegacySecurityServer. // Prefer UseModernPSKServer or UseLegacySecurityServer.
func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) { func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) {
s.defaultMsgEn = fn profile := transportProtectionProfile{
s.defaultFastStreamEncode = nil mode: ProtectionManaged,
s.defaultFastBulkEncode = nil msgEn: fn,
s.defaultFastPlainEncode = nil msgDe: s.defaultMsgDe,
secretKey: s.SecretKey,
}
s.setServerDefaultTransportProtectionProfile(profile)
s.clearServerSecurityProfiles()
s.securityReadyCheck = false s.securityReadyCheck = false
} }
// Deprecated: SetDefaultCommDecode overrides the transport codec directly. // Deprecated: SetDefaultCommDecode overrides the transport codec directly.
// Prefer UseModernPSKServer or UseLegacySecurityServer. // Prefer UseModernPSKServer or UseLegacySecurityServer.
func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) { func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) {
s.defaultMsgDe = fn profile := transportProtectionProfile{
s.defaultFastStreamEncode = nil mode: ProtectionManaged,
s.defaultFastBulkEncode = nil msgEn: s.defaultMsgEn,
s.defaultFastPlainEncode = nil msgDe: fn,
secretKey: s.SecretKey,
}
s.setServerDefaultTransportProtectionProfile(profile)
s.clearServerSecurityProfiles()
s.securityReadyCheck = false s.securityReadyCheck = false
} }
@@ -96,7 +104,21 @@ func (s *ServerCommon) GetSecretKey() []byte {
// Deprecated: SetSecretKey injects a raw transport key directly. // Deprecated: SetSecretKey injects a raw transport key directly.
// Prefer UseModernPSKServer or UseLegacySecurityServer. // Prefer UseModernPSKServer or UseLegacySecurityServer.
func (s *ServerCommon) SetSecretKey(key []byte) { func (s *ServerCommon) SetSecretKey(key []byte) {
s.SecretKey = key profile := transportProtectionProfile{
mode: ProtectionManaged,
msgEn: s.defaultMsgEn,
msgDe: s.defaultMsgDe,
secretKey: cloneTransportProtectionKey(key),
}
if len(key) == 0 {
profile.runtime = nil
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
profile.runtime = runtime
} else {
profile.runtime = nil
}
s.setServerDefaultTransportProtectionProfile(profile)
s.clearServerSecurityProfiles()
s.securityReadyCheck = len(key) == 0 s.securityReadyCheck = len(key) == 0
} }
+100
View File
@@ -3,12 +3,54 @@ package notify
import ( import (
"b612.me/stario" "b612.me/stario"
"context" "context"
"errors"
"math" "math"
"net" "net"
"testing" "testing"
"time" "time"
) )
func readServerEnvelopeFromConnWithProfile(t *testing.T, server *ServerCommon, profile transportProtectionProfile, conn net.Conn, timeout time.Duration) Envelope {
t.Helper()
queue := stario.NewQueue()
deadline := time.Now().Add(timeout)
buf := make([]byte, 4096)
for time.Now().Before(deadline) {
if err := conn.SetReadDeadline(deadline); err != nil {
t.Fatalf("SetReadDeadline failed: %v", err)
}
n, err := conn.Read(buf)
if n > 0 {
if parseErr := queue.ParseMessage(buf[:n], "server-inbound-profile"); parseErr != nil {
t.Fatalf("ParseMessage failed: %v", parseErr)
}
select {
case msg := <-queue.RestoreChan():
plain, decErr := decryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, msg.Msg)
if decErr != nil {
t.Fatalf("decryptTransportPayloadCodec failed: %v", decErr)
}
env, decErr := server.decodeEnvelopePlain(plain)
if decErr != nil {
t.Fatalf("decodeEnvelopePlain failed: %v", decErr)
}
return env
default:
}
}
if err == nil {
continue
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
break
}
t.Fatalf("conn Read failed: %v", err)
}
t.Fatal("timed out waiting for server envelope")
return Envelope{}
}
func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) { func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) {
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server) UseLegacySecurityServer(server)
@@ -85,6 +127,64 @@ func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) {
} }
} }
func TestMessageReplyUsesInboundProtectionSnapshotAfterLogicalSwitch(t *testing.T) {
secret := []byte("correct horse battery staple")
alternate, err := deriveModernPSKProtectionProfile([]byte("notify-reply-snapshot-alternate"), testModernPSKOptions(), ProtectionManaged)
if err != nil {
t.Fatalf("deriveModernPSKProtectionProfile failed: %v", err)
}
handlerErr := make(chan error, 1)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
server.SetLink("reply-snapshot", func(msg *Message) {
if msg == nil || msg.LogicalConn == nil {
select {
case handlerErr <- errors.New("reply-snapshot logical is nil"):
default:
}
return
}
msg.LogicalConn.applyTransportProtectionProfile(alternate)
if err := msg.Reply([]byte("ack")); err != nil {
select {
case handlerErr <- err:
default:
}
}
})
})
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachConnForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
reply, err := client.SendWait("reply-snapshot", []byte("ping"), time.Second)
if err != nil {
t.Fatalf("SendWait failed: %v", err)
}
if got, want := string(reply.Value), "ack"; got != want {
t.Fatalf("reply value = %q, want %q", got, want)
}
select {
case err := <-handlerErr:
t.Fatalf("reply-snapshot handler failed: %v", err)
default:
}
}
func TestServerHandleReceivedSignalReliabilityUsesInboundConnForStaleTransport(t *testing.T) { func TestServerHandleReceivedSignalReliabilityUsesInboundConnForStaleTransport(t *testing.T) {
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server) UseLegacySecurityServer(server)
+63 -20
View File
@@ -59,26 +59,8 @@ func (s *ServerCommon) pushMessageSourceFast(queue *stario.StarQueue, data []byt
if queue == nil || dispatcher == nil || len(data) == 0 { if queue == nil || dispatcher == nil || len(data) == 0 {
return false return false
} }
if err := queue.ParseMessageOwned(data, source, func(msg stario.MsgQueue) error { if err := queue.ParseMessageView(data, source, func(frame stario.FrameView) error {
payload := msg.Msg s.pushTransportPayloadSourceFast(frame.Payload, nil, frame.Conn)
source := msg.Conn
s.wg.Add(1)
if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() {
defer s.wg.Done()
logical, transport := s.resolveInboundSource(source)
if logical == nil {
return
}
now := time.Now()
inboundConn := serverInboundConn(source)
if err := s.dispatchInboundTransportPayload(logical, transport, inboundConn, payload, now); err != nil {
if s.showError || s.debugMode {
fmt.Println("server decode envelope error", err)
}
}
}) {
s.wg.Done()
}
return nil return nil
}); err != nil && (s.showError || s.debugMode) { }); err != nil && (s.showError || s.debugMode) {
fmt.Println("server parse inbound frame error", err) fmt.Println("server parse inbound frame error", err)
@@ -86,6 +68,67 @@ func (s *ServerCommon) pushMessageSourceFast(queue *stario.StarQueue, data []byt
return true return true
} }
func (s *ServerCommon) pushTransportPayloadSourceFast(payload []byte, release func(), source interface{}) bool {
dispatcher := s.serverInboundDispatcherSnapshot()
if len(payload) == 0 {
if release != nil {
release()
}
return false
}
if dispatcher == nil {
queue := s.serverQueueSnapshot()
if queue == nil {
if release != nil {
release()
}
return false
}
frame := queue.BuildMessage(payload)
if release != nil {
release()
}
if err := queue.ParseMessage(frame, source); err != nil && (s.showError || s.debugMode) {
fmt.Println("server enqueue inbound frame error", err)
}
return true
}
logical, transport := s.resolveInboundSource(source)
if logical == nil {
if release != nil {
release()
}
return true
}
plain, plainRelease, err := s.decryptTransportPayloadLogicalPooled(logical, payload, release)
if err != nil {
if s.showError || s.debugMode {
fmt.Println("server decode transport payload error", err)
}
return true
}
inboundConn := serverInboundConn(source)
if s.tryDispatchBorrowedTransportPlain(logical, transport, inboundConn, plain, plainRelease) {
return true
}
owned := plain
if plainRelease != nil {
owned = append([]byte(nil), plain...)
plainRelease()
}
s.wg.Add(1)
if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() {
defer s.wg.Done()
now := time.Now()
if err := s.dispatchInboundTransportPlain(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) {
fmt.Println("server decode envelope error", err)
}
}) {
s.wg.Done()
}
return true
}
func serverInboundConn(source interface{}) net.Conn { func serverInboundConn(source interface{}) net.Conn {
switch data := source.(type) { switch data := source.(type) {
case net.Conn: case net.Conn:
+1
View File
@@ -130,6 +130,7 @@ func (s *ServerCommon) removeLogical(logical *LogicalConn) {
s.getFileAckPool().closeScopeFamily(scope) s.getFileAckPool().closeScopeFamily(scope)
s.getSignalAckPool().closeScopeFamily(scope) s.getSignalAckPool().closeScopeFamily(scope)
s.getReceivedSignalCache().closeScope(scope) s.getReceivedSignalCache().closeScope(scope)
s.closeServerDedicatedSidecar(logical)
s.getPeerRegistry().removeLogical(logical) s.getPeerRegistry().removeLogical(logical)
} }
+3
View File
@@ -24,6 +24,7 @@ func (s *ServerCommon) OpenRecordStreamLogical(ctx context.Context, logical *Log
_ = stream.Reset(err) _ = stream.Reset(err)
return nil, err return nil, err
} }
bindRecordRuntime(record, s.getRecordRuntime())
return record, nil return record, nil
} }
@@ -41,6 +42,7 @@ func (s *ServerCommon) OpenRecordStreamTransport(ctx context.Context, transport
_ = stream.Reset(err) _ = stream.Reset(err)
return nil, err return nil, err
} }
bindRecordRuntime(record, s.getRecordRuntime())
return record, nil return record, nil
} }
@@ -68,6 +70,7 @@ func (s *ServerCommon) claimInboundRecordStream(logical *LogicalConn, transport
if err != nil { if err != nil {
return true, err return true, err
} }
bindRecordRuntime(record, runtime)
info := RecordAcceptInfo{ info := RecordAcceptInfo{
ID: stream.ID(), ID: stream.ID(),
Metadata: stream.Metadata(), Metadata: stream.Metadata(),
+19 -4
View File
@@ -358,16 +358,20 @@ func (s *ServerCommon) sendEnvelopeTransport(transport *TransportConn, env Envel
} }
func (s *ServerCommon) sendEnvelopeInboundTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope) error { func (s *ServerCommon) sendEnvelopeInboundTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope) error {
return s.sendEnvelopeInboundTransportWithProfile(logical, transport, conn, nil, env)
}
func (s *ServerCommon) sendEnvelopeInboundTransportWithProfile(logical *LogicalConn, transport *TransportConn, conn net.Conn, profile *transportProtectionProfile, env Envelope) error {
if logical == nil && transport != nil { if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot() logical = transport.logicalConnSnapshot()
} }
if logical == nil { if logical == nil {
return transportDetachedErrorForPeer(logical, transport) return transportDetachedErrorForPeer(logical, transport)
} }
if logical.msgEnSnapshot() == nil { if profile == nil && logical.msgEnSnapshot() == nil {
return transportDetachedErrorForPeer(logical, transport) return transportDetachedErrorForPeer(logical, transport)
} }
payload, err := s.encodeEnvelopePayloadLogical(logical, env) payload, err := s.encodeEnvelopePayloadInbound(logical, env, profile)
if err != nil { if err != nil {
return err return err
} }
@@ -402,7 +406,18 @@ func (s *ServerCommon) writeControlEnvelopePayload(logical *LogicalConn, transpo
return sender.submit(payload, writeDeadlineFromTimeout(logical.maxWriteTimeoutSnapshot())) return sender.submit(payload, writeDeadlineFromTimeout(logical.maxWriteTimeoutSnapshot()))
} }
func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, msg TransferMsg) error { func (s *ServerCommon) encodeEnvelopePayloadInbound(logical *LogicalConn, env Envelope, profile *transportProtectionProfile) ([]byte, error) {
if profile == nil {
return s.encodeEnvelopePayloadLogical(logical, env)
}
data, err := s.encodeEnvelopePlain(env)
if err != nil {
return nil, err
}
return encryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgEn, profile.secretKey, data)
}
func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, profile *transportProtectionProfile, msg TransferMsg) error {
if logical == nil && transport != nil { if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot() logical = transport.logicalConnSnapshot()
} }
@@ -413,7 +428,7 @@ func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *Tran
if err != nil { if err != nil {
return err return err
} }
return s.sendEnvelopeInboundTransport(logical, transport, conn, env) return s.sendEnvelopeInboundTransportWithProfile(logical, transport, conn, profile, env)
} }
func (s *ServerCommon) writeEnvelopePayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte) error { func (s *ServerCommon) writeEnvelopePayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte) error {
+13
View File
@@ -57,6 +57,7 @@ func (s *ServerCommon) detachClientSessionTransport(client *ClientConn, reason s
if runtime := s.getBulkRuntime(); runtime != nil { if runtime := s.getBulkRuntime(); runtime != nil {
runtime.closeScope(serverFileScope(client), errTransportDetached) runtime.closeScope(serverFileScope(client), errTransportDetached)
} }
s.closeServerDedicatedSidecar(logicalConnFromClient(client))
client.detachServerOwnedTransport() client.detachServerOwnedTransport()
} }
@@ -80,6 +81,7 @@ func (s *ServerCommon) detachLogicalSessionTransport(logical *LogicalConn, reaso
if runtime := s.getBulkRuntime(); runtime != nil { if runtime := s.getBulkRuntime(); runtime != nil {
runtime.closeScope(serverFileScope(logical), errTransportDetached) runtime.closeScope(serverFileScope(logical), errTransportDetached)
} }
s.closeServerDedicatedSidecar(logical)
logical.detachServerOwnedTransport() logical.detachServerOwnedTransport()
} }
@@ -108,6 +110,17 @@ func (s *ServerCommon) registerAcceptedLogical(logical *LogicalConn) *LogicalCon
} }
logical.setServer(s) logical.setServer(s)
logical.applyAttachmentProfile(s.maxReadTimeout, s.maxWriteTimeout, s.defaultMsgEn, s.defaultMsgDe, s.defaultFastStreamEncode, s.defaultFastBulkEncode, s.defaultFastPlainEncode, s.handshakeRsaKey, s.SecretKey) logical.applyAttachmentProfile(s.maxReadTimeout, s.maxWriteTimeout, s.defaultMsgEn, s.defaultMsgDe, s.defaultFastStreamEncode, s.defaultFastBulkEncode, s.defaultFastPlainEncode, s.handshakeRsaKey, s.SecretKey)
logical.updateAttachmentState(func(state *clientConnAttachmentState) {
state.authMode = s.securityAuthMode
state.peerAttached = false
state.peerAttachFallback = false
state.peerAttachAt = 0
})
if s.securityConfigured {
logical.applyTransportProtectionProfile(s.securityBootstrap)
} else {
logical.setModernPSKRuntime(s.defaultModernPSKRuntime)
}
logical.markHeartbeatNow() logical.markHeartbeatNow()
return s.getPeerRegistry().registerLogical(logical) return s.getPeerRegistry().registerLogical(logical)
} }
+19 -6
View File
@@ -33,6 +33,12 @@ func (s *ServerCommon) OpenStreamLogical(ctx context.Context, logical *LogicalCo
if resp.DataID != 0 { if resp.DataID != 0 {
req.DataID = resp.DataID req.DataID = resp.DataID
} }
if resp.FastPathVersion != 0 {
req.FastPathVersion = resp.FastPathVersion
} else {
req.FastPathVersion = streamFastPathVersionV1
}
req.Metadata = mergeStreamMetadata(req.Metadata, resp.Metadata)
transport := logical.CurrentTransportConn() transport := logical.CurrentTransportConn()
stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverStreamCloseSender(s, logical, nil), serverStreamResetSender(s, logical, nil), serverStreamDataSender(s, transport), runtime.configSnapshot()) stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverStreamCloseSender(s, logical, nil), serverStreamResetSender(s, logical, nil), serverStreamDataSender(s, transport), runtime.configSnapshot())
if err := runtime.register(scope, stream); err != nil { if err := runtime.register(scope, stream); err != nil {
@@ -72,6 +78,12 @@ func (s *ServerCommon) OpenStreamTransport(ctx context.Context, transport *Trans
if resp.DataID != 0 { if resp.DataID != 0 {
req.DataID = resp.DataID req.DataID = resp.DataID
} }
if resp.FastPathVersion != 0 {
req.FastPathVersion = resp.FastPathVersion
} else {
req.FastPathVersion = streamFastPathVersionV1
}
req.Metadata = mergeStreamMetadata(req.Metadata, resp.Metadata)
stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverStreamCloseSender(s, logical, transport), serverStreamResetSender(s, logical, transport), serverStreamDataSender(s, transport), runtime.configSnapshot()) stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverStreamCloseSender(s, logical, transport), serverStreamResetSender(s, logical, transport), serverStreamDataSender(s, transport), runtime.configSnapshot())
if err := runtime.register(scope, stream); err != nil { if err := runtime.register(scope, stream); err != nil {
_, _ = sendStreamResetServerTransport(context.Background(), s, transport, StreamResetRequest{ _, _ = sendStreamResetServerTransport(context.Background(), s, transport, StreamResetRequest{
@@ -89,11 +101,12 @@ func serverStreamRequest(runtime *streamRuntime, opt StreamOpenOptions) StreamOp
id = runtime.nextID() id = runtime.nextID()
} }
return normalizeStreamOpenRequest(StreamOpenRequest{ return normalizeStreamOpenRequest(StreamOpenRequest{
StreamID: id, StreamID: id,
Channel: opt.Channel, FastPathVersion: streamFastPathVersionCurrent,
Metadata: cloneStreamMetadata(opt.Metadata), Channel: opt.Channel,
ReadTimeout: opt.ReadTimeout, Metadata: cloneStreamMetadata(opt.Metadata),
WriteTimeout: opt.WriteTimeout, ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
}) })
} }
@@ -146,7 +159,7 @@ func serverStreamDataSender(s *ServerCommon, transport *TransportConn) streamDat
} }
} }
if dataID := stream.dataIDSnapshot(); dataID != 0 { if dataID := stream.dataIDSnapshot(); dataID != 0 {
return s.sendFastStreamDataTransport(stream.LogicalConn(), transport, dataID, stream.nextOutboundDataSeq(), chunk) return s.sendFastStreamDataTransport(ctx, stream.LogicalConn(), transport, stream, chunk)
} }
return s.sendEnvelopeTransport(transport, newStreamDataEnvelope(stream.ID(), chunk)) return s.sendEnvelopeTransport(transport, newStreamDataEnvelope(stream.ID(), chunk))
} }
+4
View File
@@ -24,6 +24,10 @@ type Server interface {
SetStreamConfig(StreamConfig) SetStreamConfig(StreamConfig)
SetTransferResumeStore(TransferResumeStore) SetTransferResumeStore(TransferResumeStore)
RecoverTransferSnapshots(context.Context) error RecoverTransferSnapshots(context.Context) error
SetBulkOpenTuning(BulkOpenTuning)
BulkOpenTuning() BulkOpenTuning
SetPeerAttachSecurityConfig(PeerAttachSecurityConfig) error
PeerAttachSecurityConfig() PeerAttachSecurityConfig
SetFileReceiveDir(dir string) error SetFileReceiveDir(dir string) error
send(c *ClientConn, msg TransferMsg) (WaitMsg, error) send(c *ClientConn, msg TransferMsg) (WaitMsg, error)
sendEnvelope(c *ClientConn, env Envelope) error sendEnvelope(c *ClientConn, env Envelope) error
+185 -51
View File
@@ -1,68 +1,130 @@
package notify package notify
import ( import (
"encoding/hex"
"errors" "errors"
"time" "time"
) )
type ClientRuntimeSnapshot struct { type ClientRuntimeSnapshot struct {
OwnerState string OwnerState string
Alive bool Alive bool
SessionEpoch uint64 SessionEpoch uint64
TransportAttached bool TransportAttached bool
HasRuntimeConn bool HasRuntimeConn bool
HasRuntimeQueue bool HasRuntimeQueue bool
HasRuntimeStopCtx bool HasRuntimeStopCtx bool
ConnectSource string ConnectSource string
ConnectNetwork string ConnectNetwork string
ConnectAddress string ConnectAddress string
CanReconnect bool CanReconnect bool
Retry ConnectionRetrySnapshot AuthMode string
ProtectionMode string
ProtectionKeyMode string
ForwardSecrecyEnabled bool
ForwardSecrecyFallback bool
ForwardSecrecyRequired bool
TransportSessionID string
PeerAttachAuthenticated bool
PeerAttachAuthFallback bool
LastPeerAttachAt time.Time
PeerAttachRequireExplicitAuth bool
PeerAttachRequireChannelBinding bool
PeerAttachChannelBindingConfigured bool
BulkNetworkProfile string
BulkDefaultMode string
BulkChunkSize int
BulkWindowBytes int
BulkMaxInFlight int
BulkAttachLimit int
BulkActiveLimit int
BulkLaneLimit int
BulkAttachAttempts int64
BulkAttachRetries int64
BulkAttachSuccess int64
BulkAutoFallbacks int64
TransportBulkAdaptiveSoftPayloadBytes int
TransportStreamAdaptiveSoftPayloadBytes int
TransportStreamAdaptiveWaitThresholdBytes int
TransportStreamAdaptiveFlushDelay time.Duration
Retry ConnectionRetrySnapshot
} }
type ServerRuntimeSnapshot struct { type ServerRuntimeSnapshot struct {
OwnerState string OwnerState string
Alive bool Alive bool
ClientCount int ClientCount int
DetachedClientCount int DetachedClientCount int
DetachedReattachableClientCount int DetachedReattachableClientCount int
DetachedExpiredClientCount int DetachedExpiredClientCount int
DetachedClientKeepSec int64 DetachedClientKeepSec int64
TransportAttached bool TransportAttached bool
HasRuntimeListener bool HasRuntimeListener bool
HasRuntimeUDPListener bool HasRuntimeUDPListener bool
HasRuntimeQueue bool HasRuntimeQueue bool
HasRuntimeStopCtx bool HasRuntimeStopCtx bool
Retry ConnectionRetrySnapshot AuthMode string
ProtectionMode string
ForwardSecrecySupported bool
ForwardSecrecyRequired bool
PeerAttachRequireExplicitAuth bool
PeerAttachRequireChannelBinding bool
PeerAttachChannelBindingConfigured bool
PeerAttachReplayWindow time.Duration
PeerAttachReplayCapacity int
PeerAttachExplicitAuth int64
PeerAttachAuthFallbacks int64
PeerAttachAuthRejects int64
PeerAttachDowngradeRejects int64
PeerAttachBindingRejects int64
PeerAttachReplayRejects int64
PeerAttachReplayOverflowRejects int64
BulkChunkSize int
BulkWindowBytes int
BulkMaxInFlight int
Retry ConnectionRetrySnapshot
} }
type ClientConnRuntimeSnapshot struct { type ClientConnRuntimeSnapshot struct {
ClientID string ClientID string
RemoteAddress string RemoteAddress string
Alive bool Alive bool
Reason string Reason string
Error string Error string
IdentityBound bool IdentityBound bool
UsesStreamTransport bool UsesStreamTransport bool
TransportGeneration uint64 TransportGeneration uint64
TransportAttached bool TransportAttached bool
HasRuntimeConn bool HasRuntimeConn bool
HasRuntimeStopCtx bool HasRuntimeStopCtx bool
TransportAttachCount uint64 TransportAttachCount uint64
TransportDetachCount uint64 TransportDetachCount uint64
LastTransportAttachAt time.Time LastTransportAttachAt time.Time
DetachedClientKeepSec int64 DetachedClientKeepSec int64
LastHeartbeatAt time.Time LastHeartbeatAt time.Time
TransportDetachReason string TransportDetachReason string
TransportDetachKind string TransportDetachKind string
TransportDetachGeneration uint64 TransportDetachGeneration uint64
TransportDetachError string TransportDetachError string
TransportDetachedAt time.Time TransportDetachedAt time.Time
TransportDetachHasExpiry bool TransportDetachHasExpiry bool
TransportDetachExpiry time.Time TransportDetachExpiry time.Time
TransportDetachRemaining time.Duration TransportDetachRemaining time.Duration
TransportDetachExpired bool TransportDetachExpired bool
ReattachEligible bool ReattachEligible bool
AuthMode string
ProtectionMode string
ProtectionKeyMode string
ForwardSecrecyEnabled bool
ForwardSecrecyFallback bool
TransportSessionID string
PeerAttachAuthenticated bool
PeerAttachAuthFallback bool
LastPeerAttachAt time.Time
TransportBulkAdaptiveSoftPayloadBytes int
TransportStreamAdaptiveSoftPayloadBytes int
TransportStreamAdaptiveWaitThresholdBytes int
TransportStreamAdaptiveFlushDelay time.Duration
} }
func (c *ClientCommon) clientRuntimeSnapshot() ClientRuntimeSnapshot { func (c *ClientCommon) clientRuntimeSnapshot() ClientRuntimeSnapshot {
@@ -85,6 +147,39 @@ func (c *ClientCommon) clientRuntimeSnapshot() ClientRuntimeSnapshot {
snapshot.ConnectAddress = source.addr snapshot.ConnectAddress = source.addr
snapshot.CanReconnect = source.canReconnect() snapshot.CanReconnect = source.canReconnect()
} }
snapshot.AuthMode = authModeName(c.securityAuthMode)
snapshot.ProtectionMode = protectionModeName(c.securityProtectionMode)
protection := c.clientTransportProtectionSnapshot()
snapshot.ProtectionKeyMode = protection.keyMode
snapshot.ForwardSecrecyEnabled = protection.forwardSecrecy
snapshot.ForwardSecrecyFallback = protection.forwardSecrecyFallback
snapshot.ForwardSecrecyRequired = c.clientRequiresForwardSecrecy()
snapshot.TransportSessionID = hex.EncodeToString(protection.sessionID)
snapshot.PeerAttachAuthenticated, snapshot.PeerAttachAuthFallback, snapshot.LastPeerAttachAt = c.clientPeerAttachAuthSnapshot()
peerAttachCfg := c.peerAttachSecuritySnapshot()
snapshot.PeerAttachRequireExplicitAuth = peerAttachCfg.requireExplicitAuth
snapshot.PeerAttachRequireChannelBinding = peerAttachCfg.requireChannelBinding
snapshot.PeerAttachChannelBindingConfigured = peerAttachCfg.channelBinding != nil
snapshot.BulkNetworkProfile = bulkNetworkProfileName(c.BulkNetworkProfile())
snapshot.BulkDefaultMode = bulkOpenModeName(c.BulkDefaultOpenMode())
tuning := c.BulkOpenTuning()
cfg := c.BulkDedicatedAttachConfig()
snapshot.BulkChunkSize = tuning.ChunkSize
snapshot.BulkWindowBytes = tuning.WindowBytes
snapshot.BulkMaxInFlight = tuning.MaxInFlight
snapshot.BulkAttachLimit = cfg.AttachLimit
snapshot.BulkActiveLimit = cfg.ActiveLimit
snapshot.BulkLaneLimit = cfg.LaneLimit
snapshot.BulkAttachAttempts = c.bulkAttachAttemptCount.Load()
snapshot.BulkAttachRetries = c.bulkAttachRetryCount.Load()
snapshot.BulkAttachSuccess = c.bulkAttachSuccessCount.Load()
snapshot.BulkAutoFallbacks = c.bulkAttachFallbackCount.Load()
if binding := c.clientTransportBindingSnapshot(); binding != nil {
snapshot.TransportBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot()
snapshot.TransportStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot()
snapshot.TransportStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot()
snapshot.TransportStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot()
}
snapshot.Retry = c.connectionRetrySnapshot() snapshot.Retry = c.connectionRetrySnapshot()
return snapshot return snapshot
} }
@@ -118,12 +213,34 @@ func (s *ServerCommon) serverRuntimeSnapshot() ServerRuntimeSnapshot {
snapshot.HasRuntimeQueue = rt.queue != nil snapshot.HasRuntimeQueue = rt.queue != nil
snapshot.HasRuntimeStopCtx = rt.stopCtx != nil snapshot.HasRuntimeStopCtx = rt.stopCtx != nil
} }
snapshot.AuthMode = authModeName(s.securityAuthMode)
snapshot.ProtectionMode = protectionModeName(s.securityProtectionMode)
snapshot.ForwardSecrecySupported = s.serverSupportsForwardSecrecy()
snapshot.ForwardSecrecyRequired = s.serverRequiresForwardSecrecy()
peerAttachCfg := s.peerAttachSecuritySnapshot()
snapshot.PeerAttachRequireExplicitAuth = peerAttachCfg.requireExplicitAuth
snapshot.PeerAttachRequireChannelBinding = peerAttachCfg.requireChannelBinding
snapshot.PeerAttachChannelBindingConfigured = peerAttachCfg.channelBinding != nil
snapshot.PeerAttachReplayWindow = peerAttachCfg.replayWindow
snapshot.PeerAttachReplayCapacity = peerAttachCfg.replayCapacity
snapshot.PeerAttachExplicitAuth = s.peerAttachExplicitCount.Load()
snapshot.PeerAttachAuthFallbacks = s.peerAttachAuthFallbackCount.Load()
snapshot.PeerAttachAuthRejects = s.peerAttachAuthRejectCount.Load()
snapshot.PeerAttachDowngradeRejects = s.peerAttachDowngradeRejectCount.Load()
snapshot.PeerAttachBindingRejects = s.peerAttachBindingRejectCount.Load()
snapshot.PeerAttachReplayRejects = s.peerAttachReplayRejectCountSnapshot()
snapshot.PeerAttachReplayOverflowRejects = s.peerAttachReplayOverflowRejectCountSnapshot()
tuning := s.BulkOpenTuning()
snapshot.BulkChunkSize = tuning.ChunkSize
snapshot.BulkWindowBytes = tuning.WindowBytes
snapshot.BulkMaxInFlight = tuning.MaxInFlight
snapshot.Retry = s.connectionRetrySnapshot() snapshot.Retry = s.connectionRetrySnapshot()
return snapshot return snapshot
} }
func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot { func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot {
status := c.clientConnStatusSnapshot() status := c.clientConnStatusSnapshot()
attachment := c.clientConnAttachmentStateSnapshot()
now := time.Now() now := time.Now()
snapshot := ClientConnRuntimeSnapshot{ snapshot := ClientConnRuntimeSnapshot{
ClientID: c.clientConnIDSnapshot(), ClientID: c.clientConnIDSnapshot(),
@@ -135,6 +252,17 @@ func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot {
TransportAttachCount: c.clientConnTransportAttachCountSnapshot(), TransportAttachCount: c.clientConnTransportAttachCountSnapshot(),
TransportDetachCount: c.clientConnTransportDetachCountSnapshot(), TransportDetachCount: c.clientConnTransportDetachCountSnapshot(),
LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(), LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(),
AuthMode: authModeName(attachment.authMode),
ProtectionMode: protectionModeName(attachment.protectionMode),
}
snapshot.PeerAttachAuthenticated = attachment.peerAttached
snapshot.PeerAttachAuthFallback = attachment.peerAttachFallback
snapshot.ProtectionKeyMode = attachment.keyMode
snapshot.ForwardSecrecyEnabled = attachment.forwardSecrecy
snapshot.ForwardSecrecyFallback = attachment.forwardSecrecyFallback
snapshot.TransportSessionID = hex.EncodeToString(attachment.sessionID)
if attachment.peerAttachAt != 0 {
snapshot.LastPeerAttachAt = time.Unix(0, attachment.peerAttachAt)
} }
if status.Err != nil { if status.Err != nil {
snapshot.Error = status.Err.Error() snapshot.Error = status.Err.Error()
@@ -153,6 +281,12 @@ func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot {
snapshot.HasRuntimeConn = c.clientConnTransportSnapshot() != nil snapshot.HasRuntimeConn = c.clientConnTransportSnapshot() != nil
snapshot.HasRuntimeStopCtx = rt.stopCtx != nil snapshot.HasRuntimeStopCtx = rt.stopCtx != nil
} }
if binding := c.clientConnTransportBindingSnapshot(); binding != nil {
snapshot.TransportBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot()
snapshot.TransportStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot()
snapshot.TransportStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot()
snapshot.TransportStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot()
}
if detach := c.clientConnTransportDetachSnapshot(); detach != nil { if detach := c.clientConnTransportDetachSnapshot(); detach != nil {
snapshot.TransportDetachReason = detach.Reason snapshot.TransportDetachReason = detach.Reason
snapshot.TransportDetachKind = c.clientConnTransportDetachKindSnapshot() snapshot.TransportDetachKind = c.clientConnTransportDetachKindSnapshot()
+333
View File
@@ -37,6 +37,60 @@ func TestGetClientRuntimeSnapshotDefaults(t *testing.T) {
if snapshot.ConnectSource != "" || snapshot.ConnectNetwork != "" || snapshot.ConnectAddress != "" || snapshot.CanReconnect { if snapshot.ConnectSource != "" || snapshot.ConnectNetwork != "" || snapshot.ConnectAddress != "" || snapshot.CanReconnect {
t.Fatalf("unexpected default connect source snapshot: %+v", snapshot) t.Fatalf("unexpected default connect source snapshot: %+v", snapshot)
} }
if got, want := snapshot.AuthMode, "none"; got != want {
t.Fatalf("AuthMode mismatch: got %q want %q", got, want)
}
if got, want := snapshot.ProtectionMode, "managed"; got != want {
t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want)
}
if snapshot.PeerAttachAuthenticated || snapshot.PeerAttachAuthFallback {
t.Fatalf("unexpected default peer attach state: %+v", snapshot)
}
if !snapshot.LastPeerAttachAt.IsZero() {
t.Fatalf("LastPeerAttachAt mismatch: got %v want zero", snapshot.LastPeerAttachAt)
}
if snapshot.PeerAttachRequireExplicitAuth || snapshot.PeerAttachRequireChannelBinding || snapshot.PeerAttachChannelBindingConfigured {
t.Fatalf("unexpected default peer attach policy: %+v", snapshot)
}
if got, want := snapshot.BulkNetworkProfile, "default"; got != want {
t.Fatalf("BulkNetworkProfile mismatch: got %q want %q", got, want)
}
if got, want := snapshot.BulkDefaultMode, "shared"; got != want {
t.Fatalf("BulkDefaultMode mismatch: got %q want %q", got, want)
}
if got, want := snapshot.BulkChunkSize, defaultBulkChunkSize; got != want {
t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkWindowBytes, defaultBulkOpenWindowBytes; got != want {
t.Fatalf("BulkWindowBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkMaxInFlight, defaultBulkOpenMaxInFlight; got != want {
t.Fatalf("BulkMaxInFlight mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkAttachLimit, defaultBulkDedicatedAttachLimit; got != want {
t.Fatalf("BulkAttachLimit mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkActiveLimit, defaultBulkDedicatedActiveLimit; got != want {
t.Fatalf("BulkActiveLimit mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkLaneLimit, defaultBulkDedicatedLaneLimit; got != want {
t.Fatalf("BulkLaneLimit mismatch: got %d want %d", got, want)
}
if snapshot.BulkAttachAttempts != 0 || snapshot.BulkAttachRetries != 0 || snapshot.BulkAttachSuccess != 0 || snapshot.BulkAutoFallbacks != 0 {
t.Fatalf("unexpected default bulk attach counters: %+v", snapshot)
}
if snapshot.TransportBulkAdaptiveSoftPayloadBytes != 0 {
t.Fatalf("TransportBulkAdaptiveSoftPayloadBytes mismatch: got %d want 0", snapshot.TransportBulkAdaptiveSoftPayloadBytes)
}
if snapshot.TransportStreamAdaptiveSoftPayloadBytes != 0 {
t.Fatalf("TransportStreamAdaptiveSoftPayloadBytes mismatch: got %d want 0", snapshot.TransportStreamAdaptiveSoftPayloadBytes)
}
if snapshot.TransportStreamAdaptiveWaitThresholdBytes != 0 {
t.Fatalf("TransportStreamAdaptiveWaitThresholdBytes mismatch: got %d want 0", snapshot.TransportStreamAdaptiveWaitThresholdBytes)
}
if snapshot.TransportStreamAdaptiveFlushDelay != 0 {
t.Fatalf("TransportStreamAdaptiveFlushDelay mismatch: got %s want 0", snapshot.TransportStreamAdaptiveFlushDelay)
}
if snapshot.Retry != (ConnectionRetrySnapshot{}) { if snapshot.Retry != (ConnectionRetrySnapshot{}) {
t.Fatalf("Retry snapshot mismatch: %+v", snapshot.Retry) t.Fatalf("Retry snapshot mismatch: %+v", snapshot.Retry)
} }
@@ -78,11 +132,162 @@ func TestGetServerRuntimeSnapshotDefaults(t *testing.T) {
if !snapshot.HasRuntimeStopCtx { if !snapshot.HasRuntimeStopCtx {
t.Fatalf("HasRuntimeStopCtx mismatch: got %v want true", snapshot.HasRuntimeStopCtx) t.Fatalf("HasRuntimeStopCtx mismatch: got %v want true", snapshot.HasRuntimeStopCtx)
} }
if got, want := snapshot.AuthMode, "none"; got != want {
t.Fatalf("AuthMode mismatch: got %q want %q", got, want)
}
if got, want := snapshot.ProtectionMode, "managed"; got != want {
t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want)
}
if snapshot.PeerAttachRequireExplicitAuth || snapshot.PeerAttachRequireChannelBinding || snapshot.PeerAttachChannelBindingConfigured {
t.Fatalf("unexpected default peer attach policy: %+v", snapshot)
}
if got, want := snapshot.PeerAttachReplayWindow, peerAttachReplayTTL; got != want {
t.Fatalf("PeerAttachReplayWindow mismatch: got %s want %s", got, want)
}
if got, want := snapshot.PeerAttachReplayCapacity, defaultPeerAttachReplayCapacity; got != want {
t.Fatalf("PeerAttachReplayCapacity mismatch: got %d want %d", got, want)
}
if snapshot.PeerAttachExplicitAuth != 0 || snapshot.PeerAttachAuthFallbacks != 0 || snapshot.PeerAttachAuthRejects != 0 || snapshot.PeerAttachDowngradeRejects != 0 || snapshot.PeerAttachBindingRejects != 0 || snapshot.PeerAttachReplayRejects != 0 || snapshot.PeerAttachReplayOverflowRejects != 0 {
t.Fatalf("unexpected default peer attach counters: %+v", snapshot)
}
if got, want := snapshot.BulkChunkSize, defaultBulkChunkSize; got != want {
t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkWindowBytes, defaultBulkOpenWindowBytes; got != want {
t.Fatalf("BulkWindowBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkMaxInFlight, defaultBulkOpenMaxInFlight; got != want {
t.Fatalf("BulkMaxInFlight mismatch: got %d want %d", got, want)
}
if snapshot.Retry != (ConnectionRetrySnapshot{}) { if snapshot.Retry != (ConnectionRetrySnapshot{}) {
t.Fatalf("Retry snapshot mismatch: %+v", snapshot.Retry) t.Fatalf("Retry snapshot mismatch: %+v", snapshot.Retry)
} }
} }
func TestGetClientRuntimeSnapshotIncludesBulkAttachStats(t *testing.T) {
client := NewClient().(*ClientCommon)
client.SetBulkNetworkProfile(BulkNetworkProfileWAN)
client.SetBulkDefaultOpenMode(BulkOpenModeAuto)
client.bulkAttachAttemptCount.Store(11)
client.bulkAttachRetryCount.Store(5)
client.bulkAttachSuccessCount.Store(6)
client.bulkAttachFallbackCount.Store(3)
snapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if got, want := snapshot.BulkNetworkProfile, "wan"; got != want {
t.Fatalf("BulkNetworkProfile mismatch: got %q want %q", got, want)
}
if got, want := snapshot.BulkDefaultMode, "auto"; got != want {
t.Fatalf("BulkDefaultMode mismatch: got %q want %q", got, want)
}
if got, want := snapshot.BulkChunkSize, 512*1024; got != want {
t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkWindowBytes, 8*1024*1024; got != want {
t.Fatalf("BulkWindowBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkMaxInFlight, 16; got != want {
t.Fatalf("BulkMaxInFlight mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkAttachLimit, 2; got != want {
t.Fatalf("BulkAttachLimit mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkActiveLimit, 4096; got != want {
t.Fatalf("BulkActiveLimit mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkLaneLimit, 2; got != want {
t.Fatalf("BulkLaneLimit mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkAttachAttempts, int64(11); got != want {
t.Fatalf("BulkAttachAttempts mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkAttachRetries, int64(5); got != want {
t.Fatalf("BulkAttachRetries mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkAttachSuccess, int64(6); got != want {
t.Fatalf("BulkAttachSuccess mismatch: got %d want %d", got, want)
}
if got, want := snapshot.BulkAutoFallbacks, int64(3); got != want {
t.Fatalf("BulkAutoFallbacks mismatch: got %d want %d", got, want)
}
}
func TestGetClientRuntimeSnapshotIncludesAdaptiveBulkSoftPayload(t *testing.T) {
client := NewClient().(*ClientCommon)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
binding := newTransportBinding(left, queue)
binding.observeBulkAdaptivePayloadWrite(8*1024*1024, 640*time.Millisecond, 0, nil)
client.setClientSessionRuntime(&clientSessionRuntime{
transport: binding,
stopCtx: stopCtx,
stopFn: stopFn,
queue: queue,
epoch: 1,
})
client.markSessionStarted()
snapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if got, want := snapshot.TransportBulkAdaptiveSoftPayloadBytes, bulkAdaptiveSoftPayloadMinBytes; got != want {
t.Fatalf("TransportBulkAdaptiveSoftPayloadBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.TransportStreamAdaptiveSoftPayloadBytes, streamAdaptiveSoftPayloadStartBytes; got != want {
t.Fatalf("TransportStreamAdaptiveSoftPayloadBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.TransportStreamAdaptiveWaitThresholdBytes, streamBatchWaitThreshold; got != want {
t.Fatalf("TransportStreamAdaptiveWaitThresholdBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.TransportStreamAdaptiveFlushDelay, streamBatchMaxFlushDelay; got != want {
t.Fatalf("TransportStreamAdaptiveFlushDelay mismatch: got %s want %s", got, want)
}
}
func TestGetClientRuntimeSnapshotIncludesAdaptiveStreamTuning(t *testing.T) {
client := NewClient().(*ClientCommon)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
binding := newTransportBinding(left, queue)
binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil)
client.setClientSessionRuntime(&clientSessionRuntime{
transport: binding,
stopCtx: stopCtx,
stopFn: stopFn,
queue: queue,
epoch: 1,
})
client.markSessionStarted()
snapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if got, want := snapshot.TransportStreamAdaptiveSoftPayloadBytes, streamAdaptiveSoftPayloadMinBytes; got != want {
t.Fatalf("TransportStreamAdaptiveSoftPayloadBytes mismatch: got %d want %d", got, want)
}
if got, want := snapshot.TransportStreamAdaptiveWaitThresholdBytes, streamAdaptiveWaitThresholdMinBytes; got != want {
t.Fatalf("TransportStreamAdaptiveWaitThresholdBytes mismatch: got %d want %d", got, want)
}
if got := snapshot.TransportStreamAdaptiveFlushDelay; got != 0 {
t.Fatalf("TransportStreamAdaptiveFlushDelay mismatch: got %s want 0", got)
}
}
func TestGetRuntimeSnapshotRejectsNil(t *testing.T) { func TestGetRuntimeSnapshotRejectsNil(t *testing.T) {
if _, err := GetClientRuntimeSnapshot(nil); !errors.Is(err, errClientRuntimeSnapshotNil) { if _, err := GetClientRuntimeSnapshot(nil); !errors.Is(err, errClientRuntimeSnapshotNil) {
t.Fatalf("GetClientRuntimeSnapshot nil error = %v, want %v", err, errClientRuntimeSnapshotNil) t.Fatalf("GetClientRuntimeSnapshot nil error = %v, want %v", err, errClientRuntimeSnapshotNil)
@@ -304,6 +509,134 @@ func TestGetClientConnRuntimeSnapshotExposesDetachState(t *testing.T) {
if snapshot.LastHeartbeatAt.IsZero() { if snapshot.LastHeartbeatAt.IsZero() {
t.Fatal("LastHeartbeatAt should be recorded") t.Fatal("LastHeartbeatAt should be recorded")
} }
if got, want := snapshot.AuthMode, "none"; got != want {
t.Fatalf("AuthMode mismatch: got %q want %q", got, want)
}
if got, want := snapshot.ProtectionMode, "managed"; got != want {
t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want)
}
if snapshot.PeerAttachAuthenticated || snapshot.PeerAttachAuthFallback {
t.Fatalf("unexpected peer attach state: %+v", snapshot)
}
if !snapshot.LastPeerAttachAt.IsZero() {
t.Fatalf("LastPeerAttachAt mismatch: got %v want zero", snapshot.LastPeerAttachAt)
}
}
func TestGetRuntimeSnapshotsIncludePeerAttachSecurityState(t *testing.T) {
secret := []byte("correct horse battery staple")
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UsePSKOverExternalTransportServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
}
})
client := NewClient().(*ClientCommon)
if err := UsePSKOverExternalTransportClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachLogicalForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
deadline := time.Now().Add(time.Second)
for {
logical := server.GetLogicalConn(client.peerIdentity)
if logical != nil {
authenticated, fallback, _ := logical.peerAttachAuthenticatedSnapshot()
if authenticated && !fallback && logical.protectionModeSnapshot() == ProtectionExternal && server.peerAttachExplicitCount.Load() == 1 {
break
}
}
if time.Now().After(deadline) {
t.Fatal("peer attach security state did not converge before snapshot")
}
time.Sleep(time.Millisecond)
}
clientSnapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if got, want := clientSnapshot.AuthMode, "psk"; got != want {
t.Fatalf("client AuthMode mismatch: got %q want %q", got, want)
}
if got, want := clientSnapshot.ProtectionMode, "external"; got != want {
t.Fatalf("client ProtectionMode mismatch: got %q want %q", got, want)
}
if clientSnapshot.PeerAttachRequireExplicitAuth || clientSnapshot.PeerAttachRequireChannelBinding || clientSnapshot.PeerAttachChannelBindingConfigured {
t.Fatalf("unexpected client peer attach policy snapshot: %+v", clientSnapshot)
}
if !clientSnapshot.PeerAttachAuthenticated || clientSnapshot.PeerAttachAuthFallback {
t.Fatalf("unexpected client peer attach state: %+v", clientSnapshot)
}
if clientSnapshot.LastPeerAttachAt.IsZero() {
t.Fatal("client LastPeerAttachAt should be recorded")
}
serverSnapshot, err := GetServerRuntimeSnapshot(server)
if err != nil {
t.Fatalf("GetServerRuntimeSnapshot failed: %v", err)
}
if got, want := serverSnapshot.AuthMode, "psk"; got != want {
t.Fatalf("server AuthMode mismatch: got %q want %q", got, want)
}
if got, want := serverSnapshot.ProtectionMode, "external"; got != want {
t.Fatalf("server ProtectionMode mismatch: got %q want %q", got, want)
}
if serverSnapshot.PeerAttachRequireExplicitAuth || serverSnapshot.PeerAttachRequireChannelBinding || serverSnapshot.PeerAttachChannelBindingConfigured {
t.Fatalf("unexpected server peer attach policy snapshot: %+v", serverSnapshot)
}
if got, want := serverSnapshot.PeerAttachExplicitAuth, int64(1); got != want {
t.Fatalf("PeerAttachExplicitAuth mismatch: got %d want %d", got, want)
}
if serverSnapshot.PeerAttachAuthFallbacks != 0 || serverSnapshot.PeerAttachAuthRejects != 0 || serverSnapshot.PeerAttachDowngradeRejects != 0 || serverSnapshot.PeerAttachBindingRejects != 0 || serverSnapshot.PeerAttachReplayRejects != 0 || serverSnapshot.PeerAttachReplayOverflowRejects != 0 {
t.Fatalf("unexpected server peer attach counters: %+v", serverSnapshot)
}
logical := server.GetLogicalConn(client.peerIdentity)
if logical == nil {
t.Fatal("server logical should exist after peer attach")
}
logicalSnapshot, err := GetLogicalConnRuntimeSnapshot(logical)
if err != nil {
t.Fatalf("GetLogicalConnRuntimeSnapshot failed: %v", err)
}
if got, want := logicalSnapshot.AuthMode, "psk"; got != want {
t.Fatalf("logical AuthMode mismatch: got %q want %q", got, want)
}
if got, want := logicalSnapshot.ProtectionMode, "external"; got != want {
t.Fatalf("logical ProtectionMode mismatch: got %q want %q", got, want)
}
if !logicalSnapshot.PeerAttachAuthenticated || logicalSnapshot.PeerAttachAuthFallback {
t.Fatalf("unexpected logical peer attach state: %+v", logicalSnapshot)
}
if logicalSnapshot.LastPeerAttachAt.IsZero() {
t.Fatal("logical LastPeerAttachAt should be recorded")
}
clientConnSnapshot, err := GetClientConnRuntimeSnapshot(clientConnFromLogical(logical))
if err != nil {
t.Fatalf("GetClientConnRuntimeSnapshot failed: %v", err)
}
if got, want := clientConnSnapshot.AuthMode, "psk"; got != want {
t.Fatalf("client conn AuthMode mismatch: got %q want %q", got, want)
}
if got, want := clientConnSnapshot.ProtectionMode, "external"; got != want {
t.Fatalf("client conn ProtectionMode mismatch: got %q want %q", got, want)
}
if !clientConnSnapshot.PeerAttachAuthenticated || clientConnSnapshot.PeerAttachAuthFallback {
t.Fatalf("unexpected client conn peer attach state: %+v", clientConnSnapshot)
}
if clientConnSnapshot.LastPeerAttachAt.IsZero() {
t.Fatal("client conn LastPeerAttachAt should be recorded")
}
} }
func TestGetServerDetachedClientRuntimeSnapshotsFiltersAndSorts(t *testing.T) { func TestGetServerDetachedClientRuntimeSnapshotsFiltersAndSorts(t *testing.T) {
+2
View File
@@ -140,6 +140,7 @@ func (c *ClientCommon) cleanupClientSessionResources() {
if runtime := c.getBulkRuntime(); runtime != nil { if runtime := c.getBulkRuntime(); runtime != nil {
runtime.closeAll(errServiceShutdown) runtime.closeAll(errServiceShutdown)
} }
c.closeClientDedicatedSidecar()
} }
func (s *ServerCommon) cleanupServerSessionResources() { func (s *ServerCommon) cleanupServerSessionResources() {
@@ -158,4 +159,5 @@ func (s *ServerCommon) cleanupServerSessionResources() {
if runtime := s.getBulkRuntime(); runtime != nil { if runtime := s.getBulkRuntime(); runtime != nil {
runtime.closeAll(errServiceShutdown) runtime.closeAll(errServiceShutdown)
} }
s.closeAllServerDedicatedSidecars()
} }
+2 -2
View File
@@ -76,7 +76,7 @@ func startSignalRoundTripServerForBenchmark(b *testing.B) (*ServerCommon, string
server.SetLink("signal-roundtrip", func(msg *Message) { server.SetLink("signal-roundtrip", func(msg *Message) {
_ = msg.Reply([]byte("ack:" + string(msg.Value))) _ = msg.Reply([]byte("ack:" + string(msg.Value)))
}) })
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil {
if benchmarkListenPermissionDenied(err) { if benchmarkListenPermissionDenied(err) {
b.Skipf("tcp benchmark requires local listen permission: %v", err) b.Skipf("tcp benchmark requires local listen permission: %v", err)
} }
@@ -102,7 +102,7 @@ func newSignalRoundTripBenchmarkClient(b *testing.B, addr string) *ClientCommon
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
b.Fatalf("UseModernPSKClient failed: %v", err) b.Fatalf("UseModernPSKClient failed: %v", err)
} }
if err := client.Connect("tcp", addr); err != nil { if err := client.Connect("tcp", benchmarkTCPDialAddr(b, addr)); err != nil {
b.Fatalf("client Connect failed: %v", err) b.Fatalf("client Connect failed: %v", err)
} }
return client return client
+30 -14
View File
@@ -3,20 +3,24 @@ package notify
import "time" import "time"
type snapshotBindingDiagnostics struct { type snapshotBindingDiagnostics struct {
BindingOwner string BindingOwner string
BindingAlive bool BindingAlive bool
BindingCurrent bool BindingCurrent bool
BindingReason string BindingReason string
BindingError string BindingError string
TransportAttached bool BindingBulkAdaptiveSoftPayloadBytes int
TransportHasRuntimeConn bool BindingStreamAdaptiveSoftPayloadBytes int
TransportCurrent bool BindingStreamAdaptiveWaitThresholdBytes int
TransportDetachReason string BindingStreamAdaptiveFlushDelay time.Duration
TransportDetachKind string TransportAttached bool
TransportDetachError string TransportHasRuntimeConn bool
TransportDetachGeneration uint64 TransportCurrent bool
TransportDetachedAt time.Time TransportDetachReason string
ReattachEligible bool TransportDetachKind string
TransportDetachError string
TransportDetachGeneration uint64
TransportDetachedAt time.Time
ReattachEligible bool
} }
func snapshotBindingDiagnosticsFromClient(c *ClientCommon, sessionEpoch uint64) snapshotBindingDiagnostics { func snapshotBindingDiagnosticsFromClient(c *ClientCommon, sessionEpoch uint64) snapshotBindingDiagnostics {
@@ -36,6 +40,12 @@ func snapshotBindingDiagnosticsFromClient(c *ClientCommon, sessionEpoch uint64)
diag.TransportAttached = c.clientTransportAttachedSnapshot() diag.TransportAttached = c.clientTransportAttachedSnapshot()
diag.TransportHasRuntimeConn = c.clientTransportConnSnapshot() != nil diag.TransportHasRuntimeConn = c.clientTransportConnSnapshot() != nil
diag.TransportCurrent = diag.BindingCurrent && diag.TransportAttached diag.TransportCurrent = diag.BindingCurrent && diag.TransportAttached
if binding := c.clientTransportBindingSnapshot(); binding != nil {
diag.BindingBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot()
diag.BindingStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot()
diag.BindingStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot()
diag.BindingStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot()
}
return diag return diag
} }
@@ -72,5 +82,11 @@ func snapshotBindingDiagnosticsFromLogical(logical *LogicalConn, transport *Tran
diag.TransportCurrent = runtime.TransportAttached diag.TransportCurrent = runtime.TransportAttached
} }
diag.BindingCurrent = diag.BindingAlive && diag.TransportCurrent diag.BindingCurrent = diag.BindingAlive && diag.TransportCurrent
if binding := logical.transportBindingSnapshot(); binding != nil {
diag.BindingBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot()
diag.BindingStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot()
diag.BindingStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot()
diag.BindingStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot()
}
return diag return diag
} }
+171 -31
View File
@@ -7,6 +7,7 @@ import (
"net" "net"
"os" "os"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
@@ -88,12 +89,67 @@ type streamCloseSender func(context.Context, *streamHandle, bool) error
type streamResetSender func(context.Context, *streamHandle, string) error type streamResetSender func(context.Context, *streamHandle, string) error
type streamDataSender func(context.Context, *streamHandle, []byte) error type streamDataSender func(context.Context, *streamHandle, []byte) error
type streamReadChunk struct {
data []byte
release func()
}
func (c *streamReadChunk) clear() {
if c == nil {
return
}
if c.release != nil {
c.release()
}
c.data = nil
c.release = nil
}
type streamReadPayloadOwner struct {
refs atomic.Int32
release func()
}
func newStreamReadPayloadOwner(release func()) *streamReadPayloadOwner {
if release == nil {
return nil
}
owner := &streamReadPayloadOwner{release: release}
owner.refs.Store(1)
return owner
}
func (o *streamReadPayloadOwner) retainChunk() func() {
if o == nil {
return nil
}
o.refs.Add(1)
return o.releaseChunk
}
func (o *streamReadPayloadOwner) releaseChunk() {
if o == nil {
return
}
if o.refs.Add(-1) == 0 && o.release != nil {
o.release()
}
}
func (o *streamReadPayloadOwner) done() {
if o == nil {
return
}
o.releaseChunk()
}
type streamHandle struct { type streamHandle struct {
runtime *streamRuntime runtime *streamRuntime
runtimeScope string runtimeScope string
id string id string
dataID uint64 dataID uint64
outboundSeq uint64 fastPathVersion uint8
outboundSeq atomic.Uint64
channel StreamChannel channel StreamChannel
metadata StreamMetadata metadata StreamMetadata
sessionEpoch uint64 sessionEpoch uint64
@@ -122,8 +178,8 @@ type streamHandle struct {
remoteClosed bool remoteClosed bool
peerReadClosed bool peerReadClosed bool
resetErr error resetErr error
readQueue [][]byte readQueue []streamReadChunk
readBuf []byte readBuf streamReadChunk
bufferedBytes int bufferedBytes int
readNotify chan struct{} readNotify chan struct{}
readDeadline time.Time readDeadline time.Time
@@ -132,6 +188,9 @@ type streamHandle struct {
writeDeadlineOverride bool writeDeadlineOverride bool
readDeadlineNotify chan struct{} readDeadlineNotify chan struct{}
writeDeadlineNotify chan struct{} writeDeadlineNotify chan struct{}
writeWaitSeq uint64
writeWaitCancel context.CancelFunc
writeWaitChanged chan struct{}
bytesRead int64 bytesRead int64
bytesWritten int64 bytesWritten int64
readCalls int64 readCalls int64
@@ -157,6 +216,7 @@ func newStreamHandle(parent context.Context, runtime *streamRuntime, runtimeScop
runtimeScope: runtimeScope, runtimeScope: runtimeScope,
id: req.StreamID, id: req.StreamID,
dataID: req.DataID, dataID: req.DataID,
fastPathVersion: normalizeStreamFastPathVersion(req.FastPathVersion),
channel: normalizeStreamChannel(req.Channel), channel: normalizeStreamChannel(req.Channel),
metadata: cloneStreamMetadata(req.Metadata), metadata: cloneStreamMetadata(req.Metadata),
sessionEpoch: sessionEpoch, sessionEpoch: sessionEpoch,
@@ -224,13 +284,25 @@ func (s *streamHandle) dataIDSnapshot() uint64 {
} }
func (s *streamHandle) nextOutboundDataSeq() uint64 { func (s *streamHandle) nextOutboundDataSeq() uint64 {
return s.reserveOutboundDataSeqs(1)
}
func (s *streamHandle) reserveOutboundDataSeqs(count int) uint64 {
if s == nil { if s == nil {
return 0 return 0
} }
s.mu.Lock() if count <= 0 {
defer s.mu.Unlock() count = 1
s.outboundSeq++ }
return s.outboundSeq end := s.outboundSeq.Add(uint64(count))
return end - uint64(count) + 1
}
func (s *streamHandle) fastPathVersionSnapshot() uint8 {
if s == nil {
return streamFastPathVersionV1
}
return normalizeStreamFastPathVersion(s.fastPathVersion)
} }
func (s *streamHandle) Channel() StreamChannel { func (s *streamHandle) Channel() StreamChannel {
@@ -321,20 +393,23 @@ func (s *streamHandle) Read(p []byte) (int, error) {
for { for {
s.mu.Lock() s.mu.Lock()
localReadClosed := s.localReadClosed localReadClosed := s.localReadClosed
if len(s.readBuf) > 0 { if len(s.readBuf.data) > 0 {
n := copy(p, s.readBuf) n := copy(p, s.readBuf.data)
s.readBuf = s.readBuf[n:] s.readBuf.data = s.readBuf.data[n:]
s.bufferedBytes -= n s.bufferedBytes -= n
if s.bufferedBytes < 0 { if s.bufferedBytes < 0 {
s.bufferedBytes = 0 s.bufferedBytes = 0
} }
if len(s.readBuf.data) == 0 {
s.readBuf.clear()
}
s.recordReadLocked(n, time.Now()) s.recordReadLocked(n, time.Now())
s.mu.Unlock() s.mu.Unlock()
return n, nil return n, nil
} }
if len(s.readQueue) > 0 { if len(s.readQueue) > 0 {
s.readBuf = s.readQueue[0] s.readBuf = s.readQueue[0]
s.readQueue[0] = nil s.readQueue[0] = streamReadChunk{}
s.readQueue = s.readQueue[1:] s.readQueue = s.readQueue[1:]
s.mu.Unlock() s.mu.Unlock()
continue continue
@@ -377,6 +452,7 @@ func (s *streamHandle) Write(p []byte) (int, error) {
sendDataFn := s.sendDataFn sendDataFn := s.sendDataFn
chunkSize := s.chunkSize chunkSize := s.chunkSize
writeTimeout := s.writeTimeout writeTimeout := s.writeTimeout
writeDeadlineOverride := s.writeDeadlineOverride
streamCtx := s.ctx streamCtx := s.ctx
runtime := s.runtime runtime := s.runtime
s.mu.Unlock() s.mu.Unlock()
@@ -399,6 +475,20 @@ func (s *streamHandle) Write(p []byte) (int, error) {
end = len(p) end = len(p)
} }
chunk := p[written:end] chunk := p[written:end]
if !writeDeadlineOverride && writeTimeout <= 0 {
if tryAcquireStreamOutboundBudget(runtime, len(chunk)) {
err := sendDataFn(streamCtx, s, chunk)
releaseStreamOutboundBudget(runtime, len(chunk))
if err != nil {
if written > 0 {
s.recordWrite(written, time.Now())
}
return written, s.normalizeWriteError(err)
}
written = end
continue
}
}
sendCtx, cancel, deadlineChanged, err := s.newWriteContext(streamCtx, writeTimeout) sendCtx, cancel, deadlineChanged, err := s.newWriteContext(streamCtx, writeTimeout)
if err != nil { if err != nil {
if written > 0 { if written > 0 {
@@ -464,7 +554,15 @@ func (s *streamHandle) SetWriteDeadline(deadline time.Time) error {
s.writeDeadline = deadline s.writeDeadline = deadline
s.writeDeadlineOverride = true s.writeDeadlineOverride = true
signalStreamDeadlineChangeLocked(&s.writeDeadlineNotify) signalStreamDeadlineChangeLocked(&s.writeDeadlineNotify)
waitCancel := s.writeWaitCancel
if s.writeWaitChanged != nil {
close(s.writeWaitChanged)
s.writeWaitChanged = nil
}
s.mu.Unlock() s.mu.Unlock()
if waitCancel != nil {
waitCancel()
}
return nil return nil
} }
@@ -535,7 +633,6 @@ func (s *streamHandle) newWriteContext(parent context.Context, writeTimeout time
} }
s.mu.Lock() s.mu.Lock()
deadline := s.effectiveWriteDeadlineLocked(time.Now(), writeTimeout) deadline := s.effectiveWriteDeadlineLocked(time.Now(), writeTimeout)
deadlineNotify := s.writeDeadlineNotify
s.mu.Unlock() s.mu.Unlock()
if !deadline.IsZero() && !deadline.After(time.Now()) { if !deadline.IsZero() && !deadline.After(time.Now()) {
return nil, func() {}, nil, os.ErrDeadlineExceeded return nil, func() {}, nil, os.ErrDeadlineExceeded
@@ -548,19 +645,20 @@ func (s *streamHandle) newWriteContext(parent context.Context, writeTimeout time
baseCtx, baseCancel = context.WithCancel(parent) baseCtx, baseCancel = context.WithCancel(parent)
} }
changed := make(chan struct{}) changed := make(chan struct{})
done := make(chan struct{}) s.mu.Lock()
go func() { s.writeWaitSeq++
defer close(done) waitSeq := s.writeWaitSeq
select { s.writeWaitCancel = baseCancel
case <-baseCtx.Done(): s.writeWaitChanged = changed
case <-deadlineNotify: s.mu.Unlock()
close(changed)
baseCancel()
}
}()
cancel := func() { cancel := func() {
baseCancel() baseCancel()
<-done s.mu.Lock()
if s.writeWaitSeq == waitSeq {
s.writeWaitCancel = nil
s.writeWaitChanged = nil
}
s.mu.Unlock()
} }
return baseCtx, cancel, changed, nil return baseCtx, cancel, changed, nil
} }
@@ -783,39 +881,61 @@ func (s *streamHandle) pushOwnedChunk(chunk []byte) error {
return s.pushChunkWithOwnership(chunk, true) return s.pushChunkWithOwnership(chunk, true)
} }
func (s *streamHandle) pushOwnedChunkWithRelease(chunk []byte, release func()) error {
return s.pushChunkWithOwnershipAndRelease(chunk, true, release)
}
func (s *streamHandle) pushChunkWithOwnership(chunk []byte, owned bool) error { func (s *streamHandle) pushChunkWithOwnership(chunk []byte, owned bool) error {
return s.pushChunkWithOwnershipAndRelease(chunk, owned, nil)
}
func (s *streamHandle) pushChunkWithOwnershipAndRelease(chunk []byte, owned bool, release func()) error {
if s == nil { if s == nil {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
if len(chunk) == 0 { if len(chunk) == 0 {
if release != nil {
release()
}
return nil return nil
} }
stored := chunk stored := streamReadChunk{data: chunk, release: release}
if !owned { if !owned {
stored = append([]byte(nil), chunk...) stored.data = append([]byte(nil), chunk...)
if stored.release != nil {
stored.release()
stored.release = nil
}
} }
s.mu.Lock() s.mu.Lock()
if s.resetErr != nil { if s.resetErr != nil {
err := s.resetErr err := s.resetErr
s.mu.Unlock() s.mu.Unlock()
stored.clear()
return err return err
} }
if s.inboundQueueLimit > 0 && s.bufferedChunkCountLocked() >= s.inboundQueueLimit { if s.inboundQueueLimit > 0 && s.bufferedChunkCountLocked() >= s.inboundQueueLimit {
err := s.markResetLocked(errStreamBackpressureExceeded) err := s.markResetLocked(errStreamBackpressureExceeded)
s.mu.Unlock() s.mu.Unlock()
stored.clear()
s.notifyReadable() s.notifyReadable()
s.finalize() s.finalize()
return err return err
} }
if s.inboundBytesLimit > 0 && s.bufferedBytes+len(stored) > s.inboundBytesLimit { if s.inboundBytesLimit > 0 && s.bufferedBytes+len(stored.data) > s.inboundBytesLimit {
err := s.markResetLocked(errStreamBackpressureExceeded) err := s.markResetLocked(errStreamBackpressureExceeded)
s.mu.Unlock() s.mu.Unlock()
stored.clear()
s.notifyReadable() s.notifyReadable()
s.finalize() s.finalize()
return err return err
} }
s.readQueue = append(s.readQueue, stored) if len(s.readBuf.data) == 0 && len(s.readQueue) == 0 {
s.bufferedBytes += len(stored) s.readBuf = stored
} else {
s.readQueue = append(s.readQueue, stored)
}
s.bufferedBytes += len(stored.data)
s.notifyReadableLocked() s.notifyReadableLocked()
s.mu.Unlock() s.mu.Unlock()
return nil return nil
@@ -836,11 +956,12 @@ func (s *streamHandle) clearBufferedDataLocked() {
if s == nil { if s == nil {
return return
} }
s.readBuf.clear()
for i := range s.readQueue { for i := range s.readQueue {
s.readQueue[i] = nil s.readQueue[i].clear()
} }
s.readQueue = nil s.readQueue = nil
s.readBuf = nil s.readBuf = streamReadChunk{}
s.bufferedBytes = 0 s.bufferedBytes = 0
} }
@@ -849,7 +970,7 @@ func (s *streamHandle) bufferedChunkCountLocked() int {
return 0 return 0
} }
count := len(s.readQueue) count := len(s.readQueue)
if len(s.readBuf) > 0 { if len(s.readBuf.data) > 0 {
count++ count++
} }
return count return count
@@ -917,6 +1038,10 @@ func (s *streamHandle) snapshot() StreamSnapshot {
snapshot.BindingCurrent = diag.BindingCurrent snapshot.BindingCurrent = diag.BindingCurrent
snapshot.BindingReason = diag.BindingReason snapshot.BindingReason = diag.BindingReason
snapshot.BindingError = diag.BindingError snapshot.BindingError = diag.BindingError
snapshot.BindingBulkAdaptiveSoftPayloadBytes = diag.BindingBulkAdaptiveSoftPayloadBytes
snapshot.BindingStreamAdaptiveSoftPayloadBytes = diag.BindingStreamAdaptiveSoftPayloadBytes
snapshot.BindingStreamAdaptiveWaitThresholdBytes = diag.BindingStreamAdaptiveWaitThresholdBytes
snapshot.BindingStreamAdaptiveFlushDelay = diag.BindingStreamAdaptiveFlushDelay
snapshot.TransportAttached = diag.TransportAttached snapshot.TransportAttached = diag.TransportAttached
snapshot.TransportHasRuntimeConn = diag.TransportHasRuntimeConn snapshot.TransportHasRuntimeConn = diag.TransportHasRuntimeConn
snapshot.TransportCurrent = diag.TransportCurrent snapshot.TransportCurrent = diag.TransportCurrent
@@ -1057,8 +1182,23 @@ func acquireStreamOutboundBudget(runtime *streamRuntime, ctx context.Context, si
return runtime.acquireOutbound(ctx, size) return runtime.acquireOutbound(ctx, size)
} }
func tryAcquireStreamOutboundBudget(runtime *streamRuntime, size int) bool {
if runtime == nil {
return true
}
return runtime.tryAcquireOutbound(size)
}
func releaseStreamOutboundBudget(runtime *streamRuntime, size int) {
if runtime == nil {
return
}
runtime.releaseOutbound(size)
}
func normalizeStreamOpenRequest(req StreamOpenRequest) StreamOpenRequest { func normalizeStreamOpenRequest(req StreamOpenRequest) StreamOpenRequest {
req.Channel = normalizeStreamChannel(req.Channel) req.Channel = normalizeStreamChannel(req.Channel)
req.FastPathVersion = normalizeStreamFastPathVersion(req.FastPathVersion)
req.Metadata = cloneStreamMetadata(req.Metadata) req.Metadata = cloneStreamMetadata(req.Metadata)
return req return req
} }
+6
View File
@@ -0,0 +1,6 @@
package notify
type streamBatchCodec struct {
encodeSingle func(streamFastDataFrame) ([]byte, error)
encodeBatch func([]streamFastDataFrame) ([]byte, error)
}
+582
View File
@@ -0,0 +1,582 @@
package notify
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
)
const (
streamBatchMaxPayloads = 64
streamBatchMaxPayloadBytes = 2 * 1024 * 1024
streamBatchMaxFlushDelay = 50 * time.Microsecond
streamBatchWaitThreshold = 128 * 1024
)
const (
streamBatchRequestQueued int32 = iota
streamBatchRequestStarted
streamBatchRequestCanceled
)
type streamBatchRequestState struct {
value atomic.Int32
}
type streamBatchRequest struct {
ctx context.Context
frame streamFastDataFrame
hasFrame bool
encodedPayload []byte
hasEncoded bool
frames []streamFastDataFrame
fastPathVersion uint8
deadline time.Time
done chan error
state *streamBatchRequestState
}
type streamBatchSender struct {
binding *transportBinding
codec streamBatchCodec
writeTimeoutProvider func() time.Duration
reqCh chan streamBatchRequest
stopCh chan struct{}
doneCh chan struct{}
stopOnce sync.Once
flushMu sync.Mutex
queued atomic.Int64
errMu sync.Mutex
err error
}
func newStreamBatchSender(binding *transportBinding, codec streamBatchCodec, writeTimeoutProvider func() time.Duration) *streamBatchSender {
sender := &streamBatchSender{
binding: binding,
codec: codec,
writeTimeoutProvider: writeTimeoutProvider,
reqCh: make(chan streamBatchRequest, streamBatchMaxPayloads*4),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
go sender.run()
return sender
}
func (s *streamBatchSender) submitData(ctx context.Context, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error {
if s == nil {
return errTransportDetached
}
if len(payload) == 0 {
return nil
}
return s.submitRequest(streamBatchRequest{
ctx: ctx,
frame: streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: payload,
},
hasFrame: true,
fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion),
})
}
func (s *streamBatchSender) submitEncoded(ctx context.Context, fastPathVersion uint8, payload []byte) error {
if s == nil {
return errTransportDetached
}
if len(payload) == 0 {
return nil
}
return s.submitRequest(streamBatchRequest{
ctx: ctx,
encodedPayload: payload,
hasEncoded: true,
fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion),
})
}
func (s *streamBatchSender) submitFrames(ctx context.Context, fastPathVersion uint8, frames []streamFastDataFrame) error {
if s == nil {
return errTransportDetached
}
if len(frames) == 0 {
return nil
}
queuedFrames := append([]streamFastDataFrame(nil), frames...)
req := streamBatchRequest{
ctx: ctx,
frames: queuedFrames,
fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion),
}
if len(queuedFrames) == 1 {
req.frame = queuedFrames[0]
req.frames = nil
req.hasFrame = true
}
return s.submitRequest(req)
}
func (s *streamBatchSender) submitRequest(req streamBatchRequest) error {
if s == nil {
return errTransportDetached
}
if req.ctx == nil {
req.ctx = context.Background()
}
if !req.hasFrame && !req.hasEncoded && len(req.frames) == 0 {
return nil
}
req.fastPathVersion = normalizeStreamFastPathVersion(req.fastPathVersion)
req.done = make(chan error, 1)
req.state = &streamBatchRequestState{}
if deadline, ok := req.ctx.Deadline(); ok {
req.deadline = deadline
}
if err := s.errSnapshot(); err != nil {
return err
}
if s.shouldDirectSubmit(req) {
if submitted, err := s.tryDirectSubmit(req); submitted {
return err
}
}
s.queued.Add(1)
select {
case <-req.ctx.Done():
s.queued.Add(-1)
return normalizeStreamDeadlineError(req.ctx.Err())
case <-s.stopCh:
s.queued.Add(-1)
return s.stoppedErr()
case s.reqCh <- req:
}
select {
case err := <-req.done:
return err
case <-req.ctx.Done():
if req.tryCancel() {
return normalizeStreamDeadlineError(req.ctx.Err())
}
return <-req.done
}
}
func (s *streamBatchSender) shouldDirectSubmit(req streamBatchRequest) bool {
if req.hasEncoded {
return false
}
if !req.hasFrame && len(req.frames) == 0 {
return false
}
return !streamFastPathSupportsBatch(req.fastPathVersion)
}
func (s *streamBatchSender) tryDirectSubmit(req streamBatchRequest) (bool, error) {
if s == nil {
return true, errTransportDetached
}
if err := s.errSnapshot(); err != nil {
return true, err
}
select {
case <-req.ctx.Done():
return true, normalizeStreamDeadlineError(req.ctx.Err())
case <-s.stopCh:
return true, s.stoppedErr()
default:
}
if s.queued.Load() != 0 {
return false, nil
}
if !s.flushMu.TryLock() {
return false, nil
}
defer s.flushMu.Unlock()
if s.queued.Load() != 0 {
return false, nil
}
if err := s.errSnapshot(); err != nil {
return true, err
}
if !req.tryStart() {
return true, req.canceledErr()
}
if err := req.contextErr(); err != nil {
return true, err
}
if err := s.flush([]streamBatchRequest{req}); err != nil {
s.setErr(err)
s.failPending(err)
return true, err
}
return true, nil
}
func (s *streamBatchSender) run() {
defer close(s.doneCh)
for {
req, ok := s.nextRequest()
if !ok {
return
}
batch := []streamBatchRequest{req}
batchBytes := streamBatchRequestApproxBytes(req)
softPayloadLimit := s.batchSoftPayloadLimit()
waitThreshold := s.batchWaitThreshold()
flushDelay := s.batchFlushDelay()
timer := (*time.Timer)(nil)
timerCh := (<-chan time.Time)(nil)
if flushDelay > 0 && batchBytes < waitThreshold && batchBytes < softPayloadLimit && len(batch) < streamBatchMaxPayloads {
timer = time.NewTimer(flushDelay)
timerCh = timer.C
}
drain:
for len(batch) < streamBatchMaxPayloads && batchBytes < softPayloadLimit {
if timerCh == nil {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
batchBytes += streamBatchRequestApproxBytes(next)
default:
break drain
}
continue
}
select {
case <-s.stopCh:
if timer != nil {
timer.Stop()
}
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
batchBytes += streamBatchRequestApproxBytes(next)
case <-timerCh:
timerCh = nil
break drain
}
}
if timer != nil {
if !timer.Stop() && timerCh != nil {
select {
case <-timer.C:
default:
}
}
}
s.flushMu.Lock()
err := s.errSnapshot()
active := make([]streamBatchRequest, 0, len(batch))
for _, item := range batch {
if !item.tryStart() {
s.finishRequest(item, item.canceledErr())
continue
}
if itemErr := item.contextErr(); itemErr != nil {
s.finishRequest(item, itemErr)
continue
}
active = append(active, item)
}
if len(active) == 0 {
s.flushMu.Unlock()
continue
}
if err == nil {
err = s.flush(active)
}
s.flushMu.Unlock()
if err != nil {
s.setErr(err)
for _, item := range active {
s.finishRequest(item, err)
}
s.failPending(err)
return
}
for _, item := range active {
s.finishRequest(item, nil)
}
}
}
func (s *streamBatchSender) nextRequest() (streamBatchRequest, bool) {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return streamBatchRequest{}, false
case req := <-s.reqCh:
return req, true
}
}
func (s *streamBatchSender) flush(requests []streamBatchRequest) error {
if s == nil || s.binding == nil {
return errTransportDetached
}
queue := s.binding.queueSnapshot()
if queue == nil {
return errTransportFrameQueueUnavailable
}
payloads, err := s.encodeRequests(requests)
if err != nil {
return err
}
writeTimeout := s.transportWriteTimeout()
payloadBytes := 0
for _, payload := range payloads {
payloadBytes += len(payload)
}
started := time.Now()
err = s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error {
return writeFramedPayloadBatchUnlocked(conn, queue, payloads)
})
s.binding.observeStreamAdaptivePayloadWrite(payloadBytes, time.Since(started), writeTimeout, err)
return err
}
func (s *streamBatchSender) transportWriteTimeout() time.Duration {
if s == nil || s.writeTimeoutProvider == nil {
return 0
}
return s.writeTimeoutProvider()
}
func (s *streamBatchSender) batchSoftPayloadLimit() int {
if s == nil || s.binding == nil {
return streamAdaptiveSoftPayloadFallbackBytes
}
return s.binding.streamAdaptiveSoftPayloadBytesSnapshot()
}
func (s *streamBatchSender) batchWaitThreshold() int {
if s == nil || s.binding == nil {
return streamBatchWaitThreshold
}
return s.binding.streamAdaptiveWaitThresholdBytesSnapshot()
}
func (s *streamBatchSender) batchFlushDelay() time.Duration {
if s == nil || s.binding == nil {
return streamBatchMaxFlushDelay
}
return s.binding.streamAdaptiveFlushDelaySnapshot()
}
func (s *streamBatchSender) batchPlainPayloadLimit() int {
limit := s.batchSoftPayloadLimit()
if limit <= streamFastBatchHeaderLen {
return streamFastBatchHeaderLen + 1
}
return minInt(limit, streamFastBatchMaxPlainBytes)
}
func (s *streamBatchSender) encodeRequests(requests []streamBatchRequest) ([][]byte, error) {
if len(requests) == 0 {
return nil, nil
}
payloads := make([][]byte, 0, len(requests))
batchPlainLimit := s.batchPlainPayloadLimit()
var batch []streamFastDataFrame
flushBatch := func() error {
if len(batch) == 0 {
return nil
}
payload, err := s.encodeBatch(batch)
if err != nil {
return err
}
payloads = append(payloads, payload)
batch = batch[:0]
return nil
}
batchBytes := streamFastBatchHeaderLen
appendFrame := func(frame streamFastDataFrame, fastPathVersion uint8) error {
if !streamFastPathSupportsBatch(fastPathVersion) {
if err := flushBatch(); err != nil {
return err
}
batchBytes = streamFastBatchHeaderLen
payload, err := s.encodeSingle(frame)
if err != nil {
return err
}
payloads = append(payloads, payload)
return nil
}
frameLen := streamFastBatchFrameLen(frame)
if frameLen+streamFastBatchHeaderLen > batchPlainLimit {
if err := flushBatch(); err != nil {
return err
}
batchBytes = streamFastBatchHeaderLen
payload, err := s.encodeSingle(frame)
if err != nil {
return err
}
payloads = append(payloads, payload)
return nil
}
if len(batch) > 0 && (len(batch) >= streamFastBatchMaxItems || batchBytes+frameLen > batchPlainLimit) {
if err := flushBatch(); err != nil {
return err
}
batchBytes = streamFastBatchHeaderLen
}
if batch == nil {
batch = make([]streamFastDataFrame, 0, minInt(len(requests), streamFastBatchMaxItems))
}
batch = append(batch, frame)
batchBytes += frameLen
return nil
}
for _, req := range requests {
if req.hasFrame {
if err := appendFrame(req.frame, req.fastPathVersion); err != nil {
return nil, err
}
}
for _, frame := range req.frames {
if err := appendFrame(frame, req.fastPathVersion); err != nil {
return nil, err
}
}
if req.hasEncoded {
if err := flushBatch(); err != nil {
return nil, err
}
batchBytes = streamFastBatchHeaderLen
payloads = append(payloads, req.encodedPayload)
}
}
if err := flushBatch(); err != nil {
return nil, err
}
return payloads, nil
}
func streamBatchRequestApproxBytes(req streamBatchRequest) int {
total := 0
if req.hasFrame {
total += streamFastBatchFrameLen(req.frame)
}
for _, frame := range req.frames {
total += streamFastBatchFrameLen(frame)
}
if req.hasEncoded {
total += len(req.encodedPayload)
}
return total
}
func (s *streamBatchSender) encodeSingle(frame streamFastDataFrame) ([]byte, error) {
if s == nil || s.codec.encodeSingle == nil {
return nil, errTransportDetached
}
return s.codec.encodeSingle(frame)
}
func (s *streamBatchSender) encodeBatch(frames []streamFastDataFrame) ([]byte, error) {
if len(frames) == 1 || s.codec.encodeBatch == nil {
return s.encodeSingle(frames[0])
}
return s.codec.encodeBatch(frames)
}
func (s *streamBatchSender) finishRequest(req streamBatchRequest, err error) {
if s != nil {
s.queued.Add(-1)
}
req.done <- err
}
func (s *streamBatchSender) stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
s.setErr(errTransportDetached)
close(s.stopCh)
})
<-s.doneCh
}
func (s *streamBatchSender) failPending(err error) {
for {
select {
case item := <-s.reqCh:
s.finishRequest(item, err)
default:
return
}
}
}
func (s *streamBatchSender) setErr(err error) {
if s == nil || err == nil {
return
}
s.errMu.Lock()
if s.err == nil {
s.err = err
}
s.errMu.Unlock()
}
func (s *streamBatchSender) errSnapshot() error {
if s == nil {
return errTransportDetached
}
s.errMu.Lock()
defer s.errMu.Unlock()
return s.err
}
func (s *streamBatchSender) stoppedErr() error {
if err := s.errSnapshot(); err != nil {
return err
}
return errTransportDetached
}
func (r streamBatchRequest) contextErr() error {
if r.ctx == nil {
return nil
}
select {
case <-r.ctx.Done():
return normalizeStreamDeadlineError(r.ctx.Err())
default:
return nil
}
}
func (r streamBatchRequest) tryStart() bool {
if r.state == nil {
return true
}
return r.state.value.CompareAndSwap(streamBatchRequestQueued, streamBatchRequestStarted)
}
func (r streamBatchRequest) tryCancel() bool {
if r.state == nil {
return false
}
return r.state.value.CompareAndSwap(streamBatchRequestQueued, streamBatchRequestCanceled)
}
func (r streamBatchRequest) canceledErr() error {
if err := r.contextErr(); err != nil {
return err
}
return context.Canceled
}
+116 -20
View File
@@ -56,7 +56,59 @@ func BenchmarkStreamTCPThroughput(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg) benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg, benchmarkTransportSecurityModernPSK)
})
}
}
func BenchmarkStreamTCPThroughputTrustedRaw(b *testing.B) {
cases := []struct {
name string
payloadSize int
cfg StreamConfig
}{
{
name: "default_64KiB",
payloadSize: 64 * 1024,
},
{
name: "tuned_256KiB",
payloadSize: 256 * 1024,
cfg: StreamConfig{
ChunkSize: 256 * 1024,
InboundQueueLimit: 256,
InboundBufferedBytesLimit: 32 * 1024 * 1024,
OutboundWindowBytes: 8 * 1024 * 1024,
OutboundMaxInFlightChunks: 32,
},
},
{
name: "tuned_512KiB",
payloadSize: 512 * 1024,
cfg: StreamConfig{
ChunkSize: 512 * 1024,
InboundQueueLimit: 256,
InboundBufferedBytesLimit: 64 * 1024 * 1024,
OutboundWindowBytes: 16 * 1024 * 1024,
OutboundMaxInFlightChunks: 32,
},
},
{
name: "tuned_1MiB",
payloadSize: 1024 * 1024,
cfg: StreamConfig{
ChunkSize: 1024 * 1024,
InboundQueueLimit: 256,
InboundBufferedBytesLimit: 64 * 1024 * 1024,
OutboundWindowBytes: 16 * 1024 * 1024,
OutboundMaxInFlightChunks: 32,
},
},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg, benchmarkTransportSecurityTrustedRaw)
}) })
} }
} }
@@ -108,19 +160,69 @@ func BenchmarkStreamTCPThroughputConcurrent(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg) benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg, benchmarkTransportSecurityModernPSK)
}) })
} }
} }
func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfig) { func BenchmarkStreamTCPThroughputConcurrentTrustedRaw(b *testing.B) {
cases := []struct {
name string
payloadSize int
concurrency int
cfg StreamConfig
}{
{
name: "streams_2_512KiB",
payloadSize: 512 * 1024,
concurrency: 2,
cfg: StreamConfig{
ChunkSize: 512 * 1024,
InboundQueueLimit: 512,
InboundBufferedBytesLimit: 128 * 1024 * 1024,
OutboundWindowBytes: 32 * 1024 * 1024,
OutboundMaxInFlightChunks: 64,
},
},
{
name: "streams_4_512KiB",
payloadSize: 512 * 1024,
concurrency: 4,
cfg: StreamConfig{
ChunkSize: 512 * 1024,
InboundQueueLimit: 1024,
InboundBufferedBytesLimit: 256 * 1024 * 1024,
OutboundWindowBytes: 64 * 1024 * 1024,
OutboundMaxInFlightChunks: 128,
},
},
{
name: "streams_8_512KiB",
payloadSize: 512 * 1024,
concurrency: 8,
cfg: StreamConfig{
ChunkSize: 512 * 1024,
InboundQueueLimit: 2048,
InboundBufferedBytesLimit: 512 * 1024 * 1024,
OutboundWindowBytes: 128 * 1024 * 1024,
OutboundMaxInFlightChunks: 256,
},
},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg, benchmarkTransportSecurityTrustedRaw)
})
}
}
func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfig, securityMode benchmarkTransportSecurityMode) {
b.Helper() b.Helper()
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
server.SetStreamConfig(cfg) server.SetStreamConfig(cfg)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyServerTransportSecurity(b, server, securityMode)
b.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, 1) acceptCh := make(chan StreamAcceptInfo, 1)
server.SetStreamHandler(func(info StreamAcceptInfo) error { server.SetStreamHandler(func(info StreamAcceptInfo) error {
@@ -128,7 +230,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi
return nil return nil
}) })
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil {
b.Fatalf("server Listen failed: %v", err) b.Fatalf("server Listen failed: %v", err)
} }
b.Cleanup(func() { b.Cleanup(func() {
@@ -137,10 +239,8 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi
client := NewClient().(*ClientCommon) client := NewClient().(*ClientCommon)
client.SetStreamConfig(cfg) client.SetStreamConfig(cfg)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyClientTransportSecurity(b, client, securityMode)
b.Fatalf("UseModernPSKClient failed: %v", err) if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
b.Fatalf("client Connect failed: %v", err) b.Fatalf("client Connect failed: %v", err)
} }
b.Cleanup(func() { b.Cleanup(func() {
@@ -198,7 +298,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi
_ = stream.Close() _ = stream.Close()
} }
func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, cfg StreamConfig) { func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, cfg StreamConfig, securityMode benchmarkTransportSecurityMode) {
b.Helper() b.Helper()
if concurrency <= 0 { if concurrency <= 0 {
b.Fatal("concurrency must be > 0") b.Fatal("concurrency must be > 0")
@@ -206,9 +306,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
server.SetStreamConfig(cfg) server.SetStreamConfig(cfg)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyServerTransportSecurity(b, server, securityMode)
b.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, concurrency*2) acceptCh := make(chan StreamAcceptInfo, concurrency*2)
server.SetStreamHandler(func(info StreamAcceptInfo) error { server.SetStreamHandler(func(info StreamAcceptInfo) error {
@@ -216,7 +314,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu
return nil return nil
}) })
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil {
b.Fatalf("server Listen failed: %v", err) b.Fatalf("server Listen failed: %v", err)
} }
b.Cleanup(func() { b.Cleanup(func() {
@@ -225,10 +323,8 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu
client := NewClient().(*ClientCommon) client := NewClient().(*ClientCommon)
client.SetStreamConfig(cfg) client.SetStreamConfig(cfg)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyClientTransportSecurity(b, client, securityMode)
b.Fatalf("UseModernPSKClient failed: %v", err) if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
b.Fatalf("client Connect failed: %v", err) b.Fatalf("client Connect failed: %v", err)
} }
b.Cleanup(func() { b.Cleanup(func() {
+85
View File
@@ -0,0 +1,85 @@
package notify
import (
"context"
"errors"
"testing"
"time"
)
func TestStreamOwnedChunkReleaseAfterRead(t *testing.T) {
stream := newStreamHandle(context.Background(), newStreamRuntime("stream-buffer-release-read"), clientFileScope(), StreamOpenRequest{
StreamID: "stream-buffer-release-read",
DataID: 1,
}, 0, nil, nil, 0, nil, nil, nil, streamConfig{})
released := 0
if err := stream.pushOwnedChunkWithRelease([]byte("hello"), func() {
released++
}); err != nil {
t.Fatalf("pushOwnedChunkWithRelease failed: %v", err)
}
buf := make([]byte, 5)
n, err := stream.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n != 5 || string(buf[:n]) != "hello" {
t.Fatalf("Read = %d %q, want 5 hello", n, string(buf[:n]))
}
if released != 1 {
t.Fatalf("release count = %d, want 1", released)
}
}
func TestStreamOwnedChunkReleaseOnReset(t *testing.T) {
stream := newStreamHandle(context.Background(), newStreamRuntime("stream-buffer-release-reset"), clientFileScope(), StreamOpenRequest{
StreamID: "stream-buffer-release-reset",
DataID: 1,
}, 0, nil, nil, 0, nil, nil, nil, streamConfig{})
released := 0
if err := stream.pushOwnedChunkWithRelease([]byte("hello"), func() {
released++
}); err != nil {
t.Fatalf("pushOwnedChunkWithRelease failed: %v", err)
}
stream.markReset(errors.New("boom"))
if released != 1 {
t.Fatalf("release count = %d, want 1", released)
}
}
func TestClientDispatchFastStreamDataWithOwnerReleasesAfterRead(t *testing.T) {
client := NewClient().(*ClientCommon)
runtime := client.getStreamRuntime()
if runtime == nil {
t.Fatal("client stream runtime should not be nil")
}
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "stream-owner",
DataID: 23,
Channel: StreamDataChannel,
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
if err := runtime.register(clientFileScope(), stream); err != nil {
t.Fatalf("register stream failed: %v", err)
}
released := 0
owner := newStreamReadPayloadOwner(func() {
released++
})
client.dispatchFastStreamDataWithOwner(streamFastDataFrame{
DataID: 23,
Seq: 1,
Payload: []byte("fast-owner"),
}, owner)
owner.done()
readStreamExactly(t, stream, "fast-owner", 2*time.Second)
if released != 1 {
t.Fatalf("release count = %d, want 1", released)
}
}
+15 -6
View File
@@ -7,19 +7,22 @@ import (
) )
type StreamOpenRequest struct { type StreamOpenRequest struct {
StreamID string StreamID string
DataID uint64 DataID uint64
Channel StreamChannel FastPathVersion uint8
Metadata StreamMetadata Channel StreamChannel
ReadTimeout time.Duration Metadata StreamMetadata
WriteTimeout time.Duration ReadTimeout time.Duration
WriteTimeout time.Duration
} }
type StreamOpenResponse struct { type StreamOpenResponse struct {
StreamID string StreamID string
DataID uint64 DataID uint64
FastPathVersion uint8
Accepted bool Accepted bool
TransportGeneration uint64 TransportGeneration uint64
Metadata StreamMetadata
Error string Error string
} }
@@ -91,10 +94,13 @@ func (c *ClientCommon) handleInboundStreamOpen(msg *Message) {
return return
} }
scope := clientFileScope() scope := clientFileScope()
req.FastPathVersion = negotiateStreamFastPathVersion(req.FastPathVersion)
resp.FastPathVersion = req.FastPathVersion
if req.DataID == 0 { if req.DataID == 0 {
req.DataID = runtime.nextDataID() req.DataID = runtime.nextDataID()
resp.DataID = req.DataID resp.DataID = req.DataID
} }
req.Metadata, resp.Metadata = negotiateRecordStreamOpenMetadata(req.Channel, req.Metadata)
stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, scope, req, c.currentClientSessionEpoch(), nil, nil, 0, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot()) stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, scope, req, c.currentClientSessionEpoch(), nil, nil, 0, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot())
stream.setClientSnapshotOwner(c) stream.setClientSnapshotOwner(c)
stream.setAddrSnapshot(c.clientStreamAddrSnapshot()) stream.setAddrSnapshot(c.clientStreamAddrSnapshot())
@@ -176,10 +182,13 @@ func (s *ServerCommon) handleInboundStreamOpen(msg *Message) {
} }
transport := messageTransportConnSnapshot(msg) transport := messageTransportConnSnapshot(msg)
scope := serverFileScope(logical) scope := serverFileScope(logical)
req.FastPathVersion = negotiateStreamFastPathVersion(req.FastPathVersion)
resp.FastPathVersion = req.FastPathVersion
if req.DataID == 0 { if req.DataID == 0 {
req.DataID = runtime.nextDataID() req.DataID = runtime.nextDataID()
resp.DataID = req.DataID resp.DataID = req.DataID
} }
req.Metadata, resp.Metadata = negotiateRecordStreamOpenMetadata(req.Channel, req.Metadata)
stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, streamTransportGeneration(logical, transport), serverStreamCloseSender(s, logical, transport), serverStreamResetSender(s, logical, transport), serverStreamDataSender(s, transport), runtime.configSnapshot()) stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, streamTransportGeneration(logical, transport), serverStreamCloseSender(s, logical, transport), serverStreamResetSender(s, logical, transport), serverStreamDataSender(s, transport), runtime.configSnapshot())
if err := runtime.register(scope, stream); err != nil { if err := runtime.register(scope, stream); err != nil {
resp.Error = err.Error() resp.Error = err.Error()
+22 -2
View File
@@ -83,6 +83,10 @@ func (s *ServerCommon) dispatchStreamEnvelope(logical *LogicalConn, transport *T
} }
func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) { func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
c.dispatchFastStreamDataWithOwner(frame, nil)
}
func (c *ClientCommon) dispatchFastStreamDataWithOwner(frame streamFastDataFrame, owner *streamReadPayloadOwner) {
if frame.DataID == 0 { if frame.DataID == 0 {
return return
} }
@@ -107,7 +111,13 @@ func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
c.bestEffortRejectInboundStreamData(stream.ID(), frame.DataID, detachErr.Error()) c.bestEffortRejectInboundStreamData(stream.ID(), frame.DataID, detachErr.Error())
return return
} }
if err := stream.pushOwnedChunk(frame.Payload); err != nil { var err error
if owner != nil {
err = stream.pushOwnedChunkWithRelease(frame.Payload, owner.retainChunk())
} else {
err = stream.pushOwnedChunk(frame.Payload)
}
if err != nil {
if c.showError || c.debugMode { if c.showError || c.debugMode {
fmt.Println("client stream push chunk error", err) fmt.Println("client stream push chunk error", err)
} }
@@ -118,6 +128,10 @@ func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
} }
func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame) { func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame) {
s.dispatchFastStreamDataWithOwner(logical, transport, conn, frame, nil)
}
func (s *ServerCommon) dispatchFastStreamDataWithOwner(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame, owner *streamReadPayloadOwner) {
if logical == nil || frame.DataID == 0 { if logical == nil || frame.DataID == 0 {
return return
} }
@@ -141,7 +155,13 @@ func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *T
s.bestEffortRejectInboundStreamData(logical, transport, conn, stream.ID(), frame.DataID, detachErr.Error()) s.bestEffortRejectInboundStreamData(logical, transport, conn, stream.ID(), frame.DataID, detachErr.Error())
return return
} }
if err := stream.pushOwnedChunk(frame.Payload); err != nil { var err error
if owner != nil {
err = stream.pushOwnedChunkWithRelease(frame.Payload, owner.retainChunk())
} else {
err = stream.pushOwnedChunk(frame.Payload)
}
if err != nil {
if s.showError || s.debugMode { if s.showError || s.debugMode {
fmt.Println("server stream push chunk error", err) fmt.Println("server stream push chunk error", err)
} }
+189 -12
View File
@@ -1,8 +1,10 @@
package notify package notify
import ( import (
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io"
) )
var ( var (
@@ -15,6 +17,7 @@ const (
streamFastPayloadVersion = 1 streamFastPayloadVersion = 1
streamFastPayloadTypeData = 1 streamFastPayloadTypeData = 1
streamFastPayloadHeaderLen = 28 streamFastPayloadHeaderLen = 28
streamFastBatchDirectLimit = 512 * 1024
) )
type streamFastDataFrame struct { type streamFastDataFrame struct {
@@ -24,6 +27,56 @@ type streamFastDataFrame struct {
Payload []byte Payload []byte
} }
func streamAdaptiveFramePayloadLimit(binding *transportBinding) int {
if binding == nil {
return 0
}
limit := binding.streamAdaptiveSoftPayloadBytesSnapshot() - streamFastPayloadHeaderLen
if limit <= 0 {
return 1
}
maxPayload := streamFastBatchMaxPlainBytes - streamFastPayloadHeaderLen
if limit > maxPayload {
return maxPayload
}
return limit
}
func streamFastSplitFrameCount(size int, maxPayload int) int {
if size <= 0 || maxPayload <= 0 {
return 1
}
return (size + maxPayload - 1) / maxPayload
}
func buildStreamFastSplitFrames(dataID uint64, startSeq uint64, chunk []byte, maxPayload int) []streamFastDataFrame {
if len(chunk) == 0 {
return nil
}
if maxPayload <= 0 || len(chunk) <= maxPayload {
return []streamFastDataFrame{{
DataID: dataID,
Seq: startSeq,
Payload: chunk,
}}
}
frames := make([]streamFastDataFrame, 0, streamFastSplitFrameCount(len(chunk), maxPayload))
seq := startSeq
for offset := 0; offset < len(chunk); offset += maxPayload {
end := offset + maxPayload
if end > len(chunk) {
end = len(chunk)
}
frames = append(frames, streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: chunk[offset:end],
})
seq++
}
return frames
}
func encodeStreamFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error { func encodeStreamFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error {
if dataID == 0 { if dataID == 0 {
return errStreamFastDataIDEmpty return errStreamFastDataIDEmpty
@@ -51,6 +104,31 @@ func encodeStreamFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byt
return frame, nil return frame, nil
} }
func encodeStreamFastFramePayload(frame streamFastDataFrame) ([]byte, error) {
framePayload := make([]byte, streamFastPayloadHeaderLen+len(frame.Payload))
if err := encodeStreamFastDataFrameHeader(framePayload, frame.DataID, frame.Seq, len(frame.Payload)); err != nil {
return nil, err
}
framePayload[6] = frame.Flags
copy(framePayload[streamFastPayloadHeaderLen:], frame.Payload)
return framePayload, nil
}
func encodeStreamFastFramePayloadFast(encode transportFastPlainEncoder, secretKey []byte, frame streamFastDataFrame) ([]byte, error) {
if encode == nil {
return nil, errTransportPayloadEncryptFailed
}
plainLen := streamFastPayloadHeaderLen + len(frame.Payload)
return encode(secretKey, plainLen, func(dst []byte) error {
if err := encodeStreamFastDataFrameHeader(dst, frame.DataID, frame.Seq, len(frame.Payload)); err != nil {
return err
}
dst[6] = frame.Flags
copy(dst[streamFastPayloadHeaderLen:], frame.Payload)
return nil
})
}
func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error) { func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error) {
if len(payload) < 4 || string(payload[:4]) != streamFastPayloadMagic { if len(payload) < 4 || string(payload[:4]) != streamFastPayloadMagic {
return streamFastDataFrame{}, false, nil return streamFastDataFrame{}, false, nil
@@ -77,18 +155,68 @@ func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error
}, true, nil }, true, nil
} }
func (c *ClientCommon) encodeFastStreamDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) { func (c *ClientCommon) encodeFastStreamPayload(frame streamFastDataFrame) ([]byte, error) {
if c != nil && c.fastStreamEncode != nil { profile := c.clientTransportProtectionSnapshot()
return c.fastStreamEncode(c.SecretKey, dataID, seq, chunk) if c != nil && profile.fastStreamEncode != nil && frame.Flags == 0 {
return profile.fastStreamEncode(profile.secretKey, frame.DataID, frame.Seq, frame.Payload)
} }
plain, err := encodeStreamFastDataFrame(dataID, seq, chunk) if c != nil && profile.fastPlainEncode != nil {
return encodeStreamFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
}
plain, err := encodeStreamFastFramePayload(frame)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c.encryptTransportPayload(plain) return c.encryptTransportPayload(plain)
} }
func (c *ClientCommon) sendFastStreamData(dataID uint64, seq uint64, chunk []byte) error { func (c *ClientCommon) encodeFastStreamDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
return c.encodeFastStreamPayload(streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: chunk,
})
}
func (c *ClientCommon) encodeFastStreamBatchPayload(frames []streamFastDataFrame) ([]byte, error) {
if c == nil {
return nil, errStreamClientNil
}
profile := c.clientTransportProtectionSnapshot()
if profile.fastPlainEncode != nil {
return encodeStreamFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
}
plain, err := encodeStreamFastBatchPlain(frames)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (c *ClientCommon) sendFastStreamData(ctx context.Context, stream *streamHandle, chunk []byte) error {
if stream == nil {
return io.ErrClosedPipe
}
dataID := stream.dataIDSnapshot()
fastPathVersion := stream.fastPathVersionSnapshot()
if binding := c.clientTransportBindingSnapshot(); binding != nil && streamFastPathSupportsBatch(fastPathVersion) {
if sender := binding.clientStreamBatchSenderSnapshot(c); sender != nil {
if maxPayload := streamAdaptiveFramePayloadLimit(binding); maxPayload > 0 && len(chunk) > maxPayload {
startSeq := stream.reserveOutboundDataSeqs(streamFastSplitFrameCount(len(chunk), maxPayload))
return sender.submitFrames(ctx, fastPathVersion, buildStreamFastSplitFrames(dataID, startSeq, chunk, maxPayload))
}
seq := stream.reserveOutboundDataSeqs(1)
if len(chunk) < streamFastBatchDirectLimit {
return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk)
}
payload, err := c.encodeFastStreamDataPayload(dataID, seq, chunk)
if err != nil {
return err
}
return sender.submitEncoded(ctx, fastPathVersion, payload)
}
}
seq := stream.reserveOutboundDataSeqs(1)
payload, err := c.encodeFastStreamDataPayload(dataID, seq, chunk) payload, err := c.encodeFastStreamDataPayload(dataID, seq, chunk)
if err != nil { if err != nil {
return err return err
@@ -96,29 +224,78 @@ func (c *ClientCommon) sendFastStreamData(dataID uint64, seq uint64, chunk []byt
return c.writePayloadToTransport(payload) return c.writePayloadToTransport(payload)
} }
func (s *ServerCommon) encodeFastStreamDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) { func (s *ServerCommon) encodeFastStreamPayloadLogical(logical *LogicalConn, frame streamFastDataFrame) ([]byte, error) {
if logical != nil { if logical == nil {
if fastStreamEncode := logical.fastStreamEncodeSnapshot(); fastStreamEncode != nil { return nil, errTransportDetached
return fastStreamEncode(logical.secretKeySnapshot(), dataID, seq, chunk)
}
} }
plain, err := encodeStreamFastDataFrame(dataID, seq, chunk) if fastStreamEncode := logical.fastStreamEncodeSnapshot(); fastStreamEncode != nil && frame.Flags == 0 {
return fastStreamEncode(logical.secretKeySnapshot(), frame.DataID, frame.Seq, frame.Payload)
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeStreamFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame)
}
plain, err := encodeStreamFastFramePayload(frame)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return s.encryptTransportPayloadLogical(logical, plain) return s.encryptTransportPayloadLogical(logical, plain)
} }
func (s *ServerCommon) sendFastStreamDataTransport(logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte) error { func (s *ServerCommon) encodeFastStreamDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
return s.encodeFastStreamPayloadLogical(logical, streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: chunk,
})
}
func (s *ServerCommon) encodeFastStreamBatchPayloadLogical(logical *LogicalConn, frames []streamFastDataFrame) ([]byte, error) {
if logical == nil {
return nil, errTransportDetached
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeStreamFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames)
}
plain, err := encodeStreamFastBatchPlain(frames)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func (s *ServerCommon) sendFastStreamDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, stream *streamHandle, chunk []byte) error {
if err := s.ensureServerTransportSendReady(transport); err != nil { if err := s.ensureServerTransportSendReady(transport); err != nil {
return err return err
} }
if stream == nil {
return io.ErrClosedPipe
}
if logical == nil && transport != nil { if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot() logical = transport.logicalConnSnapshot()
} }
if logical == nil { if logical == nil {
return errTransportDetached return errTransportDetached
} }
dataID := stream.dataIDSnapshot()
fastPathVersion := stream.fastPathVersionSnapshot()
if binding := logical.transportBindingSnapshot(); binding != nil && binding.queueSnapshot() != nil && streamFastPathSupportsBatch(fastPathVersion) {
if sender := binding.serverStreamBatchSenderSnapshot(logical); sender != nil {
if maxPayload := streamAdaptiveFramePayloadLimit(binding); maxPayload > 0 && len(chunk) > maxPayload {
startSeq := stream.reserveOutboundDataSeqs(streamFastSplitFrameCount(len(chunk), maxPayload))
return sender.submitFrames(ctx, fastPathVersion, buildStreamFastSplitFrames(dataID, startSeq, chunk, maxPayload))
}
seq := stream.reserveOutboundDataSeqs(1)
if len(chunk) < streamFastBatchDirectLimit {
return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk)
}
payload, err := s.encodeFastStreamDataPayloadLogical(logical, dataID, seq, chunk)
if err != nil {
return err
}
return sender.submitEncoded(ctx, fastPathVersion, payload)
}
}
seq := stream.reserveOutboundDataSeqs(1)
payload, err := s.encodeFastStreamDataPayloadLogical(logical, dataID, seq, chunk) payload, err := s.encodeFastStreamDataPayloadLogical(logical, dataID, seq, chunk)
if err != nil { if err != nil {
return err return err
+414
View File
@@ -4,6 +4,8 @@ import (
"b612.me/stario" "b612.me/stario"
"context" "context"
"math" "math"
"sync"
"sync/atomic"
"testing" "testing"
"time" "time"
) )
@@ -31,6 +33,46 @@ func TestStreamFastDataFrameRoundTrip(t *testing.T) {
} }
} }
func TestStreamFastBatchPlainRoundTrip(t *testing.T) {
frames := []streamFastDataFrame{
{
DataID: 11,
Seq: 7,
Payload: []byte("alpha"),
},
{
DataID: 12,
Seq: 8,
Payload: []byte("beta"),
},
}
wire, err := encodeStreamFastBatchPlain(frames)
if err != nil {
t.Fatalf("encodeStreamFastBatchPlain failed: %v", err)
}
decoded, matched, err := decodeStreamFastBatchPlain(wire)
if err != nil {
t.Fatalf("decodeStreamFastBatchPlain failed: %v", err)
}
if !matched {
t.Fatal("decodeStreamFastBatchPlain should match encoded batch")
}
if got, want := len(decoded), len(frames); got != want {
t.Fatalf("decoded frame count = %d, want %d", got, want)
}
for index := range frames {
if got, want := decoded[index].DataID, frames[index].DataID; got != want {
t.Fatalf("frame %d data id = %d, want %d", index, got, want)
}
if got, want := decoded[index].Seq, frames[index].Seq; got != want {
t.Fatalf("frame %d seq = %d, want %d", index, got, want)
}
if got, want := string(decoded[index].Payload), string(frames[index].Payload); got != want {
t.Fatalf("frame %d payload = %q, want %q", index, got, want)
}
}
}
func TestClientDispatchInboundTransportPayloadFastStream(t *testing.T) { func TestClientDispatchInboundTransportPayloadFastStream(t *testing.T) {
client := NewClient().(*ClientCommon) client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
@@ -60,6 +102,165 @@ func TestClientDispatchInboundTransportPayloadFastStream(t *testing.T) {
readStreamExactly(t, stream, "fast-payload", 2*time.Second) readStreamExactly(t, stream, "fast-payload", 2*time.Second)
} }
func TestClientDispatchInboundTransportPayloadFastStreamBatch(t *testing.T) {
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
runtime := client.getStreamRuntime()
if runtime == nil {
t.Fatal("client stream runtime should not be nil")
}
streamA := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "fast-client-a",
DataID: 23,
FastPathVersion: streamFastPathVersionCurrent,
Channel: StreamDataChannel,
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
streamB := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "fast-client-b",
DataID: 24,
FastPathVersion: streamFastPathVersionCurrent,
Channel: StreamDataChannel,
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
if err := runtime.register(clientFileScope(), streamA); err != nil {
t.Fatalf("register streamA failed: %v", err)
}
if err := runtime.register(clientFileScope(), streamB); err != nil {
t.Fatalf("register streamB failed: %v", err)
}
payload, err := client.encodeFastStreamBatchPayload([]streamFastDataFrame{
{DataID: 23, Seq: 1, Payload: []byte("fast-a")},
{DataID: 24, Seq: 2, Payload: []byte("fast-b")},
})
if err != nil {
t.Fatalf("encodeFastStreamBatchPayload failed: %v", err)
}
if err := client.dispatchInboundTransportPayload(payload, time.Now()); err != nil {
t.Fatalf("dispatchInboundTransportPayload failed: %v", err)
}
readStreamExactly(t, streamA, "fast-a", 2*time.Second)
readStreamExactly(t, streamB, "fast-b", 2*time.Second)
}
func TestStreamBatchSenderEncodeRequestsCoalescesFastV2Frames(t *testing.T) {
var (
singleCalls int
batchCalls [][]streamFastDataFrame
)
sender := &streamBatchSender{
codec: streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
singleCalls++
return []byte("single"), nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
cloned := make([]streamFastDataFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil
},
},
}
payloads, err := sender.encodeRequests([]streamBatchRequest{
{
frames: []streamFastDataFrame{{
DataID: 101,
Seq: 1,
Payload: []byte("a"),
}},
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{{
DataID: 102,
Seq: 2,
Payload: []byte("b"),
}},
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 1; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if singleCalls != 0 {
t.Fatalf("single encode calls = %d, want 0", singleCalls)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("batched frame count = %d, want %d", got, want)
}
}
func TestStreamBatchSenderEncodeRequestsFlushesBeforePreEncodedPayload(t *testing.T) {
var (
singleCalls int
batchCalls [][]streamFastDataFrame
)
sender := &streamBatchSender{
codec: streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
singleCalls++
return append([]byte("single-"), frame.Payload...), nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
cloned := make([]streamFastDataFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch-2"), nil
},
},
}
payloads, err := sender.encodeRequests([]streamBatchRequest{
{
frames: []streamFastDataFrame{
{DataID: 101, Seq: 1, Payload: []byte("a")},
{DataID: 102, Seq: 2, Payload: []byte("b")},
},
fastPathVersion: streamFastPathVersionV2,
},
{
encodedPayload: []byte("raw"),
hasEncoded: true,
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{
{DataID: 103, Seq: 3, Payload: []byte("c")},
},
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 3; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := string(payloads[0]), "batch-2"; got != want {
t.Fatalf("first payload = %q, want %q", got, want)
}
if got, want := string(payloads[1]), "raw"; got != want {
t.Fatalf("second payload = %q, want %q", got, want)
}
if got, want := string(payloads[2]), "single-c"; got != want {
t.Fatalf("third payload = %q, want %q", got, want)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if singleCalls != 1 {
t.Fatalf("single encode calls = %d, want %d", singleCalls, 1)
}
}
func TestClientPushMessageFastDispatchesDirectWithRuntimeDispatcher(t *testing.T) { func TestClientPushMessageFastDispatchesDirectWithRuntimeDispatcher(t *testing.T) {
client := NewClient().(*ClientCommon) client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
@@ -110,3 +311,216 @@ func TestClientPushMessageFastDispatchesDirectWithRuntimeDispatcher(t *testing.T
default: default:
} }
} }
func TestStreamBatchSenderEncodeRequestsResetsBatchBytesAfterFlushBoundary(t *testing.T) {
var (
singleCalls int
batchCalls [][]streamFastDataFrame
)
sender := &streamBatchSender{
codec: streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
singleCalls++
return []byte("single"), nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
cloned := make([]streamFastDataFrame, len(frames))
copy(cloned, frames)
batchCalls = append(batchCalls, cloned)
return []byte("batch"), nil
},
},
}
largePayload := make([]byte, streamFastBatchMaxPlainBytes-streamFastBatchHeaderLen-streamFastBatchItemHeaderLen-128)
payloads, err := sender.encodeRequests([]streamBatchRequest{
{
frames: []streamFastDataFrame{{
DataID: 101,
Seq: 1,
Payload: largePayload,
}},
fastPathVersion: streamFastPathVersionV2,
},
{
encodedPayload: []byte("raw"),
hasEncoded: true,
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{{
DataID: 202,
Seq: 1,
Payload: []byte("a"),
}},
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{{
DataID: 202,
Seq: 2,
Payload: []byte("b"),
}},
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 3; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := singleCalls, 1; got != want {
t.Fatalf("single encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls), 1; got != want {
t.Fatalf("batch encode calls = %d, want %d", got, want)
}
if got, want := len(batchCalls[0]), 2; got != want {
t.Fatalf("post-flush batched frame count = %d, want %d", got, want)
}
}
func TestStreamBatchSenderEncodeRequestsUsesAdaptiveSoftLimit(t *testing.T) {
binding := &transportBinding{}
binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil)
var (
singleCalls int
batchCalls int
)
sender := &streamBatchSender{
binding: binding,
codec: streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
singleCalls++
return []byte("single"), nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
batchCalls++
return []byte("batch"), nil
},
},
}
payload := make([]byte, 160*1024)
payloads, err := sender.encodeRequests([]streamBatchRequest{
{
frames: []streamFastDataFrame{{
DataID: 101,
Seq: 1,
Payload: payload,
}},
fastPathVersion: streamFastPathVersionV2,
},
{
frames: []streamFastDataFrame{{
DataID: 102,
Seq: 2,
Payload: payload,
}},
fastPathVersion: streamFastPathVersionV2,
},
})
if err != nil {
t.Fatalf("encodeRequests failed: %v", err)
}
if got, want := len(payloads), 2; got != want {
t.Fatalf("payload count = %d, want %d", got, want)
}
if got, want := singleCalls, 2; got != want {
t.Fatalf("single encode calls = %d, want %d", got, want)
}
if batchCalls != 0 {
t.Fatalf("batch encode calls = %d, want 0", batchCalls)
}
}
func TestClientSendFastStreamDataSplitsLargeChunkWhenAdaptiveSoftLimitShrinks(t *testing.T) {
binding := newTransportBinding(&delayedWriteConn{}, stario.NewQueue())
binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil)
var (
mu sync.Mutex
singleFrames []streamFastDataFrame
batchFrames [][]streamFastDataFrame
)
binding.streamSender = newStreamBatchSender(binding, streamBatchCodec{
encodeSingle: func(frame streamFastDataFrame) ([]byte, error) {
mu.Lock()
singleFrames = append(singleFrames, streamFastDataFrame{
DataID: frame.DataID,
Seq: frame.Seq,
Payload: append([]byte(nil), frame.Payload...),
})
mu.Unlock()
return []byte{1}, nil
},
encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) {
cloned := make([]streamFastDataFrame, len(frames))
for index := range frames {
cloned[index] = streamFastDataFrame{
DataID: frames[index].DataID,
Seq: frames[index].Seq,
Payload: append([]byte(nil), frames[index].Payload...),
}
}
mu.Lock()
batchFrames = append(batchFrames, cloned)
mu.Unlock()
return []byte{2}, nil
},
}, nil)
defer binding.stopBackgroundWorkers()
client := NewClient().(*ClientCommon)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
client.setClientSessionRuntime(&clientSessionRuntime{
transport: binding,
transportAttached: true,
stopCtx: stopCtx,
stopFn: stopFn,
queue: binding.queueSnapshot(),
inboundDispatcher: newInboundDispatcher(),
suppressGoodByeOnStop: &atomic.Bool{},
})
stream := &streamHandle{
dataID: 41,
fastPathVersion: streamFastPathVersionV2,
}
chunk := make([]byte, streamFastBatchDirectLimit+128*1024)
for index := range chunk {
chunk[index] = byte(index)
}
if err := client.sendFastStreamData(context.Background(), stream, chunk); err != nil {
t.Fatalf("sendFastStreamData failed: %v", err)
}
expectedFrames := streamFastSplitFrameCount(len(chunk), streamAdaptiveFramePayloadLimit(binding))
if got, want := int(stream.outboundSeq.Load()), expectedFrames; got != want {
t.Fatalf("reserved seq count = %d, want %d", got, want)
}
mu.Lock()
defer mu.Unlock()
if len(batchFrames) != 0 {
t.Fatalf("batch encode calls = %d, want 0", len(batchFrames))
}
if got, want := len(singleFrames), expectedFrames; got != want {
t.Fatalf("single frame count = %d, want %d", got, want)
}
rebuilt := make([]byte, 0, len(chunk))
for index, frame := range singleFrames {
if got, want := frame.DataID, uint64(41); got != want {
t.Fatalf("frame %d data id = %d, want %d", index, got, want)
}
if got, want := frame.Seq, uint64(index+1); got != want {
t.Fatalf("frame %d seq = %d, want %d", index, got, want)
}
rebuilt = append(rebuilt, frame.Payload...)
}
if string(rebuilt) != string(chunk) {
t.Fatal("rebuilt payload does not match original chunk")
}
}
+165 -55
View File
@@ -3,15 +3,19 @@ package notify
import ( import (
"context" "context"
"sync" "sync"
"sync/atomic"
) )
type streamFlowController struct { type streamFlowController struct {
mu sync.Mutex mu sync.Mutex
queue []*streamFlowRequest
inFlightBytes int queue []*streamFlowRequest
inFlightChunks int
windowBytes int inFlightBytes atomic.Int64
maxChunks int inFlightChunks atomic.Int64
windowBytes atomic.Int64
maxChunks atomic.Int64
waiters atomic.Int32
} }
type streamFlowRequest struct { type streamFlowRequest struct {
@@ -22,10 +26,10 @@ type streamFlowRequest struct {
func newStreamFlowController(cfg streamConfig) *streamFlowController { func newStreamFlowController(cfg streamConfig) *streamFlowController {
cfg = normalizeStreamConfig(cfg) cfg = normalizeStreamConfig(cfg)
return &streamFlowController{ controller := &streamFlowController{}
windowBytes: cfg.OutboundWindowBytes, controller.windowBytes.Store(int64(cfg.OutboundWindowBytes))
maxChunks: cfg.OutboundMaxInFlightChunks, controller.maxChunks.Store(int64(cfg.OutboundMaxInFlightChunks))
} return controller
} }
func (c *streamFlowController) applyConfig(cfg streamConfig) { func (c *streamFlowController) applyConfig(cfg streamConfig) {
@@ -33,9 +37,12 @@ func (c *streamFlowController) applyConfig(cfg streamConfig) {
return return
} }
cfg = normalizeStreamConfig(cfg) cfg = normalizeStreamConfig(cfg)
c.windowBytes.Store(int64(cfg.OutboundWindowBytes))
c.maxChunks.Store(int64(cfg.OutboundMaxInFlightChunks))
if c.waiters.Load() == 0 {
return
}
c.mu.Lock() c.mu.Lock()
c.windowBytes = cfg.OutboundWindowBytes
c.maxChunks = cfg.OutboundMaxInFlightChunks
c.drainLocked() c.drainLocked()
c.mu.Unlock() c.mu.Unlock()
} }
@@ -47,58 +54,32 @@ func (c *streamFlowController) acquire(ctx context.Context, size int) (func(), e
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
if c.tryAcquire(size) {
return c.releaseFunc(size), nil
}
req := &streamFlowRequest{ req := &streamFlowRequest{
size: size, size: size,
ready: make(chan struct{}), ready: make(chan struct{}),
} }
c.mu.Lock() c.mu.Lock()
if c.tryAcquireLocked(size) {
c.mu.Unlock()
return c.releaseFunc(size), nil
}
c.queue = append(c.queue, req) c.queue = append(c.queue, req)
c.waiters.Add(1)
c.drainLocked() c.drainLocked()
c.mu.Unlock() c.mu.Unlock()
select { select {
case <-req.ready: case <-req.ready:
released := false return c.releaseFunc(size), nil
return func() {
c.mu.Lock()
if released {
c.mu.Unlock()
return
}
released = true
c.inFlightBytes -= size
if c.inFlightBytes < 0 {
c.inFlightBytes = 0
}
if c.inFlightChunks > 0 {
c.inFlightChunks--
}
c.drainLocked()
c.mu.Unlock()
}, nil
case <-ctx.Done(): case <-ctx.Done():
c.mu.Lock() c.mu.Lock()
if req.admitted { if req.admitted {
c.mu.Unlock() c.mu.Unlock()
released := false return c.releaseFunc(size), nil
return func() {
c.mu.Lock()
if released {
c.mu.Unlock()
return
}
released = true
c.inFlightBytes -= size
if c.inFlightBytes < 0 {
c.inFlightBytes = 0
}
if c.inFlightChunks > 0 {
c.inFlightChunks--
}
c.drainLocked()
c.mu.Unlock()
}, nil
} }
c.removeLocked(req) c.removeLocked(req)
c.drainLocked() c.drainLocked()
@@ -107,6 +88,128 @@ func (c *streamFlowController) acquire(ctx context.Context, size int) (func(), e
} }
} }
func (c *streamFlowController) tryAcquire(size int) bool {
if c == nil || size <= 0 {
return true
}
if c.waiters.Load() != 0 {
return false
}
return c.tryAcquireCAS(size)
}
func (c *streamFlowController) tryAcquireLocked(size int) bool {
if c == nil || size <= 0 {
return true
}
if len(c.queue) != 0 {
return false
}
return c.tryAcquireCAS(size)
}
func (c *streamFlowController) tryAcquireCAS(size int) bool {
if c == nil || size <= 0 {
return true
}
size64 := int64(size)
for {
window := c.windowBytes.Load()
maxChunks := c.maxChunks.Load()
inFlightBytes := c.inFlightBytes.Load()
inFlightChunks := c.inFlightChunks.Load()
if maxChunks > 0 && inFlightChunks >= maxChunks {
return false
}
if window > 0 && inFlightBytes+size64 > window {
if !(inFlightBytes == 0 && inFlightChunks == 0) {
return false
}
}
if !c.inFlightBytes.CompareAndSwap(inFlightBytes, inFlightBytes+size64) {
continue
}
if c.addChunksCAS(1, maxChunks) {
return true
}
c.subBytesCAS(size64)
return false
}
}
func (c *streamFlowController) addChunksCAS(delta int64, maxChunks int64) bool {
if c == nil || delta <= 0 {
return true
}
for {
current := c.inFlightChunks.Load()
if maxChunks > 0 && current+delta > maxChunks {
return false
}
if c.inFlightChunks.CompareAndSwap(current, current+delta) {
return true
}
}
}
func (c *streamFlowController) subBytesCAS(delta int64) {
if c == nil || delta <= 0 {
return
}
for {
current := c.inFlightBytes.Load()
next := current - delta
if next < 0 {
next = 0
}
if c.inFlightBytes.CompareAndSwap(current, next) {
return
}
}
}
func (c *streamFlowController) subChunksCAS(delta int64) {
if c == nil || delta <= 0 {
return
}
for {
current := c.inFlightChunks.Load()
next := current - delta
if next < 0 {
next = 0
}
if c.inFlightChunks.CompareAndSwap(current, next) {
return
}
}
}
func (c *streamFlowController) releaseFunc(size int) func() {
released := false
return func() {
if released {
return
}
released = true
c.release(size)
}
}
func (c *streamFlowController) release(size int) {
if c == nil || size <= 0 {
return
}
c.subBytesCAS(int64(size))
c.subChunksCAS(1)
if c.waiters.Load() == 0 {
return
}
c.mu.Lock()
c.drainLocked()
c.mu.Unlock()
}
func (c *streamFlowController) removeLocked(req *streamFlowRequest) { func (c *streamFlowController) removeLocked(req *streamFlowRequest) {
if c == nil || req == nil { if c == nil || req == nil {
return return
@@ -118,6 +221,7 @@ func (c *streamFlowController) removeLocked(req *streamFlowRequest) {
copy(c.queue[i:], c.queue[i+1:]) copy(c.queue[i:], c.queue[i+1:])
c.queue[len(c.queue)-1] = nil c.queue[len(c.queue)-1] = nil
c.queue = c.queue[:len(c.queue)-1] c.queue = c.queue[:len(c.queue)-1]
c.waiters.Add(-1)
return return
} }
} }
@@ -132,18 +236,17 @@ func (c *streamFlowController) drainLocked() {
c.queue = c.queue[1:] c.queue = c.queue[1:]
continue continue
} }
if c.maxChunks > 0 && c.inFlightChunks >= c.maxChunks { if !c.canAdmitLocked(req.size) {
return return
} }
if !c.canAdmitLocked(req.size) { if !c.tryAcquireCAS(req.size) {
return return
} }
copy(c.queue[0:], c.queue[1:]) copy(c.queue[0:], c.queue[1:])
c.queue[len(c.queue)-1] = nil c.queue[len(c.queue)-1] = nil
c.queue = c.queue[:len(c.queue)-1] c.queue = c.queue[:len(c.queue)-1]
c.waiters.Add(-1)
req.admitted = true req.admitted = true
c.inFlightBytes += req.size
c.inFlightChunks++
close(req.ready) close(req.ready)
} }
} }
@@ -155,11 +258,18 @@ func (c *streamFlowController) canAdmitLocked(size int) bool {
if size <= 0 { if size <= 0 {
return true return true
} }
if c.windowBytes <= 0 { window := c.windowBytes.Load()
chunks := c.inFlightChunks.Load()
bytes := c.inFlightBytes.Load()
maxChunks := c.maxChunks.Load()
if maxChunks > 0 && chunks >= maxChunks {
return false
}
if window <= 0 {
return true return true
} }
if c.inFlightBytes+size <= c.windowBytes { if bytes+int64(size) <= window {
return true return true
} }
return c.inFlightBytes == 0 && c.inFlightChunks == 0 return bytes == 0 && chunks == 0
} }
+48
View File
@@ -105,3 +105,51 @@ func TestStreamFlowControllerAdmitsRequestsFIFO(t *testing.T) {
t.Fatalf("second admitted request = %d, want 3", second) t.Fatalf("second admitted request = %d, want 3", second)
} }
} }
func TestStreamFlowControllerTryAcquireDoesNotBypassQueuedWaiter(t *testing.T) {
controller := newStreamFlowController(streamConfig{
ChunkSize: 4,
InboundQueueLimit: 1,
InboundBufferedBytesLimit: 4,
OutboundWindowBytes: 4,
OutboundMaxInFlightChunks: 1,
})
releaseFirst, err := controller.acquire(context.Background(), 4)
if err != nil {
t.Fatalf("first acquire failed: %v", err)
}
defer releaseFirst()
waiterReady := make(chan struct{})
waiterAcquired := make(chan func(), 1)
go func() {
close(waiterReady)
releaseSecond, err := controller.acquire(context.Background(), 4)
if err != nil {
t.Errorf("waiter acquire failed: %v", err)
return
}
waiterAcquired <- releaseSecond
}()
<-waiterReady
time.Sleep(20 * time.Millisecond)
if controller.tryAcquire(4) {
t.Fatal("tryAcquire should not bypass queued waiter")
}
releaseFirst()
var releaseSecond func()
select {
case releaseSecond = <-waiterAcquired:
case <-time.After(time.Second):
t.Fatal("timed out waiting for queued waiter to acquire")
}
if releaseSecond == nil {
t.Fatal("queued waiter returned nil release func")
}
releaseSecond()
}
+37 -13
View File
@@ -3,7 +3,6 @@ package notify
import ( import (
"context" "context"
"fmt" "fmt"
"strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -17,7 +16,7 @@ type streamRuntime struct {
mu sync.RWMutex mu sync.RWMutex
handler func(StreamAcceptInfo) error handler func(StreamAcceptInfo) error
streams map[string]*streamHandle streams map[string]*streamHandle
data map[string]*streamHandle data map[string]map[uint64]*streamHandle
cfg streamConfig cfg streamConfig
flow *streamFlowController flow *streamFlowController
} }
@@ -27,7 +26,7 @@ func newStreamRuntime(rolePrefix string) *streamRuntime {
return &streamRuntime{ return &streamRuntime{
rolePrefix: rolePrefix, rolePrefix: rolePrefix,
streams: make(map[string]*streamHandle), streams: make(map[string]*streamHandle),
data: make(map[string]*streamHandle), data: make(map[string]map[uint64]*streamHandle),
cfg: cfg, cfg: cfg,
flow: newStreamFlowController(cfg), flow: newStreamFlowController(cfg),
} }
@@ -72,18 +71,23 @@ func (r *streamRuntime) register(scope string, stream *streamHandle) error {
if stream == nil || stream.id == "" { if stream == nil || stream.id == "" {
return errStreamIDEmpty return errStreamIDEmpty
} }
scope = normalizeFileScope(scope)
key := streamRuntimeKey(scope, stream.id) key := streamRuntimeKey(scope, stream.id)
dataKey := streamRuntimeDataKey(scope, stream.dataID)
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if _, ok := r.streams[key]; ok { if _, ok := r.streams[key]; ok {
return errStreamAlreadyExists return errStreamAlreadyExists
} }
if stream.dataID != 0 { if stream.dataID != 0 {
if _, ok := r.data[dataKey]; ok { dataScope := r.data[scope]
if dataScope == nil {
dataScope = make(map[uint64]*streamHandle)
r.data[scope] = dataScope
}
if _, ok := dataScope[stream.dataID]; ok {
return errStreamAlreadyExists return errStreamAlreadyExists
} }
r.data[dataKey] = stream dataScope[stream.dataID] = stream
} }
r.streams[key] = stream r.streams[key] = stream
return nil return nil
@@ -104,10 +108,14 @@ func (r *streamRuntime) lookupByDataID(scope string, dataID uint64) (*streamHand
if r == nil || dataID == 0 { if r == nil || dataID == 0 {
return nil, false return nil, false
} }
key := streamRuntimeDataKey(scope, dataID) scope = normalizeFileScope(scope)
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
stream, ok := r.data[key] dataScope := r.data[scope]
if dataScope == nil {
return nil, false
}
stream, ok := dataScope[dataID]
return stream, ok return stream, ok
} }
@@ -115,11 +123,17 @@ func (r *streamRuntime) remove(scope string, streamID string) {
if r == nil || streamID == "" { if r == nil || streamID == "" {
return return
} }
scope = normalizeFileScope(scope)
key := streamRuntimeKey(scope, streamID) key := streamRuntimeKey(scope, streamID)
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if stream := r.streams[key]; stream != nil && stream.dataID != 0 { if stream := r.streams[key]; stream != nil && stream.dataID != 0 {
delete(r.data, streamRuntimeDataKey(scope, stream.dataID)) if dataScope := r.data[scope]; dataScope != nil {
delete(dataScope, stream.dataID)
if len(dataScope) == 0 {
delete(r.data, scope)
}
}
} }
delete(r.streams, key) delete(r.streams, key)
} }
@@ -131,6 +145,20 @@ func (r *streamRuntime) acquireOutbound(ctx context.Context, size int) (func(),
return r.flow.acquire(ctx, size) return r.flow.acquire(ctx, size)
} }
func (r *streamRuntime) tryAcquireOutbound(size int) bool {
if r == nil || r.flow == nil {
return true
}
return r.flow.tryAcquire(size)
}
func (r *streamRuntime) releaseOutbound(size int) {
if r == nil || r.flow == nil {
return
}
r.flow.release(size)
}
func (r *streamRuntime) snapshots() []StreamSnapshot { func (r *streamRuntime) snapshots() []StreamSnapshot {
if r == nil { if r == nil {
return nil return nil
@@ -182,10 +210,6 @@ func streamRuntimeKey(scope string, streamID string) string {
return normalizeFileScope(scope) + "\x00" + streamID return normalizeFileScope(scope) + "\x00" + streamID
} }
func streamRuntimeDataKey(scope string, dataID uint64) string {
return normalizeFileScope(scope) + "\x01" + strconv.FormatUint(dataID, 10)
}
func (c *ClientCommon) getStreamRuntime() *streamRuntime { func (c *ClientCommon) getStreamRuntime() *streamRuntime {
if c == nil { if c == nil {
return nil return nil
+177
View File
@@ -0,0 +1,177 @@
package notify
import "encoding/binary"
const (
streamFastPathVersionV1 = 1
streamFastPathVersionV2 = 2
streamFastPathVersionCurrent = streamFastPathVersionV2
)
const (
streamFastBatchMagic = "NSB1"
streamFastBatchVersion = 1
streamFastBatchHeaderLen = 12
streamFastBatchItemHeaderLen = 24
streamFastBatchMaxItems = 64
streamFastBatchMaxPlainBytes = 8 * 1024 * 1024
)
func normalizeStreamFastPathVersion(version uint8) uint8 {
if version < streamFastPathVersionV1 {
return streamFastPathVersionV1
}
if version > streamFastPathVersionCurrent {
return streamFastPathVersionCurrent
}
return version
}
func negotiateStreamFastPathVersion(version uint8) uint8 {
return normalizeStreamFastPathVersion(version)
}
func streamFastPathSupportsBatch(version uint8) bool {
return normalizeStreamFastPathVersion(version) >= streamFastPathVersionV2
}
func streamFastBatchFrameLen(frame streamFastDataFrame) int {
return streamFastBatchItemHeaderLen + len(frame.Payload)
}
func streamFastBatchPlainLen(frames []streamFastDataFrame) int {
total := streamFastBatchHeaderLen
for _, frame := range frames {
total += streamFastBatchFrameLen(frame)
}
return total
}
func encodeStreamFastBatchPlain(frames []streamFastDataFrame) ([]byte, error) {
if len(frames) == 0 {
return nil, errStreamFastPayloadInvalid
}
buf := make([]byte, streamFastBatchPlainLen(frames))
if err := writeStreamFastBatchPlain(buf, frames); err != nil {
return nil, err
}
return buf, nil
}
func encodeStreamFastBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, frames []streamFastDataFrame) ([]byte, error) {
if encode == nil {
return nil, errTransportPayloadEncryptFailed
}
plainLen := streamFastBatchPlainLen(frames)
return encode(secretKey, plainLen, func(dst []byte) error {
return writeStreamFastBatchPlain(dst, frames)
})
}
func writeStreamFastBatchPlain(dst []byte, frames []streamFastDataFrame) error {
if len(frames) == 0 || len(dst) != streamFastBatchPlainLen(frames) {
return errStreamFastPayloadInvalid
}
copy(dst[:4], streamFastBatchMagic)
dst[4] = streamFastBatchVersion
binary.BigEndian.PutUint32(dst[8:12], uint32(len(frames)))
offset := streamFastBatchHeaderLen
for _, frame := range frames {
if frame.DataID == 0 {
return errStreamFastPayloadInvalid
}
dst[offset] = frame.Flags
binary.BigEndian.PutUint64(dst[offset+4:offset+12], frame.DataID)
binary.BigEndian.PutUint64(dst[offset+12:offset+20], frame.Seq)
binary.BigEndian.PutUint32(dst[offset+20:offset+24], uint32(len(frame.Payload)))
offset += streamFastBatchItemHeaderLen
copy(dst[offset:offset+len(frame.Payload)], frame.Payload)
offset += len(frame.Payload)
}
return nil
}
func walkStreamFastBatchPlain(payload []byte, fn func(streamFastDataFrame) error) (bool, error) {
if len(payload) < 4 || string(payload[:4]) != streamFastBatchMagic {
return false, nil
}
if len(payload) < streamFastBatchHeaderLen {
return true, errStreamFastPayloadInvalid
}
if payload[4] != streamFastBatchVersion {
return true, errStreamFastPayloadInvalid
}
count := int(binary.BigEndian.Uint32(payload[8:12]))
if count <= 0 {
return true, errStreamFastPayloadInvalid
}
offset := streamFastBatchHeaderLen
for index := 0; index < count; index++ {
if len(payload)-offset < streamFastBatchItemHeaderLen {
return true, errStreamFastPayloadInvalid
}
flags := payload[offset]
dataID := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
seq := binary.BigEndian.Uint64(payload[offset+12 : offset+20])
payloadLen := int(binary.BigEndian.Uint32(payload[offset+20 : offset+24]))
offset += streamFastBatchItemHeaderLen
if dataID == 0 || payloadLen < 0 || len(payload)-offset < payloadLen {
return true, errStreamFastPayloadInvalid
}
if fn != nil {
if err := fn(streamFastDataFrame{
Flags: flags,
DataID: dataID,
Seq: seq,
Payload: payload[offset : offset+payloadLen],
}); err != nil {
return true, err
}
}
offset += payloadLen
}
if offset != len(payload) {
return true, errStreamFastPayloadInvalid
}
return true, nil
}
func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, error) {
frames := make([]streamFastDataFrame, 0, 1)
matched, err := walkStreamFastBatchPlain(payload, func(frame streamFastDataFrame) error {
frames = append(frames, frame)
return nil
})
if !matched || err != nil {
return nil, matched, err
}
return frames, true, nil
}
func walkStreamFastFrames(payload []byte, fn func(streamFastDataFrame) error) (bool, error) {
if matched, err := walkStreamFastBatchPlain(payload, fn); matched {
return true, err
}
frame, matched, err := decodeStreamFastDataFrame(payload)
if !matched || err != nil {
return matched, err
}
if fn != nil {
if err := fn(frame); err != nil {
return true, err
}
}
return true, nil
}
func decodeStreamFastDataFrames(payload []byte) ([]streamFastDataFrame, bool, error) {
frames := make([]streamFastDataFrame, 0, 1)
matched, err := walkStreamFastFrames(payload, func(frame streamFastDataFrame) error {
frames = append(frames, frame)
return nil
})
if !matched || err != nil {
return nil, matched, err
}
return frames, true, nil
}
+46 -42
View File
@@ -7,48 +7,52 @@ import (
) )
type StreamSnapshot struct { type StreamSnapshot struct {
ID string ID string
DataID uint64 DataID uint64
Scope string Scope string
Channel StreamChannel Channel StreamChannel
Metadata StreamMetadata Metadata StreamMetadata
BindingOwner string BindingOwner string
BindingAlive bool BindingAlive bool
BindingCurrent bool BindingCurrent bool
BindingReason string BindingReason string
BindingError string BindingError string
SessionEpoch uint64 BindingBulkAdaptiveSoftPayloadBytes int
LogicalClientID string BindingStreamAdaptiveSoftPayloadBytes int
LocalAddress string BindingStreamAdaptiveWaitThresholdBytes int
RemoteAddress string BindingStreamAdaptiveFlushDelay time.Duration
TransportGeneration uint64 SessionEpoch uint64
TransportAttached bool LogicalClientID string
TransportHasRuntimeConn bool LocalAddress string
TransportCurrent bool RemoteAddress string
TransportDetachReason string TransportGeneration uint64
TransportDetachKind string TransportAttached bool
TransportDetachGeneration uint64 TransportHasRuntimeConn bool
TransportDetachError string TransportCurrent bool
TransportDetachedAt time.Time TransportDetachReason string
ReattachEligible bool TransportDetachKind string
LocalClosed bool TransportDetachGeneration uint64
LocalReadClosed bool TransportDetachError string
RemoteClosed bool TransportDetachedAt time.Time
PeerReadClosed bool ReattachEligible bool
BufferedChunks int LocalClosed bool
BufferedBytes int LocalReadClosed bool
ReadTimeout time.Duration RemoteClosed bool
WriteTimeout time.Duration PeerReadClosed bool
BytesRead int64 BufferedChunks int
BytesWritten int64 BufferedBytes int
ReadCalls int64 ReadTimeout time.Duration
WriteCalls int64 WriteTimeout time.Duration
OpenedAt time.Time BytesRead int64
LastReadAt time.Time BytesWritten int64
LastWriteAt time.Time ReadCalls int64
ReadDeadline time.Time WriteCalls int64
WriteDeadline time.Time OpenedAt time.Time
ResetError string LastReadAt time.Time
LastWriteAt time.Time
ReadDeadline time.Time
WriteDeadline time.Time
ResetError string
} }
type clientStreamSnapshotReader interface { type clientStreamSnapshotReader interface {

Some files were not shown because too many files have changed in this diff Show More