package notify import ( "bytes" "errors" "testing" ) func newPeerAttachAuthLogicalForTest(t *testing.T, server *ServerCommon) *LogicalConn { t.Helper() logical := newServerLogicalConn(server, "accepted-auth", nil) logical = server.registerAcceptedLogical(logical) if logical == nil { t.Fatal("registerAcceptedLogical returned nil") } return logical } func TestPeerAttachExplicitAuthHelpersRoundTrip(t *testing.T) { secret := []byte("correct horse battery staple") client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } logical := newPeerAttachAuthLogicalForTest(t, server) req, reqState, err := client.buildPeerAttachRequest("peer-explicit") if err != nil { t.Fatalf("buildPeerAttachRequest failed: %v", err) } if !supportsExplicitPeerAttachAuth(req.Features) { t.Fatalf("request features = %d, want explicit auth bit", req.Features) } if !supportsPeerAttachForwardSecrecy(req.Features) { t.Fatalf("request features = %d, want forward secrecy bit", req.Features) } if len(req.ClientNonce) != peerAttachNonceSize { t.Fatalf("client nonce length = %d, want %d", len(req.ClientNonce), peerAttachNonceSize) } auth, err := server.validatePeerAttachRequestAuth(logical, nil, req) if err != nil { t.Fatalf("validatePeerAttachRequestAuth failed: %v", err) } if !auth.explicit || auth.fallback { t.Fatalf("auth result mismatch: %+v", auth) } resp := peerAttachResponse{ PeerID: req.PeerID, Accepted: true, } server.signPeerAttachResponse(logical, req, &resp, auth) if !supportsExplicitPeerAttachAuth(resp.Features) { t.Fatalf("response features = %d, want explicit auth bit", resp.Features) } if len(resp.ServerNonce) != peerAttachNonceSize { t.Fatalf("server nonce length = %d, want %d", len(resp.ServerNonce), peerAttachNonceSize) } verifyResult, err := client.verifyPeerAttachResponse(req, resp, reqState) if err != nil { t.Fatalf("verifyPeerAttachResponse failed: %v", err) } if verifyResult.authFallback { t.Fatal("explicit response should not be marked as fallback") } if !verifyResult.steadyProfile.forwardSecrecyFallback { t.Fatal("response without fs extension should mark forward secrecy fallback") } resp.AuthTag[0] ^= 0xff if verifyResult, err = client.verifyPeerAttachResponse(req, resp, reqState); !errors.Is(err, errPeerAttachAuthInvalid) { t.Fatalf("tampered response error = %v, want %v", err, errPeerAttachAuthInvalid) } else if verifyResult.authFallback { t.Fatal("tampered explicit response should not be treated as fallback") } } func TestPeerAttachRequestAuthRejectsReplay(t *testing.T) { secret := []byte("correct horse battery staple") client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } logical := newPeerAttachAuthLogicalForTest(t, server) req, _, err := client.buildPeerAttachRequest("peer-replay") if err != nil { t.Fatalf("buildPeerAttachRequest failed: %v", err) } if _, err := server.validatePeerAttachRequestAuth(logical, nil, req); err != nil { t.Fatalf("first validatePeerAttachRequestAuth failed: %v", err) } if _, err := server.validatePeerAttachRequestAuth(logical, nil, req); !errors.Is(err, errPeerAttachReplayRejected) { t.Fatalf("second validatePeerAttachRequestAuth error = %v, want %v", err, errPeerAttachReplayRejected) } classifyPeerAttachRejectCounter(server, errPeerAttachReplayRejected) if got, want := server.peerAttachReplayRejectCountSnapshot(), int64(1); got != want { t.Fatalf("replay reject count = %d, want %d", got, want) } } func TestPeerAttachAuthFallbackCompatibility(t *testing.T) { secret := []byte("correct horse battery staple") client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } logical := newPeerAttachAuthLogicalForTest(t, server) auth, err := server.validatePeerAttachRequestAuth(logical, nil, peerAttachRequest{PeerID: "peer-fallback"}) if err != nil { t.Fatalf("validatePeerAttachRequestAuth fallback failed: %v", err) } if auth.explicit || !auth.fallback { t.Fatalf("fallback auth result mismatch: %+v", auth) } req, reqState, err := client.buildPeerAttachRequest("peer-fallback") if err != nil { t.Fatalf("buildPeerAttachRequest failed: %v", err) } verifyResult, err := client.verifyPeerAttachResponse(req, peerAttachResponse{ PeerID: req.PeerID, Accepted: true, }, reqState) if err != nil { t.Fatalf("verifyPeerAttachResponse fallback failed: %v", err) } if !verifyResult.authFallback { t.Fatal("unsigned legacy response should be marked as fallback") } } func TestPeerAttachForwardSecrecyNegotiatesDerivedSteadyProfile(t *testing.T) { secret := []byte("correct horse battery staple") client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } logical := newPeerAttachAuthLogicalForTest(t, server) req, reqState, err := client.buildPeerAttachRequest("peer-fs") if err != nil { t.Fatalf("buildPeerAttachRequest failed: %v", err) } auth, err := server.validatePeerAttachRequestAuth(logical, nil, req) if err != nil { t.Fatalf("validatePeerAttachRequestAuth failed: %v", err) } resp := peerAttachResponse{ PeerID: req.PeerID, Accepted: true, } serverProfile, err := server.preparePeerAttachSteadyTransportProfile(logical, req, &resp, auth) if err != nil { t.Fatalf("preparePeerAttachSteadyTransportProfile failed: %v", err) } server.signPeerAttachResponse(logical, req, &resp, auth) verifyResult, err := client.verifyPeerAttachResponse(req, resp, reqState) if err != nil { t.Fatalf("verifyPeerAttachResponse failed: %v", err) } if !supportsPeerAttachForwardSecrecy(resp.Features) { t.Fatalf("response features = %d, want forward secrecy bit", resp.Features) } if resp.KeyMode != peerAttachKeyModeECDHE { t.Fatalf("response key mode = %q, want %q", resp.KeyMode, peerAttachKeyModeECDHE) } if !verifyResult.steadyProfile.forwardSecrecy { t.Fatal("client steady profile should enable forward secrecy") } if !serverProfile.forwardSecrecy { t.Fatal("server steady profile should enable forward secrecy") } if len(verifyResult.steadyProfile.sessionID) == 0 { t.Fatal("client session id should be populated") } if !bytes.Equal(verifyResult.steadyProfile.secretKey, serverProfile.secretKey) { t.Fatal("client/server derived steady keys should match") } if !bytes.Equal(verifyResult.steadyProfile.sessionID, serverProfile.sessionID) { t.Fatal("client/server session ids should match") } if bytes.Equal(verifyResult.steadyProfile.secretKey, client.securityBootstrap.secretKey) { t.Fatal("derived steady key should differ from bootstrap key") } } func TestPeerAttachForwardSecrecyStrictRejectsFallback(t *testing.T) { secret := []byte("correct horse battery staple") client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) opts := testModernPSKOptions() opts.RequireForwardSecrecy = true if err := UseModernPSKClient(client, secret, opts); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := UseModernPSKServer(server, secret, opts); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } req, reqState, err := client.buildPeerAttachRequest("peer-fs-strict") if err != nil { t.Fatalf("buildPeerAttachRequest failed: %v", err) } _, err = client.verifyPeerAttachResponse(req, peerAttachResponse{ PeerID: req.PeerID, Accepted: true, Features: peerAttachFeatureExplicitAuth, ServerNonce: make([]byte, peerAttachNonceSize), AuthTag: computePeerAttachResponseAuthTag(client.securityBootstrap.secretKey, req, peerAttachResponse{PeerID: req.PeerID, Accepted: true, Features: peerAttachFeatureExplicitAuth, ServerNonce: make([]byte, peerAttachNonceSize)}, nil), }, reqState) if !errors.Is(err, errPeerAttachForwardSecrecyRequired) { t.Fatalf("verifyPeerAttachResponse error = %v, want %v", err, errPeerAttachForwardSecrecyRequired) } }