feat: 增强 starssh 的 agent forwarding 与 tcp/unix 转发能力

- 为 LoginInput 增加 ForwardSSHAgent 配置,并在 Exec/PTTY 会话创建时按需自动请求 agent forwarding
- 新增 agent_forward 运行时,封装本地 ssh-agent 建连、转发注册、显式请求与 unavailable/denied 语义
- 自动 agent forwarding 改为 best-effort:本地 agent 不可用、转发被拒绝或初始化失败时不再打断会话创建
- 为 StarSSH 增加 closing 状态与 agent forwarder 生命周期回收,避免 Close 与会话创建并发时泄漏资源
- 扩展 ForwardRequest 为带网络归一化的转发模型,支持 tcp/tcp4/tcp6/unix 端点组合
- 新增本地/远端 tcp<->unix、unix<->unix 及 detached helper,补齐 streamlocal 场景下的常用 API
- 将显式网络地址编码收口为 tcp4://、tcp6://、unix://,消除 tcp:22 一类值的解析歧义
- 为本地 unix listener 增加 stale socket 探测、复用与关闭清理,避免遗留 socket 导致重启失败
- 补充 agent forwarding、关闭竞态、remote unix forward、local unix forward、stale socket 复用与端点解析等回归测试
This commit is contained in:
兔子 2026-04-26 20:27:10 +08:00
parent f20eb653ae
commit b29246a9c4
Signed by: b612
GPG Key ID: 99DD2222B612B612
7 changed files with 1463 additions and 45 deletions

151
agent_forward.go Normal file
View File

@ -0,0 +1,151 @@
package starssh
import (
"errors"
"fmt"
"io"
"strings"
"time"
"golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
)
var requestSSHAgentForwarding = func(session *ssh.Session) error {
return sshagent.RequestAgentForwarding(session)
}
var routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
return sshagent.ForwardToAgent(client, keyring)
}
var newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
conn, err := dialSSHAgent(timeout)
if err != nil {
return nil, nil, wrapSSHAgentForwardingUnavailable(err)
}
if conn == nil {
return nil, nil, wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
}
return sshagent.NewClient(conn), conn, nil
}
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
if s == nil {
return errors.New("ssh client is nil")
}
if session == nil {
return errors.New("ssh session is nil")
}
if err := s.ensureAgentForwarding(); err != nil {
return err
}
if err := requestSSHAgentForwarding(session); err != nil {
if isSSHAgentForwardingDeniedError(err) {
return fmt.Errorf("%w: %v", errSSHAgentForwardingDenied, err)
}
return err
}
return nil
}
func (s *StarSSH) maybeRequestAgentForwarding(session *ssh.Session) error {
if s == nil || !s.LoginInfo.ForwardSSHAgent {
return nil
}
err := s.RequestAgentForwarding(session)
if isSSHAgentForwardingDeniedError(err) || isSSHAgentForwardingUnavailableError(err) {
return nil
}
return err
}
func (s *StarSSH) ensureAgentForwarding() error {
if s == nil {
return errors.New("ssh client is nil")
}
s.agentForwardMu.Lock()
defer s.agentForwardMu.Unlock()
if s.agentForwarder != nil {
return nil
}
client, err := s.requireSSHClient()
if err != nil {
return err
}
keyring, closer, err := newSSHAgentForwarder(s.LoginInfo.Timeout)
if err != nil {
return wrapSSHAgentForwardingUnavailable(err)
}
if s.closing.Load() {
_ = closer.Close()
return errSSHClientClosing
}
if err := routeSSHAgentForwarding(client, keyring); err != nil {
_ = closer.Close()
return err
}
if !s.canAttachAgentForwarder(client) {
_ = closer.Close()
return errSSHClientClosing
}
s.agentForwarder = closer
return nil
}
func (s *StarSSH) takeAgentForwarder() io.Closer {
if s == nil {
return nil
}
s.agentForwardMu.Lock()
defer s.agentForwardMu.Unlock()
closer := s.agentForwarder
s.agentForwarder = nil
return closer
}
func isSSHAgentForwardingDeniedError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, errSSHAgentForwardingDenied) {
return true
}
message := strings.ToLower(err.Error())
return strings.Contains(message, "forwarding request denied") ||
strings.Contains(message, "agent forwarding disabled")
}
func isSSHAgentForwardingUnavailableError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, errSSHAgentForwardingUnavailable) {
return true
}
message := strings.ToLower(err.Error())
return strings.Contains(message, "ssh-agent forwarding unavailable") ||
strings.Contains(message, "ssh-agent unavailable")
}
func wrapSSHAgentForwardingUnavailable(err error) error {
if err == nil {
return nil
}
if errors.Is(err, errSSHAgentForwardingUnavailable) {
return err
}
if errors.Is(err, errSSHAgentUnavailable) {
return fmt.Errorf("%w: %w", errSSHAgentForwardingUnavailable, err)
}
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
}

417
agent_forward_test.go Normal file
View File

