feat(transport): 完成安全架构拆分并收口 stream/bulk 传输优化

- 新增 managed/external/nested 三种传输保护模式
  - 新增 peer attach 显式认证、抗重放、channel binding 和可选前向保密协商
  - 明确单连接注入与可重拨连接源的语义边界
  - 禁止 ConnectByConn 场景下 dedicated bulk 走 sidecar,auto 模式自动回退 shared
  - 修正 dedicated attach 在 bootstrap/steady profile 切换下的处理逻辑
  - 优化 shared bulk super-batch 与批量 framed write 路径
  - 降低 stream/bulk fast path 的复制和分发损耗
  - 补齐 benchmark、回归测试、运行时快照和 README 文档
This commit is contained in:
兔子 2026-04-20 16:35:44 +08:00
parent f038a89771
commit 98ef9e7fcc
Signed by: b612
GPG Key ID: 99DD2222B612B612
52 changed files with 4069 additions and 445 deletions

View File

@ -25,6 +25,21 @@
未配置时会返回 `errModernPSKRequired` 未配置时会返回 `errModernPSKRequired`
## 安全模式选择
- `UseModernPSKClient` / `UseModernPSKServer`
- bootstrap 和稳态传输都由 `notify` 自己保护
- 适合默认场景
- 支持 peer attach 显式认证、抗重放,以及在需要时协商前向保密
- `UsePSKOverExternalTransportClient` / `UsePSKOverExternalTransportServer`
- bootstrap 仍用 PSK 做认证
- 稳态阶段信任外部物理通道,不再做 `notify` 内层加密
- 适合 `tls.Conn` 或调用方自认可信的外部通道
- 不支持 `RequireForwardSecrecy`
- `UseNestedSecurityClient` / `UseNestedSecurityServer`
- 外层已有可信通道,但仍保留 `notify` 内层保护
- 适合需要“外层可信 + 内层独立保护”的场景
## 快速开始 ## 快速开始
服务端: 服务端:
@ -83,6 +98,42 @@ func main() {
} }
``` ```
## 连接入口与物理连接语义
- `Connect` / `ConnectTimeout`
- 由 `notify` 自己拨号
- 支持重连,也支持 dedicated bulk 额外 sidecar 连接
- `ConnectByFactory`
- 调用方提供 `dialFn`
- `notify` 会在需要时再次调用 `dialFn`,因此仍支持重连和 dedicated bulk
- `ConnectByConn`
- 调用方注入一个已经建立好的 `net.Conn`
- 该模式被视为“单物理连接模式”
- `OpenDedicatedBulk` 会直接返回错误
- `OpenBulk` 使用 `auto` 模式时会自动回退到 `shared`
- `ListenByListener`
- 服务端复用调用方提供的 `net.Listener`
- 适合需要和现有 listener 栈整合的场景
`dedicated bulk` 依赖额外物理连接,因此只适用于可再次拨号的 transport source。
## Peer Attach 安全策略
可通过 `SetPeerAttachSecurityConfig` 配置逻辑会话 attach 阶段的额外保护。
- `RequireExplicitAuth`
- 要求 peer attach 使用显式认证
- `RequireChannelBinding`
- 要求 attach 绑定到底层可信通道
- 启用后会隐式要求显式认证
- `ChannelBinding`
- 由调用方提供 channel binding 提取函数
- 适合外层 TLS 或其他可信通道整合
- `ReplayWindow` / `ReplayCapacity`
- 控制 attach 抗重放窗口和缓存容量
如果你选择 `UsePSKOverExternalTransport*`,并且希望 attach 阶段显式绑定到外层可信信道,建议同时配置 channel binding。
## RecordStream 说明 ## RecordStream 说明
`RecordStream` 构建在 `Stream` 之上,适合“有边界的顺序记录”场景。 `RecordStream` 构建在 `Stream` 之上,适合“有边界的顺序记录”场景。
@ -146,6 +197,9 @@ func main() {
- 共享密钥派生Argon2id - 共享密钥派生Argon2id
- 消息层加密AES-GCM - 消息层加密AES-GCM
- `stream` / `bulk` fast path 复用现代编码栈 - `stream` / `bulk` fast path 复用现代编码栈
- peer attach 显式认证 / 抗重放
- 可选 channel binding
- 可选前向保密(`UseModernPSK*` / `UseNestedSecurity*`
兼容入口仍保留,但属于历史路径: 兼容入口仍保留,但属于历史路径:

View File

@ -0,0 +1,42 @@
package notify
import "testing"
type benchmarkTransportSecurityMode string
const (
benchmarkTransportSecurityModernPSK benchmarkTransportSecurityMode = "modern_psk"
benchmarkTransportSecurityTrustedRaw benchmarkTransportSecurityMode = "trusted_raw"
)
func benchmarkApplyServerTransportSecurity(tb testing.TB, server *ServerCommon, mode benchmarkTransportSecurityMode) {
tb.Helper()
switch mode {
case benchmarkTransportSecurityModernPSK:
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
tb.Fatalf("UseModernPSKServer failed: %v", err)
}
case benchmarkTransportSecurityTrustedRaw:
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
tb.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
}
default:
tb.Fatalf("unsupported benchmark transport security mode %q", mode)
}
}
func benchmarkApplyClientTransportSecurity(tb testing.TB, client *ClientCommon, mode benchmarkTransportSecurityMode) {
tb.Helper()
switch mode {
case benchmarkTransportSecurityModernPSK:
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
tb.Fatalf("UseModernPSKClient failed: %v", err)
}
case benchmarkTransportSecurityTrustedRaw:
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
tb.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
}
default:
tb.Fatalf("unsupported benchmark transport security mode %q", mode)
}
}

View File

@ -161,6 +161,7 @@ var (
errBulkRangeInvalid = errors.New("bulk range is invalid") errBulkRangeInvalid = errors.New("bulk range is invalid")
errBulkBackpressureExceeded = errors.New("bulk inbound backpressure exceeded") errBulkBackpressureExceeded = errors.New("bulk inbound backpressure exceeded")
errBulkDedicatedStreamOnly = errors.New("dedicated bulk requires stream transport") errBulkDedicatedStreamOnly = errors.New("dedicated bulk requires stream transport")
errBulkDedicatedSingleConn = errors.New("dedicated bulk requires a dialable additional connection source; ConnectByConn only supports shared transport")
errBulkDedicatedActiveLimit = errors.New("dedicated bulk active limit reached") errBulkDedicatedActiveLimit = errors.New("dedicated bulk active limit reached")
) )
@ -174,6 +175,9 @@ func clientDedicatedBulkSupportError(c *ClientCommon) error {
if source := c.clientConnectSourceSnapshot(); source != nil && source.isUDP() { if source := c.clientConnectSourceSnapshot(); source != nil && source.isUDP() {
return errBulkDedicatedStreamOnly return errBulkDedicatedStreamOnly
} }
if source := c.clientConnectSourceSnapshot(); source != nil && !source.supportsAdditionalConn() {
return errBulkDedicatedSingleConn
}
return nil return nil
} }

View File

@ -418,18 +418,18 @@ func (s *bulkBatchSender) flush(requests []bulkBatchRequest) error {
} }
}() }()
writeTimeout := s.transportWriteTimeout() writeTimeout := s.transportWriteTimeout()
frames := make([][]byte, 0, len(payloads))
payloadBytes := 0
for _, payload := range payloads { for _, payload := range payloads {
frame := payload.payload frames = append(frames, payload.payload)
payloadBytes += len(payload.payload)
}
started := time.Now() started := time.Now()
err := s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error { err = s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error {
return writeFramedPayloadUnlocked(conn, queue, frame) return writeFramedPayloadBatchUnlocked(conn, queue, frames)
}) })
s.binding.observeBulkAdaptivePayloadWrite(len(frame), time.Since(started), writeTimeout, err) s.binding.observeBulkAdaptivePayloadWrite(payloadBytes, time.Since(started), writeTimeout, err)
if err != nil {
return err return err
}
}
return nil
} }
func (s *bulkBatchSender) encodeRequests(requests []bulkBatchRequest) ([]bulkBatchEncodedPayload, error) { func (s *bulkBatchSender) encodeRequests(requests []bulkBatchRequest) ([]bulkBatchEncodedPayload, error) {

View File

@ -38,7 +38,25 @@ func BenchmarkBulkTCPThroughput(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughput(b, tc.payloadSize, false) benchmarkBulkTCPThroughput(b, tc.payloadSize, false, benchmarkTransportSecurityModernPSK)
})
}
}
func BenchmarkBulkTCPThroughputTrustedRaw(b *testing.B) {
cases := []struct {
name string
payloadSize int
}{
{name: "chunk_256KiB", payloadSize: 256 * 1024},
{name: "chunk_512KiB", payloadSize: 512 * 1024},
{name: "chunk_768KiB", payloadSize: 768 * 1024},
{name: "chunk_1MiB", payloadSize: 1024 * 1024},
{name: "chunk_2MiB", payloadSize: 2 * 1024 * 1024},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughput(b, tc.payloadSize, false, benchmarkTransportSecurityTrustedRaw)
}) })
} }
} }
@ -72,7 +90,25 @@ func BenchmarkBulkTCPThroughputDedicated(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughput(b, tc.payloadSize, true) benchmarkBulkTCPThroughput(b, tc.payloadSize, true, benchmarkTransportSecurityModernPSK)
})
}
}
func BenchmarkBulkTCPThroughputDedicatedTrustedRaw(b *testing.B) {
cases := []struct {
name string
payloadSize int
}{
{name: "chunk_256KiB", payloadSize: 256 * 1024},
{name: "chunk_512KiB", payloadSize: 512 * 1024},
{name: "chunk_768KiB", payloadSize: 768 * 1024},
{name: "chunk_1MiB", payloadSize: 1024 * 1024},
{name: "chunk_2MiB", payloadSize: 2 * 1024 * 1024},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughput(b, tc.payloadSize, true, benchmarkTransportSecurityTrustedRaw)
}) })
} }
} }
@ -107,7 +143,25 @@ func BenchmarkBulkTCPThroughputConcurrent(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false) benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false, benchmarkTransportSecurityModernPSK)
})
}
}
func BenchmarkBulkTCPThroughputConcurrentTrustedRaw(b *testing.B) {
cases := []struct {
name string
payloadSize int
concurrency int
}{
{name: "bulks_2_512KiB", payloadSize: 512 * 1024, concurrency: 2},
{name: "bulks_4_512KiB", payloadSize: 512 * 1024, concurrency: 4},
{name: "bulks_2_1MiB", payloadSize: 1024 * 1024, concurrency: 2},
{name: "bulks_4_1MiB", payloadSize: 1024 * 1024, concurrency: 4},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false, benchmarkTransportSecurityTrustedRaw)
}) })
} }
} }
@ -142,18 +196,34 @@ func BenchmarkBulkTCPThroughputConcurrentDedicated(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true) benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true, benchmarkTransportSecurityModernPSK)
}) })
} }
} }
func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) { func BenchmarkBulkTCPThroughputConcurrentDedicatedTrustedRaw(b *testing.B) {
cases := []struct {
name string
payloadSize int
concurrency int
}{
{name: "bulks_2_512KiB", payloadSize: 512 * 1024, concurrency: 2},
{name: "bulks_4_512KiB", payloadSize: 512 * 1024, concurrency: 4},
{name: "bulks_2_1MiB", payloadSize: 1024 * 1024, concurrency: 2},
{name: "bulks_4_1MiB", payloadSize: 1024 * 1024, concurrency: 4},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true, benchmarkTransportSecurityTrustedRaw)
})
}
}
func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool, securityMode benchmarkTransportSecurityMode) {
b.Helper() b.Helper()
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyServerTransportSecurity(b, server, securityMode)
b.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan BulkAcceptInfo, 1) acceptCh := make(chan BulkAcceptInfo, 1)
server.SetBulkHandler(func(info BulkAcceptInfo) error { server.SetBulkHandler(func(info BulkAcceptInfo) error {
@ -169,9 +239,7 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
}) })
client := NewClient().(*ClientCommon) client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyClientTransportSecurity(b, client, securityMode)
b.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil { if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
b.Fatalf("client Connect failed: %v", err) b.Fatalf("client Connect failed: %v", err)
} }
@ -241,16 +309,14 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
_ = bulk.Close() _ = bulk.Close()
} }
func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool) { func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool, securityMode benchmarkTransportSecurityMode) {
b.Helper() b.Helper()
if concurrency <= 0 { if concurrency <= 0 {
b.Fatal("concurrency must be > 0") b.Fatal("concurrency must be > 0")
} }
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyServerTransportSecurity(b, server, securityMode)
b.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan BulkAcceptInfo, concurrency*2) acceptCh := make(chan BulkAcceptInfo, concurrency*2)
server.SetBulkHandler(func(info BulkAcceptInfo) error { server.SetBulkHandler(func(info BulkAcceptInfo) error {
@ -266,9 +332,7 @@ func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurr
}) })
client := NewClient().(*ClientCommon) client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyClientTransportSecurity(b, client, securityMode)
b.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil { if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
b.Fatalf("client Connect failed: %v", err) b.Fatalf("client Connect failed: %v", err)
} }

View File

