feat(sftp): 增加可选的并发传输策略
- 增加 SFTP client 级配置,支持 packet size、单文件并发请求数、并发读和并发写 - 将吞吐优化限制在 StarSSH 托管的 SFTP client 路径中 - 上传和下载在显式启用时使用并发快路径,同时保留原子传输生命周期 - 避免快路径失败前提前上报 100% 进度 - 补充安全校验、托管 client 配置、进度、取消和下载对齐的回归测试
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
@@ -36,6 +37,19 @@ type SFTPTransferOptions struct {
|
||||
VerifySize *bool
|
||||
VerifyChecksum *bool
|
||||
TempSuffix string
|
||||
Client SFTPClientOptions
|
||||
}
|
||||
|
||||
// SFTPClientOptions controls the underlying SFTP protocol client.
|
||||
//
|
||||
// These options only apply when StarSSH creates the SFTP client internally.
|
||||
// They are intentionally separate from BufferSize, which only controls the
|
||||
// local copy buffer used by transfer progress reporting.
|
||||
type SFTPClientOptions struct {
|
||||
MaxPacketSize int
|
||||
MaxConcurrentRequestsPerFile int
|
||||
ConcurrentReads *bool
|
||||
ConcurrentWrites *bool
|
||||
}
|
||||
|
||||
type resolvedSFTPTransferOptions struct {
|
||||
@@ -48,6 +62,38 @@ type resolvedSFTPTransferOptions struct {
|
||||
VerifySize bool
|
||||
VerifyChecksum bool
|
||||
TempSuffix string
|
||||
Client resolvedSFTPClientOptions
|
||||
}
|
||||
|
||||
type resolvedSFTPClientOptions struct {
|
||||
MaxPacketSize int
|
||||
MaxConcurrentRequestsPerFile int
|
||||
ConcurrentReads *bool
|
||||
ConcurrentWrites *bool
|
||||
}
|
||||
|
||||
type sftpConcurrentReaderFrom interface {
|
||||
ReadFromWithConcurrency(io.Reader, int) (int64, error)
|
||||
}
|
||||
|
||||
type sftpUploadProgressReader struct {
|
||||
ctx context.Context
|
||||
reader io.Reader
|
||||
total int64
|
||||
progress func(float64)
|
||||
|
||||
mu sync.Mutex
|
||||
copied int64
|
||||
}
|
||||
|
||||
type sftpDownloadProgressWriter struct {
|
||||
ctx context.Context
|
||||
writer io.Writer
|
||||
total int64
|
||||
progress func(float64)
|
||||
|
||||
mu sync.Mutex
|
||||
copied int64
|
||||
}
|
||||
|
||||
type SFTPErrorCategory string
|
||||
@@ -106,6 +152,7 @@ var (
|
||||
sftpVerifyLocalSizeFunc = verifyLocalSize
|
||||
sftpLocalFileSHA256Func = localFileSHA256
|
||||
sftpRemoteFileSHA256Func = remoteFileSHA256
|
||||
sftpNewClientFunc = sftp.NewClient
|
||||
)
|
||||
|
||||
func DefaultSFTPTransferOptions() SFTPTransferOptions {
|
||||
@@ -121,6 +168,16 @@ func DefaultSFTPTransferOptions() SFTPTransferOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func ThroughputSFTPTransferOptions() SFTPTransferOptions {
|
||||
opts := DefaultSFTPTransferOptions()
|
||||
opts.Client = SFTPClientOptions{
|
||||
ConcurrentReads: SFTPBool(true),
|
||||
ConcurrentWrites: SFTPBool(true),
|
||||
MaxConcurrentRequestsPerFile: 32,
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func SFTPBool(value bool) *bool {
|
||||
return &value
|
||||
}
|
||||
@@ -316,22 +373,9 @@ func (fs *SFTPFileSystem) Rename(ctx context.Context, oldPath string, newPath st
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeSFTPTransferOptions(options *SFTPTransferOptions) resolvedSFTPTransferOptions {
|
||||
func normalizeSFTPTransferOptions(options *SFTPTransferOptions) (resolvedSFTPTransferOptions, error) {
|
||||
opts := DefaultSFTPTransferOptions()
|
||||
if options == nil {
|
||||
return resolvedSFTPTransferOptions{
|
||||
BufferSize: opts.BufferSize,
|
||||
Progress: opts.Progress,
|
||||
RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)),
|
||||
RetryInitialBackoff: derefSFTPDuration(opts.RetryInitialBackoff, defaultSFTPRetryInitialBackoff),
|
||||
AtomicUpload: derefSFTPBool(opts.AtomicUpload, true),
|
||||
AtomicDownload: derefSFTPBool(opts.AtomicDownload, true),
|
||||
VerifySize: derefSFTPBool(opts.VerifySize, true),
|
||||
VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false),
|
||||
TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix),
|
||||
}
|
||||
}
|
||||
|
||||
if options != nil {
|
||||
if options.BufferSize > 0 {
|
||||
opts.BufferSize = options.BufferSize
|
||||
}
|
||||
@@ -359,8 +403,10 @@ func normalizeSFTPTransferOptions(options *SFTPTransferOptions) resolvedSFTPTran
|
||||
if strings.TrimSpace(options.TempSuffix) != "" {
|
||||
opts.TempSuffix = options.TempSuffix
|
||||
}
|
||||
opts.Client = options.Client
|
||||
}
|
||||
|
||||
return resolvedSFTPTransferOptions{
|
||||
resolved := resolvedSFTPTransferOptions{
|
||||
BufferSize: opts.BufferSize,
|
||||
Progress: opts.Progress,
|
||||
RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)),
|
||||
@@ -371,6 +417,11 @@ func normalizeSFTPTransferOptions(options *SFTPTransferOptions) resolvedSFTPTran
|
||||
VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false),
|
||||
TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix),
|
||||
}
|
||||
resolved.Client = normalizeSFTPClientOptions(opts.Client)
|
||||
if err := validateResolvedSFTPTransferOptions(resolved); err != nil {
|
||||
return resolvedSFTPTransferOptions{}, err
|
||||
}
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
func derefSFTPBool(value *bool, fallback bool) bool {
|
||||
@@ -409,13 +460,56 @@ func normalizeSFTPRetryCount(value int) int {
|
||||
return value
|
||||
}
|
||||
|
||||
func normalizeSFTPClientOptions(options SFTPClientOptions) resolvedSFTPClientOptions {
|
||||
return resolvedSFTPClientOptions{
|
||||
MaxPacketSize: options.MaxPacketSize,
|
||||
MaxConcurrentRequestsPerFile: options.MaxConcurrentRequestsPerFile,
|
||||
ConcurrentReads: options.ConcurrentReads,
|
||||
ConcurrentWrites: options.ConcurrentWrites,
|
||||
}
|
||||
}
|
||||
|
||||
func validateResolvedSFTPTransferOptions(opts resolvedSFTPTransferOptions) error {
|
||||
return validateResolvedSFTPClientOptions(opts.Client)
|
||||
}
|
||||
|
||||
func validateResolvedSFTPClientOptions(opts resolvedSFTPClientOptions) error {
|
||||
if opts.MaxPacketSize < 0 {
|
||||
return errors.New("sftp max packet size must not be negative")
|
||||
}
|
||||
if opts.MaxConcurrentRequestsPerFile < 0 {
|
||||
return errors.New("sftp max concurrent requests per file must not be negative")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rejectExternalSFTPClientOptions(opts resolvedSFTPClientOptions) error {
|
||||
if opts.MaxPacketSize != 0 ||
|
||||
opts.MaxConcurrentRequestsPerFile != 0 ||
|
||||
opts.ConcurrentReads != nil ||
|
||||
opts.ConcurrentWrites != nil {
|
||||
return errors.New("sftp client options require StarSSH-managed SFTP client")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSFTPUploadOptions(opts resolvedSFTPTransferOptions) error {
|
||||
if derefSFTPBool(opts.Client.ConcurrentWrites, false) && !opts.AtomicUpload {
|
||||
return errors.New("sftp concurrent writes require atomic upload")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (star *StarSSH) runSFTPClientOperation(ctx context.Context, operation string, remotePath string, fn func(*sftp.Client) error) error {
|
||||
if err := ensureContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
opts := normalizeSFTPTransferOptions(nil)
|
||||
opts, err := normalizeSFTPTransferOptions(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return executeSFTPRetry(ctx, operation, "", remotePath, opts, func(attempt int) error {
|
||||
return star.withIsolatedSFTPClient(ctx, fn)
|
||||
return star.withIsolatedSFTPClient(ctx, opts.Client, fn)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -423,23 +517,27 @@ func (star *StarSSH) runSFTPClientOperationNoRetry(ctx context.Context, fn func(
|
||||
if err := ensureContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return star.withIsolatedSFTPClient(ctx, fn)
|
||||
return star.withIsolatedSFTPClient(ctx, resolvedSFTPClientOptions{}, fn)
|
||||
}
|
||||
|
||||
func (star *StarSSH) CreateSftpClient() (*sftp.Client, error) {
|
||||
return star.createSFTPClientWithOptions(resolvedSFTPClientOptions{})
|
||||
}
|
||||
|
||||
func (star *StarSSH) createSFTPClientWithOptions(options resolvedSFTPClientOptions) (*sftp.Client, error) {
|
||||
client, err := star.requireSSHClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sftp.NewClient(client)
|
||||
return sftpNewClientFunc(client, buildSFTPClientOptions(options)...)
|
||||
}
|
||||
|
||||
func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, fn func(*sftp.Client) error) error {
|
||||
func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, options resolvedSFTPClientOptions, fn func(*sftp.Client) error) error {
|
||||
if err := ensureContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := star.CreateSftpClient()
|
||||
client, err := star.createSFTPClientWithOptions(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -448,6 +546,23 @@ func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, fn func(*sftp.C
|
||||
return fn(client)
|
||||
}
|
||||
|
||||
func buildSFTPClientOptions(options resolvedSFTPClientOptions) []sftp.ClientOption {
|
||||
clientOptions := make([]sftp.ClientOption, 0, 4)
|
||||
if options.MaxPacketSize > 0 {
|
||||
clientOptions = append(clientOptions, sftp.MaxPacketChecked(options.MaxPacketSize))
|
||||
}
|
||||
if options.MaxConcurrentRequestsPerFile > 0 {
|
||||
clientOptions = append(clientOptions, sftp.MaxConcurrentRequestsPerFile(options.MaxConcurrentRequestsPerFile))
|
||||
}
|
||||
if options.ConcurrentReads != nil {
|
||||
clientOptions = append(clientOptions, sftp.UseConcurrentReads(*options.ConcurrentReads))
|
||||
}
|
||||
if options.ConcurrentWrites != nil {
|
||||
clientOptions = append(clientOptions, sftp.UseConcurrentWrites(*options.ConcurrentWrites))
|
||||
}
|
||||
return clientOptions
|
||||
}
|
||||
|
||||
func (star *StarSSH) getReusableSFTPClient() (*sftp.Client, error) {
|
||||
if star == nil {
|
||||
return nil, errors.New("ssh client is nil")
|
||||
@@ -526,7 +641,7 @@ func (star *StarSSH) runSFTPWithRetry(
|
||||
fn func(context.Context, *sftp.Client, resolvedSFTPTransferOptions) error,
|
||||
) error {
|
||||
return executeSFTPRetry(ctx, operation, localPath, remotePath, opts, func(attempt int) error {
|
||||
return star.withIsolatedSFTPClient(ctx, func(client *sftp.Client) error {
|
||||
return star.withIsolatedSFTPClient(ctx, opts.Client, func(client *sftp.Client) error {
|
||||
return fn(ctx, client, opts)
|
||||
})
|
||||
})
|
||||
@@ -537,7 +652,13 @@ func (star *StarSSH) SftpTransferOut(localFilePath, remotePath string) error {
|
||||
}
|
||||
|
||||
func (star *StarSSH) SftpTransferOutContext(ctx context.Context, localFilePath, remotePath string, options *SFTPTransferOptions) error {
|
||||
opts := normalizeSFTPTransferOptions(options)
|
||||
opts, err := normalizeSFTPTransferOptions(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateSFTPUploadOptions(opts); err != nil {
|
||||
return err
|
||||
}
|
||||
return star.runSFTPWithRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
||||
return transferOutContext(ctx, client, localFilePath, remotePath, opts)
|
||||
})
|
||||
@@ -548,7 +669,16 @@ func SftpTransferOut(localFilePath, remotePath string, sftpClient *sftp.Client)
|
||||
}
|
||||
|
||||
func SftpTransferOutWithContext(ctx context.Context, localFilePath, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error {
|
||||
opts := normalizeSFTPTransferOptions(options)
|
||||
opts, err := normalizeSFTPTransferOptions(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateSFTPUploadOptions(opts); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rejectExternalSFTPClientOptions(opts.Client); err != nil {
|
||||
return err
|
||||
}
|
||||
return executeSFTPRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(attempt int) error {
|
||||
return transferOutContext(ctx, sftpClient, localFilePath, remotePath, opts)
|
||||
})
|
||||
@@ -559,7 +689,13 @@ func (star *StarSSH) SftpTransferOutByte(localData []byte, remotePath string) er
|
||||
}
|
||||
|
||||
func (star *StarSSH) SftpTransferOutByteContext(ctx context.Context, localData []byte, remotePath string, options *SFTPTransferOptions) error {
|
||||
opts := normalizeSFTPTransferOptions(options)
|
||||
opts, err := normalizeSFTPTransferOptions(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateSFTPUploadOptions(opts); err != nil {
|
||||
return err
|
||||
}
|
||||
return star.runSFTPWithRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
||||
return transferOutByteContext(ctx, client, localData, remotePath, opts)
|
||||
})
|
||||
@@ -570,7 +706,16 @@ func SftpTransferOutByte(localData []byte, remotePath string, sftpClient *sftp.C
|
||||
}
|
||||
|
||||
func SftpTransferOutByteWithContext(ctx context.Context, localData []byte, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error {
|
||||
opts := normalizeSFTPTransferOptions(options)
|
||||
opts, err := normalizeSFTPTransferOptions(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateSFTPUploadOptions(opts); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rejectExternalSFTPClientOptions(opts.Client); err != nil {
|
||||
return err
|
||||
}
|
||||
return executeSFTPRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(attempt int) error {
|
||||
return transferOutByteContext(ctx, sftpClient, localData, remotePath, opts)
|
||||
})
|
||||
@@ -595,10 +740,13 @@ func (star *StarSSH) SftpTransferInByte(remotePath string) ([]byte, error) {
|
||||
}
|
||||
|
||||
func (star *StarSSH) SftpTransferInByteContext(ctx context.Context, remotePath string, options *SFTPTransferOptions) ([]byte, error) {
|
||||
opts := normalizeSFTPTransferOptions(options)
|
||||
opts, err := normalizeSFTPTransferOptions(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var data []byte
|
||||
err := star.runSFTPWithRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
||||
err = star.runSFTPWithRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
||||
out, runErr := transferInByteContext(ctx, client, remotePath, opts)
|
||||
if runErr != nil {
|
||||
return runErr
|
||||
@@ -617,10 +765,16 @@ func SftpTransferInByte(remotePath string, sftpClient *sftp.Client) ([]byte, err
|
||||
}
|
||||
|
||||
func SftpTransferInByteWithContext(ctx context.Context, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) ([]byte, error) {
|
||||
opts := normalizeSFTPTransferOptions(options)
|
||||
opts, err := normalizeSFTPTransferOptions(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rejectExternalSFTPClientOptions(opts.Client); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var data []byte
|
||||
err := executeSFTPRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(attempt int) error {
|
||||
err = executeSFTPRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(attempt int) error {
|
||||
out, runErr := transferInByteContext(ctx, sftpClient, remotePath, opts)
|
||||
if runErr != nil {
|
||||
return runErr
|
||||
@@ -639,7 +793,10 @@ func (star *StarSSH) SftpTransferIn(src, dst string) error {
|
||||
}
|
||||
|
||||
func (star *StarSSH) SftpTransferInContext(ctx context.Context, src, dst string, options *SFTPTransferOptions) error {
|
||||
opts := normalizeSFTPTransferOptions(options)
|
||||
opts, err := normalizeSFTPTransferOptions(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return star.runSFTPWithRetry(ctx, "sftp_get_file", dst, src, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error {
|
||||
return transferInContext(ctx, client, src, dst, opts)
|
||||
})
|
||||
@@ -650,7 +807,13 @@ func SftpTransferIn(src, dst string, sftpClient *sftp.Client) error {
|
||||
}
|
||||
|
||||
func SftpTransferInWithContext(ctx context.Context, src, dst string, sftpClient *sftp.Client, options *SFTPTransferOptions) error {
|
||||
opts := normalizeSFTPTransferOptions(options)
|
||||
opts, err := normalizeSFTPTransferOptions(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rejectExternalSFTPClientOptions(opts.Client); err != nil {
|
||||
return err
|
||||
}
|
||||
return executeSFTPRetry(ctx, "sftp_get_file", dst, src, opts, func(attempt int) error {
|
||||
return transferInContext(ctx, sftpClient, src, dst, opts)
|
||||
})
|
||||
@@ -712,7 +875,7 @@ func transferOutContext(ctx context.Context, sftpClient *sftp.Client, localFileP
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := sftpCopyWithProgressFunc(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil {
|
||||
if _, err := copyUploadWithProgressContext(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress, opts); err != nil {
|
||||
_ = dstFile.Close()
|
||||
return err
|
||||
}
|
||||
@@ -792,7 +955,7 @@ func transferOutByteContext(ctx context.Context, sftpClient *sftp.Client, localD
|
||||
}
|
||||
|
||||
reader := bytes.NewReader(localData)
|
||||
if _, err := sftpCopyWithProgressFunc(ctx, dstFile, reader, opts.BufferSize, int64(len(localData)), opts.Progress); err != nil {
|
||||
if _, err := copyUploadWithProgressContext(ctx, dstFile, reader, opts.BufferSize, int64(len(localData)), opts.Progress, opts); err != nil {
|
||||
_ = dstFile.Close()
|
||||
return err
|
||||
}
|
||||
@@ -877,7 +1040,7 @@ func transferInContext(ctx context.Context, sftpClient *sftp.Client, src, dst st
|
||||
}()
|
||||
}
|
||||
|
||||
if _, err := sftpCopyWithProgressFunc(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil {
|
||||
if _, err := copyDownloadWithProgressContext(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress, opts); err != nil {
|
||||
_ = dstFile.Close()
|
||||
return err
|
||||
}
|
||||
@@ -948,7 +1111,7 @@ func transferInByteContext(ctx context.Context, sftpClient *sftp.Client, remoteP
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
if _, err := sftpCopyWithProgressFunc(ctx, &out, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil {
|
||||
if _, err := copyDownloadWithProgressContext(ctx, &out, srcFile, opts.BufferSize, stat.Size(), opts.Progress, opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1149,6 +1312,151 @@ func copyWithProgressContext(ctx context.Context, dst io.Writer, src io.Reader,
|
||||
return copied, nil
|
||||
}
|
||||
|
||||
func copyUploadWithProgressContext(
|
||||
ctx context.Context,
|
||||
dst io.Writer,
|
||||
src io.Reader,
|
||||
bufSize int,
|
||||
total int64,
|
||||
progress func(float64),
|
||||
opts resolvedSFTPTransferOptions,
|
||||
) (int64, error) {
|
||||
if !derefSFTPBool(opts.Client.ConcurrentWrites, false) {
|
||||
return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress)
|
||||
}
|
||||
|
||||
readerFrom, ok := dst.(sftpConcurrentReaderFrom)
|
||||
if !ok {
|
||||
return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress)
|
||||
}
|
||||
|
||||
if err := ensureContext(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if progress != nil && total > 0 {
|
||||
progress(0)
|
||||
}
|
||||
|
||||
wrappedSrc := &sftpUploadProgressReader{
|
||||
ctx: ctx,
|
||||
reader: src,
|
||||
total: total,
|
||||
progress: progress,
|
||||
}
|
||||
written, err := readerFrom.ReadFromWithConcurrency(wrappedSrc, opts.Client.MaxConcurrentRequestsPerFile)
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
if err := ensureContext(ctx); err != nil {
|
||||
return written, err
|
||||
}
|
||||
reportProgress(progress, written, total)
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func copyDownloadWithProgressContext(
|
||||
ctx context.Context,
|
||||
dst io.Writer,
|
||||
src io.Reader,
|
||||
bufSize int,
|
||||
total int64,
|
||||
progress func(float64),
|
||||
opts resolvedSFTPTransferOptions,
|
||||
) (int64, error) {
|
||||
if !derefSFTPBool(opts.Client.ConcurrentReads, false) {
|
||||
return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress)
|
||||
}
|
||||
|
||||
writerTo, ok := src.(io.WriterTo)
|
||||
if !ok {
|
||||
return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress)
|
||||
}
|
||||
|
||||
if err := ensureContext(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if progress != nil && total > 0 {
|
||||
progress(0)
|
||||
}
|
||||
|
||||
wrappedDst := &sftpDownloadProgressWriter{
|
||||
ctx: ctx,
|
||||
writer: dst,
|
||||
total: total,
|
||||
progress: progress,
|
||||
}
|
||||
written, err := writerTo.WriteTo(wrappedDst)
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
if err := ensureContext(ctx); err != nil {
|
||||
return written, err
|
||||
}
|
||||
reportProgress(progress, written, total)
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (r *sftpUploadProgressReader) Read(p []byte) (int, error) {
|
||||
if err := ensureContext(r.ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if err := ensureContext(r.ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
n, err := r.reader.Read(p)
|
||||
if n > 0 {
|
||||
r.copied += int64(n)
|
||||
reportQueuedTransferProgress(r.progress, r.copied, r.total)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *sftpDownloadProgressWriter) Write(p []byte) (int, error) {
|
||||
if err := ensureContext(w.ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if err := ensureContext(w.ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
n, err := w.writer.Write(p)
|
||||
if n > 0 {
|
||||
w.copied += int64(n)
|
||||
reportQueuedTransferProgress(w.progress, w.copied, w.total)
|
||||
}
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if n != len(p) {
|
||||
return n, io.ErrShortWrite
|
||||
}
|
||||
if err := ensureContext(w.ctx); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func reportQueuedTransferProgress(progress func(float64), copied int64, total int64) {
|
||||
if progress == nil || total <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
percent := float64(copied) / float64(total) * 100
|
||||
if percent >= 100 {
|
||||
percent = 99
|
||||
}
|
||||
progress(percent)
|
||||
}
|
||||
|
||||
func reportProgress(progress func(float64), copied int64, total int64) {
|
||||
if progress == nil {
|
||||
return
|
||||
|
||||
+666
-14
@@ -1,6 +1,7 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -11,16 +12,521 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) {
|
||||
opts := normalizeSFTPTransferOptions(nil)
|
||||
opts := mustNormalizeSFTPTransferOptions(t, nil)
|
||||
if !opts.AtomicUpload {
|
||||
t.Fatal("expected atomic upload to default to enabled")
|
||||
}
|
||||
if !opts.AtomicDownload {
|
||||
t.Fatal("expected atomic download to default to enabled")
|
||||
}
|
||||
if opts.Client.ConcurrentWrites != nil {
|
||||
t.Fatal("expected concurrent writes to default to unset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestThroughputSFTPTransferOptionsEnablesExplicitConcurrentWrites(t *testing.T) {
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
if !opts.AtomicUpload {
|
||||
t.Fatal("expected throughput preset to keep atomic upload enabled")
|
||||
}
|
||||
if opts.Client.ConcurrentReads == nil || !*opts.Client.ConcurrentReads {
|
||||
t.Fatal("expected throughput preset to enable concurrent reads")
|
||||
}
|
||||
if opts.Client.ConcurrentWrites == nil || !*opts.Client.ConcurrentWrites {
|
||||
t.Fatal("expected throughput preset to enable concurrent writes")
|
||||
}
|
||||
if opts.Client.MaxConcurrentRequestsPerFile != 32 {
|
||||
t.Fatalf("unexpected max concurrent requests: got %d want 32", opts.Client.MaxConcurrentRequestsPerFile)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSFTPUploadOptionsRejectsConcurrentWritesWithoutAtomicUpload(t *testing.T) {
|
||||
options := ThroughputSFTPTransferOptions()
|
||||
options.AtomicUpload = SFTPBool(false)
|
||||
opts := mustNormalizeSFTPTransferOptions(t, &options)
|
||||
|
||||
err := validateSFTPUploadOptions(opts)
|
||||
if err == nil || !strings.Contains(err.Error(), "atomic upload") {
|
||||
t.Fatalf("expected atomic upload rejection, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeSFTPTransferOptionsRejectsNegativeClientValues(t *testing.T) {
|
||||
_, err := normalizeSFTPTransferOptions(&SFTPTransferOptions{
|
||||
Client: SFTPClientOptions{MaxPacketSize: -1},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "max packet") {
|
||||
t.Fatalf("expected max packet rejection, got %v", err)
|
||||
}
|
||||
|
||||
_, err = normalizeSFTPTransferOptions(&SFTPTransferOptions{
|
||||
Client: SFTPClientOptions{MaxConcurrentRequestsPerFile: -1},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "max concurrent") {
|
||||
t.Fatalf("expected max concurrent rejection, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSFTPClientOptionsRejectsUnsupportedCheckedPacketSize(t *testing.T) {
|
||||
options := buildSFTPClientOptions(resolvedSFTPClientOptions{MaxPacketSize: 32769})
|
||||
client := &sftp.Client{}
|
||||
|
||||
err := options[0](client)
|
||||
if err == nil || !strings.Contains(err.Error(), "32KB") {
|
||||
t.Fatalf("expected checked packet size rejection, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSftpTransferOutWithContextRejectsClientOptionsForExternalClient(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remotePath := filepath.Join(root, "remote.txt")
|
||||
if err := os.WriteFile(localPath, []byte("payload"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
|
||||
err := SftpTransferOutWithContext(context.Background(), localPath, remotePath, client, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
if err == nil || !strings.Contains(err.Error(), "StarSSH-managed") {
|
||||
t.Fatalf("expected external client option rejection, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSftpTransferOutContextPassesClientOptionsToManagedClient(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remotePath := filepath.Join(root, "remote.txt")
|
||||
if err := os.WriteFile(localPath, []byte("payload"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
|
||||
client := newSFTPTestClient(t)
|
||||
options := ThroughputSFTPTransferOptions()
|
||||
var captured []sftp.ClientOption
|
||||
oldNewClient := sftpNewClientFunc
|
||||
sftpNewClientFunc = func(_ *ssh.Client, opts ...sftp.ClientOption) (*sftp.Client, error) {
|
||||
captured = append([]sftp.ClientOption(nil), opts...)
|
||||
return client, nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
sftpNewClientFunc = oldNewClient
|
||||
})
|
||||
|
||||
star := &StarSSH{Client: &ssh.Client{}}
|
||||
if err := star.SftpTransferOutContext(context.Background(), localPath, remotePath, &options); err != nil {
|
||||
t.Fatalf("transfer out: %v", err)
|
||||
}
|
||||
|
||||
if len(captured) == 0 {
|
||||
t.Fatal("expected managed SFTP client options to be passed to factory")
|
||||
}
|
||||
if got, want := len(captured), len(buildSFTPClientOptions(mustNormalizeSFTPTransferOptions(t, &options).Client)); got != want {
|
||||
t.Fatalf("unexpected client option count: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSftpTransferInWithContextRejectsClientOptionsForExternalClient(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
dstPath := filepath.Join(root, "local.txt")
|
||||
if err := os.WriteFile(srcPath, []byte("payload"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
|
||||
err := SftpTransferInWithContext(context.Background(), srcPath, dstPath, client, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
if err == nil || !strings.Contains(err.Error(), "StarSSH-managed") {
|
||||
t.Fatalf("expected external client option rejection, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSftpTransferInContextPassesClientOptionsToManagedClient(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
dstPath := filepath.Join(root, "local.txt")
|
||||
if err := os.WriteFile(srcPath, []byte("payload"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
|
||||
client := newSFTPTestClient(t)
|
||||
options := ThroughputSFTPTransferOptions()
|
||||
var captured []sftp.ClientOption
|
||||
oldNewClient := sftpNewClientFunc
|
||||
sftpNewClientFunc = func(_ *ssh.Client, opts ...sftp.ClientOption) (*sftp.Client, error) {
|
||||
captured = append([]sftp.ClientOption(nil), opts...)
|
||||
return client, nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
sftpNewClientFunc = oldNewClient
|
||||
})
|
||||
|
||||
star := &StarSSH{Client: &ssh.Client{}}
|
||||
if err := star.SftpTransferInContext(context.Background(), srcPath, dstPath, &options); err != nil {
|
||||
t.Fatalf("transfer in: %v", err)
|
||||
}
|
||||
|
||||
if len(captured) == 0 {
|
||||
t.Fatal("expected managed SFTP client options to be passed to factory")
|
||||
}
|
||||
if got, want := len(captured), len(buildSFTPClientOptions(mustNormalizeSFTPTransferOptions(t, &options).Client)); got != want {
|
||||
t.Fatalf("unexpected client option count: got %d want %d", got, want)
|
||||
}
|
||||
assertFileContent(t, dstPath, "payload")
|
||||
}
|
||||
|
||||
func TestBuildSFTPClientOptionsCanTransfer(t *testing.T) {
|
||||
client := newSFTPTestClientWithOptions(t, buildSFTPClientOptions(resolvedSFTPClientOptions{
|
||||
MaxPacketSize: 4096,
|
||||
MaxConcurrentRequestsPerFile: 4,
|
||||
ConcurrentWrites: SFTPBool(true),
|
||||
}))
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remotePath := filepath.Join(root, "remote.txt")
|
||||
|
||||
if err := os.WriteFile(localPath, []byte(strings.Repeat("payload-", 2048)), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
if err := transferOutContext(context.Background(), client, localPath, remotePath, opts); err != nil {
|
||||
t.Fatalf("transfer out with client options: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(remotePath)
|
||||
if err != nil {
|
||||
t.Fatalf("read remote file: %v", err)
|
||||
}
|
||||
if got := string(data); got != strings.Repeat("payload-", 2048) {
|
||||
t.Fatalf("unexpected remote payload length: got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSFTPClientOptionsCanDownload(t *testing.T) {
|
||||
client := newSFTPTestClientWithOptions(t, buildSFTPClientOptions(resolvedSFTPClientOptions{
|
||||
MaxPacketSize: 4096,
|
||||
MaxConcurrentRequestsPerFile: 4,
|
||||
ConcurrentReads: SFTPBool(true),
|
||||
}))
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
dstPath := filepath.Join(root, "local.txt")
|
||||
|
||||
payload := strings.Repeat("payload-", 2048)
|
||||
if err := os.WriteFile(srcPath, []byte(payload), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
if err := transferInContext(context.Background(), client, srcPath, dstPath, opts); err != nil {
|
||||
t.Fatalf("transfer in with client options: %v", err)
|
||||
}
|
||||
|
||||
assertFileContent(t, dstPath, payload)
|
||||
}
|
||||
|
||||
func TestCopyUploadWithProgressUsesConcurrentReadFromWhenEnabled(t *testing.T) {
|
||||
dst := &spyConcurrentReadFrom{}
|
||||
src := strings.NewReader("payload")
|
||||
var progress []float64
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
|
||||
written, err := copyUploadWithProgressContext(context.Background(), dst, src, 3, int64(len("payload")), func(value float64) {
|
||||
progress = append(progress, value)
|
||||
}, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("copy upload: %v", err)
|
||||
}
|
||||
if written != int64(len("payload")) {
|
||||
t.Fatalf("unexpected written bytes: got %d", written)
|
||||
}
|
||||
if !dst.usedReadFrom {
|
||||
t.Fatal("expected concurrent ReadFrom path to be used")
|
||||
}
|
||||
if dst.concurrency != opts.Client.MaxConcurrentRequestsPerFile {
|
||||
t.Fatalf("unexpected concurrency: got %d want %d", dst.concurrency, opts.Client.MaxConcurrentRequestsPerFile)
|
||||
}
|
||||
if got := dst.buf.String(); got != "payload" {
|
||||
t.Fatalf("unexpected copied payload: got %q", got)
|
||||
}
|
||||
if len(progress) == 0 || progress[len(progress)-1] != 100 {
|
||||
t.Fatalf("expected final progress 100, got %v", progress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyUploadWithProgressReportsDuringConcurrentReadFrom(t *testing.T) {
|
||||
dst := &spyConcurrentReadFrom{}
|
||||
src := &chunkedReader{
|
||||
reader: strings.NewReader("payload"),
|
||||
chunkSize: 2,
|
||||
}
|
||||
var progress []float64
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
|
||||
written, err := copyUploadWithProgressContext(context.Background(), dst, src, 3, int64(len("payload")), func(value float64) {
|
||||
progress = append(progress, value)
|
||||
}, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("copy upload: %v", err)
|
||||
}
|
||||
if written != int64(len("payload")) {
|
||||
t.Fatalf("unexpected written bytes: got %d", written)
|
||||
}
|
||||
|
||||
var sawIntermediate bool
|
||||
for _, value := range progress {
|
||||
if value > 0 && value < 100 {
|
||||
sawIntermediate = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sawIntermediate {
|
||||
t.Fatalf("expected intermediate progress during concurrent readfrom, got %v", progress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyUploadWithProgressCancelsDuringConcurrentReadFrom(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
dst := &spyConcurrentReadFrom{}
|
||||
src := &cancelAfterReadReader{
|
||||
reader: strings.NewReader("payload"),
|
||||
chunkSize: 3,
|
||||
cancel: cancel,
|
||||
}
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
|
||||
written, err := copyUploadWithProgressContext(ctx, dst, src, 3, int64(len("payload")), nil, opts)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context cancellation, got written=%d err=%v", written, err)
|
||||
}
|
||||
if !dst.usedReadFrom {
|
||||
t.Fatal("expected concurrent ReadFrom path to be used")
|
||||
}
|
||||
if written <= 0 || written >= int64(len("payload")) {
|
||||
t.Fatalf("expected partial write before cancellation, got %d", written)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyUploadWithProgressDoesNotReportDoneBeforeConcurrentWriteError(t *testing.T) {
|
||||
copyErr := errors.New("write status failed")
|
||||
dst := &spyConcurrentReadFrom{errAfterRead: copyErr}
|
||||
var progress []float64
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
|
||||
written, err := copyUploadWithProgressContext(context.Background(), dst, strings.NewReader("payload"), 3, int64(len("payload")), func(value float64) {
|
||||
progress = append(progress, value)
|
||||
}, opts)
|
||||
if !errors.Is(err, copyErr) {
|
||||
t.Fatalf("expected concurrent write error, got written=%d err=%v", written, err)
|
||||
}
|
||||
if written != int64(len("payload")) {
|
||||
t.Fatalf("expected read byte count before write error, got %d", written)
|
||||
}
|
||||
for _, value := range progress {
|
||||
if value >= 100 {
|
||||
t.Fatalf("progress reported completion before write success: %v", progress)
|
||||
}
|
||||
}
|
||||
if len(progress) == 0 {
|
||||
t.Fatal("expected queued progress before write error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyUploadWithProgressReturnsConcurrentReadFromError(t *testing.T) {
|
||||
copyErr := errors.New("readfrom failed")
|
||||
dst := &spyConcurrentReadFrom{err: copyErr}
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
|
||||
_, err := copyUploadWithProgressContext(context.Background(), dst, strings.NewReader("payload"), 3, int64(len("payload")), nil, opts)
|
||||
if !errors.Is(err, copyErr) {
|
||||
t.Fatalf("expected concurrent readfrom error, got %v", err)
|
||||
}
|
||||
if !dst.usedReadFrom {
|
||||
t.Fatal("expected concurrent ReadFrom path to be used")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyUploadWithProgressDefaultsToSequentialCopy(t *testing.T) {
|
||||
oldCopy := sftpCopyWithProgressFunc
|
||||
called := false
|
||||
sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) {
|
||||
called = true
|
||||
return oldCopy(ctx, dst, src, bufSize, total, progress)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
sftpCopyWithProgressFunc = oldCopy
|
||||
})
|
||||
|
||||
var dst bytes.Buffer
|
||||
written, err := copyUploadWithProgressContext(context.Background(), &dst, strings.NewReader("payload"), 3, int64(len("payload")), nil, mustNormalizeSFTPTransferOptions(t, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("copy upload: %v", err)
|
||||
}
|
||||
if written != int64(len("payload")) {
|
||||
t.Fatalf("unexpected written bytes: got %d", written)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("expected default path to use existing copy helper")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyDownloadWithProgressUsesConcurrentWriteToWhenEnabled(t *testing.T) {
|
||||
src := &spyConcurrentWriteTo{payload: []byte("payload")}
|
||||
var dst bytes.Buffer
|
||||
var progress []float64
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
|
||||
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) {
|
||||
progress = append(progress, value)
|
||||
}, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("copy download: %v", err)
|
||||
}
|
||||
if written != int64(len("payload")) {
|
||||
t.Fatalf("unexpected written bytes: got %d", written)
|
||||
}
|
||||
if !src.usedWriteTo {
|
||||
t.Fatal("expected concurrent WriteTo path to be used")
|
||||
}
|
||||
if got := dst.String(); got != "payload" {
|
||||
t.Fatalf("unexpected copied payload: got %q", got)
|
||||
}
|
||||
if len(progress) == 0 || progress[len(progress)-1] != 100 {
|
||||
t.Fatalf("expected final progress 100, got %v", progress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyDownloadWithProgressReportsDuringConcurrentWriteTo(t *testing.T) {
|
||||
src := &spyConcurrentWriteTo{
|
||||
payload: []byte("payload"),
|
||||
chunkSize: 2,
|
||||
}
|
||||
var dst bytes.Buffer
|
||||
var progress []float64
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
|
||||
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) {
|
||||
progress = append(progress, value)
|
||||
}, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("copy download: %v", err)
|
||||
}
|
||||
if written != int64(len("payload")) {
|
||||
t.Fatalf("unexpected written bytes: got %d", written)
|
||||
}
|
||||
|
||||
var sawIntermediate bool
|
||||
for _, value := range progress {
|
||||
if value > 0 && value < 100 {
|
||||
sawIntermediate = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sawIntermediate {
|
||||
t.Fatalf("expected intermediate progress during concurrent writeto, got %v", progress)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyDownloadWithProgressCancelsDuringConcurrentWriteTo(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
src := &spyConcurrentWriteTo{
|
||||
payload: []byte("payload"),
|
||||
chunkSize: 3,
|
||||
cancelAfterWrite: cancel,
|
||||
}
|
||||
var dst bytes.Buffer
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
|
||||
written, err := copyDownloadWithProgressContext(ctx, &dst, src, 3, int64(len("payload")), nil, opts)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context cancellation, got written=%d err=%v", written, err)
|
||||
}
|
||||
if !src.usedWriteTo {
|
||||
t.Fatal("expected concurrent WriteTo path to be used")
|
||||
}
|
||||
if written <= 0 || written >= int64(len("payload")) {
|
||||
t.Fatalf("expected partial write before cancellation, got %d", written)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyDownloadWithProgressDoesNotReportDoneBeforeConcurrentReadError(t *testing.T) {
|
||||
copyErr := errors.New("read status failed")
|
||||
src := &spyConcurrentWriteTo{
|
||||
payload: []byte("payload"),
|
||||
errAfterWrite: copyErr,
|
||||
}
|
||||
var dst bytes.Buffer
|
||||
var progress []float64
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
|
||||
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) {
|
||||
progress = append(progress, value)
|
||||
}, opts)
|
||||
if !errors.Is(err, copyErr) {
|
||||
t.Fatalf("expected concurrent read error, got written=%d err=%v", written, err)
|
||||
}
|
||||
if written != int64(len("payload")) {
|
||||
t.Fatalf("expected local byte count before read error, got %d", written)
|
||||
}
|
||||
for _, value := range progress {
|
||||
if value >= 100 {
|
||||
t.Fatalf("progress reported completion before download success: %v", progress)
|
||||
}
|
||||
}
|
||||
if len(progress) == 0 {
|
||||
t.Fatal("expected queued progress before read error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyDownloadWithProgressReturnsConcurrentWriteToError(t *testing.T) {
|
||||
copyErr := errors.New("writeto failed")
|
||||
src := &spyConcurrentWriteTo{err: copyErr}
|
||||
var dst bytes.Buffer
|
||||
opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions()))
|
||||
|
||||
_, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), nil, opts)
|
||||
if !errors.Is(err, copyErr) {
|
||||
t.Fatalf("expected concurrent writeto error, got %v", err)
|
||||
}
|
||||
if !src.usedWriteTo {
|
||||
t.Fatal("expected concurrent WriteTo path to be used")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyDownloadWithProgressDefaultsToSequentialCopy(t *testing.T) {
|
||||
oldCopy := sftpCopyWithProgressFunc
|
||||
called := false
|
||||
sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) {
|
||||
called = true
|
||||
return oldCopy(ctx, dst, src, bufSize, total, progress)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
sftpCopyWithProgressFunc = oldCopy
|
||||
})
|
||||
|
||||
var dst bytes.Buffer
|
||||
src := &chunkedReader{
|
||||
reader: strings.NewReader("payload"),
|
||||
chunkSize: 2,
|
||||
}
|
||||
written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), nil, mustNormalizeSFTPTransferOptions(t, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("copy download: %v", err)
|
||||
}
|
||||
if written != int64(len("payload")) {
|
||||
t.Fatalf("unexpected written bytes: got %d", written)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("expected default path to use existing copy helper")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) {
|
||||
@@ -47,7 +553,7 @@ func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) {
|
||||
sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize
|
||||
})
|
||||
|
||||
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
|
||||
err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||
if !errors.Is(err, verifyErr) {
|
||||
t.Fatalf("expected verify failure, got %v", err)
|
||||
}
|
||||
@@ -82,7 +588,7 @@ func TestTransferOutContextRejectsRemoteSymlinkTarget(t *testing.T) {
|
||||
t.Skipf("symlink unsupported: %v", err)
|
||||
}
|
||||
|
||||
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
|
||||
err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||
if err == nil || !strings.Contains(err.Error(), "symlink") {
|
||||
t.Fatalf("expected symlink rejection, got %v", err)
|
||||
}
|
||||
@@ -118,7 +624,7 @@ func TestTransferOutContextRejectsRemoteDirectoryTarget(t *testing.T) {
|
||||
t.Fatalf("mkdir remote target: %v", err)
|
||||
}
|
||||
|
||||
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
|
||||
err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||
if err == nil || !strings.Contains(err.Error(), "directory") {
|
||||
t.Fatalf("expected directory rejection, got %v", err)
|
||||
}
|
||||
@@ -149,7 +655,7 @@ func TestTransferOutContextPreservesRemoteModeOnOverwrite(t *testing.T) {
|
||||
t.Fatalf("chmod remote file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
if err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
||||
t.Fatalf("transfer out: %v", err)
|
||||
}
|
||||
|
||||
@@ -170,7 +676,7 @@ func TestTransferOutContextAppliesLocalModeForNewRemoteFile(t *testing.T) {
|
||||
t.Fatalf("chmod local file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
if err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
||||
t.Fatalf("transfer out: %v", err)
|
||||
}
|
||||
|
||||
@@ -190,7 +696,7 @@ func TestTransferOutByteContextPreservesRemoteModeOnOverwrite(t *testing.T) {
|
||||
t.Fatalf("chmod remote file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
||||
t.Fatalf("transfer out bytes: %v", err)
|
||||
}
|
||||
|
||||
@@ -222,7 +728,7 @@ func TestTransferInContextVerifyFailurePreservesLocalTarget(t *testing.T) {
|
||||
sftpVerifyLocalSizeFunc = oldVerifyLocalSize
|
||||
})
|
||||
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||
if !errors.Is(err, verifyErr) {
|
||||
t.Fatalf("expected verify failure, got %v", err)
|
||||
}
|
||||
@@ -257,7 +763,7 @@ func TestTransferInContextRejectsLocalSymlinkTarget(t *testing.T) {
|
||||
t.Skipf("symlink unsupported: %v", err)
|
||||
}
|
||||
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||
if err == nil || !strings.Contains(err.Error(), "symlink") {
|
||||
t.Fatalf("expected symlink rejection, got %v", err)
|
||||
}
|
||||
@@ -286,7 +792,7 @@ func TestTransferInContextRejectsLocalDirectoryTarget(t *testing.T) {
|
||||
t.Fatalf("mkdir local target: %v", err)
|
||||
}
|
||||
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||
if err == nil || !strings.Contains(err.Error(), "directory") {
|
||||
t.Fatalf("expected directory rejection, got %v", err)
|
||||
}
|
||||
@@ -317,7 +823,7 @@ func TestTransferInContextPreservesLocalModeOnOverwrite(t *testing.T) {
|
||||
t.Fatalf("chmod local file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
if err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
||||
t.Fatalf("transfer in: %v", err)
|
||||
}
|
||||
|
||||
@@ -338,7 +844,7 @@ func TestTransferInContextAppliesRemoteModeForNewLocalFile(t *testing.T) {
|
||||
t.Fatalf("chmod remote file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
if err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil {
|
||||
t.Fatalf("transfer in: %v", err)
|
||||
}
|
||||
|
||||
@@ -387,7 +893,7 @@ func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) {
|
||||
sftpCopyWithProgressFunc = oldCopy
|
||||
})
|
||||
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil))
|
||||
if !errors.Is(err, copyErr) {
|
||||
t.Fatalf("expected copy failure, got %v", err)
|
||||
}
|
||||
@@ -405,7 +911,153 @@ func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) {
|
||||
assertNoTransferTemps(t, dstPath)
|
||||
}
|
||||
|
||||
func mustNormalizeSFTPTransferOptions(t *testing.T, options *SFTPTransferOptions) resolvedSFTPTransferOptions {
|
||||
t.Helper()
|
||||
|
||||
opts, err := normalizeSFTPTransferOptions(options)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize sftp transfer options: %v", err)
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func SFTPOptionsPtr(options SFTPTransferOptions) *SFTPTransferOptions {
|
||||
return &options
|
||||
}
|
||||
|
||||
type spyConcurrentReadFrom struct {
|
||||
buf bytes.Buffer
|
||||
concurrency int
|
||||
usedReadFrom bool
|
||||
err error
|
||||
errAfterRead error
|
||||
}
|
||||
|
||||
func (w *spyConcurrentReadFrom) Write(p []byte) (int, error) {
|
||||
return w.buf.Write(p)
|
||||
}
|
||||
|
||||
func (w *spyConcurrentReadFrom) ReadFromWithConcurrency(r io.Reader, concurrency int) (int64, error) {
|
||||
w.usedReadFrom = true
|
||||
w.concurrency = concurrency
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
written, err := w.buf.ReadFrom(r)
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
if w.errAfterRead != nil {
|
||||
return written, w.errAfterRead
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
type spyConcurrentWriteTo struct {
|
||||
payload []byte
|
||||
chunkSize int
|
||||
usedWriteTo bool
|
||||
err error
|
||||
errAfterWrite error
|
||||
cancelAfterWrite context.CancelFunc
|
||||
}
|
||||
|
||||
func (r *spyConcurrentWriteTo) Read(p []byte) (int, error) {
|
||||
if len(r.payload) == 0 {
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, r.payload)
|
||||
r.payload = r.payload[n:]
|
||||
if len(r.payload) == 0 {
|
||||
return n, io.EOF
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *spyConcurrentWriteTo) WriteTo(w io.Writer) (int64, error) {
|
||||
r.usedWriteTo = true
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
|
||||
payload := r.payload
|
||||
if len(payload) == 0 {
|
||||
if r.errAfterWrite != nil {
|
||||
return 0, r.errAfterWrite
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
chunkSize := r.chunkSize
|
||||
if chunkSize <= 0 || chunkSize > len(payload) {
|
||||
chunkSize = len(payload)
|
||||
}
|
||||
|
||||
var written int64
|
||||
for len(payload) > 0 {
|
||||
size := chunkSize
|
||||
if size > len(payload) {
|
||||
size = len(payload)
|
||||
}
|
||||
n, err := w.Write(payload[:size])
|
||||
written += int64(n)
|
||||
if r.cancelAfterWrite != nil && n > 0 {
|
||||
r.cancelAfterWrite()
|
||||
r.cancelAfterWrite = nil
|
||||
}
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
if n != size {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
payload = payload[size:]
|
||||
}
|
||||
|
||||
if r.errAfterWrite != nil {
|
||||
return written, r.errAfterWrite
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
type chunkedReader struct {
|
||||
reader io.Reader
|
||||
chunkSize int
|
||||
}
|
||||
|
||||
func (r *chunkedReader) Read(p []byte) (int, error) {
|
||||
if r.chunkSize > 0 && len(p) > r.chunkSize {
|
||||
p = p[:r.chunkSize]
|
||||
}
|
||||
return r.reader.Read(p)
|
||||
}
|
||||
|
||||
type cancelAfterReadReader struct {
|
||||
reader io.Reader
|
||||
chunkSize int
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (r *cancelAfterReadReader) Read(p []byte) (int, error) {
|
||||
if r.chunkSize > 0 && len(p) > r.chunkSize {
|
||||
p = p[:r.chunkSize]
|
||||
}
|
||||
n, err := r.reader.Read(p)
|
||||
if n > 0 && r.cancel != nil {
|
||||
r.cancel()
|
||||
r.cancel = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func newSFTPTestClient(t *testing.T) *sftp.Client {
|
||||
return newSFTPTestClientWithOptions(t, nil)
|
||||
}
|
||||
|
||||
func newSFTPTestClientWithOptions(t *testing.T, options []sftp.ClientOption) *sftp.Client {
|
||||
t.Helper()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
@@ -419,7 +1071,7 @@ func newSFTPTestClient(t *testing.T) *sftp.Client {
|
||||
serveErrCh <- server.Serve()
|
||||
}()
|
||||
|
||||
client, err := sftp.NewClientPipe(clientConn, clientConn)
|
||||
client, err := sftp.NewClientPipe(clientConn, clientConn, options...)
|
||||
if err != nil {
|
||||
_ = server.Close()
|
||||
t.Fatalf("create sftp client: %v", err)
|
||||
|
||||
Reference in New Issue
Block a user