From 98ef9e7fcc11802544b3d294c01df774be5882cf Mon Sep 17 00:00:00 2001 From: starainrt Date: Mon, 20 Apr 2026 16:35:44 +0800 Subject: [PATCH] =?UTF-8?q?feat(transport):=20=E5=AE=8C=E6=88=90=E5=AE=89?= =?UTF-8?q?=E5=85=A8=E6=9E=B6=E6=9E=84=E6=8B=86=E5=88=86=E5=B9=B6=E6=94=B6?= =?UTF-8?q?=E5=8F=A3=20stream/bulk=20=E4=BC=A0=E8=BE=93=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 managed/external/nested 三种传输保护模式 - 新增 peer attach 显式认证、抗重放、channel binding 和可选前向保密协商 - 明确单连接注入与可重拨连接源的语义边界 - 禁止 ConnectByConn 场景下 dedicated bulk 走 sidecar,auto 模式自动回退 shared - 修正 dedicated attach 在 bootstrap/steady profile 切换下的处理逻辑 - 优化 shared bulk super-batch 与批量 framed write 路径 - 降低 stream/bulk fast path 的复制和分发损耗 - 补齐 benchmark、回归测试、运行时快照和 README 文档 --- README.md | 54 +++ benchmark_transport_security_test.go | 42 +++ bulk.go | 4 + bulk_batch_sender.go | 20 +- bulk_benchmark_test.go | 100 ++++- bulk_dedicated.go | 51 ++- bulk_dedicated_attach_test.go | 70 ++++ bulk_dedicated_lane_sender.go | 124 +++++- bulk_dedicated_lane_sender_test.go | 51 +++ bulk_dispatcher.go | 5 +- bulk_fastpath.go | 177 +++++++-- bulk_transport_guard_test.go | 170 +++++++++ client.go | 152 ++++---- client_bulk.go | 2 +- client_config.go | 46 ++- client_conn_attachment.go | 68 +++- client_connect_source.go | 37 +- client_legacy_security.go | 6 +- client_runtime.go | 3 + client_transport.go | 25 +- clienttype.go | 2 + logical_conn.go | 25 +- logical_transport_peer_fields_test.go | 14 + msg.go | 39 +- peer_attach_auth.go | 517 ++++++++++++++++++++++++++ peer_attach_auth_test.go | 237 ++++++++++++ peer_attach_policy.go | 139 +++++++ peer_attach_policy_test.go | 221 +++++++++++ peer_identity.go | 78 +++- peer_identity_test.go | 10 +- security_forward_secrecy.go | 139 +++++++ security_profile.go | 408 ++++++++++++++++++++ security_psk.go | 201 +++++++--- security_psk_test.go | 173 +++++++++ server.go | 107 +++--- server_bulk.go | 2 +- server_config.go | 47 ++- server_inbound_reply_test.go | 100 +++++ server_inbound_source.go | 28 +- server_send.go | 23 +- server_session.go | 12 +- servertype.go | 2 + session_runtime_snapshot.go | 113 +++++- session_runtime_snapshot_test.go | 161 ++++++++ stream.go | 104 +++++- stream_benchmark_test.go | 128 ++++++- stream_buffer_release_test.go | 85 +++++ stream_dispatcher.go | 24 +- stream_fastpath.go | 14 +- stream_shared_batch.go | 70 +++- transport_codec.go | 43 ++- transport_write_test.go | 41 ++ 52 files changed, 4069 insertions(+), 445 deletions(-) create mode 100644 benchmark_transport_security_test.go create mode 100644 peer_attach_auth.go create mode 100644 peer_attach_auth_test.go create mode 100644 peer_attach_policy.go create mode 100644 peer_attach_policy_test.go create mode 100644 security_forward_secrecy.go create mode 100644 security_profile.go create mode 100644 stream_buffer_release_test.go diff --git a/README.md b/README.md index 204b70c..da2af05 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,21 @@ 未配置时会返回 `errModernPSKRequired`。 +## 安全模式选择 + +- `UseModernPSKClient` / `UseModernPSKServer` + - bootstrap 和稳态传输都由 `notify` 自己保护 + - 适合默认场景 + - 支持 peer attach 显式认证、抗重放,以及在需要时协商前向保密 +- `UsePSKOverExternalTransportClient` / `UsePSKOverExternalTransportServer` + - bootstrap 仍用 PSK 做认证 + - 稳态阶段信任外部物理通道,不再做 `notify` 内层加密 + - 适合 `tls.Conn` 或调用方自认可信的外部通道 + - 不支持 `RequireForwardSecrecy` +- `UseNestedSecurityClient` / `UseNestedSecurityServer` + - 外层已有可信通道,但仍保留 `notify` 内层保护 + - 适合需要“外层可信 + 内层独立保护”的场景 + ## 快速开始 服务端: @@ -83,6 +98,42 @@ func main() { } ``` +## 连接入口与物理连接语义 + +- `Connect` / `ConnectTimeout` + - 由 `notify` 自己拨号 + - 支持重连,也支持 dedicated bulk 额外 sidecar 连接 +- `ConnectByFactory` + - 调用方提供 `dialFn` + - `notify` 会在需要时再次调用 `dialFn`,因此仍支持重连和 dedicated bulk +- `ConnectByConn` + - 调用方注入一个已经建立好的 `net.Conn` + - 该模式被视为“单物理连接模式” + - `OpenDedicatedBulk` 会直接返回错误 + - `OpenBulk` 使用 `auto` 模式时会自动回退到 `shared` +- `ListenByListener` + - 服务端复用调用方提供的 `net.Listener` + - 适合需要和现有 listener 栈整合的场景 + +`dedicated bulk` 依赖额外物理连接,因此只适用于可再次拨号的 transport source。 + +## Peer Attach 安全策略 + +可通过 `SetPeerAttachSecurityConfig` 配置逻辑会话 attach 阶段的额外保护。 + +- `RequireExplicitAuth` + - 要求 peer attach 使用显式认证 +- `RequireChannelBinding` + - 要求 attach 绑定到底层可信通道 + - 启用后会隐式要求显式认证 +- `ChannelBinding` + - 由调用方提供 channel binding 提取函数 + - 适合外层 TLS 或其他可信通道整合 +- `ReplayWindow` / `ReplayCapacity` + - 控制 attach 抗重放窗口和缓存容量 + +如果你选择 `UsePSKOverExternalTransport*`,并且希望 attach 阶段显式绑定到外层可信信道,建议同时配置 channel binding。 + ## RecordStream 说明 `RecordStream` 构建在 `Stream` 之上,适合“有边界的顺序记录”场景。 @@ -146,6 +197,9 @@ func main() { - 共享密钥派生(Argon2id) - 消息层加密(AES-GCM) - `stream` / `bulk` fast path 复用现代编码栈 +- peer attach 显式认证 / 抗重放 +- 可选 channel binding +- 可选前向保密(`UseModernPSK*` / `UseNestedSecurity*`) 兼容入口仍保留,但属于历史路径: diff --git a/benchmark_transport_security_test.go b/benchmark_transport_security_test.go new file mode 100644 index 0000000..83a77d5 --- /dev/null +++ b/benchmark_transport_security_test.go @@ -0,0 +1,42 @@ +package notify + +import "testing" + +type benchmarkTransportSecurityMode string + +const ( + benchmarkTransportSecurityModernPSK benchmarkTransportSecurityMode = "modern_psk" + benchmarkTransportSecurityTrustedRaw benchmarkTransportSecurityMode = "trusted_raw" +) + +func benchmarkApplyServerTransportSecurity(tb testing.TB, server *ServerCommon, mode benchmarkTransportSecurityMode) { + tb.Helper() + switch mode { + case benchmarkTransportSecurityModernPSK: + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + tb.Fatalf("UseModernPSKServer failed: %v", err) + } + case benchmarkTransportSecurityTrustedRaw: + if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + tb.Fatalf("UsePSKOverExternalTransportServer failed: %v", err) + } + default: + tb.Fatalf("unsupported benchmark transport security mode %q", mode) + } +} + +func benchmarkApplyClientTransportSecurity(tb testing.TB, client *ClientCommon, mode benchmarkTransportSecurityMode) { + tb.Helper() + switch mode { + case benchmarkTransportSecurityModernPSK: + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + tb.Fatalf("UseModernPSKClient failed: %v", err) + } + case benchmarkTransportSecurityTrustedRaw: + if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + tb.Fatalf("UsePSKOverExternalTransportClient failed: %v", err) + } + default: + tb.Fatalf("unsupported benchmark transport security mode %q", mode) + } +} diff --git a/bulk.go b/bulk.go index 89e8168..53a5375 100644 --- a/bulk.go +++ b/bulk.go @@ -161,6 +161,7 @@ var ( errBulkRangeInvalid = errors.New("bulk range is invalid") errBulkBackpressureExceeded = errors.New("bulk inbound backpressure exceeded") errBulkDedicatedStreamOnly = errors.New("dedicated bulk requires stream transport") + errBulkDedicatedSingleConn = errors.New("dedicated bulk requires a dialable additional connection source; ConnectByConn only supports shared transport") errBulkDedicatedActiveLimit = errors.New("dedicated bulk active limit reached") ) @@ -174,6 +175,9 @@ func clientDedicatedBulkSupportError(c *ClientCommon) error { if source := c.clientConnectSourceSnapshot(); source != nil && source.isUDP() { return errBulkDedicatedStreamOnly } + if source := c.clientConnectSourceSnapshot(); source != nil && !source.supportsAdditionalConn() { + return errBulkDedicatedSingleConn + } return nil } diff --git a/bulk_batch_sender.go b/bulk_batch_sender.go index 02e697d..6146ae8 100644 --- a/bulk_batch_sender.go +++ b/bulk_batch_sender.go @@ -418,18 +418,18 @@ func (s *bulkBatchSender) flush(requests []bulkBatchRequest) error { } }() writeTimeout := s.transportWriteTimeout() + frames := make([][]byte, 0, len(payloads)) + payloadBytes := 0 for _, payload := range payloads { - frame := payload.payload - started := time.Now() - err := s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error { - return writeFramedPayloadUnlocked(conn, queue, frame) - }) - s.binding.observeBulkAdaptivePayloadWrite(len(frame), time.Since(started), writeTimeout, err) - if err != nil { - return err - } + frames = append(frames, payload.payload) + payloadBytes += len(payload.payload) } - return nil + started := time.Now() + err = s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error { + return writeFramedPayloadBatchUnlocked(conn, queue, frames) + }) + s.binding.observeBulkAdaptivePayloadWrite(payloadBytes, time.Since(started), writeTimeout, err) + return err } func (s *bulkBatchSender) encodeRequests(requests []bulkBatchRequest) ([]bulkBatchEncodedPayload, error) { diff --git a/bulk_benchmark_test.go b/bulk_benchmark_test.go index ad10758..e6a480e 100644 --- a/bulk_benchmark_test.go +++ b/bulk_benchmark_test.go @@ -38,7 +38,25 @@ func BenchmarkBulkTCPThroughput(b *testing.B) { for _, tc := range cases { b.Run(tc.name, func(b *testing.B) { - benchmarkBulkTCPThroughput(b, tc.payloadSize, false) + benchmarkBulkTCPThroughput(b, tc.payloadSize, false, benchmarkTransportSecurityModernPSK) + }) + } +} + +func BenchmarkBulkTCPThroughputTrustedRaw(b *testing.B) { + cases := []struct { + name string + payloadSize int + }{ + {name: "chunk_256KiB", payloadSize: 256 * 1024}, + {name: "chunk_512KiB", payloadSize: 512 * 1024}, + {name: "chunk_768KiB", payloadSize: 768 * 1024}, + {name: "chunk_1MiB", payloadSize: 1024 * 1024}, + {name: "chunk_2MiB", payloadSize: 2 * 1024 * 1024}, + } + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkBulkTCPThroughput(b, tc.payloadSize, false, benchmarkTransportSecurityTrustedRaw) }) } } @@ -72,7 +90,25 @@ func BenchmarkBulkTCPThroughputDedicated(b *testing.B) { for _, tc := range cases { b.Run(tc.name, func(b *testing.B) { - benchmarkBulkTCPThroughput(b, tc.payloadSize, true) + benchmarkBulkTCPThroughput(b, tc.payloadSize, true, benchmarkTransportSecurityModernPSK) + }) + } +} + +func BenchmarkBulkTCPThroughputDedicatedTrustedRaw(b *testing.B) { + cases := []struct { + name string + payloadSize int + }{ + {name: "chunk_256KiB", payloadSize: 256 * 1024}, + {name: "chunk_512KiB", payloadSize: 512 * 1024}, + {name: "chunk_768KiB", payloadSize: 768 * 1024}, + {name: "chunk_1MiB", payloadSize: 1024 * 1024}, + {name: "chunk_2MiB", payloadSize: 2 * 1024 * 1024}, + } + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkBulkTCPThroughput(b, tc.payloadSize, true, benchmarkTransportSecurityTrustedRaw) }) } } @@ -107,7 +143,25 @@ func BenchmarkBulkTCPThroughputConcurrent(b *testing.B) { for _, tc := range cases { b.Run(tc.name, func(b *testing.B) { - benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false) + benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false, benchmarkTransportSecurityModernPSK) + }) + } +} + +func BenchmarkBulkTCPThroughputConcurrentTrustedRaw(b *testing.B) { + cases := []struct { + name string + payloadSize int + concurrency int + }{ + {name: "bulks_2_512KiB", payloadSize: 512 * 1024, concurrency: 2}, + {name: "bulks_4_512KiB", payloadSize: 512 * 1024, concurrency: 4}, + {name: "bulks_2_1MiB", payloadSize: 1024 * 1024, concurrency: 2}, + {name: "bulks_4_1MiB", payloadSize: 1024 * 1024, concurrency: 4}, + } + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false, benchmarkTransportSecurityTrustedRaw) }) } } @@ -142,18 +196,34 @@ func BenchmarkBulkTCPThroughputConcurrentDedicated(b *testing.B) { for _, tc := range cases { b.Run(tc.name, func(b *testing.B) { - benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true) + benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true, benchmarkTransportSecurityModernPSK) }) } } -func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) { +func BenchmarkBulkTCPThroughputConcurrentDedicatedTrustedRaw(b *testing.B) { + cases := []struct { + name string + payloadSize int + concurrency int + }{ + {name: "bulks_2_512KiB", payloadSize: 512 * 1024, concurrency: 2}, + {name: "bulks_4_512KiB", payloadSize: 512 * 1024, concurrency: 4}, + {name: "bulks_2_1MiB", payloadSize: 1024 * 1024, concurrency: 2}, + {name: "bulks_4_1MiB", payloadSize: 1024 * 1024, concurrency: 4}, + } + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true, benchmarkTransportSecurityTrustedRaw) + }) + } +} + +func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool, securityMode benchmarkTransportSecurityMode) { b.Helper() server := NewServer().(*ServerCommon) - if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { - b.Fatalf("UseModernPSKServer failed: %v", err) - } + benchmarkApplyServerTransportSecurity(b, server, securityMode) acceptCh := make(chan BulkAcceptInfo, 1) server.SetBulkHandler(func(info BulkAcceptInfo) error { @@ -169,9 +239,7 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) { }) client := NewClient().(*ClientCommon) - if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { - b.Fatalf("UseModernPSKClient failed: %v", err) - } + benchmarkApplyClientTransportSecurity(b, client, securityMode) if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil { b.Fatalf("client Connect failed: %v", err) } @@ -241,16 +309,14 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) { _ = bulk.Close() } -func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool) { +func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool, securityMode benchmarkTransportSecurityMode) { b.Helper() if concurrency <= 0 { b.Fatal("concurrency must be > 0") } server := NewServer().(*ServerCommon) - if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { - b.Fatalf("UseModernPSKServer failed: %v", err) - } + benchmarkApplyServerTransportSecurity(b, server, securityMode) acceptCh := make(chan BulkAcceptInfo, concurrency*2) server.SetBulkHandler(func(info BulkAcceptInfo) error { @@ -266,9 +332,7 @@ func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurr }) client := NewClient().(*ClientCommon) - if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { - b.Fatalf("UseModernPSKClient failed: %v", err) - } + benchmarkApplyClientTransportSecurity(b, client, securityMode) if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil { b.Fatalf("client Connect failed: %v", err) } diff --git a/bulk_dedicated.go b/bulk_dedicated.go index 8f63ca2..b6c172f 100644 --- a/bulk_dedicated.go +++ b/bulk_dedicated.go @@ -268,6 +268,9 @@ func readBulkDedicatedRecordPooled(conn net.Conn) ([]byte, func(), error) { func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context, timeout time.Duration) (net.Conn, error) { source := c.clientConnectSourceSnapshot() if source != nil { + if !source.supportsAdditionalConn() { + return nil, errBulkDedicatedSingleConn + } if source.network != "" && source.addr != "" { if timeout > 0 { return transport.DialTimeout(source.network, source.addr, timeout) @@ -277,6 +280,7 @@ func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context, timeout time.D if source.canReconnect() { return source.dial(ctx) } + return nil, errClientReconnectSourceUnavailable } conn := c.clientTransportConnSnapshot() if conn == nil || conn.RemoteAddr() == nil { @@ -661,7 +665,8 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn Value: reqPayload, Type: MSG_SYS_WAIT, } - frame, err := encodeDirectSignalFrame(stario.NewQueue(), c.sequenceEn, c.msgEn, c.SecretKey, msg) + attachProfile := c.clientDedicatedBulkAttachTransportProtectionProfile() + frame, err := encodeDirectSignalFrame(stario.NewQueue(), c.sequenceEn, attachProfile.msgEn, attachProfile.secretKey, msg) if err != nil { return bulkAttachResponse{}, err } @@ -675,7 +680,7 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn if err != nil { return bulkAttachResponse{}, err } - transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, replyPayload) + transfer, err := decodeDirectSignalPayload(c.sequenceDe, attachProfile.msgDe, attachProfile.secretKey, replyPayload) if err != nil { return bulkAttachResponse{}, err } @@ -685,6 +690,16 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn return decodeBulkAttachResponse(c.sequenceDe, transfer.Value) } +func (c *ClientCommon) clientDedicatedBulkAttachTransportProtectionProfile() transportProtectionProfile { + if c == nil { + return transportProtectionProfile{} + } + if c.securityConfigured && c.securityBootstrap.msgEn != nil && c.securityBootstrap.msgDe != nil { + return c.securityBootstrap.clone() + } + return c.clientTransportProtectionSnapshot() +} + func (c *ClientCommon) readDedicatedSidecarLoop(sidecar *bulkDedicatedSidecar) { if c == nil || sidecar == nil || sidecar.conn == nil { return @@ -695,7 +710,8 @@ func (c *ClientCommon) readDedicatedSidecarLoop(sidecar *bulkDedicatedSidecar) { c.handleClientDedicatedSidecarFailure(sidecar, err) return } - plain, plainRelease, err := decryptTransportPayloadCodecPooled(c.modernPSKRuntime, c.msgDe, c.SecretKey, payload, payloadRelease) + profile := c.clientTransportProtectionSnapshot() + plain, plainRelease, err := decryptTransportPayloadCodecPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload, payloadRelease) if err != nil { c.handleClientDedicatedSidecarFailure(sidecar, err) return @@ -1023,7 +1039,7 @@ func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Mes Type: MSG_SYS_REPLY, } if message.inboundConn != nil { - return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply) + return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, messageInboundTransportProtectionSnapshot(&message), reply) } _, err = s.sendLogical(client, reply) return err @@ -1041,7 +1057,7 @@ func (s *ServerCommon) readDedicatedSidecarLoop(logical *LogicalConn, sidecar *b s.handleServerDedicatedSidecarFailure(logical, sidecar, err) return } - plain, plainRelease, err := decryptTransportPayloadCodecPooled(logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, payloadRelease) + plain, plainRelease, err := decryptTransportPayloadCodecPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, payloadRelease) if err != nil { s.handleServerDedicatedSidecarFailure(logical, sidecar, err) return @@ -1171,7 +1187,8 @@ func (c *ClientCommon) dedicatedBulkLaneSender(bulk *bulkHandle) (*bulkDedicated return nil, transportDetachedError("dedicated bulk sidecar not attached", nil) } sender := sidecar.laneSenderWithFactory(func(conn net.Conn) *bulkDedicatedLaneSender { - laneRuntime := c.modernPSKRuntime + profile := c.clientTransportProtectionSnapshot() + laneRuntime := profile.runtime if forked, err := forkDedicatedLaneModernPSKRuntime(laneRuntime); err == nil && forked != nil { laneRuntime = forked } @@ -1198,7 +1215,7 @@ func (c *ClientCommon) sendDedicatedBulkData(ctx context.Context, bulk *bulkHand return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk) } -func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte) (int, error) { +func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) { if c == nil || bulk == nil { return 0, errBulkClientNil } @@ -1206,7 +1223,7 @@ func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHan if err != nil { return 0, err } - return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize) + return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize, payloadOwned) } func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error { @@ -1345,7 +1362,7 @@ func (s *ServerCommon) sendDedicatedBulkData(ctx context.Context, logical *Logic return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk) } -func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, startSeq uint64, payload []byte) (int, error) { +func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) { if s == nil || bulk == nil { return 0, errBulkServerNil } @@ -1353,7 +1370,7 @@ func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *Logi if err != nil { return 0, err } - return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize) + return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize, payloadOwned) } func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error { @@ -1419,13 +1436,14 @@ func (c *ClientCommon) encodeDedicatedBulkBatchPayload(dataID uint64, items []bu if c == nil { return nil, errBulkClientNil } - if runtime := c.modernPSKRuntime; runtime != nil { + profile := c.clientTransportProtectionSnapshot() + if runtime := profile.runtime; runtime != nil { return runtime.sealFilledPayload(bulkDedicatedBatchPlainLen(items), func(dst []byte) error { return writeBulkDedicatedBatchPlain(dst, dataID, items) }) } - if c.fastPlainEncode != nil { - return encodeBulkDedicatedBatchPayloadFast(c.fastPlainEncode, c.SecretKey, dataID, items) + if profile.fastPlainEncode != nil { + return encodeBulkDedicatedBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, dataID, items) } plain, err := encodeBulkDedicatedBatchPlain(dataID, items) if err != nil { @@ -1453,8 +1471,9 @@ func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooledWithRuntime(runtim return writeBulkDedicatedBatchesPlain(dst, batches) }) } - if c.fastPlainEncode != nil { - payload, err := encodeBulkDedicatedBatchesPayloadFast(c.fastPlainEncode, c.SecretKey, batches) + profile := c.clientTransportProtectionSnapshot() + if profile.fastPlainEncode != nil { + payload, err := encodeBulkDedicatedBatchesPayloadFast(profile.fastPlainEncode, profile.secretKey, batches) return payload, nil, err } plain, err := encodeBulkDedicatedBatchesPlain(batches) @@ -1466,7 +1485,7 @@ func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooledWithRuntime(runtim } func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooled(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) { - return c.encodeDedicatedBulkBatchesPayloadPooledWithRuntime(c.modernPSKRuntime, batches) + return c.encodeDedicatedBulkBatchesPayloadPooledWithRuntime(c.clientTransportProtectionSnapshot().runtime, batches) } func (c *ClientCommon) encodeDedicatedBulkBatchPayloadPooled(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, func(), error) { diff --git a/bulk_dedicated_attach_test.go b/bulk_dedicated_attach_test.go index 643029a..ab6b2b8 100644 --- a/bulk_dedicated_attach_test.go +++ b/bulk_dedicated_attach_test.go @@ -115,6 +115,76 @@ func TestSendDedicatedBulkAttachRequestKeepsCoalescedDedicatedPayloadUnread(t *t } } +func TestSendDedicatedBulkAttachRequestUsesBootstrapProtectionEvenAfterSteadySwitch(t *testing.T) { + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + client.msgID = 100 + + alternate, err := deriveModernPSKProtectionProfile([]byte("notify-dedicated-attach-other-secret"), integrationModernPSKOptions(), ProtectionManaged) + if err != nil { + t.Fatalf("deriveModernPSKProtectionProfile(alternate) failed: %v", err) + } + client.setClientTransportProtectionProfile(alternate) + + bulk := newBulkHandle(context.Background(), newBulkRuntime("dedicated-attach-bootstrap-test"), clientFileScope(), BulkOpenRequest{ + BulkID: "bulk-attach-bootstrap-test", + DataID: 1, + Dedicated: true, + AttachToken: "attach-token", + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + + bootstrap := client.clientDedicatedBulkAttachTransportProtectionProfile() + encodedResp, err := client.sequenceEn(bulkAttachResponse{Accepted: true}) + if err != nil { + t.Fatalf("encode bulkAttachResponse failed: %v", err) + } + replyFrame, err := encodeDirectSignalFrame(stario.NewQueue(), client.sequenceEn, bootstrap.msgEn, bootstrap.secretKey, TransferMsg{ + ID: 101, + Key: systemBulkAttachKey, + Value: encodedResp, + Type: MSG_SYS_REPLY, + }) + if err != nil { + t.Fatalf("encode attach reply frame failed: %v", err) + } + conn := newBulkAttachScriptConn(replyFrame) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + resp, err := client.sendDedicatedBulkAttachRequest(ctx, conn, bulk) + if err != nil { + t.Fatalf("sendDedicatedBulkAttachRequest failed: %v", err) + } + if !resp.Accepted { + t.Fatalf("bulk attach response = %+v, want accepted", resp) + } + + parsedReq := stario.NewQueue() + var reqMsg TransferMsg + if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request", func(msgq stario.MsgQueue) error { + transfer, err := decodeDirectSignalPayload(client.sequenceDe, bootstrap.msgDe, bootstrap.secretKey, msgq.Msg) + if err != nil { + return err + } + reqMsg = transfer + return nil + }); err != nil { + t.Fatalf("parse written attach request with bootstrap profile failed: %v", err) + } + if reqMsg.Key != systemBulkAttachKey || reqMsg.Type != MSG_SYS_WAIT { + t.Fatalf("attach request message mismatch: %+v", reqMsg) + } + + if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request-current", func(msgq stario.MsgQueue) error { + _, err := decodeDirectSignalPayload(client.sequenceDe, alternate.msgDe, alternate.secretKey, msgq.Msg) + return err + }); !errors.Is(err, errTransportPayloadDecryptFailed) { + t.Fatalf("decode written attach request with current steady profile error = %v, want %v", err, errTransportPayloadDecryptFailed) + } +} + func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHandoff(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) diff --git a/bulk_dedicated_lane_sender.go b/bulk_dedicated_lane_sender.go index b563bc6..2608dc4 100644 --- a/bulk_dedicated_lane_sender.go +++ b/bulk_dedicated_lane_sender.go @@ -83,7 +83,7 @@ func (r *bulkDedicatedLaneBatchRequest) reset() { } } -func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool) { +func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool, borrowItems bool) { if r == nil { return } @@ -94,6 +94,10 @@ func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint if deadline, ok := ctx.Deadline(); ok { r.Deadline = deadline } + if borrowItems { + r.Items = items + return + } if cap(r.Items) < len(items) { r.Items = make([]bulkDedicatedSendRequest, len(items)) } else { @@ -119,10 +123,10 @@ func (s *bulkDedicatedLaneSender) submitData(ctx context.Context, dataID uint64, Seq: seq, Payload: append([]byte(nil), payload...), }} - return s.submitBatch(ctx, dataID, items, false) + return s.submitBatch(ctx, dataID, items, false, false) } -func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int) (int, error) { +func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int, payloadOwned bool) (int, error) { if s == nil { return 0, errTransportDetached } @@ -132,6 +136,9 @@ func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64 if chunkSize <= 0 { chunkSize = defaultBulkChunkSize } + if submitted, written, err := s.tryDirectSubmitWrite(ctx, dataID, startSeq, payload, chunkSize); submitted { + return written, err + } written := 0 seq := startSeq for written < len(payload) { @@ -170,7 +177,7 @@ func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64 seq++ written = end } - if err := s.submitWriteBatch(ctx, dataID, items); err != nil { + if err := s.submitWriteBatch(ctx, dataID, items, payloadOwned); err != nil { return start, err } start = written @@ -178,7 +185,7 @@ func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64 return written, nil } -func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) error { +func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, _ bool) error { if s == nil { return errTransportDetached } @@ -188,8 +195,7 @@ func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID u if submitted, err := s.tryDirectSubmitBatch(ctx, dataID, items); submitted { return err } - queuedItems := copyBulkDedicatedSendRequests(items) - return s.submitBatch(ctx, dataID, queuedItems, true) + return s.submitBatch(ctx, dataID, items, true, true) } func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint64, frameType uint8, flags uint8, seq uint64, payload []byte) error { @@ -204,10 +210,10 @@ func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint if len(payload) > 0 { items[0].Payload = append([]byte(nil), payload...) } - return s.submitBatch(ctx, dataID, items, true) + return s.submitBatch(ctx, dataID, items, true, false) } -func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool) error { +func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool, borrowItems bool) error { if s == nil { return errTransportDetached } @@ -218,7 +224,7 @@ func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64 return err } req := getBulkDedicatedLaneBatchRequest() - req.prepare(ctx, dataID, items, wait) + req.prepare(ctx, dataID, items, wait, borrowItems) s.queued.Add(1) select { case <-ctx.Done(): @@ -237,6 +243,104 @@ func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64 } } +func (s *bulkDedicatedLaneSender) tryDirectSubmitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int) (bool, int, error) { + if s == nil { + return true, 0, errTransportDetached + } + if ctx == nil { + ctx = context.Background() + } + if len(payload) == 0 { + return true, 0, nil + } + if chunkSize <= 0 { + chunkSize = defaultBulkChunkSize + } + if err := s.errSnapshot(); err != nil { + return true, 0, err + } + select { + case <-ctx.Done(): + return true, 0, normalizeStreamDeadlineError(ctx.Err()) + case <-s.stopCh: + return true, 0, s.stoppedErr() + default: + } + if s.queued.Load() != 0 { + return false, 0, nil + } + if !s.flushMu.TryLock() { + return false, 0, nil + } + defer s.flushMu.Unlock() + if s.queued.Load() != 0 { + return false, 0, nil + } + if err := s.errSnapshot(); err != nil { + return true, 0, err + } + written := 0 + seq := startSeq + deadline, _ := ctx.Deadline() + for written < len(payload) { + select { + case <-ctx.Done(): + return true, written, normalizeStreamDeadlineError(ctx.Err()) + case <-s.stopCh: + return true, written, s.stoppedErr() + default: + } + var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest + items := itemBuf[:0] + batchBytes := bulkDedicatedBatchHeaderLen + start := written + for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems { + end := written + chunkSize + if end > len(payload) { + end = len(payload) + } + itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written) + if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes { + break + } + items = append(items, bulkDedicatedSendRequest{ + Type: bulkFastPayloadTypeData, + Seq: seq, + Payload: payload[written:end], + }) + batchBytes += itemLen + seq++ + written = end + } + if len(items) == 0 { + end := written + chunkSize + if end > len(payload) { + end = len(payload) + } + items = append(items, bulkDedicatedSendRequest{ + Type: bulkFastPayloadTypeData, + Seq: seq, + Payload: payload[written:end], + }) + seq++ + written = end + } + if err := s.flush([]bulkDedicatedOutboundBatch{{ + DataID: dataID, + Items: items, + }}, deadline); err != nil { + err = normalizeDedicatedBulkSendError(err) + s.setErr(err) + s.failPending(err) + if s.fail != nil { + go s.fail(err) + } + return true, start, err + } + } + return true, written, nil +} + func (s *bulkDedicatedLaneSender) tryDirectSubmitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) (bool, error) { if s == nil { return true, errTransportDetached diff --git a/bulk_dedicated_lane_sender_test.go b/bulk_dedicated_lane_sender_test.go index b29d3a6..c91fa67 100644 --- a/bulk_dedicated_lane_sender_test.go +++ b/bulk_dedicated_lane_sender_test.go @@ -7,6 +7,57 @@ import ( "time" ) +func TestBulkDedicatedLaneBatchRequestPrepareBorrowedSharesItems(t *testing.T) { + req := getBulkDedicatedLaneBatchRequest() + defer req.recycle() + + items := []bulkDedicatedSendRequest{{ + Type: bulkFastPayloadTypeData, + Seq: 7, + Payload: []byte("hello"), + }} + req.prepare(context.Background(), 11, items, true, true) + + if got, want := len(req.Items), 1; got != want { + t.Fatalf("prepared item count = %d, want %d", got, want) + } + if &req.Items[0] != &items[0] { + t.Fatal("prepare with borrowed items should share request items") + } +} + +func TestBulkDedicatedLaneSenderTryDirectSubmitWriteFlushesWholePayload(t *testing.T) { + conn := &shortWriteBulkRecordConn{maxPerWrite: 1024} + encodeCalls := 0 + sender := &bulkDedicatedLaneSender{ + conn: conn, + stopCh: make(chan struct{}), + encode: func(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) { + encodeCalls++ + payload, err := encodeBulkDedicatedBatchesPlain(batches) + return payload, nil, err + }, + } + + payload := bytes.Repeat([]byte("a"), 3*defaultBulkChunkSize) + submitted, written, err := sender.tryDirectSubmitWrite(context.Background(), 9, 1, payload, defaultBulkChunkSize) + if err != nil { + t.Fatalf("tryDirectSubmitWrite error = %v", err) + } + if !submitted { + t.Fatal("tryDirectSubmitWrite should submit directly") + } + if got, want := written, len(payload); got != want { + t.Fatalf("written = %d, want %d", got, want) + } + if encodeCalls == 0 { + t.Fatal("encode should be called at least once") + } + if got := sender.queued.Load(); got != 0 { + t.Fatalf("queued requests = %d, want 0", got) + } +} + func TestBulkDedicatedLaneSenderCollectBatchRequestsBatchesAcrossDataIDs(t *testing.T) { sender := &bulkDedicatedLaneSender{ reqCh: make(chan *bulkDedicatedLaneBatchRequest, 3), diff --git a/bulk_dispatcher.go b/bulk_dispatcher.go index f3e2aa3..d899d57 100644 --- a/bulk_dispatcher.go +++ b/bulk_dispatcher.go @@ -162,7 +162,8 @@ func (c *ClientCommon) tryDispatchBorrowedBulkTransportPayload(payload []byte) b if c == nil || len(payload) == 0 { return false } - plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(c.modernPSKRuntime, c.msgDe, c.SecretKey, payload) + profile := c.clientTransportProtectionSnapshot() + plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload) if err != nil { if c.showError || c.debugMode { fmt.Println("client decode transport payload error", err) @@ -194,7 +195,7 @@ func (s *ServerCommon) tryDispatchBorrowedBulkTransportPayload(source interface{ if logical == nil { return false } - plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload) + plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload) if err != nil { if s.showError || s.debugMode { fmt.Println("server decode transport payload error", err) diff --git a/bulk_fastpath.go b/bulk_fastpath.go index f64175c..cd3cd61 100644 --- a/bulk_fastpath.go +++ b/bulk_fastpath.go @@ -4,6 +4,7 @@ import ( "context" "encoding/binary" "errors" + "fmt" "net" "sync" "time" @@ -132,8 +133,9 @@ func (c *ClientCommon) encodeBulkFastPayload(frame bulkFastFrame) ([]byte, error if c == nil { return nil, errBulkClientNil } - if c.fastPlainEncode != nil { - return encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame) + profile := c.clientTransportProtectionSnapshot() + if profile.fastPlainEncode != nil { + return encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame) } plain, err := encodeBulkFastFramePayload(frame) if err != nil { @@ -146,8 +148,9 @@ func (c *ClientCommon) encodeBulkFastBatchPayload(frames []bulkFastFrame) ([]byt if c == nil { return nil, errBulkClientNil } - if c.fastPlainEncode != nil { - return encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames) + profile := c.clientTransportProtectionSnapshot() + if profile.fastPlainEncode != nil { + return encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames) } plain, err := encodeBulkFastBatchPlain(frames) if err != nil { @@ -160,11 +163,12 @@ func (c *ClientCommon) encodeBulkFastPayloadPooled(frame bulkFastFrame) ([]byte, if c == nil { return nil, nil, errBulkClientNil } - if runtime := c.modernPSKRuntime; runtime != nil { + profile := c.clientTransportProtectionSnapshot() + if runtime := profile.runtime; runtime != nil { return encodeBulkFastFramePayloadPooled(runtime, frame) } - if c.fastPlainEncode != nil { - payload, err := encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame) + if profile.fastPlainEncode != nil { + payload, err := encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame) return payload, nil, err } plain, err := encodeBulkFastFramePayload(frame) @@ -179,11 +183,12 @@ func (c *ClientCommon) encodeBulkFastBatchPayloadPooled(frames []bulkFastFrame) if c == nil { return nil, nil, errBulkClientNil } - if runtime := c.modernPSKRuntime; runtime != nil { + profile := c.clientTransportProtectionSnapshot() + if runtime := profile.runtime; runtime != nil { return encodeBulkFastBatchPayloadPooled(runtime, frames) } - if c.fastPlainEncode != nil { - payload, err := encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames) + if profile.fastPlainEncode != nil { + payload, err := encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames) return payload, nil, err } plain, err := encodeBulkFastBatchPlain(frames) @@ -460,28 +465,126 @@ func putBulkFastFrameScratch(buf []byte) { bulkFastFrameScratchPool.Put(buf[:0]) } +func transportFastPayloadMagic(payload []byte) string { + if len(payload) < 4 { + return "" + } + return string(payload[:4]) +} + +func (c *ClientCommon) decryptTransportPayloadPooled(payload []byte, release func()) ([]byte, func(), error) { + profile := c.clientTransportProtectionSnapshot() + return decryptTransportPayloadCodecPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload, release) +} + +func (s *ServerCommon) decryptTransportPayloadLogicalPooled(logical *LogicalConn, payload []byte, release func()) ([]byte, func(), error) { + if logical == nil { + if release != nil { + release() + } + return nil, nil, errTransportDetached + } + return decryptTransportPayloadCodecPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, release) +} + +func (c *ClientCommon) tryDispatchBorrowedTransportPlain(plain []byte, release func()) bool { + switch transportFastPayloadMagic(plain) { + case bulkFastPayloadMagic, bulkFastBatchMagic: + owner := newBulkReadPayloadOwner(release) + matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error { + c.dispatchFastBulkFrameWithOwner(frame, owner) + return nil + }) + if owner != nil { + owner.done() + } + if !matched { + walkErr = errBulkFastPayloadInvalid + } + if walkErr != nil && (c.showError || c.debugMode) { + fmt.Println("client decode bulk fast payload error", walkErr) + } + return true + case streamFastPayloadMagic, streamFastBatchMagic: + owner := newStreamReadPayloadOwner(release) + matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error { + c.dispatchFastStreamDataWithOwner(frame, owner) + return nil + }) + if owner != nil { + owner.done() + } + if !matched { + walkErr = errStreamFastPayloadInvalid + } + if walkErr != nil && (c.showError || c.debugMode) { + fmt.Println("client decode stream fast payload error", walkErr) + } + return true + default: + return false + } +} + +func (s *ServerCommon) tryDispatchBorrowedTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, release func()) bool { + switch transportFastPayloadMagic(plain) { + case bulkFastPayloadMagic, bulkFastBatchMagic: + owner := newBulkReadPayloadOwner(release) + matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error { + s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, owner) + return nil + }) + if owner != nil { + owner.done() + } + if !matched { + walkErr = errBulkFastPayloadInvalid + } + if walkErr != nil && (s.showError || s.debugMode) { + fmt.Println("server decode bulk fast payload error", walkErr) + } + return true + case streamFastPayloadMagic, streamFastBatchMagic: + owner := newStreamReadPayloadOwner(release) + matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error { + s.dispatchFastStreamDataWithOwner(logical, transport, conn, frame, owner) + return nil + }) + if owner != nil { + owner.done() + } + if !matched { + walkErr = errStreamFastPayloadInvalid + } + if walkErr != nil && (s.showError || s.debugMode) { + fmt.Println("server decode stream fast payload error", walkErr) + } + return true + default: + return false + } +} + func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error { plain, err := c.decryptTransportPayload(payload) if err != nil { return err } - if frames, matched, err := decodeBulkFastFrames(plain); matched { - if err != nil { - return err - } - for _, frame := range frames { - c.dispatchFastBulkFrame(frame) - } + return c.dispatchInboundTransportPlain(plain, now) +} + +func (c *ClientCommon) dispatchInboundTransportPlain(plain []byte, now time.Time) error { + if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error { + c.dispatchFastBulkFrame(frame) return nil + }); matched { + return err } - if frames, matched, err := decodeStreamFastDataFrames(plain); matched { - if err != nil { - return err - } - for _, frame := range frames { - c.dispatchFastStreamData(frame) - } + if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error { + c.dispatchFastStreamData(frame) return nil + }); matched { + return err } env, err := c.decodeEnvelopePlain(plain) if err != nil { @@ -502,23 +605,21 @@ func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, tra if err != nil { return err } - if frames, matched, err := decodeBulkFastFrames(plain); matched { - if err != nil { - return err - } - for _, frame := range frames { - s.dispatchFastBulkFrame(logical, transport, conn, frame) - } + return s.dispatchInboundTransportPlain(logical, transport, conn, plain, now) +} + +func (s *ServerCommon) dispatchInboundTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, now time.Time) error { + if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error { + s.dispatchFastBulkFrame(logical, transport, conn, frame) return nil + }); matched { + return err } - if frames, matched, err := decodeStreamFastDataFrames(plain); matched { - if err != nil { - return err - } - for _, frame := range frames { - s.dispatchFastStreamData(logical, transport, conn, frame) - } + if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error { + s.dispatchFastStreamData(logical, transport, conn, frame) return nil + }); matched { + return err } env, err := s.decodeEnvelopePlain(plain) if err != nil { diff --git a/bulk_transport_guard_test.go b/bulk_transport_guard_test.go index bd099ea..bd8cae2 100644 --- a/bulk_transport_guard_test.go +++ b/bulk_transport_guard_test.go @@ -100,6 +100,176 @@ func TestBulkOpenAutoUDPFallsBackToShared(t *testing.T) { _ = accepted.Bulk.Close() } +func TestOpenDedicatedBulkConnectByConnRejectedAsSingleConnMode(t *testing.T) { + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err) + } + server.SetBulkHandler(func(info BulkAcceptInfo) error { + return nil + }) + }) + + client := NewClient().(*ClientCommon) + if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err) + } + + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachConnForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + _, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{ + Offset: 0, + Length: 128, + }, + }) + if !errors.Is(err, errBulkDedicatedSingleConn) { + t.Fatalf("client OpenDedicatedBulk over ConnectByConn error = %v, want %v", err, errBulkDedicatedSingleConn) + } +} + +func TestOpenBulkAutoConnectByConnFallsBackToShared(t *testing.T) { + acceptCh := make(chan BulkAcceptInfo, 2) + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err) + } + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + }) + + client := NewClient().(*ClientCommon) + if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err) + } + + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachConnForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Mode: BulkOpenModeAuto, + Range: BulkRange{ + Offset: 0, + Length: 128, + }, + }) + if err != nil { + t.Fatalf("client OpenBulk auto over ConnectByConn failed: %v", err) + } + if bulk.Snapshot().Dedicated { + t.Fatal("client OpenBulk auto over ConnectByConn should fall back to shared") + } + defer func() { + _ = bulk.Close() + }() + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + if accepted.Dedicated { + t.Fatal("server accepted bulk should be shared after ConnectByConn auto fallback") + } + if _, err := bulk.Write([]byte("shared-over-single-conn")); err != nil { + t.Fatalf("client bulk Write failed: %v", err) + } + readBulkExactly(t, accepted.Bulk, "shared-over-single-conn", 2*time.Second) + select { + case extra := <-acceptCh: + t.Fatalf("unexpected extra server bulk accept: %+v", extra) + case <-time.After(300 * time.Millisecond): + } + _ = accepted.Bulk.Close() +} + +func TestOpenDedicatedBulkExternalTransportDialableSourceSucceeds(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 2) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + bulk, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{ + ID: "external-dedicated-dialable", + Range: BulkRange{ + Offset: 0, + Length: 64, + }, + }) + if err != nil { + t.Fatalf("client OpenDedicatedBulk failed: %v", err) + } + if !bulk.Snapshot().Dedicated { + t.Fatal("client OpenDedicatedBulk over dialable external transport should stay dedicated") + } + defer func() { + _ = bulk.Close() + }() + + accepted := waitAcceptedBulkByID(t, acceptCh, bulk.ID(), 2*time.Second) + if !accepted.Dedicated { + t.Fatal("server accepted bulk should stay dedicated over dialable external transport") + } + defer func() { + _ = accepted.Bulk.Close() + }() + + clientHandle := bulk.(*bulkHandle) + if clientHandle.dedicatedConnSnapshot() == nil { + t.Fatal("client dedicated sidecar conn should be attached") + } + if mainConn := client.clientTransportConnSnapshot(); mainConn != nil && clientHandle.dedicatedConnSnapshot() == mainConn { + t.Fatal("client dedicated sidecar should use an additional physical connection") + } + if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionExternal { + t.Fatalf("client steady protection mode = %v, want %v", got, ProtectionExternal) + } + + payload := "external-dedicated-sidecar" + if _, err := bulk.Write([]byte(payload)); err != nil { + t.Fatalf("client bulk Write failed: %v", err) + } + readBulkExactly(t, accepted.Bulk, payload, 2*time.Second) +} + func TestOpenDedicatedBulkWaitsForActiveSlotUntilContextDeadline(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { diff --git a/client.go b/client.go index 5d4b8fc..666b325 100644 --- a/client.go +++ b/client.go @@ -10,75 +10,87 @@ import ( ) type ClientCommon struct { - alive atomic.Value - status Status - byeFromServer bool - conn net.Conn - mu sync.Mutex - msgID uint64 - peerIdentity string - sessionEpoch uint64 - sessionOwnerState atomic.Int32 - sessionRuntime atomic.Pointer[clientSessionRuntime] - connectSource atomic.Pointer[clientConnectSource] - queue *stario.StarQueue - stopFn context.CancelFunc - stopCtx context.Context - parallelNum int - maxReadTimeout time.Duration - maxWriteTimeout time.Duration - keyExchangeFn func(c Client) error - linkFns map[string]func(message *Message) - defaultFns func(message *Message) - msgEn func([]byte, []byte) []byte - msgDe func([]byte, []byte) []byte - fastStreamEncode transportFastStreamEncoder - fastBulkEncode transportFastBulkEncoder - fastPlainEncode transportFastPlainEncoder - modernPSKRuntime *modernPSKCodecRuntime - handshakeRsaPubKey []byte - SecretKey []byte - noFinSyncMsgMaxKeepSeconds int - lastHeartbeat int64 - heartbeatPeriod time.Duration - wg stario.WaitGroup - netType NetType - showError bool - skipKeyExchange bool - useHeartBeat bool - sequenceDe func([]byte) (interface{}, error) - sequenceEn func(interface{}) ([]byte, error) - logicalSession *logicalSessionState - onFileEvent func(FileEvent) - fileEventObserver func(FileEvent) - fileTransferCfg fileTransferConfig - signalReliableCfg signalReliabilityConfig - streamRuntime *streamRuntime - recordRuntime *recordRuntime - bulkRuntime *bulkRuntime - bulkDefaultOpenMode BulkOpenMode - bulkNetworkProfile BulkNetworkProfile - bulkOpenTuning BulkOpenTuning - bulkDedicatedAttachLimit int - bulkDedicatedAttachSem chan struct{} - bulkDedicatedAttachRetry int - bulkDedicatedAttachBackoff time.Duration - bulkDedicatedDialTimeout time.Duration - bulkDedicatedHelloTimeout time.Duration - bulkDedicatedActiveLimit int - bulkDedicatedActive atomic.Int32 - bulkDedicatedActiveWait chan struct{} - bulkDedicatedLaneLimit int - bulkDedicatedSidecarMu sync.Mutex - bulkDedicatedLanes map[uint32]*bulkDedicatedLane - bulkDedicatedNextLaneID uint32 - bulkAttachAttemptCount atomic.Int64 - bulkAttachRetryCount atomic.Int64 - bulkAttachSuccessCount atomic.Int64 - bulkAttachFallbackCount atomic.Int64 - connectionRetryState *connectionRetryState - securityReadyCheck bool - debugMode bool + alive atomic.Value + status Status + byeFromServer bool + conn net.Conn + mu sync.Mutex + msgID uint64 + peerIdentity string + sessionEpoch uint64 + sessionOwnerState atomic.Int32 + sessionRuntime atomic.Pointer[clientSessionRuntime] + connectSource atomic.Pointer[clientConnectSource] + queue *stario.StarQueue + stopFn context.CancelFunc + stopCtx context.Context + parallelNum int + maxReadTimeout time.Duration + maxWriteTimeout time.Duration + keyExchangeFn func(c Client) error + linkFns map[string]func(message *Message) + defaultFns func(message *Message) + msgEn func([]byte, []byte) []byte + msgDe func([]byte, []byte) []byte + fastStreamEncode transportFastStreamEncoder + fastBulkEncode transportFastBulkEncoder + fastPlainEncode transportFastPlainEncoder + modernPSKRuntime *modernPSKCodecRuntime + handshakeRsaPubKey []byte + SecretKey []byte + transportProtection atomic.Pointer[transportProtectionProfile] + peerAttachSecurity atomic.Pointer[peerAttachSecurityState] + securityBootstrap transportProtectionProfile + securitySteady transportProtectionProfile + securitySteadyNegotiated transportProtectionProfile + securityAuthMode AuthMode + securityProtectionMode ProtectionMode + securityRequireForwardSecrecy bool + securityConfigured bool + peerAttachAuthenticated bool + peerAttachAuthFallback bool + peerAttachAt int64 + noFinSyncMsgMaxKeepSeconds int + lastHeartbeat int64 + heartbeatPeriod time.Duration + wg stario.WaitGroup + netType NetType + showError bool + skipKeyExchange bool + useHeartBeat bool + sequenceDe func([]byte) (interface{}, error) + sequenceEn func(interface{}) ([]byte, error) + logicalSession *logicalSessionState + onFileEvent func(FileEvent) + fileEventObserver func(FileEvent) + fileTransferCfg fileTransferConfig + signalReliableCfg signalReliabilityConfig + streamRuntime *streamRuntime + recordRuntime *recordRuntime + bulkRuntime *bulkRuntime + bulkDefaultOpenMode BulkOpenMode + bulkNetworkProfile BulkNetworkProfile + bulkOpenTuning BulkOpenTuning + bulkDedicatedAttachLimit int + bulkDedicatedAttachSem chan struct{} + bulkDedicatedAttachRetry int + bulkDedicatedAttachBackoff time.Duration + bulkDedicatedDialTimeout time.Duration + bulkDedicatedHelloTimeout time.Duration + bulkDedicatedActiveLimit int + bulkDedicatedActive atomic.Int32 + bulkDedicatedActiveWait chan struct{} + bulkDedicatedLaneLimit int + bulkDedicatedSidecarMu sync.Mutex + bulkDedicatedLanes map[uint32]*bulkDedicatedLane + bulkDedicatedNextLaneID uint32 + bulkAttachAttemptCount atomic.Int64 + bulkAttachRetryCount atomic.Int64 + bulkAttachSuccessCount atomic.Int64 + bulkAttachFallbackCount atomic.Int64 + connectionRetryState *connectionRetryState + securityReadyCheck bool + debugMode bool } func NewClient() Client { @@ -134,6 +146,8 @@ func NewClient() Client { client.fileEventObserver = normalizeFileEventCallback(nil) client.stopCtx, client.stopFn = context.WithCancel(context.Background()) client.sessionRuntime.Store(newClientSessionRuntimeBase(client.stopCtx, client.stopFn)) + client.setClientTransportProtectionProfile(defaultTransportProtectionProfile()) + client.peerAttachSecurity.Store(defaultPeerAttachSecurityState()) bindClientStreamControl(&client) bindClientBulkControl(&client) client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler) diff --git a/client_bulk.go b/client_bulk.go index 81a3467..15fbd83 100644 --- a/client_bulk.go +++ b/client_bulk.go @@ -294,7 +294,7 @@ func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender { if err := bulk.waitDedicatedReady(ctx); err != nil { return 0, err } - return c.sendDedicatedBulkWrite(ctx, bulk, startSeq, payload) + return c.sendDedicatedBulkWrite(ctx, bulk, startSeq, payload, payloadOwned) } if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) { return 0, errTransportDetached diff --git a/client_config.go b/client_config.go index a0aef26..744ba8a 100644 --- a/client_config.go +++ b/client_config.go @@ -65,32 +65,40 @@ func (c *ClientCommon) RecoverTransferSnapshots(ctx context.Context) error { } func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte { - return c.msgEn + return c.clientTransportProtectionSnapshot().msgEn } // Deprecated: SetMsgEn overrides the transport codec directly. // Prefer UseModernPSKClient or UseLegacySecurityClient. func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) { - c.msgEn = fn - c.fastStreamEncode = nil - c.fastBulkEncode = nil - c.fastPlainEncode = nil - c.modernPSKRuntime = nil + profile := c.clientTransportProtectionSnapshot() + profile.mode = ProtectionManaged + profile.msgEn = fn + profile.fastStreamEncode = nil + profile.fastBulkEncode = nil + profile.fastPlainEncode = nil + profile.runtime = nil + c.setClientTransportProtectionProfile(profile) + c.clearClientSecurityProfiles() c.securityReadyCheck = false } func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte { - return c.msgDe + return c.clientTransportProtectionSnapshot().msgDe } // Deprecated: SetMsgDe overrides the transport codec directly. // Prefer UseModernPSKClient or UseLegacySecurityClient. func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) { - c.msgDe = fn - c.fastStreamEncode = nil - c.fastBulkEncode = nil - c.fastPlainEncode = nil - c.modernPSKRuntime = nil + profile := c.clientTransportProtectionSnapshot() + profile.mode = ProtectionManaged + profile.msgDe = fn + profile.fastStreamEncode = nil + profile.fastBulkEncode = nil + profile.fastPlainEncode = nil + profile.runtime = nil + c.setClientTransportProtectionProfile(profile) + c.clearClientSecurityProfiles() c.securityReadyCheck = false } @@ -103,20 +111,24 @@ func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) { } func (c *ClientCommon) GetSecretKey() []byte { - return c.SecretKey + return c.clientTransportProtectionSnapshot().secretKey } // Deprecated: SetSecretKey injects a raw transport key directly. // Prefer UseModernPSKClient or UseLegacySecurityClient. func (c *ClientCommon) SetSecretKey(key []byte) { - c.SecretKey = key + profile := c.clientTransportProtectionSnapshot() + profile.mode = ProtectionManaged + profile.secretKey = cloneTransportProtectionKey(key) if len(key) == 0 { - c.modernPSKRuntime = nil + profile.runtime = nil } else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil { - c.modernPSKRuntime = runtime + profile.runtime = runtime } else { - c.modernPSKRuntime = nil + profile.runtime = nil } + c.setClientTransportProtectionProfile(profile) + c.clearClientSecurityProfiles() c.securityReadyCheck = len(key) == 0 c.skipKeyExchange = true } diff --git a/client_conn_attachment.go b/client_conn_attachment.go index 103a28c..0ed5fdd 100644 --- a/client_conn_attachment.go +++ b/client_conn_attachment.go @@ -6,17 +6,26 @@ import ( ) type clientConnAttachmentState struct { - maxReadTimeout time.Duration - maxWriteTimeout time.Duration - msgEn func([]byte, []byte) []byte - msgDe func([]byte, []byte) []byte - fastStreamEncode transportFastStreamEncoder - fastBulkEncode transportFastBulkEncoder - fastPlainEncode transportFastPlainEncoder - modernPSKRuntime *modernPSKCodecRuntime - handshakeRsaKey []byte - secretKey []byte - lastHeartBeat int64 + maxReadTimeout time.Duration + maxWriteTimeout time.Duration + authMode AuthMode + protectionMode ProtectionMode + msgEn func([]byte, []byte) []byte + msgDe func([]byte, []byte) []byte + fastStreamEncode transportFastStreamEncoder + fastBulkEncode transportFastBulkEncoder + fastPlainEncode transportFastPlainEncoder + modernPSKRuntime *modernPSKCodecRuntime + handshakeRsaKey []byte + secretKey []byte + keyMode string + sessionID []byte + forwardSecrecy bool + forwardSecrecyFallback bool + peerAttached bool + peerAttachFallback bool + peerAttachAt int64 + lastHeartBeat int64 } func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnAttachmentState { @@ -26,6 +35,7 @@ func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnA cloned := *src cloned.handshakeRsaKey = cloneClientConnAttachmentBytes(src.handshakeRsaKey) cloned.secretKey = cloneClientConnAttachmentBytes(src.secretKey) + cloned.sessionID = cloneClientConnAttachmentBytes(src.sessionID) return &cloned } @@ -153,6 +163,7 @@ func (c *ClientConn) applyClientConnAttachmentProfile(maxReadTimeout time.Durati c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { state.maxReadTimeout = maxReadTimeout state.maxWriteTimeout = maxWriteTimeout + state.protectionMode = ProtectionManaged state.msgEn = msgEn state.msgDe = msgDe state.modernPSKRuntime = nil @@ -206,6 +217,7 @@ func (c *ClientConn) clientConnMsgEnSnapshot() func([]byte, []byte) []byte { func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) { c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { + state.protectionMode = ProtectionManaged state.msgEn = fn state.fastStreamEncode = nil state.fastBulkEncode = nil @@ -223,6 +235,7 @@ func (c *ClientConn) clientConnMsgDeSnapshot() func([]byte, []byte) []byte { func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) { c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { + state.protectionMode = ProtectionManaged state.msgDe = fn state.fastStreamEncode = nil state.fastBulkEncode = nil @@ -319,6 +332,39 @@ func (c *LogicalConn) modernPSKRuntimeSnapshot() *modernPSKCodecRuntime { return nil } +func (c *LogicalConn) protectionModeSnapshot() ProtectionMode { + if state := c.attachmentStateRaw(); state != nil { + return state.protectionMode + } + return ProtectionManaged +} + +func (c *LogicalConn) authModeSnapshot() AuthMode { + if state := c.attachmentStateRaw(); state != nil { + return state.authMode + } + return AuthNone +} + +func (c *LogicalConn) peerAttachAuthenticatedSnapshot() (bool, bool, time.Time) { + if state := c.attachmentStateRaw(); state != nil { + if state.peerAttachAt == 0 { + return state.peerAttached, state.peerAttachFallback, time.Time{} + } + return state.peerAttached, state.peerAttachFallback, time.Unix(0, state.peerAttachAt) + } + return false, false, time.Time{} +} + +func (c *LogicalConn) markPeerAttachAuthenticated(authMode AuthMode, fallback bool, at time.Time) { + c.updateAttachmentState(func(state *clientConnAttachmentState) { + state.authMode = authMode + state.peerAttached = true + state.peerAttachFallback = fallback + state.peerAttachAt = at.UnixNano() + }) +} + func (c *LogicalConn) setModernPSKRuntime(runtime *modernPSKCodecRuntime) { c.updateAttachmentState(func(state *clientConnAttachmentState) { state.modernPSKRuntime = runtime diff --git a/client_connect_source.go b/client_connect_source.go index 0c34861..5722696 100644 --- a/client_connect_source.go +++ b/client_connect_source.go @@ -18,14 +18,18 @@ const ( var errClientReconnectSourceUnavailable = errors.New("client reconnect source is unavailable") type clientConnectSource struct { - kind string - network string - addr string - dialFn func(context.Context) (net.Conn, error) + kind string + network string + addr string + dialFn func(context.Context) (net.Conn, error) + supportsAdditional bool } func newClientConnConnectSource(conn net.Conn) *clientConnectSource { - source := &clientConnectSource{kind: clientConnectSourceConn} + source := &clientConnectSource{ + kind: clientConnectSourceConn, + supportsAdditional: false, + } if conn == nil { return source } @@ -43,9 +47,10 @@ func newClientConnConnectSource(conn net.Conn) *clientConnectSource { func newClientNetworkConnectSource(network string, addr string) *clientConnectSource { return &clientConnectSource{ - kind: clientConnectSourceNetwork, - network: network, - addr: addr, + kind: clientConnectSourceNetwork, + network: network, + addr: addr, + supportsAdditional: true, dialFn: func(context.Context) (net.Conn, error) { return transport.Dial(network, addr) }, @@ -54,9 +59,10 @@ func newClientNetworkConnectSource(network string, addr string) *clientConnectSo func newClientTimeoutConnectSource(network string, addr string, timeout time.Duration) *clientConnectSource { return &clientConnectSource{ - kind: clientConnectSourceTimeout, - network: network, - addr: addr, + kind: clientConnectSourceTimeout, + network: network, + addr: addr, + supportsAdditional: true, dialFn: func(context.Context) (net.Conn, error) { return transport.DialTimeout(network, addr, timeout) }, @@ -65,8 +71,9 @@ func newClientTimeoutConnectSource(network string, addr string, timeout time.Dur func newClientFactoryConnectSource(dialFn func(context.Context) (net.Conn, error)) *clientConnectSource { return &clientConnectSource{ - kind: clientConnectSourceFactory, - dialFn: dialFn, + kind: clientConnectSourceFactory, + dialFn: dialFn, + supportsAdditional: true, } } @@ -82,6 +89,10 @@ func (s *clientConnectSource) canReconnect() bool { return s != nil && s.dialFn != nil } +func (s *clientConnectSource) supportsAdditionalConn() bool { + return s != nil && s.supportsAdditional +} + func (s *clientConnectSource) isUDP() bool { if s == nil { return false diff --git a/client_legacy_security.go b/client_legacy_security.go index 1c16e43..b1e9071 100644 --- a/client_legacy_security.go +++ b/client_legacy_security.go @@ -31,7 +31,11 @@ func (c *ClientCommon) ExchangeKey(newKey []byte) error { if string(data.Value) != "success" { return errors.New("cannot exchange new aes-key") } - c.SecretKey = newKey + profile := c.clientTransportProtectionSnapshot() + profile.mode = ProtectionManaged + profile.secretKey = cloneTransportProtectionKey(newKey) + profile.runtime = nil + c.setClientTransportProtectionProfile(profile) time.Sleep(time.Millisecond * 100) return nil } diff --git a/client_runtime.go b/client_runtime.go index cdd3a13..29ef4ec 100644 --- a/client_runtime.go +++ b/client_runtime.go @@ -350,6 +350,8 @@ func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime, if rt == nil { return nil } + c.resetClientPeerAttachAuth() + c.activateClientBootstrapTransportProtection() if runKeyExchange && !c.skipKeyExchange { if err := c.keyExchangeFn(c); err != nil { return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err) @@ -358,6 +360,7 @@ func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime, if err := c.announceClientPeerIdentity(); err != nil { return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err) } + c.activateClientSteadyTransportProtection() return nil } diff --git a/client_transport.go b/client_transport.go index 35ab280..b3cf6c6 100644 --- a/client_transport.go +++ b/client_transport.go @@ -174,28 +174,37 @@ func (c *ClientCommon) dispatchTransportPayloadFast(payload []byte, release func } return } - if c.tryDispatchBorrowedBulkTransportPayload(payload) { - if release != nil { - release() + plain, plainRelease, err := c.decryptTransportPayloadPooled(payload, release) + if err != nil { + if c.showError || c.debugMode { + fmt.Println("client decode transport payload error", err) } return } - owned := append([]byte(nil), payload...) - if release != nil { - release() + if c.tryDispatchBorrowedTransportPlain(plain, plainRelease) { + return } if dispatcher == nil { now := time.Now() - if err := c.dispatchInboundTransportPayload(owned, now); err != nil && (c.showError || c.debugMode) { + err := c.dispatchInboundTransportPlain(plain, now) + if plainRelease != nil { + plainRelease() + } + if err != nil && (c.showError || c.debugMode) { fmt.Println("client decode envelope error", err) } return } + owned := plain + if plainRelease != nil { + owned = append([]byte(nil), plain...) + plainRelease() + } c.wg.Add(1) if !dispatcher.Dispatch(clientInboundDispatchSource(), func() { defer c.wg.Done() now := time.Now() - if err := c.dispatchInboundTransportPayload(owned, now); err != nil && (c.showError || c.debugMode) { + if err := c.dispatchInboundTransportPlain(owned, now); err != nil && (c.showError || c.debugMode) { fmt.Println("client decode envelope error", err) } }) { diff --git a/clienttype.go b/clienttype.go index bace9a6..43bdd0c 100644 --- a/clienttype.go +++ b/clienttype.go @@ -26,6 +26,8 @@ type Client interface { BulkOpenTuning() BulkOpenTuning SetBulkDedicatedAttachConfig(BulkDedicatedAttachConfig) BulkDedicatedAttachConfig() BulkDedicatedAttachConfig + SetPeerAttachSecurityConfig(PeerAttachSecurityConfig) error + PeerAttachSecurityConfig() PeerAttachSecurityConfig SetFileReceiveDir(dir string) error send(msg TransferMsg) (WaitMsg, error) sendEnvelope(env Envelope) error diff --git a/logical_conn.go b/logical_conn.go index 00cc98e..63aa755 100644 --- a/logical_conn.go +++ b/logical_conn.go @@ -404,6 +404,7 @@ func (c *LogicalConn) applyAttachmentProfile(maxReadTimeout time.Duration, maxWr c.updateAttachmentState(func(state *clientConnAttachmentState) { state.maxReadTimeout = maxReadTimeout state.maxWriteTimeout = maxWriteTimeout + state.protectionMode = ProtectionManaged state.msgEn = msgEn state.msgDe = msgDe state.fastStreamEncode = fastStreamEncode @@ -525,17 +526,23 @@ func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot { return ClientConnRuntimeSnapshot{} } status := c.Status() + authenticated, fallback, attachAt := c.peerAttachAuthenticatedSnapshot() now := time.Now() snapshot := ClientConnRuntimeSnapshot{ - ClientID: c.clientIDSnapshot(), - Alive: status.Alive, - Reason: status.Reason, - IdentityBound: c.clientConnIdentityBoundSnapshot(), - UsesStreamTransport: c.usesStreamTransportSnapshot(), - TransportGeneration: c.transportGenerationSnapshot(), - TransportAttachCount: c.clientConnTransportAttachCountSnapshot(), - TransportDetachCount: c.clientConnTransportDetachCountSnapshot(), - LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(), + ClientID: c.clientIDSnapshot(), + Alive: status.Alive, + Reason: status.Reason, + IdentityBound: c.clientConnIdentityBoundSnapshot(), + UsesStreamTransport: c.usesStreamTransportSnapshot(), + TransportGeneration: c.transportGenerationSnapshot(), + TransportAttachCount: c.clientConnTransportAttachCountSnapshot(), + TransportDetachCount: c.clientConnTransportDetachCountSnapshot(), + LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(), + AuthMode: authModeName(c.authModeSnapshot()), + ProtectionMode: protectionModeName(c.protectionModeSnapshot()), + PeerAttachAuthenticated: authenticated, + PeerAttachAuthFallback: fallback, + LastPeerAttachAt: attachAt, } if status.Err != nil { snapshot.Error = status.Err.Error() diff --git a/logical_transport_peer_fields_test.go b/logical_transport_peer_fields_test.go index 0fb5795..e34e0ae 100644 --- a/logical_transport_peer_fields_test.go +++ b/logical_transport_peer_fields_test.go @@ -43,6 +43,20 @@ func TestHydrateServerMessagePeerFieldsFromLogicalConn(t *testing.T) { } } +func TestHydrateServerMessagePeerFieldsZeroValueDoesNotPanic(t *testing.T) { + message := hydrateServerMessagePeerFields(Message{}) + + if message.LogicalConn != nil { + t.Fatal("zero-value message should not hydrate logical conn") + } + if message.ClientConn != nil { + t.Fatal("zero-value message should not hydrate client conn") + } + if message.TransportConn != nil { + t.Fatal("zero-value message should not hydrate transport conn") + } +} + func TestServerPublishSendFileEventHydratesLogicalConnAlias(t *testing.T) { server := NewServer().(*ServerCommon) left, right := net.Pipe() diff --git a/msg.go b/msg.go index 1875c4f..c77195b 100644 --- a/msg.go +++ b/msg.go @@ -37,9 +37,10 @@ type Message struct { NetType LogicalConn *LogicalConn // Deprecated: ClientConn aliases LogicalConn for compatibility. - ClientConn *ClientConn - TransportConn *TransportConn - ServerConn Client + ClientConn *ClientConn + TransportConn *TransportConn + ServerConn Client + inboundTransportProfile *transportProtectionProfile TransferMsg Time time.Time inboundConn net.Conn @@ -58,7 +59,7 @@ type messageLogicalTransferSender interface { } type messageInboundTransferSender interface { - sendTransferInbound(*LogicalConn, *TransportConn, net.Conn, TransferMsg) error + sendTransferInbound(*LogicalConn, *TransportConn, net.Conn, *transportProtectionProfile, TransferMsg) error } func (m *Message) Reply(value MsgVal) (err error) { @@ -86,7 +87,7 @@ func (m *Message) Reply(value MsgVal) (err error) { if sender == nil { return transportDetachedErrorForPeer(logical, transport) } - return sender.sendTransferInbound(logical, transport, m.inboundConn, reply) + return sender.sendTransferInbound(logical, transport, m.inboundConn, messageInboundTransportProtectionSnapshot(m), reply) } if transport != nil { _, err = transport.sendTransfer(reply) @@ -123,12 +124,19 @@ func hydrateServerMessagePeerFields(message Message) Message { if message.LogicalConn == nil { message.LogicalConn = logicalConnFromClient(message.ClientConn) } - if message.ClientConn == nil { + if message.LogicalConn == nil && message.TransportConn != nil { + message.LogicalConn = message.TransportConn.logicalConnSnapshot() + } + if message.ClientConn == nil && message.LogicalConn != nil { message.ClientConn = message.LogicalConn.compatClientConn() } if message.TransportConn == nil && message.LogicalConn != nil { message.TransportConn = message.LogicalConn.CurrentTransportConn() } + if message.inboundConn != nil && message.inboundTransportProfile == nil && message.LogicalConn != nil { + profile := message.LogicalConn.transportProtectionProfileSnapshot() + message.inboundTransportProfile = &profile + } return message } @@ -155,3 +163,22 @@ func messageTransportConnSnapshot(message *Message) *TransportConn { } return logical.CurrentTransportConn() } + +func messageInboundTransportProtectionSnapshot(message *Message) *transportProtectionProfile { + if message == nil { + return nil + } + if message.inboundTransportProfile != nil { + return message.inboundTransportProfile + } + if message.inboundConn == nil { + return nil + } + logical := messageLogicalConnSnapshot(message) + if logical == nil { + return nil + } + profile := logical.transportProtectionProfileSnapshot() + message.inboundTransportProfile = &profile + return message.inboundTransportProfile +} diff --git a/peer_attach_auth.go b/peer_attach_auth.go new file mode 100644 index 0000000..b42d408 --- /dev/null +++ b/peer_attach_auth.go @@ -0,0 +1,517 @@ +package notify + +import ( + "bytes" + "crypto/hmac" + cryptorand "crypto/rand" + "crypto/sha256" + "encoding/binary" + "errors" + "net" + "sync" + "sync/atomic" + "time" +) + +const ( + peerAttachFeatureExplicitAuth uint64 = 1 << iota + peerAttachFeatureChannelBinding + peerAttachFeatureForwardSecrecy +) + +const ( + peerAttachNonceSize = 16 + peerAttachReplayTTL = 30 * time.Second +) + +var ( + errPeerAttachAuthInvalid = errors.New("peer attach auth invalid") + errPeerAttachReplayRejected = errors.New("peer attach replay rejected") + errPeerAttachReplayWindowFull = errors.New("peer attach replay window full") + errPeerAttachExplicitAuthRequired = errors.New("peer attach explicit auth required") + errPeerAttachChannelBindingRequired = errors.New("peer attach channel binding required") + errPeerAttachChannelBindingUnavailable = errors.New("peer attach channel binding unavailable") + errPeerAttachForwardSecrecyRequired = errors.New("peer attach forward secrecy required") +) + +type peerAttachAuthResult struct { + explicit bool + fallback bool + clientNonce []byte + serverNonce []byte + channelBinding []byte + clientECDHEPublicKey []byte +} + +type peerAttachReplayCache struct { + mu sync.Mutex + entries map[string]time.Time + rejects atomic.Int64 + overflowRejects atomic.Int64 +} + +func newPeerAttachNonce() ([]byte, error) { + buf := make([]byte, peerAttachNonceSize) + if _, err := cryptorand.Read(buf); err != nil { + return nil, err + } + return buf, nil +} + +func appendPeerAttachAuthBytes(dst []byte, data []byte) []byte { + dst = binary.BigEndian.AppendUint32(dst, uint32(len(data))) + return append(dst, data...) +} + +func appendPeerAttachAuthString(dst []byte, value string) []byte { + return appendPeerAttachAuthBytes(dst, []byte(value)) +} + +func appendPeerAttachAuthBool(dst []byte, value bool) []byte { + if value { + return append(dst, 1) + } + return append(dst, 0) +} + +func peerAttachRequestAuthPayload(req peerAttachRequest, channelBinding []byte) []byte { + buf := make([]byte, 0, 96+len(req.PeerID)+len(channelBinding)) + buf = appendPeerAttachAuthString(buf, "notify/peer-attach/request-auth/v1") + buf = binary.BigEndian.AppendUint64(buf, req.Features) + buf = appendPeerAttachAuthString(buf, req.PeerID) + buf = appendPeerAttachAuthBytes(buf, req.ClientNonce) + if supportsPeerAttachChannelBinding(req.Features) { + buf = appendPeerAttachAuthBytes(buf, channelBinding) + } + return buf +} + +func peerAttachResponseAuthPayload(req peerAttachRequest, resp peerAttachResponse, channelBinding []byte) []byte { + buf := make([]byte, 0, 160+len(req.PeerID)+len(resp.PeerID)+len(resp.Error)+len(channelBinding)) + buf = appendPeerAttachAuthString(buf, "notify/peer-attach/response-auth/v1") + buf = binary.BigEndian.AppendUint64(buf, req.Features) + buf = appendPeerAttachAuthString(buf, req.PeerID) + buf = appendPeerAttachAuthBytes(buf, req.ClientNonce) + buf = binary.BigEndian.AppendUint64(buf, resp.Features) + buf = appendPeerAttachAuthString(buf, resp.PeerID) + buf = appendPeerAttachAuthBool(buf, resp.Accepted) + buf = appendPeerAttachAuthBool(buf, resp.Reused) + buf = appendPeerAttachAuthString(buf, resp.Error) + buf = appendPeerAttachAuthBytes(buf, resp.ServerNonce) + if supportsPeerAttachChannelBinding(resp.Features) { + buf = appendPeerAttachAuthBytes(buf, channelBinding) + } + return buf +} + +func signPeerAttachPayload(secretKey []byte, payload []byte) []byte { + if len(secretKey) == 0 { + return nil + } + mac := hmac.New(sha256.New, secretKey) + _, _ = mac.Write(payload) + return mac.Sum(nil) +} + +func computePeerAttachRequestAuthTag(secretKey []byte, req peerAttachRequest, channelBinding []byte) []byte { + return signPeerAttachPayload(secretKey, peerAttachRequestAuthPayload(req, channelBinding)) +} + +func computePeerAttachResponseAuthTag(secretKey []byte, req peerAttachRequest, resp peerAttachResponse, channelBinding []byte) []byte { + return signPeerAttachPayload(secretKey, peerAttachResponseAuthPayload(req, resp, channelBinding)) +} + +func supportsExplicitPeerAttachAuth(features uint64) bool { + return features&peerAttachFeatureExplicitAuth != 0 +} + +func supportsPeerAttachChannelBinding(features uint64) bool { + return features&peerAttachFeatureChannelBinding != 0 +} + +func supportsPeerAttachForwardSecrecy(features uint64) bool { + return features&peerAttachFeatureForwardSecrecy != 0 +} + +func classifyPeerAttachRejectCounter(s *ServerCommon, err error) { + if s == nil || err == nil { + return + } + switch { + case errors.Is(err, errPeerAttachReplayRejected): + s.peerAttachReplay.rejects.Add(1) + case errors.Is(err, errPeerAttachReplayWindowFull): + s.peerAttachReplay.overflowRejects.Add(1) + case errors.Is(err, errPeerAttachExplicitAuthRequired), errors.Is(err, errPeerAttachChannelBindingRequired), errors.Is(err, errPeerAttachForwardSecrecyRequired): + s.peerAttachDowngradeRejectCount.Add(1) + case errors.Is(err, errPeerAttachChannelBindingUnavailable): + s.peerAttachBindingRejectCount.Add(1) + default: + s.peerAttachAuthRejectCount.Add(1) + } +} + +func (c *ClientCommon) shouldUseExplicitPeerAttachAuth() bool { + if c == nil || !c.securityConfigured { + return false + } + return c.securityAuthMode == AuthPSK && len(c.securityBootstrap.secretKey) != 0 +} + +func resolvePeerAttachChannelBinding(provider PeerAttachChannelBindingProvider, role PeerAttachChannelBindingRole, peerID string, conn net.Conn) ([]byte, error) { + if provider == nil { + return nil, nil + } + if conn == nil { + return nil, errPeerAttachChannelBindingUnavailable + } + binding, err := provider(PeerAttachChannelBindingContext{ + Role: role, + PeerID: peerID, + Conn: conn, + }) + if err != nil || len(binding) == 0 { + return nil, errPeerAttachChannelBindingUnavailable + } + return bytes.Clone(binding), nil +} + +func (c *ClientCommon) buildPeerAttachRequest(peerID string) (peerAttachRequest, peerAttachRequestState, error) { + cfg := c.peerAttachSecuritySnapshot() + req := peerAttachRequest{ + PeerID: stringsTrimSpaceNoAlloc(peerID), + } + if !c.shouldUseExplicitPeerAttachAuth() { + if c.clientRequiresForwardSecrecy() { + return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachForwardSecrecyRequired + } + if cfg.requireExplicitAuth { + return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachExplicitAuthRequired + } + return req, peerAttachRequestState{}, nil + } + nonce, err := newPeerAttachNonce() + if err != nil { + return peerAttachRequest{}, peerAttachRequestState{}, err + } + req.Features = peerAttachFeatureExplicitAuth + requestState := peerAttachRequestState{} + var channelBinding []byte + if cfg.channelBinding != nil { + channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleClient, req.PeerID, c.clientTransportConnSnapshot()) + if err != nil { + return peerAttachRequest{}, peerAttachRequestState{}, err + } + } + if len(channelBinding) != 0 { + req.Features |= peerAttachFeatureChannelBinding + } + if cfg.requireChannelBinding && !supportsPeerAttachChannelBinding(req.Features) { + return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachChannelBindingRequired + } + if c.clientSupportsForwardSecrecy() { + requestState.forwardSecrecy, err = newPeerAttachForwardSecrecyClientState() + if err != nil { + return peerAttachRequest{}, peerAttachRequestState{}, err + } + req.Features |= peerAttachFeatureForwardSecrecy + req.ClientECDHEPublicKey = bytes.Clone(requestState.forwardSecrecy.publicKey) + } + req.ClientNonce = nonce + req.AuthTag = computePeerAttachRequestAuthTag(c.securityBootstrap.secretKey, req, channelBinding) + return req, requestState, nil +} + +func (c *ClientCommon) verifyPeerAttachResponse(req peerAttachRequest, resp peerAttachResponse, requestState peerAttachRequestState) (peerAttachResponseVerifyResult, error) { + cfg := c.peerAttachSecuritySnapshot() + baseSteady := transportProtectionProfile{} + if c != nil { + baseSteady = c.securitySteady.clone().withForwardSecrecyFallback(false) + } + result := peerAttachResponseVerifyResult{steadyProfile: baseSteady} + if c == nil || !c.shouldUseExplicitPeerAttachAuth() { + if c.clientRequiresForwardSecrecy() { + return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired + } + if cfg.requireExplicitAuth { + return peerAttachResponseVerifyResult{}, errPeerAttachExplicitAuthRequired + } + return result, nil + } + if !supportsExplicitPeerAttachAuth(resp.Features) { + if c.clientRequiresForwardSecrecy() { + return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired + } + if cfg.requireExplicitAuth { + return peerAttachResponseVerifyResult{}, errPeerAttachExplicitAuthRequired + } + if c.clientSupportsForwardSecrecy() { + result.steadyProfile = result.steadyProfile.withForwardSecrecyFallback(true) + } + result.authFallback = true + return result, nil + } + var channelBinding []byte + if supportsPeerAttachChannelBinding(req.Features) { + if cfg.channelBinding == nil { + return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingUnavailable + } + if !supportsPeerAttachChannelBinding(resp.Features) { + return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingRequired + } + var err error + channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleClient, req.PeerID, c.clientTransportConnSnapshot()) + if err != nil { + return peerAttachResponseVerifyResult{}, err + } + } else if cfg.requireChannelBinding { + return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingRequired + } + if len(resp.ServerNonce) != peerAttachNonceSize { + return peerAttachResponseVerifyResult{}, errPeerAttachAuthInvalid + } + expected := computePeerAttachResponseAuthTag(c.securityBootstrap.secretKey, req, peerAttachResponse{ + PeerID: resp.PeerID, + Accepted: resp.Accepted, + Reused: resp.Reused, + Error: resp.Error, + Features: resp.Features, + ServerNonce: resp.ServerNonce, + }, channelBinding) + if !hmac.Equal(resp.AuthTag, expected) { + return peerAttachResponseVerifyResult{}, errPeerAttachAuthInvalid + } + if requestState.forwardSecrecy == nil || !supportsPeerAttachForwardSecrecy(req.Features) { + return result, nil + } + if !supportsPeerAttachForwardSecrecy(resp.Features) { + if c.clientRequiresForwardSecrecy() { + return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired + } + result.steadyProfile = result.steadyProfile.withForwardSecrecyFallback(true) + return result, nil + } + if resp.KeyMode != "" && resp.KeyMode != peerAttachKeyModeECDHE { + return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyInvalid + } + profile, err := derivePeerAttachForwardSecrecyTransportProfile(c.securitySteady, c.securityBootstrap.secretKey, requestState.forwardSecrecy.privateKey, resp.ServerECDHEPublicKey, req, resp) + if err != nil { + return peerAttachResponseVerifyResult{}, err + } + result.steadyProfile = profile + return result, nil +} + +func (s *ServerCommon) validatePeerAttachRequestAuth(logical *LogicalConn, transport net.Conn, req peerAttachRequest) (peerAttachAuthResult, error) { + cfg := s.peerAttachSecuritySnapshot() + if !supportsExplicitPeerAttachAuth(req.Features) { + if s.serverRequiresForwardSecrecy() { + return peerAttachAuthResult{}, errPeerAttachForwardSecrecyRequired + } + if cfg.requireExplicitAuth { + return peerAttachAuthResult{}, errPeerAttachExplicitAuthRequired + } + return peerAttachAuthResult{fallback: true}, nil + } + if logical == nil { + return peerAttachAuthResult{}, errPeerAttachAuthInvalid + } + var channelBinding []byte + if supportsPeerAttachChannelBinding(req.Features) { + if cfg.channelBinding == nil { + return peerAttachAuthResult{}, errPeerAttachChannelBindingUnavailable + } + var err error + channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleServer, req.PeerID, transport) + if err != nil { + return peerAttachAuthResult{}, err + } + } else if cfg.requireChannelBinding { + return peerAttachAuthResult{}, errPeerAttachChannelBindingRequired + } + secretKey := logical.secretKeySnapshot() + if len(secretKey) == 0 { + return peerAttachAuthResult{}, errPeerAttachAuthInvalid + } + if len(req.ClientNonce) != peerAttachNonceSize { + return peerAttachAuthResult{}, errPeerAttachAuthInvalid + } + expected := computePeerAttachRequestAuthTag(secretKey, peerAttachRequest{ + PeerID: req.PeerID, + Features: req.Features, + ClientNonce: req.ClientNonce, + }, channelBinding) + if !hmac.Equal(req.AuthTag, expected) { + return peerAttachAuthResult{}, errPeerAttachAuthInvalid + } + if supportsPeerAttachForwardSecrecy(req.Features) { + if len(req.ClientECDHEPublicKey) != peerAttachECDHEPublicKeySize { + return peerAttachAuthResult{}, errPeerAttachForwardSecrecyInvalid + } + } + if err := s.acceptPeerAttachReplay(req.PeerID, req.ClientNonce, time.Now(), cfg.replayWindow, cfg.replayCapacity); err != nil { + return peerAttachAuthResult{}, err + } + serverNonce, err := newPeerAttachNonce() + if err != nil { + return peerAttachAuthResult{}, err + } + return peerAttachAuthResult{ + explicit: true, + clientNonce: bytes.Clone(req.ClientNonce), + serverNonce: serverNonce, + channelBinding: channelBinding, + clientECDHEPublicKey: bytes.Clone(req.ClientECDHEPublicKey), + }, nil +} + +func (s *ServerCommon) signPeerAttachResponse(logical *LogicalConn, req peerAttachRequest, resp *peerAttachResponse, auth peerAttachAuthResult) { + if s == nil || logical == nil || resp == nil || !auth.explicit { + return + } + secretKey := logical.secretKeySnapshot() + if len(secretKey) == 0 { + return + } + resp.Features |= peerAttachFeatureExplicitAuth + if len(auth.channelBinding) != 0 { + resp.Features |= peerAttachFeatureChannelBinding + } + resp.ServerNonce = bytes.Clone(auth.serverNonce) + resp.AuthTag = computePeerAttachResponseAuthTag(secretKey, req, *resp, auth.channelBinding) +} + +func (s *ServerCommon) preparePeerAttachSteadyTransportProfile(logical *LogicalConn, req peerAttachRequest, resp *peerAttachResponse, auth peerAttachAuthResult) (transportProtectionProfile, error) { + if s == nil { + return transportProtectionProfile{}, nil + } + profile := s.securitySteady.clone().withForwardSecrecyFallback(false) + if resp != nil && auth.explicit { + resp.Features |= peerAttachFeatureExplicitAuth + if len(auth.channelBinding) != 0 { + resp.Features |= peerAttachFeatureChannelBinding + } + resp.ServerNonce = bytes.Clone(auth.serverNonce) + } + if resp != nil && resp.KeyMode == "" { + resp.KeyMode = profile.keyMode + } + if !s.serverSupportsForwardSecrecy() { + return profile, nil + } + if !auth.explicit || !supportsPeerAttachForwardSecrecy(req.Features) { + if s.serverRequiresForwardSecrecy() { + return transportProtectionProfile{}, errPeerAttachForwardSecrecyRequired + } + return profile.withForwardSecrecyFallback(true), nil + } + fsState, err := newPeerAttachForwardSecrecyClientState() + if err != nil { + return transportProtectionProfile{}, err + } + if resp != nil { + resp.Features |= peerAttachFeatureForwardSecrecy + resp.KeyMode = peerAttachKeyModeECDHE + resp.ServerECDHEPublicKey = bytes.Clone(fsState.publicKey) + } + return derivePeerAttachForwardSecrecyTransportProfile(s.securitySteady, logical.secretKeySnapshot(), fsState.privateKey, auth.clientECDHEPublicKey, req, *resp) +} + +func (s *ServerCommon) acceptPeerAttachReplay(peerID string, nonce []byte, now time.Time, window time.Duration, capacity int) error { + if s == nil || len(nonce) == 0 { + return nil + } + cache := &s.peerAttachReplay + key := peerID + "\x00" + string(nonce) + expireBefore := now.Add(-window) + cache.mu.Lock() + defer cache.mu.Unlock() + if cache.entries == nil { + cache.entries = make(map[string]time.Time) + } + for replayKey, seenAt := range cache.entries { + if seenAt.Before(expireBefore) { + delete(cache.entries, replayKey) + } + } + if seenAt, ok := cache.entries[key]; ok && !seenAt.Before(expireBefore) { + return errPeerAttachReplayRejected + } + if capacity > 0 && len(cache.entries) >= capacity { + return errPeerAttachReplayWindowFull + } + cache.entries[key] = now + return nil +} + +func (s *ServerCommon) peerAttachReplayRejectCountSnapshot() int64 { + if s == nil { + return 0 + } + return s.peerAttachReplay.rejects.Load() +} + +func (s *ServerCommon) peerAttachReplayOverflowRejectCountSnapshot() int64 { + if s == nil { + return 0 + } + return s.peerAttachReplay.overflowRejects.Load() +} + +func (c *ClientCommon) markClientPeerAttachAuthenticated(fallback bool, at time.Time) { + if c == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + c.peerAttachAuthenticated = true + c.peerAttachAuthFallback = fallback + c.peerAttachAt = at.UnixNano() +} + +func (c *ClientCommon) resetClientPeerAttachAuth() { + if c == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + c.peerAttachAuthenticated = false + c.peerAttachAuthFallback = false + c.peerAttachAt = 0 +} + +func (c *ClientCommon) clientPeerAttachAuthSnapshot() (bool, bool, time.Time) { + if c == nil { + return false, false, time.Time{} + } + c.mu.Lock() + defer c.mu.Unlock() + if c.peerAttachAt == 0 { + return c.peerAttachAuthenticated, c.peerAttachAuthFallback, time.Time{} + } + return c.peerAttachAuthenticated, c.peerAttachAuthFallback, time.Unix(0, c.peerAttachAt) +} + +func stringsTrimSpaceNoAlloc(value string) string { + start := 0 + for start < len(value) { + switch value[start] { + case ' ', '\t', '\n', '\r': + start++ + default: + goto endStart + } + } + return "" +endStart: + end := len(value) + for end > start { + switch value[end-1] { + case ' ', '\t', '\n', '\r': + end-- + default: + return value[start:end] + } + } + return value[start:end] +} diff --git a/peer_attach_auth_test.go b/peer_attach_auth_test.go new file mode 100644 index 0000000..d60fe3c --- /dev/null +++ b/peer_attach_auth_test.go @@ -0,0 +1,237 @@ +package notify + +import ( + "bytes" + "errors" + "testing" +) + +func newPeerAttachAuthLogicalForTest(t *testing.T, server *ServerCommon) *LogicalConn { + t.Helper() + logical := newServerLogicalConn(server, "accepted-auth", nil) + logical = server.registerAcceptedLogical(logical) + if logical == nil { + t.Fatal("registerAcceptedLogical returned nil") + } + return logical +} + +func TestPeerAttachExplicitAuthHelpersRoundTrip(t *testing.T) { + secret := []byte("correct horse battery staple") + client := NewClient().(*ClientCommon) + server := NewServer().(*ServerCommon) + if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + logical := newPeerAttachAuthLogicalForTest(t, server) + req, reqState, err := client.buildPeerAttachRequest("peer-explicit") + if err != nil { + t.Fatalf("buildPeerAttachRequest failed: %v", err) + } + if !supportsExplicitPeerAttachAuth(req.Features) { + t.Fatalf("request features = %d, want explicit auth bit", req.Features) + } + if !supportsPeerAttachForwardSecrecy(req.Features) { + t.Fatalf("request features = %d, want forward secrecy bit", req.Features) + } + if len(req.ClientNonce) != peerAttachNonceSize { + t.Fatalf("client nonce length = %d, want %d", len(req.ClientNonce), peerAttachNonceSize) + } + + auth, err := server.validatePeerAttachRequestAuth(logical, nil, req) + if err != nil { + t.Fatalf("validatePeerAttachRequestAuth failed: %v", err) + } + if !auth.explicit || auth.fallback { + t.Fatalf("auth result mismatch: %+v", auth) + } + + resp := peerAttachResponse{ + PeerID: req.PeerID, + Accepted: true, + } + server.signPeerAttachResponse(logical, req, &resp, auth) + if !supportsExplicitPeerAttachAuth(resp.Features) { + t.Fatalf("response features = %d, want explicit auth bit", resp.Features) + } + if len(resp.ServerNonce) != peerAttachNonceSize { + t.Fatalf("server nonce length = %d, want %d", len(resp.ServerNonce), peerAttachNonceSize) + } + + verifyResult, err := client.verifyPeerAttachResponse(req, resp, reqState) + if err != nil { + t.Fatalf("verifyPeerAttachResponse failed: %v", err) + } + if verifyResult.authFallback { + t.Fatal("explicit response should not be marked as fallback") + } + if !verifyResult.steadyProfile.forwardSecrecyFallback { + t.Fatal("response without fs extension should mark forward secrecy fallback") + } + + resp.AuthTag[0] ^= 0xff + if verifyResult, err = client.verifyPeerAttachResponse(req, resp, reqState); !errors.Is(err, errPeerAttachAuthInvalid) { + t.Fatalf("tampered response error = %v, want %v", err, errPeerAttachAuthInvalid) + } else if verifyResult.authFallback { + t.Fatal("tampered explicit response should not be treated as fallback") + } +} + +func TestPeerAttachRequestAuthRejectsReplay(t *testing.T) { + secret := []byte("correct horse battery staple") + client := NewClient().(*ClientCommon) + server := NewServer().(*ServerCommon) + if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + logical := newPeerAttachAuthLogicalForTest(t, server) + req, _, err := client.buildPeerAttachRequest("peer-replay") + if err != nil { + t.Fatalf("buildPeerAttachRequest failed: %v", err) + } + + if _, err := server.validatePeerAttachRequestAuth(logical, nil, req); err != nil { + t.Fatalf("first validatePeerAttachRequestAuth failed: %v", err) + } + if _, err := server.validatePeerAttachRequestAuth(logical, nil, req); !errors.Is(err, errPeerAttachReplayRejected) { + t.Fatalf("second validatePeerAttachRequestAuth error = %v, want %v", err, errPeerAttachReplayRejected) + } + classifyPeerAttachRejectCounter(server, errPeerAttachReplayRejected) + if got, want := server.peerAttachReplayRejectCountSnapshot(), int64(1); got != want { + t.Fatalf("replay reject count = %d, want %d", got, want) + } +} + +func TestPeerAttachAuthFallbackCompatibility(t *testing.T) { + secret := []byte("correct horse battery staple") + client := NewClient().(*ClientCommon) + server := NewServer().(*ServerCommon) + if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + logical := newPeerAttachAuthLogicalForTest(t, server) + auth, err := server.validatePeerAttachRequestAuth(logical, nil, peerAttachRequest{PeerID: "peer-fallback"}) + if err != nil { + t.Fatalf("validatePeerAttachRequestAuth fallback failed: %v", err) + } + if auth.explicit || !auth.fallback { + t.Fatalf("fallback auth result mismatch: %+v", auth) + } + + req, reqState, err := client.buildPeerAttachRequest("peer-fallback") + if err != nil { + t.Fatalf("buildPeerAttachRequest failed: %v", err) + } + verifyResult, err := client.verifyPeerAttachResponse(req, peerAttachResponse{ + PeerID: req.PeerID, + Accepted: true, + }, reqState) + if err != nil { + t.Fatalf("verifyPeerAttachResponse fallback failed: %v", err) + } + if !verifyResult.authFallback { + t.Fatal("unsigned legacy response should be marked as fallback") + } +} + +func TestPeerAttachForwardSecrecyNegotiatesDerivedSteadyProfile(t *testing.T) { + secret := []byte("correct horse battery staple") + client := NewClient().(*ClientCommon) + server := NewServer().(*ServerCommon) + if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + logical := newPeerAttachAuthLogicalForTest(t, server) + req, reqState, err := client.buildPeerAttachRequest("peer-fs") + if err != nil { + t.Fatalf("buildPeerAttachRequest failed: %v", err) + } + auth, err := server.validatePeerAttachRequestAuth(logical, nil, req) + if err != nil { + t.Fatalf("validatePeerAttachRequestAuth failed: %v", err) + } + + resp := peerAttachResponse{ + PeerID: req.PeerID, + Accepted: true, + } + serverProfile, err := server.preparePeerAttachSteadyTransportProfile(logical, req, &resp, auth) + if err != nil { + t.Fatalf("preparePeerAttachSteadyTransportProfile failed: %v", err) + } + server.signPeerAttachResponse(logical, req, &resp, auth) + verifyResult, err := client.verifyPeerAttachResponse(req, resp, reqState) + if err != nil { + t.Fatalf("verifyPeerAttachResponse failed: %v", err) + } + + if !supportsPeerAttachForwardSecrecy(resp.Features) { + t.Fatalf("response features = %d, want forward secrecy bit", resp.Features) + } + if resp.KeyMode != peerAttachKeyModeECDHE { + t.Fatalf("response key mode = %q, want %q", resp.KeyMode, peerAttachKeyModeECDHE) + } + if !verifyResult.steadyProfile.forwardSecrecy { + t.Fatal("client steady profile should enable forward secrecy") + } + if !serverProfile.forwardSecrecy { + t.Fatal("server steady profile should enable forward secrecy") + } + if len(verifyResult.steadyProfile.sessionID) == 0 { + t.Fatal("client session id should be populated") + } + if !bytes.Equal(verifyResult.steadyProfile.secretKey, serverProfile.secretKey) { + t.Fatal("client/server derived steady keys should match") + } + if !bytes.Equal(verifyResult.steadyProfile.sessionID, serverProfile.sessionID) { + t.Fatal("client/server session ids should match") + } + if bytes.Equal(verifyResult.steadyProfile.secretKey, client.securityBootstrap.secretKey) { + t.Fatal("derived steady key should differ from bootstrap key") + } +} + +func TestPeerAttachForwardSecrecyStrictRejectsFallback(t *testing.T) { + secret := []byte("correct horse battery staple") + client := NewClient().(*ClientCommon) + server := NewServer().(*ServerCommon) + opts := testModernPSKOptions() + opts.RequireForwardSecrecy = true + if err := UseModernPSKClient(client, secret, opts); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := UseModernPSKServer(server, secret, opts); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + req, reqState, err := client.buildPeerAttachRequest("peer-fs-strict") + if err != nil { + t.Fatalf("buildPeerAttachRequest failed: %v", err) + } + _, err = client.verifyPeerAttachResponse(req, peerAttachResponse{ + PeerID: req.PeerID, + Accepted: true, + Features: peerAttachFeatureExplicitAuth, + ServerNonce: make([]byte, peerAttachNonceSize), + AuthTag: computePeerAttachResponseAuthTag(client.securityBootstrap.secretKey, req, peerAttachResponse{PeerID: req.PeerID, Accepted: true, Features: peerAttachFeatureExplicitAuth, ServerNonce: make([]byte, peerAttachNonceSize)}, nil), + }, reqState) + if !errors.Is(err, errPeerAttachForwardSecrecyRequired) { + t.Fatalf("verifyPeerAttachResponse error = %v, want %v", err, errPeerAttachForwardSecrecyRequired) + } +} diff --git a/peer_attach_policy.go b/peer_attach_policy.go new file mode 100644 index 0000000..6780a58 --- /dev/null +++ b/peer_attach_policy.go @@ -0,0 +1,139 @@ +package notify + +import ( + "errors" + "net" + "time" +) + +const defaultPeerAttachReplayCapacity = 4096 + +type PeerAttachChannelBindingRole string + +const ( + PeerAttachChannelBindingRoleClient PeerAttachChannelBindingRole = "client" + PeerAttachChannelBindingRoleServer PeerAttachChannelBindingRole = "server" +) + +type PeerAttachChannelBindingContext struct { + Role PeerAttachChannelBindingRole + PeerID string + Conn net.Conn +} + +type PeerAttachChannelBindingProvider func(PeerAttachChannelBindingContext) ([]byte, error) + +type PeerAttachSecurityConfig struct { + RequireExplicitAuth bool + RequireChannelBinding bool + ReplayWindow time.Duration + ReplayCapacity int + ChannelBinding PeerAttachChannelBindingProvider +} + +type peerAttachSecurityState struct { + requireExplicitAuth bool + requireChannelBinding bool + replayWindow time.Duration + replayCapacity int + channelBinding PeerAttachChannelBindingProvider +} + +var errPeerAttachChannelBindingProviderNil = errors.New("peer attach channel binding provider is nil") + +func DefaultPeerAttachSecurityConfig() PeerAttachSecurityConfig { + return PeerAttachSecurityConfig{ + ReplayWindow: peerAttachReplayTTL, + ReplayCapacity: defaultPeerAttachReplayCapacity, + } +} + +func normalizePeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) (peerAttachSecurityState, error) { + if cfg.ReplayWindow <= 0 { + cfg.ReplayWindow = peerAttachReplayTTL + } + if cfg.ReplayCapacity <= 0 { + cfg.ReplayCapacity = defaultPeerAttachReplayCapacity + } + if cfg.RequireChannelBinding { + cfg.RequireExplicitAuth = true + if cfg.ChannelBinding == nil { + return peerAttachSecurityState{}, errPeerAttachChannelBindingProviderNil + } + } + return peerAttachSecurityState{ + requireExplicitAuth: cfg.RequireExplicitAuth, + requireChannelBinding: cfg.RequireChannelBinding, + replayWindow: cfg.ReplayWindow, + replayCapacity: cfg.ReplayCapacity, + channelBinding: cfg.ChannelBinding, + }, nil +} + +func peerAttachSecurityConfigFromState(state *peerAttachSecurityState) PeerAttachSecurityConfig { + if state == nil { + return DefaultPeerAttachSecurityConfig() + } + return PeerAttachSecurityConfig{ + RequireExplicitAuth: state.requireExplicitAuth, + RequireChannelBinding: state.requireChannelBinding, + ReplayWindow: state.replayWindow, + ReplayCapacity: state.replayCapacity, + ChannelBinding: state.channelBinding, + } +} + +func defaultPeerAttachSecurityState() *peerAttachSecurityState { + cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig()) + return &cfg +} + +func (c *ClientCommon) peerAttachSecuritySnapshot() peerAttachSecurityState { + if c == nil { + cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig()) + return cfg + } + if state := c.peerAttachSecurity.Load(); state != nil { + return *state + } + cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig()) + return cfg +} + +func (s *ServerCommon) peerAttachSecuritySnapshot() peerAttachSecurityState { + if s == nil { + cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig()) + return cfg + } + if state := s.peerAttachSecurity.Load(); state != nil { + return *state + } + cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig()) + return cfg +} + +func (c *ClientCommon) SetPeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) error { + state, err := normalizePeerAttachSecurityConfig(cfg) + if err != nil { + return err + } + c.peerAttachSecurity.Store(&state) + return nil +} + +func (c *ClientCommon) PeerAttachSecurityConfig() PeerAttachSecurityConfig { + return peerAttachSecurityConfigFromState(c.peerAttachSecurity.Load()) +} + +func (s *ServerCommon) SetPeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) error { + state, err := normalizePeerAttachSecurityConfig(cfg) + if err != nil { + return err + } + s.peerAttachSecurity.Store(&state) + return nil +} + +func (s *ServerCommon) PeerAttachSecurityConfig() PeerAttachSecurityConfig { + return peerAttachSecurityConfigFromState(s.peerAttachSecurity.Load()) +} diff --git a/peer_attach_policy_test.go b/peer_attach_policy_test.go new file mode 100644 index 0000000..f14989d --- /dev/null +++ b/peer_attach_policy_test.go @@ -0,0 +1,221 @@ +package notify + +import ( + "bytes" + "errors" + "net" + "testing" + "time" +) + +func staticPeerAttachChannelBindingProvider(material []byte) PeerAttachChannelBindingProvider { + cloned := bytes.Clone(material) + return func(PeerAttachChannelBindingContext) ([]byte, error) { + return bytes.Clone(cloned), nil + } +} + +func failingPeerAttachChannelBindingProvider(PeerAttachChannelBindingContext) ([]byte, error) { + return nil, errors.New("binding unavailable") +} + +func TestSetPeerAttachSecurityConfigRejectsMissingChannelBindingProvider(t *testing.T) { + client := NewClient().(*ClientCommon) + server := NewServer().(*ServerCommon) + cfg := PeerAttachSecurityConfig{RequireChannelBinding: true} + + if err := client.SetPeerAttachSecurityConfig(cfg); !errors.Is(err, errPeerAttachChannelBindingProviderNil) { + t.Fatalf("client SetPeerAttachSecurityConfig error = %v, want %v", err, errPeerAttachChannelBindingProviderNil) + } + if err := server.SetPeerAttachSecurityConfig(cfg); !errors.Is(err, errPeerAttachChannelBindingProviderNil) { + t.Fatalf("server SetPeerAttachSecurityConfig error = %v, want %v", err, errPeerAttachChannelBindingProviderNil) + } +} + +func TestPeerAttachRequireExplicitAuthRejectsFallbackClient(t *testing.T) { + secret := []byte("correct horse battery staple") + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{ + RequireExplicitAuth: true, + }); err != nil { + t.Fatalf("SetPeerAttachSecurityConfig failed: %v", err) + } + + logical := newPeerAttachAuthLogicalForTest(t, server) + if _, err := server.validatePeerAttachRequestAuth(logical, nil, peerAttachRequest{PeerID: "peer-fallback"}); !errors.Is(err, errPeerAttachExplicitAuthRequired) { + t.Fatalf("validatePeerAttachRequestAuth error = %v, want %v", err, errPeerAttachExplicitAuthRequired) + } + classifyPeerAttachRejectCounter(server, errPeerAttachExplicitAuthRequired) + + snapshot, snapErr := GetServerRuntimeSnapshot(server) + if snapErr != nil { + t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr) + } + if got, want := snapshot.PeerAttachDowngradeRejects, int64(1); got != want { + t.Fatalf("PeerAttachDowngradeRejects = %d, want %d", got, want) + } + if got := snapshot.PeerAttachAuthFallbacks; got != 0 { + t.Fatalf("PeerAttachAuthFallbacks = %d, want 0", got) + } +} + +func TestPeerAttachChannelBindingRoundTrip(t *testing.T) { + secret := []byte("correct horse battery staple") + bindingProvider := staticPeerAttachChannelBindingProvider([]byte("tls-exporter:test")) + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + if err := UsePSKOverExternalTransportServer(server, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err) + } + if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{ + RequireChannelBinding: true, + ChannelBinding: bindingProvider, + }); err != nil { + t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err) + } + server.SetLink("echo", func(msg *Message) { + _ = msg.Reply([]byte("ack")) + }) + }) + client := NewClient().(*ClientCommon) + if err := UsePSKOverExternalTransportClient(client, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err) + } + if err := client.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{ + RequireChannelBinding: true, + ChannelBinding: bindingProvider, + }); err != nil { + t.Fatalf("client SetPeerAttachSecurityConfig failed: %v", err) + } + + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachLogicalForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + reply, err := client.SendWait("echo", []byte("ping"), time.Second) + if err != nil { + t.Fatalf("SendWait failed: %v", err) + } + if got, want := string(reply.Value), "ack"; got != want { + t.Fatalf("reply = %q, want %q", got, want) + } + + serverSnapshot, err := GetServerRuntimeSnapshot(server) + if err != nil { + t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) + } + if !serverSnapshot.PeerAttachRequireExplicitAuth || !serverSnapshot.PeerAttachRequireChannelBinding || !serverSnapshot.PeerAttachChannelBindingConfigured { + t.Fatalf("unexpected server peer attach policy snapshot: %+v", serverSnapshot) + } + if got, want := serverSnapshot.PeerAttachExplicitAuth, int64(1); got != want { + t.Fatalf("PeerAttachExplicitAuth = %d, want %d", got, want) + } + if serverSnapshot.PeerAttachAuthRejects != 0 || serverSnapshot.PeerAttachDowngradeRejects != 0 || serverSnapshot.PeerAttachBindingRejects != 0 { + t.Fatalf("unexpected server reject counters: %+v", serverSnapshot) + } + + clientSnapshot, err := GetClientRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) + } + if !clientSnapshot.PeerAttachRequireExplicitAuth || !clientSnapshot.PeerAttachRequireChannelBinding || !clientSnapshot.PeerAttachChannelBindingConfigured { + t.Fatalf("unexpected client peer attach policy snapshot: %+v", clientSnapshot) + } +} + +func TestPeerAttachChannelBindingProviderFailureRejectsAttach(t *testing.T) { + secret := []byte("correct horse battery staple") + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{ + RequireChannelBinding: true, + ChannelBinding: failingPeerAttachChannelBindingProvider, + }); err != nil { + t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err) + } + }) + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{ + RequireChannelBinding: true, + ChannelBinding: staticPeerAttachChannelBindingProvider([]byte("binding")), + }); err != nil { + t.Fatalf("client SetPeerAttachSecurityConfig failed: %v", err) + } + + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachLogicalForTest(t, server, right) + + err := client.ConnectByConn(left) + if !errors.Is(err, errPeerAttachChannelBindingUnavailable) && (err == nil || err.Error() != errPeerAttachChannelBindingUnavailable.Error()) { + t.Fatalf("ConnectByConn error = %v, want %v", err, errPeerAttachChannelBindingUnavailable) + } + + snapshot, snapErr := GetServerRuntimeSnapshot(server) + if snapErr != nil { + t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr) + } + if got, want := snapshot.PeerAttachBindingRejects, int64(1); got != want { + t.Fatalf("PeerAttachBindingRejects = %d, want %d", got, want) + } +} + +func TestPeerAttachReplayCapacityRejectsOverflow(t *testing.T) { + secret := []byte("correct horse battery staple") + client := NewClient().(*ClientCommon) + server := NewServer().(*ServerCommon) + if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{ + ReplayCapacity: 1, + }); err != nil { + t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err) + } + + logical := newPeerAttachAuthLogicalForTest(t, server) + first, _, err := client.buildPeerAttachRequest("peer-one") + if err != nil { + t.Fatalf("buildPeerAttachRequest(first) failed: %v", err) + } + second, _, err := client.buildPeerAttachRequest("peer-two") + if err != nil { + t.Fatalf("buildPeerAttachRequest(second) failed: %v", err) + } + + if _, err := server.validatePeerAttachRequestAuth(logical, nil, first); err != nil { + t.Fatalf("validatePeerAttachRequestAuth(first) failed: %v", err) + } + if _, err := server.validatePeerAttachRequestAuth(logical, nil, second); !errors.Is(err, errPeerAttachReplayWindowFull) { + t.Fatalf("validatePeerAttachRequestAuth(second) error = %v, want %v", err, errPeerAttachReplayWindowFull) + } + classifyPeerAttachRejectCounter(server, errPeerAttachReplayWindowFull) + + snapshot, snapErr := GetServerRuntimeSnapshot(server) + if snapErr != nil { + t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr) + } + if got, want := snapshot.PeerAttachReplayCapacity, 1; got != want { + t.Fatalf("PeerAttachReplayCapacity = %d, want %d", got, want) + } + if got, want := snapshot.PeerAttachReplayOverflowRejects, int64(1); got != want { + t.Fatalf("PeerAttachReplayOverflowRejects = %d, want %d", got, want) + } +} diff --git a/peer_identity.go b/peer_identity.go index d363efc..795e877 100644 --- a/peer_identity.go +++ b/peer_identity.go @@ -15,14 +15,23 @@ const ( ) type peerAttachRequest struct { - PeerID string + PeerID string + Features uint64 + ClientNonce []byte + ClientECDHEPublicKey []byte + AuthTag []byte } type peerAttachResponse struct { - PeerID string - Accepted bool - Reused bool - Error string + PeerID string + Accepted bool + Reused bool + Error string + Features uint64 + KeyMode string + ServerNonce []byte + ServerECDHEPublicKey []byte + AuthTag []byte } func newClientPeerIdentity() string { @@ -108,7 +117,11 @@ func (c *ClientCommon) announceClientPeerIdentity() error { if peerID == "" { return errors.New("peer identity is empty") } - encoded, err := c.sequenceEn(peerAttachRequest{PeerID: peerID}) + req, requestState, err := c.buildPeerAttachRequest(peerID) + if err != nil { + return err + } + encoded, err := c.sequenceEn(req) if err != nil { return err } @@ -133,6 +146,12 @@ func (c *ClientCommon) announceClientPeerIdentity() error { } return errors.New("peer attach rejected") } + verifyResult, err := c.verifyPeerAttachResponse(req, resp, requestState) + if err != nil { + return err + } + c.setClientNegotiatedSteadyTransportProtection(verifyResult.steadyProfile) + c.markClientPeerAttachAuthenticated(verifyResult.authFallback, time.Now()) return nil } @@ -188,7 +207,7 @@ func (s *ServerCommon) replyPeerAttach(client *LogicalConn, message Message, res Type: MSG_SYS_REPLY, } if message.inboundConn != nil { - return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply) + return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, messageInboundTransportProtectionSnapshot(&message), reply) } _, err = s.sendLogical(client, reply) return err @@ -200,6 +219,10 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool { } message = hydrateServerMessagePeerFields(message) current := messageLogicalConnSnapshot(&message) + transport := message.inboundConn + if transport == nil && current != nil { + transport = current.transportSnapshot() + } req, err := decodePeerAttachRequest(s.sequenceDe, message.Value) if err != nil { if current != nil { @@ -210,6 +233,18 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool { } return true } + auth, err := s.validatePeerAttachRequestAuth(current, transport, req) + if err != nil { + classifyPeerAttachRejectCounter(s, err) + if current != nil { + _ = s.replyPeerAttach(current, message, peerAttachResponse{ + PeerID: req.PeerID, + Accepted: false, + Error: err.Error(), + }) + } + return true + } bound, reused, err := s.bindAcceptedClientIdentity(current, req.PeerID) if err != nil { if current != nil { @@ -221,12 +256,37 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool { } return true } - if err := s.replyPeerAttach(bound, message, peerAttachResponse{ + resp := peerAttachResponse{ PeerID: bound.ID(), Accepted: true, Reused: reused, - }); err != nil && bound != nil { + } + steadyProfile, err := s.preparePeerAttachSteadyTransportProfile(bound, req, &resp, auth) + if err != nil { + if bound != nil { + _ = s.replyPeerAttach(bound, message, peerAttachResponse{ + PeerID: req.PeerID, + Accepted: false, + Error: err.Error(), + }) + } + return true + } + s.signPeerAttachResponse(bound, req, &resp, auth) + if bound != nil { + bound.markPeerAttachAuthenticated(s.securityAuthMode, auth.fallback, time.Now()) + if auth.explicit { + s.peerAttachExplicitCount.Add(1) + } else if auth.fallback { + s.peerAttachAuthFallbackCount.Add(1) + } + } + if err := s.replyPeerAttach(bound, message, resp); err != nil && bound != nil { s.stopLogicalSession(bound, "peer attach reply failed", err) + return true + } + if bound != nil && s.securityConfigured { + bound.applyTransportProtectionProfile(steadyProfile) } return true } diff --git a/peer_identity_test.go b/peer_identity_test.go index e90ca37..154560e 100644 --- a/peer_identity_test.go +++ b/peer_identity_test.go @@ -119,6 +119,7 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) { defer serverConn.Close() logical := bootstrapPeerAttachLogicalForTest(t, server, serverConn) + originalProfile := logical.transportProtectionProfileSnapshot() message := Message{ NetType: NET_SERVER, LogicalConn: logical, @@ -131,6 +132,13 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) { Time: time.Now(), inboundConn: serverConn, } + message = hydrateServerMessagePeerFields(message) + + alternate, err := deriveModernPSKProtectionProfile([]byte("notify-peer-attach-reply-alternate"), testModernPSKOptions(), ProtectionManaged) + if err != nil { + t.Fatalf("deriveModernPSKProtectionProfile failed: %v", err) + } + logical.applyTransportProtectionProfile(alternate) done := make(chan error, 1) go func() { @@ -140,7 +148,7 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) { }) }() - env := readServerEnvelopeFromConn(t, server, logical, clientConn, time.Second) + env := readServerEnvelopeFromConnWithProfile(t, server, originalProfile, clientConn, time.Second) if env.Kind != EnvelopeSignal { t.Fatalf("reply envelope kind = %v, want %v", env.Kind, EnvelopeSignal) } diff --git a/security_forward_secrecy.go b/security_forward_secrecy.go new file mode 100644 index 0000000..197c462 --- /dev/null +++ b/security_forward_secrecy.go @@ -0,0 +1,139 @@ +package notify + +import ( + "bytes" + "crypto/ecdh" + "crypto/hmac" + cryptorand "crypto/rand" + "crypto/sha256" + "encoding/binary" + "errors" +) + +const ( + peerAttachECDHEPublicKeySize = 32 + peerAttachSessionIDSize = 16 + + peerAttachKeyModeStatic = "psk-static" + peerAttachKeyModeECDHE = "psk-ecdhe" + transportKeyModeExternal = "external" +) + +var errPeerAttachForwardSecrecyInvalid = errors.New("peer attach forward secrecy is invalid") + +type peerAttachRequestState struct { + forwardSecrecy *peerAttachForwardSecrecyClientState +} + +type peerAttachForwardSecrecyClientState struct { + privateKey *ecdh.PrivateKey + publicKey []byte +} + +type peerAttachResponseVerifyResult struct { + authFallback bool + steadyProfile transportProtectionProfile +} + +func newPeerAttachForwardSecrecyClientState() (*peerAttachForwardSecrecyClientState, error) { + curve := ecdh.X25519() + privateKey, err := curve.GenerateKey(cryptorand.Reader) + if err != nil { + return nil, err + } + publicKey := privateKey.PublicKey().Bytes() + if len(publicKey) != peerAttachECDHEPublicKeySize { + return nil, errPeerAttachForwardSecrecyInvalid + } + return &peerAttachForwardSecrecyClientState{ + privateKey: privateKey, + publicKey: bytes.Clone(publicKey), + }, nil +} + +func derivePeerAttachForwardSecrecyTransportProfile(base transportProtectionProfile, bootstrapKey []byte, localPrivateKey *ecdh.PrivateKey, peerPublicKey []byte, req peerAttachRequest, resp peerAttachResponse) (transportProtectionProfile, error) { + if len(bootstrapKey) == 0 || localPrivateKey == nil || len(peerPublicKey) != peerAttachECDHEPublicKeySize { + return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid + } + curve := ecdh.X25519() + publicKey, err := curve.NewPublicKey(peerPublicKey) + if err != nil { + return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid + } + sharedSecret, err := localPrivateKey.ECDH(publicKey) + if err != nil { + return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid + } + transcriptHash := peerAttachForwardSecrecyTranscriptHash(req, resp) + ikm := make([]byte, 0, len(sharedSecret)+len(transcriptHash)) + ikm = append(ikm, sharedSecret...) + ikm = append(ikm, transcriptHash...) + prk := hkdfExtractSHA256(bootstrapKey, ikm) + sessionKey := hkdfExpandSHA256(prk, []byte("notify/transport/session/v1"), 32) + sessionID := hkdfExpandSHA256(prk, []byte("notify/session-id/v1"), peerAttachSessionIDSize) + return deriveModernPSKSessionProtectionProfile(base, sessionKey, sessionID) +} + +func peerAttachForwardSecrecyTranscriptHash(req peerAttachRequest, resp peerAttachResponse) []byte { + buf := make([]byte, 0, 256+len(req.PeerID)+len(resp.PeerID)+len(resp.Error)+len(req.ClientECDHEPublicKey)+len(resp.ServerECDHEPublicKey)) + buf = appendPeerAttachTranscriptString(buf, "notify/peer-attach/forward-secrecy/v1") + buf = binary.BigEndian.AppendUint64(buf, req.Features) + buf = appendPeerAttachTranscriptString(buf, req.PeerID) + buf = appendPeerAttachTranscriptBytes(buf, req.ClientNonce) + buf = appendPeerAttachTranscriptBytes(buf, req.ClientECDHEPublicKey) + buf = binary.BigEndian.AppendUint64(buf, resp.Features) + buf = appendPeerAttachTranscriptString(buf, resp.PeerID) + buf = appendPeerAttachTranscriptBool(buf, resp.Accepted) + buf = appendPeerAttachTranscriptBool(buf, resp.Reused) + buf = appendPeerAttachTranscriptString(buf, resp.Error) + buf = appendPeerAttachTranscriptBytes(buf, resp.ServerNonce) + buf = appendPeerAttachTranscriptString(buf, resp.KeyMode) + buf = appendPeerAttachTranscriptBytes(buf, resp.ServerECDHEPublicKey) + sum := sha256.Sum256(buf) + return sum[:] +} + +func appendPeerAttachTranscriptBytes(dst []byte, data []byte) []byte { + dst = binary.BigEndian.AppendUint32(dst, uint32(len(data))) + return append(dst, data...) +} + +func appendPeerAttachTranscriptString(dst []byte, value string) []byte { + return appendPeerAttachTranscriptBytes(dst, []byte(value)) +} + +func appendPeerAttachTranscriptBool(dst []byte, value bool) []byte { + if value { + return append(dst, 1) + } + return append(dst, 0) +} + +func hkdfExtractSHA256(salt []byte, ikm []byte) []byte { + mac := hmac.New(sha256.New, salt) + _, _ = mac.Write(ikm) + return mac.Sum(nil) +} + +func hkdfExpandSHA256(prk []byte, info []byte, size int) []byte { + if size <= 0 { + return nil + } + out := make([]byte, 0, size) + var block []byte + for counter := byte(1); len(out) < size; counter++ { + mac := hmac.New(sha256.New, prk) + if len(block) != 0 { + _, _ = mac.Write(block) + } + _, _ = mac.Write(info) + _, _ = mac.Write([]byte{counter}) + block = mac.Sum(nil) + remaining := size - len(out) + if remaining > len(block) { + remaining = len(block) + } + out = append(out, block[:remaining]...) + } + return out +} diff --git a/security_profile.go b/security_profile.go new file mode 100644 index 0000000..9b11d9f --- /dev/null +++ b/security_profile.go @@ -0,0 +1,408 @@ +package notify + +import "bytes" + +// AuthMode describes how notify authenticates the peer during bootstrap. +type AuthMode int + +const ( + AuthNone AuthMode = iota + AuthPSK + AuthExternalPeer +) + +// ProtectionMode describes how notify protects steady-state transport payloads. +type ProtectionMode int + +const ( + ProtectionManaged ProtectionMode = iota + ProtectionExternal + ProtectionNested +) + +// SecurityOptions describes the high-level auth/protection policy. +// +// The current implementation still exposes dedicated helper constructors such +// as UseModernPSKClient/Server and UsePSKOverExternalTransportClient/Server. +type SecurityOptions struct { + AuthMode AuthMode + ProtectionMode ProtectionMode + SharedSecret []byte + RequireForwardSecrecy bool +} + +func authModeName(mode AuthMode) string { + switch mode { + case AuthNone: + return "none" + case AuthPSK: + return "psk" + case AuthExternalPeer: + return "external-peer" + default: + return "unknown" + } +} + +func protectionModeName(mode ProtectionMode) string { + switch mode { + case ProtectionManaged: + return "managed" + case ProtectionExternal: + return "external" + case ProtectionNested: + return "nested" + default: + return "unknown" + } +} + +type transportProtectionProfile struct { + mode ProtectionMode + msgEn func([]byte, []byte) []byte + msgDe func([]byte, []byte) []byte + fastStreamEncode transportFastStreamEncoder + fastBulkEncode transportFastBulkEncoder + fastPlainEncode transportFastPlainEncoder + runtime *modernPSKCodecRuntime + secretKey []byte + keyMode string + sessionID []byte + forwardSecrecy bool + forwardSecrecyFallback bool +} + +func cloneTransportProtectionKey(src []byte) []byte { + if len(src) == 0 { + return nil + } + return bytes.Clone(src) +} + +func newTransportProtectionProfile(mode ProtectionMode, bundle modernPSKTransportBundle, runtime *modernPSKCodecRuntime, secretKey []byte) transportProtectionProfile { + return transportProtectionProfile{ + mode: mode, + msgEn: bundle.msgEn, + msgDe: bundle.msgDe, + fastStreamEncode: bundle.fastStreamEncode, + fastBulkEncode: bundle.fastBulkEncode, + fastPlainEncode: bundle.fastPlainEncode, + runtime: runtime, + secretKey: cloneTransportProtectionKey(secretKey), + keyMode: defaultTransportKeyMode(mode, secretKey), + } +} + +func defaultTransportKeyMode(mode ProtectionMode, secretKey []byte) string { + if len(secretKey) == 0 { + return "" + } + switch mode { + case ProtectionManaged, ProtectionNested: + return peerAttachKeyModeStatic + case ProtectionExternal: + return transportKeyModeExternal + default: + return "" + } +} + +func cloneTransportSessionID(src []byte) []byte { + if len(src) == 0 { + return nil + } + return bytes.Clone(src) +} + +func (p transportProtectionProfile) clone() transportProtectionProfile { + p.secretKey = cloneTransportProtectionKey(p.secretKey) + p.sessionID = cloneTransportSessionID(p.sessionID) + return p +} + +func (p transportProtectionProfile) withForwardSecrecyFallback(fallback bool) transportProtectionProfile { + p = p.clone() + p.forwardSecrecy = false + p.forwardSecrecyFallback = fallback + p.sessionID = nil + return p +} + +func passthroughTransportCodec(_ []byte, payload []byte) []byte { + return payload +} + +func passthroughFastPlainEncode(_ []byte, plainLen int, fill func([]byte) error) ([]byte, error) { + if plainLen < 0 { + return nil, errTransportPayloadEncryptFailed + } + buf := make([]byte, plainLen) + if fill != nil { + if err := fill(buf); err != nil { + return nil, err + } + } + return buf, nil +} + +func buildExternalTransportBundle() modernPSKTransportBundle { + fastPlainEncode := passthroughFastPlainEncode + fastStreamEncode := func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { + return encodeStreamFastFramePayloadFast(fastPlainEncode, secretKey, streamFastDataFrame{ + DataID: dataID, + Seq: seq, + Payload: payload, + }) + } + fastBulkEncode := func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { + return encodeBulkFastFramePayloadFast(fastPlainEncode, secretKey, bulkFastFrame{ + Type: bulkFastPayloadTypeData, + DataID: dataID, + Seq: seq, + Payload: payload, + }) + } + return modernPSKTransportBundle{ + msgEn: passthroughTransportCodec, + msgDe: passthroughTransportCodec, + fastStreamEncode: fastStreamEncode, + fastBulkEncode: fastBulkEncode, + fastPlainEncode: fastPlainEncode, + } +} + +func defaultTransportProtectionProfile() transportProtectionProfile { + return newTransportProtectionProfile(ProtectionManaged, defaultModernPSKTransportBundle(), nil, nil) +} + +func (c *ClientCommon) clientTransportProtectionSnapshot() transportProtectionProfile { + if c == nil { + return transportProtectionProfile{} + } + if state := c.transportProtection.Load(); state != nil { + return *state + } + return transportProtectionProfile{ + mode: ProtectionManaged, + msgEn: c.msgEn, + msgDe: c.msgDe, + fastStreamEncode: c.fastStreamEncode, + fastBulkEncode: c.fastBulkEncode, + fastPlainEncode: c.fastPlainEncode, + runtime: c.modernPSKRuntime, + secretKey: c.SecretKey, + keyMode: defaultTransportKeyMode(ProtectionManaged, c.SecretKey), + } +} + +func (c *ClientCommon) setClientTransportProtectionProfile(profile transportProtectionProfile) { + if c == nil { + return + } + profile.secretKey = cloneTransportProtectionKey(profile.secretKey) + profile.sessionID = cloneTransportSessionID(profile.sessionID) + c.msgEn = profile.msgEn + c.msgDe = profile.msgDe + c.fastStreamEncode = profile.fastStreamEncode + c.fastBulkEncode = profile.fastBulkEncode + c.fastPlainEncode = profile.fastPlainEncode + c.modernPSKRuntime = profile.runtime + c.SecretKey = profile.secretKey + c.transportProtection.Store(&profile) +} + +func (c *ClientCommon) clearClientSecurityProfiles() { + if c == nil { + return + } + c.securityConfigured = false + c.securityAuthMode = AuthNone + c.securityProtectionMode = ProtectionManaged + c.securityBootstrap = transportProtectionProfile{} + c.securitySteady = transportProtectionProfile{} + c.securitySteadyNegotiated = transportProtectionProfile{} + c.securityRequireForwardSecrecy = false + c.peerAttachAuthenticated = false + c.peerAttachAuthFallback = false + c.peerAttachAt = 0 +} + +func (c *ClientCommon) configureClientSecurityProfiles(authMode AuthMode, protectionMode ProtectionMode, bootstrap transportProtectionProfile, steady transportProtectionProfile, requireForwardSecrecy bool) { + if c == nil { + return + } + c.securityConfigured = true + c.securityAuthMode = authMode + c.securityProtectionMode = protectionMode + c.securityBootstrap = bootstrap.clone() + c.securitySteady = steady.clone() + c.securitySteadyNegotiated = steady.clone() + c.securityRequireForwardSecrecy = requireForwardSecrecy + c.setClientTransportProtectionProfile(bootstrap) + c.securityReadyCheck = len(bootstrap.secretKey) != 0 + c.skipKeyExchange = true +} + +func (c *ClientCommon) activateClientBootstrapTransportProtection() { + if c == nil || !c.securityConfigured { + return + } + c.resetClientNegotiatedSteadyTransportProtection() + c.setClientTransportProtectionProfile(c.securityBootstrap) +} + +func (c *ClientCommon) activateClientSteadyTransportProtection() { + if c == nil || !c.securityConfigured { + return + } + c.setClientTransportProtectionProfile(c.securitySteadyNegotiated) +} + +func (c *ClientCommon) resetClientNegotiatedSteadyTransportProtection() { + if c == nil { + return + } + c.securitySteadyNegotiated = c.securitySteady.clone().withForwardSecrecyFallback(false) +} + +func (c *ClientCommon) setClientNegotiatedSteadyTransportProtection(profile transportProtectionProfile) { + if c == nil { + return + } + c.securitySteadyNegotiated = profile.clone() +} + +func (c *ClientCommon) clientSupportsForwardSecrecy() bool { + if c == nil || !c.securityConfigured { + return false + } + if c.securityAuthMode != AuthPSK { + return false + } + if c.securityProtectionMode == ProtectionExternal { + return false + } + return len(c.securityBootstrap.secretKey) != 0 +} + +func (c *ClientCommon) clientRequiresForwardSecrecy() bool { + if c == nil { + return false + } + return c.securityRequireForwardSecrecy +} + +func (s *ServerCommon) serverSupportsForwardSecrecy() bool { + if s == nil || !s.securityConfigured { + return false + } + if s.securityAuthMode != AuthPSK { + return false + } + if s.securityProtectionMode == ProtectionExternal { + return false + } + return len(s.securityBootstrap.secretKey) != 0 +} + +func (s *ServerCommon) serverRequiresForwardSecrecy() bool { + if s == nil { + return false + } + return s.securityRequireForwardSecrecy +} + +func (s *ServerCommon) setServerDefaultTransportProtectionProfile(profile transportProtectionProfile) { + if s == nil { + return + } + profile.secretKey = cloneTransportProtectionKey(profile.secretKey) + profile.sessionID = cloneTransportSessionID(profile.sessionID) + s.defaultMsgEn = profile.msgEn + s.defaultMsgDe = profile.msgDe + s.defaultFastStreamEncode = profile.fastStreamEncode + s.defaultFastBulkEncode = profile.fastBulkEncode + s.defaultFastPlainEncode = profile.fastPlainEncode + s.defaultModernPSKRuntime = profile.runtime + s.SecretKey = profile.secretKey +} + +func (s *ServerCommon) clearServerSecurityProfiles() { + if s == nil { + return + } + s.securityConfigured = false + s.securityAuthMode = AuthNone + s.securityProtectionMode = ProtectionManaged + s.securityBootstrap = transportProtectionProfile{} + s.securitySteady = transportProtectionProfile{} + s.securityRequireForwardSecrecy = false +} + +func (s *ServerCommon) configureServerSecurityProfiles(authMode AuthMode, protectionMode ProtectionMode, bootstrap transportProtectionProfile, steady transportProtectionProfile, requireForwardSecrecy bool) { + if s == nil { + return + } + s.securityConfigured = true + s.securityAuthMode = authMode + s.securityProtectionMode = protectionMode + s.securityBootstrap = bootstrap.clone() + s.securitySteady = steady.clone() + s.securityRequireForwardSecrecy = requireForwardSecrecy + s.setServerDefaultTransportProtectionProfile(bootstrap) + s.securityReadyCheck = len(bootstrap.secretKey) != 0 +} + +func (s *ServerCommon) applyLogicalSteadyTransportProtection(logical *LogicalConn) { + if s == nil || logical == nil || !s.securityConfigured { + return + } + logical.applyTransportProtectionProfile(s.securitySteady) +} + +func transportProtectionProfileFromAttachmentState(state *clientConnAttachmentState) transportProtectionProfile { + if state == nil { + return transportProtectionProfile{} + } + return transportProtectionProfile{ + mode: state.protectionMode, + msgEn: state.msgEn, + msgDe: state.msgDe, + fastStreamEncode: state.fastStreamEncode, + fastBulkEncode: state.fastBulkEncode, + fastPlainEncode: state.fastPlainEncode, + runtime: state.modernPSKRuntime, + secretKey: cloneTransportProtectionKey(state.secretKey), + keyMode: state.keyMode, + sessionID: cloneTransportSessionID(state.sessionID), + forwardSecrecy: state.forwardSecrecy, + forwardSecrecyFallback: state.forwardSecrecyFallback, + } +} + +func (c *LogicalConn) transportProtectionProfileSnapshot() transportProtectionProfile { + if c == nil { + return transportProtectionProfile{} + } + return transportProtectionProfileFromAttachmentState(c.attachmentStateSnapshot()) +} + +func (c *LogicalConn) applyTransportProtectionProfile(profile transportProtectionProfile) { + if c == nil { + return + } + c.updateAttachmentState(func(state *clientConnAttachmentState) { + state.protectionMode = profile.mode + state.msgEn = profile.msgEn + state.msgDe = profile.msgDe + state.fastStreamEncode = profile.fastStreamEncode + state.fastBulkEncode = profile.fastBulkEncode + state.fastPlainEncode = profile.fastPlainEncode + state.modernPSKRuntime = profile.runtime + state.secretKey = cloneTransportProtectionKey(profile.secretKey) + state.keyMode = profile.keyMode + state.sessionID = cloneTransportSessionID(profile.sessionID) + state.forwardSecrecy = profile.forwardSecrecy + state.forwardSecrecyFallback = profile.forwardSecrecyFallback + }) +} diff --git a/security_psk.go b/security_psk.go index 0c0c1a3..1975e42 100644 --- a/security_psk.go +++ b/security_psk.go @@ -15,9 +15,10 @@ import ( ) var ( - errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty") - errModernPSKPayload = errors.New("invalid modern psk payload") - errModernPSKRequired = errors.New("modern psk is required: call UseModernPSKClient/UseModernPSKServer or set a transport key before Connect/Listen") + errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty") + errModernPSKPayload = errors.New("invalid modern psk payload") + errModernPSKRequired = errors.New("transport security is required: call UseModernPSKClient/UseModernPSKServer, UsePSKOverExternalTransportClient/Server, or set a transport key before Connect/Listen") + errModernPSKForwardSecrecyUnsupported = errors.New("forward secrecy is unsupported for external transport protection") ) var ( @@ -47,9 +48,10 @@ var modernPSKPayloadPool sync.Pool // The current profile derives a 32-byte transport key with Argon2id and uses // AES-GCM with a per-codec nonce prefix plus a per-message counter. type ModernPSKOptions struct { - Salt []byte - AAD []byte - Argon2Params starcrypto.Argon2Params + Salt []byte + AAD []byte + Argon2Params starcrypto.Argon2Params + RequireForwardSecrecy bool } // DefaultModernPSKOptions returns the recommended settings for the current @@ -78,24 +80,17 @@ func defaultModernPSKTransportBundle() modernPSKTransportBundle { // Argon2id, and switches message protection to AES-GCM. Configure it before // calling Connect/ConnectTimeout. func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error { - key, aad, err := deriveModernPSKKey(sharedSecret, opts) + managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged) if err != nil { return err } - transport := buildModernPSKTransportBundle(aad) - runtime, err := newModernPSKCodecRuntime(key, aad) - if err != nil { - return err - } - c.SetSecretKey(key) - c.SetMsgEn(transport.msgEn) - c.SetMsgDe(transport.msgDe) if client, ok := c.(*ClientCommon); ok { - client.fastStreamEncode = transport.fastStreamEncode - client.fastBulkEncode = transport.fastBulkEncode - client.fastPlainEncode = transport.fastPlainEncode - client.modernPSKRuntime = runtime + client.configureClientSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy) + return nil } + c.SetSecretKey(managed.secretKey) + c.SetMsgEn(managed.msgEn) + c.SetMsgDe(managed.msgDe) c.SetSkipExchangeKey(true) return nil } @@ -106,24 +101,95 @@ func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) e // It derives a transport key with Argon2id and switches message protection to // AES-GCM. Configure it before calling Listen. func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error { - key, aad, err := deriveModernPSKKey(sharedSecret, opts) + managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged) if err != nil { return err } - transport := buildModernPSKTransportBundle(aad) - runtime, err := newModernPSKCodecRuntime(key, aad) - if err != nil { - return err - } - s.SetSecretKey(key) - s.SetDefaultCommEncode(transport.msgEn) - s.SetDefaultCommDecode(transport.msgDe) if server, ok := s.(*ServerCommon); ok { - server.defaultFastStreamEncode = transport.fastStreamEncode - server.defaultFastBulkEncode = transport.fastBulkEncode - server.defaultFastPlainEncode = transport.fastPlainEncode - server.defaultModernPSKRuntime = runtime + server.configureServerSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy) + return nil } + s.SetSecretKey(managed.secretKey) + s.SetDefaultCommEncode(managed.msgEn) + s.SetDefaultCommDecode(managed.msgDe) + return nil +} + +// UsePSKOverExternalTransportClient authenticates bootstrap with PSK and then +// trusts the external channel for steady-state payload protection. +func UsePSKOverExternalTransportClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error { + if opts != nil && opts.RequireForwardSecrecy { + return errModernPSKForwardSecrecyUnsupported + } + bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged) + if err != nil { + return err + } + steady := buildExternalProtectionProfile(bootstrap.secretKey) + if client, ok := c.(*ClientCommon); ok { + client.configureClientSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false) + return nil + } + c.SetSecretKey(bootstrap.secretKey) + c.SetMsgEn(bootstrap.msgEn) + c.SetMsgDe(bootstrap.msgDe) + c.SetSkipExchangeKey(true) + return nil +} + +// UsePSKOverExternalTransportServer authenticates bootstrap with PSK and then +// trusts the external channel for steady-state payload protection. +func UsePSKOverExternalTransportServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error { + if opts != nil && opts.RequireForwardSecrecy { + return errModernPSKForwardSecrecyUnsupported + } + bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged) + if err != nil { + return err + } + steady := buildExternalProtectionProfile(bootstrap.secretKey) + if server, ok := s.(*ServerCommon); ok { + server.configureServerSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false) + return nil + } + s.SetSecretKey(bootstrap.secretKey) + s.SetDefaultCommEncode(bootstrap.msgEn) + s.SetDefaultCommDecode(bootstrap.msgDe) + return nil +} + +// UseNestedSecurityClient keeps notify transport protection enabled even when +// the physical connection is already protected by an outer trusted channel. +func UseNestedSecurityClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error { + managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested) + if err != nil { + return err + } + if client, ok := c.(*ClientCommon); ok { + client.configureClientSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy) + return nil + } + c.SetSecretKey(managed.secretKey) + c.SetMsgEn(managed.msgEn) + c.SetMsgDe(managed.msgDe) + c.SetSkipExchangeKey(true) + return nil +} + +// UseNestedSecurityServer keeps notify transport protection enabled even when +// the physical connection is already protected by an outer trusted channel. +func UseNestedSecurityServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error { + managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested) + if err != nil { + return err + } + if server, ok := s.(*ServerCommon); ok { + server.configureServerSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy) + return nil + } + s.SetSecretKey(managed.secretKey) + s.SetDefaultCommEncode(managed.msgEn) + s.SetDefaultCommDecode(managed.msgDe) return nil } @@ -132,15 +198,22 @@ func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) e // // It is kept only as an explicit fallback path for existing deployments. func UseLegacySecurityClient(c Client) { + if client, ok := c.(*ClientCommon); ok { + client.clearClientSecurityProfiles() + client.setClientTransportProtectionProfile(transportProtectionProfile{ + mode: ProtectionManaged, + msgEn: defaultMsgEn, + msgDe: defaultMsgDe, + secretKey: bytes.Clone(defaultAesKey), + }) + client.securityReadyCheck = false + client.skipKeyExchange = false + client.handshakeRsaPubKey = bytes.Clone(defaultRsaPubKey) + return + } c.SetSecretKey(bytes.Clone(defaultAesKey)) c.SetMsgEn(defaultMsgEn) c.SetMsgDe(defaultMsgDe) - if client, ok := c.(*ClientCommon); ok { - client.fastStreamEncode = nil - client.fastBulkEncode = nil - client.fastPlainEncode = nil - client.modernPSKRuntime = nil - } c.SetSkipExchangeKey(false) c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey)) } @@ -150,15 +223,21 @@ func UseLegacySecurityClient(c Client) { // // It is kept only as an explicit fallback path for existing deployments. func UseLegacySecurityServer(s Server) { + if server, ok := s.(*ServerCommon); ok { + server.clearServerSecurityProfiles() + server.setServerDefaultTransportProtectionProfile(transportProtectionProfile{ + mode: ProtectionManaged, + msgEn: defaultMsgEn, + msgDe: defaultMsgDe, + secretKey: bytes.Clone(defaultAesKey), + }) + server.securityReadyCheck = false + server.handshakeRsaKey = bytes.Clone(defaultRsaKey) + return + } s.SetSecretKey(bytes.Clone(defaultAesKey)) s.SetDefaultCommEncode(defaultMsgEn) s.SetDefaultCommDecode(defaultMsgDe) - if server, ok := s.(*ServerCommon); ok { - server.defaultFastStreamEncode = nil - server.defaultFastBulkEncode = nil - server.defaultFastPlainEncode = nil - server.defaultModernPSKRuntime = nil - } s.SetRsaPrivKey(bytes.Clone(defaultRsaKey)) } @@ -174,6 +253,40 @@ func deriveModernPSKKey(sharedSecret []byte, opts *ModernPSKOptions) ([]byte, [] return key, cfg.AAD, nil } +func deriveModernPSKProtectionProfile(sharedSecret []byte, opts *ModernPSKOptions, mode ProtectionMode) (transportProtectionProfile, error) { + key, aad, err := deriveModernPSKKey(sharedSecret, opts) + if err != nil { + return transportProtectionProfile{}, err + } + transport := buildModernPSKTransportBundle(aad) + runtime, err := newModernPSKCodecRuntime(key, aad) + if err != nil { + return transportProtectionProfile{}, err + } + return newTransportProtectionProfile(mode, transport, runtime, key), nil +} + +func buildExternalProtectionProfile(secretKey []byte) transportProtectionProfile { + return newTransportProtectionProfile(ProtectionExternal, buildExternalTransportBundle(), nil, secretKey) +} + +func deriveModernPSKSessionProtectionProfile(base transportProtectionProfile, sessionKey []byte, sessionID []byte) (transportProtectionProfile, error) { + aad := bytes.Clone(defaultModernPSKAAD) + if base.runtime != nil && len(base.runtime.aad) != 0 { + aad = bytes.Clone(base.runtime.aad) + } + runtime, err := newModernPSKCodecRuntime(sessionKey, aad) + if err != nil { + return transportProtectionProfile{}, err + } + profile := newTransportProtectionProfile(base.mode, buildModernPSKTransportBundle(aad), runtime, sessionKey) + profile.keyMode = peerAttachKeyModeECDHE + profile.sessionID = cloneTransportSessionID(sessionID) + profile.forwardSecrecy = true + profile.forwardSecrecyFallback = false + return profile, nil +} + func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions { cfg := DefaultModernPSKOptions() if opts == nil { diff --git a/security_psk_test.go b/security_psk_test.go index ca6f67b..f7879ef 100644 --- a/security_psk_test.go +++ b/security_psk_test.go @@ -3,8 +3,10 @@ package notify import ( "bytes" "errors" + "net" "reflect" "testing" + "time" "b612.me/starcrypto" ) @@ -207,6 +209,17 @@ func TestUseModernPSKRejectsEmptySecret(t *testing.T) { } } +func TestUsePSKOverExternalTransportRejectsForwardSecrecyRequirement(t *testing.T) { + opts := testModernPSKOptions() + opts.RequireForwardSecrecy = true + if err := UsePSKOverExternalTransportClient(NewClient(), []byte("secret"), opts); !errors.Is(err, errModernPSKForwardSecrecyUnsupported) { + t.Fatalf("UsePSKOverExternalTransportClient error = %v, want %v", err, errModernPSKForwardSecrecyUnsupported) + } + if err := UsePSKOverExternalTransportServer(NewServer(), []byte("secret"), opts); !errors.Is(err, errModernPSKForwardSecrecyUnsupported) { + t.Fatalf("UsePSKOverExternalTransportServer error = %v, want %v", err, errModernPSKForwardSecrecyUnsupported) + } +} + func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) { key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions()) if err != nil { @@ -315,6 +328,166 @@ func TestModernPSKFastBulkEncodeRoundTrip(t *testing.T) { } } +func TestExternalTransportFastStreamEncodeRoundTrip(t *testing.T) { + transport := buildExternalTransportBundle() + wire, err := transport.fastStreamEncode(nil, 23, 7, []byte("payload")) + if err != nil { + t.Fatalf("fastStreamEncode failed: %v", err) + } + plain := transport.msgDe(nil, wire) + frame, matched, err := decodeStreamFastDataFrame(plain) + if err != nil { + t.Fatalf("decodeStreamFastDataFrame failed: %v", err) + } + if !matched { + t.Fatal("decodeStreamFastDataFrame should match fast payload") + } + if frame.DataID != 23 || frame.Seq != 7 || !bytes.Equal(frame.Payload, []byte("payload")) { + t.Fatalf("frame mismatch: %+v", frame) + } +} + +func TestExternalTransportFastBulkEncodeRoundTrip(t *testing.T) { + transport := buildExternalTransportBundle() + wire, err := transport.fastBulkEncode(nil, 41, 9, []byte("payload")) + if err != nil { + t.Fatalf("fastBulkEncode failed: %v", err) + } + plain := transport.msgDe(nil, wire) + frame, matched, err := decodeBulkFastDataFrame(plain) + if err != nil { + t.Fatalf("decodeBulkFastDataFrame failed: %v", err) + } + if !matched { + t.Fatal("decodeBulkFastDataFrame should match fast payload") + } + if frame.DataID != 41 || frame.Seq != 9 || !bytes.Equal(frame.Payload, []byte("payload")) { + t.Fatalf("frame mismatch: %+v", frame) + } +} + +func TestDecryptTransportPayloadCodecPooledExternalDefersRelease(t *testing.T) { + payload := []byte("payload") + released := false + plain, release, err := decryptTransportPayloadCodecPooled(ProtectionExternal, nil, passthroughTransportCodec, nil, payload, func() { + released = true + }) + if err != nil { + t.Fatalf("decryptTransportPayloadCodecPooled failed: %v", err) + } + if released { + t.Fatal("release should not run before caller is done") + } + if !bytes.Equal(plain, payload) { + t.Fatalf("plain mismatch: got %q want %q", plain, payload) + } + if release == nil { + t.Fatal("release callback should be preserved for external mode") + } + release() + if !released { + t.Fatal("release callback should run when caller finishes") + } +} + +func TestUsePSKOverExternalTransportConnectByConnSwitchesToExternal(t *testing.T) { + client := NewClient().(*ClientCommon) + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + if err := UsePSKOverExternalTransportServer(server, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil { + t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err) + } + server.SetLink("external-roundtrip", func(msg *Message) { + _ = msg.Reply([]byte("ack:" + string(msg.Value))) + }) + }) + if err := UsePSKOverExternalTransportClient(client, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil { + t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err) + } + if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionManaged { + t.Fatalf("client bootstrap mode = %v, want %v", got, ProtectionManaged) + } + + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachConnForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionExternal { + t.Fatalf("client steady mode = %v, want %v", got, ProtectionExternal) + } + + reply, err := client.SendWait("external-roundtrip", []byte("ping"), time.Second) + if err != nil { + t.Fatalf("SendWait failed: %v", err) + } + if got, want := string(reply.Value), "ack:ping"; got != want { + t.Fatalf("reply mismatch: got %q want %q", got, want) + } + + list := server.GetLogicalConnList() + if len(list) != 1 { + t.Fatalf("logical conn count = %d, want 1", len(list)) + } + if got := list[0].protectionModeSnapshot(); got != ProtectionExternal { + t.Fatalf("server steady mode = %v, want %v", got, ProtectionExternal) + } +} + +func TestUseNestedSecurityConnectByConnKeepsNestedMode(t *testing.T) { + client := NewClient().(*ClientCommon) + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + if err := UseNestedSecurityServer(server, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil { + t.Fatalf("UseNestedSecurityServer failed: %v", err) + } + server.SetLink("nested-roundtrip", func(msg *Message) { + _ = msg.Reply([]byte("ack:" + string(msg.Value))) + }) + }) + if err := UseNestedSecurityClient(client, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil { + t.Fatalf("UseNestedSecurityClient failed: %v", err) + } + if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionNested { + t.Fatalf("client bootstrap mode = %v, want %v", got, ProtectionNested) + } + + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachConnForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionNested { + t.Fatalf("client steady mode = %v, want %v", got, ProtectionNested) + } + + reply, err := client.SendWait("nested-roundtrip", []byte("ping"), time.Second) + if err != nil { + t.Fatalf("SendWait failed: %v", err) + } + if got, want := string(reply.Value), "ack:ping"; got != want { + t.Fatalf("reply mismatch: got %q want %q", got, want) + } + + list := server.GetLogicalConnList() + if len(list) != 1 { + t.Fatalf("logical conn count = %d, want 1", len(list)) + } + if got := list[0].protectionModeSnapshot(); got != ProtectionNested { + t.Fatalf("server steady mode = %v, want %v", got, ProtectionNested) + } +} + func TestUseLegacySecurityRoundTrip(t *testing.T) { client := NewClient() server := NewServer() diff --git a/server.go b/server.go index 94c4f4f..11b7651 100644 --- a/server.go +++ b/server.go @@ -10,52 +10,65 @@ import ( ) type ServerCommon struct { - msgID uint64 - alive atomic.Value - status Status - sessionOwnerState atomic.Int32 - sessionRuntime atomic.Pointer[serverSessionRuntime] - listener net.Listener - udpListener *net.UDPConn - queue *stario.StarQueue - stopFn context.CancelFunc - stopCtx context.Context - maxReadTimeout time.Duration - maxWriteTimeout time.Duration - parallelNum int - wg stario.WaitGroup - peerRegistry *serverPeerRegistry - mu sync.RWMutex - handshakeRsaKey []byte - SecretKey []byte - defaultMsgEn func([]byte, []byte) []byte - defaultMsgDe func([]byte, []byte) []byte - defaultFastStreamEncode transportFastStreamEncoder - defaultFastBulkEncode transportFastBulkEncoder - defaultFastPlainEncode transportFastPlainEncoder - defaultModernPSKRuntime *modernPSKCodecRuntime - linkFns map[string]func(message *Message) - defaultFns func(message *Message) - noFinSyncMsgMaxKeepSeconds int64 - maxHeartbeatLostSeconds int64 - sequenceDe func([]byte) (interface{}, error) - sequenceEn func(interface{}) ([]byte, error) - logicalSession *logicalSessionState - onFileEvent func(FileEvent) - fileEventObserver func(FileEvent) - fileTransferCfg fileTransferConfig - signalReliableCfg signalReliabilityConfig - streamRuntime *streamRuntime - recordRuntime *recordRuntime - bulkRuntime *bulkRuntime - bulkOpenTuning BulkOpenTuning - bulkDedicatedSidecarMu sync.Mutex - bulkDedicatedSidecars map[*LogicalConn]map[uint32]*bulkDedicatedSidecar - connectionRetryState *connectionRetryState - detachedClientKeepSeconds int64 - securityReadyCheck bool - showError bool - debugMode bool + msgID uint64 + alive atomic.Value + status Status + sessionOwnerState atomic.Int32 + sessionRuntime atomic.Pointer[serverSessionRuntime] + listener net.Listener + udpListener *net.UDPConn + queue *stario.StarQueue + stopFn context.CancelFunc + stopCtx context.Context + maxReadTimeout time.Duration + maxWriteTimeout time.Duration + parallelNum int + wg stario.WaitGroup + peerRegistry *serverPeerRegistry + mu sync.RWMutex + handshakeRsaKey []byte + SecretKey []byte + defaultMsgEn func([]byte, []byte) []byte + defaultMsgDe func([]byte, []byte) []byte + defaultFastStreamEncode transportFastStreamEncoder + defaultFastBulkEncode transportFastBulkEncoder + defaultFastPlainEncode transportFastPlainEncoder + defaultModernPSKRuntime *modernPSKCodecRuntime + peerAttachSecurity atomic.Pointer[peerAttachSecurityState] + securityBootstrap transportProtectionProfile + securitySteady transportProtectionProfile + securityAuthMode AuthMode + securityProtectionMode ProtectionMode + securityRequireForwardSecrecy bool + securityConfigured bool + peerAttachReplay peerAttachReplayCache + peerAttachExplicitCount atomic.Int64 + peerAttachAuthFallbackCount atomic.Int64 + peerAttachAuthRejectCount atomic.Int64 + peerAttachDowngradeRejectCount atomic.Int64 + peerAttachBindingRejectCount atomic.Int64 + linkFns map[string]func(message *Message) + defaultFns func(message *Message) + noFinSyncMsgMaxKeepSeconds int64 + maxHeartbeatLostSeconds int64 + sequenceDe func([]byte) (interface{}, error) + sequenceEn func(interface{}) ([]byte, error) + logicalSession *logicalSessionState + onFileEvent func(FileEvent) + fileEventObserver func(FileEvent) + fileTransferCfg fileTransferConfig + signalReliableCfg signalReliabilityConfig + streamRuntime *streamRuntime + recordRuntime *recordRuntime + bulkRuntime *bulkRuntime + bulkOpenTuning BulkOpenTuning + bulkDedicatedSidecarMu sync.Mutex + bulkDedicatedSidecars map[*LogicalConn]map[uint32]*bulkDedicatedSidecar + connectionRetryState *connectionRetryState + detachedClientKeepSeconds int64 + securityReadyCheck bool + showError bool + debugMode bool } func NewServer() Server { @@ -93,6 +106,8 @@ func NewServer() Server { server.defaultFns = func(message *Message) { return } + server.setServerDefaultTransportProtectionProfile(defaultTransportProtectionProfile()) + server.peerAttachSecurity.Store(defaultPeerAttachSecurityState()) server.sessionRuntime.Store(newServerSessionRuntimeBase(server.stopCtx, server.stopFn)) bindServerStreamControl(&server) bindServerBulkControl(&server) diff --git a/server_bulk.go b/server_bulk.go index 75854ec..1a4b512 100644 --- a/server_bulk.go +++ b/server_bulk.go @@ -413,7 +413,7 @@ func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *Tra if err := bulk.waitDedicatedReady(ctx); err != nil { return 0, err } - return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, startSeq, payload) + return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, startSeq, payload, payloadOwned) } if transport == nil { return 0, errBulkTransportNil diff --git a/server_config.go b/server_config.go index f3f6866..3e99529 100644 --- a/server_config.go +++ b/server_config.go @@ -31,22 +31,28 @@ func (s *ServerCommon) Stop() error { // Deprecated: SetDefaultCommEncode overrides the transport codec directly. // Prefer UseModernPSKServer or UseLegacySecurityServer. func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) { - s.defaultMsgEn = fn - s.defaultFastStreamEncode = nil - s.defaultFastBulkEncode = nil - s.defaultFastPlainEncode = nil - s.defaultModernPSKRuntime = nil + profile := transportProtectionProfile{ + mode: ProtectionManaged, + msgEn: fn, + msgDe: s.defaultMsgDe, + secretKey: s.SecretKey, + } + s.setServerDefaultTransportProtectionProfile(profile) + s.clearServerSecurityProfiles() s.securityReadyCheck = false } // Deprecated: SetDefaultCommDecode overrides the transport codec directly. // Prefer UseModernPSKServer or UseLegacySecurityServer. func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) { - s.defaultMsgDe = fn - s.defaultFastStreamEncode = nil - s.defaultFastBulkEncode = nil - s.defaultFastPlainEncode = nil - s.defaultModernPSKRuntime = nil + profile := transportProtectionProfile{ + mode: ProtectionManaged, + msgEn: s.defaultMsgEn, + msgDe: fn, + secretKey: s.SecretKey, + } + s.setServerDefaultTransportProtectionProfile(profile) + s.clearServerSecurityProfiles() s.securityReadyCheck = false } @@ -98,14 +104,21 @@ func (s *ServerCommon) GetSecretKey() []byte { // Deprecated: SetSecretKey injects a raw transport key directly. // Prefer UseModernPSKServer or UseLegacySecurityServer. func (s *ServerCommon) SetSecretKey(key []byte) { - s.SecretKey = key - if len(key) == 0 { - s.defaultModernPSKRuntime = nil - } else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil { - s.defaultModernPSKRuntime = runtime - } else { - s.defaultModernPSKRuntime = nil + profile := transportProtectionProfile{ + mode: ProtectionManaged, + msgEn: s.defaultMsgEn, + msgDe: s.defaultMsgDe, + secretKey: cloneTransportProtectionKey(key), } + if len(key) == 0 { + profile.runtime = nil + } else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil { + profile.runtime = runtime + } else { + profile.runtime = nil + } + s.setServerDefaultTransportProtectionProfile(profile) + s.clearServerSecurityProfiles() s.securityReadyCheck = len(key) == 0 } diff --git a/server_inbound_reply_test.go b/server_inbound_reply_test.go index 0cc3c8f..f6a4721 100644 --- a/server_inbound_reply_test.go +++ b/server_inbound_reply_test.go @@ -3,12 +3,54 @@ package notify import ( "b612.me/stario" "context" + "errors" "math" "net" "testing" "time" ) +func readServerEnvelopeFromConnWithProfile(t *testing.T, server *ServerCommon, profile transportProtectionProfile, conn net.Conn, timeout time.Duration) Envelope { + t.Helper() + + queue := stario.NewQueue() + deadline := time.Now().Add(timeout) + buf := make([]byte, 4096) + for time.Now().Before(deadline) { + if err := conn.SetReadDeadline(deadline); err != nil { + t.Fatalf("SetReadDeadline failed: %v", err) + } + n, err := conn.Read(buf) + if n > 0 { + if parseErr := queue.ParseMessage(buf[:n], "server-inbound-profile"); parseErr != nil { + t.Fatalf("ParseMessage failed: %v", parseErr) + } + select { + case msg := <-queue.RestoreChan(): + plain, decErr := decryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, msg.Msg) + if decErr != nil { + t.Fatalf("decryptTransportPayloadCodec failed: %v", decErr) + } + env, decErr := server.decodeEnvelopePlain(plain) + if decErr != nil { + t.Fatalf("decodeEnvelopePlain failed: %v", decErr) + } + return env + default: + } + } + if err == nil { + continue + } + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + break + } + t.Fatalf("conn Read failed: %v", err) + } + t.Fatal("timed out waiting for server envelope") + return Envelope{} +} + func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) @@ -85,6 +127,64 @@ func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) { } } +func TestMessageReplyUsesInboundProtectionSnapshotAfterLogicalSwitch(t *testing.T) { + secret := []byte("correct horse battery staple") + alternate, err := deriveModernPSKProtectionProfile([]byte("notify-reply-snapshot-alternate"), testModernPSKOptions(), ProtectionManaged) + if err != nil { + t.Fatalf("deriveModernPSKProtectionProfile failed: %v", err) + } + handlerErr := make(chan error, 1) + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + server.SetLink("reply-snapshot", func(msg *Message) { + if msg == nil || msg.LogicalConn == nil { + select { + case handlerErr <- errors.New("reply-snapshot logical is nil"): + default: + } + return + } + msg.LogicalConn.applyTransportProtectionProfile(alternate) + if err := msg.Reply([]byte("ack")); err != nil { + select { + case handlerErr <- err: + default: + } + } + }) + }) + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachConnForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + reply, err := client.SendWait("reply-snapshot", []byte("ping"), time.Second) + if err != nil { + t.Fatalf("SendWait failed: %v", err) + } + if got, want := string(reply.Value), "ack"; got != want { + t.Fatalf("reply value = %q, want %q", got, want) + } + select { + case err := <-handlerErr: + t.Fatalf("reply-snapshot handler failed: %v", err) + default: + } +} + func TestServerHandleReceivedSignalReliabilityUsesInboundConnForStaleTransport(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) diff --git a/server_inbound_source.go b/server_inbound_source.go index a205710..e5aa63b 100644 --- a/server_inbound_source.go +++ b/server_inbound_source.go @@ -93,26 +93,34 @@ func (s *ServerCommon) pushTransportPayloadSourceFast(payload []byte, release fu } return true } - if s.tryDispatchBorrowedBulkTransportPayload(source, payload) { + logical, transport := s.resolveInboundSource(source) + if logical == nil { if release != nil { release() } return true } - owned := append([]byte(nil), payload...) - if release != nil { - release() + plain, plainRelease, err := s.decryptTransportPayloadLogicalPooled(logical, payload, release) + if err != nil { + if s.showError || s.debugMode { + fmt.Println("server decode transport payload error", err) + } + return true + } + inboundConn := serverInboundConn(source) + if s.tryDispatchBorrowedTransportPlain(logical, transport, inboundConn, plain, plainRelease) { + return true + } + owned := plain + if plainRelease != nil { + owned = append([]byte(nil), plain...) + plainRelease() } s.wg.Add(1) if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() { defer s.wg.Done() - logical, transport := s.resolveInboundSource(source) - if logical == nil { - return - } now := time.Now() - inboundConn := serverInboundConn(source) - if err := s.dispatchInboundTransportPayload(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) { + if err := s.dispatchInboundTransportPlain(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) { fmt.Println("server decode envelope error", err) } }) { diff --git a/server_send.go b/server_send.go index d6f3ae0..11b94d4 100644 --- a/server_send.go +++ b/server_send.go @@ -358,16 +358,20 @@ func (s *ServerCommon) sendEnvelopeTransport(transport *TransportConn, env Envel } func (s *ServerCommon) sendEnvelopeInboundTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope) error { + return s.sendEnvelopeInboundTransportWithProfile(logical, transport, conn, nil, env) +} + +func (s *ServerCommon) sendEnvelopeInboundTransportWithProfile(logical *LogicalConn, transport *TransportConn, conn net.Conn, profile *transportProtectionProfile, env Envelope) error { if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return transportDetachedErrorForPeer(logical, transport) } - if logical.msgEnSnapshot() == nil { + if profile == nil && logical.msgEnSnapshot() == nil { return transportDetachedErrorForPeer(logical, transport) } - payload, err := s.encodeEnvelopePayloadLogical(logical, env) + payload, err := s.encodeEnvelopePayloadInbound(logical, env, profile) if err != nil { return err } @@ -402,7 +406,18 @@ func (s *ServerCommon) writeControlEnvelopePayload(logical *LogicalConn, transpo return sender.submit(payload, writeDeadlineFromTimeout(logical.maxWriteTimeoutSnapshot())) } -func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, msg TransferMsg) error { +func (s *ServerCommon) encodeEnvelopePayloadInbound(logical *LogicalConn, env Envelope, profile *transportProtectionProfile) ([]byte, error) { + if profile == nil { + return s.encodeEnvelopePayloadLogical(logical, env) + } + data, err := s.encodeEnvelopePlain(env) + if err != nil { + return nil, err + } + return encryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgEn, profile.secretKey, data) +} + +func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, profile *transportProtectionProfile, msg TransferMsg) error { if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } @@ -413,7 +428,7 @@ func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *Tran if err != nil { return err } - return s.sendEnvelopeInboundTransport(logical, transport, conn, env) + return s.sendEnvelopeInboundTransportWithProfile(logical, transport, conn, profile, env) } func (s *ServerCommon) writeEnvelopePayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte) error { diff --git a/server_session.go b/server_session.go index 177e565..f20ac48 100644 --- a/server_session.go +++ b/server_session.go @@ -110,7 +110,17 @@ func (s *ServerCommon) registerAcceptedLogical(logical *LogicalConn) *LogicalCon } logical.setServer(s) logical.applyAttachmentProfile(s.maxReadTimeout, s.maxWriteTimeout, s.defaultMsgEn, s.defaultMsgDe, s.defaultFastStreamEncode, s.defaultFastBulkEncode, s.defaultFastPlainEncode, s.handshakeRsaKey, s.SecretKey) - logical.setModernPSKRuntime(s.defaultModernPSKRuntime) + logical.updateAttachmentState(func(state *clientConnAttachmentState) { + state.authMode = s.securityAuthMode + state.peerAttached = false + state.peerAttachFallback = false + state.peerAttachAt = 0 + }) + if s.securityConfigured { + logical.applyTransportProtectionProfile(s.securityBootstrap) + } else { + logical.setModernPSKRuntime(s.defaultModernPSKRuntime) + } logical.markHeartbeatNow() return s.getPeerRegistry().registerLogical(logical) } diff --git a/servertype.go b/servertype.go index f1c6ac3..fa17161 100644 --- a/servertype.go +++ b/servertype.go @@ -26,6 +26,8 @@ type Server interface { RecoverTransferSnapshots(context.Context) error SetBulkOpenTuning(BulkOpenTuning) BulkOpenTuning() BulkOpenTuning + SetPeerAttachSecurityConfig(PeerAttachSecurityConfig) error + PeerAttachSecurityConfig() PeerAttachSecurityConfig SetFileReceiveDir(dir string) error send(c *ClientConn, msg TransferMsg) (WaitMsg, error) sendEnvelope(c *ClientConn, env Envelope) error diff --git a/session_runtime_snapshot.go b/session_runtime_snapshot.go index 03c6ee7..0daa211 100644 --- a/session_runtime_snapshot.go +++ b/session_runtime_snapshot.go @@ -1,6 +1,7 @@ package notify import ( + "encoding/hex" "errors" "time" ) @@ -17,6 +18,19 @@ type ClientRuntimeSnapshot struct { ConnectNetwork string ConnectAddress string CanReconnect bool + AuthMode string + ProtectionMode string + ProtectionKeyMode string + ForwardSecrecyEnabled bool + ForwardSecrecyFallback bool + ForwardSecrecyRequired bool + TransportSessionID string + PeerAttachAuthenticated bool + PeerAttachAuthFallback bool + LastPeerAttachAt time.Time + PeerAttachRequireExplicitAuth bool + PeerAttachRequireChannelBinding bool + PeerAttachChannelBindingConfigured bool BulkNetworkProfile string BulkDefaultMode string BulkChunkSize int @@ -37,22 +51,38 @@ type ClientRuntimeSnapshot struct { } type ServerRuntimeSnapshot struct { - OwnerState string - Alive bool - ClientCount int - DetachedClientCount int - DetachedReattachableClientCount int - DetachedExpiredClientCount int - DetachedClientKeepSec int64 - TransportAttached bool - HasRuntimeListener bool - HasRuntimeUDPListener bool - HasRuntimeQueue bool - HasRuntimeStopCtx bool - BulkChunkSize int - BulkWindowBytes int - BulkMaxInFlight int - Retry ConnectionRetrySnapshot + OwnerState string + Alive bool + ClientCount int + DetachedClientCount int + DetachedReattachableClientCount int + DetachedExpiredClientCount int + DetachedClientKeepSec int64 + TransportAttached bool + HasRuntimeListener bool + HasRuntimeUDPListener bool + HasRuntimeQueue bool + HasRuntimeStopCtx bool + AuthMode string + ProtectionMode string + ForwardSecrecySupported bool + ForwardSecrecyRequired bool + PeerAttachRequireExplicitAuth bool + PeerAttachRequireChannelBinding bool + PeerAttachChannelBindingConfigured bool + PeerAttachReplayWindow time.Duration + PeerAttachReplayCapacity int + PeerAttachExplicitAuth int64 + PeerAttachAuthFallbacks int64 + PeerAttachAuthRejects int64 + PeerAttachDowngradeRejects int64 + PeerAttachBindingRejects int64 + PeerAttachReplayRejects int64 + PeerAttachReplayOverflowRejects int64 + BulkChunkSize int + BulkWindowBytes int + BulkMaxInFlight int + Retry ConnectionRetrySnapshot } type ClientConnRuntimeSnapshot struct { @@ -82,6 +112,15 @@ type ClientConnRuntimeSnapshot struct { TransportDetachRemaining time.Duration TransportDetachExpired bool ReattachEligible bool + AuthMode string + ProtectionMode string + ProtectionKeyMode string + ForwardSecrecyEnabled bool + ForwardSecrecyFallback bool + TransportSessionID string + PeerAttachAuthenticated bool + PeerAttachAuthFallback bool + LastPeerAttachAt time.Time TransportBulkAdaptiveSoftPayloadBytes int TransportStreamAdaptiveSoftPayloadBytes int TransportStreamAdaptiveWaitThresholdBytes int @@ -108,6 +147,19 @@ func (c *ClientCommon) clientRuntimeSnapshot() ClientRuntimeSnapshot { snapshot.ConnectAddress = source.addr snapshot.CanReconnect = source.canReconnect() } + snapshot.AuthMode = authModeName(c.securityAuthMode) + snapshot.ProtectionMode = protectionModeName(c.securityProtectionMode) + protection := c.clientTransportProtectionSnapshot() + snapshot.ProtectionKeyMode = protection.keyMode + snapshot.ForwardSecrecyEnabled = protection.forwardSecrecy + snapshot.ForwardSecrecyFallback = protection.forwardSecrecyFallback + snapshot.ForwardSecrecyRequired = c.clientRequiresForwardSecrecy() + snapshot.TransportSessionID = hex.EncodeToString(protection.sessionID) + snapshot.PeerAttachAuthenticated, snapshot.PeerAttachAuthFallback, snapshot.LastPeerAttachAt = c.clientPeerAttachAuthSnapshot() + peerAttachCfg := c.peerAttachSecuritySnapshot() + snapshot.PeerAttachRequireExplicitAuth = peerAttachCfg.requireExplicitAuth + snapshot.PeerAttachRequireChannelBinding = peerAttachCfg.requireChannelBinding + snapshot.PeerAttachChannelBindingConfigured = peerAttachCfg.channelBinding != nil snapshot.BulkNetworkProfile = bulkNetworkProfileName(c.BulkNetworkProfile()) snapshot.BulkDefaultMode = bulkOpenModeName(c.BulkDefaultOpenMode()) tuning := c.BulkOpenTuning() @@ -161,6 +213,23 @@ func (s *ServerCommon) serverRuntimeSnapshot() ServerRuntimeSnapshot { snapshot.HasRuntimeQueue = rt.queue != nil snapshot.HasRuntimeStopCtx = rt.stopCtx != nil } + snapshot.AuthMode = authModeName(s.securityAuthMode) + snapshot.ProtectionMode = protectionModeName(s.securityProtectionMode) + snapshot.ForwardSecrecySupported = s.serverSupportsForwardSecrecy() + snapshot.ForwardSecrecyRequired = s.serverRequiresForwardSecrecy() + peerAttachCfg := s.peerAttachSecuritySnapshot() + snapshot.PeerAttachRequireExplicitAuth = peerAttachCfg.requireExplicitAuth + snapshot.PeerAttachRequireChannelBinding = peerAttachCfg.requireChannelBinding + snapshot.PeerAttachChannelBindingConfigured = peerAttachCfg.channelBinding != nil + snapshot.PeerAttachReplayWindow = peerAttachCfg.replayWindow + snapshot.PeerAttachReplayCapacity = peerAttachCfg.replayCapacity + snapshot.PeerAttachExplicitAuth = s.peerAttachExplicitCount.Load() + snapshot.PeerAttachAuthFallbacks = s.peerAttachAuthFallbackCount.Load() + snapshot.PeerAttachAuthRejects = s.peerAttachAuthRejectCount.Load() + snapshot.PeerAttachDowngradeRejects = s.peerAttachDowngradeRejectCount.Load() + snapshot.PeerAttachBindingRejects = s.peerAttachBindingRejectCount.Load() + snapshot.PeerAttachReplayRejects = s.peerAttachReplayRejectCountSnapshot() + snapshot.PeerAttachReplayOverflowRejects = s.peerAttachReplayOverflowRejectCountSnapshot() tuning := s.BulkOpenTuning() snapshot.BulkChunkSize = tuning.ChunkSize snapshot.BulkWindowBytes = tuning.WindowBytes @@ -171,6 +240,7 @@ func (s *ServerCommon) serverRuntimeSnapshot() ServerRuntimeSnapshot { func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot { status := c.clientConnStatusSnapshot() + attachment := c.clientConnAttachmentStateSnapshot() now := time.Now() snapshot := ClientConnRuntimeSnapshot{ ClientID: c.clientConnIDSnapshot(), @@ -182,6 +252,17 @@ func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot { TransportAttachCount: c.clientConnTransportAttachCountSnapshot(), TransportDetachCount: c.clientConnTransportDetachCountSnapshot(), LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(), + AuthMode: authModeName(attachment.authMode), + ProtectionMode: protectionModeName(attachment.protectionMode), + } + snapshot.PeerAttachAuthenticated = attachment.peerAttached + snapshot.PeerAttachAuthFallback = attachment.peerAttachFallback + snapshot.ProtectionKeyMode = attachment.keyMode + snapshot.ForwardSecrecyEnabled = attachment.forwardSecrecy + snapshot.ForwardSecrecyFallback = attachment.forwardSecrecyFallback + snapshot.TransportSessionID = hex.EncodeToString(attachment.sessionID) + if attachment.peerAttachAt != 0 { + snapshot.LastPeerAttachAt = time.Unix(0, attachment.peerAttachAt) } if status.Err != nil { snapshot.Error = status.Err.Error() diff --git a/session_runtime_snapshot_test.go b/session_runtime_snapshot_test.go index 0a07d3a..6a305e0 100644 --- a/session_runtime_snapshot_test.go +++ b/session_runtime_snapshot_test.go @@ -37,6 +37,21 @@ func TestGetClientRuntimeSnapshotDefaults(t *testing.T) { if snapshot.ConnectSource != "" || snapshot.ConnectNetwork != "" || snapshot.ConnectAddress != "" || snapshot.CanReconnect { t.Fatalf("unexpected default connect source snapshot: %+v", snapshot) } + if got, want := snapshot.AuthMode, "none"; got != want { + t.Fatalf("AuthMode mismatch: got %q want %q", got, want) + } + if got, want := snapshot.ProtectionMode, "managed"; got != want { + t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want) + } + if snapshot.PeerAttachAuthenticated || snapshot.PeerAttachAuthFallback { + t.Fatalf("unexpected default peer attach state: %+v", snapshot) + } + if !snapshot.LastPeerAttachAt.IsZero() { + t.Fatalf("LastPeerAttachAt mismatch: got %v want zero", snapshot.LastPeerAttachAt) + } + if snapshot.PeerAttachRequireExplicitAuth || snapshot.PeerAttachRequireChannelBinding || snapshot.PeerAttachChannelBindingConfigured { + t.Fatalf("unexpected default peer attach policy: %+v", snapshot) + } if got, want := snapshot.BulkNetworkProfile, "default"; got != want { t.Fatalf("BulkNetworkProfile mismatch: got %q want %q", got, want) } @@ -117,6 +132,24 @@ func TestGetServerRuntimeSnapshotDefaults(t *testing.T) { if !snapshot.HasRuntimeStopCtx { t.Fatalf("HasRuntimeStopCtx mismatch: got %v want true", snapshot.HasRuntimeStopCtx) } + if got, want := snapshot.AuthMode, "none"; got != want { + t.Fatalf("AuthMode mismatch: got %q want %q", got, want) + } + if got, want := snapshot.ProtectionMode, "managed"; got != want { + t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want) + } + if snapshot.PeerAttachRequireExplicitAuth || snapshot.PeerAttachRequireChannelBinding || snapshot.PeerAttachChannelBindingConfigured { + t.Fatalf("unexpected default peer attach policy: %+v", snapshot) + } + if got, want := snapshot.PeerAttachReplayWindow, peerAttachReplayTTL; got != want { + t.Fatalf("PeerAttachReplayWindow mismatch: got %s want %s", got, want) + } + if got, want := snapshot.PeerAttachReplayCapacity, defaultPeerAttachReplayCapacity; got != want { + t.Fatalf("PeerAttachReplayCapacity mismatch: got %d want %d", got, want) + } + if snapshot.PeerAttachExplicitAuth != 0 || snapshot.PeerAttachAuthFallbacks != 0 || snapshot.PeerAttachAuthRejects != 0 || snapshot.PeerAttachDowngradeRejects != 0 || snapshot.PeerAttachBindingRejects != 0 || snapshot.PeerAttachReplayRejects != 0 || snapshot.PeerAttachReplayOverflowRejects != 0 { + t.Fatalf("unexpected default peer attach counters: %+v", snapshot) + } if got, want := snapshot.BulkChunkSize, defaultBulkChunkSize; got != want { t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want) } @@ -476,6 +509,134 @@ func TestGetClientConnRuntimeSnapshotExposesDetachState(t *testing.T) { if snapshot.LastHeartbeatAt.IsZero() { t.Fatal("LastHeartbeatAt should be recorded") } + if got, want := snapshot.AuthMode, "none"; got != want { + t.Fatalf("AuthMode mismatch: got %q want %q", got, want) + } + if got, want := snapshot.ProtectionMode, "managed"; got != want { + t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want) + } + if snapshot.PeerAttachAuthenticated || snapshot.PeerAttachAuthFallback { + t.Fatalf("unexpected peer attach state: %+v", snapshot) + } + if !snapshot.LastPeerAttachAt.IsZero() { + t.Fatalf("LastPeerAttachAt mismatch: got %v want zero", snapshot.LastPeerAttachAt) + } +} + +func TestGetRuntimeSnapshotsIncludePeerAttachSecurityState(t *testing.T) { + secret := []byte("correct horse battery staple") + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + if err := UsePSKOverExternalTransportServer(server, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err) + } + }) + client := NewClient().(*ClientCommon) + if err := UsePSKOverExternalTransportClient(client, secret, testModernPSKOptions()); err != nil { + t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err) + } + + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachLogicalForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + deadline := time.Now().Add(time.Second) + for { + logical := server.GetLogicalConn(client.peerIdentity) + if logical != nil { + authenticated, fallback, _ := logical.peerAttachAuthenticatedSnapshot() + if authenticated && !fallback && logical.protectionModeSnapshot() == ProtectionExternal && server.peerAttachExplicitCount.Load() == 1 { + break + } + } + if time.Now().After(deadline) { + t.Fatal("peer attach security state did not converge before snapshot") + } + time.Sleep(time.Millisecond) + } + + clientSnapshot, err := GetClientRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) + } + if got, want := clientSnapshot.AuthMode, "psk"; got != want { + t.Fatalf("client AuthMode mismatch: got %q want %q", got, want) + } + if got, want := clientSnapshot.ProtectionMode, "external"; got != want { + t.Fatalf("client ProtectionMode mismatch: got %q want %q", got, want) + } + if clientSnapshot.PeerAttachRequireExplicitAuth || clientSnapshot.PeerAttachRequireChannelBinding || clientSnapshot.PeerAttachChannelBindingConfigured { + t.Fatalf("unexpected client peer attach policy snapshot: %+v", clientSnapshot) + } + if !clientSnapshot.PeerAttachAuthenticated || clientSnapshot.PeerAttachAuthFallback { + t.Fatalf("unexpected client peer attach state: %+v", clientSnapshot) + } + if clientSnapshot.LastPeerAttachAt.IsZero() { + t.Fatal("client LastPeerAttachAt should be recorded") + } + + serverSnapshot, err := GetServerRuntimeSnapshot(server) + if err != nil { + t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) + } + if got, want := serverSnapshot.AuthMode, "psk"; got != want { + t.Fatalf("server AuthMode mismatch: got %q want %q", got, want) + } + if got, want := serverSnapshot.ProtectionMode, "external"; got != want { + t.Fatalf("server ProtectionMode mismatch: got %q want %q", got, want) + } + if serverSnapshot.PeerAttachRequireExplicitAuth || serverSnapshot.PeerAttachRequireChannelBinding || serverSnapshot.PeerAttachChannelBindingConfigured { + t.Fatalf("unexpected server peer attach policy snapshot: %+v", serverSnapshot) + } + if got, want := serverSnapshot.PeerAttachExplicitAuth, int64(1); got != want { + t.Fatalf("PeerAttachExplicitAuth mismatch: got %d want %d", got, want) + } + if serverSnapshot.PeerAttachAuthFallbacks != 0 || serverSnapshot.PeerAttachAuthRejects != 0 || serverSnapshot.PeerAttachDowngradeRejects != 0 || serverSnapshot.PeerAttachBindingRejects != 0 || serverSnapshot.PeerAttachReplayRejects != 0 || serverSnapshot.PeerAttachReplayOverflowRejects != 0 { + t.Fatalf("unexpected server peer attach counters: %+v", serverSnapshot) + } + + logical := server.GetLogicalConn(client.peerIdentity) + if logical == nil { + t.Fatal("server logical should exist after peer attach") + } + logicalSnapshot, err := GetLogicalConnRuntimeSnapshot(logical) + if err != nil { + t.Fatalf("GetLogicalConnRuntimeSnapshot failed: %v", err) + } + if got, want := logicalSnapshot.AuthMode, "psk"; got != want { + t.Fatalf("logical AuthMode mismatch: got %q want %q", got, want) + } + if got, want := logicalSnapshot.ProtectionMode, "external"; got != want { + t.Fatalf("logical ProtectionMode mismatch: got %q want %q", got, want) + } + if !logicalSnapshot.PeerAttachAuthenticated || logicalSnapshot.PeerAttachAuthFallback { + t.Fatalf("unexpected logical peer attach state: %+v", logicalSnapshot) + } + if logicalSnapshot.LastPeerAttachAt.IsZero() { + t.Fatal("logical LastPeerAttachAt should be recorded") + } + + clientConnSnapshot, err := GetClientConnRuntimeSnapshot(clientConnFromLogical(logical)) + if err != nil { + t.Fatalf("GetClientConnRuntimeSnapshot failed: %v", err) + } + if got, want := clientConnSnapshot.AuthMode, "psk"; got != want { + t.Fatalf("client conn AuthMode mismatch: got %q want %q", got, want) + } + if got, want := clientConnSnapshot.ProtectionMode, "external"; got != want { + t.Fatalf("client conn ProtectionMode mismatch: got %q want %q", got, want) + } + if !clientConnSnapshot.PeerAttachAuthenticated || clientConnSnapshot.PeerAttachAuthFallback { + t.Fatalf("unexpected client conn peer attach state: %+v", clientConnSnapshot) + } + if clientConnSnapshot.LastPeerAttachAt.IsZero() { + t.Fatal("client conn LastPeerAttachAt should be recorded") + } } func TestGetServerDetachedClientRuntimeSnapshotsFiltersAndSorts(t *testing.T) { diff --git a/stream.go b/stream.go index a607e03..ab05aad 100644 --- a/stream.go +++ b/stream.go @@ -89,6 +89,60 @@ type streamCloseSender func(context.Context, *streamHandle, bool) error type streamResetSender func(context.Context, *streamHandle, string) error type streamDataSender func(context.Context, *streamHandle, []byte) error +type streamReadChunk struct { + data []byte + release func() +} + +func (c *streamReadChunk) clear() { + if c == nil { + return + } + if c.release != nil { + c.release() + } + c.data = nil + c.release = nil +} + +type streamReadPayloadOwner struct { + refs atomic.Int32 + release func() +} + +func newStreamReadPayloadOwner(release func()) *streamReadPayloadOwner { + if release == nil { + return nil + } + owner := &streamReadPayloadOwner{release: release} + owner.refs.Store(1) + return owner +} + +func (o *streamReadPayloadOwner) retainChunk() func() { + if o == nil { + return nil + } + o.refs.Add(1) + return o.releaseChunk +} + +func (o *streamReadPayloadOwner) releaseChunk() { + if o == nil { + return + } + if o.refs.Add(-1) == 0 && o.release != nil { + o.release() + } +} + +func (o *streamReadPayloadOwner) done() { + if o == nil { + return + } + o.releaseChunk() +} + type streamHandle struct { runtime *streamRuntime runtimeScope string @@ -124,8 +178,8 @@ type streamHandle struct { remoteClosed bool peerReadClosed bool resetErr error - readQueue [][]byte - readBuf []byte + readQueue []streamReadChunk + readBuf streamReadChunk bufferedBytes int readNotify chan struct{} readDeadline time.Time @@ -339,20 +393,23 @@ func (s *streamHandle) Read(p []byte) (int, error) { for { s.mu.Lock() localReadClosed := s.localReadClosed - if len(s.readBuf) > 0 { - n := copy(p, s.readBuf) - s.readBuf = s.readBuf[n:] + if len(s.readBuf.data) > 0 { + n := copy(p, s.readBuf.data) + s.readBuf.data = s.readBuf.data[n:] s.bufferedBytes -= n if s.bufferedBytes < 0 { s.bufferedBytes = 0 } + if len(s.readBuf.data) == 0 { + s.readBuf.clear() + } s.recordReadLocked(n, time.Now()) s.mu.Unlock() return n, nil } if len(s.readQueue) > 0 { s.readBuf = s.readQueue[0] - s.readQueue[0] = nil + s.readQueue[0] = streamReadChunk{} s.readQueue = s.readQueue[1:] s.mu.Unlock() continue @@ -824,43 +881,61 @@ func (s *streamHandle) pushOwnedChunk(chunk []byte) error { return s.pushChunkWithOwnership(chunk, true) } +func (s *streamHandle) pushOwnedChunkWithRelease(chunk []byte, release func()) error { + return s.pushChunkWithOwnershipAndRelease(chunk, true, release) +} + func (s *streamHandle) pushChunkWithOwnership(chunk []byte, owned bool) error { + return s.pushChunkWithOwnershipAndRelease(chunk, owned, nil) +} + +func (s *streamHandle) pushChunkWithOwnershipAndRelease(chunk []byte, owned bool, release func()) error { if s == nil { return io.ErrClosedPipe } if len(chunk) == 0 { + if release != nil { + release() + } return nil } - stored := chunk + stored := streamReadChunk{data: chunk, release: release} if !owned { - stored = append([]byte(nil), chunk...) + stored.data = append([]byte(nil), chunk...) + if stored.release != nil { + stored.release() + stored.release = nil + } } s.mu.Lock() if s.resetErr != nil { err := s.resetErr s.mu.Unlock() + stored.clear() return err } if s.inboundQueueLimit > 0 && s.bufferedChunkCountLocked() >= s.inboundQueueLimit { err := s.markResetLocked(errStreamBackpressureExceeded) s.mu.Unlock() + stored.clear() s.notifyReadable() s.finalize() return err } - if s.inboundBytesLimit > 0 && s.bufferedBytes+len(stored) > s.inboundBytesLimit { + if s.inboundBytesLimit > 0 && s.bufferedBytes+len(stored.data) > s.inboundBytesLimit { err := s.markResetLocked(errStreamBackpressureExceeded) s.mu.Unlock() + stored.clear() s.notifyReadable() s.finalize() return err } - if len(s.readBuf) == 0 && len(s.readQueue) == 0 { + if len(s.readBuf.data) == 0 && len(s.readQueue) == 0 { s.readBuf = stored } else { s.readQueue = append(s.readQueue, stored) } - s.bufferedBytes += len(stored) + s.bufferedBytes += len(stored.data) s.notifyReadableLocked() s.mu.Unlock() return nil @@ -881,11 +956,12 @@ func (s *streamHandle) clearBufferedDataLocked() { if s == nil { return } + s.readBuf.clear() for i := range s.readQueue { - s.readQueue[i] = nil + s.readQueue[i].clear() } s.readQueue = nil - s.readBuf = nil + s.readBuf = streamReadChunk{} s.bufferedBytes = 0 } @@ -894,7 +970,7 @@ func (s *streamHandle) bufferedChunkCountLocked() int { return 0 } count := len(s.readQueue) - if len(s.readBuf) > 0 { + if len(s.readBuf.data) > 0 { count++ } return count diff --git a/stream_benchmark_test.go b/stream_benchmark_test.go index fd7118a..2762c98 100644 --- a/stream_benchmark_test.go +++ b/stream_benchmark_test.go @@ -56,7 +56,59 @@ func BenchmarkStreamTCPThroughput(b *testing.B) { for _, tc := range cases { b.Run(tc.name, func(b *testing.B) { - benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg) + benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg, benchmarkTransportSecurityModernPSK) + }) + } +} + +func BenchmarkStreamTCPThroughputTrustedRaw(b *testing.B) { + cases := []struct { + name string + payloadSize int + cfg StreamConfig + }{ + { + name: "default_64KiB", + payloadSize: 64 * 1024, + }, + { + name: "tuned_256KiB", + payloadSize: 256 * 1024, + cfg: StreamConfig{ + ChunkSize: 256 * 1024, + InboundQueueLimit: 256, + InboundBufferedBytesLimit: 32 * 1024 * 1024, + OutboundWindowBytes: 8 * 1024 * 1024, + OutboundMaxInFlightChunks: 32, + }, + }, + { + name: "tuned_512KiB", + payloadSize: 512 * 1024, + cfg: StreamConfig{ + ChunkSize: 512 * 1024, + InboundQueueLimit: 256, + InboundBufferedBytesLimit: 64 * 1024 * 1024, + OutboundWindowBytes: 16 * 1024 * 1024, + OutboundMaxInFlightChunks: 32, + }, + }, + { + name: "tuned_1MiB", + payloadSize: 1024 * 1024, + cfg: StreamConfig{ + ChunkSize: 1024 * 1024, + InboundQueueLimit: 256, + InboundBufferedBytesLimit: 64 * 1024 * 1024, + OutboundWindowBytes: 16 * 1024 * 1024, + OutboundMaxInFlightChunks: 32, + }, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg, benchmarkTransportSecurityTrustedRaw) }) } } @@ -108,19 +160,69 @@ func BenchmarkStreamTCPThroughputConcurrent(b *testing.B) { for _, tc := range cases { b.Run(tc.name, func(b *testing.B) { - benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg) + benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg, benchmarkTransportSecurityModernPSK) }) } } -func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfig) { +func BenchmarkStreamTCPThroughputConcurrentTrustedRaw(b *testing.B) { + cases := []struct { + name string + payloadSize int + concurrency int + cfg StreamConfig + }{ + { + name: "streams_2_512KiB", + payloadSize: 512 * 1024, + concurrency: 2, + cfg: StreamConfig{ + ChunkSize: 512 * 1024, + InboundQueueLimit: 512, + InboundBufferedBytesLimit: 128 * 1024 * 1024, + OutboundWindowBytes: 32 * 1024 * 1024, + OutboundMaxInFlightChunks: 64, + }, + }, + { + name: "streams_4_512KiB", + payloadSize: 512 * 1024, + concurrency: 4, + cfg: StreamConfig{ + ChunkSize: 512 * 1024, + InboundQueueLimit: 1024, + InboundBufferedBytesLimit: 256 * 1024 * 1024, + OutboundWindowBytes: 64 * 1024 * 1024, + OutboundMaxInFlightChunks: 128, + }, + }, + { + name: "streams_8_512KiB", + payloadSize: 512 * 1024, + concurrency: 8, + cfg: StreamConfig{ + ChunkSize: 512 * 1024, + InboundQueueLimit: 2048, + InboundBufferedBytesLimit: 512 * 1024 * 1024, + OutboundWindowBytes: 128 * 1024 * 1024, + OutboundMaxInFlightChunks: 256, + }, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg, benchmarkTransportSecurityTrustedRaw) + }) + } +} + +func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfig, securityMode benchmarkTransportSecurityMode) { b.Helper() server := NewServer().(*ServerCommon) server.SetStreamConfig(cfg) - if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { - b.Fatalf("UseModernPSKServer failed: %v", err) - } + benchmarkApplyServerTransportSecurity(b, server, securityMode) acceptCh := make(chan StreamAcceptInfo, 1) server.SetStreamHandler(func(info StreamAcceptInfo) error { @@ -137,9 +239,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi client := NewClient().(*ClientCommon) client.SetStreamConfig(cfg) - if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { - b.Fatalf("UseModernPSKClient failed: %v", err) - } + benchmarkApplyClientTransportSecurity(b, client, securityMode) if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil { b.Fatalf("client Connect failed: %v", err) } @@ -198,7 +298,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi _ = stream.Close() } -func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, cfg StreamConfig) { +func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, cfg StreamConfig, securityMode benchmarkTransportSecurityMode) { b.Helper() if concurrency <= 0 { b.Fatal("concurrency must be > 0") @@ -206,9 +306,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu server := NewServer().(*ServerCommon) server.SetStreamConfig(cfg) - if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { - b.Fatalf("UseModernPSKServer failed: %v", err) - } + benchmarkApplyServerTransportSecurity(b, server, securityMode) acceptCh := make(chan StreamAcceptInfo, concurrency*2) server.SetStreamHandler(func(info StreamAcceptInfo) error { @@ -225,9 +323,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu client := NewClient().(*ClientCommon) client.SetStreamConfig(cfg) - if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { - b.Fatalf("UseModernPSKClient failed: %v", err) - } + benchmarkApplyClientTransportSecurity(b, client, securityMode) if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil { b.Fatalf("client Connect failed: %v", err) } diff --git a/stream_buffer_release_test.go b/stream_buffer_release_test.go new file mode 100644 index 0000000..18a8005 --- /dev/null +++ b/stream_buffer_release_test.go @@ -0,0 +1,85 @@ +package notify + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestStreamOwnedChunkReleaseAfterRead(t *testing.T) { + stream := newStreamHandle(context.Background(), newStreamRuntime("stream-buffer-release-read"), clientFileScope(), StreamOpenRequest{ + StreamID: "stream-buffer-release-read", + DataID: 1, + }, 0, nil, nil, 0, nil, nil, nil, streamConfig{}) + + released := 0 + if err := stream.pushOwnedChunkWithRelease([]byte("hello"), func() { + released++ + }); err != nil { + t.Fatalf("pushOwnedChunkWithRelease failed: %v", err) + } + + buf := make([]byte, 5) + n, err := stream.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + if n != 5 || string(buf[:n]) != "hello" { + t.Fatalf("Read = %d %q, want 5 hello", n, string(buf[:n])) + } + if released != 1 { + t.Fatalf("release count = %d, want 1", released) + } +} + +func TestStreamOwnedChunkReleaseOnReset(t *testing.T) { + stream := newStreamHandle(context.Background(), newStreamRuntime("stream-buffer-release-reset"), clientFileScope(), StreamOpenRequest{ + StreamID: "stream-buffer-release-reset", + DataID: 1, + }, 0, nil, nil, 0, nil, nil, nil, streamConfig{}) + + released := 0 + if err := stream.pushOwnedChunkWithRelease([]byte("hello"), func() { + released++ + }); err != nil { + t.Fatalf("pushOwnedChunkWithRelease failed: %v", err) + } + + stream.markReset(errors.New("boom")) + if released != 1 { + t.Fatalf("release count = %d, want 1", released) + } +} + +func TestClientDispatchFastStreamDataWithOwnerReleasesAfterRead(t *testing.T) { + client := NewClient().(*ClientCommon) + runtime := client.getStreamRuntime() + if runtime == nil { + t.Fatal("client stream runtime should not be nil") + } + stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ + StreamID: "stream-owner", + DataID: 23, + Channel: StreamDataChannel, + }, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot()) + if err := runtime.register(clientFileScope(), stream); err != nil { + t.Fatalf("register stream failed: %v", err) + } + + released := 0 + owner := newStreamReadPayloadOwner(func() { + released++ + }) + client.dispatchFastStreamDataWithOwner(streamFastDataFrame{ + DataID: 23, + Seq: 1, + Payload: []byte("fast-owner"), + }, owner) + owner.done() + + readStreamExactly(t, stream, "fast-owner", 2*time.Second) + if released != 1 { + t.Fatalf("release count = %d, want 1", released) + } +} diff --git a/stream_dispatcher.go b/stream_dispatcher.go index b55eea4..22f9775 100644 --- a/stream_dispatcher.go +++ b/stream_dispatcher.go @@ -83,6 +83,10 @@ func (s *ServerCommon) dispatchStreamEnvelope(logical *LogicalConn, transport *T } func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) { + c.dispatchFastStreamDataWithOwner(frame, nil) +} + +func (c *ClientCommon) dispatchFastStreamDataWithOwner(frame streamFastDataFrame, owner *streamReadPayloadOwner) { if frame.DataID == 0 { return } @@ -107,7 +111,13 @@ func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) { c.bestEffortRejectInboundStreamData(stream.ID(), frame.DataID, detachErr.Error()) return } - if err := stream.pushOwnedChunk(frame.Payload); err != nil { + var err error + if owner != nil { + err = stream.pushOwnedChunkWithRelease(frame.Payload, owner.retainChunk()) + } else { + err = stream.pushOwnedChunk(frame.Payload) + } + if err != nil { if c.showError || c.debugMode { fmt.Println("client stream push chunk error", err) } @@ -118,6 +128,10 @@ func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) { } func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame) { + s.dispatchFastStreamDataWithOwner(logical, transport, conn, frame, nil) +} + +func (s *ServerCommon) dispatchFastStreamDataWithOwner(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame, owner *streamReadPayloadOwner) { if logical == nil || frame.DataID == 0 { return } @@ -141,7 +155,13 @@ func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *T s.bestEffortRejectInboundStreamData(logical, transport, conn, stream.ID(), frame.DataID, detachErr.Error()) return } - if err := stream.pushOwnedChunk(frame.Payload); err != nil { + var err error + if owner != nil { + err = stream.pushOwnedChunkWithRelease(frame.Payload, owner.retainChunk()) + } else { + err = stream.pushOwnedChunk(frame.Payload) + } + if err != nil { if s.showError || s.debugMode { fmt.Println("server stream push chunk error", err) } diff --git a/stream_fastpath.go b/stream_fastpath.go index c234d1c..5eb6314 100644 --- a/stream_fastpath.go +++ b/stream_fastpath.go @@ -156,11 +156,12 @@ func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error } func (c *ClientCommon) encodeFastStreamPayload(frame streamFastDataFrame) ([]byte, error) { - if c != nil && c.fastStreamEncode != nil && frame.Flags == 0 { - return c.fastStreamEncode(c.SecretKey, frame.DataID, frame.Seq, frame.Payload) + profile := c.clientTransportProtectionSnapshot() + if c != nil && profile.fastStreamEncode != nil && frame.Flags == 0 { + return profile.fastStreamEncode(profile.secretKey, frame.DataID, frame.Seq, frame.Payload) } - if c != nil && c.fastPlainEncode != nil { - return encodeStreamFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame) + if c != nil && profile.fastPlainEncode != nil { + return encodeStreamFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame) } plain, err := encodeStreamFastFramePayload(frame) if err != nil { @@ -181,8 +182,9 @@ func (c *ClientCommon) encodeFastStreamBatchPayload(frames []streamFastDataFrame if c == nil { return nil, errStreamClientNil } - if c.fastPlainEncode != nil { - return encodeStreamFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames) + profile := c.clientTransportProtectionSnapshot() + if profile.fastPlainEncode != nil { + return encodeStreamFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames) } plain, err := encodeStreamFastBatchPlain(frames) if err != nil { diff --git a/stream_shared_batch.go b/stream_shared_batch.go index 5b0fcac..29aab79 100644 --- a/stream_shared_batch.go +++ b/stream_shared_batch.go @@ -91,25 +91,24 @@ func writeStreamFastBatchPlain(dst []byte, frames []streamFastDataFrame) error { return nil } -func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, error) { +func walkStreamFastBatchPlain(payload []byte, fn func(streamFastDataFrame) error) (bool, error) { if len(payload) < 4 || string(payload[:4]) != streamFastBatchMagic { - return nil, false, nil + return false, nil } if len(payload) < streamFastBatchHeaderLen { - return nil, true, errStreamFastPayloadInvalid + return true, errStreamFastPayloadInvalid } if payload[4] != streamFastBatchVersion { - return nil, true, errStreamFastPayloadInvalid + return true, errStreamFastPayloadInvalid } count := int(binary.BigEndian.Uint32(payload[8:12])) if count <= 0 { - return nil, true, errStreamFastPayloadInvalid + return true, errStreamFastPayloadInvalid } - frames := make([]streamFastDataFrame, 0, count) offset := streamFastBatchHeaderLen for index := 0; index < count; index++ { if len(payload)-offset < streamFastBatchItemHeaderLen { - return nil, true, errStreamFastPayloadInvalid + return true, errStreamFastPayloadInvalid } flags := payload[offset] dataID := binary.BigEndian.Uint64(payload[offset+4 : offset+12]) @@ -117,29 +116,62 @@ func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, er payloadLen := int(binary.BigEndian.Uint32(payload[offset+20 : offset+24])) offset += streamFastBatchItemHeaderLen if dataID == 0 || payloadLen < 0 || len(payload)-offset < payloadLen { - return nil, true, errStreamFastPayloadInvalid + return true, errStreamFastPayloadInvalid + } + if fn != nil { + if err := fn(streamFastDataFrame{ + Flags: flags, + DataID: dataID, + Seq: seq, + Payload: payload[offset : offset+payloadLen], + }); err != nil { + return true, err + } } - frames = append(frames, streamFastDataFrame{ - Flags: flags, - DataID: dataID, - Seq: seq, - Payload: payload[offset : offset+payloadLen], - }) offset += payloadLen } if offset != len(payload) { - return nil, true, errStreamFastPayloadInvalid + return true, errStreamFastPayloadInvalid + } + return true, nil +} + +func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, error) { + frames := make([]streamFastDataFrame, 0, 1) + matched, err := walkStreamFastBatchPlain(payload, func(frame streamFastDataFrame) error { + frames = append(frames, frame) + return nil + }) + if !matched || err != nil { + return nil, matched, err } return frames, true, nil } -func decodeStreamFastDataFrames(payload []byte) ([]streamFastDataFrame, bool, error) { - if frames, matched, err := decodeStreamFastBatchPlain(payload); matched { - return frames, true, err +func walkStreamFastFrames(payload []byte, fn func(streamFastDataFrame) error) (bool, error) { + if matched, err := walkStreamFastBatchPlain(payload, fn); matched { + return true, err } frame, matched, err := decodeStreamFastDataFrame(payload) + if !matched || err != nil { + return matched, err + } + if fn != nil { + if err := fn(frame); err != nil { + return true, err + } + } + return true, nil +} + +func decodeStreamFastDataFrames(payload []byte) ([]streamFastDataFrame, bool, error) { + frames := make([]streamFastDataFrame, 0, 1) + matched, err := walkStreamFastFrames(payload, func(frame streamFastDataFrame) error { + frames = append(frames, frame) + return nil + }) if !matched || err != nil { return nil, matched, err } - return []streamFastDataFrame{frame}, true, nil + return frames, true, nil } diff --git a/transport_codec.go b/transport_codec.go index e810a08..b0ad87f 100644 --- a/transport_codec.go +++ b/transport_codec.go @@ -9,7 +9,10 @@ var ( errTransportPayloadDecryptFailed = errors.New("transport payload decrypt failed") ) -func encryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgEn func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) { +func encryptTransportPayloadCodec(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgEn func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) { + if mode == ProtectionExternal { + return data, nil + } if runtime != nil { encoded, err := runtime.sealPlainPayload(data) if err != nil { @@ -27,7 +30,10 @@ func encryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgEn func([]b return encoded, nil } -func decryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) { +func decryptTransportPayloadCodec(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) { + if mode == ProtectionExternal { + return data, nil + } if runtime != nil { plain, err := runtime.openPayload(data) if err != nil { @@ -45,7 +51,10 @@ func decryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgDe func([]b return plain, nil } -func decryptTransportPayloadCodecPooled(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte, release func()) ([]byte, func(), error) { +func decryptTransportPayloadCodecPooled(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte, release func()) ([]byte, func(), error) { + if mode == ProtectionExternal { + return data, release, nil + } if runtime != nil { plain, plainRelease, err := runtime.openPayloadPooled(data, release) if err != nil { @@ -69,7 +78,10 @@ func decryptTransportPayloadCodecPooled(runtime *modernPSKCodecRuntime, msgDe fu return plain, nil, nil } -func decryptTransportPayloadCodecOwnedPooled(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, func(), error) { +func decryptTransportPayloadCodecOwnedPooled(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, func(), error) { + if mode == ProtectionExternal { + return data, nil, nil + } if runtime != nil { plain, plainRelease, err := runtime.openPayloadOwnedPooled(data) if err != nil { @@ -124,9 +136,14 @@ func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte if err != nil { return nil, err } + logical := logicalConnFromClient(c) msgEn := c.clientConnMsgEnSnapshot() secretKey := c.clientConnSecretKeySnapshot() - data, err = encryptTransportPayloadCodec(c.clientConnModernPSKRuntimeSnapshot(), msgEn, secretKey, data) + mode := ProtectionManaged + if logical != nil { + mode = logical.protectionModeSnapshot() + } + data, err = encryptTransportPayloadCodec(mode, c.clientConnModernPSKRuntimeSnapshot(), msgEn, secretKey, data) if err != nil { return nil, err } @@ -140,7 +157,11 @@ func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte func (s *ServerCommon) decodeTransferMsg(c *ClientConn, data []byte) (TransferMsg, error) { msgDe := c.clientConnMsgDeSnapshot() secretKey := c.clientConnSecretKeySnapshot() - plain, err := decryptTransportPayloadCodec(c.clientConnModernPSKRuntimeSnapshot(), msgDe, secretKey, data) + mode := ProtectionManaged + if logical := logicalConnFromClient(c); logical != nil { + mode = logical.protectionModeSnapshot() + } + plain, err := decryptTransportPayloadCodec(mode, c.clientConnModernPSKRuntimeSnapshot(), msgDe, secretKey, data) if err != nil { return TransferMsg{}, err } @@ -172,7 +193,8 @@ func (c *ClientCommon) encodeEnvelopePlain(env Envelope) ([]byte, error) { } func (c *ClientCommon) encryptTransportPayload(data []byte) ([]byte, error) { - return encryptTransportPayloadCodec(c.modernPSKRuntime, c.msgEn, c.SecretKey, data) + profile := c.clientTransportProtectionSnapshot() + return encryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgEn, profile.secretKey, data) } func (c *ClientCommon) encodeEnvelope(env Envelope) ([]byte, error) { @@ -196,7 +218,8 @@ func (c *ClientCommon) decodeEnvelope(data []byte) (Envelope, error) { } func (c *ClientCommon) decryptTransportPayload(data []byte) ([]byte, error) { - return decryptTransportPayloadCodec(c.modernPSKRuntime, c.msgDe, c.SecretKey, data) + profile := c.clientTransportProtectionSnapshot() + return decryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, data) } func (c *ClientCommon) decodeEnvelopePlain(data []byte) (Envelope, error) { @@ -251,7 +274,7 @@ func (s *ServerCommon) encryptTransportPayloadLogical(logical *LogicalConn, data if msgEn == nil { return nil, errTransportDetached } - return encryptTransportPayloadCodec(logical.modernPSKRuntimeSnapshot(), msgEn, secretKey, data) + return encryptTransportPayloadCodec(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), msgEn, secretKey, data) } func (s *ServerCommon) encodeEnvelopeLogical(logical *LogicalConn, env Envelope) ([]byte, error) { @@ -290,7 +313,7 @@ func (s *ServerCommon) decryptTransportPayloadLogical(logical *LogicalConn, data if msgDe == nil { return nil, errTransportDetached } - return decryptTransportPayloadCodec(logical.modernPSKRuntimeSnapshot(), msgDe, secretKey, data) + return decryptTransportPayloadCodec(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), msgDe, secretKey, data) } func (s *ServerCommon) decodeEnvelopePlain(data []byte) (Envelope, error) { diff --git a/transport_write_test.go b/transport_write_test.go index d541988..b20c5ce 100644 --- a/transport_write_test.go +++ b/transport_write_test.go @@ -143,6 +143,43 @@ func TestStreamBatchSenderRespectsBindingWriteDeadlineWhenReceiverStalls(t *test } } +func TestBulkBatchSenderFlushAggregatesAdaptivePayloadObservation(t *testing.T) { + conn := &delayedWriteConn{delay: 20 * time.Millisecond} + binding := newTransportBinding(conn, stario.NewQueue()) + sender := newTestBulkBatchSender(binding) + payloadA := bytes.Repeat([]byte("a"), 128*1024) + payloadB := bytes.Repeat([]byte("b"), 128*1024) + + err := sender.flush([]bulkBatchRequest{ + { + ctx: context.Background(), + frames: []bulkFastFrame{{ + Type: bulkFastPayloadTypeData, + DataID: 1, + Seq: 1, + Payload: payloadA, + }}, + fastPathVersion: bulkFastPathVersionV1, + }, + { + ctx: context.Background(), + frames: []bulkFastFrame{{ + Type: bulkFastPayloadTypeData, + DataID: 2, + Seq: 1, + Payload: payloadB, + }}, + fastPathVersion: bulkFastPathVersionV1, + }, + }) + if err != nil { + t.Fatalf("flush failed: %v", err) + } + if got, want := binding.bulkAdaptiveSoftPayloadBytesSnapshot(), bulkAdaptiveSoftPayloadMinBytes; got != want { + t.Fatalf("adaptive bulk soft payload = %d, want %d", got, want) + } +} + func TestControlBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) { left, right := net.Pipe() defer left.Close() @@ -255,6 +292,10 @@ func (c *vectoredShortWriteConn) WriteBuffers(bufs *net.Buffers) (int64, error) return written, nil } +func (c *vectoredShortWriteConn) writeBuffers(bufs *net.Buffers) (int64, error) { + return c.WriteBuffers(bufs) +} + type unwrapVectoredConn struct { inner net.Conn }