@ -268,6 +268,9 @@ func readBulkDedicatedRecordPooled(conn net.Conn) ([]byte, func(), error) {
func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context, timeout time.Duration) (net.Conn, error) { func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context, timeout time.Duration) (net.Conn, error) {
source := c.clientConnectSourceSnapshot() source := c.clientConnectSourceSnapshot()
if source != nil { if source != nil {
if !source.supportsAdditionalConn() {
return nil, errBulkDedicatedSingleConn
}
if source.network != "" && source.addr != "" { if source.network != "" && source.addr != "" {
if timeout > 0 { if timeout > 0 {
return transport.DialTimeout(source.network, source.addr, timeout) return transport.DialTimeout(source.network, source.addr, timeout)
@ -277,6 +280,7 @@ func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context, timeout time.D
if source.canReconnect() { if source.canReconnect() {
return source.dial(ctx) return source.dial(ctx)
} }
return nil, errClientReconnectSourceUnavailable
} }
conn := c.clientTransportConnSnapshot() conn := c.clientTransportConnSnapshot()
if conn == nil || conn.RemoteAddr() == nil { if conn == nil || conn.RemoteAddr() == nil {
@ -661,7 +665,8 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn
Value: reqPayload, Value: reqPayload,
Type: MSG_SYS_WAIT, Type: MSG_SYS_WAIT,
} }
frame, err := encodeDirectSignalFrame(stario.NewQueue(), c.sequenceEn, c.msgEn, c.SecretKey, msg) attachProfile := c.clientDedicatedBulkAttachTransportProtectionProfile()
frame, err := encodeDirectSignalFrame(stario.NewQueue(), c.sequenceEn, attachProfile.msgEn, attachProfile.secretKey, msg)
if err != nil { if err != nil {
return bulkAttachResponse{}, err return bulkAttachResponse{}, err
} }
@ -675,7 +680,7 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn
if err != nil { if err != nil {
return bulkAttachResponse{}, err return bulkAttachResponse{}, err
} }
transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, replyPayload) transfer, err := decodeDirectSignalPayload(c.sequenceDe, attachProfile.msgDe, attachProfile.secretKey, replyPayload)
if err != nil { if err != nil {
return bulkAttachResponse{}, err return bulkAttachResponse{}, err
} }
@ -685,6 +690,16 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn
return decodeBulkAttachResponse(c.sequenceDe, transfer.Value) return decodeBulkAttachResponse(c.sequenceDe, transfer.Value)
} }
func (c *ClientCommon) clientDedicatedBulkAttachTransportProtectionProfile() transportProtectionProfile {
if c == nil {
return transportProtectionProfile{}
}
if c.securityConfigured && c.securityBootstrap.msgEn != nil && c.securityBootstrap.msgDe != nil {
return c.securityBootstrap.clone()
}
return c.clientTransportProtectionSnapshot()
}
func (c *ClientCommon) readDedicatedSidecarLoop(sidecar *bulkDedicatedSidecar) { func (c *ClientCommon) readDedicatedSidecarLoop(sidecar *bulkDedicatedSidecar) {
if c == nil || sidecar == nil || sidecar.conn == nil { if c == nil || sidecar == nil || sidecar.conn == nil {
return return
@ -695,7 +710,8 @@ func (c *ClientCommon) readDedicatedSidecarLoop(sidecar *bulkDedicatedSidecar) {
c.handleClientDedicatedSidecarFailure(sidecar, err) c.handleClientDedicatedSidecarFailure(sidecar, err)
return return
} }
plain, plainRelease, err := decryptTransportPayloadCodecPooled(c.modernPSKRuntime, c.msgDe, c.SecretKey, payload, payloadRelease) profile := c.clientTransportProtectionSnapshot()
plain, plainRelease, err := decryptTransportPayloadCodecPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload, payloadRelease)
if err != nil { if err != nil {
c.handleClientDedicatedSidecarFailure(sidecar, err) c.handleClientDedicatedSidecarFailure(sidecar, err)
return return
@ -1023,7 +1039,7 @@ func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Mes
Type: MSG_SYS_REPLY, Type: MSG_SYS_REPLY,
} }
if message.inboundConn != nil { if message.inboundConn != nil {
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply) return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, messageInboundTransportProtectionSnapshot(&message), reply)
} }
_, err = s.sendLogical(client, reply) _, err = s.sendLogical(client, reply)
return err return err
@ -1041,7 +1057,7 @@ func (s *ServerCommon) readDedicatedSidecarLoop(logical *LogicalConn, sidecar *b
s.handleServerDedicatedSidecarFailure(logical, sidecar, err) s.handleServerDedicatedSidecarFailure(logical, sidecar, err)
return return
} }
plain, plainRelease, err := decryptTransportPayloadCodecPooled(logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, payloadRelease) plain, plainRelease, err := decryptTransportPayloadCodecPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, payloadRelease)
if err != nil { if err != nil {
s.handleServerDedicatedSidecarFailure(logical, sidecar, err) s.handleServerDedicatedSidecarFailure(logical, sidecar, err)
return return
@ -1171,7 +1187,8 @@ func (c *ClientCommon) dedicatedBulkLaneSender(bulk *bulkHandle) (*bulkDedicated
return nil, transportDetachedError("dedicated bulk sidecar not attached", nil) return nil, transportDetachedError("dedicated bulk sidecar not attached", nil)
} }
sender := sidecar.laneSenderWithFactory(func(conn net.Conn) *bulkDedicatedLaneSender { sender := sidecar.laneSenderWithFactory(func(conn net.Conn) *bulkDedicatedLaneSender {
laneRuntime := c.modernPSKRuntime profile := c.clientTransportProtectionSnapshot()
laneRuntime := profile.runtime
if forked, err := forkDedicatedLaneModernPSKRuntime(laneRuntime); err == nil && forked != nil { if forked, err := forkDedicatedLaneModernPSKRuntime(laneRuntime); err == nil && forked != nil {
laneRuntime = forked laneRuntime = forked
} }
@ -1198,7 +1215,7 @@ func (c *ClientCommon) sendDedicatedBulkData(ctx context.Context, bulk *bulkHand
return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk) return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk)
} }
func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte) (int, error) { func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) {
if c == nil || bulk == nil { if c == nil || bulk == nil {
return 0, errBulkClientNil return 0, errBulkClientNil
} }
@ -1206,7 +1223,7 @@ func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHan
if err != nil { if err != nil {
return 0, err return 0, err
} }
return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize) return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize, payloadOwned)
} }
func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error { func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error {
@ -1345,7 +1362,7 @@ func (s *ServerCommon) sendDedicatedBulkData(ctx context.Context, logical *Logic
return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk) return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk)
} }
func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, startSeq uint64, payload []byte) (int, error) { func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) {
if s == nil || bulk == nil { if s == nil || bulk == nil {
return 0, errBulkServerNil return 0, errBulkServerNil
} }
@ -1353,7 +1370,7 @@ func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *Logi
if err != nil { if err != nil {
return 0, err return 0, err
} }
return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize) return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize, payloadOwned)
} }
func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error { func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error {
@ -1419,13 +1436,14 @@ func (c *ClientCommon) encodeDedicatedBulkBatchPayload(dataID uint64, items []bu
if c == nil { if c == nil {
return nil, errBulkClientNil return nil, errBulkClientNil
} }
if runtime := c.modernPSKRuntime; runtime != nil { profile := c.clientTransportProtectionSnapshot()
if runtime := profile.runtime; runtime != nil {
return runtime.sealFilledPayload(bulkDedicatedBatchPlainLen(items), func(dst []byte) error { return runtime.sealFilledPayload(bulkDedicatedBatchPlainLen(items), func(dst []byte) error {
return writeBulkDedicatedBatchPlain(dst, dataID, items) return writeBulkDedicatedBatchPlain(dst, dataID, items)
}) })
} }
if c.fastPlainEncode != nil { if profile.fastPlainEncode != nil {
return encodeBulkDedicatedBatchPayloadFast(c.fastPlainEncode, c.SecretKey, dataID, items) return encodeBulkDedicatedBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, dataID, items)
} }
plain, err := encodeBulkDedicatedBatchPlain(dataID, items) plain, err := encodeBulkDedicatedBatchPlain(dataID, items)
if err != nil { if err != nil {
@ -1453,8 +1471,9 @@ func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooledWithRuntime(runtim
return writeBulkDedicatedBatchesPlain(dst, batches) return writeBulkDedicatedBatchesPlain(dst, batches)
}) })
} }
if c.fastPlainEncode != nil { profile := c.clientTransportProtectionSnapshot()
payload, err := encodeBulkDedicatedBatchesPayloadFast(c.fastPlainEncode, c.SecretKey, batches) if profile.fastPlainEncode != nil {
payload, err := encodeBulkDedicatedBatchesPayloadFast(profile.fastPlainEncode, profile.secretKey, batches)
return payload, nil, err return payload, nil, err
} }
plain, err := encodeBulkDedicatedBatchesPlain(batches) plain, err := encodeBulkDedicatedBatchesPlain(batches)
@ -1466,7 +1485,7 @@ func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooledWithRuntime(runtim
} }
func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooled(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) { func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooled(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) {
return c.encodeDedicatedBulkBatchesPayloadPooledWithRuntime(c.modernPSKRuntime, batches) return c.encodeDedicatedBulkBatchesPayloadPooledWithRuntime(c.clientTransportProtectionSnapshot().runtime, batches)
} }
func (c *ClientCommon) encodeDedicatedBulkBatchPayloadPooled(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, func(), error) { func (c *ClientCommon) encodeDedicatedBulkBatchPayloadPooled(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, func(), error) {

View File

@ -115,6 +115,76 @@ func TestSendDedicatedBulkAttachRequestKeepsCoalescedDedicatedPayloadUnread(t *t
} }
} }
func TestSendDedicatedBulkAttachRequestUsesBootstrapProtectionEvenAfterSteadySwitch(t *testing.T) {
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
client.msgID = 100
alternate, err := deriveModernPSKProtectionProfile([]byte("notify-dedicated-attach-other-secret"), integrationModernPSKOptions(), ProtectionManaged)
if err != nil {
t.Fatalf("deriveModernPSKProtectionProfile(alternate) failed: %v", err)
}
client.setClientTransportProtectionProfile(alternate)
bulk := newBulkHandle(context.Background(), newBulkRuntime("dedicated-attach-bootstrap-test"), clientFileScope(), BulkOpenRequest{
BulkID: "bulk-attach-bootstrap-test",
DataID: 1,
Dedicated: true,
AttachToken: "attach-token",
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
bootstrap := client.clientDedicatedBulkAttachTransportProtectionProfile()
encodedResp, err := client.sequenceEn(bulkAttachResponse{Accepted: true})
if err != nil {
t.Fatalf("encode bulkAttachResponse failed: %v", err)
}
replyFrame, err := encodeDirectSignalFrame(stario.NewQueue(), client.sequenceEn, bootstrap.msgEn, bootstrap.secretKey, TransferMsg{
ID: 101,
Key: systemBulkAttachKey,
Value: encodedResp,
Type: MSG_SYS_REPLY,
})
if err != nil {
t.Fatalf("encode attach reply frame failed: %v", err)
}
conn := newBulkAttachScriptConn(replyFrame)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
resp, err := client.sendDedicatedBulkAttachRequest(ctx, conn, bulk)
if err != nil {
t.Fatalf("sendDedicatedBulkAttachRequest failed: %v", err)
}
if !resp.Accepted {
t.Fatalf("bulk attach response = %+v, want accepted", resp)
}
parsedReq := stario.NewQueue()
var reqMsg TransferMsg
if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request", func(msgq stario.MsgQueue) error {
transfer, err := decodeDirectSignalPayload(client.sequenceDe, bootstrap.msgDe, bootstrap.secretKey, msgq.Msg)
if err != nil {
return err
}
reqMsg = transfer
return nil
}); err != nil {
t.Fatalf("parse written attach request with bootstrap profile failed: %v", err)
}
if reqMsg.Key != systemBulkAttachKey || reqMsg.Type != MSG_SYS_WAIT {
t.Fatalf("attach request message mismatch: %+v", reqMsg)
}
if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request-current", func(msgq stario.MsgQueue) error {
_, err := decodeDirectSignalPayload(client.sequenceDe, alternate.msgDe, alternate.secretKey, msgq.Msg)
return err
}); !errors.Is(err, errTransportPayloadDecryptFailed) {
t.Fatalf("decode written attach request with current steady profile error = %v, want %v", err, errTransportPayloadDecryptFailed)
}
}
func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHandoff(t *testing.T) { func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHandoff(t *testing.T) {
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server) UseLegacySecurityServer(server)

View File

@ -83,7 +83,7 @@ func (r *bulkDedicatedLaneBatchRequest) reset() {
} }
} }
func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool) { func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool, borrowItems bool) {
if r == nil { if r == nil {
return return
} }
@ -94,6 +94,10 @@ func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
r.Deadline = deadline r.Deadline = deadline
} }
if borrowItems {
r.Items = items
return
}
if cap(r.Items) < len(items) { if cap(r.Items) < len(items) {
r.Items = make([]bulkDedicatedSendRequest, len(items)) r.Items = make([]bulkDedicatedSendRequest, len(items))
} else { } else {
@ -119,10 +123,10 @@ func (s *bulkDedicatedLaneSender) submitData(ctx context.Context, dataID uint64,
Seq: seq, Seq: seq,
Payload: append([]byte(nil), payload...), Payload: append([]byte(nil), payload...),
}} }}
return s.submitBatch(ctx, dataID, items, false) return s.submitBatch(ctx, dataID, items, false, false)
} }
func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int) (int, error) { func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int, payloadOwned bool) (int, error) {
if s == nil { if s == nil {
return 0, errTransportDetached return 0, errTransportDetached
} }
@ -132,6 +136,9 @@ func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64
if chunkSize <= 0 { if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize chunkSize = defaultBulkChunkSize
} }
if submitted, written, err := s.tryDirectSubmitWrite(ctx, dataID, startSeq, payload, chunkSize); submitted {
return written, err
}
written := 0 written := 0
seq := startSeq seq := startSeq
for written < len(payload) { for written < len(payload) {
@ -170,7 +177,7 @@ func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64
seq++ seq++
written = end written = end
} }
if err := s.submitWriteBatch(ctx, dataID, items); err != nil { if err := s.submitWriteBatch(ctx, dataID, items, payloadOwned); err != nil {
return start, err return start, err
} }
start = written start = written
@ -178,7 +185,7 @@ func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64
return written, nil return written, nil
} }
func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) error { func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, _ bool) error {
if s == nil { if s == nil {
return errTransportDetached return errTransportDetached
} }
@ -188,8 +195,7 @@ func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID u
if submitted, err := s.tryDirectSubmitBatch(ctx, dataID, items); submitted { if submitted, err := s.tryDirectSubmitBatch(ctx, dataID, items); submitted {
return err return err
} }
queuedItems := copyBulkDedicatedSendRequests(items) return s.submitBatch(ctx, dataID, items, true, true)
return s.submitBatch(ctx, dataID, queuedItems, true)
} }
func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint64, frameType uint8, flags uint8, seq uint64, payload []byte) error { func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint64, frameType uint8, flags uint8, seq uint64, payload []byte) error {
@ -204,10 +210,10 @@ func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint
if len(payload) > 0 { if len(payload) > 0 {
items[0].Payload = append([]byte(nil), payload...) items[0].Payload = append([]byte(nil), payload...)
} }
return s.submitBatch(ctx, dataID, items, true) return s.submitBatch(ctx, dataID, items, true, false)
} }
func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool) error { func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool, borrowItems bool) error {
if s == nil { if s == nil {
return errTransportDetached return errTransportDetached
} }
@ -218,7 +224,7 @@ func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64
return err return err
} }
req := getBulkDedicatedLaneBatchRequest() req := getBulkDedicatedLaneBatchRequest()
req.prepare(ctx, dataID, items, wait) req.prepare(ctx, dataID, items, wait, borrowItems)
s.queued.Add(1) s.queued.Add(1)
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -237,6 +243,104 @@ func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64
} }
} }
func (s *bulkDedicatedLaneSender) tryDirectSubmitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int) (bool, int, error) {
if s == nil {
return true, 0, errTransportDetached
}
if ctx == nil {
ctx = context.Background()
}
if len(payload) == 0 {
return true, 0, nil
}
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
if err := s.errSnapshot(); err != nil {
return true, 0, err
}
select {
case <-ctx.Done():
return true, 0, normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
return true, 0, s.stoppedErr()
default:
}
if s.queued.Load() != 0 {
return false, 0, nil
}
if !s.flushMu.TryLock() {
return false, 0, nil
}
defer s.flushMu.Unlock()
if s.queued.Load() != 0 {
return false, 0, nil
}
if err := s.errSnapshot(); err != nil {
return true, 0, err
}
written := 0
seq := startSeq
deadline, _ := ctx.Deadline()
for written < len(payload) {
select {
case <-ctx.Done():
return true, written, normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
return true, written, s.stoppedErr()
default:
}
var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest
items := itemBuf[:0]
batchBytes := bulkDedicatedBatchHeaderLen
start := written
for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written)
if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes {
break
}
items = append(items, bulkDedicatedSendRequest{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: payload[written:end],
})
batchBytes += itemLen
seq++
written = end
}
if len(items) == 0 {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
items = append(items, bulkDedicatedSendRequest{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: payload[written:end],
})
seq++
written = end
}
if err := s.flush([]bulkDedicatedOutboundBatch{{
DataID: dataID,
Items: items,
}}, deadline); err != nil {
err = normalizeDedicatedBulkSendError(err)
s.setErr(err)
s.failPending(err)
if s.fail != nil {
go s.fail(err)
}
return true, start, err
}
}
return true, written, nil
}
func (s *bulkDedicatedLaneSender) tryDirectSubmitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) (bool, error) { func (s *bulkDedicatedLaneSender) tryDirectSubmitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) (bool, error) {
if s == nil { if s == nil {
return true, errTransportDetached return true, errTransportDetached

View File

@ -7,6 +7,57 @@ import (
"time" "time"
) )
func TestBulkDedicatedLaneBatchRequestPrepareBorrowedSharesItems(t *testing.T) {
req := getBulkDedicatedLaneBatchRequest()
defer req.recycle()
items := []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeData,
Seq: 7,
Payload: []byte("hello"),
}}
req.prepare(context.Background(), 11, items, true, true)
if got, want := len(req.Items), 1; got != want {
t.Fatalf("prepared item count = %d, want %d", got, want)
}
if &req.Items[0] != &items[0] {
t.Fatal("prepare with borrowed items should share request items")
}
}
func TestBulkDedicatedLaneSenderTryDirectSubmitWriteFlushesWholePayload(t *testing.T) {
conn := &shortWriteBulkRecordConn{maxPerWrite: 1024}
encodeCalls := 0
sender := &bulkDedicatedLaneSender{
conn: conn,
stopCh: make(chan struct{}),
encode: func(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) {
encodeCalls++
payload, err := encodeBulkDedicatedBatchesPlain(batches)
return payload, nil, err
},
}
payload := bytes.Repeat([]byte("a"), 3*defaultBulkChunkSize)
submitted, written, err := sender.tryDirectSubmitWrite(context.Background(), 9, 1, payload, defaultBulkChunkSize)
if err != nil {
t.Fatalf("tryDirectSubmitWrite error = %v", err)
}
if !submitted {
t.Fatal("tryDirectSubmitWrite should submit directly")
}
if got, want := written, len(payload); got != want {
t.Fatalf("written = %d, want %d", got, want)
}
if encodeCalls == 0 {
t.Fatal("encode should be called at least once")
}
if got := sender.queued.Load(); got != 0 {
t.Fatalf("queued requests = %d, want 0", got)
}
}
func TestBulkDedicatedLaneSenderCollectBatchRequestsBatchesAcrossDataIDs(t *testing.T) { func TestBulkDedicatedLaneSenderCollectBatchRequestsBatchesAcrossDataIDs(t *testing.T) {
sender := &bulkDedicatedLaneSender{ sender := &bulkDedicatedLaneSender{
reqCh: make(chan *bulkDedicatedLaneBatchRequest, 3), reqCh: make(chan *bulkDedicatedLaneBatchRequest, 3),

View File

@ -162,7 +162,8 @@ func (c *ClientCommon) tryDispatchBorrowedBulkTransportPayload(payload []byte) b
if c == nil || len(payload) == 0 { if c == nil || len(payload) == 0 {
return false return false
} }
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(c.modernPSKRuntime, c.msgDe, c.SecretKey, payload) profile := c.clientTransportProtectionSnapshot()
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload)
if err != nil { if err != nil {
if c.showError || c.debugMode { if c.showError || c.debugMode {
fmt.Println("client decode transport payload error", err) fmt.Println("client decode transport payload error", err)
@ -194,7 +195,7 @@ func (s *ServerCommon) tryDispatchBorrowedBulkTransportPayload(source interface{
if logical == nil { if logical == nil {
return false return false
} }
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload) plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload)
if err != nil { if err != nil {
if s.showError || s.debugMode { if s.showError || s.debugMode {
fmt.Println("server decode transport payload error", err) fmt.Println("server decode transport payload error", err)

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
"time" "time"
@ -132,8 +133,9 @@ func (c *ClientCommon) encodeBulkFastPayload(frame bulkFastFrame) ([]byte, error
if c == nil { if c == nil {
return nil, errBulkClientNil return nil, errBulkClientNil
} }
if c.fastPlainEncode != nil { profile := c.clientTransportProtectionSnapshot()
return encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame) if profile.fastPlainEncode != nil {
return encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
} }
plain, err := encodeBulkFastFramePayload(frame) plain, err := encodeBulkFastFramePayload(frame)
if err != nil { if err != nil {
@ -146,8 +148,9 @@ func (c *ClientCommon) encodeBulkFastBatchPayload(frames []bulkFastFrame) ([]byt
if c == nil { if c == nil {
return nil, errBulkClientNil return nil, errBulkClientNil
} }
if c.fastPlainEncode != nil { profile := c.clientTransportProtectionSnapshot()
return encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames) if profile.fastPlainEncode != nil {
return encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
} }
plain, err := encodeBulkFastBatchPlain(frames) plain, err := encodeBulkFastBatchPlain(frames)
if err != nil { if err != nil {
@ -160,11 +163,12 @@ func (c *ClientCommon) encodeBulkFastPayloadPooled(frame bulkFastFrame) ([]byte,
if c == nil { if c == nil {
return nil, nil, errBulkClientNil return nil, nil, errBulkClientNil
} }
if runtime := c.modernPSKRuntime; runtime != nil { profile := c.clientTransportProtectionSnapshot()
if runtime := profile.runtime; runtime != nil {
return encodeBulkFastFramePayloadPooled(runtime, frame) return encodeBulkFastFramePayloadPooled(runtime, frame)
} }
if c.fastPlainEncode != nil { if profile.fastPlainEncode != nil {
payload, err := encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame) payload, err := encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
return payload, nil, err return payload, nil, err
} }
plain, err := encodeBulkFastFramePayload(frame) plain, err := encodeBulkFastFramePayload(frame)
@ -179,11 +183,12 @@ func (c *ClientCommon) encodeBulkFastBatchPayloadPooled(frames []bulkFastFrame)
if c == nil { if c == nil {
return nil, nil, errBulkClientNil return nil, nil, errBulkClientNil
} }
if runtime := c.modernPSKRuntime; runtime != nil { profile := c.clientTransportProtectionSnapshot()
if runtime := profile.runtime; runtime != nil {
return encodeBulkFastBatchPayloadPooled(runtime, frames) return encodeBulkFastBatchPayloadPooled(runtime, frames)
} }
if c.fastPlainEncode != nil { if profile.fastPlainEncode != nil {
payload, err := encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames) payload, err := encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
return payload, nil, err return payload, nil, err
} }
plain, err := encodeBulkFastBatchPlain(frames) plain, err := encodeBulkFastBatchPlain(frames)
@ -460,28 +465,126 @@ func putBulkFastFrameScratch(buf []byte) {
bulkFastFrameScratchPool.Put(buf[:0]) bulkFastFrameScratchPool.Put(buf[:0])
} }
func transportFastPayloadMagic(payload []byte) string {
if len(payload) < 4 {
return ""
}
return string(payload[:4])
}
func (c *ClientCommon) decryptTransportPayloadPooled(payload []byte, release func()) ([]byte, func(), error) {
profile := c.clientTransportProtectionSnapshot()
return decryptTransportPayloadCodecPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload, release)
}
func (s *ServerCommon) decryptTransportPayloadLogicalPooled(logical *LogicalConn, payload []byte, release func()) ([]byte, func(), error) {
if logical == nil {
if release != nil {
release()
}
return nil, nil, errTransportDetached
}
return decryptTransportPayloadCodecPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, release)
}
func (c *ClientCommon) tryDispatchBorrowedTransportPlain(plain []byte, release func()) bool {
switch transportFastPayloadMagic(plain) {
case bulkFastPayloadMagic, bulkFastBatchMagic:
owner := newBulkReadPayloadOwner(release)
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
c.dispatchFastBulkFrameWithOwner(frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
walkErr = errBulkFastPayloadInvalid
}
if walkErr != nil && (c.showError || c.debugMode) {
fmt.Println("client decode bulk fast payload error", walkErr)
}
return true
case streamFastPayloadMagic, streamFastBatchMagic:
owner := newStreamReadPayloadOwner(release)
matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
c.dispatchFastStreamDataWithOwner(frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
walkErr = errStreamFastPayloadInvalid
}
if walkErr != nil && (c.showError || c.debugMode) {
fmt.Println("client decode stream fast payload error", walkErr)
}
return true
default:
return false
}
}
func (s *ServerCommon) tryDispatchBorrowedTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, release func()) bool {
switch transportFastPayloadMagic(plain) {
case bulkFastPayloadMagic, bulkFastBatchMagic:
owner := newBulkReadPayloadOwner(release)
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
walkErr = errBulkFastPayloadInvalid
}
if walkErr != nil && (s.showError || s.debugMode) {
fmt.Println("server decode bulk fast payload error", walkErr)
}
return true
case streamFastPayloadMagic, streamFastBatchMagic:
owner := newStreamReadPayloadOwner(release)
matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
s.dispatchFastStreamDataWithOwner(logical, transport, conn, frame, owner)
return nil
})
if owner != nil {
owner.done()
}
if !matched {
walkErr = errStreamFastPayloadInvalid
}
if walkErr != nil && (s.showError || s.debugMode) {
fmt.Println("server decode stream fast payload error", walkErr)
}
return true
default:
return false
}
}
func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error { func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error {
plain, err := c.decryptTransportPayload(payload) plain, err := c.decryptTransportPayload(payload)
if err != nil { if err != nil {
return err return err
} }
if frames, matched, err := decodeBulkFastFrames(plain); matched { return c.dispatchInboundTransportPlain(plain, now)
if err != nil { }
return err
} func (c *ClientCommon) dispatchInboundTransportPlain(plain []byte, now time.Time) error {
for _, frame := range frames { if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
c.dispatchFastBulkFrame(frame) c.dispatchFastBulkFrame(frame)
}
return nil return nil
} }); matched {
if frames, matched, err := decodeStreamFastDataFrames(plain); matched {
if err != nil {
return err return err
} }
for _, frame := range frames { if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
c.dispatchFastStreamData(frame) c.dispatchFastStreamData(frame)
}
return nil return nil
}); matched {
return err
} }
env, err := c.decodeEnvelopePlain(plain) env, err := c.decodeEnvelopePlain(plain)
if err != nil { if err != nil {
@ -502,23 +605,21 @@ func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, tra
if err != nil { if err != nil {
return err return err
} }
if frames, matched, err := decodeBulkFastFrames(plain); matched { return s.dispatchInboundTransportPlain(logical, transport, conn, plain, now)
if err != nil { }
return err
} func (s *ServerCommon) dispatchInboundTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, now time.Time) error {
for _, frame := range frames { if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
s.dispatchFastBulkFrame(logical, transport, conn, frame) s.dispatchFastBulkFrame(logical, transport, conn, frame)
}
return nil return nil
} }); matched {
if frames, matched, err := decodeStreamFastDataFrames(plain); matched {
if err != nil {
return err return err
} }
for _, frame := range frames { if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
s.dispatchFastStreamData(logical, transport, conn, frame) s.dispatchFastStreamData(logical, transport, conn, frame)
}
return nil return nil
}); matched {
return err
} }
env, err := s.decodeEnvelopePlain(plain) env, err := s.decodeEnvelopePlain(plain)
if err != nil { if err != nil {

View File

@ -100,6 +100,176 @@ func TestBulkOpenAutoUDPFallsBackToShared(t *testing.T) {
_ = accepted.Bulk.Close() _ = accepted.Bulk.Close()
} }
func TestOpenDedicatedBulkConnectByConnRejectedAsSingleConnMode(t *testing.T) {
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
}
server.SetBulkHandler(func(info BulkAcceptInfo) error {
return nil
})
})
client := NewClient().(*ClientCommon)
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachConnForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
_, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{
Range: BulkRange{
Offset: 0,
Length: 128,
},
})
if !errors.Is(err, errBulkDedicatedSingleConn) {
t.Fatalf("client OpenDedicatedBulk over ConnectByConn error = %v, want %v", err, errBulkDedicatedSingleConn)
}
}
func TestOpenBulkAutoConnectByConnFallsBackToShared(t *testing.T) {
acceptCh := make(chan BulkAcceptInfo, 2)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
}
server.SetBulkHandler(func(info BulkAcceptInfo) error {
acceptCh <- info
return nil
})
})
client := NewClient().(*ClientCommon)
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachConnForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
Mode: BulkOpenModeAuto,
Range: BulkRange{
Offset: 0,
Length: 128,
},
})
if err != nil {
t.Fatalf("client OpenBulk auto over ConnectByConn failed: %v", err)
}
if bulk.Snapshot().Dedicated {
t.Fatal("client OpenBulk auto over ConnectByConn should fall back to shared")
}
defer func() {
_ = bulk.Close()
}()
accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second)
if accepted.Dedicated {
t.Fatal("server accepted bulk should be shared after ConnectByConn auto fallback")
}
if _, err := bulk.Write([]byte("shared-over-single-conn")); err != nil {
t.Fatalf("client bulk Write failed: %v", err)
}
readBulkExactly(t, accepted.Bulk, "shared-over-single-conn", 2*time.Second)
select {
case extra := <-acceptCh:
t.Fatalf("unexpected extra server bulk accept: %+v", extra)
case <-time.After(300 * time.Millisecond):
}
_ = accepted.Bulk.Close()
}
func TestOpenDedicatedBulkExternalTransportDialableSourceSucceeds(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
}
acceptCh := make(chan BulkAcceptInfo, 2)
server.SetBulkHandler(func(info BulkAcceptInfo) error {
acceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
bulk, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{
ID: "external-dedicated-dialable",
Range: BulkRange{
Offset: 0,
Length: 64,
},
})
if err != nil {
t.Fatalf("client OpenDedicatedBulk failed: %v", err)
}
if !bulk.Snapshot().Dedicated {
t.Fatal("client OpenDedicatedBulk over dialable external transport should stay dedicated")
}
defer func() {
_ = bulk.Close()
}()
accepted := waitAcceptedBulkByID(t, acceptCh, bulk.ID(), 2*time.Second)
if !accepted.Dedicated {
t.Fatal("server accepted bulk should stay dedicated over dialable external transport")
}
defer func() {
_ = accepted.Bulk.Close()
}()
clientHandle := bulk.(*bulkHandle)
if clientHandle.dedicatedConnSnapshot() == nil {
t.Fatal("client dedicated sidecar conn should be attached")
}
if mainConn := client.clientTransportConnSnapshot(); mainConn != nil && clientHandle.dedicatedConnSnapshot() == mainConn {
t.Fatal("client dedicated sidecar should use an additional physical connection")
}
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionExternal {
t.Fatalf("client steady protection mode = %v, want %v", got, ProtectionExternal)
}
payload := "external-dedicated-sidecar"
if _, err := bulk.Write([]byte(payload)); err != nil {
t.Fatalf("client bulk Write failed: %v", err)
}
readBulkExactly(t, accepted.Bulk, payload, 2*time.Second)
}
func TestOpenDedicatedBulkWaitsForActiveSlotUntilContextDeadline(t *testing.T) { func TestOpenDedicatedBulkWaitsForActiveSlotUntilContextDeadline(t *testing.T) {
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {

View File

@ -38,6 +38,18 @@ type ClientCommon struct {
modernPSKRuntime *modernPSKCodecRuntime modernPSKRuntime *modernPSKCodecRuntime
handshakeRsaPubKey []byte handshakeRsaPubKey []byte
SecretKey []byte SecretKey []byte
transportProtection atomic.Pointer[transportProtectionProfile]
peerAttachSecurity atomic.Pointer[peerAttachSecurityState]
securityBootstrap transportProtectionProfile
securitySteady transportProtectionProfile
securitySteadyNegotiated transportProtectionProfile
securityAuthMode AuthMode
securityProtectionMode ProtectionMode
securityRequireForwardSecrecy bool
securityConfigured bool
peerAttachAuthenticated bool
peerAttachAuthFallback bool
peerAttachAt int64
noFinSyncMsgMaxKeepSeconds int noFinSyncMsgMaxKeepSeconds int
lastHeartbeat int64 lastHeartbeat int64
heartbeatPeriod time.Duration heartbeatPeriod time.Duration
@ -134,6 +146,8 @@ func NewClient() Client {
client.fileEventObserver = normalizeFileEventCallback(nil) client.fileEventObserver = normalizeFileEventCallback(nil)
client.stopCtx, client.stopFn = context.WithCancel(context.Background()) client.stopCtx, client.stopFn = context.WithCancel(context.Background())
client.sessionRuntime.Store(newClientSessionRuntimeBase(client.stopCtx, client.stopFn)) client.sessionRuntime.Store(newClientSessionRuntimeBase(client.stopCtx, client.stopFn))
client.setClientTransportProtectionProfile(defaultTransportProtectionProfile())
client.peerAttachSecurity.Store(defaultPeerAttachSecurityState())
bindClientStreamControl(&client) bindClientStreamControl(&client)
bindClientBulkControl(&client) bindClientBulkControl(&client)
client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler) client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler)

View File

@ -294,7 +294,7 @@ func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender {
if err := bulk.waitDedicatedReady(ctx); err != nil { if err := bulk.waitDedicatedReady(ctx); err != nil {
return 0, err return 0, err
} }
return c.sendDedicatedBulkWrite(ctx, bulk, startSeq, payload) return c.sendDedicatedBulkWrite(ctx, bulk, startSeq, payload, payloadOwned)
} }
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) { if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
return 0, errTransportDetached return 0, errTransportDetached

View File

@ -65,32 +65,40 @@ func (c *ClientCommon) RecoverTransferSnapshots(ctx context.Context) error {
} }
func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte { func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte {
return c.msgEn return c.clientTransportProtectionSnapshot().msgEn
} }
// Deprecated: SetMsgEn overrides the transport codec directly. // Deprecated: SetMsgEn overrides the transport codec directly.
// Prefer UseModernPSKClient or UseLegacySecurityClient. // Prefer UseModernPSKClient or UseLegacySecurityClient.
func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) { func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
c.msgEn = fn profile := c.clientTransportProtectionSnapshot()
c.fastStreamEncode = nil profile.mode = ProtectionManaged
c.fastBulkEncode = nil profile.msgEn = fn
c.fastPlainEncode = nil profile.fastStreamEncode = nil
c.modernPSKRuntime = nil profile.fastBulkEncode = nil
profile.fastPlainEncode = nil
profile.runtime = nil
c.setClientTransportProtectionProfile(profile)
c.clearClientSecurityProfiles()
c.securityReadyCheck = false c.securityReadyCheck = false
} }
func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte { func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte {
return c.msgDe return c.clientTransportProtectionSnapshot().msgDe
} }
// Deprecated: SetMsgDe overrides the transport codec directly. // Deprecated: SetMsgDe overrides the transport codec directly.
// Prefer UseModernPSKClient or UseLegacySecurityClient. // Prefer UseModernPSKClient or UseLegacySecurityClient.
func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) { func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
c.msgDe = fn profile := c.clientTransportProtectionSnapshot()
c.fastStreamEncode = nil profile.mode = ProtectionManaged
c.fastBulkEncode = nil profile.msgDe = fn
c.fastPlainEncode = nil profile.fastStreamEncode = nil
c.modernPSKRuntime = nil profile.fastBulkEncode = nil
profile.fastPlainEncode = nil
profile.runtime = nil
c.setClientTransportProtectionProfile(profile)
c.clearClientSecurityProfiles()
c.securityReadyCheck = false c.securityReadyCheck = false
} }
@ -103,20 +111,24 @@ func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) {
} }
func (c *ClientCommon) GetSecretKey() []byte { func (c *ClientCommon) GetSecretKey() []byte {
return c.SecretKey return c.clientTransportProtectionSnapshot().secretKey
} }
// Deprecated: SetSecretKey injects a raw transport key directly. // Deprecated: SetSecretKey injects a raw transport key directly.
// Prefer UseModernPSKClient or UseLegacySecurityClient. // Prefer UseModernPSKClient or UseLegacySecurityClient.
func (c *ClientCommon) SetSecretKey(key []byte) { func (c *ClientCommon) SetSecretKey(key []byte) {
c.SecretKey = key profile := c.clientTransportProtectionSnapshot()
profile.mode = ProtectionManaged
profile.secretKey = cloneTransportProtectionKey(key)
if len(key) == 0 { if len(key) == 0 {
c.modernPSKRuntime = nil profile.runtime = nil
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil { } else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
c.modernPSKRuntime = runtime profile.runtime = runtime
} else { } else {
c.modernPSKRuntime = nil profile.runtime = nil
} }
c.setClientTransportProtectionProfile(profile)
c.clearClientSecurityProfiles()
c.securityReadyCheck = len(key) == 0 c.securityReadyCheck = len(key) == 0
c.skipKeyExchange = true c.skipKeyExchange = true
} }

View File

