refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力
- 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
This commit is contained in:
+475
@@ -0,0 +1,475 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
)
|
||||
|
||||
func TestNormalizeSFTPTransferOptionsDefaultsAtomicDownload(t *testing.T) {
|
||||
opts := normalizeSFTPTransferOptions(nil)
|
||||
if !opts.AtomicUpload {
|
||||
t.Fatal("expected atomic upload to default to enabled")
|
||||
}
|
||||
if !opts.AtomicDownload {
|
||||
t.Fatal("expected atomic download to default to enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransferOutContextVerifyFailurePreservesRemoteTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remotePath := filepath.Join(root, "remote.txt")
|
||||
|
||||
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(remotePath, []byte("original remote"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
|
||||
verifyErr := errors.New("verify failed")
|
||||
var verifiedPath string
|
||||
oldVerifyRemoteSize := sftpVerifyRemoteSizeFunc
|
||||
sftpVerifyRemoteSizeFunc = func(client *sftp.Client, remotePath string, expected int64) error {
|
||||
verifiedPath = remotePath
|
||||
return verifyErr
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
sftpVerifyRemoteSizeFunc = oldVerifyRemoteSize
|
||||
})
|
||||
|
||||
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
|
||||
if !errors.Is(err, verifyErr) {
|
||||
t.Fatalf("expected verify failure, got %v", err)
|
||||
}
|
||||
if verifiedPath == remotePath {
|
||||
t.Fatal("expected upload verification to run against temp path before final rename")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(remotePath)
|
||||
if err != nil {
|
||||
t.Fatalf("read remote file: %v", err)
|
||||
}
|
||||
if string(data) != "original remote" {
|
||||
t.Fatalf("remote target was replaced on verify failure: %q", string(data))
|
||||
}
|
||||
assertNoTransferTemps(t, remotePath)
|
||||
}
|
||||
|
||||
func TestTransferOutContextRejectsRemoteSymlinkTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remoteRealPath := filepath.Join(root, "remote-real.txt")
|
||||
remotePath := filepath.Join(root, "remote-link.txt")
|
||||
|
||||
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(remoteRealPath, []byte("original remote"), 0o644); err != nil {
|
||||
t.Fatalf("write remote backing file: %v", err)
|
||||
}
|
||||
if err := os.Symlink(remoteRealPath, remotePath); err != nil {
|
||||
t.Skipf("symlink unsupported: %v", err)
|
||||
}
|
||||
|
||||
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
|
||||
if err == nil || !strings.Contains(err.Error(), "symlink") {
|
||||
t.Fatalf("expected symlink rejection, got %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Lstat(remotePath)
|
||||
if err != nil {
|
||||
t.Fatalf("lstat remote symlink: %v", err)
|
||||
}
|
||||
if info.Mode()&os.ModeSymlink == 0 {
|
||||
t.Fatal("expected remote target to remain a symlink")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(remoteRealPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read remote backing file: %v", err)
|
||||
}
|
||||
if string(data) != "original remote" {
|
||||
t.Fatalf("remote backing file changed unexpectedly: %q", string(data))
|
||||
}
|
||||
assertNoTransferTemps(t, remotePath)
|
||||
}
|
||||
|
||||
func TestTransferOutContextRejectsRemoteDirectoryTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remotePath := filepath.Join(root, "remote-dir")
|
||||
|
||||
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
if err := os.Mkdir(remotePath, 0o755); err != nil {
|
||||
t.Fatalf("mkdir remote target: %v", err)
|
||||
}
|
||||
|
||||
err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil))
|
||||
if err == nil || !strings.Contains(err.Error(), "directory") {
|
||||
t.Fatalf("expected directory rejection, got %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(remotePath)
|
||||
if err != nil {
|
||||
t.Fatalf("stat remote directory: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Fatal("expected remote target to remain a directory")
|
||||
}
|
||||
assertNoTransferTemps(t, remotePath)
|
||||
}
|
||||
|
||||
func TestTransferOutContextPreservesRemoteModeOnOverwrite(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remotePath := filepath.Join(root, "remote.txt")
|
||||
|
||||
if err := os.WriteFile(localPath, []byte("new payload"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.Chmod(remotePath, 0o755); err != nil {
|
||||
t.Fatalf("chmod remote file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
t.Fatalf("transfer out: %v", err)
|
||||
}
|
||||
|
||||
assertMode(t, remotePath, 0o755)
|
||||
assertFileContent(t, remotePath, "new payload")
|
||||
}
|
||||
|
||||
func TestTransferOutContextAppliesLocalModeForNewRemoteFile(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
localPath := filepath.Join(root, "local.txt")
|
||||
remotePath := filepath.Join(root, "remote.txt")
|
||||
|
||||
if err := os.WriteFile(localPath, []byte("new payload"), 0o751); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
if err := os.Chmod(localPath, 0o751); err != nil {
|
||||
t.Fatalf("chmod local file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferOutContext(context.Background(), client, localPath, remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
t.Fatalf("transfer out: %v", err)
|
||||
}
|
||||
|
||||
assertMode(t, remotePath, 0o751)
|
||||
assertFileContent(t, remotePath, "new payload")
|
||||
}
|
||||
|
||||
func TestTransferOutByteContextPreservesRemoteModeOnOverwrite(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
remotePath := filepath.Join(root, "remote.txt")
|
||||
|
||||
if err := os.WriteFile(remotePath, []byte("original remote"), 0o755); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.Chmod(remotePath, 0o755); err != nil {
|
||||
t.Fatalf("chmod remote file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferOutByteContext(context.Background(), client, []byte("byte payload"), remotePath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
t.Fatalf("transfer out bytes: %v", err)
|
||||
}
|
||||
|
||||
assertMode(t, remotePath, 0o755)
|
||||
assertFileContent(t, remotePath, "byte payload")
|
||||
}
|
||||
|
||||
func TestTransferInContextVerifyFailurePreservesLocalTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
dstPath := filepath.Join(root, "local.txt")
|
||||
|
||||
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
|
||||
verifyErr := errors.New("verify local failed")
|
||||
var verifiedPath string
|
||||
oldVerifyLocalSize := sftpVerifyLocalSizeFunc
|
||||
sftpVerifyLocalSizeFunc = func(localPath string, expected int64) error {
|
||||
verifiedPath = localPath
|
||||
return verifyErr
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
sftpVerifyLocalSizeFunc = oldVerifyLocalSize
|
||||
})
|
||||
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
||||
if !errors.Is(err, verifyErr) {
|
||||
t.Fatalf("expected verify failure, got %v", err)
|
||||
}
|
||||
if verifiedPath == dstPath {
|
||||
t.Fatal("expected download verification to run against temp path before final rename")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(dstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read local file: %v", err)
|
||||
}
|
||||
if string(data) != "original local" {
|
||||
t.Fatalf("local target was replaced on verify failure: %q", string(data))
|
||||
}
|
||||
assertNoTransferTemps(t, dstPath)
|
||||
}
|
||||
|
||||
func TestTransferInContextRejectsLocalSymlinkTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
localRealPath := filepath.Join(root, "local-real.txt")
|
||||
dstPath := filepath.Join(root, "local-link.txt")
|
||||
|
||||
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(localRealPath, []byte("original local"), 0o644); err != nil {
|
||||
t.Fatalf("write local backing file: %v", err)
|
||||
}
|
||||
if err := os.Symlink(localRealPath, dstPath); err != nil {
|
||||
t.Skipf("symlink unsupported: %v", err)
|
||||
}
|
||||
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
||||
if err == nil || !strings.Contains(err.Error(), "symlink") {
|
||||
t.Fatalf("expected symlink rejection, got %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Lstat(dstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("lstat local symlink: %v", err)
|
||||
}
|
||||
if info.Mode()&os.ModeSymlink == 0 {
|
||||
t.Fatal("expected local target to remain a symlink")
|
||||
}
|
||||
assertFileContent(t, localRealPath, "original local")
|
||||
assertNoTransferTemps(t, dstPath)
|
||||
}
|
||||
|
||||
func TestTransferInContextRejectsLocalDirectoryTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
dstPath := filepath.Join(root, "local-dir")
|
||||
|
||||
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.Mkdir(dstPath, 0o755); err != nil {
|
||||
t.Fatalf("mkdir local target: %v", err)
|
||||
}
|
||||
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
||||
if err == nil || !strings.Contains(err.Error(), "directory") {
|
||||
t.Fatalf("expected directory rejection, got %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(dstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("stat local directory: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Fatal("expected local target to remain a directory")
|
||||
}
|
||||
assertNoTransferTemps(t, dstPath)
|
||||
}
|
||||
|
||||
func TestTransferInContextPreservesLocalModeOnOverwrite(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
dstPath := filepath.Join(root, "local.sh")
|
||||
|
||||
if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(dstPath, []byte("#!/bin/sh\necho local\n"), 0o755); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
if err := os.Chmod(dstPath, 0o755); err != nil {
|
||||
t.Fatalf("chmod local file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
t.Fatalf("transfer in: %v", err)
|
||||
}
|
||||
|
||||
assertMode(t, dstPath, 0o755)
|
||||
assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n")
|
||||
}
|
||||
|
||||
func TestTransferInContextAppliesRemoteModeForNewLocalFile(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.sh")
|
||||
dstPath := filepath.Join(root, "local.sh")
|
||||
|
||||
if err := os.WriteFile(srcPath, []byte("#!/bin/sh\necho remote\n"), 0o751); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.Chmod(srcPath, 0o751); err != nil {
|
||||
t.Fatalf("chmod remote file: %v", err)
|
||||
}
|
||||
|
||||
if err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil)); err != nil {
|
||||
t.Fatalf("transfer in: %v", err)
|
||||
}
|
||||
|
||||
assertMode(t, dstPath, 0o751)
|
||||
assertFileContent(t, dstPath, "#!/bin/sh\necho remote\n")
|
||||
}
|
||||
|
||||
func TestTransferInContextCopyFailurePreservesLocalTarget(t *testing.T) {
|
||||
client := newSFTPTestClient(t)
|
||||
root := t.TempDir()
|
||||
srcPath := filepath.Join(root, "remote.txt")
|
||||
dstPath := filepath.Join(root, "local.txt")
|
||||
|
||||
if err := os.WriteFile(srcPath, []byte("fresh remote payload"), 0o644); err != nil {
|
||||
t.Fatalf("write remote file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(dstPath, []byte("original local"), 0o644); err != nil {
|
||||
t.Fatalf("write local file: %v", err)
|
||||
}
|
||||
|
||||
copyErr := errors.New("copy failed")
|
||||
var copyTargetPath string
|
||||
oldCopy := sftpCopyWithProgressFunc
|
||||
sftpCopyWithProgressFunc = func(ctx context.Context, dst io.Writer, src io.Reader, bufSize int, total int64, progress func(float64)) (int64, error) {
|
||||
file, ok := dst.(*os.File)
|
||||
if !ok {
|
||||
t.Fatalf("expected local temp file writer, got %T", dst)
|
||||
}
|
||||
copyTargetPath = file.Name()
|
||||
|
||||
buf := make([]byte, 8)
|
||||
n, readErr := src.Read(buf)
|
||||
if readErr != nil && !errors.Is(readErr, io.EOF) {
|
||||
return 0, readErr
|
||||
}
|
||||
if n > 0 {
|
||||
written, err := dst.Write(buf[:n])
|
||||
if err != nil {
|
||||
return int64(written), err
|
||||
}
|
||||
return int64(written), copyErr
|
||||
}
|
||||
return 0, copyErr
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
sftpCopyWithProgressFunc = oldCopy
|
||||
})
|
||||
|
||||
err := transferInContext(context.Background(), client, srcPath, dstPath, normalizeSFTPTransferOptions(nil))
|
||||
if !errors.Is(err, copyErr) {
|
||||
t.Fatalf("expected copy failure, got %v", err)
|
||||
}
|
||||
if copyTargetPath == dstPath {
|
||||
t.Fatal("expected partial download writes to stay on temp path")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(dstPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read local file: %v", err)
|
||||
}
|
||||
if string(data) != "original local" {
|
||||
t.Fatalf("local target was modified by partial download: %q", string(data))
|
||||
}
|
||||
assertNoTransferTemps(t, dstPath)
|
||||
}
|
||||
|
||||
func newSFTPTestClient(t *testing.T) *sftp.Client {
|
||||
t.Helper()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
server, err := sftp.NewServer(serverConn)
|
||||
if err != nil {
|
||||
t.Fatalf("create sftp server: %v", err)
|
||||
}
|
||||
|
||||
serveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
serveErrCh <- server.Serve()
|
||||
}()
|
||||
|
||||
client, err := sftp.NewClientPipe(clientConn, clientConn)
|
||||
if err != nil {
|
||||
_ = server.Close()
|
||||
t.Fatalf("create sftp client: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = client.Close()
|
||||
_ = server.Close()
|
||||
serveErr := <-serveErrCh
|
||||
if serveErr == nil || errors.Is(serveErr, io.EOF) || normalizeAlreadyClosedError(serveErr) == nil {
|
||||
return
|
||||
}
|
||||
t.Errorf("unexpected sftp server error: %v", serveErr)
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func assertNoTransferTemps(t *testing.T, targetPath string) {
|
||||
t.Helper()
|
||||
|
||||
matches, err := filepath.Glob(targetPath + defaultSFTPTempSuffix + "*")
|
||||
if err != nil {
|
||||
t.Fatalf("glob temp files: %v", err)
|
||||
}
|
||||
if len(matches) != 0 {
|
||||
t.Fatalf("expected temp artifacts to be cleaned up, got %v", matches)
|
||||
}
|
||||
}
|
||||
|
||||
func assertMode(t *testing.T, targetPath string, want os.FileMode) {
|
||||
t.Helper()
|
||||
|
||||
info, err := os.Stat(targetPath)
|
||||
if err != nil {
|
||||
t.Fatalf("stat %q: %v", targetPath, err)
|
||||
}
|
||||
if got := info.Mode().Perm(); got != want {
|
||||
t.Fatalf("unexpected mode for %q: got %o want %o", targetPath, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertFileContent(t *testing.T, targetPath string, want string) {
|
||||
t.Helper()
|
||||
|
||||
data, err := os.ReadFile(targetPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read %q: %v", targetPath, err)
|
||||
}
|
||||
if string(data) != want {
|
||||
t.Fatalf("unexpected content for %q: got %q want %q", targetPath, string(data), want)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user