feat(transport): 完成安全架构拆分并收口 stream/bulk 传输优化
- 新增 managed/external/nested 三种传输保护模式 - 新增 peer attach 显式认证、抗重放、channel binding 和可选前向保密协商 - 明确单连接注入与可重拨连接源的语义边界 - 禁止 ConnectByConn 场景下 dedicated bulk 走 sidecar,auto 模式自动回退 shared - 修正 dedicated attach 在 bootstrap/steady profile 切换下的处理逻辑 - 优化 shared bulk super-batch 与批量 framed write 路径 - 降低 stream/bulk fast path 的复制和分发损耗 - 补齐 benchmark、回归测试、运行时快照和 README 文档
This commit is contained in:
parent
f038a89771
commit
98ef9e7fcc
54
README.md
54
README.md
@ -25,6 +25,21 @@
|
|||||||
|
|
||||||
未配置时会返回 `errModernPSKRequired`。
|
未配置时会返回 `errModernPSKRequired`。
|
||||||
|
|
||||||
|
## 安全模式选择
|
||||||
|
|
||||||
|
- `UseModernPSKClient` / `UseModernPSKServer`
|
||||||
|
- bootstrap 和稳态传输都由 `notify` 自己保护
|
||||||
|
- 适合默认场景
|
||||||
|
- 支持 peer attach 显式认证、抗重放,以及在需要时协商前向保密
|
||||||
|
- `UsePSKOverExternalTransportClient` / `UsePSKOverExternalTransportServer`
|
||||||
|
- bootstrap 仍用 PSK 做认证
|
||||||
|
- 稳态阶段信任外部物理通道,不再做 `notify` 内层加密
|
||||||
|
- 适合 `tls.Conn` 或调用方自认可信的外部通道
|
||||||
|
- 不支持 `RequireForwardSecrecy`
|
||||||
|
- `UseNestedSecurityClient` / `UseNestedSecurityServer`
|
||||||
|
- 外层已有可信通道,但仍保留 `notify` 内层保护
|
||||||
|
- 适合需要“外层可信 + 内层独立保护”的场景
|
||||||
|
|
||||||
## 快速开始
|
## 快速开始
|
||||||
|
|
||||||
服务端:
|
服务端:
|
||||||
@ -83,6 +98,42 @@ func main() {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 连接入口与物理连接语义
|
||||||
|
|
||||||
|
- `Connect` / `ConnectTimeout`
|
||||||
|
- 由 `notify` 自己拨号
|
||||||
|
- 支持重连,也支持 dedicated bulk 额外 sidecar 连接
|
||||||
|
- `ConnectByFactory`
|
||||||
|
- 调用方提供 `dialFn`
|
||||||
|
- `notify` 会在需要时再次调用 `dialFn`,因此仍支持重连和 dedicated bulk
|
||||||
|
- `ConnectByConn`
|
||||||
|
- 调用方注入一个已经建立好的 `net.Conn`
|
||||||
|
- 该模式被视为“单物理连接模式”
|
||||||
|
- `OpenDedicatedBulk` 会直接返回错误
|
||||||
|
- `OpenBulk` 使用 `auto` 模式时会自动回退到 `shared`
|
||||||
|
- `ListenByListener`
|
||||||
|
- 服务端复用调用方提供的 `net.Listener`
|
||||||
|
- 适合需要和现有 listener 栈整合的场景
|
||||||
|
|
||||||
|
`dedicated bulk` 依赖额外物理连接,因此只适用于可再次拨号的 transport source。
|
||||||
|
|
||||||
|
## Peer Attach 安全策略
|
||||||
|
|
||||||
|
可通过 `SetPeerAttachSecurityConfig` 配置逻辑会话 attach 阶段的额外保护。
|
||||||
|
|
||||||
|
- `RequireExplicitAuth`
|
||||||
|
- 要求 peer attach 使用显式认证
|
||||||
|
- `RequireChannelBinding`
|
||||||
|
- 要求 attach 绑定到底层可信通道
|
||||||
|
- 启用后会隐式要求显式认证
|
||||||
|
- `ChannelBinding`
|
||||||
|
- 由调用方提供 channel binding 提取函数
|
||||||
|
- 适合外层 TLS 或其他可信通道整合
|
||||||
|
- `ReplayWindow` / `ReplayCapacity`
|
||||||
|
- 控制 attach 抗重放窗口和缓存容量
|
||||||
|
|
||||||
|
如果你选择 `UsePSKOverExternalTransport*`,并且希望 attach 阶段显式绑定到外层可信信道,建议同时配置 channel binding。
|
||||||
|
|
||||||
## RecordStream 说明
|
## RecordStream 说明
|
||||||
|
|
||||||
`RecordStream` 构建在 `Stream` 之上,适合“有边界的顺序记录”场景。
|
`RecordStream` 构建在 `Stream` 之上,适合“有边界的顺序记录”场景。
|
||||||
@ -146,6 +197,9 @@ func main() {
|
|||||||
- 共享密钥派生(Argon2id)
|
- 共享密钥派生(Argon2id)
|
||||||
- 消息层加密(AES-GCM)
|
- 消息层加密(AES-GCM)
|
||||||
- `stream` / `bulk` fast path 复用现代编码栈
|
- `stream` / `bulk` fast path 复用现代编码栈
|
||||||
|
- peer attach 显式认证 / 抗重放
|
||||||
|
- 可选 channel binding
|
||||||
|
- 可选前向保密(`UseModernPSK*` / `UseNestedSecurity*`)
|
||||||
|
|
||||||
兼容入口仍保留,但属于历史路径:
|
兼容入口仍保留,但属于历史路径:
|
||||||
|
|
||||||
|
|||||||
42
benchmark_transport_security_test.go
Normal file
42
benchmark_transport_security_test.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
type benchmarkTransportSecurityMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
benchmarkTransportSecurityModernPSK benchmarkTransportSecurityMode = "modern_psk"
|
||||||
|
benchmarkTransportSecurityTrustedRaw benchmarkTransportSecurityMode = "trusted_raw"
|
||||||
|
)
|
||||||
|
|
||||||
|
func benchmarkApplyServerTransportSecurity(tb testing.TB, server *ServerCommon, mode benchmarkTransportSecurityMode) {
|
||||||
|
tb.Helper()
|
||||||
|
switch mode {
|
||||||
|
case benchmarkTransportSecurityModernPSK:
|
||||||
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
tb.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
case benchmarkTransportSecurityTrustedRaw:
|
||||||
|
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
tb.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
tb.Fatalf("unsupported benchmark transport security mode %q", mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkApplyClientTransportSecurity(tb testing.TB, client *ClientCommon, mode benchmarkTransportSecurityMode) {
|
||||||
|
tb.Helper()
|
||||||
|
switch mode {
|
||||||
|
case benchmarkTransportSecurityModernPSK:
|
||||||
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
tb.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
case benchmarkTransportSecurityTrustedRaw:
|
||||||
|
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
tb.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
tb.Fatalf("unsupported benchmark transport security mode %q", mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
4
bulk.go
4
bulk.go
@ -161,6 +161,7 @@ var (
|
|||||||
errBulkRangeInvalid = errors.New("bulk range is invalid")
|
errBulkRangeInvalid = errors.New("bulk range is invalid")
|
||||||
errBulkBackpressureExceeded = errors.New("bulk inbound backpressure exceeded")
|
errBulkBackpressureExceeded = errors.New("bulk inbound backpressure exceeded")
|
||||||
errBulkDedicatedStreamOnly = errors.New("dedicated bulk requires stream transport")
|
errBulkDedicatedStreamOnly = errors.New("dedicated bulk requires stream transport")
|
||||||
|
errBulkDedicatedSingleConn = errors.New("dedicated bulk requires a dialable additional connection source; ConnectByConn only supports shared transport")
|
||||||
errBulkDedicatedActiveLimit = errors.New("dedicated bulk active limit reached")
|
errBulkDedicatedActiveLimit = errors.New("dedicated bulk active limit reached")
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -174,6 +175,9 @@ func clientDedicatedBulkSupportError(c *ClientCommon) error {
|
|||||||
if source := c.clientConnectSourceSnapshot(); source != nil && source.isUDP() {
|
if source := c.clientConnectSourceSnapshot(); source != nil && source.isUDP() {
|
||||||
return errBulkDedicatedStreamOnly
|
return errBulkDedicatedStreamOnly
|
||||||
}
|
}
|
||||||
|
if source := c.clientConnectSourceSnapshot(); source != nil && !source.supportsAdditionalConn() {
|
||||||
|
return errBulkDedicatedSingleConn
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -418,18 +418,18 @@ func (s *bulkBatchSender) flush(requests []bulkBatchRequest) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
writeTimeout := s.transportWriteTimeout()
|
writeTimeout := s.transportWriteTimeout()
|
||||||
|
frames := make([][]byte, 0, len(payloads))
|
||||||
|
payloadBytes := 0
|
||||||
for _, payload := range payloads {
|
for _, payload := range payloads {
|
||||||
frame := payload.payload
|
frames = append(frames, payload.payload)
|
||||||
|
payloadBytes += len(payload.payload)
|
||||||
|
}
|
||||||
started := time.Now()
|
started := time.Now()
|
||||||
err := s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error {
|
err = s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error {
|
||||||
return writeFramedPayloadUnlocked(conn, queue, frame)
|
return writeFramedPayloadBatchUnlocked(conn, queue, frames)
|
||||||
})
|
})
|
||||||
s.binding.observeBulkAdaptivePayloadWrite(len(frame), time.Since(started), writeTimeout, err)
|
s.binding.observeBulkAdaptivePayloadWrite(payloadBytes, time.Since(started), writeTimeout, err)
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *bulkBatchSender) encodeRequests(requests []bulkBatchRequest) ([]bulkBatchEncodedPayload, error) {
|
func (s *bulkBatchSender) encodeRequests(requests []bulkBatchRequest) ([]bulkBatchEncodedPayload, error) {
|
||||||
|
|||||||
@ -38,7 +38,25 @@ func BenchmarkBulkTCPThroughput(b *testing.B) {
|
|||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
b.Run(tc.name, func(b *testing.B) {
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
benchmarkBulkTCPThroughput(b, tc.payloadSize, false)
|
benchmarkBulkTCPThroughput(b, tc.payloadSize, false, benchmarkTransportSecurityModernPSK)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBulkTCPThroughputTrustedRaw(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payloadSize int
|
||||||
|
}{
|
||||||
|
{name: "chunk_256KiB", payloadSize: 256 * 1024},
|
||||||
|
{name: "chunk_512KiB", payloadSize: 512 * 1024},
|
||||||
|
{name: "chunk_768KiB", payloadSize: 768 * 1024},
|
||||||
|
{name: "chunk_1MiB", payloadSize: 1024 * 1024},
|
||||||
|
{name: "chunk_2MiB", payloadSize: 2 * 1024 * 1024},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkBulkTCPThroughput(b, tc.payloadSize, false, benchmarkTransportSecurityTrustedRaw)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -72,7 +90,25 @@ func BenchmarkBulkTCPThroughputDedicated(b *testing.B) {
|
|||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
b.Run(tc.name, func(b *testing.B) {
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
benchmarkBulkTCPThroughput(b, tc.payloadSize, true)
|
benchmarkBulkTCPThroughput(b, tc.payloadSize, true, benchmarkTransportSecurityModernPSK)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBulkTCPThroughputDedicatedTrustedRaw(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payloadSize int
|
||||||
|
}{
|
||||||
|
{name: "chunk_256KiB", payloadSize: 256 * 1024},
|
||||||
|
{name: "chunk_512KiB", payloadSize: 512 * 1024},
|
||||||
|
{name: "chunk_768KiB", payloadSize: 768 * 1024},
|
||||||
|
{name: "chunk_1MiB", payloadSize: 1024 * 1024},
|
||||||
|
{name: "chunk_2MiB", payloadSize: 2 * 1024 * 1024},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkBulkTCPThroughput(b, tc.payloadSize, true, benchmarkTransportSecurityTrustedRaw)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -107,7 +143,25 @@ func BenchmarkBulkTCPThroughputConcurrent(b *testing.B) {
|
|||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
b.Run(tc.name, func(b *testing.B) {
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false)
|
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false, benchmarkTransportSecurityModernPSK)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBulkTCPThroughputConcurrentTrustedRaw(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payloadSize int
|
||||||
|
concurrency int
|
||||||
|
}{
|
||||||
|
{name: "bulks_2_512KiB", payloadSize: 512 * 1024, concurrency: 2},
|
||||||
|
{name: "bulks_4_512KiB", payloadSize: 512 * 1024, concurrency: 4},
|
||||||
|
{name: "bulks_2_1MiB", payloadSize: 1024 * 1024, concurrency: 2},
|
||||||
|
{name: "bulks_4_1MiB", payloadSize: 1024 * 1024, concurrency: 4},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false, benchmarkTransportSecurityTrustedRaw)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -142,18 +196,34 @@ func BenchmarkBulkTCPThroughputConcurrentDedicated(b *testing.B) {
|
|||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
b.Run(tc.name, func(b *testing.B) {
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true)
|
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true, benchmarkTransportSecurityModernPSK)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
|
func BenchmarkBulkTCPThroughputConcurrentDedicatedTrustedRaw(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payloadSize int
|
||||||
|
concurrency int
|
||||||
|
}{
|
||||||
|
{name: "bulks_2_512KiB", payloadSize: 512 * 1024, concurrency: 2},
|
||||||
|
{name: "bulks_4_512KiB", payloadSize: 512 * 1024, concurrency: 4},
|
||||||
|
{name: "bulks_2_1MiB", payloadSize: 1024 * 1024, concurrency: 2},
|
||||||
|
{name: "bulks_4_1MiB", payloadSize: 1024 * 1024, concurrency: 4},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true, benchmarkTransportSecurityTrustedRaw)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool, securityMode benchmarkTransportSecurityMode) {
|
||||||
b.Helper()
|
b.Helper()
|
||||||
|
|
||||||
server := NewServer().(*ServerCommon)
|
server := NewServer().(*ServerCommon)
|
||||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
benchmarkApplyServerTransportSecurity(b, server, securityMode)
|
||||||
b.Fatalf("UseModernPSKServer failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
acceptCh := make(chan BulkAcceptInfo, 1)
|
acceptCh := make(chan BulkAcceptInfo, 1)
|
||||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||||
@ -169,9 +239,7 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
client := NewClient().(*ClientCommon)
|
client := NewClient().(*ClientCommon)
|
||||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
benchmarkApplyClientTransportSecurity(b, client, securityMode)
|
||||||
b.Fatalf("UseModernPSKClient failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
|
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
|
||||||
b.Fatalf("client Connect failed: %v", err)
|
b.Fatalf("client Connect failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -241,16 +309,14 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
|
|||||||
_ = bulk.Close()
|
_ = bulk.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool) {
|
func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool, securityMode benchmarkTransportSecurityMode) {
|
||||||
b.Helper()
|
b.Helper()
|
||||||
if concurrency <= 0 {
|
if concurrency <= 0 {
|
||||||
b.Fatal("concurrency must be > 0")
|
b.Fatal("concurrency must be > 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
server := NewServer().(*ServerCommon)
|
server := NewServer().(*ServerCommon)
|
||||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
benchmarkApplyServerTransportSecurity(b, server, securityMode)
|
||||||
b.Fatalf("UseModernPSKServer failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
acceptCh := make(chan BulkAcceptInfo, concurrency*2)
|
acceptCh := make(chan BulkAcceptInfo, concurrency*2)
|
||||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||||
@ -266,9 +332,7 @@ func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurr
|
|||||||
})
|
})
|
||||||
|
|
||||||
client := NewClient().(*ClientCommon)
|
client := NewClient().(*ClientCommon)
|
||||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
benchmarkApplyClientTransportSecurity(b, client, securityMode)
|
||||||
b.Fatalf("UseModernPSKClient failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
|
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
|
||||||
b.Fatalf("client Connect failed: %v", err)
|
b.Fatalf("client Connect failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -268,6 +268,9 @@ func readBulkDedicatedRecordPooled(conn net.Conn) ([]byte, func(), error) {
|
|||||||
func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context, timeout time.Duration) (net.Conn, error) {
|
func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context, timeout time.Duration) (net.Conn, error) {
|
||||||
source := c.clientConnectSourceSnapshot()
|
source := c.clientConnectSourceSnapshot()
|
||||||
if source != nil {
|
if source != nil {
|
||||||
|
if !source.supportsAdditionalConn() {
|
||||||
|
return nil, errBulkDedicatedSingleConn
|
||||||
|
}
|
||||||
if source.network != "" && source.addr != "" {
|
if source.network != "" && source.addr != "" {
|
||||||
if timeout > 0 {
|
if timeout > 0 {
|
||||||
return transport.DialTimeout(source.network, source.addr, timeout)
|
return transport.DialTimeout(source.network, source.addr, timeout)
|
||||||
@ -277,6 +280,7 @@ func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context, timeout time.D
|
|||||||
if source.canReconnect() {
|
if source.canReconnect() {
|
||||||
return source.dial(ctx)
|
return source.dial(ctx)
|
||||||
}
|
}
|
||||||
|
return nil, errClientReconnectSourceUnavailable
|
||||||
}
|
}
|
||||||
conn := c.clientTransportConnSnapshot()
|
conn := c.clientTransportConnSnapshot()
|
||||||
if conn == nil || conn.RemoteAddr() == nil {
|
if conn == nil || conn.RemoteAddr() == nil {
|
||||||
@ -661,7 +665,8 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn
|
|||||||
Value: reqPayload,
|
Value: reqPayload,
|
||||||
Type: MSG_SYS_WAIT,
|
Type: MSG_SYS_WAIT,
|
||||||
}
|
}
|
||||||
frame, err := encodeDirectSignalFrame(stario.NewQueue(), c.sequenceEn, c.msgEn, c.SecretKey, msg)
|
attachProfile := c.clientDedicatedBulkAttachTransportProtectionProfile()
|
||||||
|
frame, err := encodeDirectSignalFrame(stario.NewQueue(), c.sequenceEn, attachProfile.msgEn, attachProfile.secretKey, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return bulkAttachResponse{}, err
|
return bulkAttachResponse{}, err
|
||||||
}
|
}
|
||||||
@ -675,7 +680,7 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return bulkAttachResponse{}, err
|
return bulkAttachResponse{}, err
|
||||||
}
|
}
|
||||||
transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, replyPayload)
|
transfer, err := decodeDirectSignalPayload(c.sequenceDe, attachProfile.msgDe, attachProfile.secretKey, replyPayload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return bulkAttachResponse{}, err
|
return bulkAttachResponse{}, err
|
||||||
}
|
}
|
||||||
@ -685,6 +690,16 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn
|
|||||||
return decodeBulkAttachResponse(c.sequenceDe, transfer.Value)
|
return decodeBulkAttachResponse(c.sequenceDe, transfer.Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientDedicatedBulkAttachTransportProtectionProfile() transportProtectionProfile {
|
||||||
|
if c == nil {
|
||||||
|
return transportProtectionProfile{}
|
||||||
|
}
|
||||||
|
if c.securityConfigured && c.securityBootstrap.msgEn != nil && c.securityBootstrap.msgDe != nil {
|
||||||
|
return c.securityBootstrap.clone()
|
||||||
|
}
|
||||||
|
return c.clientTransportProtectionSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) readDedicatedSidecarLoop(sidecar *bulkDedicatedSidecar) {
|
func (c *ClientCommon) readDedicatedSidecarLoop(sidecar *bulkDedicatedSidecar) {
|
||||||
if c == nil || sidecar == nil || sidecar.conn == nil {
|
if c == nil || sidecar == nil || sidecar.conn == nil {
|
||||||
return
|
return
|
||||||
@ -695,7 +710,8 @@ func (c *ClientCommon) readDedicatedSidecarLoop(sidecar *bulkDedicatedSidecar) {
|
|||||||
c.handleClientDedicatedSidecarFailure(sidecar, err)
|
c.handleClientDedicatedSidecarFailure(sidecar, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
plain, plainRelease, err := decryptTransportPayloadCodecPooled(c.modernPSKRuntime, c.msgDe, c.SecretKey, payload, payloadRelease)
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
|
plain, plainRelease, err := decryptTransportPayloadCodecPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload, payloadRelease)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.handleClientDedicatedSidecarFailure(sidecar, err)
|
c.handleClientDedicatedSidecarFailure(sidecar, err)
|
||||||
return
|
return
|
||||||
@ -1023,7 +1039,7 @@ func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Mes
|
|||||||
Type: MSG_SYS_REPLY,
|
Type: MSG_SYS_REPLY,
|
||||||
}
|
}
|
||||||
if message.inboundConn != nil {
|
if message.inboundConn != nil {
|
||||||
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply)
|
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, messageInboundTransportProtectionSnapshot(&message), reply)
|
||||||
}
|
}
|
||||||
_, err = s.sendLogical(client, reply)
|
_, err = s.sendLogical(client, reply)
|
||||||
return err
|
return err
|
||||||
@ -1041,7 +1057,7 @@ func (s *ServerCommon) readDedicatedSidecarLoop(logical *LogicalConn, sidecar *b
|
|||||||
s.handleServerDedicatedSidecarFailure(logical, sidecar, err)
|
s.handleServerDedicatedSidecarFailure(logical, sidecar, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
plain, plainRelease, err := decryptTransportPayloadCodecPooled(logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, payloadRelease)
|
plain, plainRelease, err := decryptTransportPayloadCodecPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, payloadRelease)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.handleServerDedicatedSidecarFailure(logical, sidecar, err)
|
s.handleServerDedicatedSidecarFailure(logical, sidecar, err)
|
||||||
return
|
return
|
||||||
@ -1171,7 +1187,8 @@ func (c *ClientCommon) dedicatedBulkLaneSender(bulk *bulkHandle) (*bulkDedicated
|
|||||||
return nil, transportDetachedError("dedicated bulk sidecar not attached", nil)
|
return nil, transportDetachedError("dedicated bulk sidecar not attached", nil)
|
||||||
}
|
}
|
||||||
sender := sidecar.laneSenderWithFactory(func(conn net.Conn) *bulkDedicatedLaneSender {
|
sender := sidecar.laneSenderWithFactory(func(conn net.Conn) *bulkDedicatedLaneSender {
|
||||||
laneRuntime := c.modernPSKRuntime
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
|
laneRuntime := profile.runtime
|
||||||
if forked, err := forkDedicatedLaneModernPSKRuntime(laneRuntime); err == nil && forked != nil {
|
if forked, err := forkDedicatedLaneModernPSKRuntime(laneRuntime); err == nil && forked != nil {
|
||||||
laneRuntime = forked
|
laneRuntime = forked
|
||||||
}
|
}
|
||||||
@ -1198,7 +1215,7 @@ func (c *ClientCommon) sendDedicatedBulkData(ctx context.Context, bulk *bulkHand
|
|||||||
return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk)
|
return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte) (int, error) {
|
func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) {
|
||||||
if c == nil || bulk == nil {
|
if c == nil || bulk == nil {
|
||||||
return 0, errBulkClientNil
|
return 0, errBulkClientNil
|
||||||
}
|
}
|
||||||
@ -1206,7 +1223,7 @@ func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHan
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize)
|
return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize, payloadOwned)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error {
|
func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error {
|
||||||
@ -1345,7 +1362,7 @@ func (s *ServerCommon) sendDedicatedBulkData(ctx context.Context, logical *Logic
|
|||||||
return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk)
|
return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, startSeq uint64, payload []byte) (int, error) {
|
func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) {
|
||||||
if s == nil || bulk == nil {
|
if s == nil || bulk == nil {
|
||||||
return 0, errBulkServerNil
|
return 0, errBulkServerNil
|
||||||
}
|
}
|
||||||
@ -1353,7 +1370,7 @@ func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *Logi
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize)
|
return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize, payloadOwned)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error {
|
func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error {
|
||||||
@ -1419,13 +1436,14 @@ func (c *ClientCommon) encodeDedicatedBulkBatchPayload(dataID uint64, items []bu
|
|||||||
if c == nil {
|
if c == nil {
|
||||||
return nil, errBulkClientNil
|
return nil, errBulkClientNil
|
||||||
}
|
}
|
||||||
if runtime := c.modernPSKRuntime; runtime != nil {
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
|
if runtime := profile.runtime; runtime != nil {
|
||||||
return runtime.sealFilledPayload(bulkDedicatedBatchPlainLen(items), func(dst []byte) error {
|
return runtime.sealFilledPayload(bulkDedicatedBatchPlainLen(items), func(dst []byte) error {
|
||||||
return writeBulkDedicatedBatchPlain(dst, dataID, items)
|
return writeBulkDedicatedBatchPlain(dst, dataID, items)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if c.fastPlainEncode != nil {
|
if profile.fastPlainEncode != nil {
|
||||||
return encodeBulkDedicatedBatchPayloadFast(c.fastPlainEncode, c.SecretKey, dataID, items)
|
return encodeBulkDedicatedBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, dataID, items)
|
||||||
}
|
}
|
||||||
plain, err := encodeBulkDedicatedBatchPlain(dataID, items)
|
plain, err := encodeBulkDedicatedBatchPlain(dataID, items)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1453,8 +1471,9 @@ func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooledWithRuntime(runtim
|
|||||||
return writeBulkDedicatedBatchesPlain(dst, batches)
|
return writeBulkDedicatedBatchesPlain(dst, batches)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if c.fastPlainEncode != nil {
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
payload, err := encodeBulkDedicatedBatchesPayloadFast(c.fastPlainEncode, c.SecretKey, batches)
|
if profile.fastPlainEncode != nil {
|
||||||
|
payload, err := encodeBulkDedicatedBatchesPayloadFast(profile.fastPlainEncode, profile.secretKey, batches)
|
||||||
return payload, nil, err
|
return payload, nil, err
|
||||||
}
|
}
|
||||||
plain, err := encodeBulkDedicatedBatchesPlain(batches)
|
plain, err := encodeBulkDedicatedBatchesPlain(batches)
|
||||||
@ -1466,7 +1485,7 @@ func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooledWithRuntime(runtim
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooled(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) {
|
func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooled(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) {
|
||||||
return c.encodeDedicatedBulkBatchesPayloadPooledWithRuntime(c.modernPSKRuntime, batches)
|
return c.encodeDedicatedBulkBatchesPayloadPooledWithRuntime(c.clientTransportProtectionSnapshot().runtime, batches)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) encodeDedicatedBulkBatchPayloadPooled(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, func(), error) {
|
func (c *ClientCommon) encodeDedicatedBulkBatchPayloadPooled(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, func(), error) {
|
||||||
|
|||||||
@ -115,6 +115,76 @@ func TestSendDedicatedBulkAttachRequestKeepsCoalescedDedicatedPayloadUnread(t *t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSendDedicatedBulkAttachRequestUsesBootstrapProtectionEvenAfterSteadySwitch(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
client.msgID = 100
|
||||||
|
|
||||||
|
alternate, err := deriveModernPSKProtectionProfile([]byte("notify-dedicated-attach-other-secret"), integrationModernPSKOptions(), ProtectionManaged)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("deriveModernPSKProtectionProfile(alternate) failed: %v", err)
|
||||||
|
}
|
||||||
|
client.setClientTransportProtectionProfile(alternate)
|
||||||
|
|
||||||
|
bulk := newBulkHandle(context.Background(), newBulkRuntime("dedicated-attach-bootstrap-test"), clientFileScope(), BulkOpenRequest{
|
||||||
|
BulkID: "bulk-attach-bootstrap-test",
|
||||||
|
DataID: 1,
|
||||||
|
Dedicated: true,
|
||||||
|
AttachToken: "attach-token",
|
||||||
|
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
|
bootstrap := client.clientDedicatedBulkAttachTransportProtectionProfile()
|
||||||
|
encodedResp, err := client.sequenceEn(bulkAttachResponse{Accepted: true})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encode bulkAttachResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
replyFrame, err := encodeDirectSignalFrame(stario.NewQueue(), client.sequenceEn, bootstrap.msgEn, bootstrap.secretKey, TransferMsg{
|
||||||
|
ID: 101,
|
||||||
|
Key: systemBulkAttachKey,
|
||||||
|
Value: encodedResp,
|
||||||
|
Type: MSG_SYS_REPLY,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encode attach reply frame failed: %v", err)
|
||||||
|
}
|
||||||
|
conn := newBulkAttachScriptConn(replyFrame)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
resp, err := client.sendDedicatedBulkAttachRequest(ctx, conn, bulk)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sendDedicatedBulkAttachRequest failed: %v", err)
|
||||||
|
}
|
||||||
|
if !resp.Accepted {
|
||||||
|
t.Fatalf("bulk attach response = %+v, want accepted", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedReq := stario.NewQueue()
|
||||||
|
var reqMsg TransferMsg
|
||||||
|
if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request", func(msgq stario.MsgQueue) error {
|
||||||
|
transfer, err := decodeDirectSignalPayload(client.sequenceDe, bootstrap.msgDe, bootstrap.secretKey, msgq.Msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
reqMsg = transfer
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("parse written attach request with bootstrap profile failed: %v", err)
|
||||||
|
}
|
||||||
|
if reqMsg.Key != systemBulkAttachKey || reqMsg.Type != MSG_SYS_WAIT {
|
||||||
|
t.Fatalf("attach request message mismatch: %+v", reqMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := parsedReq.ParseMessageOwned(conn.writtenBytes(), "attach-request-current", func(msgq stario.MsgQueue) error {
|
||||||
|
_, err := decodeDirectSignalPayload(client.sequenceDe, alternate.msgDe, alternate.secretKey, msgq.Msg)
|
||||||
|
return err
|
||||||
|
}); !errors.Is(err, errTransportPayloadDecryptFailed) {
|
||||||
|
t.Fatalf("decode written attach request with current steady profile error = %v, want %v", err, errTransportPayloadDecryptFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHandoff(t *testing.T) {
|
func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHandoff(t *testing.T) {
|
||||||
server := NewServer().(*ServerCommon)
|
server := NewServer().(*ServerCommon)
|
||||||
UseLegacySecurityServer(server)
|
UseLegacySecurityServer(server)
|
||||||
|
|||||||
@ -83,7 +83,7 @@ func (r *bulkDedicatedLaneBatchRequest) reset() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool) {
|
func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool, borrowItems bool) {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -94,6 +94,10 @@ func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint
|
|||||||
if deadline, ok := ctx.Deadline(); ok {
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
r.Deadline = deadline
|
r.Deadline = deadline
|
||||||
}
|
}
|
||||||
|
if borrowItems {
|
||||||
|
r.Items = items
|
||||||
|
return
|
||||||
|
}
|
||||||
if cap(r.Items) < len(items) {
|
if cap(r.Items) < len(items) {
|
||||||
r.Items = make([]bulkDedicatedSendRequest, len(items))
|
r.Items = make([]bulkDedicatedSendRequest, len(items))
|
||||||
} else {
|
} else {
|
||||||
@ -119,10 +123,10 @@ func (s *bulkDedicatedLaneSender) submitData(ctx context.Context, dataID uint64,
|
|||||||
Seq: seq,
|
Seq: seq,
|
||||||
Payload: append([]byte(nil), payload...),
|
Payload: append([]byte(nil), payload...),
|
||||||
}}
|
}}
|
||||||
return s.submitBatch(ctx, dataID, items, false)
|
return s.submitBatch(ctx, dataID, items, false, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int) (int, error) {
|
func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int, payloadOwned bool) (int, error) {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return 0, errTransportDetached
|
return 0, errTransportDetached
|
||||||
}
|
}
|
||||||
@ -132,6 +136,9 @@ func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64
|
|||||||
if chunkSize <= 0 {
|
if chunkSize <= 0 {
|
||||||
chunkSize = defaultBulkChunkSize
|
chunkSize = defaultBulkChunkSize
|
||||||
}
|
}
|
||||||
|
if submitted, written, err := s.tryDirectSubmitWrite(ctx, dataID, startSeq, payload, chunkSize); submitted {
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
written := 0
|
written := 0
|
||||||
seq := startSeq
|
seq := startSeq
|
||||||
for written < len(payload) {
|
for written < len(payload) {
|
||||||
@ -170,7 +177,7 @@ func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64
|
|||||||
seq++
|
seq++
|
||||||
written = end
|
written = end
|
||||||
}
|
}
|
||||||
if err := s.submitWriteBatch(ctx, dataID, items); err != nil {
|
if err := s.submitWriteBatch(ctx, dataID, items, payloadOwned); err != nil {
|
||||||
return start, err
|
return start, err
|
||||||
}
|
}
|
||||||
start = written
|
start = written
|
||||||
@ -178,7 +185,7 @@ func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64
|
|||||||
return written, nil
|
return written, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) error {
|
func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, _ bool) error {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return errTransportDetached
|
return errTransportDetached
|
||||||
}
|
}
|
||||||
@ -188,8 +195,7 @@ func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID u
|
|||||||
if submitted, err := s.tryDirectSubmitBatch(ctx, dataID, items); submitted {
|
if submitted, err := s.tryDirectSubmitBatch(ctx, dataID, items); submitted {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
queuedItems := copyBulkDedicatedSendRequests(items)
|
return s.submitBatch(ctx, dataID, items, true, true)
|
||||||
return s.submitBatch(ctx, dataID, queuedItems, true)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint64, frameType uint8, flags uint8, seq uint64, payload []byte) error {
|
func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint64, frameType uint8, flags uint8, seq uint64, payload []byte) error {
|
||||||
@ -204,10 +210,10 @@ func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint
|
|||||||
if len(payload) > 0 {
|
if len(payload) > 0 {
|
||||||
items[0].Payload = append([]byte(nil), payload...)
|
items[0].Payload = append([]byte(nil), payload...)
|
||||||
}
|
}
|
||||||
return s.submitBatch(ctx, dataID, items, true)
|
return s.submitBatch(ctx, dataID, items, true, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool) error {
|
func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool, borrowItems bool) error {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return errTransportDetached
|
return errTransportDetached
|
||||||
}
|
}
|
||||||
@ -218,7 +224,7 @@ func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
req := getBulkDedicatedLaneBatchRequest()
|
req := getBulkDedicatedLaneBatchRequest()
|
||||||
req.prepare(ctx, dataID, items, wait)
|
req.prepare(ctx, dataID, items, wait, borrowItems)
|
||||||
s.queued.Add(1)
|
s.queued.Add(1)
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@ -237,6 +243,104 @@ func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedLaneSender) tryDirectSubmitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int) (bool, int, error) {
|
||||||
|
if s == nil {
|
||||||
|
return true, 0, errTransportDetached
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return true, 0, nil
|
||||||
|
}
|
||||||
|
if chunkSize <= 0 {
|
||||||
|
chunkSize = defaultBulkChunkSize
|
||||||
|
}
|
||||||
|
if err := s.errSnapshot(); err != nil {
|
||||||
|
return true, 0, err
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return true, 0, normalizeStreamDeadlineError(ctx.Err())
|
||||||
|
case <-s.stopCh:
|
||||||
|
return true, 0, s.stoppedErr()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if s.queued.Load() != 0 {
|
||||||
|
return false, 0, nil
|
||||||
|
}
|
||||||
|
if !s.flushMu.TryLock() {
|
||||||
|
return false, 0, nil
|
||||||
|
}
|
||||||
|
defer s.flushMu.Unlock()
|
||||||
|
if s.queued.Load() != 0 {
|
||||||
|
return false, 0, nil
|
||||||
|
}
|
||||||
|
if err := s.errSnapshot(); err != nil {
|
||||||
|
return true, 0, err
|
||||||
|
}
|
||||||
|
written := 0
|
||||||
|
seq := startSeq
|
||||||
|
deadline, _ := ctx.Deadline()
|
||||||
|
for written < len(payload) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return true, written, normalizeStreamDeadlineError(ctx.Err())
|
||||||
|
case <-s.stopCh:
|
||||||
|
return true, written, s.stoppedErr()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest
|
||||||
|
items := itemBuf[:0]
|
||||||
|
batchBytes := bulkDedicatedBatchHeaderLen
|
||||||
|
start := written
|
||||||
|
for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems {
|
||||||
|
end := written + chunkSize
|
||||||
|
if end > len(payload) {
|
||||||
|
end = len(payload)
|
||||||
|
}
|
||||||
|
itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written)
|
||||||
|
if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
items = append(items, bulkDedicatedSendRequest{
|
||||||
|
Type: bulkFastPayloadTypeData,
|
||||||
|
Seq: seq,
|
||||||
|
Payload: payload[written:end],
|
||||||
|
})
|
||||||
|
batchBytes += itemLen
|
||||||
|
seq++
|
||||||
|
written = end
|
||||||
|
}
|
||||||
|
if len(items) == 0 {
|
||||||
|
end := written + chunkSize
|
||||||
|
if end > len(payload) {
|
||||||
|
end = len(payload)
|
||||||
|
}
|
||||||
|
items = append(items, bulkDedicatedSendRequest{
|
||||||
|
Type: bulkFastPayloadTypeData,
|
||||||
|
Seq: seq,
|
||||||
|
Payload: payload[written:end],
|
||||||
|
})
|
||||||
|
seq++
|
||||||
|
written = end
|
||||||
|
}
|
||||||
|
if err := s.flush([]bulkDedicatedOutboundBatch{{
|
||||||
|
DataID: dataID,
|
||||||
|
Items: items,
|
||||||
|
}}, deadline); err != nil {
|
||||||
|
err = normalizeDedicatedBulkSendError(err)
|
||||||
|
s.setErr(err)
|
||||||
|
s.failPending(err)
|
||||||
|
if s.fail != nil {
|
||||||
|
go s.fail(err)
|
||||||
|
}
|
||||||
|
return true, start, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, written, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *bulkDedicatedLaneSender) tryDirectSubmitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) (bool, error) {
|
func (s *bulkDedicatedLaneSender) tryDirectSubmitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) (bool, error) {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return true, errTransportDetached
|
return true, errTransportDetached
|
||||||
|
|||||||
@ -7,6 +7,57 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestBulkDedicatedLaneBatchRequestPrepareBorrowedSharesItems(t *testing.T) {
|
||||||
|
req := getBulkDedicatedLaneBatchRequest()
|
||||||
|
defer req.recycle()
|
||||||
|
|
||||||
|
items := []bulkDedicatedSendRequest{{
|
||||||
|
Type: bulkFastPayloadTypeData,
|
||||||
|
Seq: 7,
|
||||||
|
Payload: []byte("hello"),
|
||||||
|
}}
|
||||||
|
req.prepare(context.Background(), 11, items, true, true)
|
||||||
|
|
||||||
|
if got, want := len(req.Items), 1; got != want {
|
||||||
|
t.Fatalf("prepared item count = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if &req.Items[0] != &items[0] {
|
||||||
|
t.Fatal("prepare with borrowed items should share request items")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBulkDedicatedLaneSenderTryDirectSubmitWriteFlushesWholePayload(t *testing.T) {
|
||||||
|
conn := &shortWriteBulkRecordConn{maxPerWrite: 1024}
|
||||||
|
encodeCalls := 0
|
||||||
|
sender := &bulkDedicatedLaneSender{
|
||||||
|
conn: conn,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
encode: func(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) {
|
||||||
|
encodeCalls++
|
||||||
|
payload, err := encodeBulkDedicatedBatchesPlain(batches)
|
||||||
|
return payload, nil, err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := bytes.Repeat([]byte("a"), 3*defaultBulkChunkSize)
|
||||||
|
submitted, written, err := sender.tryDirectSubmitWrite(context.Background(), 9, 1, payload, defaultBulkChunkSize)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tryDirectSubmitWrite error = %v", err)
|
||||||
|
}
|
||||||
|
if !submitted {
|
||||||
|
t.Fatal("tryDirectSubmitWrite should submit directly")
|
||||||
|
}
|
||||||
|
if got, want := written, len(payload); got != want {
|
||||||
|
t.Fatalf("written = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if encodeCalls == 0 {
|
||||||
|
t.Fatal("encode should be called at least once")
|
||||||
|
}
|
||||||
|
if got := sender.queued.Load(); got != 0 {
|
||||||
|
t.Fatalf("queued requests = %d, want 0", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBulkDedicatedLaneSenderCollectBatchRequestsBatchesAcrossDataIDs(t *testing.T) {
|
func TestBulkDedicatedLaneSenderCollectBatchRequestsBatchesAcrossDataIDs(t *testing.T) {
|
||||||
sender := &bulkDedicatedLaneSender{
|
sender := &bulkDedicatedLaneSender{
|
||||||
reqCh: make(chan *bulkDedicatedLaneBatchRequest, 3),
|
reqCh: make(chan *bulkDedicatedLaneBatchRequest, 3),
|
||||||
|
|||||||
@ -162,7 +162,8 @@ func (c *ClientCommon) tryDispatchBorrowedBulkTransportPayload(payload []byte) b
|
|||||||
if c == nil || len(payload) == 0 {
|
if c == nil || len(payload) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(c.modernPSKRuntime, c.msgDe, c.SecretKey, payload)
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
|
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if c.showError || c.debugMode {
|
if c.showError || c.debugMode {
|
||||||
fmt.Println("client decode transport payload error", err)
|
fmt.Println("client decode transport payload error", err)
|
||||||
@ -194,7 +195,7 @@ func (s *ServerCommon) tryDispatchBorrowedBulkTransportPayload(source interface{
|
|||||||
if logical == nil {
|
if logical == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload)
|
plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s.showError || s.debugMode {
|
if s.showError || s.debugMode {
|
||||||
fmt.Println("server decode transport payload error", err)
|
fmt.Println("server decode transport payload error", err)
|
||||||
|
|||||||
165
bulk_fastpath.go
165
bulk_fastpath.go
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -132,8 +133,9 @@ func (c *ClientCommon) encodeBulkFastPayload(frame bulkFastFrame) ([]byte, error
|
|||||||
if c == nil {
|
if c == nil {
|
||||||
return nil, errBulkClientNil
|
return nil, errBulkClientNil
|
||||||
}
|
}
|
||||||
if c.fastPlainEncode != nil {
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
return encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame)
|
if profile.fastPlainEncode != nil {
|
||||||
|
return encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
|
||||||
}
|
}
|
||||||
plain, err := encodeBulkFastFramePayload(frame)
|
plain, err := encodeBulkFastFramePayload(frame)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -146,8 +148,9 @@ func (c *ClientCommon) encodeBulkFastBatchPayload(frames []bulkFastFrame) ([]byt
|
|||||||
if c == nil {
|
if c == nil {
|
||||||
return nil, errBulkClientNil
|
return nil, errBulkClientNil
|
||||||
}
|
}
|
||||||
if c.fastPlainEncode != nil {
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
return encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames)
|
if profile.fastPlainEncode != nil {
|
||||||
|
return encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
|
||||||
}
|
}
|
||||||
plain, err := encodeBulkFastBatchPlain(frames)
|
plain, err := encodeBulkFastBatchPlain(frames)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -160,11 +163,12 @@ func (c *ClientCommon) encodeBulkFastPayloadPooled(frame bulkFastFrame) ([]byte,
|
|||||||
if c == nil {
|
if c == nil {
|
||||||
return nil, nil, errBulkClientNil
|
return nil, nil, errBulkClientNil
|
||||||
}
|
}
|
||||||
if runtime := c.modernPSKRuntime; runtime != nil {
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
|
if runtime := profile.runtime; runtime != nil {
|
||||||
return encodeBulkFastFramePayloadPooled(runtime, frame)
|
return encodeBulkFastFramePayloadPooled(runtime, frame)
|
||||||
}
|
}
|
||||||
if c.fastPlainEncode != nil {
|
if profile.fastPlainEncode != nil {
|
||||||
payload, err := encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame)
|
payload, err := encodeBulkFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
|
||||||
return payload, nil, err
|
return payload, nil, err
|
||||||
}
|
}
|
||||||
plain, err := encodeBulkFastFramePayload(frame)
|
plain, err := encodeBulkFastFramePayload(frame)
|
||||||
@ -179,11 +183,12 @@ func (c *ClientCommon) encodeBulkFastBatchPayloadPooled(frames []bulkFastFrame)
|
|||||||
if c == nil {
|
if c == nil {
|
||||||
return nil, nil, errBulkClientNil
|
return nil, nil, errBulkClientNil
|
||||||
}
|
}
|
||||||
if runtime := c.modernPSKRuntime; runtime != nil {
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
|
if runtime := profile.runtime; runtime != nil {
|
||||||
return encodeBulkFastBatchPayloadPooled(runtime, frames)
|
return encodeBulkFastBatchPayloadPooled(runtime, frames)
|
||||||
}
|
}
|
||||||
if c.fastPlainEncode != nil {
|
if profile.fastPlainEncode != nil {
|
||||||
payload, err := encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames)
|
payload, err := encodeBulkFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
|
||||||
return payload, nil, err
|
return payload, nil, err
|
||||||
}
|
}
|
||||||
plain, err := encodeBulkFastBatchPlain(frames)
|
plain, err := encodeBulkFastBatchPlain(frames)
|
||||||
@ -460,28 +465,126 @@ func putBulkFastFrameScratch(buf []byte) {
|
|||||||
bulkFastFrameScratchPool.Put(buf[:0])
|
bulkFastFrameScratchPool.Put(buf[:0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func transportFastPayloadMagic(payload []byte) string {
|
||||||
|
if len(payload) < 4 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(payload[:4])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) decryptTransportPayloadPooled(payload []byte, release func()) ([]byte, func(), error) {
|
||||||
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
|
return decryptTransportPayloadCodecPooled(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, payload, release)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) decryptTransportPayloadLogicalPooled(logical *LogicalConn, payload []byte, release func()) ([]byte, func(), error) {
|
||||||
|
if logical == nil {
|
||||||
|
if release != nil {
|
||||||
|
release()
|
||||||
|
}
|
||||||
|
return nil, nil, errTransportDetached
|
||||||
|
}
|
||||||
|
return decryptTransportPayloadCodecPooled(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, release)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) tryDispatchBorrowedTransportPlain(plain []byte, release func()) bool {
|
||||||
|
switch transportFastPayloadMagic(plain) {
|
||||||
|
case bulkFastPayloadMagic, bulkFastBatchMagic:
|
||||||
|
owner := newBulkReadPayloadOwner(release)
|
||||||
|
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
|
||||||
|
c.dispatchFastBulkFrameWithOwner(frame, owner)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if owner != nil {
|
||||||
|
owner.done()
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
walkErr = errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
if walkErr != nil && (c.showError || c.debugMode) {
|
||||||
|
fmt.Println("client decode bulk fast payload error", walkErr)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
case streamFastPayloadMagic, streamFastBatchMagic:
|
||||||
|
owner := newStreamReadPayloadOwner(release)
|
||||||
|
matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
|
||||||
|
c.dispatchFastStreamDataWithOwner(frame, owner)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if owner != nil {
|
||||||
|
owner.done()
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
walkErr = errStreamFastPayloadInvalid
|
||||||
|
}
|
||||||
|
if walkErr != nil && (c.showError || c.debugMode) {
|
||||||
|
fmt.Println("client decode stream fast payload error", walkErr)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) tryDispatchBorrowedTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, release func()) bool {
|
||||||
|
switch transportFastPayloadMagic(plain) {
|
||||||
|
case bulkFastPayloadMagic, bulkFastBatchMagic:
|
||||||
|
owner := newBulkReadPayloadOwner(release)
|
||||||
|
matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
|
||||||
|
s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, owner)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if owner != nil {
|
||||||
|
owner.done()
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
walkErr = errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
if walkErr != nil && (s.showError || s.debugMode) {
|
||||||
|
fmt.Println("server decode bulk fast payload error", walkErr)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
case streamFastPayloadMagic, streamFastBatchMagic:
|
||||||
|
owner := newStreamReadPayloadOwner(release)
|
||||||
|
matched, walkErr := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
|
||||||
|
s.dispatchFastStreamDataWithOwner(logical, transport, conn, frame, owner)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if owner != nil {
|
||||||
|
owner.done()
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
walkErr = errStreamFastPayloadInvalid
|
||||||
|
}
|
||||||
|
if walkErr != nil && (s.showError || s.debugMode) {
|
||||||
|
fmt.Println("server decode stream fast payload error", walkErr)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error {
|
func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error {
|
||||||
plain, err := c.decryptTransportPayload(payload)
|
plain, err := c.decryptTransportPayload(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if frames, matched, err := decodeBulkFastFrames(plain); matched {
|
return c.dispatchInboundTransportPlain(plain, now)
|
||||||
if err != nil {
|
}
|
||||||
return err
|
|
||||||
}
|
func (c *ClientCommon) dispatchInboundTransportPlain(plain []byte, now time.Time) error {
|
||||||
for _, frame := range frames {
|
if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
|
||||||
c.dispatchFastBulkFrame(frame)
|
c.dispatchFastBulkFrame(frame)
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}); matched {
|
||||||
if frames, matched, err := decodeStreamFastDataFrames(plain); matched {
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, frame := range frames {
|
if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
|
||||||
c.dispatchFastStreamData(frame)
|
c.dispatchFastStreamData(frame)
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
|
}); matched {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
env, err := c.decodeEnvelopePlain(plain)
|
env, err := c.decodeEnvelopePlain(plain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -502,23 +605,21 @@ func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, tra
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if frames, matched, err := decodeBulkFastFrames(plain); matched {
|
return s.dispatchInboundTransportPlain(logical, transport, conn, plain, now)
|
||||||
if err != nil {
|
}
|
||||||
return err
|
|
||||||
}
|
func (s *ServerCommon) dispatchInboundTransportPlain(logical *LogicalConn, transport *TransportConn, conn net.Conn, plain []byte, now time.Time) error {
|
||||||
for _, frame := range frames {
|
if matched, err := walkBulkFastFrames(plain, func(frame bulkFastFrame) error {
|
||||||
s.dispatchFastBulkFrame(logical, transport, conn, frame)
|
s.dispatchFastBulkFrame(logical, transport, conn, frame)
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}); matched {
|
||||||
if frames, matched, err := decodeStreamFastDataFrames(plain); matched {
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, frame := range frames {
|
if matched, err := walkStreamFastFrames(plain, func(frame streamFastDataFrame) error {
|
||||||
s.dispatchFastStreamData(logical, transport, conn, frame)
|
s.dispatchFastStreamData(logical, transport, conn, frame)
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
|
}); matched {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
env, err := s.decodeEnvelopePlain(plain)
|
env, err := s.decodeEnvelopePlain(plain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -100,6 +100,176 @@ func TestBulkOpenAutoUDPFallsBackToShared(t *testing.T) {
|
|||||||
_ = accepted.Bulk.Close()
|
_ = accepted.Bulk.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenDedicatedBulkConnectByConnRejectedAsSingleConnMode(t *testing.T) {
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||||
|
}
|
||||||
|
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, right)
|
||||||
|
if err := client.ConnectByConn(left); err != nil {
|
||||||
|
t.Fatalf("ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
_ = client.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{
|
||||||
|
Range: BulkRange{
|
||||||
|
Offset: 0,
|
||||||
|
Length: 128,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if !errors.Is(err, errBulkDedicatedSingleConn) {
|
||||||
|
t.Fatalf("client OpenDedicatedBulk over ConnectByConn error = %v, want %v", err, errBulkDedicatedSingleConn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenBulkAutoConnectByConnFallsBackToShared(t *testing.T) {
|
||||||
|
acceptCh := make(chan BulkAcceptInfo, 2)
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||||
|
}
|
||||||
|
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||||
|
acceptCh <- info
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, right)
|
||||||
|
if err := client.ConnectByConn(left); err != nil {
|
||||||
|
t.Fatalf("ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
_ = client.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
|
||||||
|
Mode: BulkOpenModeAuto,
|
||||||
|
Range: BulkRange{
|
||||||
|
Offset: 0,
|
||||||
|
Length: 128,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client OpenBulk auto over ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
if bulk.Snapshot().Dedicated {
|
||||||
|
t.Fatal("client OpenBulk auto over ConnectByConn should fall back to shared")
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = bulk.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second)
|
||||||
|
if accepted.Dedicated {
|
||||||
|
t.Fatal("server accepted bulk should be shared after ConnectByConn auto fallback")
|
||||||
|
}
|
||||||
|
if _, err := bulk.Write([]byte("shared-over-single-conn")); err != nil {
|
||||||
|
t.Fatalf("client bulk Write failed: %v", err)
|
||||||
|
}
|
||||||
|
readBulkExactly(t, accepted.Bulk, "shared-over-single-conn", 2*time.Second)
|
||||||
|
select {
|
||||||
|
case extra := <-acceptCh:
|
||||||
|
t.Fatalf("unexpected extra server bulk accept: %+v", extra)
|
||||||
|
case <-time.After(300 * time.Millisecond):
|
||||||
|
}
|
||||||
|
_ = accepted.Bulk.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenDedicatedBulkExternalTransportDialableSourceSucceeds(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UsePSKOverExternalTransportServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||||
|
}
|
||||||
|
acceptCh := make(chan BulkAcceptInfo, 2)
|
||||||
|
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||||
|
acceptCh <- info
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||||
|
t.Fatalf("server Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UsePSKOverExternalTransportClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||||||
|
t.Fatalf("client Connect failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
_ = client.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
bulk, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{
|
||||||
|
ID: "external-dedicated-dialable",
|
||||||
|
Range: BulkRange{
|
||||||
|
Offset: 0,
|
||||||
|
Length: 64,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client OpenDedicatedBulk failed: %v", err)
|
||||||
|
}
|
||||||
|
if !bulk.Snapshot().Dedicated {
|
||||||
|
t.Fatal("client OpenDedicatedBulk over dialable external transport should stay dedicated")
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = bulk.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
accepted := waitAcceptedBulkByID(t, acceptCh, bulk.ID(), 2*time.Second)
|
||||||
|
if !accepted.Dedicated {
|
||||||
|
t.Fatal("server accepted bulk should stay dedicated over dialable external transport")
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = accepted.Bulk.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientHandle := bulk.(*bulkHandle)
|
||||||
|
if clientHandle.dedicatedConnSnapshot() == nil {
|
||||||
|
t.Fatal("client dedicated sidecar conn should be attached")
|
||||||
|
}
|
||||||
|
if mainConn := client.clientTransportConnSnapshot(); mainConn != nil && clientHandle.dedicatedConnSnapshot() == mainConn {
|
||||||
|
t.Fatal("client dedicated sidecar should use an additional physical connection")
|
||||||
|
}
|
||||||
|
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionExternal {
|
||||||
|
t.Fatalf("client steady protection mode = %v, want %v", got, ProtectionExternal)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := "external-dedicated-sidecar"
|
||||||
|
if _, err := bulk.Write([]byte(payload)); err != nil {
|
||||||
|
t.Fatalf("client bulk Write failed: %v", err)
|
||||||
|
}
|
||||||
|
readBulkExactly(t, accepted.Bulk, payload, 2*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenDedicatedBulkWaitsForActiveSlotUntilContextDeadline(t *testing.T) {
|
func TestOpenDedicatedBulkWaitsForActiveSlotUntilContextDeadline(t *testing.T) {
|
||||||
server := NewServer().(*ServerCommon)
|
server := NewServer().(*ServerCommon)
|
||||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
|||||||
14
client.go
14
client.go
@ -38,6 +38,18 @@ type ClientCommon struct {
|
|||||||
modernPSKRuntime *modernPSKCodecRuntime
|
modernPSKRuntime *modernPSKCodecRuntime
|
||||||
handshakeRsaPubKey []byte
|
handshakeRsaPubKey []byte
|
||||||
SecretKey []byte
|
SecretKey []byte
|
||||||
|
transportProtection atomic.Pointer[transportProtectionProfile]
|
||||||
|
peerAttachSecurity atomic.Pointer[peerAttachSecurityState]
|
||||||
|
securityBootstrap transportProtectionProfile
|
||||||
|
securitySteady transportProtectionProfile
|
||||||
|
securitySteadyNegotiated transportProtectionProfile
|
||||||
|
securityAuthMode AuthMode
|
||||||
|
securityProtectionMode ProtectionMode
|
||||||
|
securityRequireForwardSecrecy bool
|
||||||
|
securityConfigured bool
|
||||||
|
peerAttachAuthenticated bool
|
||||||
|
peerAttachAuthFallback bool
|
||||||
|
peerAttachAt int64
|
||||||
noFinSyncMsgMaxKeepSeconds int
|
noFinSyncMsgMaxKeepSeconds int
|
||||||
lastHeartbeat int64
|
lastHeartbeat int64
|
||||||
heartbeatPeriod time.Duration
|
heartbeatPeriod time.Duration
|
||||||
@ -134,6 +146,8 @@ func NewClient() Client {
|
|||||||
client.fileEventObserver = normalizeFileEventCallback(nil)
|
client.fileEventObserver = normalizeFileEventCallback(nil)
|
||||||
client.stopCtx, client.stopFn = context.WithCancel(context.Background())
|
client.stopCtx, client.stopFn = context.WithCancel(context.Background())
|
||||||
client.sessionRuntime.Store(newClientSessionRuntimeBase(client.stopCtx, client.stopFn))
|
client.sessionRuntime.Store(newClientSessionRuntimeBase(client.stopCtx, client.stopFn))
|
||||||
|
client.setClientTransportProtectionProfile(defaultTransportProtectionProfile())
|
||||||
|
client.peerAttachSecurity.Store(defaultPeerAttachSecurityState())
|
||||||
bindClientStreamControl(&client)
|
bindClientStreamControl(&client)
|
||||||
bindClientBulkControl(&client)
|
bindClientBulkControl(&client)
|
||||||
client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler)
|
client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler)
|
||||||
|
|||||||
@ -294,7 +294,7 @@ func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender {
|
|||||||
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return c.sendDedicatedBulkWrite(ctx, bulk, startSeq, payload)
|
return c.sendDedicatedBulkWrite(ctx, bulk, startSeq, payload, payloadOwned)
|
||||||
}
|
}
|
||||||
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
|
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
|
||||||
return 0, errTransportDetached
|
return 0, errTransportDetached
|
||||||
|
|||||||
@ -65,32 +65,40 @@ func (c *ClientCommon) RecoverTransferSnapshots(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte {
|
func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte {
|
||||||
return c.msgEn
|
return c.clientTransportProtectionSnapshot().msgEn
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: SetMsgEn overrides the transport codec directly.
|
// Deprecated: SetMsgEn overrides the transport codec directly.
|
||||||
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||||
func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
|
func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
|
||||||
c.msgEn = fn
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
c.fastStreamEncode = nil
|
profile.mode = ProtectionManaged
|
||||||
c.fastBulkEncode = nil
|
profile.msgEn = fn
|
||||||
c.fastPlainEncode = nil
|
profile.fastStreamEncode = nil
|
||||||
c.modernPSKRuntime = nil
|
profile.fastBulkEncode = nil
|
||||||
|
profile.fastPlainEncode = nil
|
||||||
|
profile.runtime = nil
|
||||||
|
c.setClientTransportProtectionProfile(profile)
|
||||||
|
c.clearClientSecurityProfiles()
|
||||||
c.securityReadyCheck = false
|
c.securityReadyCheck = false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte {
|
func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte {
|
||||||
return c.msgDe
|
return c.clientTransportProtectionSnapshot().msgDe
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: SetMsgDe overrides the transport codec directly.
|
// Deprecated: SetMsgDe overrides the transport codec directly.
|
||||||
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||||
func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
|
func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
|
||||||
c.msgDe = fn
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
c.fastStreamEncode = nil
|
profile.mode = ProtectionManaged
|
||||||
c.fastBulkEncode = nil
|
profile.msgDe = fn
|
||||||
c.fastPlainEncode = nil
|
profile.fastStreamEncode = nil
|
||||||
c.modernPSKRuntime = nil
|
profile.fastBulkEncode = nil
|
||||||
|
profile.fastPlainEncode = nil
|
||||||
|
profile.runtime = nil
|
||||||
|
c.setClientTransportProtectionProfile(profile)
|
||||||
|
c.clearClientSecurityProfiles()
|
||||||
c.securityReadyCheck = false
|
c.securityReadyCheck = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -103,20 +111,24 @@ func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) GetSecretKey() []byte {
|
func (c *ClientCommon) GetSecretKey() []byte {
|
||||||
return c.SecretKey
|
return c.clientTransportProtectionSnapshot().secretKey
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: SetSecretKey injects a raw transport key directly.
|
// Deprecated: SetSecretKey injects a raw transport key directly.
|
||||||
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||||
func (c *ClientCommon) SetSecretKey(key []byte) {
|
func (c *ClientCommon) SetSecretKey(key []byte) {
|
||||||
c.SecretKey = key
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
|
profile.mode = ProtectionManaged
|
||||||
|
profile.secretKey = cloneTransportProtectionKey(key)
|
||||||
if len(key) == 0 {
|
if len(key) == 0 {
|
||||||
c.modernPSKRuntime = nil
|
profile.runtime = nil
|
||||||
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
|
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
|
||||||
c.modernPSKRuntime = runtime
|
profile.runtime = runtime
|
||||||
} else {
|
} else {
|
||||||
c.modernPSKRuntime = nil
|
profile.runtime = nil
|
||||||
}
|
}
|
||||||
|
c.setClientTransportProtectionProfile(profile)
|
||||||
|
c.clearClientSecurityProfiles()
|
||||||
c.securityReadyCheck = len(key) == 0
|
c.securityReadyCheck = len(key) == 0
|
||||||
c.skipKeyExchange = true
|
c.skipKeyExchange = true
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,6 +8,8 @@ import (
|
|||||||
type clientConnAttachmentState struct {
|
type clientConnAttachmentState struct {
|
||||||
maxReadTimeout time.Duration
|
maxReadTimeout time.Duration
|
||||||
maxWriteTimeout time.Duration
|
maxWriteTimeout time.Duration
|
||||||
|
authMode AuthMode
|
||||||
|
protectionMode ProtectionMode
|
||||||
msgEn func([]byte, []byte) []byte
|
msgEn func([]byte, []byte) []byte
|
||||||
msgDe func([]byte, []byte) []byte
|
msgDe func([]byte, []byte) []byte
|
||||||
fastStreamEncode transportFastStreamEncoder
|
fastStreamEncode transportFastStreamEncoder
|
||||||
@ -16,6 +18,13 @@ type clientConnAttachmentState struct {
|
|||||||
modernPSKRuntime *modernPSKCodecRuntime
|
modernPSKRuntime *modernPSKCodecRuntime
|
||||||
handshakeRsaKey []byte
|
handshakeRsaKey []byte
|
||||||
secretKey []byte
|
secretKey []byte
|
||||||
|
keyMode string
|
||||||
|
sessionID []byte
|
||||||
|
forwardSecrecy bool
|
||||||
|
forwardSecrecyFallback bool
|
||||||
|
peerAttached bool
|
||||||
|
peerAttachFallback bool
|
||||||
|
peerAttachAt int64
|
||||||
lastHeartBeat int64
|
lastHeartBeat int64
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -26,6 +35,7 @@ func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnA
|
|||||||
cloned := *src
|
cloned := *src
|
||||||
cloned.handshakeRsaKey = cloneClientConnAttachmentBytes(src.handshakeRsaKey)
|
cloned.handshakeRsaKey = cloneClientConnAttachmentBytes(src.handshakeRsaKey)
|
||||||
cloned.secretKey = cloneClientConnAttachmentBytes(src.secretKey)
|
cloned.secretKey = cloneClientConnAttachmentBytes(src.secretKey)
|
||||||
|
cloned.sessionID = cloneClientConnAttachmentBytes(src.sessionID)
|
||||||
return &cloned
|
return &cloned
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,6 +163,7 @@ func (c *ClientConn) applyClientConnAttachmentProfile(maxReadTimeout time.Durati
|
|||||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
state.maxReadTimeout = maxReadTimeout
|
state.maxReadTimeout = maxReadTimeout
|
||||||
state.maxWriteTimeout = maxWriteTimeout
|
state.maxWriteTimeout = maxWriteTimeout
|
||||||
|
state.protectionMode = ProtectionManaged
|
||||||
state.msgEn = msgEn
|
state.msgEn = msgEn
|
||||||
state.msgDe = msgDe
|
state.msgDe = msgDe
|
||||||
state.modernPSKRuntime = nil
|
state.modernPSKRuntime = nil
|
||||||
@ -206,6 +217,7 @@ func (c *ClientConn) clientConnMsgEnSnapshot() func([]byte, []byte) []byte {
|
|||||||
|
|
||||||
func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) {
|
func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) {
|
||||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.protectionMode = ProtectionManaged
|
||||||
state.msgEn = fn
|
state.msgEn = fn
|
||||||
state.fastStreamEncode = nil
|
state.fastStreamEncode = nil
|
||||||
state.fastBulkEncode = nil
|
state.fastBulkEncode = nil
|
||||||
@ -223,6 +235,7 @@ func (c *ClientConn) clientConnMsgDeSnapshot() func([]byte, []byte) []byte {
|
|||||||
|
|
||||||
func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) {
|
func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) {
|
||||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.protectionMode = ProtectionManaged
|
||||||
state.msgDe = fn
|
state.msgDe = fn
|
||||||
state.fastStreamEncode = nil
|
state.fastStreamEncode = nil
|
||||||
state.fastBulkEncode = nil
|
state.fastBulkEncode = nil
|
||||||
@ -319,6 +332,39 @@ func (c *LogicalConn) modernPSKRuntimeSnapshot() *modernPSKCodecRuntime {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) protectionModeSnapshot() ProtectionMode {
|
||||||
|
if state := c.attachmentStateRaw(); state != nil {
|
||||||
|
return state.protectionMode
|
||||||
|
}
|
||||||
|
return ProtectionManaged
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) authModeSnapshot() AuthMode {
|
||||||
|
if state := c.attachmentStateRaw(); state != nil {
|
||||||
|
return state.authMode
|
||||||
|
}
|
||||||
|
return AuthNone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) peerAttachAuthenticatedSnapshot() (bool, bool, time.Time) {
|
||||||
|
if state := c.attachmentStateRaw(); state != nil {
|
||||||
|
if state.peerAttachAt == 0 {
|
||||||
|
return state.peerAttached, state.peerAttachFallback, time.Time{}
|
||||||
|
}
|
||||||
|
return state.peerAttached, state.peerAttachFallback, time.Unix(0, state.peerAttachAt)
|
||||||
|
}
|
||||||
|
return false, false, time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) markPeerAttachAuthenticated(authMode AuthMode, fallback bool, at time.Time) {
|
||||||
|
c.updateAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.authMode = authMode
|
||||||
|
state.peerAttached = true
|
||||||
|
state.peerAttachFallback = fallback
|
||||||
|
state.peerAttachAt = at.UnixNano()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (c *LogicalConn) setModernPSKRuntime(runtime *modernPSKCodecRuntime) {
|
func (c *LogicalConn) setModernPSKRuntime(runtime *modernPSKCodecRuntime) {
|
||||||
c.updateAttachmentState(func(state *clientConnAttachmentState) {
|
c.updateAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
state.modernPSKRuntime = runtime
|
state.modernPSKRuntime = runtime
|
||||||
|
|||||||
@ -22,10 +22,14 @@ type clientConnectSource struct {
|
|||||||
network string
|
network string
|
||||||
addr string
|
addr string
|
||||||
dialFn func(context.Context) (net.Conn, error)
|
dialFn func(context.Context) (net.Conn, error)
|
||||||
|
supportsAdditional bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientConnConnectSource(conn net.Conn) *clientConnectSource {
|
func newClientConnConnectSource(conn net.Conn) *clientConnectSource {
|
||||||
source := &clientConnectSource{kind: clientConnectSourceConn}
|
source := &clientConnectSource{
|
||||||
|
kind: clientConnectSourceConn,
|
||||||
|
supportsAdditional: false,
|
||||||
|
}
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return source
|
return source
|
||||||
}
|
}
|
||||||
@ -46,6 +50,7 @@ func newClientNetworkConnectSource(network string, addr string) *clientConnectSo
|
|||||||
kind: clientConnectSourceNetwork,
|
kind: clientConnectSourceNetwork,
|
||||||
network: network,
|
network: network,
|
||||||
addr: addr,
|
addr: addr,
|
||||||
|
supportsAdditional: true,
|
||||||
dialFn: func(context.Context) (net.Conn, error) {
|
dialFn: func(context.Context) (net.Conn, error) {
|
||||||
return transport.Dial(network, addr)
|
return transport.Dial(network, addr)
|
||||||
},
|
},
|
||||||
@ -57,6 +62,7 @@ func newClientTimeoutConnectSource(network string, addr string, timeout time.Dur
|
|||||||
kind: clientConnectSourceTimeout,
|
kind: clientConnectSourceTimeout,
|
||||||
network: network,
|
network: network,
|
||||||
addr: addr,
|
addr: addr,
|
||||||
|
supportsAdditional: true,
|
||||||
dialFn: func(context.Context) (net.Conn, error) {
|
dialFn: func(context.Context) (net.Conn, error) {
|
||||||
return transport.DialTimeout(network, addr, timeout)
|
return transport.DialTimeout(network, addr, timeout)
|
||||||
},
|
},
|
||||||
@ -67,6 +73,7 @@ func newClientFactoryConnectSource(dialFn func(context.Context) (net.Conn, error
|
|||||||
return &clientConnectSource{
|
return &clientConnectSource{
|
||||||
kind: clientConnectSourceFactory,
|
kind: clientConnectSourceFactory,
|
||||||
dialFn: dialFn,
|
dialFn: dialFn,
|
||||||
|
supportsAdditional: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,6 +89,10 @@ func (s *clientConnectSource) canReconnect() bool {
|
|||||||
return s != nil && s.dialFn != nil
|
return s != nil && s.dialFn != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *clientConnectSource) supportsAdditionalConn() bool {
|
||||||
|
return s != nil && s.supportsAdditional
|
||||||
|
}
|
||||||
|
|
||||||
func (s *clientConnectSource) isUDP() bool {
|
func (s *clientConnectSource) isUDP() bool {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@ -31,7 +31,11 @@ func (c *ClientCommon) ExchangeKey(newKey []byte) error {
|
|||||||
if string(data.Value) != "success" {
|
if string(data.Value) != "success" {
|
||||||
return errors.New("cannot exchange new aes-key")
|
return errors.New("cannot exchange new aes-key")
|
||||||
}
|
}
|
||||||
c.SecretKey = newKey
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
|
profile.mode = ProtectionManaged
|
||||||
|
profile.secretKey = cloneTransportProtectionKey(newKey)
|
||||||
|
profile.runtime = nil
|
||||||
|
c.setClientTransportProtectionProfile(profile)
|
||||||
time.Sleep(time.Millisecond * 100)
|
time.Sleep(time.Millisecond * 100)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -350,6 +350,8 @@ func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime,
|
|||||||
if rt == nil {
|
if rt == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
c.resetClientPeerAttachAuth()
|
||||||
|
c.activateClientBootstrapTransportProtection()
|
||||||
if runKeyExchange && !c.skipKeyExchange {
|
if runKeyExchange && !c.skipKeyExchange {
|
||||||
if err := c.keyExchangeFn(c); err != nil {
|
if err := c.keyExchangeFn(c); err != nil {
|
||||||
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err)
|
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err)
|
||||||
@ -358,6 +360,7 @@ func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime,
|
|||||||
if err := c.announceClientPeerIdentity(); err != nil {
|
if err := c.announceClientPeerIdentity(); err != nil {
|
||||||
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err)
|
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err)
|
||||||
}
|
}
|
||||||
|
c.activateClientSteadyTransportProtection()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -174,28 +174,37 @@ func (c *ClientCommon) dispatchTransportPayloadFast(payload []byte, release func
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if c.tryDispatchBorrowedBulkTransportPayload(payload) {
|
plain, plainRelease, err := c.decryptTransportPayloadPooled(payload, release)
|
||||||
if release != nil {
|
if err != nil {
|
||||||
release()
|
if c.showError || c.debugMode {
|
||||||
|
fmt.Println("client decode transport payload error", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
owned := append([]byte(nil), payload...)
|
if c.tryDispatchBorrowedTransportPlain(plain, plainRelease) {
|
||||||
if release != nil {
|
return
|
||||||
release()
|
|
||||||
}
|
}
|
||||||
if dispatcher == nil {
|
if dispatcher == nil {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
if err := c.dispatchInboundTransportPayload(owned, now); err != nil && (c.showError || c.debugMode) {
|
err := c.dispatchInboundTransportPlain(plain, now)
|
||||||
|
if plainRelease != nil {
|
||||||
|
plainRelease()
|
||||||
|
}
|
||||||
|
if err != nil && (c.showError || c.debugMode) {
|
||||||
fmt.Println("client decode envelope error", err)
|
fmt.Println("client decode envelope error", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
owned := plain
|
||||||
|
if plainRelease != nil {
|
||||||
|
owned = append([]byte(nil), plain...)
|
||||||
|
plainRelease()
|
||||||
|
}
|
||||||
c.wg.Add(1)
|
c.wg.Add(1)
|
||||||
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
|
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
|
||||||
defer c.wg.Done()
|
defer c.wg.Done()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
if err := c.dispatchInboundTransportPayload(owned, now); err != nil && (c.showError || c.debugMode) {
|
if err := c.dispatchInboundTransportPlain(owned, now); err != nil && (c.showError || c.debugMode) {
|
||||||
fmt.Println("client decode envelope error", err)
|
fmt.Println("client decode envelope error", err)
|
||||||
}
|
}
|
||||||
}) {
|
}) {
|
||||||
|
|||||||
@ -26,6 +26,8 @@ type Client interface {
|
|||||||
BulkOpenTuning() BulkOpenTuning
|
BulkOpenTuning() BulkOpenTuning
|
||||||
SetBulkDedicatedAttachConfig(BulkDedicatedAttachConfig)
|
SetBulkDedicatedAttachConfig(BulkDedicatedAttachConfig)
|
||||||
BulkDedicatedAttachConfig() BulkDedicatedAttachConfig
|
BulkDedicatedAttachConfig() BulkDedicatedAttachConfig
|
||||||
|
SetPeerAttachSecurityConfig(PeerAttachSecurityConfig) error
|
||||||
|
PeerAttachSecurityConfig() PeerAttachSecurityConfig
|
||||||
SetFileReceiveDir(dir string) error
|
SetFileReceiveDir(dir string) error
|
||||||
send(msg TransferMsg) (WaitMsg, error)
|
send(msg TransferMsg) (WaitMsg, error)
|
||||||
sendEnvelope(env Envelope) error
|
sendEnvelope(env Envelope) error
|
||||||
|
|||||||
@ -404,6 +404,7 @@ func (c *LogicalConn) applyAttachmentProfile(maxReadTimeout time.Duration, maxWr
|
|||||||
c.updateAttachmentState(func(state *clientConnAttachmentState) {
|
c.updateAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
state.maxReadTimeout = maxReadTimeout
|
state.maxReadTimeout = maxReadTimeout
|
||||||
state.maxWriteTimeout = maxWriteTimeout
|
state.maxWriteTimeout = maxWriteTimeout
|
||||||
|
state.protectionMode = ProtectionManaged
|
||||||
state.msgEn = msgEn
|
state.msgEn = msgEn
|
||||||
state.msgDe = msgDe
|
state.msgDe = msgDe
|
||||||
state.fastStreamEncode = fastStreamEncode
|
state.fastStreamEncode = fastStreamEncode
|
||||||
@ -525,6 +526,7 @@ func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot {
|
|||||||
return ClientConnRuntimeSnapshot{}
|
return ClientConnRuntimeSnapshot{}
|
||||||
}
|
}
|
||||||
status := c.Status()
|
status := c.Status()
|
||||||
|
authenticated, fallback, attachAt := c.peerAttachAuthenticatedSnapshot()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
snapshot := ClientConnRuntimeSnapshot{
|
snapshot := ClientConnRuntimeSnapshot{
|
||||||
ClientID: c.clientIDSnapshot(),
|
ClientID: c.clientIDSnapshot(),
|
||||||
@ -536,6 +538,11 @@ func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot {
|
|||||||
TransportAttachCount: c.clientConnTransportAttachCountSnapshot(),
|
TransportAttachCount: c.clientConnTransportAttachCountSnapshot(),
|
||||||
TransportDetachCount: c.clientConnTransportDetachCountSnapshot(),
|
TransportDetachCount: c.clientConnTransportDetachCountSnapshot(),
|
||||||
LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(),
|
LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(),
|
||||||
|
AuthMode: authModeName(c.authModeSnapshot()),
|
||||||
|
ProtectionMode: protectionModeName(c.protectionModeSnapshot()),
|
||||||
|
PeerAttachAuthenticated: authenticated,
|
||||||
|
PeerAttachAuthFallback: fallback,
|
||||||
|
LastPeerAttachAt: attachAt,
|
||||||
}
|
}
|
||||||
if status.Err != nil {
|
if status.Err != nil {
|
||||||
snapshot.Error = status.Err.Error()
|
snapshot.Error = status.Err.Error()
|
||||||
|
|||||||
@ -43,6 +43,20 @@ func TestHydrateServerMessagePeerFieldsFromLogicalConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHydrateServerMessagePeerFieldsZeroValueDoesNotPanic(t *testing.T) {
|
||||||
|
message := hydrateServerMessagePeerFields(Message{})
|
||||||
|
|
||||||
|
if message.LogicalConn != nil {
|
||||||
|
t.Fatal("zero-value message should not hydrate logical conn")
|
||||||
|
}
|
||||||
|
if message.ClientConn != nil {
|
||||||
|
t.Fatal("zero-value message should not hydrate client conn")
|
||||||
|
}
|
||||||
|
if message.TransportConn != nil {
|
||||||
|
t.Fatal("zero-value message should not hydrate transport conn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServerPublishSendFileEventHydratesLogicalConnAlias(t *testing.T) {
|
func TestServerPublishSendFileEventHydratesLogicalConnAlias(t *testing.T) {
|
||||||
server := NewServer().(*ServerCommon)
|
server := NewServer().(*ServerCommon)
|
||||||
left, right := net.Pipe()
|
left, right := net.Pipe()
|
||||||
|
|||||||
33
msg.go
33
msg.go
@ -40,6 +40,7 @@ type Message struct {
|
|||||||
ClientConn *ClientConn
|
ClientConn *ClientConn
|
||||||
TransportConn *TransportConn
|
TransportConn *TransportConn
|
||||||
ServerConn Client
|
ServerConn Client
|
||||||
|
inboundTransportProfile *transportProtectionProfile
|
||||||
TransferMsg
|
TransferMsg
|
||||||
Time time.Time
|
Time time.Time
|
||||||
inboundConn net.Conn
|
inboundConn net.Conn
|
||||||
@ -58,7 +59,7 @@ type messageLogicalTransferSender interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type messageInboundTransferSender interface {
|
type messageInboundTransferSender interface {
|
||||||
sendTransferInbound(*LogicalConn, *TransportConn, net.Conn, TransferMsg) error
|
sendTransferInbound(*LogicalConn, *TransportConn, net.Conn, *transportProtectionProfile, TransferMsg) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Message) Reply(value MsgVal) (err error) {
|
func (m *Message) Reply(value MsgVal) (err error) {
|
||||||
@ -86,7 +87,7 @@ func (m *Message) Reply(value MsgVal) (err error) {
|
|||||||
if sender == nil {
|
if sender == nil {
|
||||||
return transportDetachedErrorForPeer(logical, transport)
|
return transportDetachedErrorForPeer(logical, transport)
|
||||||
}
|
}
|
||||||
return sender.sendTransferInbound(logical, transport, m.inboundConn, reply)
|
return sender.sendTransferInbound(logical, transport, m.inboundConn, messageInboundTransportProtectionSnapshot(m), reply)
|
||||||
}
|
}
|
||||||
if transport != nil {
|
if transport != nil {
|
||||||
_, err = transport.sendTransfer(reply)
|
_, err = transport.sendTransfer(reply)
|
||||||
@ -123,12 +124,19 @@ func hydrateServerMessagePeerFields(message Message) Message {
|
|||||||
if message.LogicalConn == nil {
|
if message.LogicalConn == nil {
|
||||||
message.LogicalConn = logicalConnFromClient(message.ClientConn)
|
message.LogicalConn = logicalConnFromClient(message.ClientConn)
|
||||||
}
|
}
|
||||||
if message.ClientConn == nil {
|
if message.LogicalConn == nil && message.TransportConn != nil {
|
||||||
|
message.LogicalConn = message.TransportConn.logicalConnSnapshot()
|
||||||
|
}
|
||||||
|
if message.ClientConn == nil && message.LogicalConn != nil {
|
||||||
message.ClientConn = message.LogicalConn.compatClientConn()
|
message.ClientConn = message.LogicalConn.compatClientConn()
|
||||||
}
|
}
|
||||||
if message.TransportConn == nil && message.LogicalConn != nil {
|
if message.TransportConn == nil && message.LogicalConn != nil {
|
||||||
message.TransportConn = message.LogicalConn.CurrentTransportConn()
|
message.TransportConn = message.LogicalConn.CurrentTransportConn()
|
||||||
}
|
}
|
||||||
|
if message.inboundConn != nil && message.inboundTransportProfile == nil && message.LogicalConn != nil {
|
||||||
|
profile := message.LogicalConn.transportProtectionProfileSnapshot()
|
||||||
|
message.inboundTransportProfile = &profile
|
||||||
|
}
|
||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,3 +163,22 @@ func messageTransportConnSnapshot(message *Message) *TransportConn {
|
|||||||
}
|
}
|
||||||
return logical.CurrentTransportConn()
|
return logical.CurrentTransportConn()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func messageInboundTransportProtectionSnapshot(message *Message) *transportProtectionProfile {
|
||||||
|
if message == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if message.inboundTransportProfile != nil {
|
||||||
|
return message.inboundTransportProfile
|
||||||
|
}
|
||||||
|
if message.inboundConn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
logical := messageLogicalConnSnapshot(message)
|
||||||
|
if logical == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
profile := logical.transportProtectionProfileSnapshot()
|
||||||
|
message.inboundTransportProfile = &profile
|
||||||
|
return message.inboundTransportProfile
|
||||||
|
}
|
||||||
|
|||||||
517
peer_attach_auth.go
Normal file
517
peer_attach_auth.go
Normal file
@ -0,0 +1,517 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
cryptorand "crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
peerAttachFeatureExplicitAuth uint64 = 1 << iota
|
||||||
|
peerAttachFeatureChannelBinding
|
||||||
|
peerAttachFeatureForwardSecrecy
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
peerAttachNonceSize = 16
|
||||||
|
peerAttachReplayTTL = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errPeerAttachAuthInvalid = errors.New("peer attach auth invalid")
|
||||||
|
errPeerAttachReplayRejected = errors.New("peer attach replay rejected")
|
||||||
|
errPeerAttachReplayWindowFull = errors.New("peer attach replay window full")
|
||||||
|
errPeerAttachExplicitAuthRequired = errors.New("peer attach explicit auth required")
|
||||||
|
errPeerAttachChannelBindingRequired = errors.New("peer attach channel binding required")
|
||||||
|
errPeerAttachChannelBindingUnavailable = errors.New("peer attach channel binding unavailable")
|
||||||
|
errPeerAttachForwardSecrecyRequired = errors.New("peer attach forward secrecy required")
|
||||||
|
)
|
||||||
|
|
||||||
|
type peerAttachAuthResult struct {
|
||||||
|
explicit bool
|
||||||
|
fallback bool
|
||||||
|
clientNonce []byte
|
||||||
|
serverNonce []byte
|
||||||
|
channelBinding []byte
|
||||||
|
clientECDHEPublicKey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type peerAttachReplayCache struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
entries map[string]time.Time
|
||||||
|
rejects atomic.Int64
|
||||||
|
overflowRejects atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPeerAttachNonce() ([]byte, error) {
|
||||||
|
buf := make([]byte, peerAttachNonceSize)
|
||||||
|
if _, err := cryptorand.Read(buf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendPeerAttachAuthBytes(dst []byte, data []byte) []byte {
|
||||||
|
dst = binary.BigEndian.AppendUint32(dst, uint32(len(data)))
|
||||||
|
return append(dst, data...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendPeerAttachAuthString(dst []byte, value string) []byte {
|
||||||
|
return appendPeerAttachAuthBytes(dst, []byte(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendPeerAttachAuthBool(dst []byte, value bool) []byte {
|
||||||
|
if value {
|
||||||
|
return append(dst, 1)
|
||||||
|
}
|
||||||
|
return append(dst, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerAttachRequestAuthPayload(req peerAttachRequest, channelBinding []byte) []byte {
|
||||||
|
buf := make([]byte, 0, 96+len(req.PeerID)+len(channelBinding))
|
||||||
|
buf = appendPeerAttachAuthString(buf, "notify/peer-attach/request-auth/v1")
|
||||||
|
buf = binary.BigEndian.AppendUint64(buf, req.Features)
|
||||||
|
buf = appendPeerAttachAuthString(buf, req.PeerID)
|
||||||
|
buf = appendPeerAttachAuthBytes(buf, req.ClientNonce)
|
||||||
|
if supportsPeerAttachChannelBinding(req.Features) {
|
||||||
|
buf = appendPeerAttachAuthBytes(buf, channelBinding)
|
||||||
|
}
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerAttachResponseAuthPayload(req peerAttachRequest, resp peerAttachResponse, channelBinding []byte) []byte {
|
||||||
|
buf := make([]byte, 0, 160+len(req.PeerID)+len(resp.PeerID)+len(resp.Error)+len(channelBinding))
|
||||||
|
buf = appendPeerAttachAuthString(buf, "notify/peer-attach/response-auth/v1")
|
||||||
|
buf = binary.BigEndian.AppendUint64(buf, req.Features)
|
||||||
|
buf = appendPeerAttachAuthString(buf, req.PeerID)
|
||||||
|
buf = appendPeerAttachAuthBytes(buf, req.ClientNonce)
|
||||||
|
buf = binary.BigEndian.AppendUint64(buf, resp.Features)
|
||||||
|
buf = appendPeerAttachAuthString(buf, resp.PeerID)
|
||||||
|
buf = appendPeerAttachAuthBool(buf, resp.Accepted)
|
||||||
|
buf = appendPeerAttachAuthBool(buf, resp.Reused)
|
||||||
|
buf = appendPeerAttachAuthString(buf, resp.Error)
|
||||||
|
buf = appendPeerAttachAuthBytes(buf, resp.ServerNonce)
|
||||||
|
if supportsPeerAttachChannelBinding(resp.Features) {
|
||||||
|
buf = appendPeerAttachAuthBytes(buf, channelBinding)
|
||||||
|
}
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func signPeerAttachPayload(secretKey []byte, payload []byte) []byte {
|
||||||
|
if len(secretKey) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mac := hmac.New(sha256.New, secretKey)
|
||||||
|
_, _ = mac.Write(payload)
|
||||||
|
return mac.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func computePeerAttachRequestAuthTag(secretKey []byte, req peerAttachRequest, channelBinding []byte) []byte {
|
||||||
|
return signPeerAttachPayload(secretKey, peerAttachRequestAuthPayload(req, channelBinding))
|
||||||
|
}
|
||||||
|
|
||||||
|
func computePeerAttachResponseAuthTag(secretKey []byte, req peerAttachRequest, resp peerAttachResponse, channelBinding []byte) []byte {
|
||||||
|
return signPeerAttachPayload(secretKey, peerAttachResponseAuthPayload(req, resp, channelBinding))
|
||||||
|
}
|
||||||
|
|
||||||
|
func supportsExplicitPeerAttachAuth(features uint64) bool {
|
||||||
|
return features&peerAttachFeatureExplicitAuth != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func supportsPeerAttachChannelBinding(features uint64) bool {
|
||||||
|
return features&peerAttachFeatureChannelBinding != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func supportsPeerAttachForwardSecrecy(features uint64) bool {
|
||||||
|
return features&peerAttachFeatureForwardSecrecy != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyPeerAttachRejectCounter(s *ServerCommon, err error) {
|
||||||
|
if s == nil || err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, errPeerAttachReplayRejected):
|
||||||
|
s.peerAttachReplay.rejects.Add(1)
|
||||||
|
case errors.Is(err, errPeerAttachReplayWindowFull):
|
||||||
|
s.peerAttachReplay.overflowRejects.Add(1)
|
||||||
|
case errors.Is(err, errPeerAttachExplicitAuthRequired), errors.Is(err, errPeerAttachChannelBindingRequired), errors.Is(err, errPeerAttachForwardSecrecyRequired):
|
||||||
|
s.peerAttachDowngradeRejectCount.Add(1)
|
||||||
|
case errors.Is(err, errPeerAttachChannelBindingUnavailable):
|
||||||
|
s.peerAttachBindingRejectCount.Add(1)
|
||||||
|
default:
|
||||||
|
s.peerAttachAuthRejectCount.Add(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) shouldUseExplicitPeerAttachAuth() bool {
|
||||||
|
if c == nil || !c.securityConfigured {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return c.securityAuthMode == AuthPSK && len(c.securityBootstrap.secretKey) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolvePeerAttachChannelBinding(provider PeerAttachChannelBindingProvider, role PeerAttachChannelBindingRole, peerID string, conn net.Conn) ([]byte, error) {
|
||||||
|
if provider == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return nil, errPeerAttachChannelBindingUnavailable
|
||||||
|
}
|
||||||
|
binding, err := provider(PeerAttachChannelBindingContext{
|
||||||
|
Role: role,
|
||||||
|
PeerID: peerID,
|
||||||
|
Conn: conn,
|
||||||
|
})
|
||||||
|
if err != nil || len(binding) == 0 {
|
||||||
|
return nil, errPeerAttachChannelBindingUnavailable
|
||||||
|
}
|
||||||
|
return bytes.Clone(binding), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) buildPeerAttachRequest(peerID string) (peerAttachRequest, peerAttachRequestState, error) {
|
||||||
|
cfg := c.peerAttachSecuritySnapshot()
|
||||||
|
req := peerAttachRequest{
|
||||||
|
PeerID: stringsTrimSpaceNoAlloc(peerID),
|
||||||
|
}
|
||||||
|
if !c.shouldUseExplicitPeerAttachAuth() {
|
||||||
|
if c.clientRequiresForwardSecrecy() {
|
||||||
|
return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachForwardSecrecyRequired
|
||||||
|
}
|
||||||
|
if cfg.requireExplicitAuth {
|
||||||
|
return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachExplicitAuthRequired
|
||||||
|
}
|
||||||
|
return req, peerAttachRequestState{}, nil
|
||||||
|
}
|
||||||
|
nonce, err := newPeerAttachNonce()
|
||||||
|
if err != nil {
|
||||||
|
return peerAttachRequest{}, peerAttachRequestState{}, err
|
||||||
|
}
|
||||||
|
req.Features = peerAttachFeatureExplicitAuth
|
||||||
|
requestState := peerAttachRequestState{}
|
||||||
|
var channelBinding []byte
|
||||||
|
if cfg.channelBinding != nil {
|
||||||
|
channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleClient, req.PeerID, c.clientTransportConnSnapshot())
|
||||||
|
if err != nil {
|
||||||
|
return peerAttachRequest{}, peerAttachRequestState{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(channelBinding) != 0 {
|
||||||
|
req.Features |= peerAttachFeatureChannelBinding
|
||||||
|
}
|
||||||
|
if cfg.requireChannelBinding && !supportsPeerAttachChannelBinding(req.Features) {
|
||||||
|
return peerAttachRequest{}, peerAttachRequestState{}, errPeerAttachChannelBindingRequired
|
||||||
|
}
|
||||||
|
if c.clientSupportsForwardSecrecy() {
|
||||||
|
requestState.forwardSecrecy, err = newPeerAttachForwardSecrecyClientState()
|
||||||
|
if err != nil {
|
||||||
|
return peerAttachRequest{}, peerAttachRequestState{}, err
|
||||||
|
}
|
||||||
|
req.Features |= peerAttachFeatureForwardSecrecy
|
||||||
|
req.ClientECDHEPublicKey = bytes.Clone(requestState.forwardSecrecy.publicKey)
|
||||||
|
}
|
||||||
|
req.ClientNonce = nonce
|
||||||
|
req.AuthTag = computePeerAttachRequestAuthTag(c.securityBootstrap.secretKey, req, channelBinding)
|
||||||
|
return req, requestState, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) verifyPeerAttachResponse(req peerAttachRequest, resp peerAttachResponse, requestState peerAttachRequestState) (peerAttachResponseVerifyResult, error) {
|
||||||
|
cfg := c.peerAttachSecuritySnapshot()
|
||||||
|
baseSteady := transportProtectionProfile{}
|
||||||
|
if c != nil {
|
||||||
|
baseSteady = c.securitySteady.clone().withForwardSecrecyFallback(false)
|
||||||
|
}
|
||||||
|
result := peerAttachResponseVerifyResult{steadyProfile: baseSteady}
|
||||||
|
if c == nil || !c.shouldUseExplicitPeerAttachAuth() {
|
||||||
|
if c.clientRequiresForwardSecrecy() {
|
||||||
|
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired
|
||||||
|
}
|
||||||
|
if cfg.requireExplicitAuth {
|
||||||
|
return peerAttachResponseVerifyResult{}, errPeerAttachExplicitAuthRequired
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
if !supportsExplicitPeerAttachAuth(resp.Features) {
|
||||||
|
if c.clientRequiresForwardSecrecy() {
|
||||||
|
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired
|
||||||
|
}
|
||||||
|
if cfg.requireExplicitAuth {
|
||||||
|
return peerAttachResponseVerifyResult{}, errPeerAttachExplicitAuthRequired
|
||||||
|
}
|
||||||
|
if c.clientSupportsForwardSecrecy() {
|
||||||
|
result.steadyProfile = result.steadyProfile.withForwardSecrecyFallback(true)
|
||||||
|
}
|
||||||
|
result.authFallback = true
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
var channelBinding []byte
|
||||||
|
if supportsPeerAttachChannelBinding(req.Features) {
|
||||||
|
if cfg.channelBinding == nil {
|
||||||
|
return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingUnavailable
|
||||||
|
}
|
||||||
|
if !supportsPeerAttachChannelBinding(resp.Features) {
|
||||||
|
return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingRequired
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleClient, req.PeerID, c.clientTransportConnSnapshot())
|
||||||
|
if err != nil {
|
||||||
|
return peerAttachResponseVerifyResult{}, err
|
||||||
|
}
|
||||||
|
} else if cfg.requireChannelBinding {
|
||||||
|
return peerAttachResponseVerifyResult{}, errPeerAttachChannelBindingRequired
|
||||||
|
}
|
||||||
|
if len(resp.ServerNonce) != peerAttachNonceSize {
|
||||||
|
return peerAttachResponseVerifyResult{}, errPeerAttachAuthInvalid
|
||||||
|
}
|
||||||
|
expected := computePeerAttachResponseAuthTag(c.securityBootstrap.secretKey, req, peerAttachResponse{
|
||||||
|
PeerID: resp.PeerID,
|
||||||
|
Accepted: resp.Accepted,
|
||||||
|
Reused: resp.Reused,
|
||||||
|
Error: resp.Error,
|
||||||
|
Features: resp.Features,
|
||||||
|
ServerNonce: resp.ServerNonce,
|
||||||
|
}, channelBinding)
|
||||||
|
if !hmac.Equal(resp.AuthTag, expected) {
|
||||||
|
return peerAttachResponseVerifyResult{}, errPeerAttachAuthInvalid
|
||||||
|
}
|
||||||
|
if requestState.forwardSecrecy == nil || !supportsPeerAttachForwardSecrecy(req.Features) {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
if !supportsPeerAttachForwardSecrecy(resp.Features) {
|
||||||
|
if c.clientRequiresForwardSecrecy() {
|
||||||
|
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyRequired
|
||||||
|
}
|
||||||
|
result.steadyProfile = result.steadyProfile.withForwardSecrecyFallback(true)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
if resp.KeyMode != "" && resp.KeyMode != peerAttachKeyModeECDHE {
|
||||||
|
return peerAttachResponseVerifyResult{}, errPeerAttachForwardSecrecyInvalid
|
||||||
|
}
|
||||||
|
profile, err := derivePeerAttachForwardSecrecyTransportProfile(c.securitySteady, c.securityBootstrap.secretKey, requestState.forwardSecrecy.privateKey, resp.ServerECDHEPublicKey, req, resp)
|
||||||
|
if err != nil {
|
||||||
|
return peerAttachResponseVerifyResult{}, err
|
||||||
|
}
|
||||||
|
result.steadyProfile = profile
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) validatePeerAttachRequestAuth(logical *LogicalConn, transport net.Conn, req peerAttachRequest) (peerAttachAuthResult, error) {
|
||||||
|
cfg := s.peerAttachSecuritySnapshot()
|
||||||
|
if !supportsExplicitPeerAttachAuth(req.Features) {
|
||||||
|
if s.serverRequiresForwardSecrecy() {
|
||||||
|
return peerAttachAuthResult{}, errPeerAttachForwardSecrecyRequired
|
||||||
|
}
|
||||||
|
if cfg.requireExplicitAuth {
|
||||||
|
return peerAttachAuthResult{}, errPeerAttachExplicitAuthRequired
|
||||||
|
}
|
||||||
|
return peerAttachAuthResult{fallback: true}, nil
|
||||||
|
}
|
||||||
|
if logical == nil {
|
||||||
|
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
|
||||||
|
}
|
||||||
|
var channelBinding []byte
|
||||||
|
if supportsPeerAttachChannelBinding(req.Features) {
|
||||||
|
if cfg.channelBinding == nil {
|
||||||
|
return peerAttachAuthResult{}, errPeerAttachChannelBindingUnavailable
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
channelBinding, err = resolvePeerAttachChannelBinding(cfg.channelBinding, PeerAttachChannelBindingRoleServer, req.PeerID, transport)
|
||||||
|
if err != nil {
|
||||||
|
return peerAttachAuthResult{}, err
|
||||||
|
}
|
||||||
|
} else if cfg.requireChannelBinding {
|
||||||
|
return peerAttachAuthResult{}, errPeerAttachChannelBindingRequired
|
||||||
|
}
|
||||||
|
secretKey := logical.secretKeySnapshot()
|
||||||
|
if len(secretKey) == 0 {
|
||||||
|
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
|
||||||
|
}
|
||||||
|
if len(req.ClientNonce) != peerAttachNonceSize {
|
||||||
|
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
|
||||||
|
}
|
||||||
|
expected := computePeerAttachRequestAuthTag(secretKey, peerAttachRequest{
|
||||||
|
PeerID: req.PeerID,
|
||||||
|
Features: req.Features,
|
||||||
|
ClientNonce: req.ClientNonce,
|
||||||
|
}, channelBinding)
|
||||||
|
if !hmac.Equal(req.AuthTag, expected) {
|
||||||
|
return peerAttachAuthResult{}, errPeerAttachAuthInvalid
|
||||||
|
}
|
||||||
|
if supportsPeerAttachForwardSecrecy(req.Features) {
|
||||||
|
if len(req.ClientECDHEPublicKey) != peerAttachECDHEPublicKeySize {
|
||||||
|
return peerAttachAuthResult{}, errPeerAttachForwardSecrecyInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.acceptPeerAttachReplay(req.PeerID, req.ClientNonce, time.Now(), cfg.replayWindow, cfg.replayCapacity); err != nil {
|
||||||
|
return peerAttachAuthResult{}, err
|
||||||
|
}
|
||||||
|
serverNonce, err := newPeerAttachNonce()
|
||||||
|
if err != nil {
|
||||||
|
return peerAttachAuthResult{}, err
|
||||||
|
}
|
||||||
|
return peerAttachAuthResult{
|
||||||
|
explicit: true,
|
||||||
|
clientNonce: bytes.Clone(req.ClientNonce),
|
||||||
|
serverNonce: serverNonce,
|
||||||
|
channelBinding: channelBinding,
|
||||||
|
clientECDHEPublicKey: bytes.Clone(req.ClientECDHEPublicKey),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) signPeerAttachResponse(logical *LogicalConn, req peerAttachRequest, resp *peerAttachResponse, auth peerAttachAuthResult) {
|
||||||
|
if s == nil || logical == nil || resp == nil || !auth.explicit {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
secretKey := logical.secretKeySnapshot()
|
||||||
|
if len(secretKey) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Features |= peerAttachFeatureExplicitAuth
|
||||||
|
if len(auth.channelBinding) != 0 {
|
||||||
|
resp.Features |= peerAttachFeatureChannelBinding
|
||||||
|
}
|
||||||
|
resp.ServerNonce = bytes.Clone(auth.serverNonce)
|
||||||
|
resp.AuthTag = computePeerAttachResponseAuthTag(secretKey, req, *resp, auth.channelBinding)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) preparePeerAttachSteadyTransportProfile(logical *LogicalConn, req peerAttachRequest, resp *peerAttachResponse, auth peerAttachAuthResult) (transportProtectionProfile, error) {
|
||||||
|
if s == nil {
|
||||||
|
return transportProtectionProfile{}, nil
|
||||||
|
}
|
||||||
|
profile := s.securitySteady.clone().withForwardSecrecyFallback(false)
|
||||||
|
if resp != nil && auth.explicit {
|
||||||
|
resp.Features |= peerAttachFeatureExplicitAuth
|
||||||
|
if len(auth.channelBinding) != 0 {
|
||||||
|
resp.Features |= peerAttachFeatureChannelBinding
|
||||||
|
}
|
||||||
|
resp.ServerNonce = bytes.Clone(auth.serverNonce)
|
||||||
|
}
|
||||||
|
if resp != nil && resp.KeyMode == "" {
|
||||||
|
resp.KeyMode = profile.keyMode
|
||||||
|
}
|
||||||
|
if !s.serverSupportsForwardSecrecy() {
|
||||||
|
return profile, nil
|
||||||
|
}
|
||||||
|
if !auth.explicit || !supportsPeerAttachForwardSecrecy(req.Features) {
|
||||||
|
if s.serverRequiresForwardSecrecy() {
|
||||||
|
return transportProtectionProfile{}, errPeerAttachForwardSecrecyRequired
|
||||||
|
}
|
||||||
|
return profile.withForwardSecrecyFallback(true), nil
|
||||||
|
}
|
||||||
|
fsState, err := newPeerAttachForwardSecrecyClientState()
|
||||||
|
if err != nil {
|
||||||
|
return transportProtectionProfile{}, err
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
resp.Features |= peerAttachFeatureForwardSecrecy
|
||||||
|
resp.KeyMode = peerAttachKeyModeECDHE
|
||||||
|
resp.ServerECDHEPublicKey = bytes.Clone(fsState.publicKey)
|
||||||
|
}
|
||||||
|
return derivePeerAttachForwardSecrecyTransportProfile(s.securitySteady, logical.secretKeySnapshot(), fsState.privateKey, auth.clientECDHEPublicKey, req, *resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) acceptPeerAttachReplay(peerID string, nonce []byte, now time.Time, window time.Duration, capacity int) error {
|
||||||
|
if s == nil || len(nonce) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cache := &s.peerAttachReplay
|
||||||
|
key := peerID + "\x00" + string(nonce)
|
||||||
|
expireBefore := now.Add(-window)
|
||||||
|
cache.mu.Lock()
|
||||||
|
defer cache.mu.Unlock()
|
||||||
|
if cache.entries == nil {
|
||||||
|
cache.entries = make(map[string]time.Time)
|
||||||
|
}
|
||||||
|
for replayKey, seenAt := range cache.entries {
|
||||||
|
if seenAt.Before(expireBefore) {
|
||||||
|
delete(cache.entries, replayKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if seenAt, ok := cache.entries[key]; ok && !seenAt.Before(expireBefore) {
|
||||||
|
return errPeerAttachReplayRejected
|
||||||
|
}
|
||||||
|
if capacity > 0 && len(cache.entries) >= capacity {
|
||||||
|
return errPeerAttachReplayWindowFull
|
||||||
|
}
|
||||||
|
cache.entries[key] = now
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) peerAttachReplayRejectCountSnapshot() int64 {
|
||||||
|
if s == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return s.peerAttachReplay.rejects.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) peerAttachReplayOverflowRejectCountSnapshot() int64 {
|
||||||
|
if s == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return s.peerAttachReplay.overflowRejects.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) markClientPeerAttachAuthenticated(fallback bool, at time.Time) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.peerAttachAuthenticated = true
|
||||||
|
c.peerAttachAuthFallback = fallback
|
||||||
|
c.peerAttachAt = at.UnixNano()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) resetClientPeerAttachAuth() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.peerAttachAuthenticated = false
|
||||||
|
c.peerAttachAuthFallback = false
|
||||||
|
c.peerAttachAt = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientPeerAttachAuthSnapshot() (bool, bool, time.Time) {
|
||||||
|
if c == nil {
|
||||||
|
return false, false, time.Time{}
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.peerAttachAt == 0 {
|
||||||
|
return c.peerAttachAuthenticated, c.peerAttachAuthFallback, time.Time{}
|
||||||
|
}
|
||||||
|
return c.peerAttachAuthenticated, c.peerAttachAuthFallback, time.Unix(0, c.peerAttachAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringsTrimSpaceNoAlloc(value string) string {
|
||||||
|
start := 0
|
||||||
|
for start < len(value) {
|
||||||
|
switch value[start] {
|
||||||
|
case ' ', '\t', '\n', '\r':
|
||||||
|
start++
|
||||||
|
default:
|
||||||
|
goto endStart
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
endStart:
|
||||||
|
end := len(value)
|
||||||
|
for end > start {
|
||||||
|
switch value[end-1] {
|
||||||
|
case ' ', '\t', '\n', '\r':
|
||||||
|
end--
|
||||||
|
default:
|
||||||
|
return value[start:end]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value[start:end]
|
||||||
|
}
|
||||||
237
peer_attach_auth_test.go
Normal file
237
peer_attach_auth_test.go
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newPeerAttachAuthLogicalForTest(t *testing.T, server *ServerCommon) *LogicalConn {
|
||||||
|
t.Helper()
|
||||||
|
logical := newServerLogicalConn(server, "accepted-auth", nil)
|
||||||
|
logical = server.registerAcceptedLogical(logical)
|
||||||
|
if logical == nil {
|
||||||
|
t.Fatal("registerAcceptedLogical returned nil")
|
||||||
|
}
|
||||||
|
return logical
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerAttachExplicitAuthHelpersRoundTrip(t *testing.T) {
|
||||||
|
secret := []byte("correct horse battery staple")
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logical := newPeerAttachAuthLogicalForTest(t, server)
|
||||||
|
req, reqState, err := client.buildPeerAttachRequest("peer-explicit")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildPeerAttachRequest failed: %v", err)
|
||||||
|
}
|
||||||
|
if !supportsExplicitPeerAttachAuth(req.Features) {
|
||||||
|
t.Fatalf("request features = %d, want explicit auth bit", req.Features)
|
||||||
|
}
|
||||||
|
if !supportsPeerAttachForwardSecrecy(req.Features) {
|
||||||
|
t.Fatalf("request features = %d, want forward secrecy bit", req.Features)
|
||||||
|
}
|
||||||
|
if len(req.ClientNonce) != peerAttachNonceSize {
|
||||||
|
t.Fatalf("client nonce length = %d, want %d", len(req.ClientNonce), peerAttachNonceSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth, err := server.validatePeerAttachRequestAuth(logical, nil, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("validatePeerAttachRequestAuth failed: %v", err)
|
||||||
|
}
|
||||||
|
if !auth.explicit || auth.fallback {
|
||||||
|
t.Fatalf("auth result mismatch: %+v", auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := peerAttachResponse{
|
||||||
|
PeerID: req.PeerID,
|
||||||
|
Accepted: true,
|
||||||
|
}
|
||||||
|
server.signPeerAttachResponse(logical, req, &resp, auth)
|
||||||
|
if !supportsExplicitPeerAttachAuth(resp.Features) {
|
||||||
|
t.Fatalf("response features = %d, want explicit auth bit", resp.Features)
|
||||||
|
}
|
||||||
|
if len(resp.ServerNonce) != peerAttachNonceSize {
|
||||||
|
t.Fatalf("server nonce length = %d, want %d", len(resp.ServerNonce), peerAttachNonceSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyResult, err := client.verifyPeerAttachResponse(req, resp, reqState)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("verifyPeerAttachResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
if verifyResult.authFallback {
|
||||||
|
t.Fatal("explicit response should not be marked as fallback")
|
||||||
|
}
|
||||||
|
if !verifyResult.steadyProfile.forwardSecrecyFallback {
|
||||||
|
t.Fatal("response without fs extension should mark forward secrecy fallback")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.AuthTag[0] ^= 0xff
|
||||||
|
if verifyResult, err = client.verifyPeerAttachResponse(req, resp, reqState); !errors.Is(err, errPeerAttachAuthInvalid) {
|
||||||
|
t.Fatalf("tampered response error = %v, want %v", err, errPeerAttachAuthInvalid)
|
||||||
|
} else if verifyResult.authFallback {
|
||||||
|
t.Fatal("tampered explicit response should not be treated as fallback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerAttachRequestAuthRejectsReplay(t *testing.T) {
|
||||||
|
secret := []byte("correct horse battery staple")
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logical := newPeerAttachAuthLogicalForTest(t, server)
|
||||||
|
req, _, err := client.buildPeerAttachRequest("peer-replay")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildPeerAttachRequest failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := server.validatePeerAttachRequestAuth(logical, nil, req); err != nil {
|
||||||
|
t.Fatalf("first validatePeerAttachRequestAuth failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := server.validatePeerAttachRequestAuth(logical, nil, req); !errors.Is(err, errPeerAttachReplayRejected) {
|
||||||
|
t.Fatalf("second validatePeerAttachRequestAuth error = %v, want %v", err, errPeerAttachReplayRejected)
|
||||||
|
}
|
||||||
|
classifyPeerAttachRejectCounter(server, errPeerAttachReplayRejected)
|
||||||
|
if got, want := server.peerAttachReplayRejectCountSnapshot(), int64(1); got != want {
|
||||||
|
t.Fatalf("replay reject count = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerAttachAuthFallbackCompatibility(t *testing.T) {
|
||||||
|
secret := []byte("correct horse battery staple")
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logical := newPeerAttachAuthLogicalForTest(t, server)
|
||||||
|
auth, err := server.validatePeerAttachRequestAuth(logical, nil, peerAttachRequest{PeerID: "peer-fallback"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("validatePeerAttachRequestAuth fallback failed: %v", err)
|
||||||
|
}
|
||||||
|
if auth.explicit || !auth.fallback {
|
||||||
|
t.Fatalf("fallback auth result mismatch: %+v", auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, reqState, err := client.buildPeerAttachRequest("peer-fallback")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildPeerAttachRequest failed: %v", err)
|
||||||
|
}
|
||||||
|
verifyResult, err := client.verifyPeerAttachResponse(req, peerAttachResponse{
|
||||||
|
PeerID: req.PeerID,
|
||||||
|
Accepted: true,
|
||||||
|
}, reqState)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("verifyPeerAttachResponse fallback failed: %v", err)
|
||||||
|
}
|
||||||
|
if !verifyResult.authFallback {
|
||||||
|
t.Fatal("unsigned legacy response should be marked as fallback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerAttachForwardSecrecyNegotiatesDerivedSteadyProfile(t *testing.T) {
|
||||||
|
secret := []byte("correct horse battery staple")
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logical := newPeerAttachAuthLogicalForTest(t, server)
|
||||||
|
req, reqState, err := client.buildPeerAttachRequest("peer-fs")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildPeerAttachRequest failed: %v", err)
|
||||||
|
}
|
||||||
|
auth, err := server.validatePeerAttachRequestAuth(logical, nil, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("validatePeerAttachRequestAuth failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := peerAttachResponse{
|
||||||
|
PeerID: req.PeerID,
|
||||||
|
Accepted: true,
|
||||||
|
}
|
||||||
|
serverProfile, err := server.preparePeerAttachSteadyTransportProfile(logical, req, &resp, auth)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("preparePeerAttachSteadyTransportProfile failed: %v", err)
|
||||||
|
}
|
||||||
|
server.signPeerAttachResponse(logical, req, &resp, auth)
|
||||||
|
verifyResult, err := client.verifyPeerAttachResponse(req, resp, reqState)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("verifyPeerAttachResponse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !supportsPeerAttachForwardSecrecy(resp.Features) {
|
||||||
|
t.Fatalf("response features = %d, want forward secrecy bit", resp.Features)
|
||||||
|
}
|
||||||
|
if resp.KeyMode != peerAttachKeyModeECDHE {
|
||||||
|
t.Fatalf("response key mode = %q, want %q", resp.KeyMode, peerAttachKeyModeECDHE)
|
||||||
|
}
|
||||||
|
if !verifyResult.steadyProfile.forwardSecrecy {
|
||||||
|
t.Fatal("client steady profile should enable forward secrecy")
|
||||||
|
}
|
||||||
|
if !serverProfile.forwardSecrecy {
|
||||||
|
t.Fatal("server steady profile should enable forward secrecy")
|
||||||
|
}
|
||||||
|
if len(verifyResult.steadyProfile.sessionID) == 0 {
|
||||||
|
t.Fatal("client session id should be populated")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(verifyResult.steadyProfile.secretKey, serverProfile.secretKey) {
|
||||||
|
t.Fatal("client/server derived steady keys should match")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(verifyResult.steadyProfile.sessionID, serverProfile.sessionID) {
|
||||||
|
t.Fatal("client/server session ids should match")
|
||||||
|
}
|
||||||
|
if bytes.Equal(verifyResult.steadyProfile.secretKey, client.securityBootstrap.secretKey) {
|
||||||
|
t.Fatal("derived steady key should differ from bootstrap key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerAttachForwardSecrecyStrictRejectsFallback(t *testing.T) {
|
||||||
|
secret := []byte("correct horse battery staple")
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
opts := testModernPSKOptions()
|
||||||
|
opts.RequireForwardSecrecy = true
|
||||||
|
if err := UseModernPSKClient(client, secret, opts); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := UseModernPSKServer(server, secret, opts); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, reqState, err := client.buildPeerAttachRequest("peer-fs-strict")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildPeerAttachRequest failed: %v", err)
|
||||||
|
}
|
||||||
|
_, err = client.verifyPeerAttachResponse(req, peerAttachResponse{
|
||||||
|
PeerID: req.PeerID,
|
||||||
|
Accepted: true,
|
||||||
|
Features: peerAttachFeatureExplicitAuth,
|
||||||
|
ServerNonce: make([]byte, peerAttachNonceSize),
|
||||||
|
AuthTag: computePeerAttachResponseAuthTag(client.securityBootstrap.secretKey, req, peerAttachResponse{PeerID: req.PeerID, Accepted: true, Features: peerAttachFeatureExplicitAuth, ServerNonce: make([]byte, peerAttachNonceSize)}, nil),
|
||||||
|
}, reqState)
|
||||||
|
if !errors.Is(err, errPeerAttachForwardSecrecyRequired) {
|
||||||
|
t.Fatalf("verifyPeerAttachResponse error = %v, want %v", err, errPeerAttachForwardSecrecyRequired)
|
||||||
|
}
|
||||||
|
}
|
||||||
139
peer_attach_policy.go
Normal file
139
peer_attach_policy.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultPeerAttachReplayCapacity = 4096
|
||||||
|
|
||||||
|
type PeerAttachChannelBindingRole string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PeerAttachChannelBindingRoleClient PeerAttachChannelBindingRole = "client"
|
||||||
|
PeerAttachChannelBindingRoleServer PeerAttachChannelBindingRole = "server"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PeerAttachChannelBindingContext struct {
|
||||||
|
Role PeerAttachChannelBindingRole
|
||||||
|
PeerID string
|
||||||
|
Conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeerAttachChannelBindingProvider func(PeerAttachChannelBindingContext) ([]byte, error)
|
||||||
|
|
||||||
|
type PeerAttachSecurityConfig struct {
|
||||||
|
RequireExplicitAuth bool
|
||||||
|
RequireChannelBinding bool
|
||||||
|
ReplayWindow time.Duration
|
||||||
|
ReplayCapacity int
|
||||||
|
ChannelBinding PeerAttachChannelBindingProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
type peerAttachSecurityState struct {
|
||||||
|
requireExplicitAuth bool
|
||||||
|
requireChannelBinding bool
|
||||||
|
replayWindow time.Duration
|
||||||
|
replayCapacity int
|
||||||
|
channelBinding PeerAttachChannelBindingProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
var errPeerAttachChannelBindingProviderNil = errors.New("peer attach channel binding provider is nil")
|
||||||
|
|
||||||
|
func DefaultPeerAttachSecurityConfig() PeerAttachSecurityConfig {
|
||||||
|
return PeerAttachSecurityConfig{
|
||||||
|
ReplayWindow: peerAttachReplayTTL,
|
||||||
|
ReplayCapacity: defaultPeerAttachReplayCapacity,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) (peerAttachSecurityState, error) {
|
||||||
|
if cfg.ReplayWindow <= 0 {
|
||||||
|
cfg.ReplayWindow = peerAttachReplayTTL
|
||||||
|
}
|
||||||
|
if cfg.ReplayCapacity <= 0 {
|
||||||
|
cfg.ReplayCapacity = defaultPeerAttachReplayCapacity
|
||||||
|
}
|
||||||
|
if cfg.RequireChannelBinding {
|
||||||
|
cfg.RequireExplicitAuth = true
|
||||||
|
if cfg.ChannelBinding == nil {
|
||||||
|
return peerAttachSecurityState{}, errPeerAttachChannelBindingProviderNil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return peerAttachSecurityState{
|
||||||
|
requireExplicitAuth: cfg.RequireExplicitAuth,
|
||||||
|
requireChannelBinding: cfg.RequireChannelBinding,
|
||||||
|
replayWindow: cfg.ReplayWindow,
|
||||||
|
replayCapacity: cfg.ReplayCapacity,
|
||||||
|
channelBinding: cfg.ChannelBinding,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerAttachSecurityConfigFromState(state *peerAttachSecurityState) PeerAttachSecurityConfig {
|
||||||
|
if state == nil {
|
||||||
|
return DefaultPeerAttachSecurityConfig()
|
||||||
|
}
|
||||||
|
return PeerAttachSecurityConfig{
|
||||||
|
RequireExplicitAuth: state.requireExplicitAuth,
|
||||||
|
RequireChannelBinding: state.requireChannelBinding,
|
||||||
|
ReplayWindow: state.replayWindow,
|
||||||
|
ReplayCapacity: state.replayCapacity,
|
||||||
|
ChannelBinding: state.channelBinding,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultPeerAttachSecurityState() *peerAttachSecurityState {
|
||||||
|
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
|
||||||
|
return &cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) peerAttachSecuritySnapshot() peerAttachSecurityState {
|
||||||
|
if c == nil {
|
||||||
|
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
if state := c.peerAttachSecurity.Load(); state != nil {
|
||||||
|
return *state
|
||||||
|
}
|
||||||
|
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) peerAttachSecuritySnapshot() peerAttachSecurityState {
|
||||||
|
if s == nil {
|
||||||
|
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
if state := s.peerAttachSecurity.Load(); state != nil {
|
||||||
|
return *state
|
||||||
|
}
|
||||||
|
cfg, _ := normalizePeerAttachSecurityConfig(DefaultPeerAttachSecurityConfig())
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SetPeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) error {
|
||||||
|
state, err := normalizePeerAttachSecurityConfig(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.peerAttachSecurity.Store(&state)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) PeerAttachSecurityConfig() PeerAttachSecurityConfig {
|
||||||
|
return peerAttachSecurityConfigFromState(c.peerAttachSecurity.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) SetPeerAttachSecurityConfig(cfg PeerAttachSecurityConfig) error {
|
||||||
|
state, err := normalizePeerAttachSecurityConfig(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.peerAttachSecurity.Store(&state)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) PeerAttachSecurityConfig() PeerAttachSecurityConfig {
|
||||||
|
return peerAttachSecurityConfigFromState(s.peerAttachSecurity.Load())
|
||||||
|
}
|
||||||
221
peer_attach_policy_test.go
Normal file
221
peer_attach_policy_test.go
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func staticPeerAttachChannelBindingProvider(material []byte) PeerAttachChannelBindingProvider {
|
||||||
|
cloned := bytes.Clone(material)
|
||||||
|
return func(PeerAttachChannelBindingContext) ([]byte, error) {
|
||||||
|
return bytes.Clone(cloned), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func failingPeerAttachChannelBindingProvider(PeerAttachChannelBindingContext) ([]byte, error) {
|
||||||
|
return nil, errors.New("binding unavailable")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetPeerAttachSecurityConfigRejectsMissingChannelBindingProvider(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
cfg := PeerAttachSecurityConfig{RequireChannelBinding: true}
|
||||||
|
|
||||||
|
if err := client.SetPeerAttachSecurityConfig(cfg); !errors.Is(err, errPeerAttachChannelBindingProviderNil) {
|
||||||
|
t.Fatalf("client SetPeerAttachSecurityConfig error = %v, want %v", err, errPeerAttachChannelBindingProviderNil)
|
||||||
|
}
|
||||||
|
if err := server.SetPeerAttachSecurityConfig(cfg); !errors.Is(err, errPeerAttachChannelBindingProviderNil) {
|
||||||
|
t.Fatalf("server SetPeerAttachSecurityConfig error = %v, want %v", err, errPeerAttachChannelBindingProviderNil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerAttachRequireExplicitAuthRejectsFallbackClient(t *testing.T) {
|
||||||
|
secret := []byte("correct horse battery staple")
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
|
||||||
|
RequireExplicitAuth: true,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("SetPeerAttachSecurityConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logical := newPeerAttachAuthLogicalForTest(t, server)
|
||||||
|
if _, err := server.validatePeerAttachRequestAuth(logical, nil, peerAttachRequest{PeerID: "peer-fallback"}); !errors.Is(err, errPeerAttachExplicitAuthRequired) {
|
||||||
|
t.Fatalf("validatePeerAttachRequestAuth error = %v, want %v", err, errPeerAttachExplicitAuthRequired)
|
||||||
|
}
|
||||||
|
classifyPeerAttachRejectCounter(server, errPeerAttachExplicitAuthRequired)
|
||||||
|
|
||||||
|
snapshot, snapErr := GetServerRuntimeSnapshot(server)
|
||||||
|
if snapErr != nil {
|
||||||
|
t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.PeerAttachDowngradeRejects, int64(1); got != want {
|
||||||
|
t.Fatalf("PeerAttachDowngradeRejects = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got := snapshot.PeerAttachAuthFallbacks; got != 0 {
|
||||||
|
t.Fatalf("PeerAttachAuthFallbacks = %d, want 0", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerAttachChannelBindingRoundTrip(t *testing.T) {
|
||||||
|
secret := []byte("correct horse battery staple")
|
||||||
|
bindingProvider := staticPeerAttachChannelBindingProvider([]byte("tls-exporter:test"))
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
if err := UsePSKOverExternalTransportServer(server, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
|
||||||
|
RequireChannelBinding: true,
|
||||||
|
ChannelBinding: bindingProvider,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
server.SetLink("echo", func(msg *Message) {
|
||||||
|
_ = msg.Reply([]byte("ack"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UsePSKOverExternalTransportClient(client, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := client.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
|
||||||
|
RequireChannelBinding: true,
|
||||||
|
ChannelBinding: bindingProvider,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("client SetPeerAttachSecurityConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
bootstrapPeerAttachLogicalForTest(t, server, right)
|
||||||
|
if err := client.ConnectByConn(left); err != nil {
|
||||||
|
t.Fatalf("ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
_ = client.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
reply, err := client.SendWait("echo", []byte("ping"), time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SendWait failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := string(reply.Value), "ack"; got != want {
|
||||||
|
t.Fatalf("reply = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverSnapshot, err := GetServerRuntimeSnapshot(server)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetServerRuntimeSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if !serverSnapshot.PeerAttachRequireExplicitAuth || !serverSnapshot.PeerAttachRequireChannelBinding || !serverSnapshot.PeerAttachChannelBindingConfigured {
|
||||||
|
t.Fatalf("unexpected server peer attach policy snapshot: %+v", serverSnapshot)
|
||||||
|
}
|
||||||
|
if got, want := serverSnapshot.PeerAttachExplicitAuth, int64(1); got != want {
|
||||||
|
t.Fatalf("PeerAttachExplicitAuth = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if serverSnapshot.PeerAttachAuthRejects != 0 || serverSnapshot.PeerAttachDowngradeRejects != 0 || serverSnapshot.PeerAttachBindingRejects != 0 {
|
||||||
|
t.Fatalf("unexpected server reject counters: %+v", serverSnapshot)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientSnapshot, err := GetClientRuntimeSnapshot(client)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if !clientSnapshot.PeerAttachRequireExplicitAuth || !clientSnapshot.PeerAttachRequireChannelBinding || !clientSnapshot.PeerAttachChannelBindingConfigured {
|
||||||
|
t.Fatalf("unexpected client peer attach policy snapshot: %+v", clientSnapshot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerAttachChannelBindingProviderFailureRejectsAttach(t *testing.T) {
|
||||||
|
secret := []byte("correct horse battery staple")
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
|
||||||
|
RequireChannelBinding: true,
|
||||||
|
ChannelBinding: failingPeerAttachChannelBindingProvider,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := client.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
|
||||||
|
RequireChannelBinding: true,
|
||||||
|
ChannelBinding: staticPeerAttachChannelBindingProvider([]byte("binding")),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("client SetPeerAttachSecurityConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
bootstrapPeerAttachLogicalForTest(t, server, right)
|
||||||
|
|
||||||
|
err := client.ConnectByConn(left)
|
||||||
|
if !errors.Is(err, errPeerAttachChannelBindingUnavailable) && (err == nil || err.Error() != errPeerAttachChannelBindingUnavailable.Error()) {
|
||||||
|
t.Fatalf("ConnectByConn error = %v, want %v", err, errPeerAttachChannelBindingUnavailable)
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshot, snapErr := GetServerRuntimeSnapshot(server)
|
||||||
|
if snapErr != nil {
|
||||||
|
t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.PeerAttachBindingRejects, int64(1); got != want {
|
||||||
|
t.Fatalf("PeerAttachBindingRejects = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerAttachReplayCapacityRejectsOverflow(t *testing.T) {
|
||||||
|
secret := []byte("correct horse battery staple")
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{
|
||||||
|
ReplayCapacity: 1,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logical := newPeerAttachAuthLogicalForTest(t, server)
|
||||||
|
first, _, err := client.buildPeerAttachRequest("peer-one")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildPeerAttachRequest(first) failed: %v", err)
|
||||||
|
}
|
||||||
|
second, _, err := client.buildPeerAttachRequest("peer-two")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildPeerAttachRequest(second) failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := server.validatePeerAttachRequestAuth(logical, nil, first); err != nil {
|
||||||
|
t.Fatalf("validatePeerAttachRequestAuth(first) failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := server.validatePeerAttachRequestAuth(logical, nil, second); !errors.Is(err, errPeerAttachReplayWindowFull) {
|
||||||
|
t.Fatalf("validatePeerAttachRequestAuth(second) error = %v, want %v", err, errPeerAttachReplayWindowFull)
|
||||||
|
}
|
||||||
|
classifyPeerAttachRejectCounter(server, errPeerAttachReplayWindowFull)
|
||||||
|
|
||||||
|
snapshot, snapErr := GetServerRuntimeSnapshot(server)
|
||||||
|
if snapErr != nil {
|
||||||
|
t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.PeerAttachReplayCapacity, 1; got != want {
|
||||||
|
t.Fatalf("PeerAttachReplayCapacity = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.PeerAttachReplayOverflowRejects, int64(1); got != want {
|
||||||
|
t.Fatalf("PeerAttachReplayOverflowRejects = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -16,6 +16,10 @@ const (
|
|||||||
|
|
||||||
type peerAttachRequest struct {
|
type peerAttachRequest struct {
|
||||||
PeerID string
|
PeerID string
|
||||||
|
Features uint64
|
||||||
|
ClientNonce []byte
|
||||||
|
ClientECDHEPublicKey []byte
|
||||||
|
AuthTag []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type peerAttachResponse struct {
|
type peerAttachResponse struct {
|
||||||
@ -23,6 +27,11 @@ type peerAttachResponse struct {
|
|||||||
Accepted bool
|
Accepted bool
|
||||||
Reused bool
|
Reused bool
|
||||||
Error string
|
Error string
|
||||||
|
Features uint64
|
||||||
|
KeyMode string
|
||||||
|
ServerNonce []byte
|
||||||
|
ServerECDHEPublicKey []byte
|
||||||
|
AuthTag []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClientPeerIdentity() string {
|
func newClientPeerIdentity() string {
|
||||||
@ -108,7 +117,11 @@ func (c *ClientCommon) announceClientPeerIdentity() error {
|
|||||||
if peerID == "" {
|
if peerID == "" {
|
||||||
return errors.New("peer identity is empty")
|
return errors.New("peer identity is empty")
|
||||||
}
|
}
|
||||||
encoded, err := c.sequenceEn(peerAttachRequest{PeerID: peerID})
|
req, requestState, err := c.buildPeerAttachRequest(peerID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
encoded, err := c.sequenceEn(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -133,6 +146,12 @@ func (c *ClientCommon) announceClientPeerIdentity() error {
|
|||||||
}
|
}
|
||||||
return errors.New("peer attach rejected")
|
return errors.New("peer attach rejected")
|
||||||
}
|
}
|
||||||
|
verifyResult, err := c.verifyPeerAttachResponse(req, resp, requestState)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.setClientNegotiatedSteadyTransportProtection(verifyResult.steadyProfile)
|
||||||
|
c.markClientPeerAttachAuthenticated(verifyResult.authFallback, time.Now())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -188,7 +207,7 @@ func (s *ServerCommon) replyPeerAttach(client *LogicalConn, message Message, res
|
|||||||
Type: MSG_SYS_REPLY,
|
Type: MSG_SYS_REPLY,
|
||||||
}
|
}
|
||||||
if message.inboundConn != nil {
|
if message.inboundConn != nil {
|
||||||
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply)
|
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, messageInboundTransportProtectionSnapshot(&message), reply)
|
||||||
}
|
}
|
||||||
_, err = s.sendLogical(client, reply)
|
_, err = s.sendLogical(client, reply)
|
||||||
return err
|
return err
|
||||||
@ -200,6 +219,10 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool {
|
|||||||
}
|
}
|
||||||
message = hydrateServerMessagePeerFields(message)
|
message = hydrateServerMessagePeerFields(message)
|
||||||
current := messageLogicalConnSnapshot(&message)
|
current := messageLogicalConnSnapshot(&message)
|
||||||
|
transport := message.inboundConn
|
||||||
|
if transport == nil && current != nil {
|
||||||
|
transport = current.transportSnapshot()
|
||||||
|
}
|
||||||
req, err := decodePeerAttachRequest(s.sequenceDe, message.Value)
|
req, err := decodePeerAttachRequest(s.sequenceDe, message.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if current != nil {
|
if current != nil {
|
||||||
@ -210,6 +233,18 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
auth, err := s.validatePeerAttachRequestAuth(current, transport, req)
|
||||||
|
if err != nil {
|
||||||
|
classifyPeerAttachRejectCounter(s, err)
|
||||||
|
if current != nil {
|
||||||
|
_ = s.replyPeerAttach(current, message, peerAttachResponse{
|
||||||
|
PeerID: req.PeerID,
|
||||||
|
Accepted: false,
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
bound, reused, err := s.bindAcceptedClientIdentity(current, req.PeerID)
|
bound, reused, err := s.bindAcceptedClientIdentity(current, req.PeerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if current != nil {
|
if current != nil {
|
||||||
@ -221,12 +256,37 @@ func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if err := s.replyPeerAttach(bound, message, peerAttachResponse{
|
resp := peerAttachResponse{
|
||||||
PeerID: bound.ID(),
|
PeerID: bound.ID(),
|
||||||
Accepted: true,
|
Accepted: true,
|
||||||
Reused: reused,
|
Reused: reused,
|
||||||
}); err != nil && bound != nil {
|
}
|
||||||
|
steadyProfile, err := s.preparePeerAttachSteadyTransportProfile(bound, req, &resp, auth)
|
||||||
|
if err != nil {
|
||||||
|
if bound != nil {
|
||||||
|
_ = s.replyPeerAttach(bound, message, peerAttachResponse{
|
||||||
|
PeerID: req.PeerID,
|
||||||
|
Accepted: false,
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
s.signPeerAttachResponse(bound, req, &resp, auth)
|
||||||
|
if bound != nil {
|
||||||
|
bound.markPeerAttachAuthenticated(s.securityAuthMode, auth.fallback, time.Now())
|
||||||
|
if auth.explicit {
|
||||||
|
s.peerAttachExplicitCount.Add(1)
|
||||||
|
} else if auth.fallback {
|
||||||
|
s.peerAttachAuthFallbackCount.Add(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.replyPeerAttach(bound, message, resp); err != nil && bound != nil {
|
||||||
s.stopLogicalSession(bound, "peer attach reply failed", err)
|
s.stopLogicalSession(bound, "peer attach reply failed", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if bound != nil && s.securityConfigured {
|
||||||
|
bound.applyTransportProtectionProfile(steadyProfile)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@ -119,6 +119,7 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) {
|
|||||||
defer serverConn.Close()
|
defer serverConn.Close()
|
||||||
|
|
||||||
logical := bootstrapPeerAttachLogicalForTest(t, server, serverConn)
|
logical := bootstrapPeerAttachLogicalForTest(t, server, serverConn)
|
||||||
|
originalProfile := logical.transportProtectionProfileSnapshot()
|
||||||
message := Message{
|
message := Message{
|
||||||
NetType: NET_SERVER,
|
NetType: NET_SERVER,
|
||||||
LogicalConn: logical,
|
LogicalConn: logical,
|
||||||
@ -131,6 +132,13 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) {
|
|||||||
Time: time.Now(),
|
Time: time.Now(),
|
||||||
inboundConn: serverConn,
|
inboundConn: serverConn,
|
||||||
}
|
}
|
||||||
|
message = hydrateServerMessagePeerFields(message)
|
||||||
|
|
||||||
|
alternate, err := deriveModernPSKProtectionProfile([]byte("notify-peer-attach-reply-alternate"), testModernPSKOptions(), ProtectionManaged)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("deriveModernPSKProtectionProfile failed: %v", err)
|
||||||
|
}
|
||||||
|
logical.applyTransportProtectionProfile(alternate)
|
||||||
|
|
||||||
done := make(chan error, 1)
|
done := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@ -140,7 +148,7 @@ func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}()
|
}()
|
||||||
|
|
||||||
env := readServerEnvelopeFromConn(t, server, logical, clientConn, time.Second)
|
env := readServerEnvelopeFromConnWithProfile(t, server, originalProfile, clientConn, time.Second)
|
||||||
if env.Kind != EnvelopeSignal {
|
if env.Kind != EnvelopeSignal {
|
||||||
t.Fatalf("reply envelope kind = %v, want %v", env.Kind, EnvelopeSignal)
|
t.Fatalf("reply envelope kind = %v, want %v", env.Kind, EnvelopeSignal)
|
||||||
}
|
}
|
||||||
|
|||||||
139
security_forward_secrecy.go
Normal file
139
security_forward_secrecy.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ecdh"
|
||||||
|
"crypto/hmac"
|
||||||
|
cryptorand "crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
peerAttachECDHEPublicKeySize = 32
|
||||||
|
peerAttachSessionIDSize = 16
|
||||||
|
|
||||||
|
peerAttachKeyModeStatic = "psk-static"
|
||||||
|
peerAttachKeyModeECDHE = "psk-ecdhe"
|
||||||
|
transportKeyModeExternal = "external"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errPeerAttachForwardSecrecyInvalid = errors.New("peer attach forward secrecy is invalid")
|
||||||
|
|
||||||
|
type peerAttachRequestState struct {
|
||||||
|
forwardSecrecy *peerAttachForwardSecrecyClientState
|
||||||
|
}
|
||||||
|
|
||||||
|
type peerAttachForwardSecrecyClientState struct {
|
||||||
|
privateKey *ecdh.PrivateKey
|
||||||
|
publicKey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type peerAttachResponseVerifyResult struct {
|
||||||
|
authFallback bool
|
||||||
|
steadyProfile transportProtectionProfile
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPeerAttachForwardSecrecyClientState() (*peerAttachForwardSecrecyClientState, error) {
|
||||||
|
curve := ecdh.X25519()
|
||||||
|
privateKey, err := curve.GenerateKey(cryptorand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
publicKey := privateKey.PublicKey().Bytes()
|
||||||
|
if len(publicKey) != peerAttachECDHEPublicKeySize {
|
||||||
|
return nil, errPeerAttachForwardSecrecyInvalid
|
||||||
|
}
|
||||||
|
return &peerAttachForwardSecrecyClientState{
|
||||||
|
privateKey: privateKey,
|
||||||
|
publicKey: bytes.Clone(publicKey),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func derivePeerAttachForwardSecrecyTransportProfile(base transportProtectionProfile, bootstrapKey []byte, localPrivateKey *ecdh.PrivateKey, peerPublicKey []byte, req peerAttachRequest, resp peerAttachResponse) (transportProtectionProfile, error) {
|
||||||
|
if len(bootstrapKey) == 0 || localPrivateKey == nil || len(peerPublicKey) != peerAttachECDHEPublicKeySize {
|
||||||
|
return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid
|
||||||
|
}
|
||||||
|
curve := ecdh.X25519()
|
||||||
|
publicKey, err := curve.NewPublicKey(peerPublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid
|
||||||
|
}
|
||||||
|
sharedSecret, err := localPrivateKey.ECDH(publicKey)
|
||||||
|
if err != nil {
|
||||||
|
return transportProtectionProfile{}, errPeerAttachForwardSecrecyInvalid
|
||||||
|
}
|
||||||
|
transcriptHash := peerAttachForwardSecrecyTranscriptHash(req, resp)
|
||||||
|
ikm := make([]byte, 0, len(sharedSecret)+len(transcriptHash))
|
||||||
|
ikm = append(ikm, sharedSecret...)
|
||||||
|
ikm = append(ikm, transcriptHash...)
|
||||||
|
prk := hkdfExtractSHA256(bootstrapKey, ikm)
|
||||||
|
sessionKey := hkdfExpandSHA256(prk, []byte("notify/transport/session/v1"), 32)
|
||||||
|
sessionID := hkdfExpandSHA256(prk, []byte("notify/session-id/v1"), peerAttachSessionIDSize)
|
||||||
|
return deriveModernPSKSessionProtectionProfile(base, sessionKey, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerAttachForwardSecrecyTranscriptHash(req peerAttachRequest, resp peerAttachResponse) []byte {
|
||||||
|
buf := make([]byte, 0, 256+len(req.PeerID)+len(resp.PeerID)+len(resp.Error)+len(req.ClientECDHEPublicKey)+len(resp.ServerECDHEPublicKey))
|
||||||
|
buf = appendPeerAttachTranscriptString(buf, "notify/peer-attach/forward-secrecy/v1")
|
||||||
|
buf = binary.BigEndian.AppendUint64(buf, req.Features)
|
||||||
|
buf = appendPeerAttachTranscriptString(buf, req.PeerID)
|
||||||
|
buf = appendPeerAttachTranscriptBytes(buf, req.ClientNonce)
|
||||||
|
buf = appendPeerAttachTranscriptBytes(buf, req.ClientECDHEPublicKey)
|
||||||
|
buf = binary.BigEndian.AppendUint64(buf, resp.Features)
|
||||||
|
buf = appendPeerAttachTranscriptString(buf, resp.PeerID)
|
||||||
|
buf = appendPeerAttachTranscriptBool(buf, resp.Accepted)
|
||||||
|
buf = appendPeerAttachTranscriptBool(buf, resp.Reused)
|
||||||
|
buf = appendPeerAttachTranscriptString(buf, resp.Error)
|
||||||
|
buf = appendPeerAttachTranscriptBytes(buf, resp.ServerNonce)
|
||||||
|
buf = appendPeerAttachTranscriptString(buf, resp.KeyMode)
|
||||||
|
buf = appendPeerAttachTranscriptBytes(buf, resp.ServerECDHEPublicKey)
|
||||||
|
sum := sha256.Sum256(buf)
|
||||||
|
return sum[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendPeerAttachTranscriptBytes(dst []byte, data []byte) []byte {
|
||||||
|
dst = binary.BigEndian.AppendUint32(dst, uint32(len(data)))
|
||||||
|
return append(dst, data...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendPeerAttachTranscriptString(dst []byte, value string) []byte {
|
||||||
|
return appendPeerAttachTranscriptBytes(dst, []byte(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendPeerAttachTranscriptBool(dst []byte, value bool) []byte {
|
||||||
|
if value {
|
||||||
|
return append(dst, 1)
|
||||||
|
}
|
||||||
|
return append(dst, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hkdfExtractSHA256(salt []byte, ikm []byte) []byte {
|
||||||
|
mac := hmac.New(sha256.New, salt)
|
||||||
|
_, _ = mac.Write(ikm)
|
||||||
|
return mac.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hkdfExpandSHA256(prk []byte, info []byte, size int) []byte {
|
||||||
|
if size <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]byte, 0, size)
|
||||||
|
var block []byte
|
||||||
|
for counter := byte(1); len(out) < size; counter++ {
|
||||||
|
mac := hmac.New(sha256.New, prk)
|
||||||
|
if len(block) != 0 {
|
||||||
|
_, _ = mac.Write(block)
|
||||||
|
}
|
||||||
|
_, _ = mac.Write(info)
|
||||||
|
_, _ = mac.Write([]byte{counter})
|
||||||
|
block = mac.Sum(nil)
|
||||||
|
remaining := size - len(out)
|
||||||
|
if remaining > len(block) {
|
||||||
|
remaining = len(block)
|
||||||
|
}
|
||||||
|
out = append(out, block[:remaining]...)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
408
security_profile.go
Normal file
408
security_profile.go
Normal file
@ -0,0 +1,408 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import "bytes"
|
||||||
|
|
||||||
|
// AuthMode describes how notify authenticates the peer during bootstrap.
|
||||||
|
type AuthMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuthNone AuthMode = iota
|
||||||
|
AuthPSK
|
||||||
|
AuthExternalPeer
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProtectionMode describes how notify protects steady-state transport payloads.
|
||||||
|
type ProtectionMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProtectionManaged ProtectionMode = iota
|
||||||
|
ProtectionExternal
|
||||||
|
ProtectionNested
|
||||||
|
)
|
||||||
|
|
||||||
|
// SecurityOptions describes the high-level auth/protection policy.
|
||||||
|
//
|
||||||
|
// The current implementation still exposes dedicated helper constructors such
|
||||||
|
// as UseModernPSKClient/Server and UsePSKOverExternalTransportClient/Server.
|
||||||
|
type SecurityOptions struct {
|
||||||
|
AuthMode AuthMode
|
||||||
|
ProtectionMode ProtectionMode
|
||||||
|
SharedSecret []byte
|
||||||
|
RequireForwardSecrecy bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func authModeName(mode AuthMode) string {
|
||||||
|
switch mode {
|
||||||
|
case AuthNone:
|
||||||
|
return "none"
|
||||||
|
case AuthPSK:
|
||||||
|
return "psk"
|
||||||
|
case AuthExternalPeer:
|
||||||
|
return "external-peer"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func protectionModeName(mode ProtectionMode) string {
|
||||||
|
switch mode {
|
||||||
|
case ProtectionManaged:
|
||||||
|
return "managed"
|
||||||
|
case ProtectionExternal:
|
||||||
|
return "external"
|
||||||
|
case ProtectionNested:
|
||||||
|
return "nested"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type transportProtectionProfile struct {
|
||||||
|
mode ProtectionMode
|
||||||
|
msgEn func([]byte, []byte) []byte
|
||||||
|
msgDe func([]byte, []byte) []byte
|
||||||
|
fastStreamEncode transportFastStreamEncoder
|
||||||
|
fastBulkEncode transportFastBulkEncoder
|
||||||
|
fastPlainEncode transportFastPlainEncoder
|
||||||
|
runtime *modernPSKCodecRuntime
|
||||||
|
secretKey []byte
|
||||||
|
keyMode string
|
||||||
|
sessionID []byte
|
||||||
|
forwardSecrecy bool
|
||||||
|
forwardSecrecyFallback bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneTransportProtectionKey(src []byte) []byte {
|
||||||
|
if len(src) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bytes.Clone(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTransportProtectionProfile(mode ProtectionMode, bundle modernPSKTransportBundle, runtime *modernPSKCodecRuntime, secretKey []byte) transportProtectionProfile {
|
||||||
|
return transportProtectionProfile{
|
||||||
|
mode: mode,
|
||||||
|
msgEn: bundle.msgEn,
|
||||||
|
msgDe: bundle.msgDe,
|
||||||
|
fastStreamEncode: bundle.fastStreamEncode,
|
||||||
|
fastBulkEncode: bundle.fastBulkEncode,
|
||||||
|
fastPlainEncode: bundle.fastPlainEncode,
|
||||||
|
runtime: runtime,
|
||||||
|
secretKey: cloneTransportProtectionKey(secretKey),
|
||||||
|
keyMode: defaultTransportKeyMode(mode, secretKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultTransportKeyMode(mode ProtectionMode, secretKey []byte) string {
|
||||||
|
if len(secretKey) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch mode {
|
||||||
|
case ProtectionManaged, ProtectionNested:
|
||||||
|
return peerAttachKeyModeStatic
|
||||||
|
case ProtectionExternal:
|
||||||
|
return transportKeyModeExternal
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneTransportSessionID(src []byte) []byte {
|
||||||
|
if len(src) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bytes.Clone(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p transportProtectionProfile) clone() transportProtectionProfile {
|
||||||
|
p.secretKey = cloneTransportProtectionKey(p.secretKey)
|
||||||
|
p.sessionID = cloneTransportSessionID(p.sessionID)
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p transportProtectionProfile) withForwardSecrecyFallback(fallback bool) transportProtectionProfile {
|
||||||
|
p = p.clone()
|
||||||
|
p.forwardSecrecy = false
|
||||||
|
p.forwardSecrecyFallback = fallback
|
||||||
|
p.sessionID = nil
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func passthroughTransportCodec(_ []byte, payload []byte) []byte {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
func passthroughFastPlainEncode(_ []byte, plainLen int, fill func([]byte) error) ([]byte, error) {
|
||||||
|
if plainLen < 0 {
|
||||||
|
return nil, errTransportPayloadEncryptFailed
|
||||||
|
}
|
||||||
|
buf := make([]byte, plainLen)
|
||||||
|
if fill != nil {
|
||||||
|
if err := fill(buf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildExternalTransportBundle() modernPSKTransportBundle {
|
||||||
|
fastPlainEncode := passthroughFastPlainEncode
|
||||||
|
fastStreamEncode := func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
||||||
|
return encodeStreamFastFramePayloadFast(fastPlainEncode, secretKey, streamFastDataFrame{
|
||||||
|
DataID: dataID,
|
||||||
|
Seq: seq,
|
||||||
|
Payload: payload,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
fastBulkEncode := func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
||||||
|
return encodeBulkFastFramePayloadFast(fastPlainEncode, secretKey, bulkFastFrame{
|
||||||
|
Type: bulkFastPayloadTypeData,
|
||||||
|
DataID: dataID,
|
||||||
|
Seq: seq,
|
||||||
|
Payload: payload,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return modernPSKTransportBundle{
|
||||||
|
msgEn: passthroughTransportCodec,
|
||||||
|
msgDe: passthroughTransportCodec,
|
||||||
|
fastStreamEncode: fastStreamEncode,
|
||||||
|
fastBulkEncode: fastBulkEncode,
|
||||||
|
fastPlainEncode: fastPlainEncode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultTransportProtectionProfile() transportProtectionProfile {
|
||||||
|
return newTransportProtectionProfile(ProtectionManaged, defaultModernPSKTransportBundle(), nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientTransportProtectionSnapshot() transportProtectionProfile {
|
||||||
|
if c == nil {
|
||||||
|
return transportProtectionProfile{}
|
||||||
|
}
|
||||||
|
if state := c.transportProtection.Load(); state != nil {
|
||||||
|
return *state
|
||||||
|
}
|
||||||
|
return transportProtectionProfile{
|
||||||
|
mode: ProtectionManaged,
|
||||||
|
msgEn: c.msgEn,
|
||||||
|
msgDe: c.msgDe,
|
||||||
|
fastStreamEncode: c.fastStreamEncode,
|
||||||
|
fastBulkEncode: c.fastBulkEncode,
|
||||||
|
fastPlainEncode: c.fastPlainEncode,
|
||||||
|
runtime: c.modernPSKRuntime,
|
||||||
|
secretKey: c.SecretKey,
|
||||||
|
keyMode: defaultTransportKeyMode(ProtectionManaged, c.SecretKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) setClientTransportProtectionProfile(profile transportProtectionProfile) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
profile.secretKey = cloneTransportProtectionKey(profile.secretKey)
|
||||||
|
profile.sessionID = cloneTransportSessionID(profile.sessionID)
|
||||||
|
c.msgEn = profile.msgEn
|
||||||
|
c.msgDe = profile.msgDe
|
||||||
|
c.fastStreamEncode = profile.fastStreamEncode
|
||||||
|
c.fastBulkEncode = profile.fastBulkEncode
|
||||||
|
c.fastPlainEncode = profile.fastPlainEncode
|
||||||
|
c.modernPSKRuntime = profile.runtime
|
||||||
|
c.SecretKey = profile.secretKey
|
||||||
|
c.transportProtection.Store(&profile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clearClientSecurityProfiles() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.securityConfigured = false
|
||||||
|
c.securityAuthMode = AuthNone
|
||||||
|
c.securityProtectionMode = ProtectionManaged
|
||||||
|
c.securityBootstrap = transportProtectionProfile{}
|
||||||
|
c.securitySteady = transportProtectionProfile{}
|
||||||
|
c.securitySteadyNegotiated = transportProtectionProfile{}
|
||||||
|
c.securityRequireForwardSecrecy = false
|
||||||
|
c.peerAttachAuthenticated = false
|
||||||
|
c.peerAttachAuthFallback = false
|
||||||
|
c.peerAttachAt = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) configureClientSecurityProfiles(authMode AuthMode, protectionMode ProtectionMode, bootstrap transportProtectionProfile, steady transportProtectionProfile, requireForwardSecrecy bool) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.securityConfigured = true
|
||||||
|
c.securityAuthMode = authMode
|
||||||
|
c.securityProtectionMode = protectionMode
|
||||||
|
c.securityBootstrap = bootstrap.clone()
|
||||||
|
c.securitySteady = steady.clone()
|
||||||
|
c.securitySteadyNegotiated = steady.clone()
|
||||||
|
c.securityRequireForwardSecrecy = requireForwardSecrecy
|
||||||
|
c.setClientTransportProtectionProfile(bootstrap)
|
||||||
|
c.securityReadyCheck = len(bootstrap.secretKey) != 0
|
||||||
|
c.skipKeyExchange = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) activateClientBootstrapTransportProtection() {
|
||||||
|
if c == nil || !c.securityConfigured {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.resetClientNegotiatedSteadyTransportProtection()
|
||||||
|
c.setClientTransportProtectionProfile(c.securityBootstrap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) activateClientSteadyTransportProtection() {
|
||||||
|
if c == nil || !c.securityConfigured {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.setClientTransportProtectionProfile(c.securitySteadyNegotiated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) resetClientNegotiatedSteadyTransportProtection() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.securitySteadyNegotiated = c.securitySteady.clone().withForwardSecrecyFallback(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) setClientNegotiatedSteadyTransportProtection(profile transportProtectionProfile) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.securitySteadyNegotiated = profile.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientSupportsForwardSecrecy() bool {
|
||||||
|
if c == nil || !c.securityConfigured {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if c.securityAuthMode != AuthPSK {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if c.securityProtectionMode == ProtectionExternal {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return len(c.securityBootstrap.secretKey) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientRequiresForwardSecrecy() bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return c.securityRequireForwardSecrecy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) serverSupportsForwardSecrecy() bool {
|
||||||
|
if s == nil || !s.securityConfigured {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.securityAuthMode != AuthPSK {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.securityProtectionMode == ProtectionExternal {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return len(s.securityBootstrap.secretKey) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) serverRequiresForwardSecrecy() bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.securityRequireForwardSecrecy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) setServerDefaultTransportProtectionProfile(profile transportProtectionProfile) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
profile.secretKey = cloneTransportProtectionKey(profile.secretKey)
|
||||||
|
profile.sessionID = cloneTransportSessionID(profile.sessionID)
|
||||||
|
s.defaultMsgEn = profile.msgEn
|
||||||
|
s.defaultMsgDe = profile.msgDe
|
||||||
|
s.defaultFastStreamEncode = profile.fastStreamEncode
|
||||||
|
s.defaultFastBulkEncode = profile.fastBulkEncode
|
||||||
|
s.defaultFastPlainEncode = profile.fastPlainEncode
|
||||||
|
s.defaultModernPSKRuntime = profile.runtime
|
||||||
|
s.SecretKey = profile.secretKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) clearServerSecurityProfiles() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.securityConfigured = false
|
||||||
|
s.securityAuthMode = AuthNone
|
||||||
|
s.securityProtectionMode = ProtectionManaged
|
||||||
|
s.securityBootstrap = transportProtectionProfile{}
|
||||||
|
s.securitySteady = transportProtectionProfile{}
|
||||||
|
s.securityRequireForwardSecrecy = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) configureServerSecurityProfiles(authMode AuthMode, protectionMode ProtectionMode, bootstrap transportProtectionProfile, steady transportProtectionProfile, requireForwardSecrecy bool) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.securityConfigured = true
|
||||||
|
s.securityAuthMode = authMode
|
||||||
|
s.securityProtectionMode = protectionMode
|
||||||
|
s.securityBootstrap = bootstrap.clone()
|
||||||
|
s.securitySteady = steady.clone()
|
||||||
|
s.securityRequireForwardSecrecy = requireForwardSecrecy
|
||||||
|
s.setServerDefaultTransportProtectionProfile(bootstrap)
|
||||||
|
s.securityReadyCheck = len(bootstrap.secretKey) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) applyLogicalSteadyTransportProtection(logical *LogicalConn) {
|
||||||
|
if s == nil || logical == nil || !s.securityConfigured {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logical.applyTransportProtectionProfile(s.securitySteady)
|
||||||
|
}
|
||||||
|
|
||||||
|
func transportProtectionProfileFromAttachmentState(state *clientConnAttachmentState) transportProtectionProfile {
|
||||||
|
if state == nil {
|
||||||
|
return transportProtectionProfile{}
|
||||||
|
}
|
||||||
|
return transportProtectionProfile{
|
||||||
|
mode: state.protectionMode,
|
||||||
|
msgEn: state.msgEn,
|
||||||
|
msgDe: state.msgDe,
|
||||||
|
fastStreamEncode: state.fastStreamEncode,
|
||||||
|
fastBulkEncode: state.fastBulkEncode,
|
||||||
|
fastPlainEncode: state.fastPlainEncode,
|
||||||
|
runtime: state.modernPSKRuntime,
|
||||||
|
secretKey: cloneTransportProtectionKey(state.secretKey),
|
||||||
|
keyMode: state.keyMode,
|
||||||
|
sessionID: cloneTransportSessionID(state.sessionID),
|
||||||
|
forwardSecrecy: state.forwardSecrecy,
|
||||||
|
forwardSecrecyFallback: state.forwardSecrecyFallback,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) transportProtectionProfileSnapshot() transportProtectionProfile {
|
||||||
|
if c == nil {
|
||||||
|
return transportProtectionProfile{}
|
||||||
|
}
|
||||||
|
return transportProtectionProfileFromAttachmentState(c.attachmentStateSnapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) applyTransportProtectionProfile(profile transportProtectionProfile) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.updateAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.protectionMode = profile.mode
|
||||||
|
state.msgEn = profile.msgEn
|
||||||
|
state.msgDe = profile.msgDe
|
||||||
|
state.fastStreamEncode = profile.fastStreamEncode
|
||||||
|
state.fastBulkEncode = profile.fastBulkEncode
|
||||||
|
state.fastPlainEncode = profile.fastPlainEncode
|
||||||
|
state.modernPSKRuntime = profile.runtime
|
||||||
|
state.secretKey = cloneTransportProtectionKey(profile.secretKey)
|
||||||
|
state.keyMode = profile.keyMode
|
||||||
|
state.sessionID = cloneTransportSessionID(profile.sessionID)
|
||||||
|
state.forwardSecrecy = profile.forwardSecrecy
|
||||||
|
state.forwardSecrecyFallback = profile.forwardSecrecyFallback
|
||||||
|
})
|
||||||
|
}
|
||||||
191
security_psk.go
191
security_psk.go
@ -17,7 +17,8 @@ import (
|
|||||||
var (
|
var (
|
||||||
errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty")
|
errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty")
|
||||||
errModernPSKPayload = errors.New("invalid modern psk payload")
|
errModernPSKPayload = errors.New("invalid modern psk payload")
|
||||||
errModernPSKRequired = errors.New("modern psk is required: call UseModernPSKClient/UseModernPSKServer or set a transport key before Connect/Listen")
|
errModernPSKRequired = errors.New("transport security is required: call UseModernPSKClient/UseModernPSKServer, UsePSKOverExternalTransportClient/Server, or set a transport key before Connect/Listen")
|
||||||
|
errModernPSKForwardSecrecyUnsupported = errors.New("forward secrecy is unsupported for external transport protection")
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -50,6 +51,7 @@ type ModernPSKOptions struct {
|
|||||||
Salt []byte
|
Salt []byte
|
||||||
AAD []byte
|
AAD []byte
|
||||||
Argon2Params starcrypto.Argon2Params
|
Argon2Params starcrypto.Argon2Params
|
||||||
|
RequireForwardSecrecy bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultModernPSKOptions returns the recommended settings for the current
|
// DefaultModernPSKOptions returns the recommended settings for the current
|
||||||
@ -78,24 +80,17 @@ func defaultModernPSKTransportBundle() modernPSKTransportBundle {
|
|||||||
// Argon2id, and switches message protection to AES-GCM. Configure it before
|
// Argon2id, and switches message protection to AES-GCM. Configure it before
|
||||||
// calling Connect/ConnectTimeout.
|
// calling Connect/ConnectTimeout.
|
||||||
func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
|
func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
|
||||||
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
|
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
transport := buildModernPSKTransportBundle(aad)
|
|
||||||
runtime, err := newModernPSKCodecRuntime(key, aad)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.SetSecretKey(key)
|
|
||||||
c.SetMsgEn(transport.msgEn)
|
|
||||||
c.SetMsgDe(transport.msgDe)
|
|
||||||
if client, ok := c.(*ClientCommon); ok {
|
if client, ok := c.(*ClientCommon); ok {
|
||||||
client.fastStreamEncode = transport.fastStreamEncode
|
client.configureClientSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy)
|
||||||
client.fastBulkEncode = transport.fastBulkEncode
|
return nil
|
||||||
client.fastPlainEncode = transport.fastPlainEncode
|
|
||||||
client.modernPSKRuntime = runtime
|
|
||||||
}
|
}
|
||||||
|
c.SetSecretKey(managed.secretKey)
|
||||||
|
c.SetMsgEn(managed.msgEn)
|
||||||
|
c.SetMsgDe(managed.msgDe)
|
||||||
c.SetSkipExchangeKey(true)
|
c.SetSkipExchangeKey(true)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -106,24 +101,95 @@ func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) e
|
|||||||
// It derives a transport key with Argon2id and switches message protection to
|
// It derives a transport key with Argon2id and switches message protection to
|
||||||
// AES-GCM. Configure it before calling Listen.
|
// AES-GCM. Configure it before calling Listen.
|
||||||
func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
|
func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
|
||||||
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
|
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
transport := buildModernPSKTransportBundle(aad)
|
|
||||||
runtime, err := newModernPSKCodecRuntime(key, aad)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.SetSecretKey(key)
|
|
||||||
s.SetDefaultCommEncode(transport.msgEn)
|
|
||||||
s.SetDefaultCommDecode(transport.msgDe)
|
|
||||||
if server, ok := s.(*ServerCommon); ok {
|
if server, ok := s.(*ServerCommon); ok {
|
||||||
server.defaultFastStreamEncode = transport.fastStreamEncode
|
server.configureServerSecurityProfiles(AuthPSK, ProtectionManaged, managed, managed, opts != nil && opts.RequireForwardSecrecy)
|
||||||
server.defaultFastBulkEncode = transport.fastBulkEncode
|
return nil
|
||||||
server.defaultFastPlainEncode = transport.fastPlainEncode
|
|
||||||
server.defaultModernPSKRuntime = runtime
|
|
||||||
}
|
}
|
||||||
|
s.SetSecretKey(managed.secretKey)
|
||||||
|
s.SetDefaultCommEncode(managed.msgEn)
|
||||||
|
s.SetDefaultCommDecode(managed.msgDe)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsePSKOverExternalTransportClient authenticates bootstrap with PSK and then
|
||||||
|
// trusts the external channel for steady-state payload protection.
|
||||||
|
func UsePSKOverExternalTransportClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
|
||||||
|
if opts != nil && opts.RequireForwardSecrecy {
|
||||||
|
return errModernPSKForwardSecrecyUnsupported
|
||||||
|
}
|
||||||
|
bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
steady := buildExternalProtectionProfile(bootstrap.secretKey)
|
||||||
|
if client, ok := c.(*ClientCommon); ok {
|
||||||
|
client.configureClientSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.SetSecretKey(bootstrap.secretKey)
|
||||||
|
c.SetMsgEn(bootstrap.msgEn)
|
||||||
|
c.SetMsgDe(bootstrap.msgDe)
|
||||||
|
c.SetSkipExchangeKey(true)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsePSKOverExternalTransportServer authenticates bootstrap with PSK and then
|
||||||
|
// trusts the external channel for steady-state payload protection.
|
||||||
|
func UsePSKOverExternalTransportServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
|
||||||
|
if opts != nil && opts.RequireForwardSecrecy {
|
||||||
|
return errModernPSKForwardSecrecyUnsupported
|
||||||
|
}
|
||||||
|
bootstrap, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionManaged)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
steady := buildExternalProtectionProfile(bootstrap.secretKey)
|
||||||
|
if server, ok := s.(*ServerCommon); ok {
|
||||||
|
server.configureServerSecurityProfiles(AuthPSK, ProtectionExternal, bootstrap, steady, false)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.SetSecretKey(bootstrap.secretKey)
|
||||||
|
s.SetDefaultCommEncode(bootstrap.msgEn)
|
||||||
|
s.SetDefaultCommDecode(bootstrap.msgDe)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseNestedSecurityClient keeps notify transport protection enabled even when
|
||||||
|
// the physical connection is already protected by an outer trusted channel.
|
||||||
|
func UseNestedSecurityClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error {
|
||||||
|
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if client, ok := c.(*ClientCommon); ok {
|
||||||
|
client.configureClientSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.SetSecretKey(managed.secretKey)
|
||||||
|
c.SetMsgEn(managed.msgEn)
|
||||||
|
c.SetMsgDe(managed.msgDe)
|
||||||
|
c.SetSkipExchangeKey(true)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseNestedSecurityServer keeps notify transport protection enabled even when
|
||||||
|
// the physical connection is already protected by an outer trusted channel.
|
||||||
|
func UseNestedSecurityServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error {
|
||||||
|
managed, err := deriveModernPSKProtectionProfile(sharedSecret, opts, ProtectionNested)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if server, ok := s.(*ServerCommon); ok {
|
||||||
|
server.configureServerSecurityProfiles(AuthPSK, ProtectionNested, managed, managed, opts != nil && opts.RequireForwardSecrecy)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.SetSecretKey(managed.secretKey)
|
||||||
|
s.SetDefaultCommEncode(managed.msgEn)
|
||||||
|
s.SetDefaultCommDecode(managed.msgDe)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,15 +198,22 @@ func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) e
|
|||||||
//
|
//
|
||||||
// It is kept only as an explicit fallback path for existing deployments.
|
// It is kept only as an explicit fallback path for existing deployments.
|
||||||
func UseLegacySecurityClient(c Client) {
|
func UseLegacySecurityClient(c Client) {
|
||||||
|
if client, ok := c.(*ClientCommon); ok {
|
||||||
|
client.clearClientSecurityProfiles()
|
||||||
|
client.setClientTransportProtectionProfile(transportProtectionProfile{
|
||||||
|
mode: ProtectionManaged,
|
||||||
|
msgEn: defaultMsgEn,
|
||||||
|
msgDe: defaultMsgDe,
|
||||||
|
secretKey: bytes.Clone(defaultAesKey),
|
||||||
|
})
|
||||||
|
client.securityReadyCheck = false
|
||||||
|
client.skipKeyExchange = false
|
||||||
|
client.handshakeRsaPubKey = bytes.Clone(defaultRsaPubKey)
|
||||||
|
return
|
||||||
|
}
|
||||||
c.SetSecretKey(bytes.Clone(defaultAesKey))
|
c.SetSecretKey(bytes.Clone(defaultAesKey))
|
||||||
c.SetMsgEn(defaultMsgEn)
|
c.SetMsgEn(defaultMsgEn)
|
||||||
c.SetMsgDe(defaultMsgDe)
|
c.SetMsgDe(defaultMsgDe)
|
||||||
if client, ok := c.(*ClientCommon); ok {
|
|
||||||
client.fastStreamEncode = nil
|
|
||||||
client.fastBulkEncode = nil
|
|
||||||
client.fastPlainEncode = nil
|
|
||||||
client.modernPSKRuntime = nil
|
|
||||||
}
|
|
||||||
c.SetSkipExchangeKey(false)
|
c.SetSkipExchangeKey(false)
|
||||||
c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey))
|
c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey))
|
||||||
}
|
}
|
||||||
@ -150,15 +223,21 @@ func UseLegacySecurityClient(c Client) {
|
|||||||
//
|
//
|
||||||
// It is kept only as an explicit fallback path for existing deployments.
|
// It is kept only as an explicit fallback path for existing deployments.
|
||||||
func UseLegacySecurityServer(s Server) {
|
func UseLegacySecurityServer(s Server) {
|
||||||
|
if server, ok := s.(*ServerCommon); ok {
|
||||||
|
server.clearServerSecurityProfiles()
|
||||||
|
server.setServerDefaultTransportProtectionProfile(transportProtectionProfile{
|
||||||
|
mode: ProtectionManaged,
|
||||||
|
msgEn: defaultMsgEn,
|
||||||
|
msgDe: defaultMsgDe,
|
||||||
|
secretKey: bytes.Clone(defaultAesKey),
|
||||||
|
})
|
||||||
|
server.securityReadyCheck = false
|
||||||
|
server.handshakeRsaKey = bytes.Clone(defaultRsaKey)
|
||||||
|
return
|
||||||
|
}
|
||||||
s.SetSecretKey(bytes.Clone(defaultAesKey))
|
s.SetSecretKey(bytes.Clone(defaultAesKey))
|
||||||
s.SetDefaultCommEncode(defaultMsgEn)
|
s.SetDefaultCommEncode(defaultMsgEn)
|
||||||
s.SetDefaultCommDecode(defaultMsgDe)
|
s.SetDefaultCommDecode(defaultMsgDe)
|
||||||
if server, ok := s.(*ServerCommon); ok {
|
|
||||||
server.defaultFastStreamEncode = nil
|
|
||||||
server.defaultFastBulkEncode = nil
|
|
||||||
server.defaultFastPlainEncode = nil
|
|
||||||
server.defaultModernPSKRuntime = nil
|
|
||||||
}
|
|
||||||
s.SetRsaPrivKey(bytes.Clone(defaultRsaKey))
|
s.SetRsaPrivKey(bytes.Clone(defaultRsaKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,6 +253,40 @@ func deriveModernPSKKey(sharedSecret []byte, opts *ModernPSKOptions) ([]byte, []
|
|||||||
return key, cfg.AAD, nil
|
return key, cfg.AAD, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func deriveModernPSKProtectionProfile(sharedSecret []byte, opts *ModernPSKOptions, mode ProtectionMode) (transportProtectionProfile, error) {
|
||||||
|
key, aad, err := deriveModernPSKKey(sharedSecret, opts)
|
||||||
|
if err != nil {
|
||||||
|
return transportProtectionProfile{}, err
|
||||||
|
}
|
||||||
|
transport := buildModernPSKTransportBundle(aad)
|
||||||
|
runtime, err := newModernPSKCodecRuntime(key, aad)
|
||||||
|
if err != nil {
|
||||||
|
return transportProtectionProfile{}, err
|
||||||
|
}
|
||||||
|
return newTransportProtectionProfile(mode, transport, runtime, key), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildExternalProtectionProfile(secretKey []byte) transportProtectionProfile {
|
||||||
|
return newTransportProtectionProfile(ProtectionExternal, buildExternalTransportBundle(), nil, secretKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func deriveModernPSKSessionProtectionProfile(base transportProtectionProfile, sessionKey []byte, sessionID []byte) (transportProtectionProfile, error) {
|
||||||
|
aad := bytes.Clone(defaultModernPSKAAD)
|
||||||
|
if base.runtime != nil && len(base.runtime.aad) != 0 {
|
||||||
|
aad = bytes.Clone(base.runtime.aad)
|
||||||
|
}
|
||||||
|
runtime, err := newModernPSKCodecRuntime(sessionKey, aad)
|
||||||
|
if err != nil {
|
||||||
|
return transportProtectionProfile{}, err
|
||||||
|
}
|
||||||
|
profile := newTransportProtectionProfile(base.mode, buildModernPSKTransportBundle(aad), runtime, sessionKey)
|
||||||
|
profile.keyMode = peerAttachKeyModeECDHE
|
||||||
|
profile.sessionID = cloneTransportSessionID(sessionID)
|
||||||
|
profile.forwardSecrecy = true
|
||||||
|
profile.forwardSecrecyFallback = false
|
||||||
|
return profile, nil
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions {
|
func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions {
|
||||||
cfg := DefaultModernPSKOptions()
|
cfg := DefaultModernPSKOptions()
|
||||||
if opts == nil {
|
if opts == nil {
|
||||||
|
|||||||
@ -3,8 +3,10 @@ package notify
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"b612.me/starcrypto"
|
"b612.me/starcrypto"
|
||||||
)
|
)
|
||||||
@ -207,6 +209,17 @@ func TestUseModernPSKRejectsEmptySecret(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsePSKOverExternalTransportRejectsForwardSecrecyRequirement(t *testing.T) {
|
||||||
|
opts := testModernPSKOptions()
|
||||||
|
opts.RequireForwardSecrecy = true
|
||||||
|
if err := UsePSKOverExternalTransportClient(NewClient(), []byte("secret"), opts); !errors.Is(err, errModernPSKForwardSecrecyUnsupported) {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportClient error = %v, want %v", err, errModernPSKForwardSecrecyUnsupported)
|
||||||
|
}
|
||||||
|
if err := UsePSKOverExternalTransportServer(NewServer(), []byte("secret"), opts); !errors.Is(err, errModernPSKForwardSecrecyUnsupported) {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportServer error = %v, want %v", err, errModernPSKForwardSecrecyUnsupported)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) {
|
func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) {
|
||||||
key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions())
|
key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -315,6 +328,166 @@ func TestModernPSKFastBulkEncodeRoundTrip(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExternalTransportFastStreamEncodeRoundTrip(t *testing.T) {
|
||||||
|
transport := buildExternalTransportBundle()
|
||||||
|
wire, err := transport.fastStreamEncode(nil, 23, 7, []byte("payload"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fastStreamEncode failed: %v", err)
|
||||||
|
}
|
||||||
|
plain := transport.msgDe(nil, wire)
|
||||||
|
frame, matched, err := decodeStreamFastDataFrame(plain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decodeStreamFastDataFrame failed: %v", err)
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
t.Fatal("decodeStreamFastDataFrame should match fast payload")
|
||||||
|
}
|
||||||
|
if frame.DataID != 23 || frame.Seq != 7 || !bytes.Equal(frame.Payload, []byte("payload")) {
|
||||||
|
t.Fatalf("frame mismatch: %+v", frame)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExternalTransportFastBulkEncodeRoundTrip(t *testing.T) {
|
||||||
|
transport := buildExternalTransportBundle()
|
||||||
|
wire, err := transport.fastBulkEncode(nil, 41, 9, []byte("payload"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fastBulkEncode failed: %v", err)
|
||||||
|
}
|
||||||
|
plain := transport.msgDe(nil, wire)
|
||||||
|
frame, matched, err := decodeBulkFastDataFrame(plain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decodeBulkFastDataFrame failed: %v", err)
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
t.Fatal("decodeBulkFastDataFrame should match fast payload")
|
||||||
|
}
|
||||||
|
if frame.DataID != 41 || frame.Seq != 9 || !bytes.Equal(frame.Payload, []byte("payload")) {
|
||||||
|
t.Fatalf("frame mismatch: %+v", frame)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptTransportPayloadCodecPooledExternalDefersRelease(t *testing.T) {
|
||||||
|
payload := []byte("payload")
|
||||||
|
released := false
|
||||||
|
plain, release, err := decryptTransportPayloadCodecPooled(ProtectionExternal, nil, passthroughTransportCodec, nil, payload, func() {
|
||||||
|
released = true
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decryptTransportPayloadCodecPooled failed: %v", err)
|
||||||
|
}
|
||||||
|
if released {
|
||||||
|
t.Fatal("release should not run before caller is done")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(plain, payload) {
|
||||||
|
t.Fatalf("plain mismatch: got %q want %q", plain, payload)
|
||||||
|
}
|
||||||
|
if release == nil {
|
||||||
|
t.Fatal("release callback should be preserved for external mode")
|
||||||
|
}
|
||||||
|
release()
|
||||||
|
if !released {
|
||||||
|
t.Fatal("release callback should run when caller finishes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsePSKOverExternalTransportConnectByConnSwitchesToExternal(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
if err := UsePSKOverExternalTransportServer(server, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||||
|
}
|
||||||
|
server.SetLink("external-roundtrip", func(msg *Message) {
|
||||||
|
_ = msg.Reply([]byte("ack:" + string(msg.Value)))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
if err := UsePSKOverExternalTransportClient(client, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionManaged {
|
||||||
|
t.Fatalf("client bootstrap mode = %v, want %v", got, ProtectionManaged)
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, right)
|
||||||
|
if err := client.ConnectByConn(left); err != nil {
|
||||||
|
t.Fatalf("ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
_ = client.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionExternal {
|
||||||
|
t.Fatalf("client steady mode = %v, want %v", got, ProtectionExternal)
|
||||||
|
}
|
||||||
|
|
||||||
|
reply, err := client.SendWait("external-roundtrip", []byte("ping"), time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SendWait failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := string(reply.Value), "ack:ping"; got != want {
|
||||||
|
t.Fatalf("reply mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
list := server.GetLogicalConnList()
|
||||||
|
if len(list) != 1 {
|
||||||
|
t.Fatalf("logical conn count = %d, want 1", len(list))
|
||||||
|
}
|
||||||
|
if got := list[0].protectionModeSnapshot(); got != ProtectionExternal {
|
||||||
|
t.Fatalf("server steady mode = %v, want %v", got, ProtectionExternal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUseNestedSecurityConnectByConnKeepsNestedMode(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
if err := UseNestedSecurityServer(server, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseNestedSecurityServer failed: %v", err)
|
||||||
|
}
|
||||||
|
server.SetLink("nested-roundtrip", func(msg *Message) {
|
||||||
|
_ = msg.Reply([]byte("ack:" + string(msg.Value)))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
if err := UseNestedSecurityClient(client, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseNestedSecurityClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionNested {
|
||||||
|
t.Fatalf("client bootstrap mode = %v, want %v", got, ProtectionNested)
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, right)
|
||||||
|
if err := client.ConnectByConn(left); err != nil {
|
||||||
|
t.Fatalf("ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
_ = client.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionNested {
|
||||||
|
t.Fatalf("client steady mode = %v, want %v", got, ProtectionNested)
|
||||||
|
}
|
||||||
|
|
||||||
|
reply, err := client.SendWait("nested-roundtrip", []byte("ping"), time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SendWait failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := string(reply.Value), "ack:ping"; got != want {
|
||||||
|
t.Fatalf("reply mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
list := server.GetLogicalConnList()
|
||||||
|
if len(list) != 1 {
|
||||||
|
t.Fatalf("logical conn count = %d, want 1", len(list))
|
||||||
|
}
|
||||||
|
if got := list[0].protectionModeSnapshot(); got != ProtectionNested {
|
||||||
|
t.Fatalf("server steady mode = %v, want %v", got, ProtectionNested)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUseLegacySecurityRoundTrip(t *testing.T) {
|
func TestUseLegacySecurityRoundTrip(t *testing.T) {
|
||||||
client := NewClient()
|
client := NewClient()
|
||||||
server := NewServer()
|
server := NewServer()
|
||||||
|
|||||||
15
server.go
15
server.go
@ -34,6 +34,19 @@ type ServerCommon struct {
|
|||||||
defaultFastBulkEncode transportFastBulkEncoder
|
defaultFastBulkEncode transportFastBulkEncoder
|
||||||
defaultFastPlainEncode transportFastPlainEncoder
|
defaultFastPlainEncode transportFastPlainEncoder
|
||||||
defaultModernPSKRuntime *modernPSKCodecRuntime
|
defaultModernPSKRuntime *modernPSKCodecRuntime
|
||||||
|
peerAttachSecurity atomic.Pointer[peerAttachSecurityState]
|
||||||
|
securityBootstrap transportProtectionProfile
|
||||||
|
securitySteady transportProtectionProfile
|
||||||
|
securityAuthMode AuthMode
|
||||||
|
securityProtectionMode ProtectionMode
|
||||||
|
securityRequireForwardSecrecy bool
|
||||||
|
securityConfigured bool
|
||||||
|
peerAttachReplay peerAttachReplayCache
|
||||||
|
peerAttachExplicitCount atomic.Int64
|
||||||
|
peerAttachAuthFallbackCount atomic.Int64
|
||||||
|
peerAttachAuthRejectCount atomic.Int64
|
||||||
|
peerAttachDowngradeRejectCount atomic.Int64
|
||||||
|
peerAttachBindingRejectCount atomic.Int64
|
||||||
linkFns map[string]func(message *Message)
|
linkFns map[string]func(message *Message)
|
||||||
defaultFns func(message *Message)
|
defaultFns func(message *Message)
|
||||||
noFinSyncMsgMaxKeepSeconds int64
|
noFinSyncMsgMaxKeepSeconds int64
|
||||||
@ -93,6 +106,8 @@ func NewServer() Server {
|
|||||||
server.defaultFns = func(message *Message) {
|
server.defaultFns = func(message *Message) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
server.setServerDefaultTransportProtectionProfile(defaultTransportProtectionProfile())
|
||||||
|
server.peerAttachSecurity.Store(defaultPeerAttachSecurityState())
|
||||||
server.sessionRuntime.Store(newServerSessionRuntimeBase(server.stopCtx, server.stopFn))
|
server.sessionRuntime.Store(newServerSessionRuntimeBase(server.stopCtx, server.stopFn))
|
||||||
bindServerStreamControl(&server)
|
bindServerStreamControl(&server)
|
||||||
bindServerBulkControl(&server)
|
bindServerBulkControl(&server)
|
||||||
|
|||||||
@ -413,7 +413,7 @@ func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *Tra
|
|||||||
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, startSeq, payload)
|
return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, startSeq, payload, payloadOwned)
|
||||||
}
|
}
|
||||||
if transport == nil {
|
if transport == nil {
|
||||||
return 0, errBulkTransportNil
|
return 0, errBulkTransportNil
|
||||||
|
|||||||
@ -31,22 +31,28 @@ func (s *ServerCommon) Stop() error {
|
|||||||
// Deprecated: SetDefaultCommEncode overrides the transport codec directly.
|
// Deprecated: SetDefaultCommEncode overrides the transport codec directly.
|
||||||
// Prefer UseModernPSKServer or UseLegacySecurityServer.
|
// Prefer UseModernPSKServer or UseLegacySecurityServer.
|
||||||
func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) {
|
func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) {
|
||||||
s.defaultMsgEn = fn
|
profile := transportProtectionProfile{
|
||||||
s.defaultFastStreamEncode = nil
|
mode: ProtectionManaged,
|
||||||
s.defaultFastBulkEncode = nil
|
msgEn: fn,
|
||||||
s.defaultFastPlainEncode = nil
|
msgDe: s.defaultMsgDe,
|
||||||
s.defaultModernPSKRuntime = nil
|
secretKey: s.SecretKey,
|
||||||
|
}
|
||||||
|
s.setServerDefaultTransportProtectionProfile(profile)
|
||||||
|
s.clearServerSecurityProfiles()
|
||||||
s.securityReadyCheck = false
|
s.securityReadyCheck = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: SetDefaultCommDecode overrides the transport codec directly.
|
// Deprecated: SetDefaultCommDecode overrides the transport codec directly.
|
||||||
// Prefer UseModernPSKServer or UseLegacySecurityServer.
|
// Prefer UseModernPSKServer or UseLegacySecurityServer.
|
||||||
func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) {
|
func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) {
|
||||||
s.defaultMsgDe = fn
|
profile := transportProtectionProfile{
|
||||||
s.defaultFastStreamEncode = nil
|
mode: ProtectionManaged,
|
||||||
s.defaultFastBulkEncode = nil
|
msgEn: s.defaultMsgEn,
|
||||||
s.defaultFastPlainEncode = nil
|
msgDe: fn,
|
||||||
s.defaultModernPSKRuntime = nil
|
secretKey: s.SecretKey,
|
||||||
|
}
|
||||||
|
s.setServerDefaultTransportProtectionProfile(profile)
|
||||||
|
s.clearServerSecurityProfiles()
|
||||||
s.securityReadyCheck = false
|
s.securityReadyCheck = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,14 +104,21 @@ func (s *ServerCommon) GetSecretKey() []byte {
|
|||||||
// Deprecated: SetSecretKey injects a raw transport key directly.
|
// Deprecated: SetSecretKey injects a raw transport key directly.
|
||||||
// Prefer UseModernPSKServer or UseLegacySecurityServer.
|
// Prefer UseModernPSKServer or UseLegacySecurityServer.
|
||||||
func (s *ServerCommon) SetSecretKey(key []byte) {
|
func (s *ServerCommon) SetSecretKey(key []byte) {
|
||||||
s.SecretKey = key
|
profile := transportProtectionProfile{
|
||||||
if len(key) == 0 {
|
mode: ProtectionManaged,
|
||||||
s.defaultModernPSKRuntime = nil
|
msgEn: s.defaultMsgEn,
|
||||||
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
|
msgDe: s.defaultMsgDe,
|
||||||
s.defaultModernPSKRuntime = runtime
|
secretKey: cloneTransportProtectionKey(key),
|
||||||
} else {
|
|
||||||
s.defaultModernPSKRuntime = nil
|
|
||||||
}
|
}
|
||||||
|
if len(key) == 0 {
|
||||||
|
profile.runtime = nil
|
||||||
|
} else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil {
|
||||||
|
profile.runtime = runtime
|
||||||
|
} else {
|
||||||
|
profile.runtime = nil
|
||||||
|
}
|
||||||
|
s.setServerDefaultTransportProtectionProfile(profile)
|
||||||
|
s.clearServerSecurityProfiles()
|
||||||
s.securityReadyCheck = len(key) == 0
|
s.securityReadyCheck = len(key) == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -3,12 +3,54 @@ package notify
|
|||||||
import (
|
import (
|
||||||
"b612.me/stario"
|
"b612.me/stario"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func readServerEnvelopeFromConnWithProfile(t *testing.T, server *ServerCommon, profile transportProtectionProfile, conn net.Conn, timeout time.Duration) Envelope {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
queue := stario.NewQueue()
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||||
|
t.Fatalf("SetReadDeadline failed: %v", err)
|
||||||
|
}
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
if parseErr := queue.ParseMessage(buf[:n], "server-inbound-profile"); parseErr != nil {
|
||||||
|
t.Fatalf("ParseMessage failed: %v", parseErr)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case msg := <-queue.RestoreChan():
|
||||||
|
plain, decErr := decryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, msg.Msg)
|
||||||
|
if decErr != nil {
|
||||||
|
t.Fatalf("decryptTransportPayloadCodec failed: %v", decErr)
|
||||||
|
}
|
||||||
|
env, decErr := server.decodeEnvelopePlain(plain)
|
||||||
|
if decErr != nil {
|
||||||
|
t.Fatalf("decodeEnvelopePlain failed: %v", decErr)
|
||||||
|
}
|
||||||
|
return env
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
t.Fatalf("conn Read failed: %v", err)
|
||||||
|
}
|
||||||
|
t.Fatal("timed out waiting for server envelope")
|
||||||
|
return Envelope{}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) {
|
func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) {
|
||||||
server := NewServer().(*ServerCommon)
|
server := NewServer().(*ServerCommon)
|
||||||
UseLegacySecurityServer(server)
|
UseLegacySecurityServer(server)
|
||||||
@ -85,6 +127,64 @@ func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMessageReplyUsesInboundProtectionSnapshotAfterLogicalSwitch(t *testing.T) {
|
||||||
|
secret := []byte("correct horse battery staple")
|
||||||
|
alternate, err := deriveModernPSKProtectionProfile([]byte("notify-reply-snapshot-alternate"), testModernPSKOptions(), ProtectionManaged)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("deriveModernPSKProtectionProfile failed: %v", err)
|
||||||
|
}
|
||||||
|
handlerErr := make(chan error, 1)
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
server.SetLink("reply-snapshot", func(msg *Message) {
|
||||||
|
if msg == nil || msg.LogicalConn == nil {
|
||||||
|
select {
|
||||||
|
case handlerErr <- errors.New("reply-snapshot logical is nil"):
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg.LogicalConn.applyTransportProtectionProfile(alternate)
|
||||||
|
if err := msg.Reply([]byte("ack")); err != nil {
|
||||||
|
select {
|
||||||
|
case handlerErr <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, right)
|
||||||
|
if err := client.ConnectByConn(left); err != nil {
|
||||||
|
t.Fatalf("ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
_ = client.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
reply, err := client.SendWait("reply-snapshot", []byte("ping"), time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SendWait failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := string(reply.Value), "ack"; got != want {
|
||||||
|
t.Fatalf("reply value = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case err := <-handlerErr:
|
||||||
|
t.Fatalf("reply-snapshot handler failed: %v", err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServerHandleReceivedSignalReliabilityUsesInboundConnForStaleTransport(t *testing.T) {
|
func TestServerHandleReceivedSignalReliabilityUsesInboundConnForStaleTransport(t *testing.T) {
|
||||||
server := NewServer().(*ServerCommon)
|
server := NewServer().(*ServerCommon)
|
||||||
UseLegacySecurityServer(server)
|
UseLegacySecurityServer(server)
|
||||||
|
|||||||
@ -93,26 +93,34 @@ func (s *ServerCommon) pushTransportPayloadSourceFast(payload []byte, release fu
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if s.tryDispatchBorrowedBulkTransportPayload(source, payload) {
|
logical, transport := s.resolveInboundSource(source)
|
||||||
|
if logical == nil {
|
||||||
if release != nil {
|
if release != nil {
|
||||||
release()
|
release()
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
owned := append([]byte(nil), payload...)
|
plain, plainRelease, err := s.decryptTransportPayloadLogicalPooled(logical, payload, release)
|
||||||
if release != nil {
|
if err != nil {
|
||||||
release()
|
if s.showError || s.debugMode {
|
||||||
|
fmt.Println("server decode transport payload error", err)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
inboundConn := serverInboundConn(source)
|
||||||
|
if s.tryDispatchBorrowedTransportPlain(logical, transport, inboundConn, plain, plainRelease) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
owned := plain
|
||||||
|
if plainRelease != nil {
|
||||||
|
owned = append([]byte(nil), plain...)
|
||||||
|
plainRelease()
|
||||||
}
|
}
|
||||||
s.wg.Add(1)
|
s.wg.Add(1)
|
||||||
if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() {
|
if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() {
|
||||||
defer s.wg.Done()
|
defer s.wg.Done()
|
||||||
logical, transport := s.resolveInboundSource(source)
|
|
||||||
if logical == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
inboundConn := serverInboundConn(source)
|
if err := s.dispatchInboundTransportPlain(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) {
|
||||||
if err := s.dispatchInboundTransportPayload(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) {
|
|
||||||
fmt.Println("server decode envelope error", err)
|
fmt.Println("server decode envelope error", err)
|
||||||
}
|
}
|
||||||
}) {
|
}) {
|
||||||
|
|||||||
@ -358,16 +358,20 @@ func (s *ServerCommon) sendEnvelopeTransport(transport *TransportConn, env Envel
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerCommon) sendEnvelopeInboundTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope) error {
|
func (s *ServerCommon) sendEnvelopeInboundTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope) error {
|
||||||
|
return s.sendEnvelopeInboundTransportWithProfile(logical, transport, conn, nil, env)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendEnvelopeInboundTransportWithProfile(logical *LogicalConn, transport *TransportConn, conn net.Conn, profile *transportProtectionProfile, env Envelope) error {
|
||||||
if logical == nil && transport != nil {
|
if logical == nil && transport != nil {
|
||||||
logical = transport.logicalConnSnapshot()
|
logical = transport.logicalConnSnapshot()
|
||||||
}
|
}
|
||||||
if logical == nil {
|
if logical == nil {
|
||||||
return transportDetachedErrorForPeer(logical, transport)
|
return transportDetachedErrorForPeer(logical, transport)
|
||||||
}
|
}
|
||||||
if logical.msgEnSnapshot() == nil {
|
if profile == nil && logical.msgEnSnapshot() == nil {
|
||||||
return transportDetachedErrorForPeer(logical, transport)
|
return transportDetachedErrorForPeer(logical, transport)
|
||||||
}
|
}
|
||||||
payload, err := s.encodeEnvelopePayloadLogical(logical, env)
|
payload, err := s.encodeEnvelopePayloadInbound(logical, env, profile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -402,7 +406,18 @@ func (s *ServerCommon) writeControlEnvelopePayload(logical *LogicalConn, transpo
|
|||||||
return sender.submit(payload, writeDeadlineFromTimeout(logical.maxWriteTimeoutSnapshot()))
|
return sender.submit(payload, writeDeadlineFromTimeout(logical.maxWriteTimeoutSnapshot()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, msg TransferMsg) error {
|
func (s *ServerCommon) encodeEnvelopePayloadInbound(logical *LogicalConn, env Envelope, profile *transportProtectionProfile) ([]byte, error) {
|
||||||
|
if profile == nil {
|
||||||
|
return s.encodeEnvelopePayloadLogical(logical, env)
|
||||||
|
}
|
||||||
|
data, err := s.encodeEnvelopePlain(env)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return encryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgEn, profile.secretKey, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, profile *transportProtectionProfile, msg TransferMsg) error {
|
||||||
if logical == nil && transport != nil {
|
if logical == nil && transport != nil {
|
||||||
logical = transport.logicalConnSnapshot()
|
logical = transport.logicalConnSnapshot()
|
||||||
}
|
}
|
||||||
@ -413,7 +428,7 @@ func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *Tran
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.sendEnvelopeInboundTransport(logical, transport, conn, env)
|
return s.sendEnvelopeInboundTransportWithProfile(logical, transport, conn, profile, env)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerCommon) writeEnvelopePayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte) error {
|
func (s *ServerCommon) writeEnvelopePayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte) error {
|
||||||
|
|||||||
@ -110,7 +110,17 @@ func (s *ServerCommon) registerAcceptedLogical(logical *LogicalConn) *LogicalCon
|
|||||||
}
|
}
|
||||||
logical.setServer(s)
|
logical.setServer(s)
|
||||||
logical.applyAttachmentProfile(s.maxReadTimeout, s.maxWriteTimeout, s.defaultMsgEn, s.defaultMsgDe, s.defaultFastStreamEncode, s.defaultFastBulkEncode, s.defaultFastPlainEncode, s.handshakeRsaKey, s.SecretKey)
|
logical.applyAttachmentProfile(s.maxReadTimeout, s.maxWriteTimeout, s.defaultMsgEn, s.defaultMsgDe, s.defaultFastStreamEncode, s.defaultFastBulkEncode, s.defaultFastPlainEncode, s.handshakeRsaKey, s.SecretKey)
|
||||||
|
logical.updateAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.authMode = s.securityAuthMode
|
||||||
|
state.peerAttached = false
|
||||||
|
state.peerAttachFallback = false
|
||||||
|
state.peerAttachAt = 0
|
||||||
|
})
|
||||||
|
if s.securityConfigured {
|
||||||
|
logical.applyTransportProtectionProfile(s.securityBootstrap)
|
||||||
|
} else {
|
||||||
logical.setModernPSKRuntime(s.defaultModernPSKRuntime)
|
logical.setModernPSKRuntime(s.defaultModernPSKRuntime)
|
||||||
|
}
|
||||||
logical.markHeartbeatNow()
|
logical.markHeartbeatNow()
|
||||||
return s.getPeerRegistry().registerLogical(logical)
|
return s.getPeerRegistry().registerLogical(logical)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -26,6 +26,8 @@ type Server interface {
|
|||||||
RecoverTransferSnapshots(context.Context) error
|
RecoverTransferSnapshots(context.Context) error
|
||||||
SetBulkOpenTuning(BulkOpenTuning)
|
SetBulkOpenTuning(BulkOpenTuning)
|
||||||
BulkOpenTuning() BulkOpenTuning
|
BulkOpenTuning() BulkOpenTuning
|
||||||
|
SetPeerAttachSecurityConfig(PeerAttachSecurityConfig) error
|
||||||
|
PeerAttachSecurityConfig() PeerAttachSecurityConfig
|
||||||
SetFileReceiveDir(dir string) error
|
SetFileReceiveDir(dir string) error
|
||||||
send(c *ClientConn, msg TransferMsg) (WaitMsg, error)
|
send(c *ClientConn, msg TransferMsg) (WaitMsg, error)
|
||||||
sendEnvelope(c *ClientConn, env Envelope) error
|
sendEnvelope(c *ClientConn, env Envelope) error
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package notify
|
package notify
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -17,6 +18,19 @@ type ClientRuntimeSnapshot struct {
|
|||||||
ConnectNetwork string
|
ConnectNetwork string
|
||||||
ConnectAddress string
|
ConnectAddress string
|
||||||
CanReconnect bool
|
CanReconnect bool
|
||||||
|
AuthMode string
|
||||||
|
ProtectionMode string
|
||||||
|
ProtectionKeyMode string
|
||||||
|
ForwardSecrecyEnabled bool
|
||||||
|
ForwardSecrecyFallback bool
|
||||||
|
ForwardSecrecyRequired bool
|
||||||
|
TransportSessionID string
|
||||||
|
PeerAttachAuthenticated bool
|
||||||
|
PeerAttachAuthFallback bool
|
||||||
|
LastPeerAttachAt time.Time
|
||||||
|
PeerAttachRequireExplicitAuth bool
|
||||||
|
PeerAttachRequireChannelBinding bool
|
||||||
|
PeerAttachChannelBindingConfigured bool
|
||||||
BulkNetworkProfile string
|
BulkNetworkProfile string
|
||||||
BulkDefaultMode string
|
BulkDefaultMode string
|
||||||
BulkChunkSize int
|
BulkChunkSize int
|
||||||
@ -49,6 +63,22 @@ type ServerRuntimeSnapshot struct {
|
|||||||
HasRuntimeUDPListener bool
|
HasRuntimeUDPListener bool
|
||||||
HasRuntimeQueue bool
|
HasRuntimeQueue bool
|
||||||
HasRuntimeStopCtx bool
|
HasRuntimeStopCtx bool
|
||||||
|
AuthMode string
|
||||||
|
ProtectionMode string
|
||||||
|
ForwardSecrecySupported bool
|
||||||
|
ForwardSecrecyRequired bool
|
||||||
|
PeerAttachRequireExplicitAuth bool
|
||||||
|
PeerAttachRequireChannelBinding bool
|
||||||
|
PeerAttachChannelBindingConfigured bool
|
||||||
|
PeerAttachReplayWindow time.Duration
|
||||||
|
PeerAttachReplayCapacity int
|
||||||
|
PeerAttachExplicitAuth int64
|
||||||
|
PeerAttachAuthFallbacks int64
|
||||||
|
PeerAttachAuthRejects int64
|
||||||
|
PeerAttachDowngradeRejects int64
|
||||||
|
PeerAttachBindingRejects int64
|
||||||
|
PeerAttachReplayRejects int64
|
||||||
|
PeerAttachReplayOverflowRejects int64
|
||||||
BulkChunkSize int
|
BulkChunkSize int
|
||||||
BulkWindowBytes int
|
BulkWindowBytes int
|
||||||
BulkMaxInFlight int
|
BulkMaxInFlight int
|
||||||
@ -82,6 +112,15 @@ type ClientConnRuntimeSnapshot struct {
|
|||||||
TransportDetachRemaining time.Duration
|
TransportDetachRemaining time.Duration
|
||||||
TransportDetachExpired bool
|
TransportDetachExpired bool
|
||||||
ReattachEligible bool
|
ReattachEligible bool
|
||||||
|
AuthMode string
|
||||||
|
ProtectionMode string
|
||||||
|
ProtectionKeyMode string
|
||||||
|
ForwardSecrecyEnabled bool
|
||||||
|
ForwardSecrecyFallback bool
|
||||||
|
TransportSessionID string
|
||||||
|
PeerAttachAuthenticated bool
|
||||||
|
PeerAttachAuthFallback bool
|
||||||
|
LastPeerAttachAt time.Time
|
||||||
TransportBulkAdaptiveSoftPayloadBytes int
|
TransportBulkAdaptiveSoftPayloadBytes int
|
||||||
TransportStreamAdaptiveSoftPayloadBytes int
|
TransportStreamAdaptiveSoftPayloadBytes int
|
||||||
TransportStreamAdaptiveWaitThresholdBytes int
|
TransportStreamAdaptiveWaitThresholdBytes int
|
||||||
@ -108,6 +147,19 @@ func (c *ClientCommon) clientRuntimeSnapshot() ClientRuntimeSnapshot {
|
|||||||
snapshot.ConnectAddress = source.addr
|
snapshot.ConnectAddress = source.addr
|
||||||
snapshot.CanReconnect = source.canReconnect()
|
snapshot.CanReconnect = source.canReconnect()
|
||||||
}
|
}
|
||||||
|
snapshot.AuthMode = authModeName(c.securityAuthMode)
|
||||||
|
snapshot.ProtectionMode = protectionModeName(c.securityProtectionMode)
|
||||||
|
protection := c.clientTransportProtectionSnapshot()
|
||||||
|
snapshot.ProtectionKeyMode = protection.keyMode
|
||||||
|
snapshot.ForwardSecrecyEnabled = protection.forwardSecrecy
|
||||||
|
snapshot.ForwardSecrecyFallback = protection.forwardSecrecyFallback
|
||||||
|
snapshot.ForwardSecrecyRequired = c.clientRequiresForwardSecrecy()
|
||||||
|
snapshot.TransportSessionID = hex.EncodeToString(protection.sessionID)
|
||||||
|
snapshot.PeerAttachAuthenticated, snapshot.PeerAttachAuthFallback, snapshot.LastPeerAttachAt = c.clientPeerAttachAuthSnapshot()
|
||||||
|
peerAttachCfg := c.peerAttachSecuritySnapshot()
|
||||||
|
snapshot.PeerAttachRequireExplicitAuth = peerAttachCfg.requireExplicitAuth
|
||||||
|
snapshot.PeerAttachRequireChannelBinding = peerAttachCfg.requireChannelBinding
|
||||||
|
snapshot.PeerAttachChannelBindingConfigured = peerAttachCfg.channelBinding != nil
|
||||||
snapshot.BulkNetworkProfile = bulkNetworkProfileName(c.BulkNetworkProfile())
|
snapshot.BulkNetworkProfile = bulkNetworkProfileName(c.BulkNetworkProfile())
|
||||||
snapshot.BulkDefaultMode = bulkOpenModeName(c.BulkDefaultOpenMode())
|
snapshot.BulkDefaultMode = bulkOpenModeName(c.BulkDefaultOpenMode())
|
||||||
tuning := c.BulkOpenTuning()
|
tuning := c.BulkOpenTuning()
|
||||||
@ -161,6 +213,23 @@ func (s *ServerCommon) serverRuntimeSnapshot() ServerRuntimeSnapshot {
|
|||||||
snapshot.HasRuntimeQueue = rt.queue != nil
|
snapshot.HasRuntimeQueue = rt.queue != nil
|
||||||
snapshot.HasRuntimeStopCtx = rt.stopCtx != nil
|
snapshot.HasRuntimeStopCtx = rt.stopCtx != nil
|
||||||
}
|
}
|
||||||
|
snapshot.AuthMode = authModeName(s.securityAuthMode)
|
||||||
|
snapshot.ProtectionMode = protectionModeName(s.securityProtectionMode)
|
||||||
|
snapshot.ForwardSecrecySupported = s.serverSupportsForwardSecrecy()
|
||||||
|
snapshot.ForwardSecrecyRequired = s.serverRequiresForwardSecrecy()
|
||||||
|
peerAttachCfg := s.peerAttachSecuritySnapshot()
|
||||||
|
snapshot.PeerAttachRequireExplicitAuth = peerAttachCfg.requireExplicitAuth
|
||||||
|
snapshot.PeerAttachRequireChannelBinding = peerAttachCfg.requireChannelBinding
|
||||||
|
snapshot.PeerAttachChannelBindingConfigured = peerAttachCfg.channelBinding != nil
|
||||||
|
snapshot.PeerAttachReplayWindow = peerAttachCfg.replayWindow
|
||||||
|
snapshot.PeerAttachReplayCapacity = peerAttachCfg.replayCapacity
|
||||||
|
snapshot.PeerAttachExplicitAuth = s.peerAttachExplicitCount.Load()
|
||||||
|
snapshot.PeerAttachAuthFallbacks = s.peerAttachAuthFallbackCount.Load()
|
||||||
|
snapshot.PeerAttachAuthRejects = s.peerAttachAuthRejectCount.Load()
|
||||||
|
snapshot.PeerAttachDowngradeRejects = s.peerAttachDowngradeRejectCount.Load()
|
||||||
|
snapshot.PeerAttachBindingRejects = s.peerAttachBindingRejectCount.Load()
|
||||||
|
snapshot.PeerAttachReplayRejects = s.peerAttachReplayRejectCountSnapshot()
|
||||||
|
snapshot.PeerAttachReplayOverflowRejects = s.peerAttachReplayOverflowRejectCountSnapshot()
|
||||||
tuning := s.BulkOpenTuning()
|
tuning := s.BulkOpenTuning()
|
||||||
snapshot.BulkChunkSize = tuning.ChunkSize
|
snapshot.BulkChunkSize = tuning.ChunkSize
|
||||||
snapshot.BulkWindowBytes = tuning.WindowBytes
|
snapshot.BulkWindowBytes = tuning.WindowBytes
|
||||||
@ -171,6 +240,7 @@ func (s *ServerCommon) serverRuntimeSnapshot() ServerRuntimeSnapshot {
|
|||||||
|
|
||||||
func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot {
|
func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot {
|
||||||
status := c.clientConnStatusSnapshot()
|
status := c.clientConnStatusSnapshot()
|
||||||
|
attachment := c.clientConnAttachmentStateSnapshot()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
snapshot := ClientConnRuntimeSnapshot{
|
snapshot := ClientConnRuntimeSnapshot{
|
||||||
ClientID: c.clientConnIDSnapshot(),
|
ClientID: c.clientConnIDSnapshot(),
|
||||||
@ -182,6 +252,17 @@ func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot {
|
|||||||
TransportAttachCount: c.clientConnTransportAttachCountSnapshot(),
|
TransportAttachCount: c.clientConnTransportAttachCountSnapshot(),
|
||||||
TransportDetachCount: c.clientConnTransportDetachCountSnapshot(),
|
TransportDetachCount: c.clientConnTransportDetachCountSnapshot(),
|
||||||
LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(),
|
LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(),
|
||||||
|
AuthMode: authModeName(attachment.authMode),
|
||||||
|
ProtectionMode: protectionModeName(attachment.protectionMode),
|
||||||
|
}
|
||||||
|
snapshot.PeerAttachAuthenticated = attachment.peerAttached
|
||||||
|
snapshot.PeerAttachAuthFallback = attachment.peerAttachFallback
|
||||||
|
snapshot.ProtectionKeyMode = attachment.keyMode
|
||||||
|
snapshot.ForwardSecrecyEnabled = attachment.forwardSecrecy
|
||||||
|
snapshot.ForwardSecrecyFallback = attachment.forwardSecrecyFallback
|
||||||
|
snapshot.TransportSessionID = hex.EncodeToString(attachment.sessionID)
|
||||||
|
if attachment.peerAttachAt != 0 {
|
||||||
|
snapshot.LastPeerAttachAt = time.Unix(0, attachment.peerAttachAt)
|
||||||
}
|
}
|
||||||
if status.Err != nil {
|
if status.Err != nil {
|
||||||
snapshot.Error = status.Err.Error()
|
snapshot.Error = status.Err.Error()
|
||||||
|
|||||||
@ -37,6 +37,21 @@ func TestGetClientRuntimeSnapshotDefaults(t *testing.T) {
|
|||||||
if snapshot.ConnectSource != "" || snapshot.ConnectNetwork != "" || snapshot.ConnectAddress != "" || snapshot.CanReconnect {
|
if snapshot.ConnectSource != "" || snapshot.ConnectNetwork != "" || snapshot.ConnectAddress != "" || snapshot.CanReconnect {
|
||||||
t.Fatalf("unexpected default connect source snapshot: %+v", snapshot)
|
t.Fatalf("unexpected default connect source snapshot: %+v", snapshot)
|
||||||
}
|
}
|
||||||
|
if got, want := snapshot.AuthMode, "none"; got != want {
|
||||||
|
t.Fatalf("AuthMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.ProtectionMode, "managed"; got != want {
|
||||||
|
t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if snapshot.PeerAttachAuthenticated || snapshot.PeerAttachAuthFallback {
|
||||||
|
t.Fatalf("unexpected default peer attach state: %+v", snapshot)
|
||||||
|
}
|
||||||
|
if !snapshot.LastPeerAttachAt.IsZero() {
|
||||||
|
t.Fatalf("LastPeerAttachAt mismatch: got %v want zero", snapshot.LastPeerAttachAt)
|
||||||
|
}
|
||||||
|
if snapshot.PeerAttachRequireExplicitAuth || snapshot.PeerAttachRequireChannelBinding || snapshot.PeerAttachChannelBindingConfigured {
|
||||||
|
t.Fatalf("unexpected default peer attach policy: %+v", snapshot)
|
||||||
|
}
|
||||||
if got, want := snapshot.BulkNetworkProfile, "default"; got != want {
|
if got, want := snapshot.BulkNetworkProfile, "default"; got != want {
|
||||||
t.Fatalf("BulkNetworkProfile mismatch: got %q want %q", got, want)
|
t.Fatalf("BulkNetworkProfile mismatch: got %q want %q", got, want)
|
||||||
}
|
}
|
||||||
@ -117,6 +132,24 @@ func TestGetServerRuntimeSnapshotDefaults(t *testing.T) {
|
|||||||
if !snapshot.HasRuntimeStopCtx {
|
if !snapshot.HasRuntimeStopCtx {
|
||||||
t.Fatalf("HasRuntimeStopCtx mismatch: got %v want true", snapshot.HasRuntimeStopCtx)
|
t.Fatalf("HasRuntimeStopCtx mismatch: got %v want true", snapshot.HasRuntimeStopCtx)
|
||||||
}
|
}
|
||||||
|
if got, want := snapshot.AuthMode, "none"; got != want {
|
||||||
|
t.Fatalf("AuthMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.ProtectionMode, "managed"; got != want {
|
||||||
|
t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if snapshot.PeerAttachRequireExplicitAuth || snapshot.PeerAttachRequireChannelBinding || snapshot.PeerAttachChannelBindingConfigured {
|
||||||
|
t.Fatalf("unexpected default peer attach policy: %+v", snapshot)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.PeerAttachReplayWindow, peerAttachReplayTTL; got != want {
|
||||||
|
t.Fatalf("PeerAttachReplayWindow mismatch: got %s want %s", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.PeerAttachReplayCapacity, defaultPeerAttachReplayCapacity; got != want {
|
||||||
|
t.Fatalf("PeerAttachReplayCapacity mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if snapshot.PeerAttachExplicitAuth != 0 || snapshot.PeerAttachAuthFallbacks != 0 || snapshot.PeerAttachAuthRejects != 0 || snapshot.PeerAttachDowngradeRejects != 0 || snapshot.PeerAttachBindingRejects != 0 || snapshot.PeerAttachReplayRejects != 0 || snapshot.PeerAttachReplayOverflowRejects != 0 {
|
||||||
|
t.Fatalf("unexpected default peer attach counters: %+v", snapshot)
|
||||||
|
}
|
||||||
if got, want := snapshot.BulkChunkSize, defaultBulkChunkSize; got != want {
|
if got, want := snapshot.BulkChunkSize, defaultBulkChunkSize; got != want {
|
||||||
t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want)
|
t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want)
|
||||||
}
|
}
|
||||||
@ -476,6 +509,134 @@ func TestGetClientConnRuntimeSnapshotExposesDetachState(t *testing.T) {
|
|||||||
if snapshot.LastHeartbeatAt.IsZero() {
|
if snapshot.LastHeartbeatAt.IsZero() {
|
||||||
t.Fatal("LastHeartbeatAt should be recorded")
|
t.Fatal("LastHeartbeatAt should be recorded")
|
||||||
}
|
}
|
||||||
|
if got, want := snapshot.AuthMode, "none"; got != want {
|
||||||
|
t.Fatalf("AuthMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.ProtectionMode, "managed"; got != want {
|
||||||
|
t.Fatalf("ProtectionMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if snapshot.PeerAttachAuthenticated || snapshot.PeerAttachAuthFallback {
|
||||||
|
t.Fatalf("unexpected peer attach state: %+v", snapshot)
|
||||||
|
}
|
||||||
|
if !snapshot.LastPeerAttachAt.IsZero() {
|
||||||
|
t.Fatalf("LastPeerAttachAt mismatch: got %v want zero", snapshot.LastPeerAttachAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRuntimeSnapshotsIncludePeerAttachSecurityState(t *testing.T) {
|
||||||
|
secret := []byte("correct horse battery staple")
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
if err := UsePSKOverExternalTransportServer(server, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UsePSKOverExternalTransportClient(client, secret, testModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
bootstrapPeerAttachLogicalForTest(t, server, right)
|
||||||
|
if err := client.ConnectByConn(left); err != nil {
|
||||||
|
t.Fatalf("ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
_ = client.Stop()
|
||||||
|
}()
|
||||||
|
deadline := time.Now().Add(time.Second)
|
||||||
|
for {
|
||||||
|
logical := server.GetLogicalConn(client.peerIdentity)
|
||||||
|
if logical != nil {
|
||||||
|
authenticated, fallback, _ := logical.peerAttachAuthenticatedSnapshot()
|
||||||
|
if authenticated && !fallback && logical.protectionModeSnapshot() == ProtectionExternal && server.peerAttachExplicitCount.Load() == 1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
t.Fatal("peer attach security state did not converge before snapshot")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientSnapshot, err := GetClientRuntimeSnapshot(client)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := clientSnapshot.AuthMode, "psk"; got != want {
|
||||||
|
t.Fatalf("client AuthMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := clientSnapshot.ProtectionMode, "external"; got != want {
|
||||||
|
t.Fatalf("client ProtectionMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if clientSnapshot.PeerAttachRequireExplicitAuth || clientSnapshot.PeerAttachRequireChannelBinding || clientSnapshot.PeerAttachChannelBindingConfigured {
|
||||||
|
t.Fatalf("unexpected client peer attach policy snapshot: %+v", clientSnapshot)
|
||||||
|
}
|
||||||
|
if !clientSnapshot.PeerAttachAuthenticated || clientSnapshot.PeerAttachAuthFallback {
|
||||||
|
t.Fatalf("unexpected client peer attach state: %+v", clientSnapshot)
|
||||||
|
}
|
||||||
|
if clientSnapshot.LastPeerAttachAt.IsZero() {
|
||||||
|
t.Fatal("client LastPeerAttachAt should be recorded")
|
||||||
|
}
|
||||||
|
|
||||||
|
serverSnapshot, err := GetServerRuntimeSnapshot(server)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetServerRuntimeSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := serverSnapshot.AuthMode, "psk"; got != want {
|
||||||
|
t.Fatalf("server AuthMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := serverSnapshot.ProtectionMode, "external"; got != want {
|
||||||
|
t.Fatalf("server ProtectionMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if serverSnapshot.PeerAttachRequireExplicitAuth || serverSnapshot.PeerAttachRequireChannelBinding || serverSnapshot.PeerAttachChannelBindingConfigured {
|
||||||
|
t.Fatalf("unexpected server peer attach policy snapshot: %+v", serverSnapshot)
|
||||||
|
}
|
||||||
|
if got, want := serverSnapshot.PeerAttachExplicitAuth, int64(1); got != want {
|
||||||
|
t.Fatalf("PeerAttachExplicitAuth mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if serverSnapshot.PeerAttachAuthFallbacks != 0 || serverSnapshot.PeerAttachAuthRejects != 0 || serverSnapshot.PeerAttachDowngradeRejects != 0 || serverSnapshot.PeerAttachBindingRejects != 0 || serverSnapshot.PeerAttachReplayRejects != 0 || serverSnapshot.PeerAttachReplayOverflowRejects != 0 {
|
||||||
|
t.Fatalf("unexpected server peer attach counters: %+v", serverSnapshot)
|
||||||
|
}
|
||||||
|
|
||||||
|
logical := server.GetLogicalConn(client.peerIdentity)
|
||||||
|
if logical == nil {
|
||||||
|
t.Fatal("server logical should exist after peer attach")
|
||||||
|
}
|
||||||
|
logicalSnapshot, err := GetLogicalConnRuntimeSnapshot(logical)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetLogicalConnRuntimeSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := logicalSnapshot.AuthMode, "psk"; got != want {
|
||||||
|
t.Fatalf("logical AuthMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := logicalSnapshot.ProtectionMode, "external"; got != want {
|
||||||
|
t.Fatalf("logical ProtectionMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if !logicalSnapshot.PeerAttachAuthenticated || logicalSnapshot.PeerAttachAuthFallback {
|
||||||
|
t.Fatalf("unexpected logical peer attach state: %+v", logicalSnapshot)
|
||||||
|
}
|
||||||
|
if logicalSnapshot.LastPeerAttachAt.IsZero() {
|
||||||
|
t.Fatal("logical LastPeerAttachAt should be recorded")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConnSnapshot, err := GetClientConnRuntimeSnapshot(clientConnFromLogical(logical))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientConnRuntimeSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := clientConnSnapshot.AuthMode, "psk"; got != want {
|
||||||
|
t.Fatalf("client conn AuthMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := clientConnSnapshot.ProtectionMode, "external"; got != want {
|
||||||
|
t.Fatalf("client conn ProtectionMode mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if !clientConnSnapshot.PeerAttachAuthenticated || clientConnSnapshot.PeerAttachAuthFallback {
|
||||||
|
t.Fatalf("unexpected client conn peer attach state: %+v", clientConnSnapshot)
|
||||||
|
}
|
||||||
|
if clientConnSnapshot.LastPeerAttachAt.IsZero() {
|
||||||
|
t.Fatal("client conn LastPeerAttachAt should be recorded")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetServerDetachedClientRuntimeSnapshotsFiltersAndSorts(t *testing.T) {
|
func TestGetServerDetachedClientRuntimeSnapshotsFiltersAndSorts(t *testing.T) {
|
||||||
|
|||||||
104
stream.go
104
stream.go
@ -89,6 +89,60 @@ type streamCloseSender func(context.Context, *streamHandle, bool) error
|
|||||||
type streamResetSender func(context.Context, *streamHandle, string) error
|
type streamResetSender func(context.Context, *streamHandle, string) error
|
||||||
type streamDataSender func(context.Context, *streamHandle, []byte) error
|
type streamDataSender func(context.Context, *streamHandle, []byte) error
|
||||||
|
|
||||||
|
type streamReadChunk struct {
|
||||||
|
data []byte
|
||||||
|
release func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *streamReadChunk) clear() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.release != nil {
|
||||||
|
c.release()
|
||||||
|
}
|
||||||
|
c.data = nil
|
||||||
|
c.release = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type streamReadPayloadOwner struct {
|
||||||
|
refs atomic.Int32
|
||||||
|
release func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStreamReadPayloadOwner(release func()) *streamReadPayloadOwner {
|
||||||
|
if release == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
owner := &streamReadPayloadOwner{release: release}
|
||||||
|
owner.refs.Store(1)
|
||||||
|
return owner
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *streamReadPayloadOwner) retainChunk() func() {
|
||||||
|
if o == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
o.refs.Add(1)
|
||||||
|
return o.releaseChunk
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *streamReadPayloadOwner) releaseChunk() {
|
||||||
|
if o == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if o.refs.Add(-1) == 0 && o.release != nil {
|
||||||
|
o.release()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *streamReadPayloadOwner) done() {
|
||||||
|
if o == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
o.releaseChunk()
|
||||||
|
}
|
||||||
|
|
||||||
type streamHandle struct {
|
type streamHandle struct {
|
||||||
runtime *streamRuntime
|
runtime *streamRuntime
|
||||||
runtimeScope string
|
runtimeScope string
|
||||||
@ -124,8 +178,8 @@ type streamHandle struct {
|
|||||||
remoteClosed bool
|
remoteClosed bool
|
||||||
peerReadClosed bool
|
peerReadClosed bool
|
||||||
resetErr error
|
resetErr error
|
||||||
readQueue [][]byte
|
readQueue []streamReadChunk
|
||||||
readBuf []byte
|
readBuf streamReadChunk
|
||||||
bufferedBytes int
|
bufferedBytes int
|
||||||
readNotify chan struct{}
|
readNotify chan struct{}
|
||||||
readDeadline time.Time
|
readDeadline time.Time
|
||||||
@ -339,20 +393,23 @@ func (s *streamHandle) Read(p []byte) (int, error) {
|
|||||||
for {
|
for {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
localReadClosed := s.localReadClosed
|
localReadClosed := s.localReadClosed
|
||||||
if len(s.readBuf) > 0 {
|
if len(s.readBuf.data) > 0 {
|
||||||
n := copy(p, s.readBuf)
|
n := copy(p, s.readBuf.data)
|
||||||
s.readBuf = s.readBuf[n:]
|
s.readBuf.data = s.readBuf.data[n:]
|
||||||
s.bufferedBytes -= n
|
s.bufferedBytes -= n
|
||||||
if s.bufferedBytes < 0 {
|
if s.bufferedBytes < 0 {
|
||||||
s.bufferedBytes = 0
|
s.bufferedBytes = 0
|
||||||
}
|
}
|
||||||
|
if len(s.readBuf.data) == 0 {
|
||||||
|
s.readBuf.clear()
|
||||||
|
}
|
||||||
s.recordReadLocked(n, time.Now())
|
s.recordReadLocked(n, time.Now())
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
if len(s.readQueue) > 0 {
|
if len(s.readQueue) > 0 {
|
||||||
s.readBuf = s.readQueue[0]
|
s.readBuf = s.readQueue[0]
|
||||||
s.readQueue[0] = nil
|
s.readQueue[0] = streamReadChunk{}
|
||||||
s.readQueue = s.readQueue[1:]
|
s.readQueue = s.readQueue[1:]
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
continue
|
continue
|
||||||
@ -824,43 +881,61 @@ func (s *streamHandle) pushOwnedChunk(chunk []byte) error {
|
|||||||
return s.pushChunkWithOwnership(chunk, true)
|
return s.pushChunkWithOwnership(chunk, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *streamHandle) pushOwnedChunkWithRelease(chunk []byte, release func()) error {
|
||||||
|
return s.pushChunkWithOwnershipAndRelease(chunk, true, release)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *streamHandle) pushChunkWithOwnership(chunk []byte, owned bool) error {
|
func (s *streamHandle) pushChunkWithOwnership(chunk []byte, owned bool) error {
|
||||||
|
return s.pushChunkWithOwnershipAndRelease(chunk, owned, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *streamHandle) pushChunkWithOwnershipAndRelease(chunk []byte, owned bool, release func()) error {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return io.ErrClosedPipe
|
return io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
if len(chunk) == 0 {
|
if len(chunk) == 0 {
|
||||||
|
if release != nil {
|
||||||
|
release()
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
stored := chunk
|
stored := streamReadChunk{data: chunk, release: release}
|
||||||
if !owned {
|
if !owned {
|
||||||
stored = append([]byte(nil), chunk...)
|
stored.data = append([]byte(nil), chunk...)
|
||||||
|
if stored.release != nil {
|
||||||
|
stored.release()
|
||||||
|
stored.release = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if s.resetErr != nil {
|
if s.resetErr != nil {
|
||||||
err := s.resetErr
|
err := s.resetErr
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
stored.clear()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if s.inboundQueueLimit > 0 && s.bufferedChunkCountLocked() >= s.inboundQueueLimit {
|
if s.inboundQueueLimit > 0 && s.bufferedChunkCountLocked() >= s.inboundQueueLimit {
|
||||||
err := s.markResetLocked(errStreamBackpressureExceeded)
|
err := s.markResetLocked(errStreamBackpressureExceeded)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
stored.clear()
|
||||||
s.notifyReadable()
|
s.notifyReadable()
|
||||||
s.finalize()
|
s.finalize()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if s.inboundBytesLimit > 0 && s.bufferedBytes+len(stored) > s.inboundBytesLimit {
|
if s.inboundBytesLimit > 0 && s.bufferedBytes+len(stored.data) > s.inboundBytesLimit {
|
||||||
err := s.markResetLocked(errStreamBackpressureExceeded)
|
err := s.markResetLocked(errStreamBackpressureExceeded)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
stored.clear()
|
||||||
s.notifyReadable()
|
s.notifyReadable()
|
||||||
s.finalize()
|
s.finalize()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(s.readBuf) == 0 && len(s.readQueue) == 0 {
|
if len(s.readBuf.data) == 0 && len(s.readQueue) == 0 {
|
||||||
s.readBuf = stored
|
s.readBuf = stored
|
||||||
} else {
|
} else {
|
||||||
s.readQueue = append(s.readQueue, stored)
|
s.readQueue = append(s.readQueue, stored)
|
||||||
}
|
}
|
||||||
s.bufferedBytes += len(stored)
|
s.bufferedBytes += len(stored.data)
|
||||||
s.notifyReadableLocked()
|
s.notifyReadableLocked()
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
@ -881,11 +956,12 @@ func (s *streamHandle) clearBufferedDataLocked() {
|
|||||||
if s == nil {
|
if s == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
s.readBuf.clear()
|
||||||
for i := range s.readQueue {
|
for i := range s.readQueue {
|
||||||
s.readQueue[i] = nil
|
s.readQueue[i].clear()
|
||||||
}
|
}
|
||||||
s.readQueue = nil
|
s.readQueue = nil
|
||||||
s.readBuf = nil
|
s.readBuf = streamReadChunk{}
|
||||||
s.bufferedBytes = 0
|
s.bufferedBytes = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -894,7 +970,7 @@ func (s *streamHandle) bufferedChunkCountLocked() int {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
count := len(s.readQueue)
|
count := len(s.readQueue)
|
||||||
if len(s.readBuf) > 0 {
|
if len(s.readBuf.data) > 0 {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
return count
|
return count
|
||||||
|
|||||||
@ -56,7 +56,59 @@ func BenchmarkStreamTCPThroughput(b *testing.B) {
|
|||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
b.Run(tc.name, func(b *testing.B) {
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg)
|
benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg, benchmarkTransportSecurityModernPSK)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStreamTCPThroughputTrustedRaw(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payloadSize int
|
||||||
|
cfg StreamConfig
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default_64KiB",
|
||||||
|
payloadSize: 64 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tuned_256KiB",
|
||||||
|
payloadSize: 256 * 1024,
|
||||||
|
cfg: StreamConfig{
|
||||||
|
ChunkSize: 256 * 1024,
|
||||||
|
InboundQueueLimit: 256,
|
||||||
|
InboundBufferedBytesLimit: 32 * 1024 * 1024,
|
||||||
|
OutboundWindowBytes: 8 * 1024 * 1024,
|
||||||
|
OutboundMaxInFlightChunks: 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tuned_512KiB",
|
||||||
|
payloadSize: 512 * 1024,
|
||||||
|
cfg: StreamConfig{
|
||||||
|
ChunkSize: 512 * 1024,
|
||||||
|
InboundQueueLimit: 256,
|
||||||
|
InboundBufferedBytesLimit: 64 * 1024 * 1024,
|
||||||
|
OutboundWindowBytes: 16 * 1024 * 1024,
|
||||||
|
OutboundMaxInFlightChunks: 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tuned_1MiB",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
cfg: StreamConfig{
|
||||||
|
ChunkSize: 1024 * 1024,
|
||||||
|
InboundQueueLimit: 256,
|
||||||
|
InboundBufferedBytesLimit: 64 * 1024 * 1024,
|
||||||
|
OutboundWindowBytes: 16 * 1024 * 1024,
|
||||||
|
OutboundMaxInFlightChunks: 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg, benchmarkTransportSecurityTrustedRaw)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -108,19 +160,69 @@ func BenchmarkStreamTCPThroughputConcurrent(b *testing.B) {
|
|||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
b.Run(tc.name, func(b *testing.B) {
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg)
|
benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg, benchmarkTransportSecurityModernPSK)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfig) {
|
func BenchmarkStreamTCPThroughputConcurrentTrustedRaw(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payloadSize int
|
||||||
|
concurrency int
|
||||||
|
cfg StreamConfig
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "streams_2_512KiB",
|
||||||
|
payloadSize: 512 * 1024,
|
||||||
|
concurrency: 2,
|
||||||
|
cfg: StreamConfig{
|
||||||
|
ChunkSize: 512 * 1024,
|
||||||
|
InboundQueueLimit: 512,
|
||||||
|
InboundBufferedBytesLimit: 128 * 1024 * 1024,
|
||||||
|
OutboundWindowBytes: 32 * 1024 * 1024,
|
||||||
|
OutboundMaxInFlightChunks: 64,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streams_4_512KiB",
|
||||||
|
payloadSize: 512 * 1024,
|
||||||
|
concurrency: 4,
|
||||||
|
cfg: StreamConfig{
|
||||||
|
ChunkSize: 512 * 1024,
|
||||||
|
InboundQueueLimit: 1024,
|
||||||
|
InboundBufferedBytesLimit: 256 * 1024 * 1024,
|
||||||
|
OutboundWindowBytes: 64 * 1024 * 1024,
|
||||||
|
OutboundMaxInFlightChunks: 128,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streams_8_512KiB",
|
||||||
|
payloadSize: 512 * 1024,
|
||||||
|
concurrency: 8,
|
||||||
|
cfg: StreamConfig{
|
||||||
|
ChunkSize: 512 * 1024,
|
||||||
|
InboundQueueLimit: 2048,
|
||||||
|
InboundBufferedBytesLimit: 512 * 1024 * 1024,
|
||||||
|
OutboundWindowBytes: 128 * 1024 * 1024,
|
||||||
|
OutboundMaxInFlightChunks: 256,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg, benchmarkTransportSecurityTrustedRaw)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfig, securityMode benchmarkTransportSecurityMode) {
|
||||||
b.Helper()
|
b.Helper()
|
||||||
|
|
||||||
server := NewServer().(*ServerCommon)
|
server := NewServer().(*ServerCommon)
|
||||||
server.SetStreamConfig(cfg)
|
server.SetStreamConfig(cfg)
|
||||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
benchmarkApplyServerTransportSecurity(b, server, securityMode)
|
||||||
b.Fatalf("UseModernPSKServer failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
acceptCh := make(chan StreamAcceptInfo, 1)
|
acceptCh := make(chan StreamAcceptInfo, 1)
|
||||||
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||||||
@ -137,9 +239,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi
|
|||||||
|
|
||||||
client := NewClient().(*ClientCommon)
|
client := NewClient().(*ClientCommon)
|
||||||
client.SetStreamConfig(cfg)
|
client.SetStreamConfig(cfg)
|
||||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
benchmarkApplyClientTransportSecurity(b, client, securityMode)
|
||||||
b.Fatalf("UseModernPSKClient failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
|
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
|
||||||
b.Fatalf("client Connect failed: %v", err)
|
b.Fatalf("client Connect failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -198,7 +298,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi
|
|||||||
_ = stream.Close()
|
_ = stream.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, cfg StreamConfig) {
|
func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, cfg StreamConfig, securityMode benchmarkTransportSecurityMode) {
|
||||||
b.Helper()
|
b.Helper()
|
||||||
if concurrency <= 0 {
|
if concurrency <= 0 {
|
||||||
b.Fatal("concurrency must be > 0")
|
b.Fatal("concurrency must be > 0")
|
||||||
@ -206,9 +306,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu
|
|||||||
|
|
||||||
server := NewServer().(*ServerCommon)
|
server := NewServer().(*ServerCommon)
|
||||||
server.SetStreamConfig(cfg)
|
server.SetStreamConfig(cfg)
|
||||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
benchmarkApplyServerTransportSecurity(b, server, securityMode)
|
||||||
b.Fatalf("UseModernPSKServer failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
acceptCh := make(chan StreamAcceptInfo, concurrency*2)
|
acceptCh := make(chan StreamAcceptInfo, concurrency*2)
|
||||||
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||||||
@ -225,9 +323,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu
|
|||||||
|
|
||||||
client := NewClient().(*ClientCommon)
|
client := NewClient().(*ClientCommon)
|
||||||
client.SetStreamConfig(cfg)
|
client.SetStreamConfig(cfg)
|
||||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
benchmarkApplyClientTransportSecurity(b, client, securityMode)
|
||||||
b.Fatalf("UseModernPSKClient failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
|
if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil {
|
||||||
b.Fatalf("client Connect failed: %v", err)
|
b.Fatalf("client Connect failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
85
stream_buffer_release_test.go
Normal file
85
stream_buffer_release_test.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStreamOwnedChunkReleaseAfterRead(t *testing.T) {
|
||||||
|
stream := newStreamHandle(context.Background(), newStreamRuntime("stream-buffer-release-read"), clientFileScope(), StreamOpenRequest{
|
||||||
|
StreamID: "stream-buffer-release-read",
|
||||||
|
DataID: 1,
|
||||||
|
}, 0, nil, nil, 0, nil, nil, nil, streamConfig{})
|
||||||
|
|
||||||
|
released := 0
|
||||||
|
if err := stream.pushOwnedChunkWithRelease([]byte("hello"), func() {
|
||||||
|
released++
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("pushOwnedChunkWithRelease failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 5)
|
||||||
|
n, err := stream.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Read failed: %v", err)
|
||||||
|
}
|
||||||
|
if n != 5 || string(buf[:n]) != "hello" {
|
||||||
|
t.Fatalf("Read = %d %q, want 5 hello", n, string(buf[:n]))
|
||||||
|
}
|
||||||
|
if released != 1 {
|
||||||
|
t.Fatalf("release count = %d, want 1", released)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamOwnedChunkReleaseOnReset(t *testing.T) {
|
||||||
|
stream := newStreamHandle(context.Background(), newStreamRuntime("stream-buffer-release-reset"), clientFileScope(), StreamOpenRequest{
|
||||||
|
StreamID: "stream-buffer-release-reset",
|
||||||
|
DataID: 1,
|
||||||
|
}, 0, nil, nil, 0, nil, nil, nil, streamConfig{})
|
||||||
|
|
||||||
|
released := 0
|
||||||
|
if err := stream.pushOwnedChunkWithRelease([]byte("hello"), func() {
|
||||||
|
released++
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("pushOwnedChunkWithRelease failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.markReset(errors.New("boom"))
|
||||||
|
if released != 1 {
|
||||||
|
t.Fatalf("release count = %d, want 1", released)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientDispatchFastStreamDataWithOwnerReleasesAfterRead(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
runtime := client.getStreamRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
t.Fatal("client stream runtime should not be nil")
|
||||||
|
}
|
||||||
|
stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{
|
||||||
|
StreamID: "stream-owner",
|
||||||
|
DataID: 23,
|
||||||
|
Channel: StreamDataChannel,
|
||||||
|
}, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot())
|
||||||
|
if err := runtime.register(clientFileScope(), stream); err != nil {
|
||||||
|
t.Fatalf("register stream failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
released := 0
|
||||||
|
owner := newStreamReadPayloadOwner(func() {
|
||||||
|
released++
|
||||||
|
})
|
||||||
|
client.dispatchFastStreamDataWithOwner(streamFastDataFrame{
|
||||||
|
DataID: 23,
|
||||||
|
Seq: 1,
|
||||||
|
Payload: []byte("fast-owner"),
|
||||||
|
}, owner)
|
||||||
|
owner.done()
|
||||||
|
|
||||||
|
readStreamExactly(t, stream, "fast-owner", 2*time.Second)
|
||||||
|
if released != 1 {
|
||||||
|
t.Fatalf("release count = %d, want 1", released)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -83,6 +83,10 @@ func (s *ServerCommon) dispatchStreamEnvelope(logical *LogicalConn, transport *T
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
|
func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
|
||||||
|
c.dispatchFastStreamDataWithOwner(frame, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) dispatchFastStreamDataWithOwner(frame streamFastDataFrame, owner *streamReadPayloadOwner) {
|
||||||
if frame.DataID == 0 {
|
if frame.DataID == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -107,7 +111,13 @@ func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
|
|||||||
c.bestEffortRejectInboundStreamData(stream.ID(), frame.DataID, detachErr.Error())
|
c.bestEffortRejectInboundStreamData(stream.ID(), frame.DataID, detachErr.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := stream.pushOwnedChunk(frame.Payload); err != nil {
|
var err error
|
||||||
|
if owner != nil {
|
||||||
|
err = stream.pushOwnedChunkWithRelease(frame.Payload, owner.retainChunk())
|
||||||
|
} else {
|
||||||
|
err = stream.pushOwnedChunk(frame.Payload)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
if c.showError || c.debugMode {
|
if c.showError || c.debugMode {
|
||||||
fmt.Println("client stream push chunk error", err)
|
fmt.Println("client stream push chunk error", err)
|
||||||
}
|
}
|
||||||
@ -118,6 +128,10 @@ func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame) {
|
func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame) {
|
||||||
|
s.dispatchFastStreamDataWithOwner(logical, transport, conn, frame, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) dispatchFastStreamDataWithOwner(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame, owner *streamReadPayloadOwner) {
|
||||||
if logical == nil || frame.DataID == 0 {
|
if logical == nil || frame.DataID == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -141,7 +155,13 @@ func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *T
|
|||||||
s.bestEffortRejectInboundStreamData(logical, transport, conn, stream.ID(), frame.DataID, detachErr.Error())
|
s.bestEffortRejectInboundStreamData(logical, transport, conn, stream.ID(), frame.DataID, detachErr.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := stream.pushOwnedChunk(frame.Payload); err != nil {
|
var err error
|
||||||
|
if owner != nil {
|
||||||
|
err = stream.pushOwnedChunkWithRelease(frame.Payload, owner.retainChunk())
|
||||||
|
} else {
|
||||||
|
err = stream.pushOwnedChunk(frame.Payload)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
if s.showError || s.debugMode {
|
if s.showError || s.debugMode {
|
||||||
fmt.Println("server stream push chunk error", err)
|
fmt.Println("server stream push chunk error", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -156,11 +156,12 @@ func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) encodeFastStreamPayload(frame streamFastDataFrame) ([]byte, error) {
|
func (c *ClientCommon) encodeFastStreamPayload(frame streamFastDataFrame) ([]byte, error) {
|
||||||
if c != nil && c.fastStreamEncode != nil && frame.Flags == 0 {
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
return c.fastStreamEncode(c.SecretKey, frame.DataID, frame.Seq, frame.Payload)
|
if c != nil && profile.fastStreamEncode != nil && frame.Flags == 0 {
|
||||||
|
return profile.fastStreamEncode(profile.secretKey, frame.DataID, frame.Seq, frame.Payload)
|
||||||
}
|
}
|
||||||
if c != nil && c.fastPlainEncode != nil {
|
if c != nil && profile.fastPlainEncode != nil {
|
||||||
return encodeStreamFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame)
|
return encodeStreamFastFramePayloadFast(profile.fastPlainEncode, profile.secretKey, frame)
|
||||||
}
|
}
|
||||||
plain, err := encodeStreamFastFramePayload(frame)
|
plain, err := encodeStreamFastFramePayload(frame)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -181,8 +182,9 @@ func (c *ClientCommon) encodeFastStreamBatchPayload(frames []streamFastDataFrame
|
|||||||
if c == nil {
|
if c == nil {
|
||||||
return nil, errStreamClientNil
|
return nil, errStreamClientNil
|
||||||
}
|
}
|
||||||
if c.fastPlainEncode != nil {
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
return encodeStreamFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames)
|
if profile.fastPlainEncode != nil {
|
||||||
|
return encodeStreamFastBatchPayloadFast(profile.fastPlainEncode, profile.secretKey, frames)
|
||||||
}
|
}
|
||||||
plain, err := encodeStreamFastBatchPlain(frames)
|
plain, err := encodeStreamFastBatchPlain(frames)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -91,25 +91,24 @@ func writeStreamFastBatchPlain(dst []byte, frames []streamFastDataFrame) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, error) {
|
func walkStreamFastBatchPlain(payload []byte, fn func(streamFastDataFrame) error) (bool, error) {
|
||||||
if len(payload) < 4 || string(payload[:4]) != streamFastBatchMagic {
|
if len(payload) < 4 || string(payload[:4]) != streamFastBatchMagic {
|
||||||
return nil, false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
if len(payload) < streamFastBatchHeaderLen {
|
if len(payload) < streamFastBatchHeaderLen {
|
||||||
return nil, true, errStreamFastPayloadInvalid
|
return true, errStreamFastPayloadInvalid
|
||||||
}
|
}
|
||||||
if payload[4] != streamFastBatchVersion {
|
if payload[4] != streamFastBatchVersion {
|
||||||
return nil, true, errStreamFastPayloadInvalid
|
return true, errStreamFastPayloadInvalid
|
||||||
}
|
}
|
||||||
count := int(binary.BigEndian.Uint32(payload[8:12]))
|
count := int(binary.BigEndian.Uint32(payload[8:12]))
|
||||||
if count <= 0 {
|
if count <= 0 {
|
||||||
return nil, true, errStreamFastPayloadInvalid
|
return true, errStreamFastPayloadInvalid
|
||||||
}
|
}
|
||||||
frames := make([]streamFastDataFrame, 0, count)
|
|
||||||
offset := streamFastBatchHeaderLen
|
offset := streamFastBatchHeaderLen
|
||||||
for index := 0; index < count; index++ {
|
for index := 0; index < count; index++ {
|
||||||
if len(payload)-offset < streamFastBatchItemHeaderLen {
|
if len(payload)-offset < streamFastBatchItemHeaderLen {
|
||||||
return nil, true, errStreamFastPayloadInvalid
|
return true, errStreamFastPayloadInvalid
|
||||||
}
|
}
|
||||||
flags := payload[offset]
|
flags := payload[offset]
|
||||||
dataID := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
|
dataID := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
|
||||||
@ -117,29 +116,62 @@ func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, er
|
|||||||
payloadLen := int(binary.BigEndian.Uint32(payload[offset+20 : offset+24]))
|
payloadLen := int(binary.BigEndian.Uint32(payload[offset+20 : offset+24]))
|
||||||
offset += streamFastBatchItemHeaderLen
|
offset += streamFastBatchItemHeaderLen
|
||||||
if dataID == 0 || payloadLen < 0 || len(payload)-offset < payloadLen {
|
if dataID == 0 || payloadLen < 0 || len(payload)-offset < payloadLen {
|
||||||
return nil, true, errStreamFastPayloadInvalid
|
return true, errStreamFastPayloadInvalid
|
||||||
}
|
}
|
||||||
frames = append(frames, streamFastDataFrame{
|
if fn != nil {
|
||||||
|
if err := fn(streamFastDataFrame{
|
||||||
Flags: flags,
|
Flags: flags,
|
||||||
DataID: dataID,
|
DataID: dataID,
|
||||||
Seq: seq,
|
Seq: seq,
|
||||||
Payload: payload[offset : offset+payloadLen],
|
Payload: payload[offset : offset+payloadLen],
|
||||||
})
|
}); err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
}
|
||||||
offset += payloadLen
|
offset += payloadLen
|
||||||
}
|
}
|
||||||
if offset != len(payload) {
|
if offset != len(payload) {
|
||||||
return nil, true, errStreamFastPayloadInvalid
|
return true, errStreamFastPayloadInvalid
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, error) {
|
||||||
|
frames := make([]streamFastDataFrame, 0, 1)
|
||||||
|
matched, err := walkStreamFastBatchPlain(payload, func(frame streamFastDataFrame) error {
|
||||||
|
frames = append(frames, frame)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if !matched || err != nil {
|
||||||
|
return nil, matched, err
|
||||||
}
|
}
|
||||||
return frames, true, nil
|
return frames, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeStreamFastDataFrames(payload []byte) ([]streamFastDataFrame, bool, error) {
|
func walkStreamFastFrames(payload []byte, fn func(streamFastDataFrame) error) (bool, error) {
|
||||||
if frames, matched, err := decodeStreamFastBatchPlain(payload); matched {
|
if matched, err := walkStreamFastBatchPlain(payload, fn); matched {
|
||||||
return frames, true, err
|
return true, err
|
||||||
}
|
}
|
||||||
frame, matched, err := decodeStreamFastDataFrame(payload)
|
frame, matched, err := decodeStreamFastDataFrame(payload)
|
||||||
|
if !matched || err != nil {
|
||||||
|
return matched, err
|
||||||
|
}
|
||||||
|
if fn != nil {
|
||||||
|
if err := fn(frame); err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeStreamFastDataFrames(payload []byte) ([]streamFastDataFrame, bool, error) {
|
||||||
|
frames := make([]streamFastDataFrame, 0, 1)
|
||||||
|
matched, err := walkStreamFastFrames(payload, func(frame streamFastDataFrame) error {
|
||||||
|
frames = append(frames, frame)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
if !matched || err != nil {
|
if !matched || err != nil {
|
||||||
return nil, matched, err
|
return nil, matched, err
|
||||||
}
|
}
|
||||||
return []streamFastDataFrame{frame}, true, nil
|
return frames, true, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,7 +9,10 @@ var (
|
|||||||
errTransportPayloadDecryptFailed = errors.New("transport payload decrypt failed")
|
errTransportPayloadDecryptFailed = errors.New("transport payload decrypt failed")
|
||||||
)
|
)
|
||||||
|
|
||||||
func encryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgEn func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) {
|
func encryptTransportPayloadCodec(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgEn func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) {
|
||||||
|
if mode == ProtectionExternal {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
if runtime != nil {
|
if runtime != nil {
|
||||||
encoded, err := runtime.sealPlainPayload(data)
|
encoded, err := runtime.sealPlainPayload(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -27,7 +30,10 @@ func encryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgEn func([]b
|
|||||||
return encoded, nil
|
return encoded, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) {
|
func decryptTransportPayloadCodec(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) {
|
||||||
|
if mode == ProtectionExternal {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
if runtime != nil {
|
if runtime != nil {
|
||||||
plain, err := runtime.openPayload(data)
|
plain, err := runtime.openPayload(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -45,7 +51,10 @@ func decryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgDe func([]b
|
|||||||
return plain, nil
|
return plain, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decryptTransportPayloadCodecPooled(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte, release func()) ([]byte, func(), error) {
|
func decryptTransportPayloadCodecPooled(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte, release func()) ([]byte, func(), error) {
|
||||||
|
if mode == ProtectionExternal {
|
||||||
|
return data, release, nil
|
||||||
|
}
|
||||||
if runtime != nil {
|
if runtime != nil {
|
||||||
plain, plainRelease, err := runtime.openPayloadPooled(data, release)
|
plain, plainRelease, err := runtime.openPayloadPooled(data, release)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -69,7 +78,10 @@ func decryptTransportPayloadCodecPooled(runtime *modernPSKCodecRuntime, msgDe fu
|
|||||||
return plain, nil, nil
|
return plain, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decryptTransportPayloadCodecOwnedPooled(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, func(), error) {
|
func decryptTransportPayloadCodecOwnedPooled(mode ProtectionMode, runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, func(), error) {
|
||||||
|
if mode == ProtectionExternal {
|
||||||
|
return data, nil, nil
|
||||||
|
}
|
||||||
if runtime != nil {
|
if runtime != nil {
|
||||||
plain, plainRelease, err := runtime.openPayloadOwnedPooled(data)
|
plain, plainRelease, err := runtime.openPayloadOwnedPooled(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -124,9 +136,14 @@ func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
logical := logicalConnFromClient(c)
|
||||||
msgEn := c.clientConnMsgEnSnapshot()
|
msgEn := c.clientConnMsgEnSnapshot()
|
||||||
secretKey := c.clientConnSecretKeySnapshot()
|
secretKey := c.clientConnSecretKeySnapshot()
|
||||||
data, err = encryptTransportPayloadCodec(c.clientConnModernPSKRuntimeSnapshot(), msgEn, secretKey, data)
|
mode := ProtectionManaged
|
||||||
|
if logical != nil {
|
||||||
|
mode = logical.protectionModeSnapshot()
|
||||||
|
}
|
||||||
|
data, err = encryptTransportPayloadCodec(mode, c.clientConnModernPSKRuntimeSnapshot(), msgEn, secretKey, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -140,7 +157,11 @@ func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte
|
|||||||
func (s *ServerCommon) decodeTransferMsg(c *ClientConn, data []byte) (TransferMsg, error) {
|
func (s *ServerCommon) decodeTransferMsg(c *ClientConn, data []byte) (TransferMsg, error) {
|
||||||
msgDe := c.clientConnMsgDeSnapshot()
|
msgDe := c.clientConnMsgDeSnapshot()
|
||||||
secretKey := c.clientConnSecretKeySnapshot()
|
secretKey := c.clientConnSecretKeySnapshot()
|
||||||
plain, err := decryptTransportPayloadCodec(c.clientConnModernPSKRuntimeSnapshot(), msgDe, secretKey, data)
|
mode := ProtectionManaged
|
||||||
|
if logical := logicalConnFromClient(c); logical != nil {
|
||||||
|
mode = logical.protectionModeSnapshot()
|
||||||
|
}
|
||||||
|
plain, err := decryptTransportPayloadCodec(mode, c.clientConnModernPSKRuntimeSnapshot(), msgDe, secretKey, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return TransferMsg{}, err
|
return TransferMsg{}, err
|
||||||
}
|
}
|
||||||
@ -172,7 +193,8 @@ func (c *ClientCommon) encodeEnvelopePlain(env Envelope) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) encryptTransportPayload(data []byte) ([]byte, error) {
|
func (c *ClientCommon) encryptTransportPayload(data []byte) ([]byte, error) {
|
||||||
return encryptTransportPayloadCodec(c.modernPSKRuntime, c.msgEn, c.SecretKey, data)
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
|
return encryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgEn, profile.secretKey, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) encodeEnvelope(env Envelope) ([]byte, error) {
|
func (c *ClientCommon) encodeEnvelope(env Envelope) ([]byte, error) {
|
||||||
@ -196,7 +218,8 @@ func (c *ClientCommon) decodeEnvelope(data []byte) (Envelope, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) decryptTransportPayload(data []byte) ([]byte, error) {
|
func (c *ClientCommon) decryptTransportPayload(data []byte) ([]byte, error) {
|
||||||
return decryptTransportPayloadCodec(c.modernPSKRuntime, c.msgDe, c.SecretKey, data)
|
profile := c.clientTransportProtectionSnapshot()
|
||||||
|
return decryptTransportPayloadCodec(profile.mode, profile.runtime, profile.msgDe, profile.secretKey, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientCommon) decodeEnvelopePlain(data []byte) (Envelope, error) {
|
func (c *ClientCommon) decodeEnvelopePlain(data []byte) (Envelope, error) {
|
||||||
@ -251,7 +274,7 @@ func (s *ServerCommon) encryptTransportPayloadLogical(logical *LogicalConn, data
|
|||||||
if msgEn == nil {
|
if msgEn == nil {
|
||||||
return nil, errTransportDetached
|
return nil, errTransportDetached
|
||||||
}
|
}
|
||||||
return encryptTransportPayloadCodec(logical.modernPSKRuntimeSnapshot(), msgEn, secretKey, data)
|
return encryptTransportPayloadCodec(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), msgEn, secretKey, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerCommon) encodeEnvelopeLogical(logical *LogicalConn, env Envelope) ([]byte, error) {
|
func (s *ServerCommon) encodeEnvelopeLogical(logical *LogicalConn, env Envelope) ([]byte, error) {
|
||||||
@ -290,7 +313,7 @@ func (s *ServerCommon) decryptTransportPayloadLogical(logical *LogicalConn, data
|
|||||||
if msgDe == nil {
|
if msgDe == nil {
|
||||||
return nil, errTransportDetached
|
return nil, errTransportDetached
|
||||||
}
|
}
|
||||||
return decryptTransportPayloadCodec(logical.modernPSKRuntimeSnapshot(), msgDe, secretKey, data)
|
return decryptTransportPayloadCodec(logical.protectionModeSnapshot(), logical.modernPSKRuntimeSnapshot(), msgDe, secretKey, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerCommon) decodeEnvelopePlain(data []byte) (Envelope, error) {
|
func (s *ServerCommon) decodeEnvelopePlain(data []byte) (Envelope, error) {
|
||||||
|
|||||||
@ -143,6 +143,43 @@ func TestStreamBatchSenderRespectsBindingWriteDeadlineWhenReceiverStalls(t *test
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBulkBatchSenderFlushAggregatesAdaptivePayloadObservation(t *testing.T) {
|
||||||
|
conn := &delayedWriteConn{delay: 20 * time.Millisecond}
|
||||||
|
binding := newTransportBinding(conn, stario.NewQueue())
|
||||||
|
sender := newTestBulkBatchSender(binding)
|
||||||
|
payloadA := bytes.Repeat([]byte("a"), 128*1024)
|
||||||
|
payloadB := bytes.Repeat([]byte("b"), 128*1024)
|
||||||
|
|
||||||
|
err := sender.flush([]bulkBatchRequest{
|
||||||
|
{
|
||||||
|
ctx: context.Background(),
|
||||||
|
frames: []bulkFastFrame{{
|
||||||
|
Type: bulkFastPayloadTypeData,
|
||||||
|
DataID: 1,
|
||||||
|
Seq: 1,
|
||||||
|
Payload: payloadA,
|
||||||
|
}},
|
||||||
|
fastPathVersion: bulkFastPathVersionV1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ctx: context.Background(),
|
||||||
|
frames: []bulkFastFrame{{
|
||||||
|
Type: bulkFastPayloadTypeData,
|
||||||
|
DataID: 2,
|
||||||
|
Seq: 1,
|
||||||
|
Payload: payloadB,
|
||||||
|
}},
|
||||||
|
fastPathVersion: bulkFastPathVersionV1,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("flush failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := binding.bulkAdaptiveSoftPayloadBytesSnapshot(), bulkAdaptiveSoftPayloadMinBytes; got != want {
|
||||||
|
t.Fatalf("adaptive bulk soft payload = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestControlBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) {
|
func TestControlBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) {
|
||||||
left, right := net.Pipe()
|
left, right := net.Pipe()
|
||||||
defer left.Close()
|
defer left.Close()
|
||||||
@ -255,6 +292,10 @@ func (c *vectoredShortWriteConn) WriteBuffers(bufs *net.Buffers) (int64, error)
|
|||||||
return written, nil
|
return written, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *vectoredShortWriteConn) writeBuffers(bufs *net.Buffers) (int64, error) {
|
||||||
|
return c.WriteBuffers(bufs)
|
||||||
|
}
|
||||||
|
|
||||||
type unwrapVectoredConn struct {
|
type unwrapVectoredConn struct {
|
||||||
inner net.Conn
|
inner net.Conn
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user