@ -8,6 +8,8 @@ import (
type clientConnAttachmentState struct { type clientConnAttachmentState struct {
maxReadTimeout time.Duration maxReadTimeout time.Duration
maxWriteTimeout time.Duration maxWriteTimeout time.Duration
authMode AuthMode
protectionMode ProtectionMode
msgEn func([]byte, []byte) []byte msgEn func([]byte, []byte) []byte
msgDe func([]byte, []byte) []byte msgDe func([]byte, []byte) []byte
fastStreamEncode transportFastStreamEncoder fastStreamEncode transportFastStreamEncoder
@ -16,6 +18,13 @@ type clientConnAttachmentState struct {
modernPSKRuntime *modernPSKCodecRuntime modernPSKRuntime *modernPSKCodecRuntime
handshakeRsaKey []byte handshakeRsaKey []byte
secretKey []byte secretKey []byte
keyMode string
sessionID []byte
forwardSecrecy bool
forwardSecrecyFallback bool
peerAttached bool
peerAttachFallback bool
peerAttachAt int64
lastHeartBeat int64 lastHeartBeat int64
} }
@ -26,6 +35,7 @@ func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnA
cloned := *src cloned := *src
cloned.handshakeRsaKey = cloneClientConnAttachmentBytes(src.handshakeRsaKey) cloned.handshakeRsaKey = cloneClientConnAttachmentBytes(src.handshakeRsaKey)
cloned.secretKey = cloneClientConnAttachmentBytes(src.secretKey) cloned.secretKey = cloneClientConnAttachmentBytes(src.secretKey)
cloned.sessionID = cloneClientConnAttachmentBytes(src.sessionID)
return &cloned return &cloned
} }
@ -153,6 +163,7 @@ func (c *ClientConn) applyClientConnAttachmentProfile(maxReadTimeout time.Durati
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.maxReadTimeout = maxReadTimeout state.maxReadTimeout = maxReadTimeout
state.maxWriteTimeout = maxWriteTimeout state.maxWriteTimeout = maxWriteTimeout
state.protectionMode = ProtectionManaged
state.msgEn = msgEn state.msgEn = msgEn
state.msgDe = msgDe state.msgDe = msgDe
state.modernPSKRuntime = nil state.modernPSKRuntime = nil
@ -206,6 +217,7 @@ func (c *ClientConn) clientConnMsgEnSnapshot() func([]byte, []byte) []byte {
func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) { func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.protectionMode = ProtectionManaged
state.msgEn = fn state.msgEn = fn
state.fastStreamEncode = nil state.fastStreamEncode = nil
state.fastBulkEncode = nil state.fastBulkEncode = nil
@ -223,6 +235,7 @@ func (c *ClientConn) clientConnMsgDeSnapshot() func([]byte, []byte) []byte {
func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) { func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.protectionMode = ProtectionManaged
state.msgDe = fn state.msgDe = fn
state.fastStreamEncode = nil state.fastStreamEncode = nil
state.fastBulkEncode = nil state.fastBulkEncode = nil
@ -319,6 +332,39 @@ func (c *LogicalConn) modernPSKRuntimeSnapshot() *modernPSKCodecRuntime {
return nil return nil
} }
func (c *LogicalConn) protectionModeSnapshot() ProtectionMode {
if state := c.attachmentStateRaw(); state != nil {
return state.protectionMode
}
return ProtectionManaged
}
func (c *LogicalConn) authModeSnapshot() AuthMode {
if state := c.attachmentStateRaw(); state != nil {
return state.authMode
}
return AuthNone
}
func (c *LogicalConn) peerAttachAuthenticatedSnapshot() (bool, bool, time.Time) {
if state := c.attachmentStateRaw(); state != nil {
if state.peerAttachAt == 0 {
return state.peerAttached, state.peerAttachFallback, time.Time{}
}
return state.peerAttached, state.peerAttachFallback, time.Unix(0, state.peerAttachAt)
}
return false, false, time.Time{}
}
func (c *LogicalConn) markPeerAttachAuthenticated(authMode AuthMode, fallback bool, at time.Time) {
c.updateAttachmentState(func(state *clientConnAttachmentState) {
state.authMode = authMode
state.peerAttached = true
state.peerAttachFallback = fallback
state.peerAttachAt = at.UnixNano()
})
}
func (c *LogicalConn) setModernPSKRuntime(runtime *modernPSKCodecRuntime) { func (c *LogicalConn) setModernPSKRuntime(runtime *modernPSKCodecRuntime) {
c.updateAttachmentState(func(state *clientConnAttachmentState) { c.updateAttachmentState(func(state *clientConnAttachmentState) {
state.modernPSKRuntime = runtime state.modernPSKRuntime = runtime

View File

@ -22,10 +22,14 @@ type clientConnectSource struct {
network string network string
addr string addr string
dialFn func(context.Context) (net.Conn, error) dialFn func(context.Context) (net.Conn, error)
supportsAdditional bool
} }
func newClientConnConnectSource(conn net.Conn) *clientConnectSource { func newClientConnConnectSource(conn net.Conn) *clientConnectSource {
source := &clientConnectSource{kind: clientConnectSourceConn} source := &clientConnectSource{
kind: clientConnectSourceConn,
supportsAdditional: false,
}
if conn == nil { if conn == nil {
return source return source
} }
@ -46,6 +50,7 @@ func newClientNetworkConnectSource(network string, addr string) *clientConnectSo
kind: clientConnectSourceNetwork, kind: clientConnectSourceNetwork,
network: network, network: network,
addr: addr, addr: addr,
supportsAdditional: true,
dialFn: func(context.Context) (net.Conn, error) { dialFn: func(context.Context) (net.Conn, error) {
return transport.Dial(network, addr) return transport.Dial(network, addr)
}, },
@ -57,6 +62,7 @@ func newClientTimeoutConnectSource(network string, addr string, timeout time.Dur
kind: clientConnectSourceTimeout, kind: clientConnectSourceTimeout,
network: network, network: network,
addr: addr, addr: addr,
supportsAdditional: true,
dialFn: func(context.Context) (net.Conn, error) { dialFn: func(context.Context) (net.Conn, error) {
return transport.DialTimeout(network, addr, timeout) return transport.DialTimeout(network, addr, timeout)
}, },
@ -67,6 +73,7 @@ func newClientFactoryConnectSource(dialFn func(context.Context) (net.Conn, error
return &clientConnectSource{ return &clientConnectSource{
kind: clientConnectSourceFactory, kind: clientConnectSourceFactory,
dialFn: dialFn, dialFn: dialFn,
supportsAdditional: true,
} }
} }
@ -82,6 +89,10 @@ func (s *clientConnectSource) canReconnect() bool {
return s != nil && s.dialFn != nil return s != nil && s.dialFn != nil
} }
func (s *clientConnectSource) supportsAdditionalConn() bool {
return s != nil && s.supportsAdditional
}
func (s *clientConnectSource) isUDP() bool { func (s *clientConnectSource) isUDP() bool {
if s == nil { if s == nil {
return false return false

View File

@ -31,7 +31,11 @@ func (c *ClientCommon) ExchangeKey(newKey []byte) error {
if string(data.Value) != "success" { if string(data.Value) != "success" {
return errors.New("cannot exchange new aes-key") return errors.New("cannot exchange new aes-key")
} }
c.SecretKey = newKey profile := c.clientTransportProtectionSnapshot()
profile.mode = ProtectionManaged
profile.secretKey = cloneTransportProtectionKey(newKey)
profile.runtime = nil
c.setClientTransportProtectionProfile(profile)
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
return nil return nil
} }

View File

@ -350,6 +350,8 @@ func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime,
if rt == nil { if rt == nil {
return nil return nil
} }
c.resetClientPeerAttachAuth()
c.activateClientBootstrapTransportProtection()
if runKeyExchange && !c.skipKeyExchange { if runKeyExchange && !c.skipKeyExchange {
if err := c.keyExchangeFn(c); err != nil { if err := c.keyExchangeFn(c); err != nil {
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err) return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err)
@ -358,6 +360,7 @@ func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime,
if err := c.announceClientPeerIdentity(); err != nil { if err := c.announceClientPeerIdentity(); err != nil {
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err) return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err)
} }
c.activateClientSteadyTransportProtection()
return nil return nil
} }

View File

@ -174,28 +174,37 @@ func (c *ClientCommon) dispatchTransportPayloadFast(payload []byte, release func
} }
return return
} }
if c.tryDispatchBorrowedBulkTransportPayload(payload) { plain, plainRelease, err := c.decryptTransportPayloadPooled(payload, release)
if release != nil { if err != nil {
release() if c.showError || c.debugMode {
fmt.Println("client decode transport payload error", err)
} }
return return
} }
owned := append([]byte(nil), payload...) if c.tryDispatchBorrowedTransportPlain(plain, plainRelease) {
if release != nil { return
release()
} }
if dispatcher == nil { if dispatcher == nil {
now := time.Now() now := time.Now()
if err := c.dispatchInboundTransportPayload(owned, now); err != nil && (c.showError || c.debugMode) { err := c.dispatchInboundTransportPlain(plain, now)
if plainRelease != nil {
plainRelease()
}
if err != nil && (c.showError || c.debugMode) {
fmt.Println("client decode envelope error", err) fmt.Println("client decode envelope error", err)
} }
return return
} }
owned := plain
if plainRelease != nil {
owned = append([]byte(nil), plain...)
plainRelease()
}
c.wg.Add(1) c.wg.Add(1)
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() { if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
defer c.wg.Done() defer c.wg.Done()
now := time.Now() now := time.Now()
if err := c.dispatchInboundTransportPayload(owned, now); err != nil && (c.showError || c.debugMode) { if err := c.dispatchInboundTransportPlain(owned, now); err != nil && (c.showError || c.debugMode) {
fmt.Println("client decode envelope error", err) fmt.Println("client decode envelope error", err)
} }
}) { }) {

View File

@ -26,6 +26,8 @@ type Client interface {
BulkOpenTuning() BulkOpenTuning BulkOpenTuning() BulkOpenTuning
SetBulkDedicatedAttachConfig(BulkDedicatedAttachConfig) SetBulkDedicatedAttachConfig(BulkDedicatedAttachConfig)
BulkDedicatedAttachConfig() BulkDedicatedAttachConfig BulkDedicatedAttachConfig() BulkDedicatedAttachConfig
SetPeerAttachSecurityConfig(PeerAttachSecurityConfig) error
PeerAttachSecurityConfig() PeerAttachSecurityConfig
SetFileReceiveDir(dir string) error SetFileReceiveDir(dir string) error
send(msg TransferMsg) (WaitMsg, error) send(msg TransferMsg) (WaitMsg, error)
sendEnvelope(env Envelope) error sendEnvelope(env Envelope) error

View File

@ -404,6 +404,7 @@ func (c *LogicalConn) applyAttachmentProfile(maxReadTimeout time.Duration, maxWr
c.updateAttachmentState(func(state *clientConnAttachmentState) { c.updateAttachmentState(func(state *clientConnAttachmentState) {
state.maxReadTimeout = maxReadTimeout state.maxReadTimeout = maxReadTimeout
state.maxWriteTimeout = maxWriteTimeout state.maxWriteTimeout = maxWriteTimeout
state.protectionMode = ProtectionManaged
state.msgEn = msgEn state.msgEn = msgEn
state.msgDe = msgDe state.msgDe = msgDe
state.fastStreamEncode = fastStreamEncode state.fastStreamEncode = fastStreamEncode
@ -525,6 +526,7 @@ func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot {
return ClientConnRuntimeSnapshot{} return ClientConnRuntimeSnapshot{}
} }
status := c.Status() status := c.Status()
authenticated, fallback, attachAt := c.peerAttachAuthenticatedSnapshot()
now := time.Now() now := time.Now()
snapshot := ClientConnRuntimeSnapshot{ snapshot := ClientConnRuntimeSnapshot{
ClientID: c.clientIDSnapshot(), ClientID: c.clientIDSnapshot(),
@ -536,6 +538,11 @@ func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot {
TransportAttachCount: c.clientConnTransportAttachCountSnapshot(), TransportAttachCount: c.clientConnTransportAttachCountSnapshot(),
TransportDetachCount: c.clientConnTransportDetachCountSnapshot(), TransportDetachCount: c.clientConnTransportDetachCountSnapshot(),
LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(), LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(),
AuthMode: authModeName(c.authModeSnapshot()),
ProtectionMode: protectionModeName(c.protectionModeSnapshot()),
PeerAttachAuthenticated: authenticated,
PeerAttachAuthFallback: fallback,
LastPeerAttachAt: attachAt,
} }
if status.Err != nil { if status.Err != nil {
snapshot.Error = status.Err.Error() snapshot.Error = status.Err.Error()

View File

@ -43,6 +43,20 @@ func TestHydrateServerMessagePeerFieldsFromLogicalConn(t *testing.T) {
} }
} }
func TestHydrateServerMessagePeerFieldsZeroValueDoesNotPanic(t *testing.T) {
message := hydrateServerMessagePeerFields(Message{})
if message.LogicalConn != nil {
t.Fatal("zero-value message should not hydrate logical conn")
}
if message.ClientConn != nil {
t.Fatal("zero-value message should not hydrate client conn")
}
if message.TransportConn != nil {
t.Fatal("zero-value message should not hydrate transport conn")
}
}
func TestServerPublishSendFileEventHydratesLogicalConnAlias(t *testing.T) { func TestServerPublishSendFileEventHydratesLogicalConnAlias(t *testing.T) {
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
left, right := net.Pipe() left, right := net.Pipe()

33
msg.go
View File

@ -40,6 +40,7 @@ type Message struct {
ClientConn *ClientConn ClientConn *ClientConn
TransportConn *TransportConn TransportConn *TransportConn
ServerConn Client ServerConn Client
inboundTransportProfile *transportProtectionProfile
TransferMsg TransferMsg
Time time.Time Time time.Time
inboundConn net.Conn inboundConn net.Conn
@ -58,7 +59,7 @@ type messageLogicalTransferSender interface {
} }
type messageInboundTransferSender interface { type messageInboundTransferSender interface {
sendTransferInbound(*LogicalConn, *TransportConn, net.Conn, TransferMsg) error sendTransferInbound(*LogicalConn, *TransportConn, net.Conn, *transportProtectionProfile, TransferMsg) error
} }
func (m *Message) Reply(value MsgVal) (err error) { func (m *Message) Reply(value MsgVal) (err error) {
@ -86,7 +87,7 @@ func (m *Message) Reply(value MsgVal) (err error) {
if sender == nil { if sender == nil {
return transportDetachedErrorForPeer(logical, transport) return transportDetachedErrorForPeer(logical, transport)
} }
return sender.sendTransferInbound(logical, transport, m.inboundConn, reply) return sender.sendTransferInbound(logical, transport, m.inboundConn, messageInboundTransportProtectionSnapshot(m), reply)
} }
if transport != nil { if transport != nil {
_, err = transport.sendTransfer(reply) _, err = transport.sendTransfer(reply)
@ -123,12 +124,19 @@ func hydrateServerMessagePeerFields(message Message) Message {
if message.LogicalConn == nil { if message.LogicalConn == nil {
message.LogicalConn = logicalConnFromClient(message.ClientConn) message.LogicalConn = logicalConnFromClient(message.ClientConn)
} }
if message.ClientConn == nil { if message.LogicalConn == nil && message.TransportConn != nil {
message.LogicalConn = message.TransportConn.logicalConnSnapshot()
}
if message.ClientConn == nil && message.LogicalConn != nil {
message.ClientConn = message.LogicalConn.compatClientConn() message.ClientConn = message.LogicalConn.compatClientConn()
} }
if message.TransportConn == nil && message.LogicalConn != nil { if message.TransportConn == nil && message.LogicalConn != nil {
message.TransportConn = message.LogicalConn.CurrentTransportConn() message.TransportConn = message.LogicalConn.CurrentTransportConn()
} }
if message.inboundConn != nil && message.inboundTransportProfile == nil && message.LogicalConn != nil {
profile := message.LogicalConn.transportProtectionProfileSnapshot()
message.inboundTransportProfile = &profile
}
return message return message
} }
@ -155,3 +163,22 @@ func messageTransportConnSnapshot(message *Message) *TransportConn {
} }
return logical.CurrentTransportConn() return logical.CurrentTransportConn()
} }
func messageInboundTransportProtectionSnapshot(message *Message) *transportProtectionProfile {
if message == nil {
return nil
}
if message.inboundTransportProfile != nil {
return message.inboundTransportProfile
}
if message.inboundConn == nil {
return nil
}
logical := messageLogicalConnSnapshot(message)
if logical == nil {
return nil
}
profile := logical.transportProtectionProfileSnapshot()
message.inboundTransportProfile = &profile
return message.inboundTransportProfile
}

517
peer_attach_auth.go Normal file
View File

@ -0,0 +1,517 @@
package notify
import (
"bytes"
"crypto/hmac"
cryptorand "crypto/rand"
"crypto/sha256"
"encoding/binary"
"errors"
"net"
"sync"
"sync/atomic"
"time"
)
const (
peerAttachFeatureExplicitAuth uint64 = 1 << iota
peerAttachFeatureChannelBinding
peerAttachFeatureForwardSecrecy
)
const (
peerAttachNonceSize = 16
peerAttachReplayTTL = 30 * time.Second
)
var (
errPeerAttachAuthInvalid = errors.New("peer attach auth invalid")
errPeerAttachReplayRejected = errors.New("peer attach replay rejected")
errPeerAttachReplayWindowFull = errors.New("peer attach replay window full")
errPeerAttachExplicitAuthRequired = errors.New("peer attach explicit auth required")
errPeerAttachChannelBindingRequired = errors.New("peer attach channel binding required")
errPeerAttachChannelBindingUnavailable = errors.New("peer attach channel binding unavailable")
errPeerAttachForwardSecrecyRequired = errors.New("peer attach forward secrecy required")
)
type peerAttachAuthResult struct {
explicit bool
fallback bool
clientNonce []byte
serverNonce []byte
channelBinding []byte
clientECDHEPublicKey []byte
}
type peerAttachReplayCache struct {
mu sync.Mutex
entries map[string]time.Time
rejects atomic.Int64
overflowRejects atomic.Int64
}
func newPeerAttachNonce() ([]byte, error) {
buf := make([]byte, peerAttachNonceSize)
if _, err := cryptorand.Read(buf); err != nil {
return nil, err
}
return buf, nil
}
func appendPeerAttachAuthBytes(dst []byte, data []byte) []byte {
dst = binary.BigEndian.AppendUint32(dst, uint32(len(data)))
return append(dst, data...)
}
func appendPeerAttachAuthString(dst []byte, value string) []byte {
return appendPeerAttachAuthBytes(dst, []byte(value))
}
func appendPeerAttachAuthBool(dst []byte, value bool) []byte {
if value {
return append(dst, 1)
}
return append(dst, 0)
}
func peerAttachRequestAuthPayload(req peerAttachRequest, channelBinding []byte) []byte {
buf := make([]byte, 0, 96+len(req.PeerID)+len(channelBinding))
buf = appendPeerAttachAuthString(buf, "notify/peer-attach/request-auth/v1")
buf = binary.BigEndian.AppendUint64(buf, req.Features)
buf = appendPeerAttachAuthString(buf, req.PeerID)
buf = appendPeerAttachAuthBytes(buf, req.ClientNonce)
if supportsPeerAttachChannelBinding(req.Features) {
buf = appendPeerAttachAuthBytes(buf, channelBinding)
}
return buf
}
func peerAttachResponseAuthPayload(req peerAttachRequest, resp peerAttachResponse, channelBinding []byte) []byte {
buf := make([]byte, 0, 160+len(req.PeerID)+len(resp.PeerID)+len(resp.Error)+len(channelBinding))
buf = appendPeerAttachAuthString(buf, "notify/peer-attach/response-auth/v1")
buf = binary.BigEndian.AppendUint64(buf, req.Features)
buf = appendPeerAttachAuthString(buf, req.PeerID)
buf = appendPeerAttachAuthBytes(buf, req.ClientNonce)
buf = binary.BigEndian.AppendUint64(buf, resp.Features)
buf = appendPeerAttachAuthString(buf, resp.PeerID)
buf = appendPeerAttachAuthBool(buf, resp.Accepted)
buf = appendPeerAttachAuthBool(buf, resp.Reused)
buf = appendPeerAttachAuthString(buf, resp.Error)
buf = appendPeerAttachAuthBytes(buf, resp.ServerNonce)
if supportsPeerAttachChannelBinding(resp.Features) {
buf = appendPeerAttachAuthBytes(buf, channelBinding)
}
return buf
}
func signPeerAttachPayload(secretKey []byte, payload []byte) []byte {
if len(secretKey) == 0 {
return nil
}
mac := hmac.New(sha256.New, secretKey)
_, _ = mac.Write(payload)
return mac.Sum(nil)
}
func computePeerAttachRequestAuthTag(secretKey []byte, req peerAttachRequest, channelBinding []byte) []byte {
return signPeerAttachPayload(secretKey, peerAttachRequestAuthPayload(req, channelBinding))
}
func computePeerAttachResponseAuthTag(secretKey []byte, req peerAttachRequest, resp peerAttachResponse, channelBinding []byte) []byte {
return signPeerAttachPayload(secretKey, peerAttachResponseAuthPayload(req, resp, channelBinding))
}
func supportsExplicitPeerAttachAuth(features uint64) bool {
return features&peerAttachFeatureExplicitAuth != 0
}
func supportsPeerAttachChannelBinding(features uint64) bool {
return features&peerAttachFeatureChannelBinding != 0
}
func supportsPeerAttachForwardSecrecy(features uint64) bool {
return features&peerAttachFeatureForwardSecrecy != 0
}
func classifyPeerAttachRejectCounter(s *ServerCommon, err error) {
if s == nil || err == nil {
return
}
switch {
case errors.Is(err, errPeerAttachReplayRejected):
s.peerAttachReplay.rejects.Add(1)
case errors.Is(err, errPeerAttachReplayWindowFull):
s.peerAttachReplay.overflowRejects.Add(1)
case errors.Is(err, errPeerAttachExplicitAuthRequired), errors.Is(err, errPeerAttachChannelBindingRequired), errors.Is(err, errPeerAttachForwardSecrecyRequired):
s.peerAttachDowngradeRejectCount.Add(1)
case errors.Is(err, errPeerAttachChannelBindingUnavailable):
s.peerAttachBindingRejectCount.Add(1)
default:
s.peerAttachAuthRejectCount.Add(1)
}
}
func (c *ClientCommon) shouldUseExplicitPeerAttachAuth() bool {
if c == nil || !c.securityConfigured {
return false
}
return c.securityAuthMode == AuthPSK && len(c.securityBootstrap.secretKey) != 0
}
func resolvePeerAttachChannelBinding(provider PeerAttachChannelBindingProvider, role PeerAttachChannelBindingRole, peerID string, conn net.Conn) ([]byte, error) {
if provider == nil {
return nil, nil
}
if conn == nil {
return nil, errPeerAttachChannelBindingUnavailable
}
binding, err := provider(PeerAttachChannelBindingContext{
Role: role,
PeerID: peerID,
Conn: conn,
})
if err != nil || len(binding) == 0 {
return nil, errPeerAttachChannelBindingUnavailable
}
return bytes.Clone(binding), nil
}
func (c *ClientCommon) buildPeerAttachRequest(peerID string) (peerAttachRequest, peerAttachRequestState, error) {
cfg := c.peerAttachSecuritySnapshot()
req := peerAttachRequest{
PeerID: stringsTrimSpaceNoAlloc(peerID),
}
if !c.shouldUseExplicitPeerAttachAuth() {
if c.clientRequiresForwardSecrecy() {
return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachForwardSecrecyRequired
}
if cfg.requireExplicitAuth {
return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachExplicitAuthRequired
}
return req, peerAttachRequestState{}, nil
}
nonce, err := newPeerAttachNonce()
if err != nil {
return peerAttachRequest{}, peerAttachRequestState{}, err
}
req.Features = peerAttachFeatureExplicitAuth
requestState := peerAttachRequestState{}
var channelBinding []byte
if cfg.channelBinding != nil {
channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleClient, req.PeerID, c.clientTransportConnSnapshot())
if err != nil {
return peerAttachRequest{}, peerAttachRequestState{}, err
}
}
if len(channelBinding) != 0 {
req.Features |= peerAttachFeatureChannelBinding
}
if cfg.requireChannelBinding && !supportsPeerAttachChannelBinding(req.Features) {
return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachChannelBindingRequired
}
if c.clientSupportsForwardSecrecy() {
requestState.forwardSecrecy, err = newPeerAttachForwardSecrecyClientState()
if err != nil {
return peerAttachRequest{}, peerAttachRequestState{}, err
}
req.Features |= peerAttachFeatureForwardSecrecy
req.ClientECDHEPublicKey = bytes.Clone(requestState.forwardSecrecy.publicKey)
}
req.ClientNonce = nonce
req.AuthTag = computePeerAttachRequestAuthTag(c.securityBootstrap.secretKey, req, channelBinding)
return req, requestState, nil
}
func (c *ClientCommon) verifyPeerAttachResponse(req peerAttachRequest, resp peerAttachResponse, requestState peerAttachRequestState) (peerAttachResponseVerifyResult, error) {
cfg := c.peerAttachSecuritySnapshot()
baseSteady := transportProtectionProfile{}
if c != nil {
baseSteady = c.securitySteady.clone().withForwardSecrecyFallback(false)
}
result := peerAttachResponseVerifyResult{steadyProfile: baseSteady}
if c == nil || !c.shouldUseExplicitPeerAttachAuth() {
if c.clientRequiresForwardSecrecy() {
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired
}
if cfg.requireExplicitAuth {
return peerAttachResponseVerifyResult{}, errPeerAttachExplicitAuthRequired
}
return result, nil
}
if !supportsExplicitPeerAttachAuth(resp.Features) {
if c.clientRequiresForwardSecrecy() {
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired
}
if cfg.requireExplicitAuth {
return peerAttachResponseVerifyResult{}, errPeerAttachExplicitAuthRequired
}
if c.clientSupportsForwardSecrecy() {
result.steadyProfile = result.steadyProfile.withForwardSecrecyFallback(true)
}
result.authFallback = true
return result, nil
}
var channelBinding []byte
if supportsPeerAttachChannelBinding(req.Features) {
if cfg.channelBinding == nil {
return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingUnavailable
}
if !supportsPeerAttachChannelBinding(resp.Features) {
return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingRequired
}
var err error
channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleClient, req.PeerID, c.clientTransportConnSnapshot())
if err != nil {
return peerAttachResponseVerifyResult{}, err
}
} else if cfg.requireChannelBinding {
return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingRequired
}
if len(resp.ServerNonce) != peerAttachNonceSize {
return peerAttachResponseVerifyResult{}, errPeerAttachAuthInvalid
}
expected := computePeerAttachResponseAuthTag(c.securityBootstrap.secretKey, req, peerAttachResponse{
PeerID: resp.PeerID,
Accepted: resp.Accepted,
Reused: resp.Reused,
Error: resp.Error,
Features: resp.Features,
ServerNonce: resp.ServerNonce,
}, channelBinding)
if !hmac.Equal(resp.AuthTag, expected) {
return peerAttachResponseVerifyResult{}, errPeerAttachAuthInvalid
}
if requestState.forwardSecrecy == nil || !supportsPeerAttachForwardSecrecy(req.Features) {
return result, nil
}
if !supportsPeerAttachForwardSecrecy(resp.Features) {
if c.clientRequiresForwardSecrecy() {
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired
}
result.steadyProfile = result.steadyProfile.withForwardSecrecyFallback(true)
return result, nil
}
if resp.KeyMode != "" && resp.KeyMode != peerAttachKeyModeECDHE {
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyInvalid
}
profile, err := derivePeerAttachForwardSecrecyTransportProfile(c.securitySteady, c.securityBootstrap.secretKey, requestState.forwardSecrecy.privateKey, resp.ServerECDHEPublicKey, req, resp)
if err != nil {
return peerAttachResponseVerifyResult{}, err
}
result.steadyProfile = profile
return result, nil
}
func (s *ServerCommon) validatePeerAttachRequestAuth(logical *LogicalConn, transport net.Conn, req peerAttachRequest) (peerAttachAuthResult, error) {
cfg := s.peerAttachSecuritySnapshot()
if !supportsExplicitPeerAttachAuth(req.Features) {
if s.serverRequiresForwardSecrecy() {
return peerAttachAuthResult{}, errPeerAttachForwardSecrecyRequired
}
if cfg.requireExplicitAuth {
return peerAttachAuthResult{}, errPeerAttachExplicitAuthRequired
}
return peerAttachAuthResult{fallback: true}, nil
}
if logical == nil {
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
}
var channelBinding []byte
if supportsPeerAttachChannelBinding(req.Features) {
if cfg.channelBinding == nil {
return peerAttachAuthResult{}, errPeerAttachChannelBindingUnavailable
}
var err error
channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleServer, req.PeerID, transport)
if err != nil {
return peerAttachAuthResult{}, err
}
} else if cfg.requireChannelBinding {
return peerAttachAuthResult{}, errPeerAttachChannelBindingRequired
}
secretKey := logical.secretKeySnapshot()
if len(secretKey) == 0 {
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
}
if len(req.ClientNonce) != peerAttachNonceSize {
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
}
expected := computePeerAttachRequestAuthTag(secretKey, peerAttachRequest{
PeerID: req.PeerID,
Features: req.Features,
ClientNonce: req.ClientNonce,
}, channelBinding)
if !hmac.Equal(req.AuthTag, expected) {
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
}
if supportsPeerAttachForwardSecrecy(req.Features) {
if len(req.ClientECDHEPublicKey) != peerAttachECDHEPublicKeySize {
return peerAttachAuthResult{}, errPeerAttachForwardSecrecyInvalid
}
}
if err := s.acceptPeerAttachReplay(req.PeerID, req.ClientNonce, time.Now(), cfg.replayWindow, cfg.replayCapacity); err != nil {
return peerAttachAuthResult{}, err
}
serverNonce, err := newPeerAttachNonce()
if err != nil {
return peerAttachAuthResult{}, err
}
return peerAttachAuthResult{
explicit: true,
clientNonce: bytes.Clone(req.ClientNonce),
serverNonce: serverNonce,
channelBinding: channelBinding,
clientECDHEPublicKey: bytes.Clone(req.ClientECDHEPublicKey),
}, nil
}
func (s *ServerCommon) signPeerAttachResponse(logical *LogicalConn, req peerAttachRequest, resp *peerAttachResponse, auth peerAttachAuthResult) {
if s == nil || logical == nil || resp == nil || !auth.explicit {
return
}
secretKey := logical.secretKeySnapshot()
if len(secretKey) == 0 {
return
}
resp.Features |= peerAttachFeatureExplicitAuth
if len(auth.channelBinding) != 0 {
resp.Features |= peerAttachFeatureChannelBinding
}
resp.ServerNonce = bytes.Clone(auth.serverNonce)
resp.AuthTag = computePeerAttachResponseAuthTag(secretKey, req, *resp, auth.channelBinding)
}
func (s *ServerCommon) preparePeerAttachSteadyTransportProfile(logical *LogicalConn, req peerAttachRequest, resp *peerAttachResponse, auth peerAttachAuthResult) (transportProtectionProfile, error) {
if s == nil {
return transportProtectionProfile{}, nil
}
profile := s.securitySteady.clone().withForwardSecrecyFallback(false)
if resp != nil && auth.explicit {
resp.Features |= peerAttachFeatureExplicitAuth
if len(auth.channelBinding) != 0 {
resp.Features |= peerAttachFeatureChannelBinding
}
resp.ServerNonce = bytes.Clone(auth.serverNonce)
}
if resp != nil && resp.KeyMode == "" {
resp.KeyMode = profile.keyMode
}
if !s.serverSupportsForwardSecrecy() {
return profile, nil
}
if !auth.explicit || !supportsPeerAttachForwardSecrecy(req.Features) {
if s.serverRequiresForwardSecrecy() {
return transportProtectionProfile{}, errPeerAttachForwardSecrecyRequired
}
return profile.withForwardSecrecyFallback(true), nil
}
fsState, err := newPeerAttachForwardSecrecyClientState()
if err != nil {
return transportProtectionProfile{}, err
}
if resp != nil {
resp.Features |= peerAttachFeatureForwardSecrecy
resp.KeyMode = peerAttachKeyModeECDHE
resp.ServerECDHEPublicKey = bytes.Clone(fsState.publicKey)
}
return derivePeerAttachForwardSecrecyTransportProfile(s.securitySteady, logical.secretKeySnapshot(), fsState.privateKey, auth.clientECDHEPublicKey, req, *resp)
}
func (s *ServerCommon) acceptPeerAttachReplay(peerID string, nonce []byte, now time.Time, window time.Duration, capacity int) error {
if s == nil || len(nonce) == 0 {
return nil
}
cache := &s.peerAttachReplay
key := peerID + "\x00" + string(nonce)
expireBefore := now.Add(-window)
cache.mu.Lock()
defer cache.mu.Unlock()
if cache.entries == nil {
cache.entries = make(map[string]time.Time)
}
for replayKey, seenAt := range cache.entries {
if seenAt.Before(expireBefore) {
delete(cache.entries, replayKey)
}
}
if seenAt, ok := cache.entries[key]; ok && !seenAt.Before(expireBefore) {
return errPeerAttachReplayRejected
}
if capacity > 0 && len(cache.entries) >= capacity {
return errPeerAttachReplayWindowFull
}
cache.entries[key] = now
return nil
}
func (s *ServerCommon) peerAttachReplayRejectCountSnapshot() int64 {
if s == nil {
return 0
}
return s.peerAttachReplay.rejects.Load()
}
func (s *ServerCommon) peerAttachReplayOverflowRejectCountSnapshot() int64 {
if s == nil {
return 0
}
return s.peerAttachReplay.overflowRejects.Load()
}
func (c *ClientCommon) markClientPeerAttachAuthenticated(fallback bool, at time.Time) {
if c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
c.peerAttachAuthenticated = true
c.peerAttachAuthFallback = fallback
c.peerAttachAt = at.UnixNano()
}
func (c *ClientCommon) resetClientPeerAttachAuth() {
if c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
c.peerAttachAuthenticated = false
c.peerAttachAuthFallback = false
c.peerAttachAt = 0
}
func (c *ClientCommon) clientPeerAttachAuthSnapshot() (bool, bool, time.Time) {
if c == nil {
return false, false, time.Time{}
}
c.mu.Lock()
defer c.mu.Unlock()
if c.peerAttachAt == 0 {
return c.peerAttachAuthenticated, c.peerAttachAuthFallback, time.Time{}
}
return c.peerAttachAuthenticated, c.peerAttachAuthFallback, time.Unix(0, c.peerAttachAt)
}
func stringsTrimSpaceNoAlloc(value string) string {
start := 0
for start < len(value) {
switch value[start] {
case ' ', '\t', '\n', '\r':
start++
default:
goto endStart
}
}
return ""
endStart:
end := len(value)
for end > start {
switch value[end-1] {
case ' ', '\t', '\n', '\r':
end--
default:
return value[start:end]
}
}
return value[start:end]
}

237
peer_attach_auth_test.go Normal file
View File

@ -0,0 +1,237 @@
package notify
import (
"bytes"
"errors"
"testing"
)
func newPeerAttachAuthLogicalForTest(t *testing.T, server *ServerCommon) *LogicalConn {
t.Helper()
logical := newServerLogicalConn(server, "accepted-auth", nil)
logical = server.registerAcceptedLogical(logical)
if logical == nil {
t.Fatal("registerAcceptedLogical returned nil")
}
return logical
}
func TestPeerAttachExplicitAuthHelpersRoundTrip(t *testing.T) {
secret := []byte("correct horse battery staple")
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
logical := newPeerAttachAuthLogicalForTest(t, server)
req, reqState, err := client.buildPeerAttachRequest("peer-explicit")
if err != nil {
t.Fatalf("buildPeerAttachRequest failed: %v", err)
}
if !supportsExplicitPeerAttachAuth(req.Features) {
t.Fatalf("request features = %d, want explicit auth bit", req.Features)
}
if !supportsPeerAttachForwardSecrecy(req.Features) {
t.Fatalf("request features = %d, want forward secrecy bit", req.Features)
}
if len(req.ClientNonce) != peerAttachNonceSize {
t.Fatalf("client nonce length = %d, want %d", len(req.ClientNonce), peerAttachNonceSize)
}
auth, err := server.validatePeerAttachRequestAuth(logical, nil, req)
if err != nil {
t.Fatalf("validatePeerAttachRequestAuth failed: %v", err)
}
if !auth.explicit || auth.fallback {
t.Fatalf("auth result mismatch: %+v", auth)
}
resp := peerAttachResponse{
PeerID: req.PeerID,
Accepted: true,
}
server.signPeerAttachResponse(logical, req, &resp, auth)
if !supportsExplicitPeerAttachAuth(resp.Features) {
t.Fatalf("response features = %d, want explicit auth bit", resp.Features)
}
if len(resp.ServerNonce) != peerAttachNonceSize {
t.Fatalf("server nonce length = %d, want %d", len(resp.ServerNonce), peerAttachNonceSize)
}
verifyResult, err := client.verifyPeerAttachResponse(req, resp, reqState)
if err != nil {
t.Fatalf("verifyPeerAttachResponse failed: %v", err)
}
if verifyResult.authFallback {
t.Fatal("explicit response should not be marked as fallback")
}
if !verifyResult.steadyProfile.forwardSecrecyFallback {
t.Fatal("response without fs extension should mark forward secrecy fallback")
}
resp.AuthTag[0] ^= 0xff
if verifyResult, err = client.verifyPeerAttachResponse(req, resp, reqState); !errors.Is(err, errPeerAttachAuthInvalid) {
t.Fatalf("tampered response error = %v, want %v", err, errPeerAttachAuthInvalid)
} else if verifyResult.authFallback {
t.Fatal("tampered explicit response should not be treated as fallback")
}
}
func TestPeerAttachRequestAuthRejectsReplay(t *testing.T) {
secret := []byte("correct horse battery staple")
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
logical := newPeerAttachAuthLogicalForTest(t, server)
req, _, err := client.buildPeerAttachRequest("peer-replay")
if err != nil {
t.Fatalf("buildPeerAttachRequest failed: %v", err)
}
if _, err := server.validatePeerAttachRequestAuth(logical, nil, req); err != nil {
t.Fatalf("first validatePeerAttachRequestAuth failed: %v", err)
}
if _, err := server.validatePeerAttachRequestAuth(logical, nil, req); !errors.Is(err, errPeerAttachReplayRejected) {
t.Fatalf("second validatePeerAttachRequestAuth error = %v, want %v", err, errPeerAttachReplayRejected)
}
classifyPeerAttachRejectCounter(server, errPeerAttachReplayRejected)
if got, want := server.peerAttachReplayRejectCountSnapshot(), int64(1); got != want {
t.Fatalf("replay reject count = %d, want %d", got, want)
}
}
func TestPeerAttachAuthFallbackCompatibility(t *testing.T) {
secret := []byte("correct horse battery staple")
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
logical := newPeerAttachAuthLogicalForTest(t, server)
auth, err := server.validatePeerAttachRequestAuth(logical, nil, peerAttachRequest{PeerID: "peer-fallback"})
if err != nil {
t.Fatalf("validatePeerAttachRequestAuth fallback failed: %v", err)
}
if auth.explicit || !auth.fallback {
t.Fatalf("fallback auth result mismatch: %+v", auth)
}
req, reqState, err := client.buildPeerAttachRequest("peer-fallback")
if err != nil {
t.Fatalf("buildPeerAttachRequest failed: %v", err)
}
verifyResult, err := client.verifyPeerAttachResponse(req, peerAttachResponse{
PeerID: req.PeerID,
Accepted: true,
}, reqState)
if err != nil {
t.Fatalf("verifyPeerAttachResponse fallback failed: %v", err)
}
if !verifyResult.authFallback {
t.Fatal("unsigned legacy response should be marked as fallback")
}
}
func TestPeerAttachForwardSecrecyNegotiatesDerivedSteadyProfile(t *testing.T) {
secret := []byte("correct horse battery staple")
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
logical := newPeerAttachAuthLogicalForTest(t, server)
req, reqState, err := client.buildPeerAttachRequest("peer-fs")
if err != nil {
t.Fatalf("buildPeerAttachRequest failed: %v", err)
}
auth, err := server.validatePeerAttachRequestAuth(logical, nil, req)
if err != nil {
t.Fatalf("validatePeerAttachRequestAuth failed: %v", err)
}
resp := peerAttachResponse{
PeerID: req.PeerID,
Accepted: true,
}
serverProfile, err := server.preparePeerAttachSteadyTransportProfile(logical, req, &resp, auth)
if err != nil {
t.Fatalf("preparePeerAttachSteadyTransportProfile failed: %v", err)
}
server.signPeerAttachResponse(logical, req, &resp, auth)
verifyResult, err := client.verifyPeerAttachResponse(req, resp, reqState)
if err != nil {
t.Fatalf("verifyPeerAttachResponse failed: %v", err)
}
if !supportsPeerAttachForwardSecrecy(resp.Features) {
t.Fatalf("response features = %d, want forward secrecy bit", resp.Features)
}
if resp.KeyMode != peerAttachKeyModeECDHE {
t.Fatalf("response key mode = %q, want %q", resp.KeyMode, peerAttachKeyModeECDHE)
}
if !verifyResult.steadyProfile.forwardSecrecy {
t.Fatal("client steady profile should enable forward secrecy")
}
if !serverProfile.forwardSecrecy {
t.Fatal("server steady profile should enable forward secrecy")
}
if len(verifyResult.steadyProfile.sessionID) == 0 {
t.Fatal("client session id should be populated")
}
if !bytes.Equal(verifyResult.steadyProfile.secretKey, serverProfile.secretKey) {
t.Fatal("client/server derived steady keys should match")
}
if !bytes.Equal(verifyResult.steadyProfile.sessionID, serverProfile.sessionID) {
t.Fatal("client/server session ids should match")
}
if bytes.Equal(verifyResult.steadyProfile.secretKey, client.securityBootstrap.secretKey) {
t.Fatal("derived steady key should differ from bootstrap key")
}
}
func TestPeerAttachForwardSecrecyStrictRejectsFallback(t *testing.T) {
secret := []byte("correct horse battery staple")
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
opts := testModernPSKOptions()
opts.RequireForwardSecrecy = true
if err := UseModernPSKClient(client, secret, opts); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, opts); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
req, reqState, err := client.buildPeerAttachRequest("peer-fs-strict")
if err != nil {
t.Fatalf("buildPeerAttachRequest failed: %v", err)
}
_, err = client.verifyPeerAttachResponse(req, peerAttachResponse{
PeerID: req.PeerID,
Accepted: true,
Features: peerAttachFeatureExplicitAuth,
ServerNonce: make([]byte, peerAttachNonceSize),
AuthTag: computePeerAttachResponseAuthTag(client.securityBootstrap.secretKey, req, peerAttachResponse{PeerID: req.PeerID, Accepted: true, Features: peerAttachFeatureExplicitAuth, ServerNonce: make([]byte, peerAttachNonceSize)}, nil),
}, reqState)
if !errors.Is(err, errPeerAttachForwardSecrecyRequired) {
t.Fatalf("verifyPeerAttachResponse error = %v, want %v", err, errPeerAttachForwardSecrecyRequired)
}
}

139
peer_attach_policy.go Normal file
View File

@ -0,0 +1,139 @@
package notify
import (
"errors"
"net"
"time"
)
const defaultPeerAttachReplayCapacity = 4096
type PeerAttachChannelBindingRole string
const (
PeerAttachChannelBindingRoleClient PeerAttachChannelBindingRole = "client"
PeerAttachChannelBindingRoleServer PeerAttachChannelBindingRole = "server"
)
type PeerAttachChannelBindingContext struct {
Role PeerAttachChannelBindingRole
PeerID string
Conn net.Conn
}
type PeerAttachChannelBindingProvider func(PeerAttachChannelBindingContext) ([]byte, error)
type PeerAttachSecurityConfig struct {
RequireExplicitAuth bool
RequireChannelBinding bool
ReplayWindow time.Duration
ReplayCapacity int
ChannelBinding PeerAttachChannelBindingProvider
}
type peerAttachSecurityState struct {
requireExplicitAuth bool
requireChannelBinding bool
replayWindow time.Duration
replayCapacity int
channelBinding PeerAttachChannelBindingProvider
}
var errPeerAttachChannelBindingProviderNil = errors.New("peer attach channel binding provider is nil")
func DefaultPeerAttachSecurityConfig() PeerAttachSecurityConfig {
return PeerAttachSecurityConfig{
ReplayWindow: peerAttachReplayTTL,
ReplayCapacity: defaultPeerAttachReplayCapacity,
}
}
func normalizePeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) (peerAttachSecurityState, error) {
if cfg.ReplayWindow <= 0 {
cfg.ReplayWindow = peerAttachReplayTTL
}
if cfg.ReplayCapacity <= 0 {
cfg.ReplayCapacity = defaultPeerAttachReplayCapacity
}
if cfg.RequireChannelBinding {
cfg.RequireExplicitAuth = true
if cfg.ChannelBinding == nil {
return peerAttachSecurityState{}, errPeerAttachChannelBindingProviderNil
}
}
return peerAttachSecurityState{
requireExplicitAuth: cfg.RequireExplicitAuth,
requireChannelBinding: cfg.RequireChannelBinding,
replayWindow: cfg.ReplayWindow,
replayCapacity: cfg.ReplayCapacity,
channelBinding: cfg.ChannelBinding,
}, nil
}
func peerAttachSecurityConfigFromState(state *peerAttachSecurityState) PeerAttachSecurityConfig {
if state == nil {
return DefaultPeerAttachSecurityConfig()
}
return PeerAttachSecurityConfig{
RequireExplicitAuth: state.requireExplicitAuth,
RequireChannelBinding: state.requireChannelBinding,
ReplayWindow: state.replayWindow,
ReplayCapacity: state.replayCapacity,
ChannelBinding: state.channelBinding,
}
}
func defaultPeerAttachSecurityState() *peerAttachSecurityState {
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
return &cfg
}
func (c *ClientCommon) peerAttachSecuritySnapshot() peerAttachSecurityState {
if c == nil {
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
return cfg
}
if state := c.peerAttachSecurity.Load(); state != nil {
return *state
}
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
return cfg
}
func (s *ServerCommon) peerAttachSecuritySnapshot() peerAttachSecurityState {
if s == nil {
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
return cfg
}
if state := s.peerAttachSecurity.Load(); state != nil {
return *state
}
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
return cfg
}
func (c *ClientCommon) SetPeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) error {
state, err := normalizePeerAttachSecurityConfig(cfg)
if err != nil {
return err
}
c.peerAttachSecurity.Store(&state)
return nil
}
func (c *ClientCommon) PeerAttachSecurityConfig() PeerAttachSecurityConfig {
return peerAttachSecurityConfigFromState(c.peerAttachSecurity.Load())
}
func (s *ServerCommon) SetPeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) error {
state, err := normalizePeerAttachSecurityConfig(cfg)
if err != nil {
return err
}
s.peerAttachSecurity.Store(&state)
return nil
}
func (s *ServerCommon) PeerAttachSecurityConfig() PeerAttachSecurityConfig {
return peerAttachSecurityConfigFromState(s.peerAttachSecurity.Load())
}

221
peer_attach_policy_test.go Normal file
View File

@ -0,0 +1,221 @@
package notify
import (
"bytes"
"errors"
"net"
"testing"
"time"
)
func staticPeerAttachChannelBindingProvider(material []byte) PeerAttachChannelBindingProvider {
cloned := bytes.Clone(material)
return func(PeerAttachChannelBindingContext) ([]byte, error) {
return bytes.Clone(cloned), nil
}
}
func failingPeerAttachChannelBindingProvider(PeerAttachChannelBindingContext) ([]byte, error) {
return nil, errors.New("binding unavailable")
}
func TestSetPeerAttachSecurityConfigRejectsMissingChannelBindingProvider(t *testing.T) {
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
cfg := PeerAttachSecurityConfig{RequireChannelBinding: true}
if err := client.SetPeerAttachSecurityConfig(cfg); !errors.Is(err, errPeerAttachChannelBindingProviderNil) {
t.Fatalf("client SetPeerAttachSecurityConfig error = %v, want %v", err, errPeerAttachChannelBindingProviderNil)
}
if err := server.SetPeerAttachSecurityConfig(cfg); !errors.Is(err, errPeerAttachChannelBindingProviderNil) {
t.Fatalf("server SetPeerAttachSecurityConfig error = %v, want %v", err, errPeerAttachChannelBindingProviderNil)
}
}
func TestPeerAttachRequireExplicitAuthRejectsFallbackClient(t *testing.T) {
secret := []byte("correct horse battery staple")
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
RequireExplicitAuth: true,
}); err != nil {
t.Fatalf("SetPeerAttachSecurityConfig failed: %v", err)
}
logical := newPeerAttachAuthLogicalForTest(t, server)
if _, err := server.validatePeerAttachRequestAuth(logical, nil, peerAttachRequest{PeerID: "peer-fallback"}); !errors.Is(err, errPeerAttachExplicitAuthRequired) {
t.Fatalf("validatePeerAttachRequestAuth error = %v, want %v", err, errPeerAttachExplicitAuthRequired)
}
classifyPeerAttachRejectCounter(server, errPeerAttachExplicitAuthRequired)
snapshot, snapErr := GetServerRuntimeSnapshot(server)
if snapErr != nil {
t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr)
}
if got, want := snapshot.PeerAttachDowngradeRejects, int64(1); got != want {
t.Fatalf("PeerAttachDowngradeRejects = %d, want %d", got, want)
}
if got := snapshot.PeerAttachAuthFallbacks; got != 0 {
t.Fatalf("PeerAttachAuthFallbacks = %d, want 0", got)
}
}
func TestPeerAttachChannelBindingRoundTrip(t *testing.T) {
secret := []byte("correct horse battery staple")
bindingProvider := staticPeerAttachChannelBindingProvider([]byte("tls-exporter:test"))
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UsePSKOverExternalTransportServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
}
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
RequireChannelBinding: true,
ChannelBinding: bindingProvider,
}); err != nil {
t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err)
}
server.SetLink("echo", func(msg *Message) {
_ = msg.Reply([]byte("ack"))
})
})
client := NewClient().(*ClientCommon)
if err := UsePSKOverExternalTransportClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
}
if err := client.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
RequireChannelBinding: true,
ChannelBinding: bindingProvider,
}); err != nil {
t.Fatalf("client SetPeerAttachSecurityConfig failed: %v", err)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachLogicalForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
reply, err := client.SendWait("echo", []byte("ping"), time.Second)
if err != nil {
t.Fatalf("SendWait failed: %v", err)
}
if got, want := string(reply.Value), "ack"; got != want {
t.Fatalf("reply = %q, want %q", got, want)
}
serverSnapshot, err := GetServerRuntimeSnapshot(server)
if err != nil {
t.Fatalf("GetServerRuntimeSnapshot failed: %v", err)
}
if !serverSnapshot.PeerAttachRequireExplicitAuth || !serverSnapshot.PeerAttachRequireChannelBinding || !serverSnapshot.PeerAttachChannelBindingConfigured {
t.Fatalf("unexpected server peer attach policy snapshot: %+v", serverSnapshot)
}
if got, want := serverSnapshot.PeerAttachExplicitAuth, int64(1); got != want {
t.Fatalf("PeerAttachExplicitAuth = %d, want %d", got, want)
}
if serverSnapshot.PeerAttachAuthRejects != 0 || serverSnapshot.PeerAttachDowngradeRejects != 0 || serverSnapshot.PeerAttachBindingRejects != 0 {
t.Fatalf("unexpected server reject counters: %+v", serverSnapshot)
}
clientSnapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if !clientSnapshot.PeerAttachRequireExplicitAuth || !clientSnapshot.PeerAttachRequireChannelBinding || !clientSnapshot.PeerAttachChannelBindingConfigured {
t.Fatalf("unexpected client peer attach policy snapshot: %+v", clientSnapshot)
}
}
func TestPeerAttachChannelBindingProviderFailureRejectsAttach(t *testing.T) {
secret := []byte("correct horse battery staple")
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
RequireChannelBinding: true,
ChannelBinding: failingPeerAttachChannelBindingProvider,
}); err != nil {
t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err)
}
})
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
RequireChannelBinding: true,
ChannelBinding: staticPeerAttachChannelBindingProvider([]byte("binding")),
}); err != nil {
t.Fatalf("client SetPeerAttachSecurityConfig failed: %v", err)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachLogicalForTest(t, server, right)
err := client.ConnectByConn(left)
if !errors.Is(err, errPeerAttachChannelBindingUnavailable) && (err == nil || err.Error() != errPeerAttachChannelBindingUnavailable.Error()) {
t.Fatalf("ConnectByConn error = %v, want %v", err, errPeerAttachChannelBindingUnavailable)
}
snapshot, snapErr := GetServerRuntimeSnapshot(server)
if snapErr != nil {
t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr)
}
if got, want := snapshot.PeerAttachBindingRejects, int64(1); got != want {
t.Fatalf("PeerAttachBindingRejects = %d, want %d", got, want)
}
}
func TestPeerAttachReplayCapacityRejectsOverflow(t *testing.T) {
secret := []byte("correct horse battery staple")
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
ReplayCapacity: 1,
}); err != nil {
t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err)
}
logical := newPeerAttachAuthLogicalForTest(t, server)
first, _, err := client.buildPeerAttachRequest("peer-one")
if err != nil {
t.Fatalf("buildPeerAttachRequest(first) failed: %v", err)
}
second, _, err := client.buildPeerAttachRequest("peer-two")
if err != nil {
t.Fatalf("buildPeerAttachRequest(second) failed: %v", err)
}
if _, err := server.validatePeerAttachRequestAuth(logical, nil, first); err != nil {
t.Fatalf("validatePeerAttachRequestAuth(first) failed: %v", err)
}
if _, err := server.validatePeerAttachRequestAuth(logical, nil, second); !errors.Is(err, errPeerAttachReplayWindowFull) {
t.Fatalf("validatePeerAttachRequestAuth(second) error = %v, want %v", err, errPeerAttachReplayWindowFull)
}
classifyPeerAttachRejectCounter(server, errPeerAttachReplayWindowFull)
snapshot, snapErr := GetServerRuntimeSnapshot(server)
if snapErr != nil {
t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr)
}
if got, want := snapshot.PeerAttachReplayCapacity, 1; got != want {
t.Fatalf("PeerAttachReplayCapacity = %d, want %d", got, want)
}
if got, want := snapshot.PeerAttachReplayOverflowRejects, int64(1); got != want {
t.Fatalf("PeerAttachReplayOverflowRejects = %d, want %d", got, want)
}
}