@ -0,0 +1,417 @@
package starssh
import (
"errors"
"io"
"sync/atomic"
"testing"
"time"
"golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
)
type testCloser struct {
closed atomic.Int32
}
func (c *testCloser) Close() error {
c.closed.Add(1)
return nil
}
func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) {
oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
closeSSHClient = oldCloseSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
Timeout: time.Second,
},
}
star.setTransport(baseClient, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
return &ssh.Session{}, nil
}
var agentInitCalls atomic.Int32
closer := &testCloser{}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
agentInitCalls.Add(1)
if timeout != time.Second {
t.Fatalf("unexpected forwarding timeout: %v", timeout)
}
return sshagent.NewKeyring(), closer, nil
}
var routeCalls atomic.Int32
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
routeCalls.Add(1)
if client != baseClient {
t.Fatalf("unexpected routed client %p", client)
}
if keyring == nil {
t.Fatal("expected non-nil forwarded agent keyring")
}
return nil
}
var requestCalls atomic.Int32
requestSSHAgentForwarding = func(session *ssh.Session) error {
requestCalls.Add(1)
if session == nil {
t.Fatal("expected non-nil ssh session")
}
return nil
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("first exec session: %v", err)
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("second exec session: %v", err)
}
if agentInitCalls.Load() != 1 {
t.Fatalf("expected one agent forwarder init, got %d", agentInitCalls.Load())
}
if routeCalls.Load() != 1 {
t.Fatalf("expected one agent route registration, got %d", routeCalls.Load())
}
if requestCalls.Load() != 2 {
t.Fatalf("expected agent forwarding request on each session, got %d", requestCalls.Load())
}
closeSSHClient = func(client sshClientRequester) error { return nil }
if err := star.Close(); err != nil {
t.Fatalf("close starssh: %v", err)
}
if closer.closed.Load() != 1 {
t.Fatalf("expected forwarded agent closer to run once, got %d", closer.closed.Load())
}
}
func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) {
oldNewSSHSession := newSSHSession
oldRequestSessionPTY := requestSessionPTY
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
requestSessionPTY = oldRequestSessionPTY
newSSHAgentForwarder = oldNewSSHAgentForwarder
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
var ptyCalls atomic.Int32
requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error {
ptyCalls.Add(1)
return nil
}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return sshagent.NewKeyring(), &testCloser{}, nil
}
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
var requestCalls atomic.Int32
requestSSHAgentForwarding = func(session *ssh.Session) error {
requestCalls.Add(1)
return nil
}
if _, err := star.NewPTYSession(nil); err != nil {
t.Fatalf("new pty session: %v", err)
}
if ptyCalls.Load() != 1 {
t.Fatalf("expected one PTY request, got %d", ptyCalls.Load())
}
if requestCalls.Load() != 1 {
t.Fatalf("expected one agent forwarding request, got %d", requestCalls.Load())
}
}
func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) {
oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
t.Fatal("agent forwarder should not initialize when disabled")
return nil, nil, nil
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
t.Fatal("agent forwarding should not be requested when disabled")
return nil
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("new exec session without forwarding: %v", err)
}
}
func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) {
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHAgentForwarder = oldNewSSHAgentForwarder
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
}
requestSSHAgentForwarding = func(session *ssh.Session) error {
t.Fatal("session request should not run when agent forwarder init fails")
return nil
}
err := star.RequestAgentForwarding(&ssh.Session{})
if err == nil {
t.Fatal("expected agent forwarding init error")
}
}
func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) {
oldNewSSHAgentForwarder := newSSHAgentForwarder
t.Cleanup(func() {
newSSHAgentForwarder = oldNewSSHAgentForwarder
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied")
}
err := star.RequestAgentForwarding(&ssh.Session{})
if !isSSHAgentForwardingUnavailableError(err) {
t.Fatalf("expected unavailable error, got %v", err)
}
}
func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) {
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHAgentForwarder = oldNewSSHAgentForwarder
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{}
star.setTransport(&ssh.Client{}, nil)
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return sshagent.NewKeyring(), &testCloser{}, nil
}
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
requestSSHAgentForwarding = func(session *ssh.Session) error {
return errors.New("forwarding request denied")
}
err := star.RequestAgentForwarding(&ssh.Session{})
if !isSSHAgentForwardingDeniedError(err) {
t.Fatalf("expected forwarding denied error, got %v", err)
}
}
func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) {
oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldRequestSSHAgentForwarding := requestSSHAgentForwarding
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
requestSSHAgentForwarding = oldRequestSSHAgentForwarding
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return sshagent.NewKeyring(), &testCloser{}, nil
}
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil }
requestSSHAgentForwarding = func(session *ssh.Session) error {
return errors.New("forwarding request denied")
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("new exec session should ignore denied agent forwarding: %v", err)
}
}
func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) {
oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable")
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("new exec session should ignore unavailable agent forwarding: %v", err)
}
}
func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) {
oldNewSSHSession := newSSHSession
oldNewSSHAgentForwarder := newSSHAgentForwarder
t.Cleanup(func() {
newSSHSession = oldNewSSHSession
newSSHAgentForwarder = oldNewSSHAgentForwarder
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return &ssh.Session{}, nil
}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused")
}
if _, err := star.NewExecSession(); err != nil {
t.Fatalf("new exec session should ignore agent setup error: %v", err)
}
}
func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) {
oldNewSSHAgentForwarder := newSSHAgentForwarder
oldRouteSSHAgentForwarding := routeSSHAgentForwarding
oldCloseSSHClient := closeSSHClient
t.Cleanup(func() {
newSSHAgentForwarder = oldNewSSHAgentForwarder
routeSSHAgentForwarding = oldRouteSSHAgentForwarding
closeSSHClient = oldCloseSSHClient
})
star := &StarSSH{
LoginInfo: LoginInput{
ForwardSSHAgent: true,
},
}
star.setTransport(&ssh.Client{}, nil)
started := make(chan struct{})
release := make(chan struct{})
closer := &testCloser{}
newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
close(started)
<-release
return sshagent.NewKeyring(), closer, nil
}
routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
return nil
}
closeSSHClient = func(client sshClientRequester) error { return nil }
errCh := make(chan error, 1)
go func() {
errCh <- star.ensureAgentForwarding()
}()
<-started
closeDone := make(chan struct{})
go func() {
_ = star.Close()
close(closeDone)
}()
deadline := time.Now().Add(time.Second)
for !star.closing.Load() {
if time.Now().After(deadline) {
t.Fatal("close did not enter closing state in time")
}
time.Sleep(time.Millisecond)
}
close(release)
err := <-errCh
if !errors.Is(err, errSSHClientClosing) {
t.Fatalf("expected closing error, got %v", err)
}
<-closeDone
if closer.closed.Load() != 1 {
t.Fatalf("expected new forwarder closer to be closed once, got %d", closer.closed.Load())
}
if got := star.takeAgentForwarder(); got != nil {
t.Fatal("expected no leaked agent forwarder after close race")
}
}

