package notify import ( "context" "errors" "testing" "time" ) func TestBindTransferControlValidation(t *testing.T) { if err := BindTransferControlClient(nil, TransferControlHandler{}); !errors.Is(err, errTransferControlClientNil) { t.Fatalf("BindTransferControlClient nil error mismatch: %v", err) } if err := BindTransferControlServer(nil, TransferControlHandler{}); !errors.Is(err, errTransferControlServerNil) { t.Fatalf("BindTransferControlServer nil error mismatch: %v", err) } client := NewClient() if err := BindTransferControlClient(client, TransferControlHandler{}); !errors.Is(err, errTransferControlHandlerEmpty) { t.Fatalf("BindTransferControlClient empty handler error mismatch: %v", err) } server := NewServer() if err := BindTransferControlServer(server, TransferControlHandler{}); !errors.Is(err, errTransferControlHandlerEmpty) { t.Fatalf("BindTransferControlServer empty handler error mismatch: %v", err) } if _, err := SendTransferBeginClient(context.Background(), nil, TransferBeginRequest{}); !errors.Is(err, errTransferControlClientNil) { t.Fatalf("SendTransferBeginClient nil client error mismatch: %v", err) } if _, err := SendTransferBeginServer(context.Background(), nil, nil, TransferBeginRequest{}); !errors.Is(err, errTransferControlServerNil) { t.Fatalf("SendTransferBeginServer nil server error mismatch: %v", err) } if _, err := SendTransferBeginServer(context.Background(), server, nil, TransferBeginRequest{}); !errors.Is(err, errTransferControlLogicalConnNil) { t.Fatalf("SendTransferBeginServer nil conn error mismatch: %v", err) } } func TestTransferControlRoundTripTCP(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } disableBuiltinTransferControlForServer(server) beginReqCh := make(chan TransferBeginRequest, 1) commitReqCh := make(chan TransferCommitRequest, 2) if err := BindTransferControlServer(server, TransferControlHandler{ Begin: func(_ *Message, req TransferBeginRequest) (TransferBeginResponse, error) { beginReqCh <- req return TransferBeginResponse{ TransferID: req.TransferID, Accepted: true, NextOffset: 512, Missing: []TransferRange{ {Offset: 768, Length: 128}, }, }, nil }, Commit: func(_ *Message, req TransferCommitRequest) (TransferCommitResponse, error) { commitReqCh <- req return TransferCommitResponse{ TransferID: req.TransferID, Accepted: true, }, nil }, }); err != nil { t.Fatalf("BindTransferControlServer failed: %v", err) } if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { t.Fatalf("server Listen failed: %v", err) } defer func() { _ = server.Stop() }() addr := server.listener.Addr().String() client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } disableBuiltinTransferControlForClient(client) if err := BindTransferControlClient(client, TransferControlHandler{ Begin: func(_ *Message, req TransferBeginRequest) (TransferBeginResponse, error) { beginReqCh <- req return TransferBeginResponse{ TransferID: req.TransferID, Accepted: true, NextOffset: 256, }, nil }, Commit: func(_ *Message, req TransferCommitRequest) (TransferCommitResponse, error) { commitReqCh <- req return TransferCommitResponse{ TransferID: req.TransferID, Accepted: true, }, nil }, }); err != nil { t.Fatalf("BindTransferControlClient failed: %v", err) } if err := client.Connect("tcp", addr); err != nil { t.Fatalf("client Connect failed: %v", err) } defer func() { _ = client.Stop() }() beginResp, err := SendTransferBeginClient(context.Background(), client, TransferBeginRequest{ TransferID: "tx-client", Channel: TransferChannelData, Size: 1024, Checksum: "sha256:demo", Metadata: map[string]string{ "name": "demo.bin", }, }) if err != nil { t.Fatalf("SendTransferBeginClient failed: %v", err) } if !beginResp.Accepted || beginResp.TransferID != "tx-client" || beginResp.NextOffset != 512 { t.Fatalf("begin response mismatch: %+v", beginResp) } if len(beginResp.Missing) != 1 || beginResp.Missing[0].Offset != 768 || beginResp.Missing[0].Length != 128 { t.Fatalf("begin response missing mismatch: %+v", beginResp.Missing) } select { case got := <-beginReqCh: if got.TransferID != "tx-client" || got.Channel != TransferChannelData || got.Size != 1024 || got.Checksum != "sha256:demo" { t.Fatalf("begin request mismatch: %+v", got) } if got.Metadata["name"] != "demo.bin" { t.Fatalf("begin request metadata mismatch: %+v", got.Metadata) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for begin request") } commitClientResp, err := SendTransferCommitClient(context.Background(), client, TransferCommitRequest{ TransferID: "tx-client", Size: 1024, Checksum: "sha256:demo", }) if err != nil { t.Fatalf("SendTransferCommitClient failed: %v", err) } if !commitClientResp.Accepted || commitClientResp.TransferID != "tx-client" { t.Fatalf("client commit response mismatch: %+v", commitClientResp) } select { case got := <-commitReqCh: if got.TransferID != "tx-client" || got.Size != 1024 || got.Checksum != "sha256:demo" { t.Fatalf("client commit request mismatch: %+v", got) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for client commit request") } logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) conn := clientConnFromLogical(logical) serverScope := serverFileScope(conn) clientSnapshot, ok, err := GetClientTransferSnapshotByID(client, "tx-client") if err != nil { t.Fatalf("GetClientTransferSnapshotByID failed: %v", err) } if !ok { t.Fatal("client snapshot should exist") } if got, want := clientSnapshot.Direction, TransferDirectionSend; got != want { t.Fatalf("client snapshot direction = %v, want %v", got, want) } if got, want := clientSnapshot.Scope, clientFileScope(); got != want { t.Fatalf("client snapshot scope = %q, want %q", got, want) } if got, want := clientSnapshot.State, TransferStateDone; got != want { t.Fatalf("client snapshot state = %v, want %v", got, want) } if got, want := clientSnapshot.AckedBytes, int64(512); got != want { t.Fatalf("client snapshot acked bytes = %d, want %d", got, want) } if got, want := clientSnapshot.Stage, "commit"; got != want { t.Fatalf("client snapshot stage = %q, want %q", got, want) } serverSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-client") if err != nil { t.Fatalf("GetServerTransferSnapshotByID failed: %v", err) } if !ok { t.Fatal("server snapshot should exist") } if got, want := serverSnapshot.Direction, TransferDirectionReceive; got != want { t.Fatalf("server snapshot direction = %v, want %v", got, want) } if got, want := serverSnapshot.Scope, serverScope; got != want { t.Fatalf("server snapshot scope = %q, want %q", got, want) } if got, want := serverSnapshot.RuntimeScope, serverTransportScope(conn); got != want { t.Fatalf("server snapshot runtime scope = %q, want %q", got, want) } if got, want := serverSnapshot.TransportGeneration, uint64(1); got != want { t.Fatalf("server snapshot transport generation = %d, want %d", got, want) } if got, want := serverSnapshot.State, TransferStateDone; got != want { t.Fatalf("server snapshot state = %v, want %v", got, want) } if got, want := serverSnapshot.ReceivedBytes, int64(512); got != want { t.Fatalf("server snapshot received bytes = %d, want %d", got, want) } if got, want := serverSnapshot.Stage, "commit"; got != want { t.Fatalf("server snapshot stage = %q, want %q", got, want) } beginServerResp, err := SendTransferBeginServer(context.Background(), server, conn, TransferBeginRequest{ TransferID: "tx-server", Channel: TransferChannelControl, Size: 512, Checksum: "sha256:server", }) if err != nil { t.Fatalf("SendTransferBeginServer failed: %v", err) } if !beginServerResp.Accepted || beginServerResp.TransferID != "tx-server" || beginServerResp.NextOffset != 256 { t.Fatalf("server begin response mismatch: %+v", beginServerResp) } select { case got := <-beginReqCh: if got.TransferID != "tx-server" || got.Channel != TransferChannelControl || got.Size != 512 || got.Checksum != "sha256:server" { t.Fatalf("server begin request mismatch: %+v", got) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for server begin request") } commitResp, err := SendTransferCommitServer(context.Background(), server, conn, TransferCommitRequest{ TransferID: "tx-server", Size: 512, Checksum: "sha256:server", }) if err != nil { t.Fatalf("SendTransferCommitServer failed: %v", err) } if !commitResp.Accepted || commitResp.TransferID != "tx-server" { t.Fatalf("commit response mismatch: %+v", commitResp) } select { case got := <-commitReqCh: if got.TransferID != "tx-server" || got.Size != 512 || got.Checksum != "sha256:server" { t.Fatalf("server commit request mismatch: %+v", got) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for server commit request") } serverSendSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-server") if err != nil { t.Fatalf("GetServerTransferSnapshotByID server-send failed: %v", err) } if !ok { t.Fatal("server send snapshot should exist") } if got, want := serverSendSnapshot.Direction, TransferDirectionSend; got != want { t.Fatalf("server send snapshot direction = %v, want %v", got, want) } if got, want := serverSendSnapshot.Scope, serverScope; got != want { t.Fatalf("server send snapshot scope = %q, want %q", got, want) } if got, want := serverSendSnapshot.RuntimeScope, serverTransportScope(conn); got != want { t.Fatalf("server send snapshot runtime scope = %q, want %q", got, want) } if got, want := serverSendSnapshot.TransportGeneration, uint64(1); got != want { t.Fatalf("server send snapshot transport generation = %d, want %d", got, want) } if got, want := serverSendSnapshot.Channel, TransferChannelControl; got != want { t.Fatalf("server send snapshot channel = %q, want %q", got, want) } if got, want := serverSendSnapshot.State, TransferStateDone; got != want { t.Fatalf("server send snapshot state = %v, want %v", got, want) } if got, want := serverSendSnapshot.AckedBytes, int64(256); got != want { t.Fatalf("server send snapshot acked bytes = %d, want %d", got, want) } clientRecvSnapshot, ok, err := GetClientTransferSnapshotByID(client, "tx-server") if err != nil { t.Fatalf("GetClientTransferSnapshotByID client-recv failed: %v", err) } if !ok { t.Fatal("client receive snapshot should exist") } if got, want := clientRecvSnapshot.Direction, TransferDirectionReceive; got != want { t.Fatalf("client receive snapshot direction = %v, want %v", got, want) } if got, want := clientRecvSnapshot.Scope, clientFileScope(); got != want { t.Fatalf("client receive snapshot scope = %q, want %q", got, want) } if got, want := clientRecvSnapshot.Channel, TransferChannelControl; got != want { t.Fatalf("client receive snapshot channel = %q, want %q", got, want) } if got, want := clientRecvSnapshot.State, TransferStateDone; got != want { t.Fatalf("client receive snapshot state = %v, want %v", got, want) } if got, want := clientRecvSnapshot.ReceivedBytes, int64(256); got != want { t.Fatalf("client receive snapshot received bytes = %d, want %d", got, want) } } func waitForTransferControlLogicalConn(t *testing.T, server *ServerCommon, timeout time.Duration) *LogicalConn { t.Helper() deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { logicals := server.GetLogicalConnList() if len(logicals) > 0 && logicals[0] != nil { return logicals[0] } time.Sleep(10 * time.Millisecond) } t.Fatal("timed out waiting for logical connection") return nil } func waitForTransferControlClientConn(t *testing.T, server *ServerCommon, timeout time.Duration) *ClientConn { return clientConnFromLogical(waitForTransferControlLogicalConn(t, server, timeout)) } func disableBuiltinTransferControlForClient(client *ClientCommon) { if client == nil { return } state := client.getTransferState() state.mu.Lock() state.controlEnabled = false state.handler = nil state.builtinHandler = nil state.mu.Unlock() } func disableBuiltinTransferControlForServer(server *ServerCommon) { if server == nil { return } state := server.getTransferState() state.mu.Lock() state.controlEnabled = false state.handler = nil state.builtinHandler = nil state.mu.Unlock() }