View File

@ -16,6 +16,10 @@ const (
type peerAttachRequest struct { type peerAttachRequest struct {
PeerID string PeerID string
Features uint64
ClientNonce []byte
ClientECDHEPublicKey []byte
AuthTag []byte
} }
type peerAttachResponse struct { type peerAttachResponse struct {
@ -23,6 +27,11 @@ type peerAttachResponse struct {
Accepted bool Accepted bool
Reused bool Reused bool
Error string Error string
Features uint64
KeyMode string
ServerNonce []byte
ServerECDHEPublicKey []byte
AuthTag []byte
} }
func newClientPeerIdentity() string { func newClientPeerIdentity() string {
@ -108,7 +117,11 @@ func (c *ClientCommon) announceClientPeerIdentity() error {
if peerID == "" { if peerID == "" {
return errors.New("peer identity is empty") return errors.New("peer identity is empty")
} }
encoded, err := c.sequenceEn(peerAttachRequest{PeerID: peerID}) req, requestState, err := c.buildPeerAttachRequest(peerID)
if err != nil {
return err
}
encoded, err := c.sequenceEn(req)
if err != nil { if err != nil {
return err return err
} }
@ -133,6 +146,12 @@ func (c *ClientCommon) announceClientPeerIdentity() error {
} }
return errors.New("peer attach rejected") return errors.New("peer attach rejected")
} }
verifyResult, err := c.verifyPeerAttachResponse(req, resp, requestState)
if err != nil {
return err
}
c.setClientNegotiatedSteadyTransportProtection(verifyResult.steadyProfile)
c.markClientPeerAttachAuthenticated(verifyResult.authFallback, time.Now())
return nil return nil
} }
@ -188,7 +207,7 @@ func (s *ServerCommon) replyPeerAttach(client *LogicalConn, message Message, res
Type: MSG_SYS_REPLY, Type: MSG_SYS_REPLY,
} }
if message.inboundConn != nil { if message.inboundConn != nil {
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply) return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, messageInboundTransportProtectionSnapshot(&message), reply)
} }
_, err = s.sendLogical(client, reply) _, err = s.sendLogical(client, reply)
return err return err
@ -200,6 +219,10 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool {
} }
message = hydrateServerMessagePeerFields(message) message = hydrateServerMessagePeerFields(message)
current := messageLogicalConnSnapshot(&message) current := messageLogicalConnSnapshot(&message)
transport := message.inboundConn
if transport == nil && current != nil {
transport = current.transportSnapshot()
}
req, err := decodePeerAttachRequest(s.sequenceDe, message.Value) req, err := decodePeerAttachRequest(s.sequenceDe, message.Value)
if err != nil { if err != nil {
if current != nil { if current != nil {
@ -210,6 +233,18 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool {
} }
return true return true
} }
auth, err := s.validatePeerAttachRequestAuth(current, transport, req)
if err != nil {
classifyPeerAttachRejectCounter(s, err)
if current != nil {
_ = s.replyPeerAttach(current, message, peerAttachResponse{
PeerID: req.PeerID,
Accepted: false,
Error: err.Error(),
})
}
return true
}
bound, reused, err := s.bindAcceptedClientIdentity(current, req.PeerID) bound, reused, err := s.bindAcceptedClientIdentity(current, req.PeerID)
if err != nil { if err != nil {
if current != nil { if current != nil {
@ -221,12 +256,37 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool {
} }
return true return true
} }
if err := s.replyPeerAttach(bound, message, peerAttachResponse{ resp := peerAttachResponse{
PeerID: bound.ID(), PeerID: bound.ID(),
Accepted: true, Accepted: true,
Reused: reused, Reused: reused,
}); err != nil && bound != nil { }
steadyProfile, err := s.preparePeerAttachSteadyTransportProfile(bound, req, &resp, auth)
if err != nil {
if bound != nil {
_ = s.replyPeerAttach(bound, message, peerAttachResponse{
PeerID: req.PeerID,
Accepted: false,
Error: err.Error(),
})
}
return true
}
s.signPeerAttachResponse(bound, req, &resp, auth)
if bound != nil {
bound.markPeerAttachAuthenticated(s.securityAuthMode, auth.fallback, time.Now())
if auth.explicit {
s.peerAttachExplicitCount.Add(1)
} else if auth.fallback {
s.peerAttachAuthFallbackCount.Add(1)
}
}
if err := s.replyPeerAttach(bound, message, resp); err != nil && bound != nil {
s.stopLogicalSession(bound, "peer attach reply failed", err) s.stopLogicalSession(bound, "peer attach reply failed", err)
return true
}
if bound != nil && s.securityConfigured {
bound.applyTransportProtectionProfile(steadyProfile)
} }
return true return true
} }