View File

@ -3,21 +3,39 @@ package starssh
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
"syscall"
"time"
"golang.org/x/crypto/ssh"
)
type ForwardRequest struct {
// Keep the exported shape compatible with older positional literals:
// ForwardRequest{listenAddr, targetAddr, dialContext}.
//
// Non-default networks can be encoded with an explicit scheme-like prefix:
// "tcp4://127.0.0.1:22", "tcp6://[::1]:22", "unix:///tmp/socket".
// Bare values default to the "tcp" network.
ListenAddr string
TargetAddr string
DialContext DialContextFunc
}
type normalizedForwardRequest struct {
ListenNetwork string
ListenAddr string
TargetNetwork string
TargetAddr string
DialContext DialContextFunc
}
type DynamicForwardRequest struct {
ListenAddr string
}
@ -41,10 +59,16 @@ type PortForwarder struct {
cleanupFns []func() error
}
const unixForwardProbeTimeout = 200 * time.Millisecond
var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
return client.Dial(network, address)
}
var listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) {
return client.Listen(network, address)
}
var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) {
if ctx == nil {
ctx = context.Background()
@ -64,6 +88,90 @@ func (s *StarSSH) DialTCPContextCloseOnCancel(ctx context.Context, network strin
return s.dialTCPContext(ctx, network, address, s.Close)
}
func (s *StarSSH) StartLocalTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
return s.StartLocalForward(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartLocalTCPForwardDetached(listenAddr string, targetAddr string) (*PortForwarder, error) {
return s.StartLocalForwardDetached(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartLocalTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
return s.StartLocalForward(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartLocalTCPToUnixForwardDetached(listenAddr string, targetPath string) (*PortForwarder, error) {
return s.StartLocalForwardDetached(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartLocalUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
return s.StartLocalForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartLocalUnixForwardDetached(listenPath string, targetAddr string) (*PortForwarder, error) {
return s.StartLocalForwardDetached(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartLocalUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
return s.StartLocalForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartLocalUnixToUnixForwardDetached(listenPath string, targetPath string) (*PortForwarder, error) {
return s.StartLocalForwardDetached(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartRemoteTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) {
return s.StartRemoteForward(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartRemoteTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) {
return s.StartRemoteForward(ForwardRequest{
ListenAddr: listenAddr,
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) StartRemoteUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) {
return s.StartRemoteForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: targetAddr,
})
}
func (s *StarSSH) StartRemoteUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) {
return s.StartRemoteForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", listenPath),
TargetAddr: forwardEndpoint("unix", targetPath),
})
}
func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) {
if ctx == nil {
ctx = context.Background()
@ -136,21 +244,22 @@ func (s *StarSSH) StartLocalForward(req ForwardRequest) (*PortForwarder, error)
if _, err := s.requireSSHClient(); err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
normalizedReq, err := normalizeForwardRequest(req)
if err != nil {
return nil, err
}
if strings.TrimSpace(normalizedReq.ListenAddr) == "" {
return nil, errors.New("local forward listen address is empty")
}
if strings.TrimSpace(req.TargetAddr) == "" {
return nil, errors.New("local forward target address is empty")
}
listener, err := net.Listen("tcp", req.ListenAddr)
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
if err != nil {
return nil, err
}
forwarder := newPortForwarder(listener)
forwarder.addCleanup(cleanup)
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
return s.DialTCPContext(ctx, "tcp", req.TargetAddr)
return s.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
})
return forwarder, nil
}
@ -159,14 +268,12 @@ func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder,
if _, err := s.requireSSHClient(); err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
return nil, errors.New("local forward listen address is empty")
}
if strings.TrimSpace(req.TargetAddr) == "" {
return nil, errors.New("local forward target address is empty")
normalizedReq, err := normalizeForwardRequest(req)
if err != nil {
return nil, err
}
listener, err := net.Listen("tcp", req.ListenAddr)
listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
if err != nil {
return nil, err
}
@ -174,15 +281,19 @@ func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder,
forwardClient, err := s.newForwardDialClient(context.Background())
if err != nil {
_ = listener.Close()
if cleanup != nil {
_ = cleanup()
}
return nil, err
}
forwarder := newPortForwarder(listener)
forwarder.addCleanup(cleanup)
forwarder.addCleanup(func() error {
return normalizeAlreadyClosedError(forwardClient.Close())
})
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
return forwardClient.DialTCPContext(ctx, "tcp", req.TargetAddr)
return forwardClient.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
})
return forwarder, nil
}
@ -192,19 +303,17 @@ func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error)
if err != nil {
return nil, err
}
if strings.TrimSpace(req.ListenAddr) == "" {
return nil, errors.New("remote forward listen address is empty")
}
if strings.TrimSpace(req.TargetAddr) == "" {
return nil, errors.New("remote forward target address is empty")
}
listener, err := client.Listen("tcp", req.ListenAddr)
normalizedReq, err := normalizeForwardRequest(req)
if err != nil {
return nil, err
}
dialContext := req.DialContext
listener, err := listenSSHClient(client, normalizedReq.ListenNetwork, normalizedReq.ListenAddr)
if err != nil {
return nil, err
}
dialContext := normalizedReq.DialContext
if dialContext == nil {
dialer := &net.Dialer{
Timeout: defaultLoginTimeout,
@ -214,7 +323,7 @@ func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error)
forwarder := newPortForwarder(listener)
forwarder.serve(func(ctx context.Context) (net.Conn, error) {
return dialContext(ctx, "tcp", req.TargetAddr)
return dialContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr)
})
return forwarder, nil
}
@ -239,6 +348,74 @@ func (s *StarSSH) StartDynamicForward(req DynamicForwardRequest) (*PortForwarder
return forwarder, nil
}
func normalizeForwardRequest(req ForwardRequest) (normalizedForwardRequest, error) {
normalized := normalizedForwardRequest{
DialContext: req.DialContext,
}
var err error
normalized.ListenNetwork, normalized.ListenAddr, err = parseForwardEndpoint(req.ListenAddr)
if err != nil {
return normalized, fmt.Errorf("normalize listen address: %w", err)
}
normalized.TargetNetwork, normalized.TargetAddr, err = parseForwardEndpoint(req.TargetAddr)
if err != nil {
return normalized, fmt.Errorf("normalize target address: %w", err)
}
if strings.TrimSpace(normalized.ListenAddr) == "" {
return normalized, errors.New("forward listen address is empty")
}
if strings.TrimSpace(normalized.TargetAddr) == "" {
return normalized, errors.New("forward target address is empty")
}
return normalized, nil
}
func normalizeForwardNetwork(network string) string {
network = strings.ToLower(strings.TrimSpace(network))
if network == "" {
return "tcp"
}
return network
}
func isSupportedForwardNetwork(network string) bool {
switch network {
case "tcp", "tcp4", "tcp6", "unix":
return true
default:
return false
}
}
func parseForwardEndpoint(value string) (network string, address string, err error) {
value = strings.TrimSpace(value)
if value == "" {
return "tcp", "", nil
}
lowerValue := strings.ToLower(value)
for _, prefix := range []string{"tcp4://", "tcp6://", "tcp://", "unix://"} {
if strings.HasPrefix(lowerValue, prefix) {
network = normalizeForwardNetwork(strings.TrimSuffix(prefix, "://"))
address = value[len(prefix):]
if !isSupportedForwardNetwork(network) {
return "", "", fmt.Errorf("unsupported forward network %q", network)
}
return network, address, nil
}
}
return "tcp", value, nil
}
func forwardEndpoint(network string, address string) string {
network = normalizeForwardNetwork(network)
if network == "tcp" {
return address
}
return network + "://" + address
}
func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) {
if _, err := s.requireSSHClient(); err != nil {
return nil, err
@ -344,6 +521,87 @@ func (f *PortForwarder) addCleanup(fn func() error) {
f.cleanupFns = append(f.cleanupFns, fn)
}
func prepareLocalForwardListener(network string, address string) (net.Listener, func() error, error) {
network = normalizeForwardNetwork(network)
if network != "unix" {
listener, err := net.Listen(network, address)
return listener, nil, err
}
if err := removeStaleUnixSocket(address); err != nil {
return nil, nil, err
}
listener, err := net.Listen(network, address)
if err != nil {
return nil, nil, err
}
cleanup, err := makeUnixSocketCleanup(address)
if err != nil {
_ = listener.Close()
_ = removeUnixSocketPath(address)
return nil, nil, err
}
return listener, cleanup, nil
}
func removeStaleUnixSocket(path string) error {
info, err := os.Lstat(path)
if errors.Is(err, os.ErrNotExist) {
return nil
}
if err != nil {
return err
}
if info.Mode()&os.ModeSocket == 0 {
return fmt.Errorf("local unix forward path %q already exists and is not a socket", path)
}
conn, err := net.DialTimeout("unix", path, unixForwardProbeTimeout)
if err == nil {
_ = conn.Close()
return fmt.Errorf("local unix forward path %q is already in use", path)
}
if !isStaleUnixSocketDialError(err) {
return fmt.Errorf("probe existing unix socket %q: %w", path, err)
}
return removeUnixSocketPath(path)
}
func isStaleUnixSocketDialError(err error) bool {
return errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT)
}
func makeUnixSocketCleanup(path string) (func() error, error) {
info, err := os.Lstat(path)
if err != nil {
return nil, err
}
return func() error {
current, err := os.Lstat(path)
if errors.Is(err, os.ErrNotExist) {
return nil
}
if err != nil {
return err
}
if current.Mode()&os.ModeSocket == 0 || !os.SameFile(info, current) {
return nil
}
return removeUnixSocketPath(path)
}, nil
}
func removeUnixSocketPath(path string) error {
err := os.Remove(path)
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
func (f *PortForwarder) runCleanup() {
if f == nil {
return

View File

@ -2,8 +2,13 @@ package starssh
import (
"context"
"errors"
"io"
"net"
"os"
"path/filepath"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
@ -11,6 +16,59 @@ import (
"golang.org/x/crypto/ssh"
)
type stubListener struct {
addr net.Addr
acceptCh chan net.Conn
closeCh chan struct{}
closeOnce sync.Once
}
type dialRecord struct {
network string
addr string
}
func newStubListener(addr net.Addr) *stubListener {
return &stubListener{
addr: addr,
acceptCh: make(chan net.Conn, 1),
closeCh: make(chan struct{}),
}
}
func (l *stubListener) Accept() (net.Conn, error) {
select {
case conn, ok := <-l.acceptCh:
if !ok {
return nil, io.EOF
}
return conn, nil
case <-l.closeCh:
return nil, net.ErrClosed
}
}
func (l *stubListener) Close() error {
l.closeOnce.Do(func() {
close(l.closeCh)
close(l.acceptCh)
})
return nil
}
func (l *stubListener) Addr() net.Addr {
return l.addr
}
func (l *stubListener) Push(conn net.Conn) error {
select {
case <-l.closeCh:
return net.ErrClosed
case l.acceptCh <- conn:
return nil
}
}
func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldNewDetachedForwardClient := newDetachedForwardClient
@ -63,6 +121,64 @@ func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) {
}
}
func TestForwardRequestLegacyPositionalLiteralDefaultsToTCP(t *testing.T) {
dialer := func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, nil
}
req, err := normalizeForwardRequest(ForwardRequest{
"127.0.0.1:10022",
"example.internal:22",
dialer,
})
if err != nil {
t.Fatalf("normalizeForwardRequest: %v", err)
}
if req.ListenNetwork != "tcp" {
t.Fatalf("ListenNetwork=%q want tcp", req.ListenNetwork)
}
if req.TargetNetwork != "tcp" {
t.Fatalf("TargetNetwork=%q want tcp", req.TargetNetwork)
}
if req.ListenAddr != "127.0.0.1:10022" || req.TargetAddr != "example.internal:22" {
t.Fatalf("unexpected normalized request: %+v", req)
}
if req.DialContext == nil {
t.Fatal("expected DialContext to be preserved")
}
}
func TestParseForwardEndpointTreatsTCPPrefixLikePlainAddress(t *testing.T) {
network, address, err := parseForwardEndpoint("tcp:22")
if err != nil {
t.Fatalf("parseForwardEndpoint: %v", err)
}
if network != "tcp" {
t.Fatalf("network=%q want tcp", network)
}
if address != "tcp:22" {
t.Fatalf("address=%q want tcp:22", address)
}
}
func TestParseForwardEndpointSupportsExplicitSchemes(t *testing.T) {
network, address, err := parseForwardEndpoint("unix:///tmp/test-forward.sock")
if err != nil {
t.Fatalf("parseForwardEndpoint unix: %v", err)
}
if network != "unix" || address != "/tmp/test-forward.sock" {
t.Fatalf("unexpected unix endpoint parse: network=%q address=%q", network, address)
}
network, address, err = parseForwardEndpoint("tcp6://[::1]:2222")
if err != nil {
t.Fatalf("parseForwardEndpoint tcp6: %v", err)
}
if network != "tcp6" || address != "[::1]:2222" {
t.Fatalf("unexpected tcp6 endpoint parse: network=%q address=%q", network, address)
}
}
func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) {
oldDialSSHClient := dialSSHClient
oldNewDetachedForwardClient := newDetachedForwardClient
@ -132,6 +248,424 @@ func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) {
}
}
func TestStartRemoteForwardSupportsUnixListenAndTCPTarget(t *testing.T) {
oldListenSSHClient := listenSSHClient
t.Cleanup(func() {
listenSSHClient = oldListenSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
listener := newStubListener(&net.UnixAddr{
Name: "/run/user/0/gnupg/S.gpg-agent",
Net: "unix",
})
var listenedNetwork string
var listenedAddr string
listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
listenedNetwork = network
listenedAddr = address
return listener, nil
}
var targetNetwork string
var targetAddr string
forwarder, err := star.StartRemoteForward(ForwardRequest{
ListenAddr: forwardEndpoint("unix", "/run/user/0/gnupg/S.gpg-agent"),
TargetAddr: "127.0.0.1:4321",
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
targetNetwork = network
targetAddr = address
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
},
})
if err != nil {
t.Fatalf("start remote unix forward: %v", err)
}
defer forwarder.Close()
srcPeer, forwardedConn := net.Pipe()
defer srcPeer.Close()
if err := listener.Push(forwardedConn); err != nil {
t.Fatalf("push forwarded connection: %v", err)
}
payload := []byte("unix-forward")
done := make(chan []byte, 1)
go func() {
reply := make([]byte, len(payload))
_, _ = io.ReadFull(srcPeer, reply)
done <- reply
}()
if _, err := srcPeer.Write(payload); err != nil {
t.Fatalf("write source payload: %v", err)
}
select {
case reply := <-done:
if string(reply) != string(payload) {
t.Fatalf("unexpected remote unix forward reply: %q", string(reply))
}
case <-time.After(2 * time.Second):
t.Fatal("remote unix forward did not relay payload")
}
if listenedNetwork != "unix" || listenedAddr != "/run/user/0/gnupg/S.gpg-agent" {
t.Fatalf("unexpected remote listen request: network=%q addr=%q", listenedNetwork, listenedAddr)
}
if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" {
t.Fatalf("unexpected local dial target: network=%q addr=%q", targetNetwork, targetAddr)
}
}
func TestStartLocalUnixForwardUsesUnixListenerAndTCPTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
var targetNetwork string
var targetAddr string
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
targetNetwork = network
targetAddr = address
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
socketPath := filepath.Join(t.TempDir(), "forward.sock")
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
if err != nil {
t.Fatalf("start local unix forward: %v", err)
}
defer func() {
closeErr := forwarder.Close()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
t.Fatalf("close local unix forward: %v", closeErr)
}
}()
conn, err := net.DialTimeout("unix", socketPath, time.Second)
if err != nil {
t.Fatalf("dial unix forward listener: %v", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
payload := []byte("unix-local-forward")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write unix forward payload: %v", err)
}
reply := make([]byte, len(payload))
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read unix forward reply: %v", err)
}
if string(reply) != string(payload) {
t.Fatalf("unexpected unix forward reply: %q", string(reply))
}
if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" {
t.Fatalf("unexpected remote dial target: network=%q addr=%q", targetNetwork, targetAddr)
}
}
func TestStartLocalUnixForwardRemovesSocketOnClose(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
socketPath := filepath.Join(t.TempDir(), "cleanup.sock")
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
if err != nil {
t.Fatalf("start local unix forward: %v", err)
}
if _, err := os.Lstat(socketPath); err != nil {
t.Fatalf("socket should exist while forward is running: %v", err)
}
if err := forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
t.Fatalf("close local unix forward: %v", err)
}
if _, err := os.Lstat(socketPath); !errors.Is(err, os.ErrNotExist) {
t.Fatalf("socket path should be removed on close, got err=%v", err)
}
}
func TestStartLocalUnixForwardReusesStaleSocketPath(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
serverConn, clientConn := net.Pipe()
go echoForwardPipe(serverConn)
return clientConn, nil
}
socketPath := filepath.Join(t.TempDir(), "stale.sock")
staleListener, err := net.ListenUnix("unix", &net.UnixAddr{
Name: socketPath,
Net: "unix",
})
if err != nil {
t.Fatalf("create stale unix socket: %v", err)
}
staleListener.SetUnlinkOnClose(false)
if err := staleListener.Close(); err != nil {
t.Fatalf("close stale unix socket listener: %v", err)
}
if _, err := os.Lstat(socketPath); err != nil {
t.Fatalf("expected stale unix socket path to remain after close: %v", err)
}
forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321")
if err != nil {
t.Fatalf("start local unix forward on stale socket path: %v", err)
}
defer func() {
closeErr := forwarder.Close()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
t.Fatalf("close local unix forward: %v", closeErr)
}
}()
reply := make([]byte, len("stale-reuse"))
conn, err := net.DialTimeout("unix", socketPath, time.Second)
if err != nil {
t.Fatalf("dial reused unix forward listener: %v", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
if _, err := conn.Write([]byte("stale-reuse")); err != nil {
t.Fatalf("write reused unix forward payload: %v", err)
}
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read reused unix forward reply: %v", err)
}
if string(reply) != "stale-reuse" {
t.Fatalf("unexpected reply on reused unix forward: %q", string(reply))
}
}
func TestStartLocalUnixToUnixForwardUsesUnixTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
targetSocketPath := filepath.Join(t.TempDir(), "target.sock")
targetListener, err := net.Listen("unix", targetSocketPath)
if err != nil {
t.Fatalf("listen target unix socket: %v", err)
}
defer targetListener.Close()
done := make(chan []byte, 1)
go func() {
conn, acceptErr := targetListener.Accept()
if acceptErr != nil {
done <- nil
return
}
defer conn.Close()
buf := make([]byte, 64)
n, _ := conn.Read(buf)
_, _ = conn.Write(buf[:n])
done <- buf[:n]
}()
dialRecordCh := make(chan dialRecord, 1)
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
dialRecordCh <- dialRecord{network: network, addr: address}
var dialer net.Dialer
return dialer.DialContext(ctx, network, address)
}
listenSocketPath := filepath.Join(t.TempDir(), "listen.sock")
forwarder, err := star.StartLocalUnixToUnixForward(listenSocketPath, targetSocketPath)
if err != nil {
t.Fatalf("start local unix-to-unix forward: %v", err)
}
defer func() {
closeErr := forwarder.Close()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
t.Fatalf("close local unix-to-unix forward: %v", closeErr)
}
}()
conn, err := net.DialTimeout("unix", listenSocketPath, time.Second)
if err != nil {
t.Fatalf("dial unix-to-unix listener: %v", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
payload := []byte("unix-to-unix")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write unix-to-unix payload: %v", err)
}
reply := make([]byte, len(payload))
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("read unix-to-unix reply: %v", err)
}
if string(reply) != string(payload) {
t.Fatalf("unexpected unix-to-unix reply: %q", string(reply))
}
select {
case got := <-done:
if string(got) != string(payload) {
t.Fatalf("unexpected payload seen by target unix socket: %q", string(got))
}
case <-time.After(2 * time.Second):
t.Fatal("target unix socket did not receive forwarded payload")
}
select {
case got := <-dialRecordCh:
if got.network != "unix" || got.addr != targetSocketPath {
t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr)
}
case <-time.After(2 * time.Second):
t.Fatal("did not observe unix target dial")
}
}
func TestStartLocalTCPToUnixForwardUsesUnixTarget(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket smoke test is exercised in WSL/Linux CI path")
}
oldDialSSHClient := dialSSHClient
t.Cleanup(func() {
dialSSHClient = oldDialSSHClient
})
baseClient := &ssh.Client{}
star := &StarSSH{}
star.setTransport(baseClient, nil)
targetSocketPath := filepath.Join(t.TempDir(), "target-tcp-to-unix.sock")
targetListener, err := net.Listen("unix", targetSocketPath)
if err != nil {
t.Fatalf("listen target unix socket: %v", err)
}
defer targetListener.Close()
done := make(chan []byte, 1)
go func() {
conn, acceptErr := targetListener.Accept()
if acceptErr != nil {
done <- nil
return
}
defer conn.Close()
buf := make([]byte, 64)
n, _ := conn.Read(buf)
_, _ = conn.Write(buf[:n])
done <- buf[:n]
}()
dialRecordCh := make(chan dialRecord, 1)
dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) {
if client != baseClient {
t.Fatalf("unexpected ssh client %p", client)
}
dialRecordCh <- dialRecord{network: network, addr: address}
var dialer net.Dialer
return dialer.DialContext(ctx, network, address)
}
forwarder, err := star.StartLocalTCPToUnixForward("127.0.0.1:0", targetSocketPath)
if err != nil {
t.Fatalf("start local tcp-to-unix forward: %v", err)
}
defer func() {
closeErr := forwarder.Close()
if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) {
t.Fatalf("close local tcp-to-unix forward: %v", closeErr)
}
}()
reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("tcp-to-unix"))
if string(reply) != "tcp-to-unix" {
t.Fatalf("unexpected tcp-to-unix reply: %q", string(reply))
}
select {
case got := <-done:
if string(got) != "tcp-to-unix" {
t.Fatalf("unexpected payload seen by unix target: %q", string(got))
}
case <-time.After(2 * time.Second):
t.Fatal("unix target did not receive forwarded tcp payload")
}
select {
case got := <-dialRecordCh:
if got.network != "unix" || got.addr != targetSocketPath {
t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr)
}
case <-time.After(2 * time.Second):
t.Fatal("did not observe unix target dial")
}
}
func echoForwardPipe(conn net.Conn) {
defer conn.Close()
buf := make([]byte, 4096)

View File

@ -9,6 +9,14 @@ import (
"golang.org/x/crypto/ssh"
)
var newSSHSession = func(client *ssh.Client) (*ssh.Session, error) {
return client.NewSession()
}
var requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error {
return session.RequestPty(config.Term, config.Rows, config.Columns, config.Modes)
}
func (s *StarSSH) Close() error {
return s.closeTransport(true)
}
@ -22,7 +30,15 @@ func (s *StarSSH) NewExecSession() (*ssh.Session, error) {
if err != nil {
return nil, err
}
return NewExecSession(client)
session, err := NewExecSession(client)
if err != nil {
return nil, err
}
if err := s.maybeRequestAgentForwarding(session); err != nil {
_ = session.Close()
return nil, err
}
return session, nil
}
func (s *StarSSH) NewPTYSession(config *TerminalConfig) (*ssh.Session, error) {
@ -30,7 +46,15 @@ func (s *StarSSH) NewPTYSession(config *TerminalConfig) (*ssh.Session, error) {
if err != nil {
return nil, err
}
return NewPTYSession(client, config)
session, err := NewPTYSession(client, config)
if err != nil {
return nil, err
}
if err := s.maybeRequestAgentForwarding(session); err != nil {
_ = session.Close()
return nil, err
}
return session, nil
}
func NewTransferSession(client *ssh.Client) (*ssh.Session, error) {
@ -41,7 +65,7 @@ func NewExecSession(client *ssh.Client) (*ssh.Session, error) {
if client == nil {
return nil, errors.New("ssh client is nil")
}
return client.NewSession()
return newSSHSession(client)
}
func NewSession(client *ssh.Client) (*ssh.Session, error) {
@ -53,13 +77,13 @@ func NewPTYSession(client *ssh.Client, config *TerminalConfig) (*ssh.Session, er
return nil, errors.New("ssh client is nil")
}
session, err := client.NewSession()
session, err := newSSHSession(client)
if err != nil {
return nil, err
}
cfg := normalizeTerminalConfig(config)
if err := session.RequestPty(cfg.Term, cfg.Rows, cfg.Columns, cfg.Modes); err != nil {
if err := requestSessionPTY(session, cfg); err != nil {
_ = session.Close()
return nil, err
}

View File

@ -6,6 +6,8 @@ import (
"golang.org/x/crypto/ssh"
)
var errSSHClientClosing = errors.New("ssh client is closing")
type sshClientRequester interface {
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
Close() error
@ -29,10 +31,19 @@ func (s *StarSSH) snapshotSSHClient() *ssh.Client {
}
func (s *StarSSH) requireSSHClient() (*ssh.Client, error) {
if s == nil {
return nil, errors.New("ssh client is nil")
}
if s.closing.Load() {
return nil, errSSHClientClosing
}
client := s.snapshotSSHClient()
if client == nil {
return nil, errors.New("ssh client is nil")
}
if s.closing.Load() {
return nil, errSSHClientClosing
}
return client, nil
}
@ -46,6 +57,7 @@ func (s *StarSSH) setTransport(client *ssh.Client, upstream *StarSSH) {
s.Client = client
s.upstream = upstream
s.online = client != nil
s.closing.Store(false)
}
func (s *StarSSH) detachTransport() (*ssh.Client, *StarSSH) {
@ -84,7 +96,9 @@ func (s *StarSSH) closeTransport(waitKeepalive bool) error {
return nil
}
s.closing.Store(true)
_ = s.closeReusableSFTPClient()
agentForwarder := s.takeAgentForwarder()
client, upstream := s.detachTransport()
stop, done := s.takeKeepaliveHandles()
@ -93,8 +107,13 @@ func (s *StarSSH) closeTransport(waitKeepalive bool) error {
}
var closeErr error
if agentForwarder != nil {
closeErr = normalizeAlreadyClosedError(agentForwarder.Close())
}
if client != nil {
closeErr = normalizeAlreadyClosedError(closeSSHClient(client))
if err := normalizeAlreadyClosedError(closeSSHClient(client)); closeErr == nil {
closeErr = err
}
}
if waitKeepalive && done != nil {
<-done
@ -104,3 +123,13 @@ func (s *StarSSH) closeTransport(waitKeepalive bool) error {
}
return closeErr
}
func (s *StarSSH) canAttachAgentForwarder(client *ssh.Client) bool {
if s == nil || client == nil || s.closing.Load() {
return false
}
s.stateMu.RLock()
defer s.stateMu.RUnlock()
return !s.closing.Load() && s.Client == client
}

View File

@ -6,6 +6,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pkg/sftp"
@ -58,21 +59,24 @@ const (
)
type StarSSH struct {
stateMu sync.RWMutex
Client *ssh.Client
PublicKey ssh.PublicKey
PubkeyBase64 string
Hostname string
RemoteAddr net.Addr
Banner string
LoginInfo LoginInput
online bool
upstream *StarSSH
sftpClient *sftp.Client
sftpMu sync.Mutex
keepaliveMu sync.Mutex
keepaliveStop chan struct{}
keepaliveDone chan struct{}
stateMu sync.RWMutex
Client *ssh.Client
PublicKey ssh.PublicKey
PubkeyBase64 string
Hostname string
RemoteAddr net.Addr
Banner string
LoginInfo LoginInput
online bool
upstream *StarSSH
sftpClient *sftp.Client
sftpMu sync.Mutex
agentForwardMu sync.Mutex
agentForwarder io.Closer
keepaliveMu sync.Mutex
keepaliveStop chan struct{}
keepaliveDone chan struct{}
closing atomic.Bool
}
type LoginInput struct {
@ -86,6 +90,7 @@ type LoginInput struct {
Prikey string
PrikeyPwd string
DisableSSHAgent bool
ForwardSSHAgent bool
AuthOrder []AuthMethodKind
Addr string
Port int