feat: 增强 starssh 的 agent forwarding 与 tcp/unix 转发能力
- 为 LoginInput 增加 ForwardSSHAgent 配置,并在 Exec/PTTY 会话创建时按需自动请求 agent forwarding - 新增 agent_forward 运行时,封装本地 ssh-agent 建连、转发注册、显式请求与 unavailable/denied 语义 - 自动 agent forwarding 改为 best-effort:本地 agent 不可用、转发被拒绝或初始化失败时不再打断会话创建 - 为 StarSSH 增加 closing 状态与 agent forwarder 生命周期回收,避免 Close 与会话创建并发时泄漏资源 - 扩展 ForwardRequest 为带网络归一化的转发模型,支持 tcp/tcp4/tcp6/unix 端点组合 - 新增本地/远端 tcp<->unix、unix<->unix 及 detached helper,补齐 streamlocal 场景下的常用 API - 将显式网络地址编码收口为 tcp4://、tcp6://、unix://,消除 tcp:22 一类值的解析歧义 - 为本地 unix listener 增加 stale socket 探测、复用与关闭清理,避免遗留 socket 导致重启失败 - 补充 agent forwarding、关闭竞态、remote unix forward、local unix forward、stale socket 复用与端点解析等回归测试
This commit is contained in:
@@ -0,0 +1,417 @@
|
||||
package starssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
sshagent "golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
type testCloser struct {
|
||||
closed atomic.Int32
|
||||
}
|
||||
|
||||
func (c *testCloser) Close() error {
|
||||
c.closed.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
baseClient := &ssh.Client{}
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
Timeout: time.Second,
|
||||
},
|
||||
}
|
||||
star.setTransport(baseClient, nil)
|
||||
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
if client != baseClient {
|
||||
t.Fatalf("unexpected ssh client %p", client)
|
||||
}
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
|
||||
var agentInitCalls atomic.Int32
|
||||
closer := &testCloser{}
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
agentInitCalls.Add(1)
|
||||
if timeout != time.Second {
|
||||
t.Fatalf("unexpected forwarding timeout: %v", timeout)
|
||||
}
|
||||
return sshagent.NewKeyring(), closer, nil
|
||||
}
|
||||
|
||||
var routeCalls atomic.Int32
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
||||
routeCalls.Add(1)
|
||||
if client != baseClient {
|
||||
t.Fatalf("unexpected routed client %p", client)
|
||||
}
|
||||
if keyring == nil {
|
||||
t.Fatal("expected non-nil forwarded agent keyring")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var requestCalls atomic.Int32
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
requestCalls.Add(1)
|
||||
if session == nil {
|
||||
t.Fatal("expected non-nil ssh session")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
t.Fatalf("first exec session: %v", err)
|
||||
}
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
t.Fatalf("second exec session: %v", err)
|
||||
}
|
||||
|
||||
if agentInitCalls.Load() != 1 {
|
||||
t.Fatalf("expected one agent forwarder init, got %d", agentInitCalls.Load())
|
||||
}
|
||||
if routeCalls.Load() != 1 {
|
||||
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
|
||||
}
|
||||
if requestCalls.Load() != 2 {
|
||||
t.Fatalf("expected agent forwarding request on each session, got %d", requestCalls.Load())
|
||||
}
|
||||
|
||||
closeSSHClient = func(client sshClientRequester) error { return nil }
|
||||
if err := star.Close(); err != nil {
|
||||
t.Fatalf("close starssh: %v", err)
|
||||
}
|
||||
if closer.closed.Load() != 1 {
|
||||
t.Fatalf("expected forwarded agent closer to run once, got %d", closer.closed.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldRequestSessionPTY := requestSessionPTY
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
requestSessionPTY = oldRequestSessionPTY
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
},
|
||||
}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
|
||||
var ptyCalls atomic.Int32
|
||||
requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error {
|
||||
ptyCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return sshagent.NewKeyring(), &testCloser{}, nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
|
||||
|
||||
var requestCalls atomic.Int32
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
requestCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := star.NewPTYSession(nil); err != nil {
|
||||
t.Fatalf("new pty session: %v", err)
|
||||
}
|
||||
if ptyCalls.Load() != 1 {
|
||||
t.Fatalf("expected one PTY request, got %d", ptyCalls.Load())
|
||||
}
|
||||
if requestCalls.Load() != 1 {
|
||||
t.Fatalf("expected one agent forwarding request, got %d", requestCalls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
t.Fatal("agent forwarder should not initialize when disabled")
|
||||
return nil, nil, nil
|
||||
}
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
t.Fatal("agent forwarding should not be requested when disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
t.Fatalf("new exec session without forwarding: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||
}
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
t.Fatal("session request should not run when agent forwarder init fails")
|
||||
return nil
|
||||
}
|
||||
|
||||
err := star.RequestAgentForwarding(&ssh.Session{})
|
||||
if err == nil {
|
||||
t.Fatal("expected agent forwarding init error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
t.Cleanup(func() {
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
})
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
|
||||
}
|
||||
|
||||
err := star.RequestAgentForwarding(&ssh.Session{})
|
||||
if !isSSHAgentForwardingUnavailableError(err) {
|
||||
t.Fatalf("expected unavailable error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return sshagent.NewKeyring(), &testCloser{}, nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
return errors.New("forwarding request denied")
|
||||
}
|
||||
|
||||
err := star.RequestAgentForwarding(&ssh.Session{})
|
||||
if !isSSHAgentForwardingDeniedError(err) {
|
||||
t.Fatalf("expected forwarding denied error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
|
||||
})
|
||||
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
},
|
||||
}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return sshagent.NewKeyring(), &testCloser{}, nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
|
||||
requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||||
return errors.New("forwarding request denied")
|
||||
}
|
||||
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
t.Fatalf("new exec session should ignore denied agent forwarding: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
})
|
||||
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
},
|
||||
}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
|
||||
}
|
||||
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
t.Fatalf("new exec session should ignore unavailable agent forwarding: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
|
||||
oldNewSSHSession := newSSHSession
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
t.Cleanup(func() {
|
||||
newSSHSession = oldNewSSHSession
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
})
|
||||
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
},
|
||||
}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
|
||||
return &ssh.Session{}, nil
|
||||
}
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
|
||||
}
|
||||
|
||||
if _, err := star.NewExecSession(); err != nil {
|
||||
t.Fatalf("new exec session should ignore agent setup error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
|
||||
oldNewSSHAgentForwarder := newSSHAgentForwarder
|
||||
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
|
||||
oldCloseSSHClient := closeSSHClient
|
||||
t.Cleanup(func() {
|
||||
newSSHAgentForwarder = oldNewSSHAgentForwarder
|
||||
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
|
||||
closeSSHClient = oldCloseSSHClient
|
||||
})
|
||||
|
||||
star := &StarSSH{
|
||||
LoginInfo: LoginInput{
|
||||
ForwardSSHAgent: true,
|
||||
},
|
||||
}
|
||||
star.setTransport(&ssh.Client{}, nil)
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
closer := &testCloser{}
|
||||
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||||
close(started)
|
||||
<-release
|
||||
return sshagent.NewKeyring(), closer, nil
|
||||
}
|
||||
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
||||
return nil
|
||||
}
|
||||
closeSSHClient = func(client sshClientRequester) error { return nil }
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- star.ensureAgentForwarding()
|
||||
}()
|
||||
|
||||
<-started
|
||||
closeDone := make(chan struct{})
|
||||
go func() {
|
||||
_ = star.Close()
|
||||
close(closeDone)
|
||||
}()
|
||||
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for !star.closing.Load() {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("close did not enter closing state in time")
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
close(release)
|
||||
|
||||
err := <-errCh
|
||||
if !errors.Is(err, errSSHClientClosing) {
|
||||
t.Fatalf("expected closing error, got %v", err)
|
||||
}
|
||||
<-closeDone
|
||||
|
||||
if closer.closed.Load() != 1 {
|
||||
t.Fatalf("expected new forwarder closer to be closed once, got %d", closer.closed.Load())
|
||||
}
|
||||
if got := star.takeAgentForwarder(); got != nil {
|
||||
t.Fatal("expected no leaked agent forwarder after close race")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user