notify/transfer_control_test.go

358 lines
13 KiB
Go
Raw Permalink Normal View History

package notify
import (
"context"
"errors"
"testing"
"time"
)
func TestBindTransferControlValidation(t *testing.T) {
if err := BindTransferControlClient(nil, TransferControlHandler{}); !errors.Is(err, errTransferControlClientNil) {
t.Fatalf("BindTransferControlClient nil error mismatch: %v", err)
}
if err := BindTransferControlServer(nil, TransferControlHandler{}); !errors.Is(err, errTransferControlServerNil) {
t.Fatalf("BindTransferControlServer nil error mismatch: %v", err)
}
client := NewClient()
if err := BindTransferControlClient(client, TransferControlHandler{}); !errors.Is(err, errTransferControlHandlerEmpty) {
t.Fatalf("BindTransferControlClient empty handler error mismatch: %v", err)
}
server := NewServer()
if err := BindTransferControlServer(server, TransferControlHandler{}); !errors.Is(err, errTransferControlHandlerEmpty) {
t.Fatalf("BindTransferControlServer empty handler error mismatch: %v", err)
}
if _, err := SendTransferBeginClient(context.Background(), nil, TransferBeginRequest{}); !errors.Is(err, errTransferControlClientNil) {
t.Fatalf("SendTransferBeginClient nil client error mismatch: %v", err)
}
if _, err := SendTransferBeginServer(context.Background(), nil, nil, TransferBeginRequest{}); !errors.Is(err, errTransferControlServerNil) {
t.Fatalf("SendTransferBeginServer nil server error mismatch: %v", err)
}
if _, err := SendTransferBeginServer(context.Background(), server, nil, TransferBeginRequest{}); !errors.Is(err, errTransferControlLogicalConnNil) {
t.Fatalf("SendTransferBeginServer nil conn error mismatch: %v", err)
}
}
func TestTransferControlRoundTripTCP(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
disableBuiltinTransferControlForServer(server)
beginReqCh := make(chan TransferBeginRequest, 1)
commitReqCh := make(chan TransferCommitRequest, 2)
if err := BindTransferControlServer(server, TransferControlHandler{
Begin: func(_ *Message, req TransferBeginRequest) (TransferBeginResponse, error) {
beginReqCh <- req
return TransferBeginResponse{
TransferID: req.TransferID,
Accepted: true,
NextOffset: 512,
Missing: []TransferRange{
{Offset: 768, Length: 128},
},
}, nil
},
Commit: func(_ *Message, req TransferCommitRequest) (TransferCommitResponse, error) {
commitReqCh <- req
return TransferCommitResponse{
TransferID: req.TransferID,
Accepted: true,
}, nil
},
}); err != nil {
t.Fatalf("BindTransferControlServer failed: %v", err)
}
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
addr := server.listener.Addr().String()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
disableBuiltinTransferControlForClient(client)
if err := BindTransferControlClient(client, TransferControlHandler{
Begin: func(_ *Message, req TransferBeginRequest) (TransferBeginResponse, error) {
beginReqCh <- req
return TransferBeginResponse{
TransferID: req.TransferID,
Accepted: true,
NextOffset: 256,
}, nil
},
Commit: func(_ *Message, req TransferCommitRequest) (TransferCommitResponse, error) {
commitReqCh <- req
return TransferCommitResponse{
TransferID: req.TransferID,
Accepted: true,
}, nil
},
}); err != nil {
t.Fatalf("BindTransferControlClient failed: %v", err)
}
if err := client.Connect("tcp", addr); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
beginResp, err := SendTransferBeginClient(context.Background(), client, TransferBeginRequest{
TransferID: "tx-client",
Channel: TransferChannelData,
Size: 1024,
Checksum: "sha256:demo",
Metadata: map[string]string{
"name": "demo.bin",
},
})
if err != nil {
t.Fatalf("SendTransferBeginClient failed: %v", err)
}
if !beginResp.Accepted || beginResp.TransferID != "tx-client" || beginResp.NextOffset != 512 {
t.Fatalf("begin response mismatch: %+v", beginResp)
}
if len(beginResp.Missing) != 1 || beginResp.Missing[0].Offset != 768 || beginResp.Missing[0].Length != 128 {
t.Fatalf("begin response missing mismatch: %+v", beginResp.Missing)
}
select {
case got := <-beginReqCh:
if got.TransferID != "tx-client" || got.Channel != TransferChannelData || got.Size != 1024 || got.Checksum != "sha256:demo" {
t.Fatalf("begin request mismatch: %+v", got)
}
if got.Metadata["name"] != "demo.bin" {
t.Fatalf("begin request metadata mismatch: %+v", got.Metadata)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for begin request")
}
commitClientResp, err := SendTransferCommitClient(context.Background(), client, TransferCommitRequest{
TransferID: "tx-client",
Size: 1024,
Checksum: "sha256:demo",
})
if err != nil {
t.Fatalf("SendTransferCommitClient failed: %v", err)
}
if !commitClientResp.Accepted || commitClientResp.TransferID != "tx-client" {
t.Fatalf("client commit response mismatch: %+v", commitClientResp)
}
select {
case got := <-commitReqCh:
if got.TransferID != "tx-client" || got.Size != 1024 || got.Checksum != "sha256:demo" {
t.Fatalf("client commit request mismatch: %+v", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for client commit request")
}
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
conn := clientConnFromLogical(logical)
serverScope := serverFileScope(conn)
clientSnapshot, ok, err := GetClientTransferSnapshotByID(client, "tx-client")
if err != nil {
t.Fatalf("GetClientTransferSnapshotByID failed: %v", err)
}
if !ok {
t.Fatal("client snapshot should exist")
}
if got, want := clientSnapshot.Direction, TransferDirectionSend; got != want {
t.Fatalf("client snapshot direction = %v, want %v", got, want)
}
if got, want := clientSnapshot.Scope, clientFileScope(); got != want {
t.Fatalf("client snapshot scope = %q, want %q", got, want)
}
if got, want := clientSnapshot.State, TransferStateDone; got != want {
t.Fatalf("client snapshot state = %v, want %v", got, want)
}
if got, want := clientSnapshot.AckedBytes, int64(512); got != want {
t.Fatalf("client snapshot acked bytes = %d, want %d", got, want)
}
if got, want := clientSnapshot.Stage, "commit"; got != want {
t.Fatalf("client snapshot stage = %q, want %q", got, want)
}
serverSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-client")
if err != nil {
t.Fatalf("GetServerTransferSnapshotByID failed: %v", err)
}
if !ok {
t.Fatal("server snapshot should exist")
}
if got, want := serverSnapshot.Direction, TransferDirectionReceive; got != want {
t.Fatalf("server snapshot direction = %v, want %v", got, want)
}
if got, want := serverSnapshot.Scope, serverScope; got != want {
t.Fatalf("server snapshot scope = %q, want %q", got, want)
}
if got, want := serverSnapshot.RuntimeScope, serverTransportScope(conn); got != want {
t.Fatalf("server snapshot runtime scope = %q, want %q", got, want)
}
if got, want := serverSnapshot.TransportGeneration, uint64(1); got != want {
t.Fatalf("server snapshot transport generation = %d, want %d", got, want)
}
if got, want := serverSnapshot.State, TransferStateDone; got != want {
t.Fatalf("server snapshot state = %v, want %v", got, want)
}
if got, want := serverSnapshot.ReceivedBytes, int64(512); got != want {
t.Fatalf("server snapshot received bytes = %d, want %d", got, want)
}
if got, want := serverSnapshot.Stage, "commit"; got != want {
t.Fatalf("server snapshot stage = %q, want %q", got, want)
}
beginServerResp, err := SendTransferBeginServer(context.Background(), server, conn, TransferBeginRequest{
TransferID: "tx-server",
Channel: TransferChannelControl,
Size: 512,
Checksum: "sha256:server",
})
if err != nil {
t.Fatalf("SendTransferBeginServer failed: %v", err)
}
if !beginServerResp.Accepted || beginServerResp.TransferID != "tx-server" || beginServerResp.NextOffset != 256 {
t.Fatalf("server begin response mismatch: %+v", beginServerResp)
}
select {
case got := <-beginReqCh:
if got.TransferID != "tx-server" || got.Channel != TransferChannelControl || got.Size != 512 || got.Checksum != "sha256:server" {
t.Fatalf("server begin request mismatch: %+v", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for server begin request")
}
commitResp, err := SendTransferCommitServer(context.Background(), server, conn, TransferCommitRequest{
TransferID: "tx-server",
Size: 512,
Checksum: "sha256:server",
})
if err != nil {
t.Fatalf("SendTransferCommitServer failed: %v", err)
}
if !commitResp.Accepted || commitResp.TransferID != "tx-server" {
t.Fatalf("commit response mismatch: %+v", commitResp)
}
select {
case got := <-commitReqCh:
if got.TransferID != "tx-server" || got.Size != 512 || got.Checksum != "sha256:server" {
t.Fatalf("server commit request mismatch: %+v", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for server commit request")
}
serverSendSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-server")
if err != nil {
t.Fatalf("GetServerTransferSnapshotByID server-send failed: %v", err)
}
if !ok {
t.Fatal("server send snapshot should exist")
}
if got, want := serverSendSnapshot.Direction, TransferDirectionSend; got != want {
t.Fatalf("server send snapshot direction = %v, want %v", got, want)
}
if got, want := serverSendSnapshot.Scope, serverScope; got != want {
t.Fatalf("server send snapshot scope = %q, want %q", got, want)
}
if got, want := serverSendSnapshot.RuntimeScope, serverTransportScope(conn); got != want {
t.Fatalf("server send snapshot runtime scope = %q, want %q", got, want)
}
if got, want := serverSendSnapshot.TransportGeneration, uint64(1); got != want {
t.Fatalf("server send snapshot transport generation = %d, want %d", got, want)
}
if got, want := serverSendSnapshot.Channel, TransferChannelControl; got != want {
t.Fatalf("server send snapshot channel = %q, want %q", got, want)
}
if got, want := serverSendSnapshot.State, TransferStateDone; got != want {
t.Fatalf("server send snapshot state = %v, want %v", got, want)
}
if got, want := serverSendSnapshot.AckedBytes, int64(256); got != want {
t.Fatalf("server send snapshot acked bytes = %d, want %d", got, want)
}
clientRecvSnapshot, ok, err := GetClientTransferSnapshotByID(client, "tx-server")
if err != nil {
t.Fatalf("GetClientTransferSnapshotByID client-recv failed: %v", err)
}
if !ok {
t.Fatal("client receive snapshot should exist")
}
if got, want := clientRecvSnapshot.Direction, TransferDirectionReceive; got != want {
t.Fatalf("client receive snapshot direction = %v, want %v", got, want)
}
if got, want := clientRecvSnapshot.Scope, clientFileScope(); got != want {
t.Fatalf("client receive snapshot scope = %q, want %q", got, want)
}
if got, want := clientRecvSnapshot.Channel, TransferChannelControl; got != want {
t.Fatalf("client receive snapshot channel = %q, want %q", got, want)
}
if got, want := clientRecvSnapshot.State, TransferStateDone; got != want {
t.Fatalf("client receive snapshot state = %v, want %v", got, want)
}
if got, want := clientRecvSnapshot.ReceivedBytes, int64(256); got != want {
t.Fatalf("client receive snapshot received bytes = %d, want %d", got, want)
}
}
func waitForTransferControlLogicalConn(t *testing.T, server *ServerCommon, timeout time.Duration) *LogicalConn {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
logicals := server.GetLogicalConnList()
if len(logicals) > 0 && logicals[0] != nil {
return logicals[0]
}
time.Sleep(10 * time.Millisecond)
}
t.Fatal("timed out waiting for logical connection")
return nil
}
func waitForTransferControlClientConn(t *testing.T, server *ServerCommon, timeout time.Duration) *ClientConn {
return clientConnFromLogical(waitForTransferControlLogicalConn(t, server, timeout))
}
func disableBuiltinTransferControlForClient(client *ClientCommon) {
if client == nil {
return
}
state := client.getTransferState()
state.mu.Lock()
state.controlEnabled = false
state.handler = nil
state.builtinHandler = nil
state.mu.Unlock()
}
func disableBuiltinTransferControlForServer(server *ServerCommon) {
if server == nil {
return
}
state := server.getTransferState()
state.mu.Lock()
state.controlEnabled = false
state.handler = nil
state.builtinHandler = nil
state.mu.Unlock()
}