package notify import ( "bytes" "errors" "net" "reflect" "testing" "time" "b612.me/starcrypto" ) func testModernPSKOptions() *ModernPSKOptions { return &ModernPSKOptions{ Salt: []byte("notify-modern-psk-test-salt"), AAD: []byte("notify-modern-psk-test-aad"), Argon2Params: starcrypto.Argon2Params{ Time: 1, Memory: 8, Threads: 1, KeyLen: 32, }, } } func TestUseModernPSKRoundTrip(t *testing.T) { client := NewClient() server := NewServer() secret := []byte("correct horse battery staple") opts := testModernPSKOptions() if err := UseModernPSKClient(client, secret, opts); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := UseModernPSKServer(server, secret, opts); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } cc := client.(*ClientCommon) ss := server.(*ServerCommon) if !cc.SkipExchangeKey() { t.Fatal("client should skip legacy key exchange after UseModernPSKClient") } if len(cc.SecretKey) != 32 { t.Fatalf("client derived key length = %d, want 32", len(cc.SecretKey)) } if !bytes.Equal(cc.SecretKey, ss.SecretKey) { t.Fatal("derived transport keys do not match") } plain := []byte("notify modern psk transport") wire := cc.msgEn(cc.SecretKey, plain) got := ss.defaultMsgDe(ss.SecretKey, wire) if !bytes.Equal(got, plain) { t.Fatalf("server decode mismatch: got %q want %q", got, plain) } replyWire := ss.defaultMsgEn(ss.SecretKey, plain) reply := cc.msgDe(cc.SecretKey, replyWire) if !bytes.Equal(reply, plain) { t.Fatalf("client decode mismatch: got %q want %q", reply, plain) } } func TestNewClientConnectRequiresModernPSK(t *testing.T) { client := NewClient() err := client.Connect("tcp", "127.0.0.1:1") if !errors.Is(err, errModernPSKRequired) { t.Fatalf("Connect error = %v, want %v", err, errModernPSKRequired) } } func TestNewServerListenRequiresModernPSK(t *testing.T) { server := NewServer() err := server.Listen("tcp", "127.0.0.1:1") if !errors.Is(err, errModernPSKRequired) { t.Fatalf("Listen error = %v, want %v", err, errModernPSKRequired) } } func TestDefaultConstructorsUseModernTransportAfterSetSecretKey(t *testing.T) { client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) sharedKey := []byte("0123456789abcdef0123456789abcdef") client.SetSecretKey(sharedKey) server.SetSecretKey(sharedKey) if client.modernPSKRuntime == nil { t.Fatal("client modernPSKRuntime should be installed after SetSecretKey") } if server.defaultModernPSKRuntime == nil { t.Fatal("server defaultModernPSKRuntime should be installed after SetSecretKey") } plain := []byte("notify default modern transport") wire := client.msgEn(client.SecretKey, plain) got := server.defaultMsgDe(server.SecretKey, wire) if !bytes.Equal(got, plain) { t.Fatalf("server decode mismatch: got %q want %q", got, plain) } } func TestCustomCodecOverridesClearModernRuntime(t *testing.T) { client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) sharedKey := []byte("0123456789abcdef0123456789abcdef") client.SetSecretKey(sharedKey) server.SetSecretKey(sharedKey) if client.modernPSKRuntime == nil || server.defaultModernPSKRuntime == nil { t.Fatal("modern runtimes should be installed before override") } client.SetMsgEn(defaultMsgEn) client.SetMsgDe(defaultMsgDe) server.SetDefaultCommEncode(defaultMsgEn) server.SetDefaultCommDecode(defaultMsgDe) if client.modernPSKRuntime != nil { t.Fatal("client modernPSKRuntime should be cleared after custom codec override") } if server.defaultModernPSKRuntime != nil { t.Fatal("server defaultModernPSKRuntime should be cleared after custom codec override") } } func TestDefaultConstructorsDecodeSignalEnvelopeWithModernTransport(t *testing.T) { client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) sharedKey := []byte("0123456789abcdef0123456789abcdef") client.SetSecretKey(sharedKey) server.SetSecretKey(sharedKey) want := TransferMsg{ ID: 42, Key: "modern-signal", Value: MsgVal("payload"), Type: MSG_ASYNC, } body, err := client.sequenceEn(want) if err != nil { t.Fatalf("sequenceEn failed: %v", err) } wire := client.msgEn(client.SecretKey, body) env, err := server.decodeEnvelope(newServerCodecClientConnForTest(server), wire) if err != nil { t.Fatalf("decodeEnvelope failed: %v", err) } if env.Kind != EnvelopeSignal { t.Fatalf("envelope kind = %v, want %v", env.Kind, EnvelopeSignal) } got, err := unwrapTransferMsgEnvelope(env, server.sequenceDe) if err != nil { t.Fatalf("unwrapTransferMsgEnvelope failed: %v", err) } if !reflect.DeepEqual(got, want) { t.Fatalf("signal mismatch: got %#v want %#v", got, want) } } func TestDefaultConstructorsDecodeFileEnvelopesWithModernTransport(t *testing.T) { client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) sharedKey := []byte("0123456789abcdef0123456789abcdef") client.SetSecretKey(sharedKey) server.SetSecretKey(sharedKey) tests := []struct { name string env Envelope }{ { name: "file-meta", env: newFileMetaEnvelope("file-1", "demo.txt", 64, "checksum", 0644, 123456789), }, { name: "file-chunk", env: newFileChunkEnvelope("file-1", 32, []byte("chunk-data")), }, { name: "file-ack", env: newFileAckEnvelope("file-1", "chunk", 32, ""), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { body, err := client.sequenceEn(tt.env) if err != nil { t.Fatalf("sequenceEn failed: %v", err) } wire := client.msgEn(client.SecretKey, body) got, err := server.decodeEnvelope(newServerCodecClientConnForTest(server), wire) if err != nil { t.Fatalf("decodeEnvelope failed: %v", err) } if !reflect.DeepEqual(got, tt.env) { t.Fatalf("envelope mismatch: got %#v want %#v", got, tt.env) } }) } } func TestUseModernPSKRejectsEmptySecret(t *testing.T) { if err := UseModernPSKClient(NewClient(), nil, testModernPSKOptions()); err == nil { t.Fatal("UseModernPSKClient should reject empty secret") } if err := UseModernPSKServer(NewServer(), nil, testModernPSKOptions()); err == nil { t.Fatal("UseModernPSKServer should reject empty secret") } } func TestUsePSKOverExternalTransportRejectsForwardSecrecyRequirement(t *testing.T) { opts := testModernPSKOptions() opts.RequireForwardSecrecy = true if err := UsePSKOverExternalTransportClient(NewClient(), []byte("secret"), opts); !errors.Is(err, errModernPSKForwardSecrecyUnsupported) { t.Fatalf("UsePSKOverExternalTransportClient error = %v, want %v", err, errModernPSKForwardSecrecyUnsupported) } if err := UsePSKOverExternalTransportServer(NewServer(), []byte("secret"), opts); !errors.Is(err, errModernPSKForwardSecrecyUnsupported) { t.Fatalf("UsePSKOverExternalTransportServer error = %v, want %v", err, errModernPSKForwardSecrecyUnsupported) } } func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) { key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions()) if err != nil { t.Fatalf("deriveModernPSKKey failed: %v", err) } _, modernMsgDe := buildModernPSKCodecs(aad) legacyWire := defaultMsgEn(key, []byte("legacy payload")) if got := modernMsgDe(key, legacyWire); got != nil { t.Fatalf("modern decoder should reject legacy payload, got %q", got) } } func TestModernPSKCodecUsesUniqueNoncePerMessage(t *testing.T) { key, aad, err := deriveModernPSKKey([]byte("notify-unique-nonce"), testModernPSKOptions()) if err != nil { t.Fatalf("deriveModernPSKKey failed: %v", err) } msgEn, msgDe := buildModernPSKCodecs(aad) first := msgEn(key, []byte("payload")) second := msgEn(key, []byte("payload")) if first == nil || second == nil { t.Fatal("modern msgEn should produce payload") } if bytes.Equal(first, second) { t.Fatal("two modern payloads should not be byte-identical") } if !bytes.Equal(first[:len(modernPSKMagic)], modernPSKMagic) { t.Fatalf("first payload magic = %q, want %q", first[:len(modernPSKMagic)], modernPSKMagic) } if !bytes.Equal(second[:len(modernPSKMagic)], modernPSKMagic) { t.Fatalf("second payload magic = %q, want %q", second[:len(modernPSKMagic)], modernPSKMagic) } if bytes.Equal(first[len(modernPSKMagic):len(modernPSKMagic)+modernPSKNonceSize], second[len(modernPSKMagic):len(modernPSKMagic)+modernPSKNonceSize]) { t.Fatal("modern payload nonces should differ between messages") } if got := msgDe(key, first); !bytes.Equal(got, []byte("payload")) { t.Fatalf("first decode = %q, want %q", got, "payload") } if got := msgDe(key, second); !bytes.Equal(got, []byte("payload")) { t.Fatalf("second decode = %q, want %q", got, "payload") } } func TestModernPSKFastStreamEncodeRoundTrip(t *testing.T) { key, aad, err := deriveModernPSKKey([]byte("notify-fast-stream"), testModernPSKOptions()) if err != nil { t.Fatalf("deriveModernPSKKey failed: %v", err) } transport := buildModernPSKTransportBundle(aad) wire, err := transport.fastStreamEncode(key, 23, 7, []byte("payload")) if err != nil { t.Fatalf("fastStreamEncode failed: %v", err) } plain := transport.msgDe(key, wire) if plain == nil { t.Fatal("msgDe returned nil") } frame, matched, err := decodeStreamFastDataFrame(plain) if err != nil { t.Fatalf("decodeStreamFastDataFrame failed: %v", err) } if !matched { t.Fatal("decodeStreamFastDataFrame should match fast payload") } if frame.DataID != 23 { t.Fatalf("data id = %d, want %d", frame.DataID, 23) } if frame.Seq != 7 { t.Fatalf("seq = %d, want %d", frame.Seq, 7) } if !bytes.Equal(frame.Payload, []byte("payload")) { t.Fatalf("payload = %q, want %q", frame.Payload, "payload") } } func TestModernPSKFastBulkEncodeRoundTrip(t *testing.T) { key, aad, err := deriveModernPSKKey([]byte("notify-fast-bulk"), testModernPSKOptions()) if err != nil { t.Fatalf("deriveModernPSKKey failed: %v", err) } transport := buildModernPSKTransportBundle(aad) wire, err := transport.fastBulkEncode(key, 41, 9, []byte("payload")) if err != nil { t.Fatalf("fastBulkEncode failed: %v", err) } plain := transport.msgDe(key, wire) if plain == nil { t.Fatal("msgDe returned nil") } frame, matched, err := decodeBulkFastDataFrame(plain) if err != nil { t.Fatalf("decodeBulkFastDataFrame failed: %v", err) } if !matched { t.Fatal("decodeBulkFastDataFrame should match fast payload") } if frame.DataID != 41 { t.Fatalf("data id = %d, want %d", frame.DataID, 41) } if frame.Seq != 9 { t.Fatalf("seq = %d, want %d", frame.Seq, 9) } if !bytes.Equal(frame.Payload, []byte("payload")) { t.Fatalf("payload = %q, want %q", frame.Payload, "payload") } } func TestExternalTransportFastStreamEncodeRoundTrip(t *testing.T) { transport := buildExternalTransportBundle() wire, err := transport.fastStreamEncode(nil, 23, 7, []byte("payload")) if err != nil { t.Fatalf("fastStreamEncode failed: %v", err) } plain := transport.msgDe(nil, wire) frame, matched, err := decodeStreamFastDataFrame(plain) if err != nil { t.Fatalf("decodeStreamFastDataFrame failed: %v", err) } if !matched { t.Fatal("decodeStreamFastDataFrame should match fast payload") } if frame.DataID != 23 || frame.Seq != 7 || !bytes.Equal(frame.Payload, []byte("payload")) { t.Fatalf("frame mismatch: %+v", frame) } } func TestExternalTransportFastBulkEncodeRoundTrip(t *testing.T) { transport := buildExternalTransportBundle() wire, err := transport.fastBulkEncode(nil, 41, 9, []byte("payload")) if err != nil { t.Fatalf("fastBulkEncode failed: %v", err) } plain := transport.msgDe(nil, wire) frame, matched, err := decodeBulkFastDataFrame(plain) if err != nil { t.Fatalf("decodeBulkFastDataFrame failed: %v", err) } if !matched { t.Fatal("decodeBulkFastDataFrame should match fast payload") } if frame.DataID != 41 || frame.Seq != 9 || !bytes.Equal(frame.Payload, []byte("payload")) { t.Fatalf("frame mismatch: %+v", frame) } } func TestDecryptTransportPayloadCodecPooledExternalDefersRelease(t *testing.T) { payload := []byte("payload") released := false plain, release, err := decryptTransportPayloadCodecPooled(ProtectionExternal, nil, passthroughTransportCodec, nil, payload, func() { released = true }) if err != nil { t.Fatalf("decryptTransportPayloadCodecPooled failed: %v", err) } if released { t.Fatal("release should not run before caller is done") } if !bytes.Equal(plain, payload) { t.Fatalf("plain mismatch: got %q want %q", plain, payload) } if release == nil { t.Fatal("release callback should be preserved for external mode") } release() if !released { t.Fatal("release callback should run when caller finishes") } } func TestUsePSKOverExternalTransportConnectByConnSwitchesToExternal(t *testing.T) { client := NewClient().(*ClientCommon) server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { if err := UsePSKOverExternalTransportServer(server, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil { t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err) } server.SetLink("external-roundtrip", func(msg *Message) { _ = msg.Reply([]byte("ack:" + string(msg.Value))) }) }) if err := UsePSKOverExternalTransportClient(client, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil { t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err) } if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionManaged { t.Fatalf("client bootstrap mode = %v, want %v", got, ProtectionManaged) } left, right := net.Pipe() defer right.Close() bootstrapPeerAttachConnForTest(t, server, right) if err := client.ConnectByConn(left); err != nil { t.Fatalf("ConnectByConn failed: %v", err) } defer func() { client.setByeFromServer(true) _ = client.Stop() }() if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionExternal { t.Fatalf("client steady mode = %v, want %v", got, ProtectionExternal) } reply, err := client.SendWait("external-roundtrip", []byte("ping"), time.Second) if err != nil { t.Fatalf("SendWait failed: %v", err) } if got, want := string(reply.Value), "ack:ping"; got != want { t.Fatalf("reply mismatch: got %q want %q", got, want) } list := server.GetLogicalConnList() if len(list) != 1 { t.Fatalf("logical conn count = %d, want 1", len(list)) } if got := list[0].protectionModeSnapshot(); got != ProtectionExternal { t.Fatalf("server steady mode = %v, want %v", got, ProtectionExternal) } } func TestUseNestedSecurityConnectByConnKeepsNestedMode(t *testing.T) { client := NewClient().(*ClientCommon) server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { if err := UseNestedSecurityServer(server, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil { t.Fatalf("UseNestedSecurityServer failed: %v", err) } server.SetLink("nested-roundtrip", func(msg *Message) { _ = msg.Reply([]byte("ack:" + string(msg.Value))) }) }) if err := UseNestedSecurityClient(client, []byte("correct horse battery staple"), testModernPSKOptions()); err != nil { t.Fatalf("UseNestedSecurityClient failed: %v", err) } if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionNested { t.Fatalf("client bootstrap mode = %v, want %v", got, ProtectionNested) } left, right := net.Pipe() defer right.Close() bootstrapPeerAttachConnForTest(t, server, right) if err := client.ConnectByConn(left); err != nil { t.Fatalf("ConnectByConn failed: %v", err) } defer func() { client.setByeFromServer(true) _ = client.Stop() }() if got := client.clientTransportProtectionSnapshot().mode; got != ProtectionNested { t.Fatalf("client steady mode = %v, want %v", got, ProtectionNested) } reply, err := client.SendWait("nested-roundtrip", []byte("ping"), time.Second) if err != nil { t.Fatalf("SendWait failed: %v", err) } if got, want := string(reply.Value), "ack:ping"; got != want { t.Fatalf("reply mismatch: got %q want %q", got, want) } list := server.GetLogicalConnList() if len(list) != 1 { t.Fatalf("logical conn count = %d, want 1", len(list)) } if got := list[0].protectionModeSnapshot(); got != ProtectionNested { t.Fatalf("server steady mode = %v, want %v", got, ProtectionNested) } } func TestUseLegacySecurityRoundTrip(t *testing.T) { client := NewClient() server := NewServer() UseLegacySecurityClient(client) UseLegacySecurityServer(server) cc := client.(*ClientCommon) ss := server.(*ServerCommon) if cc.SkipExchangeKey() { t.Fatal("legacy client should keep legacy exchange enabled") } if !bytes.Equal(cc.SecretKey, defaultAesKey) { t.Fatal("legacy client should restore the default AES key") } if !bytes.Equal(ss.SecretKey, defaultAesKey) { t.Fatal("legacy server should restore the default AES key") } if !bytes.Equal(cc.RsaPubKey(), defaultRsaPubKey) { t.Fatal("legacy client should restore the default RSA public key") } if !bytes.Equal(ss.RsaPrivKey(), defaultRsaKey) { t.Fatal("legacy server should restore the default RSA private key") } plain := []byte("notify legacy transport") wire := cc.msgEn(cc.SecretKey, plain) got := ss.defaultMsgDe(ss.SecretKey, wire) if !bytes.Equal(got, plain) { t.Fatalf("legacy server decode mismatch: got %q want %q", got, plain) } replyWire := ss.defaultMsgEn(ss.SecretKey, plain) reply := cc.msgDe(cc.SecretKey, replyWire) if !bytes.Equal(reply, plain) { t.Fatalf("legacy client decode mismatch: got %q want %q", reply, plain) } }