feat(transport): 完成安全架构拆分并收口 stream/bulk 传输优化
- 新增 managed/external/nested 三种传输保护模式 - 新增 peer attach 显式认证、抗重放、channel binding 和可选前向保密协商 - 明确单连接注入与可重拨连接源的语义边界 - 禁止 ConnectByConn 场景下 dedicated bulk 走 sidecar,auto 模式自动回退 shared - 修正 dedicated attach 在 bootstrap/steady profile 切换下的处理逻辑 - 优化 shared bulk super-batch 与批量 framed write 路径 - 降低 stream/bulk fast path 的复制和分发损耗 - 补齐 benchmark、回归测试、运行时快照和 README 文档
This commit is contained in:
parent
f038a89771
commit
98ef9e7fcc
54
README.md
54
README.md
@ -25,6 +25,21 @@
|
||||
|
||||
未配置时会返回 `errModernPSKRequired`。
|
||||
|
||||
## 安全模式选择
|
||||
|
||||
- `UseModernPSKClient` / `UseModernPSKServer`
|
||||
- bootstrap 和稳态传输都由 `notify` 自己保护
|
||||
- 适合默认场景
|
||||
- 支持 peer attach 显式认证、抗重放,以及在需要时协商前向保密
|
||||
- `UsePSKOverExternalTransportClient` / `UsePSKOverExternalTransportServer`
|
||||
- bootstrap 仍用 PSK 做认证
|
||||
- 稳态阶段信任外部物理通道,不再做 `notify` 内层加密
|
||||
- 适合 `tls.Conn` 或调用方自认可信的外部通道
|
||||
- 不支持 `RequireForwardSecrecy`
|
||||
- `UseNestedSecurityClient` / `UseNestedSecurityServer`
|
||||
- 外层已有可信通道,但仍保留 `notify` 内层保护
|
||||
- 适合需要“外层可信 + 内层独立保护”的场景
|
||||
|
||||
## 快速开始
|
||||
|
||||
服务端:
|
||||
@ -83,6 +98,42 @@ func main() {
|
||||
}
|
||||
```
|
||||
|
||||
## 连接入口与物理连接语义
|
||||
|
||||
- `Connect` / `ConnectTimeout`
|
||||
- 由 `notify` 自己拨号
|
||||
- 支持重连,也支持 dedicated bulk 额外 sidecar 连接
|
||||
- `ConnectByFactory`
|
||||
- 调用方提供 `dialFn`
|
||||
- `notify` 会在需要时再次调用 `dialFn`,因此仍支持重连和 dedicated bulk
|
||||
- `ConnectByConn`
|
||||
- 调用方注入一个已经建立好的 `net.Conn`
|
||||
- 该模式被视为“单物理连接模式”
|
||||
- `OpenDedicatedBulk` 会直接返回错误
|
||||
- `OpenBulk` 使用 `auto` 模式时会自动回退到 `shared`
|
||||
- `ListenByListener`
|
||||
- 服务端复用调用方提供的 `net.Listener`
|
||||
- 适合需要和现有 listener 栈整合的场景
|
||||
|
||||
`dedicated bulk` 依赖额外物理连接,因此只适用于可再次拨号的 transport source。
|
||||
|
||||
## Peer Attach 安全策略
|
||||
|
||||
可通过 `SetPeerAttachSecurityConfig` 配置逻辑会话 attach 阶段的额外保护。
|
||||
|
||||
- `RequireExplicitAuth`
|
||||
- 要求 peer attach 使用显式认证
|
||||
- `RequireChannelBinding`
|
||||
- 要求 attach 绑定到底层可信通道
|
||||
- 启用后会隐式要求显式认证
|
||||
- `ChannelBinding`
|
||||
- 由调用方提供 channel binding 提取函数
|
||||
- 适合外层 TLS 或其他可信通道整合
|
||||
- `ReplayWindow` / `ReplayCapacity`
|
||||
- 控制 attach 抗重放窗口和缓存容量
|
||||
|
||||
如果你选择 `UsePSKOverExternalTransport*`,并且希望 attach 阶段显式绑定到外层可信信道,建议同时配置 channel binding。
|
||||
|
||||
## RecordStream 说明
|
||||
|
||||
`RecordStream` 构建在 `Stream` 之上,适合“有边界的顺序记录”场景。
|
||||
@ -146,6 +197,9 @@ func main() {
|
||||
- 共享密钥派生(Argon2id)
|
||||
- 消息层加密(AES-GCM)
|
||||
- `stream` / `bulk` fast path 复用现代编码栈
|
||||
- peer attach 显式认证 / 抗重放
|
||||
- 可选 channel binding
|
||||
- 可选前向保密(`UseModernPSK*` / `UseNestedSecurity*`)
|
||||
|
||||
兼容入口仍保留,但属于历史路径:
|
||||
|
||||
|
||||
42
benchmark_transport_security_test.go
Normal file
42
benchmark_transport_security_test.go
Normal file
@ -0,0 +1,42 @@
|
||||
package notify
|
||||
|
||||
import "testing"
|
||||
|
||||
type benchmarkTransportSecurityMode string
|
||||
|
||||
const (
|
||||
benchmarkTransportSecurityModernPSK benchmarkTransportSecurityMode = "modern_psk"
|
||||
benchmarkTransportSecurityTrustedRaw benchmarkTransportSecurityMode = "trusted_raw"
|
||||
)
|
||||
|
||||
func benchmarkApplyServerTransportSecurity(tb testing.TB, server *ServerCommon, mode benchmarkTransportSecurityMode) {
|
||||
tb.Helper()
|
||||
switch mode {
|
||||
case benchmarkTransportSecurityModernPSK:
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
tb.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
case benchmarkTransportSecurityTrustedRaw:
|
||||
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
tb.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||
}
|
||||
default:
|
||||
tb.Fatalf("unsupported benchmark transport security mode %q", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkApplyClientTransportSecurity(tb testing.TB, client *ClientCommon, mode benchmarkTransportSecurityMode) {
|
||||
tb.Helper()
|
||||
switch mode {
|
||||
case benchmarkTransportSecurityModernPSK:
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
tb.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
case benchmarkTransportSecurityTrustedRaw:
|
||||
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
tb.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||
}
|
||||
default:
|
||||
tb.Fatalf("unsupported benchmark transport security mode %q", mode)
|
||||
}
|
||||
}
|
||||
4
bulk.go
4
bulk.go
@ -161,6 +161,7 @@ var (
|
||||
errBulkRangeInvalid = errors.New("bulk range is invalid")
|
||||
errBulkBackpressureExceeded = errors.New("bulk inbound backpressure exceeded")
|
||||
errBulkDedicatedStreamOnly = errors.New("dedicated bulk requires stream transport")
|
||||
errBulkDedicatedSingleConn = errors.New("dedicated bulk requires a dialable additional connection source; ConnectByConn only supports shared transport")
|
||||
errBulkDedicatedActiveLimit = errors.New("dedicated bulk active limit reached")
|
||||
)
|
||||
|
||||
@ -174,6 +175,9 @@ func clientDedicatedBulkSupportError(c *ClientCommon) error {
|
||||
if source := c.clientConnectSourceSnapshot(); source != nil && source.isUDP() {
|
||||
return errBulkDedicatedStreamOnly
|
||||
}
|
||||
if source := c.clientConnectSourceSnapshot(); source != nil && !source.supportsAdditionalConn() {
|
||||
return errBulkDedicatedSingleConn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -418,18 +418,18 @@ func (s *bulkBatchSender) flush(requests []bulkBatchRequest) error {
|
||||
}
|
||||
}()
|
||||
writeTimeout := s.transportWriteTimeout()
|
||||
frames := make([][]byte, 0, len(payloads))
|
||||
payloadBytes := 0
|
||||
for _, payload := range payloads {
|
||||
frame := payload.payload
|
||||
started := time.Now()
|
||||
err := s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error {
|
||||
return writeFramedPayloadUnlocked(conn, queue, frame)
|
||||
})
|
||||
s.binding.observeBulkAdaptivePayloadWrite(len(frame), time.Since(started), writeTimeout, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
frames = append(frames, payload.payload)
|
||||
payloadBytes += len(payload.payload)
|
||||
}
|
||||
return nil
|
||||
started := time.Now()
|
||||
err = s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error {
|
||||
return writeFramedPayloadBatchUnlocked(conn, queue, frames)
|
||||
})
|
||||
s.binding.observeBulkAdaptivePayloadWrite(payloadBytes, time.Since(started), writeTimeout, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *bulkBatchSender) encodeRequests(requests []bulkBatchRequest) ([]bulkBatchEncodedPayload, error) {
|
||||
|
||||
@ -38,7 +38,25 @@ func BenchmarkBulkTCPThroughput(b *testing.B) {
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkTCPThroughput(b, tc.payloadSize, false)
|
||||
benchmarkBulkTCPThroughput(b, tc.payloadSize, false, benchmarkTransportSecurityModernPSK)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBulkTCPThroughputTrustedRaw(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payloadSize int
|
||||
}{
|
||||
{name: "chunk_256KiB", payloadSize: 256 * 1024},
|
||||
{name: "chunk_512KiB", payloadSize: 512 * 1024},
|
||||
{name: "chunk_768KiB", payloadSize: 768 * 1024},
|
||||
{name: "chunk_1MiB", payloadSize: 1024 * 1024},
|
||||
{name: "chunk_2MiB", payloadSize: 2 * 1024 * 1024},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkTCPThroughput(b, tc.payloadSize, false, benchmarkTransportSecurityTrustedRaw)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -72,7 +90,25 @@ func BenchmarkBulkTCPThroughputDedicated(b *testing.B) {
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkTCPThroughput(b, tc.payloadSize, true)
|
||||
benchmarkBulkTCPThroughput(b, tc.payloadSize, true, benchmarkTransportSecurityModernPSK)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBulkTCPThroughputDedicatedTrustedRaw(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payloadSize int
|
||||
}{
|
||||
{name: "chunk_256KiB", payloadSize: 256 * 1024},
|
||||
{name: "chunk_512KiB", payloadSize: 512 * 1024},
|
||||
{name: "chunk_768KiB", payloadSize: 768 * 1024},
|
||||
{name: "chunk_1MiB", payloadSize: 1024 * 1024},
|
||||
{name: "chunk_2MiB", payloadSize: 2 * 1024 * 1024},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkTCPThroughput(b, tc.payloadSize, true, benchmarkTransportSecurityTrustedRaw)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -107,7 +143,25 @@ func BenchmarkBulkTCPThroughputConcurrent(b *testing.B) {
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false)
|
||||
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false, benchmarkTransportSecurityModernPSK)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBulkTCPThroughputConcurrentTrustedRaw(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payloadSize int
|
||||
concurrency int
|
||||
}{
|
||||
{name: "bulks_2_512KiB", payloadSize: 512 * 1024, concurrency: 2},
|
||||
{name: "bulks_4_512KiB", payloadSize: 512 * 1024, concurrency: 4},
|
||||
{name: "bulks_2_1MiB", payloadSize: 1024 * 1024, concurrency: 2},
|
||||
{name: "bulks_4_1MiB", payloadSize: 1024 * 1024, concurrency: 4},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false, benchmarkTransportSecurityTrustedRaw)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -142,18 +196,34 @@ func BenchmarkBulkTCPThroughputConcurrentDedicated(b *testing.B) {
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true)
|
||||
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true, benchmarkTransportSecurityModernPSK)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
|
||||
func BenchmarkBulkTCPThroughputConcurrentDedicatedTrustedRaw(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payloadSize int
|
||||
concurrency int
|
||||
}{
|
||||
{name: "bulks_2_512KiB", payloadSize: 512 * 1024, concurrency: 2},
|
||||
{name: "bulks_4_512KiB", payloadSize: 512 * 1024, concurrency: 4},
|
||||
{name: "bulks_2_1MiB", payloadSize: 1024 * 1024, concurrency: 2},
|
||||
{name: "bulks_4_1MiB", payloadSize: 1024 * 1024, concurrency: 4},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true, benchmarkTransportSecurityTrustedRaw)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool, securityMode benchmarkTransportSecurityMode) {
|
||||
b.Helper()
|
||||
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
b.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
benchmarkApplyServerTransportSecurity(b, server, securityMode)
|
||||
|
||||
acceptCh := make(chan BulkAcceptInfo, 1)
|
||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||
@ -169,9 +239,7 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
|
||||
})
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
b.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
benchmarkApplyClientTransportSecurity(b, client, securityMode)
|
||||
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
|
||||
b.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
@ -241,16 +309,14 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
|
||||
_ = bulk.Close()
|
||||
}
|
||||
|
||||
func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool) {
|
||||
func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool, securityMode benchmarkTransportSecurityMode) {
|
||||
b.Helper()
|
||||
if concurrency <= 0 {
|
||||
b.Fatal("concurrency must be > 0")
|
||||
}
|
||||
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
b.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
benchmarkApplyServerTransportSecurity(b, server, securityMode)
|
||||
|
||||
acceptCh := make(chan BulkAcceptInfo, concurrency*2)
|
||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||
@ -266,9 +332,7 @@ func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurr
|
||||
})
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
b.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
benchmarkApplyClientTransportSecurity(b, client, securityMode)
|
||||
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
|
||||
b.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
|
||||
@ -268,6 +268,9 @@ func readBulkDedicatedRecordPooled(conn net.Conn) ([]byte, func(), error) {
|
||||
func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context, timeout time.Duration) (net.Conn, error) {
|
||||
source := c.clientConnectSourceSnapshot()
|
||||
if source != nil {
|
||||
if !source.supportsAdditionalConn() {
|
||||
return nil, errBulkDedicatedSingleConn
|
||||
}
|
||||
if source.network != "" && source.addr != "" {
|
||||
if timeout > 0 {
|
||||
return transport.DialTimeout(source.network, source.addr, timeout)
|
||||
@ -277,6 +280,7 @@ func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context, timeout time.D
|
||||
if source.canReconnect() {
|
||||
return source.dial(ctx)
|
||||
}
|
||||
return nil, errClientReconnectSourceUnavailable
|
||||
}
|
||||
conn := c.clientTransportConnSnapshot()
|
||||
if conn == nil || conn.RemoteAddr() == nil {
|
||||
@ -661,7 +665,8 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn
|
||||
Value: reqPayload,
|
||||
Type: MSG_SYS_WAIT,
|
||||
}
|
||||
frame, err := encodeDirectSignalFrame(stario.NewQueue(), c.sequenceEn, c.msgEn, c.SecretKey, msg)
|
||||
attachProfile := c.clientDedicatedBulkAttachTransportProtectionProfile()
|
||||
frame, err := encodeDirectSignalFrame(stario.NewQueue(), c.sequenceEn, attachProfile.msgEn, attachProfile.secretKey, msg)
|
||||
if err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
@ -675,7 +680,7 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn
|
||||
if err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, replyPayload)
|
||||
transfer, err := decodeDirectSignalPayload(c.sequenceDe, attachProfile.msgDe, attachProfile.secretKey, replyPayload)
|
||||
if err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
@ -685,6 +690,16 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn
|
||||
return decodeBulkAttachResponse(c.sequenceDe, transfer.Value)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientDedicatedBulkAttachTransportProtectionProfile() transportProtectionProfile {
|
||||
if c == nil {
|
||||
return transportProtectionProfile{}
|
||||
}
|
||||
if c.securityConfigured && c.securityBootstrap.msgEn != nil && c.securityBootstrap.msgDe != nil {
|
||||
return c.securityBootstrap.clone()
|
||||
}
|
||||
return c.clientTransportProtectionSnapshot()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) readDedicatedSidecarLoop(sidecar *bulkDedicatedSidecar) {
|
||||
if c == nil || sidecar == nil || sidecar.conn == nil {
|
||||
return
|
||||
@ -695,7 +710,8 @@ func (c *ClientCommon) readDedicatedSidecarLoop(sidecar *bulkDedicatedSidecar) {
|
||||
c.handleClientDedicatedSidecarFailure(sidecar, err)
|
||||
return
|
||||
}
|
||||
plain, plainRelease, err := decryptTransportPayloadCodecPooled(c.modernPSKRuntime, c.msgDe, c.SecretKey, payload, payloadRelease)
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
plain, plainRelease, err := decryptTransportPayloadCodecPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload, payloadRelease)
|
||||
if err != nil {
|
||||
c.handleClientDedicatedSidecarFailure(sidecar, err)
|
||||
return
|
||||
@ -1023,7 +1039,7 @@ func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Mes
|
||||
Type: MSG_SYS_REPLY,
|
||||
}
|
||||
if message.inboundConn != nil {
|
||||
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply)
|
||||
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, messageInboundTransportProtectionSnapshot(&message), reply)
|
||||
}
|
||||
_, err = s.sendLogical(client, reply)
|
||||
return err
|
||||
@ -1041,7 +1057,7 @@ func (s *ServerCommon) readDedicatedSidecarLoop(logical *LogicalConn, sidecar *b
|
||||
s.handleServerDedicatedSidecarFailure(logical, sidecar, err)
|
||||
return
|
||||
}
|
||||
plain, plainRelease, err := decryptTransportPayloadCodecPooled(logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, payloadRelease)
|
||||
plain, plainRelease, err := decryptTransportPayloadCodecPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, payloadRelease)
|
||||
if err != nil {
|
||||
s.handleServerDedicatedSidecarFailure(logical, sidecar, err)
|
||||
return
|
||||
@ -1171,7 +1187,8 @@ func (c *ClientCommon) dedicatedBulkLaneSender(bulk *bulkHandle) (*bulkDedicated
|
||||
return nil, transportDetachedError("dedicated bulk sidecar not attached", nil)
|
||||
}
|
||||
sender := sidecar.laneSenderWithFactory(func(conn net.Conn) *bulkDedicatedLaneSender {
|
||||
laneRuntime := c.modernPSKRuntime
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
laneRuntime := profile.runtime
|
||||
if forked, err := forkDedicatedLaneModernPSKRuntime(laneRuntime); err == nil && forked != nil {
|
||||
laneRuntime = forked
|
||||
}
|
||||
@ -1198,7 +1215,7 @@ func (c *ClientCommon) sendDedicatedBulkData(ctx context.Context, bulk *bulkHand
|
||||
return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte) (int, error) {
|
||||
func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) {
|
||||
if c == nil || bulk == nil {
|
||||
return 0, errBulkClientNil
|
||||
}
|
||||
@ -1206,7 +1223,7 @@ func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHan
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize)
|
||||
return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize, payloadOwned)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error {
|
||||
@ -1345,7 +1362,7 @@ func (s *ServerCommon) sendDedicatedBulkData(ctx context.Context, logical *Logic
|
||||
return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, startSeq uint64, payload []byte) (int, error) {
|
||||
func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) {
|
||||
if s == nil || bulk == nil {
|
||||
return 0, errBulkServerNil
|
||||
}
|
||||
@ -1353,7 +1370,7 @@ func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *Logi
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize)
|
||||
return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize, payloadOwned)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error {
|
||||
@ -1419,13 +1436,14 @@ func (c *ClientCommon) encodeDedicatedBulkBatchPayload(dataID uint64, items []bu
|
||||
if c == nil {
|
||||
return nil, errBulkClientNil
|
||||
}
|
||||
if runtime := c.modernPSKRuntime; runtime != nil {
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
if runtime := profile.runtime; runtime != nil {
|
||||
return runtime.sealFilledPayload(bulkDedicatedBatchPlainLen(items), func(dst []byte) error {
|
||||
return writeBulkDedicatedBatchPlain(dst, dataID, items)
|
||||
})
|
||||
}
|
||||
if c.fastPlainEncode != nil {
|
||||
return encodeBulkDedicatedBatchPayloadFast(c.fastPlainEncode, c.SecretKey, dataID, items)
|
||||
if profile.fastPlainEncode != nil {
|
||||
return encodeBulkDedicatedBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, dataID, items)
|
||||
}
|
||||
plain, err := encodeBulkDedicatedBatchPlain(dataID, items)
|
||||
if err != nil {
|
||||
@ -1453,8 +1471,9 @@ func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooledWithRuntime(runtim
|
||||
return writeBulkDedicatedBatchesPlain(dst, batches)
|
||||
})
|
||||
}
|
||||
if c.fastPlainEncode != nil {
|
||||
payload, err := encodeBulkDedicatedBatchesPayloadFast(c.fastPlainEncode, c.SecretKey, batches)
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
if profile.fastPlainEncode != nil {
|
||||
payload, err := encodeBulkDedicatedBatchesPayloadFast(profile.fastPlainEncode, profile.secretKey, batches)
|
||||
return payload, nil, err
|
||||
}
|
||||
plain, err := encodeBulkDedicatedBatchesPlain(batches)
|
||||
@ -1466,7 +1485,7 @@ func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooledWithRuntime(runtim
|
||||
}
|
||||
|
||||
func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooled(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) {
|
||||
return c.encodeDedicatedBulkBatchesPayloadPooledWithRuntime(c.modernPSKRuntime, batches)
|
||||
return c.encodeDedicatedBulkBatchesPayloadPooledWithRuntime(c.clientTransportProtectionSnapshot().runtime, batches)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) encodeDedicatedBulkBatchPayloadPooled(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, func(), error) {
|
||||
|
||||
@ -115,6 +115,76 @@ func TestSendDedicatedBulkAttachRequestKeepsCoalescedDedicatedPayloadUnread(t *t
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendDedicatedBulkAttachRequestUsesBootstrapProtectionEvenAfterSteadySwitch(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
client.msgID = 100
|
||||
|
||||
alternate, err := deriveModernPSKProtectionProfile([]byte("notify-dedicated-attach-other-secret"), integrationModernPSKOptions(), ProtectionManaged)
|
||||
if err != nil {
|
||||
t.Fatalf("deriveModernPSKProtectionProfile(alternate) failed: %v", err)
|
||||
}
|
||||
client.setClientTransportProtectionProfile(alternate)
|
||||
|
||||
bulk := newBulkHandle(context.Background(), newBulkRuntime("dedicated-attach-bootstrap-test"), clientFileScope(), BulkOpenRequest{
|
||||
BulkID: "bulk-attach-bootstrap-test",
|
||||
DataID: 1,
|
||||
Dedicated: true,
|
||||
AttachToken: "attach-token",
|
||||
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
|
||||
|
||||
bootstrap := client.clientDedicatedBulkAttachTransportProtectionProfile()
|
||||
encodedResp, err := client.sequenceEn(bulkAttachResponse{Accepted: true})
|
||||
if err != nil {
|
||||
t.Fatalf("encode bulkAttachResponse failed: %v", err)
|
||||
}
|
||||
replyFrame, err := encodeDirectSignalFrame(stario.NewQueue(), client.sequenceEn, bootstrap.msgEn, bootstrap.secretKey, TransferMsg{
|
||||
ID: 101,
|
||||
Key: systemBulkAttachKey,
|
||||
Value: encodedResp,
|
||||
Type: MSG_SYS_REPLY,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("encode attach reply frame failed: %v", err)
|
||||
}
|
||||
conn := newBulkAttachScriptConn(replyFrame)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
resp, err := client.sendDedicatedBulkAttachRequest(ctx, conn, bulk)
|
||||
if err != nil {
|
||||
t.Fatalf("sendDedicatedBulkAttachRequest failed: %v", err)
|
||||
}
|
||||
if !resp.Accepted {
|
||||
t.Fatalf("bulk attach response = %+v, want accepted", resp)
|
||||
}
|
||||
|
||||
parsedReq := stario.NewQueue()
|
||||
var reqMsg TransferMsg
|
||||
if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request", func(msgq stario.MsgQueue) error {
|
||||
transfer, err := decodeDirectSignalPayload(client.sequenceDe, bootstrap.msgDe, bootstrap.secretKey, msgq.Msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reqMsg = transfer
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("parse written attach request with bootstrap profile failed: %v", err)
|
||||
}
|
||||
if reqMsg.Key != systemBulkAttachKey || reqMsg.Type != MSG_SYS_WAIT {
|
||||
t.Fatalf("attach request message mismatch: %+v", reqMsg)
|
||||
}
|
||||
|
||||
if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request-current", func(msgq stario.MsgQueue) error {
|
||||
_, err := decodeDirectSignalPayload(client.sequenceDe, alternate.msgDe, alternate.secretKey, msgq.Msg)
|
||||
return err
|
||||
}); !errors.Is(err, errTransportPayloadDecryptFailed) {
|
||||
t.Fatalf("decode written attach request with current steady profile error = %v, want %v", err, errTransportPayloadDecryptFailed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHandoff(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
UseLegacySecurityServer(server)
|
||||
|
||||
@ -83,7 +83,7 @@ func (r *bulkDedicatedLaneBatchRequest) reset() {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool) {
|
||||
func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool, borrowItems bool) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
@ -94,6 +94,10 @@ func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
r.Deadline = deadline
|
||||
}
|
||||
if borrowItems {
|
||||
r.Items = items
|
||||
return
|
||||
}
|
||||
if cap(r.Items) < len(items) {
|
||||
r.Items = make([]bulkDedicatedSendRequest, len(items))
|
||||
} else {
|
||||
@ -119,10 +123,10 @@ func (s *bulkDedicatedLaneSender) submitData(ctx context.Context, dataID uint64,
|
||||
Seq: seq,
|
||||
Payload: append([]byte(nil), payload...),
|
||||
}}
|
||||
return s.submitBatch(ctx, dataID, items, false)
|
||||
return s.submitBatch(ctx, dataID, items, false, false)
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int) (int, error) {
|
||||
func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int, payloadOwned bool) (int, error) {
|
||||
if s == nil {
|
||||
return 0, errTransportDetached
|
||||
}
|
||||
@ -132,6 +136,9 @@ func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64
|
||||
if chunkSize <= 0 {
|
||||
chunkSize = defaultBulkChunkSize
|
||||
}
|
||||
if submitted, written, err := s.tryDirectSubmitWrite(ctx, dataID, startSeq, payload, chunkSize); submitted {
|
||||
return written, err
|
||||
}
|
||||
written := 0
|
||||
seq := startSeq
|
||||
for written < len(payload) {
|
||||
@ -170,7 +177,7 @@ func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64
|
||||
seq++
|
||||
written = end
|
||||
}
|
||||
if err := s.submitWriteBatch(ctx, dataID, items); err != nil {
|
||||
if err := s.submitWriteBatch(ctx, dataID, items, payloadOwned); err != nil {
|
||||
return start, err
|
||||
}
|
||||
start = written
|
||||
@ -178,7 +185,7 @@ func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) error {
|
||||
func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, _ bool) error {
|
||||
if s == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
@ -188,8 +195,7 @@ func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID u
|
||||
if submitted, err := s.tryDirectSubmitBatch(ctx, dataID, items); submitted {
|
||||
return err
|
||||
}
|
||||
queuedItems := copyBulkDedicatedSendRequests(items)
|
||||
return s.submitBatch(ctx, dataID, queuedItems, true)
|
||||
return s.submitBatch(ctx, dataID, items, true, true)
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint64, frameType uint8, flags uint8, seq uint64, payload []byte) error {
|
||||
@ -204,10 +210,10 @@ func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint
|
||||
if len(payload) > 0 {
|
||||
items[0].Payload = append([]byte(nil), payload...)
|
||||
}
|
||||
return s.submitBatch(ctx, dataID, items, true)
|
||||
return s.submitBatch(ctx, dataID, items, true, false)
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool) error {
|
||||
func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool, borrowItems bool) error {
|
||||
if s == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
@ -218,7 +224,7 @@ func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64
|
||||
return err
|
||||
}
|
||||
req := getBulkDedicatedLaneBatchRequest()
|
||||
req.prepare(ctx, dataID, items, wait)
|
||||
req.prepare(ctx, dataID, items, wait, borrowItems)
|
||||
s.queued.Add(1)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@ -237,6 +243,104 @@ func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedLaneSender) tryDirectSubmitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int) (bool, int, error) {
|
||||
if s == nil {
|
||||
return true, 0, errTransportDetached
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return true, 0, nil
|
||||
}
|
||||
if chunkSize <= 0 {
|
||||
chunkSize = defaultBulkChunkSize
|
||||
}
|
||||
if err := s.errSnapshot(); err != nil {
|
||||
return true, 0, err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return true, 0, normalizeStreamDeadlineError(ctx.Err())
|
||||
case <-s.stopCh:
|
||||
return true, 0, s.stoppedErr()
|
||||
default:
|
||||
}
|
||||
if s.queued.Load() != 0 {
|
||||
return false, 0, nil
|
||||
}
|
||||
if !s.flushMu.TryLock() {
|
||||
return false, 0, nil
|
||||
}
|
||||
defer s.flushMu.Unlock()
|
||||
if s.queued.Load() != 0 {
|
||||
return false, 0, nil
|
||||
}
|
||||
if err := s.errSnapshot(); err != nil {
|
||||
return true, 0, err
|
||||
}
|
||||
written := 0
|
||||
seq := startSeq
|
||||
deadline, _ := ctx.Deadline()
|
||||
for written < len(payload) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return true, written, normalizeStreamDeadlineError(ctx.Err())
|
||||
case <-s.stopCh:
|
||||
return true, written, s.stoppedErr()
|
||||
default:
|
||||
}
|
||||
var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest
|
||||
items := itemBuf[:0]
|
||||
batchBytes := bulkDedicatedBatchHeaderLen
|
||||
start := written
|
||||
for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems {
|
||||
end := written + chunkSize
|
||||
if end > len(payload) {
|
||||
end = len(payload)
|
||||
}
|
||||
itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written)
|
||||
if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes {
|
||||
break
|
||||
}
|
||||
items = append(items, bulkDedicatedSendRequest{
|
||||
Type: bulkFastPayloadTypeData,
|
||||
Seq: seq,
|
||||
Payload: payload[written:end],
|
||||
})
|
||||
batchBytes += itemLen
|
||||
seq++
|
||||
written = end
|
||||
}
|
||||
if len(items) == 0 {
|
||||
end := written + chunkSize
|
||||
if end > len(payload) {
|
||||
end = len(payload)
|
||||
}
|
||||
items = append(items, bulkDedicatedSendRequest{
|
||||
Type: bulkFastPayloadTypeData,
|
||||
Seq: seq,
|
||||
Payload: payload[written:end],
|
||||
})
|
||||
seq++
|
||||
written = end
|
||||
}
|
||||
if err := s.flush([]bulkDedicatedOutboundBatch{{
|
||||
DataID: dataID,
|
||||
Items: items,
|
||||
}}, deadline); err != nil {
|
||||
err = normalizeDedicatedBulkSendError(err)
|
||||
s.setErr(err)
|
||||
s.failPending(err)
|
||||
if s.fail != nil {
|
||||
go s.fail(err)
|
||||
}
|
||||
return true, start, err
|
||||
}
|
||||
}
|
||||
return true, written, nil
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedLaneSender) tryDirectSubmitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) (bool, error) {
|
||||
if s == nil {
|
||||
return true, errTransportDetached
|
||||
|
||||
@ -7,6 +7,57 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBulkDedicatedLaneBatchRequestPrepareBorrowedSharesItems(t *testing.T) {
|
||||
req := getBulkDedicatedLaneBatchRequest()
|
||||
defer req.recycle()
|
||||
|
||||
items := []bulkDedicatedSendRequest{{
|
||||
Type: bulkFastPayloadTypeData,
|
||||
Seq: 7,
|
||||
Payload: []byte("hello"),
|
||||
}}
|
||||
req.prepare(context.Background(), 11, items, true, true)
|
||||
|
||||
if got, want := len(req.Items), 1; got != want {
|
||||
t.Fatalf("prepared item count = %d, want %d", got, want)
|
||||
}
|
||||
if &req.Items[0] != &items[0] {
|
||||
t.Fatal("prepare with borrowed items should share request items")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkDedicatedLaneSenderTryDirectSubmitWriteFlushesWholePayload(t *testing.T) {
|
||||
conn := &shortWriteBulkRecordConn{maxPerWrite: 1024}
|
||||
encodeCalls := 0
|
||||
sender := &bulkDedicatedLaneSender{
|
||||
conn: conn,
|
||||
stopCh: make(chan struct{}),
|
||||
encode: func(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) {
|
||||
encodeCalls++
|
||||
payload, err := encodeBulkDedicatedBatchesPlain(batches)
|
||||
return payload, nil, err
|
||||
},
|
||||
}
|
||||
|
||||
payload := bytes.Repeat([]byte("a"), 3*defaultBulkChunkSize)
|
||||
submitted, written, err := sender.tryDirectSubmitWrite(context.Background(), 9, 1, payload, defaultBulkChunkSize)
|
||||
if err != nil {
|
||||
t.Fatalf("tryDirectSubmitWrite error = %v", err)
|
||||
}
|
||||
if !submitted {
|
||||
t.Fatal("tryDirectSubmitWrite should submit directly")
|
||||
}
|
||||
if got, want := written, len(payload); got != want {
|
||||
t.Fatalf("written = %d, want %d", got, want)
|
||||
}
|
||||
if encodeCalls == 0 {
|
||||
t.Fatal("encode should be called at least once")
|
||||
}
|
||||
if got := sender.queued.Load(); got != 0 {
|
||||
t.Fatalf("queued requests = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkDedicatedLaneSenderCollectBatchRequestsBatchesAcrossDataIDs(t *testing.T) {
|
||||
sender := &bulkDedicatedLaneSender{
|
||||
reqCh: make(chan *bulkDedicatedLaneBatchRequest, 3),
|
||||
|
||||
@ -162,7 +162,8 @@ func (c *ClientCommon) tryDispatchBorrowedBulkTransportPayload(payload []byte) b
|
||||
if c == nil || len(payload) == 0 {
|
||||
return false
|
||||
}
|
||||
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(c.modernPSKRuntime, c.msgDe, c.SecretKey, payload)
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload)
|
||||
if err != nil {
|
||||
if c.showError || c.debugMode {
|
||||
fmt.Println("client decode transport payload error", err)
|
||||
@ -194,7 +195,7 @@ func (s *ServerCommon) tryDispatchBorrowedBulkTransportPayload(source interface{
|
||||
if logical == nil {
|
||||
return false
|
||||
}
|
||||
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload)
|
||||
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload)
|
||||
if err != nil {
|
||||
if s.showError || s.debugMode {
|
||||
fmt.Println("server decode transport payload error", err)
|
||||
|
||||
177
bulk_fastpath.go
177
bulk_fastpath.go
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@ -132,8 +133,9 @@ func (c *ClientCommon) encodeBulkFastPayload(frame bulkFastFrame) ([]byte, error
|
||||
if c == nil {
|
||||
return nil, errBulkClientNil
|
||||
}
|
||||
if c.fastPlainEncode != nil {
|
||||
return encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame)
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
if profile.fastPlainEncode != nil {
|
||||
return encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
|
||||
}
|
||||
plain, err := encodeBulkFastFramePayload(frame)
|
||||
if err != nil {
|
||||
@ -146,8 +148,9 @@ func (c *ClientCommon) encodeBulkFastBatchPayload(frames []bulkFastFrame) ([]byt
|
||||
if c == nil {
|
||||
return nil, errBulkClientNil
|
||||
}
|
||||
if c.fastPlainEncode != nil {
|
||||
return encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames)
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
if profile.fastPlainEncode != nil {
|
||||
return encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
|
||||
}
|
||||
plain, err := encodeBulkFastBatchPlain(frames)
|
||||
if err != nil {
|
||||
@ -160,11 +163,12 @@ func (c *ClientCommon) encodeBulkFastPayloadPooled(frame bulkFastFrame) ([]byte,
|
||||
if c == nil {
|
||||
return nil, nil, errBulkClientNil
|
||||
}
|
||||
if runtime := c.modernPSKRuntime; runtime != nil {
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
if runtime := profile.runtime; runtime != nil {
|
||||
return encodeBulkFastFramePayloadPooled(runtime, frame)
|
||||
}
|
||||
if c.fastPlainEncode != nil {
|
||||
payload, err := encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame)
|
||||
if profile.fastPlainEncode != nil {
|
||||
payload, err := encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
|
||||
return payload, nil, err
|
||||
}
|
||||
plain, err := encodeBulkFastFramePayload(frame)
|
||||
@ -179,11 +183,12 @@ func (c *ClientCommon) encodeBulkFastBatchPayloadPooled(frames []bulkFastFrame)
|
||||
if c == nil {
|
||||
return nil, nil, errBulkClientNil
|
||||
}
|
||||
if runtime := c.modernPSKRuntime; runtime != nil {
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
if runtime := profile.runtime; runtime != nil {
|
||||
return encodeBulkFastBatchPayloadPooled(runtime, frames)
|
||||
}
|
||||
if c.fastPlainEncode != nil {
|
||||
payload, err := encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames)
|
||||
if profile.fastPlainEncode != nil {
|
||||
payload, err := encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
|
||||
return payload, nil, err
|
||||
}
|
||||
plain, err := encodeBulkFastBatchPlain(frames)
|
||||
@ -460,28 +465,126 @@ func putBulkFastFrameScratch(buf []byte) {
|
||||
bulkFastFrameScratchPool.Put(buf[:0])
|
||||
}
|
||||
|
||||
func transportFastPayloadMagic(payload []byte) string {
|
||||
if len(payload) < 4 {
|
||||
return ""
|
||||
}
|
||||
return string(payload[:4])
|
||||
}
|
||||
|
||||
func (c *ClientCommon) decryptTransportPayloadPooled(payload []byte, release func()) ([]byte, func(), error) {
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
return decryptTransportPayloadCodecPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload, release)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) decryptTransportPayloadLogicalPooled(logical *LogicalConn, payload []byte, release func()) ([]byte, func(), error) {
|
||||
if logical == nil {
|
||||
if release != nil {
|
||||
release()
|
||||
}
|
||||
return nil, nil, errTransportDetached
|
||||
}
|
||||
return decryptTransportPayloadCodecPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, release)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) tryDispatchBorrowedTransportPlain(plain []byte, release func()) bool {
|
||||
switch transportFastPayloadMagic(plain) {
|
||||
case bulkFastPayloadMagic, bulkFastBatchMagic:
|
||||
owner := newBulkReadPayloadOwner(release)
|
||||
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
|
||||
c.dispatchFastBulkFrameWithOwner(frame, owner)
|
||||
return nil
|
||||
})
|
||||
if owner != nil {
|
||||
owner.done()
|
||||
}
|
||||
if !matched {
|
||||
walkErr = errBulkFastPayloadInvalid
|
||||
}
|
||||
if walkErr != nil && (c.showError || c.debugMode) {
|
||||
fmt.Println("client decode bulk fast payload error", walkErr)
|
||||
}
|
||||
return true
|
||||
case streamFastPayloadMagic, streamFastBatchMagic:
|
||||
owner := newStreamReadPayloadOwner(release)
|
||||
matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
|
||||
c.dispatchFastStreamDataWithOwner(frame, owner)
|
||||
return nil
|
||||
})
|
||||
if owner != nil {
|
||||
owner.done()
|
||||
}
|
||||
if !matched {
|
||||
walkErr = errStreamFastPayloadInvalid
|
||||
}
|
||||
if walkErr != nil && (c.showError || c.debugMode) {
|
||||
fmt.Println("client decode stream fast payload error", walkErr)
|
||||
}
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerCommon) tryDispatchBorrowedTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, release func()) bool {
|
||||
switch transportFastPayloadMagic(plain) {
|
||||
case bulkFastPayloadMagic, bulkFastBatchMagic:
|
||||
owner := newBulkReadPayloadOwner(release)
|
||||
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
|
||||
s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, owner)
|
||||
return nil
|
||||
})
|
||||
if owner != nil {
|
||||
owner.done()
|
||||
}
|
||||
if !matched {
|
||||
walkErr = errBulkFastPayloadInvalid
|
||||
}
|
||||
if walkErr != nil && (s.showError || s.debugMode) {
|
||||
fmt.Println("server decode bulk fast payload error", walkErr)
|
||||
}
|
||||
return true
|
||||
case streamFastPayloadMagic, streamFastBatchMagic:
|
||||
owner := newStreamReadPayloadOwner(release)
|
||||
matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
|
||||
s.dispatchFastStreamDataWithOwner(logical, transport, conn, frame, owner)
|
||||
return nil
|
||||
})
|
||||
if owner != nil {
|
||||
owner.done()
|
||||
}
|
||||
if !matched {
|
||||
walkErr = errStreamFastPayloadInvalid
|
||||
}
|
||||
if walkErr != nil && (s.showError || s.debugMode) {
|
||||
fmt.Println("server decode stream fast payload error", walkErr)
|
||||
}
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error {
|
||||
plain, err := c.decryptTransportPayload(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if frames, matched, err := decodeBulkFastFrames(plain); matched {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, frame := range frames {
|
||||
c.dispatchFastBulkFrame(frame)
|
||||
}
|
||||
return c.dispatchInboundTransportPlain(plain, now)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) dispatchInboundTransportPlain(plain []byte, now time.Time) error {
|
||||
if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
|
||||
c.dispatchFastBulkFrame(frame)
|
||||
return nil
|
||||
}); matched {
|
||||
return err
|
||||
}
|
||||
if frames, matched, err := decodeStreamFastDataFrames(plain); matched {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, frame := range frames {
|
||||
c.dispatchFastStreamData(frame)
|
||||
}
|
||||
if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
|
||||
c.dispatchFastStreamData(frame)
|
||||
return nil
|
||||
}); matched {
|
||||
return err
|
||||
}
|
||||
env, err := c.decodeEnvelopePlain(plain)
|
||||
if err != nil {
|
||||
@ -502,23 +605,21 @@ func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, tra
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if frames, matched, err := decodeBulkFastFrames(plain); matched {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, frame := range frames {
|
||||
s.dispatchFastBulkFrame(logical, transport, conn, frame)
|
||||
}
|
||||
return s.dispatchInboundTransportPlain(logical, transport, conn, plain, now)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) dispatchInboundTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, now time.Time) error {
|
||||
if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
|
||||
s.dispatchFastBulkFrame(logical, transport, conn, frame)
|
||||
return nil
|
||||
}); matched {
|
||||
return err
|
||||
}
|
||||
if frames, matched, err := decodeStreamFastDataFrames(plain); matched {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, frame := range frames {
|
||||
s.dispatchFastStreamData(logical, transport, conn, frame)
|
||||
}
|
||||
if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
|
||||
s.dispatchFastStreamData(logical, transport, conn, frame)
|
||||
return nil
|
||||
}); matched {
|
||||
return err
|
||||
}
|
||||
env, err := s.decodeEnvelopePlain(plain)
|
||||
if err != nil {
|
||||
|
||||
@ -100,6 +100,176 @@ func TestBulkOpenAutoUDPFallsBackToShared(t *testing.T) {
|
||||
_ = accepted.Bulk.Close()
|
||||
}
|
||||
|
||||
func TestOpenDedicatedBulkConnectByConnRejectedAsSingleConnMode(t *testing.T) {
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||
}
|
||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||
}
|
||||
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
bootstrapPeerAttachConnForTest(t, server, right)
|
||||
if err := client.ConnectByConn(left); err != nil {
|
||||
t.Fatalf("ConnectByConn failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
client.setByeFromServer(true)
|
||||
_ = client.Stop()
|
||||
}()
|
||||
|
||||
_, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{
|
||||
Range: BulkRange{
|
||||
Offset: 0,
|
||||
Length: 128,
|
||||
},
|
||||
})
|
||||
if !errors.Is(err, errBulkDedicatedSingleConn) {
|
||||
t.Fatalf("client OpenDedicatedBulk over ConnectByConn error = %v, want %v", err, errBulkDedicatedSingleConn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenBulkAutoConnectByConnFallsBackToShared(t *testing.T) {
|
||||
acceptCh := make(chan BulkAcceptInfo, 2)
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||
}
|
||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||
acceptCh <- info
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||
}
|
||||
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
bootstrapPeerAttachConnForTest(t, server, right)
|
||||
if err := client.ConnectByConn(left); err != nil {
|
||||
t.Fatalf("ConnectByConn failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
client.setByeFromServer(true)
|
||||
_ = client.Stop()
|
||||
}()
|
||||
|
||||
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
|
||||
Mode: BulkOpenModeAuto,
|
||||
Range: BulkRange{
|
||||
Offset: 0,
|
||||
Length: 128,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("client OpenBulk auto over ConnectByConn failed: %v", err)
|
||||
}
|
||||
if bulk.Snapshot().Dedicated {
|
||||
t.Fatal("client OpenBulk auto over ConnectByConn should fall back to shared")
|
||||
}
|
||||
defer func() {
|
||||
_ = bulk.Close()
|
||||
}()
|
||||
|
||||
accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second)
|
||||
if accepted.Dedicated {
|
||||
t.Fatal("server accepted bulk should be shared after ConnectByConn auto fallback")
|
||||
}
|
||||
if _, err := bulk.Write([]byte("shared-over-single-conn")); err != nil {
|
||||
t.Fatalf("client bulk Write failed: %v", err)
|
||||
}
|
||||
readBulkExactly(t, accepted.Bulk, "shared-over-single-conn", 2*time.Second)
|
||||
select {
|
||||
case extra := <-acceptCh:
|
||||
t.Fatalf("unexpected extra server bulk accept: %+v", extra)
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
_ = accepted.Bulk.Close()
|
||||
}
|
||||
|
||||
func TestOpenDedicatedBulkExternalTransportDialableSourceSucceeds(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||
}
|
||||
acceptCh := make(chan BulkAcceptInfo, 2)
|
||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||
acceptCh <- info
|
||||
return nil
|
||||
})
|
||||
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||
t.Fatalf("server Listen failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = server.Stop()
|
||||
}()
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||
}
|
||||
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||||
t.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
client.setByeFromServer(true)
|
||||
_ = client.Stop()
|
||||
}()
|
||||
|
||||
bulk, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{
|
||||
ID: "external-dedicated-dialable",
|
||||
Range: BulkRange{
|
||||
Offset: 0,
|
||||
Length: 64,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("client OpenDedicatedBulk failed: %v", err)
|
||||
}
|
||||
if !bulk.Snapshot().Dedicated {
|
||||
t.Fatal("client OpenDedicatedBulk over dialable external transport should stay dedicated")
|
||||
}
|
||||
defer func() {
|
||||
_ = bulk.Close()
|
||||
}()
|
||||
|
||||
accepted := waitAcceptedBulkByID(t, acceptCh, bulk.ID(), 2*time.Second)
|
||||
if !accepted.Dedicated {
|
||||
t.Fatal("server accepted bulk should stay dedicated over dialable external transport")
|
||||
}
|
||||
defer func() {
|
||||
_ = accepted.Bulk.Close()
|
||||
}()
|
||||
|
||||
clientHandle := bulk.(*bulkHandle)
|
||||
if clientHandle.dedicatedConnSnapshot() == nil {
|
||||
t.Fatal("client dedicated sidecar conn should be attached")
|
||||
}
|
||||
if mainConn := client.clientTransportConnSnapshot(); mainConn != nil && clientHandle.dedicatedConnSnapshot() == mainConn {
|
||||
t.Fatal("client dedicated sidecar should use an additional physical connection")
|
||||
}
|
||||
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionExternal {
|
||||
t.Fatalf("client steady protection mode = %v, want %v", got, ProtectionExternal)
|
||||
}
|
||||
|
||||
payload := "external-dedicated-sidecar"
|
||||
if _, err := bulk.Write([]byte(payload)); err != nil {
|
||||
t.Fatalf("client bulk Write failed: %v", err)
|
||||
}
|
||||
readBulkExactly(t, accepted.Bulk, payload, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestOpenDedicatedBulkWaitsForActiveSlotUntilContextDeadline(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
|
||||
152
client.go
152
client.go
@ -10,75 +10,87 @@ import (
|
||||
)
|
||||
|
||||
type ClientCommon struct {
|
||||
alive atomic.Value
|
||||
status Status
|
||||
byeFromServer bool
|
||||
conn net.Conn
|
||||
mu sync.Mutex
|
||||
msgID uint64
|
||||
peerIdentity string
|
||||
sessionEpoch uint64
|
||||
sessionOwnerState atomic.Int32
|
||||
sessionRuntime atomic.Pointer[clientSessionRuntime]
|
||||
connectSource atomic.Pointer[clientConnectSource]
|
||||
queue *stario.StarQueue
|
||||
stopFn context.CancelFunc
|
||||
stopCtx context.Context
|
||||
parallelNum int
|
||||
maxReadTimeout time.Duration
|
||||
maxWriteTimeout time.Duration
|
||||
keyExchangeFn func(c Client) error
|
||||
linkFns map[string]func(message *Message)
|
||||
defaultFns func(message *Message)
|
||||
msgEn func([]byte, []byte) []byte
|
||||
msgDe func([]byte, []byte) []byte
|
||||
fastStreamEncode transportFastStreamEncoder
|
||||
fastBulkEncode transportFastBulkEncoder
|
||||
fastPlainEncode transportFastPlainEncoder
|
||||
modernPSKRuntime *modernPSKCodecRuntime
|
||||
handshakeRsaPubKey []byte
|
||||
SecretKey []byte
|
||||
noFinSyncMsgMaxKeepSeconds int
|
||||
lastHeartbeat int64
|
||||
heartbeatPeriod time.Duration
|
||||
wg stario.WaitGroup
|
||||
netType NetType
|
||||
showError bool
|
||||
skipKeyExchange bool
|
||||
useHeartBeat bool
|
||||
sequenceDe func([]byte) (interface{}, error)
|
||||
sequenceEn func(interface{}) ([]byte, error)
|
||||
logicalSession *logicalSessionState
|
||||
onFileEvent func(FileEvent)
|
||||
fileEventObserver func(FileEvent)
|
||||
fileTransferCfg fileTransferConfig
|
||||
signalReliableCfg signalReliabilityConfig
|
||||
streamRuntime *streamRuntime
|
||||
recordRuntime *recordRuntime
|
||||
bulkRuntime *bulkRuntime
|
||||
bulkDefaultOpenMode BulkOpenMode
|
||||
bulkNetworkProfile BulkNetworkProfile
|
||||
bulkOpenTuning BulkOpenTuning
|
||||
bulkDedicatedAttachLimit int
|
||||
bulkDedicatedAttachSem chan struct{}
|
||||
bulkDedicatedAttachRetry int
|
||||
bulkDedicatedAttachBackoff time.Duration
|
||||
bulkDedicatedDialTimeout time.Duration
|
||||
bulkDedicatedHelloTimeout time.Duration
|
||||
bulkDedicatedActiveLimit int
|
||||
bulkDedicatedActive atomic.Int32
|
||||
bulkDedicatedActiveWait chan struct{}
|
||||
bulkDedicatedLaneLimit int
|
||||
bulkDedicatedSidecarMu sync.Mutex
|
||||
bulkDedicatedLanes map[uint32]*bulkDedicatedLane
|
||||
bulkDedicatedNextLaneID uint32
|
||||
bulkAttachAttemptCount atomic.Int64
|
||||
bulkAttachRetryCount atomic.Int64
|
||||
bulkAttachSuccessCount atomic.Int64
|
||||
bulkAttachFallbackCount atomic.Int64
|
||||
connectionRetryState *connectionRetryState
|
||||
securityReadyCheck bool
|
||||
debugMode bool
|
||||
alive atomic.Value
|
||||
status Status
|
||||
byeFromServer bool
|
||||
conn net.Conn
|
||||
mu sync.Mutex
|
||||
msgID uint64
|
||||
peerIdentity string
|
||||
sessionEpoch uint64
|
||||
sessionOwnerState atomic.Int32
|
||||
sessionRuntime atomic.Pointer[clientSessionRuntime]
|
||||
connectSource atomic.Pointer[clientConnectSource]
|
||||
queue *stario.StarQueue
|
||||
stopFn context.CancelFunc
|
||||
stopCtx context.Context
|
||||
parallelNum int
|
||||
maxReadTimeout time.Duration
|
||||
maxWriteTimeout time.Duration
|
||||
keyExchangeFn func(c Client) error
|
||||
linkFns map[string]func(message *Message)
|
||||
defaultFns func(message *Message)
|
||||
msgEn func([]byte, []byte) []byte
|
||||
msgDe func([]byte, []byte) []byte
|
||||
fastStreamEncode transportFastStreamEncoder
|
||||
fastBulkEncode transportFastBulkEncoder
|
||||
fastPlainEncode transportFastPlainEncoder
|
||||
modernPSKRuntime *modernPSKCodecRuntime
|
||||
handshakeRsaPubKey []byte
|
||||
SecretKey []byte
|
||||
transportProtection atomic.Pointer[transportProtectionProfile]
|
||||
peerAttachSecurity atomic.Pointer[peerAttachSecurityState]
|
||||
securityBootstrap transportProtectionProfile
|
||||
securitySteady transportProtectionProfile
|
||||
securitySteadyNegotiated transportProtectionProfile
|
||||
securityAuthMode AuthMode
|
||||
securityProtectionMode ProtectionMode
|
||||
securityRequireForwardSecrecy bool
|
||||
securityConfigured bool
|
||||
peerAttachAuthenticated bool
|
||||
peerAttachAuthFallback bool
|
||||
peerAttachAt int64
|
||||
noFinSyncMsgMaxKeepSeconds int
|
||||
lastHeartbeat int64
|
||||
heartbeatPeriod time.Duration
|
||||
wg stario.WaitGroup
|
||||
netType NetType
|
||||
showError bool
|
||||
skipKeyExchange bool
|
||||
useHeartBeat bool
|
||||
sequenceDe func([]byte) (interface{}, error)
|
||||
sequenceEn func(interface{}) ([]byte, error)
|
||||
logicalSession *logicalSessionState
|
||||
onFileEvent func(FileEvent)
|
||||
fileEventObserver func(FileEvent)
|
||||
fileTransferCfg fileTransferConfig
|
||||
signalReliableCfg signalReliabilityConfig
|
||||
streamRuntime *streamRuntime
|
||||
recordRuntime *recordRuntime
|
||||
bulkRuntime *bulkRuntime
|
||||
bulkDefaultOpenMode BulkOpenMode
|
||||
bulkNetworkProfile BulkNetworkProfile
|
||||
bulkOpenTuning BulkOpenTuning
|
||||
bulkDedicatedAttachLimit int
|
||||
bulkDedicatedAttachSem chan struct{}
|
||||
bulkDedicatedAttachRetry int
|
||||
bulkDedicatedAttachBackoff time.Duration
|
||||
bulkDedicatedDialTimeout time.Duration
|
||||
bulkDedicatedHelloTimeout time.Duration
|
||||
bulkDedicatedActiveLimit int
|
||||
bulkDedicatedActive atomic.Int32
|
||||
bulkDedicatedActiveWait chan struct{}
|
||||
bulkDedicatedLaneLimit int
|
||||
bulkDedicatedSidecarMu sync.Mutex
|
||||
bulkDedicatedLanes map[uint32]*bulkDedicatedLane
|
||||
bulkDedicatedNextLaneID uint32
|
||||
bulkAttachAttemptCount atomic.Int64
|
||||
bulkAttachRetryCount atomic.Int64
|
||||
bulkAttachSuccessCount atomic.Int64
|
||||
bulkAttachFallbackCount atomic.Int64
|
||||
connectionRetryState *connectionRetryState
|
||||
securityReadyCheck bool
|
||||
debugMode bool
|
||||
}
|
||||
|
||||
func NewClient() Client {
|
||||
@ -134,6 +146,8 @@ func NewClient() Client {
|
||||
client.fileEventObserver = normalizeFileEventCallback(nil)
|
||||
client.stopCtx, client.stopFn = context.WithCancel(context.Background())
|
||||
client.sessionRuntime.Store(newClientSessionRuntimeBase(client.stopCtx, client.stopFn))
|
||||
client.setClientTransportProtectionProfile(defaultTransportProtectionProfile())
|
||||
client.peerAttachSecurity.Store(defaultPeerAttachSecurityState())
|
||||
bindClientStreamControl(&client)
|
||||
bindClientBulkControl(&client)
|
||||
client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler)
|
||||
|
||||
@ -294,7 +294,7 @@ func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender {
|
||||
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.sendDedicatedBulkWrite(ctx, bulk, startSeq, payload)
|
||||
return c.sendDedicatedBulkWrite(ctx, bulk, startSeq, payload, payloadOwned)
|
||||
}
|
||||
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
|
||||
return 0, errTransportDetached
|
||||
|
||||
@ -65,32 +65,40 @@ func (c *ClientCommon) RecoverTransferSnapshots(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte {
|
||||
return c.msgEn
|
||||
return c.clientTransportProtectionSnapshot().msgEn
|
||||
}
|
||||
|
||||
// Deprecated: SetMsgEn overrides the transport codec directly.
|
||||
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||
func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
|
||||
c.msgEn = fn
|
||||
c.fastStreamEncode = nil
|
||||
c.fastBulkEncode = nil
|
||||
c.fastPlainEncode = nil
|
||||
c.modernPSKRuntime = nil
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
profile.mode = ProtectionManaged
|
||||
profile.msgEn = fn
|
||||
profile.fastStreamEncode = nil
|
||||
profile.fastBulkEncode = nil
|
||||
profile.fastPlainEncode = nil
|
||||
profile.runtime = nil
|
||||
c.setClientTransportProtectionProfile(profile)
|
||||
c.clearClientSecurityProfiles()
|
||||
c.securityReadyCheck = false
|
||||
}
|
||||
|
||||
func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte {
|
||||
return c.msgDe
|
||||
return c.clientTransportProtectionSnapshot().msgDe
|
||||
}
|
||||
|
||||
// Deprecated: SetMsgDe overrides the transport codec directly.
|
||||
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||
func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
|
||||
c.msgDe = fn
|
||||
c.fastStreamEncode = nil
|
||||
c.fastBulkEncode = nil
|
||||
c.fastPlainEncode = nil
|
||||
c.modernPSKRuntime = nil
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
profile.mode = ProtectionManaged
|
||||
profile.msgDe = fn
|
||||
profile.fastStreamEncode = nil
|
||||
profile.fastBulkEncode = nil
|
||||
profile.fastPlainEncode = nil
|
||||
profile.runtime = nil
|
||||
c.setClientTransportProtectionProfile(profile)
|
||||
c.clearClientSecurityProfiles()
|
||||
c.securityReadyCheck = false
|
||||
}
|
||||
|
||||
@ -103,20 +111,24 @@ func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) {
|
||||
}
|
||||
|
||||
func (c *ClientCommon) GetSecretKey() []byte {
|
||||
return c.SecretKey
|
||||
return c.clientTransportProtectionSnapshot().secretKey
|
||||
}
|
||||
|
||||
// Deprecated: SetSecretKey injects a raw transport key directly.
|
||||
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||
func (c *ClientCommon) SetSecretKey(key []byte) {
|
||||
c.SecretKey = key
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
profile.mode = ProtectionManaged
|
||||
profile.secretKey = cloneTransportProtectionKey(key)
|
||||
if len(key) == 0 {
|
||||
c.modernPSKRuntime = nil
|
||||
profile.runtime = nil
|
||||
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
|
||||
c.modernPSKRuntime = runtime
|
||||
profile.runtime = runtime
|
||||
} else {
|
||||
c.modernPSKRuntime = nil
|
||||
profile.runtime = nil
|
||||
}
|
||||
c.setClientTransportProtectionProfile(profile)
|
||||
c.clearClientSecurityProfiles()
|
||||
c.securityReadyCheck = len(key) == 0
|
||||
c.skipKeyExchange = true
|
||||
}
|
||||
|
||||
@ -6,17 +6,26 @@ import (
|
||||
)
|
||||
|
||||
type clientConnAttachmentState struct {
|
||||
maxReadTimeout time.Duration
|
||||
maxWriteTimeout time.Duration
|
||||
msgEn func([]byte, []byte) []byte
|
||||
msgDe func([]byte, []byte) []byte
|
||||
fastStreamEncode transportFastStreamEncoder
|
||||
fastBulkEncode transportFastBulkEncoder
|
||||
fastPlainEncode transportFastPlainEncoder
|
||||
modernPSKRuntime *modernPSKCodecRuntime
|
||||
handshakeRsaKey []byte
|
||||
secretKey []byte
|
||||
lastHeartBeat int64
|
||||
maxReadTimeout time.Duration
|
||||
maxWriteTimeout time.Duration
|
||||
authMode AuthMode
|
||||
protectionMode ProtectionMode
|
||||
msgEn func([]byte, []byte) []byte
|
||||
msgDe func([]byte, []byte) []byte
|
||||
fastStreamEncode transportFastStreamEncoder
|
||||
fastBulkEncode transportFastBulkEncoder
|
||||
fastPlainEncode transportFastPlainEncoder
|
||||
modernPSKRuntime *modernPSKCodecRuntime
|
||||
handshakeRsaKey []byte
|
||||
secretKey []byte
|
||||
keyMode string
|
||||
sessionID []byte
|
||||
forwardSecrecy bool
|
||||
forwardSecrecyFallback bool
|
||||
peerAttached bool
|
||||
peerAttachFallback bool
|
||||
peerAttachAt int64
|
||||
lastHeartBeat int64
|
||||
}
|
||||
|
||||
func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnAttachmentState {
|
||||
@ -26,6 +35,7 @@ func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnA
|
||||
cloned := *src
|
||||
cloned.handshakeRsaKey = cloneClientConnAttachmentBytes(src.handshakeRsaKey)
|
||||
cloned.secretKey = cloneClientConnAttachmentBytes(src.secretKey)
|
||||
cloned.sessionID = cloneClientConnAttachmentBytes(src.sessionID)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
@ -153,6 +163,7 @@ func (c *ClientConn) applyClientConnAttachmentProfile(maxReadTimeout time.Durati
|
||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.maxReadTimeout = maxReadTimeout
|
||||
state.maxWriteTimeout = maxWriteTimeout
|
||||
state.protectionMode = ProtectionManaged
|
||||
state.msgEn = msgEn
|
||||
state.msgDe = msgDe
|
||||
state.modernPSKRuntime = nil
|
||||
@ -206,6 +217,7 @@ func (c *ClientConn) clientConnMsgEnSnapshot() func([]byte, []byte) []byte {
|
||||
|
||||
func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) {
|
||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.protectionMode = ProtectionManaged
|
||||
state.msgEn = fn
|
||||
state.fastStreamEncode = nil
|
||||
state.fastBulkEncode = nil
|
||||
@ -223,6 +235,7 @@ func (c *ClientConn) clientConnMsgDeSnapshot() func([]byte, []byte) []byte {
|
||||
|
||||
func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) {
|
||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.protectionMode = ProtectionManaged
|
||||
state.msgDe = fn
|
||||
state.fastStreamEncode = nil
|
||||
state.fastBulkEncode = nil
|
||||
@ -319,6 +332,39 @@ func (c *LogicalConn) modernPSKRuntimeSnapshot() *modernPSKCodecRuntime {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LogicalConn) protectionModeSnapshot() ProtectionMode {
|
||||
if state := c.attachmentStateRaw(); state != nil {
|
||||
return state.protectionMode
|
||||
}
|
||||
return ProtectionManaged
|
||||
}
|
||||
|
||||
func (c *LogicalConn) authModeSnapshot() AuthMode {
|
||||
if state := c.attachmentStateRaw(); state != nil {
|
||||
return state.authMode
|
||||
}
|
||||
return AuthNone
|
||||
}
|
||||
|
||||
func (c *LogicalConn) peerAttachAuthenticatedSnapshot() (bool, bool, time.Time) {
|
||||
if state := c.attachmentStateRaw(); state != nil {
|
||||
if state.peerAttachAt == 0 {
|
||||
return state.peerAttached, state.peerAttachFallback, time.Time{}
|
||||
}
|
||||
return state.peerAttached, state.peerAttachFallback, time.Unix(0, state.peerAttachAt)
|
||||
}
|
||||
return false, false, time.Time{}
|
||||
}
|
||||
|
||||
func (c *LogicalConn) markPeerAttachAuthenticated(authMode AuthMode, fallback bool, at time.Time) {
|
||||
c.updateAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.authMode = authMode
|
||||
state.peerAttached = true
|
||||
state.peerAttachFallback = fallback
|
||||
state.peerAttachAt = at.UnixNano()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *LogicalConn) setModernPSKRuntime(runtime *modernPSKCodecRuntime) {
|
||||
c.updateAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.modernPSKRuntime = runtime
|
||||
|
||||
@ -18,14 +18,18 @@ const (
|
||||
var errClientReconnectSourceUnavailable = errors.New("client reconnect source is unavailable")
|
||||
|
||||
type clientConnectSource struct {
|
||||
kind string
|
||||
network string
|
||||
addr string
|
||||
dialFn func(context.Context) (net.Conn, error)
|
||||
kind string
|
||||
network string
|
||||
addr string
|
||||
dialFn func(context.Context) (net.Conn, error)
|
||||
supportsAdditional bool
|
||||
}
|
||||
|
||||
func newClientConnConnectSource(conn net.Conn) *clientConnectSource {
|
||||
source := &clientConnectSource{kind: clientConnectSourceConn}
|
||||
source := &clientConnectSource{
|
||||
kind: clientConnectSourceConn,
|
||||
supportsAdditional: false,
|
||||
}
|
||||
if conn == nil {
|
||||
return source
|
||||
}
|
||||
@ -43,9 +47,10 @@ func newClientConnConnectSource(conn net.Conn) *clientConnectSource {
|
||||
|
||||
func newClientNetworkConnectSource(network string, addr string) *clientConnectSource {
|
||||
return &clientConnectSource{
|
||||
kind: clientConnectSourceNetwork,
|
||||
network: network,
|
||||
addr: addr,
|
||||
kind: clientConnectSourceNetwork,
|
||||
network: network,
|
||||
addr: addr,
|
||||
supportsAdditional: true,
|
||||
dialFn: func(context.Context) (net.Conn, error) {
|
||||
return transport.Dial(network, addr)
|
||||
},
|
||||
@ -54,9 +59,10 @@ func newClientNetworkConnectSource(network string, addr string) *clientConnectSo
|
||||
|
||||
func newClientTimeoutConnectSource(network string, addr string, timeout time.Duration) *clientConnectSource {
|
||||
return &clientConnectSource{
|
||||
kind: clientConnectSourceTimeout,
|
||||
network: network,
|
||||
addr: addr,
|
||||
kind: clientConnectSourceTimeout,
|
||||
network: network,
|
||||
addr: addr,
|
||||
supportsAdditional: true,
|
||||
dialFn: func(context.Context) (net.Conn, error) {
|
||||
return transport.DialTimeout(network, addr, timeout)
|
||||
},
|
||||
@ -65,8 +71,9 @@ func newClientTimeoutConnectSource(network string, addr string, timeout time.Dur
|
||||
|
||||
func newClientFactoryConnectSource(dialFn func(context.Context) (net.Conn, error)) *clientConnectSource {
|
||||
return &clientConnectSource{
|
||||
kind: clientConnectSourceFactory,
|
||||
dialFn: dialFn,
|
||||
kind: clientConnectSourceFactory,
|
||||
dialFn: dialFn,
|
||||
supportsAdditional: true,
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,6 +89,10 @@ func (s *clientConnectSource) canReconnect() bool {
|
||||
return s != nil && s.dialFn != nil
|
||||
}
|
||||
|
||||
func (s *clientConnectSource) supportsAdditionalConn() bool {
|
||||
return s != nil && s.supportsAdditional
|
||||
}
|
||||
|
||||
func (s *clientConnectSource) isUDP() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
|
||||
@ -31,7 +31,11 @@ func (c *ClientCommon) ExchangeKey(newKey []byte) error {
|
||||
if string(data.Value) != "success" {
|
||||
return errors.New("cannot exchange new aes-key")
|
||||
}
|
||||
c.SecretKey = newKey
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
profile.mode = ProtectionManaged
|
||||
profile.secretKey = cloneTransportProtectionKey(newKey)
|
||||
profile.runtime = nil
|
||||
c.setClientTransportProtectionProfile(profile)
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -350,6 +350,8 @@ func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime,
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
c.resetClientPeerAttachAuth()
|
||||
c.activateClientBootstrapTransportProtection()
|
||||
if runKeyExchange && !c.skipKeyExchange {
|
||||
if err := c.keyExchangeFn(c); err != nil {
|
||||
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err)
|
||||
@ -358,6 +360,7 @@ func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime,
|
||||
if err := c.announceClientPeerIdentity(); err != nil {
|
||||
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err)
|
||||
}
|
||||
c.activateClientSteadyTransportProtection()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -174,28 +174,37 @@ func (c *ClientCommon) dispatchTransportPayloadFast(payload []byte, release func
|
||||
}
|
||||
return
|
||||
}
|
||||
if c.tryDispatchBorrowedBulkTransportPayload(payload) {
|
||||
if release != nil {
|
||||
release()
|
||||
plain, plainRelease, err := c.decryptTransportPayloadPooled(payload, release)
|
||||
if err != nil {
|
||||
if c.showError || c.debugMode {
|
||||
fmt.Println("client decode transport payload error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
owned := append([]byte(nil), payload...)
|
||||
if release != nil {
|
||||
release()
|
||||
if c.tryDispatchBorrowedTransportPlain(plain, plainRelease) {
|
||||
return
|
||||
}
|
||||
if dispatcher == nil {
|
||||
now := time.Now()
|
||||
if err := c.dispatchInboundTransportPayload(owned, now); err != nil && (c.showError || c.debugMode) {
|
||||
err := c.dispatchInboundTransportPlain(plain, now)
|
||||
if plainRelease != nil {
|
||||
plainRelease()
|
||||
}
|
||||
if err != nil && (c.showError || c.debugMode) {
|
||||
fmt.Println("client decode envelope error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
owned := plain
|
||||
if plainRelease != nil {
|
||||
owned = append([]byte(nil), plain...)
|
||||
plainRelease()
|
||||
}
|
||||
c.wg.Add(1)
|
||||
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
|
||||
defer c.wg.Done()
|
||||
now := time.Now()
|
||||
if err := c.dispatchInboundTransportPayload(owned, now); err != nil && (c.showError || c.debugMode) {
|
||||
if err := c.dispatchInboundTransportPlain(owned, now); err != nil && (c.showError || c.debugMode) {
|
||||
fmt.Println("client decode envelope error", err)
|
||||
}
|
||||
}) {
|
||||
|
||||
@ -26,6 +26,8 @@ type Client interface {
|
||||
BulkOpenTuning() BulkOpenTuning
|
||||
SetBulkDedicatedAttachConfig(BulkDedicatedAttachConfig)
|
||||
BulkDedicatedAttachConfig() BulkDedicatedAttachConfig
|
||||
SetPeerAttachSecurityConfig(PeerAttachSecurityConfig) error
|
||||
PeerAttachSecurityConfig() PeerAttachSecurityConfig
|
||||
SetFileReceiveDir(dir string) error
|
||||
send(msg TransferMsg) (WaitMsg, error)
|
||||
sendEnvelope(env Envelope) error
|
||||
|
||||
@ -404,6 +404,7 @@ func (c *LogicalConn) applyAttachmentProfile(maxReadTimeout time.Duration, maxWr
|
||||
c.updateAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.maxReadTimeout = maxReadTimeout
|
||||
state.maxWriteTimeout = maxWriteTimeout
|
||||
state.protectionMode = ProtectionManaged
|
||||
state.msgEn = msgEn
|
||||
state.msgDe = msgDe
|
||||
state.fastStreamEncode = fastStreamEncode
|
||||
@ -525,17 +526,23 @@ func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot {
|
||||
return ClientConnRuntimeSnapshot{}
|
||||
}
|
||||
status := c.Status()
|
||||
authenticated, fallback, attachAt := c.peerAttachAuthenticatedSnapshot()
|
||||
now := time.Now()
|
||||
snapshot := ClientConnRuntimeSnapshot{
|
||||
ClientID: c.clientIDSnapshot(),
|
||||
Alive: status.Alive,
|
||||
Reason: status.Reason,
|
||||
IdentityBound: c.clientConnIdentityBoundSnapshot(),
|
||||
UsesStreamTransport: c.usesStreamTransportSnapshot(),
|
||||
TransportGeneration: c.transportGenerationSnapshot(),
|
||||
TransportAttachCount: c.clientConnTransportAttachCountSnapshot(),
|
||||
TransportDetachCount: c.clientConnTransportDetachCountSnapshot(),
|
||||
LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(),
|
||||
ClientID: c.clientIDSnapshot(),
|
||||
Alive: status.Alive,
|
||||
Reason: status.Reason,
|
||||
IdentityBound: c.clientConnIdentityBoundSnapshot(),
|
||||
UsesStreamTransport: c.usesStreamTransportSnapshot(),
|
||||
TransportGeneration: c.transportGenerationSnapshot(),
|
||||
TransportAttachCount: c.clientConnTransportAttachCountSnapshot(),
|
||||
TransportDetachCount: c.clientConnTransportDetachCountSnapshot(),
|
||||
LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(),
|
||||
AuthMode: authModeName(c.authModeSnapshot()),
|
||||
ProtectionMode: protectionModeName(c.protectionModeSnapshot()),
|
||||
PeerAttachAuthenticated: authenticated,
|
||||
PeerAttachAuthFallback: fallback,
|
||||
LastPeerAttachAt: attachAt,
|
||||
}
|
||||
if status.Err != nil {
|
||||
snapshot.Error = status.Err.Error()
|
||||
|
||||
@ -43,6 +43,20 @@ func TestHydrateServerMessagePeerFieldsFromLogicalConn(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHydrateServerMessagePeerFieldsZeroValueDoesNotPanic(t *testing.T) {
|
||||
message := hydrateServerMessagePeerFields(Message{})
|
||||
|
||||
if message.LogicalConn != nil {
|
||||
t.Fatal("zero-value message should not hydrate logical conn")
|
||||
}
|
||||
if message.ClientConn != nil {
|
||||
t.Fatal("zero-value message should not hydrate client conn")
|
||||
}
|
||||
if message.TransportConn != nil {
|
||||
t.Fatal("zero-value message should not hydrate transport conn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerPublishSendFileEventHydratesLogicalConnAlias(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
left, right := net.Pipe()
|
||||
|
||||
39
msg.go
39
msg.go
@ -37,9 +37,10 @@ type Message struct {
|
||||
NetType
|
||||
LogicalConn *LogicalConn
|
||||
// Deprecated: ClientConn aliases LogicalConn for compatibility.
|
||||
ClientConn *ClientConn
|
||||
TransportConn *TransportConn
|
||||
ServerConn Client
|
||||
ClientConn *ClientConn
|
||||
TransportConn *TransportConn
|
||||
ServerConn Client
|
||||
inboundTransportProfile *transportProtectionProfile
|
||||
TransferMsg
|
||||
Time time.Time
|
||||
inboundConn net.Conn
|
||||
@ -58,7 +59,7 @@ type messageLogicalTransferSender interface {
|
||||
}
|
||||
|
||||
type messageInboundTransferSender interface {
|
||||
sendTransferInbound(*LogicalConn, *TransportConn, net.Conn, TransferMsg) error
|
||||
sendTransferInbound(*LogicalConn, *TransportConn, net.Conn, *transportProtectionProfile, TransferMsg) error
|
||||
}
|
||||
|
||||
func (m *Message) Reply(value MsgVal) (err error) {
|
||||
@ -86,7 +87,7 @@ func (m *Message) Reply(value MsgVal) (err error) {
|
||||
if sender == nil {
|
||||
return transportDetachedErrorForPeer(logical, transport)
|
||||
}
|
||||
return sender.sendTransferInbound(logical, transport, m.inboundConn, reply)
|
||||
return sender.sendTransferInbound(logical, transport, m.inboundConn, messageInboundTransportProtectionSnapshot(m), reply)
|
||||
}
|
||||
if transport != nil {
|
||||
_, err = transport.sendTransfer(reply)
|
||||
@ -123,12 +124,19 @@ func hydrateServerMessagePeerFields(message Message) Message {
|
||||
if message.LogicalConn == nil {
|
||||
message.LogicalConn = logicalConnFromClient(message.ClientConn)
|
||||
}
|
||||
if message.ClientConn == nil {
|
||||
if message.LogicalConn == nil && message.TransportConn != nil {
|
||||
message.LogicalConn = message.TransportConn.logicalConnSnapshot()
|
||||
}
|
||||
if message.ClientConn == nil && message.LogicalConn != nil {
|
||||
message.ClientConn = message.LogicalConn.compatClientConn()
|
||||
}
|
||||
if message.TransportConn == nil && message.LogicalConn != nil {
|
||||
message.TransportConn = message.LogicalConn.CurrentTransportConn()
|
||||
}
|
||||
if message.inboundConn != nil && message.inboundTransportProfile == nil && message.LogicalConn != nil {
|
||||
profile := message.LogicalConn.transportProtectionProfileSnapshot()
|
||||
message.inboundTransportProfile = &profile
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
@ -155,3 +163,22 @@ func messageTransportConnSnapshot(message *Message) *TransportConn {
|
||||
}
|
||||
return logical.CurrentTransportConn()
|
||||
}
|
||||
|
||||
func messageInboundTransportProtectionSnapshot(message *Message) *transportProtectionProfile {
|
||||
if message == nil {
|
||||
return nil
|
||||
}
|
||||
if message.inboundTransportProfile != nil {
|
||||
return message.inboundTransportProfile
|
||||
}
|
||||
if message.inboundConn == nil {
|
||||
return nil
|
||||
}
|
||||
logical := messageLogicalConnSnapshot(message)
|
||||
if logical == nil {
|
||||
return nil
|
||||
}
|
||||
profile := logical.transportProtectionProfileSnapshot()
|
||||
message.inboundTransportProfile = &profile
|
||||
return message.inboundTransportProfile
|
||||
}
|
||||
|
||||
517
peer_attach_auth.go
Normal file
517
peer_attach_auth.go
Normal file
@ -0,0 +1,517 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
peerAttachFeatureExplicitAuth uint64 = 1 << iota
|
||||
peerAttachFeatureChannelBinding
|
||||
peerAttachFeatureForwardSecrecy
|
||||
)
|
||||
|
||||
const (
|
||||
peerAttachNonceSize = 16
|
||||
peerAttachReplayTTL = 30 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
errPeerAttachAuthInvalid = errors.New("peer attach auth invalid")
|
||||
errPeerAttachReplayRejected = errors.New("peer attach replay rejected")
|
||||
errPeerAttachReplayWindowFull = errors.New("peer attach replay window full")
|
||||
errPeerAttachExplicitAuthRequired = errors.New("peer attach explicit auth required")
|
||||
errPeerAttachChannelBindingRequired = errors.New("peer attach channel binding required")
|
||||
errPeerAttachChannelBindingUnavailable = errors.New("peer attach channel binding unavailable")
|
||||
errPeerAttachForwardSecrecyRequired = errors.New("peer attach forward secrecy required")
|
||||
)
|
||||
|
||||
type peerAttachAuthResult struct {
|
||||
explicit bool
|
||||
fallback bool
|
||||
clientNonce []byte
|
||||
serverNonce []byte
|
||||
channelBinding []byte
|
||||
clientECDHEPublicKey []byte
|
||||
}
|
||||
|
||||
type peerAttachReplayCache struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]time.Time
|
||||
rejects atomic.Int64
|
||||
overflowRejects atomic.Int64
|
||||
}
|
||||
|
||||
func newPeerAttachNonce() ([]byte, error) {
|
||||
buf := make([]byte, peerAttachNonceSize)
|
||||
if _, err := cryptorand.Read(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func appendPeerAttachAuthBytes(dst []byte, data []byte) []byte {
|
||||
dst = binary.BigEndian.AppendUint32(dst, uint32(len(data)))
|
||||
return append(dst, data...)
|
||||
}
|
||||
|
||||
func appendPeerAttachAuthString(dst []byte, value string) []byte {
|
||||
return appendPeerAttachAuthBytes(dst, []byte(value))
|
||||
}
|
||||
|
||||
func appendPeerAttachAuthBool(dst []byte, value bool) []byte {
|
||||
if value {
|
||||
return append(dst, 1)
|
||||
}
|
||||
return append(dst, 0)
|
||||
}
|
||||
|
||||
func peerAttachRequestAuthPayload(req peerAttachRequest, channelBinding []byte) []byte {
|
||||
buf := make([]byte, 0, 96+len(req.PeerID)+len(channelBinding))
|
||||
buf = appendPeerAttachAuthString(buf, "notify/peer-attach/request-auth/v1")
|
||||
buf = binary.BigEndian.AppendUint64(buf, req.Features)
|
||||
buf = appendPeerAttachAuthString(buf, req.PeerID)
|
||||
buf = appendPeerAttachAuthBytes(buf, req.ClientNonce)
|
||||
if supportsPeerAttachChannelBinding(req.Features) {
|
||||
buf = appendPeerAttachAuthBytes(buf, channelBinding)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
func peerAttachResponseAuthPayload(req peerAttachRequest, resp peerAttachResponse, channelBinding []byte) []byte {
|
||||
buf := make([]byte, 0, 160+len(req.PeerID)+len(resp.PeerID)+len(resp.Error)+len(channelBinding))
|
||||
buf = appendPeerAttachAuthString(buf, "notify/peer-attach/response-auth/v1")
|
||||
buf = binary.BigEndian.AppendUint64(buf, req.Features)
|
||||
buf = appendPeerAttachAuthString(buf, req.PeerID)
|
||||
buf = appendPeerAttachAuthBytes(buf, req.ClientNonce)
|
||||
buf = binary.BigEndian.AppendUint64(buf, resp.Features)
|
||||
buf = appendPeerAttachAuthString(buf, resp.PeerID)
|
||||
buf = appendPeerAttachAuthBool(buf, resp.Accepted)
|
||||
buf = appendPeerAttachAuthBool(buf, resp.Reused)
|
||||
buf = appendPeerAttachAuthString(buf, resp.Error)
|
||||
buf = appendPeerAttachAuthBytes(buf, resp.ServerNonce)
|
||||
if supportsPeerAttachChannelBinding(resp.Features) {
|
||||
buf = appendPeerAttachAuthBytes(buf, channelBinding)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
func signPeerAttachPayload(secretKey []byte, payload []byte) []byte {
|
||||
if len(secretKey) == 0 {
|
||||
return nil
|
||||
}
|
||||
mac := hmac.New(sha256.New, secretKey)
|
||||
_, _ = mac.Write(payload)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func computePeerAttachRequestAuthTag(secretKey []byte, req peerAttachRequest, channelBinding []byte) []byte {
|
||||
return signPeerAttachPayload(secretKey, peerAttachRequestAuthPayload(req, channelBinding))
|
||||
}
|
||||
|
||||
func computePeerAttachResponseAuthTag(secretKey []byte, req peerAttachRequest, resp peerAttachResponse, channelBinding []byte) []byte {
|
||||
return signPeerAttachPayload(secretKey, peerAttachResponseAuthPayload(req, resp, channelBinding))
|
||||
}
|
||||
|
||||
func supportsExplicitPeerAttachAuth(features uint64) bool {
|
||||
return features&peerAttachFeatureExplicitAuth != 0
|
||||
}
|
||||
|
||||
func supportsPeerAttachChannelBinding(features uint64) bool {
|
||||
return features&peerAttachFeatureChannelBinding != 0
|
||||
}
|
||||
|
||||
func supportsPeerAttachForwardSecrecy(features uint64) bool {
|
||||
return features&peerAttachFeatureForwardSecrecy != 0
|
||||
}
|
||||
|
||||
func classifyPeerAttachRejectCounter(s *ServerCommon, err error) {
|
||||
if s == nil || err == nil {
|
||||
return
|
||||
}
|
||||
switch {
|
||||
case errors.Is(err, errPeerAttachReplayRejected):
|
||||
s.peerAttachReplay.rejects.Add(1)
|
||||
case errors.Is(err, errPeerAttachReplayWindowFull):
|
||||
s.peerAttachReplay.overflowRejects.Add(1)
|
||||
case errors.Is(err, errPeerAttachExplicitAuthRequired), errors.Is(err, errPeerAttachChannelBindingRequired), errors.Is(err, errPeerAttachForwardSecrecyRequired):
|
||||
s.peerAttachDowngradeRejectCount.Add(1)
|
||||
case errors.Is(err, errPeerAttachChannelBindingUnavailable):
|
||||
s.peerAttachBindingRejectCount.Add(1)
|
||||
default:
|
||||
s.peerAttachAuthRejectCount.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) shouldUseExplicitPeerAttachAuth() bool {
|
||||
if c == nil || !c.securityConfigured {
|
||||
return false
|
||||
}
|
||||
return c.securityAuthMode == AuthPSK && len(c.securityBootstrap.secretKey) != 0
|
||||
}
|
||||
|
||||
func resolvePeerAttachChannelBinding(provider PeerAttachChannelBindingProvider, role PeerAttachChannelBindingRole, peerID string, conn net.Conn) ([]byte, error) {
|
||||
if provider == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if conn == nil {
|
||||
return nil, errPeerAttachChannelBindingUnavailable
|
||||
}
|
||||
binding, err := provider(PeerAttachChannelBindingContext{
|
||||
Role: role,
|
||||
PeerID: peerID,
|
||||
Conn: conn,
|
||||
})
|
||||
if err != nil || len(binding) == 0 {
|
||||
return nil, errPeerAttachChannelBindingUnavailable
|
||||
}
|
||||
return bytes.Clone(binding), nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) buildPeerAttachRequest(peerID string) (peerAttachRequest, peerAttachRequestState, error) {
|
||||
cfg := c.peerAttachSecuritySnapshot()
|
||||
req := peerAttachRequest{
|
||||
PeerID: stringsTrimSpaceNoAlloc(peerID),
|
||||
}
|
||||
if !c.shouldUseExplicitPeerAttachAuth() {
|
||||
if c.clientRequiresForwardSecrecy() {
|
||||
return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachForwardSecrecyRequired
|
||||
}
|
||||
if cfg.requireExplicitAuth {
|
||||
return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachExplicitAuthRequired
|
||||
}
|
||||
return req, peerAttachRequestState{}, nil
|
||||
}
|
||||
nonce, err := newPeerAttachNonce()
|
||||
if err != nil {
|
||||
return peerAttachRequest{}, peerAttachRequestState{}, err
|
||||
}
|
||||
req.Features = peerAttachFeatureExplicitAuth
|
||||
requestState := peerAttachRequestState{}
|
||||
var channelBinding []byte
|
||||
if cfg.channelBinding != nil {
|
||||
channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleClient, req.PeerID, c.clientTransportConnSnapshot())
|
||||
if err != nil {
|
||||
return peerAttachRequest{}, peerAttachRequestState{}, err
|
||||
}
|
||||
}
|
||||
if len(channelBinding) != 0 {
|
||||
req.Features |= peerAttachFeatureChannelBinding
|
||||
}
|
||||
if cfg.requireChannelBinding && !supportsPeerAttachChannelBinding(req.Features) {
|
||||
return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachChannelBindingRequired
|
||||
}
|
||||
if c.clientSupportsForwardSecrecy() {
|
||||
requestState.forwardSecrecy, err = newPeerAttachForwardSecrecyClientState()
|
||||
if err != nil {
|
||||
return peerAttachRequest{}, peerAttachRequestState{}, err
|
||||
}
|
||||
req.Features |= peerAttachFeatureForwardSecrecy
|
||||
req.ClientECDHEPublicKey = bytes.Clone(requestState.forwardSecrecy.publicKey)
|
||||
}
|
||||
req.ClientNonce = nonce
|
||||
req.AuthTag = computePeerAttachRequestAuthTag(c.securityBootstrap.secretKey, req, channelBinding)
|
||||
return req, requestState, nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) verifyPeerAttachResponse(req peerAttachRequest, resp peerAttachResponse, requestState peerAttachRequestState) (peerAttachResponseVerifyResult, error) {
|
||||
cfg := c.peerAttachSecuritySnapshot()
|
||||
baseSteady := transportProtectionProfile{}
|
||||
if c != nil {
|
||||
baseSteady = c.securitySteady.clone().withForwardSecrecyFallback(false)
|
||||
}
|
||||
result := peerAttachResponseVerifyResult{steadyProfile: baseSteady}
|
||||
if c == nil || !c.shouldUseExplicitPeerAttachAuth() {
|
||||
if c.clientRequiresForwardSecrecy() {
|
||||
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired
|
||||
}
|
||||
if cfg.requireExplicitAuth {
|
||||
return peerAttachResponseVerifyResult{}, errPeerAttachExplicitAuthRequired
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
if !supportsExplicitPeerAttachAuth(resp.Features) {
|
||||
if c.clientRequiresForwardSecrecy() {
|
||||
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired
|
||||
}
|
||||
if cfg.requireExplicitAuth {
|
||||
return peerAttachResponseVerifyResult{}, errPeerAttachExplicitAuthRequired
|
||||
}
|
||||
if c.clientSupportsForwardSecrecy() {
|
||||
result.steadyProfile = result.steadyProfile.withForwardSecrecyFallback(true)
|
||||
}
|
||||
result.authFallback = true
|
||||
return result, nil
|
||||
}
|
||||
var channelBinding []byte
|
||||
if supportsPeerAttachChannelBinding(req.Features) {
|
||||
if cfg.channelBinding == nil {
|
||||
return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingUnavailable
|
||||
}
|
||||
if !supportsPeerAttachChannelBinding(resp.Features) {
|
||||
return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingRequired
|
||||
}
|
||||
var err error
|
||||
channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleClient, req.PeerID, c.clientTransportConnSnapshot())
|
||||
if err != nil {
|
||||
return peerAttachResponseVerifyResult{}, err
|
||||
}
|
||||
} else if cfg.requireChannelBinding {
|
||||
return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingRequired
|
||||
}
|
||||
if len(resp.ServerNonce) != peerAttachNonceSize {
|
||||
return peerAttachResponseVerifyResult{}, errPeerAttachAuthInvalid
|
||||
}
|
||||
expected := computePeerAttachResponseAuthTag(c.securityBootstrap.secretKey, req, peerAttachResponse{
|
||||
PeerID: resp.PeerID,
|
||||
Accepted: resp.Accepted,
|
||||
Reused: resp.Reused,
|
||||
Error: resp.Error,
|
||||
Features: resp.Features,
|
||||
ServerNonce: resp.ServerNonce,
|
||||
}, channelBinding)
|
||||
if !hmac.Equal(resp.AuthTag, expected) {
|
||||
return peerAttachResponseVerifyResult{}, errPeerAttachAuthInvalid
|
||||
}
|
||||
if requestState.forwardSecrecy == nil || !supportsPeerAttachForwardSecrecy(req.Features) {
|
||||
return result, nil
|
||||
}
|
||||
if !supportsPeerAttachForwardSecrecy(resp.Features) {
|
||||
if c.clientRequiresForwardSecrecy() {
|
||||
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired
|
||||
}
|
||||
result.steadyProfile = result.steadyProfile.withForwardSecrecyFallback(true)
|
||||
return result, nil
|
||||
}
|
||||
if resp.KeyMode != "" && resp.KeyMode != peerAttachKeyModeECDHE {
|
||||
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyInvalid
|
||||
}
|
||||
profile, err := derivePeerAttachForwardSecrecyTransportProfile(c.securitySteady, c.securityBootstrap.secretKey, requestState.forwardSecrecy.privateKey, resp.ServerECDHEPublicKey, req, resp)
|
||||
if err != nil {
|
||||
return peerAttachResponseVerifyResult{}, err
|
||||
}
|
||||
result.steadyProfile = profile
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *ServerCommon) validatePeerAttachRequestAuth(logical *LogicalConn, transport net.Conn, req peerAttachRequest) (peerAttachAuthResult, error) {
|
||||
cfg := s.peerAttachSecuritySnapshot()
|
||||
if !supportsExplicitPeerAttachAuth(req.Features) {
|
||||
if s.serverRequiresForwardSecrecy() {
|
||||
return peerAttachAuthResult{}, errPeerAttachForwardSecrecyRequired
|
||||
}
|
||||
if cfg.requireExplicitAuth {
|
||||
return peerAttachAuthResult{}, errPeerAttachExplicitAuthRequired
|
||||
}
|
||||
return peerAttachAuthResult{fallback: true}, nil
|
||||
}
|
||||
if logical == nil {
|
||||
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
|
||||
}
|
||||
var channelBinding []byte
|
||||
if supportsPeerAttachChannelBinding(req.Features) {
|
||||
if cfg.channelBinding == nil {
|
||||
return peerAttachAuthResult{}, errPeerAttachChannelBindingUnavailable
|
||||
}
|
||||
var err error
|
||||
channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleServer, req.PeerID, transport)
|
||||
if err != nil {
|
||||
return peerAttachAuthResult{}, err
|
||||
}
|
||||
} else if cfg.requireChannelBinding {
|
||||
return peerAttachAuthResult{}, errPeerAttachChannelBindingRequired
|
||||
}
|
||||
secretKey := logical.secretKeySnapshot()
|
||||
if len(secretKey) == 0 {
|
||||
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
|
||||
}
|
||||
if len(req.ClientNonce) != peerAttachNonceSize {
|
||||
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
|
||||
}
|
||||
expected := computePeerAttachRequestAuthTag(secretKey, peerAttachRequest{
|
||||
PeerID: req.PeerID,
|
||||
Features: req.Features,
|
||||
ClientNonce: req.ClientNonce,
|
||||
}, channelBinding)
|
||||
if !hmac.Equal(req.AuthTag, expected) {
|
||||
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
|
||||
}
|
||||
if supportsPeerAttachForwardSecrecy(req.Features) {
|
||||
if len(req.ClientECDHEPublicKey) != peerAttachECDHEPublicKeySize {
|
||||
return peerAttachAuthResult{}, errPeerAttachForwardSecrecyInvalid
|
||||
}
|
||||
}
|
||||
if err := s.acceptPeerAttachReplay(req.PeerID, req.ClientNonce, time.Now(), cfg.replayWindow, cfg.replayCapacity); err != nil {
|
||||
return peerAttachAuthResult{}, err
|
||||
}
|
||||
serverNonce, err := newPeerAttachNonce()
|
||||
if err != nil {
|
||||
return peerAttachAuthResult{}, err
|
||||
}
|
||||
return peerAttachAuthResult{
|
||||
explicit: true,
|
||||
clientNonce: bytes.Clone(req.ClientNonce),
|
||||
serverNonce: serverNonce,
|
||||
channelBinding: channelBinding,
|
||||
clientECDHEPublicKey: bytes.Clone(req.ClientECDHEPublicKey),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ServerCommon) signPeerAttachResponse(logical *LogicalConn, req peerAttachRequest, resp *peerAttachResponse, auth peerAttachAuthResult) {
|
||||
if s == nil || logical == nil || resp == nil || !auth.explicit {
|
||||
return
|
||||
}
|
||||
secretKey := logical.secretKeySnapshot()
|
||||
if len(secretKey) == 0 {
|
||||
return
|
||||
}
|
||||
resp.Features |= peerAttachFeatureExplicitAuth
|
||||
if len(auth.channelBinding) != 0 {
|
||||
resp.Features |= peerAttachFeatureChannelBinding
|
||||
}
|
||||
resp.ServerNonce = bytes.Clone(auth.serverNonce)
|
||||
resp.AuthTag = computePeerAttachResponseAuthTag(secretKey, req, *resp, auth.channelBinding)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) preparePeerAttachSteadyTransportProfile(logical *LogicalConn, req peerAttachRequest, resp *peerAttachResponse, auth peerAttachAuthResult) (transportProtectionProfile, error) {
|
||||
if s == nil {
|
||||
return transportProtectionProfile{}, nil
|
||||
}
|
||||
profile := s.securitySteady.clone().withForwardSecrecyFallback(false)
|
||||
if resp != nil && auth.explicit {
|
||||
resp.Features |= peerAttachFeatureExplicitAuth
|
||||
if len(auth.channelBinding) != 0 {
|
||||
resp.Features |= peerAttachFeatureChannelBinding
|
||||
}
|
||||
resp.ServerNonce = bytes.Clone(auth.serverNonce)
|
||||
}
|
||||
if resp != nil && resp.KeyMode == "" {
|
||||
resp.KeyMode = profile.keyMode
|
||||
}
|
||||
if !s.serverSupportsForwardSecrecy() {
|
||||
return profile, nil
|
||||
}
|
||||
if !auth.explicit || !supportsPeerAttachForwardSecrecy(req.Features) {
|
||||
if s.serverRequiresForwardSecrecy() {
|
||||
return transportProtectionProfile{}, errPeerAttachForwardSecrecyRequired
|
||||
}
|
||||
return profile.withForwardSecrecyFallback(true), nil
|
||||
}
|
||||
fsState, err := newPeerAttachForwardSecrecyClientState()
|
||||
if err != nil {
|
||||
return transportProtectionProfile{}, err
|
||||
}
|
||||
if resp != nil {
|
||||
resp.Features |= peerAttachFeatureForwardSecrecy
|
||||
resp.KeyMode = peerAttachKeyModeECDHE
|
||||
resp.ServerECDHEPublicKey = bytes.Clone(fsState.publicKey)
|
||||
}
|
||||
return derivePeerAttachForwardSecrecyTransportProfile(s.securitySteady, logical.secretKeySnapshot(), fsState.privateKey, auth.clientECDHEPublicKey, req, *resp)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) acceptPeerAttachReplay(peerID string, nonce []byte, now time.Time, window time.Duration, capacity int) error {
|
||||
if s == nil || len(nonce) == 0 {
|
||||
return nil
|
||||
}
|
||||
cache := &s.peerAttachReplay
|
||||
key := peerID + "\x00" + string(nonce)
|
||||
expireBefore := now.Add(-window)
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
if cache.entries == nil {
|
||||
cache.entries = make(map[string]time.Time)
|
||||
}
|
||||
for replayKey, seenAt := range cache.entries {
|
||||
if seenAt.Before(expireBefore) {
|
||||
delete(cache.entries, replayKey)
|
||||
}
|
||||
}
|
||||
if seenAt, ok := cache.entries[key]; ok && !seenAt.Before(expireBefore) {
|
||||
return errPeerAttachReplayRejected
|
||||
}
|
||||
if capacity > 0 && len(cache.entries) >= capacity {
|
||||
return errPeerAttachReplayWindowFull
|
||||
}
|
||||
cache.entries[key] = now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerCommon) peerAttachReplayRejectCountSnapshot() int64 {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return s.peerAttachReplay.rejects.Load()
|
||||
}
|
||||
|
||||
func (s *ServerCommon) peerAttachReplayOverflowRejectCountSnapshot() int64 {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return s.peerAttachReplay.overflowRejects.Load()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) markClientPeerAttachAuthenticated(fallback bool, at time.Time) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.peerAttachAuthenticated = true
|
||||
c.peerAttachAuthFallback = fallback
|
||||
c.peerAttachAt = at.UnixNano()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) resetClientPeerAttachAuth() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.peerAttachAuthenticated = false
|
||||
c.peerAttachAuthFallback = false
|
||||
c.peerAttachAt = 0
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientPeerAttachAuthSnapshot() (bool, bool, time.Time) {
|
||||
if c == nil {
|
||||
return false, false, time.Time{}
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.peerAttachAt == 0 {
|
||||
return c.peerAttachAuthenticated, c.peerAttachAuthFallback, time.Time{}
|
||||
}
|
||||
return c.peerAttachAuthenticated, c.peerAttachAuthFallback, time.Unix(0, c.peerAttachAt)
|
||||
}
|
||||
|
||||
func stringsTrimSpaceNoAlloc(value string) string {
|
||||
start := 0
|
||||
for start < len(value) {
|
||||
switch value[start] {
|
||||
case ' ', '\t', '\n', '\r':
|
||||
start++
|
||||
default:
|
||||
goto endStart
|
||||
}
|
||||
}
|
||||
return ""
|
||||
endStart:
|
||||
end := len(value)
|
||||
for end > start {
|
||||
switch value[end-1] {
|
||||
case ' ', '\t', '\n', '\r':
|
||||
end--
|
||||
default:
|
||||
return value[start:end]
|
||||
}
|
||||
}
|
||||
return value[start:end]
|
||||
}
|
||||
237
peer_attach_auth_test.go
Normal file
237
peer_attach_auth_test.go
Normal file
@ -0,0 +1,237 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newPeerAttachAuthLogicalForTest(t *testing.T, server *ServerCommon) *LogicalConn {
|
||||
t.Helper()
|
||||
logical := newServerLogicalConn(server, "accepted-auth", nil)
|
||||
logical = server.registerAcceptedLogical(logical)
|
||||
if logical == nil {
|
||||
t.Fatal("registerAcceptedLogical returned nil")
|
||||
}
|
||||
return logical
|
||||
}
|
||||
|
||||
func TestPeerAttachExplicitAuthHelpersRoundTrip(t *testing.T) {
|
||||
secret := []byte("correct horse battery staple")
|
||||
client := NewClient().(*ClientCommon)
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
|
||||
logical := newPeerAttachAuthLogicalForTest(t, server)
|
||||
req, reqState, err := client.buildPeerAttachRequest("peer-explicit")
|
||||
if err != nil {
|
||||
t.Fatalf("buildPeerAttachRequest failed: %v", err)
|
||||
}
|
||||
if !supportsExplicitPeerAttachAuth(req.Features) {
|
||||
t.Fatalf("request features = %d, want explicit auth bit", req.Features)
|
||||
}
|
||||
if !supportsPeerAttachForwardSecrecy(req.Features) {
|
||||
t.Fatalf("request features = %d, want forward secrecy bit", req.Features)
|
||||
}
|
||||
if len(req.ClientNonce) != peerAttachNonceSize {
|
||||
t.Fatalf("client nonce length = %d, want %d", len(req.ClientNonce), peerAttachNonceSize)
|
||||
}
|
||||
|
||||
auth, err := server.validatePeerAttachRequestAuth(logical, nil, req)
|
||||
if err != nil {
|
||||
t.Fatalf("validatePeerAttachRequestAuth failed: %v", err)
|
||||
}
|
||||
if !auth.explicit || auth.fallback {
|
||||
t.Fatalf("auth result mismatch: %+v", auth)
|
||||
}
|
||||
|
||||
resp := peerAttachResponse{
|
||||
PeerID: req.PeerID,
|
||||
Accepted: true,
|
||||
}
|
||||
server.signPeerAttachResponse(logical, req, &resp, auth)
|
||||
if !supportsExplicitPeerAttachAuth(resp.Features) {
|
||||
t.Fatalf("response features = %d, want explicit auth bit", resp.Features)
|
||||
}
|
||||
if len(resp.ServerNonce) != peerAttachNonceSize {
|
||||
t.Fatalf("server nonce length = %d, want %d", len(resp.ServerNonce), peerAttachNonceSize)
|
||||
}
|
||||
|
||||
verifyResult, err := client.verifyPeerAttachResponse(req, resp, reqState)
|
||||
if err != nil {
|
||||
t.Fatalf("verifyPeerAttachResponse failed: %v", err)
|
||||
}
|
||||
if verifyResult.authFallback {
|
||||
t.Fatal("explicit response should not be marked as fallback")
|
||||
}
|
||||
if !verifyResult.steadyProfile.forwardSecrecyFallback {
|
||||
t.Fatal("response without fs extension should mark forward secrecy fallback")
|
||||
}
|
||||
|
||||
resp.AuthTag[0] ^= 0xff
|
||||
if verifyResult, err = client.verifyPeerAttachResponse(req, resp, reqState); !errors.Is(err, errPeerAttachAuthInvalid) {
|
||||
t.Fatalf("tampered response error = %v, want %v", err, errPeerAttachAuthInvalid)
|
||||
} else if verifyResult.authFallback {
|
||||
t.Fatal("tampered explicit response should not be treated as fallback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerAttachRequestAuthRejectsReplay(t *testing.T) {
|
||||
secret := []byte("correct horse battery staple")
|
||||
client := NewClient().(*ClientCommon)
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
|
||||
logical := newPeerAttachAuthLogicalForTest(t, server)
|
||||
req, _, err := client.buildPeerAttachRequest("peer-replay")
|
||||
if err != nil {
|
||||
t.Fatalf("buildPeerAttachRequest failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := server.validatePeerAttachRequestAuth(logical, nil, req); err != nil {
|
||||
t.Fatalf("first validatePeerAttachRequestAuth failed: %v", err)
|
||||
}
|
||||
if _, err := server.validatePeerAttachRequestAuth(logical, nil, req); !errors.Is(err, errPeerAttachReplayRejected) {
|
||||
t.Fatalf("second validatePeerAttachRequestAuth error = %v, want %v", err, errPeerAttachReplayRejected)
|
||||
}
|
||||
classifyPeerAttachRejectCounter(server, errPeerAttachReplayRejected)
|
||||
if got, want := server.peerAttachReplayRejectCountSnapshot(), int64(1); got != want {
|
||||
t.Fatalf("replay reject count = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerAttachAuthFallbackCompatibility(t *testing.T) {
|
||||
secret := []byte("correct horse battery staple")
|
||||
client := NewClient().(*ClientCommon)
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
|
||||
logical := newPeerAttachAuthLogicalForTest(t, server)
|
||||
auth, err := server.validatePeerAttachRequestAuth(logical, nil, peerAttachRequest{PeerID: "peer-fallback"})
|
||||
if err != nil {
|
||||
t.Fatalf("validatePeerAttachRequestAuth fallback failed: %v", err)
|
||||
}
|
||||
if auth.explicit || !auth.fallback {
|
||||
t.Fatalf("fallback auth result mismatch: %+v", auth)
|
||||
}
|
||||
|
||||
req, reqState, err := client.buildPeerAttachRequest("peer-fallback")
|
||||
if err != nil {
|
||||
t.Fatalf("buildPeerAttachRequest failed: %v", err)
|
||||
}
|
||||
verifyResult, err := client.verifyPeerAttachResponse(req, peerAttachResponse{
|
||||
PeerID: req.PeerID,
|
||||
Accepted: true,
|
||||
}, reqState)
|
||||
if err != nil {
|
||||
t.Fatalf("verifyPeerAttachResponse fallback failed: %v", err)
|
||||
}
|
||||
if !verifyResult.authFallback {
|
||||
t.Fatal("unsigned legacy response should be marked as fallback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerAttachForwardSecrecyNegotiatesDerivedSteadyProfile(t *testing.T) {
|
||||
secret := []byte("correct horse battery staple")
|
||||
client := NewClient().(*ClientCommon)
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
|
||||
logical := newPeerAttachAuthLogicalForTest(t, server)
|
||||
req, reqState, err := client.buildPeerAttachRequest("peer-fs")
|
||||
if err != nil {
|
||||
t.Fatalf("buildPeerAttachRequest failed: %v", err)
|
||||
}
|
||||
auth, err := server.validatePeerAttachRequestAuth(logical, nil, req)
|
||||
if err != nil {
|
||||
t.Fatalf("validatePeerAttachRequestAuth failed: %v", err)
|
||||
}
|
||||
|
||||
resp := peerAttachResponse{
|
||||
PeerID: req.PeerID,
|
||||
Accepted: true,
|
||||
}
|
||||
serverProfile, err := server.preparePeerAttachSteadyTransportProfile(logical, req, &resp, auth)
|
||||
if err != nil {
|
||||
t.Fatalf("preparePeerAttachSteadyTransportProfile failed: %v", err)
|
||||
}
|
||||
server.signPeerAttachResponse(logical, req, &resp, auth)
|
||||
verifyResult, err := client.verifyPeerAttachResponse(req, resp, reqState)
|
||||
if err != nil {
|
||||
t.Fatalf("verifyPeerAttachResponse failed: %v", err)
|
||||
}
|
||||
|
||||
if !supportsPeerAttachForwardSecrecy(resp.Features) {
|
||||
t.Fatalf("response features = %d, want forward secrecy bit", resp.Features)
|
||||
}
|
||||
if resp.KeyMode != peerAttachKeyModeECDHE {
|
||||
t.Fatalf("response key mode = %q, want %q", resp.KeyMode, peerAttachKeyModeECDHE)
|
||||
}
|
||||
if !verifyResult.steadyProfile.forwardSecrecy {
|
||||
t.Fatal("client steady profile should enable forward secrecy")
|
||||
}
|
||||
if !serverProfile.forwardSecrecy {
|
||||
t.Fatal("server steady profile should enable forward secrecy")
|
||||
}
|
||||
if len(verifyResult.steadyProfile.sessionID) == 0 {
|
||||
t.Fatal("client session id should be populated")
|
||||
}
|
||||
if !bytes.Equal(verifyResult.steadyProfile.secretKey, serverProfile.secretKey) {
|
||||
t.Fatal("client/server derived steady keys should match")
|
||||
}
|
||||
if !bytes.Equal(verifyResult.steadyProfile.sessionID, serverProfile.sessionID) {
|
||||
t.Fatal("client/server session ids should match")
|
||||
}
|
||||
if bytes.Equal(verifyResult.steadyProfile.secretKey, client.securityBootstrap.secretKey) {
|
||||
t.Fatal("derived steady key should differ from bootstrap key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerAttachForwardSecrecyStrictRejectsFallback(t *testing.T) {
|
||||
secret := []byte("correct horse battery staple")
|
||||
client := NewClient().(*ClientCommon)
|
||||
server := NewServer().(*ServerCommon)
|
||||
opts := testModernPSKOptions()
|
||||
opts.RequireForwardSecrecy = true
|
||||
if err := UseModernPSKClient(client, secret, opts); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := UseModernPSKServer(server, secret, opts); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
|
||||
req, reqState, err := client.buildPeerAttachRequest("peer-fs-strict")
|
||||
if err != nil {
|
||||
t.Fatalf("buildPeerAttachRequest failed: %v", err)
|
||||
}
|
||||
_, err = client.verifyPeerAttachResponse(req, peerAttachResponse{
|
||||
PeerID: req.PeerID,
|
||||
Accepted: true,
|
||||
Features: peerAttachFeatureExplicitAuth,
|
||||
ServerNonce: make([]byte, peerAttachNonceSize),
|
||||
AuthTag: computePeerAttachResponseAuthTag(client.securityBootstrap.secretKey, req, peerAttachResponse{PeerID: req.PeerID, Accepted: true, Features: peerAttachFeatureExplicitAuth, ServerNonce: make([]byte, peerAttachNonceSize)}, nil),
|
||||
}, reqState)
|
||||
if !errors.Is(err, errPeerAttachForwardSecrecyRequired) {
|
||||
t.Fatalf("verifyPeerAttachResponse error = %v, want %v", err, errPeerAttachForwardSecrecyRequired)
|
||||
}
|
||||
}
|
||||
139
peer_attach_policy.go
Normal file
139
peer_attach_policy.go
Normal file
@ -0,0 +1,139 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultPeerAttachReplayCapacity = 4096
|
||||
|
||||
type PeerAttachChannelBindingRole string
|
||||
|
||||
const (
|
||||
PeerAttachChannelBindingRoleClient PeerAttachChannelBindingRole = "client"
|
||||
PeerAttachChannelBindingRoleServer PeerAttachChannelBindingRole = "server"
|
||||
)
|
||||
|
||||
type PeerAttachChannelBindingContext struct {
|
||||
Role PeerAttachChannelBindingRole
|
||||
PeerID string
|
||||
Conn net.Conn
|
||||
}
|
||||
|
||||
type PeerAttachChannelBindingProvider func(PeerAttachChannelBindingContext) ([]byte, error)
|
||||
|
||||
type PeerAttachSecurityConfig struct {
|
||||
RequireExplicitAuth bool
|
||||
RequireChannelBinding bool
|
||||
ReplayWindow time.Duration
|
||||
ReplayCapacity int
|
||||
ChannelBinding PeerAttachChannelBindingProvider
|
||||
}
|
||||
|
||||
type peerAttachSecurityState struct {
|
||||
requireExplicitAuth bool
|
||||
requireChannelBinding bool
|
||||
replayWindow time.Duration
|
||||
replayCapacity int
|
||||
channelBinding PeerAttachChannelBindingProvider
|
||||
}
|
||||
|
||||
var errPeerAttachChannelBindingProviderNil = errors.New("peer attach channel binding provider is nil")
|
||||
|
||||
func DefaultPeerAttachSecurityConfig() PeerAttachSecurityConfig {
|
||||
return PeerAttachSecurityConfig{
|
||||
ReplayWindow: peerAttachReplayTTL,
|
||||
ReplayCapacity: defaultPeerAttachReplayCapacity,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizePeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) (peerAttachSecurityState, error) {
|
||||
if cfg.ReplayWindow <= 0 {
|
||||
cfg.ReplayWindow = peerAttachReplayTTL
|
||||
}
|
||||
if cfg.ReplayCapacity <= 0 {
|
||||
cfg.ReplayCapacity = defaultPeerAttachReplayCapacity
|
||||
}
|
||||
if cfg.RequireChannelBinding {
|
||||
cfg.RequireExplicitAuth = true
|
||||
if cfg.ChannelBinding == nil {
|
||||
return peerAttachSecurityState{}, errPeerAttachChannelBindingProviderNil
|
||||
}
|
||||
}
|
||||
return peerAttachSecurityState{
|
||||
requireExplicitAuth: cfg.RequireExplicitAuth,
|
||||
requireChannelBinding: cfg.RequireChannelBinding,
|
||||
replayWindow: cfg.ReplayWindow,
|
||||
replayCapacity: cfg.ReplayCapacity,
|
||||
channelBinding: cfg.ChannelBinding,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func peerAttachSecurityConfigFromState(state *peerAttachSecurityState) PeerAttachSecurityConfig {
|
||||
if state == nil {
|
||||
return DefaultPeerAttachSecurityConfig()
|
||||
}
|
||||
return PeerAttachSecurityConfig{
|
||||
RequireExplicitAuth: state.requireExplicitAuth,
|
||||
RequireChannelBinding: state.requireChannelBinding,
|
||||
ReplayWindow: state.replayWindow,
|
||||
ReplayCapacity: state.replayCapacity,
|
||||
ChannelBinding: state.channelBinding,
|
||||
}
|
||||
}
|
||||
|
||||
func defaultPeerAttachSecurityState() *peerAttachSecurityState {
|
||||
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
|
||||
return &cfg
|
||||
}
|
||||
|
||||
func (c *ClientCommon) peerAttachSecuritySnapshot() peerAttachSecurityState {
|
||||
if c == nil {
|
||||
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
|
||||
return cfg
|
||||
}
|
||||
if state := c.peerAttachSecurity.Load(); state != nil {
|
||||
return *state
|
||||
}
|
||||
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (s *ServerCommon) peerAttachSecuritySnapshot() peerAttachSecurityState {
|
||||
if s == nil {
|
||||
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
|
||||
return cfg
|
||||
}
|
||||
if state := s.peerAttachSecurity.Load(); state != nil {
|
||||
return *state
|
||||
}
|
||||
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SetPeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) error {
|
||||
state, err := normalizePeerAttachSecurityConfig(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.peerAttachSecurity.Store(&state)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) PeerAttachSecurityConfig() PeerAttachSecurityConfig {
|
||||
return peerAttachSecurityConfigFromState(c.peerAttachSecurity.Load())
|
||||
}
|
||||
|
||||
func (s *ServerCommon) SetPeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) error {
|
||||
state, err := normalizePeerAttachSecurityConfig(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.peerAttachSecurity.Store(&state)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerCommon) PeerAttachSecurityConfig() PeerAttachSecurityConfig {
|
||||
return peerAttachSecurityConfigFromState(s.peerAttachSecurity.Load())
|
||||
}
|
||||
221
peer_attach_policy_test.go
Normal file
221
peer_attach_policy_test.go
Normal file
@ -0,0 +1,221 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func staticPeerAttachChannelBindingProvider(material []byte) PeerAttachChannelBindingProvider {
|
||||
cloned := bytes.Clone(material)
|
||||
return func(PeerAttachChannelBindingContext) ([]byte, error) {
|
||||
return bytes.Clone(cloned), nil
|
||||
}
|
||||
}
|
||||
|
||||
func failingPeerAttachChannelBindingProvider(PeerAttachChannelBindingContext) ([]byte, error) {
|
||||
return nil, errors.New("binding unavailable")
|
||||
}
|
||||
|
||||
func TestSetPeerAttachSecurityConfigRejectsMissingChannelBindingProvider(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
server := NewServer().(*ServerCommon)
|
||||
cfg := PeerAttachSecurityConfig{RequireChannelBinding: true}
|
||||
|
||||
if err := client.SetPeerAttachSecurityConfig(cfg); !errors.Is(err, errPeerAttachChannelBindingProviderNil) {
|
||||
t.Fatalf("client SetPeerAttachSecurityConfig error = %v, want %v", err, errPeerAttachChannelBindingProviderNil)
|
||||
}
|
||||
if err := server.SetPeerAttachSecurityConfig(cfg); !errors.Is(err, errPeerAttachChannelBindingProviderNil) {
|
||||
t.Fatalf("server SetPeerAttachSecurityConfig error = %v, want %v", err, errPeerAttachChannelBindingProviderNil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerAttachRequireExplicitAuthRejectsFallbackClient(t *testing.T) {
|
||||
secret := []byte("correct horse battery staple")
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
|
||||
RequireExplicitAuth: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("SetPeerAttachSecurityConfig failed: %v", err)
|
||||
}
|
||||
|
||||
logical := newPeerAttachAuthLogicalForTest(t, server)
|
||||
if _, err := server.validatePeerAttachRequestAuth(logical, nil, peerAttachRequest{PeerID: "peer-fallback"}); !errors.Is(err, errPeerAttachExplicitAuthRequired) {
|
||||
t.Fatalf("validatePeerAttachRequestAuth error = %v, want %v", err, errPeerAttachExplicitAuthRequired)
|
||||
}
|
||||
classifyPeerAttachRejectCounter(server, errPeerAttachExplicitAuthRequired)
|
||||
|
||||
snapshot, snapErr := GetServerRuntimeSnapshot(server)
|
||||
if snapErr != nil {
|
||||
t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr)
|
||||
}
|
||||
if got, want := snapshot.PeerAttachDowngradeRejects, int64(1); got != want {
|
||||
t.Fatalf("PeerAttachDowngradeRejects = %d, want %d", got, want)
|
||||
}
|
||||
if got := snapshot.PeerAttachAuthFallbacks; got != 0 {
|
||||
t.Fatalf("PeerAttachAuthFallbacks = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerAttachChannelBindingRoundTrip(t *testing.T) {
|
||||
secret := []byte("correct horse battery staple")
|
||||
bindingProvider := staticPeerAttachChannelBindingProvider([]byte("tls-exporter:test"))
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
if err := UsePSKOverExternalTransportServer(server, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||
}
|
||||
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
|
||||
RequireChannelBinding: true,
|
||||
ChannelBinding: bindingProvider,
|
||||
}); err != nil {
|
||||
t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err)
|
||||
}
|
||||
server.SetLink("echo", func(msg *Message) {
|
||||
_ = msg.Reply([]byte("ack"))
|
||||
})
|
||||
})
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UsePSKOverExternalTransportClient(client, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||
}
|
||||
if err := client.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
|
||||
RequireChannelBinding: true,
|
||||
ChannelBinding: bindingProvider,
|
||||
}); err != nil {
|
||||
t.Fatalf("client SetPeerAttachSecurityConfig failed: %v", err)
|
||||
}
|
||||
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
bootstrapPeerAttachLogicalForTest(t, server, right)
|
||||
if err := client.ConnectByConn(left); err != nil {
|
||||
t.Fatalf("ConnectByConn failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
client.setByeFromServer(true)
|
||||
_ = client.Stop()
|
||||
}()
|
||||
|
||||
reply, err := client.SendWait("echo", []byte("ping"), time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("SendWait failed: %v", err)
|
||||
}
|
||||
if got, want := string(reply.Value), "ack"; got != want {
|
||||
t.Fatalf("reply = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
serverSnapshot, err := GetServerRuntimeSnapshot(server)
|
||||
if err != nil {
|
||||
t.Fatalf("GetServerRuntimeSnapshot failed: %v", err)
|
||||
}
|
||||
if !serverSnapshot.PeerAttachRequireExplicitAuth || !serverSnapshot.PeerAttachRequireChannelBinding || !serverSnapshot.PeerAttachChannelBindingConfigured {
|
||||
t.Fatalf("unexpected server peer attach policy snapshot: %+v", serverSnapshot)
|
||||
}
|
||||
if got, want := serverSnapshot.PeerAttachExplicitAuth, int64(1); got != want {
|
||||
t.Fatalf("PeerAttachExplicitAuth = %d, want %d", got, want)
|
||||
}
|
||||
if serverSnapshot.PeerAttachAuthRejects != 0 || serverSnapshot.PeerAttachDowngradeRejects != 0 || serverSnapshot.PeerAttachBindingRejects != 0 {
|
||||
t.Fatalf("unexpected server reject counters: %+v", serverSnapshot)
|
||||
}
|
||||
|
||||
clientSnapshot, err := GetClientRuntimeSnapshot(client)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
|
||||
}
|
||||
if !clientSnapshot.PeerAttachRequireExplicitAuth || !clientSnapshot.PeerAttachRequireChannelBinding || !clientSnapshot.PeerAttachChannelBindingConfigured {
|
||||
t.Fatalf("unexpected client peer attach policy snapshot: %+v", clientSnapshot)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerAttachChannelBindingProviderFailureRejectsAttach(t *testing.T) {
|
||||
secret := []byte("correct horse battery staple")
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
|
||||
RequireChannelBinding: true,
|
||||
ChannelBinding: failingPeerAttachChannelBindingProvider,
|
||||
}); err != nil {
|
||||
t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err)
|
||||
}
|
||||
})
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := client.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
|
||||
RequireChannelBinding: true,
|
||||
ChannelBinding: staticPeerAttachChannelBindingProvider([]byte("binding")),
|
||||
}); err != nil {
|
||||
t.Fatalf("client SetPeerAttachSecurityConfig failed: %v", err)
|
||||
}
|
||||
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
bootstrapPeerAttachLogicalForTest(t, server, right)
|
||||
|
||||
err := client.ConnectByConn(left)
|
||||
if !errors.Is(err, errPeerAttachChannelBindingUnavailable) && (err == nil || err.Error() != errPeerAttachChannelBindingUnavailable.Error()) {
|
||||
t.Fatalf("ConnectByConn error = %v, want %v", err, errPeerAttachChannelBindingUnavailable)
|
||||
}
|
||||
|
||||
snapshot, snapErr := GetServerRuntimeSnapshot(server)
|
||||
if snapErr != nil {
|
||||
t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr)
|
||||
}
|
||||
if got, want := snapshot.PeerAttachBindingRejects, int64(1); got != want {
|
||||
t.Fatalf("PeerAttachBindingRejects = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerAttachReplayCapacityRejectsOverflow(t *testing.T) {
|
||||
secret := []byte("correct horse battery staple")
|
||||
client := NewClient().(*ClientCommon)
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
|
||||
ReplayCapacity: 1,
|
||||
}); err != nil {
|
||||
t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err)
|
||||
}
|
||||
|
||||
logical := newPeerAttachAuthLogicalForTest(t, server)
|
||||
first, _, err := client.buildPeerAttachRequest("peer-one")
|
||||
if err != nil {
|
||||
t.Fatalf("buildPeerAttachRequest(first) failed: %v", err)
|
||||
}
|
||||
second, _, err := client.buildPeerAttachRequest("peer-two")
|
||||
if err != nil {
|
||||
t.Fatalf("buildPeerAttachRequest(second) failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := server.validatePeerAttachRequestAuth(logical, nil, first); err != nil {
|
||||
t.Fatalf("validatePeerAttachRequestAuth(first) failed: %v", err)
|
||||
}
|
||||
if _, err := server.validatePeerAttachRequestAuth(logical, nil, second); !errors.Is(err, errPeerAttachReplayWindowFull) {
|
||||
t.Fatalf("validatePeerAttachRequestAuth(second) error = %v, want %v", err, errPeerAttachReplayWindowFull)
|
||||
}
|
||||
classifyPeerAttachRejectCounter(server, errPeerAttachReplayWindowFull)
|
||||
|
||||
snapshot, snapErr := GetServerRuntimeSnapshot(server)
|
||||
if snapErr != nil {
|
||||
t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr)
|
||||
}
|
||||
if got, want := snapshot.PeerAttachReplayCapacity, 1; got != want {
|
||||
t.Fatalf("PeerAttachReplayCapacity = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.PeerAttachReplayOverflowRejects, int64(1); got != want {
|
||||
t.Fatalf("PeerAttachReplayOverflowRejects = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
@ -15,14 +15,23 @@ const (
|
||||
)
|
||||
|
||||
type peerAttachRequest struct {
|
||||
PeerID string
|
||||
PeerID string
|
||||
Features uint64
|
||||
ClientNonce []byte
|
||||
ClientECDHEPublicKey []byte
|
||||
AuthTag []byte
|
||||
}
|
||||
|
||||
type peerAttachResponse struct {
|
||||
PeerID string
|
||||
Accepted bool
|
||||
Reused bool
|
||||
Error string
|
||||
PeerID string
|
||||
Accepted bool
|
||||
Reused bool
|
||||
Error string
|
||||
Features uint64
|
||||
KeyMode string
|
||||
ServerNonce []byte
|
||||
ServerECDHEPublicKey []byte
|
||||
AuthTag []byte
|
||||
}
|
||||
|
||||
func newClientPeerIdentity() string {
|
||||
@ -108,7 +117,11 @@ func (c *ClientCommon) announceClientPeerIdentity() error {
|
||||
if peerID == "" {
|
||||
return errors.New("peer identity is empty")
|
||||
}
|
||||
encoded, err := c.sequenceEn(peerAttachRequest{PeerID: peerID})
|
||||
req, requestState, err := c.buildPeerAttachRequest(peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
encoded, err := c.sequenceEn(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -133,6 +146,12 @@ func (c *ClientCommon) announceClientPeerIdentity() error {
|
||||
}
|
||||
return errors.New("peer attach rejected")
|
||||
}
|
||||
verifyResult, err := c.verifyPeerAttachResponse(req, resp, requestState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.setClientNegotiatedSteadyTransportProtection(verifyResult.steadyProfile)
|
||||
c.markClientPeerAttachAuthenticated(verifyResult.authFallback, time.Now())
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -188,7 +207,7 @@ func (s *ServerCommon) replyPeerAttach(client *LogicalConn, message Message, res
|
||||
Type: MSG_SYS_REPLY,
|
||||
}
|
||||
if message.inboundConn != nil {
|
||||
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply)
|
||||
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, messageInboundTransportProtectionSnapshot(&message), reply)
|
||||
}
|
||||
_, err = s.sendLogical(client, reply)
|
||||
return err
|
||||
@ -200,6 +219,10 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool {
|
||||
}
|
||||
message = hydrateServerMessagePeerFields(message)
|
||||
current := messageLogicalConnSnapshot(&message)
|
||||
transport := message.inboundConn
|
||||
if transport == nil && current != nil {
|
||||
transport = current.transportSnapshot()
|
||||
}
|
||||
req, err := decodePeerAttachRequest(s.sequenceDe, message.Value)
|
||||
if err != nil {
|
||||
if current != nil {
|
||||
@ -210,6 +233,18 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
auth, err := s.validatePeerAttachRequestAuth(current, transport, req)
|
||||
if err != nil {
|
||||
classifyPeerAttachRejectCounter(s, err)
|
||||
if current != nil {
|
||||
_ = s.replyPeerAttach(current, message, peerAttachResponse{
|
||||
PeerID: req.PeerID,
|
||||
Accepted: false,
|
||||
Error: err.Error(),
|
||||
})
|
||||
}
|
||||
return true
|
||||
}
|
||||
bound, reused, err := s.bindAcceptedClientIdentity(current, req.PeerID)
|
||||
if err != nil {
|
||||
if current != nil {
|
||||
@ -221,12 +256,37 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
if err := s.replyPeerAttach(bound, message, peerAttachResponse{
|
||||
resp := peerAttachResponse{
|
||||
PeerID: bound.ID(),
|
||||
Accepted: true,
|
||||
Reused: reused,
|
||||
}); err != nil && bound != nil {
|
||||
}
|
||||
steadyProfile, err := s.preparePeerAttachSteadyTransportProfile(bound, req, &resp, auth)
|
||||
if err != nil {
|
||||
if bound != nil {
|
||||
_ = s.replyPeerAttach(bound, message, peerAttachResponse{
|
||||
PeerID: req.PeerID,
|
||||
Accepted: false,
|
||||
Error: err.Error(),
|
||||
})
|
||||
}
|
||||
return true
|
||||
}
|
||||
s.signPeerAttachResponse(bound, req, &resp, auth)
|
||||
if bound != nil {
|
||||
bound.markPeerAttachAuthenticated(s.securityAuthMode, auth.fallback, time.Now())
|
||||
if auth.explicit {
|
||||
s.peerAttachExplicitCount.Add(1)
|
||||
} else if auth.fallback {
|
||||
s.peerAttachAuthFallbackCount.Add(1)
|
||||
}
|
||||
}
|
||||
if err := s.replyPeerAttach(bound, message, resp); err != nil && bound != nil {
|
||||
s.stopLogicalSession(bound, "peer attach reply failed", err)
|
||||
return true
|
||||
}
|
||||
if bound != nil && s.securityConfigured {
|
||||
bound.applyTransportProtectionProfile(steadyProfile)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@ -119,6 +119,7 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) {
|
||||
defer serverConn.Close()
|
||||
|
||||
logical := bootstrapPeerAttachLogicalForTest(t, server, serverConn)
|
||||
originalProfile := logical.transportProtectionProfileSnapshot()
|
||||
message := Message{
|
||||
NetType: NET_SERVER,
|
||||
LogicalConn: logical,
|
||||
@ -131,6 +132,13 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) {
|
||||
Time: time.Now(),
|
||||
inboundConn: serverConn,
|
||||
}
|
||||
message = hydrateServerMessagePeerFields(message)
|
||||
|
||||
alternate, err := deriveModernPSKProtectionProfile([]byte("notify-peer-attach-reply-alternate"), testModernPSKOptions(), ProtectionManaged)
|
||||
if err != nil {
|
||||
t.Fatalf("deriveModernPSKProtectionProfile failed: %v", err)
|
||||
}
|
||||
logical.applyTransportProtectionProfile(alternate)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
@ -140,7 +148,7 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
env := readServerEnvelopeFromConn(t, server, logical, clientConn, time.Second)
|
||||
env := readServerEnvelopeFromConnWithProfile(t, server, originalProfile, clientConn, time.Second)
|
||||
if env.Kind != EnvelopeSignal {
|
||||
t.Fatalf("reply envelope kind = %v, want %v", env.Kind, EnvelopeSignal)
|
||||
}
|
||||
|
||||
139
security_forward_secrecy.go
Normal file
139
security_forward_secrecy.go
Normal file
@ -0,0 +1,139 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdh"
|
||||
"crypto/hmac"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
peerAttachECDHEPublicKeySize = 32
|
||||
peerAttachSessionIDSize = 16
|
||||
|
||||
peerAttachKeyModeStatic = "psk-static"
|
||||
peerAttachKeyModeECDHE = "psk-ecdhe"
|
||||
transportKeyModeExternal = "external"
|
||||
)
|
||||
|
||||
var errPeerAttachForwardSecrecyInvalid = errors.New("peer attach forward secrecy is invalid")
|
||||
|
||||
type peerAttachRequestState struct {
|
||||
forwardSecrecy *peerAttachForwardSecrecyClientState
|
||||
}
|
||||
|
||||
type peerAttachForwardSecrecyClientState struct {
|
||||
privateKey *ecdh.PrivateKey
|
||||
publicKey []byte
|
||||
}
|
||||
|
||||
type peerAttachResponseVerifyResult struct {
|
||||
authFallback bool
|
||||
steadyProfile transportProtectionProfile
|
||||
}
|
||||
|
||||
func newPeerAttachForwardSecrecyClientState() (*peerAttachForwardSecrecyClientState, error) {
|
||||
curve := ecdh.X25519()
|
||||
privateKey, err := curve.GenerateKey(cryptorand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
publicKey := privateKey.PublicKey().Bytes()
|
||||
if len(publicKey) != peerAttachECDHEPublicKeySize {
|
||||
return nil, errPeerAttachForwardSecrecyInvalid
|
||||
}
|
||||
return &peerAttachForwardSecrecyClientState{
|
||||
privateKey: privateKey,
|
||||
publicKey: bytes.Clone(publicKey),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func derivePeerAttachForwardSecrecyTransportProfile(base transportProtectionProfile, bootstrapKey []byte, localPrivateKey *ecdh.PrivateKey, peerPublicKey []byte, req peerAttachRequest, resp peerAttachResponse) (transportProtectionProfile, error) {
|
||||
if len(bootstrapKey) == 0 || localPrivateKey == nil || len(peerPublicKey) != peerAttachECDHEPublicKeySize {
|
||||
return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid
|
||||
}
|
||||
curve := ecdh.X25519()
|
||||
publicKey, err := curve.NewPublicKey(peerPublicKey)
|
||||
if err != nil {
|
||||
return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid
|
||||
}
|
||||
sharedSecret, err := localPrivateKey.ECDH(publicKey)
|
||||
if err != nil {
|
||||
return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid
|
||||
}
|
||||
transcriptHash := peerAttachForwardSecrecyTranscriptHash(req, resp)
|
||||
ikm := make([]byte, 0, len(sharedSecret)+len(transcriptHash))
|
||||
ikm = append(ikm, sharedSecret...)
|
||||
ikm = append(ikm, transcriptHash...)
|
||||
prk := hkdfExtractSHA256(bootstrapKey, ikm)
|
||||
sessionKey := hkdfExpandSHA256(prk, []byte("notify/transport/session/v1"), 32)
|
||||
sessionID := hkdfExpandSHA256(prk, []byte("notify/session-id/v1"), peerAttachSessionIDSize)
|
||||
return deriveModernPSKSessionProtectionProfile(base, sessionKey, sessionID)
|
||||
}
|
||||
|
||||
func peerAttachForwardSecrecyTranscriptHash(req peerAttachRequest, resp peerAttachResponse) []byte {
|
||||
buf := make([]byte, 0, 256+len(req.PeerID)+len(resp.PeerID)+len(resp.Error)+len(req.ClientECDHEPublicKey)+len(resp.ServerECDHEPublicKey))
|
||||
buf = appendPeerAttachTranscriptString(buf, "notify/peer-attach/forward-secrecy/v1")
|
||||
buf = binary.BigEndian.AppendUint64(buf, req.Features)
|
||||
buf = appendPeerAttachTranscriptString(buf, req.PeerID)
|
||||
buf = appendPeerAttachTranscriptBytes(buf, req.ClientNonce)
|
||||
buf = appendPeerAttachTranscriptBytes(buf, req.ClientECDHEPublicKey)
|
||||
buf = binary.BigEndian.AppendUint64(buf, resp.Features)
|
||||
buf = appendPeerAttachTranscriptString(buf, resp.PeerID)
|
||||
buf = appendPeerAttachTranscriptBool(buf, resp.Accepted)
|
||||
buf = appendPeerAttachTranscriptBool(buf, resp.Reused)
|
||||
buf = appendPeerAttachTranscriptString(buf, resp.Error)
|
||||
buf = appendPeerAttachTranscriptBytes(buf, resp.ServerNonce)
|
||||
buf = appendPeerAttachTranscriptString(buf, resp.KeyMode)
|
||||
buf = appendPeerAttachTranscriptBytes(buf, resp.ServerECDHEPublicKey)
|
||||
sum := sha256.Sum256(buf)
|
||||
return sum[:]
|
||||
}
|
||||
|
||||
func appendPeerAttachTranscriptBytes(dst []byte, data []byte) []byte {
|
||||
dst = binary.BigEndian.AppendUint32(dst, uint32(len(data)))
|
||||
return append(dst, data...)
|
||||
}
|
||||
|
||||
func appendPeerAttachTranscriptString(dst []byte, value string) []byte {
|
||||
return appendPeerAttachTranscriptBytes(dst, []byte(value))
|
||||
}
|
||||
|
||||
func appendPeerAttachTranscriptBool(dst []byte, value bool) []byte {
|
||||
if value {
|
||||
return append(dst, 1)
|
||||
}
|
||||
return append(dst, 0)
|
||||
}
|
||||
|
||||
func hkdfExtractSHA256(salt []byte, ikm []byte) []byte {
|
||||
mac := hmac.New(sha256.New, salt)
|
||||
_, _ = mac.Write(ikm)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func hkdfExpandSHA256(prk []byte, info []byte, size int) []byte {
|
||||
if size <= 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]byte, 0, size)
|
||||
var block []byte
|
||||
for counter := byte(1); len(out) < size; counter++ {
|
||||
mac := hmac.New(sha256.New, prk)
|
||||
if len(block) != 0 {
|
||||
_, _ = mac.Write(block)
|
||||
}
|
||||
_, _ = mac.Write(info)
|
||||
_, _ = mac.Write([]byte{counter})
|
||||
block = mac.Sum(nil)
|
||||
remaining := size - len(out)
|
||||
if remaining > len(block) {
|
||||
remaining = len(block)
|
||||
}
|
||||
out = append(out, block[:remaining]...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
408
security_profile.go
Normal file
408
security_profile.go
Normal file
@ -0,0 +1,408 @@
|
||||
package notify
|
||||
|
||||
import "bytes"
|
||||
|
||||
// AuthMode describes how notify authenticates the peer during bootstrap.
|
||||
type AuthMode int
|
||||
|
||||
const (
|
||||
AuthNone AuthMode = iota
|
||||
AuthPSK
|
||||
AuthExternalPeer
|
||||
)
|
||||
|
||||
// ProtectionMode describes how notify protects steady-state transport payloads.
|
||||
type ProtectionMode int
|
||||
|
||||
const (
|
||||
ProtectionManaged ProtectionMode = iota
|
||||
ProtectionExternal
|
||||
ProtectionNested
|
||||
)
|
||||
|
||||
// SecurityOptions describes the high-level auth/protection policy.
|
||||
//
|
||||
// The current implementation still exposes dedicated helper constructors such
|
||||
// as UseModernPSKClient/Server and UsePSKOverExternalTransportClient/Server.
|
||||
type SecurityOptions struct {
|
||||
AuthMode AuthMode
|
||||
ProtectionMode ProtectionMode
|
||||
SharedSecret []byte
|
||||
RequireForwardSecrecy bool
|
||||
}
|
||||
|
||||
func authModeName(mode AuthMode) string {
|
||||
switch mode {
|
||||
case AuthNone:
|
||||
return "none"
|
||||
case AuthPSK:
|
||||
return "psk"
|
||||
case AuthExternalPeer:
|
||||
return "external-peer"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func protectionModeName(mode ProtectionMode) string {
|
||||
switch mode {
|
||||
case ProtectionManaged:
|
||||
return "managed"
|
||||
case ProtectionExternal:
|
||||
return "external"
|
||||
case ProtectionNested:
|
||||
return "nested"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type transportProtectionProfile struct {
|
||||
mode ProtectionMode
|
||||
msgEn func([]byte, []byte) []byte
|
||||
msgDe func([]byte, []byte) []byte
|
||||
fastStreamEncode transportFastStreamEncoder
|
||||
fastBulkEncode transportFastBulkEncoder
|
||||
fastPlainEncode transportFastPlainEncoder
|
||||
runtime *modernPSKCodecRuntime
|
||||
secretKey []byte
|
||||
keyMode string
|
||||
sessionID []byte
|
||||
forwardSecrecy bool
|
||||
forwardSecrecyFallback bool
|
||||
}
|
||||
|
||||
func cloneTransportProtectionKey(src []byte) []byte {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
return bytes.Clone(src)
|
||||
}
|
||||
|
||||
func newTransportProtectionProfile(mode ProtectionMode, bundle modernPSKTransportBundle, runtime *modernPSKCodecRuntime, secretKey []byte) transportProtectionProfile {
|
||||
return transportProtectionProfile{
|
||||
mode: mode,
|
||||
msgEn: bundle.msgEn,
|
||||
msgDe: bundle.msgDe,
|
||||
fastStreamEncode: bundle.fastStreamEncode,
|
||||
fastBulkEncode: bundle.fastBulkEncode,
|
||||
fastPlainEncode: bundle.fastPlainEncode,
|
||||
runtime: runtime,
|
||||
secretKey: cloneTransportProtectionKey(secretKey),
|
||||
keyMode: defaultTransportKeyMode(mode, secretKey),
|
||||
}
|
||||
}
|
||||
|
||||
func defaultTransportKeyMode(mode ProtectionMode, secretKey []byte) string {
|
||||
if len(secretKey) == 0 {
|
||||
return ""
|
||||
}
|
||||
switch mode {
|
||||
case ProtectionManaged, ProtectionNested:
|
||||
return peerAttachKeyModeStatic
|
||||
case ProtectionExternal:
|
||||
return transportKeyModeExternal
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func cloneTransportSessionID(src []byte) []byte {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
return bytes.Clone(src)
|
||||
}
|
||||
|
||||
func (p transportProtectionProfile) clone() transportProtectionProfile {
|
||||
p.secretKey = cloneTransportProtectionKey(p.secretKey)
|
||||
p.sessionID = cloneTransportSessionID(p.sessionID)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p transportProtectionProfile) withForwardSecrecyFallback(fallback bool) transportProtectionProfile {
|
||||
p = p.clone()
|
||||
p.forwardSecrecy = false
|
||||
p.forwardSecrecyFallback = fallback
|
||||
p.sessionID = nil
|
||||
return p
|
||||
}
|
||||
|
||||
func passthroughTransportCodec(_ []byte, payload []byte) []byte {
|
||||
return payload
|
||||
}
|
||||
|
||||
func passthroughFastPlainEncode(_ []byte, plainLen int, fill func([]byte) error) ([]byte, error) {
|
||||
if plainLen < 0 {
|
||||
return nil, errTransportPayloadEncryptFailed
|
||||
}
|
||||
buf := make([]byte, plainLen)
|
||||
if fill != nil {
|
||||
if err := fill(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func buildExternalTransportBundle() modernPSKTransportBundle {
|
||||
fastPlainEncode := passthroughFastPlainEncode
|
||||
fastStreamEncode := func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
||||
return encodeStreamFastFramePayloadFast(fastPlainEncode, secretKey, streamFastDataFrame{
|
||||
DataID: dataID,
|
||||
Seq: seq,
|
||||
Payload: payload,
|
||||
})
|
||||
}
|
||||
fastBulkEncode := func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
||||
return encodeBulkFastFramePayloadFast(fastPlainEncode, secretKey, bulkFastFrame{
|
||||
Type: bulkFastPayloadTypeData,
|
||||
DataID: dataID,
|
||||
Seq: seq,
|
||||
Payload: payload,
|
||||
})
|
||||
}
|
||||
return modernPSKTransportBundle{
|
||||
msgEn: passthroughTransportCodec,
|
||||
msgDe: passthroughTransportCodec,
|
||||
fastStreamEncode: fastStreamEncode,
|
||||
fastBulkEncode: fastBulkEncode,
|
||||
fastPlainEncode: fastPlainEncode,
|
||||
}
|
||||
}
|
||||
|
||||
func defaultTransportProtectionProfile() transportProtectionProfile {
|
||||
return newTransportProtectionProfile(ProtectionManaged, defaultModernPSKTransportBundle(), nil, nil)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientTransportProtectionSnapshot() transportProtectionProfile {
|
||||
if c == nil {
|
||||
return transportProtectionProfile{}
|
||||
}
|
||||
if state := c.transportProtection.Load(); state != nil {
|
||||
return *state
|
||||
}
|
||||
return transportProtectionProfile{
|
||||
mode: ProtectionManaged,
|
||||
msgEn: c.msgEn,
|
||||
msgDe: c.msgDe,
|
||||
fastStreamEncode: c.fastStreamEncode,
|
||||
fastBulkEncode: c.fastBulkEncode,
|
||||
fastPlainEncode: c.fastPlainEncode,
|
||||
runtime: c.modernPSKRuntime,
|
||||
secretKey: c.SecretKey,
|
||||
keyMode: defaultTransportKeyMode(ProtectionManaged, c.SecretKey),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) setClientTransportProtectionProfile(profile transportProtectionProfile) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
profile.secretKey = cloneTransportProtectionKey(profile.secretKey)
|
||||
profile.sessionID = cloneTransportSessionID(profile.sessionID)
|
||||
c.msgEn = profile.msgEn
|
||||
c.msgDe = profile.msgDe
|
||||
c.fastStreamEncode = profile.fastStreamEncode
|
||||
c.fastBulkEncode = profile.fastBulkEncode
|
||||
c.fastPlainEncode = profile.fastPlainEncode
|
||||
c.modernPSKRuntime = profile.runtime
|
||||
c.SecretKey = profile.secretKey
|
||||
c.transportProtection.Store(&profile)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clearClientSecurityProfiles() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.securityConfigured = false
|
||||
c.securityAuthMode = AuthNone
|
||||
c.securityProtectionMode = ProtectionManaged
|
||||
c.securityBootstrap = transportProtectionProfile{}
|
||||
c.securitySteady = transportProtectionProfile{}
|
||||
c.securitySteadyNegotiated = transportProtectionProfile{}
|
||||
c.securityRequireForwardSecrecy = false
|
||||
c.peerAttachAuthenticated = false
|
||||
c.peerAttachAuthFallback = false
|
||||
c.peerAttachAt = 0
|
||||
}
|
||||
|
||||
func (c *ClientCommon) configureClientSecurityProfiles(authMode AuthMode, protectionMode ProtectionMode, bootstrap transportProtectionProfile, steady transportProtectionProfile, requireForwardSecrecy bool) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.securityConfigured = true
|
||||
c.securityAuthMode = authMode
|
||||
c.securityProtectionMode = protectionMode
|
||||
c.securityBootstrap = bootstrap.clone()
|
||||
c.securitySteady = steady.clone()
|
||||
c.securitySteadyNegotiated = steady.clone()
|
||||
c.securityRequireForwardSecrecy = requireForwardSecrecy
|
||||
c.setClientTransportProtectionProfile(bootstrap)
|
||||
c.securityReadyCheck = len(bootstrap.secretKey) != 0
|
||||
c.skipKeyExchange = true
|
||||
}
|
||||
|
||||
func (c *ClientCommon) activateClientBootstrapTransportProtection() {
|
||||
if c == nil || !c.securityConfigured {
|
||||
return
|
||||
}
|
||||
c.resetClientNegotiatedSteadyTransportProtection()
|
||||
c.setClientTransportProtectionProfile(c.securityBootstrap)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) activateClientSteadyTransportProtection() {
|
||||
if c == nil || !c.securityConfigured {
|
||||
return
|
||||
}
|
||||
c.setClientTransportProtectionProfile(c.securitySteadyNegotiated)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) resetClientNegotiatedSteadyTransportProtection() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.securitySteadyNegotiated = c.securitySteady.clone().withForwardSecrecyFallback(false)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) setClientNegotiatedSteadyTransportProtection(profile transportProtectionProfile) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.securitySteadyNegotiated = profile.clone()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientSupportsForwardSecrecy() bool {
|
||||
if c == nil || !c.securityConfigured {
|
||||
return false
|
||||
}
|
||||
if c.securityAuthMode != AuthPSK {
|
||||
return false
|
||||
}
|
||||
if c.securityProtectionMode == ProtectionExternal {
|
||||
return false
|
||||
}
|
||||
return len(c.securityBootstrap.secretKey) != 0
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientRequiresForwardSecrecy() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
return c.securityRequireForwardSecrecy
|
||||
}
|
||||
|
||||
func (s *ServerCommon) serverSupportsForwardSecrecy() bool {
|
||||
if s == nil || !s.securityConfigured {
|
||||
return false
|
||||
}
|
||||
if s.securityAuthMode != AuthPSK {
|
||||
return false
|
||||
}
|
||||
if s.securityProtectionMode == ProtectionExternal {
|
||||
return false
|
||||
}
|
||||
return len(s.securityBootstrap.secretKey) != 0
|
||||
}
|
||||
|
||||
func (s *ServerCommon) serverRequiresForwardSecrecy() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
return s.securityRequireForwardSecrecy
|
||||
}
|
||||
|
||||
func (s *ServerCommon) setServerDefaultTransportProtectionProfile(profile transportProtectionProfile) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
profile.secretKey = cloneTransportProtectionKey(profile.secretKey)
|
||||
profile.sessionID = cloneTransportSessionID(profile.sessionID)
|
||||
s.defaultMsgEn = profile.msgEn
|
||||
s.defaultMsgDe = profile.msgDe
|
||||
s.defaultFastStreamEncode = profile.fastStreamEncode
|
||||
s.defaultFastBulkEncode = profile.fastBulkEncode
|
||||
s.defaultFastPlainEncode = profile.fastPlainEncode
|
||||
s.defaultModernPSKRuntime = profile.runtime
|
||||
s.SecretKey = profile.secretKey
|
||||
}
|
||||
|
||||
func (s *ServerCommon) clearServerSecurityProfiles() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.securityConfigured = false
|
||||
s.securityAuthMode = AuthNone
|
||||
s.securityProtectionMode = ProtectionManaged
|
||||
s.securityBootstrap = transportProtectionProfile{}
|
||||
s.securitySteady = transportProtectionProfile{}
|
||||
s.securityRequireForwardSecrecy = false
|
||||
}
|
||||
|
||||
func (s *ServerCommon) configureServerSecurityProfiles(authMode AuthMode, protectionMode ProtectionMode, bootstrap transportProtectionProfile, steady transportProtectionProfile, requireForwardSecrecy bool) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.securityConfigured = true
|
||||
s.securityAuthMode = authMode
|
||||
s.securityProtectionMode = protectionMode
|
||||
s.securityBootstrap = bootstrap.clone()
|
||||
s.securitySteady = steady.clone()
|
||||
s.securityRequireForwardSecrecy = requireForwardSecrecy
|
||||
s.setServerDefaultTransportProtectionProfile(bootstrap)
|
||||
s.securityReadyCheck = len(bootstrap.secretKey) != 0
|
||||
}
|
||||
|
||||
func (s *ServerCommon) applyLogicalSteadyTransportProtection(logical *LogicalConn) {
|
||||
if s == nil || logical == nil || !s.securityConfigured {
|
||||
return
|
||||
}
|
||||
logical.applyTransportProtectionProfile(s.securitySteady)
|
||||
}
|
||||
|
||||
func transportProtectionProfileFromAttachmentState(state *clientConnAttachmentState) transportProtectionProfile {
|
||||
if state == nil {
|
||||
return transportProtectionProfile{}
|
||||
}
|
||||
return transportProtectionProfile{
|
||||
mode: state.protectionMode,
|
||||
msgEn: state.msgEn,
|
||||
msgDe: state.msgDe,
|
||||
fastStreamEncode: state.fastStreamEncode,
|
||||
fastBulkEncode: state.fastBulkEncode,
|
||||
fastPlainEncode: state.fastPlainEncode,
|
||||
runtime: state.modernPSKRuntime,
|
||||
secretKey: cloneTransportProtectionKey(state.secretKey),
|
||||
keyMode: state.keyMode,
|
||||
sessionID: cloneTransportSessionID(state.sessionID),
|
||||
forwardSecrecy: state.forwardSecrecy,
|
||||
forwardSecrecyFallback: state.forwardSecrecyFallback,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LogicalConn) transportProtectionProfileSnapshot() transportProtectionProfile {
|
||||
if c == nil {
|
||||
return transportProtectionProfile{}
|
||||
}
|
||||
return transportProtectionProfileFromAttachmentState(c.attachmentStateSnapshot())
|
||||
}
|
||||
|
||||
func (c *LogicalConn) applyTransportProtectionProfile(profile transportProtectionProfile) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.updateAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.protectionMode = profile.mode
|
||||
state.msgEn = profile.msgEn
|
||||
state.msgDe = profile.msgDe
|
||||
state.fastStreamEncode = profile.fastStreamEncode
|
||||
state.fastBulkEncode = profile.fastBulkEncode
|
||||
state.fastPlainEncode = profile.fastPlainEncode
|
||||
state.modernPSKRuntime = profile.runtime
|
||||
state.secretKey = cloneTransportProtectionKey(profile.secretKey)
|
||||
state.keyMode = profile.keyMode
|
||||
state.sessionID = cloneTransportSessionID(profile.sessionID)
|
||||
state.forwardSecrecy = profile.forwardSecrecy
|
||||
state.forwardSecrecyFallback = profile.forwardSecrecyFallback
|
||||
})
|
||||
}
|
||||
201
security_psk.go
201
security_psk.go
@ -15,9 +15,10 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty")
|
||||
errModernPSKPayload = errors.New("invalid modern psk payload")
|
||||
errModernPSKRequired = errors.New("modern psk is required: call UseModernPSKClient/UseModernPSKServer or set a transport key before Connect/Listen")
|
||||
errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty")
|
||||
errModernPSKPayload = errors.New("invalid modern psk payload")
|
||||
errModernPSKRequired = errors.New("transport security is required: call UseModernPSKClient/UseModernPSKServer, UsePSKOverExternalTransportClient/Server, or set a transport key before Connect/Listen")
|
||||
errModernPSKForwardSecrecyUnsupported = errors.New("forward secrecy is unsupported for external transport protection")
|
||||
)
|
||||
|
||||
var (
|
||||
@ -47,9 +48,10 @@ var modernPSKPayloadPool sync.Pool
|
||||
// The current profile derives a 32-byte transport key with Argon2id and uses
|
||||
// AES-GCM with a per-codec nonce prefix plus a per-message counter.
|
||||
type ModernPSKOptions struct {
|
||||
Salt []byte
|
||||
AAD []byte
|
||||
Argon2Params starcrypto.Argon2Params
|
||||
Salt []byte
|
||||
AAD []byte
|
||||
Argon2Params starcrypto.Argon2Params
|
||||
RequireForwardSecrecy bool
|
||||
}
|
||||
|
||||
// DefaultModernPSKOptions returns the recommended settings for the current
|
||||
@ -78,24 +80,17 @@ func defaultModernPSKTransportBundle() modernPSKTransportBundle {
|
||||
// Argon2id, and switches message protection to AES-GCM. Configure it before
|
||||
// calling Connect/ConnectTimeout.
|
||||
func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
|
||||
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
|
||||
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
transport := buildModernPSKTransportBundle(aad)
|
||||
runtime, err := newModernPSKCodecRuntime(key, aad)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.SetSecretKey(key)
|
||||
c.SetMsgEn(transport.msgEn)
|
||||
c.SetMsgDe(transport.msgDe)
|
||||
if client, ok := c.(*ClientCommon); ok {
|
||||
client.fastStreamEncode = transport.fastStreamEncode
|
||||
client.fastBulkEncode = transport.fastBulkEncode
|
||||
client.fastPlainEncode = transport.fastPlainEncode
|
||||
client.modernPSKRuntime = runtime
|
||||
client.configureClientSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy)
|
||||
return nil
|
||||
}
|
||||
c.SetSecretKey(managed.secretKey)
|
||||
c.SetMsgEn(managed.msgEn)
|
||||
c.SetMsgDe(managed.msgDe)
|
||||
c.SetSkipExchangeKey(true)
|
||||
return nil
|
||||
}
|
||||
@ -106,24 +101,95 @@ func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) e
|
||||
// It derives a transport key with Argon2id and switches message protection to
|
||||
// AES-GCM. Configure it before calling Listen.
|
||||
func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
|
||||
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
|
||||
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
transport := buildModernPSKTransportBundle(aad)
|
||||
runtime, err := newModernPSKCodecRuntime(key, aad)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.SetSecretKey(key)
|
||||
s.SetDefaultCommEncode(transport.msgEn)
|
||||
s.SetDefaultCommDecode(transport.msgDe)
|
||||
if server, ok := s.(*ServerCommon); ok {
|
||||
server.defaultFastStreamEncode = transport.fastStreamEncode
|
||||
server.defaultFastBulkEncode = transport.fastBulkEncode
|
||||
server.defaultFastPlainEncode = transport.fastPlainEncode
|
||||
server.defaultModernPSKRuntime = runtime
|
||||
server.configureServerSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy)
|
||||
return nil
|
||||
}
|
||||
s.SetSecretKey(managed.secretKey)
|
||||
s.SetDefaultCommEncode(managed.msgEn)
|
||||
s.SetDefaultCommDecode(managed.msgDe)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UsePSKOverExternalTransportClient authenticates bootstrap with PSK and then
|
||||
// trusts the external channel for steady-state payload protection.
|
||||
func UsePSKOverExternalTransportClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
|
||||
if opts != nil && opts.RequireForwardSecrecy {
|
||||
return errModernPSKForwardSecrecyUnsupported
|
||||
}
|
||||
bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
steady := buildExternalProtectionProfile(bootstrap.secretKey)
|
||||
if client, ok := c.(*ClientCommon); ok {
|
||||
client.configureClientSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false)
|
||||
return nil
|
||||
}
|
||||
c.SetSecretKey(bootstrap.secretKey)
|
||||
c.SetMsgEn(bootstrap.msgEn)
|
||||
c.SetMsgDe(bootstrap.msgDe)
|
||||
c.SetSkipExchangeKey(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UsePSKOverExternalTransportServer authenticates bootstrap with PSK and then
|
||||
// trusts the external channel for steady-state payload protection.
|
||||
func UsePSKOverExternalTransportServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
|
||||
if opts != nil && opts.RequireForwardSecrecy {
|
||||
return errModernPSKForwardSecrecyUnsupported
|
||||
}
|
||||
bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
steady := buildExternalProtectionProfile(bootstrap.secretKey)
|
||||
if server, ok := s.(*ServerCommon); ok {
|
||||
server.configureServerSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false)
|
||||
return nil
|
||||
}
|
||||
s.SetSecretKey(bootstrap.secretKey)
|
||||
s.SetDefaultCommEncode(bootstrap.msgEn)
|
||||
s.SetDefaultCommDecode(bootstrap.msgDe)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UseNestedSecurityClient keeps notify transport protection enabled even when
|
||||
// the physical connection is already protected by an outer trusted channel.
|
||||
func UseNestedSecurityClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
|
||||
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if client, ok := c.(*ClientCommon); ok {
|
||||
client.configureClientSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy)
|
||||
return nil
|
||||
}
|
||||
c.SetSecretKey(managed.secretKey)
|
||||
c.SetMsgEn(managed.msgEn)
|
||||
c.SetMsgDe(managed.msgDe)
|
||||
c.SetSkipExchangeKey(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UseNestedSecurityServer keeps notify transport protection enabled even when
|
||||
// the physical connection is already protected by an outer trusted channel.
|
||||
func UseNestedSecurityServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
|
||||
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if server, ok := s.(*ServerCommon); ok {
|
||||
server.configureServerSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy)
|
||||
return nil
|
||||
}
|
||||
s.SetSecretKey(managed.secretKey)
|
||||
s.SetDefaultCommEncode(managed.msgEn)
|
||||
s.SetDefaultCommDecode(managed.msgDe)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -132,15 +198,22 @@ func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) e
|
||||
//
|
||||
// It is kept only as an explicit fallback path for existing deployments.
|
||||
func UseLegacySecurityClient(c Client) {
|
||||
if client, ok := c.(*ClientCommon); ok {
|
||||
client.clearClientSecurityProfiles()
|
||||
client.setClientTransportProtectionProfile(transportProtectionProfile{
|
||||
mode: ProtectionManaged,
|
||||
msgEn: defaultMsgEn,
|
||||
msgDe: defaultMsgDe,
|
||||
secretKey: bytes.Clone(defaultAesKey),
|
||||
})
|
||||
client.securityReadyCheck = false
|
||||
client.skipKeyExchange = false
|
||||
client.handshakeRsaPubKey = bytes.Clone(defaultRsaPubKey)
|
||||
return
|
||||
}
|
||||
c.SetSecretKey(bytes.Clone(defaultAesKey))
|
||||
c.SetMsgEn(defaultMsgEn)
|
||||
c.SetMsgDe(defaultMsgDe)
|
||||
if client, ok := c.(*ClientCommon); ok {
|
||||
client.fastStreamEncode = nil
|
||||
client.fastBulkEncode = nil
|
||||
client.fastPlainEncode = nil
|
||||
client.modernPSKRuntime = nil
|
||||
}
|
||||
c.SetSkipExchangeKey(false)
|
||||
c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey))
|
||||
}
|
||||
@ -150,15 +223,21 @@ func UseLegacySecurityClient(c Client) {
|
||||
//
|
||||
// It is kept only as an explicit fallback path for existing deployments.
|
||||
func UseLegacySecurityServer(s Server) {
|
||||
if server, ok := s.(*ServerCommon); ok {
|
||||
server.clearServerSecurityProfiles()
|
||||
server.setServerDefaultTransportProtectionProfile(transportProtectionProfile{
|
||||
mode: ProtectionManaged,
|
||||
msgEn: defaultMsgEn,
|
||||
msgDe: defaultMsgDe,
|
||||
secretKey: bytes.Clone(defaultAesKey),
|
||||
})
|
||||
server.securityReadyCheck = false
|
||||
server.handshakeRsaKey = bytes.Clone(defaultRsaKey)
|
||||
return
|
||||
}
|
||||
s.SetSecretKey(bytes.Clone(defaultAesKey))
|
||||
s.SetDefaultCommEncode(defaultMsgEn)
|
||||
s.SetDefaultCommDecode(defaultMsgDe)
|
||||
if server, ok := s.(*ServerCommon); ok {
|
||||
server.defaultFastStreamEncode = nil
|
||||
server.defaultFastBulkEncode = nil
|
||||
server.defaultFastPlainEncode = nil
|
||||
server.defaultModernPSKRuntime = nil
|
||||
}
|
||||
s.SetRsaPrivKey(bytes.Clone(defaultRsaKey))
|
||||
}
|
||||
|
||||
@ -174,6 +253,40 @@ func deriveModernPSKKey(sharedSecret []byte, opts *ModernPSKOptions) ([]byte, []
|
||||
return key, cfg.AAD, nil
|
||||
}
|
||||
|
||||
func deriveModernPSKProtectionProfile(sharedSecret []byte, opts *ModernPSKOptions, mode ProtectionMode) (transportProtectionProfile, error) {
|
||||
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
|
||||
if err != nil {
|
||||
return transportProtectionProfile{}, err
|
||||
}
|
||||
transport := buildModernPSKTransportBundle(aad)
|
||||
runtime, err := newModernPSKCodecRuntime(key, aad)
|
||||
if err != nil {
|
||||
return transportProtectionProfile{}, err
|
||||
}
|
||||
return newTransportProtectionProfile(mode, transport, runtime, key), nil
|
||||
}
|
||||
|
||||
func buildExternalProtectionProfile(secretKey []byte) transportProtectionProfile {
|
||||
return newTransportProtectionProfile(ProtectionExternal, buildExternalTransportBundle(), nil, secretKey)
|
||||
}
|
||||
|
||||
func deriveModernPSKSessionProtectionProfile(base transportProtectionProfile, sessionKey []byte, sessionID []byte) (transportProtectionProfile, error) {
|
||||
aad := bytes.Clone(defaultModernPSKAAD)
|
||||
if base.runtime != nil && len(base.runtime.aad) != 0 {
|
||||
aad = bytes.Clone(base.runtime.aad)
|
||||
}
|
||||
runtime, err := newModernPSKCodecRuntime(sessionKey, aad)
|
||||
if err != nil {
|
||||
return transportProtectionProfile{}, err
|
||||
}
|
||||
profile := newTransportProtectionProfile(base.mode, buildModernPSKTransportBundle(aad), runtime, sessionKey)
|
||||
profile.keyMode = peerAttachKeyModeECDHE
|
||||
profile.sessionID = cloneTransportSessionID(sessionID)
|
||||
profile.forwardSecrecy = true
|
||||
profile.forwardSecrecyFallback = false
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions {
|
||||
cfg := DefaultModernPSKOptions()
|
||||
if opts == nil {
|
||||
|
||||
@ -3,8 +3,10 @@ package notify
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"b612.me/starcrypto"
|
||||
)
|
||||
@ -207,6 +209,17 @@ func TestUseModernPSKRejectsEmptySecret(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsePSKOverExternalTransportRejectsForwardSecrecyRequirement(t *testing.T) {
|
||||
opts := testModernPSKOptions()
|
||||
opts.RequireForwardSecrecy = true
|
||||
if err := UsePSKOverExternalTransportClient(NewClient(), []byte("secret"), opts); !errors.Is(err, errModernPSKForwardSecrecyUnsupported) {
|
||||
t.Fatalf("UsePSKOverExternalTransportClient error = %v, want %v", err, errModernPSKForwardSecrecyUnsupported)
|
||||
}
|
||||
if err := UsePSKOverExternalTransportServer(NewServer(), []byte("secret"), opts); !errors.Is(err, errModernPSKForwardSecrecyUnsupported) {
|
||||
t.Fatalf("UsePSKOverExternalTransportServer error = %v, want %v", err, errModernPSKForwardSecrecyUnsupported)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) {
|
||||
key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions())
|
||||
if err != nil {
|
||||
@ -315,6 +328,166 @@ func TestModernPSKFastBulkEncodeRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalTransportFastStreamEncodeRoundTrip(t *testing.T) {
|
||||
transport := buildExternalTransportBundle()
|
||||
wire, err := transport.fastStreamEncode(nil, 23, 7, []byte("payload"))
|
||||
if err != nil {
|
||||
t.Fatalf("fastStreamEncode failed: %v", err)
|
||||
}
|
||||
plain := transport.msgDe(nil, wire)
|
||||
frame, matched, err := decodeStreamFastDataFrame(plain)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeStreamFastDataFrame failed: %v", err)
|
||||
}
|
||||
if !matched {
|
||||
t.Fatal("decodeStreamFastDataFrame should match fast payload")
|
||||
}
|
||||
if frame.DataID != 23 || frame.Seq != 7 || !bytes.Equal(frame.Payload, []byte("payload")) {
|
||||
t.Fatalf("frame mismatch: %+v", frame)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalTransportFastBulkEncodeRoundTrip(t *testing.T) {
|
||||
transport := buildExternalTransportBundle()
|
||||
wire, err := transport.fastBulkEncode(nil, 41, 9, []byte("payload"))
|
||||
if err != nil {
|
||||
t.Fatalf("fastBulkEncode failed: %v", err)
|
||||
}
|
||||
plain := transport.msgDe(nil, wire)
|
||||
frame, matched, err := decodeBulkFastDataFrame(plain)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeBulkFastDataFrame failed: %v", err)
|
||||
}
|
||||
if !matched {
|
||||
t.Fatal("decodeBulkFastDataFrame should match fast payload")
|
||||
}
|
||||
if frame.DataID != 41 || frame.Seq != 9 || !bytes.Equal(frame.Payload, []byte("payload")) {
|
||||
t.Fatalf("frame mismatch: %+v", frame)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptTransportPayloadCodecPooledExternalDefersRelease(t *testing.T) {
|
||||
payload := []byte("payload")
|
||||
released := false
|
||||
plain, release, err := decryptTransportPayloadCodecPooled(ProtectionExternal, nil, passthroughTransportCodec, nil, payload, func() {
|
||||
released = true
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("decryptTransportPayloadCodecPooled failed: %v", err)
|
||||
}
|
||||
if released {
|
||||
t.Fatal("release should not run before caller is done")
|
||||
}
|
||||
if !bytes.Equal(plain, payload) {
|
||||
t.Fatalf("plain mismatch: got %q want %q", plain, payload)
|
||||
}
|
||||
if release == nil {
|
||||
t.Fatal("release callback should be preserved for external mode")
|
||||
}
|
||||
release()
|
||||
if !released {
|
||||
t.Fatal("release callback should run when caller finishes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsePSKOverExternalTransportConnectByConnSwitchesToExternal(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
if err := UsePSKOverExternalTransportServer(server, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||
}
|
||||
server.SetLink("external-roundtrip", func(msg *Message) {
|
||||
_ = msg.Reply([]byte("ack:" + string(msg.Value)))
|
||||
})
|
||||
})
|
||||
if err := UsePSKOverExternalTransportClient(client, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||
}
|
||||
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionManaged {
|
||||
t.Fatalf("client bootstrap mode = %v, want %v", got, ProtectionManaged)
|
||||
}
|
||||
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
bootstrapPeerAttachConnForTest(t, server, right)
|
||||
if err := client.ConnectByConn(left); err != nil {
|
||||
t.Fatalf("ConnectByConn failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
client.setByeFromServer(true)
|
||||
_ = client.Stop()
|
||||
}()
|
||||
|
||||
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionExternal {
|
||||
t.Fatalf("client steady mode = %v, want %v", got, ProtectionExternal)
|
||||
}
|
||||
|
||||
reply, err := client.SendWait("external-roundtrip", []byte("ping"), time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("SendWait failed: %v", err)
|
||||
}
|
||||
if got, want := string(reply.Value), "ack:ping"; got != want {
|
||||
t.Fatalf("reply mismatch: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
list := server.GetLogicalConnList()
|
||||
if len(list) != 1 {
|
||||
t.Fatalf("logical conn count = %d, want 1", len(list))
|
||||
}
|
||||
if got := list[0].protectionModeSnapshot(); got != ProtectionExternal {
|
||||
t.Fatalf("server steady mode = %v, want %v", got, ProtectionExternal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseNestedSecurityConnectByConnKeepsNestedMode(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
if err := UseNestedSecurityServer(server, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseNestedSecurityServer failed: %v", err)
|
||||
}
|
||||
server.SetLink("nested-roundtrip", func(msg *Message) {
|
||||
_ = msg.Reply([]byte("ack:" + string(msg.Value)))
|
||||
})
|
||||
})
|
||||
if err := UseNestedSecurityClient(client, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseNestedSecurityClient failed: %v", err)
|
||||
}
|
||||
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionNested {
|
||||
t.Fatalf("client bootstrap mode = %v, want %v", got, ProtectionNested)
|
||||
}
|
||||
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
bootstrapPeerAttachConnForTest(t, server, right)
|
||||
if err := client.ConnectByConn(left); err != nil {
|
||||
t.Fatalf("ConnectByConn failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
client.setByeFromServer(true)
|
||||
_ = client.Stop()
|
||||
}()
|
||||
|
||||
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionNested {
|
||||
t.Fatalf("client steady mode = %v, want %v", got, ProtectionNested)
|
||||
}
|
||||
|
||||
reply, err := client.SendWait("nested-roundtrip", []byte("ping"), time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("SendWait failed: %v", err)
|
||||
}
|
||||
if got, want := string(reply.Value), "ack:ping"; got != want {
|
||||
t.Fatalf("reply mismatch: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
list := server.GetLogicalConnList()
|
||||
if len(list) != 1 {
|
||||
t.Fatalf("logical conn count = %d, want 1", len(list))
|
||||
}
|
||||
if got := list[0].protectionModeSnapshot(); got != ProtectionNested {
|
||||
t.Fatalf("server steady mode = %v, want %v", got, ProtectionNested)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseLegacySecurityRoundTrip(t *testing.T) {
|
||||
client := NewClient()
|
||||
server := NewServer()
|
||||
|
||||
107
server.go
107
server.go
@ -10,52 +10,65 @@ import (
|
||||
)
|
||||
|
||||
type ServerCommon struct {
|
||||
msgID uint64
|
||||
alive atomic.Value
|
||||
status Status
|
||||
sessionOwnerState atomic.Int32
|
||||
sessionRuntime atomic.Pointer[serverSessionRuntime]
|
||||
listener net.Listener
|
||||
udpListener *net.UDPConn
|
||||
queue *stario.StarQueue
|
||||
stopFn context.CancelFunc
|
||||
stopCtx context.Context
|
||||
maxReadTimeout time.Duration
|
||||
maxWriteTimeout time.Duration
|
||||
parallelNum int
|
||||
wg stario.WaitGroup
|
||||
peerRegistry *serverPeerRegistry
|
||||
mu sync.RWMutex
|
||||
handshakeRsaKey []byte
|
||||
SecretKey []byte
|
||||
defaultMsgEn func([]byte, []byte) []byte
|
||||
defaultMsgDe func([]byte, []byte) []byte
|
||||
defaultFastStreamEncode transportFastStreamEncoder
|
||||
defaultFastBulkEncode transportFastBulkEncoder
|
||||
defaultFastPlainEncode transportFastPlainEncoder
|
||||
defaultModernPSKRuntime *modernPSKCodecRuntime
|
||||
linkFns map[string]func(message *Message)
|
||||
defaultFns func(message *Message)
|
||||
noFinSyncMsgMaxKeepSeconds int64
|
||||
maxHeartbeatLostSeconds int64
|
||||
sequenceDe func([]byte) (interface{}, error)
|
||||
sequenceEn func(interface{}) ([]byte, error)
|
||||
logicalSession *logicalSessionState
|
||||
onFileEvent func(FileEvent)
|
||||
fileEventObserver func(FileEvent)
|
||||
fileTransferCfg fileTransferConfig
|
||||
signalReliableCfg signalReliabilityConfig
|
||||
streamRuntime *streamRuntime
|
||||
recordRuntime *recordRuntime
|
||||
bulkRuntime *bulkRuntime
|
||||
bulkOpenTuning BulkOpenTuning
|
||||
bulkDedicatedSidecarMu sync.Mutex
|
||||
bulkDedicatedSidecars map[*LogicalConn]map[uint32]*bulkDedicatedSidecar
|
||||
connectionRetryState *connectionRetryState
|
||||
detachedClientKeepSeconds int64
|
||||
securityReadyCheck bool
|
||||
showError bool
|
||||
debugMode bool
|
||||
msgID uint64
|
||||
alive atomic.Value
|
||||
status Status
|
||||
sessionOwnerState atomic.Int32
|
||||
sessionRuntime atomic.Pointer[serverSessionRuntime]
|
||||
listener net.Listener
|
||||
udpListener *net.UDPConn
|
||||
queue *stario.StarQueue
|
||||
stopFn context.CancelFunc
|
||||
stopCtx context.Context
|
||||
maxReadTimeout time.Duration
|
||||
maxWriteTimeout time.Duration
|
||||
parallelNum int
|
||||
wg stario.WaitGroup
|
||||
peerRegistry *serverPeerRegistry
|
||||
mu sync.RWMutex
|
||||
handshakeRsaKey []byte
|
||||
SecretKey []byte
|
||||
defaultMsgEn func([]byte, []byte) []byte
|
||||
defaultMsgDe func([]byte, []byte) []byte
|
||||
defaultFastStreamEncode transportFastStreamEncoder
|
||||
defaultFastBulkEncode transportFastBulkEncoder
|
||||
defaultFastPlainEncode transportFastPlainEncoder
|
||||
defaultModernPSKRuntime *modernPSKCodecRuntime
|
||||
peerAttachSecurity atomic.Pointer[peerAttachSecurityState]
|
||||
securityBootstrap transportProtectionProfile
|
||||
securitySteady transportProtectionProfile
|
||||
securityAuthMode AuthMode
|
||||
securityProtectionMode ProtectionMode
|
||||
securityRequireForwardSecrecy bool
|
||||
securityConfigured bool
|
||||
peerAttachReplay peerAttachReplayCache
|
||||
peerAttachExplicitCount atomic.Int64
|
||||
peerAttachAuthFallbackCount atomic.Int64
|
||||
peerAttachAuthRejectCount atomic.Int64
|
||||
peerAttachDowngradeRejectCount atomic.Int64
|
||||
peerAttachBindingRejectCount atomic.Int64
|
||||
linkFns map[string]func(message *Message)
|
||||
defaultFns func(message *Message)
|
||||
noFinSyncMsgMaxKeepSeconds int64
|
||||
maxHeartbeatLostSeconds int64
|
||||
sequenceDe func([]byte) (interface{}, error)
|
||||
sequenceEn func(interface{}) ([]byte, error)
|
||||
logicalSession *logicalSessionState
|
||||
onFileEvent func(FileEvent)
|
||||
fileEventObserver func(FileEvent)
|
||||
fileTransferCfg fileTransferConfig
|
||||
signalReliableCfg signalReliabilityConfig
|
||||
streamRuntime *streamRuntime
|
||||
recordRuntime *recordRuntime
|
||||
bulkRuntime *bulkRuntime
|
||||
bulkOpenTuning BulkOpenTuning
|
||||
bulkDedicatedSidecarMu sync.Mutex
|
||||
bulkDedicatedSidecars map[*LogicalConn]map[uint32]*bulkDedicatedSidecar
|
||||
connectionRetryState *connectionRetryState
|
||||
detachedClientKeepSeconds int64
|
||||
securityReadyCheck bool
|
||||
showError bool
|
||||
debugMode bool
|
||||
}
|
||||
|
||||
func NewServer() Server {
|
||||
@ -93,6 +106,8 @@ func NewServer() Server {
|
||||
server.defaultFns = func(message *Message) {
|
||||
return
|
||||
}
|
||||
server.setServerDefaultTransportProtectionProfile(defaultTransportProtectionProfile())
|
||||
server.peerAttachSecurity.Store(defaultPeerAttachSecurityState())
|
||||
server.sessionRuntime.Store(newServerSessionRuntimeBase(server.stopCtx, server.stopFn))
|
||||
bindServerStreamControl(&server)
|
||||
bindServerBulkControl(&server)
|
||||
|
||||
@ -413,7 +413,7 @@ func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *Tra
|
||||
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, startSeq, payload)
|
||||
return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, startSeq, payload, payloadOwned)
|
||||
}
|
||||
if transport == nil {
|
||||
return 0, errBulkTransportNil
|
||||
|
||||
@ -31,22 +31,28 @@ func (s *ServerCommon) Stop() error {
|
||||
// Deprecated: SetDefaultCommEncode overrides the transport codec directly.
|
||||
// Prefer UseModernPSKServer or UseLegacySecurityServer.
|
||||
func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) {
|
||||
s.defaultMsgEn = fn
|
||||
s.defaultFastStreamEncode = nil
|
||||
s.defaultFastBulkEncode = nil
|
||||
s.defaultFastPlainEncode = nil
|
||||
s.defaultModernPSKRuntime = nil
|
||||
profile := transportProtectionProfile{
|
||||
mode: ProtectionManaged,
|
||||
msgEn: fn,
|
||||
msgDe: s.defaultMsgDe,
|
||||
secretKey: s.SecretKey,
|
||||
}
|
||||
s.setServerDefaultTransportProtectionProfile(profile)
|
||||
s.clearServerSecurityProfiles()
|
||||
s.securityReadyCheck = false
|
||||
}
|
||||
|
||||
// Deprecated: SetDefaultCommDecode overrides the transport codec directly.
|
||||
// Prefer UseModernPSKServer or UseLegacySecurityServer.
|
||||
func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) {
|
||||
s.defaultMsgDe = fn
|
||||
s.defaultFastStreamEncode = nil
|
||||
s.defaultFastBulkEncode = nil
|
||||
s.defaultFastPlainEncode = nil
|
||||
s.defaultModernPSKRuntime = nil
|
||||
profile := transportProtectionProfile{
|
||||
mode: ProtectionManaged,
|
||||
msgEn: s.defaultMsgEn,
|
||||
msgDe: fn,
|
||||
secretKey: s.SecretKey,
|
||||
}
|
||||
s.setServerDefaultTransportProtectionProfile(profile)
|
||||
s.clearServerSecurityProfiles()
|
||||
s.securityReadyCheck = false
|
||||
}
|
||||
|
||||
@ -98,14 +104,21 @@ func (s *ServerCommon) GetSecretKey() []byte {
|
||||
// Deprecated: SetSecretKey injects a raw transport key directly.
|
||||
// Prefer UseModernPSKServer or UseLegacySecurityServer.
|
||||
func (s *ServerCommon) SetSecretKey(key []byte) {
|
||||
s.SecretKey = key
|
||||
if len(key) == 0 {
|
||||
s.defaultModernPSKRuntime = nil
|
||||
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
|
||||
s.defaultModernPSKRuntime = runtime
|
||||
} else {
|
||||
s.defaultModernPSKRuntime = nil
|
||||
profile := transportProtectionProfile{
|
||||
mode: ProtectionManaged,
|
||||
msgEn: s.defaultMsgEn,
|
||||
msgDe: s.defaultMsgDe,
|
||||
secretKey: cloneTransportProtectionKey(key),
|
||||
}
|
||||
if len(key) == 0 {
|
||||
profile.runtime = nil
|
||||
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
|
||||
profile.runtime = runtime
|
||||
} else {
|
||||
profile.runtime = nil
|
||||
}
|
||||
s.setServerDefaultTransportProtectionProfile(profile)
|
||||
s.clearServerSecurityProfiles()
|
||||
s.securityReadyCheck = len(key) == 0
|
||||
}
|
||||
|
||||
|
||||
@ -3,12 +3,54 @@ package notify
|
||||
import (
|
||||
"b612.me/stario"
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func readServerEnvelopeFromConnWithProfile(t *testing.T, server *ServerCommon, profile transportProtectionProfile, conn net.Conn, timeout time.Duration) Envelope {
|
||||
t.Helper()
|
||||
|
||||
queue := stario.NewQueue()
|
||||
deadline := time.Now().Add(timeout)
|
||||
buf := make([]byte, 4096)
|
||||
for time.Now().Before(deadline) {
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
t.Fatalf("SetReadDeadline failed: %v", err)
|
||||
}
|
||||
n, err := conn.Read(buf)
|
||||
if n > 0 {
|
||||
if parseErr := queue.ParseMessage(buf[:n], "server-inbound-profile"); parseErr != nil {
|
||||
t.Fatalf("ParseMessage failed: %v", parseErr)
|
||||
}
|
||||
select {
|
||||
case msg := <-queue.RestoreChan():
|
||||
plain, decErr := decryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, msg.Msg)
|
||||
if decErr != nil {
|
||||
t.Fatalf("decryptTransportPayloadCodec failed: %v", decErr)
|
||||
}
|
||||
env, decErr := server.decodeEnvelopePlain(plain)
|
||||
if decErr != nil {
|
||||
t.Fatalf("decodeEnvelopePlain failed: %v", decErr)
|
||||
}
|
||||
return env
|
||||
default:
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
break
|
||||
}
|
||||
t.Fatalf("conn Read failed: %v", err)
|
||||
}
|
||||
t.Fatal("timed out waiting for server envelope")
|
||||
return Envelope{}
|
||||
}
|
||||
|
||||
func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
UseLegacySecurityServer(server)
|
||||
@ -85,6 +127,64 @@ func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageReplyUsesInboundProtectionSnapshotAfterLogicalSwitch(t *testing.T) {
|
||||
secret := []byte("correct horse battery staple")
|
||||
alternate, err := deriveModernPSKProtectionProfile([]byte("notify-reply-snapshot-alternate"), testModernPSKOptions(), ProtectionManaged)
|
||||
if err != nil {
|
||||
t.Fatalf("deriveModernPSKProtectionProfile failed: %v", err)
|
||||
}
|
||||
handlerErr := make(chan error, 1)
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
server.SetLink("reply-snapshot", func(msg *Message) {
|
||||
if msg == nil || msg.LogicalConn == nil {
|
||||
select {
|
||||
case handlerErr <- errors.New("reply-snapshot logical is nil"):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
msg.LogicalConn.applyTransportProtectionProfile(alternate)
|
||||
if err := msg.Reply([]byte("ack")); err != nil {
|
||||
select {
|
||||
case handlerErr <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
bootstrapPeerAttachConnForTest(t, server, right)
|
||||
if err := client.ConnectByConn(left); err != nil {
|
||||
t.Fatalf("ConnectByConn failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
client.setByeFromServer(true)
|
||||
_ = client.Stop()
|
||||
}()
|
||||
|
||||
reply, err := client.SendWait("reply-snapshot", []byte("ping"), time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("SendWait failed: %v", err)
|
||||
}
|
||||
if got, want := string(reply.Value), "ack"; got != want {
|
||||
t.Fatalf("reply value = %q, want %q", got, want)
|
||||
}
|
||||
select {
|
||||
case err := <-handlerErr:
|
||||
t.Fatalf("reply-snapshot handler failed: %v", err)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerHandleReceivedSignalReliabilityUsesInboundConnForStaleTransport(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
UseLegacySecurityServer(server)
|
||||
|
||||
@ -93,26 +93,34 @@ func (s *ServerCommon) pushTransportPayloadSourceFast(payload []byte, release fu
|
||||
}
|
||||
return true
|
||||
}
|
||||
if s.tryDispatchBorrowedBulkTransportPayload(source, payload) {
|
||||
logical, transport := s.resolveInboundSource(source)
|
||||
if logical == nil {
|
||||
if release != nil {
|
||||
release()
|
||||
}
|
||||
return true
|
||||
}
|
||||
owned := append([]byte(nil), payload...)
|
||||
if release != nil {
|
||||
release()
|
||||
plain, plainRelease, err := s.decryptTransportPayloadLogicalPooled(logical, payload, release)
|
||||
if err != nil {
|
||||
if s.showError || s.debugMode {
|
||||
fmt.Println("server decode transport payload error", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
inboundConn := serverInboundConn(source)
|
||||
if s.tryDispatchBorrowedTransportPlain(logical, transport, inboundConn, plain, plainRelease) {
|
||||
return true
|
||||
}
|
||||
owned := plain
|
||||
if plainRelease != nil {
|
||||
owned = append([]byte(nil), plain...)
|
||||
plainRelease()
|
||||
}
|
||||
s.wg.Add(1)
|
||||
if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() {
|
||||
defer s.wg.Done()
|
||||
logical, transport := s.resolveInboundSource(source)
|
||||
if logical == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
inboundConn := serverInboundConn(source)
|
||||
if err := s.dispatchInboundTransportPayload(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) {
|
||||
if err := s.dispatchInboundTransportPlain(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) {
|
||||
fmt.Println("server decode envelope error", err)
|
||||
}
|
||||
}) {
|
||||
|
||||
@ -358,16 +358,20 @@ func (s *ServerCommon) sendEnvelopeTransport(transport *TransportConn, env Envel
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendEnvelopeInboundTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope) error {
|
||||
return s.sendEnvelopeInboundTransportWithProfile(logical, transport, conn, nil, env)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendEnvelopeInboundTransportWithProfile(logical *LogicalConn, transport *TransportConn, conn net.Conn, profile *transportProtectionProfile, env Envelope) error {
|
||||
if logical == nil && transport != nil {
|
||||
logical = transport.logicalConnSnapshot()
|
||||
}
|
||||
if logical == nil {
|
||||
return transportDetachedErrorForPeer(logical, transport)
|
||||
}
|
||||
if logical.msgEnSnapshot() == nil {
|
||||
if profile == nil && logical.msgEnSnapshot() == nil {
|
||||
return transportDetachedErrorForPeer(logical, transport)
|
||||
}
|
||||
payload, err := s.encodeEnvelopePayloadLogical(logical, env)
|
||||
payload, err := s.encodeEnvelopePayloadInbound(logical, env, profile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -402,7 +406,18 @@ func (s *ServerCommon) writeControlEnvelopePayload(logical *LogicalConn, transpo
|
||||
return sender.submit(payload, writeDeadlineFromTimeout(logical.maxWriteTimeoutSnapshot()))
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, msg TransferMsg) error {
|
||||
func (s *ServerCommon) encodeEnvelopePayloadInbound(logical *LogicalConn, env Envelope, profile *transportProtectionProfile) ([]byte, error) {
|
||||
if profile == nil {
|
||||
return s.encodeEnvelopePayloadLogical(logical, env)
|
||||
}
|
||||
data, err := s.encodeEnvelopePlain(env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgEn, profile.secretKey, data)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, profile *transportProtectionProfile, msg TransferMsg) error {
|
||||
if logical == nil && transport != nil {
|
||||
logical = transport.logicalConnSnapshot()
|
||||
}
|
||||
@ -413,7 +428,7 @@ func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *Tran
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.sendEnvelopeInboundTransport(logical, transport, conn, env)
|
||||
return s.sendEnvelopeInboundTransportWithProfile(logical, transport, conn, profile, env)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) writeEnvelopePayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte) error {
|
||||
|
||||
@ -110,7 +110,17 @@ func (s *ServerCommon) registerAcceptedLogical(logical *LogicalConn) *LogicalCon
|
||||
}
|
||||
logical.setServer(s)
|
||||
logical.applyAttachmentProfile(s.maxReadTimeout, s.maxWriteTimeout, s.defaultMsgEn, s.defaultMsgDe, s.defaultFastStreamEncode, s.defaultFastBulkEncode, s.defaultFastPlainEncode, s.handshakeRsaKey, s.SecretKey)
|
||||
logical.setModernPSKRuntime(s.defaultModernPSKRuntime)
|
||||
logical.updateAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.authMode = s.securityAuthMode
|
||||
state.peerAttached = false
|
||||
state.peerAttachFallback = false
|
||||
state.peerAttachAt = 0
|
||||
})
|
||||
if s.securityConfigured {
|
||||
logical.applyTransportProtectionProfile(s.securityBootstrap)
|
||||
} else {
|
||||
logical.setModernPSKRuntime(s.defaultModernPSKRuntime)
|
||||
}
|
||||
logical.markHeartbeatNow()
|
||||
return s.getPeerRegistry().registerLogical(logical)
|
||||
}
|
||||
|
||||
@ -26,6 +26,8 @@ type Server interface {
|
||||
RecoverTransferSnapshots(context.Context) error
|
||||
SetBulkOpenTuning(BulkOpenTuning)
|
||||
BulkOpenTuning() BulkOpenTuning
|
||||
SetPeerAttachSecurityConfig(PeerAttachSecurityConfig) error
|
||||
PeerAttachSecurityConfig() PeerAttachSecurityConfig
|
||||
SetFileReceiveDir(dir string) error
|
||||
send(c *ClientConn, msg TransferMsg) (WaitMsg, error)
|
||||
sendEnvelope(c *ClientConn, env Envelope) error
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
@ -17,6 +18,19 @@ type ClientRuntimeSnapshot struct {
|
||||
ConnectNetwork string
|
||||
ConnectAddress string
|
||||
CanReconnect bool
|
||||
AuthMode string
|
||||
ProtectionMode string
|
||||
ProtectionKeyMode string
|
||||
ForwardSecrecyEnabled bool
|
||||
ForwardSecrecyFallback bool
|
||||
ForwardSecrecyRequired bool
|
||||
TransportSessionID string
|
||||
PeerAttachAuthenticated bool
|
||||
PeerAttachAuthFallback bool
|
||||
LastPeerAttachAt time.Time
|
||||
PeerAttachRequireExplicitAuth bool
|
||||
PeerAttachRequireChannelBinding bool
|
||||
PeerAttachChannelBindingConfigured bool
|
||||
BulkNetworkProfile string
|
||||
BulkDefaultMode string
|
||||
BulkChunkSize int
|
||||
@ -37,22 +51,38 @@ type ClientRuntimeSnapshot struct {
|
||||
}
|
||||
|
||||
type ServerRuntimeSnapshot struct {
|
||||
OwnerState string
|
||||
Alive bool
|
||||
ClientCount int
|
||||
DetachedClientCount int
|
||||
DetachedReattachableClientCount int
|
||||
DetachedExpiredClientCount int
|
||||
DetachedClientKeepSec int64
|
||||
TransportAttached bool
|
||||
HasRuntimeListener bool
|
||||
HasRuntimeUDPListener bool
|
||||
HasRuntimeQueue bool
|
||||
HasRuntimeStopCtx bool
|
||||
BulkChunkSize int
|
||||
BulkWindowBytes int
|
||||
BulkMaxInFlight int
|
||||
Retry ConnectionRetrySnapshot
|
||||
OwnerState string
|
||||
Alive bool
|
||||
ClientCount int
|
||||
DetachedClientCount int
|
||||
DetachedReattachableClientCount int
|
||||
DetachedExpiredClientCount int
|
||||
DetachedClientKeepSec int64
|
||||
TransportAttached bool
|
||||
HasRuntimeListener bool
|
||||
HasRuntimeUDPListener bool
|
||||
HasRuntimeQueue bool
|
||||
HasRuntimeStopCtx bool
|
||||
AuthMode string
|
||||
ProtectionMode string
|
||||
ForwardSecrecySupported bool
|
||||
ForwardSecrecyRequired bool
|
||||
PeerAttachRequireExplicitAuth bool
|
||||
PeerAttachRequireChannelBinding bool
|
||||
PeerAttachChannelBindingConfigured bool
|
||||
PeerAttachReplayWindow time.Duration
|
||||
PeerAttachReplayCapacity int
|
||||
PeerAttachExplicitAuth int64
|
||||
PeerAttachAuthFallbacks int64
|
||||
PeerAttachAuthRejects int64
|
||||
PeerAttachDowngradeRejects int64
|
||||
PeerAttachBindingRejects int64
|
||||
PeerAttachReplayRejects int64
|
||||
PeerAttachReplayOverflowRejects int64
|
||||
BulkChunkSize int
|
||||
BulkWindowBytes int
|
||||
BulkMaxInFlight int
|
||||
Retry ConnectionRetrySnapshot
|
||||
}
|
||||
|
||||
type ClientConnRuntimeSnapshot struct {
|
||||
@ -82,6 +112,15 @@ type ClientConnRuntimeSnapshot struct {
|
||||
TransportDetachRemaining time.Duration
|
||||
TransportDetachExpired bool
|
||||
ReattachEligible bool
|
||||
AuthMode string
|
||||
ProtectionMode string
|
||||
ProtectionKeyMode string
|
||||
ForwardSecrecyEnabled bool
|
||||
ForwardSecrecyFallback bool
|
||||
TransportSessionID string
|
||||
PeerAttachAuthenticated bool
|
||||
PeerAttachAuthFallback bool
|
||||
LastPeerAttachAt time.Time
|
||||
TransportBulkAdaptiveSoftPayloadBytes int
|
||||
TransportStreamAdaptiveSoftPayloadBytes int
|
||||
TransportStreamAdaptiveWaitThresholdBytes int
|
||||
@ -108,6 +147,19 @@ func (c *ClientCommon) clientRuntimeSnapshot() ClientRuntimeSnapshot {
|
||||
snapshot.ConnectAddress = source.addr
|
||||
snapshot.CanReconnect = source.canReconnect()
|
||||
}
|
||||
snapshot.AuthMode = authModeName(c.securityAuthMode)
|
||||
snapshot.ProtectionMode = protectionModeName(c.securityProtectionMode)
|
||||
protection := c.clientTransportProtectionSnapshot()
|
||||
snapshot.ProtectionKeyMode = protection.keyMode
|
||||
snapshot.ForwardSecrecyEnabled = protection.forwardSecrecy
|
||||
snapshot.ForwardSecrecyFallback = protection.forwardSecrecyFallback
|
||||
snapshot.ForwardSecrecyRequired = c.clientRequiresForwardSecrecy()
|
||||
snapshot.TransportSessionID = hex.EncodeToString(protection.sessionID)
|
||||
snapshot.PeerAttachAuthenticated, snapshot.PeerAttachAuthFallback, snapshot.LastPeerAttachAt = c.clientPeerAttachAuthSnapshot()
|
||||
peerAttachCfg := c.peerAttachSecuritySnapshot()
|
||||
snapshot.PeerAttachRequireExplicitAuth = peerAttachCfg.requireExplicitAuth
|
||||
snapshot.PeerAttachRequireChannelBinding = peerAttachCfg.requireChannelBinding
|
||||
snapshot.PeerAttachChannelBindingConfigured = peerAttachCfg.channelBinding != nil
|
||||
snapshot.BulkNetworkProfile = bulkNetworkProfileName(c.BulkNetworkProfile())
|
||||
snapshot.BulkDefaultMode = bulkOpenModeName(c.BulkDefaultOpenMode())
|
||||
tuning := c.BulkOpenTuning()
|
||||
@ -161,6 +213,23 @@ func (s *ServerCommon) serverRuntimeSnapshot() ServerRuntimeSnapshot {
|
||||
snapshot.HasRuntimeQueue = rt.queue != nil
|
||||
snapshot.HasRuntimeStopCtx = rt.stopCtx != nil
|
||||
}
|
||||
snapshot.AuthMode = authModeName(s.securityAuthMode)
|
||||
snapshot.ProtectionMode = protectionModeName(s.securityProtectionMode)
|
||||
snapshot.ForwardSecrecySupported = s.serverSupportsForwardSecrecy()
|
||||
snapshot.ForwardSecrecyRequired = s.serverRequiresForwardSecrecy()
|
||||
peerAttachCfg := s.peerAttachSecuritySnapshot()
|
||||
snapshot.PeerAttachRequireExplicitAuth = peerAttachCfg.requireExplicitAuth
|
||||
snapshot.PeerAttachRequireChannelBinding = peerAttachCfg.requireChannelBinding
|
||||
snapshot.PeerAttachChannelBindingConfigured = peerAttachCfg.channelBinding != nil
|
||||
snapshot.PeerAttachReplayWindow = peerAttachCfg.replayWindow
|
||||
snapshot.PeerAttachReplayCapacity = peerAttachCfg.replayCapacity
|
||||
snapshot.PeerAttachExplicitAuth = s.peerAttachExplicitCount.Load()
|
||||
snapshot.PeerAttachAuthFallbacks = s.peerAttachAuthFallbackCount.Load()
|
||||
snapshot.PeerAttachAuthRejects = s.peerAttachAuthRejectCount.Load()
|
||||
snapshot.PeerAttachDowngradeRejects = s.peerAttachDowngradeRejectCount.Load()
|
||||
snapshot.PeerAttachBindingRejects = s.peerAttachBindingRejectCount.Load()
|
||||
snapshot.PeerAttachReplayRejects = s.peerAttachReplayRejectCountSnapshot()
|
||||
snapshot.PeerAttachReplayOverflowRejects = s.peerAttachReplayOverflowRejectCountSnapshot()
|
||||
tuning := s.BulkOpenTuning()
|
||||
snapshot.BulkChunkSize = tuning.ChunkSize
|
||||
snapshot.BulkWindowBytes = tuning.WindowBytes
|
||||
@ -171,6 +240,7 @@ func (s *ServerCommon) serverRuntimeSnapshot() ServerRuntimeSnapshot {
|
||||
|
||||
func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot {
|
||||
status := c.clientConnStatusSnapshot()
|
||||
attachment := c.clientConnAttachmentStateSnapshot()
|
||||
now := time.Now()
|
||||
snapshot := ClientConnRuntimeSnapshot{
|
||||
ClientID: c.clientConnIDSnapshot(),
|
||||
@ -182,6 +252,17 @@ func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot {
|
||||
TransportAttachCount: c.clientConnTransportAttachCountSnapshot(),
|
||||
TransportDetachCount: c.clientConnTransportDetachCountSnapshot(),
|
||||
LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(),
|
||||
AuthMode: authModeName(attachment.authMode),
|
||||
ProtectionMode: protectionModeName(attachment.protectionMode),
|
||||
}
|
||||
snapshot.PeerAttachAuthenticated = attachment.peerAttached
|
||||
snapshot.PeerAttachAuthFallback = attachment.peerAttachFallback
|
||||
snapshot.ProtectionKeyMode = attachment.keyMode
|
||||
snapshot.ForwardSecrecyEnabled = attachment.forwardSecrecy
|
||||
snapshot.ForwardSecrecyFallback = attachment.forwardSecrecyFallback
|
||||
snapshot.TransportSessionID = hex.EncodeToString(attachment.sessionID)
|
||||
if attachment.peerAttachAt != 0 {
|
||||
snapshot.LastPeerAttachAt = time.Unix(0, attachment.peerAttachAt)
|
||||
}
|
||||
if status.Err != nil {
|
||||
snapshot.Error = status.Err.Error()
|
||||
|
||||
@ -37,6 +37,21 @@ func TestGetClientRuntimeSnapshotDefaults(t *testing.T) {
|
||||
if snapshot.ConnectSource != "" || snapshot.ConnectNetwork != "" || snapshot.ConnectAddress != "" || snapshot.CanReconnect {
|
||||
t.Fatalf("unexpected default connect source snapshot: %+v", snapshot)
|
||||
}
|
||||
if got, want := snapshot.AuthMode, "none"; got != want {
|
||||
t.Fatalf("AuthMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := snapshot.ProtectionMode, "managed"; got != want {
|
||||
t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if snapshot.PeerAttachAuthenticated || snapshot.PeerAttachAuthFallback {
|
||||
t.Fatalf("unexpected default peer attach state: %+v", snapshot)
|
||||
}
|
||||
if !snapshot.LastPeerAttachAt.IsZero() {
|
||||
t.Fatalf("LastPeerAttachAt mismatch: got %v want zero", snapshot.LastPeerAttachAt)
|
||||
}
|
||||
if snapshot.PeerAttachRequireExplicitAuth || snapshot.PeerAttachRequireChannelBinding || snapshot.PeerAttachChannelBindingConfigured {
|
||||
t.Fatalf("unexpected default peer attach policy: %+v", snapshot)
|
||||
}
|
||||
if got, want := snapshot.BulkNetworkProfile, "default"; got != want {
|
||||
t.Fatalf("BulkNetworkProfile mismatch: got %q want %q", got, want)
|
||||
}
|
||||
@ -117,6 +132,24 @@ func TestGetServerRuntimeSnapshotDefaults(t *testing.T) {
|
||||
if !snapshot.HasRuntimeStopCtx {
|
||||
t.Fatalf("HasRuntimeStopCtx mismatch: got %v want true", snapshot.HasRuntimeStopCtx)
|
||||
}
|
||||
if got, want := snapshot.AuthMode, "none"; got != want {
|
||||
t.Fatalf("AuthMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := snapshot.ProtectionMode, "managed"; got != want {
|
||||
t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if snapshot.PeerAttachRequireExplicitAuth || snapshot.PeerAttachRequireChannelBinding || snapshot.PeerAttachChannelBindingConfigured {
|
||||
t.Fatalf("unexpected default peer attach policy: %+v", snapshot)
|
||||
}
|
||||
if got, want := snapshot.PeerAttachReplayWindow, peerAttachReplayTTL; got != want {
|
||||
t.Fatalf("PeerAttachReplayWindow mismatch: got %s want %s", got, want)
|
||||
}
|
||||
if got, want := snapshot.PeerAttachReplayCapacity, defaultPeerAttachReplayCapacity; got != want {
|
||||
t.Fatalf("PeerAttachReplayCapacity mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if snapshot.PeerAttachExplicitAuth != 0 || snapshot.PeerAttachAuthFallbacks != 0 || snapshot.PeerAttachAuthRejects != 0 || snapshot.PeerAttachDowngradeRejects != 0 || snapshot.PeerAttachBindingRejects != 0 || snapshot.PeerAttachReplayRejects != 0 || snapshot.PeerAttachReplayOverflowRejects != 0 {
|
||||
t.Fatalf("unexpected default peer attach counters: %+v", snapshot)
|
||||
}
|
||||
if got, want := snapshot.BulkChunkSize, defaultBulkChunkSize; got != want {
|
||||
t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want)
|
||||
}
|
||||
@ -476,6 +509,134 @@ func TestGetClientConnRuntimeSnapshotExposesDetachState(t *testing.T) {
|
||||
if snapshot.LastHeartbeatAt.IsZero() {
|
||||
t.Fatal("LastHeartbeatAt should be recorded")
|
||||
}
|
||||
if got, want := snapshot.AuthMode, "none"; got != want {
|
||||
t.Fatalf("AuthMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := snapshot.ProtectionMode, "managed"; got != want {
|
||||
t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if snapshot.PeerAttachAuthenticated || snapshot.PeerAttachAuthFallback {
|
||||
t.Fatalf("unexpected peer attach state: %+v", snapshot)
|
||||
}
|
||||
if !snapshot.LastPeerAttachAt.IsZero() {
|
||||
t.Fatalf("LastPeerAttachAt mismatch: got %v want zero", snapshot.LastPeerAttachAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRuntimeSnapshotsIncludePeerAttachSecurityState(t *testing.T) {
|
||||
secret := []byte("correct horse battery staple")
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
if err := UsePSKOverExternalTransportServer(server, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||
}
|
||||
})
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UsePSKOverExternalTransportClient(client, secret, testModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||
}
|
||||
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
bootstrapPeerAttachLogicalForTest(t, server, right)
|
||||
if err := client.ConnectByConn(left); err != nil {
|
||||
t.Fatalf("ConnectByConn failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
client.setByeFromServer(true)
|
||||
_ = client.Stop()
|
||||
}()
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for {
|
||||
logical := server.GetLogicalConn(client.peerIdentity)
|
||||
if logical != nil {
|
||||
authenticated, fallback, _ := logical.peerAttachAuthenticatedSnapshot()
|
||||
if authenticated && !fallback && logical.protectionModeSnapshot() == ProtectionExternal && server.peerAttachExplicitCount.Load() == 1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("peer attach security state did not converge before snapshot")
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
clientSnapshot, err := GetClientRuntimeSnapshot(client)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
|
||||
}
|
||||
if got, want := clientSnapshot.AuthMode, "psk"; got != want {
|
||||
t.Fatalf("client AuthMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := clientSnapshot.ProtectionMode, "external"; got != want {
|
||||
t.Fatalf("client ProtectionMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if clientSnapshot.PeerAttachRequireExplicitAuth || clientSnapshot.PeerAttachRequireChannelBinding || clientSnapshot.PeerAttachChannelBindingConfigured {
|
||||
t.Fatalf("unexpected client peer attach policy snapshot: %+v", clientSnapshot)
|
||||
}
|
||||
if !clientSnapshot.PeerAttachAuthenticated || clientSnapshot.PeerAttachAuthFallback {
|
||||
t.Fatalf("unexpected client peer attach state: %+v", clientSnapshot)
|
||||
}
|
||||
if clientSnapshot.LastPeerAttachAt.IsZero() {
|
||||
t.Fatal("client LastPeerAttachAt should be recorded")
|
||||
}
|
||||
|
||||
serverSnapshot, err := GetServerRuntimeSnapshot(server)
|
||||
if err != nil {
|
||||
t.Fatalf("GetServerRuntimeSnapshot failed: %v", err)
|
||||
}
|
||||
if got, want := serverSnapshot.AuthMode, "psk"; got != want {
|
||||
t.Fatalf("server AuthMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := serverSnapshot.ProtectionMode, "external"; got != want {
|
||||
t.Fatalf("server ProtectionMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if serverSnapshot.PeerAttachRequireExplicitAuth || serverSnapshot.PeerAttachRequireChannelBinding || serverSnapshot.PeerAttachChannelBindingConfigured {
|
||||
t.Fatalf("unexpected server peer attach policy snapshot: %+v", serverSnapshot)
|
||||
}
|
||||
if got, want := serverSnapshot.PeerAttachExplicitAuth, int64(1); got != want {
|
||||
t.Fatalf("PeerAttachExplicitAuth mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if serverSnapshot.PeerAttachAuthFallbacks != 0 || serverSnapshot.PeerAttachAuthRejects != 0 || serverSnapshot.PeerAttachDowngradeRejects != 0 || serverSnapshot.PeerAttachBindingRejects != 0 || serverSnapshot.PeerAttachReplayRejects != 0 || serverSnapshot.PeerAttachReplayOverflowRejects != 0 {
|
||||
t.Fatalf("unexpected server peer attach counters: %+v", serverSnapshot)
|
||||
}
|
||||
|
||||
logical := server.GetLogicalConn(client.peerIdentity)
|
||||
if logical == nil {
|
||||
t.Fatal("server logical should exist after peer attach")
|
||||
}
|
||||
logicalSnapshot, err := GetLogicalConnRuntimeSnapshot(logical)
|
||||
if err != nil {
|
||||
t.Fatalf("GetLogicalConnRuntimeSnapshot failed: %v", err)
|
||||
}
|
||||
if got, want := logicalSnapshot.AuthMode, "psk"; got != want {
|
||||
t.Fatalf("logical AuthMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := logicalSnapshot.ProtectionMode, "external"; got != want {
|
||||
t.Fatalf("logical ProtectionMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if !logicalSnapshot.PeerAttachAuthenticated || logicalSnapshot.PeerAttachAuthFallback {
|
||||
t.Fatalf("unexpected logical peer attach state: %+v", logicalSnapshot)
|
||||
}
|
||||
if logicalSnapshot.LastPeerAttachAt.IsZero() {
|
||||
t.Fatal("logical LastPeerAttachAt should be recorded")
|
||||
}
|
||||
|
||||
clientConnSnapshot, err := GetClientConnRuntimeSnapshot(clientConnFromLogical(logical))
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientConnRuntimeSnapshot failed: %v", err)
|
||||
}
|
||||
if got, want := clientConnSnapshot.AuthMode, "psk"; got != want {
|
||||
t.Fatalf("client conn AuthMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := clientConnSnapshot.ProtectionMode, "external"; got != want {
|
||||
t.Fatalf("client conn ProtectionMode mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if !clientConnSnapshot.PeerAttachAuthenticated || clientConnSnapshot.PeerAttachAuthFallback {
|
||||
t.Fatalf("unexpected client conn peer attach state: %+v", clientConnSnapshot)
|
||||
}
|
||||
if clientConnSnapshot.LastPeerAttachAt.IsZero() {
|
||||
t.Fatal("client conn LastPeerAttachAt should be recorded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetServerDetachedClientRuntimeSnapshotsFiltersAndSorts(t *testing.T) {
|
||||
|
||||
104
stream.go
104
stream.go
@ -89,6 +89,60 @@ type streamCloseSender func(context.Context, *streamHandle, bool) error
|
||||
type streamResetSender func(context.Context, *streamHandle, string) error
|
||||
type streamDataSender func(context.Context, *streamHandle, []byte) error
|
||||
|
||||
type streamReadChunk struct {
|
||||
data []byte
|
||||
release func()
|
||||
}
|
||||
|
||||
func (c *streamReadChunk) clear() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if c.release != nil {
|
||||
c.release()
|
||||
}
|
||||
c.data = nil
|
||||
c.release = nil
|
||||
}
|
||||
|
||||
type streamReadPayloadOwner struct {
|
||||
refs atomic.Int32
|
||||
release func()
|
||||
}
|
||||
|
||||
func newStreamReadPayloadOwner(release func()) *streamReadPayloadOwner {
|
||||
if release == nil {
|
||||
return nil
|
||||
}
|
||||
owner := &streamReadPayloadOwner{release: release}
|
||||
owner.refs.Store(1)
|
||||
return owner
|
||||
}
|
||||
|
||||
func (o *streamReadPayloadOwner) retainChunk() func() {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
o.refs.Add(1)
|
||||
return o.releaseChunk
|
||||
}
|
||||
|
||||
func (o *streamReadPayloadOwner) releaseChunk() {
|
||||
if o == nil {
|
||||
return
|
||||
}
|
||||
if o.refs.Add(-1) == 0 && o.release != nil {
|
||||
o.release()
|
||||
}
|
||||
}
|
||||
|
||||
func (o *streamReadPayloadOwner) done() {
|
||||
if o == nil {
|
||||
return
|
||||
}
|
||||
o.releaseChunk()
|
||||
}
|
||||
|
||||
type streamHandle struct {
|
||||
runtime *streamRuntime
|
||||
runtimeScope string
|
||||
@ -124,8 +178,8 @@ type streamHandle struct {
|
||||
remoteClosed bool
|
||||
peerReadClosed bool
|
||||
resetErr error
|
||||
readQueue [][]byte
|
||||
readBuf []byte
|
||||
readQueue []streamReadChunk
|
||||
readBuf streamReadChunk
|
||||
bufferedBytes int
|
||||
readNotify chan struct{}
|
||||
readDeadline time.Time
|
||||
@ -339,20 +393,23 @@ func (s *streamHandle) Read(p []byte) (int, error) {
|
||||
for {
|
||||
s.mu.Lock()
|
||||
localReadClosed := s.localReadClosed
|
||||
if len(s.readBuf) > 0 {
|
||||
n := copy(p, s.readBuf)
|
||||
s.readBuf = s.readBuf[n:]
|
||||
if len(s.readBuf.data) > 0 {
|
||||
n := copy(p, s.readBuf.data)
|
||||
s.readBuf.data = s.readBuf.data[n:]
|
||||
s.bufferedBytes -= n
|
||||
if s.bufferedBytes < 0 {
|
||||
s.bufferedBytes = 0
|
||||
}
|
||||
if len(s.readBuf.data) == 0 {
|
||||
s.readBuf.clear()
|
||||
}
|
||||
s.recordReadLocked(n, time.Now())
|
||||
s.mu.Unlock()
|
||||
return n, nil
|
||||
}
|
||||
if len(s.readQueue) > 0 {
|
||||
s.readBuf = s.readQueue[0]
|
||||
s.readQueue[0] = nil
|
||||
s.readQueue[0] = streamReadChunk{}
|
||||
s.readQueue = s.readQueue[1:]
|
||||
s.mu.Unlock()
|
||||
continue
|
||||
@ -824,43 +881,61 @@ func (s *streamHandle) pushOwnedChunk(chunk []byte) error {
|
||||
return s.pushChunkWithOwnership(chunk, true)
|
||||
}
|
||||
|
||||
func (s *streamHandle) pushOwnedChunkWithRelease(chunk []byte, release func()) error {
|
||||
return s.pushChunkWithOwnershipAndRelease(chunk, true, release)
|
||||
}
|
||||
|
||||
func (s *streamHandle) pushChunkWithOwnership(chunk []byte, owned bool) error {
|
||||
return s.pushChunkWithOwnershipAndRelease(chunk, owned, nil)
|
||||
}
|
||||
|
||||
func (s *streamHandle) pushChunkWithOwnershipAndRelease(chunk []byte, owned bool, release func()) error {
|
||||
if s == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
if len(chunk) == 0 {
|
||||
if release != nil {
|
||||
release()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
stored := chunk
|
||||
stored := streamReadChunk{data: chunk, release: release}
|
||||
if !owned {
|
||||
stored = append([]byte(nil), chunk...)
|
||||
stored.data = append([]byte(nil), chunk...)
|
||||
if stored.release != nil {
|
||||
stored.release()
|
||||
stored.release = nil
|
||||
}
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.resetErr != nil {
|
||||
err := s.resetErr
|
||||
s.mu.Unlock()
|
||||
stored.clear()
|
||||
return err
|
||||
}
|
||||
if s.inboundQueueLimit > 0 && s.bufferedChunkCountLocked() >= s.inboundQueueLimit {
|
||||
err := s.markResetLocked(errStreamBackpressureExceeded)
|
||||
s.mu.Unlock()
|
||||
stored.clear()
|
||||
s.notifyReadable()
|
||||
s.finalize()
|
||||
return err
|
||||
}
|
||||
if s.inboundBytesLimit > 0 && s.bufferedBytes+len(stored) > s.inboundBytesLimit {
|
||||
if s.inboundBytesLimit > 0 && s.bufferedBytes+len(stored.data) > s.inboundBytesLimit {
|
||||
err := s.markResetLocked(errStreamBackpressureExceeded)
|
||||
s.mu.Unlock()
|
||||
stored.clear()
|
||||
s.notifyReadable()
|
||||
s.finalize()
|
||||
return err
|
||||
}
|
||||
if len(s.readBuf) == 0 && len(s.readQueue) == 0 {
|
||||
if len(s.readBuf.data) == 0 && len(s.readQueue) == 0 {
|
||||
s.readBuf = stored
|
||||
} else {
|
||||
s.readQueue = append(s.readQueue, stored)
|
||||
}
|
||||
s.bufferedBytes += len(stored)
|
||||
s.bufferedBytes += len(stored.data)
|
||||
s.notifyReadableLocked()
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
@ -881,11 +956,12 @@ func (s *streamHandle) clearBufferedDataLocked() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.readBuf.clear()
|
||||
for i := range s.readQueue {
|
||||
s.readQueue[i] = nil
|
||||
s.readQueue[i].clear()
|
||||
}
|
||||
s.readQueue = nil
|
||||
s.readBuf = nil
|
||||
s.readBuf = streamReadChunk{}
|
||||
s.bufferedBytes = 0
|
||||
}
|
||||
|
||||
@ -894,7 +970,7 @@ func (s *streamHandle) bufferedChunkCountLocked() int {
|
||||
return 0
|
||||
}
|
||||
count := len(s.readQueue)
|
||||
if len(s.readBuf) > 0 {
|
||||
if len(s.readBuf.data) > 0 {
|
||||
count++
|
||||
}
|
||||
return count
|
||||
|
||||
@ -56,7 +56,59 @@ func BenchmarkStreamTCPThroughput(b *testing.B) {
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg)
|
||||
benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg, benchmarkTransportSecurityModernPSK)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStreamTCPThroughputTrustedRaw(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payloadSize int
|
||||
cfg StreamConfig
|
||||
}{
|
||||
{
|
||||
name: "default_64KiB",
|
||||
payloadSize: 64 * 1024,
|
||||
},
|
||||
{
|
||||
name: "tuned_256KiB",
|
||||
payloadSize: 256 * 1024,
|
||||
cfg: StreamConfig{
|
||||
ChunkSize: 256 * 1024,
|
||||
InboundQueueLimit: 256,
|
||||
InboundBufferedBytesLimit: 32 * 1024 * 1024,
|
||||
OutboundWindowBytes: 8 * 1024 * 1024,
|
||||
OutboundMaxInFlightChunks: 32,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tuned_512KiB",
|
||||
payloadSize: 512 * 1024,
|
||||
cfg: StreamConfig{
|
||||
ChunkSize: 512 * 1024,
|
||||
InboundQueueLimit: 256,
|
||||
InboundBufferedBytesLimit: 64 * 1024 * 1024,
|
||||
OutboundWindowBytes: 16 * 1024 * 1024,
|
||||
OutboundMaxInFlightChunks: 32,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tuned_1MiB",
|
||||
payloadSize: 1024 * 1024,
|
||||
cfg: StreamConfig{
|
||||
ChunkSize: 1024 * 1024,
|
||||
InboundQueueLimit: 256,
|
||||
InboundBufferedBytesLimit: 64 * 1024 * 1024,
|
||||
OutboundWindowBytes: 16 * 1024 * 1024,
|
||||
OutboundMaxInFlightChunks: 32,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg, benchmarkTransportSecurityTrustedRaw)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -108,19 +160,69 @@ func BenchmarkStreamTCPThroughputConcurrent(b *testing.B) {
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg)
|
||||
benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg, benchmarkTransportSecurityModernPSK)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfig) {
|
||||
func BenchmarkStreamTCPThroughputConcurrentTrustedRaw(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payloadSize int
|
||||
concurrency int
|
||||
cfg StreamConfig
|
||||
}{
|
||||
{
|
||||
name: "streams_2_512KiB",
|
||||
payloadSize: 512 * 1024,
|
||||
concurrency: 2,
|
||||
cfg: StreamConfig{
|
||||
ChunkSize: 512 * 1024,
|
||||
InboundQueueLimit: 512,
|
||||
InboundBufferedBytesLimit: 128 * 1024 * 1024,
|
||||
OutboundWindowBytes: 32 * 1024 * 1024,
|
||||
OutboundMaxInFlightChunks: 64,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streams_4_512KiB",
|
||||
payloadSize: 512 * 1024,
|
||||
concurrency: 4,
|
||||
cfg: StreamConfig{
|
||||
ChunkSize: 512 * 1024,
|
||||
InboundQueueLimit: 1024,
|
||||
InboundBufferedBytesLimit: 256 * 1024 * 1024,
|
||||
OutboundWindowBytes: 64 * 1024 * 1024,
|
||||
OutboundMaxInFlightChunks: 128,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streams_8_512KiB",
|
||||
payloadSize: 512 * 1024,
|
||||
concurrency: 8,
|
||||
cfg: StreamConfig{
|
||||
ChunkSize: 512 * 1024,
|
||||
InboundQueueLimit: 2048,
|
||||
InboundBufferedBytesLimit: 512 * 1024 * 1024,
|
||||
OutboundWindowBytes: 128 * 1024 * 1024,
|
||||
OutboundMaxInFlightChunks: 256,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg, benchmarkTransportSecurityTrustedRaw)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfig, securityMode benchmarkTransportSecurityMode) {
|
||||
b.Helper()
|
||||
|
||||
server := NewServer().(*ServerCommon)
|
||||
server.SetStreamConfig(cfg)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
b.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
benchmarkApplyServerTransportSecurity(b, server, securityMode)
|
||||
|
||||
acceptCh := make(chan StreamAcceptInfo, 1)
|
||||
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||||
@ -137,9 +239,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
client.SetStreamConfig(cfg)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
b.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
benchmarkApplyClientTransportSecurity(b, client, securityMode)
|
||||
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
|
||||
b.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
@ -198,7 +298,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi
|
||||
_ = stream.Close()
|
||||
}
|
||||
|
||||
func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, cfg StreamConfig) {
|
||||
func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, cfg StreamConfig, securityMode benchmarkTransportSecurityMode) {
|
||||
b.Helper()
|
||||
if concurrency <= 0 {
|
||||
b.Fatal("concurrency must be > 0")
|
||||
@ -206,9 +306,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu
|
||||
|
||||
server := NewServer().(*ServerCommon)
|
||||
server.SetStreamConfig(cfg)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
b.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
benchmarkApplyServerTransportSecurity(b, server, securityMode)
|
||||
|
||||
acceptCh := make(chan StreamAcceptInfo, concurrency*2)
|
||||
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||||
@ -225,9 +323,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
client.SetStreamConfig(cfg)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
b.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
benchmarkApplyClientTransportSecurity(b, client, securityMode)
|
||||
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
|
||||
b.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
|
||||
85
stream_buffer_release_test.go
Normal file
85
stream_buffer_release_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestStreamOwnedChunkReleaseAfterRead(t *testing.T) {
|
||||
stream := newStreamHandle(context.Background(), newStreamRuntime("stream-buffer-release-read"), clientFileScope(), StreamOpenRequest{
|
||||
StreamID: "stream-buffer-release-read",
|
||||
DataID: 1,
|
||||
}, 0, nil, nil, 0, nil, nil, nil, streamConfig{})
|
||||
|
||||
released := 0
|
||||
if err := stream.pushOwnedChunkWithRelease([]byte("hello"), func() {
|
||||
released++
|
||||
}); err != nil {
|
||||
t.Fatalf("pushOwnedChunkWithRelease failed: %v", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 5)
|
||||
n, err := stream.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
if n != 5 || string(buf[:n]) != "hello" {
|
||||
t.Fatalf("Read = %d %q, want 5 hello", n, string(buf[:n]))
|
||||
}
|
||||
if released != 1 {
|
||||
t.Fatalf("release count = %d, want 1", released)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamOwnedChunkReleaseOnReset(t *testing.T) {
|
||||
stream := newStreamHandle(context.Background(), newStreamRuntime("stream-buffer-release-reset"), clientFileScope(), StreamOpenRequest{
|
||||
StreamID: "stream-buffer-release-reset",
|
||||
DataID: 1,
|
||||
}, 0, nil, nil, 0, nil, nil, nil, streamConfig{})
|
||||
|
||||
released := 0
|
||||
if err := stream.pushOwnedChunkWithRelease([]byte("hello"), func() {
|
||||
released++
|
||||
}); err != nil {
|
||||
t.Fatalf("pushOwnedChunkWithRelease failed: %v", err)
|
||||
}
|
||||
|
||||
stream.markReset(errors.New("boom"))
|
||||
if released != 1 {
|
||||
t.Fatalf("release count = %d, want 1", released)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientDispatchFastStreamDataWithOwnerReleasesAfterRead(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
runtime := client.getStreamRuntime()
|
||||
if runtime == nil {
|
||||
t.Fatal("client stream runtime should not be nil")
|
||||
}
|
||||
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
|
||||
StreamID: "stream-owner",
|
||||
DataID: 23,
|
||||
Channel: StreamDataChannel,
|
||||
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
|
||||
if err := runtime.register(clientFileScope(), stream); err != nil {
|
||||
t.Fatalf("register stream failed: %v", err)
|
||||
}
|
||||
|
||||
released := 0
|
||||
owner := newStreamReadPayloadOwner(func() {
|
||||
released++
|
||||
})
|
||||
client.dispatchFastStreamDataWithOwner(streamFastDataFrame{
|
||||
DataID: 23,
|
||||
Seq: 1,
|
||||
Payload: []byte("fast-owner"),
|
||||
}, owner)
|
||||
owner.done()
|
||||
|
||||
readStreamExactly(t, stream, "fast-owner", 2*time.Second)
|
||||
if released != 1 {
|
||||
t.Fatalf("release count = %d, want 1", released)
|
||||
}
|
||||
}
|
||||
@ -83,6 +83,10 @@ func (s *ServerCommon) dispatchStreamEnvelope(logical *LogicalConn, transport *T
|
||||
}
|
||||
|
||||
func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
|
||||
c.dispatchFastStreamDataWithOwner(frame, nil)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) dispatchFastStreamDataWithOwner(frame streamFastDataFrame, owner *streamReadPayloadOwner) {
|
||||
if frame.DataID == 0 {
|
||||
return
|
||||
}
|
||||
@ -107,7 +111,13 @@ func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
|
||||
c.bestEffortRejectInboundStreamData(stream.ID(), frame.DataID, detachErr.Error())
|
||||
return
|
||||
}
|
||||
if err := stream.pushOwnedChunk(frame.Payload); err != nil {
|
||||
var err error
|
||||
if owner != nil {
|
||||
err = stream.pushOwnedChunkWithRelease(frame.Payload, owner.retainChunk())
|
||||
} else {
|
||||
err = stream.pushOwnedChunk(frame.Payload)
|
||||
}
|
||||
if err != nil {
|
||||
if c.showError || c.debugMode {
|
||||
fmt.Println("client stream push chunk error", err)
|
||||
}
|
||||
@ -118,6 +128,10 @@ func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
|
||||
}
|
||||
|
||||
func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame) {
|
||||
s.dispatchFastStreamDataWithOwner(logical, transport, conn, frame, nil)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) dispatchFastStreamDataWithOwner(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame, owner *streamReadPayloadOwner) {
|
||||
if logical == nil || frame.DataID == 0 {
|
||||
return
|
||||
}
|
||||
@ -141,7 +155,13 @@ func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *T
|
||||
s.bestEffortRejectInboundStreamData(logical, transport, conn, stream.ID(), frame.DataID, detachErr.Error())
|
||||
return
|
||||
}
|
||||
if err := stream.pushOwnedChunk(frame.Payload); err != nil {
|
||||
var err error
|
||||
if owner != nil {
|
||||
err = stream.pushOwnedChunkWithRelease(frame.Payload, owner.retainChunk())
|
||||
} else {
|
||||
err = stream.pushOwnedChunk(frame.Payload)
|
||||
}
|
||||
if err != nil {
|
||||
if s.showError || s.debugMode {
|
||||
fmt.Println("server stream push chunk error", err)
|
||||
}
|
||||
|
||||
@ -156,11 +156,12 @@ func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error
|
||||
}
|
||||
|
||||
func (c *ClientCommon) encodeFastStreamPayload(frame streamFastDataFrame) ([]byte, error) {
|
||||
if c != nil && c.fastStreamEncode != nil && frame.Flags == 0 {
|
||||
return c.fastStreamEncode(c.SecretKey, frame.DataID, frame.Seq, frame.Payload)
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
if c != nil && profile.fastStreamEncode != nil && frame.Flags == 0 {
|
||||
return profile.fastStreamEncode(profile.secretKey, frame.DataID, frame.Seq, frame.Payload)
|
||||
}
|
||||
if c != nil && c.fastPlainEncode != nil {
|
||||
return encodeStreamFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame)
|
||||
if c != nil && profile.fastPlainEncode != nil {
|
||||
return encodeStreamFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
|
||||
}
|
||||
plain, err := encodeStreamFastFramePayload(frame)
|
||||
if err != nil {
|
||||
@ -181,8 +182,9 @@ func (c *ClientCommon) encodeFastStreamBatchPayload(frames []streamFastDataFrame
|
||||
if c == nil {
|
||||
return nil, errStreamClientNil
|
||||
}
|
||||
if c.fastPlainEncode != nil {
|
||||
return encodeStreamFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames)
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
if profile.fastPlainEncode != nil {
|
||||
return encodeStreamFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
|
||||
}
|
||||
plain, err := encodeStreamFastBatchPlain(frames)
|
||||
if err != nil {
|
||||
|
||||
@ -91,25 +91,24 @@ func writeStreamFastBatchPlain(dst []byte, frames []streamFastDataFrame) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, error) {
|
||||
func walkStreamFastBatchPlain(payload []byte, fn func(streamFastDataFrame) error) (bool, error) {
|
||||
if len(payload) < 4 || string(payload[:4]) != streamFastBatchMagic {
|
||||
return nil, false, nil
|
||||
return false, nil
|
||||
}
|
||||
if len(payload) < streamFastBatchHeaderLen {
|
||||
return nil, true, errStreamFastPayloadInvalid
|
||||
return true, errStreamFastPayloadInvalid
|
||||
}
|
||||
if payload[4] != streamFastBatchVersion {
|
||||
return nil, true, errStreamFastPayloadInvalid
|
||||
return true, errStreamFastPayloadInvalid
|
||||
}
|
||||
count := int(binary.BigEndian.Uint32(payload[8:12]))
|
||||
if count <= 0 {
|
||||
return nil, true, errStreamFastPayloadInvalid
|
||||
return true, errStreamFastPayloadInvalid
|
||||
}
|
||||
frames := make([]streamFastDataFrame, 0, count)
|
||||
offset := streamFastBatchHeaderLen
|
||||
for index := 0; index < count; index++ {
|
||||
if len(payload)-offset < streamFastBatchItemHeaderLen {
|
||||
return nil, true, errStreamFastPayloadInvalid
|
||||
return true, errStreamFastPayloadInvalid
|
||||
}
|
||||
flags := payload[offset]
|
||||
dataID := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
|
||||
@ -117,29 +116,62 @@ func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, er
|
||||
payloadLen := int(binary.BigEndian.Uint32(payload[offset+20 : offset+24]))
|
||||
offset += streamFastBatchItemHeaderLen
|
||||
if dataID == 0 || payloadLen < 0 || len(payload)-offset < payloadLen {
|
||||
return nil, true, errStreamFastPayloadInvalid
|
||||
return true, errStreamFastPayloadInvalid
|
||||
}
|
||||
if fn != nil {
|
||||
if err := fn(streamFastDataFrame{
|
||||
Flags: flags,
|
||||
DataID: dataID,
|
||||
Seq: seq,
|
||||
Payload: payload[offset : offset+payloadLen],
|
||||
}); err != nil {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
frames = append(frames, streamFastDataFrame{
|
||||
Flags: flags,
|
||||
DataID: dataID,
|
||||
Seq: seq,
|
||||
Payload: payload[offset : offset+payloadLen],
|
||||
})
|
||||
offset += payloadLen
|
||||
}
|
||||
if offset != len(payload) {
|
||||
return nil, true, errStreamFastPayloadInvalid
|
||||
return true, errStreamFastPayloadInvalid
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, error) {
|
||||
frames := make([]streamFastDataFrame, 0, 1)
|
||||
matched, err := walkStreamFastBatchPlain(payload, func(frame streamFastDataFrame) error {
|
||||
frames = append(frames, frame)
|
||||
return nil
|
||||
})
|
||||
if !matched || err != nil {
|
||||
return nil, matched, err
|
||||
}
|
||||
return frames, true, nil
|
||||
}
|
||||
|
||||
func decodeStreamFastDataFrames(payload []byte) ([]streamFastDataFrame, bool, error) {
|
||||
if frames, matched, err := decodeStreamFastBatchPlain(payload); matched {
|
||||
return frames, true, err
|
||||
func walkStreamFastFrames(payload []byte, fn func(streamFastDataFrame) error) (bool, error) {
|
||||
if matched, err := walkStreamFastBatchPlain(payload, fn); matched {
|
||||
return true, err
|
||||
}
|
||||
frame, matched, err := decodeStreamFastDataFrame(payload)
|
||||
if !matched || err != nil {
|
||||
return matched, err
|
||||
}
|
||||
if fn != nil {
|
||||
if err := fn(frame); err != nil {
|
||||
return true, err
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func decodeStreamFastDataFrames(payload []byte) ([]streamFastDataFrame, bool, error) {
|
||||
frames := make([]streamFastDataFrame, 0, 1)
|
||||
matched, err := walkStreamFastFrames(payload, func(frame streamFastDataFrame) error {
|
||||
frames = append(frames, frame)
|
||||
return nil
|
||||
})
|
||||
if !matched || err != nil {
|
||||
return nil, matched, err
|
||||
}
|
||||
return []streamFastDataFrame{frame}, true, nil
|
||||
return frames, true, nil
|
||||
}
|
||||
|
||||
@ -9,7 +9,10 @@ var (
|
||||
errTransportPayloadDecryptFailed = errors.New("transport payload decrypt failed")
|
||||
)
|
||||
|
||||
func encryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgEn func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) {
|
||||
func encryptTransportPayloadCodec(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgEn func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) {
|
||||
if mode == ProtectionExternal {
|
||||
return data, nil
|
||||
}
|
||||
if runtime != nil {
|
||||
encoded, err := runtime.sealPlainPayload(data)
|
||||
if err != nil {
|
||||
@ -27,7 +30,10 @@ func encryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgEn func([]b
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func decryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) {
|
||||
func decryptTransportPayloadCodec(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) {
|
||||
if mode == ProtectionExternal {
|
||||
return data, nil
|
||||
}
|
||||
if runtime != nil {
|
||||
plain, err := runtime.openPayload(data)
|
||||
if err != nil {
|
||||
@ -45,7 +51,10 @@ func decryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgDe func([]b
|
||||
return plain, nil
|
||||
}
|
||||
|
||||
func decryptTransportPayloadCodecPooled(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte, release func()) ([]byte, func(), error) {
|
||||
func decryptTransportPayloadCodecPooled(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte, release func()) ([]byte, func(), error) {
|
||||
if mode == ProtectionExternal {
|
||||
return data, release, nil
|
||||
}
|
||||
if runtime != nil {
|
||||
plain, plainRelease, err := runtime.openPayloadPooled(data, release)
|
||||
if err != nil {
|
||||
@ -69,7 +78,10 @@ func decryptTransportPayloadCodecPooled(runtime *modernPSKCodecRuntime, msgDe fu
|
||||
return plain, nil, nil
|
||||
}
|
||||
|
||||
func decryptTransportPayloadCodecOwnedPooled(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, func(), error) {
|
||||
func decryptTransportPayloadCodecOwnedPooled(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, func(), error) {
|
||||
if mode == ProtectionExternal {
|
||||
return data, nil, nil
|
||||
}
|
||||
if runtime != nil {
|
||||
plain, plainRelease, err := runtime.openPayloadOwnedPooled(data)
|
||||
if err != nil {
|
||||
@ -124,9 +136,14 @@ func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logical := logicalConnFromClient(c)
|
||||
msgEn := c.clientConnMsgEnSnapshot()
|
||||
secretKey := c.clientConnSecretKeySnapshot()
|
||||
data, err = encryptTransportPayloadCodec(c.clientConnModernPSKRuntimeSnapshot(), msgEn, secretKey, data)
|
||||
mode := ProtectionManaged
|
||||
if logical != nil {
|
||||
mode = logical.protectionModeSnapshot()
|
||||
}
|
||||
data, err = encryptTransportPayloadCodec(mode, c.clientConnModernPSKRuntimeSnapshot(), msgEn, secretKey, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -140,7 +157,11 @@ func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte
|
||||
func (s *ServerCommon) decodeTransferMsg(c *ClientConn, data []byte) (TransferMsg, error) {
|
||||
msgDe := c.clientConnMsgDeSnapshot()
|
||||
secretKey := c.clientConnSecretKeySnapshot()
|
||||
plain, err := decryptTransportPayloadCodec(c.clientConnModernPSKRuntimeSnapshot(), msgDe, secretKey, data)
|
||||
mode := ProtectionManaged
|
||||
if logical := logicalConnFromClient(c); logical != nil {
|
||||
mode = logical.protectionModeSnapshot()
|
||||
}
|
||||
plain, err := decryptTransportPayloadCodec(mode, c.clientConnModernPSKRuntimeSnapshot(), msgDe, secretKey, data)
|
||||
if err != nil {
|
||||
return TransferMsg{}, err
|
||||
}
|
||||
@ -172,7 +193,8 @@ func (c *ClientCommon) encodeEnvelopePlain(env Envelope) ([]byte, error) {
|
||||
}
|
||||
|
||||
func (c *ClientCommon) encryptTransportPayload(data []byte) ([]byte, error) {
|
||||
return encryptTransportPayloadCodec(c.modernPSKRuntime, c.msgEn, c.SecretKey, data)
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
return encryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgEn, profile.secretKey, data)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) encodeEnvelope(env Envelope) ([]byte, error) {
|
||||
@ -196,7 +218,8 @@ func (c *ClientCommon) decodeEnvelope(data []byte) (Envelope, error) {
|
||||
}
|
||||
|
||||
func (c *ClientCommon) decryptTransportPayload(data []byte) ([]byte, error) {
|
||||
return decryptTransportPayloadCodec(c.modernPSKRuntime, c.msgDe, c.SecretKey, data)
|
||||
profile := c.clientTransportProtectionSnapshot()
|
||||
return decryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, data)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) decodeEnvelopePlain(data []byte) (Envelope, error) {
|
||||
@ -251,7 +274,7 @@ func (s *ServerCommon) encryptTransportPayloadLogical(logical *LogicalConn, data
|
||||
if msgEn == nil {
|
||||
return nil, errTransportDetached
|
||||
}
|
||||
return encryptTransportPayloadCodec(logical.modernPSKRuntimeSnapshot(), msgEn, secretKey, data)
|
||||
return encryptTransportPayloadCodec(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), msgEn, secretKey, data)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) encodeEnvelopeLogical(logical *LogicalConn, env Envelope) ([]byte, error) {
|
||||
@ -290,7 +313,7 @@ func (s *ServerCommon) decryptTransportPayloadLogical(logical *LogicalConn, data
|
||||
if msgDe == nil {
|
||||
return nil, errTransportDetached
|
||||
}
|
||||
return decryptTransportPayloadCodec(logical.modernPSKRuntimeSnapshot(), msgDe, secretKey, data)
|
||||
return decryptTransportPayloadCodec(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), msgDe, secretKey, data)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) decodeEnvelopePlain(data []byte) (Envelope, error) {
|
||||
|
||||
@ -143,6 +143,43 @@ func TestStreamBatchSenderRespectsBindingWriteDeadlineWhenReceiverStalls(t *test
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkBatchSenderFlushAggregatesAdaptivePayloadObservation(t *testing.T) {
|
||||
conn := &delayedWriteConn{delay: 20 * time.Millisecond}
|
||||
binding := newTransportBinding(conn, stario.NewQueue())
|
||||
sender := newTestBulkBatchSender(binding)
|
||||
payloadA := bytes.Repeat([]byte("a"), 128*1024)
|
||||
payloadB := bytes.Repeat([]byte("b"), 128*1024)
|
||||
|
||||
err := sender.flush([]bulkBatchRequest{
|
||||
{
|
||||
ctx: context.Background(),
|
||||
frames: []bulkFastFrame{{
|
||||
Type: bulkFastPayloadTypeData,
|
||||
DataID: 1,
|
||||
Seq: 1,
|
||||
Payload: payloadA,
|
||||
}},
|
||||
fastPathVersion: bulkFastPathVersionV1,
|
||||
},
|
||||
{
|
||||
ctx: context.Background(),
|
||||
frames: []bulkFastFrame{{
|
||||
Type: bulkFastPayloadTypeData,
|
||||
DataID: 2,
|
||||
Seq: 1,
|
||||
Payload: payloadB,
|
||||
}},
|
||||
fastPathVersion: bulkFastPathVersionV1,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("flush failed: %v", err)
|
||||
}
|
||||
if got, want := binding.bulkAdaptiveSoftPayloadBytesSnapshot(), bulkAdaptiveSoftPayloadMinBytes; got != want {
|
||||
t.Fatalf("adaptive bulk soft payload = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) {
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
@ -255,6 +292,10 @@ func (c *vectoredShortWriteConn) WriteBuffers(bufs *net.Buffers) (int64, error)
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (c *vectoredShortWriteConn) writeBuffers(bufs *net.Buffers) (int64, error) {
|
||||
return c.WriteBuffers(bufs)
|
||||
}
|
||||
|
||||
type unwrapVectoredConn struct {
|
||||
inner net.Conn
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user