358 lines
13 KiB
Go
358 lines
13 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestBindTransferControlValidation(t *testing.T) {
|
||
|
|
if err := BindTransferControlClient(nil, TransferControlHandler{}); !errors.Is(err, errTransferControlClientNil) {
|
||
|
|
t.Fatalf("BindTransferControlClient nil error mismatch: %v", err)
|
||
|
|
}
|
||
|
|
if err := BindTransferControlServer(nil, TransferControlHandler{}); !errors.Is(err, errTransferControlServerNil) {
|
||
|
|
t.Fatalf("BindTransferControlServer nil error mismatch: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
client := NewClient()
|
||
|
|
if err := BindTransferControlClient(client, TransferControlHandler{}); !errors.Is(err, errTransferControlHandlerEmpty) {
|
||
|
|
t.Fatalf("BindTransferControlClient empty handler error mismatch: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
server := NewServer()
|
||
|
|
if err := BindTransferControlServer(server, TransferControlHandler{}); !errors.Is(err, errTransferControlHandlerEmpty) {
|
||
|
|
t.Fatalf("BindTransferControlServer empty handler error mismatch: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if _, err := SendTransferBeginClient(context.Background(), nil, TransferBeginRequest{}); !errors.Is(err, errTransferControlClientNil) {
|
||
|
|
t.Fatalf("SendTransferBeginClient nil client error mismatch: %v", err)
|
||
|
|
}
|
||
|
|
if _, err := SendTransferBeginServer(context.Background(), nil, nil, TransferBeginRequest{}); !errors.Is(err, errTransferControlServerNil) {
|
||
|
|
t.Fatalf("SendTransferBeginServer nil server error mismatch: %v", err)
|
||
|
|
}
|
||
|
|
if _, err := SendTransferBeginServer(context.Background(), server, nil, TransferBeginRequest{}); !errors.Is(err, errTransferControlLogicalConnNil) {
|
||
|
|
t.Fatalf("SendTransferBeginServer nil conn error mismatch: %v", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTransferControlRoundTripTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
disableBuiltinTransferControlForServer(server)
|
||
|
|
|
||
|
|
beginReqCh := make(chan TransferBeginRequest, 1)
|
||
|
|
commitReqCh := make(chan TransferCommitRequest, 2)
|
||
|
|
if err := BindTransferControlServer(server, TransferControlHandler{
|
||
|
|
Begin: func(_ *Message, req TransferBeginRequest) (TransferBeginResponse, error) {
|
||
|
|
beginReqCh <- req
|
||
|
|
return TransferBeginResponse{
|
||
|
|
TransferID: req.TransferID,
|
||
|
|
Accepted: true,
|
||
|
|
NextOffset: 512,
|
||
|
|
Missing: []TransferRange{
|
||
|
|
{Offset: 768, Length: 128},
|
||
|
|
},
|
||
|
|
}, nil
|
||
|
|
},
|
||
|
|
Commit: func(_ *Message, req TransferCommitRequest) (TransferCommitResponse, error) {
|
||
|
|
commitReqCh <- req
|
||
|
|
return TransferCommitResponse{
|
||
|
|
TransferID: req.TransferID,
|
||
|
|
Accepted: true,
|
||
|
|
}, nil
|
||
|
|
},
|
||
|
|
}); err != nil {
|
||
|
|
t.Fatalf("BindTransferControlServer failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
addr := server.listener.Addr().String()
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
disableBuiltinTransferControlForClient(client)
|
||
|
|
|
||
|
|
if err := BindTransferControlClient(client, TransferControlHandler{
|
||
|
|
Begin: func(_ *Message, req TransferBeginRequest) (TransferBeginResponse, error) {
|
||
|
|
beginReqCh <- req
|
||
|
|
return TransferBeginResponse{
|
||
|
|
TransferID: req.TransferID,
|
||
|
|
Accepted: true,
|
||
|
|
NextOffset: 256,
|
||
|
|
}, nil
|
||
|
|
},
|
||
|
|
Commit: func(_ *Message, req TransferCommitRequest) (TransferCommitResponse, error) {
|
||
|
|
commitReqCh <- req
|
||
|
|
return TransferCommitResponse{
|
||
|
|
TransferID: req.TransferID,
|
||
|
|
Accepted: true,
|
||
|
|
}, nil
|
||
|
|
},
|
||
|
|
}); err != nil {
|
||
|
|
t.Fatalf("BindTransferControlClient failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := client.Connect("tcp", addr); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = client.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
beginResp, err := SendTransferBeginClient(context.Background(), client, TransferBeginRequest{
|
||
|
|
TransferID: "tx-client",
|
||
|
|
Channel: TransferChannelData,
|
||
|
|
Size: 1024,
|
||
|
|
Checksum: "sha256:demo",
|
||
|
|
Metadata: map[string]string{
|
||
|
|
"name": "demo.bin",
|
||
|
|
},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("SendTransferBeginClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if !beginResp.Accepted || beginResp.TransferID != "tx-client" || beginResp.NextOffset != 512 {
|
||
|
|
t.Fatalf("begin response mismatch: %+v", beginResp)
|
||
|
|
}
|
||
|
|
if len(beginResp.Missing) != 1 || beginResp.Missing[0].Offset != 768 || beginResp.Missing[0].Length != 128 {
|
||
|
|
t.Fatalf("begin response missing mismatch: %+v", beginResp.Missing)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case got := <-beginReqCh:
|
||
|
|
if got.TransferID != "tx-client" || got.Channel != TransferChannelData || got.Size != 1024 || got.Checksum != "sha256:demo" {
|
||
|
|
t.Fatalf("begin request mismatch: %+v", got)
|
||
|
|
}
|
||
|
|
if got.Metadata["name"] != "demo.bin" {
|
||
|
|
t.Fatalf("begin request metadata mismatch: %+v", got.Metadata)
|
||
|
|
}
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("timed out waiting for begin request")
|
||
|
|
}
|
||
|
|
|
||
|
|
commitClientResp, err := SendTransferCommitClient(context.Background(), client, TransferCommitRequest{
|
||
|
|
TransferID: "tx-client",
|
||
|
|
Size: 1024,
|
||
|
|
Checksum: "sha256:demo",
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("SendTransferCommitClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if !commitClientResp.Accepted || commitClientResp.TransferID != "tx-client" {
|
||
|
|
t.Fatalf("client commit response mismatch: %+v", commitClientResp)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case got := <-commitReqCh:
|
||
|
|
if got.TransferID != "tx-client" || got.Size != 1024 || got.Checksum != "sha256:demo" {
|
||
|
|
t.Fatalf("client commit request mismatch: %+v", got)
|
||
|
|
}
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("timed out waiting for client commit request")
|
||
|
|
}
|
||
|
|
|
||
|
|
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
|
||
|
|
conn := clientConnFromLogical(logical)
|
||
|
|
serverScope := serverFileScope(conn)
|
||
|
|
|
||
|
|
clientSnapshot, ok, err := GetClientTransferSnapshotByID(client, "tx-client")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetClientTransferSnapshotByID failed: %v", err)
|
||
|
|
}
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("client snapshot should exist")
|
||
|
|
}
|
||
|
|
if got, want := clientSnapshot.Direction, TransferDirectionSend; got != want {
|
||
|
|
t.Fatalf("client snapshot direction = %v, want %v", got, want)
|
||
|
|
}
|
||
|
|
if got, want := clientSnapshot.Scope, clientFileScope(); got != want {
|
||
|
|
t.Fatalf("client snapshot scope = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := clientSnapshot.State, TransferStateDone; got != want {
|
||
|
|
t.Fatalf("client snapshot state = %v, want %v", got, want)
|
||
|
|
}
|
||
|
|
if got, want := clientSnapshot.AckedBytes, int64(512); got != want {
|
||
|
|
t.Fatalf("client snapshot acked bytes = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := clientSnapshot.Stage, "commit"; got != want {
|
||
|
|
t.Fatalf("client snapshot stage = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
|
||
|
|
serverSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-client")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetServerTransferSnapshotByID failed: %v", err)
|
||
|
|
}
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("server snapshot should exist")
|
||
|
|
}
|
||
|
|
if got, want := serverSnapshot.Direction, TransferDirectionReceive; got != want {
|
||
|
|
t.Fatalf("server snapshot direction = %v, want %v", got, want)
|
||
|
|
}
|
||
|
|
if got, want := serverSnapshot.Scope, serverScope; got != want {
|
||
|
|
t.Fatalf("server snapshot scope = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := serverSnapshot.RuntimeScope, serverTransportScope(conn); got != want {
|
||
|
|
t.Fatalf("server snapshot runtime scope = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := serverSnapshot.TransportGeneration, uint64(1); got != want {
|
||
|
|
t.Fatalf("server snapshot transport generation = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := serverSnapshot.State, TransferStateDone; got != want {
|
||
|
|
t.Fatalf("server snapshot state = %v, want %v", got, want)
|
||
|
|
}
|
||
|
|
if got, want := serverSnapshot.ReceivedBytes, int64(512); got != want {
|
||
|
|
t.Fatalf("server snapshot received bytes = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := serverSnapshot.Stage, "commit"; got != want {
|
||
|
|
t.Fatalf("server snapshot stage = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
|
||
|
|
beginServerResp, err := SendTransferBeginServer(context.Background(), server, conn, TransferBeginRequest{
|
||
|
|
TransferID: "tx-server",
|
||
|
|
Channel: TransferChannelControl,
|
||
|
|
Size: 512,
|
||
|
|
Checksum: "sha256:server",
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("SendTransferBeginServer failed: %v", err)
|
||
|
|
}
|
||
|
|
if !beginServerResp.Accepted || beginServerResp.TransferID != "tx-server" || beginServerResp.NextOffset != 256 {
|
||
|
|
t.Fatalf("server begin response mismatch: %+v", beginServerResp)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case got := <-beginReqCh:
|
||
|
|
if got.TransferID != "tx-server" || got.Channel != TransferChannelControl || got.Size != 512 || got.Checksum != "sha256:server" {
|
||
|
|
t.Fatalf("server begin request mismatch: %+v", got)
|
||
|
|
}
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("timed out waiting for server begin request")
|
||
|
|
}
|
||
|
|
|
||
|
|
commitResp, err := SendTransferCommitServer(context.Background(), server, conn, TransferCommitRequest{
|
||
|
|
TransferID: "tx-server",
|
||
|
|
Size: 512,
|
||
|
|
Checksum: "sha256:server",
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("SendTransferCommitServer failed: %v", err)
|
||
|
|
}
|
||
|
|
if !commitResp.Accepted || commitResp.TransferID != "tx-server" {
|
||
|
|
t.Fatalf("commit response mismatch: %+v", commitResp)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case got := <-commitReqCh:
|
||
|
|
if got.TransferID != "tx-server" || got.Size != 512 || got.Checksum != "sha256:server" {
|
||
|
|
t.Fatalf("server commit request mismatch: %+v", got)
|
||
|
|
}
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("timed out waiting for server commit request")
|
||
|
|
}
|
||
|
|
|
||
|
|
serverSendSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-server")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetServerTransferSnapshotByID server-send failed: %v", err)
|
||
|
|
}
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("server send snapshot should exist")
|
||
|
|
}
|
||
|
|
if got, want := serverSendSnapshot.Direction, TransferDirectionSend; got != want {
|
||
|
|
t.Fatalf("server send snapshot direction = %v, want %v", got, want)
|
||
|
|
}
|
||
|
|
if got, want := serverSendSnapshot.Scope, serverScope; got != want {
|
||
|
|
t.Fatalf("server send snapshot scope = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := serverSendSnapshot.RuntimeScope, serverTransportScope(conn); got != want {
|
||
|
|
t.Fatalf("server send snapshot runtime scope = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := serverSendSnapshot.TransportGeneration, uint64(1); got != want {
|
||
|
|
t.Fatalf("server send snapshot transport generation = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := serverSendSnapshot.Channel, TransferChannelControl; got != want {
|
||
|
|
t.Fatalf("server send snapshot channel = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := serverSendSnapshot.State, TransferStateDone; got != want {
|
||
|
|
t.Fatalf("server send snapshot state = %v, want %v", got, want)
|
||
|
|
}
|
||
|
|
if got, want := serverSendSnapshot.AckedBytes, int64(256); got != want {
|
||
|
|
t.Fatalf("server send snapshot acked bytes = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
|
||
|
|
clientRecvSnapshot, ok, err := GetClientTransferSnapshotByID(client, "tx-server")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetClientTransferSnapshotByID client-recv failed: %v", err)
|
||
|
|
}
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("client receive snapshot should exist")
|
||
|
|
}
|
||
|
|
if got, want := clientRecvSnapshot.Direction, TransferDirectionReceive; got != want {
|
||
|
|
t.Fatalf("client receive snapshot direction = %v, want %v", got, want)
|
||
|
|
}
|
||
|
|
if got, want := clientRecvSnapshot.Scope, clientFileScope(); got != want {
|
||
|
|
t.Fatalf("client receive snapshot scope = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := clientRecvSnapshot.Channel, TransferChannelControl; got != want {
|
||
|
|
t.Fatalf("client receive snapshot channel = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := clientRecvSnapshot.State, TransferStateDone; got != want {
|
||
|
|
t.Fatalf("client receive snapshot state = %v, want %v", got, want)
|
||
|
|
}
|
||
|
|
if got, want := clientRecvSnapshot.ReceivedBytes, int64(256); got != want {
|
||
|
|
t.Fatalf("client receive snapshot received bytes = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func waitForTransferControlLogicalConn(t *testing.T, server *ServerCommon, timeout time.Duration) *LogicalConn {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
deadline := time.Now().Add(timeout)
|
||
|
|
for time.Now().Before(deadline) {
|
||
|
|
logicals := server.GetLogicalConnList()
|
||
|
|
if len(logicals) > 0 && logicals[0] != nil {
|
||
|
|
return logicals[0]
|
||
|
|
}
|
||
|
|
time.Sleep(10 * time.Millisecond)
|
||
|
|
}
|
||
|
|
t.Fatal("timed out waiting for logical connection")
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func waitForTransferControlClientConn(t *testing.T, server *ServerCommon, timeout time.Duration) *ClientConn {
|
||
|
|
return clientConnFromLogical(waitForTransferControlLogicalConn(t, server, timeout))
|
||
|
|
}
|
||
|
|
|
||
|
|
func disableBuiltinTransferControlForClient(client *ClientCommon) {
|
||
|
|
if client == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
state := client.getTransferState()
|
||
|
|
state.mu.Lock()
|
||
|
|
state.controlEnabled = false
|
||
|
|
state.handler = nil
|
||
|
|
state.builtinHandler = nil
|
||
|
|
state.mu.Unlock()
|
||
|
|
}
|
||
|
|
|
||
|
|
func disableBuiltinTransferControlForServer(server *ServerCommon) {
|
||
|
|
if server == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
state := server.getTransferState()
|
||
|
|
state.mu.Lock()
|
||
|
|
state.controlEnabled = false
|
||
|
|
state.handler = nil
|
||
|
|
state.builtinHandler = nil
|
||
|
|
state.mu.Unlock()
|
||
|
|
}
|