package starssh import ( "context" "errors" "io" "net" "os" "path/filepath" "strings" "testing" "github.com/pkg/sftp" ) func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) { opts := normalizeSFTPTransferOptions(nil) if !opts.AtomicUpload { t.Fatal("expected atomic upload to default to enabled") } if !opts.AtomicDownload { t.Fatal("expected atomic download to default to enabled") } } func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() localPath := filepath.Join(root, "local.txt") remotePath := filepath.Join(root, "remote.txt") if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } if err := os.WriteFile(remotePath, []byte("original remote"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } verifyErr := errors.New("verify failed") var verifiedPath string oldVerifyRemoteSize := sftpVerifyRemoteSizeFunc sftpVerifyRemoteSizeFunc = func(client *sftp.Client, remotePath string, expected int64) error { verifiedPath = remotePath return verifyErr } t.Cleanup(func() { sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize }) err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)) if !errors.Is(err, verifyErr) { t.Fatalf("expected verify failure, got %v", err) } if verifiedPath == remotePath { t.Fatal("expected upload verification to run against temp path before final rename") } data, err := os.ReadFile(remotePath) if err != nil { t.Fatalf("read remote file: %v", err) } if string(data) != "original remote" { t.Fatalf("remote target was replaced on verify failure: %q", string(data)) } assertNoTransferTemps(t, remotePath) } func TestTransferOutContextRejectsRemoteSymlinkTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() localPath := filepath.Join(root, "local.txt") remoteRealPath := filepath.Join(root, "remote-real.txt") remotePath := filepath.Join(root, "remote-link.txt") if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } if err := os.WriteFile(remoteRealPath, []byte("original remote"), 0o644); err != nil { t.Fatalf("write remote backing file: %v", err) } if err := os.Symlink(remoteRealPath, remotePath); err != nil { t.Skipf("symlink unsupported: %v", err) } err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)) if err == nil || !strings.Contains(err.Error(), "symlink") { t.Fatalf("expected symlink rejection, got %v", err) } info, err := os.Lstat(remotePath) if err != nil { t.Fatalf("lstat remote symlink: %v", err) } if info.Mode()&os.ModeSymlink == 0 { t.Fatal("expected remote target to remain a symlink") } data, err := os.ReadFile(remoteRealPath) if err != nil { t.Fatalf("read remote backing file: %v", err) } if string(data) != "original remote" { t.Fatalf("remote backing file changed unexpectedly: %q", string(data)) } assertNoTransferTemps(t, remotePath) } func TestTransferOutContextRejectsRemoteDirectoryTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() localPath := filepath.Join(root, "local.txt") remotePath := filepath.Join(root, "remote-dir") if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } if err := os.Mkdir(remotePath, 0o755); err != nil { t.Fatalf("mkdir remote target: %v", err) } err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)) if err == nil || !strings.Contains(err.Error(), "directory") { t.Fatalf("expected directory rejection, got %v", err) } info, err := os.Stat(remotePath) if err != nil { t.Fatalf("stat remote directory: %v", err) } if !info.IsDir() { t.Fatal("expected remote target to remain a directory") } assertNoTransferTemps(t, remotePath) } func TestTransferOutContextPreservesRemoteModeOnOverwrite(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() localPath := filepath.Join(root, "local.txt") remotePath := filepath.Join(root, "remote.txt") if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.Chmod(remotePath, 0o755); err != nil { t.Fatalf("chmod remote file: %v", err) } if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil { t.Fatalf("transfer out: %v", err) } assertMode(t, remotePath, 0o755) assertFileContent(t, remotePath, "new payload") } func TestTransferOutContextAppliesLocalModeForNewRemoteFile(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() localPath := filepath.Join(root, "local.txt") remotePath := filepath.Join(root, "remote.txt") if err := os.WriteFile(localPath, []byte("new payload"), 0o751); err != nil { t.Fatalf("write local file: %v", err) } if err := os.Chmod(localPath, 0o751); err != nil { t.Fatalf("chmod local file: %v", err) } if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil { t.Fatalf("transfer out: %v", err) } assertMode(t, remotePath, 0o751) assertFileContent(t, remotePath, "new payload") } func TestTransferOutByteContextPreservesRemoteModeOnOverwrite(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() remotePath := filepath.Join(root, "remote.txt") if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.Chmod(remotePath, 0o755); err != nil { t.Fatalf("chmod remote file: %v", err) } if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, normalizeSFTPTransferOptions(nil)); err != nil { t.Fatalf("transfer out bytes: %v", err) } assertMode(t, remotePath, 0o755) assertFileContent(t, remotePath, "byte payload") } func TestTransferInContextVerifyFailurePreservesLocalTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") dstPath := filepath.Join(root, "local.txt") if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } verifyErr := errors.New("verify local failed") var verifiedPath string oldVerifyLocalSize := sftpVerifyLocalSizeFunc sftpVerifyLocalSizeFunc = func(localPath string, expected int64) error { verifiedPath = localPath return verifyErr } t.Cleanup(func() { sftpVerifyLocalSizeFunc = oldVerifyLocalSize }) err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) if !errors.Is(err, verifyErr) { t.Fatalf("expected verify failure, got %v", err) } if verifiedPath == dstPath { t.Fatal("expected download verification to run against temp path before final rename") } data, err := os.ReadFile(dstPath) if err != nil { t.Fatalf("read local file: %v", err) } if string(data) != "original local" { t.Fatalf("local target was replaced on verify failure: %q", string(data)) } assertNoTransferTemps(t, dstPath) } func TestTransferInContextRejectsLocalSymlinkTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") localRealPath := filepath.Join(root, "local-real.txt") dstPath := filepath.Join(root, "local-link.txt") if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.WriteFile(localRealPath, []byte("original local"), 0o644); err != nil { t.Fatalf("write local backing file: %v", err) } if err := os.Symlink(localRealPath, dstPath); err != nil { t.Skipf("symlink unsupported: %v", err) } err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) if err == nil || !strings.Contains(err.Error(), "symlink") { t.Fatalf("expected symlink rejection, got %v", err) } info, err := os.Lstat(dstPath) if err != nil { t.Fatalf("lstat local symlink: %v", err) } if info.Mode()&os.ModeSymlink == 0 { t.Fatal("expected local target to remain a symlink") } assertFileContent(t, localRealPath, "original local") assertNoTransferTemps(t, dstPath) } func TestTransferInContextRejectsLocalDirectoryTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") dstPath := filepath.Join(root, "local-dir") if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.Mkdir(dstPath, 0o755); err != nil { t.Fatalf("mkdir local target: %v", err) } err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) if err == nil || !strings.Contains(err.Error(), "directory") { t.Fatalf("expected directory rejection, got %v", err) } info, err := os.Stat(dstPath) if err != nil { t.Fatalf("stat local directory: %v", err) } if !info.IsDir() { t.Fatal("expected local target to remain a directory") } assertNoTransferTemps(t, dstPath) } func TestTransferInContextPreservesLocalModeOnOverwrite(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") dstPath := filepath.Join(root, "local.sh") if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.WriteFile(dstPath, []byte("#!/bin/sh\necho local\n"), 0o755); err != nil { t.Fatalf("write local file: %v", err) } if err := os.Chmod(dstPath, 0o755); err != nil { t.Fatalf("chmod local file: %v", err) } if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil { t.Fatalf("transfer in: %v", err) } assertMode(t, dstPath, 0o755) assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n") } func TestTransferInContextAppliesRemoteModeForNewLocalFile(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.sh") dstPath := filepath.Join(root, "local.sh") if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o751); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.Chmod(srcPath, 0o751); err != nil { t.Fatalf("chmod remote file: %v", err) } if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil { t.Fatalf("transfer in: %v", err) } assertMode(t, dstPath, 0o751) assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n") } func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) { client := newSFTPTestClient(t) root := t.TempDir() srcPath := filepath.Join(root, "remote.txt") dstPath := filepath.Join(root, "local.txt") if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil { t.Fatalf("write remote file: %v", err) } if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil { t.Fatalf("write local file: %v", err) } copyErr := errors.New("copy failed") var copyTargetPath string oldCopy := sftpCopyWithProgressFunc sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) { file, ok := dst.(*os.File) if !ok { t.Fatalf("expected local temp file writer, got %T", dst) } copyTargetPath = file.Name() buf := make([]byte, 8) n, readErr := src.Read(buf) if readErr != nil && !errors.Is(readErr, io.EOF) { return 0, readErr } if n > 0 { written, err := dst.Write(buf[:n]) if err != nil { return int64(written), err } return int64(written), copyErr } return 0, copyErr } t.Cleanup(func() { sftpCopyWithProgressFunc = oldCopy }) err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)) if !errors.Is(err, copyErr) { t.Fatalf("expected copy failure, got %v", err) } if copyTargetPath == dstPath { t.Fatal("expected partial download writes to stay on temp path") } data, err := os.ReadFile(dstPath) if err != nil { t.Fatalf("read local file: %v", err) } if string(data) != "original local" { t.Fatalf("local target was modified by partial download: %q", string(data)) } assertNoTransferTemps(t, dstPath) } func newSFTPTestClient(t *testing.T) *sftp.Client { t.Helper() serverConn, clientConn := net.Pipe() server, err := sftp.NewServer(serverConn) if err != nil { t.Fatalf("create sftp server: %v", err) } serveErrCh := make(chan error, 1) go func() { serveErrCh <- server.Serve() }() client, err := sftp.NewClientPipe(clientConn, clientConn) if err != nil { _ = server.Close() t.Fatalf("create sftp client: %v", err) } t.Cleanup(func() { _ = client.Close() _ = server.Close() serveErr := <-serveErrCh if serveErr == nil || errors.Is(serveErr, io.EOF) || normalizeAlreadyClosedError(serveErr) == nil { return } t.Errorf("unexpected sftp server error: %v", serveErr) }) return client } func assertNoTransferTemps(t *testing.T, targetPath string) { t.Helper() matches, err := filepath.Glob(targetPath + defaultSFTPTempSuffix + "*") if err != nil { t.Fatalf("glob temp files: %v", err) } if len(matches) != 0 { t.Fatalf("expected temp artifacts to be cleaned up, got %v", matches) } } func assertMode(t *testing.T, targetPath string, want os.FileMode) { t.Helper() info, err := os.Stat(targetPath) if err != nil { t.Fatalf("stat %q: %v", targetPath, err) } if got := info.Mode().Perm(); got != want { t.Fatalf("unexpected mode for %q: got %o want %o", targetPath, got, want) } } func assertFileContent(t *testing.T, targetPath string, want string) { t.Helper() data, err := os.ReadFile(targetPath) if err != nil { t.Fatalf("read %q: %v", targetPath, err) } if string(data) != want { t.Fatalf("unexpected content for %q: got %q want %q", targetPath, string(data), want) } }