- 引入 LogicalConn/TransportConn 分层,ClientConn 保留兼容适配层 - 新增 Stream、Bulk、RecordStream 三条数据面能力及对应控制路径 - 完成 transfer/file 传输内核与状态快照、诊断能力 - 补齐 reconnect、inbound dispatcher、modern psk 等基础模块 - 增加大规模回归、并发与基准测试覆盖 - 更新依赖库
327 lines
9.9 KiB
Go
327 lines
9.9 KiB
Go
package notify
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"b612.me/starcrypto"
|
|
)
|
|
|
|
func testModernPSKOptions() *ModernPSKOptions {
|
|
return &ModernPSKOptions{
|
|
Salt: []byte("notify-modern-psk-test-salt"),
|
|
AAD: []byte("notify-modern-psk-test-aad"),
|
|
Argon2Params: starcrypto.Argon2Params{
|
|
Time: 1,
|
|
Memory: 8,
|
|
Threads: 1,
|
|
KeyLen: 32,
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestUseModernPSKRoundTrip(t *testing.T) {
|
|
client := NewClient()
|
|
server := NewServer()
|
|
secret := []byte("correct horse battery staple")
|
|
opts := testModernPSKOptions()
|
|
|
|
if err := UseModernPSKClient(client, secret, opts); err != nil {
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
|
}
|
|
if err := UseModernPSKServer(server, secret, opts); err != nil {
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
|
}
|
|
|
|
cc := client.(*ClientCommon)
|
|
ss := server.(*ServerCommon)
|
|
if !cc.SkipExchangeKey() {
|
|
t.Fatal("client should skip legacy key exchange after UseModernPSKClient")
|
|
}
|
|
if len(cc.SecretKey) != 32 {
|
|
t.Fatalf("client derived key length = %d, want 32", len(cc.SecretKey))
|
|
}
|
|
if !bytes.Equal(cc.SecretKey, ss.SecretKey) {
|
|
t.Fatal("derived transport keys do not match")
|
|
}
|
|
|
|
plain := []byte("notify modern psk transport")
|
|
wire := cc.msgEn(cc.SecretKey, plain)
|
|
got := ss.defaultMsgDe(ss.SecretKey, wire)
|
|
if !bytes.Equal(got, plain) {
|
|
t.Fatalf("server decode mismatch: got %q want %q", got, plain)
|
|
}
|
|
|
|
replyWire := ss.defaultMsgEn(ss.SecretKey, plain)
|
|
reply := cc.msgDe(cc.SecretKey, replyWire)
|
|
if !bytes.Equal(reply, plain) {
|
|
t.Fatalf("client decode mismatch: got %q want %q", reply, plain)
|
|
}
|
|
}
|
|
|
|
func TestNewClientConnectRequiresModernPSK(t *testing.T) {
|
|
client := NewClient()
|
|
err := client.Connect("tcp", "127.0.0.1:1")
|
|
if !errors.Is(err, errModernPSKRequired) {
|
|
t.Fatalf("Connect error = %v, want %v", err, errModernPSKRequired)
|
|
}
|
|
}
|
|
|
|
func TestNewServerListenRequiresModernPSK(t *testing.T) {
|
|
server := NewServer()
|
|
err := server.Listen("tcp", "127.0.0.1:1")
|
|
if !errors.Is(err, errModernPSKRequired) {
|
|
t.Fatalf("Listen error = %v, want %v", err, errModernPSKRequired)
|
|
}
|
|
}
|
|
|
|
func TestDefaultConstructorsUseModernTransportAfterSetSecretKey(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
server := NewServer().(*ServerCommon)
|
|
sharedKey := []byte("0123456789abcdef0123456789abcdef")
|
|
client.SetSecretKey(sharedKey)
|
|
server.SetSecretKey(sharedKey)
|
|
|
|
plain := []byte("notify default modern transport")
|
|
wire := client.msgEn(client.SecretKey, plain)
|
|
got := server.defaultMsgDe(server.SecretKey, wire)
|
|
if !bytes.Equal(got, plain) {
|
|
t.Fatalf("server decode mismatch: got %q want %q", got, plain)
|
|
}
|
|
}
|
|
|
|
func TestDefaultConstructorsDecodeSignalEnvelopeWithModernTransport(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
server := NewServer().(*ServerCommon)
|
|
sharedKey := []byte("0123456789abcdef0123456789abcdef")
|
|
client.SetSecretKey(sharedKey)
|
|
server.SetSecretKey(sharedKey)
|
|
|
|
want := TransferMsg{
|
|
ID: 42,
|
|
Key: "modern-signal",
|
|
Value: MsgVal("payload"),
|
|
Type: MSG_ASYNC,
|
|
}
|
|
body, err := client.sequenceEn(want)
|
|
if err != nil {
|
|
t.Fatalf("sequenceEn failed: %v", err)
|
|
}
|
|
wire := client.msgEn(client.SecretKey, body)
|
|
env, err := server.decodeEnvelope(newServerCodecClientConnForTest(server), wire)
|
|
if err != nil {
|
|
t.Fatalf("decodeEnvelope failed: %v", err)
|
|
}
|
|
if env.Kind != EnvelopeSignal {
|
|
t.Fatalf("envelope kind = %v, want %v", env.Kind, EnvelopeSignal)
|
|
}
|
|
got, err := unwrapTransferMsgEnvelope(env, server.sequenceDe)
|
|
if err != nil {
|
|
t.Fatalf("unwrapTransferMsgEnvelope failed: %v", err)
|
|
}
|
|
if !reflect.DeepEqual(got, want) {
|
|
t.Fatalf("signal mismatch: got %#v want %#v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestDefaultConstructorsDecodeFileEnvelopesWithModernTransport(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
server := NewServer().(*ServerCommon)
|
|
sharedKey := []byte("0123456789abcdef0123456789abcdef")
|
|
client.SetSecretKey(sharedKey)
|
|
server.SetSecretKey(sharedKey)
|
|
|
|
tests := []struct {
|
|
name string
|
|
env Envelope
|
|
}{
|
|
{
|
|
name: "file-meta",
|
|
env: newFileMetaEnvelope("file-1", "demo.txt", 64, "checksum", 0644, 123456789),
|
|
},
|
|
{
|
|
name: "file-chunk",
|
|
env: newFileChunkEnvelope("file-1", 32, []byte("chunk-data")),
|
|
},
|
|
{
|
|
name: "file-ack",
|
|
env: newFileAckEnvelope("file-1", "chunk", 32, ""),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
body, err := client.sequenceEn(tt.env)
|
|
if err != nil {
|
|
t.Fatalf("sequenceEn failed: %v", err)
|
|
}
|
|
wire := client.msgEn(client.SecretKey, body)
|
|
got, err := server.decodeEnvelope(newServerCodecClientConnForTest(server), wire)
|
|
if err != nil {
|
|
t.Fatalf("decodeEnvelope failed: %v", err)
|
|
}
|
|
if !reflect.DeepEqual(got, tt.env) {
|
|
t.Fatalf("envelope mismatch: got %#v want %#v", got, tt.env)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUseModernPSKRejectsEmptySecret(t *testing.T) {
|
|
if err := UseModernPSKClient(NewClient(), nil, testModernPSKOptions()); err == nil {
|
|
t.Fatal("UseModernPSKClient should reject empty secret")
|
|
}
|
|
if err := UseModernPSKServer(NewServer(), nil, testModernPSKOptions()); err == nil {
|
|
t.Fatal("UseModernPSKServer should reject empty secret")
|
|
}
|
|
}
|
|
|
|
func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) {
|
|
key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions())
|
|
if err != nil {
|
|
t.Fatalf("deriveModernPSKKey failed: %v", err)
|
|
}
|
|
_, modernMsgDe := buildModernPSKCodecs(aad)
|
|
legacyWire := defaultMsgEn(key, []byte("legacy payload"))
|
|
if got := modernMsgDe(key, legacyWire); got != nil {
|
|
t.Fatalf("modern decoder should reject legacy payload, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestModernPSKCodecUsesUniqueNoncePerMessage(t *testing.T) {
|
|
key, aad, err := deriveModernPSKKey([]byte("notify-unique-nonce"), testModernPSKOptions())
|
|
if err != nil {
|
|
t.Fatalf("deriveModernPSKKey failed: %v", err)
|
|
}
|
|
msgEn, msgDe := buildModernPSKCodecs(aad)
|
|
|
|
first := msgEn(key, []byte("payload"))
|
|
second := msgEn(key, []byte("payload"))
|
|
if first == nil || second == nil {
|
|
t.Fatal("modern msgEn should produce payload")
|
|
}
|
|
if bytes.Equal(first, second) {
|
|
t.Fatal("two modern payloads should not be byte-identical")
|
|
}
|
|
if !bytes.Equal(first[:len(modernPSKMagic)], modernPSKMagic) {
|
|
t.Fatalf("first payload magic = %q, want %q", first[:len(modernPSKMagic)], modernPSKMagic)
|
|
}
|
|
if !bytes.Equal(second[:len(modernPSKMagic)], modernPSKMagic) {
|
|
t.Fatalf("second payload magic = %q, want %q", second[:len(modernPSKMagic)], modernPSKMagic)
|
|
}
|
|
if bytes.Equal(first[len(modernPSKMagic):len(modernPSKMagic)+modernPSKNonceSize], second[len(modernPSKMagic):len(modernPSKMagic)+modernPSKNonceSize]) {
|
|
t.Fatal("modern payload nonces should differ between messages")
|
|
}
|
|
if got := msgDe(key, first); !bytes.Equal(got, []byte("payload")) {
|
|
t.Fatalf("first decode = %q, want %q", got, "payload")
|
|
}
|
|
if got := msgDe(key, second); !bytes.Equal(got, []byte("payload")) {
|
|
t.Fatalf("second decode = %q, want %q", got, "payload")
|
|
}
|
|
}
|
|
|
|
func TestModernPSKFastStreamEncodeRoundTrip(t *testing.T) {
|
|
key, aad, err := deriveModernPSKKey([]byte("notify-fast-stream"), testModernPSKOptions())
|
|
if err != nil {
|
|
t.Fatalf("deriveModernPSKKey failed: %v", err)
|
|
}
|
|
transport := buildModernPSKTransportBundle(aad)
|
|
wire, err := transport.fastStreamEncode(key, 23, 7, []byte("payload"))
|
|
if err != nil {
|
|
t.Fatalf("fastStreamEncode failed: %v", err)
|
|
}
|
|
plain := transport.msgDe(key, wire)
|
|
if plain == nil {
|
|
t.Fatal("msgDe returned nil")
|
|
}
|
|
frame, matched, err := decodeStreamFastDataFrame(plain)
|
|
if err != nil {
|
|
t.Fatalf("decodeStreamFastDataFrame failed: %v", err)
|
|
}
|
|
if !matched {
|
|
t.Fatal("decodeStreamFastDataFrame should match fast payload")
|
|
}
|
|
if frame.DataID != 23 {
|
|
t.Fatalf("data id = %d, want %d", frame.DataID, 23)
|
|
}
|
|
if frame.Seq != 7 {
|
|
t.Fatalf("seq = %d, want %d", frame.Seq, 7)
|
|
}
|
|
if !bytes.Equal(frame.Payload, []byte("payload")) {
|
|
t.Fatalf("payload = %q, want %q", frame.Payload, "payload")
|
|
}
|
|
}
|
|
|
|
func TestModernPSKFastBulkEncodeRoundTrip(t *testing.T) {
|
|
key, aad, err := deriveModernPSKKey([]byte("notify-fast-bulk"), testModernPSKOptions())
|
|
if err != nil {
|
|
t.Fatalf("deriveModernPSKKey failed: %v", err)
|
|
}
|
|
transport := buildModernPSKTransportBundle(aad)
|
|
wire, err := transport.fastBulkEncode(key, 41, 9, []byte("payload"))
|
|
if err != nil {
|
|
t.Fatalf("fastBulkEncode failed: %v", err)
|
|
}
|
|
plain := transport.msgDe(key, wire)
|
|
if plain == nil {
|
|
t.Fatal("msgDe returned nil")
|
|
}
|
|
frame, matched, err := decodeBulkFastDataFrame(plain)
|
|
if err != nil {
|
|
t.Fatalf("decodeBulkFastDataFrame failed: %v", err)
|
|
}
|
|
if !matched {
|
|
t.Fatal("decodeBulkFastDataFrame should match fast payload")
|
|
}
|
|
if frame.DataID != 41 {
|
|
t.Fatalf("data id = %d, want %d", frame.DataID, 41)
|
|
}
|
|
if frame.Seq != 9 {
|
|
t.Fatalf("seq = %d, want %d", frame.Seq, 9)
|
|
}
|
|
if !bytes.Equal(frame.Payload, []byte("payload")) {
|
|
t.Fatalf("payload = %q, want %q", frame.Payload, "payload")
|
|
}
|
|
}
|
|
|
|
func TestUseLegacySecurityRoundTrip(t *testing.T) {
|
|
client := NewClient()
|
|
server := NewServer()
|
|
|
|
UseLegacySecurityClient(client)
|
|
UseLegacySecurityServer(server)
|
|
|
|
cc := client.(*ClientCommon)
|
|
ss := server.(*ServerCommon)
|
|
if cc.SkipExchangeKey() {
|
|
t.Fatal("legacy client should keep legacy exchange enabled")
|
|
}
|
|
if !bytes.Equal(cc.SecretKey, defaultAesKey) {
|
|
t.Fatal("legacy client should restore the default AES key")
|
|
}
|
|
if !bytes.Equal(ss.SecretKey, defaultAesKey) {
|
|
t.Fatal("legacy server should restore the default AES key")
|
|
}
|
|
if !bytes.Equal(cc.RsaPubKey(), defaultRsaPubKey) {
|
|
t.Fatal("legacy client should restore the default RSA public key")
|
|
}
|
|
if !bytes.Equal(ss.RsaPrivKey(), defaultRsaKey) {
|
|
t.Fatal("legacy server should restore the default RSA private key")
|
|
}
|
|
|
|
plain := []byte("notify legacy transport")
|
|
wire := cc.msgEn(cc.SecretKey, plain)
|
|
got := ss.defaultMsgDe(ss.SecretKey, wire)
|
|
if !bytes.Equal(got, plain) {
|
|
t.Fatalf("legacy server decode mismatch: got %q want %q", got, plain)
|
|
}
|
|
|
|
replyWire := ss.defaultMsgEn(ss.SecretKey, plain)
|
|
reply := cc.msgDe(cc.SecretKey, replyWire)
|
|
if !bytes.Equal(reply, plain) {
|
|
t.Fatalf("legacy client decode mismatch: got %q want %q", reply, plain)
|
|
}
|
|
}
|