View File

@ -119,6 +119,7 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) {
defer serverConn.Close() defer serverConn.Close()
logical := bootstrapPeerAttachLogicalForTest(t, server, serverConn) logical := bootstrapPeerAttachLogicalForTest(t, server, serverConn)
originalProfile := logical.transportProtectionProfileSnapshot()
message := Message{ message := Message{
NetType: NET_SERVER, NetType: NET_SERVER,
LogicalConn: logical, LogicalConn: logical,
@ -131,6 +132,13 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) {
Time: time.Now(), Time: time.Now(),
inboundConn: serverConn, inboundConn: serverConn,
} }
message = hydrateServerMessagePeerFields(message)
alternate, err := deriveModernPSKProtectionProfile([]byte("notify-peer-attach-reply-alternate"), testModernPSKOptions(), ProtectionManaged)
if err != nil {
t.Fatalf("deriveModernPSKProtectionProfile failed: %v", err)
}
logical.applyTransportProtectionProfile(alternate)
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
@ -140,7 +148,7 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) {
}) })
}() }()
env := readServerEnvelopeFromConn(t, server, logical, clientConn, time.Second) env := readServerEnvelopeFromConnWithProfile(t, server, originalProfile, clientConn, time.Second)
if env.Kind != EnvelopeSignal { if env.Kind != EnvelopeSignal {
t.Fatalf("reply envelope kind = %v, want %v", env.Kind, EnvelopeSignal) t.Fatalf("reply envelope kind = %v, want %v", env.Kind, EnvelopeSignal)
} }

139
security_forward_secrecy.go Normal file
View File

@ -0,0 +1,139 @@
package notify
import (
"bytes"
"crypto/ecdh"
"crypto/hmac"
cryptorand "crypto/rand"
"crypto/sha256"
"encoding/binary"
"errors"
)
const (
peerAttachECDHEPublicKeySize = 32
peerAttachSessionIDSize = 16
peerAttachKeyModeStatic = "psk-static"
peerAttachKeyModeECDHE = "psk-ecdhe"
transportKeyModeExternal = "external"
)
var errPeerAttachForwardSecrecyInvalid = errors.New("peer attach forward secrecy is invalid")
type peerAttachRequestState struct {
forwardSecrecy *peerAttachForwardSecrecyClientState
}
type peerAttachForwardSecrecyClientState struct {
privateKey *ecdh.PrivateKey
publicKey []byte
}
type peerAttachResponseVerifyResult struct {
authFallback bool
steadyProfile transportProtectionProfile
}
func newPeerAttachForwardSecrecyClientState() (*peerAttachForwardSecrecyClientState, error) {
curve := ecdh.X25519()
privateKey, err := curve.GenerateKey(cryptorand.Reader)
if err != nil {
return nil, err
}
publicKey := privateKey.PublicKey().Bytes()
if len(publicKey) != peerAttachECDHEPublicKeySize {
return nil, errPeerAttachForwardSecrecyInvalid
}
return &peerAttachForwardSecrecyClientState{
privateKey: privateKey,
publicKey: bytes.Clone(publicKey),
}, nil
}
func derivePeerAttachForwardSecrecyTransportProfile(base transportProtectionProfile, bootstrapKey []byte, localPrivateKey *ecdh.PrivateKey, peerPublicKey []byte, req peerAttachRequest, resp peerAttachResponse) (transportProtectionProfile, error) {
if len(bootstrapKey) == 0 || localPrivateKey == nil || len(peerPublicKey) != peerAttachECDHEPublicKeySize {
return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid
}
curve := ecdh.X25519()
publicKey, err := curve.NewPublicKey(peerPublicKey)
if err != nil {
return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid
}
sharedSecret, err := localPrivateKey.ECDH(publicKey)
if err != nil {
return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid
}
transcriptHash := peerAttachForwardSecrecyTranscriptHash(req, resp)
ikm := make([]byte, 0, len(sharedSecret)+len(transcriptHash))
ikm = append(ikm, sharedSecret...)
ikm = append(ikm, transcriptHash...)
prk := hkdfExtractSHA256(bootstrapKey, ikm)
sessionKey := hkdfExpandSHA256(prk, []byte("notify/transport/session/v1"), 32)
sessionID := hkdfExpandSHA256(prk, []byte("notify/session-id/v1"), peerAttachSessionIDSize)
return deriveModernPSKSessionProtectionProfile(base, sessionKey, sessionID)
}
func peerAttachForwardSecrecyTranscriptHash(req peerAttachRequest, resp peerAttachResponse) []byte {
buf := make([]byte, 0, 256+len(req.PeerID)+len(resp.PeerID)+len(resp.Error)+len(req.ClientECDHEPublicKey)+len(resp.ServerECDHEPublicKey))
buf = appendPeerAttachTranscriptString(buf, "notify/peer-attach/forward-secrecy/v1")
buf = binary.BigEndian.AppendUint64(buf, req.Features)
buf = appendPeerAttachTranscriptString(buf, req.PeerID)
buf = appendPeerAttachTranscriptBytes(buf, req.ClientNonce)
buf = appendPeerAttachTranscriptBytes(buf, req.ClientECDHEPublicKey)
buf = binary.BigEndian.AppendUint64(buf, resp.Features)
buf = appendPeerAttachTranscriptString(buf, resp.PeerID)
buf = appendPeerAttachTranscriptBool(buf, resp.Accepted)
buf = appendPeerAttachTranscriptBool(buf, resp.Reused)
buf = appendPeerAttachTranscriptString(buf, resp.Error)
buf = appendPeerAttachTranscriptBytes(buf, resp.ServerNonce)
buf = appendPeerAttachTranscriptString(buf, resp.KeyMode)
buf = appendPeerAttachTranscriptBytes(buf, resp.ServerECDHEPublicKey)
sum := sha256.Sum256(buf)
return sum[:]
}
func appendPeerAttachTranscriptBytes(dst []byte, data []byte) []byte {
dst = binary.BigEndian.AppendUint32(dst, uint32(len(data)))
return append(dst, data...)
}
func appendPeerAttachTranscriptString(dst []byte, value string) []byte {
return appendPeerAttachTranscriptBytes(dst, []byte(value))
}
func appendPeerAttachTranscriptBool(dst []byte, value bool) []byte {
if value {
return append(dst, 1)
}
return append(dst, 0)
}
func hkdfExtractSHA256(salt []byte, ikm []byte) []byte {
mac := hmac.New(sha256.New, salt)
_, _ = mac.Write(ikm)
return mac.Sum(nil)
}
func hkdfExpandSHA256(prk []byte, info []byte, size int) []byte {
if size <= 0 {
return nil
}
out := make([]byte, 0, size)
var block []byte
for counter := byte(1); len(out) < size; counter++ {
mac := hmac.New(sha256.New, prk)
if len(block) != 0 {
_, _ = mac.Write(block)
}
_, _ = mac.Write(info)
_, _ = mac.Write([]byte{counter})
block = mac.Sum(nil)
remaining := size - len(out)
if remaining > len(block) {
remaining = len(block)
}
out = append(out, block[:remaining]...)
}
return out
}

408
security_profile.go Normal file
View File

@ -0,0 +1,408 @@
package notify
import "bytes"
// AuthMode describes how notify authenticates the peer during bootstrap.
type AuthMode int
const (
AuthNone AuthMode = iota
AuthPSK
AuthExternalPeer
)
// ProtectionMode describes how notify protects steady-state transport payloads.
type ProtectionMode int
const (
ProtectionManaged ProtectionMode = iota
ProtectionExternal
ProtectionNested
)
// SecurityOptions describes the high-level auth/protection policy.
//
// The current implementation still exposes dedicated helper constructors such
// as UseModernPSKClient/Server and UsePSKOverExternalTransportClient/Server.
type SecurityOptions struct {
AuthMode AuthMode
ProtectionMode ProtectionMode
SharedSecret []byte
RequireForwardSecrecy bool
}
func authModeName(mode AuthMode) string {
switch mode {
case AuthNone:
return "none"
case AuthPSK:
return "psk"
case AuthExternalPeer:
return "external-peer"
default:
return "unknown"
}
}
func protectionModeName(mode ProtectionMode) string {
switch mode {
case ProtectionManaged:
return "managed"
case ProtectionExternal:
return "external"
case ProtectionNested:
return "nested"
default:
return "unknown"
}
}
type transportProtectionProfile struct {
mode ProtectionMode
msgEn func([]byte, []byte) []byte
msgDe func([]byte, []byte) []byte
fastStreamEncode transportFastStreamEncoder
fastBulkEncode transportFastBulkEncoder
fastPlainEncode transportFastPlainEncoder
runtime *modernPSKCodecRuntime
secretKey []byte
keyMode string
sessionID []byte
forwardSecrecy bool
forwardSecrecyFallback bool
}
func cloneTransportProtectionKey(src []byte) []byte {
if len(src) == 0 {
return nil
}
return bytes.Clone(src)
}
func newTransportProtectionProfile(mode ProtectionMode, bundle modernPSKTransportBundle, runtime *modernPSKCodecRuntime, secretKey []byte) transportProtectionProfile {
return transportProtectionProfile{
mode: mode,
msgEn: bundle.msgEn,
msgDe: bundle.msgDe,
fastStreamEncode: bundle.fastStreamEncode,
fastBulkEncode: bundle.fastBulkEncode,
fastPlainEncode: bundle.fastPlainEncode,
runtime: runtime,
secretKey: cloneTransportProtectionKey(secretKey),
keyMode: defaultTransportKeyMode(mode, secretKey),
}
}
func defaultTransportKeyMode(mode ProtectionMode, secretKey []byte) string {
if len(secretKey) == 0 {
return ""
}
switch mode {
case ProtectionManaged, ProtectionNested:
return peerAttachKeyModeStatic
case ProtectionExternal:
return transportKeyModeExternal
default:
return ""
}
}
func cloneTransportSessionID(src []byte) []byte {
if len(src) == 0 {
return nil
}
return bytes.Clone(src)
}
func (p transportProtectionProfile) clone() transportProtectionProfile {
p.secretKey = cloneTransportProtectionKey(p.secretKey)
p.sessionID = cloneTransportSessionID(p.sessionID)
return p
}
func (p transportProtectionProfile) withForwardSecrecyFallback(fallback bool) transportProtectionProfile {
p = p.clone()
p.forwardSecrecy = false
p.forwardSecrecyFallback = fallback
p.sessionID = nil
return p
}
func passthroughTransportCodec(_ []byte, payload []byte) []byte {
return payload
}
func passthroughFastPlainEncode(_ []byte, plainLen int, fill func([]byte) error) ([]byte, error) {
if plainLen < 0 {
return nil, errTransportPayloadEncryptFailed
}
buf := make([]byte, plainLen)
if fill != nil {
if err := fill(buf); err != nil {
return nil, err
}
}
return buf, nil
}
func buildExternalTransportBundle() modernPSKTransportBundle {
fastPlainEncode := passthroughFastPlainEncode
fastStreamEncode := func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
return encodeStreamFastFramePayloadFast(fastPlainEncode, secretKey, streamFastDataFrame{
DataID: dataID,
Seq: seq,
Payload: payload,
})
}
fastBulkEncode := func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
return encodeBulkFastFramePayloadFast(fastPlainEncode, secretKey, bulkFastFrame{
Type: bulkFastPayloadTypeData,
DataID: dataID,
Seq: seq,
Payload: payload,
})
}
return modernPSKTransportBundle{
msgEn: passthroughTransportCodec,
msgDe: passthroughTransportCodec,
fastStreamEncode: fastStreamEncode,
fastBulkEncode: fastBulkEncode,
fastPlainEncode: fastPlainEncode,
}
}
func defaultTransportProtectionProfile() transportProtectionProfile {
return newTransportProtectionProfile(ProtectionManaged, defaultModernPSKTransportBundle(), nil, nil)
}
func (c *ClientCommon) clientTransportProtectionSnapshot() transportProtectionProfile {
if c == nil {
return transportProtectionProfile{}
}
if state := c.transportProtection.Load(); state != nil {
return *state
}
return transportProtectionProfile{
mode: ProtectionManaged,
msgEn: c.msgEn,
msgDe: c.msgDe,
fastStreamEncode: c.fastStreamEncode,
fastBulkEncode: c.fastBulkEncode,
fastPlainEncode: c.fastPlainEncode,
runtime: c.modernPSKRuntime,
secretKey: c.SecretKey,
keyMode: defaultTransportKeyMode(ProtectionManaged, c.SecretKey),
}
}
func (c *ClientCommon) setClientTransportProtectionProfile(profile transportProtectionProfile) {
if c == nil {
return
}
profile.secretKey = cloneTransportProtectionKey(profile.secretKey)
profile.sessionID = cloneTransportSessionID(profile.sessionID)
c.msgEn = profile.msgEn
c.msgDe = profile.msgDe
c.fastStreamEncode = profile.fastStreamEncode
c.fastBulkEncode = profile.fastBulkEncode
c.fastPlainEncode = profile.fastPlainEncode
c.modernPSKRuntime = profile.runtime
c.SecretKey = profile.secretKey
c.transportProtection.Store(&profile)
}
func (c *ClientCommon) clearClientSecurityProfiles() {
if c == nil {
return
}
c.securityConfigured = false
c.securityAuthMode = AuthNone
c.securityProtectionMode = ProtectionManaged
c.securityBootstrap = transportProtectionProfile{}
c.securitySteady = transportProtectionProfile{}
c.securitySteadyNegotiated = transportProtectionProfile{}
c.securityRequireForwardSecrecy = false
c.peerAttachAuthenticated = false
c.peerAttachAuthFallback = false
c.peerAttachAt = 0
}
func (c *ClientCommon) configureClientSecurityProfiles(authMode AuthMode, protectionMode ProtectionMode, bootstrap transportProtectionProfile, steady transportProtectionProfile, requireForwardSecrecy bool) {
if c == nil {
return
}
c.securityConfigured = true
c.securityAuthMode = authMode
c.securityProtectionMode = protectionMode
c.securityBootstrap = bootstrap.clone()
c.securitySteady = steady.clone()
c.securitySteadyNegotiated = steady.clone()
c.securityRequireForwardSecrecy = requireForwardSecrecy
c.setClientTransportProtectionProfile(bootstrap)
c.securityReadyCheck = len(bootstrap.secretKey) != 0
c.skipKeyExchange = true
}
func (c *ClientCommon) activateClientBootstrapTransportProtection() {
if c == nil || !c.securityConfigured {
return
}
c.resetClientNegotiatedSteadyTransportProtection()
c.setClientTransportProtectionProfile(c.securityBootstrap)
}
func (c *ClientCommon) activateClientSteadyTransportProtection() {
if c == nil || !c.securityConfigured {
return
}
c.setClientTransportProtectionProfile(c.securitySteadyNegotiated)
}
func (c *ClientCommon) resetClientNegotiatedSteadyTransportProtection() {
if c == nil {
return
}
c.securitySteadyNegotiated = c.securitySteady.clone().withForwardSecrecyFallback(false)
}
func (c *ClientCommon) setClientNegotiatedSteadyTransportProtection(profile transportProtectionProfile) {
if c == nil {
return
}
c.securitySteadyNegotiated = profile.clone()
}
func (c *ClientCommon) clientSupportsForwardSecrecy() bool {
if c == nil || !c.securityConfigured {
return false
}
if c.securityAuthMode != AuthPSK {
return false
}
if c.securityProtectionMode == ProtectionExternal {
return false
}
return len(c.securityBootstrap.secretKey) != 0
}
func (c *ClientCommon) clientRequiresForwardSecrecy() bool {
if c == nil {
return false
}
return c.securityRequireForwardSecrecy
}
func (s *ServerCommon) serverSupportsForwardSecrecy() bool {
if s == nil || !s.securityConfigured {
return false
}
if s.securityAuthMode != AuthPSK {
return false
}
if s.securityProtectionMode == ProtectionExternal {
return false
}
return len(s.securityBootstrap.secretKey) != 0
}
func (s *ServerCommon) serverRequiresForwardSecrecy() bool {
if s == nil {
return false
}
return s.securityRequireForwardSecrecy
}
func (s *ServerCommon) setServerDefaultTransportProtectionProfile(profile transportProtectionProfile) {
if s == nil {
return
}
profile.secretKey = cloneTransportProtectionKey(profile.secretKey)
profile.sessionID = cloneTransportSessionID(profile.sessionID)
s.defaultMsgEn = profile.msgEn
s.defaultMsgDe = profile.msgDe
s.defaultFastStreamEncode = profile.fastStreamEncode
s.defaultFastBulkEncode = profile.fastBulkEncode
s.defaultFastPlainEncode = profile.fastPlainEncode
s.defaultModernPSKRuntime = profile.runtime
s.SecretKey = profile.secretKey
}
func (s *ServerCommon) clearServerSecurityProfiles() {
if s == nil {
return
}
s.securityConfigured = false
s.securityAuthMode = AuthNone
s.securityProtectionMode = ProtectionManaged
s.securityBootstrap = transportProtectionProfile{}
s.securitySteady = transportProtectionProfile{}
s.securityRequireForwardSecrecy = false
}
func (s *ServerCommon) configureServerSecurityProfiles(authMode AuthMode, protectionMode ProtectionMode, bootstrap transportProtectionProfile, steady transportProtectionProfile, requireForwardSecrecy bool) {
if s == nil {
return
}
s.securityConfigured = true
s.securityAuthMode = authMode
s.securityProtectionMode = protectionMode
s.securityBootstrap = bootstrap.clone()
s.securitySteady = steady.clone()
s.securityRequireForwardSecrecy = requireForwardSecrecy
s.setServerDefaultTransportProtectionProfile(bootstrap)
s.securityReadyCheck = len(bootstrap.secretKey) != 0
}
func (s *ServerCommon) applyLogicalSteadyTransportProtection(logical *LogicalConn) {
if s == nil || logical == nil || !s.securityConfigured {
return
}
logical.applyTransportProtectionProfile(s.securitySteady)
}
func transportProtectionProfileFromAttachmentState(state *clientConnAttachmentState) transportProtectionProfile {
if state == nil {
return transportProtectionProfile{}
}
return transportProtectionProfile{
mode: state.protectionMode,
msgEn: state.msgEn,
msgDe: state.msgDe,
fastStreamEncode: state.fastStreamEncode,
fastBulkEncode: state.fastBulkEncode,
fastPlainEncode: state.fastPlainEncode,
runtime: state.modernPSKRuntime,
secretKey: cloneTransportProtectionKey(state.secretKey),
keyMode: state.keyMode,
sessionID: cloneTransportSessionID(state.sessionID),
forwardSecrecy: state.forwardSecrecy,
forwardSecrecyFallback: state.forwardSecrecyFallback,
}
}
func (c *LogicalConn) transportProtectionProfileSnapshot() transportProtectionProfile {
if c == nil {
return transportProtectionProfile{}
}
return transportProtectionProfileFromAttachmentState(c.attachmentStateSnapshot())
}
func (c *LogicalConn) applyTransportProtectionProfile(profile transportProtectionProfile) {
if c == nil {
return
}
c.updateAttachmentState(func(state *clientConnAttachmentState) {
state.protectionMode = profile.mode
state.msgEn = profile.msgEn
state.msgDe = profile.msgDe
state.fastStreamEncode = profile.fastStreamEncode
state.fastBulkEncode = profile.fastBulkEncode
state.fastPlainEncode = profile.fastPlainEncode
state.modernPSKRuntime = profile.runtime
state.secretKey = cloneTransportProtectionKey(profile.secretKey)
state.keyMode = profile.keyMode
state.sessionID = cloneTransportSessionID(profile.sessionID)
state.forwardSecrecy = profile.forwardSecrecy
state.forwardSecrecyFallback = profile.forwardSecrecyFallback
})
}

View File

