package notify import ( "bytes" "errors" "net" "testing" "time" ) func staticPeerAttachChannelBindingProvider(material []byte) PeerAttachChannelBindingProvider { cloned := bytes.Clone(material) return func(PeerAttachChannelBindingContext) ([]byte, error) { return bytes.Clone(cloned), nil } } func failingPeerAttachChannelBindingProvider(PeerAttachChannelBindingContext) ([]byte, error) { return nil, errors.New("binding unavailable") } func TestSetPeerAttachSecurityConfigRejectsMissingChannelBindingProvider(t *testing.T) { client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) cfg := PeerAttachSecurityConfig{RequireChannelBinding: true} if err := client.SetPeerAttachSecurityConfig(cfg); !errors.Is(err, errPeerAttachChannelBindingProviderNil) { t.Fatalf("client SetPeerAttachSecurityConfig error = %v, want %v", err, errPeerAttachChannelBindingProviderNil) } if err := server.SetPeerAttachSecurityConfig(cfg); !errors.Is(err, errPeerAttachChannelBindingProviderNil) { t.Fatalf("server SetPeerAttachSecurityConfig error = %v, want %v", err, errPeerAttachChannelBindingProviderNil) } } func TestPeerAttachRequireExplicitAuthRejectsFallbackClient(t *testing.T) { secret := []byte("correct horse battery staple") server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{ RequireExplicitAuth: true, }); err != nil { t.Fatalf("SetPeerAttachSecurityConfig failed: %v", err) } logical := newPeerAttachAuthLogicalForTest(t, server) if _, err := server.validatePeerAttachRequestAuth(logical, nil, peerAttachRequest{PeerID: "peer-fallback"}); !errors.Is(err, errPeerAttachExplicitAuthRequired) { t.Fatalf("validatePeerAttachRequestAuth error = %v, want %v", err, errPeerAttachExplicitAuthRequired) } classifyPeerAttachRejectCounter(server, errPeerAttachExplicitAuthRequired) snapshot, snapErr := GetServerRuntimeSnapshot(server) if snapErr != nil { t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr) } if got, want := snapshot.PeerAttachDowngradeRejects, int64(1); got != want { t.Fatalf("PeerAttachDowngradeRejects = %d, want %d", got, want) } if got := snapshot.PeerAttachAuthFallbacks; got != 0 { t.Fatalf("PeerAttachAuthFallbacks = %d, want 0", got) } } func TestPeerAttachChannelBindingRoundTrip(t *testing.T) { secret := []byte("correct horse battery staple") bindingProvider := staticPeerAttachChannelBindingProvider([]byte("tls-exporter:test")) server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { if err := UsePSKOverExternalTransportServer(server, secret, testModernPSKOptions()); err != nil { t.Fatalf("UsePSKOverExternalTransportServer failed: %v", err) } if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{ RequireChannelBinding: true, ChannelBinding: bindingProvider, }); err != nil { t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err) } server.SetLink("echo", func(msg *Message) { _ = msg.Reply([]byte("ack")) }) }) client := NewClient().(*ClientCommon) if err := UsePSKOverExternalTransportClient(client, secret, testModernPSKOptions()); err != nil { t.Fatalf("UsePSKOverExternalTransportClient failed: %v", err) } if err := client.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{ RequireChannelBinding: true, ChannelBinding: bindingProvider, }); err != nil { t.Fatalf("client SetPeerAttachSecurityConfig failed: %v", err) } left, right := net.Pipe() defer right.Close() bootstrapPeerAttachLogicalForTest(t, server, right) if err := client.ConnectByConn(left); err != nil { t.Fatalf("ConnectByConn failed: %v", err) } defer func() { client.setByeFromServer(true) _ = client.Stop() }() reply, err := client.SendWait("echo", []byte("ping"), time.Second) if err != nil { t.Fatalf("SendWait failed: %v", err) } if got, want := string(reply.Value), "ack"; got != want { t.Fatalf("reply = %q, want %q", got, want) } serverSnapshot, err := GetServerRuntimeSnapshot(server) if err != nil { t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) } if !serverSnapshot.PeerAttachRequireExplicitAuth || !serverSnapshot.PeerAttachRequireChannelBinding || !serverSnapshot.PeerAttachChannelBindingConfigured { t.Fatalf("unexpected server peer attach policy snapshot: %+v", serverSnapshot) } if got, want := serverSnapshot.PeerAttachExplicitAuth, int64(1); got != want { t.Fatalf("PeerAttachExplicitAuth = %d, want %d", got, want) } if serverSnapshot.PeerAttachAuthRejects != 0 || serverSnapshot.PeerAttachDowngradeRejects != 0 || serverSnapshot.PeerAttachBindingRejects != 0 { t.Fatalf("unexpected server reject counters: %+v", serverSnapshot) } clientSnapshot, err := GetClientRuntimeSnapshot(client) if err != nil { t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) } if !clientSnapshot.PeerAttachRequireExplicitAuth || !clientSnapshot.PeerAttachRequireChannelBinding || !clientSnapshot.PeerAttachChannelBindingConfigured { t.Fatalf("unexpected client peer attach policy snapshot: %+v", clientSnapshot) } } func TestPeerAttachChannelBindingProviderFailureRejectsAttach(t *testing.T) { secret := []byte("correct horse battery staple") server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{ RequireChannelBinding: true, ChannelBinding: failingPeerAttachChannelBindingProvider, }); err != nil { t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err) } }) client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := client.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{ RequireChannelBinding: true, ChannelBinding: staticPeerAttachChannelBindingProvider([]byte("binding")), }); err != nil { t.Fatalf("client SetPeerAttachSecurityConfig failed: %v", err) } left, right := net.Pipe() defer right.Close() bootstrapPeerAttachLogicalForTest(t, server, right) err := client.ConnectByConn(left) if !errors.Is(err, errPeerAttachChannelBindingUnavailable) && (err == nil || err.Error() != errPeerAttachChannelBindingUnavailable.Error()) { t.Fatalf("ConnectByConn error = %v, want %v", err, errPeerAttachChannelBindingUnavailable) } snapshot, snapErr := GetServerRuntimeSnapshot(server) if snapErr != nil { t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr) } if got, want := snapshot.PeerAttachBindingRejects, int64(1); got != want { t.Fatalf("PeerAttachBindingRejects = %d, want %d", got, want) } } func TestPeerAttachReplayCapacityRejectsOverflow(t *testing.T) { secret := []byte("correct horse battery staple") client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) if err := UseModernPSKClient(client, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := UseModernPSKServer(server, secret, testModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } if err := server.SetPeerAttachSecurityConfig(PeerAttachSecurityConfig{ ReplayCapacity: 1, }); err != nil { t.Fatalf("server SetPeerAttachSecurityConfig failed: %v", err) } logical := newPeerAttachAuthLogicalForTest(t, server) first, _, err := client.buildPeerAttachRequest("peer-one") if err != nil { t.Fatalf("buildPeerAttachRequest(first) failed: %v", err) } second, _, err := client.buildPeerAttachRequest("peer-two") if err != nil { t.Fatalf("buildPeerAttachRequest(second) failed: %v", err) } if _, err := server.validatePeerAttachRequestAuth(logical, nil, first); err != nil { t.Fatalf("validatePeerAttachRequestAuth(first) failed: %v", err) } if _, err := server.validatePeerAttachRequestAuth(logical, nil, second); !errors.Is(err, errPeerAttachReplayWindowFull) { t.Fatalf("validatePeerAttachRequestAuth(second) error = %v, want %v", err, errPeerAttachReplayWindowFull) } classifyPeerAttachRejectCounter(server, errPeerAttachReplayWindowFull) snapshot, snapErr := GetServerRuntimeSnapshot(server) if snapErr != nil { t.Fatalf("GetServerRuntimeSnapshot failed: %v", snapErr) } if got, want := snapshot.PeerAttachReplayCapacity, 1; got != want { t.Fatalf("PeerAttachReplayCapacity = %d, want %d", got, want) } if got, want := snapshot.PeerAttachReplayOverflowRejects, int64(1); got != want { t.Fatalf("PeerAttachReplayOverflowRejects = %d, want %d", got, want) } }