package starssh import ( "bytes" "context" "errors" "io" "net" "os" "path/filepath" "strings" "testing" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" ) func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) { opts := mustNormalizeSFTPTransferOptions(t, nil) if !opts.AtomicUpload { t.Fatal("expected atomic upload to default to enabled") } if !opts.AtomicDownload { t.Fatal("expected atomic download to default to enabled") } if opts.Client.ConcurrentWrites != nil { t.Fatal("expected concurrent writes to default to unset") } } func TestThroughputSFTPTransferOptionsEnablesExplicitConcurrentWrites(t *testing.T) { opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) if !opts.AtomicUpload { t.Fatal("expected throughput preset to keep atomic upload enabled") } if opts.Client.ConcurrentReads == nil || !*opts.Client.ConcurrentReads { t.Fatal("expected throughput preset to enable concurrent reads") } if opts.Client.ConcurrentWrites == nil || !*opts.Client.ConcurrentWrites { t.Fatal("expected throughput preset to enable concurrent writes") } if opts.Client.MaxConcurrentRequestsPerFile != 32 { t.Fatalf("unexpected max concurrent requests: got %d want 32", opts.Client.MaxConcurrentRequestsPerFile) } } func TestValidateSFTPUploadOptionsRejectsConcurrentWritesWithoutAtomicUpload(t *testing.T) { options := ThroughputSFTPTransferOptions() options.AtomicUpload = SFTPBool(false) opts := mustNormalizeSFTPTransferOptions(t, &options) err := validateSFTPUploadOptions(opts) if err == nil || !strings.Contains(err.Error(), "atomic upload") { t.Fatalf("expected atomic upload rejection, got %v", err) } } func TestNormalizeSFTPTransferOptionsRejectsNegativeClientValues(t *testing.T) { _, err := normalizeSFTPTransferOptions(&SFTPTransferOptions{ Client: SFTPClientOptions{MaxPacketSize: -1}, }) if err == nil || !strings.Contains(err.Error(), "max packet") { t.Fatalf("expected max packet rejection, got %v", err) } _, err = normalizeSFTPTransferOptions(&SFTPTransferOptions{ Client: SFTPClientOptions{MaxConcurrentRequestsPerFile: -1}, }) if err == nil || !strings.Contains(err.Error(), "max concurrent") { t.Fatalf("expected max concurrent rejection, got %v", err) } } func TestBuildSFTPClientOptionsRejectsUnsupportedCheckedPacketSize(t *testing.T) { options := buildSFTPClientOptions(resolvedSFTPClientOptions{MaxPacketSize: 32769}) client := &sftp.Client{} err := options[0](client) if err == nil || !strings.Contains(err.Error(), "32KB") { t.Fatalf("expected checked packet size rejection, got %v", err) } } func TestSftpTransferOutWithContextRejectsClientOptionsForExternalClient(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() localPath := filepath.Join(root, "local.txt") remotePath := filepath.Join(root, "remote.txt") if err := os.WriteFile(localPath, []byte("payload"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } err := SftpTransferOutWithContext(context.Background(), localPath, remotePath, client, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) if err == nil || !strings.Contains(err.Error(), "StarSSH-managed") { t.Fatalf("expected external client option rejection, got %v", err) } } func TestSftpTransferOutContextPassesClientOptionsToManagedClient(t *testing.T) { root := t.TempDir() localPath := filepath.Join(root, "local.txt") remotePath := filepath.Join(root, "remote.txt") if err := os.WriteFile(localPath, []byte("payload"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } client := newSFTPTestClient(t) options := ThroughputSFTPTransferOptions() var captured []sftp.ClientOption oldNewClient := sftpNewClientFunc sftpNewClientFunc = func(_ *ssh.Client, opts ...sftp.ClientOption) (*sftp.Client, error) { captured = append([]sftp.ClientOption(nil), opts...) return client, nil } t.Cleanup(func() { sftpNewClientFunc = oldNewClient }) star := &StarSSH{Client: &ssh.Client{}} if err := star.SftpTransferOutContext(context.Background(), localPath, remotePath, &options); err != nil { t.Fatalf("transfer out: %v", err) } if len(captured) == 0 { t.Fatal("expected managed SFTP client options to be passed to factory") } if got, want := len(captured), len(buildSFTPClientOptions(mustNormalizeSFTPTransferOptions(t, &options).Client)); got != want { t.Fatalf("unexpected client option count: got %d want %d", got, want) } } func TestSftpTransferInWithContextRejectsClientOptionsForExternalClient(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") dstPath := filepath.Join(root, "local.txt") if err := os.WriteFile(srcPath, []byte("payload"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } err := SftpTransferInWithContext(context.Background(), srcPath, dstPath, client, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) if err == nil || !strings.Contains(err.Error(), "StarSSH-managed") { t.Fatalf("expected external client option rejection, got %v", err) } } func TestSftpTransferInContextPassesClientOptionsToManagedClient(t *testing.T) { root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") dstPath := filepath.Join(root, "local.txt") if err := os.WriteFile(srcPath, []byte("payload"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } client := newSFTPTestClient(t) options := ThroughputSFTPTransferOptions() var captured []sftp.ClientOption oldNewClient := sftpNewClientFunc sftpNewClientFunc = func(_ *ssh.Client, opts ...sftp.ClientOption) (*sftp.Client, error) { captured = append([]sftp.ClientOption(nil), opts...) return client, nil } t.Cleanup(func() { sftpNewClientFunc = oldNewClient }) star := &StarSSH{Client: &ssh.Client{}} if err := star.SftpTransferInContext(context.Background(), srcPath, dstPath, &options); err != nil { t.Fatalf("transfer in: %v", err) } if len(captured) == 0 { t.Fatal("expected managed SFTP client options to be passed to factory") } if got, want := len(captured), len(buildSFTPClientOptions(mustNormalizeSFTPTransferOptions(t, &options).Client)); got != want { t.Fatalf("unexpected client option count: got %d want %d", got, want) } assertFileContent(t, dstPath, "payload") } func TestBuildSFTPClientOptionsCanTransfer(t *testing.T) { client := newSFTPTestClientWithOptions(t, buildSFTPClientOptions(resolvedSFTPClientOptions{ MaxPacketSize: 4096, MaxConcurrentRequestsPerFile: 4, ConcurrentWrites: SFTPBool(true), })) root := t.TempDir() localPath := filepath.Join(root, "local.txt") remotePath := filepath.Join(root, "remote.txt") if err := os.WriteFile(localPath, []byte(strings.Repeat("payload-", 2048)), 0o644); err != nil { t.Fatalf("write local file: %v", err) } opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) if err := transferOutContext(context.Background(), client, localPath, remotePath, opts); err != nil { t.Fatalf("transfer out with client options: %v", err) } data, err := os.ReadFile(remotePath) if err != nil { t.Fatalf("read remote file: %v", err) } if got := string(data); got != strings.Repeat("payload-", 2048) { t.Fatalf("unexpected remote payload length: got %d", len(got)) } } func TestBuildSFTPClientOptionsCanDownload(t *testing.T) { client := newSFTPTestClientWithOptions(t, buildSFTPClientOptions(resolvedSFTPClientOptions{ MaxPacketSize: 4096, MaxConcurrentRequestsPerFile: 4, ConcurrentReads: SFTPBool(true), })) root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") dstPath := filepath.Join(root, "local.txt") payload := strings.Repeat("payload-", 2048) if err := os.WriteFile(srcPath, []byte(payload), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) if err := transferInContext(context.Background(), client, srcPath, dstPath, opts); err != nil { t.Fatalf("transfer in with client options: %v", err) } assertFileContent(t, dstPath, payload) } func TestCopyUploadWithProgressUsesConcurrentReadFromWhenEnabled(t *testing.T) { dst := &spyConcurrentReadFrom{} src := strings.NewReader("payload") var progress []float64 opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) written, err := copyUploadWithProgressContext(context.Background(), dst, src, 3, int64(len("payload")), func(value float64) { progress = append(progress, value) }, opts) if err != nil { t.Fatalf("copy upload: %v", err) } if written != int64(len("payload")) { t.Fatalf("unexpected written bytes: got %d", written) } if !dst.usedReadFrom { t.Fatal("expected concurrent ReadFrom path to be used") } if dst.concurrency != opts.Client.MaxConcurrentRequestsPerFile { t.Fatalf("unexpected concurrency: got %d want %d", dst.concurrency, opts.Client.MaxConcurrentRequestsPerFile) } if got := dst.buf.String(); got != "payload" { t.Fatalf("unexpected copied payload: got %q", got) } if len(progress) == 0 || progress[len(progress)-1] != 100 { t.Fatalf("expected final progress 100, got %v", progress) } } func TestCopyUploadWithProgressReportsDuringConcurrentReadFrom(t *testing.T) { dst := &spyConcurrentReadFrom{} src := &chunkedReader{ reader: strings.NewReader("payload"), chunkSize: 2, } var progress []float64 opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) written, err := copyUploadWithProgressContext(context.Background(), dst, src, 3, int64(len("payload")), func(value float64) { progress = append(progress, value) }, opts) if err != nil { t.Fatalf("copy upload: %v", err) } if written != int64(len("payload")) { t.Fatalf("unexpected written bytes: got %d", written) } var sawIntermediate bool for _, value := range progress { if value > 0 && value < 100 { sawIntermediate = true break } } if !sawIntermediate { t.Fatalf("expected intermediate progress during concurrent readfrom, got %v", progress) } } func TestCopyUploadWithProgressCancelsDuringConcurrentReadFrom(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() dst := &spyConcurrentReadFrom{} src := &cancelAfterReadReader{ reader: strings.NewReader("payload"), chunkSize: 3, cancel: cancel, } opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) written, err := copyUploadWithProgressContext(ctx, dst, src, 3, int64(len("payload")), nil, opts) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context cancellation, got written=%d err=%v", written, err) } if !dst.usedReadFrom { t.Fatal("expected concurrent ReadFrom path to be used") } if written <= 0 || written >= int64(len("payload")) { t.Fatalf("expected partial write before cancellation, got %d", written) } } func TestCopyUploadWithProgressDoesNotReportDoneBeforeConcurrentWriteError(t *testing.T) { copyErr := errors.New("write status failed") dst := &spyConcurrentReadFrom{errAfterRead: copyErr} var progress []float64 opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) written, err := copyUploadWithProgressContext(context.Background(), dst, strings.NewReader("payload"), 3, int64(len("payload")), func(value float64) { progress = append(progress, value) }, opts) if !errors.Is(err, copyErr) { t.Fatalf("expected concurrent write error, got written=%d err=%v", written, err) } if written != int64(len("payload")) { t.Fatalf("expected read byte count before write error, got %d", written) } for _, value := range progress { if value >= 100 { t.Fatalf("progress reported completion before write success: %v", progress) } } if len(progress) == 0 { t.Fatal("expected queued progress before write error") } } func TestCopyUploadWithProgressReturnsConcurrentReadFromError(t *testing.T) { copyErr := errors.New("readfrom failed") dst := &spyConcurrentReadFrom{err: copyErr} opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) _, err := copyUploadWithProgressContext(context.Background(), dst, strings.NewReader("payload"), 3, int64(len("payload")), nil, opts) if !errors.Is(err, copyErr) { t.Fatalf("expected concurrent readfrom error, got %v", err) } if !dst.usedReadFrom { t.Fatal("expected concurrent ReadFrom path to be used") } } func TestCopyUploadWithProgressDefaultsToSequentialCopy(t *testing.T) { oldCopy := sftpCopyWithProgressFunc called := false sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) { called = true return oldCopy(ctx, dst, src, bufSize, total, progress) } t.Cleanup(func() { sftpCopyWithProgressFunc = oldCopy }) var dst bytes.Buffer written, err := copyUploadWithProgressContext(context.Background(), &dst, strings.NewReader("payload"), 3, int64(len("payload")), nil, mustNormalizeSFTPTransferOptions(t, nil)) if err != nil { t.Fatalf("copy upload: %v", err) } if written != int64(len("payload")) { t.Fatalf("unexpected written bytes: got %d", written) } if !called { t.Fatal("expected default path to use existing copy helper") } } func TestCopyDownloadWithProgressUsesConcurrentWriteToWhenEnabled(t *testing.T) { src := &spyConcurrentWriteTo{payload: []byte("payload")} var dst bytes.Buffer var progress []float64 opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) { progress = append(progress, value) }, opts) if err != nil { t.Fatalf("copy download: %v", err) } if written != int64(len("payload")) { t.Fatalf("unexpected written bytes: got %d", written) } if !src.usedWriteTo { t.Fatal("expected concurrent WriteTo path to be used") } if got := dst.String(); got != "payload" { t.Fatalf("unexpected copied payload: got %q", got) } if len(progress) == 0 || progress[len(progress)-1] != 100 { t.Fatalf("expected final progress 100, got %v", progress) } } func TestCopyDownloadWithProgressReportsDuringConcurrentWriteTo(t *testing.T) { src := &spyConcurrentWriteTo{ payload: []byte("payload"), chunkSize: 2, } var dst bytes.Buffer var progress []float64 opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) { progress = append(progress, value) }, opts) if err != nil { t.Fatalf("copy download: %v", err) } if written != int64(len("payload")) { t.Fatalf("unexpected written bytes: got %d", written) } var sawIntermediate bool for _, value := range progress { if value > 0 && value < 100 { sawIntermediate = true break } } if !sawIntermediate { t.Fatalf("expected intermediate progress during concurrent writeto, got %v", progress) } } func TestCopyDownloadWithProgressCancelsDuringConcurrentWriteTo(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() src := &spyConcurrentWriteTo{ payload: []byte("payload"), chunkSize: 3, cancelAfterWrite: cancel, } var dst bytes.Buffer opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) written, err := copyDownloadWithProgressContext(ctx, &dst, src, 3, int64(len("payload")), nil, opts) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context cancellation, got written=%d err=%v", written, err) } if !src.usedWriteTo { t.Fatal("expected concurrent WriteTo path to be used") } if written <= 0 || written >= int64(len("payload")) { t.Fatalf("expected partial write before cancellation, got %d", written) } } func TestCopyDownloadWithProgressDoesNotReportDoneBeforeConcurrentReadError(t *testing.T) { copyErr := errors.New("read status failed") src := &spyConcurrentWriteTo{ payload: []byte("payload"), errAfterWrite: copyErr, } var dst bytes.Buffer var progress []float64 opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), func(value float64) { progress = append(progress, value) }, opts) if !errors.Is(err, copyErr) { t.Fatalf("expected concurrent read error, got written=%d err=%v", written, err) } if written != int64(len("payload")) { t.Fatalf("expected local byte count before read error, got %d", written) } for _, value := range progress { if value >= 100 { t.Fatalf("progress reported completion before download success: %v", progress) } } if len(progress) == 0 { t.Fatal("expected queued progress before read error") } } func TestCopyDownloadWithProgressReturnsConcurrentWriteToError(t *testing.T) { copyErr := errors.New("writeto failed") src := &spyConcurrentWriteTo{err: copyErr} var dst bytes.Buffer opts := mustNormalizeSFTPTransferOptions(t, SFTPOptionsPtr(ThroughputSFTPTransferOptions())) _, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), nil, opts) if !errors.Is(err, copyErr) { t.Fatalf("expected concurrent writeto error, got %v", err) } if !src.usedWriteTo { t.Fatal("expected concurrent WriteTo path to be used") } } func TestCopyDownloadWithProgressDefaultsToSequentialCopy(t *testing.T) { oldCopy := sftpCopyWithProgressFunc called := false sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) { called = true return oldCopy(ctx, dst, src, bufSize, total, progress) } t.Cleanup(func() { sftpCopyWithProgressFunc = oldCopy }) var dst bytes.Buffer src := &chunkedReader{ reader: strings.NewReader("payload"), chunkSize: 2, } written, err := copyDownloadWithProgressContext(context.Background(), &dst, src, 3, int64(len("payload")), nil, mustNormalizeSFTPTransferOptions(t, nil)) if err != nil { t.Fatalf("copy download: %v", err) } if written != int64(len("payload")) { t.Fatalf("unexpected written bytes: got %d", written) } if !called { t.Fatal("expected default path to use existing copy helper") } } func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() localPath := filepath.Join(root, "local.txt") remotePath := filepath.Join(root, "remote.txt") if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } if err := os.WriteFile(remotePath, []byte("original remote"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } verifyErr := errors.New("verify failed") var verifiedPath string oldVerifyRemoteSize := sftpVerifyRemoteSizeFunc sftpVerifyRemoteSizeFunc = func(client *sftp.Client, remotePath string, expected int64) error { verifiedPath = remotePath return verifyErr } t.Cleanup(func() { sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize }) err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)) if !errors.Is(err, verifyErr) { t.Fatalf("expected verify failure, got %v", err) } if verifiedPath == remotePath { t.Fatal("expected upload verification to run against temp path before final rename") } data, err := os.ReadFile(remotePath) if err != nil { t.Fatalf("read remote file: %v", err) } if string(data) != "original remote" { t.Fatalf("remote target was replaced on verify failure: %q", string(data)) } assertNoTransferTemps(t, remotePath) } func TestTransferOutContextRejectsRemoteSymlinkTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() localPath := filepath.Join(root, "local.txt") remoteRealPath := filepath.Join(root, "remote-real.txt") remotePath := filepath.Join(root, "remote-link.txt") if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } if err := os.WriteFile(remoteRealPath, []byte("original remote"), 0o644); err != nil { t.Fatalf("write remote backing file: %v", err) } if err := os.Symlink(remoteRealPath, remotePath); err != nil { t.Skipf("symlink unsupported: %v", err) } err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)) if err == nil || !strings.Contains(err.Error(), "symlink") { t.Fatalf("expected symlink rejection, got %v", err) } info, err := os.Lstat(remotePath) if err != nil { t.Fatalf("lstat remote symlink: %v", err) } if info.Mode()&os.ModeSymlink == 0 { t.Fatal("expected remote target to remain a symlink") } data, err := os.ReadFile(remoteRealPath) if err != nil { t.Fatalf("read remote backing file: %v", err) } if string(data) != "original remote" { t.Fatalf("remote backing file changed unexpectedly: %q", string(data)) } assertNoTransferTemps(t, remotePath) } func TestTransferOutContextRejectsRemoteDirectoryTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() localPath := filepath.Join(root, "local.txt") remotePath := filepath.Join(root, "remote-dir") if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } if err := os.Mkdir(remotePath, 0o755); err != nil { t.Fatalf("mkdir remote target: %v", err) } err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)) if err == nil || !strings.Contains(err.Error(), "directory") { t.Fatalf("expected directory rejection, got %v", err) } info, err := os.Stat(remotePath) if err != nil { t.Fatalf("stat remote directory: %v", err) } if !info.IsDir() { t.Fatal("expected remote target to remain a directory") } assertNoTransferTemps(t, remotePath) } func TestTransferOutContextPreservesRemoteModeOnOverwrite(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() localPath := filepath.Join(root, "local.txt") remotePath := filepath.Join(root, "remote.txt") if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.Chmod(remotePath, 0o755); err != nil { t.Fatalf("chmod remote file: %v", err) } if err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil { t.Fatalf("transfer out: %v", err) } assertMode(t, remotePath, 0o755) assertFileContent(t, remotePath, "new payload") } func TestTransferOutContextAppliesLocalModeForNewRemoteFile(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() localPath := filepath.Join(root, "local.txt") remotePath := filepath.Join(root, "remote.txt") if err := os.WriteFile(localPath, []byte("new payload"), 0o751); err != nil { t.Fatalf("write local file: %v", err) } if err := os.Chmod(localPath, 0o751); err != nil { t.Fatalf("chmod local file: %v", err) } if err := transferOutContext(context.Background(), client, localPath, remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil { t.Fatalf("transfer out: %v", err) } assertMode(t, remotePath, 0o751) assertFileContent(t, remotePath, "new payload") } func TestTransferOutByteContextPreservesRemoteModeOnOverwrite(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() remotePath := filepath.Join(root, "remote.txt") if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.Chmod(remotePath, 0o755); err != nil { t.Fatalf("chmod remote file: %v", err) } if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil { t.Fatalf("transfer out bytes: %v", err) } assertMode(t, remotePath, 0o755) assertFileContent(t, remotePath, "byte payload") } func TestTransferInContextVerifyFailurePreservesLocalTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") dstPath := filepath.Join(root, "local.txt") if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } verifyErr := errors.New("verify local failed") var verifiedPath string oldVerifyLocalSize := sftpVerifyLocalSizeFunc sftpVerifyLocalSizeFunc = func(localPath string, expected int64) error { verifiedPath = localPath return verifyErr } t.Cleanup(func() { sftpVerifyLocalSizeFunc = oldVerifyLocalSize }) err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)) if !errors.Is(err, verifyErr) { t.Fatalf("expected verify failure, got %v", err) } if verifiedPath == dstPath { t.Fatal("expected download verification to run against temp path before final rename") } data, err := os.ReadFile(dstPath) if err != nil { t.Fatalf("read local file: %v", err) } if string(data) != "original local" { t.Fatalf("local target was replaced on verify failure: %q", string(data)) } assertNoTransferTemps(t, dstPath) } func TestTransferInContextRejectsLocalSymlinkTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") localRealPath := filepath.Join(root, "local-real.txt") dstPath := filepath.Join(root, "local-link.txt") if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.WriteFile(localRealPath, []byte("original local"), 0o644); err != nil { t.Fatalf("write local backing file: %v", err) } if err := os.Symlink(localRealPath, dstPath); err != nil { t.Skipf("symlink unsupported: %v", err) } err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)) if err == nil || !strings.Contains(err.Error(), "symlink") { t.Fatalf("expected symlink rejection, got %v", err) } info, err := os.Lstat(dstPath) if err != nil { t.Fatalf("lstat local symlink: %v", err) } if info.Mode()&os.ModeSymlink == 0 { t.Fatal("expected local target to remain a symlink") } assertFileContent(t, localRealPath, "original local") assertNoTransferTemps(t, dstPath) } func TestTransferInContextRejectsLocalDirectoryTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") dstPath := filepath.Join(root, "local-dir") if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.Mkdir(dstPath, 0o755); err != nil { t.Fatalf("mkdir local target: %v", err) } err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)) if err == nil || !strings.Contains(err.Error(), "directory") { t.Fatalf("expected directory rejection, got %v", err) } info, err := os.Stat(dstPath) if err != nil { t.Fatalf("stat local directory: %v", err) } if !info.IsDir() { t.Fatal("expected local target to remain a directory") } assertNoTransferTemps(t, dstPath) } func TestTransferInContextPreservesLocalModeOnOverwrite(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") dstPath := filepath.Join(root, "local.sh") if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.WriteFile(dstPath, []byte("#!/bin/sh\necho local\n"), 0o755); err != nil { t.Fatalf("write local file: %v", err) } if err := os.Chmod(dstPath, 0o755); err != nil { t.Fatalf("chmod local file: %v", err) } if err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil { t.Fatalf("transfer in: %v", err) } assertMode(t, dstPath, 0o755) assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n") } func TestTransferInContextAppliesRemoteModeForNewLocalFile(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.sh") dstPath := filepath.Join(root, "local.sh") if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o751); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.Chmod(srcPath, 0o751); err != nil { t.Fatalf("chmod remote file: %v", err) } if err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)); err != nil { t.Fatalf("transfer in: %v", err) } assertMode(t, dstPath, 0o751) assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n") } func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") dstPath := filepath.Join(root, "local.txt") if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } copyErr := errors.New("copy failed") var copyTargetPath string oldCopy := sftpCopyWithProgressFunc sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) { file, ok := dst.(*os.File) if !ok { t.Fatalf("expected local temp file writer, got %T", dst) } copyTargetPath = file.Name() buf := make([]byte, 8) n, readErr := src.Read(buf) if readErr != nil && !errors.Is(readErr, io.EOF) { return 0, readErr } if n > 0 { written, err := dst.Write(buf[:n]) if err != nil { return int64(written), err } return int64(written), copyErr } return 0, copyErr } t.Cleanup(func() { sftpCopyWithProgressFunc = oldCopy }) err := transferInContext(context.Background(), client, srcPath, dstPath, mustNormalizeSFTPTransferOptions(t, nil)) if !errors.Is(err, copyErr) { t.Fatalf("expected copy failure, got %v", err) } if copyTargetPath == dstPath { t.Fatal("expected partial download writes to stay on temp path") } data, err := os.ReadFile(dstPath) if err != nil { t.Fatalf("read local file: %v", err) } if string(data) != "original local" { t.Fatalf("local target was modified by partial download: %q", string(data)) } assertNoTransferTemps(t, dstPath) } func mustNormalizeSFTPTransferOptions(t *testing.T, options *SFTPTransferOptions) resolvedSFTPTransferOptions { t.Helper() opts, err := normalizeSFTPTransferOptions(options) if err != nil { t.Fatalf("normalize sftp transfer options: %v", err) } return opts } func SFTPOptionsPtr(options SFTPTransferOptions) *SFTPTransferOptions { return &options } type spyConcurrentReadFrom struct { buf bytes.Buffer concurrency int usedReadFrom bool err error errAfterRead error } func (w *spyConcurrentReadFrom) Write(p []byte) (int, error) { return w.buf.Write(p) } func (w *spyConcurrentReadFrom) ReadFromWithConcurrency(r io.Reader, concurrency int) (int64, error) { w.usedReadFrom = true w.concurrency = concurrency if w.err != nil { return 0, w.err } written, err := w.buf.ReadFrom(r) if err != nil { return written, err } if w.errAfterRead != nil { return written, w.errAfterRead } return written, nil } type spyConcurrentWriteTo struct { payload []byte chunkSize int usedWriteTo bool err error errAfterWrite error cancelAfterWrite context.CancelFunc } func (r *spyConcurrentWriteTo) Read(p []byte) (int, error) { if len(r.payload) == 0 { if r.err != nil { return 0, r.err } return 0, io.EOF } n := copy(p, r.payload) r.payload = r.payload[n:] if len(r.payload) == 0 { return n, io.EOF } return n, nil } func (r *spyConcurrentWriteTo) WriteTo(w io.Writer) (int64, error) { r.usedWriteTo = true if r.err != nil { return 0, r.err } payload := r.payload if len(payload) == 0 { if r.errAfterWrite != nil { return 0, r.errAfterWrite } return 0, nil } chunkSize := r.chunkSize if chunkSize <= 0 || chunkSize > len(payload) { chunkSize = len(payload) } var written int64 for len(payload) > 0 { size := chunkSize if size > len(payload) { size = len(payload) } n, err := w.Write(payload[:size]) written += int64(n) if r.cancelAfterWrite != nil && n > 0 { r.cancelAfterWrite() r.cancelAfterWrite = nil } if err != nil { return written, err } if n != size { return written, io.ErrShortWrite } payload = payload[size:] } if r.errAfterWrite != nil { return written, r.errAfterWrite } return written, nil } type chunkedReader struct { reader io.Reader chunkSize int } func (r *chunkedReader) Read(p []byte) (int, error) { if r.chunkSize > 0 && len(p) > r.chunkSize { p = p[:r.chunkSize] } return r.reader.Read(p) } type cancelAfterReadReader struct { reader io.Reader chunkSize int cancel context.CancelFunc } func (r *cancelAfterReadReader) Read(p []byte) (int, error) { if r.chunkSize > 0 && len(p) > r.chunkSize { p = p[:r.chunkSize] } n, err := r.reader.Read(p) if n > 0 && r.cancel != nil { r.cancel() r.cancel = nil } return n, err } func newSFTPTestClient(t *testing.T) *sftp.Client { return newSFTPTestClientWithOptions(t, nil) } func newSFTPTestClientWithOptions(t *testing.T, options []sftp.ClientOption) *sftp.Client { t.Helper() serverConn, clientConn := net.Pipe() server, err := sftp.NewServer(serverConn) if err != nil { t.Fatalf("create sftp server: %v", err) } serveErrCh := make(chan error, 1) go func() { serveErrCh <- server.Serve() }() client, err := sftp.NewClientPipe(clientConn, clientConn, options...) if err != nil { _ = server.Close() t.Fatalf("create sftp client: %v", err) } t.Cleanup(func() { _ = client.Close() _ = server.Close() serveErr := <-serveErrCh if serveErr == nil || errors.Is(serveErr, io.EOF) || normalizeAlreadyClosedError(serveErr) == nil { return } t.Errorf("unexpected sftp server error: %v", serveErr) }) return client } func assertNoTransferTemps(t *testing.T, targetPath string) { t.Helper() matches, err := filepath.Glob(targetPath + defaultSFTPTempSuffix + "*") if err != nil { t.Fatalf("glob temp files: %v", err) } if len(matches) != 0 { t.Fatalf("expected temp artifacts to be cleaned up, got %v", matches) } } func assertMode(t *testing.T, targetPath string, want os.FileMode) { t.Helper() info, err := os.Stat(targetPath) if err != nil { t.Fatalf("stat %q: %v", targetPath, err) } if got := info.Mode().Perm(); got != want { t.Fatalf("unexpected mode for %q: got %o want %o", targetPath, got, want) } } func assertFileContent(t *testing.T, targetPath string, want string) { t.Helper() data, err := os.ReadFile(targetPath) if err != nil { t.Fatalf("read %q: %v", targetPath, err) } if string(data) != want { t.Fatalf("unexpected content for %q: got %q want %q", targetPath, string(data), want) } }