@ -17,7 +17,8 @@ import (
var ( var (
errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty") errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty")
errModernPSKPayload = errors.New("invalid modern psk payload") errModernPSKPayload = errors.New("invalid modern psk payload")
errModernPSKRequired = errors.New("modern psk is required: call UseModernPSKClient/UseModernPSKServer or set a transport key before Connect/Listen") errModernPSKRequired = errors.New("transport security is required: call UseModernPSKClient/UseModernPSKServer, UsePSKOverExternalTransportClient/Server, or set a transport key before Connect/Listen")
errModernPSKForwardSecrecyUnsupported = errors.New("forward secrecy is unsupported for external transport protection")
) )
var ( var (
@ -50,6 +51,7 @@ type ModernPSKOptions struct {
Salt []byte Salt []byte
AAD []byte AAD []byte
Argon2Params starcrypto.Argon2Params Argon2Params starcrypto.Argon2Params
RequireForwardSecrecy bool
} }
// DefaultModernPSKOptions returns the recommended settings for the current // DefaultModernPSKOptions returns the recommended settings for the current
@ -78,24 +80,17 @@ func defaultModernPSKTransportBundle() modernPSKTransportBundle {
// Argon2id, and switches message protection to AES-GCM. Configure it before // Argon2id, and switches message protection to AES-GCM. Configure it before
// calling Connect/ConnectTimeout. // calling Connect/ConnectTimeout.
func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error { func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
key, aad, err := deriveModernPSKKey(sharedSecret, opts) managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
if err != nil { if err != nil {
return err return err
} }
transport := buildModernPSKTransportBundle(aad)
runtime, err := newModernPSKCodecRuntime(key, aad)
if err != nil {
return err
}
c.SetSecretKey(key)
c.SetMsgEn(transport.msgEn)
c.SetMsgDe(transport.msgDe)
if client, ok := c.(*ClientCommon); ok { if client, ok := c.(*ClientCommon); ok {
client.fastStreamEncode = transport.fastStreamEncode client.configureClientSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy)
client.fastBulkEncode = transport.fastBulkEncode return nil
client.fastPlainEncode = transport.fastPlainEncode
client.modernPSKRuntime = runtime
} }
c.SetSecretKey(managed.secretKey)
c.SetMsgEn(managed.msgEn)
c.SetMsgDe(managed.msgDe)
c.SetSkipExchangeKey(true) c.SetSkipExchangeKey(true)
return nil return nil
} }
@ -106,24 +101,95 @@ func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) e
// It derives a transport key with Argon2id and switches message protection to // It derives a transport key with Argon2id and switches message protection to
// AES-GCM. Configure it before calling Listen. // AES-GCM. Configure it before calling Listen.
func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error { func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
key, aad, err := deriveModernPSKKey(sharedSecret, opts) managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
if err != nil { if err != nil {
return err return err
} }
transport := buildModernPSKTransportBundle(aad)
runtime, err := newModernPSKCodecRuntime(key, aad)
if err != nil {
return err
}
s.SetSecretKey(key)
s.SetDefaultCommEncode(transport.msgEn)
s.SetDefaultCommDecode(transport.msgDe)
if server, ok := s.(*ServerCommon); ok { if server, ok := s.(*ServerCommon); ok {
server.defaultFastStreamEncode = transport.fastStreamEncode server.configureServerSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy)
server.defaultFastBulkEncode = transport.fastBulkEncode return nil
server.defaultFastPlainEncode = transport.fastPlainEncode
server.defaultModernPSKRuntime = runtime
} }
s.SetSecretKey(managed.secretKey)
s.SetDefaultCommEncode(managed.msgEn)
s.SetDefaultCommDecode(managed.msgDe)
return nil
}
// UsePSKOverExternalTransportClient authenticates bootstrap with PSK and then
// trusts the external channel for steady-state payload protection.
func UsePSKOverExternalTransportClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
if opts != nil && opts.RequireForwardSecrecy {
return errModernPSKForwardSecrecyUnsupported
}
bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
if err != nil {
return err
}
steady := buildExternalProtectionProfile(bootstrap.secretKey)
if client, ok := c.(*ClientCommon); ok {
client.configureClientSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false)
return nil
}
c.SetSecretKey(bootstrap.secretKey)
c.SetMsgEn(bootstrap.msgEn)
c.SetMsgDe(bootstrap.msgDe)
c.SetSkipExchangeKey(true)
return nil
}
// UsePSKOverExternalTransportServer authenticates bootstrap with PSK and then
// trusts the external channel for steady-state payload protection.
func UsePSKOverExternalTransportServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
if opts != nil && opts.RequireForwardSecrecy {
return errModernPSKForwardSecrecyUnsupported
}
bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
if err != nil {
return err
}
steady := buildExternalProtectionProfile(bootstrap.secretKey)
if server, ok := s.(*ServerCommon); ok {
server.configureServerSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false)
return nil
}
s.SetSecretKey(bootstrap.secretKey)
s.SetDefaultCommEncode(bootstrap.msgEn)
s.SetDefaultCommDecode(bootstrap.msgDe)
return nil
}
// UseNestedSecurityClient keeps notify transport protection enabled even when
// the physical connection is already protected by an outer trusted channel.
func UseNestedSecurityClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested)
if err != nil {
return err
}
if client, ok := c.(*ClientCommon); ok {
client.configureClientSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy)
return nil
}
c.SetSecretKey(managed.secretKey)
c.SetMsgEn(managed.msgEn)
c.SetMsgDe(managed.msgDe)
c.SetSkipExchangeKey(true)
return nil
}
// UseNestedSecurityServer keeps notify transport protection enabled even when
// the physical connection is already protected by an outer trusted channel.
func UseNestedSecurityServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested)
if err != nil {
return err
}
if server, ok := s.(*ServerCommon); ok {
server.configureServerSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy)
return nil
}
s.SetSecretKey(managed.secretKey)
s.SetDefaultCommEncode(managed.msgEn)
s.SetDefaultCommDecode(managed.msgDe)
return nil return nil
} }
@ -132,15 +198,22 @@ func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) e
// //
// It is kept only as an explicit fallback path for existing deployments. // It is kept only as an explicit fallback path for existing deployments.
func UseLegacySecurityClient(c Client) { func UseLegacySecurityClient(c Client) {
if client, ok := c.(*ClientCommon); ok {
client.clearClientSecurityProfiles()
client.setClientTransportProtectionProfile(transportProtectionProfile{
mode: ProtectionManaged,
msgEn: defaultMsgEn,
msgDe: defaultMsgDe,
secretKey: bytes.Clone(defaultAesKey),
})
client.securityReadyCheck = false
client.skipKeyExchange = false
client.handshakeRsaPubKey = bytes.Clone(defaultRsaPubKey)
return
}
c.SetSecretKey(bytes.Clone(defaultAesKey)) c.SetSecretKey(bytes.Clone(defaultAesKey))
c.SetMsgEn(defaultMsgEn) c.SetMsgEn(defaultMsgEn)
c.SetMsgDe(defaultMsgDe) c.SetMsgDe(defaultMsgDe)
if client, ok := c.(*ClientCommon); ok {
client.fastStreamEncode = nil
client.fastBulkEncode = nil
client.fastPlainEncode = nil
client.modernPSKRuntime = nil
}
c.SetSkipExchangeKey(false) c.SetSkipExchangeKey(false)
c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey)) c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey))
} }
@ -150,15 +223,21 @@ func UseLegacySecurityClient(c Client) {
// //
// It is kept only as an explicit fallback path for existing deployments. // It is kept only as an explicit fallback path for existing deployments.
func UseLegacySecurityServer(s Server) { func UseLegacySecurityServer(s Server) {
if server, ok := s.(*ServerCommon); ok {
server.clearServerSecurityProfiles()
server.setServerDefaultTransportProtectionProfile(transportProtectionProfile{
mode: ProtectionManaged,
msgEn: defaultMsgEn,
msgDe: defaultMsgDe,
secretKey: bytes.Clone(defaultAesKey),
})
server.securityReadyCheck = false
server.handshakeRsaKey = bytes.Clone(defaultRsaKey)
return
}
s.SetSecretKey(bytes.Clone(defaultAesKey)) s.SetSecretKey(bytes.Clone(defaultAesKey))
s.SetDefaultCommEncode(defaultMsgEn) s.SetDefaultCommEncode(defaultMsgEn)
s.SetDefaultCommDecode(defaultMsgDe) s.SetDefaultCommDecode(defaultMsgDe)
if server, ok := s.(*ServerCommon); ok {
server.defaultFastStreamEncode = nil
server.defaultFastBulkEncode = nil
server.defaultFastPlainEncode = nil
server.defaultModernPSKRuntime = nil
}
s.SetRsaPrivKey(bytes.Clone(defaultRsaKey)) s.SetRsaPrivKey(bytes.Clone(defaultRsaKey))
} }
@ -174,6 +253,40 @@ func deriveModernPSKKey(sharedSecret []byte, opts *ModernPSKOptions) ([]byte, []
return key, cfg.AAD, nil return key, cfg.AAD, nil
} }
func deriveModernPSKProtectionProfile(sharedSecret []byte, opts *ModernPSKOptions, mode ProtectionMode) (transportProtectionProfile, error) {
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
if err != nil {
return transportProtectionProfile{}, err
}
transport := buildModernPSKTransportBundle(aad)
runtime, err := newModernPSKCodecRuntime(key, aad)
if err != nil {
return transportProtectionProfile{}, err
}
return newTransportProtectionProfile(mode, transport, runtime, key), nil
}
func buildExternalProtectionProfile(secretKey []byte) transportProtectionProfile {
return newTransportProtectionProfile(ProtectionExternal, buildExternalTransportBundle(), nil, secretKey)
}
func deriveModernPSKSessionProtectionProfile(base transportProtectionProfile, sessionKey []byte, sessionID []byte) (transportProtectionProfile, error) {
aad := bytes.Clone(defaultModernPSKAAD)
if base.runtime != nil && len(base.runtime.aad) != 0 {
aad = bytes.Clone(base.runtime.aad)
}
runtime, err := newModernPSKCodecRuntime(sessionKey, aad)
if err != nil {
return transportProtectionProfile{}, err
}
profile := newTransportProtectionProfile(base.mode, buildModernPSKTransportBundle(aad), runtime, sessionKey)
profile.keyMode = peerAttachKeyModeECDHE
profile.sessionID = cloneTransportSessionID(sessionID)
profile.forwardSecrecy = true
profile.forwardSecrecyFallback = false
return profile, nil
}
func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions { func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions {
cfg := DefaultModernPSKOptions() cfg := DefaultModernPSKOptions()
if opts == nil { if opts == nil {

View File

@ -3,8 +3,10 @@ package notify
import ( import (
"bytes" "bytes"
"errors" "errors"
"net"
"reflect" "reflect"
"testing" "testing"
"time"
"b612.me/starcrypto" "b612.me/starcrypto"
) )
@ -207,6 +209,17 @@ func TestUseModernPSKRejectsEmptySecret(t *testing.T) {
} }
} }
func TestUsePSKOverExternalTransportRejectsForwardSecrecyRequirement(t *testing.T) {
opts := testModernPSKOptions()
opts.RequireForwardSecrecy = true
if err := UsePSKOverExternalTransportClient(NewClient(), []byte("secret"), opts); !errors.Is(err, errModernPSKForwardSecrecyUnsupported) {
t.Fatalf("UsePSKOverExternalTransportClient error = %v, want %v", err, errModernPSKForwardSecrecyUnsupported)
}
if err := UsePSKOverExternalTransportServer(NewServer(), []byte("secret"), opts); !errors.Is(err, errModernPSKForwardSecrecyUnsupported) {
t.Fatalf("UsePSKOverExternalTransportServer error = %v, want %v", err, errModernPSKForwardSecrecyUnsupported)
}
}
func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) { func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) {
key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions()) key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions())
if err != nil { if err != nil {
@ -315,6 +328,166 @@ func TestModernPSKFastBulkEncodeRoundTrip(t *testing.T) {
} }
} }
func TestExternalTransportFastStreamEncodeRoundTrip(t *testing.T) {
transport := buildExternalTransportBundle()
wire, err := transport.fastStreamEncode(nil, 23, 7, []byte("payload"))
if err != nil {
t.Fatalf("fastStreamEncode failed: %v", err)
}
plain := transport.msgDe(nil, wire)
frame, matched, err := decodeStreamFastDataFrame(plain)
if err != nil {
t.Fatalf("decodeStreamFastDataFrame failed: %v", err)
}
if !matched {
t.Fatal("decodeStreamFastDataFrame should match fast payload")
}
if frame.DataID != 23 || frame.Seq != 7 || !bytes.Equal(frame.Payload, []byte("payload")) {
t.Fatalf("frame mismatch: %+v", frame)
}
}
func TestExternalTransportFastBulkEncodeRoundTrip(t *testing.T) {
transport := buildExternalTransportBundle()
wire, err := transport.fastBulkEncode(nil, 41, 9, []byte("payload"))
if err != nil {
t.Fatalf("fastBulkEncode failed: %v", err)
}
plain := transport.msgDe(nil, wire)
frame, matched, err := decodeBulkFastDataFrame(plain)
if err != nil {
t.Fatalf("decodeBulkFastDataFrame failed: %v", err)
}
if !matched {
t.Fatal("decodeBulkFastDataFrame should match fast payload")
}
if frame.DataID != 41 || frame.Seq != 9 || !bytes.Equal(frame.Payload, []byte("payload")) {
t.Fatalf("frame mismatch: %+v", frame)
}
}
func TestDecryptTransportPayloadCodecPooledExternalDefersRelease(t *testing.T) {
payload := []byte("payload")
released := false
plain, release, err := decryptTransportPayloadCodecPooled(ProtectionExternal, nil, passthroughTransportCodec, nil, payload, func() {
released = true
})
if err != nil {
t.Fatalf("decryptTransportPayloadCodecPooled failed: %v", err)
}
if released {
t.Fatal("release should not run before caller is done")
}
if !bytes.Equal(plain, payload) {
t.Fatalf("plain mismatch: got %q want %q", plain, payload)
}
if release == nil {
t.Fatal("release callback should be preserved for external mode")
}
release()
if !released {
t.Fatal("release callback should run when caller finishes")
}
}
func TestUsePSKOverExternalTransportConnectByConnSwitchesToExternal(t *testing.T) {
client := NewClient().(*ClientCommon)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UsePSKOverExternalTransportServer(server, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
}
server.SetLink("external-roundtrip", func(msg *Message) {
_ = msg.Reply([]byte("ack:" + string(msg.Value)))
})
})
if err := UsePSKOverExternalTransportClient(client, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
}
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionManaged {
t.Fatalf("client bootstrap mode = %v, want %v", got, ProtectionManaged)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachConnForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionExternal {
t.Fatalf("client steady mode = %v, want %v", got, ProtectionExternal)
}
reply, err := client.SendWait("external-roundtrip", []byte("ping"), time.Second)
if err != nil {
t.Fatalf("SendWait failed: %v", err)
}
if got, want := string(reply.Value), "ack:ping"; got != want {
t.Fatalf("reply mismatch: got %q want %q", got, want)
}
list := server.GetLogicalConnList()
if len(list) != 1 {
t.Fatalf("logical conn count = %d, want 1", len(list))
}
if got := list[0].protectionModeSnapshot(); got != ProtectionExternal {
t.Fatalf("server steady mode = %v, want %v", got, ProtectionExternal)
}
}
func TestUseNestedSecurityConnectByConnKeepsNestedMode(t *testing.T) {
client := NewClient().(*ClientCommon)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UseNestedSecurityServer(server, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
t.Fatalf("UseNestedSecurityServer failed: %v", err)
}
server.SetLink("nested-roundtrip", func(msg *Message) {
_ = msg.Reply([]byte("ack:" + string(msg.Value)))
})
})
if err := UseNestedSecurityClient(client, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
t.Fatalf("UseNestedSecurityClient failed: %v", err)
}
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionNested {
t.Fatalf("client bootstrap mode = %v, want %v", got, ProtectionNested)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachConnForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionNested {
t.Fatalf("client steady mode = %v, want %v", got, ProtectionNested)
}
reply, err := client.SendWait("nested-roundtrip", []byte("ping"), time.Second)
if err != nil {
t.Fatalf("SendWait failed: %v", err)
}
if got, want := string(reply.Value), "ack:ping"; got != want {
t.Fatalf("reply mismatch: got %q want %q", got, want)
}
list := server.GetLogicalConnList()
if len(list) != 1 {
t.Fatalf("logical conn count = %d, want 1", len(list))
}
if got := list[0].protectionModeSnapshot(); got != ProtectionNested {
t.Fatalf("server steady mode = %v, want %v", got, ProtectionNested)
}
}
func TestUseLegacySecurityRoundTrip(t *testing.T) { func TestUseLegacySecurityRoundTrip(t *testing.T) {
client := NewClient() client := NewClient()
server := NewServer() server := NewServer()

View File

@ -34,6 +34,19 @@ type ServerCommon struct {
defaultFastBulkEncode transportFastBulkEncoder defaultFastBulkEncode transportFastBulkEncoder
defaultFastPlainEncode transportFastPlainEncoder defaultFastPlainEncode transportFastPlainEncoder
defaultModernPSKRuntime *modernPSKCodecRuntime defaultModernPSKRuntime *modernPSKCodecRuntime
peerAttachSecurity atomic.Pointer[peerAttachSecurityState]
securityBootstrap transportProtectionProfile
securitySteady transportProtectionProfile
securityAuthMode AuthMode
securityProtectionMode ProtectionMode
securityRequireForwardSecrecy bool
securityConfigured bool
peerAttachReplay peerAttachReplayCache
peerAttachExplicitCount atomic.Int64
peerAttachAuthFallbackCount atomic.Int64
peerAttachAuthRejectCount atomic.Int64
peerAttachDowngradeRejectCount atomic.Int64
peerAttachBindingRejectCount atomic.Int64
linkFns map[string]func(message *Message) linkFns map[string]func(message *Message)
defaultFns func(message *Message) defaultFns func(message *Message)
noFinSyncMsgMaxKeepSeconds int64 noFinSyncMsgMaxKeepSeconds int64
@ -93,6 +106,8 @@ func NewServer() Server {
server.defaultFns = func(message *Message) { server.defaultFns = func(message *Message) {
return return
} }
server.setServerDefaultTransportProtectionProfile(defaultTransportProtectionProfile())
server.peerAttachSecurity.Store(defaultPeerAttachSecurityState())
server.sessionRuntime.Store(newServerSessionRuntimeBase(server.stopCtx, server.stopFn)) server.sessionRuntime.Store(newServerSessionRuntimeBase(server.stopCtx, server.stopFn))
bindServerStreamControl(&server) bindServerStreamControl(&server)
bindServerBulkControl(&server) bindServerBulkControl(&server)

View File

@ -413,7 +413,7 @@ func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *Tra
if err := bulk.waitDedicatedReady(ctx); err != nil { if err := bulk.waitDedicatedReady(ctx); err != nil {
return 0, err return 0, err
} }
return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, startSeq, payload) return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, startSeq, payload, payloadOwned)
} }
if transport == nil { if transport == nil {
return 0, errBulkTransportNil return 0, errBulkTransportNil

View File

@ -31,22 +31,28 @@ func (s *ServerCommon) Stop() error {
// Deprecated: SetDefaultCommEncode overrides the transport codec directly. // Deprecated: SetDefaultCommEncode overrides the transport codec directly.
// Prefer UseModernPSKServer or UseLegacySecurityServer. // Prefer UseModernPSKServer or UseLegacySecurityServer.
func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) { func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) {
s.defaultMsgEn = fn profile := transportProtectionProfile{
s.defaultFastStreamEncode = nil mode: ProtectionManaged,
s.defaultFastBulkEncode = nil msgEn: fn,
s.defaultFastPlainEncode = nil msgDe: s.defaultMsgDe,
s.defaultModernPSKRuntime = nil secretKey: s.SecretKey,
}
s.setServerDefaultTransportProtectionProfile(profile)
s.clearServerSecurityProfiles()
s.securityReadyCheck = false s.securityReadyCheck = false
} }
// Deprecated: SetDefaultCommDecode overrides the transport codec directly. // Deprecated: SetDefaultCommDecode overrides the transport codec directly.
// Prefer UseModernPSKServer or UseLegacySecurityServer. // Prefer UseModernPSKServer or UseLegacySecurityServer.
func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) { func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) {
s.defaultMsgDe = fn profile := transportProtectionProfile{
s.defaultFastStreamEncode = nil mode: ProtectionManaged,
s.defaultFastBulkEncode = nil msgEn: s.defaultMsgEn,
s.defaultFastPlainEncode = nil msgDe: fn,
s.defaultModernPSKRuntime = nil secretKey: s.SecretKey,
}
s.setServerDefaultTransportProtectionProfile(profile)
s.clearServerSecurityProfiles()
s.securityReadyCheck = false s.securityReadyCheck = false
} }
@ -98,14 +104,21 @@ func (s *ServerCommon) GetSecretKey() []byte {
// Deprecated: SetSecretKey injects a raw transport key directly. // Deprecated: SetSecretKey injects a raw transport key directly.
// Prefer UseModernPSKServer or UseLegacySecurityServer. // Prefer UseModernPSKServer or UseLegacySecurityServer.
func (s *ServerCommon) SetSecretKey(key []byte) { func (s *ServerCommon) SetSecretKey(key []byte) {
s.SecretKey = key profile := transportProtectionProfile{
if len(key) == 0 { mode: ProtectionManaged,
s.defaultModernPSKRuntime = nil msgEn: s.defaultMsgEn,
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil { msgDe: s.defaultMsgDe,
s.defaultModernPSKRuntime = runtime secretKey: cloneTransportProtectionKey(key),
} else {
s.defaultModernPSKRuntime = nil
} }
if len(key) == 0 {
profile.runtime = nil
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
profile.runtime = runtime
} else {
profile.runtime = nil
}
s.setServerDefaultTransportProtectionProfile(profile)
s.clearServerSecurityProfiles()
s.securityReadyCheck = len(key) == 0 s.securityReadyCheck = len(key) == 0
} }

View File

@ -3,12 +3,54 @@ package notify
import ( import (
"b612.me/stario" "b612.me/stario"
"context" "context"
"errors"
"math" "math"
"net" "net"
"testing" "testing"
"time" "time"
) )
func readServerEnvelopeFromConnWithProfile(t *testing.T, server *ServerCommon, profile transportProtectionProfile, conn net.Conn, timeout time.Duration) Envelope {
t.Helper()
queue := stario.NewQueue()
deadline := time.Now().Add(timeout)
buf := make([]byte, 4096)
for time.Now().Before(deadline) {
if err := conn.SetReadDeadline(deadline); err != nil {
t.Fatalf("SetReadDeadline failed: %v", err)
}
n, err := conn.Read(buf)
if n > 0 {
if parseErr := queue.ParseMessage(buf[:n], "server-inbound-profile"); parseErr != nil {
t.Fatalf("ParseMessage failed: %v", parseErr)
}
select {
case msg := <-queue.RestoreChan():
plain, decErr := decryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, msg.Msg)
if decErr != nil {
t.Fatalf("decryptTransportPayloadCodec failed: %v", decErr)
}
env, decErr := server.decodeEnvelopePlain(plain)
if decErr != nil {
t.Fatalf("decodeEnvelopePlain failed: %v", decErr)
}
return env
default:
}
}
if err == nil {
continue
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
break
}
t.Fatalf("conn Read failed: %v", err)
}
t.Fatal("timed out waiting for server envelope")
return Envelope{}
}
func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) { func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) {
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server) UseLegacySecurityServer(server)
@ -85,6 +127,64 @@ func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) {
} }
} }
func TestMessageReplyUsesInboundProtectionSnapshotAfterLogicalSwitch(t *testing.T) {
secret := []byte("correct horse battery staple")
alternate, err := deriveModernPSKProtectionProfile([]byte("notify-reply-snapshot-alternate"), testModernPSKOptions(), ProtectionManaged)
if err != nil {
t.Fatalf("deriveModernPSKProtectionProfile failed: %v", err)
}
handlerErr := make(chan error, 1)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
server.SetLink("reply-snapshot", func(msg *Message) {
if msg == nil || msg.LogicalConn == nil {
select {
case handlerErr <- errors.New("reply-snapshot logical is nil"):
default:
}
return
}
msg.LogicalConn.applyTransportProtectionProfile(alternate)
if err := msg.Reply([]byte("ack")); err != nil {
select {
case handlerErr <- err:
default:
}
}
})
})
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachConnForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
reply, err := client.SendWait("reply-snapshot", []byte("ping"), time.Second)
if err != nil {
t.Fatalf("SendWait failed: %v", err)
}
if got, want := string(reply.Value), "ack"; got != want {
t.Fatalf("reply value = %q, want %q", got, want)
}
select {
case err := <-handlerErr:
t.Fatalf("reply-snapshot handler failed: %v", err)
default:
}
}
func TestServerHandleReceivedSignalReliabilityUsesInboundConnForStaleTransport(t *testing.T) { func TestServerHandleReceivedSignalReliabilityUsesInboundConnForStaleTransport(t *testing.T) {
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server) UseLegacySecurityServer(server)

View File

@ -93,26 +93,34 @@ func (s *ServerCommon) pushTransportPayloadSourceFast(payload []byte, release fu
} }
return true return true
} }
if s.tryDispatchBorrowedBulkTransportPayload(source, payload) { logical, transport := s.resolveInboundSource(source)
if logical == nil {
if release != nil { if release != nil {
release() release()
} }
return true return true
} }
owned := append([]byte(nil), payload...) plain, plainRelease, err := s.decryptTransportPayloadLogicalPooled(logical, payload, release)
if release != nil { if err != nil {
release() if s.showError || s.debugMode {
fmt.Println("server decode transport payload error", err)
}
return true
}
inboundConn := serverInboundConn(source)
if s.tryDispatchBorrowedTransportPlain(logical, transport, inboundConn, plain, plainRelease) {
return true
}
owned := plain
if plainRelease != nil {
owned = append([]byte(nil), plain...)
plainRelease()
} }
s.wg.Add(1) s.wg.Add(1)
if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() { if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() {
defer s.wg.Done() defer s.wg.Done()
logical, transport := s.resolveInboundSource(source)
if logical == nil {
return
}
now := time.Now() now := time.Now()
inboundConn := serverInboundConn(source) if err := s.dispatchInboundTransportPlain(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) {
if err := s.dispatchInboundTransportPayload(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) {
fmt.Println("server decode envelope error", err) fmt.Println("server decode envelope error", err)
} }
}) { }) {

View File

@ -358,16 +358,20 @@ func (s *ServerCommon) sendEnvelopeTransport(transport *TransportConn, env Envel
} }
func (s *ServerCommon) sendEnvelopeInboundTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope) error { func (s *ServerCommon) sendEnvelopeInboundTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope) error {
return s.sendEnvelopeInboundTransportWithProfile(logical, transport, conn, nil, env)
}
func (s *ServerCommon) sendEnvelopeInboundTransportWithProfile(logical *LogicalConn, transport *TransportConn, conn net.Conn, profile *transportProtectionProfile, env Envelope) error {
if logical == nil && transport != nil { if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot() logical = transport.logicalConnSnapshot()
} }
if logical == nil { if logical == nil {
return transportDetachedErrorForPeer(logical, transport) return transportDetachedErrorForPeer(logical, transport)
} }
if logical.msgEnSnapshot() == nil { if profile == nil && logical.msgEnSnapshot() == nil {
return transportDetachedErrorForPeer(logical, transport) return transportDetachedErrorForPeer(logical, transport)
} }
payload, err := s.encodeEnvelopePayloadLogical(logical, env) payload, err := s.encodeEnvelopePayloadInbound(logical, env, profile)
if err != nil { if err != nil {
return err return err
} }
@ -402,7 +406,18 @@ func (s *ServerCommon) writeControlEnvelopePayload(logical *LogicalConn, transpo
return sender.submit(payload, writeDeadlineFromTimeout(logical.maxWriteTimeoutSnapshot())) return sender.submit(payload, writeDeadlineFromTimeout(logical.maxWriteTimeoutSnapshot()))
} }
func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, msg TransferMsg) error { func (s *ServerCommon) encodeEnvelopePayloadInbound(logical *LogicalConn, env Envelope, profile *transportProtectionProfile) ([]byte, error) {
if profile == nil {
return s.encodeEnvelopePayloadLogical(logical, env)
}
data, err := s.encodeEnvelopePlain(env)
if err != nil {
return nil, err
}
return encryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgEn, profile.secretKey, data)
}
func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, profile *transportProtectionProfile, msg TransferMsg) error {
if logical == nil && transport != nil { if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot() logical = transport.logicalConnSnapshot()
} }
@ -413,7 +428,7 @@ func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *Tran
if err != nil { if err != nil {
return err return err
} }
return s.sendEnvelopeInboundTransport(logical, transport, conn, env) return s.sendEnvelopeInboundTransportWithProfile(logical, transport, conn, profile, env)
} }
func (s *ServerCommon) writeEnvelopePayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte) error { func (s *ServerCommon) writeEnvelopePayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte) error {

View File

@ -110,7 +110,17 @@ func (s *ServerCommon) registerAcceptedLogical(logical *LogicalConn) *LogicalCon
} }
logical.setServer(s) logical.setServer(s)
logical.applyAttachmentProfile(s.maxReadTimeout, s.maxWriteTimeout, s.defaultMsgEn, s.defaultMsgDe, s.defaultFastStreamEncode, s.defaultFastBulkEncode, s.defaultFastPlainEncode, s.handshakeRsaKey, s.SecretKey) logical.applyAttachmentProfile(s.maxReadTimeout, s.maxWriteTimeout, s.defaultMsgEn, s.defaultMsgDe, s.defaultFastStreamEncode, s.defaultFastBulkEncode, s.defaultFastPlainEncode, s.handshakeRsaKey, s.SecretKey)
logical.updateAttachmentState(func(state *clientConnAttachmentState) {
state.authMode = s.securityAuthMode
state.peerAttached = false
state.peerAttachFallback = false
state.peerAttachAt = 0
})
if s.securityConfigured {
logical.applyTransportProtectionProfile(s.securityBootstrap)
} else {
logical.setModernPSKRuntime(s.defaultModernPSKRuntime) logical.setModernPSKRuntime(s.defaultModernPSKRuntime)
}
logical.markHeartbeatNow() logical.markHeartbeatNow()
return s.getPeerRegistry().registerLogical(logical) return s.getPeerRegistry().registerLogical(logical)
} }

