diff --git a/sftp.go b/sftp.go index 3bd40a3..54c575b 100644 --- a/sftp.go +++ b/sftp.go @@ -13,6 +13,7 @@ import ( "path" "path/filepath" "strings" + "sync" "time" "github.com/pkg/sftp" @@ -36,6 +37,19 @@ type SFTPTransferOptions struct { VerifySize *bool VerifyChecksum *bool TempSuffix string + Client SFTPClientOptions +} + +// SFTPClientOptions controls the underlying SFTP protocol client. +// +// These options only apply when StarSSH creates the SFTP client internally. +// They are intentionally separate from BufferSize, which only controls the +// local copy buffer used by transfer progress reporting. +type SFTPClientOptions struct { + MaxPacketSize int + MaxConcurrentRequestsPerFile int + ConcurrentReads *bool + ConcurrentWrites *bool } type resolvedSFTPTransferOptions struct { @@ -48,6 +62,38 @@ type resolvedSFTPTransferOptions struct { VerifySize bool VerifyChecksum bool TempSuffix string + Client resolvedSFTPClientOptions +} + +type resolvedSFTPClientOptions struct { + MaxPacketSize int + MaxConcurrentRequestsPerFile int + ConcurrentReads *bool + ConcurrentWrites *bool +} + +type sftpConcurrentReaderFrom interface { + ReadFromWithConcurrency(io.Reader, int) (int64, error) +} + +type sftpUploadProgressReader struct { + ctx context.Context + reader io.Reader + total int64 + progress func(float64) + + mu sync.Mutex + copied int64 +} + +type sftpDownloadProgressWriter struct { + ctx context.Context + writer io.Writer + total int64 + progress func(float64) + + mu sync.Mutex + copied int64 } type SFTPErrorCategory string @@ -106,6 +152,7 @@ var ( sftpVerifyLocalSizeFunc = verifyLocalSize sftpLocalFileSHA256Func = localFileSHA256 sftpRemoteFileSHA256Func = remoteFileSHA256 + sftpNewClientFunc = sftp.NewClient ) func DefaultSFTPTransferOptions() SFTPTransferOptions { @@ -121,6 +168,16 @@ func DefaultSFTPTransferOptions() SFTPTransferOptions { } } +func ThroughputSFTPTransferOptions() SFTPTransferOptions { + opts := DefaultSFTPTransferOptions() + opts.Client = SFTPClientOptions{ + ConcurrentReads: SFTPBool(true), + ConcurrentWrites: SFTPBool(true), + MaxConcurrentRequestsPerFile: 32, + } + return opts +} + func SFTPBool(value bool) *bool { return &value } @@ -316,51 +373,40 @@ func (fs *SFTPFileSystem) Rename(ctx context.Context, oldPath string, newPath st }) } -func normalizeSFTPTransferOptions(options *SFTPTransferOptions) resolvedSFTPTransferOptions { +func normalizeSFTPTransferOptions(options *SFTPTransferOptions) (resolvedSFTPTransferOptions, error) { opts := DefaultSFTPTransferOptions() - if options == nil { - return resolvedSFTPTransferOptions{ - BufferSize: opts.BufferSize, - Progress: opts.Progress, - RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)), - RetryInitialBackoff: derefSFTPDuration(opts.RetryInitialBackoff, defaultSFTPRetryInitialBackoff), - AtomicUpload: derefSFTPBool(opts.AtomicUpload, true), - AtomicDownload: derefSFTPBool(opts.AtomicDownload, true), - VerifySize: derefSFTPBool(opts.VerifySize, true), - VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false), - TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix), + if options != nil { + if options.BufferSize > 0 { + opts.BufferSize = options.BufferSize } + if options.Progress != nil { + opts.Progress = options.Progress + } + if options.RetryCount != nil { + opts.RetryCount = options.RetryCount + } + if options.RetryInitialBackoff != nil { + opts.RetryInitialBackoff = options.RetryInitialBackoff + } + if options.AtomicUpload != nil { + opts.AtomicUpload = options.AtomicUpload + } + if options.AtomicDownload != nil { + opts.AtomicDownload = options.AtomicDownload + } + if options.VerifySize != nil { + opts.VerifySize = options.VerifySize + } + if options.VerifyChecksum != nil { + opts.VerifyChecksum = options.VerifyChecksum + } + if strings.TrimSpace(options.TempSuffix) != "" { + opts.TempSuffix = options.TempSuffix + } + opts.Client = options.Client } - if options.BufferSize > 0 { - opts.BufferSize = options.BufferSize - } - if options.Progress != nil { - opts.Progress = options.Progress - } - if options.RetryCount != nil { - opts.RetryCount = options.RetryCount - } - if options.RetryInitialBackoff != nil { - opts.RetryInitialBackoff = options.RetryInitialBackoff - } - if options.AtomicUpload != nil { - opts.AtomicUpload = options.AtomicUpload - } - if options.AtomicDownload != nil { - opts.AtomicDownload = options.AtomicDownload - } - if options.VerifySize != nil { - opts.VerifySize = options.VerifySize - } - if options.VerifyChecksum != nil { - opts.VerifyChecksum = options.VerifyChecksum - } - if strings.TrimSpace(options.TempSuffix) != "" { - opts.TempSuffix = options.TempSuffix - } - - return resolvedSFTPTransferOptions{ + resolved := resolvedSFTPTransferOptions{ BufferSize: opts.BufferSize, Progress: opts.Progress, RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)), @@ -371,6 +417,11 @@ func normalizeSFTPTransferOptions(options *SFTPTransferOptions) resolvedSFTPTran VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false), TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix), } + resolved.Client = normalizeSFTPClientOptions(opts.Client) + if err := validateResolvedSFTPTransferOptions(resolved); err != nil { + return resolvedSFTPTransferOptions{}, err + } + return resolved, nil } func derefSFTPBool(value *bool, fallback bool) bool { @@ -409,13 +460,56 @@ func normalizeSFTPRetryCount(value int) int { return value } +func normalizeSFTPClientOptions(options SFTPClientOptions) resolvedSFTPClientOptions { + return resolvedSFTPClientOptions{ + MaxPacketSize: options.MaxPacketSize, + MaxConcurrentRequestsPerFile: options.MaxConcurrentRequestsPerFile, + ConcurrentReads: options.ConcurrentReads, + ConcurrentWrites: options.ConcurrentWrites, + } +} + +func validateResolvedSFTPTransferOptions(opts resolvedSFTPTransferOptions) error { + return validateResolvedSFTPClientOptions(opts.Client) +} + +func validateResolvedSFTPClientOptions(opts resolvedSFTPClientOptions) error { + if opts.MaxPacketSize < 0 { + return errors.New("sftp max packet size must not be negative") + } + if opts.MaxConcurrentRequestsPerFile < 0 { + return errors.New("sftp max concurrent requests per file must not be negative") + } + return nil +} + +func rejectExternalSFTPClientOptions(opts resolvedSFTPClientOptions) error { + if opts.MaxPacketSize != 0 || + opts.MaxConcurrentRequestsPerFile != 0 || + opts.ConcurrentReads != nil || + opts.ConcurrentWrites != nil { + return errors.New("sftp client options require StarSSH-managed SFTP client") + } + return nil +} + +func validateSFTPUploadOptions(opts resolvedSFTPTransferOptions) error { + if derefSFTPBool(opts.Client.ConcurrentWrites, false) && !opts.AtomicUpload { + return errors.New("sftp concurrent writes require atomic upload") + } + return nil +} + func (star *StarSSH) runSFTPClientOperation(ctx context.Context, operation string, remotePath string, fn func(*sftp.Client) error) error { if err := ensureContext(ctx); err != nil { return err } - opts := normalizeSFTPTransferOptions(nil) + opts, err := normalizeSFTPTransferOptions(nil) + if err != nil { + return err + } return executeSFTPRetry(ctx, operation, "", remotePath, opts, func(attempt int) error { - return star.withIsolatedSFTPClient(ctx, fn) + return star.withIsolatedSFTPClient(ctx, opts.Client, fn) }) } @@ -423,23 +517,27 @@ func (star *StarSSH) runSFTPClientOperationNoRetry(ctx context.Context, fn func( if err := ensureContext(ctx); err != nil { return err } - return star.withIsolatedSFTPClient(ctx, fn) + return star.withIsolatedSFTPClient(ctx, resolvedSFTPClientOptions{}, fn) } func (star *StarSSH) CreateSftpClient() (*sftp.Client, error) { + return star.createSFTPClientWithOptions(resolvedSFTPClientOptions{}) +} + +func (star *StarSSH) createSFTPClientWithOptions(options resolvedSFTPClientOptions) (*sftp.Client, error) { client, err := star.requireSSHClient() if err != nil { return nil, err } - return sftp.NewClient(client) + return sftpNewClientFunc(client, buildSFTPClientOptions(options)...) } -func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, fn func(*sftp.Client) error) error { +func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, options resolvedSFTPClientOptions, fn func(*sftp.Client) error) error { if err := ensureContext(ctx); err != nil { return err } - client, err := star.CreateSftpClient() + client, err := star.createSFTPClientWithOptions(options) if err != nil { return err } @@ -448,6 +546,23 @@ func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, fn func(*sftp.C return fn(client) } +func buildSFTPClientOptions(options resolvedSFTPClientOptions) []sftp.ClientOption { + clientOptions := make([]sftp.ClientOption, 0, 4) + if options.MaxPacketSize > 0 { + clientOptions = append(clientOptions, sftp.MaxPacketChecked(options.MaxPacketSize)) + } + if options.MaxConcurrentRequestsPerFile > 0 { + clientOptions = append(clientOptions, sftp.MaxConcurrentRequestsPerFile(options.MaxConcurrentRequestsPerFile)) + } + if options.ConcurrentReads != nil { + clientOptions = append(clientOptions, sftp.UseConcurrentReads(*options.ConcurrentReads)) + } + if options.ConcurrentWrites != nil { + clientOptions = append(clientOptions, sftp.UseConcurrentWrites(*options.ConcurrentWrites)) + } + return clientOptions +} + func (star *StarSSH) getReusableSFTPClient() (*sftp.Client, error) { if star == nil { return nil, errors.New("ssh client is nil") @@ -526,7 +641,7 @@ func (star *StarSSH) runSFTPWithRetry( fn func(context.Context, *sftp.Client, resolvedSFTPTransferOptions) error, ) error { return executeSFTPRetry(ctx, operation, localPath, remotePath, opts, func(attempt int) error { - return star.withIsolatedSFTPClient(ctx, func(client *sftp.Client) error { + return star.withIsolatedSFTPClient(ctx, opts.Client, func(client *sftp.Client) error { return fn(ctx, client, opts) }) }) @@ -537,7 +652,13 @@ func (star *StarSSH) SftpTransferOut(localFilePath, remotePath string) error { } func (star *StarSSH) SftpTransferOutContext(ctx context.Context, localFilePath, remotePath string, options *SFTPTransferOptions) error { - opts := normalizeSFTPTransferOptions(options) + opts, err := normalizeSFTPTransferOptions(options) + if err != nil { + return err + } + if err := validateSFTPUploadOptions(opts); err != nil { + return err + } return star.runSFTPWithRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { return transferOutContext(ctx, client, localFilePath, remotePath, opts) }) @@ -548,7 +669,16 @@ func SftpTransferOut(localFilePath, remotePath string, sftpClient *sftp.Client) } func SftpTransferOutWithContext(ctx context.Context, localFilePath, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error { - opts := normalizeSFTPTransferOptions(options) + opts, err := normalizeSFTPTransferOptions(options) + if err != nil { + return err + } + if err := validateSFTPUploadOptions(opts); err != nil { + return err + } + if err := rejectExternalSFTPClientOptions(opts.Client); err != nil { + return err + } return executeSFTPRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(attempt int) error { return transferOutContext(ctx, sftpClient, localFilePath, remotePath, opts) }) @@ -559,7 +689,13 @@ func (star *StarSSH) SftpTransferOutByte(localData []byte, remotePath string) er } func (star *StarSSH) SftpTransferOutByteContext(ctx context.Context, localData []byte, remotePath string, options *SFTPTransferOptions) error { - opts := normalizeSFTPTransferOptions(options) + opts, err := normalizeSFTPTransferOptions(options) + if err != nil { + return err + } + if err := validateSFTPUploadOptions(opts); err != nil { + return err + } return star.runSFTPWithRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { return transferOutByteContext(ctx, client, localData, remotePath, opts) }) @@ -570,7 +706,16 @@ func SftpTransferOutByte(localData []byte, remotePath string, sftpClient *sftp.C } func SftpTransferOutByteWithContext(ctx context.Context, localData []byte, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error { - opts := normalizeSFTPTransferOptions(options) + opts, err := normalizeSFTPTransferOptions(options) + if err != nil { + return err + } + if err := validateSFTPUploadOptions(opts); err != nil { + return err + } + if err := rejectExternalSFTPClientOptions(opts.Client); err != nil { + return err + } return executeSFTPRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(attempt int) error { return transferOutByteContext(ctx, sftpClient, localData, remotePath, opts) }) @@ -595,10 +740,13 @@ func (star *StarSSH) SftpTransferInByte(remotePath string) ([]byte, error) { } func (star *StarSSH) SftpTransferInByteContext(ctx context.Context, remotePath string, options *SFTPTransferOptions) ([]byte, error) { - opts := normalizeSFTPTransferOptions(options) + opts, err := normalizeSFTPTransferOptions(options) + if err != nil { + return nil, err + } var data []byte - err := star.runSFTPWithRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { + err = star.runSFTPWithRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { out, runErr := transferInByteContext(ctx, client, remotePath, opts) if runErr != nil { return runErr @@ -617,10 +765,16 @@ func SftpTransferInByte(remotePath string, sftpClient *sftp.Client) ([]byte, err } func SftpTransferInByteWithContext(ctx context.Context, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) ([]byte, error) { - opts := normalizeSFTPTransferOptions(options) + opts, err := normalizeSFTPTransferOptions(options) + if err != nil { + return nil, err + } + if err := rejectExternalSFTPClientOptions(opts.Client); err != nil { + return nil, err + } var data []byte - err := executeSFTPRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(attempt int) error { + err = executeSFTPRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(attempt int) error { out, runErr := transferInByteContext(ctx, sftpClient, remotePath, opts) if runErr != nil { return runErr @@ -639,7 +793,10 @@ func (star *StarSSH) SftpTransferIn(src, dst string) error { } func (star *StarSSH) SftpTransferInContext(ctx context.Context, src, dst string, options *SFTPTransferOptions) error { - opts := normalizeSFTPTransferOptions(options) + opts, err := normalizeSFTPTransferOptions(options) + if err != nil { + return err + } return star.runSFTPWithRetry(ctx, "sftp_get_file", dst, src, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { return transferInContext(ctx, client, src, dst, opts) }) @@ -650,7 +807,13 @@ func SftpTransferIn(src, dst string, sftpClient *sftp.Client) error { } func SftpTransferInWithContext(ctx context.Context, src, dst string, sftpClient *sftp.Client, options *SFTPTransferOptions) error { - opts := normalizeSFTPTransferOptions(options) + opts, err := normalizeSFTPTransferOptions(options) + if err != nil { + return err + } + if err := rejectExternalSFTPClientOptions(opts.Client); err != nil { + return err + } return executeSFTPRetry(ctx, "sftp_get_file", dst, src, opts, func(attempt int) error { return transferInContext(ctx, sftpClient, src, dst, opts) }) @@ -712,7 +875,7 @@ func transferOutContext(ctx context.Context, sftpClient *sftp.Client, localFileP return err } - if _, err := sftpCopyWithProgressFunc(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil { + if _, err := copyUploadWithProgressContext(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress, opts); err != nil { _ = dstFile.Close() return err } @@ -792,7 +955,7 @@ func transferOutByteContext(ctx context.Context, sftpClient *sftp.Client, localD } reader := bytes.NewReader(localData) - if _, err := sftpCopyWithProgressFunc(ctx, dstFile, reader, opts.BufferSize, int64(len(localData)), opts.Progress); err != nil { + if _, err := copyUploadWithProgressContext(ctx, dstFile, reader, opts.BufferSize, int64(len(localData)), opts.Progress, opts); err != nil { _ = dstFile.Close() return err } @@ -877,7 +1040,7 @@ func transferInContext(ctx context.Context, sftpClient *sftp.Client, src, dst st }() } - if _, err := sftpCopyWithProgressFunc(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil { + if _, err := copyDownloadWithProgressContext(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress, opts); err != nil { _ = dstFile.Close() return err } @@ -948,7 +1111,7 @@ func transferInByteContext(ctx context.Context, sftpClient *sftp.Client, remoteP } var out bytes.Buffer - if _, err := sftpCopyWithProgressFunc(ctx, &out, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil { + if _, err := copyDownloadWithProgressContext(ctx, &out, srcFile, opts.BufferSize, stat.Size(), opts.Progress, opts); err != nil { return nil, err } @@ -1149,6 +1312,151 @@ func copyWithProgressContext(ctx context.Context, dst io.Writer, src io.Reader, return copied, nil } +func copyUploadWithProgressContext( + ctx context.Context, + dst io.Writer, + src io.Reader, + bufSize int, + total int64, + progress func(float64), + opts resolvedSFTPTransferOptions, +) (int64, error) { + if !derefSFTPBool(opts.Client.ConcurrentWrites, false) { + return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress) + } + + readerFrom, ok := dst.(sftpConcurrentReaderFrom) + if !ok { + return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress) + } + + if err := ensureContext(ctx); err != nil { + return 0, err + } + if progress != nil && total > 0 { + progress(0) + } + + wrappedSrc := &sftpUploadProgressReader{ + ctx: ctx, + reader: src, + total: total, + progress: progress, + } + written, err := readerFrom.ReadFromWithConcurrency(wrappedSrc, opts.Client.MaxConcurrentRequestsPerFile) + if err != nil { + return written, err + } + if err := ensureContext(ctx); err != nil { + return written, err + } + reportProgress(progress, written, total) + return written, nil +} + +func copyDownloadWithProgressContext( + ctx context.Context, + dst io.Writer, + src io.Reader, + bufSize int, + total int64, + progress func(float64), + opts resolvedSFTPTransferOptions, +) (int64, error) { + if !derefSFTPBool(opts.Client.ConcurrentReads, false) { + return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress) + } + + writerTo, ok := src.(io.WriterTo) + if !ok { + return sftpCopyWithProgressFunc(ctx, dst, src, bufSize, total, progress) + } + + if err := ensureContext(ctx); err != nil { + return 0, err + } + if progress != nil && total > 0 { + progress(0) + } + + wrappedDst := &sftpDownloadProgressWriter{ + ctx: ctx, + writer: dst, + total: total, + progress: progress, + } + written, err := writerTo.WriteTo(wrappedDst) + if err != nil { + return written, err + } + if err := ensureContext(ctx); err != nil { + return written, err + } + reportProgress(progress, written, total) + return written, nil +} + +func (r *sftpUploadProgressReader) Read(p []byte) (int, error) { + if err := ensureContext(r.ctx); err != nil { + return 0, err + } + + r.mu.Lock() + defer r.mu.Unlock() + + if err := ensureContext(r.ctx); err != nil { + return 0, err + } + + n, err := r.reader.Read(p) + if n > 0 { + r.copied += int64(n) + reportQueuedTransferProgress(r.progress, r.copied, r.total) + } + return n, err +} + +func (w *sftpDownloadProgressWriter) Write(p []byte) (int, error) { + if err := ensureContext(w.ctx); err != nil { + return 0, err + } + + w.mu.Lock() + defer w.mu.Unlock() + + if err := ensureContext(w.ctx); err != nil { + return 0, err + } + + n, err := w.writer.Write(p) + if n > 0 { + w.copied += int64(n) + reportQueuedTransferProgress(w.progress, w.copied, w.total) + } + if err != nil { + return n, err + } + if n != len(p) { + return n, io.ErrShortWrite + } + if err := ensureContext(w.ctx); err != nil { + return n, err + } + return n, nil +} + +func reportQueuedTransferProgress(progress func(float64), copied int64, total int64) { + if progress == nil || total <= 0 { + return + } + + percent := float64(copied) / float64(total) * 100 + if percent >= 100 { + percent = 99 + } + progress(percent) +} + func reportProgress(progress func(float64), copied int64, total int64) { if progress == nil { return diff --git a/sftp_test.go b/sftp_test.go index 1170c6f..0f7a18d 100644 --- a/sftp_test.go +++ b/sftp_test.go @@ -1,6 +1,7 @@ package starssh import ( + "bytes" "context" "errors" "io" @@ -11,16 +12,521 @@ import ( "testing" "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" ) func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) { - opts := normalizeSFTPTransferOptions(nil) + opts := mustNormalizeSFTPTransferOptions(t, nil) if !opts.AtomicUpload { t.Fatal("expected atomic upload to default to enabled") } if !opts.AtomicDownload { t.Fatal("expected atomic download to default to enabled") } + if opts.Client.ConcurrentWrites != nil { + t.Fatal("expected concurrent writes to default to unset") + } +} + +func TestThroughputSFTPTransferOptionsEnablesExplicitConcurrentWrites(t *testing.T) { + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + if !opts.AtomicUpload { + t.Fatal("expected throughput preset to keep atomic upload enabled") + } + if opts.Client.ConcurrentReads == nil || !*opts.Client.ConcurrentReads { + t.Fatal("expected throughput preset to enable concurrent reads") + } + if opts.Client.ConcurrentWrites == nil || !*opts.Client.ConcurrentWrites { + t.Fatal("expected throughput preset to enable concurrent writes") + } + if opts.Client.MaxConcurrentRequestsPerFile != 32 { + t.Fatalf("unexpected max concurrent requests: got %d want 32", opts.Client.MaxConcurrentRequestsPerFile) + } +} + +func TestValidateSFTPUploadOptionsRejectsConcurrentWritesWithoutAtomicUpload(t *testing.T) { + options := ThroughputSFTPTransferOptions() + options.AtomicUpload = SFTPBool(false) + opts := mustNormalizeSFTPTransferOptions(t, &options) + + err := validateSFTPUploadOptions(opts) + if err == nil || !strings.Contains(err.Error(), "atomic upload") { + t.Fatalf("expected atomic upload rejection, got %v", err) + } +} + +func TestNormalizeSFTPTransferOptionsRejectsNegativeClientValues(t *testing.T) { + _, err := normalizeSFTPTransferOptions(&SFTPTransferOptions{ + Client: SFTPClientOptions{MaxPacketSize: -1}, + }) + if err == nil || !strings.Contains(err.Error(), "max packet") { + t.Fatalf("expected max packet rejection, got %v", err) + } + + _, err = normalizeSFTPTransferOptions(&SFTPTransferOptions{ + Client: SFTPClientOptions{MaxConcurrentRequestsPerFile: -1}, + }) + if err == nil || !strings.Contains(err.Error(), "max concurrent") { + t.Fatalf("expected max concurrent rejection, got %v", err) + } +} + +func TestBuildSFTPClientOptionsRejectsUnsupportedCheckedPacketSize(t *testing.T) { + options := buildSFTPClientOptions(resolvedSFTPClientOptions{MaxPacketSize: 32769}) + client := &sftp.Client{} + + err := options[0](client) + if err == nil || !strings.Contains(err.Error(), "32KB") { + t.Fatalf("expected checked packet size rejection, got %v", err) + } +} + +func TestSftpTransferOutWithContextRejectsClientOptionsForExternalClient(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + localPath := filepath.Join(root, "local.txt") + remotePath := filepath.Join(root, "remote.txt") + if err := os.WriteFile(localPath, []byte("payload"), 0o644); err != nil { + t.Fatalf("write local file: %v", err) + } + + err := SftpTransferOutWithContext(context.Background(), localPath, remotePath, client, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + if err == nil || !strings.Contains(err.Error(), "StarSSH-managed") { + t.Fatalf("expected external client option rejection, got %v", err) + } +} + +func TestSftpTransferOutContextPassesClientOptionsToManagedClient(t *testing.T) { + root := t.TempDir() + localPath := filepath.Join(root, "local.txt") + remotePath := filepath.Join(root, "remote.txt") + if err := os.WriteFile(localPath, []byte("payload"), 0o644); err != nil { + t.Fatalf("write local file: %v", err) + } + + client := newSFTPTestClient(t) + options := ThroughputSFTPTransferOptions() + var captured []sftp.ClientOption + oldNewClient := sftpNewClientFunc + sftpNewClientFunc = func(_ *ssh.Client, opts ...sftp.ClientOption) (*sftp.Client, error) { + captured = append([]sftp.ClientOption(nil), opts...) + return client, nil + } + t.Cleanup(func() { + sftpNewClientFunc = oldNewClient + }) + + star := &StarSSH{Client: &ssh.Client{}} + if err := star.SftpTransferOutContext(context.Background(), localPath, remotePath, &options); err != nil { + t.Fatalf("transfer out: %v", err) + } + + if len(captured) == 0 { + t.Fatal("expected managed SFTP client options to be passed to factory") + } + if got, want := len(captured), len(buildSFTPClientOptions(mustNormalizeSFTPTransferOptions(t, &options).Client)); got != want { + t.Fatalf("unexpected client option count: got %d want %d", got, want) + } +} + +func TestSftpTransferInWithContextRejectsClientOptionsForExternalClient(t *testing.T) { + client := newSFTPTestClient(t) + root := t.TempDir() + srcPath := filepath.Join(root, "remote.txt") + dstPath := filepath.Join(root, "local.txt") + if err := os.WriteFile(srcPath, []byte("payload"), 0o644); err != nil { + t.Fatalf("write remote file: %v", err) + } + + err := SftpTransferInWithContext(context.Background(), srcPath, dstPath, client, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + if err == nil || !strings.Contains(err.Error(), "StarSSH-managed") { + t.Fatalf("expected external client option rejection, got %v", err) + } +} + +func TestSftpTransferInContextPassesClientOptionsToManagedClient(t *testing.T) { + root := t.TempDir() + srcPath := filepath.Join(root, "remote.txt") + dstPath := filepath.Join(root, "local.txt") + if err := os.WriteFile(srcPath, []byte("payload"), 0o644); err != nil { + t.Fatalf("write remote file: %v", err) + } + + client := newSFTPTestClient(t) + options := ThroughputSFTPTransferOptions() + var captured []sftp.ClientOption + oldNewClient := sftpNewClientFunc + sftpNewClientFunc = func(_ *ssh.Client, opts ...sftp.ClientOption) (*sftp.Client, error) { + captured = append([]sftp.ClientOption(nil), opts...) + return client, nil + } + t.Cleanup(func() { + sftpNewClientFunc = oldNewClient + }) + + star := &StarSSH{Client: &ssh.Client{}} + if err := star.SftpTransferInContext(context.Background(), srcPath, dstPath, &options); err != nil { + t.Fatalf("transfer in: %v", err) + } + + if len(captured) == 0 { + t.Fatal("expected managed SFTP client options to be passed to factory") + } + if got, want := len(captured), len(buildSFTPClientOptions(mustNormalizeSFTPTransferOptions(t, &options).Client)); got != want { + t.Fatalf("unexpected client option count: got %d want %d", got, want) + } + assertFileContent(t, dstPath, "payload") +} + +func TestBuildSFTPClientOptionsCanTransfer(t *testing.T) { + client := newSFTPTestClientWithOptions(t, buildSFTPClientOptions(resolvedSFTPClientOptions{ + MaxPacketSize: 4096, + MaxConcurrentRequestsPerFile: 4, + ConcurrentWrites: SFTPBool(true), + })) + root := t.TempDir() + localPath := filepath.Join(root, "local.txt") + remotePath := filepath.Join(root, "remote.txt") + + if err := os.WriteFile(localPath, []byte(strings.Repeat("payload-", 2048)), 0o644); err != nil { + t.Fatalf("write local file: %v", err) + } + + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + if err := transferOutContext(context.Background(), client, localPath, remotePath, opts); err != nil { + t.Fatalf("transfer out with client options: %v", err) + } + + data, err := os.ReadFile(remotePath) + if err != nil { + t.Fatalf("read remote file: %v", err) + } + if got := string(data); got != strings.Repeat("payload-", 2048) { + t.Fatalf("unexpected remote payload length: got %d", len(got)) + } +} + +func TestBuildSFTPClientOptionsCanDownload(t *testing.T) { + client := newSFTPTestClientWithOptions(t, buildSFTPClientOptions(resolvedSFTPClientOptions{ + MaxPacketSize: 4096, + MaxConcurrentRequestsPerFile: 4, + ConcurrentReads: SFTPBool(true), + })) + root := t.TempDir() + srcPath := filepath.Join(root, "remote.txt") + dstPath := filepath.Join(root, "local.txt") + + payload := strings.Repeat("payload-", 2048) + if err := os.WriteFile(srcPath, []byte(payload), 0o644); err != nil { + t.Fatalf("write remote file: %v", err) + } + + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + if err := transferInContext(context.Background(), client, srcPath, dstPath, opts); err != nil { + t.Fatalf("transfer in with client options: %v", err) + } + + assertFileContent(t, dstPath, payload) +} + +func TestCopyUploadWithProgressUsesConcurrentReadFromWhenEnabled(t *testing.T) { + dst := &spyConcurrentReadFrom{} + src := strings.NewReader("payload") + var progress []float64 + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + + written, err := copyUploadWithProgressContext(context.Background(), dst, src, 3, int64(len("payload")), func(value float64) { + progress = append(progress, value) + }, opts) + if err != nil { + t.Fatalf("copy upload: %v", err) + } + if written != int64(len("payload")) { + t.Fatalf("unexpected written bytes: got %d", written) + } + if !dst.usedReadFrom { + t.Fatal("expected concurrent ReadFrom path to be used") + } + if dst.concurrency != opts.Client.MaxConcurrentRequestsPerFile { + t.Fatalf("unexpected concurrency: got %d want %d", dst.concurrency, opts.Client.MaxConcurrentRequestsPerFile) + } + if got := dst.buf.String(); got != "payload" { + t.Fatalf("unexpected copied payload: got %q", got) + } + if len(progress) == 0 || progress[len(progress)-1] != 100 { + t.Fatalf("expected final progress 100, got %v", progress) + } +} + +func TestCopyUploadWithProgressReportsDuringConcurrentReadFrom(t *testing.T) { + dst := &spyConcurrentReadFrom{} + src := &chunkedReader{ + reader: strings.NewReader("payload"), + chunkSize: 2, + } + var progress []float64 + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + + written, err := copyUploadWithProgressContext(context.Background(), dst, src, 3, int64(len("payload")), func(value float64) { + progress = append(progress, value) + }, opts) + if err != nil { + t.Fatalf("copy upload: %v", err) + } + if written != int64(len("payload")) { + t.Fatalf("unexpected written bytes: got %d", written) + } + + var sawIntermediate bool + for _, value := range progress { + if value > 0 && value < 100 { + sawIntermediate = true + break + } + } + if !sawIntermediate { + t.Fatalf("expected intermediate progress during concurrent readfrom, got %v", progress) + } +} + +func TestCopyUploadWithProgressCancelsDuringConcurrentReadFrom(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dst := &spyConcurrentReadFrom{} + src := &cancelAfterReadReader{ + reader: strings.NewReader("payload"), + chunkSize: 3, + cancel: cancel, + } + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + + written, err := copyUploadWithProgressContext(ctx, dst, src, 3, int64(len("payload")), nil, opts) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancellation, got written=%d err=%v", written, err) + } + if !dst.usedReadFrom { + t.Fatal("expected concurrent ReadFrom path to be used") + } + if written <= 0 || written >= int64(len("payload")) { + t.Fatalf("expected partial write before cancellation, got %d", written) + } +} + +func TestCopyUploadWithProgressDoesNotReportDoneBeforeConcurrentWriteError(t *testing.T) { + copyErr := errors.New("write status failed") + dst := &spyConcurrentReadFrom{errAfterRead: copyErr} + var progress []float64 + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + + written, err := copyUploadWithProgressContext(context.Background(), dst, strings.NewReader("payload"), 3, int64(len("payload")), func(value float64) { + progress = append(progress, value) + }, opts) + if !errors.Is(err, copyErr) { + t.Fatalf("expected concurrent write error, got written=%d err=%v", written, err) + } + if written != int64(len("payload")) { + t.Fatalf("expected read byte count before write error, got %d", written) + } + for _, value := range progress { + if value >= 100 { + t.Fatalf("progress reported completion before write success: %v", progress) + } + } + if len(progress) == 0 { + t.Fatal("expected queued progress before write error") + } +} + +func TestCopyUploadWithProgressReturnsConcurrentReadFromError(t *testing.T) { + copyErr := errors.New("readfrom failed") + dst := &spyConcurrentReadFrom{err: copyErr} + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + + _, err := copyUploadWithProgressContext(context.Background(), dst, strings.NewReader("payload"), 3, int64(len("payload")), nil, opts) + if !errors.Is(err, copyErr) { + t.Fatalf("expected concurrent readfrom error, got %v", err) + } + if !dst.usedReadFrom { + t.Fatal("expected concurrent ReadFrom path to be used") + } +} + +func TestCopyUploadWithProgressDefaultsToSequentialCopy(t *testing.T) { + oldCopy := sftpCopyWithProgressFunc + called := false + sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) { + called = true + return oldCopy(ctx, dst, src, bufSize, total, progress) + } + t.Cleanup(func() { + sftpCopyWithProgressFunc = oldCopy + }) + + var dst bytes.Buffer + written, err := copyUploadWithProgressContext(context.Background(), &dst, strings.NewReader("payload"), 3, int64(len("payload")), nil, mustNormalizeSFTPTransferOptions(t, nil)) + if err != nil { + t.Fatalf("copy upload: %v", err) + } + if written != int64(len("payload")) { + t.Fatalf("unexpected written bytes: got %d", written) + } + if !called { + t.Fatal("expected default path to use existing copy helper") + } +} + +func TestCopyDownloadWithProgressUsesConcurrentWriteToWhenEnabled(t *testing.T) { + src := &spyConcurrentWriteTo{payload: []byte("payload")} + var dst bytes.Buffer + var progress []float64 + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + + written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) { + progress = append(progress, value) + }, opts) + if err != nil { + t.Fatalf("copy download: %v", err) + } + if written != int64(len("payload")) { + t.Fatalf("unexpected written bytes: got %d", written) + } + if !src.usedWriteTo { + t.Fatal("expected concurrent WriteTo path to be used") + } + if got := dst.String(); got != "payload" { + t.Fatalf("unexpected copied payload: got %q", got) + } + if len(progress) == 0 || progress[len(progress)-1] != 100 { + t.Fatalf("expected final progress 100, got %v", progress) + } +} + +func TestCopyDownloadWithProgressReportsDuringConcurrentWriteTo(t *testing.T) { + src := &spyConcurrentWriteTo{ + payload: []byte("payload"), + chunkSize: 2, + } + var dst bytes.Buffer + var progress []float64 + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + + written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) { + progress = append(progress, value) + }, opts) + if err != nil { + t.Fatalf("copy download: %v", err) + } + if written != int64(len("payload")) { + t.Fatalf("unexpected written bytes: got %d", written) + } + + var sawIntermediate bool + for _, value := range progress { + if value > 0 && value < 100 { + sawIntermediate = true + break + } + } + if !sawIntermediate { + t.Fatalf("expected intermediate progress during concurrent writeto, got %v", progress) + } +} + +func TestCopyDownloadWithProgressCancelsDuringConcurrentWriteTo(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + src := &spyConcurrentWriteTo{ + payload: []byte("payload"), + chunkSize: 3, + cancelAfterWrite: cancel, + } + var dst bytes.Buffer + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + + written, err := copyDownloadWithProgressContext(ctx, &dst, src, 3, int64(len("payload")), nil, opts) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancellation, got written=%d err=%v", written, err) + } + if !src.usedWriteTo { + t.Fatal("expected concurrent WriteTo path to be used") + } + if written <= 0 || written >= int64(len("payload")) { + t.Fatalf("expected partial write before cancellation, got %d", written) + } +} + +func TestCopyDownloadWithProgressDoesNotReportDoneBeforeConcurrentReadError(t *testing.T) { + copyErr := errors.New("read status failed") + src := &spyConcurrentWriteTo{ + payload: []byte("payload"), + errAfterWrite: copyErr, + } + var dst bytes.Buffer + var progress []float64 + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + + written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) { + progress = append(progress, value) + }, opts) + if !errors.Is(err, copyErr) { + t.Fatalf("expected concurrent read error, got written=%d err=%v", written, err) + } + if written != int64(len("payload")) { + t.Fatalf("expected local byte count before read error, got %d", written) + } + for _, value := range progress { + if value >= 100 { + t.Fatalf("progress reported completion before download success: %v", progress) + } + } + if len(progress) == 0 { + t.Fatal("expected queued progress before read error") + } +} + +func TestCopyDownloadWithProgressReturnsConcurrentWriteToError(t *testing.T) { + copyErr := errors.New("writeto failed") + src := &spyConcurrentWriteTo{err: copyErr} + var dst bytes.Buffer + opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) + + _, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), nil, opts) + if !errors.Is(err, copyErr) { + t.Fatalf("expected concurrent writeto error, got %v", err) + } + if !src.usedWriteTo { + t.Fatal("expected concurrent WriteTo path to be used") + } +} + +func TestCopyDownloadWithProgressDefaultsToSequentialCopy(t *testing.T) { + oldCopy := sftpCopyWithProgressFunc + called := false + sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) { + called = true + return oldCopy(ctx, dst, src, bufSize, total, progress) + } + t.Cleanup(func() { + sftpCopyWithProgressFunc = oldCopy + }) + + var dst bytes.Buffer + src := &chunkedReader{ + reader: strings.NewReader("payload"), + chunkSize: 2, + } + written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), nil, mustNormalizeSFTPTransferOptions(t, nil)) + if err != nil { + t.Fatalf("copy download: %v", err) + } + if written != int64(len("payload")) { + t.Fatalf("unexpected written bytes: got %d", written) + } + if !called { + t.Fatal("expected default path to use existing copy helper") + } } func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) { @@ -47,7 +553,7 @@ func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) { sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize }) - err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)) + err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)) if !errors.Is(err, verifyErr) { t.Fatalf("expected verify failure, got %v", err) } @@ -82,7 +588,7 @@ func TestTransferOutContextRejectsRemoteSymlinkTarget(t *testing.T) { t.Skipf("symlink unsupported: %v", err) } - err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)) + err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)) if err == nil || !strings.Contains(err.Error(), "symlink") { t.Fatalf("expected symlink rejection, got %v", err) } @@ -118,7 +624,7 @@ func TestTransferOutContextRejectsRemoteDirectoryTarget(t *testing.T) { t.Fatalf("mkdir remote target: %v", err) } - err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)) + err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)) if err == nil || !strings.Contains(err.Error(), "directory") { t.Fatalf("expected directory rejection, got %v", err) } @@ -149,7 +655,7 @@ func TestTransferOutContextPreservesRemoteModeOnOverwrite(t *testing.T) { t.Fatalf("chmod remote file: %v", err) } - if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil { + if err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil { t.Fatalf("transfer out: %v", err) } @@ -170,7 +676,7 @@ func TestTransferOutContextAppliesLocalModeForNewRemoteFile(t *testing.T) { t.Fatalf("chmod local file: %v", err) } - if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil { + if err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil { t.Fatalf("transfer out: %v", err) } @@ -190,7 +696,7 @@ func TestTransferOutByteContextPreservesRemoteModeOnOverwrite(t *testing.T) { t.Fatalf("chmod remote file: %v", err) } - if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, normalizeSFTPTransferOptions(nil)); err != nil { + if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil { t.Fatalf("transfer out bytes: %v", err) } @@ -222,7 +728,7 @@ func TestTransferInContextVerifyFailurePreservesLocalTarget(t *testing.T) { sftpVerifyLocalSizeFunc = oldVerifyLocalSize }) - err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) + err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)) if !errors.Is(err, verifyErr) { t.Fatalf("expected verify failure, got %v", err) } @@ -257,7 +763,7 @@ func TestTransferInContextRejectsLocalSymlinkTarget(t *testing.T) { t.Skipf("symlink unsupported: %v", err) } - err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) + err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)) if err == nil || !strings.Contains(err.Error(), "symlink") { t.Fatalf("expected symlink rejection, got %v", err) } @@ -286,7 +792,7 @@ func TestTransferInContextRejectsLocalDirectoryTarget(t *testing.T) { t.Fatalf("mkdir local target: %v", err) } - err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) + err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)) if err == nil || !strings.Contains(err.Error(), "directory") { t.Fatalf("expected directory rejection, got %v", err) } @@ -317,7 +823,7 @@ func TestTransferInContextPreservesLocalModeOnOverwrite(t *testing.T) { t.Fatalf("chmod local file: %v", err) } - if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil { + if err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil { t.Fatalf("transfer in: %v", err) } @@ -338,7 +844,7 @@ func TestTransferInContextAppliesRemoteModeForNewLocalFile(t *testing.T) { t.Fatalf("chmod remote file: %v", err) } - if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil { + if err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil { t.Fatalf("transfer in: %v", err) } @@ -387,7 +893,7 @@ func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) { sftpCopyWithProgressFunc = oldCopy }) - err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) + err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)) if !errors.Is(err, copyErr) { t.Fatalf("expected copy failure, got %v", err) } @@ -405,7 +911,153 @@ func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) { assertNoTransferTemps(t, dstPath) } +func mustNormalizeSFTPTransferOptions(t *testing.T, options *SFTPTransferOptions) resolvedSFTPTransferOptions { + t.Helper() + + opts, err := normalizeSFTPTransferOptions(options) + if err != nil { + t.Fatalf("normalize sftp transfer options: %v", err) + } + return opts +} + +func SFTPOptionsPtr(options SFTPTransferOptions) *SFTPTransferOptions { + return &options +} + +type spyConcurrentReadFrom struct { + buf bytes.Buffer + concurrency int + usedReadFrom bool + err error + errAfterRead error +} + +func (w *spyConcurrentReadFrom) Write(p []byte) (int, error) { + return w.buf.Write(p) +} + +func (w *spyConcurrentReadFrom) ReadFromWithConcurrency(r io.Reader, concurrency int) (int64, error) { + w.usedReadFrom = true + w.concurrency = concurrency + if w.err != nil { + return 0, w.err + } + written, err := w.buf.ReadFrom(r) + if err != nil { + return written, err + } + if w.errAfterRead != nil { + return written, w.errAfterRead + } + return written, nil +} + +type spyConcurrentWriteTo struct { + payload []byte + chunkSize int + usedWriteTo bool + err error + errAfterWrite error + cancelAfterWrite context.CancelFunc +} + +func (r *spyConcurrentWriteTo) Read(p []byte) (int, error) { + if len(r.payload) == 0 { + if r.err != nil { + return 0, r.err + } + return 0, io.EOF + } + n := copy(p, r.payload) + r.payload = r.payload[n:] + if len(r.payload) == 0 { + return n, io.EOF + } + return n, nil +} + +func (r *spyConcurrentWriteTo) WriteTo(w io.Writer) (int64, error) { + r.usedWriteTo = true + if r.err != nil { + return 0, r.err + } + + payload := r.payload + if len(payload) == 0 { + if r.errAfterWrite != nil { + return 0, r.errAfterWrite + } + return 0, nil + } + + chunkSize := r.chunkSize + if chunkSize <= 0 || chunkSize > len(payload) { + chunkSize = len(payload) + } + + var written int64 + for len(payload) > 0 { + size := chunkSize + if size > len(payload) { + size = len(payload) + } + n, err := w.Write(payload[:size]) + written += int64(n) + if r.cancelAfterWrite != nil && n > 0 { + r.cancelAfterWrite() + r.cancelAfterWrite = nil + } + if err != nil { + return written, err + } + if n != size { + return written, io.ErrShortWrite + } + payload = payload[size:] + } + + if r.errAfterWrite != nil { + return written, r.errAfterWrite + } + return written, nil +} + +type chunkedReader struct { + reader io.Reader + chunkSize int +} + +func (r *chunkedReader) Read(p []byte) (int, error) { + if r.chunkSize > 0 && len(p) > r.chunkSize { + p = p[:r.chunkSize] + } + return r.reader.Read(p) +} + +type cancelAfterReadReader struct { + reader io.Reader + chunkSize int + cancel context.CancelFunc +} + +func (r *cancelAfterReadReader) Read(p []byte) (int, error) { + if r.chunkSize > 0 && len(p) > r.chunkSize { + p = p[:r.chunkSize] + } + n, err := r.reader.Read(p) + if n > 0 && r.cancel != nil { + r.cancel() + r.cancel = nil + } + return n, err +} + func newSFTPTestClient(t *testing.T) *sftp.Client { + return newSFTPTestClientWithOptions(t, nil) +} + +func newSFTPTestClientWithOptions(t *testing.T, options []sftp.ClientOption) *sftp.Client { t.Helper() serverConn, clientConn := net.Pipe() @@ -419,7 +1071,7 @@ func newSFTPTestClient(t *testing.T) *sftp.Client { serveErrCh <- server.Serve() }() - client, err := sftp.NewClientPipe(clientConn, clientConn) + client, err := sftp.NewClientPipe(clientConn, clientConn, options...) if err != nil { _ = server.Close() t.Fatalf("create sftp client: %v", err)