package starssh import ( "bytes" "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "io" "net" "os" "path" "path/filepath" "strings" "time" "github.com/pkg/sftp" ) const ( defaultSFTPRetryCount = 2 defaultSFTPRetryInitialBackoff = 250 * time.Millisecond defaultSFTPTempSuffix = ".starssh.tmp" ) const preservedFileModeBits os.FileMode = os.ModePerm | os.ModeSetuid | os.ModeSetgid | os.ModeSticky type SFTPTransferOptions struct { BufferSize int Progress func(float64) RetryCount *int RetryInitialBackoff *time.Duration AtomicUpload *bool AtomicDownload *bool VerifySize *bool VerifyChecksum *bool TempSuffix string } type resolvedSFTPTransferOptions struct { BufferSize int Progress func(float64) RetryCount int RetryInitialBackoff time.Duration AtomicUpload bool AtomicDownload bool VerifySize bool VerifyChecksum bool TempSuffix string } type SFTPErrorCategory string const ( SFTPErrorRetryable SFTPErrorCategory = "retryable" SFTPErrorPermanent SFTPErrorCategory = "permanent" ) type SFTPTransferError struct { Operation string LocalPath string RemotePath string Attempt int Category SFTPErrorCategory Err error } func (e *SFTPTransferError) Error() string { if e == nil { return "" } return fmt.Sprintf("%s failed [%s] (attempt=%d, local=%q, remote=%q): %v", e.Operation, e.Category, e.Attempt, e.LocalPath, e.RemotePath, e.Err) } func (e *SFTPTransferError) Unwrap() error { if e == nil { return nil } return e.Err } type FS interface { Stat(context.Context, string) (os.FileInfo, error) ReadDir(context.Context, string) ([]os.FileInfo, error) ReadFile(context.Context, string, *SFTPTransferOptions) ([]byte, error) WriteFile(context.Context, string, []byte, *SFTPTransferOptions) error MkdirAll(context.Context, string) error Remove(context.Context, string) error RemoveAll(context.Context, string) error Rename(context.Context, string, string) error } type SFTPFileSystem struct { star *StarSSH } type atomicReplaceTarget struct { exists bool mode os.FileMode } var ( sftpCopyWithProgressFunc = copyWithProgressContext sftpVerifyRemoteSizeFunc = verifyRemoteSize sftpVerifyLocalSizeFunc = verifyLocalSize sftpLocalFileSHA256Func = localFileSHA256 sftpRemoteFileSHA256Func = remoteFileSHA256 ) func DefaultSFTPTransferOptions() SFTPTransferOptions { return SFTPTransferOptions{ BufferSize: defaultTransferBufferSize, RetryCount: SFTPInt(defaultSFTPRetryCount), RetryInitialBackoff: SFTPDuration(defaultSFTPRetryInitialBackoff), AtomicUpload: SFTPBool(true), AtomicDownload: SFTPBool(true), VerifySize: SFTPBool(true), VerifyChecksum: SFTPBool(false), TempSuffix: defaultSFTPTempSuffix, } } func SFTPBool(value bool) *bool { return &value } func SFTPInt(value int) *int { return &value } func SFTPDuration(value time.Duration) *time.Duration { return &value } func (star *StarSSH) FS() *SFTPFileSystem { return &SFTPFileSystem{star: star} } func (star *StarSSH) Stat(remotePath string) (os.FileInfo, error) { return star.StatContext(context.Background(), remotePath) } func (star *StarSSH) StatContext(ctx context.Context, remotePath string) (os.FileInfo, error) { return star.FS().Stat(ctx, remotePath) } func (star *StarSSH) ReadDir(remotePath string) ([]os.FileInfo, error) { return star.ReadDirContext(context.Background(), remotePath) } func (star *StarSSH) ReadDirContext(ctx context.Context, remotePath string) ([]os.FileInfo, error) { return star.FS().ReadDir(ctx, remotePath) } func (star *StarSSH) ReadFile(remotePath string) ([]byte, error) { return star.ReadFileContext(context.Background(), remotePath, nil) } func (star *StarSSH) ReadFileContext(ctx context.Context, remotePath string, options *SFTPTransferOptions) ([]byte, error) { return star.FS().ReadFile(ctx, remotePath, options) } func (star *StarSSH) WriteFile(remotePath string, data []byte) error { return star.WriteFileContext(context.Background(), remotePath, data, nil) } func (star *StarSSH) WriteFileContext(ctx context.Context, remotePath string, data []byte, options *SFTPTransferOptions) error { return star.FS().WriteFile(ctx, remotePath, data, options) } func (star *StarSSH) MkdirAll(remotePath string) error { return star.MkdirAllContext(context.Background(), remotePath) } func (star *StarSSH) MkdirAllContext(ctx context.Context, remotePath string) error { return star.FS().MkdirAll(ctx, remotePath) } func (star *StarSSH) Remove(remotePath string) error { return star.RemoveContext(context.Background(), remotePath) } func (star *StarSSH) RemoveContext(ctx context.Context, remotePath string) error { return star.FS().Remove(ctx, remotePath) } func (star *StarSSH) RemoveAll(remotePath string) error { return star.RemoveAllContext(context.Background(), remotePath) } func (star *StarSSH) RemoveAllContext(ctx context.Context, remotePath string) error { return star.FS().RemoveAll(ctx, remotePath) } func (star *StarSSH) Rename(oldPath string, newPath string) error { return star.RenameContext(context.Background(), oldPath, newPath) } func (star *StarSSH) RenameContext(ctx context.Context, oldPath string, newPath string) error { return star.FS().Rename(ctx, oldPath, newPath) } func (fs *SFTPFileSystem) Stat(ctx context.Context, remotePath string) (os.FileInfo, error) { if fs == nil || fs.star == nil { return nil, errors.New("sftp filesystem is nil") } if err := validateRemotePath(remotePath); err != nil { return nil, err } var info os.FileInfo err := fs.star.runSFTPClientOperation(ctx, "sftp_stat", remotePath, func(client *sftp.Client) error { out, err := client.Stat(remotePath) if err != nil { return err } info = out return nil }) if err != nil { return nil, err } return info, nil } func (fs *SFTPFileSystem) ReadDir(ctx context.Context, remotePath string) ([]os.FileInfo, error) { if fs == nil || fs.star == nil { return nil, errors.New("sftp filesystem is nil") } if err := validateRemotePath(remotePath); err != nil { return nil, err } var entries []os.FileInfo err := fs.star.runSFTPClientOperation(ctx, "sftp_readdir", remotePath, func(client *sftp.Client) error { out, err := client.ReadDir(remotePath) if err != nil { return err } entries = out return nil }) if err != nil { return nil, err } return entries, nil } func (fs *SFTPFileSystem) ReadFile(ctx context.Context, remotePath string, options *SFTPTransferOptions) ([]byte, error) { if fs == nil || fs.star == nil { return nil, errors.New("sftp filesystem is nil") } return fs.star.SftpTransferInByteContext(ctx, remotePath, options) } func (fs *SFTPFileSystem) WriteFile(ctx context.Context, remotePath string, data []byte, options *SFTPTransferOptions) error { if fs == nil || fs.star == nil { return errors.New("sftp filesystem is nil") } return fs.star.SftpTransferOutByteContext(ctx, data, remotePath, options) } func (fs *SFTPFileSystem) MkdirAll(ctx context.Context, remotePath string) error { if fs == nil || fs.star == nil { return errors.New("sftp filesystem is nil") } if err := validateRemotePath(remotePath); err != nil { return err } return fs.star.runSFTPClientOperation(ctx, "sftp_mkdir_all", remotePath, func(client *sftp.Client) error { return client.MkdirAll(remotePath) }) } func (fs *SFTPFileSystem) Remove(ctx context.Context, remotePath string) error { if fs == nil || fs.star == nil { return errors.New("sftp filesystem is nil") } if err := validateRemotePath(remotePath); err != nil { return err } return fs.star.runSFTPClientOperationNoRetry(ctx, func(client *sftp.Client) error { return removeRemotePath(client, remotePath) }) } func (fs *SFTPFileSystem) RemoveAll(ctx context.Context, remotePath string) error { if fs == nil || fs.star == nil { return errors.New("sftp filesystem is nil") } if err := validateRemotePath(remotePath); err != nil { return err } return fs.star.runSFTPClientOperation(ctx, "sftp_remove_all", remotePath, func(client *sftp.Client) error { return removeRemoteAll(ctx, client, remotePath) }) } func (fs *SFTPFileSystem) Rename(ctx context.Context, oldPath string, newPath string) error { if fs == nil || fs.star == nil { return errors.New("sftp filesystem is nil") } if err := validateRemotePath(oldPath); err != nil { return err } if err := validateRemotePath(newPath); err != nil { return err } return fs.star.runSFTPClientOperationNoRetry(ctx, func(client *sftp.Client) error { return renameRemoteAtomic(client, oldPath, newPath) }) } func normalizeSFTPTransferOptions(options *SFTPTransferOptions) resolvedSFTPTransferOptions { opts := DefaultSFTPTransferOptions() if options == nil { return resolvedSFTPTransferOptions{ BufferSize: opts.BufferSize, Progress: opts.Progress, RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)), RetryInitialBackoff: derefSFTPDuration(opts.RetryInitialBackoff, defaultSFTPRetryInitialBackoff), AtomicUpload: derefSFTPBool(opts.AtomicUpload, true), AtomicDownload: derefSFTPBool(opts.AtomicDownload, true), VerifySize: derefSFTPBool(opts.VerifySize, true), VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false), TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix), } } if options.BufferSize > 0 { opts.BufferSize = options.BufferSize } if options.Progress != nil { opts.Progress = options.Progress } if options.RetryCount != nil { opts.RetryCount = options.RetryCount } if options.RetryInitialBackoff != nil { opts.RetryInitialBackoff = options.RetryInitialBackoff } if options.AtomicUpload != nil { opts.AtomicUpload = options.AtomicUpload } if options.AtomicDownload != nil { opts.AtomicDownload = options.AtomicDownload } if options.VerifySize != nil { opts.VerifySize = options.VerifySize } if options.VerifyChecksum != nil { opts.VerifyChecksum = options.VerifyChecksum } if strings.TrimSpace(options.TempSuffix) != "" { opts.TempSuffix = options.TempSuffix } return resolvedSFTPTransferOptions{ BufferSize: opts.BufferSize, Progress: opts.Progress, RetryCount: normalizeSFTPRetryCount(derefSFTPInt(opts.RetryCount, defaultSFTPRetryCount)), RetryInitialBackoff: derefSFTPDuration(opts.RetryInitialBackoff, defaultSFTPRetryInitialBackoff), AtomicUpload: derefSFTPBool(opts.AtomicUpload, true), AtomicDownload: derefSFTPBool(opts.AtomicDownload, true), VerifySize: derefSFTPBool(opts.VerifySize, true), VerifyChecksum: derefSFTPBool(opts.VerifyChecksum, false), TempSuffix: normalizeSFTPTempSuffix(opts.TempSuffix), } } func derefSFTPBool(value *bool, fallback bool) bool { if value == nil { return fallback } return *value } func derefSFTPInt(value *int, fallback int) int { if value == nil { return fallback } return *value } func derefSFTPDuration(value *time.Duration, fallback time.Duration) time.Duration { if value == nil { return fallback } return *value } func normalizeSFTPTempSuffix(value string) string { trimmed := strings.TrimSpace(value) if trimmed == "" { return defaultSFTPTempSuffix } return trimmed } func normalizeSFTPRetryCount(value int) int { if value < 0 { return 0 } return value } func (star *StarSSH) runSFTPClientOperation(ctx context.Context, operation string, remotePath string, fn func(*sftp.Client) error) error { if err := ensureContext(ctx); err != nil { return err } opts := normalizeSFTPTransferOptions(nil) return executeSFTPRetry(ctx, operation, "", remotePath, opts, func(attempt int) error { return star.withIsolatedSFTPClient(ctx, fn) }) } func (star *StarSSH) runSFTPClientOperationNoRetry(ctx context.Context, fn func(*sftp.Client) error) error { if err := ensureContext(ctx); err != nil { return err } return star.withIsolatedSFTPClient(ctx, fn) } func (star *StarSSH) CreateSftpClient() (*sftp.Client, error) { client, err := star.requireSSHClient() if err != nil { return nil, err } return sftp.NewClient(client) } func (star *StarSSH) withIsolatedSFTPClient(ctx context.Context, fn func(*sftp.Client) error) error { if err := ensureContext(ctx); err != nil { return err } client, err := star.CreateSftpClient() if err != nil { return err } defer client.Close() return fn(client) } func (star *StarSSH) getReusableSFTPClient() (*sftp.Client, error) { if star == nil { return nil, errors.New("ssh client is nil") } star.sftpMu.Lock() defer star.sftpMu.Unlock() if star.sftpClient != nil { return star.sftpClient, nil } sshClient, err := star.requireSSHClient() if err != nil { return nil, err } client, err := sftp.NewClient(sshClient) if err != nil { return nil, err } star.sftpClient = client return client, nil } func (star *StarSSH) resetReusableSFTPClient() { if star == nil { return } star.sftpMu.Lock() defer star.sftpMu.Unlock() if star.sftpClient != nil { _ = star.sftpClient.Close() star.sftpClient = nil } } func (star *StarSSH) closeReusableSFTPClient() error { if star == nil { return nil } star.sftpMu.Lock() defer star.sftpMu.Unlock() if star.sftpClient == nil { return nil } err := star.sftpClient.Close() star.sftpClient = nil return err } func (star *StarSSH) withReusableSFTPClient(ctx context.Context, fn func(*sftp.Client) error) error { if err := ensureContext(ctx); err != nil { return err } client, err := star.getReusableSFTPClient() if err != nil { return err } return fn(client) } func (star *StarSSH) runSFTPWithRetry( ctx context.Context, operation string, localPath string, remotePath string, opts resolvedSFTPTransferOptions, fn func(context.Context, *sftp.Client, resolvedSFTPTransferOptions) error, ) error { return executeSFTPRetry(ctx, operation, localPath, remotePath, opts, func(attempt int) error { return star.withIsolatedSFTPClient(ctx, func(client *sftp.Client) error { return fn(ctx, client, opts) }) }) } func (star *StarSSH) SftpTransferOut(localFilePath, remotePath string) error { return star.SftpTransferOutContext(context.Background(), localFilePath, remotePath, nil) } func (star *StarSSH) SftpTransferOutContext(ctx context.Context, localFilePath, remotePath string, options *SFTPTransferOptions) error { opts := normalizeSFTPTransferOptions(options) return star.runSFTPWithRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { return transferOutContext(ctx, client, localFilePath, remotePath, opts) }) } func SftpTransferOut(localFilePath, remotePath string, sftpClient *sftp.Client) error { return SftpTransferOutWithContext(context.Background(), localFilePath, remotePath, sftpClient, nil) } func SftpTransferOutWithContext(ctx context.Context, localFilePath, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error { opts := normalizeSFTPTransferOptions(options) return executeSFTPRetry(ctx, "sftp_put_file", localFilePath, remotePath, opts, func(attempt int) error { return transferOutContext(ctx, sftpClient, localFilePath, remotePath, opts) }) } func (star *StarSSH) SftpTransferOutByte(localData []byte, remotePath string) error { return star.SftpTransferOutByteContext(context.Background(), localData, remotePath, nil) } func (star *StarSSH) SftpTransferOutByteContext(ctx context.Context, localData []byte, remotePath string, options *SFTPTransferOptions) error { opts := normalizeSFTPTransferOptions(options) return star.runSFTPWithRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { return transferOutByteContext(ctx, client, localData, remotePath, opts) }) } func SftpTransferOutByte(localData []byte, remotePath string, sftpClient *sftp.Client) error { return SftpTransferOutByteWithContext(context.Background(), localData, remotePath, sftpClient, nil) } func SftpTransferOutByteWithContext(ctx context.Context, localData []byte, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) error { opts := normalizeSFTPTransferOptions(options) return executeSFTPRetry(ctx, "sftp_put_bytes", "", remotePath, opts, func(attempt int) error { return transferOutByteContext(ctx, sftpClient, localData, remotePath, opts) }) } func (star *StarSSH) SftpTransferOutFunc(localFilePath, remotePath string, bufcap int, rtefunc func(float64)) error { return star.SftpTransferOutContext(context.Background(), localFilePath, remotePath, &SFTPTransferOptions{ BufferSize: bufcap, Progress: rtefunc, }) } func SftpTransferOutFunc(localFilePath, remotePath string, bufcap int, rtefunc func(float64), sftpClient *sftp.Client) error { return SftpTransferOutWithContext(context.Background(), localFilePath, remotePath, sftpClient, &SFTPTransferOptions{ BufferSize: bufcap, Progress: rtefunc, }) } func (star *StarSSH) SftpTransferInByte(remotePath string) ([]byte, error) { return star.SftpTransferInByteContext(context.Background(), remotePath, nil) } func (star *StarSSH) SftpTransferInByteContext(ctx context.Context, remotePath string, options *SFTPTransferOptions) ([]byte, error) { opts := normalizeSFTPTransferOptions(options) var data []byte err := star.runSFTPWithRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { out, runErr := transferInByteContext(ctx, client, remotePath, opts) if runErr != nil { return runErr } data = out return nil }) if err != nil { return nil, err } return data, nil } func SftpTransferInByte(remotePath string, sftpClient *sftp.Client) ([]byte, error) { return SftpTransferInByteWithContext(context.Background(), remotePath, sftpClient, nil) } func SftpTransferInByteWithContext(ctx context.Context, remotePath string, sftpClient *sftp.Client, options *SFTPTransferOptions) ([]byte, error) { opts := normalizeSFTPTransferOptions(options) var data []byte err := executeSFTPRetry(ctx, "sftp_get_bytes", "", remotePath, opts, func(attempt int) error { out, runErr := transferInByteContext(ctx, sftpClient, remotePath, opts) if runErr != nil { return runErr } data = out return nil }) if err != nil { return nil, err } return data, nil } func (star *StarSSH) SftpTransferIn(src, dst string) error { return star.SftpTransferInContext(context.Background(), src, dst, nil) } func (star *StarSSH) SftpTransferInContext(ctx context.Context, src, dst string, options *SFTPTransferOptions) error { opts := normalizeSFTPTransferOptions(options) return star.runSFTPWithRetry(ctx, "sftp_get_file", dst, src, opts, func(ctx context.Context, client *sftp.Client, opts resolvedSFTPTransferOptions) error { return transferInContext(ctx, client, src, dst, opts) }) } func SftpTransferIn(src, dst string, sftpClient *sftp.Client) error { return SftpTransferInWithContext(context.Background(), src, dst, sftpClient, nil) } func SftpTransferInWithContext(ctx context.Context, src, dst string, sftpClient *sftp.Client, options *SFTPTransferOptions) error { opts := normalizeSFTPTransferOptions(options) return executeSFTPRetry(ctx, "sftp_get_file", dst, src, opts, func(attempt int) error { return transferInContext(ctx, sftpClient, src, dst, opts) }) } func (star *StarSSH) SftpTransferInFunc(src, dst string, bufcap int, rtefunc func(float64)) error { return star.SftpTransferInContext(context.Background(), src, dst, &SFTPTransferOptions{ BufferSize: bufcap, Progress: rtefunc, }) } func SftpTransferInFunc(src, dst string, bufcap int, rtefunc func(float64), sftpClient *sftp.Client) error { return SftpTransferInWithContext(context.Background(), src, dst, sftpClient, &SFTPTransferOptions{ BufferSize: bufcap, Progress: rtefunc, }) } func transferOutContext(ctx context.Context, sftpClient *sftp.Client, localFilePath string, remotePath string, opts resolvedSFTPTransferOptions) error { if err := ensureContext(ctx); err != nil { return err } if err := validateSFTPClient(sftpClient); err != nil { return err } if strings.TrimSpace(localFilePath) == "" || strings.TrimSpace(remotePath) == "" { return errors.New("local path and remote path must not be empty") } srcFile, err := os.Open(localFilePath) if err != nil { return err } defer srcFile.Close() stat, err := srcFile.Stat() if err != nil { return err } tempPath, targetPath := buildUploadTargetPath(remotePath, opts) targetInfo := atomicReplaceTarget{} if tempPath != "" { out, err := inspectRemoteAtomicTarget(sftpClient, remotePath) if err != nil { return err } targetInfo = out } if tempPath != "" { defer func() { _ = sftpClient.Remove(tempPath) }() } dstFile, err := sftpClient.Create(targetPath) if err != nil { return err } if _, err := sftpCopyWithProgressFunc(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil { _ = dstFile.Close() return err } if err := dstFile.Close(); err != nil { return err } verifyPath := remotePath if tempPath != "" { verifyPath = tempPath } if opts.VerifySize { if err := sftpVerifyRemoteSizeFunc(sftpClient, verifyPath, stat.Size()); err != nil { return err } } if opts.VerifyChecksum { localHash, err := sftpLocalFileSHA256Func(ctx, localFilePath) if err != nil { return err } remoteHash, err := sftpRemoteFileSHA256Func(ctx, sftpClient, verifyPath) if err != nil { return err } if localHash != remoteHash { return fmt.Errorf("checksum mismatch after upload: local=%s remote=%s", localHash, remoteHash) } } if tempPath != "" { mode := stat.Mode() if desiredMode, ok := determineAtomicReplaceMode(targetInfo, &mode); ok { if err := applyRemoteFileMode(sftpClient, tempPath, desiredMode); err != nil { return err } } if err := renameRemoteAtomic(sftpClient, tempPath, remotePath); err != nil { return err } tempPath = "" } return nil } func transferOutByteContext(ctx context.Context, sftpClient *sftp.Client, localData []byte, remotePath string, opts resolvedSFTPTransferOptions) error { if err := ensureContext(ctx); err != nil { return err } if err := validateSFTPClient(sftpClient); err != nil { return err } if strings.TrimSpace(remotePath) == "" { return errors.New("remote path must not be empty") } tempPath, targetPath := buildUploadTargetPath(remotePath, opts) targetInfo := atomicReplaceTarget{} if tempPath != "" { out, err := inspectRemoteAtomicTarget(sftpClient, remotePath) if err != nil { return err } targetInfo = out } if tempPath != "" { defer func() { _ = sftpClient.Remove(tempPath) }() } dstFile, err := sftpClient.Create(targetPath) if err != nil { return err } reader := bytes.NewReader(localData) if _, err := sftpCopyWithProgressFunc(ctx, dstFile, reader, opts.BufferSize, int64(len(localData)), opts.Progress); err != nil { _ = dstFile.Close() return err } if err := dstFile.Close(); err != nil { return err } verifyPath := remotePath if tempPath != "" { verifyPath = tempPath } if opts.VerifySize { if err := sftpVerifyRemoteSizeFunc(sftpClient, verifyPath, int64(len(localData))); err != nil { return err } } if opts.VerifyChecksum { localHash := checksumBytes(localData) remoteHash, err := sftpRemoteFileSHA256Func(ctx, sftpClient, verifyPath) if err != nil { return err } if localHash != remoteHash { return fmt.Errorf("checksum mismatch after upload: local=%s remote=%s", localHash, remoteHash) } } if tempPath != "" { if desiredMode, ok := determineAtomicReplaceMode(targetInfo, nil); ok { if err := applyRemoteFileMode(sftpClient, tempPath, desiredMode); err != nil { return err } } if err := renameRemoteAtomic(sftpClient, tempPath, remotePath); err != nil { return err } tempPath = "" } return nil } func transferInContext(ctx context.Context, sftpClient *sftp.Client, src, dst string, opts resolvedSFTPTransferOptions) error { if err := ensureContext(ctx); err != nil { return err } if err := validateSFTPClient(sftpClient); err != nil { return err } if strings.TrimSpace(src) == "" || strings.TrimSpace(dst) == "" { return errors.New("source path and destination path must not be empty") } srcFile, err := sftpClient.Open(src) if err != nil { return err } defer srcFile.Close() stat, err := srcFile.Stat() if err != nil { return err } targetInfo := atomicReplaceTarget{} if opts.AtomicDownload { out, err := inspectLocalAtomicTarget(dst) if err != nil { return err } targetInfo = out } dstFile, tempPath, err := createLocalTransferFile(dst, opts) if err != nil { return err } if tempPath != "" { defer func() { _ = os.Remove(tempPath) }() } if _, err := sftpCopyWithProgressFunc(ctx, dstFile, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil { _ = dstFile.Close() return err } if err := dstFile.Close(); err != nil { return err } verifyPath := dst if tempPath != "" { verifyPath = tempPath } if opts.VerifySize { if err := sftpVerifyLocalSizeFunc(verifyPath, stat.Size()); err != nil { return err } } if opts.VerifyChecksum { localHash, err := sftpLocalFileSHA256Func(ctx, verifyPath) if err != nil { return err } remoteHash, err := sftpRemoteFileSHA256Func(ctx, sftpClient, src) if err != nil { return err } if localHash != remoteHash { return fmt.Errorf("checksum mismatch after download: local=%s remote=%s", localHash, remoteHash) } } if tempPath != "" { mode := stat.Mode() if desiredMode, ok := determineAtomicReplaceMode(targetInfo, &mode); ok { if err := applyLocalFileMode(tempPath, desiredMode); err != nil { return err } } if err := renameLocalAtomic(tempPath, dst); err != nil { return err } tempPath = "" } return nil } func transferInByteContext(ctx context.Context, sftpClient *sftp.Client, remotePath string, opts resolvedSFTPTransferOptions) ([]byte, error) { if err := ensureContext(ctx); err != nil { return nil, err } if err := validateSFTPClient(sftpClient); err != nil { return nil, err } if strings.TrimSpace(remotePath) == "" { return nil, errors.New("remote path must not be empty") } srcFile, err := sftpClient.Open(remotePath) if err != nil { return nil, err } defer srcFile.Close() stat, err := srcFile.Stat() if err != nil { return nil, err } var out bytes.Buffer if _, err := sftpCopyWithProgressFunc(ctx, &out, srcFile, opts.BufferSize, stat.Size(), opts.Progress); err != nil { return nil, err } data := out.Bytes() if opts.VerifySize && int64(len(data)) != stat.Size() { return nil, fmt.Errorf("download size mismatch: local=%d remote=%d", len(data), stat.Size()) } if opts.VerifyChecksum { localHash := checksumBytes(data) remoteHash, err := sftpRemoteFileSHA256Func(ctx, sftpClient, remotePath) if err != nil { return nil, err } if localHash != remoteHash { return nil, fmt.Errorf("checksum mismatch after download: local=%s remote=%s", localHash, remoteHash) } } return data, nil } func executeSFTPRetry( ctx context.Context, operation string, localPath string, remotePath string, opts resolvedSFTPTransferOptions, fn func(attempt int) error, ) error { backoff := opts.RetryInitialBackoff if backoff <= 0 { backoff = defaultSFTPRetryInitialBackoff } for attempt := 0; attempt <= opts.RetryCount; attempt++ { if err := ensureContext(ctx); err != nil { return wrapSFTPTransferError(operation, localPath, remotePath, attempt, SFTPErrorPermanent, err) } err := fn(attempt) if err == nil { return nil } category := classifySFTPError(err) wrappedErr := wrapSFTPTransferError(operation, localPath, remotePath, attempt, category, err) if category != SFTPErrorRetryable || attempt >= opts.RetryCount { return wrappedErr } timer := time.NewTimer(backoff) select { case <-ctx.Done(): if !timer.Stop() { <-timer.C } return wrapSFTPTransferError(operation, localPath, remotePath, attempt, SFTPErrorPermanent, ctx.Err()) case <-timer.C: } if backoff < 4*time.Second { backoff *= 2 } } return nil } func wrapSFTPTransferError(operation, localPath, remotePath string, attempt int, category SFTPErrorCategory, err error) error { if err == nil { return nil } var transferErr *SFTPTransferError if errors.As(err, &transferErr) { return err } return &SFTPTransferError{ Operation: operation, LocalPath: localPath, RemotePath: remotePath, Attempt: attempt, Category: category, Err: err, } } func classifySFTPError(err error) SFTPErrorCategory { if isRetryableTransferError(err) { return SFTPErrorRetryable } return SFTPErrorPermanent } func isRetryableTransferError(err error) bool { if err == nil { return false } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false } if errors.Is(err, os.ErrNotExist) { return false } var netErr net.Error if errors.As(err, &netErr) { if netErr.Timeout() || netErr.Temporary() { return true } } errText := strings.ToLower(err.Error()) if strings.Contains(errText, "permission denied") || strings.Contains(errText, "no such file") { return false } retryableHints := []string{ "connection reset", "broken pipe", "connection aborted", "connection refused", "connection lost", "timeout", "timed out", "unexpected eof", "use of closed network connection", "transport is closing", } for _, hint := range retryableHints { if strings.Contains(errText, hint) { return true } } return false } func validateSFTPClient(client *sftp.Client) error { if client == nil { return errors.New("sftp client is nil") } return nil } func ensureContext(ctx context.Context) error { if ctx == nil { return nil } select { case <-ctx.Done(): return ctx.Err() default: return nil } } func copyWithProgressContext(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) { buffer := make([]byte, normalizeBufferSize(bufSize)) var copied int64 if progress != nil && total > 0 { progress(0) } for { if err := ensureContext(ctx); err != nil { return copied, err } n, readErr := src.Read(buffer) if n > 0 { if err := ensureContext(ctx); err != nil { return copied, err } written, writeErr := dst.Write(buffer[:n]) copied += int64(written) if writeErr != nil { return copied, writeErr } if written != n { return copied, io.ErrShortWrite } reportProgress(progress, copied, total) } if readErr == io.EOF { break } if readErr != nil { return copied, readErr } } if progress != nil { progress(100) } return copied, nil } func reportProgress(progress func(float64), copied int64, total int64) { if progress == nil { return } if total <= 0 { progress(100) return } percent := float64(copied) / float64(total) * 100 if percent > 100 { percent = 100 } progress(percent) } func buildUploadTargetPath(remotePath string, opts resolvedSFTPTransferOptions) (tempPath string, targetPath string) { targetPath = remotePath if !opts.AtomicUpload { return "", targetPath } suffix := strings.TrimSpace(opts.TempSuffix) if suffix == "" { suffix = defaultSFTPTempSuffix } tempPath = fmt.Sprintf("%s%s.%s", remotePath, suffix, newNonce(4)) return tempPath, tempPath } func createLocalTransferFile(localPath string, opts resolvedSFTPTransferOptions) (*os.File, string, error) { if !opts.AtomicDownload { file, err := os.Create(localPath) if err != nil { return nil, "", err } return file, "", nil } dir := filepath.Dir(localPath) pattern := fmt.Sprintf("%s%s.*", filepath.Base(localPath), normalizeSFTPTempSuffix(opts.TempSuffix)) file, err := os.CreateTemp(dir, pattern) if err != nil { return nil, "", err } return file, file.Name(), nil } func renameRemoteAtomic(client *sftp.Client, from, to string) error { if from == to { return nil } if _, err := inspectRemoteAtomicTarget(client, to); err != nil { return err } type posixRenamer interface { PosixRename(string, string) error } if renamer, ok := interface{}(client).(posixRenamer); ok { if err := renamer.PosixRename(from, to); err == nil { return nil } } renameErr := client.Rename(from, to) if renameErr == nil { return nil } targetInfo, err := inspectRemoteAtomicTarget(client, to) if err != nil { return errors.Join(renameErr, err) } if !targetInfo.exists { return renameErr } backupPath := buildRenameBackupPath(to) if err := client.Rename(to, backupPath); err != nil { return errors.Join(renameErr, fmt.Errorf("backup existing target %q failed: %w", to, err)) } if err := client.Rename(from, to); err != nil { restoreErr := client.Rename(backupPath, to) if restoreErr != nil { return errors.Join(renameErr, err, fmt.Errorf("restore original target %q failed: %w", to, restoreErr)) } return errors.Join(renameErr, err) } if err := removeRemotePath(client, backupPath); err != nil && !isNotExistError(err) { return fmt.Errorf("rename succeeded but backup cleanup %q failed: %w", backupPath, err) } return nil } func renameLocalAtomic(from, to string) error { if from == to { return nil } if _, err := inspectLocalAtomicTarget(to); err != nil { return err } renameErr := os.Rename(from, to) if renameErr == nil { return nil } targetInfo, err := inspectLocalAtomicTarget(to) if err != nil { return errors.Join(renameErr, err) } if !targetInfo.exists { return renameErr } backupPath := buildLocalRenameBackupPath(to) if err := os.Rename(to, backupPath); err != nil { return errors.Join(renameErr, fmt.Errorf("backup existing local target %q failed: %w", to, err)) } if err := os.Rename(from, to); err != nil { restoreErr := os.Rename(backupPath, to) if restoreErr != nil { return errors.Join(renameErr, err, fmt.Errorf("restore original local target %q failed: %w", to, restoreErr)) } return errors.Join(renameErr, err) } if err := os.Remove(backupPath); err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("rename succeeded but local backup cleanup %q failed: %w", backupPath, err) } return nil } func buildRenameBackupPath(targetPath string) string { return fmt.Sprintf("%s%s.rename-backup.%s", targetPath, defaultSFTPTempSuffix, newNonce(4)) } func buildLocalRenameBackupPath(targetPath string) string { dir := filepath.Dir(targetPath) name := fmt.Sprintf("%s%s.rename-backup.%s", filepath.Base(targetPath), defaultSFTPTempSuffix, newNonce(4)) return filepath.Join(dir, name) } func remotePathExists(client *sftp.Client, remotePath string) (bool, error) { if client == nil { return false, errors.New("sftp client is nil") } _, err := client.Lstat(remotePath) if err == nil { return true, nil } if isNotExistError(err) { return false, nil } return false, err } func localPathExists(localPath string) (bool, error) { _, err := os.Lstat(localPath) if err == nil { return true, nil } if errors.Is(err, os.ErrNotExist) { return false, nil } return false, err } func verifyRemoteSize(client *sftp.Client, remotePath string, expected int64) error { info, err := client.Stat(remotePath) if err != nil { return err } if info.Size() != expected { return fmt.Errorf("remote size mismatch: got %d want %d", info.Size(), expected) } return nil } func verifyLocalSize(localPath string, expected int64) error { info, err := os.Stat(localPath) if err != nil { return err } if info.Size() != expected { return fmt.Errorf("local size mismatch: got %d want %d", info.Size(), expected) } return nil } func localFileSHA256(ctx context.Context, path string) (string, error) { file, err := os.Open(path) if err != nil { return "", err } defer file.Close() return readerSHA256(ctx, file) } func remoteFileSHA256(ctx context.Context, client *sftp.Client, remotePath string) (string, error) { file, err := client.Open(remotePath) if err != nil { return "", err } defer file.Close() return readerSHA256(ctx, file) } func checksumBytes(data []byte) string { sum := sha256.Sum256(data) return hex.EncodeToString(sum[:]) } func readerSHA256(ctx context.Context, reader io.Reader) (string, error) { hasher := sha256.New() buf := make([]byte, normalizeBufferSize(defaultTransferBufferSize)) for { if err := ensureContext(ctx); err != nil { return "", err } n, err := reader.Read(buf) if n > 0 { if _, writeErr := hasher.Write(buf[:n]); writeErr != nil { return "", writeErr } } if err == io.EOF { break } if err != nil { return "", err } } return hex.EncodeToString(hasher.Sum(nil)), nil } func isNotExistError(err error) bool { if err == nil { return false } if os.IsNotExist(err) { return true } return strings.Contains(strings.ToLower(err.Error()), "no such file") } func validateRemotePath(remotePath string) error { if strings.TrimSpace(remotePath) == "" { return errors.New("remote path must not be empty") } return nil } func inspectRemoteAtomicTarget(client *sftp.Client, remotePath string) (atomicReplaceTarget, error) { if err := validateSFTPClient(client); err != nil { return atomicReplaceTarget{}, err } info, err := client.Lstat(remotePath) if err != nil { if isNotExistError(err) { return atomicReplaceTarget{}, nil } return atomicReplaceTarget{}, err } if err := validateAtomicReplaceTarget(remotePath, info); err != nil { return atomicReplaceTarget{}, err } return atomicReplaceTarget{ exists: true, mode: info.Mode(), }, nil } func inspectLocalAtomicTarget(localPath string) (atomicReplaceTarget, error) { info, err := os.Lstat(localPath) if err != nil { if errors.Is(err, os.ErrNotExist) { return atomicReplaceTarget{}, nil } return atomicReplaceTarget{}, err } if err := validateAtomicReplaceTarget(localPath, info); err != nil { return atomicReplaceTarget{}, err } return atomicReplaceTarget{ exists: true, mode: info.Mode(), }, nil } func validateAtomicReplaceTarget(targetPath string, info os.FileInfo) error { if info == nil { return nil } mode := info.Mode() switch { case mode&os.ModeSymlink != 0: return fmt.Errorf("atomic overwrite target %q is a symlink", targetPath) case mode.IsRegular(): return nil default: return fmt.Errorf("atomic overwrite target %q is %s", targetPath, describeFileInfoType(info)) } } func describeFileInfoType(info os.FileInfo) string { if info == nil { return "unknown" } mode := info.Mode() switch { case mode&os.ModeSymlink != 0: return "a symlink" case mode.IsDir(): return "a directory" case mode&os.ModeNamedPipe != 0: return "a named pipe" case mode&os.ModeSocket != 0: return "a socket" case mode&os.ModeDevice != 0 && mode&os.ModeCharDevice != 0: return "a character device" case mode&os.ModeDevice != 0: return "a block device" default: return "not a regular file" } } func determineAtomicReplaceMode(target atomicReplaceTarget, sourceMode *os.FileMode) (os.FileMode, bool) { if target.exists { return normalizePreservedFileMode(target.mode), true } if sourceMode == nil { return 0, false } return normalizePreservedFileMode(*sourceMode), true } func normalizePreservedFileMode(mode os.FileMode) os.FileMode { return mode & preservedFileModeBits } func applyLocalFileMode(localPath string, mode os.FileMode) error { return os.Chmod(localPath, normalizePreservedFileMode(mode)) } func applyRemoteFileMode(client *sftp.Client, remotePath string, mode os.FileMode) error { if err := validateSFTPClient(client); err != nil { return err } return client.Chmod(remotePath, normalizePreservedFileMode(mode)) } func removeRemotePath(client *sftp.Client, remotePath string) error { info, err := client.Lstat(remotePath) if err != nil { return err } if info.Mode()&os.ModeSymlink != 0 { return client.Remove(remotePath) } if info.IsDir() { return client.RemoveDirectory(remotePath) } return client.Remove(remotePath) } func removeRemoteAll(ctx context.Context, client *sftp.Client, remotePath string) error { if err := ensureContext(ctx); err != nil { return err } info, err := client.Lstat(remotePath) if err != nil { if isNotExistError(err) { return nil } return err } if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() { if err := client.Remove(remotePath); err != nil && !isNotExistError(err) { return err } return nil } entries, err := client.ReadDir(remotePath) if err != nil { return err } for _, entry := range entries { childPath := path.Join(remotePath, entry.Name()) if err := removeRemoteAll(ctx, client, childPath); err != nil { return err } } if err := client.RemoveDirectory(remotePath); err != nil && !isNotExistError(err) { return err } return nil }