View File

@ -26,6 +26,8 @@ type Server interface {
RecoverTransferSnapshots(context.Context) error RecoverTransferSnapshots(context.Context) error
SetBulkOpenTuning(BulkOpenTuning) SetBulkOpenTuning(BulkOpenTuning)
BulkOpenTuning() BulkOpenTuning BulkOpenTuning() BulkOpenTuning
SetPeerAttachSecurityConfig(PeerAttachSecurityConfig) error
PeerAttachSecurityConfig() PeerAttachSecurityConfig
SetFileReceiveDir(dir string) error SetFileReceiveDir(dir string) error
send(c *ClientConn, msg TransferMsg) (WaitMsg, error) send(c *ClientConn, msg TransferMsg) (WaitMsg, error)
sendEnvelope(c *ClientConn, env Envelope) error sendEnvelope(c *ClientConn, env Envelope) error

View File

@ -1,6 +1,7 @@
package notify package notify
import ( import (
"encoding/hex"
"errors" "errors"
"time" "time"
) )
@ -17,6 +18,19 @@ type ClientRuntimeSnapshot struct {
ConnectNetwork string ConnectNetwork string
ConnectAddress string ConnectAddress string
CanReconnect bool CanReconnect bool
AuthMode string
ProtectionMode string
ProtectionKeyMode string
ForwardSecrecyEnabled bool
ForwardSecrecyFallback bool
ForwardSecrecyRequired bool
TransportSessionID string
PeerAttachAuthenticated bool
PeerAttachAuthFallback bool
LastPeerAttachAt time.Time
PeerAttachRequireExplicitAuth bool
PeerAttachRequireChannelBinding bool
PeerAttachChannelBindingConfigured bool
BulkNetworkProfile string BulkNetworkProfile string
BulkDefaultMode string BulkDefaultMode string
BulkChunkSize int BulkChunkSize int
@ -49,6 +63,22 @@ type ServerRuntimeSnapshot struct {
HasRuntimeUDPListener bool HasRuntimeUDPListener bool
HasRuntimeQueue bool HasRuntimeQueue bool
HasRuntimeStopCtx bool HasRuntimeStopCtx bool
AuthMode string
ProtectionMode string
ForwardSecrecySupported bool
ForwardSecrecyRequired bool
PeerAttachRequireExplicitAuth bool
PeerAttachRequireChannelBinding bool
PeerAttachChannelBindingConfigured bool
PeerAttachReplayWindow time.Duration
PeerAttachReplayCapacity int
PeerAttachExplicitAuth int64
PeerAttachAuthFallbacks int64
PeerAttachAuthRejects int64
PeerAttachDowngradeRejects int64
PeerAttachBindingRejects int64
PeerAttachReplayRejects int64
PeerAttachReplayOverflowRejects int64
BulkChunkSize int BulkChunkSize int
BulkWindowBytes int BulkWindowBytes int
BulkMaxInFlight int BulkMaxInFlight int
@ -82,6 +112,15 @@ type ClientConnRuntimeSnapshot struct {
TransportDetachRemaining time.Duration TransportDetachRemaining time.Duration
TransportDetachExpired bool TransportDetachExpired bool
ReattachEligible bool ReattachEligible bool
AuthMode string
ProtectionMode string
ProtectionKeyMode string
ForwardSecrecyEnabled bool
ForwardSecrecyFallback bool
TransportSessionID string
PeerAttachAuthenticated bool
PeerAttachAuthFallback bool
LastPeerAttachAt time.Time
TransportBulkAdaptiveSoftPayloadBytes int TransportBulkAdaptiveSoftPayloadBytes int
TransportStreamAdaptiveSoftPayloadBytes int TransportStreamAdaptiveSoftPayloadBytes int
TransportStreamAdaptiveWaitThresholdBytes int TransportStreamAdaptiveWaitThresholdBytes int
@ -108,6 +147,19 @@ func (c *ClientCommon) clientRuntimeSnapshot() ClientRuntimeSnapshot {
snapshot.ConnectAddress = source.addr snapshot.ConnectAddress = source.addr
snapshot.CanReconnect = source.canReconnect() snapshot.CanReconnect = source.canReconnect()
} }
snapshot.AuthMode = authModeName(c.securityAuthMode)
snapshot.ProtectionMode = protectionModeName(c.securityProtectionMode)
protection := c.clientTransportProtectionSnapshot()
snapshot.ProtectionKeyMode = protection.keyMode
snapshot.ForwardSecrecyEnabled = protection.forwardSecrecy
snapshot.ForwardSecrecyFallback = protection.forwardSecrecyFallback
snapshot.ForwardSecrecyRequired = c.clientRequiresForwardSecrecy()
snapshot.TransportSessionID = hex.EncodeToString(protection.sessionID)
snapshot.PeerAttachAuthenticated, snapshot.PeerAttachAuthFallback, snapshot.LastPeerAttachAt = c.clientPeerAttachAuthSnapshot()
peerAttachCfg := c.peerAttachSecuritySnapshot()
snapshot.PeerAttachRequireExplicitAuth = peerAttachCfg.requireExplicitAuth
snapshot.PeerAttachRequireChannelBinding = peerAttachCfg.requireChannelBinding
snapshot.PeerAttachChannelBindingConfigured = peerAttachCfg.channelBinding != nil
snapshot.BulkNetworkProfile = bulkNetworkProfileName(c.BulkNetworkProfile()) snapshot.BulkNetworkProfile = bulkNetworkProfileName(c.BulkNetworkProfile())
snapshot.BulkDefaultMode = bulkOpenModeName(c.BulkDefaultOpenMode()) snapshot.BulkDefaultMode = bulkOpenModeName(c.BulkDefaultOpenMode())
tuning := c.BulkOpenTuning() tuning := c.BulkOpenTuning()
@ -161,6 +213,23 @@ func (s *ServerCommon) serverRuntimeSnapshot() ServerRuntimeSnapshot {
snapshot.HasRuntimeQueue = rt.queue != nil snapshot.HasRuntimeQueue = rt.queue != nil
snapshot.HasRuntimeStopCtx = rt.stopCtx != nil snapshot.HasRuntimeStopCtx = rt.stopCtx != nil
} }
snapshot.AuthMode = authModeName(s.securityAuthMode)
snapshot.ProtectionMode = protectionModeName(s.securityProtectionMode)
snapshot.ForwardSecrecySupported = s.serverSupportsForwardSecrecy()
snapshot.ForwardSecrecyRequired = s.serverRequiresForwardSecrecy()
peerAttachCfg := s.peerAttachSecuritySnapshot()
snapshot.PeerAttachRequireExplicitAuth = peerAttachCfg.requireExplicitAuth
snapshot.PeerAttachRequireChannelBinding = peerAttachCfg.requireChannelBinding
snapshot.PeerAttachChannelBindingConfigured = peerAttachCfg.channelBinding != nil
snapshot.PeerAttachReplayWindow = peerAttachCfg.replayWindow
snapshot.PeerAttachReplayCapacity = peerAttachCfg.replayCapacity
snapshot.PeerAttachExplicitAuth = s.peerAttachExplicitCount.Load()
snapshot.PeerAttachAuthFallbacks = s.peerAttachAuthFallbackCount.Load()
snapshot.PeerAttachAuthRejects = s.peerAttachAuthRejectCount.Load()
snapshot.PeerAttachDowngradeRejects = s.peerAttachDowngradeRejectCount.Load()
snapshot.PeerAttachBindingRejects = s.peerAttachBindingRejectCount.Load()
snapshot.PeerAttachReplayRejects = s.peerAttachReplayRejectCountSnapshot()
snapshot.PeerAttachReplayOverflowRejects = s.peerAttachReplayOverflowRejectCountSnapshot()
tuning := s.BulkOpenTuning() tuning := s.BulkOpenTuning()
snapshot.BulkChunkSize = tuning.ChunkSize snapshot.BulkChunkSize = tuning.ChunkSize
snapshot.BulkWindowBytes = tuning.WindowBytes snapshot.BulkWindowBytes = tuning.WindowBytes
@ -171,6 +240,7 @@ func (s *ServerCommon) serverRuntimeSnapshot() ServerRuntimeSnapshot {
func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot { func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot {
status := c.clientConnStatusSnapshot() status := c.clientConnStatusSnapshot()
attachment := c.clientConnAttachmentStateSnapshot()
now := time.Now() now := time.Now()
snapshot := ClientConnRuntimeSnapshot{ snapshot := ClientConnRuntimeSnapshot{
ClientID: c.clientConnIDSnapshot(), ClientID: c.clientConnIDSnapshot(),
@ -182,6 +252,17 @@ func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot {
TransportAttachCount: c.clientConnTransportAttachCountSnapshot(), TransportAttachCount: c.clientConnTransportAttachCountSnapshot(),
TransportDetachCount: c.clientConnTransportDetachCountSnapshot(), TransportDetachCount: c.clientConnTransportDetachCountSnapshot(),
LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(), LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(),
AuthMode: authModeName(attachment.authMode),
ProtectionMode: protectionModeName(attachment.protectionMode),
}
snapshot.PeerAttachAuthenticated = attachment.peerAttached
snapshot.PeerAttachAuthFallback = attachment.peerAttachFallback
snapshot.ProtectionKeyMode = attachment.keyMode
snapshot.ForwardSecrecyEnabled = attachment.forwardSecrecy
snapshot.ForwardSecrecyFallback = attachment.forwardSecrecyFallback
snapshot.TransportSessionID = hex.EncodeToString(attachment.sessionID)
if attachment.peerAttachAt != 0 {
snapshot.LastPeerAttachAt = time.Unix(0, attachment.peerAttachAt)
} }
if status.Err != nil { if status.Err != nil {
snapshot.Error = status.Err.Error() snapshot.Error = status.Err.Error()

View File

@ -37,6 +37,21 @@ func TestGetClientRuntimeSnapshotDefaults(t *testing.T) {
if snapshot.ConnectSource != "" || snapshot.ConnectNetwork != "" || snapshot.ConnectAddress != "" || snapshot.CanReconnect { if snapshot.ConnectSource != "" || snapshot.ConnectNetwork != "" || snapshot.ConnectAddress != "" || snapshot.CanReconnect {
t.Fatalf("unexpected default connect source snapshot: %+v", snapshot) t.Fatalf("unexpected default connect source snapshot: %+v", snapshot)
} }
if got, want := snapshot.AuthMode, "none"; got != want {
t.Fatalf("AuthMode mismatch: got %q want %q", got, want)
}
if got, want := snapshot.ProtectionMode, "managed"; got != want {
t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want)
}
if snapshot.PeerAttachAuthenticated || snapshot.PeerAttachAuthFallback {
t.Fatalf("unexpected default peer attach state: %+v", snapshot)
}
if !snapshot.LastPeerAttachAt.IsZero() {
t.Fatalf("LastPeerAttachAt mismatch: got %v want zero", snapshot.LastPeerAttachAt)
}
if snapshot.PeerAttachRequireExplicitAuth || snapshot.PeerAttachRequireChannelBinding || snapshot.PeerAttachChannelBindingConfigured {
t.Fatalf("unexpected default peer attach policy: %+v", snapshot)
}
if got, want := snapshot.BulkNetworkProfile, "default"; got != want { if got, want := snapshot.BulkNetworkProfile, "default"; got != want {
t.Fatalf("BulkNetworkProfile mismatch: got %q want %q", got, want) t.Fatalf("BulkNetworkProfile mismatch: got %q want %q", got, want)
} }
@ -117,6 +132,24 @@ func TestGetServerRuntimeSnapshotDefaults(t *testing.T) {
if !snapshot.HasRuntimeStopCtx { if !snapshot.HasRuntimeStopCtx {
t.Fatalf("HasRuntimeStopCtx mismatch: got %v want true", snapshot.HasRuntimeStopCtx) t.Fatalf("HasRuntimeStopCtx mismatch: got %v want true", snapshot.HasRuntimeStopCtx)
} }
if got, want := snapshot.AuthMode, "none"; got != want {
t.Fatalf("AuthMode mismatch: got %q want %q", got, want)
}
if got, want := snapshot.ProtectionMode, "managed"; got != want {
t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want)
}
if snapshot.PeerAttachRequireExplicitAuth || snapshot.PeerAttachRequireChannelBinding || snapshot.PeerAttachChannelBindingConfigured {
t.Fatalf("unexpected default peer attach policy: %+v", snapshot)
}
if got, want := snapshot.PeerAttachReplayWindow, peerAttachReplayTTL; got != want {
t.Fatalf("PeerAttachReplayWindow mismatch: got %s want %s", got, want)
}
if got, want := snapshot.PeerAttachReplayCapacity, defaultPeerAttachReplayCapacity; got != want {
t.Fatalf("PeerAttachReplayCapacity mismatch: got %d want %d", got, want)
}
if snapshot.PeerAttachExplicitAuth != 0 || snapshot.PeerAttachAuthFallbacks != 0 || snapshot.PeerAttachAuthRejects != 0 || snapshot.PeerAttachDowngradeRejects != 0 || snapshot.PeerAttachBindingRejects != 0 || snapshot.PeerAttachReplayRejects != 0 || snapshot.PeerAttachReplayOverflowRejects != 0 {
t.Fatalf("unexpected default peer attach counters: %+v", snapshot)
}
if got, want := snapshot.BulkChunkSize, defaultBulkChunkSize; got != want { if got, want := snapshot.BulkChunkSize, defaultBulkChunkSize; got != want {
t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want) t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want)
} }
@ -476,6 +509,134 @@ func TestGetClientConnRuntimeSnapshotExposesDetachState(t *testing.T) {
if snapshot.LastHeartbeatAt.IsZero() { if snapshot.LastHeartbeatAt.IsZero() {
t.Fatal("LastHeartbeatAt should be recorded") t.Fatal("LastHeartbeatAt should be recorded")
} }
if got, want := snapshot.AuthMode, "none"; got != want {
t.Fatalf("AuthMode mismatch: got %q want %q", got, want)
}
if got, want := snapshot.ProtectionMode, "managed"; got != want {
t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want)
}
if snapshot.PeerAttachAuthenticated || snapshot.PeerAttachAuthFallback {
t.Fatalf("unexpected peer attach state: %+v", snapshot)
}
if !snapshot.LastPeerAttachAt.IsZero() {
t.Fatalf("LastPeerAttachAt mismatch: got %v want zero", snapshot.LastPeerAttachAt)
}
}
func TestGetRuntimeSnapshotsIncludePeerAttachSecurityState(t *testing.T) {
secret := []byte("correct horse battery staple")
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
if err := UsePSKOverExternalTransportServer(server, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
}
})
client := NewClient().(*ClientCommon)
if err := UsePSKOverExternalTransportClient(client, secret, testModernPSKOptions()); err != nil {
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
}
left, right := net.Pipe()
defer right.Close()
bootstrapPeerAttachLogicalForTest(t, server, right)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
defer func() {
client.setByeFromServer(true)
_ = client.Stop()
}()
deadline := time.Now().Add(time.Second)
for {
logical := server.GetLogicalConn(client.peerIdentity)
if logical != nil {
authenticated, fallback, _ := logical.peerAttachAuthenticatedSnapshot()
if authenticated && !fallback && logical.protectionModeSnapshot() == ProtectionExternal && server.peerAttachExplicitCount.Load() == 1 {
break
}
}
if time.Now().After(deadline) {
t.Fatal("peer attach security state did not converge before snapshot")
}
time.Sleep(time.Millisecond)
}
clientSnapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if got, want := clientSnapshot.AuthMode, "psk"; got != want {
t.Fatalf("client AuthMode mismatch: got %q want %q", got, want)
}
if got, want := clientSnapshot.ProtectionMode, "external"; got != want {
t.Fatalf("client ProtectionMode mismatch: got %q want %q", got, want)
}
if clientSnapshot.PeerAttachRequireExplicitAuth || clientSnapshot.PeerAttachRequireChannelBinding || clientSnapshot.PeerAttachChannelBindingConfigured {
t.Fatalf("unexpected client peer attach policy snapshot: %+v", clientSnapshot)
}
if !clientSnapshot.PeerAttachAuthenticated || clientSnapshot.PeerAttachAuthFallback {
t.Fatalf("unexpected client peer attach state: %+v", clientSnapshot)
}
if clientSnapshot.LastPeerAttachAt.IsZero() {
t.Fatal("client LastPeerAttachAt should be recorded")
}
serverSnapshot, err := GetServerRuntimeSnapshot(server)
if err != nil {
t.Fatalf("GetServerRuntimeSnapshot failed: %v", err)
}
if got, want := serverSnapshot.AuthMode, "psk"; got != want {
t.Fatalf("server AuthMode mismatch: got %q want %q", got, want)
}
if got, want := serverSnapshot.ProtectionMode, "external"; got != want {
t.Fatalf("server ProtectionMode mismatch: got %q want %q", got, want)
}
if serverSnapshot.PeerAttachRequireExplicitAuth || serverSnapshot.PeerAttachRequireChannelBinding || serverSnapshot.PeerAttachChannelBindingConfigured {
t.Fatalf("unexpected server peer attach policy snapshot: %+v", serverSnapshot)
}
if got, want := serverSnapshot.PeerAttachExplicitAuth, int64(1); got != want {
t.Fatalf("PeerAttachExplicitAuth mismatch: got %d want %d", got, want)
}
if serverSnapshot.PeerAttachAuthFallbacks != 0 || serverSnapshot.PeerAttachAuthRejects != 0 || serverSnapshot.PeerAttachDowngradeRejects != 0 || serverSnapshot.PeerAttachBindingRejects != 0 || serverSnapshot.PeerAttachReplayRejects != 0 || serverSnapshot.PeerAttachReplayOverflowRejects != 0 {
t.Fatalf("unexpected server peer attach counters: %+v", serverSnapshot)
}
logical := server.GetLogicalConn(client.peerIdentity)
if logical == nil {
t.Fatal("server logical should exist after peer attach")
}
logicalSnapshot, err := GetLogicalConnRuntimeSnapshot(logical)
if err != nil {
t.Fatalf("GetLogicalConnRuntimeSnapshot failed: %v", err)
}
if got, want := logicalSnapshot.AuthMode, "psk"; got != want {
t.Fatalf("logical AuthMode mismatch: got %q want %q", got, want)
}
if got, want := logicalSnapshot.ProtectionMode, "external"; got != want {
t.Fatalf("logical ProtectionMode mismatch: got %q want %q", got, want)
}
if !logicalSnapshot.PeerAttachAuthenticated || logicalSnapshot.PeerAttachAuthFallback {
t.Fatalf("unexpected logical peer attach state: %+v", logicalSnapshot)
}
if logicalSnapshot.LastPeerAttachAt.IsZero() {
t.Fatal("logical LastPeerAttachAt should be recorded")
}
clientConnSnapshot, err := GetClientConnRuntimeSnapshot(clientConnFromLogical(logical))
if err != nil {
t.Fatalf("GetClientConnRuntimeSnapshot failed: %v", err)
}
if got, want := clientConnSnapshot.AuthMode, "psk"; got != want {
t.Fatalf("client conn AuthMode mismatch: got %q want %q", got, want)
}
if got, want := clientConnSnapshot.ProtectionMode, "external"; got != want {
t.Fatalf("client conn ProtectionMode mismatch: got %q want %q", got, want)
}
if !clientConnSnapshot.PeerAttachAuthenticated || clientConnSnapshot.PeerAttachAuthFallback {
t.Fatalf("unexpected client conn peer attach state: %+v", clientConnSnapshot)
}
if clientConnSnapshot.LastPeerAttachAt.IsZero() {
t.Fatal("client conn LastPeerAttachAt should be recorded")
}
} }
func TestGetServerDetachedClientRuntimeSnapshotsFiltersAndSorts(t *testing.T) { func TestGetServerDetachedClientRuntimeSnapshotsFiltersAndSorts(t *testing.T) {

104
stream.go
View File

@ -89,6 +89,60 @@ type streamCloseSender func(context.Context, *streamHandle, bool) error
type streamResetSender func(context.Context, *streamHandle, string) error type streamResetSender func(context.Context, *streamHandle, string) error
type streamDataSender func(context.Context, *streamHandle, []byte) error type streamDataSender func(context.Context, *streamHandle, []byte) error
type streamReadChunk struct {
data []byte
release func()
}
func (c *streamReadChunk) clear() {
if c == nil {
return
}
if c.release != nil {
c.release()
}
c.data = nil
c.release = nil
}
type streamReadPayloadOwner struct {
refs atomic.Int32
release func()
}
func newStreamReadPayloadOwner(release func()) *streamReadPayloadOwner {
if release == nil {
return nil
}
owner := &streamReadPayloadOwner{release: release}
owner.refs.Store(1)
return owner
}
func (o *streamReadPayloadOwner) retainChunk() func() {
if o == nil {
return nil
}
o.refs.Add(1)
return o.releaseChunk
}
func (o *streamReadPayloadOwner) releaseChunk() {
if o == nil {
return
}
if o.refs.Add(-1) == 0 && o.release != nil {
o.release()
}
}
func (o *streamReadPayloadOwner) done() {
if o == nil {
return
}
o.releaseChunk()
}
type streamHandle struct { type streamHandle struct {
runtime *streamRuntime runtime *streamRuntime
runtimeScope string runtimeScope string
@ -124,8 +178,8 @@ type streamHandle struct {
remoteClosed bool remoteClosed bool
peerReadClosed bool peerReadClosed bool
resetErr error resetErr error
readQueue [][]byte readQueue []streamReadChunk
readBuf []byte readBuf streamReadChunk
bufferedBytes int bufferedBytes int
readNotify chan struct{} readNotify chan struct{}
readDeadline time.Time readDeadline time.Time
@ -339,20 +393,23 @@ func (s *streamHandle) Read(p []byte) (int, error) {
for { for {
s.mu.Lock() s.mu.Lock()
localReadClosed := s.localReadClosed localReadClosed := s.localReadClosed
if len(s.readBuf) > 0 { if len(s.readBuf.data) > 0 {
n := copy(p, s.readBuf) n := copy(p, s.readBuf.data)
s.readBuf = s.readBuf[n:] s.readBuf.data = s.readBuf.data[n:]
s.bufferedBytes -= n s.bufferedBytes -= n
if s.bufferedBytes < 0 { if s.bufferedBytes < 0 {
s.bufferedBytes = 0 s.bufferedBytes = 0
} }
if len(s.readBuf.data) == 0 {
s.readBuf.clear()
}
s.recordReadLocked(n, time.Now()) s.recordReadLocked(n, time.Now())
s.mu.Unlock() s.mu.Unlock()
return n, nil return n, nil
} }
if len(s.readQueue) > 0 { if len(s.readQueue) > 0 {
s.readBuf = s.readQueue[0] s.readBuf = s.readQueue[0]
s.readQueue[0] = nil s.readQueue[0] = streamReadChunk{}
s.readQueue = s.readQueue[1:] s.readQueue = s.readQueue[1:]
s.mu.Unlock() s.mu.Unlock()
continue continue
@ -824,43 +881,61 @@ func (s *streamHandle) pushOwnedChunk(chunk []byte) error {
return s.pushChunkWithOwnership(chunk, true) return s.pushChunkWithOwnership(chunk, true)
} }
func (s *streamHandle) pushOwnedChunkWithRelease(chunk []byte, release func()) error {
return s.pushChunkWithOwnershipAndRelease(chunk, true, release)
}
func (s *streamHandle) pushChunkWithOwnership(chunk []byte, owned bool) error { func (s *streamHandle) pushChunkWithOwnership(chunk []byte, owned bool) error {
return s.pushChunkWithOwnershipAndRelease(chunk, owned, nil)
}
func (s *streamHandle) pushChunkWithOwnershipAndRelease(chunk []byte, owned bool, release func()) error {
if s == nil { if s == nil {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
if len(chunk) == 0 { if len(chunk) == 0 {
if release != nil {
release()
}
return nil return nil
} }
stored := chunk stored := streamReadChunk{data: chunk, release: release}
if !owned { if !owned {
stored = append([]byte(nil), chunk...) stored.data = append([]byte(nil), chunk...)
if stored.release != nil {
stored.release()
stored.release = nil
}
} }
s.mu.Lock() s.mu.Lock()
if s.resetErr != nil { if s.resetErr != nil {
err := s.resetErr err := s.resetErr
s.mu.Unlock() s.mu.Unlock()
stored.clear()
return err return err
} }
if s.inboundQueueLimit > 0 && s.bufferedChunkCountLocked() >= s.inboundQueueLimit { if s.inboundQueueLimit > 0 && s.bufferedChunkCountLocked() >= s.inboundQueueLimit {
err := s.markResetLocked(errStreamBackpressureExceeded) err := s.markResetLocked(errStreamBackpressureExceeded)
s.mu.Unlock() s.mu.Unlock()
stored.clear()
s.notifyReadable() s.notifyReadable()
s.finalize() s.finalize()
return err return err
} }
if s.inboundBytesLimit > 0 && s.bufferedBytes+len(stored) > s.inboundBytesLimit { if s.inboundBytesLimit > 0 && s.bufferedBytes+len(stored.data) > s.inboundBytesLimit {
err := s.markResetLocked(errStreamBackpressureExceeded) err := s.markResetLocked(errStreamBackpressureExceeded)
s.mu.Unlock() s.mu.Unlock()
stored.clear()
s.notifyReadable() s.notifyReadable()
s.finalize() s.finalize()
return err return err
} }
if len(s.readBuf) == 0 && len(s.readQueue) == 0 { if len(s.readBuf.data) == 0 && len(s.readQueue) == 0 {
s.readBuf = stored s.readBuf = stored
} else { } else {
s.readQueue = append(s.readQueue, stored) s.readQueue = append(s.readQueue, stored)
} }
s.bufferedBytes += len(stored) s.bufferedBytes += len(stored.data)
s.notifyReadableLocked() s.notifyReadableLocked()
s.mu.Unlock() s.mu.Unlock()
return nil return nil
@ -881,11 +956,12 @@ func (s *streamHandle) clearBufferedDataLocked() {
if s == nil { if s == nil {
return return
} }
s.readBuf.clear()
for i := range s.readQueue { for i := range s.readQueue {
s.readQueue[i] = nil s.readQueue[i].clear()
} }
s.readQueue = nil s.readQueue = nil
s.readBuf = nil s.readBuf = streamReadChunk{}
s.bufferedBytes = 0 s.bufferedBytes = 0
} }
@ -894,7 +970,7 @@ func (s *streamHandle) bufferedChunkCountLocked() int {
return 0 return 0
} }
count := len(s.readQueue) count := len(s.readQueue)
if len(s.readBuf) > 0 { if len(s.readBuf.data) > 0 {
count++ count++
} }
return count return count

View File

@ -56,7 +56,59 @@ func BenchmarkStreamTCPThroughput(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg) benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg, benchmarkTransportSecurityModernPSK)
})
}
}
func BenchmarkStreamTCPThroughputTrustedRaw(b *testing.B) {
cases := []struct {
name string
payloadSize int
cfg StreamConfig
}{
{
name: "default_64KiB",
payloadSize: 64 * 1024,
},
{
name: "tuned_256KiB",
payloadSize: 256 * 1024,
cfg: StreamConfig{
ChunkSize: 256 * 1024,
InboundQueueLimit: 256,
InboundBufferedBytesLimit: 32 * 1024 * 1024,
OutboundWindowBytes: 8 * 1024 * 1024,
OutboundMaxInFlightChunks: 32,
},
},
{
name: "tuned_512KiB",
payloadSize: 512 * 1024,
cfg: StreamConfig{
ChunkSize: 512 * 1024,
InboundQueueLimit: 256,
InboundBufferedBytesLimit: 64 * 1024 * 1024,
OutboundWindowBytes: 16 * 1024 * 1024,
OutboundMaxInFlightChunks: 32,
},
},
{
name: "tuned_1MiB",
payloadSize: 1024 * 1024,
cfg: StreamConfig{
ChunkSize: 1024 * 1024,
InboundQueueLimit: 256,
InboundBufferedBytesLimit: 64 * 1024 * 1024,
OutboundWindowBytes: 16 * 1024 * 1024,
OutboundMaxInFlightChunks: 32,
},
},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg, benchmarkTransportSecurityTrustedRaw)
}) })
} }
} }
@ -108,19 +160,69 @@ func BenchmarkStreamTCPThroughputConcurrent(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) { b.Run(tc.name, func(b *testing.B) {
benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg) benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg, benchmarkTransportSecurityModernPSK)
}) })
} }
} }
func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfig) { func BenchmarkStreamTCPThroughputConcurrentTrustedRaw(b *testing.B) {
cases := []struct {
name string
payloadSize int
concurrency int
cfg StreamConfig
}{
{
name: "streams_2_512KiB",
payloadSize: 512 * 1024,
concurrency: 2,
cfg: StreamConfig{
ChunkSize: 512 * 1024,
InboundQueueLimit: 512,
InboundBufferedBytesLimit: 128 * 1024 * 1024,
OutboundWindowBytes: 32 * 1024 * 1024,
OutboundMaxInFlightChunks: 64,
},
},
{
name: "streams_4_512KiB",
payloadSize: 512 * 1024,
concurrency: 4,
cfg: StreamConfig{
ChunkSize: 512 * 1024,
InboundQueueLimit: 1024,
InboundBufferedBytesLimit: 256 * 1024 * 1024,
OutboundWindowBytes: 64 * 1024 * 1024,
OutboundMaxInFlightChunks: 128,
},
},
{
name: "streams_8_512KiB",
payloadSize: 512 * 1024,
concurrency: 8,
cfg: StreamConfig{
ChunkSize: 512 * 1024,
InboundQueueLimit: 2048,
InboundBufferedBytesLimit: 512 * 1024 * 1024,
OutboundWindowBytes: 128 * 1024 * 1024,
OutboundMaxInFlightChunks: 256,
},
},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg, benchmarkTransportSecurityTrustedRaw)
})
}
}
func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfig, securityMode benchmarkTransportSecurityMode) {
b.Helper() b.Helper()
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
server.SetStreamConfig(cfg) server.SetStreamConfig(cfg)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyServerTransportSecurity(b, server, securityMode)
b.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, 1) acceptCh := make(chan StreamAcceptInfo, 1)
server.SetStreamHandler(func(info StreamAcceptInfo) error { server.SetStreamHandler(func(info StreamAcceptInfo) error {
@ -137,9 +239,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi
client := NewClient().(*ClientCommon) client := NewClient().(*ClientCommon)
client.SetStreamConfig(cfg) client.SetStreamConfig(cfg)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyClientTransportSecurity(b, client, securityMode)
b.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil { if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
b.Fatalf("client Connect failed: %v", err) b.Fatalf("client Connect failed: %v", err)
} }
@ -198,7 +298,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi
_ = stream.Close() _ = stream.Close()
} }
func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, cfg StreamConfig) { func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, cfg StreamConfig, securityMode benchmarkTransportSecurityMode) {
b.Helper() b.Helper()
if concurrency <= 0 { if concurrency <= 0 {
b.Fatal("concurrency must be > 0") b.Fatal("concurrency must be > 0")
@ -206,9 +306,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu
server := NewServer().(*ServerCommon) server := NewServer().(*ServerCommon)
server.SetStreamConfig(cfg) server.SetStreamConfig(cfg)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyServerTransportSecurity(b, server, securityMode)
b.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan StreamAcceptInfo, concurrency*2) acceptCh := make(chan StreamAcceptInfo, concurrency*2)
server.SetStreamHandler(func(info StreamAcceptInfo) error { server.SetStreamHandler(func(info StreamAcceptInfo) error {
@ -225,9 +323,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu
client := NewClient().(*ClientCommon) client := NewClient().(*ClientCommon)
client.SetStreamConfig(cfg) client.SetStreamConfig(cfg)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { benchmarkApplyClientTransportSecurity(b, client, securityMode)
b.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil { if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
b.Fatalf("client Connect failed: %v", err) b.Fatalf("client Connect failed: %v", err)
} }

View File

@ -0,0 +1,85 @@
package notify
import (
"context"
"errors"
"testing"
"time"
)
func TestStreamOwnedChunkReleaseAfterRead(t *testing.T) {
stream := newStreamHandle(context.Background(), newStreamRuntime("stream-buffer-release-read"), clientFileScope(), StreamOpenRequest{
StreamID: "stream-buffer-release-read",
DataID: 1,
}, 0, nil, nil, 0, nil, nil, nil, streamConfig{})
released := 0
if err := stream.pushOwnedChunkWithRelease([]byte("hello"), func() {
released++
}); err != nil {
t.Fatalf("pushOwnedChunkWithRelease failed: %v", err)
}
buf := make([]byte, 5)
n, err := stream.Read(buf)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if n != 5 || string(buf[:n]) != "hello" {
t.Fatalf("Read = %d %q, want 5 hello", n, string(buf[:n]))
}
if released != 1 {
t.Fatalf("release count = %d, want 1", released)
}
}
func TestStreamOwnedChunkReleaseOnReset(t *testing.T) {
stream := newStreamHandle(context.Background(), newStreamRuntime("stream-buffer-release-reset"), clientFileScope(), StreamOpenRequest{
StreamID: "stream-buffer-release-reset",
DataID: 1,
}, 0, nil, nil, 0, nil, nil, nil, streamConfig{})
released := 0
if err := stream.pushOwnedChunkWithRelease([]byte("hello"), func() {
released++
}); err != nil {
t.Fatalf("pushOwnedChunkWithRelease failed: %v", err)
}
stream.markReset(errors.New("boom"))
if released != 1 {
t.Fatalf("release count = %d, want 1", released)
}
}
func TestClientDispatchFastStreamDataWithOwnerReleasesAfterRead(t *testing.T) {
client := NewClient().(*ClientCommon)
runtime := client.getStreamRuntime()
if runtime == nil {
t.Fatal("client stream runtime should not be nil")
}
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
StreamID: "stream-owner",
DataID: 23,
Channel: StreamDataChannel,
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
if err := runtime.register(clientFileScope(), stream); err != nil {
t.Fatalf("register stream failed: %v", err)
}
released := 0
owner := newStreamReadPayloadOwner(func() {
released++
})
client.dispatchFastStreamDataWithOwner(streamFastDataFrame{
DataID: 23,
Seq: 1,
Payload: []byte("fast-owner"),
}, owner)
owner.done()
readStreamExactly(t, stream, "fast-owner", 2*time.Second)
if released != 1 {
t.Fatalf("release count = %d, want 1", released)
}
}

View File

@ -83,6 +83,10 @@ func (s *ServerCommon) dispatchStreamEnvelope(logical *LogicalConn, transport *T
} }
func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) { func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
c.dispatchFastStreamDataWithOwner(frame, nil)
}
func (c *ClientCommon) dispatchFastStreamDataWithOwner(frame streamFastDataFrame, owner *streamReadPayloadOwner) {
if frame.DataID == 0 { if frame.DataID == 0 {
return return
} }
@ -107,7 +111,13 @@ func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
c.bestEffortRejectInboundStreamData(stream.ID(), frame.DataID, detachErr.Error()) c.bestEffortRejectInboundStreamData(stream.ID(), frame.DataID, detachErr.Error())
return return
} }
if err := stream.pushOwnedChunk(frame.Payload); err != nil { var err error
if owner != nil {
err = stream.pushOwnedChunkWithRelease(frame.Payload, owner.retainChunk())
} else {
err = stream.pushOwnedChunk(frame.Payload)
}
if err != nil {
if c.showError || c.debugMode { if c.showError || c.debugMode {
fmt.Println("client stream push chunk error", err) fmt.Println("client stream push chunk error", err)
} }
@ -118,6 +128,10 @@ func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
} }
func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame) { func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame) {
s.dispatchFastStreamDataWithOwner(logical, transport, conn, frame, nil)
}
func (s *ServerCommon) dispatchFastStreamDataWithOwner(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame, owner *streamReadPayloadOwner) {
if logical == nil || frame.DataID == 0 { if logical == nil || frame.DataID == 0 {
return return
} }
@ -141,7 +155,13 @@ func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *T
s.bestEffortRejectInboundStreamData(logical, transport, conn, stream.ID(), frame.DataID, detachErr.Error()) s.bestEffortRejectInboundStreamData(logical, transport, conn, stream.ID(), frame.DataID, detachErr.Error())
return return
} }
if err := stream.pushOwnedChunk(frame.Payload); err != nil { var err error
if owner != nil {
err = stream.pushOwnedChunkWithRelease(frame.Payload, owner.retainChunk())
} else {
err = stream.pushOwnedChunk(frame.Payload)
}
if err != nil {
if s.showError || s.debugMode { if s.showError || s.debugMode {
fmt.Println("server stream push chunk error", err) fmt.Println("server stream push chunk error", err)
} }

View File

@ -156,11 +156,12 @@ func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error
} }
func (c *ClientCommon) encodeFastStreamPayload(frame streamFastDataFrame) ([]byte, error) { func (c *ClientCommon) encodeFastStreamPayload(frame streamFastDataFrame) ([]byte, error) {
if c != nil && c.fastStreamEncode != nil && frame.Flags == 0 { profile := c.clientTransportProtectionSnapshot()
return c.fastStreamEncode(c.SecretKey, frame.DataID, frame.Seq, frame.Payload) if c != nil && profile.fastStreamEncode != nil && frame.Flags == 0 {
return profile.fastStreamEncode(profile.secretKey, frame.DataID, frame.Seq, frame.Payload)
} }
if c != nil && c.fastPlainEncode != nil { if c != nil && profile.fastPlainEncode != nil {
return encodeStreamFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame) return encodeStreamFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
} }
plain, err := encodeStreamFastFramePayload(frame) plain, err := encodeStreamFastFramePayload(frame)
if err != nil { if err != nil {
@ -181,8 +182,9 @@ func (c *ClientCommon) encodeFastStreamBatchPayload(frames []streamFastDataFrame
if c == nil { if c == nil {
return nil, errStreamClientNil return nil, errStreamClientNil
} }
if c.fastPlainEncode != nil { profile := c.clientTransportProtectionSnapshot()
return encodeStreamFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames) if profile.fastPlainEncode != nil {
return encodeStreamFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
} }
plain, err := encodeStreamFastBatchPlain(frames) plain, err := encodeStreamFastBatchPlain(frames)
if err != nil { if err != nil {

View File

@ -91,25 +91,24 @@ func writeStreamFastBatchPlain(dst []byte, frames []streamFastDataFrame) error {
return nil return nil
} }
func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, error) { func walkStreamFastBatchPlain(payload []byte, fn func(streamFastDataFrame) error) (bool, error) {
if len(payload) < 4 || string(payload[:4]) != streamFastBatchMagic { if len(payload) < 4 || string(payload[:4]) != streamFastBatchMagic {
return nil, false, nil return false, nil
} }
if len(payload) < streamFastBatchHeaderLen { if len(payload) < streamFastBatchHeaderLen {
return nil, true, errStreamFastPayloadInvalid return true, errStreamFastPayloadInvalid
} }
if payload[4] != streamFastBatchVersion { if payload[4] != streamFastBatchVersion {
return nil, true, errStreamFastPayloadInvalid return true, errStreamFastPayloadInvalid
} }
count := int(binary.BigEndian.Uint32(payload[8:12])) count := int(binary.BigEndian.Uint32(payload[8:12]))
if count <= 0 { if count <= 0 {
return nil, true, errStreamFastPayloadInvalid return true, errStreamFastPayloadInvalid
} }
frames := make([]streamFastDataFrame, 0, count)
offset := streamFastBatchHeaderLen offset := streamFastBatchHeaderLen
for index := 0; index < count; index++ { for index := 0; index < count; index++ {
if len(payload)-offset < streamFastBatchItemHeaderLen { if len(payload)-offset < streamFastBatchItemHeaderLen {
return nil, true, errStreamFastPayloadInvalid return true, errStreamFastPayloadInvalid
} }
flags := payload[offset] flags := payload[offset]
dataID := binary.BigEndian.Uint64(payload[offset+4 : offset+12]) dataID := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
@ -117,29 +116,62 @@ func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, er
payloadLen := int(binary.BigEndian.Uint32(payload[offset+20 : offset+24])) payloadLen := int(binary.BigEndian.Uint32(payload[offset+20 : offset+24]))
offset += streamFastBatchItemHeaderLen offset += streamFastBatchItemHeaderLen
if dataID == 0 || payloadLen < 0 || len(payload)-offset < payloadLen { if dataID == 0 || payloadLen < 0 || len(payload)-offset < payloadLen {
return nil, true, errStreamFastPayloadInvalid return true, errStreamFastPayloadInvalid
} }
frames = append(frames, streamFastDataFrame{ if fn != nil {
if err := fn(streamFastDataFrame{
Flags: flags, Flags: flags,
DataID: dataID, DataID: dataID,
Seq: seq, Seq: seq,
Payload: payload[offset : offset+payloadLen], Payload: payload[offset : offset+payloadLen],
}) }); err != nil {
return true, err
}
}
offset += payloadLen offset += payloadLen
} }
if offset != len(payload) { if offset != len(payload) {
return nil, true, errStreamFastPayloadInvalid return true, errStreamFastPayloadInvalid
}
return true, nil
}
func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, error) {
frames := make([]streamFastDataFrame, 0, 1)
matched, err := walkStreamFastBatchPlain(payload, func(frame streamFastDataFrame) error {
frames = append(frames, frame)
return nil
})
if !matched || err != nil {
return nil, matched, err
} }
return frames, true, nil return frames, true, nil
} }
func decodeStreamFastDataFrames(payload []byte) ([]streamFastDataFrame, bool, error) { func walkStreamFastFrames(payload []byte, fn func(streamFastDataFrame) error) (bool, error) {
if frames, matched, err := decodeStreamFastBatchPlain(payload); matched { if matched, err := walkStreamFastBatchPlain(payload, fn); matched {
return frames, true, err return true, err
} }
frame, matched, err := decodeStreamFastDataFrame(payload) frame, matched, err := decodeStreamFastDataFrame(payload)
if !matched || err != nil {
return matched, err
}
if fn != nil {
if err := fn(frame); err != nil {
return true, err
}
}
return true, nil
}
func decodeStreamFastDataFrames(payload []byte) ([]streamFastDataFrame, bool, error) {
frames := make([]streamFastDataFrame, 0, 1)
matched, err := walkStreamFastFrames(payload, func(frame streamFastDataFrame) error {
frames = append(frames, frame)
return nil
})
if !matched || err != nil { if !matched || err != nil {
return nil, matched, err return nil, matched, err
} }
return []streamFastDataFrame{frame}, true, nil return frames, true, nil
} }

View File

@ -9,7 +9,10 @@ var (
errTransportPayloadDecryptFailed = errors.New("transport payload decrypt failed") errTransportPayloadDecryptFailed = errors.New("transport payload decrypt failed")
) )
func encryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgEn func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) { func encryptTransportPayloadCodec(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgEn func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) {
if mode == ProtectionExternal {
return data, nil
}
if runtime != nil { if runtime != nil {
encoded, err := runtime.sealPlainPayload(data) encoded, err := runtime.sealPlainPayload(data)
if err != nil { if err != nil {
@ -27,7 +30,10 @@ func encryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgEn func([]b
return encoded, nil return encoded, nil
} }
func decryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) { func decryptTransportPayloadCodec(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) {
if mode == ProtectionExternal {
return data, nil
}
if runtime != nil { if runtime != nil {
plain, err := runtime.openPayload(data) plain, err := runtime.openPayload(data)
if err != nil { if err != nil {
@ -45,7 +51,10 @@ func decryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgDe func([]b
return plain, nil return plain, nil
} }
func decryptTransportPayloadCodecPooled(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte, release func()) ([]byte, func(), error) { func decryptTransportPayloadCodecPooled(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte, release func()) ([]byte, func(), error) {
if mode == ProtectionExternal {
return data, release, nil
}
if runtime != nil { if runtime != nil {
plain, plainRelease, err := runtime.openPayloadPooled(data, release) plain, plainRelease, err := runtime.openPayloadPooled(data, release)
if err != nil { if err != nil {
@ -69,7 +78,10 @@ func decryptTransportPayloadCodecPooled(runtime *modernPSKCodecRuntime, msgDe fu
return plain, nil, nil return plain, nil, nil
} }
func decryptTransportPayloadCodecOwnedPooled(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, func(), error) { func decryptTransportPayloadCodecOwnedPooled(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, func(), error) {
if mode == ProtectionExternal {
return data, nil, nil
}
if runtime != nil { if runtime != nil {
plain, plainRelease, err := runtime.openPayloadOwnedPooled(data) plain, plainRelease, err := runtime.openPayloadOwnedPooled(data)
if err != nil { if err != nil {
@ -124,9 +136,14 @@ func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte
if err != nil { if err != nil {
return nil, err return nil, err
} }
logical := logicalConnFromClient(c)
msgEn := c.clientConnMsgEnSnapshot() msgEn := c.clientConnMsgEnSnapshot()
secretKey := c.clientConnSecretKeySnapshot() secretKey := c.clientConnSecretKeySnapshot()
data, err = encryptTransportPayloadCodec(c.clientConnModernPSKRuntimeSnapshot(), msgEn, secretKey, data) mode := ProtectionManaged
if logical != nil {
mode = logical.protectionModeSnapshot()
}
data, err = encryptTransportPayloadCodec(mode, c.clientConnModernPSKRuntimeSnapshot(), msgEn, secretKey, data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -140,7 +157,11 @@ func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte
func (s *ServerCommon) decodeTransferMsg(c *ClientConn, data []byte) (TransferMsg, error) { func (s *ServerCommon) decodeTransferMsg(c *ClientConn, data []byte) (TransferMsg, error) {
msgDe := c.clientConnMsgDeSnapshot() msgDe := c.clientConnMsgDeSnapshot()
secretKey := c.clientConnSecretKeySnapshot() secretKey := c.clientConnSecretKeySnapshot()
plain, err := decryptTransportPayloadCodec(c.clientConnModernPSKRuntimeSnapshot(), msgDe, secretKey, data) mode := ProtectionManaged
if logical := logicalConnFromClient(c); logical != nil {
mode = logical.protectionModeSnapshot()
}
plain, err := decryptTransportPayloadCodec(mode, c.clientConnModernPSKRuntimeSnapshot(), msgDe, secretKey, data)
if err != nil { if err != nil {
return TransferMsg{}, err return TransferMsg{}, err
} }
@ -172,7 +193,8 @@ func (c *ClientCommon) encodeEnvelopePlain(env Envelope) ([]byte, error) {
} }
func (c *ClientCommon) encryptTransportPayload(data []byte) ([]byte, error) { func (c *ClientCommon) encryptTransportPayload(data []byte) ([]byte, error) {
return encryptTransportPayloadCodec(c.modernPSKRuntime, c.msgEn, c.SecretKey, data) profile := c.clientTransportProtectionSnapshot()
return encryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgEn, profile.secretKey, data)
} }
func (c *ClientCommon) encodeEnvelope(env Envelope) ([]byte, error) { func (c *ClientCommon) encodeEnvelope(env Envelope) ([]byte, error) {
@ -196,7 +218,8 @@ func (c *ClientCommon) decodeEnvelope(data []byte) (Envelope, error) {
} }
func (c *ClientCommon) decryptTransportPayload(data []byte) ([]byte, error) { func (c *ClientCommon) decryptTransportPayload(data []byte) ([]byte, error) {
return decryptTransportPayloadCodec(c.modernPSKRuntime, c.msgDe, c.SecretKey, data) profile := c.clientTransportProtectionSnapshot()
return decryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, data)
} }
func (c *ClientCommon) decodeEnvelopePlain(data []byte) (Envelope, error) { func (c *ClientCommon) decodeEnvelopePlain(data []byte) (Envelope, error) {
@ -251,7 +274,7 @@ func (s *ServerCommon) encryptTransportPayloadLogical(logical *LogicalConn, data
if msgEn == nil { if msgEn == nil {
return nil, errTransportDetached return nil, errTransportDetached
} }
return encryptTransportPayloadCodec(logical.modernPSKRuntimeSnapshot(), msgEn, secretKey, data) return encryptTransportPayloadCodec(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), msgEn, secretKey, data)
} }
func (s *ServerCommon) encodeEnvelopeLogical(logical *LogicalConn, env Envelope) ([]byte, error) { func (s *ServerCommon) encodeEnvelopeLogical(logical *LogicalConn, env Envelope) ([]byte, error) {
@ -290,7 +313,7 @@ func (s *ServerCommon) decryptTransportPayloadLogical(logical *LogicalConn, data
if msgDe == nil { if msgDe == nil {
return nil, errTransportDetached return nil, errTransportDetached
} }
return decryptTransportPayloadCodec(logical.modernPSKRuntimeSnapshot(), msgDe, secretKey, data) return decryptTransportPayloadCodec(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), msgDe, secretKey, data)
} }
func (s *ServerCommon) decodeEnvelopePlain(data []byte) (Envelope, error) { func (s *ServerCommon) decodeEnvelopePlain(data []byte) (Envelope, error) {

View File

@ -143,6 +143,43 @@ func TestStreamBatchSenderRespectsBindingWriteDeadlineWhenReceiverStalls(t *test
} }
} }
func TestBulkBatchSenderFlushAggregatesAdaptivePayloadObservation(t *testing.T) {
conn := &delayedWriteConn{delay: 20 * time.Millisecond}
binding := newTransportBinding(conn, stario.NewQueue())
sender := newTestBulkBatchSender(binding)
payloadA := bytes.Repeat([]byte("a"), 128*1024)
payloadB := bytes.Repeat([]byte("b"), 128*1024)
err := sender.flush([]bulkBatchRequest{
{
ctx: context.Background(),
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 1,
Seq: 1,
Payload: payloadA,
}},
fastPathVersion: bulkFastPathVersionV1,
},
{
ctx: context.Background(),
frames: []bulkFastFrame{{
Type: bulkFastPayloadTypeData,
DataID: 2,
Seq: 1,
Payload: payloadB,
}},
fastPathVersion: bulkFastPathVersionV1,
},
})
if err != nil {
t.Fatalf("flush failed: %v", err)
}
if got, want := binding.bulkAdaptiveSoftPayloadBytesSnapshot(), bulkAdaptiveSoftPayloadMinBytes; got != want {
t.Fatalf("adaptive bulk soft payload = %d, want %d", got, want)
}
}
func TestControlBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) { func TestControlBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) {
left, right := net.Pipe() left, right := net.Pipe()
defer left.Close() defer left.Close()
@ -255,6 +292,10 @@ func (c *vectoredShortWriteConn) WriteBuffers(bufs *net.Buffers) (int64, error)
return written, nil return written, nil
} }
func (c *vectoredShortWriteConn) writeBuffers(bufs *net.Buffers) (int64, error) {
return c.WriteBuffers(bufs)
}
type unwrapVectoredConn struct { type unwrapVectoredConn struct {
inner net.Conn inner net.Conn
} }