672a111ec1
- 增加 SFTP client 级配置,支持 packet size、单文件并发请求数、并发读和并发写 - 将吞吐优化限制在 StarSSH 托管的 SFTP client 路径中 - 上传和下载在显式启用时使用并发快路径,同时保留原子传输生命周期 - 避免快路径失败前提前上报 100% 进度 - 补充安全校验、托管 client 配置、进度、取消和下载对齐的回归测试
1128 lines
35 KiB
Go
1128 lines
35 KiB
Go
package starssh
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/pkg/sftp"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) {
|
|
opts := mustNormalizeSFTPTransferOptions(t, nil)
|
|
if !opts.AtomicUpload {
|
|
t.Fatal("expected atomic upload to default to enabled")
|
|
}
|
|
if !opts.AtomicDownload {
|
|
t.Fatal("expected atomic download to default to enabled")
|
|
}
|
|
if opts.Client.ConcurrentWrites != nil {
|
|
t.Fatal("expected concurrent writes to default to unset")
|
|
}
|
|
}
|
|
|
|
func TestThroughputSFTPTransferOptionsEnablesExplicitConcurrentWrites(t *testing.T) {
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
if !opts.AtomicUpload {
|
|
t.Fatal("expected throughput preset to keep atomic upload enabled")
|
|
}
|
|
if opts.Client.ConcurrentReads == nil || !*opts.Client.ConcurrentReads {
|
|
t.Fatal("expected throughput preset to enable concurrent reads")
|
|
}
|
|
if opts.Client.ConcurrentWrites == nil || !*opts.Client.ConcurrentWrites {
|
|
t.Fatal("expected throughput preset to enable concurrent writes")
|
|
}
|
|
if opts.Client.MaxConcurrentRequestsPerFile != 32 {
|
|
t.Fatalf("unexpected max concurrent requests: got %d want 32", opts.Client.MaxConcurrentRequestsPerFile)
|
|
}
|
|
}
|
|
|
|
func TestValidateSFTPUploadOptionsRejectsConcurrentWritesWithoutAtomicUpload(t *testing.T) {
|
|
options := ThroughputSFTPTransferOptions()
|
|
options.AtomicUpload = SFTPBool(false)
|
|
opts := mustNormalizeSFTPTransferOptions(t, &options)
|
|
|
|
err := validateSFTPUploadOptions(opts)
|
|
if err == nil || !strings.Contains(err.Error(), "atomic upload") {
|
|
t.Fatalf("expected atomic upload rejection, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeSFTPTransferOptionsRejectsNegativeClientValues(t *testing.T) {
|
|
_, err := normalizeSFTPTransferOptions(&SFTPTransferOptions{
|
|
Client: SFTPClientOptions{MaxPacketSize: -1},
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "max packet") {
|
|
t.Fatalf("expected max packet rejection, got %v", err)
|
|
}
|
|
|
|
_, err = normalizeSFTPTransferOptions(&SFTPTransferOptions{
|
|
Client: SFTPClientOptions{MaxConcurrentRequestsPerFile: -1},
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "max concurrent") {
|
|
t.Fatalf("expected max concurrent rejection, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestBuildSFTPClientOptionsRejectsUnsupportedCheckedPacketSize(t *testing.T) {
|
|
options := buildSFTPClientOptions(resolvedSFTPClientOptions{MaxPacketSize: 32769})
|
|
client := &sftp.Client{}
|
|
|
|
err := options[0](client)
|
|
if err == nil || !strings.Contains(err.Error(), "32KB") {
|
|
t.Fatalf("expected checked packet size rejection, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSftpTransferOutWithContextRejectsClientOptionsForExternalClient(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
localPath := filepath.Join(root, "local.txt")
|
|
remotePath := filepath.Join(root, "remote.txt")
|
|
if err := os.WriteFile(localPath, []byte("payload"), 0o644); err != nil {
|
|
t.Fatalf("write local file: %v", err)
|
|
}
|
|
|
|
err := SftpTransferOutWithContext(context.Background(), localPath, remotePath, client, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
if err == nil || !strings.Contains(err.Error(), "StarSSH-managed") {
|
|
t.Fatalf("expected external client option rejection, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSftpTransferOutContextPassesClientOptionsToManagedClient(t *testing.T) {
|
|
root := t.TempDir()
|
|
localPath := filepath.Join(root, "local.txt")
|
|
remotePath := filepath.Join(root, "remote.txt")
|
|
if err := os.WriteFile(localPath, []byte("payload"), 0o644); err != nil {
|
|
t.Fatalf("write local file: %v", err)
|
|
}
|
|
|
|
client := newSFTPTestClient(t)
|
|
options := ThroughputSFTPTransferOptions()
|
|
var captured []sftp.ClientOption
|
|
oldNewClient := sftpNewClientFunc
|
|
sftpNewClientFunc = func(_ *ssh.Client, opts ...sftp.ClientOption) (*sftp.Client, error) {
|
|
captured = append([]sftp.ClientOption(nil), opts...)
|
|
return client, nil
|
|
}
|
|
t.Cleanup(func() {
|
|
sftpNewClientFunc = oldNewClient
|
|
})
|
|
|
|
star := &StarSSH{Client: &ssh.Client{}}
|
|
if err := star.SftpTransferOutContext(context.Background(), localPath, remotePath, &options); err != nil {
|
|
t.Fatalf("transfer out: %v", err)
|
|
}
|
|
|
|
if len(captured) == 0 {
|
|
t.Fatal("expected managed SFTP client options to be passed to factory")
|
|
}
|
|
if got, want := len(captured), len(buildSFTPClientOptions(mustNormalizeSFTPTransferOptions(t, &options).Client)); got != want {
|
|
t.Fatalf("unexpected client option count: got %d want %d", got, want)
|
|
}
|
|
}
|
|
|
|
func TestSftpTransferInWithContextRejectsClientOptionsForExternalClient(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
srcPath := filepath.Join(root, "remote.txt")
|
|
dstPath := filepath.Join(root, "local.txt")
|
|
if err := os.WriteFile(srcPath, []byte("payload"), 0o644); err != nil {
|
|
t.Fatalf("write remote file: %v", err)
|
|
}
|
|
|
|
err := SftpTransferInWithContext(context.Background(), srcPath, dstPath, client, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
if err == nil || !strings.Contains(err.Error(), "StarSSH-managed") {
|
|
t.Fatalf("expected external client option rejection, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSftpTransferInContextPassesClientOptionsToManagedClient(t *testing.T) {
|
|
root := t.TempDir()
|
|
srcPath := filepath.Join(root, "remote.txt")
|
|
dstPath := filepath.Join(root, "local.txt")
|
|
if err := os.WriteFile(srcPath, []byte("payload"), 0o644); err != nil {
|
|
t.Fatalf("write remote file: %v", err)
|
|
}
|
|
|
|
client := newSFTPTestClient(t)
|
|
options := ThroughputSFTPTransferOptions()
|
|
var captured []sftp.ClientOption
|
|
oldNewClient := sftpNewClientFunc
|
|
sftpNewClientFunc = func(_ *ssh.Client, opts ...sftp.ClientOption) (*sftp.Client, error) {
|
|
captured = append([]sftp.ClientOption(nil), opts...)
|
|
return client, nil
|
|
}
|
|
t.Cleanup(func() {
|
|
sftpNewClientFunc = oldNewClient
|
|
})
|
|
|
|
star := &StarSSH{Client: &ssh.Client{}}
|
|
if err := star.SftpTransferInContext(context.Background(), srcPath, dstPath, &options); err != nil {
|
|
t.Fatalf("transfer in: %v", err)
|
|
}
|
|
|
|
if len(captured) == 0 {
|
|
t.Fatal("expected managed SFTP client options to be passed to factory")
|
|
}
|
|
if got, want := len(captured), len(buildSFTPClientOptions(mustNormalizeSFTPTransferOptions(t, &options).Client)); got != want {
|
|
t.Fatalf("unexpected client option count: got %d want %d", got, want)
|
|
}
|
|
assertFileContent(t, dstPath, "payload")
|
|
}
|
|
|
|
func TestBuildSFTPClientOptionsCanTransfer(t *testing.T) {
|
|
client := newSFTPTestClientWithOptions(t, buildSFTPClientOptions(resolvedSFTPClientOptions{
|
|
MaxPacketSize: 4096,
|
|
MaxConcurrentRequestsPerFile: 4,
|
|
ConcurrentWrites: SFTPBool(true),
|
|
}))
|
|
root := t.TempDir()
|
|
localPath := filepath.Join(root, "local.txt")
|
|
remotePath := filepath.Join(root, "remote.txt")
|
|
|
|
if err := os.WriteFile(localPath, []byte(strings.Repeat("payload-", 2048)), 0o644); err != nil {
|
|
t.Fatalf("write local file: %v", err)
|
|
}
|
|
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
if err := transferOutContext(context.Background(), client, localPath, remotePath, opts); err != nil {
|
|
t.Fatalf("transfer out with client options: %v", err)
|
|
}
|
|
|
|
data, err := os.ReadFile(remotePath)
|
|
if err != nil {
|
|
t.Fatalf("read remote file: %v", err)
|
|
}
|
|
if got := string(data); got != strings.Repeat("payload-", 2048) {
|
|
t.Fatalf("unexpected remote payload length: got %d", len(got))
|
|
}
|
|
}
|
|
|
|
func TestBuildSFTPClientOptionsCanDownload(t *testing.T) {
|
|
client := newSFTPTestClientWithOptions(t, buildSFTPClientOptions(resolvedSFTPClientOptions{
|
|
MaxPacketSize: 4096,
|
|
MaxConcurrentRequestsPerFile: 4,
|
|
ConcurrentReads: SFTPBool(true),
|
|
}))
|
|
root := t.TempDir()
|
|
srcPath := filepath.Join(root, "remote.txt")
|
|
dstPath := filepath.Join(root, "local.txt")
|
|
|
|
payload := strings.Repeat("payload-", 2048)
|
|
if err := os.WriteFile(srcPath, []byte(payload), 0o644); err != nil {
|
|
t.Fatalf("write remote file: %v", err)
|
|
}
|
|
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
if err := transferInContext(context.Background(), client, srcPath, dstPath, opts); err != nil {
|
|
t.Fatalf("transfer in with client options: %v", err)
|
|
}
|
|
|
|
assertFileContent(t, dstPath, payload)
|
|
}
|
|
|
|
func TestCopyUploadWithProgressUsesConcurrentReadFromWhenEnabled(t *testing.T) {
|
|
dst := &spyConcurrentReadFrom{}
|
|
src := strings.NewReader("payload")
|
|
var progress []float64
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
|
|
written, err := copyUploadWithProgressContext(context.Background(), dst, src, 3, int64(len("payload")), func(value float64) {
|
|
progress = append(progress, value)
|
|
}, opts)
|
|
if err != nil {
|
|
t.Fatalf("copy upload: %v", err)
|
|
}
|
|
if written != int64(len("payload")) {
|
|
t.Fatalf("unexpected written bytes: got %d", written)
|
|
}
|
|
if !dst.usedReadFrom {
|
|
t.Fatal("expected concurrent ReadFrom path to be used")
|
|
}
|
|
if dst.concurrency != opts.Client.MaxConcurrentRequestsPerFile {
|
|
t.Fatalf("unexpected concurrency: got %d want %d", dst.concurrency, opts.Client.MaxConcurrentRequestsPerFile)
|
|
}
|
|
if got := dst.buf.String(); got != "payload" {
|
|
t.Fatalf("unexpected copied payload: got %q", got)
|
|
}
|
|
if len(progress) == 0 || progress[len(progress)-1] != 100 {
|
|
t.Fatalf("expected final progress 100, got %v", progress)
|
|
}
|
|
}
|
|
|
|
func TestCopyUploadWithProgressReportsDuringConcurrentReadFrom(t *testing.T) {
|
|
dst := &spyConcurrentReadFrom{}
|
|
src := &chunkedReader{
|
|
reader: strings.NewReader("payload"),
|
|
chunkSize: 2,
|
|
}
|
|
var progress []float64
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
|
|
written, err := copyUploadWithProgressContext(context.Background(), dst, src, 3, int64(len("payload")), func(value float64) {
|
|
progress = append(progress, value)
|
|
}, opts)
|
|
if err != nil {
|
|
t.Fatalf("copy upload: %v", err)
|
|
}
|
|
if written != int64(len("payload")) {
|
|
t.Fatalf("unexpected written bytes: got %d", written)
|
|
}
|
|
|
|
var sawIntermediate bool
|
|
for _, value := range progress {
|
|
if value > 0 && value < 100 {
|
|
sawIntermediate = true
|
|
break
|
|
}
|
|
}
|
|
if !sawIntermediate {
|
|
t.Fatalf("expected intermediate progress during concurrent readfrom, got %v", progress)
|
|
}
|
|
}
|
|
|
|
func TestCopyUploadWithProgressCancelsDuringConcurrentReadFrom(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
dst := &spyConcurrentReadFrom{}
|
|
src := &cancelAfterReadReader{
|
|
reader: strings.NewReader("payload"),
|
|
chunkSize: 3,
|
|
cancel: cancel,
|
|
}
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
|
|
written, err := copyUploadWithProgressContext(ctx, dst, src, 3, int64(len("payload")), nil, opts)
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("expected context cancellation, got written=%d err=%v", written, err)
|
|
}
|
|
if !dst.usedReadFrom {
|
|
t.Fatal("expected concurrent ReadFrom path to be used")
|
|
}
|
|
if written <= 0 || written >= int64(len("payload")) {
|
|
t.Fatalf("expected partial write before cancellation, got %d", written)
|
|
}
|
|
}
|
|
|
|
func TestCopyUploadWithProgressDoesNotReportDoneBeforeConcurrentWriteError(t *testing.T) {
|
|
copyErr := errors.New("write status failed")
|
|
dst := &spyConcurrentReadFrom{errAfterRead: copyErr}
|
|
var progress []float64
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
|
|
written, err := copyUploadWithProgressContext(context.Background(), dst, strings.NewReader("payload"), 3, int64(len("payload")), func(value float64) {
|
|
progress = append(progress, value)
|
|
}, opts)
|
|
if !errors.Is(err, copyErr) {
|
|
t.Fatalf("expected concurrent write error, got written=%d err=%v", written, err)
|
|
}
|
|
if written != int64(len("payload")) {
|
|
t.Fatalf("expected read byte count before write error, got %d", written)
|
|
}
|
|
for _, value := range progress {
|
|
if value >= 100 {
|
|
t.Fatalf("progress reported completion before write success: %v", progress)
|
|
}
|
|
}
|
|
if len(progress) == 0 {
|
|
t.Fatal("expected queued progress before write error")
|
|
}
|
|
}
|
|
|
|
func TestCopyUploadWithProgressReturnsConcurrentReadFromError(t *testing.T) {
|
|
copyErr := errors.New("readfrom failed")
|
|
dst := &spyConcurrentReadFrom{err: copyErr}
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
|
|
_, err := copyUploadWithProgressContext(context.Background(), dst, strings.NewReader("payload"), 3, int64(len("payload")), nil, opts)
|
|
if !errors.Is(err, copyErr) {
|
|
t.Fatalf("expected concurrent readfrom error, got %v", err)
|
|
}
|
|
if !dst.usedReadFrom {
|
|
t.Fatal("expected concurrent ReadFrom path to be used")
|
|
}
|
|
}
|
|
|
|
func TestCopyUploadWithProgressDefaultsToSequentialCopy(t *testing.T) {
|
|
oldCopy := sftpCopyWithProgressFunc
|
|
called := false
|
|
sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) {
|
|
called = true
|
|
return oldCopy(ctx, dst, src, bufSize, total, progress)
|
|
}
|
|
t.Cleanup(func() {
|
|
sftpCopyWithProgressFunc = oldCopy
|
|
})
|
|
|
|
var dst bytes.Buffer
|
|
written, err := copyUploadWithProgressContext(context.Background(), &dst, strings.NewReader("payload"), 3, int64(len("payload")), nil, mustNormalizeSFTPTransferOptions(t, nil))
|
|
if err != nil {
|
|
t.Fatalf("copy upload: %v", err)
|
|
}
|
|
if written != int64(len("payload")) {
|
|
t.Fatalf("unexpected written bytes: got %d", written)
|
|
}
|
|
if !called {
|
|
t.Fatal("expected default path to use existing copy helper")
|
|
}
|
|
}
|
|
|
|
func TestCopyDownloadWithProgressUsesConcurrentWriteToWhenEnabled(t *testing.T) {
|
|
src := &spyConcurrentWriteTo{payload: []byte("payload")}
|
|
var dst bytes.Buffer
|
|
var progress []float64
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
|
|
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) {
|
|
progress = append(progress, value)
|
|
}, opts)
|
|
if err != nil {
|
|
t.Fatalf("copy download: %v", err)
|
|
}
|
|
if written != int64(len("payload")) {
|
|
t.Fatalf("unexpected written bytes: got %d", written)
|
|
}
|
|
if !src.usedWriteTo {
|
|
t.Fatal("expected concurrent WriteTo path to be used")
|
|
}
|
|
if got := dst.String(); got != "payload" {
|
|
t.Fatalf("unexpected copied payload: got %q", got)
|
|
}
|
|
if len(progress) == 0 || progress[len(progress)-1] != 100 {
|
|
t.Fatalf("expected final progress 100, got %v", progress)
|
|
}
|
|
}
|
|
|
|
func TestCopyDownloadWithProgressReportsDuringConcurrentWriteTo(t *testing.T) {
|
|
src := &spyConcurrentWriteTo{
|
|
payload: []byte("payload"),
|
|
chunkSize: 2,
|
|
}
|
|
var dst bytes.Buffer
|
|
var progress []float64
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
|
|
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) {
|
|
progress = append(progress, value)
|
|
}, opts)
|
|
if err != nil {
|
|
t.Fatalf("copy download: %v", err)
|
|
}
|
|
if written != int64(len("payload")) {
|
|
t.Fatalf("unexpected written bytes: got %d", written)
|
|
}
|
|
|
|
var sawIntermediate bool
|
|
for _, value := range progress {
|
|
if value > 0 && value < 100 {
|
|
sawIntermediate = true
|
|
break
|
|
}
|
|
}
|
|
if !sawIntermediate {
|
|
t.Fatalf("expected intermediate progress during concurrent writeto, got %v", progress)
|
|
}
|
|
}
|
|
|
|
func TestCopyDownloadWithProgressCancelsDuringConcurrentWriteTo(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
src := &spyConcurrentWriteTo{
|
|
payload: []byte("payload"),
|
|
chunkSize: 3,
|
|
cancelAfterWrite: cancel,
|
|
}
|
|
var dst bytes.Buffer
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
|
|
written, err := copyDownloadWithProgressContext(ctx, &dst, src, 3, int64(len("payload")), nil, opts)
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("expected context cancellation, got written=%d err=%v", written, err)
|
|
}
|
|
if !src.usedWriteTo {
|
|
t.Fatal("expected concurrent WriteTo path to be used")
|
|
}
|
|
if written <= 0 || written >= int64(len("payload")) {
|
|
t.Fatalf("expected partial write before cancellation, got %d", written)
|
|
}
|
|
}
|
|
|
|
func TestCopyDownloadWithProgressDoesNotReportDoneBeforeConcurrentReadError(t *testing.T) {
|
|
copyErr := errors.New("read status failed")
|
|
src := &spyConcurrentWriteTo{
|
|
payload: []byte("payload"),
|
|
errAfterWrite: copyErr,
|
|
}
|
|
var dst bytes.Buffer
|
|
var progress []float64
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
|
|
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) {
|
|
progress = append(progress, value)
|
|
}, opts)
|
|
if !errors.Is(err, copyErr) {
|
|
t.Fatalf("expected concurrent read error, got written=%d err=%v", written, err)
|
|
}
|
|
if written != int64(len("payload")) {
|
|
t.Fatalf("expected local byte count before read error, got %d", written)
|
|
}
|
|
for _, value := range progress {
|
|
if value >= 100 {
|
|
t.Fatalf("progress reported completion before download success: %v", progress)
|
|
}
|
|
}
|
|
if len(progress) == 0 {
|
|
t.Fatal("expected queued progress before read error")
|
|
}
|
|
}
|
|
|
|
func TestCopyDownloadWithProgressReturnsConcurrentWriteToError(t *testing.T) {
|
|
copyErr := errors.New("writeto failed")
|
|
src := &spyConcurrentWriteTo{err: copyErr}
|
|
var dst bytes.Buffer
|
|
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
|
|
|
_, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), nil, opts)
|
|
if !errors.Is(err, copyErr) {
|
|
t.Fatalf("expected concurrent writeto error, got %v", err)
|
|
}
|
|
if !src.usedWriteTo {
|
|
t.Fatal("expected concurrent WriteTo path to be used")
|
|
}
|
|
}
|
|
|
|
func TestCopyDownloadWithProgressDefaultsToSequentialCopy(t *testing.T) {
|
|
oldCopy := sftpCopyWithProgressFunc
|
|
called := false
|
|
sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) {
|
|
called = true
|
|
return oldCopy(ctx, dst, src, bufSize, total, progress)
|
|
}
|
|
t.Cleanup(func() {
|
|
sftpCopyWithProgressFunc = oldCopy
|
|
})
|
|
|
|
var dst bytes.Buffer
|
|
src := &chunkedReader{
|
|
reader: strings.NewReader("payload"),
|
|
chunkSize: 2,
|
|
}
|
|
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), nil, mustNormalizeSFTPTransferOptions(t, nil))
|
|
if err != nil {
|
|
t.Fatalf("copy download: %v", err)
|
|
}
|
|
if written != int64(len("payload")) {
|
|
t.Fatalf("unexpected written bytes: got %d", written)
|
|
}
|
|
if !called {
|
|
t.Fatal("expected default path to use existing copy helper")
|
|
}
|
|
}
|
|
|
|
func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
localPath := filepath.Join(root, "local.txt")
|
|
remotePath := filepath.Join(root, "remote.txt")
|
|
|
|
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
|
|
t.Fatalf("write local file: %v", err)
|
|
}
|
|
if err := os.WriteFile(remotePath, []byte("original remote"), 0o644); err != nil {
|
|
t.Fatalf("write remote file: %v", err)
|
|
}
|
|
|
|
verifyErr := errors.New("verify failed")
|
|
var verifiedPath string
|
|
oldVerifyRemoteSize := sftpVerifyRemoteSizeFunc
|
|
sftpVerifyRemoteSizeFunc = func(client *sftp.Client, remotePath string, expected int64) error {
|
|
verifiedPath = remotePath
|
|
return verifyErr
|
|
}
|
|
t.Cleanup(func() {
|
|
sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize
|
|
})
|
|
|
|
err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil))
|
|
if !errors.Is(err, verifyErr) {
|
|
t.Fatalf("expected verify failure, got %v", err)
|
|
}
|
|
if verifiedPath == remotePath {
|
|
t.Fatal("expected upload verification to run against temp path before final rename")
|
|
}
|
|
|
|
data, err := os.ReadFile(remotePath)
|
|
if err != nil {
|
|
t.Fatalf("read remote file: %v", err)
|
|
}
|
|
if string(data) != "original remote" {
|
|
t.Fatalf("remote target was replaced on verify failure: %q", string(data))
|
|
}
|
|
assertNoTransferTemps(t, remotePath)
|
|
}
|
|
|
|
func TestTransferOutContextRejectsRemoteSymlinkTarget(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
localPath := filepath.Join(root, "local.txt")
|
|
remoteRealPath := filepath.Join(root, "remote-real.txt")
|
|
remotePath := filepath.Join(root, "remote-link.txt")
|
|
|
|
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
|
|
t.Fatalf("write local file: %v", err)
|
|
}
|
|
if err := os.WriteFile(remoteRealPath, []byte("original remote"), 0o644); err != nil {
|
|
t.Fatalf("write remote backing file: %v", err)
|
|
}
|
|
if err := os.Symlink(remoteRealPath, remotePath); err != nil {
|
|
t.Skipf("symlink unsupported: %v", err)
|
|
}
|
|
|
|
err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil))
|
|
if err == nil || !strings.Contains(err.Error(), "symlink") {
|
|
t.Fatalf("expected symlink rejection, got %v", err)
|
|
}
|
|
|
|
info, err := os.Lstat(remotePath)
|
|
if err != nil {
|
|
t.Fatalf("lstat remote symlink: %v", err)
|
|
}
|
|
if info.Mode()&os.ModeSymlink == 0 {
|
|
t.Fatal("expected remote target to remain a symlink")
|
|
}
|
|
|
|
data, err := os.ReadFile(remoteRealPath)
|
|
if err != nil {
|
|
t.Fatalf("read remote backing file: %v", err)
|
|
}
|
|
if string(data) != "original remote" {
|
|
t.Fatalf("remote backing file changed unexpectedly: %q", string(data))
|
|
}
|
|
assertNoTransferTemps(t, remotePath)
|
|
}
|
|
|
|
func TestTransferOutContextRejectsRemoteDirectoryTarget(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
localPath := filepath.Join(root, "local.txt")
|
|
remotePath := filepath.Join(root, "remote-dir")
|
|
|
|
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
|
|
t.Fatalf("write local file: %v", err)
|
|
}
|
|
if err := os.Mkdir(remotePath, 0o755); err != nil {
|
|
t.Fatalf("mkdir remote target: %v", err)
|
|
}
|
|
|
|
err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil))
|
|
if err == nil || !strings.Contains(err.Error(), "directory") {
|
|
t.Fatalf("expected directory rejection, got %v", err)
|
|
}
|
|
|
|
info, err := os.Stat(remotePath)
|
|
if err != nil {
|
|
t.Fatalf("stat remote directory: %v", err)
|
|
}
|
|
if !info.IsDir() {
|
|
t.Fatal("expected remote target to remain a directory")
|
|
}
|
|
assertNoTransferTemps(t, remotePath)
|
|
}
|
|
|
|
func TestTransferOutContextPreservesRemoteModeOnOverwrite(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
localPath := filepath.Join(root, "local.txt")
|
|
remotePath := filepath.Join(root, "remote.txt")
|
|
|
|
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
|
|
t.Fatalf("write local file: %v", err)
|
|
}
|
|
if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil {
|
|
t.Fatalf("write remote file: %v", err)
|
|
}
|
|
if err := os.Chmod(remotePath, 0o755); err != nil {
|
|
t.Fatalf("chmod remote file: %v", err)
|
|
}
|
|
|
|
if err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
|
t.Fatalf("transfer out: %v", err)
|
|
}
|
|
|
|
assertMode(t, remotePath, 0o755)
|
|
assertFileContent(t, remotePath, "new payload")
|
|
}
|
|
|
|
func TestTransferOutContextAppliesLocalModeForNewRemoteFile(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
localPath := filepath.Join(root, "local.txt")
|
|
remotePath := filepath.Join(root, "remote.txt")
|
|
|
|
if err := os.WriteFile(localPath, []byte("new payload"), 0o751); err != nil {
|
|
t.Fatalf("write local file: %v", err)
|
|
}
|
|
if err := os.Chmod(localPath, 0o751); err != nil {
|
|
t.Fatalf("chmod local file: %v", err)
|
|
}
|
|
|
|
if err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
|
t.Fatalf("transfer out: %v", err)
|
|
}
|
|
|
|
assertMode(t, remotePath, 0o751)
|
|
assertFileContent(t, remotePath, "new payload")
|
|
}
|
|
|
|
func TestTransferOutByteContextPreservesRemoteModeOnOverwrite(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
remotePath := filepath.Join(root, "remote.txt")
|
|
|
|
if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil {
|
|
t.Fatalf("write remote file: %v", err)
|
|
}
|
|
if err := os.Chmod(remotePath, 0o755); err != nil {
|
|
t.Fatalf("chmod remote file: %v", err)
|
|
}
|
|
|
|
if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
|
t.Fatalf("transfer out bytes: %v", err)
|
|
}
|
|
|
|
assertMode(t, remotePath, 0o755)
|
|
assertFileContent(t, remotePath, "byte payload")
|
|
}
|
|
|
|
func TestTransferInContextVerifyFailurePreservesLocalTarget(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
srcPath := filepath.Join(root, "remote.txt")
|
|
dstPath := filepath.Join(root, "local.txt")
|
|
|
|
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
|
|
t.Fatalf("write remote file: %v", err)
|
|
}
|
|
if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil {
|
|
t.Fatalf("write local file: %v", err)
|
|
}
|
|
|
|
verifyErr := errors.New("verify local failed")
|
|
var verifiedPath string
|
|
oldVerifyLocalSize := sftpVerifyLocalSizeFunc
|
|
sftpVerifyLocalSizeFunc = func(localPath string, expected int64) error {
|
|
verifiedPath = localPath
|
|
return verifyErr
|
|
}
|
|
t.Cleanup(func() {
|
|
sftpVerifyLocalSizeFunc = oldVerifyLocalSize
|
|
})
|
|
|
|
err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
|
|
if !errors.Is(err, verifyErr) {
|
|
t.Fatalf("expected verify failure, got %v", err)
|
|
}
|
|
if verifiedPath == dstPath {
|
|
t.Fatal("expected download verification to run against temp path before final rename")
|
|
}
|
|
|
|
data, err := os.ReadFile(dstPath)
|
|
if err != nil {
|
|
t.Fatalf("read local file: %v", err)
|
|
}
|
|
if string(data) != "original local" {
|
|
t.Fatalf("local target was replaced on verify failure: %q", string(data))
|
|
}
|
|
assertNoTransferTemps(t, dstPath)
|
|
}
|
|
|
|
func TestTransferInContextRejectsLocalSymlinkTarget(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
srcPath := filepath.Join(root, "remote.txt")
|
|
localRealPath := filepath.Join(root, "local-real.txt")
|
|
dstPath := filepath.Join(root, "local-link.txt")
|
|
|
|
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
|
|
t.Fatalf("write remote file: %v", err)
|
|
}
|
|
if err := os.WriteFile(localRealPath, []byte("original local"), 0o644); err != nil {
|
|
t.Fatalf("write local backing file: %v", err)
|
|
}
|
|
if err := os.Symlink(localRealPath, dstPath); err != nil {
|
|
t.Skipf("symlink unsupported: %v", err)
|
|
}
|
|
|
|
err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
|
|
if err == nil || !strings.Contains(err.Error(), "symlink") {
|
|
t.Fatalf("expected symlink rejection, got %v", err)
|
|
}
|
|
|
|
info, err := os.Lstat(dstPath)
|
|
if err != nil {
|
|
t.Fatalf("lstat local symlink: %v", err)
|
|
}
|
|
if info.Mode()&os.ModeSymlink == 0 {
|
|
t.Fatal("expected local target to remain a symlink")
|
|
}
|
|
assertFileContent(t, localRealPath, "original local")
|
|
assertNoTransferTemps(t, dstPath)
|
|
}
|
|
|
|
func TestTransferInContextRejectsLocalDirectoryTarget(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
srcPath := filepath.Join(root, "remote.txt")
|
|
dstPath := filepath.Join(root, "local-dir")
|
|
|
|
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
|
|
t.Fatalf("write remote file: %v", err)
|
|
}
|
|
if err := os.Mkdir(dstPath, 0o755); err != nil {
|
|
t.Fatalf("mkdir local target: %v", err)
|
|
}
|
|
|
|
err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
|
|
if err == nil || !strings.Contains(err.Error(), "directory") {
|
|
t.Fatalf("expected directory rejection, got %v", err)
|
|
}
|
|
|
|
info, err := os.Stat(dstPath)
|
|
if err != nil {
|
|
t.Fatalf("stat local directory: %v", err)
|
|
}
|
|
if !info.IsDir() {
|
|
t.Fatal("expected local target to remain a directory")
|
|
}
|
|
assertNoTransferTemps(t, dstPath)
|
|
}
|
|
|
|
func TestTransferInContextPreservesLocalModeOnOverwrite(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
srcPath := filepath.Join(root, "remote.txt")
|
|
dstPath := filepath.Join(root, "local.sh")
|
|
|
|
if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o644); err != nil {
|
|
t.Fatalf("write remote file: %v", err)
|
|
}
|
|
if err := os.WriteFile(dstPath, []byte("#!/bin/sh\necho local\n"), 0o755); err != nil {
|
|
t.Fatalf("write local file: %v", err)
|
|
}
|
|
if err := os.Chmod(dstPath, 0o755); err != nil {
|
|
t.Fatalf("chmod local file: %v", err)
|
|
}
|
|
|
|
if err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
|
t.Fatalf("transfer in: %v", err)
|
|
}
|
|
|
|
assertMode(t, dstPath, 0o755)
|
|
assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n")
|
|
}
|
|
|
|
func TestTransferInContextAppliesRemoteModeForNewLocalFile(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
srcPath := filepath.Join(root, "remote.sh")
|
|
dstPath := filepath.Join(root, "local.sh")
|
|
|
|
if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o751); err != nil {
|
|
t.Fatalf("write remote file: %v", err)
|
|
}
|
|
if err := os.Chmod(srcPath, 0o751); err != nil {
|
|
t.Fatalf("chmod remote file: %v", err)
|
|
}
|
|
|
|
if err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
|
t.Fatalf("transfer in: %v", err)
|
|
}
|
|
|
|
assertMode(t, dstPath, 0o751)
|
|
assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n")
|
|
}
|
|
|
|
func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) {
|
|
client := newSFTPTestClient(t)
|
|
root := t.TempDir()
|
|
srcPath := filepath.Join(root, "remote.txt")
|
|
dstPath := filepath.Join(root, "local.txt")
|
|
|
|
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
|
|
t.Fatalf("write remote file: %v", err)
|
|
}
|
|
if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil {
|
|
t.Fatalf("write local file: %v", err)
|
|
}
|
|
|
|
copyErr := errors.New("copy failed")
|
|
var copyTargetPath string
|
|
oldCopy := sftpCopyWithProgressFunc
|
|
sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) {
|
|
file, ok := dst.(*os.File)
|
|
if !ok {
|
|
t.Fatalf("expected local temp file writer, got %T", dst)
|
|
}
|
|
copyTargetPath = file.Name()
|
|
|
|
buf := make([]byte, 8)
|
|
n, readErr := src.Read(buf)
|
|
if readErr != nil && !errors.Is(readErr, io.EOF) {
|
|
return 0, readErr
|
|
}
|
|
if n > 0 {
|
|
written, err := dst.Write(buf[:n])
|
|
if err != nil {
|
|
return int64(written), err
|
|
}
|
|
return int64(written), copyErr
|
|
}
|
|
return 0, copyErr
|
|
}
|
|
t.Cleanup(func() {
|
|
sftpCopyWithProgressFunc = oldCopy
|
|
})
|
|
|
|
err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
|
|
if !errors.Is(err, copyErr) {
|
|
t.Fatalf("expected copy failure, got %v", err)
|
|
}
|
|
if copyTargetPath == dstPath {
|
|
t.Fatal("expected partial download writes to stay on temp path")
|
|
}
|
|
|
|
data, err := os.ReadFile(dstPath)
|
|
if err != nil {
|
|
t.Fatalf("read local file: %v", err)
|
|
}
|
|
if string(data) != "original local" {
|
|
t.Fatalf("local target was modified by partial download: %q", string(data))
|
|
}
|
|
assertNoTransferTemps(t, dstPath)
|
|
}
|
|
|
|
func mustNormalizeSFTPTransferOptions(t *testing.T, options *SFTPTransferOptions) resolvedSFTPTransferOptions {
|
|
t.Helper()
|
|
|
|
opts, err := normalizeSFTPTransferOptions(options)
|
|
if err != nil {
|
|
t.Fatalf("normalize sftp transfer options: %v", err)
|
|
}
|
|
return opts
|
|
}
|
|
|
|
func SFTPOptionsPtr(options SFTPTransferOptions) *SFTPTransferOptions {
|
|
return &options
|
|
}
|
|
|
|
type spyConcurrentReadFrom struct {
|
|
buf bytes.Buffer
|
|
concurrency int
|
|
usedReadFrom bool
|
|
err error
|
|
errAfterRead error
|
|
}
|
|
|
|
func (w *spyConcurrentReadFrom) Write(p []byte) (int, error) {
|
|
return w.buf.Write(p)
|
|
}
|
|
|
|
func (w *spyConcurrentReadFrom) ReadFromWithConcurrency(r io.Reader, concurrency int) (int64, error) {
|
|
w.usedReadFrom = true
|
|
w.concurrency = concurrency
|
|
if w.err != nil {
|
|
return 0, w.err
|
|
}
|
|
written, err := w.buf.ReadFrom(r)
|
|
if err != nil {
|
|
return written, err
|
|
}
|
|
if w.errAfterRead != nil {
|
|
return written, w.errAfterRead
|
|
}
|
|
return written, nil
|
|
}
|
|
|
|
type spyConcurrentWriteTo struct {
|
|
payload []byte
|
|
chunkSize int
|
|
usedWriteTo bool
|
|
err error
|
|
errAfterWrite error
|
|
cancelAfterWrite context.CancelFunc
|
|
}
|
|
|
|
func (r *spyConcurrentWriteTo) Read(p []byte) (int, error) {
|
|
if len(r.payload) == 0 {
|
|
if r.err != nil {
|
|
return 0, r.err
|
|
}
|
|
return 0, io.EOF
|
|
}
|
|
n := copy(p, r.payload)
|
|
r.payload = r.payload[n:]
|
|
if len(r.payload) == 0 {
|
|
return n, io.EOF
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
func (r *spyConcurrentWriteTo) WriteTo(w io.Writer) (int64, error) {
|
|
r.usedWriteTo = true
|
|
if r.err != nil {
|
|
return 0, r.err
|
|
}
|
|
|
|
payload := r.payload
|
|
if len(payload) == 0 {
|
|
if r.errAfterWrite != nil {
|
|
return 0, r.errAfterWrite
|
|
}
|
|
return 0, nil
|
|
}
|
|
|
|
chunkSize := r.chunkSize
|
|
if chunkSize <= 0 || chunkSize > len(payload) {
|
|
chunkSize = len(payload)
|
|
}
|
|
|
|
var written int64
|
|
for len(payload) > 0 {
|
|
size := chunkSize
|
|
if size > len(payload) {
|
|
size = len(payload)
|
|
}
|
|
n, err := w.Write(payload[:size])
|
|
written += int64(n)
|
|
if r.cancelAfterWrite != nil && n > 0 {
|
|
r.cancelAfterWrite()
|
|
r.cancelAfterWrite = nil
|
|
}
|
|
if err != nil {
|
|
return written, err
|
|
}
|
|
if n != size {
|
|
return written, io.ErrShortWrite
|
|
}
|
|
payload = payload[size:]
|
|
}
|
|
|
|
if r.errAfterWrite != nil {
|
|
return written, r.errAfterWrite
|
|
}
|
|
return written, nil
|
|
}
|
|
|
|
type chunkedReader struct {
|
|
reader io.Reader
|
|
chunkSize int
|
|
}
|
|
|
|
func (r *chunkedReader) Read(p []byte) (int, error) {
|
|
if r.chunkSize > 0 && len(p) > r.chunkSize {
|
|
p = p[:r.chunkSize]
|
|
}
|
|
return r.reader.Read(p)
|
|
}
|
|
|
|
type cancelAfterReadReader struct {
|
|
reader io.Reader
|
|
chunkSize int
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
func (r *cancelAfterReadReader) Read(p []byte) (int, error) {
|
|
if r.chunkSize > 0 && len(p) > r.chunkSize {
|
|
p = p[:r.chunkSize]
|
|
}
|
|
n, err := r.reader.Read(p)
|
|
if n > 0 && r.cancel != nil {
|
|
r.cancel()
|
|
r.cancel = nil
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func newSFTPTestClient(t *testing.T) *sftp.Client {
|
|
return newSFTPTestClientWithOptions(t, nil)
|
|
}
|
|
|
|
func newSFTPTestClientWithOptions(t *testing.T, options []sftp.ClientOption) *sftp.Client {
|
|
t.Helper()
|
|
|
|
serverConn, clientConn := net.Pipe()
|
|
server, err := sftp.NewServer(serverConn)
|
|
if err != nil {
|
|
t.Fatalf("create sftp server: %v", err)
|
|
}
|
|
|
|
serveErrCh := make(chan error, 1)
|
|
go func() {
|
|
serveErrCh <- server.Serve()
|
|
}()
|
|
|
|
client, err := sftp.NewClientPipe(clientConn, clientConn, options...)
|
|
if err != nil {
|
|
_ = server.Close()
|
|
t.Fatalf("create sftp client: %v", err)
|
|
}
|
|
|
|
t.Cleanup(func() {
|
|
_ = client.Close()
|
|
_ = server.Close()
|
|
serveErr := <-serveErrCh
|
|
if serveErr == nil || errors.Is(serveErr, io.EOF) || normalizeAlreadyClosedError(serveErr) == nil {
|
|
return
|
|
}
|
|
t.Errorf("unexpected sftp server error: %v", serveErr)
|
|
})
|
|
|
|
return client
|
|
}
|
|
|
|
func assertNoTransferTemps(t *testing.T, targetPath string) {
|
|
t.Helper()
|
|
|
|
matches, err := filepath.Glob(targetPath + defaultSFTPTempSuffix + "*")
|
|
if err != nil {
|
|
t.Fatalf("glob temp files: %v", err)
|
|
}
|
|
if len(matches) != 0 {
|
|
t.Fatalf("expected temp artifacts to be cleaned up, got %v", matches)
|
|
}
|
|
}
|
|
|
|
func assertMode(t *testing.T, targetPath string, want os.FileMode) {
|
|
t.Helper()
|
|
|
|
info, err := os.Stat(targetPath)
|
|
if err != nil {
|
|
t.Fatalf("stat %q: %v", targetPath, err)
|
|
}
|
|
if got := info.Mode().Perm(); got != want {
|
|
t.Fatalf("unexpected mode for %q: got %o want %o", targetPath, got, want)
|
|
}
|
|
}
|
|
|
|
func assertFileContent(t *testing.T, targetPath string, want string) {
|
|
t.Helper()
|
|
|
|
data, err := os.ReadFile(targetPath)
|
|
if err != nil {
|
|
t.Fatalf("read %q: %v", targetPath, err)
|
|
}
|
|
if string(data) != want {
|
|
t.Fatalf("unexpected content for %q: got %q want %q", targetPath, string(data), want)
|
|
}
|
|
}
|