starssh/forward_test.go

699 lines
19 KiB
Go
Raw Permalink Normal View History

refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力 - 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
2026-04-26 10:45:39 +08:00
package starssh
import (
"context"
"errors"
refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力 - 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
2026-04-26 10:45:39 +08:00
"io"
"net"
"os"
"path/filepath"
"runtime"
"sync"
refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力 - 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
2026-04-26 10:45:39 +08:00
"sync/atomic"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
type stubListener struct {
addr net.Addr
acceptCh chan net.Conn
closeCh chan struct{}
closeOnce sync.Once
}
type dialRecord struct {
network string
addr string
}
func newStubListener(addr net.Addr) *stubListener {
return &stubListener{
addr: addr,
acceptCh: make(chan net.Conn, 1),
closeCh: make(chan struct{}),
}
}
func (l *stubListener) Accept() (net.Conn, error) {
select {
case conn, ok := <-l.acceptCh:
if !ok {
return nil, io.EOF
}
return conn, nil
case <-l.closeCh:
return nil, net.ErrClosed
}
}
func (l *stubListener) Close() error {
l.closeOnce.Do(func() {
close(l.closeCh)
close(l.acceptCh)
})
return nil
}
func (l *stubListener) Addr() net.Addr {
return l.addr
}
func (l *stubListener) Push(conn net.Conn) error {
select {
case <-l.closeCh:
return net.ErrClosed
case l.acceptCh <- conn:
return nil
}
}
refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力 - 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
2026-04-26 10:45:39 +08:00
func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldNewDetachedForwardClient := newDetachedForwardClient
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
newDetachedForwardClient = oldNewDetachedForwardClient
closeSSHClient = oldCloseSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
var detachedCalls atomic.Int32
newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
detachedCalls.Add(1)
return nil, nil
}
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != baseClient {
t.Errorf("expected existing ssh client, got %p want %p", client, baseClient)
}
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
closeSSHClient = func(client sshClientRequester) error {
t.Fatal("default local forward should not close the main ssh client")
return nil
}
forwarder, err := star.StartLocalForward(ForwardRequest{
ListenAddr: "127.0.0.1:0",
TargetAddr: "example.internal:22",
})
if err != nil {
t.Fatalf("start local forward: %v", err)
}
defer forwarder.Close()
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("ping"))
if string(reply) != "ping" {
t.Fatalf("unexpected forwarded reply: %q", string(reply))
}
if detachedCalls.Load() != 0 {
t.Fatalf("default local forward should not create detached ssh client, calls=%d", detachedCalls.Load())
}
}
func TestForwardRequestLegacyPositionalLiteralDefaultsToTCP(t *testing.T) {
dialer := func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, nil
}
req, err := normalizeForwardRequest(ForwardRequest{
"127.0.0.1:10022",
"example.internal:22",
dialer,
})
if err != nil {
t.Fatalf("normalizeForwardRequest: %v", err)
}
if req.ListenNetwork != "tcp" {
t.Fatalf("ListenNetwork=%q want tcp", req.ListenNetwork)
}
if req.TargetNetwork != "tcp" {
t.Fatalf("TargetNetwork=%q want tcp", req.TargetNetwork)
}
if req.ListenAddr != "127.0.0.1:10022" || req.TargetAddr != "example.internal:22" {
t.Fatalf("unexpected normalized request: %+v", req)
}
if req.DialContext == nil {
t.Fatal("expected DialContext to be preserved")
}
}
func TestParseForwardEndpointTreatsTCPPrefixLikePlainAddress(t *testing.T) {
network, address, err := parseForwardEndpoint("tcp:22")
if err != nil {
t.Fatalf("parseForwardEndpoint: %v", err)
}
if network != "tcp" {
t.Fatalf("network=%q want tcp", network)
}
if address != "tcp:22" {
t.Fatalf("address=%q want tcp:22", address)
}
}
func TestParseForwardEndpointSupportsExplicitSchemes(t *testing.T) {
network, address, err := parseForwardEndpoint("unix:///tmp/test-forward.sock")
if err != nil {
t.Fatalf("parseForwardEndpoint unix: %v", err)
}
if network != "unix" || address != "/tmp/test-forward.sock" {
t.Fatalf("unexpected unix endpoint parse: network=%q address=%q", network, address)
}
network, address, err = parseForwardEndpoint("tcp6://[::1]:2222")
if err != nil {
t.Fatalf("parseForwardEndpoint tcp6: %v", err)
}
if network != "tcp6" || address != "[::1]:2222" {
t.Fatalf("unexpected tcp6 endpoint parse: network=%q address=%q", network, address)
}
}
refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力 - 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
2026-04-26 10:45:39 +08:00
func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldNewDetachedForwardClient := newDetachedForwardClient
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
newDetachedForwardClient = oldNewDetachedForwardClient
closeSSHClient = oldCloseSSHClient
})
baseClient := &ssh.Client{}
detachedClient := &ssh.Client{}
star := &StarSSH{LoginInfo: LoginInput{User: "tester", Addr: "127.0.0.1"}}
star.setTransport(baseClient, nil)
forwardClient := &StarSSH{}
forwardClient.setTransport(detachedClient, nil)
var detachedCalls atomic.Int32
newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
detachedCalls.Add(1)
return forwardClient, nil
}
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != detachedClient {
t.Errorf("expected detached ssh client, got %p want %p", client, detachedClient)
}
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
var closeCalls atomic.Int32
closeSSHClient = func(client sshClientRequester) error {
closeCalls.Add(1)
return nil
}
forwarder, err := star.StartLocalForwardDetached(ForwardRequest{
ListenAddr: "127.0.0.1:0",
TargetAddr: "example.internal:22",
})
if err != nil {
t.Fatalf("start detached local forward: %v", err)
}
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("pong"))
if string(reply) != "pong" {
t.Fatalf("unexpected detached forwarded reply: %q", string(reply))
}
if err := forwarder.Close(); err != nil {
t.Fatalf("close detached local forward: %v", err)
}
if detachedCalls.Load() != 1 {
t.Fatalf("expected one detached ssh login, got %d", detachedCalls.Load())
}
if closeCalls.Load() != 1 {
t.Fatalf("expected detached ssh client cleanup once, got %d", closeCalls.Load())
}
if got := star.snapshotSSHClient(); got != baseClient {
t.Fatal("detached local forward should not detach the main ssh client")
}
if got := forwardClient.snapshotSSHClient(); got != nil {
t.Fatal("detached local forward should close its detached ssh client")
}
}
func TestStartRemoteForwardSupportsUnixListenAndTCPTarget(t *testing.T) {
oldListenSSHClient := listenSSHClient
t.Cleanup(func() {
listenSSHClient = oldListenSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
listener := newStubListener(&net.UnixAddr{
Name: "/run/user/0/gnupg/S.gpg-agent",
Net: "unix",
})
var listenedNetwork string
var listenedAddr string
listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
listenedNetwork = network
listenedAddr = address
return listener, nil
}
var targetNetwork string
var targetAddr string
forwarder, err := star.StartRemoteForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", "/run/user/0/gnupg/S.gpg-agent"),
TargetAddr: "127.0.0.1:4321",
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
targetNetwork = network
targetAddr = address
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
},
})
if err != nil {
t.Fatalf("start remote unix forward: %v", err)
}
defer forwarder.Close()
srcPeer, forwardedConn := net.Pipe()
defer srcPeer.Close()
if err := listener.Push(forwardedConn); err != nil {
t.Fatalf("push forwarded connection: %v", err)
}
payload := []byte("unix-forward")
done := make(chan []byte, 1)
go func() {
reply := make([]byte, len(payload))
_, _ = io.ReadFull(srcPeer, reply)
done <- reply
}()
if _, err := srcPeer.Write(payload); err != nil {
t.Fatalf("write source payload: %v", err)
}
select {
case reply := <-done:
if string(reply) != string(payload) {
t.Fatalf("unexpected remote unix forward reply: %q", string(reply))
}
case <-time.After(2 * time.Second):
t.Fatal("remote unix forward did not relay payload")
}
if listenedNetwork != "unix" || listenedAddr != "/run/user/0/gnupg/S.gpg-agent" {
t.Fatalf("unexpected remote listen request: network=%q addr=%q", listenedNetwork, listenedAddr)
}
if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" {
t.Fatalf("unexpected local dial target: network=%q addr=%q", targetNetwork, targetAddr)
}
}
func TestStartLocalUnixForwardUsesUnixListenerAndTCPTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
var targetNetwork string
var targetAddr string
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
targetNetwork = network
targetAddr = address
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
socketPath := filepath.Join(t.TempDir(), "forward.sock")
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
if err != nil {
t.Fatalf("start local unix forward: %v", err)
}
defer func() {
closeErr := forwarder.Close()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
t.Fatalf("close local unix forward: %v", closeErr)
}
}()
conn, err := net.DialTimeout("unix", socketPath, time.Second)
if err != nil {
t.Fatalf("dial unix forward listener: %v", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
payload := []byte("unix-local-forward")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write unix forward payload: %v", err)
}
reply := make([]byte, len(payload))
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read unix forward reply: %v", err)
}
if string(reply) != string(payload) {
t.Fatalf("unexpected unix forward reply: %q", string(reply))
}
if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" {
t.Fatalf("unexpected remote dial target: network=%q addr=%q", targetNetwork, targetAddr)
}
}
func TestStartLocalUnixForwardRemovesSocketOnClose(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
socketPath := filepath.Join(t.TempDir(), "cleanup.sock")
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
if err != nil {
t.Fatalf("start local unix forward: %v", err)
}
if _, err := os.Lstat(socketPath); err != nil {
t.Fatalf("socket should exist while forward is running: %v", err)
}
if err := forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
t.Fatalf("close local unix forward: %v", err)
}
if _, err := os.Lstat(socketPath); !errors.Is(err, os.ErrNotExist) {
t.Fatalf("socket path should be removed on close, got err=%v", err)
}
}
func TestStartLocalUnixForwardReusesStaleSocketPath(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
socketPath := filepath.Join(t.TempDir(), "stale.sock")
staleListener, err := net.ListenUnix("unix", &net.UnixAddr{
Name: socketPath,
Net: "unix",
})
if err != nil {
t.Fatalf("create stale unix socket: %v", err)
}
staleListener.SetUnlinkOnClose(false)
if err := staleListener.Close(); err != nil {
t.Fatalf("close stale unix socket listener: %v", err)
}
if _, err := os.Lstat(socketPath); err != nil {
t.Fatalf("expected stale unix socket path to remain after close: %v", err)
}
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
if err != nil {
t.Fatalf("start local unix forward on stale socket path: %v", err)
}
defer func() {
closeErr := forwarder.Close()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
t.Fatalf("close local unix forward: %v", closeErr)
}
}()
reply := make([]byte, len("stale-reuse"))
conn, err := net.DialTimeout("unix", socketPath, time.Second)
if err != nil {
t.Fatalf("dial reused unix forward listener: %v", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
if _, err := conn.Write([]byte("stale-reuse")); err != nil {
t.Fatalf("write reused unix forward payload: %v", err)
}
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read reused unix forward reply: %v", err)
}
if string(reply) != "stale-reuse" {
t.Fatalf("unexpected reply on reused unix forward: %q", string(reply))
}
}
func TestStartLocalUnixToUnixForwardUsesUnixTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
targetSocketPath := filepath.Join(t.TempDir(), "target.sock")
targetListener, err := net.Listen("unix", targetSocketPath)
if err != nil {
t.Fatalf("listen target unix socket: %v", err)
}
defer targetListener.Close()
done := make(chan []byte, 1)
go func() {
conn, acceptErr := targetListener.Accept()
if acceptErr != nil {
done <- nil
return
}
defer conn.Close()
buf := make([]byte, 64)
n, _ := conn.Read(buf)
_, _ = conn.Write(buf[:n])
done <- buf[:n]
}()
dialRecordCh := make(chan dialRecord, 1)
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
dialRecordCh <- dialRecord{network: network, addr: address}
var dialer net.Dialer
return dialer.DialContext(ctx, network, address)
}
listenSocketPath := filepath.Join(t.TempDir(), "listen.sock")
forwarder, err := star.StartLocalUnixToUnixForward(listenSocketPath, targetSocketPath)
if err != nil {
t.Fatalf("start local unix-to-unix forward: %v", err)
}
defer func() {
closeErr := forwarder.Close()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
t.Fatalf("close local unix-to-unix forward: %v", closeErr)
}
}()
conn, err := net.DialTimeout("unix", listenSocketPath, time.Second)
if err != nil {
t.Fatalf("dial unix-to-unix listener: %v", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
payload := []byte("unix-to-unix")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write unix-to-unix payload: %v", err)
}
reply := make([]byte, len(payload))
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read unix-to-unix reply: %v", err)
}
if string(reply) != string(payload) {
t.Fatalf("unexpected unix-to-unix reply: %q", string(reply))
}
select {
case got := <-done:
if string(got) != string(payload) {
t.Fatalf("unexpected payload seen by target unix socket: %q", string(got))
}
case <-time.After(2 * time.Second):
t.Fatal("target unix socket did not receive forwarded payload")
}
select {
case got := <-dialRecordCh:
if got.network != "unix" || got.addr != targetSocketPath {
t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr)
}
case <-time.After(2 * time.Second):
t.Fatal("did not observe unix target dial")
}
}
func TestStartLocalTCPToUnixForwardUsesUnixTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
targetSocketPath := filepath.Join(t.TempDir(), "target-tcp-to-unix.sock")
targetListener, err := net.Listen("unix", targetSocketPath)
if err != nil {
t.Fatalf("listen target unix socket: %v", err)
}
defer targetListener.Close()
done := make(chan []byte, 1)
go func() {
conn, acceptErr := targetListener.Accept()
if acceptErr != nil {
done <- nil
return
}
defer conn.Close()
buf := make([]byte, 64)
n, _ := conn.Read(buf)
_, _ = conn.Write(buf[:n])
done <- buf[:n]
}()
dialRecordCh := make(chan dialRecord, 1)
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
dialRecordCh <- dialRecord{network: network, addr: address}
var dialer net.Dialer
return dialer.DialContext(ctx, network, address)
}
forwarder, err := star.StartLocalTCPToUnixForward("127.0.0.1:0", targetSocketPath)
if err != nil {
t.Fatalf("start local tcp-to-unix forward: %v", err)
}
defer func() {
closeErr := forwarder.Close()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
t.Fatalf("close local tcp-to-unix forward: %v", closeErr)
}
}()
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("tcp-to-unix"))
if string(reply) != "tcp-to-unix" {
t.Fatalf("unexpected tcp-to-unix reply: %q", string(reply))
}
select {
case got := <-done:
if string(got) != "tcp-to-unix" {
t.Fatalf("unexpected payload seen by unix target: %q", string(got))
}
case <-time.After(2 * time.Second):
t.Fatal("unix target did not receive forwarded tcp payload")
}
select {
case got := <-dialRecordCh:
if got.network != "unix" || got.addr != targetSocketPath {
t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr)
}
case <-time.After(2 * time.Second):
t.Fatal("did not observe unix target dial")
}
}
refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力 - 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
2026-04-26 10:45:39 +08:00
func echoForwardPipe(conn net.Conn) {
defer conn.Close()
buf := make([]byte, 4096)
n, err := conn.Read(buf)
if err != nil {
return
}
_, _ = conn.Write(buf[:n])
}
func exerciseForwarder(t *testing.T, addr string, payload []byte) []byte {
t.Helper()
conn, err := net.DialTimeout("tcp", addr, time.Second)
if err != nil {
t.Fatalf("dial forward listener: %v", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write forwarded payload: %v", err)
}
reply := make([]byte, len(payload))
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read forwarded reply: %v", err)
}
return reply
}