811 lines
24 KiB
Go
811 lines
24 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"context"
|
||
|
|
"crypto/sha256"
|
||
|
|
"encoding/binary"
|
||
|
|
"encoding/hex"
|
||
|
|
"errors"
|
||
|
|
"io"
|
||
|
|
"net"
|
||
|
|
"strings"
|
||
|
|
"sync"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
itransfer "b612.me/notify/internal/transfer"
|
||
|
|
)
|
||
|
|
|
||
|
|
type transferBytesSource struct {
|
||
|
|
data []byte
|
||
|
|
failAtOffset int64
|
||
|
|
failErr error
|
||
|
|
}
|
||
|
|
|
||
|
|
type transferBlockingSource struct {
|
||
|
|
data []byte
|
||
|
|
releaseCh chan struct{}
|
||
|
|
mu sync.Mutex
|
||
|
|
started int
|
||
|
|
active int
|
||
|
|
maxActive int
|
||
|
|
startedCh chan struct{}
|
||
|
|
}
|
||
|
|
|
||
|
|
type transferDiscardStream struct{}
|
||
|
|
|
||
|
|
func (transferDiscardStream) Read([]byte) (int, error) { return 0, io.EOF }
|
||
|
|
func (transferDiscardStream) Write(p []byte) (int, error) { return len(p), nil }
|
||
|
|
func (transferDiscardStream) Close() error { return nil }
|
||
|
|
func (transferDiscardStream) ID() string { return "discard" }
|
||
|
|
func (transferDiscardStream) Channel() StreamChannel { return StreamDataChannel }
|
||
|
|
func (transferDiscardStream) Metadata() StreamMetadata { return nil }
|
||
|
|
func (transferDiscardStream) Context() context.Context { return context.Background() }
|
||
|
|
func (transferDiscardStream) LogicalConn() *LogicalConn { return nil }
|
||
|
|
func (transferDiscardStream) TransportConn() *TransportConn { return nil }
|
||
|
|
func (transferDiscardStream) TransportGeneration() uint64 { return 0 }
|
||
|
|
func (transferDiscardStream) LocalAddr() net.Addr { return nil }
|
||
|
|
func (transferDiscardStream) RemoteAddr() net.Addr { return nil }
|
||
|
|
func (transferDiscardStream) CloseWrite() error { return nil }
|
||
|
|
func (transferDiscardStream) Reset(error) error { return nil }
|
||
|
|
func (transferDiscardStream) SetDeadline(time.Time) error { return nil }
|
||
|
|
func (transferDiscardStream) SetReadDeadline(time.Time) error {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
func (transferDiscardStream) SetWriteDeadline(time.Time) error {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
type transferWriteCountStream struct {
|
||
|
|
buf bytes.Buffer
|
||
|
|
writes int
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferWriteCountStream) Read(p []byte) (int, error) { return s.buf.Read(p) }
|
||
|
|
func (s *transferWriteCountStream) Write(p []byte) (int, error) { s.writes++; return s.buf.Write(p) }
|
||
|
|
func (s *transferWriteCountStream) Close() error { return nil }
|
||
|
|
func (s *transferWriteCountStream) ID() string { return "write-count" }
|
||
|
|
func (s *transferWriteCountStream) Channel() StreamChannel { return StreamDataChannel }
|
||
|
|
func (s *transferWriteCountStream) Metadata() StreamMetadata { return nil }
|
||
|
|
func (s *transferWriteCountStream) Context() context.Context { return context.Background() }
|
||
|
|
func (s *transferWriteCountStream) LogicalConn() *LogicalConn { return nil }
|
||
|
|
func (s *transferWriteCountStream) TransportConn() *TransportConn { return nil }
|
||
|
|
func (s *transferWriteCountStream) TransportGeneration() uint64 { return 0 }
|
||
|
|
func (s *transferWriteCountStream) LocalAddr() net.Addr { return nil }
|
||
|
|
func (s *transferWriteCountStream) RemoteAddr() net.Addr { return nil }
|
||
|
|
func (s *transferWriteCountStream) CloseWrite() error { return nil }
|
||
|
|
func (s *transferWriteCountStream) Reset(error) error { return nil }
|
||
|
|
func (s *transferWriteCountStream) SetDeadline(time.Time) error { return nil }
|
||
|
|
func (s *transferWriteCountStream) SetReadDeadline(time.Time) error {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
func (s *transferWriteCountStream) SetWriteDeadline(time.Time) error {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
func (s *transferWriteCountStream) WriteCount() int { return s.writes }
|
||
|
|
func (s *transferWriteCountStream) Bytes() []byte { return append([]byte(nil), s.buf.Bytes()...) }
|
||
|
|
|
||
|
|
func newTransferBytesSource(data []byte) *transferBytesSource {
|
||
|
|
return &transferBytesSource{
|
||
|
|
data: append([]byte(nil), data...),
|
||
|
|
failAtOffset: -1,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func newTransferBlockingSource(data []byte) *transferBlockingSource {
|
||
|
|
return &transferBlockingSource{
|
||
|
|
data: append([]byte(nil), data...),
|
||
|
|
releaseCh: make(chan struct{}, len(data)+1),
|
||
|
|
startedCh: make(chan struct{}, len(data)+1),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferBytesSource) Size() int64 {
|
||
|
|
if s == nil {
|
||
|
|
return 0
|
||
|
|
}
|
||
|
|
return int64(len(s.data))
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferBytesSource) ReadAt(p []byte, off int64) (int, error) {
|
||
|
|
if s == nil {
|
||
|
|
return 0, io.EOF
|
||
|
|
}
|
||
|
|
if s.failAtOffset >= 0 && off >= s.failAtOffset {
|
||
|
|
if s.failErr != nil {
|
||
|
|
return 0, s.failErr
|
||
|
|
}
|
||
|
|
return 0, errors.New("injected transfer source failure")
|
||
|
|
}
|
||
|
|
if off >= int64(len(s.data)) {
|
||
|
|
return 0, io.EOF
|
||
|
|
}
|
||
|
|
n := copy(p, s.data[off:])
|
||
|
|
if int(off)+n >= len(s.data) {
|
||
|
|
return n, io.EOF
|
||
|
|
}
|
||
|
|
return n, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferBlockingSource) Size() int64 {
|
||
|
|
if s == nil {
|
||
|
|
return 0
|
||
|
|
}
|
||
|
|
return int64(len(s.data))
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferBlockingSource) ReadAt(p []byte, off int64) (int, error) {
|
||
|
|
if s == nil {
|
||
|
|
return 0, io.EOF
|
||
|
|
}
|
||
|
|
s.mu.Lock()
|
||
|
|
s.started++
|
||
|
|
s.active++
|
||
|
|
if s.active > s.maxActive {
|
||
|
|
s.maxActive = s.active
|
||
|
|
}
|
||
|
|
s.mu.Unlock()
|
||
|
|
s.startedCh <- struct{}{}
|
||
|
|
<-s.releaseCh
|
||
|
|
s.mu.Lock()
|
||
|
|
s.active--
|
||
|
|
s.mu.Unlock()
|
||
|
|
if off >= int64(len(s.data)) {
|
||
|
|
return 0, io.EOF
|
||
|
|
}
|
||
|
|
n := copy(p, s.data[off:])
|
||
|
|
if int(off)+n >= len(s.data) {
|
||
|
|
return n, io.EOF
|
||
|
|
}
|
||
|
|
return n, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferBlockingSource) release(n int) {
|
||
|
|
if s == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
for i := 0; i < n; i++ {
|
||
|
|
s.releaseCh <- struct{}{}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferBlockingSource) waitStarted(t *testing.T, want int, timeout time.Duration) {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
deadline := time.After(timeout)
|
||
|
|
for {
|
||
|
|
s.mu.Lock()
|
||
|
|
started := s.started
|
||
|
|
s.mu.Unlock()
|
||
|
|
if started >= want {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
select {
|
||
|
|
case <-s.startedCh:
|
||
|
|
case <-deadline:
|
||
|
|
t.Fatalf("timed out waiting for %d blocking reads, got %d", want, started)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferBlockingSource) startedCount() int {
|
||
|
|
if s == nil {
|
||
|
|
return 0
|
||
|
|
}
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
return s.started
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferBlockingSource) maxActiveCount() int {
|
||
|
|
if s == nil {
|
||
|
|
return 0
|
||
|
|
}
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
return s.maxActive
|
||
|
|
}
|
||
|
|
|
||
|
|
type transferMemorySink struct {
|
||
|
|
mu sync.Mutex
|
||
|
|
data []byte
|
||
|
|
closed bool
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferMemorySink) WriteAt(p []byte, off int64) (int, error) {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
if s.closed {
|
||
|
|
return 0, io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
if off < 0 {
|
||
|
|
return 0, errTransferSegmentOffset
|
||
|
|
}
|
||
|
|
end := int(off) + len(p)
|
||
|
|
if end > len(s.data) {
|
||
|
|
grown := make([]byte, end)
|
||
|
|
copy(grown, s.data)
|
||
|
|
s.data = grown
|
||
|
|
}
|
||
|
|
copy(s.data[off:], p)
|
||
|
|
return len(p), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferMemorySink) ReadAt(p []byte, off int64) (int, error) {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
if off >= int64(len(s.data)) {
|
||
|
|
return 0, io.EOF
|
||
|
|
}
|
||
|
|
n := copy(p, s.data[off:])
|
||
|
|
if int(off)+n >= len(s.data) {
|
||
|
|
return n, io.EOF
|
||
|
|
}
|
||
|
|
return n, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferMemorySink) Close() error {
|
||
|
|
s.mu.Lock()
|
||
|
|
s.closed = true
|
||
|
|
s.mu.Unlock()
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *transferMemorySink) Bytes() []byte {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
return append([]byte(nil), s.data...)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTransferClientToServerRoundTripTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
sink := &transferMemorySink{}
|
||
|
|
acceptCh := make(chan TransferAcceptInfo, 1)
|
||
|
|
server.SetTransferHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, error) {
|
||
|
|
acceptCh <- info
|
||
|
|
return TransferReceiveOptions{
|
||
|
|
Sink: sink,
|
||
|
|
VerifyChecksum: true,
|
||
|
|
}, nil
|
||
|
|
})
|
||
|
|
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() { _ = server.Stop() }()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() { _ = client.Stop() }()
|
||
|
|
|
||
|
|
data := bytes.Repeat([]byte("client-transfer-roundtrip-"), 4096)
|
||
|
|
checksum := transferTestChecksum(data)
|
||
|
|
handle, err := client.SendTransfer(context.Background(), TransferSendOptions{
|
||
|
|
Descriptor: TransferDescriptor{
|
||
|
|
ID: "tx-client-server",
|
||
|
|
Channel: TransferChannelData,
|
||
|
|
Size: int64(len(data)),
|
||
|
|
Checksum: checksum,
|
||
|
|
Metadata: map[string]string{"name": "client.bin"},
|
||
|
|
},
|
||
|
|
Source: newTransferBytesSource(data),
|
||
|
|
ChunkSize: 16 * 1024,
|
||
|
|
Parallelism: 2,
|
||
|
|
MaxInflightBytes: 64 * 1024,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("SendTransfer failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := handle.Wait(context.Background()); err != nil {
|
||
|
|
t.Fatalf("transfer wait failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
select {
|
||
|
|
case info := <-acceptCh:
|
||
|
|
if info.Descriptor.ID != "tx-client-server" {
|
||
|
|
t.Fatalf("accept descriptor id = %q, want %q", info.Descriptor.ID, "tx-client-server")
|
||
|
|
}
|
||
|
|
if info.Descriptor.Metadata["name"] != "client.bin" {
|
||
|
|
t.Fatalf("accept metadata mismatch: %+v", info.Descriptor.Metadata)
|
||
|
|
}
|
||
|
|
case <-time.After(2 * time.Second):
|
||
|
|
t.Fatal("timed out waiting for transfer accept info")
|
||
|
|
}
|
||
|
|
|
||
|
|
if got := sink.Bytes(); !bytes.Equal(got, data) {
|
||
|
|
t.Fatalf("received data mismatch: got %d bytes, want %d", len(got), len(data))
|
||
|
|
}
|
||
|
|
|
||
|
|
clientSnapshot, ok, err := GetClientTransferSnapshotByID(client, "tx-client-server")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetClientTransferSnapshotByID failed: %v", err)
|
||
|
|
}
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("client transfer snapshot missing")
|
||
|
|
}
|
||
|
|
if clientSnapshot.State != TransferStateDone || clientSnapshot.AckedBytes != int64(len(data)) {
|
||
|
|
t.Fatalf("client snapshot mismatch: %+v", clientSnapshot)
|
||
|
|
}
|
||
|
|
if got, want := clientSnapshot.ChunkSize, 16*1024; got != want {
|
||
|
|
t.Fatalf("client snapshot chunk size = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := clientSnapshot.Parallelism, 2; got != want {
|
||
|
|
t.Fatalf("client snapshot parallelism = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := clientSnapshot.MaxInflightBytes, int64(64*1024); got != want {
|
||
|
|
t.Fatalf("client snapshot max inflight bytes = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
|
||
|
|
serverSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-client-server")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetServerTransferSnapshotByID failed: %v", err)
|
||
|
|
}
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("server transfer snapshot missing")
|
||
|
|
}
|
||
|
|
if serverSnapshot.State != TransferStateDone || serverSnapshot.ReceivedBytes != int64(len(data)) {
|
||
|
|
t.Fatalf("server snapshot mismatch: %+v", serverSnapshot)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTransferServerToClientRoundTripTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() { _ = server.Stop() }()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
sink := &transferMemorySink{}
|
||
|
|
client.SetTransferHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, error) {
|
||
|
|
return TransferReceiveOptions{
|
||
|
|
Sink: sink,
|
||
|
|
VerifyChecksum: true,
|
||
|
|
}, nil
|
||
|
|
})
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() { _ = client.Stop() }()
|
||
|
|
|
||
|
|
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
|
||
|
|
data := bytes.Repeat([]byte("server-transfer-roundtrip-"), 3072)
|
||
|
|
checksum := transferTestChecksum(data)
|
||
|
|
|
||
|
|
handle, err := server.SendTransferLogical(context.Background(), logical, TransferSendOptions{
|
||
|
|
Descriptor: TransferDescriptor{
|
||
|
|
ID: "tx-server-client",
|
||
|
|
Channel: TransferChannelControl,
|
||
|
|
Size: int64(len(data)),
|
||
|
|
Checksum: checksum,
|
||
|
|
},
|
||
|
|
Source: newTransferBytesSource(data),
|
||
|
|
ChunkSize: 8 * 1024,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("SendTransferLogical failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := handle.Wait(context.Background()); err != nil {
|
||
|
|
t.Fatalf("transfer wait failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if got := sink.Bytes(); !bytes.Equal(got, data) {
|
||
|
|
t.Fatalf("received data mismatch: got %d bytes, want %d", len(got), len(data))
|
||
|
|
}
|
||
|
|
|
||
|
|
serverSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-server-client")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetServerTransferSnapshotByID failed: %v", err)
|
||
|
|
}
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("server transfer snapshot missing")
|
||
|
|
}
|
||
|
|
if serverSnapshot.State != TransferStateDone || serverSnapshot.AckedBytes != int64(len(data)) {
|
||
|
|
t.Fatalf("server snapshot mismatch: %+v", serverSnapshot)
|
||
|
|
}
|
||
|
|
|
||
|
|
clientSnapshot, ok, err := GetClientTransferSnapshotByID(client, "tx-server-client")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetClientTransferSnapshotByID failed: %v", err)
|
||
|
|
}
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("client transfer snapshot missing")
|
||
|
|
}
|
||
|
|
if clientSnapshot.State != TransferStateDone || clientSnapshot.ReceivedBytes != int64(len(data)) {
|
||
|
|
t.Fatalf("client snapshot mismatch: %+v", clientSnapshot)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestTransferResumeAfterPartialFailureTCP(t *testing.T) {
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
sink := &transferMemorySink{}
|
||
|
|
server.SetTransferHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, error) {
|
||
|
|
return TransferReceiveOptions{
|
||
|
|
Sink: sink,
|
||
|
|
VerifyChecksum: true,
|
||
|
|
}, nil
|
||
|
|
})
|
||
|
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() { _ = server.Stop() }()
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() { _ = client.Stop() }()
|
||
|
|
|
||
|
|
data := bytes.Repeat([]byte("resume-transfer-"), 8192)
|
||
|
|
checksum := transferTestChecksum(data)
|
||
|
|
firstSourceErr := errors.New("injected transfer source failure")
|
||
|
|
firstSource := newTransferBytesSource(data)
|
||
|
|
firstSource.failAtOffset = 32 * 1024
|
||
|
|
firstSource.failErr = firstSourceErr
|
||
|
|
|
||
|
|
firstHandle, err := client.SendTransfer(context.Background(), TransferSendOptions{
|
||
|
|
Descriptor: TransferDescriptor{
|
||
|
|
ID: "tx-resume",
|
||
|
|
Channel: TransferChannelData,
|
||
|
|
Size: int64(len(data)),
|
||
|
|
Checksum: checksum,
|
||
|
|
},
|
||
|
|
Source: firstSource,
|
||
|
|
ChunkSize: 16 * 1024,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("first SendTransfer failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := firstHandle.Wait(context.Background()); err == nil || !strings.Contains(err.Error(), firstSourceErr.Error()) {
|
||
|
|
t.Fatalf("first transfer wait error = %v, want injected source failure", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
partial := waitForTransferSnapshot(t, server, "tx-resume", 3*time.Second)
|
||
|
|
if partial.State == TransferStateDone {
|
||
|
|
t.Fatalf("partial snapshot unexpectedly done: %+v", partial)
|
||
|
|
}
|
||
|
|
|
||
|
|
secondHandle, err := client.SendTransfer(context.Background(), TransferSendOptions{
|
||
|
|
Descriptor: TransferDescriptor{
|
||
|
|
ID: "tx-resume",
|
||
|
|
Channel: TransferChannelData,
|
||
|
|
Size: int64(len(data)),
|
||
|
|
Checksum: checksum,
|
||
|
|
},
|
||
|
|
Source: newTransferBytesSource(data),
|
||
|
|
ChunkSize: 16 * 1024,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("second SendTransfer failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := secondHandle.Wait(context.Background()); err != nil {
|
||
|
|
t.Fatalf("second transfer wait failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if got := sink.Bytes(); !bytes.Equal(got, data) {
|
||
|
|
t.Fatalf("received data mismatch after resume: got %d bytes, want %d", len(got), len(data))
|
||
|
|
}
|
||
|
|
|
||
|
|
clientSnapshot, ok := latestClientTransferSnapshotByID(t, client, "tx-resume")
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("client transfer snapshot missing")
|
||
|
|
}
|
||
|
|
if clientSnapshot.State != TransferStateDone || clientSnapshot.AckedBytes != int64(len(data)) {
|
||
|
|
t.Fatalf("client snapshot mismatch after resume: %+v", clientSnapshot)
|
||
|
|
}
|
||
|
|
|
||
|
|
serverSnapshot, ok := latestServerTransferSnapshotByID(t, server, "tx-resume")
|
||
|
|
if !ok {
|
||
|
|
t.Fatal("server transfer snapshot missing")
|
||
|
|
}
|
||
|
|
if serverSnapshot.State != TransferStateDone || serverSnapshot.ReceivedBytes != int64(len(data)) {
|
||
|
|
t.Fatalf("server snapshot mismatch after resume: %+v", serverSnapshot)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSendTransferSegmentsDoesNotCallResumeDuringSteadyState(t *testing.T) {
|
||
|
|
data := bytes.Repeat([]byte("steady-state-transfer-"), 4)
|
||
|
|
resumeCalls := 0
|
||
|
|
target := transferSendTarget{
|
||
|
|
sequenceEn: func(value interface{}) ([]byte, error) {
|
||
|
|
return []byte("segment"), nil
|
||
|
|
},
|
||
|
|
sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) {
|
||
|
|
resumeCalls++
|
||
|
|
return TransferResumeResponse{TransferID: req.TransferID, Accepted: true}, nil
|
||
|
|
},
|
||
|
|
}
|
||
|
|
opt := TransferSendOptions{
|
||
|
|
Descriptor: TransferDescriptor{
|
||
|
|
ID: "tx-no-steady-resume",
|
||
|
|
Channel: TransferChannelData,
|
||
|
|
Size: int64(len(data)),
|
||
|
|
},
|
||
|
|
Source: newTransferBytesSource(data),
|
||
|
|
ChunkSize: 16,
|
||
|
|
MaxInflightBytes: 16,
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := sendTransferSegments(context.Background(), transferDiscardStream{}, target, opt, 0, transferSendHooks{}); err != nil {
|
||
|
|
t.Fatalf("sendTransferSegments failed: %v", err)
|
||
|
|
}
|
||
|
|
if resumeCalls != 0 {
|
||
|
|
t.Fatalf("sendResume call count = %d, want 0", resumeCalls)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSendTransferSegmentsAggregatesSmallFrames(t *testing.T) {
|
||
|
|
data := bytes.Repeat([]byte("aggregate-transfer-"), 16)
|
||
|
|
chunkSize := 8
|
||
|
|
stream := &transferWriteCountStream{}
|
||
|
|
target := transferSendTarget{
|
||
|
|
sequenceEn: func(value interface{}) ([]byte, error) {
|
||
|
|
segment, ok := value.(itransfer.Segment)
|
||
|
|
if !ok {
|
||
|
|
t.Fatalf("encoded value type = %T, want itransfer.Segment", value)
|
||
|
|
}
|
||
|
|
return append([]byte(nil), segment.Payload...), nil
|
||
|
|
},
|
||
|
|
}
|
||
|
|
opt := TransferSendOptions{
|
||
|
|
Descriptor: TransferDescriptor{
|
||
|
|
ID: "tx-aggregate",
|
||
|
|
Channel: TransferChannelData,
|
||
|
|
Size: int64(len(data)),
|
||
|
|
},
|
||
|
|
Source: newTransferBytesSource(data),
|
||
|
|
ChunkSize: chunkSize,
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := sendTransferSegments(context.Background(), stream, target, opt, 0, transferSendHooks{}); err != nil {
|
||
|
|
t.Fatalf("sendTransferSegments failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
expectedFrames := (len(data) + chunkSize - 1) / chunkSize
|
||
|
|
if got := stream.WriteCount(); got >= expectedFrames {
|
||
|
|
t.Fatalf("write count = %d, want less than frame count %d", got, expectedFrames)
|
||
|
|
}
|
||
|
|
if got := countTransferFrames(stream.Bytes()); got != expectedFrames {
|
||
|
|
t.Fatalf("frame count = %d, want %d", got, expectedFrames)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSendTransferSegmentsUsesParallelReadPrefetch(t *testing.T) {
|
||
|
|
data := bytes.Repeat([]byte("p"), 32)
|
||
|
|
source := newTransferBlockingSource(data)
|
||
|
|
target := transferSendTarget{
|
||
|
|
sequenceEn: func(value interface{}) ([]byte, error) {
|
||
|
|
segment, ok := value.(itransfer.Segment)
|
||
|
|
if !ok {
|
||
|
|
t.Fatalf("encoded value type = %T, want itransfer.Segment", value)
|
||
|
|
}
|
||
|
|
return append([]byte(nil), segment.Payload...), nil
|
||
|
|
},
|
||
|
|
}
|
||
|
|
opt := TransferSendOptions{
|
||
|
|
Descriptor: TransferDescriptor{
|
||
|
|
ID: "tx-parallel-prefetch",
|
||
|
|
Channel: TransferChannelData,
|
||
|
|
Size: int64(len(data)),
|
||
|
|
},
|
||
|
|
Source: source,
|
||
|
|
ChunkSize: 8,
|
||
|
|
Parallelism: 4,
|
||
|
|
MaxInflightBytes: 64,
|
||
|
|
}
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
errCh <- sendTransferSegments(context.Background(), transferDiscardStream{}, target, opt, 0, transferSendHooks{})
|
||
|
|
}()
|
||
|
|
|
||
|
|
source.waitStarted(t, 4, time.Second)
|
||
|
|
source.release(8)
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("sendTransferSegments failed: %v", err)
|
||
|
|
}
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("timed out waiting for parallel sendTransferSegments")
|
||
|
|
}
|
||
|
|
|
||
|
|
if got := source.maxActiveCount(); got < 2 {
|
||
|
|
t.Fatalf("max active reads = %d, want at least 2", got)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSendTransferSegmentsMaxInflightBytesCapsParallelReads(t *testing.T) {
|
||
|
|
data := bytes.Repeat([]byte("w"), 32)
|
||
|
|
source := newTransferBlockingSource(data)
|
||
|
|
target := transferSendTarget{
|
||
|
|
sequenceEn: func(value interface{}) ([]byte, error) {
|
||
|
|
segment, ok := value.(itransfer.Segment)
|
||
|
|
if !ok {
|
||
|
|
t.Fatalf("encoded value type = %T, want itransfer.Segment", value)
|
||
|
|
}
|
||
|
|
return append([]byte(nil), segment.Payload...), nil
|
||
|
|
},
|
||
|
|
}
|
||
|
|
opt := TransferSendOptions{
|
||
|
|
Descriptor: TransferDescriptor{
|
||
|
|
ID: "tx-window-prefetch",
|
||
|
|
Channel: TransferChannelData,
|
||
|
|
Size: int64(len(data)),
|
||
|
|
},
|
||
|
|
Source: source,
|
||
|
|
ChunkSize: 8,
|
||
|
|
Parallelism: 4,
|
||
|
|
MaxInflightBytes: 16,
|
||
|
|
}
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
errCh <- sendTransferSegments(context.Background(), transferDiscardStream{}, target, opt, 0, transferSendHooks{})
|
||
|
|
}()
|
||
|
|
|
||
|
|
source.waitStarted(t, 2, time.Second)
|
||
|
|
time.Sleep(50 * time.Millisecond)
|
||
|
|
if got := source.startedCount(); got != 2 {
|
||
|
|
t.Fatalf("started reads = %d, want 2 before release", got)
|
||
|
|
}
|
||
|
|
|
||
|
|
source.release(8)
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("sendTransferSegments failed: %v", err)
|
||
|
|
}
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("timed out waiting for window-limited sendTransferSegments")
|
||
|
|
}
|
||
|
|
|
||
|
|
if got := source.maxActiveCount(); got > 2 {
|
||
|
|
t.Fatalf("max active reads = %d, want at most 2", got)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSendTransferSegmentsNormalizesDirectOptions(t *testing.T) {
|
||
|
|
data := bytes.Repeat([]byte("n"), 32)
|
||
|
|
stream := &transferWriteCountStream{}
|
||
|
|
target := transferSendTarget{
|
||
|
|
sequenceEn: func(value interface{}) ([]byte, error) {
|
||
|
|
segment, ok := value.(itransfer.Segment)
|
||
|
|
if !ok {
|
||
|
|
t.Fatalf("encoded value type = %T, want itransfer.Segment", value)
|
||
|
|
}
|
||
|
|
return append([]byte(nil), segment.Payload...), nil
|
||
|
|
},
|
||
|
|
}
|
||
|
|
opt := TransferSendOptions{
|
||
|
|
Descriptor: TransferDescriptor{
|
||
|
|
ID: "tx-direct-normalize",
|
||
|
|
Channel: TransferChannelData,
|
||
|
|
},
|
||
|
|
Source: newTransferBytesSource(data),
|
||
|
|
Parallelism: 4,
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := sendTransferSegments(context.Background(), stream, target, opt, 0, transferSendHooks{}); err != nil {
|
||
|
|
t.Fatalf("sendTransferSegments failed: %v", err)
|
||
|
|
}
|
||
|
|
if got := countTransferFrames(stream.Bytes()); got == 0 {
|
||
|
|
t.Fatal("frame count = 0, want at least 1")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func waitForTransferReceivedBytes(t *testing.T, server *ServerCommon, transferID string, minBytes int64, timeout time.Duration) TransferSnapshot {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
deadline := time.Now().Add(timeout)
|
||
|
|
for time.Now().Before(deadline) {
|
||
|
|
snapshot, ok := latestServerTransferSnapshotByID(t, server, transferID)
|
||
|
|
if ok && snapshot.ReceivedBytes >= minBytes {
|
||
|
|
return snapshot
|
||
|
|
}
|
||
|
|
time.Sleep(10 * time.Millisecond)
|
||
|
|
}
|
||
|
|
t.Fatalf("timed out waiting for server transfer snapshot %q to reach %d bytes", transferID, minBytes)
|
||
|
|
return TransferSnapshot{}
|
||
|
|
}
|
||
|
|
|
||
|
|
func waitForTransferSnapshot(t *testing.T, server *ServerCommon, transferID string, timeout time.Duration) TransferSnapshot {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
deadline := time.Now().Add(timeout)
|
||
|
|
for time.Now().Before(deadline) {
|
||
|
|
snapshot, ok := latestServerTransferSnapshotByID(t, server, transferID)
|
||
|
|
if ok {
|
||
|
|
return snapshot
|
||
|
|
}
|
||
|
|
time.Sleep(10 * time.Millisecond)
|
||
|
|
}
|
||
|
|
t.Fatalf("timed out waiting for server transfer snapshot %q to appear", transferID)
|
||
|
|
return TransferSnapshot{}
|
||
|
|
}
|
||
|
|
|
||
|
|
func latestClientTransferSnapshotByID(t *testing.T, client *ClientCommon, transferID string) (TransferSnapshot, bool) {
|
||
|
|
t.Helper()
|
||
|
|
if client == nil || transferID == "" {
|
||
|
|
return TransferSnapshot{}, false
|
||
|
|
}
|
||
|
|
snapshots, err := GetClientTransferSnapshots(client)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetClientTransferSnapshots failed: %v", err)
|
||
|
|
}
|
||
|
|
return latestTransferSnapshotByID(snapshots, transferID)
|
||
|
|
}
|
||
|
|
|
||
|
|
func latestServerTransferSnapshotByID(t *testing.T, server *ServerCommon, transferID string) (TransferSnapshot, bool) {
|
||
|
|
t.Helper()
|
||
|
|
if server == nil || transferID == "" {
|
||
|
|
return TransferSnapshot{}, false
|
||
|
|
}
|
||
|
|
snapshots, err := GetServerTransferSnapshots(server)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetServerTransferSnapshots failed: %v", err)
|
||
|
|
}
|
||
|
|
return latestTransferSnapshotByID(snapshots, transferID)
|
||
|
|
}
|
||
|
|
|
||
|
|
func latestTransferSnapshotByID(snapshots []TransferSnapshot, transferID string) (TransferSnapshot, bool) {
|
||
|
|
var matched TransferSnapshot
|
||
|
|
found := false
|
||
|
|
for _, snapshot := range snapshots {
|
||
|
|
if snapshot.ID != transferID {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if !found || snapshot.UpdatedAt.After(matched.UpdatedAt) || (snapshot.UpdatedAt.Equal(matched.UpdatedAt) && snapshot.ReceivedBytes > matched.ReceivedBytes) {
|
||
|
|
matched = snapshot
|
||
|
|
found = true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return matched, found
|
||
|
|
}
|
||
|
|
|
||
|
|
func transferTestChecksum(data []byte) string {
|
||
|
|
sum := sha256.Sum256(data)
|
||
|
|
return hex.EncodeToString(sum[:])
|
||
|
|
}
|
||
|
|
|
||
|
|
func countTransferFrames(data []byte) int {
|
||
|
|
count := 0
|
||
|
|
for len(data) > 0 {
|
||
|
|
if len(data) < transferFrameHeaderSize {
|
||
|
|
return count
|
||
|
|
}
|
||
|
|
size := int(binary.BigEndian.Uint32(data[:transferFrameHeaderSize]))
|
||
|
|
data = data[transferFrameHeaderSize:]
|
||
|
|
if size < 0 || len(data) < size {
|
||
|
|
return count
|
||
|
|
}
|
||
|
|
data = data[size:]
|
||
|
|
count++
|
||
|
|
}
|
||
|
|
return count
|
||
|
|
}
|