feat(sftp): 增加可选的并发传输策略

- 增加 SFTP client 级配置,支持 packet size、单文件并发请求数、并发读和并发写
- 将吞吐优化限制在 StarSSH 托管的 SFTP client 路径中
- 上传和下载在显式启用时使用并发快路径,同时保留原子传输生命周期
- 避免快路径失败前提前上报 100% 进度
- 补充安全校验、托管 client 配置、进度、取消和下载对齐的回归测试
This commit is contained in:
2026-06-22 03:26:47 +08:00
parent 0c23e7d4bf
commit 672a111ec1
2 changed files with 1036 additions and 76 deletions
+345 -37
View File
@@ -13,6 +13,7 @@ import (
"path" "path"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"time" "time"
"github.com/pkg/sftp" "github.com/pkg/sftp"
@@ -36,6 +37,19 @@ type SFTPTransferOptions struct {
VerifySize *bool VerifySize *bool
VerifyChecksum *bool VerifyChecksum *bool
TempSuffix string TempSuffix string
Client SFTPClientOptions
}
// SFTPClientOptions controls the underlying SFTP protocol client.
//
// These options only apply when StarSSH creates the SFTP client internally.
// They are intentionally separate from BufferSize, which only controls the
// local copy buffer used by transfer progress reporting.
type SFTPClientOptions struct {
MaxPacketSize int
MaxConcurrentRequestsPerFile int
ConcurrentReads *bool
ConcurrentWrites *bool
} }
type resolvedSFTPTransferOptions struct { type resolvedSFTPTransferOptions struct {
@@ -48,6 +62,38 @@ type resolvedSFTPTransferOptions struct {
VerifySize bool VerifySize bool
VerifyChecksum bool VerifyChecksum bool
TempSuffix string TempSuffix string
Client resolvedSFTPClientOptions
}
type resolvedSFTPClientOptions struct {
MaxPacketSize int
MaxConcurrentRequestsPerFile int
ConcurrentReads *bool
ConcurrentWrites *bool
}
type sftpConcurrentReaderFrom interface {
ReadFromWithConcurrency(io.Reader, int) (int64, error)
}
type sftpUploadProgressReader struct {
ctx context.Context
reader io.Reader
total int64
progress func(float64)
mu sync.Mutex
copied int64
}
type sftpDownloadProgressWriter struct {
ctx context.Context
writer io.Writer
total int64
progress func(float64)
mu sync.Mutex
copied int64
} }
type SFTPErrorCategory string type SFTPErrorCategory string
@@ -106,6 +152,7 @@ var (
sftpVerifyLocalSizeFunc = verifyLocalSize sftpVerifyLocalSizeFunc = verifyLocalSize
sftpLocalFileSHA256Func = localFileSHA256 sftpLocalFileSHA256Func = localFileSHA256
sftpRemoteFileSHA256Func = remoteFileSHA256 sftpRemoteFileSHA256Func = remoteFileSHA256
sftpNewClientFunc = sftp.NewClient
) )
func DefaultSFTPTransferOptions() SFTPTransferOptions { func DefaultSFTPTransferOptions() SFTPTransferOptions {
@@ -121,6 +168,16 @@ func DefaultSFTPTransferOptions() SFTPTransferOptions {
} }
} }
func ThroughputSFTPTransferOptions() SFTPTransferOptions {
opts := DefaultSFTPTransferOptions()
opts.Client = SFTPClientOptions{
ConcurrentReads: SFTPBool(true),
ConcurrentWrites: SFTPBool(true),
MaxConcurrentRequestsPerFile: 32,
}
return opts
}
func SFTPBool(value bool) *bool { func SFTPBool(value bool) *bool {
return &value return &value
} }
@@ -316,22 +373,9 @@ func (fs *SFTPFileSystem) Rename(ctx context.Context, oldPath string, newPath st
}) })
} }
func normalizeSFTPTransferOptions(options *SFTPTransferOptions) resolvedSFTPTransferOptions { func normalizeSFTPTransferOptions(options *SFTPTransferOptions) (resolvedSFTPTransferOptions, error) {
opts := DefaultSFTPTransferOptions() opts := DefaultSFTPTransferOptions()
if options == nil { if options != nil {
return resolvedSFTPTransferOptions{
BufferSize: opts.BufferSize,
Progress: opts.Progress,
RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)),
RetryInitialBackoff: derefSFTPDuration(opts.RetryInitialBackoff, defaultSFTPRetryInitialBackoff),
AtomicUpload: derefSFTPBool(opts.AtomicUpload, true),
AtomicDownload: derefSFTPBool(opts.AtomicDownload, true),
VerifySize: derefSFTPBool(opts.VerifySize, true),
VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false),
TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix),
}
}
if options.BufferSize > 0 { if options.BufferSize > 0 {
opts.BufferSize = options.BufferSize opts.BufferSize = options.BufferSize
} }
@@ -359,8 +403,10 @@ func normalizeSFTPTransferOptions(options *SFTPTransferOptions) resolvedSFTPTran
if strings.TrimSpace(options.TempSuffix) != "" { if strings.TrimSpace(options.TempSuffix) != "" {
opts.TempSuffix = options.TempSuffix opts.TempSuffix = options.TempSuffix
} }
opts.Client = options.Client
}
return resolvedSFTPTransferOptions{ resolved := resolvedSFTPTransferOptions{
BufferSize: opts.BufferSize, BufferSize: opts.BufferSize,
Progress: opts.Progress, Progress: opts.Progress,
RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)), RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)),
@@ -371,6 +417,11 @@ func normalizeSFTPTransferOptions(options *SFTPTransferOptions) resolvedSFTPTran
VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false), VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false),
TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix), TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix),
} }
resolved.Client = normalizeSFTPClientOptions(opts.Client)
if err := validateResolvedSFTPTransferOptions(resolved); err != nil {
return resolvedSFTPTransferOptions{}, err
}
return resolved, nil
} }
func derefSFTPBool(value *bool, fallback bool) bool { func derefSFTPBool(value *bool, fallback bool) bool {
@@ -409,13 +460,56 @@ func normalizeSFTPRetryCount(value int) int {
return value return value
} }
func normalizeSFTPClientOptions(options SFTPClientOptions) resolvedSFTPClientOptions {
return resolvedSFTPClientOptions{
MaxPacketSize: options.MaxPacketSize,
MaxConcurrentRequestsPerFile: options.MaxConcurrentRequestsPerFile,
ConcurrentReads: options.ConcurrentReads,
ConcurrentWrites: options.ConcurrentWrites,
}
}
func validateResolvedSFTPTransferOptions(opts resolvedSFTPTransferOptions) error {
return validateResolvedSFTPClientOptions(opts.Client)
}
func validateResolvedSFTPClientOptions(opts resolvedSFTPClientOptions) error {
if opts.MaxPacketSize < 0 {
return errors.New("sftp max packet size must not be negative")
}
if opts.MaxConcurrentRequestsPerFile < 0 {
return errors.New("sftp max concurrent requests per file must not be negative")
}
return nil
}
func rejectExternalSFTPClientOptions(opts resolvedSFTPClientOptions) error {
if opts.MaxPacketSize != 0 ||
opts.MaxConcurrentRequestsPerFile != 0 ||
opts.ConcurrentReads != nil ||
opts.ConcurrentWrites != nil {
return errors.New("sftp client options require StarSSH-managed SFTP client")
}
return nil
}
func validateSFTPUploadOptions(opts resolvedSFTPTransferOptions) error {
if derefSFTPBool(opts.Client.ConcurrentWrites, false) && !opts.AtomicUpload {
return errors.New("sftp concurrent writes require atomic upload")
}
return nil
}
func (star *StarSSH) runSFTPClientOperation(ctx context.Context, operation string, remotePath string, fn func(*sftp.Client) error) error { func (star *StarSSH) runSFTPClientOperation(ctx context.Context, operation string, remotePath string, fn func(*sftp.Client) error) error {
if err := ensureContext(ctx); err != nil { if err := ensureContext(ctx); err != nil {
return err return err
} }
opts := normalizeSFTPTransferOptions(nil) opts, err := normalizeSFTPTransferOptions(nil)
if err != nil {
return err
}
return executeSFTPRetry(ctx, operation, "", remotePath, opts, func(attempt int) error { return executeSFTPRetry(ctx, operation, "", remotePath, opts, func(attempt int) error {
return star.withIsolatedSFTPClient(ctx, fn) return star.withIsolatedSFTPClient(ctx, opts.Client, fn)
}) })
} }
@@ -423,23 +517,27 @@ func (star *StarSSH) runSFTPClientOperationNoRetry(ctx context.Context, fn func(
if err := ensureContext(ctx); err != nil { if err := ensureContext(ctx); err != nil {
return err return err
} }
return star.withIsolatedSFTPClient(ctx, fn) return star.withIsolatedSFTPClient(ctx, resolvedSFTPClientOptions{}, fn)
} }
func (star *StarSSH) CreateSftpClient() (*sftp.Client, error) { func (star *StarSSH) CreateSftpClient() (*sftp.Client, error) {
return star.createSFTPClientWithOptions(resolvedSFTPClientOptions{})
}
func (star *StarSSH) createSFTPClientWithOptions(options resolvedSFTPClientOptions) (*sftp.Client, error) {
client, err := star.requireSSHClient() client, err := star.requireSSHClient()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return sftp.NewClient(client) return sftpNewClientFunc(client, buildSFTPClientOptions(options)...)
} }
func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, fn func(*sftp.Client) error) error { func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, options resolvedSFTPClientOptions, fn func(*sftp.Client) error) error {
if err := ensureContext(ctx); err != nil { if err := ensureContext(ctx); err != nil {
return err return err
} }
client, err := star.CreateSftpClient() client, err := star.createSFTPClientWithOptions(options)
if err != nil { if err != nil {
return err return err
} }
@@ -448,6 +546,23 @@ func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, fn func(*sftp.C
return fn(client) return fn(client)
} }
func buildSFTPClientOptions(options resolvedSFTPClientOptions) []sftp.ClientOption {
clientOptions := make([]sftp.ClientOption, 0, 4)
if options.MaxPacketSize > 0 {
clientOptions = append(clientOptions, sftp.MaxPacketChecked(options.MaxPacketSize))
}
if options.MaxConcurrentRequestsPerFile > 0 {
clientOptions = append(clientOptions, sftp.MaxConcurrentRequestsPerFile(options.MaxConcurrentRequestsPerFile))
}
if options.ConcurrentReads != nil {
clientOptions = append(clientOptions, sftp.UseConcurrentReads(*options.ConcurrentReads))
}
if options.ConcurrentWrites != nil {
clientOptions = append(clientOptions, sftp.UseConcurrentWrites(*options.ConcurrentWrites))
}
return clientOptions
}
func (star *StarSSH) getReusableSFTPClient() (*sftp.Client, error) { func (star *StarSSH) getReusableSFTPClient() (*sftp.Client, error) {
if star == nil { if star == nil {
return nil, errors.New("ssh client is nil") return nil, errors.New("ssh client is nil")
@@ -526,7 +641,7 @@ func (star *StarSSH) runSFTPWithRetry(
fn func(context.Context, *sftp.Client, resolvedSFTPTransferOptions) error, fn func(context.Context, *sftp.Client, resolvedSFTPTransferOptions) error,
) error { ) error {
return executeSFTPRetry(ctx, operation, localPath, remotePath, opts, func(attempt int) error { return executeSFTPRetry(ctx, operation, localPath, remotePath, opts, func(attempt int) error {
return star.withIsolatedSFTPClient(ctx, func(client *sftp.Client) error { return star.withIsolatedSFTPClient(ctx, opts.Client, func(client *sftp.Client) error {
return fn(ctx, client, opts) return fn(ctx, client, opts)
}) })
}) })
@@ -537,7 +652,13 @@ func (star *StarSSH) SftpTransferOut(localFilePath, remotePath string) error {
} }
func (star *StarSSH) SftpTransferOutContext(ctx context.Context, localFilePath, remotePath string, options *SFTPTransferOptions) error { func (star *StarSSH) SftpTransferOutContext(ctx context.Context, localFilePath, remotePath string, options *SFTPTransferOptions) error {
opts := normalizeSFTPTransferOptions(options) opts, err := normalizeSFTPTransferOptions(options)
if err != nil {
return err
}
if err := validateSFTPUploadOptions(opts); err != nil {
return err
}
return star.runSFTPWithRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { return star.runSFTPWithRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
return transferOutContext(ctx, client, localFilePath, remotePath, opts) return transferOutContext(ctx, client, localFilePath, remotePath, opts)
}) })
@@ -548,7 +669,16 @@ func SftpTransferOut(localFilePath, remotePath string, sftpClient *sftp.Client)
} }
func SftpTransferOutWithContext(ctx context.Context, localFilePath, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error { func SftpTransferOutWithContext(ctx context.Context, localFilePath, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error {
opts := normalizeSFTPTransferOptions(options) opts, err := normalizeSFTPTransferOptions(options)
if err != nil {
return err
}
if err := validateSFTPUploadOptions(opts); err != nil {
return err
}
if err := rejectExternalSFTPClientOptions(opts.Client); err != nil {
return err
}
return executeSFTPRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(attempt int) error { return executeSFTPRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(attempt int) error {
return transferOutContext(ctx, sftpClient, localFilePath, remotePath, opts) return transferOutContext(ctx, sftpClient, localFilePath, remotePath, opts)
}) })
@@ -559,7 +689,13 @@ func (star *StarSSH) SftpTransferOutByte(localData []byte, remotePath string) er
} }
func (star *StarSSH) SftpTransferOutByteContext(ctx context.Context, localData []byte, remotePath string, options *SFTPTransferOptions) error { func (star *StarSSH) SftpTransferOutByteContext(ctx context.Context, localData []byte, remotePath string, options *SFTPTransferOptions) error {
opts := normalizeSFTPTransferOptions(options) opts, err := normalizeSFTPTransferOptions(options)
if err != nil {
return err
}
if err := validateSFTPUploadOptions(opts); err != nil {
return err
}
return star.runSFTPWithRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { return star.runSFTPWithRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
return transferOutByteContext(ctx, client, localData, remotePath, opts) return transferOutByteContext(ctx, client, localData, remotePath, opts)
}) })
@@ -570,7 +706,16 @@ func SftpTransferOutByte(localData []byte, remotePath string, sftpClient *sftp.C
} }
func SftpTransferOutByteWithContext(ctx context.Context, localData []byte, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error { func SftpTransferOutByteWithContext(ctx context.Context, localData []byte, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error {
opts := normalizeSFTPTransferOptions(options) opts, err := normalizeSFTPTransferOptions(options)
if err != nil {
return err
}
if err := validateSFTPUploadOptions(opts); err != nil {
return err
}
if err := rejectExternalSFTPClientOptions(opts.Client); err != nil {
return err
}
return executeSFTPRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(attempt int) error { return executeSFTPRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(attempt int) error {
return transferOutByteContext(ctx, sftpClient, localData, remotePath, opts) return transferOutByteContext(ctx, sftpClient, localData, remotePath, opts)
}) })
@@ -595,10 +740,13 @@ func (star *StarSSH) SftpTransferInByte(remotePath string) ([]byte, error) {
} }
func (star *StarSSH) SftpTransferInByteContext(ctx context.Context, remotePath string, options *SFTPTransferOptions) ([]byte, error) { func (star *StarSSH) SftpTransferInByteContext(ctx context.Context, remotePath string, options *SFTPTransferOptions) ([]byte, error) {
opts := normalizeSFTPTransferOptions(options) opts, err := normalizeSFTPTransferOptions(options)
if err != nil {
return nil, err
}
var data []byte var data []byte
err := star.runSFTPWithRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { err = star.runSFTPWithRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
out, runErr := transferInByteContext(ctx, client, remotePath, opts) out, runErr := transferInByteContext(ctx, client, remotePath, opts)
if runErr != nil { if runErr != nil {
return runErr return runErr
@@ -617,10 +765,16 @@ func SftpTransferInByte(remotePath string, sftpClient *sftp.Client) ([]byte, err
} }
func SftpTransferInByteWithContext(ctx context.Context, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) ([]byte, error) { func SftpTransferInByteWithContext(ctx context.Context, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) ([]byte, error) {
opts := normalizeSFTPTransferOptions(options) opts, err := normalizeSFTPTransferOptions(options)
if err != nil {
return nil, err
}
if err := rejectExternalSFTPClientOptions(opts.Client); err != nil {
return nil, err
}
var data []byte var data []byte
err := executeSFTPRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(attempt int) error { err = executeSFTPRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(attempt int) error {
out, runErr := transferInByteContext(ctx, sftpClient, remotePath, opts) out, runErr := transferInByteContext(ctx, sftpClient, remotePath, opts)
if runErr != nil { if runErr != nil {
return runErr return runErr
@@ -639,7 +793,10 @@ func (star *StarSSH) SftpTransferIn(src, dst string) error {
} }
func (star *StarSSH) SftpTransferInContext(ctx context.Context, src, dst string, options *SFTPTransferOptions) error { func (star *StarSSH) SftpTransferInContext(ctx context.Context, src, dst string, options *SFTPTransferOptions) error {
opts := normalizeSFTPTransferOptions(options) opts, err := normalizeSFTPTransferOptions(options)
if err != nil {
return err
}
return star.runSFTPWithRetry(ctx, "sftp_get_file", dst, src, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { return star.runSFTPWithRetry(ctx, "sftp_get_file", dst, src, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
return transferInContext(ctx, client, src, dst, opts) return transferInContext(ctx, client, src, dst, opts)
}) })
@@ -650,7 +807,13 @@ func SftpTransferIn(src, dst string, sftpClient *sftp.Client) error {
} }
func SftpTransferInWithContext(ctx context.Context, src, dst string, sftpClient *sftp.Client, options *SFTPTransferOptions) error { func SftpTransferInWithContext(ctx context.Context, src, dst string, sftpClient *sftp.Client, options *SFTPTransferOptions) error {
opts := normalizeSFTPTransferOptions(options) opts, err := normalizeSFTPTransferOptions(options)
if err != nil {
return err
}
if err := rejectExternalSFTPClientOptions(opts.Client); err != nil {
return err
}
return executeSFTPRetry(ctx, "sftp_get_file", dst, src, opts, func(attempt int) error { return executeSFTPRetry(ctx, "sftp_get_file", dst, src, opts, func(attempt int) error {
return transferInContext(ctx, sftpClient, src, dst, opts) return transferInContext(ctx, sftpClient, src, dst, opts)
}) })
@@ -712,7 +875,7 @@ func transferOutContext(ctx context.Context, sftpClient *sftp.Client, localFileP
return err return err
} }
if _, err := sftpCopyWithProgressFunc(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil { if _, err := copyUploadWithProgressContext(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress, opts); err != nil {
_ = dstFile.Close() _ = dstFile.Close()
return err return err
} }
@@ -792,7 +955,7 @@ func transferOutByteContext(ctx context.Context, sftpClient *sftp.Client, localD
} }
reader := bytes.NewReader(localData) reader := bytes.NewReader(localData)
if _, err := sftpCopyWithProgressFunc(ctx, dstFile, reader, opts.BufferSize, int64(len(localData)), opts.Progress); err != nil { if _, err := copyUploadWithProgressContext(ctx, dstFile, reader, opts.BufferSize, int64(len(localData)), opts.Progress, opts); err != nil {
_ = dstFile.Close() _ = dstFile.Close()
return err return err
} }
@@ -877,7 +1040,7 @@ func transferInContext(ctx context.Context, sftpClient *sftp.Client, src, dst st
}() }()
} }
if _, err := sftpCopyWithProgressFunc(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil { if _, err := copyDownloadWithProgressContext(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress, opts); err != nil {
_ = dstFile.Close() _ = dstFile.Close()
return err return err
} }
@@ -948,7 +1111,7 @@ func transferInByteContext(ctx context.Context, sftpClient *sftp.Client, remoteP
} }
var out bytes.Buffer var out bytes.Buffer
if _, err := sftpCopyWithProgressFunc(ctx, &out, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil { if _, err := copyDownloadWithProgressContext(ctx, &out, srcFile, opts.BufferSize, stat.Size(), opts.Progress, opts); err != nil {
return nil, err return nil, err
} }
@@ -1149,6 +1312,151 @@ func copyWithProgressContext(ctx context.Context, dst io.Writer, src io.Reader,
return copied, nil return copied, nil
} }
func copyUploadWithProgressContext(
ctx context.Context,
dst io.Writer,
src io.Reader,
bufSize int,
total int64,
progress func(float64),
opts resolvedSFTPTransferOptions,
) (int64, error) {
if !derefSFTPBool(opts.Client.ConcurrentWrites, false) {
return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress)
}
readerFrom, ok := dst.(sftpConcurrentReaderFrom)
if !ok {
return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress)
}
if err := ensureContext(ctx); err != nil {
return 0, err
}
if progress != nil && total > 0 {
progress(0)
}
wrappedSrc := &sftpUploadProgressReader{
ctx: ctx,
reader: src,
total: total,
progress: progress,
}
written, err := readerFrom.ReadFromWithConcurrency(wrappedSrc, opts.Client.MaxConcurrentRequestsPerFile)
if err != nil {
return written, err
}
if err := ensureContext(ctx); err != nil {
return written, err
}
reportProgress(progress, written, total)
return written, nil
}
func copyDownloadWithProgressContext(
ctx context.Context,
dst io.Writer,
src io.Reader,
bufSize int,
total int64,
progress func(float64),
opts resolvedSFTPTransferOptions,
) (int64, error) {
if !derefSFTPBool(opts.Client.ConcurrentReads, false) {
return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress)
}
writerTo, ok := src.(io.WriterTo)
if !ok {
return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress)
}
if err := ensureContext(ctx); err != nil {
return 0, err
}
if progress != nil && total > 0 {
progress(0)
}
wrappedDst := &sftpDownloadProgressWriter{
ctx: ctx,
writer: dst,
total: total,
progress: progress,
}
written, err := writerTo.WriteTo(wrappedDst)
if err != nil {
return written, err
}
if err := ensureContext(ctx); err != nil {
return written, err
}
reportProgress(progress, written, total)
return written, nil
}
func (r *sftpUploadProgressReader) Read(p []byte) (int, error) {
if err := ensureContext(r.ctx); err != nil {
return 0, err
}
r.mu.Lock()
defer r.mu.Unlock()
if err := ensureContext(r.ctx); err != nil {
return 0, err
}
n, err := r.reader.Read(p)
if n > 0 {
r.copied += int64(n)
reportQueuedTransferProgress(r.progress, r.copied, r.total)
}
return n, err
}
func (w *sftpDownloadProgressWriter) Write(p []byte) (int, error) {
if err := ensureContext(w.ctx); err != nil {
return 0, err
}
w.mu.Lock()
defer w.mu.Unlock()
if err := ensureContext(w.ctx); err != nil {
return 0, err
}
n, err := w.writer.Write(p)
if n > 0 {
w.copied += int64(n)
reportQueuedTransferProgress(w.progress, w.copied, w.total)
}
if err != nil {
return n, err
}
if n != len(p) {
return n, io.ErrShortWrite
}
if err := ensureContext(w.ctx); err != nil {
return n, err
}
return n, nil
}
func reportQueuedTransferProgress(progress func(float64), copied int64, total int64) {
if progress == nil || total <= 0 {
return
}
percent := float64(copied) / float64(total) * 100
if percent >= 100 {
percent = 99
}
progress(percent)
}
func reportProgress(progress func(float64), copied int64, total int64) { func reportProgress(progress func(float64), copied int64, total int64) {
if progress == nil { if progress == nil {
return return
+666 -14
View File
@@ -1,6 +1,7 @@
package starssh package starssh
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"io" "io"
@@ -11,16 +12,521 @@ import (
"testing" "testing"
"github.com/pkg/sftp" "github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
) )
func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) { func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) {
opts := normalizeSFTPTransferOptions(nil) opts := mustNormalizeSFTPTransferOptions(t, nil)
if !opts.AtomicUpload { if !opts.AtomicUpload {
t.Fatal("expected atomic upload to default to enabled") t.Fatal("expected atomic upload to default to enabled")
} }
if !opts.AtomicDownload { if !opts.AtomicDownload {
t.Fatal("expected atomic download to default to enabled") t.Fatal("expected atomic download to default to enabled")
} }
if opts.Client.ConcurrentWrites != nil {
t.Fatal("expected concurrent writes to default to unset")
}
}
func TestThroughputSFTPTransferOptionsEnablesExplicitConcurrentWrites(t *testing.T) {
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
if !opts.AtomicUpload {
t.Fatal("expected throughput preset to keep atomic upload enabled")
}
if opts.Client.ConcurrentReads == nil || !*opts.Client.ConcurrentReads {
t.Fatal("expected throughput preset to enable concurrent reads")
}
if opts.Client.ConcurrentWrites == nil || !*opts.Client.ConcurrentWrites {
t.Fatal("expected throughput preset to enable concurrent writes")
}
if opts.Client.MaxConcurrentRequestsPerFile != 32 {
t.Fatalf("unexpected max concurrent requests: got %d want 32", opts.Client.MaxConcurrentRequestsPerFile)
}
}
func TestValidateSFTPUploadOptionsRejectsConcurrentWritesWithoutAtomicUpload(t *testing.T) {
options := ThroughputSFTPTransferOptions()
options.AtomicUpload = SFTPBool(false)
opts := mustNormalizeSFTPTransferOptions(t, &options)
err := validateSFTPUploadOptions(opts)
if err == nil || !strings.Contains(err.Error(), "atomic upload") {
t.Fatalf("expected atomic upload rejection, got %v", err)
}
}
func TestNormalizeSFTPTransferOptionsRejectsNegativeClientValues(t *testing.T) {
_, err := normalizeSFTPTransferOptions(&SFTPTransferOptions{
Client: SFTPClientOptions{MaxPacketSize: -1},
})
if err == nil || !strings.Contains(err.Error(), "max packet") {
t.Fatalf("expected max packet rejection, got %v", err)
}
_, err = normalizeSFTPTransferOptions(&SFTPTransferOptions{
Client: SFTPClientOptions{MaxConcurrentRequestsPerFile: -1},
})
if err == nil || !strings.Contains(err.Error(), "max concurrent") {
t.Fatalf("expected max concurrent rejection, got %v", err)
}
}
func TestBuildSFTPClientOptionsRejectsUnsupportedCheckedPacketSize(t *testing.T) {
options := buildSFTPClientOptions(resolvedSFTPClientOptions{MaxPacketSize: 32769})
client := &sftp.Client{}
err := options[0](client)
if err == nil || !strings.Contains(err.Error(), "32KB") {
t.Fatalf("expected checked packet size rejection, got %v", err)
}
}
func TestSftpTransferOutWithContextRejectsClientOptionsForExternalClient(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
localPath := filepath.Join(root, "local.txt")
remotePath := filepath.Join(root, "remote.txt")
if err := os.WriteFile(localPath, []byte("payload"), 0o644); err != nil {
t.Fatalf("write local file: %v", err)
}
err := SftpTransferOutWithContext(context.Background(), localPath, remotePath, client, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
if err == nil || !strings.Contains(err.Error(), "StarSSH-managed") {
t.Fatalf("expected external client option rejection, got %v", err)
}
}
func TestSftpTransferOutContextPassesClientOptionsToManagedClient(t *testing.T) {
root := t.TempDir()
localPath := filepath.Join(root, "local.txt")
remotePath := filepath.Join(root, "remote.txt")
if err := os.WriteFile(localPath, []byte("payload"), 0o644); err != nil {
t.Fatalf("write local file: %v", err)
}
client := newSFTPTestClient(t)
options := ThroughputSFTPTransferOptions()
var captured []sftp.ClientOption
oldNewClient := sftpNewClientFunc
sftpNewClientFunc = func(_ *ssh.Client, opts ...sftp.ClientOption) (*sftp.Client, error) {
captured = append([]sftp.ClientOption(nil), opts...)
return client, nil
}
t.Cleanup(func() {
sftpNewClientFunc = oldNewClient
})
star := &StarSSH{Client: &ssh.Client{}}
if err := star.SftpTransferOutContext(context.Background(), localPath, remotePath, &options); err != nil {
t.Fatalf("transfer out: %v", err)
}
if len(captured) == 0 {
t.Fatal("expected managed SFTP client options to be passed to factory")
}
if got, want := len(captured), len(buildSFTPClientOptions(mustNormalizeSFTPTransferOptions(t, &options).Client)); got != want {
t.Fatalf("unexpected client option count: got %d want %d", got, want)
}
}
func TestSftpTransferInWithContextRejectsClientOptionsForExternalClient(t *testing.T) {
client := newSFTPTestClient(t)
root := t.TempDir()
srcPath := filepath.Join(root, "remote.txt")
dstPath := filepath.Join(root, "local.txt")
if err := os.WriteFile(srcPath, []byte("payload"), 0o644); err != nil {
t.Fatalf("write remote file: %v", err)
}
err := SftpTransferInWithContext(context.Background(), srcPath, dstPath, client, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
if err == nil || !strings.Contains(err.Error(), "StarSSH-managed") {
t.Fatalf("expected external client option rejection, got %v", err)
}
}
func TestSftpTransferInContextPassesClientOptionsToManagedClient(t *testing.T) {
root := t.TempDir()
srcPath := filepath.Join(root, "remote.txt")
dstPath := filepath.Join(root, "local.txt")
if err := os.WriteFile(srcPath, []byte("payload"), 0o644); err != nil {
t.Fatalf("write remote file: %v", err)
}
client := newSFTPTestClient(t)
options := ThroughputSFTPTransferOptions()
var captured []sftp.ClientOption
oldNewClient := sftpNewClientFunc
sftpNewClientFunc = func(_ *ssh.Client, opts ...sftp.ClientOption) (*sftp.Client, error) {
captured = append([]sftp.ClientOption(nil), opts...)
return client, nil
}
t.Cleanup(func() {
sftpNewClientFunc = oldNewClient
})
star := &StarSSH{Client: &ssh.Client{}}
if err := star.SftpTransferInContext(context.Background(), srcPath, dstPath, &options); err != nil {
t.Fatalf("transfer in: %v", err)
}
if len(captured) == 0 {
t.Fatal("expected managed SFTP client options to be passed to factory")
}
if got, want := len(captured), len(buildSFTPClientOptions(mustNormalizeSFTPTransferOptions(t, &options).Client)); got != want {
t.Fatalf("unexpected client option count: got %d want %d", got, want)
}
assertFileContent(t, dstPath, "payload")
}
func TestBuildSFTPClientOptionsCanTransfer(t *testing.T) {
client := newSFTPTestClientWithOptions(t, buildSFTPClientOptions(resolvedSFTPClientOptions{
MaxPacketSize: 4096,
MaxConcurrentRequestsPerFile: 4,
ConcurrentWrites: SFTPBool(true),
}))
root := t.TempDir()
localPath := filepath.Join(root, "local.txt")
remotePath := filepath.Join(root, "remote.txt")
if err := os.WriteFile(localPath, []byte(strings.Repeat("payload-", 2048)), 0o644); err != nil {
t.Fatalf("write local file: %v", err)
}
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
if err := transferOutContext(context.Background(), client, localPath, remotePath, opts); err != nil {
t.Fatalf("transfer out with client options: %v", err)
}
data, err := os.ReadFile(remotePath)
if err != nil {
t.Fatalf("read remote file: %v", err)
}
if got := string(data); got != strings.Repeat("payload-", 2048) {
t.Fatalf("unexpected remote payload length: got %d", len(got))
}
}
func TestBuildSFTPClientOptionsCanDownload(t *testing.T) {
client := newSFTPTestClientWithOptions(t, buildSFTPClientOptions(resolvedSFTPClientOptions{
MaxPacketSize: 4096,
MaxConcurrentRequestsPerFile: 4,
ConcurrentReads: SFTPBool(true),
}))
root := t.TempDir()
srcPath := filepath.Join(root, "remote.txt")
dstPath := filepath.Join(root, "local.txt")
payload := strings.Repeat("payload-", 2048)
if err := os.WriteFile(srcPath, []byte(payload), 0o644); err != nil {
t.Fatalf("write remote file: %v", err)
}
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
if err := transferInContext(context.Background(), client, srcPath, dstPath, opts); err != nil {
t.Fatalf("transfer in with client options: %v", err)
}
assertFileContent(t, dstPath, payload)
}
func TestCopyUploadWithProgressUsesConcurrentReadFromWhenEnabled(t *testing.T) {
dst := &spyConcurrentReadFrom{}
src := strings.NewReader("payload")
var progress []float64
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
written, err := copyUploadWithProgressContext(context.Background(), dst, src, 3, int64(len("payload")), func(value float64) {
progress = append(progress, value)
}, opts)
if err != nil {
t.Fatalf("copy upload: %v", err)
}
if written != int64(len("payload")) {
t.Fatalf("unexpected written bytes: got %d", written)
}
if !dst.usedReadFrom {
t.Fatal("expected concurrent ReadFrom path to be used")
}
if dst.concurrency != opts.Client.MaxConcurrentRequestsPerFile {
t.Fatalf("unexpected concurrency: got %d want %d", dst.concurrency, opts.Client.MaxConcurrentRequestsPerFile)
}
if got := dst.buf.String(); got != "payload" {
t.Fatalf("unexpected copied payload: got %q", got)
}
if len(progress) == 0 || progress[len(progress)-1] != 100 {
t.Fatalf("expected final progress 100, got %v", progress)
}
}
func TestCopyUploadWithProgressReportsDuringConcurrentReadFrom(t *testing.T) {
dst := &spyConcurrentReadFrom{}
src := &chunkedReader{
reader: strings.NewReader("payload"),
chunkSize: 2,
}
var progress []float64
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
written, err := copyUploadWithProgressContext(context.Background(), dst, src, 3, int64(len("payload")), func(value float64) {
progress = append(progress, value)
}, opts)
if err != nil {
t.Fatalf("copy upload: %v", err)
}
if written != int64(len("payload")) {
t.Fatalf("unexpected written bytes: got %d", written)
}
var sawIntermediate bool
for _, value := range progress {
if value > 0 && value < 100 {
sawIntermediate = true
break
}
}
if !sawIntermediate {
t.Fatalf("expected intermediate progress during concurrent readfrom, got %v", progress)
}
}
func TestCopyUploadWithProgressCancelsDuringConcurrentReadFrom(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dst := &spyConcurrentReadFrom{}
src := &cancelAfterReadReader{
reader: strings.NewReader("payload"),
chunkSize: 3,
cancel: cancel,
}
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
written, err := copyUploadWithProgressContext(ctx, dst, src, 3, int64(len("payload")), nil, opts)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context cancellation, got written=%d err=%v", written, err)
}
if !dst.usedReadFrom {
t.Fatal("expected concurrent ReadFrom path to be used")
}
if written <= 0 || written >= int64(len("payload")) {
t.Fatalf("expected partial write before cancellation, got %d", written)
}
}
func TestCopyUploadWithProgressDoesNotReportDoneBeforeConcurrentWriteError(t *testing.T) {
copyErr := errors.New("write status failed")
dst := &spyConcurrentReadFrom{errAfterRead: copyErr}
var progress []float64
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
written, err := copyUploadWithProgressContext(context.Background(), dst, strings.NewReader("payload"), 3, int64(len("payload")), func(value float64) {
progress = append(progress, value)
}, opts)
if !errors.Is(err, copyErr) {
t.Fatalf("expected concurrent write error, got written=%d err=%v", written, err)
}
if written != int64(len("payload")) {
t.Fatalf("expected read byte count before write error, got %d", written)
}
for _, value := range progress {
if value >= 100 {
t.Fatalf("progress reported completion before write success: %v", progress)
}
}
if len(progress) == 0 {
t.Fatal("expected queued progress before write error")
}
}
func TestCopyUploadWithProgressReturnsConcurrentReadFromError(t *testing.T) {
copyErr := errors.New("readfrom failed")
dst := &spyConcurrentReadFrom{err: copyErr}
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
_, err := copyUploadWithProgressContext(context.Background(), dst, strings.NewReader("payload"), 3, int64(len("payload")), nil, opts)
if !errors.Is(err, copyErr) {
t.Fatalf("expected concurrent readfrom error, got %v", err)
}
if !dst.usedReadFrom {
t.Fatal("expected concurrent ReadFrom path to be used")
}
}
func TestCopyUploadWithProgressDefaultsToSequentialCopy(t *testing.T) {
oldCopy := sftpCopyWithProgressFunc
called := false
sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) {
called = true
return oldCopy(ctx, dst, src, bufSize, total, progress)
}
t.Cleanup(func() {
sftpCopyWithProgressFunc = oldCopy
})
var dst bytes.Buffer
written, err := copyUploadWithProgressContext(context.Background(), &dst, strings.NewReader("payload"), 3, int64(len("payload")), nil, mustNormalizeSFTPTransferOptions(t, nil))
if err != nil {
t.Fatalf("copy upload: %v", err)
}
if written != int64(len("payload")) {
t.Fatalf("unexpected written bytes: got %d", written)
}
if !called {
t.Fatal("expected default path to use existing copy helper")
}
}
func TestCopyDownloadWithProgressUsesConcurrentWriteToWhenEnabled(t *testing.T) {
src := &spyConcurrentWriteTo{payload: []byte("payload")}
var dst bytes.Buffer
var progress []float64
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) {
progress = append(progress, value)
}, opts)
if err != nil {
t.Fatalf("copy download: %v", err)
}
if written != int64(len("payload")) {
t.Fatalf("unexpected written bytes: got %d", written)
}
if !src.usedWriteTo {
t.Fatal("expected concurrent WriteTo path to be used")
}
if got := dst.String(); got != "payload" {
t.Fatalf("unexpected copied payload: got %q", got)
}
if len(progress) == 0 || progress[len(progress)-1] != 100 {
t.Fatalf("expected final progress 100, got %v", progress)
}
}
func TestCopyDownloadWithProgressReportsDuringConcurrentWriteTo(t *testing.T) {
src := &spyConcurrentWriteTo{
payload: []byte("payload"),
chunkSize: 2,
}
var dst bytes.Buffer
var progress []float64
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) {
progress = append(progress, value)
}, opts)
if err != nil {
t.Fatalf("copy download: %v", err)
}
if written != int64(len("payload")) {
t.Fatalf("unexpected written bytes: got %d", written)
}
var sawIntermediate bool
for _, value := range progress {
if value > 0 && value < 100 {
sawIntermediate = true
break
}
}
if !sawIntermediate {
t.Fatalf("expected intermediate progress during concurrent writeto, got %v", progress)
}
}
func TestCopyDownloadWithProgressCancelsDuringConcurrentWriteTo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
src := &spyConcurrentWriteTo{
payload: []byte("payload"),
chunkSize: 3,
cancelAfterWrite: cancel,
}
var dst bytes.Buffer
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
written, err := copyDownloadWithProgressContext(ctx, &dst, src, 3, int64(len("payload")), nil, opts)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context cancellation, got written=%d err=%v", written, err)
}
if !src.usedWriteTo {
t.Fatal("expected concurrent WriteTo path to be used")
}
if written <= 0 || written >= int64(len("payload")) {
t.Fatalf("expected partial write before cancellation, got %d", written)
}
}
func TestCopyDownloadWithProgressDoesNotReportDoneBeforeConcurrentReadError(t *testing.T) {
copyErr := errors.New("read status failed")
src := &spyConcurrentWriteTo{
payload: []byte("payload"),
errAfterWrite: copyErr,
}
var dst bytes.Buffer
var progress []float64
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) {
progress = append(progress, value)
}, opts)
if !errors.Is(err, copyErr) {
t.Fatalf("expected concurrent read error, got written=%d err=%v", written, err)
}
if written != int64(len("payload")) {
t.Fatalf("expected local byte count before read error, got %d", written)
}
for _, value := range progress {
if value >= 100 {
t.Fatalf("progress reported completion before download success: %v", progress)
}
}
if len(progress) == 0 {
t.Fatal("expected queued progress before read error")
}
}
func TestCopyDownloadWithProgressReturnsConcurrentWriteToError(t *testing.T) {
copyErr := errors.New("writeto failed")
src := &spyConcurrentWriteTo{err: copyErr}
var dst bytes.Buffer
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
_, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), nil, opts)
if !errors.Is(err, copyErr) {
t.Fatalf("expected concurrent writeto error, got %v", err)
}
if !src.usedWriteTo {
t.Fatal("expected concurrent WriteTo path to be used")
}
}
func TestCopyDownloadWithProgressDefaultsToSequentialCopy(t *testing.T) {
oldCopy := sftpCopyWithProgressFunc
called := false
sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) {
called = true
return oldCopy(ctx, dst, src, bufSize, total, progress)
}
t.Cleanup(func() {
sftpCopyWithProgressFunc = oldCopy
})
var dst bytes.Buffer
src := &chunkedReader{
reader: strings.NewReader("payload"),
chunkSize: 2,
}
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), nil, mustNormalizeSFTPTransferOptions(t, nil))
if err != nil {
t.Fatalf("copy download: %v", err)
}
if written != int64(len("payload")) {
t.Fatalf("unexpected written bytes: got %d", written)
}
if !called {
t.Fatal("expected default path to use existing copy helper")
}
} }
func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) { func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) {
@@ -47,7 +553,7 @@ func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) {
sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize
}) })
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)) err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil))
if !errors.Is(err, verifyErr) { if !errors.Is(err, verifyErr) {
t.Fatalf("expected verify failure, got %v", err) t.Fatalf("expected verify failure, got %v", err)
} }
@@ -82,7 +588,7 @@ func TestTransferOutContextRejectsRemoteSymlinkTarget(t *testing.T) {
t.Skipf("symlink unsupported: %v", err) t.Skipf("symlink unsupported: %v", err)
} }
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)) err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil))
if err == nil || !strings.Contains(err.Error(), "symlink") { if err == nil || !strings.Contains(err.Error(), "symlink") {
t.Fatalf("expected symlink rejection, got %v", err) t.Fatalf("expected symlink rejection, got %v", err)
} }
@@ -118,7 +624,7 @@ func TestTransferOutContextRejectsRemoteDirectoryTarget(t *testing.T) {
t.Fatalf("mkdir remote target: %v", err) t.Fatalf("mkdir remote target: %v", err)
} }
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)) err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil))
if err == nil || !strings.Contains(err.Error(), "directory") { if err == nil || !strings.Contains(err.Error(), "directory") {
t.Fatalf("expected directory rejection, got %v", err) t.Fatalf("expected directory rejection, got %v", err)
} }
@@ -149,7 +655,7 @@ func TestTransferOutContextPreservesRemoteModeOnOverwrite(t *testing.T) {
t.Fatalf("chmod remote file: %v", err) t.Fatalf("chmod remote file: %v", err)
} }
if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil { if err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
t.Fatalf("transfer out: %v", err) t.Fatalf("transfer out: %v", err)
} }
@@ -170,7 +676,7 @@ func TestTransferOutContextAppliesLocalModeForNewRemoteFile(t *testing.T) {
t.Fatalf("chmod local file: %v", err) t.Fatalf("chmod local file: %v", err)
} }
if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil { if err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
t.Fatalf("transfer out: %v", err) t.Fatalf("transfer out: %v", err)
} }
@@ -190,7 +696,7 @@ func TestTransferOutByteContextPreservesRemoteModeOnOverwrite(t *testing.T) {
t.Fatalf("chmod remote file: %v", err) t.Fatalf("chmod remote file: %v", err)
} }
if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, normalizeSFTPTransferOptions(nil)); err != nil { if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
t.Fatalf("transfer out bytes: %v", err) t.Fatalf("transfer out bytes: %v", err)
} }
@@ -222,7 +728,7 @@ func TestTransferInContextVerifyFailurePreservesLocalTarget(t *testing.T) {
sftpVerifyLocalSizeFunc = oldVerifyLocalSize sftpVerifyLocalSizeFunc = oldVerifyLocalSize
}) })
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
if !errors.Is(err, verifyErr) { if !errors.Is(err, verifyErr) {
t.Fatalf("expected verify failure, got %v", err) t.Fatalf("expected verify failure, got %v", err)
} }
@@ -257,7 +763,7 @@ func TestTransferInContextRejectsLocalSymlinkTarget(t *testing.T) {
t.Skipf("symlink unsupported: %v", err) t.Skipf("symlink unsupported: %v", err)
} }
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
if err == nil || !strings.Contains(err.Error(), "symlink") { if err == nil || !strings.Contains(err.Error(), "symlink") {
t.Fatalf("expected symlink rejection, got %v", err) t.Fatalf("expected symlink rejection, got %v", err)
} }
@@ -286,7 +792,7 @@ func TestTransferInContextRejectsLocalDirectoryTarget(t *testing.T) {
t.Fatalf("mkdir local target: %v", err) t.Fatalf("mkdir local target: %v", err)
} }
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
if err == nil || !strings.Contains(err.Error(), "directory") { if err == nil || !strings.Contains(err.Error(), "directory") {
t.Fatalf("expected directory rejection, got %v", err) t.Fatalf("expected directory rejection, got %v", err)
} }
@@ -317,7 +823,7 @@ func TestTransferInContextPreservesLocalModeOnOverwrite(t *testing.T) {
t.Fatalf("chmod local file: %v", err) t.Fatalf("chmod local file: %v", err)
} }
if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil { if err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
t.Fatalf("transfer in: %v", err) t.Fatalf("transfer in: %v", err)
} }
@@ -338,7 +844,7 @@ func TestTransferInContextAppliesRemoteModeForNewLocalFile(t *testing.T) {
t.Fatalf("chmod remote file: %v", err) t.Fatalf("chmod remote file: %v", err)
} }
if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil { if err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
t.Fatalf("transfer in: %v", err) t.Fatalf("transfer in: %v", err)
} }
@@ -387,7 +893,7 @@ func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) {
sftpCopyWithProgressFunc = oldCopy sftpCopyWithProgressFunc = oldCopy
}) })
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
if !errors.Is(err, copyErr) { if !errors.Is(err, copyErr) {
t.Fatalf("expected copy failure, got %v", err) t.Fatalf("expected copy failure, got %v", err)
} }
@@ -405,7 +911,153 @@ func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) {
assertNoTransferTemps(t, dstPath) assertNoTransferTemps(t, dstPath)
} }
func mustNormalizeSFTPTransferOptions(t *testing.T, options *SFTPTransferOptions) resolvedSFTPTransferOptions {
t.Helper()
opts, err := normalizeSFTPTransferOptions(options)
if err != nil {
t.Fatalf("normalize sftp transfer options: %v", err)
}
return opts
}
func SFTPOptionsPtr(options SFTPTransferOptions) *SFTPTransferOptions {
return &options
}
type spyConcurrentReadFrom struct {
buf bytes.Buffer
concurrency int
usedReadFrom bool
err error
errAfterRead error
}
func (w *spyConcurrentReadFrom) Write(p []byte) (int, error) {
return w.buf.Write(p)
}
func (w *spyConcurrentReadFrom) ReadFromWithConcurrency(r io.Reader, concurrency int) (int64, error) {
w.usedReadFrom = true
w.concurrency = concurrency
if w.err != nil {
return 0, w.err
}
written, err := w.buf.ReadFrom(r)
if err != nil {
return written, err
}
if w.errAfterRead != nil {
return written, w.errAfterRead
}
return written, nil
}
type spyConcurrentWriteTo struct {
payload []byte
chunkSize int
usedWriteTo bool
err error
errAfterWrite error
cancelAfterWrite context.CancelFunc
}
func (r *spyConcurrentWriteTo) Read(p []byte) (int, error) {
if len(r.payload) == 0 {
if r.err != nil {
return 0, r.err
}
return 0, io.EOF
}
n := copy(p, r.payload)
r.payload = r.payload[n:]
if len(r.payload) == 0 {
return n, io.EOF
}
return n, nil
}
func (r *spyConcurrentWriteTo) WriteTo(w io.Writer) (int64, error) {
r.usedWriteTo = true
if r.err != nil {
return 0, r.err
}
payload := r.payload
if len(payload) == 0 {
if r.errAfterWrite != nil {
return 0, r.errAfterWrite
}
return 0, nil
}
chunkSize := r.chunkSize
if chunkSize <= 0 || chunkSize > len(payload) {
chunkSize = len(payload)
}
var written int64
for len(payload) > 0 {
size := chunkSize
if size > len(payload) {
size = len(payload)
}
n, err := w.Write(payload[:size])
written += int64(n)
if r.cancelAfterWrite != nil && n > 0 {
r.cancelAfterWrite()
r.cancelAfterWrite = nil
}
if err != nil {
return written, err
}
if n != size {
return written, io.ErrShortWrite
}
payload = payload[size:]
}
if r.errAfterWrite != nil {
return written, r.errAfterWrite
}
return written, nil
}
type chunkedReader struct {
reader io.Reader
chunkSize int
}
func (r *chunkedReader) Read(p []byte) (int, error) {
if r.chunkSize > 0 && len(p) > r.chunkSize {
p = p[:r.chunkSize]
}
return r.reader.Read(p)
}
type cancelAfterReadReader struct {
reader io.Reader
chunkSize int
cancel context.CancelFunc
}
func (r *cancelAfterReadReader) Read(p []byte) (int, error) {
if r.chunkSize > 0 && len(p) > r.chunkSize {
p = p[:r.chunkSize]
}
n, err := r.reader.Read(p)
if n > 0 && r.cancel != nil {
r.cancel()
r.cancel = nil
}
return n, err
}
func newSFTPTestClient(t *testing.T) *sftp.Client { func newSFTPTestClient(t *testing.T) *sftp.Client {
return newSFTPTestClientWithOptions(t, nil)
}
func newSFTPTestClientWithOptions(t *testing.T, options []sftp.ClientOption) *sftp.Client {
t.Helper() t.Helper()
serverConn, clientConn := net.Pipe() serverConn, clientConn := net.Pipe()
@@ -419,7 +1071,7 @@ func newSFTPTestClient(t *testing.T) *sftp.Client {
serveErrCh <- server.Serve() serveErrCh <- server.Serve()
}() }()
client, err := sftp.NewClientPipe(clientConn, clientConn) client, err := sftp.NewClientPipe(clientConn, clientConn, options...)
if err != nil { if err != nil {
_ = server.Close() _ = server.Close()
t.Fatalf("create sftp client: %v", err) t.Fatalf("create sftp client: %v", err)