- 引入 LogicalConn/TransportConn 分层,ClientConn 保留兼容适配层 - 新增 Stream、Bulk、RecordStream 三条数据面能力及对应控制路径 - 完成 transfer/file 传输内核与状态快照、诊断能力 - 补齐 reconnect、inbound dispatcher、modern psk 等基础模块 - 增加大规模回归、并发与基准测试覆盖 - 更新依赖库
545 lines
15 KiB
Go
545 lines
15 KiB
Go
package notify
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
var errInMemoryListenerClosed = errors.New("in-memory listener closed")
|
|
|
|
type inMemoryListener struct {
|
|
closed chan struct{}
|
|
once sync.Once
|
|
}
|
|
|
|
func newInMemoryListener() *inMemoryListener {
|
|
return &inMemoryListener{
|
|
closed: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func (l *inMemoryListener) Accept() (net.Conn, error) {
|
|
<-l.closed
|
|
return nil, errInMemoryListenerClosed
|
|
}
|
|
|
|
func (l *inMemoryListener) Close() error {
|
|
l.once.Do(func() {
|
|
close(l.closed)
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (l *inMemoryListener) Addr() net.Addr {
|
|
return inMemoryAddr("in-memory-listener")
|
|
}
|
|
|
|
type inMemoryAddr string
|
|
|
|
func (a inMemoryAddr) Network() string { return "in-memory" }
|
|
func (a inMemoryAddr) String() string { return string(a) }
|
|
|
|
func TestConnectByConnRequiresModernPSK(t *testing.T) {
|
|
client := NewClient()
|
|
left, right := net.Pipe()
|
|
defer left.Close()
|
|
defer right.Close()
|
|
|
|
err := client.ConnectByConn(left)
|
|
if !errors.Is(err, errModernPSKRequired) {
|
|
t.Fatalf("ConnectByConn error = %v, want %v", err, errModernPSKRequired)
|
|
}
|
|
}
|
|
|
|
func TestConnectByConnWithConfiguredSecurity(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
secret := []byte("0123456789abcdef0123456789abcdef")
|
|
left, right := net.Pipe()
|
|
defer right.Close()
|
|
|
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
|
server.SetSecretKey(secret)
|
|
})
|
|
bootstrapPeerAttachConnForTest(t, server, right)
|
|
|
|
client.SetSecretKey(secret)
|
|
if err := client.ConnectByConn(left); err != nil {
|
|
t.Fatalf("ConnectByConn failed: %v", err)
|
|
}
|
|
|
|
client.setByeFromServer(true)
|
|
if err := client.Stop(); err != nil {
|
|
t.Fatalf("Stop failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestConnectByFactoryRequiresModernPSK(t *testing.T) {
|
|
client := NewClient()
|
|
called := false
|
|
|
|
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
|
called = true
|
|
left, right := net.Pipe()
|
|
_ = right.Close()
|
|
return left, nil
|
|
})
|
|
if !errors.Is(err, errModernPSKRequired) {
|
|
t.Fatalf("ConnectByFactory error = %v, want %v", err, errModernPSKRequired)
|
|
}
|
|
if called {
|
|
t.Fatal("dialFn should not be called before security validation passes")
|
|
}
|
|
}
|
|
|
|
func TestConnectByFactoryRejectsNilDialFn(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
|
|
|
err := client.ConnectByFactory(context.Background(), nil)
|
|
if err == nil || err.Error() != "dialFn is nil" {
|
|
t.Fatalf("ConnectByFactory nil dialFn error = %v, want dialFn is nil", err)
|
|
}
|
|
}
|
|
|
|
func TestConnectByFactoryPropagatesDialError(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
|
wantErr := errors.New("dial failed")
|
|
|
|
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
|
return nil, wantErr
|
|
})
|
|
if !errors.Is(err, wantErr) {
|
|
t.Fatalf("ConnectByFactory error = %v, want %v", err, wantErr)
|
|
}
|
|
}
|
|
|
|
func TestConnectByFactoryWithConfiguredSecurity(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
secret := []byte("0123456789abcdef0123456789abcdef")
|
|
left, right := net.Pipe()
|
|
defer right.Close()
|
|
|
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
|
server.SetSecretKey(secret)
|
|
})
|
|
bootstrapPeerAttachConnForTest(t, server, right)
|
|
|
|
client.SetSecretKey(secret)
|
|
if err := client.ConnectByFactory(nil, func(ctx context.Context) (net.Conn, error) {
|
|
if ctx == nil {
|
|
t.Fatal("ConnectByFactory should normalize nil context")
|
|
}
|
|
return left, nil
|
|
}); err != nil {
|
|
t.Fatalf("ConnectByFactory failed: %v", err)
|
|
}
|
|
|
|
client.setByeFromServer(true)
|
|
if err := client.Stop(); err != nil {
|
|
t.Fatalf("Stop failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestConnectByFactoryRejectsConcurrentStart(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
firstDialEntered := make(chan struct{}, 1)
|
|
firstDone := make(chan error, 1)
|
|
|
|
go func() {
|
|
firstDone <- client.ConnectByFactory(ctx, func(ctx context.Context) (net.Conn, error) {
|
|
firstDialEntered <- struct{}{}
|
|
<-ctx.Done()
|
|
return nil, ctx.Err()
|
|
})
|
|
}()
|
|
|
|
select {
|
|
case <-firstDialEntered:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("first connect attempt did not enter dialFn")
|
|
}
|
|
|
|
secondDialCalled := false
|
|
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
|
secondDialCalled = true
|
|
return nil, errors.New("second dial should not run")
|
|
})
|
|
if err == nil || err.Error() != "client already run" {
|
|
t.Fatalf("concurrent ConnectByFactory error = %v, want client already run", err)
|
|
}
|
|
if secondDialCalled {
|
|
t.Fatal("second dialFn should not be called during first connect start")
|
|
}
|
|
|
|
cancel()
|
|
select {
|
|
case err = <-firstDone:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("first ConnectByFactory did not finish after cancel")
|
|
}
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("first ConnectByFactory error = %v, want %v", err, context.Canceled)
|
|
}
|
|
|
|
wantErr := errors.New("dial after rollback")
|
|
err = client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
|
return nil, wantErr
|
|
})
|
|
if !errors.Is(err, wantErr) {
|
|
t.Fatalf("ConnectByFactory after rollback error = %v, want %v", err, wantErr)
|
|
}
|
|
}
|
|
|
|
func TestConnectByConnReattachesDetachedAliveSession(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
secret := []byte("0123456789abcdef0123456789abcdef")
|
|
client.SetSecretKey(secret)
|
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
|
server.SetSecretKey(secret)
|
|
})
|
|
|
|
firstLeft, firstRight := net.Pipe()
|
|
defer firstRight.Close()
|
|
bootstrapPeerAttachConnForTest(t, server, firstRight)
|
|
if err := client.ConnectByConn(firstLeft); err != nil {
|
|
t.Fatalf("initial ConnectByConn failed: %v", err)
|
|
}
|
|
before := client.clientSessionRuntimeSnapshot()
|
|
if before == nil {
|
|
t.Fatal("runtime should exist after initial connect")
|
|
}
|
|
initialEpoch := before.epoch
|
|
initialStopCtx := before.stopCtx
|
|
initialQueue := before.queue
|
|
|
|
client.clearClientSessionRuntimeTransport()
|
|
|
|
recvCh := make(chan Message, 1)
|
|
client.SetLink("reattach-public", func(message *Message) {
|
|
recvCh <- *message
|
|
})
|
|
|
|
secondLeft, secondRight := net.Pipe()
|
|
defer secondRight.Close()
|
|
bootstrapPeerAttachConnForTest(t, server, secondRight)
|
|
if err := client.ConnectByConn(secondLeft); err != nil {
|
|
t.Fatalf("reattach ConnectByConn failed: %v", err)
|
|
}
|
|
|
|
after := client.clientSessionRuntimeSnapshot()
|
|
if after == nil {
|
|
t.Fatal("runtime should exist after reattach")
|
|
}
|
|
if after.conn != secondLeft || after.queue != initialQueue || after.stopCtx != initialStopCtx || after.epoch != initialEpoch || !after.transportAttached {
|
|
t.Fatalf("reattached runtime mismatch: %+v", after)
|
|
}
|
|
|
|
env, err := wrapTransferMsgEnvelope(TransferMsg{
|
|
ID: 88,
|
|
Key: "reattach-public",
|
|
Value: []byte("ok"),
|
|
Type: MSG_ASYNC,
|
|
}, client.sequenceEn)
|
|
if err != nil {
|
|
t.Fatalf("wrapTransferMsgEnvelope failed: %v", err)
|
|
}
|
|
wire, err := client.encodeEnvelope(env)
|
|
if err != nil {
|
|
t.Fatalf("encodeEnvelope failed: %v", err)
|
|
}
|
|
if _, err := secondRight.Write(wire); err != nil {
|
|
t.Fatalf("reattached conn write failed: %v", err)
|
|
}
|
|
|
|
select {
|
|
case msg := <-recvCh:
|
|
if got, want := msg.Key, "reattach-public"; got != want {
|
|
t.Fatalf("message key mismatch: got %q want %q", got, want)
|
|
}
|
|
if got, want := string(msg.Value), "ok"; got != want {
|
|
t.Fatalf("message value mismatch: got %q want %q", got, want)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatal("reattached public conn did not dispatch message")
|
|
}
|
|
|
|
client.setByeFromServer(true)
|
|
if err := client.Stop(); err != nil {
|
|
t.Fatalf("final Stop failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestConnectByFactoryReattachesDetachedAliveSessionAndUpdatesSource(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
secret := []byte("0123456789abcdef0123456789abcdef")
|
|
client.SetSecretKey(secret)
|
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
|
server.SetSecretKey(secret)
|
|
})
|
|
|
|
firstLeft, firstRight := net.Pipe()
|
|
defer firstRight.Close()
|
|
bootstrapPeerAttachConnForTest(t, server, firstRight)
|
|
if err := client.ConnectByConn(firstLeft); err != nil {
|
|
t.Fatalf("initial ConnectByConn failed: %v", err)
|
|
}
|
|
before := client.clientSessionRuntimeSnapshot()
|
|
if before == nil {
|
|
t.Fatal("runtime should exist after initial connect")
|
|
}
|
|
initialEpoch := before.epoch
|
|
|
|
client.clearClientSessionRuntimeTransport()
|
|
|
|
var dialCount atomic.Int32
|
|
secondLeft, secondRight := net.Pipe()
|
|
defer secondRight.Close()
|
|
bootstrapPeerAttachConnForTest(t, server, secondRight)
|
|
if err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
|
dialCount.Add(1)
|
|
return secondLeft, nil
|
|
}); err != nil {
|
|
t.Fatalf("reattach ConnectByFactory failed: %v", err)
|
|
}
|
|
if got, want := dialCount.Load(), int32(1); got != want {
|
|
t.Fatalf("dial count mismatch: got %d want %d", got, want)
|
|
}
|
|
after := client.clientSessionRuntimeSnapshot()
|
|
if after == nil {
|
|
t.Fatal("runtime should exist after factory reattach")
|
|
}
|
|
if after.epoch != initialEpoch || after.conn != secondLeft || !after.transportAttached {
|
|
t.Fatalf("reattached runtime mismatch: %+v", after)
|
|
}
|
|
snapshot, err := GetClientRuntimeSnapshot(client)
|
|
if err != nil {
|
|
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
|
|
}
|
|
if got, want := snapshot.ConnectSource, clientConnectSourceFactory; got != want {
|
|
t.Fatalf("connect source mismatch: got %q want %q", got, want)
|
|
}
|
|
if !snapshot.CanReconnect {
|
|
t.Fatalf("snapshot should be reconnectable after factory reattach: %+v", snapshot)
|
|
}
|
|
|
|
client.setByeFromServer(true)
|
|
if err := client.Stop(); err != nil {
|
|
t.Fatalf("final Stop failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestConnectByConnFailureCleansRuntimeAndAllowsRetry(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
UseLegacySecurityClient(client)
|
|
failErr := errors.New("key exchange fail for test")
|
|
client.keyExchangeFn = func(Client) error {
|
|
return failErr
|
|
}
|
|
|
|
left1, right1 := net.Pipe()
|
|
defer right1.Close()
|
|
err := client.ConnectByConn(left1)
|
|
if !errors.Is(err, failErr) {
|
|
t.Fatalf("ConnectByConn first error = %v, want %v", err, failErr)
|
|
}
|
|
status := client.Status()
|
|
if status.Alive || status.Reason != "key exchange failed" || !errors.Is(status.Err, failErr) {
|
|
t.Fatalf("unexpected status after failed key exchange: %+v", status)
|
|
}
|
|
select {
|
|
case <-client.StopMonitorChan():
|
|
t.Fatal("StopMonitorChan should remain open after failed connect cleanup")
|
|
case <-time.After(20 * time.Millisecond):
|
|
}
|
|
|
|
client.SetSkipExchangeKey(true)
|
|
left2, right2 := net.Pipe()
|
|
defer right2.Close()
|
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
|
UseLegacySecurityServer(server)
|
|
})
|
|
bootstrapPeerAttachConnForTest(t, server, right2)
|
|
if err := client.ConnectByConn(left2); err != nil {
|
|
t.Fatalf("ConnectByConn second attempt failed: %v", err)
|
|
}
|
|
if !client.Status().Alive {
|
|
t.Fatalf("client should be alive after second ConnectByConn: %+v", client.Status())
|
|
}
|
|
client.setByeFromServer(true)
|
|
if err := client.Stop(); err != nil {
|
|
t.Fatalf("Stop failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestListenByListenerRequiresModernPSK(t *testing.T) {
|
|
server := NewServer()
|
|
listener := newInMemoryListener()
|
|
defer listener.Close()
|
|
|
|
err := server.ListenByListener(listener)
|
|
if !errors.Is(err, errModernPSKRequired) {
|
|
t.Fatalf("ListenByListener error = %v, want %v", err, errModernPSKRequired)
|
|
}
|
|
}
|
|
|
|
func TestListenByListenerWithConfiguredSecurity(t *testing.T) {
|
|
server := NewServer().(*ServerCommon)
|
|
listener := newInMemoryListener()
|
|
defer listener.Close()
|
|
|
|
server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
|
if err := server.ListenByListener(listener); err != nil {
|
|
t.Fatalf("ListenByListener failed: %v", err)
|
|
}
|
|
if !server.Status().Alive {
|
|
t.Fatal("server should be alive after ListenByListener")
|
|
}
|
|
if err := server.Stop(); err != nil {
|
|
t.Fatalf("Stop failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestListenByListenerRejectsNil(t *testing.T) {
|
|
server := NewServer().(*ServerCommon)
|
|
server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
|
err := server.ListenByListener(nil)
|
|
if err == nil || err.Error() != "listener is nil" {
|
|
t.Fatalf("ListenByListener nil error = %v, want listener is nil", err)
|
|
}
|
|
}
|
|
|
|
func TestClientReadMessagePreservesUserStopReason(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
left, right := net.Pipe()
|
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
|
defer stopFn()
|
|
|
|
client.conn = left
|
|
client.stopCtx = stopCtx
|
|
client.stopFn = stopFn
|
|
client.markSessionStarted()
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
client.readMessage()
|
|
close(done)
|
|
}()
|
|
|
|
if err := client.Stop(); err != nil {
|
|
t.Fatalf("Stop failed: %v", err)
|
|
}
|
|
_ = right.Close()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("readMessage should exit after user stop")
|
|
}
|
|
|
|
status := client.Status()
|
|
if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil {
|
|
t.Fatalf("unexpected status after user stop: %+v", status)
|
|
}
|
|
}
|
|
|
|
func TestClientReadMessagePreservesServerStopReason(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
left, right := net.Pipe()
|
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
|
defer stopFn()
|
|
|
|
client.conn = left
|
|
client.stopCtx = stopCtx
|
|
client.stopFn = stopFn
|
|
client.markSessionStarted()
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
client.readMessage()
|
|
close(done)
|
|
}()
|
|
|
|
client.stopClientSessionFromServer("recv stop signal from server", nil)
|
|
_ = right.Close()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("readMessage should exit after server stop")
|
|
}
|
|
|
|
status := client.Status()
|
|
if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
|
|
t.Fatalf("unexpected status after server stop: %+v", status)
|
|
}
|
|
}
|
|
|
|
func TestClientStopClientSessionFromServerDisablesGoodBye(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
client.markSessionStarted()
|
|
|
|
client.stopClientSessionFromServer("recv stop signal from server", nil)
|
|
|
|
if client.shouldSayGoodByeOnStop() {
|
|
t.Fatal("server stop should disable goodbye on stop")
|
|
}
|
|
status := client.Status()
|
|
if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
|
|
t.Fatalf("unexpected status after server stop helper: %+v", status)
|
|
}
|
|
}
|
|
|
|
func TestClientStopClientSessionKeepsGoodByeEnabled(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
client.markSessionStarted()
|
|
|
|
client.stopClientSession("recv stop signal from user", nil)
|
|
|
|
if !client.shouldSayGoodByeOnStop() {
|
|
t.Fatal("local stop should keep goodbye enabled")
|
|
}
|
|
status := client.Status()
|
|
if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil {
|
|
t.Fatalf("unexpected status after local stop helper: %+v", status)
|
|
}
|
|
}
|
|
|
|
func TestClientReadMessageLoopUsesProvidedStopCtx(t *testing.T) {
|
|
client := NewClient().(*ClientCommon)
|
|
left, right := net.Pipe()
|
|
defer right.Close()
|
|
|
|
loopCtx, loopCancel := context.WithCancel(context.Background())
|
|
loopCancel()
|
|
|
|
client.stopCtx = context.Background()
|
|
client.conn = nil
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
client.readMessageLoop(loopCtx, left, nil, 1)
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("readMessageLoop should exit when provided stopCtx is canceled")
|
|
}
|
|
|
|
if _, err := right.Write([]byte("x")); err == nil {
|
|
t.Fatal("peer conn should be closed when loop exits")
|
|
}
|
|
}
|