notify/security_psk_test.go
starainrt 09d972c7b7
feat(notify): 重构通信内核并补齐 stream/bulk/record/transfer 能力
- 引入 LogicalConn/TransportConn 分层,ClientConn 保留兼容适配层
  - 新增 Stream、Bulk、RecordStream 三条数据面能力及对应控制路径
  - 完成 transfer/file 传输内核与状态快照、诊断能力
  - 补齐 reconnect、inbound dispatcher、modern psk 等基础模块
  - 增加大规模回归、并发与基准测试覆盖
  - 更新依赖库
2026-04-15 15:24:36 +08:00

327 lines
9.9 KiB
Go

package notify
import (
"bytes"
"errors"
"reflect"
"testing"
"b612.me/starcrypto"
)
func testModernPSKOptions() *ModernPSKOptions {
return &ModernPSKOptions{
Salt: []byte("notify-modern-psk-test-salt"),
AAD: []byte("notify-modern-psk-test-aad"),
Argon2Params: starcrypto.Argon2Params{
Time: 1,
Memory: 8,
Threads: 1,
KeyLen: 32,
},
}
}
func TestUseModernPSKRoundTrip(t *testing.T) {
client := NewClient()
server := NewServer()
secret := []byte("correct horse battery staple")
opts := testModernPSKOptions()
if err := UseModernPSKClient(client, secret, opts); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseModernPSKServer(server, secret, opts); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
cc := client.(*ClientCommon)
ss := server.(*ServerCommon)
if !cc.SkipExchangeKey() {
t.Fatal("client should skip legacy key exchange after UseModernPSKClient")
}
if len(cc.SecretKey) != 32 {
t.Fatalf("client derived key length = %d, want 32", len(cc.SecretKey))
}
if !bytes.Equal(cc.SecretKey, ss.SecretKey) {
t.Fatal("derived transport keys do not match")
}
plain := []byte("notify modern psk transport")
wire := cc.msgEn(cc.SecretKey, plain)
got := ss.defaultMsgDe(ss.SecretKey, wire)
if !bytes.Equal(got, plain) {
t.Fatalf("server decode mismatch: got %q want %q", got, plain)
}
replyWire := ss.defaultMsgEn(ss.SecretKey, plain)
reply := cc.msgDe(cc.SecretKey, replyWire)
if !bytes.Equal(reply, plain) {
t.Fatalf("client decode mismatch: got %q want %q", reply, plain)
}
}
func TestNewClientConnectRequiresModernPSK(t *testing.T) {
client := NewClient()
err := client.Connect("tcp", "127.0.0.1:1")
if !errors.Is(err, errModernPSKRequired) {
t.Fatalf("Connect error = %v, want %v", err, errModernPSKRequired)
}
}
func TestNewServerListenRequiresModernPSK(t *testing.T) {
server := NewServer()
err := server.Listen("tcp", "127.0.0.1:1")
if !errors.Is(err, errModernPSKRequired) {
t.Fatalf("Listen error = %v, want %v", err, errModernPSKRequired)
}
}
func TestDefaultConstructorsUseModernTransportAfterSetSecretKey(t *testing.T) {
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
sharedKey := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(sharedKey)
server.SetSecretKey(sharedKey)
plain := []byte("notify default modern transport")
wire := client.msgEn(client.SecretKey, plain)
got := server.defaultMsgDe(server.SecretKey, wire)
if !bytes.Equal(got, plain) {
t.Fatalf("server decode mismatch: got %q want %q", got, plain)
}
}
func TestDefaultConstructorsDecodeSignalEnvelopeWithModernTransport(t *testing.T) {
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
sharedKey := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(sharedKey)
server.SetSecretKey(sharedKey)
want := TransferMsg{
ID: 42,
Key: "modern-signal",
Value: MsgVal("payload"),
Type: MSG_ASYNC,
}
body, err := client.sequenceEn(want)
if err != nil {
t.Fatalf("sequenceEn failed: %v", err)
}
wire := client.msgEn(client.SecretKey, body)
env, err := server.decodeEnvelope(newServerCodecClientConnForTest(server), wire)
if err != nil {
t.Fatalf("decodeEnvelope failed: %v", err)
}
if env.Kind != EnvelopeSignal {
t.Fatalf("envelope kind = %v, want %v", env.Kind, EnvelopeSignal)
}
got, err := unwrapTransferMsgEnvelope(env, server.sequenceDe)
if err != nil {
t.Fatalf("unwrapTransferMsgEnvelope failed: %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("signal mismatch: got %#v want %#v", got, want)
}
}
func TestDefaultConstructorsDecodeFileEnvelopesWithModernTransport(t *testing.T) {
client := NewClient().(*ClientCommon)
server := NewServer().(*ServerCommon)
sharedKey := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(sharedKey)
server.SetSecretKey(sharedKey)
tests := []struct {
name string
env Envelope
}{
{
name: "file-meta",
env: newFileMetaEnvelope("file-1", "demo.txt", 64, "checksum", 0644, 123456789),
},
{
name: "file-chunk",
env: newFileChunkEnvelope("file-1", 32, []byte("chunk-data")),
},
{
name: "file-ack",
env: newFileAckEnvelope("file-1", "chunk", 32, ""),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body, err := client.sequenceEn(tt.env)
if err != nil {
t.Fatalf("sequenceEn failed: %v", err)
}
wire := client.msgEn(client.SecretKey, body)
got, err := server.decodeEnvelope(newServerCodecClientConnForTest(server), wire)
if err != nil {
t.Fatalf("decodeEnvelope failed: %v", err)
}
if !reflect.DeepEqual(got, tt.env) {
t.Fatalf("envelope mismatch: got %#v want %#v", got, tt.env)
}
})
}
}
func TestUseModernPSKRejectsEmptySecret(t *testing.T) {
if err := UseModernPSKClient(NewClient(), nil, testModernPSKOptions()); err == nil {
t.Fatal("UseModernPSKClient should reject empty secret")
}
if err := UseModernPSKServer(NewServer(), nil, testModernPSKOptions()); err == nil {
t.Fatal("UseModernPSKServer should reject empty secret")
}
}
func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) {
key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions())
if err != nil {
t.Fatalf("deriveModernPSKKey failed: %v", err)
}
_, modernMsgDe := buildModernPSKCodecs(aad)
legacyWire := defaultMsgEn(key, []byte("legacy payload"))
if got := modernMsgDe(key, legacyWire); got != nil {
t.Fatalf("modern decoder should reject legacy payload, got %q", got)
}
}
func TestModernPSKCodecUsesUniqueNoncePerMessage(t *testing.T) {
key, aad, err := deriveModernPSKKey([]byte("notify-unique-nonce"), testModernPSKOptions())
if err != nil {
t.Fatalf("deriveModernPSKKey failed: %v", err)
}
msgEn, msgDe := buildModernPSKCodecs(aad)
first := msgEn(key, []byte("payload"))
second := msgEn(key, []byte("payload"))
if first == nil || second == nil {
t.Fatal("modern msgEn should produce payload")
}
if bytes.Equal(first, second) {
t.Fatal("two modern payloads should not be byte-identical")
}
if !bytes.Equal(first[:len(modernPSKMagic)], modernPSKMagic) {
t.Fatalf("first payload magic = %q, want %q", first[:len(modernPSKMagic)], modernPSKMagic)
}
if !bytes.Equal(second[:len(modernPSKMagic)], modernPSKMagic) {
t.Fatalf("second payload magic = %q, want %q", second[:len(modernPSKMagic)], modernPSKMagic)
}
if bytes.Equal(first[len(modernPSKMagic):len(modernPSKMagic)+modernPSKNonceSize], second[len(modernPSKMagic):len(modernPSKMagic)+modernPSKNonceSize]) {
t.Fatal("modern payload nonces should differ between messages")
}
if got := msgDe(key, first); !bytes.Equal(got, []byte("payload")) {
t.Fatalf("first decode = %q, want %q", got, "payload")
}
if got := msgDe(key, second); !bytes.Equal(got, []byte("payload")) {
t.Fatalf("second decode = %q, want %q", got, "payload")
}
}
func TestModernPSKFastStreamEncodeRoundTrip(t *testing.T) {
key, aad, err := deriveModernPSKKey([]byte("notify-fast-stream"), testModernPSKOptions())
if err != nil {
t.Fatalf("deriveModernPSKKey failed: %v", err)
}
transport := buildModernPSKTransportBundle(aad)
wire, err := transport.fastStreamEncode(key, 23, 7, []byte("payload"))
if err != nil {
t.Fatalf("fastStreamEncode failed: %v", err)
}
plain := transport.msgDe(key, wire)
if plain == nil {
t.Fatal("msgDe returned nil")
}
frame, matched, err := decodeStreamFastDataFrame(plain)
if err != nil {
t.Fatalf("decodeStreamFastDataFrame failed: %v", err)
}
if !matched {
t.Fatal("decodeStreamFastDataFrame should match fast payload")
}
if frame.DataID != 23 {
t.Fatalf("data id = %d, want %d", frame.DataID, 23)
}
if frame.Seq != 7 {
t.Fatalf("seq = %d, want %d", frame.Seq, 7)
}
if !bytes.Equal(frame.Payload, []byte("payload")) {
t.Fatalf("payload = %q, want %q", frame.Payload, "payload")
}
}
func TestModernPSKFastBulkEncodeRoundTrip(t *testing.T) {
key, aad, err := deriveModernPSKKey([]byte("notify-fast-bulk"), testModernPSKOptions())
if err != nil {
t.Fatalf("deriveModernPSKKey failed: %v", err)
}
transport := buildModernPSKTransportBundle(aad)
wire, err := transport.fastBulkEncode(key, 41, 9, []byte("payload"))
if err != nil {
t.Fatalf("fastBulkEncode failed: %v", err)
}
plain := transport.msgDe(key, wire)
if plain == nil {
t.Fatal("msgDe returned nil")
}
frame, matched, err := decodeBulkFastDataFrame(plain)
if err != nil {
t.Fatalf("decodeBulkFastDataFrame failed: %v", err)
}
if !matched {
t.Fatal("decodeBulkFastDataFrame should match fast payload")
}
if frame.DataID != 41 {
t.Fatalf("data id = %d, want %d", frame.DataID, 41)
}
if frame.Seq != 9 {
t.Fatalf("seq = %d, want %d", frame.Seq, 9)
}
if !bytes.Equal(frame.Payload, []byte("payload")) {
t.Fatalf("payload = %q, want %q", frame.Payload, "payload")
}
}
func TestUseLegacySecurityRoundTrip(t *testing.T) {
client := NewClient()
server := NewServer()
UseLegacySecurityClient(client)
UseLegacySecurityServer(server)
cc := client.(*ClientCommon)
ss := server.(*ServerCommon)
if cc.SkipExchangeKey() {
t.Fatal("legacy client should keep legacy exchange enabled")
}
if !bytes.Equal(cc.SecretKey, defaultAesKey) {
t.Fatal("legacy client should restore the default AES key")
}
if !bytes.Equal(ss.SecretKey, defaultAesKey) {
t.Fatal("legacy server should restore the default AES key")
}
if !bytes.Equal(cc.RsaPubKey(), defaultRsaPubKey) {
t.Fatal("legacy client should restore the default RSA public key")
}
if !bytes.Equal(ss.RsaPrivKey(), defaultRsaKey) {
t.Fatal("legacy server should restore the default RSA private key")
}
plain := []byte("notify legacy transport")
wire := cc.msgEn(cc.SecretKey, plain)
got := ss.defaultMsgDe(ss.SecretKey, wire)
if !bytes.Equal(got, plain) {
t.Fatalf("legacy server decode mismatch: got %q want %q", got, plain)
}
replyWire := ss.defaultMsgEn(ss.SecretKey, plain)
reply := cc.msgDe(cc.SecretKey, replyWire)
if !bytes.Equal(reply, plain) {
t.Fatalf("legacy client decode mismatch: got %q want %q", reply, plain)
}
}