package notify import ( "bytes" "errors" "reflect" "testing" "b612.me/starcrypto" ) func testModernPSKOptions() *ModernPSKOptions { return &ModernPSKOptions{ Salt: []byte("notify-modern-psk-test-salt"), AAD: []byte("notify-modern-psk-test-aad"), Argon2Params: starcrypto.Argon2Params{ Time: 1, Memory: 8, Threads: 1, KeyLen: 32, }, } } func TestUseModernPSKRoundTrip(t *testing.T) { client := NewClient() server := NewServer() secret := []byte("correct horse battery staple") opts := testModernPSKOptions() if err := UseModernPSKClient(client, secret, opts); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := UseModernPSKServer(server, secret, opts); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } cc := client.(*ClientCommon) ss := server.(*ServerCommon) if !cc.SkipExchangeKey() { t.Fatal("client should skip legacy key exchange after UseModernPSKClient") } if len(cc.SecretKey) != 32 { t.Fatalf("client derived key length = %d, want 32", len(cc.SecretKey)) } if !bytes.Equal(cc.SecretKey, ss.SecretKey) { t.Fatal("derived transport keys do not match") } plain := []byte("notify modern psk transport") wire := cc.msgEn(cc.SecretKey, plain) got := ss.defaultMsgDe(ss.SecretKey, wire) if !bytes.Equal(got, plain) { t.Fatalf("server decode mismatch: got %q want %q", got, plain) } replyWire := ss.defaultMsgEn(ss.SecretKey, plain) reply := cc.msgDe(cc.SecretKey, replyWire) if !bytes.Equal(reply, plain) { t.Fatalf("client decode mismatch: got %q want %q", reply, plain) } } func TestNewClientConnectRequiresModernPSK(t *testing.T) { client := NewClient() err := client.Connect("tcp", "127.0.0.1:1") if !errors.Is(err, errModernPSKRequired) { t.Fatalf("Connect error = %v, want %v", err, errModernPSKRequired) } } func TestNewServerListenRequiresModernPSK(t *testing.T) { server := NewServer() err := server.Listen("tcp", "127.0.0.1:1") if !errors.Is(err, errModernPSKRequired) { t.Fatalf("Listen error = %v, want %v", err, errModernPSKRequired) } } func TestDefaultConstructorsUseModernTransportAfterSetSecretKey(t *testing.T) { client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) sharedKey := []byte("0123456789abcdef0123456789abcdef") client.SetSecretKey(sharedKey) server.SetSecretKey(sharedKey) plain := []byte("notify default modern transport") wire := client.msgEn(client.SecretKey, plain) got := server.defaultMsgDe(server.SecretKey, wire) if !bytes.Equal(got, plain) { t.Fatalf("server decode mismatch: got %q want %q", got, plain) } } func TestDefaultConstructorsDecodeSignalEnvelopeWithModernTransport(t *testing.T) { client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) sharedKey := []byte("0123456789abcdef0123456789abcdef") client.SetSecretKey(sharedKey) server.SetSecretKey(sharedKey) want := TransferMsg{ ID: 42, Key: "modern-signal", Value: MsgVal("payload"), Type: MSG_ASYNC, } body, err := client.sequenceEn(want) if err != nil { t.Fatalf("sequenceEn failed: %v", err) } wire := client.msgEn(client.SecretKey, body) env, err := server.decodeEnvelope(newServerCodecClientConnForTest(server), wire) if err != nil { t.Fatalf("decodeEnvelope failed: %v", err) } if env.Kind != EnvelopeSignal { t.Fatalf("envelope kind = %v, want %v", env.Kind, EnvelopeSignal) } got, err := unwrapTransferMsgEnvelope(env, server.sequenceDe) if err != nil { t.Fatalf("unwrapTransferMsgEnvelope failed: %v", err) } if !reflect.DeepEqual(got, want) { t.Fatalf("signal mismatch: got %#v want %#v", got, want) } } func TestDefaultConstructorsDecodeFileEnvelopesWithModernTransport(t *testing.T) { client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) sharedKey := []byte("0123456789abcdef0123456789abcdef") client.SetSecretKey(sharedKey) server.SetSecretKey(sharedKey) tests := []struct { name string env Envelope }{ { name: "file-meta", env: newFileMetaEnvelope("file-1", "demo.txt", 64, "checksum", 0644, 123456789), }, { name: "file-chunk", env: newFileChunkEnvelope("file-1", 32, []byte("chunk-data")), }, { name: "file-ack", env: newFileAckEnvelope("file-1", "chunk", 32, ""), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { body, err := client.sequenceEn(tt.env) if err != nil { t.Fatalf("sequenceEn failed: %v", err) } wire := client.msgEn(client.SecretKey, body) got, err := server.decodeEnvelope(newServerCodecClientConnForTest(server), wire) if err != nil { t.Fatalf("decodeEnvelope failed: %v", err) } if !reflect.DeepEqual(got, tt.env) { t.Fatalf("envelope mismatch: got %#v want %#v", got, tt.env) } }) } } func TestUseModernPSKRejectsEmptySecret(t *testing.T) { if err := UseModernPSKClient(NewClient(), nil, testModernPSKOptions()); err == nil { t.Fatal("UseModernPSKClient should reject empty secret") } if err := UseModernPSKServer(NewServer(), nil, testModernPSKOptions()); err == nil { t.Fatal("UseModernPSKServer should reject empty secret") } } func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) { key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions()) if err != nil { t.Fatalf("deriveModernPSKKey failed: %v", err) } _, modernMsgDe := buildModernPSKCodecs(aad) legacyWire := defaultMsgEn(key, []byte("legacy payload")) if got := modernMsgDe(key, legacyWire); got != nil { t.Fatalf("modern decoder should reject legacy payload, got %q", got) } } func TestModernPSKCodecUsesUniqueNoncePerMessage(t *testing.T) { key, aad, err := deriveModernPSKKey([]byte("notify-unique-nonce"), testModernPSKOptions()) if err != nil { t.Fatalf("deriveModernPSKKey failed: %v", err) } msgEn, msgDe := buildModernPSKCodecs(aad) first := msgEn(key, []byte("payload")) second := msgEn(key, []byte("payload")) if first == nil || second == nil { t.Fatal("modern msgEn should produce payload") } if bytes.Equal(first, second) { t.Fatal("two modern payloads should not be byte-identical") } if !bytes.Equal(first[:len(modernPSKMagic)], modernPSKMagic) { t.Fatalf("first payload magic = %q, want %q", first[:len(modernPSKMagic)], modernPSKMagic) } if !bytes.Equal(second[:len(modernPSKMagic)], modernPSKMagic) { t.Fatalf("second payload magic = %q, want %q", second[:len(modernPSKMagic)], modernPSKMagic) } if bytes.Equal(first[len(modernPSKMagic):len(modernPSKMagic)+modernPSKNonceSize], second[len(modernPSKMagic):len(modernPSKMagic)+modernPSKNonceSize]) { t.Fatal("modern payload nonces should differ between messages") } if got := msgDe(key, first); !bytes.Equal(got, []byte("payload")) { t.Fatalf("first decode = %q, want %q", got, "payload") } if got := msgDe(key, second); !bytes.Equal(got, []byte("payload")) { t.Fatalf("second decode = %q, want %q", got, "payload") } } func TestModernPSKFastStreamEncodeRoundTrip(t *testing.T) { key, aad, err := deriveModernPSKKey([]byte("notify-fast-stream"), testModernPSKOptions()) if err != nil { t.Fatalf("deriveModernPSKKey failed: %v", err) } transport := buildModernPSKTransportBundle(aad) wire, err := transport.fastStreamEncode(key, 23, 7, []byte("payload")) if err != nil { t.Fatalf("fastStreamEncode failed: %v", err) } plain := transport.msgDe(key, wire) if plain == nil { t.Fatal("msgDe returned nil") } frame, matched, err := decodeStreamFastDataFrame(plain) if err != nil { t.Fatalf("decodeStreamFastDataFrame failed: %v", err) } if !matched { t.Fatal("decodeStreamFastDataFrame should match fast payload") } if frame.DataID != 23 { t.Fatalf("data id = %d, want %d", frame.DataID, 23) } if frame.Seq != 7 { t.Fatalf("seq = %d, want %d", frame.Seq, 7) } if !bytes.Equal(frame.Payload, []byte("payload")) { t.Fatalf("payload = %q, want %q", frame.Payload, "payload") } } func TestModernPSKFastBulkEncodeRoundTrip(t *testing.T) { key, aad, err := deriveModernPSKKey([]byte("notify-fast-bulk"), testModernPSKOptions()) if err != nil { t.Fatalf("deriveModernPSKKey failed: %v", err) } transport := buildModernPSKTransportBundle(aad) wire, err := transport.fastBulkEncode(key, 41, 9, []byte("payload")) if err != nil { t.Fatalf("fastBulkEncode failed: %v", err) } plain := transport.msgDe(key, wire) if plain == nil { t.Fatal("msgDe returned nil") } frame, matched, err := decodeBulkFastDataFrame(plain) if err != nil { t.Fatalf("decodeBulkFastDataFrame failed: %v", err) } if !matched { t.Fatal("decodeBulkFastDataFrame should match fast payload") } if frame.DataID != 41 { t.Fatalf("data id = %d, want %d", frame.DataID, 41) } if frame.Seq != 9 { t.Fatalf("seq = %d, want %d", frame.Seq, 9) } if !bytes.Equal(frame.Payload, []byte("payload")) { t.Fatalf("payload = %q, want %q", frame.Payload, "payload") } } func TestUseLegacySecurityRoundTrip(t *testing.T) { client := NewClient() server := NewServer() UseLegacySecurityClient(client) UseLegacySecurityServer(server) cc := client.(*ClientCommon) ss := server.(*ServerCommon) if cc.SkipExchangeKey() { t.Fatal("legacy client should keep legacy exchange enabled") } if !bytes.Equal(cc.SecretKey, defaultAesKey) { t.Fatal("legacy client should restore the default AES key") } if !bytes.Equal(ss.SecretKey, defaultAesKey) { t.Fatal("legacy server should restore the default AES key") } if !bytes.Equal(cc.RsaPubKey(), defaultRsaPubKey) { t.Fatal("legacy client should restore the default RSA public key") } if !bytes.Equal(ss.RsaPrivKey(), defaultRsaKey) { t.Fatal("legacy server should restore the default RSA private key") } plain := []byte("notify legacy transport") wire := cc.msgEn(cc.SecretKey, plain) got := ss.defaultMsgDe(ss.SecretKey, wire) if !bytes.Equal(got, plain) { t.Fatalf("legacy server decode mismatch: got %q want %q", got, plain) } replyWire := ss.defaultMsgEn(ss.SecretKey, plain) reply := cc.msgDe(cc.SecretKey, replyWire) if !bytes.Equal(reply, plain) { t.Fatalf("legacy client decode mismatch: got %q want %q", reply, plain) } }