diff --git a/agent_forward.go b/agent_forward.go new file mode 100644 index 0000000..14a76e3 --- /dev/null +++ b/agent_forward.go @@ -0,0 +1,151 @@ +package starssh + +import ( + "errors" + "fmt" + "io" + "strings" + "time" + + "golang.org/x/crypto/ssh" + sshagent "golang.org/x/crypto/ssh/agent" +) + +var requestSSHAgentForwarding = func(session *ssh.Session) error { + return sshagent.RequestAgentForwarding(session) +} + +var routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { + return sshagent.ForwardToAgent(client, keyring) +} + +var newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { + conn, err := dialSSHAgent(timeout) + if err != nil { + return nil, nil, wrapSSHAgentForwardingUnavailable(err) + } + if conn == nil { + return nil, nil, wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection")) + } + return sshagent.NewClient(conn), conn, nil +} + +var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied") +var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable") + +func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error { + if s == nil { + return errors.New("ssh client is nil") + } + if session == nil { + return errors.New("ssh session is nil") + } + if err := s.ensureAgentForwarding(); err != nil { + return err + } + if err := requestSSHAgentForwarding(session); err != nil { + if isSSHAgentForwardingDeniedError(err) { + return fmt.Errorf("%w: %v", errSSHAgentForwardingDenied, err) + } + return err + } + return nil +} + +func (s *StarSSH) maybeRequestAgentForwarding(session *ssh.Session) error { + if s == nil || !s.LoginInfo.ForwardSSHAgent { + return nil + } + err := s.RequestAgentForwarding(session) + if isSSHAgentForwardingDeniedError(err) || isSSHAgentForwardingUnavailableError(err) { + return nil + } + return err +} + +func (s *StarSSH) ensureAgentForwarding() error { + if s == nil { + return errors.New("ssh client is nil") + } + + s.agentForwardMu.Lock() + defer s.agentForwardMu.Unlock() + + if s.agentForwarder != nil { + return nil + } + + client, err := s.requireSSHClient() + if err != nil { + return err + } + + keyring, closer, err := newSSHAgentForwarder(s.LoginInfo.Timeout) + if err != nil { + return wrapSSHAgentForwardingUnavailable(err) + } + if s.closing.Load() { + _ = closer.Close() + return errSSHClientClosing + } + if err := routeSSHAgentForwarding(client, keyring); err != nil { + _ = closer.Close() + return err + } + if !s.canAttachAgentForwarder(client) { + _ = closer.Close() + return errSSHClientClosing + } + s.agentForwarder = closer + return nil +} + +func (s *StarSSH) takeAgentForwarder() io.Closer { + if s == nil { + return nil + } + + s.agentForwardMu.Lock() + defer s.agentForwardMu.Unlock() + + closer := s.agentForwarder + s.agentForwarder = nil + return closer +} + +func isSSHAgentForwardingDeniedError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, errSSHAgentForwardingDenied) { + return true + } + message := strings.ToLower(err.Error()) + return strings.Contains(message, "forwarding request denied") || + strings.Contains(message, "agent forwarding disabled") +} + +func isSSHAgentForwardingUnavailableError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, errSSHAgentForwardingUnavailable) { + return true + } + message := strings.ToLower(err.Error()) + return strings.Contains(message, "ssh-agent forwarding unavailable") || + strings.Contains(message, "ssh-agent unavailable") +} + +func wrapSSHAgentForwardingUnavailable(err error) error { + if err == nil { + return nil + } + if errors.Is(err, errSSHAgentForwardingUnavailable) { + return err + } + if errors.Is(err, errSSHAgentUnavailable) { + return fmt.Errorf("%w: %w", errSSHAgentForwardingUnavailable, err) + } + return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err) +} diff --git a/agent_forward_test.go b/agent_forward_test.go new file mode 100644 index 0000000..ae07de4 --- /dev/null +++ b/agent_forward_test.go @@ -0,0 +1,417 @@ +package starssh + +import ( + "errors" + "io" + "sync/atomic" + "testing" + "time" + + "golang.org/x/crypto/ssh" + sshagent "golang.org/x/crypto/ssh/agent" +) + +type testCloser struct { + closed atomic.Int32 +} + +func (c *testCloser) Close() error { + c.closed.Add(1) + return nil +} + +func TestNewExecSessionEnablesAgentForwardingOnce(t *testing.T) { + oldNewSSHSession := newSSHSession + oldNewSSHAgentForwarder := newSSHAgentForwarder + oldRouteSSHAgentForwarding := routeSSHAgentForwarding + oldRequestSSHAgentForwarding := requestSSHAgentForwarding + oldCloseSSHClient := closeSSHClient + t.Cleanup(func() { + newSSHSession = oldNewSSHSession + newSSHAgentForwarder = oldNewSSHAgentForwarder + routeSSHAgentForwarding = oldRouteSSHAgentForwarding + requestSSHAgentForwarding = oldRequestSSHAgentForwarding + closeSSHClient = oldCloseSSHClient + }) + + baseClient := &ssh.Client{} + star := &StarSSH{ + LoginInfo: LoginInput{ + ForwardSSHAgent: true, + Timeout: time.Second, + }, + } + star.setTransport(baseClient, nil) + + newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { + if client != baseClient { + t.Fatalf("unexpected ssh client %p", client) + } + return &ssh.Session{}, nil + } + + var agentInitCalls atomic.Int32 + closer := &testCloser{} + newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { + agentInitCalls.Add(1) + if timeout != time.Second { + t.Fatalf("unexpected forwarding timeout: %v", timeout) + } + return sshagent.NewKeyring(), closer, nil + } + + var routeCalls atomic.Int32 + routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { + routeCalls.Add(1) + if client != baseClient { + t.Fatalf("unexpected routed client %p", client) + } + if keyring == nil { + t.Fatal("expected non-nil forwarded agent keyring") + } + return nil + } + + var requestCalls atomic.Int32 + requestSSHAgentForwarding = func(session *ssh.Session) error { + requestCalls.Add(1) + if session == nil { + t.Fatal("expected non-nil ssh session") + } + return nil + } + + if _, err := star.NewExecSession(); err != nil { + t.Fatalf("first exec session: %v", err) + } + if _, err := star.NewExecSession(); err != nil { + t.Fatalf("second exec session: %v", err) + } + + if agentInitCalls.Load() != 1 { + t.Fatalf("expected one agent forwarder init, got %d", agentInitCalls.Load()) + } + if routeCalls.Load() != 1 { + t.Fatalf("expected one agent route registration, got %d", routeCalls.Load()) + } + if requestCalls.Load() != 2 { + t.Fatalf("expected agent forwarding request on each session, got %d", requestCalls.Load()) + } + + closeSSHClient = func(client sshClientRequester) error { return nil } + if err := star.Close(); err != nil { + t.Fatalf("close starssh: %v", err) + } + if closer.closed.Load() != 1 { + t.Fatalf("expected forwarded agent closer to run once, got %d", closer.closed.Load()) + } +} + +func TestNewPTYSessionEnablesAgentForwardingWhenConfigured(t *testing.T) { + oldNewSSHSession := newSSHSession + oldRequestSessionPTY := requestSessionPTY + oldNewSSHAgentForwarder := newSSHAgentForwarder + oldRouteSSHAgentForwarding := routeSSHAgentForwarding + oldRequestSSHAgentForwarding := requestSSHAgentForwarding + t.Cleanup(func() { + newSSHSession = oldNewSSHSession + requestSessionPTY = oldRequestSessionPTY + newSSHAgentForwarder = oldNewSSHAgentForwarder + routeSSHAgentForwarding = oldRouteSSHAgentForwarding + requestSSHAgentForwarding = oldRequestSSHAgentForwarding + }) + + star := &StarSSH{ + LoginInfo: LoginInput{ + ForwardSSHAgent: true, + }, + } + star.setTransport(&ssh.Client{}, nil) + + newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { + return &ssh.Session{}, nil + } + + var ptyCalls atomic.Int32 + requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error { + ptyCalls.Add(1) + return nil + } + + newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { + return sshagent.NewKeyring(), &testCloser{}, nil + } + routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil } + + var requestCalls atomic.Int32 + requestSSHAgentForwarding = func(session *ssh.Session) error { + requestCalls.Add(1) + return nil + } + + if _, err := star.NewPTYSession(nil); err != nil { + t.Fatalf("new pty session: %v", err) + } + if ptyCalls.Load() != 1 { + t.Fatalf("expected one PTY request, got %d", ptyCalls.Load()) + } + if requestCalls.Load() != 1 { + t.Fatalf("expected one agent forwarding request, got %d", requestCalls.Load()) + } +} + +func TestNewExecSessionSkipsAgentForwardingWhenDisabled(t *testing.T) { + oldNewSSHSession := newSSHSession + oldNewSSHAgentForwarder := newSSHAgentForwarder + oldRequestSSHAgentForwarding := requestSSHAgentForwarding + t.Cleanup(func() { + newSSHSession = oldNewSSHSession + newSSHAgentForwarder = oldNewSSHAgentForwarder + requestSSHAgentForwarding = oldRequestSSHAgentForwarding + }) + + star := &StarSSH{} + star.setTransport(&ssh.Client{}, nil) + + newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { + return &ssh.Session{}, nil + } + newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { + t.Fatal("agent forwarder should not initialize when disabled") + return nil, nil, nil + } + requestSSHAgentForwarding = func(session *ssh.Session) error { + t.Fatal("agent forwarding should not be requested when disabled") + return nil + } + + if _, err := star.NewExecSession(); err != nil { + t.Fatalf("new exec session without forwarding: %v", err) + } +} + +func TestRequestAgentForwardingReturnsUnavailableError(t *testing.T) { + oldNewSSHAgentForwarder := newSSHAgentForwarder + oldRequestSSHAgentForwarding := requestSSHAgentForwarding + t.Cleanup(func() { + newSSHAgentForwarder = oldNewSSHAgentForwarder + requestSSHAgentForwarding = oldRequestSSHAgentForwarding + }) + + star := &StarSSH{} + star.setTransport(&ssh.Client{}, nil) + + newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { + return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") + } + requestSSHAgentForwarding = func(session *ssh.Session) error { + t.Fatal("session request should not run when agent forwarder init fails") + return nil + } + + err := star.RequestAgentForwarding(&ssh.Session{}) + if err == nil { + t.Fatal("expected agent forwarding init error") + } +} + +func TestRequestAgentForwardingWrapsSetupErrorAsUnavailable(t *testing.T) { + oldNewSSHAgentForwarder := newSSHAgentForwarder + t.Cleanup(func() { + newSSHAgentForwarder = oldNewSSHAgentForwarder + }) + + star := &StarSSH{} + star.setTransport(&ssh.Client{}, nil) + + newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { + return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: permission denied") + } + + err := star.RequestAgentForwarding(&ssh.Session{}) + if !isSSHAgentForwardingUnavailableError(err) { + t.Fatalf("expected unavailable error, got %v", err) + } +} + +func TestRequestAgentForwardingReturnsDeniedError(t *testing.T) { + oldNewSSHAgentForwarder := newSSHAgentForwarder + oldRouteSSHAgentForwarding := routeSSHAgentForwarding + oldRequestSSHAgentForwarding := requestSSHAgentForwarding + t.Cleanup(func() { + newSSHAgentForwarder = oldNewSSHAgentForwarder + routeSSHAgentForwarding = oldRouteSSHAgentForwarding + requestSSHAgentForwarding = oldRequestSSHAgentForwarding + }) + + star := &StarSSH{} + star.setTransport(&ssh.Client{}, nil) + + newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { + return sshagent.NewKeyring(), &testCloser{}, nil + } + routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil } + requestSSHAgentForwarding = func(session *ssh.Session) error { + return errors.New("forwarding request denied") + } + + err := star.RequestAgentForwarding(&ssh.Session{}) + if !isSSHAgentForwardingDeniedError(err) { + t.Fatalf("expected forwarding denied error, got %v", err) + } +} + +func TestNewExecSessionIgnoresAgentForwardingDenied(t *testing.T) { + oldNewSSHSession := newSSHSession + oldNewSSHAgentForwarder := newSSHAgentForwarder + oldRouteSSHAgentForwarding := routeSSHAgentForwarding + oldRequestSSHAgentForwarding := requestSSHAgentForwarding + t.Cleanup(func() { + newSSHSession = oldNewSSHSession + newSSHAgentForwarder = oldNewSSHAgentForwarder + routeSSHAgentForwarding = oldRouteSSHAgentForwarding + requestSSHAgentForwarding = oldRequestSSHAgentForwarding + }) + + star := &StarSSH{ + LoginInfo: LoginInput{ + ForwardSSHAgent: true, + }, + } + star.setTransport(&ssh.Client{}, nil) + + newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { + return &ssh.Session{}, nil + } + newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { + return sshagent.NewKeyring(), &testCloser{}, nil + } + routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return nil } + requestSSHAgentForwarding = func(session *ssh.Session) error { + return errors.New("forwarding request denied") + } + + if _, err := star.NewExecSession(); err != nil { + t.Fatalf("new exec session should ignore denied agent forwarding: %v", err) + } +} + +func TestNewExecSessionIgnoresAgentForwardingUnavailable(t *testing.T) { + oldNewSSHSession := newSSHSession + oldNewSSHAgentForwarder := newSSHAgentForwarder + t.Cleanup(func() { + newSSHSession = oldNewSSHSession + newSSHAgentForwarder = oldNewSSHAgentForwarder + }) + + star := &StarSSH{ + LoginInfo: LoginInput{ + ForwardSSHAgent: true, + }, + } + star.setTransport(&ssh.Client{}, nil) + + newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { + return &ssh.Session{}, nil + } + newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { + return nil, nil, errors.New("ssh-agent forwarding unavailable: ssh-agent unavailable") + } + + if _, err := star.NewExecSession(); err != nil { + t.Fatalf("new exec session should ignore unavailable agent forwarding: %v", err) + } +} + +func TestNewExecSessionIgnoresAgentForwardingSetupError(t *testing.T) { + oldNewSSHSession := newSSHSession + oldNewSSHAgentForwarder := newSSHAgentForwarder + t.Cleanup(func() { + newSSHSession = oldNewSSHSession + newSSHAgentForwarder = oldNewSSHAgentForwarder + }) + + star := &StarSSH{ + LoginInfo: LoginInput{ + ForwardSSHAgent: true, + }, + } + star.setTransport(&ssh.Client{}, nil) + + newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { + return &ssh.Session{}, nil + } + newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { + return nil, nil, errors.New("dial unix /tmp/ssh-broken.sock: connect: connection refused") + } + + if _, err := star.NewExecSession(); err != nil { + t.Fatalf("new exec session should ignore agent setup error: %v", err) + } +} + +func TestEnsureAgentForwardingClosesNewForwarderWhenCloseStarts(t *testing.T) { + oldNewSSHAgentForwarder := newSSHAgentForwarder + oldRouteSSHAgentForwarding := routeSSHAgentForwarding + oldCloseSSHClient := closeSSHClient + t.Cleanup(func() { + newSSHAgentForwarder = oldNewSSHAgentForwarder + routeSSHAgentForwarding = oldRouteSSHAgentForwarding + closeSSHClient = oldCloseSSHClient + }) + + star := &StarSSH{ + LoginInfo: LoginInput{ + ForwardSSHAgent: true, + }, + } + star.setTransport(&ssh.Client{}, nil) + + started := make(chan struct{}) + release := make(chan struct{}) + closer := &testCloser{} + newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { + close(started) + <-release + return sshagent.NewKeyring(), closer, nil + } + routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { + return nil + } + closeSSHClient = func(client sshClientRequester) error { return nil } + + errCh := make(chan error, 1) + go func() { + errCh <- star.ensureAgentForwarding() + }() + + <-started + closeDone := make(chan struct{}) + go func() { + _ = star.Close() + close(closeDone) + }() + + deadline := time.Now().Add(time.Second) + for !star.closing.Load() { + if time.Now().After(deadline) { + t.Fatal("close did not enter closing state in time") + } + time.Sleep(time.Millisecond) + } + + close(release) + + err := <-errCh + if !errors.Is(err, errSSHClientClosing) { + t.Fatalf("expected closing error, got %v", err) + } + <-closeDone + + if closer.closed.Load() != 1 { + t.Fatalf("expected new forwarder closer to be closed once, got %d", closer.closed.Load()) + } + if got := star.takeAgentForwarder(); got != nil { + t.Fatal("expected no leaked agent forwarder after close race") + } +} diff --git a/forward.go b/forward.go index 57f9e10..9d7bc4f 100644 --- a/forward.go +++ b/forward.go @@ -3,21 +3,39 @@ package starssh import ( "context" "errors" + "fmt" "io" "net" + "os" "strconv" "strings" "sync" + "syscall" + "time" "golang.org/x/crypto/ssh" ) type ForwardRequest struct { + // Keep the exported shape compatible with older positional literals: + // ForwardRequest{listenAddr, targetAddr, dialContext}. + // + // Non-default networks can be encoded with an explicit scheme-like prefix: + // "tcp4://127.0.0.1:22", "tcp6://[::1]:22", "unix:///tmp/socket". + // Bare values default to the "tcp" network. ListenAddr string TargetAddr string DialContext DialContextFunc } +type normalizedForwardRequest struct { + ListenNetwork string + ListenAddr string + TargetNetwork string + TargetAddr string + DialContext DialContextFunc +} + type DynamicForwardRequest struct { ListenAddr string } @@ -41,10 +59,16 @@ type PortForwarder struct { cleanupFns []func() error } +const unixForwardProbeTimeout = 200 * time.Millisecond + var dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { return client.Dial(network, address) } +var listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) { + return client.Listen(network, address) +} + var newDetachedForwardClient = func(ctx context.Context, input LoginInput) (*StarSSH, error) { if ctx == nil { ctx = context.Background() @@ -64,6 +88,90 @@ func (s *StarSSH) DialTCPContextCloseOnCancel(ctx context.Context, network strin return s.dialTCPContext(ctx, network, address, s.Close) } +func (s *StarSSH) StartLocalTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) { + return s.StartLocalForward(ForwardRequest{ + ListenAddr: listenAddr, + TargetAddr: targetAddr, + }) +} + +func (s *StarSSH) StartLocalTCPForwardDetached(listenAddr string, targetAddr string) (*PortForwarder, error) { + return s.StartLocalForwardDetached(ForwardRequest{ + ListenAddr: listenAddr, + TargetAddr: targetAddr, + }) +} + +func (s *StarSSH) StartLocalTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) { + return s.StartLocalForward(ForwardRequest{ + ListenAddr: listenAddr, + TargetAddr: forwardEndpoint("unix", targetPath), + }) +} + +func (s *StarSSH) StartLocalTCPToUnixForwardDetached(listenAddr string, targetPath string) (*PortForwarder, error) { + return s.StartLocalForwardDetached(ForwardRequest{ + ListenAddr: listenAddr, + TargetAddr: forwardEndpoint("unix", targetPath), + }) +} + +func (s *StarSSH) StartLocalUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) { + return s.StartLocalForward(ForwardRequest{ + ListenAddr: forwardEndpoint("unix", listenPath), + TargetAddr: targetAddr, + }) +} + +func (s *StarSSH) StartLocalUnixForwardDetached(listenPath string, targetAddr string) (*PortForwarder, error) { + return s.StartLocalForwardDetached(ForwardRequest{ + ListenAddr: forwardEndpoint("unix", listenPath), + TargetAddr: targetAddr, + }) +} + +func (s *StarSSH) StartLocalUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) { + return s.StartLocalForward(ForwardRequest{ + ListenAddr: forwardEndpoint("unix", listenPath), + TargetAddr: forwardEndpoint("unix", targetPath), + }) +} + +func (s *StarSSH) StartLocalUnixToUnixForwardDetached(listenPath string, targetPath string) (*PortForwarder, error) { + return s.StartLocalForwardDetached(ForwardRequest{ + ListenAddr: forwardEndpoint("unix", listenPath), + TargetAddr: forwardEndpoint("unix", targetPath), + }) +} + +func (s *StarSSH) StartRemoteTCPForward(listenAddr string, targetAddr string) (*PortForwarder, error) { + return s.StartRemoteForward(ForwardRequest{ + ListenAddr: listenAddr, + TargetAddr: targetAddr, + }) +} + +func (s *StarSSH) StartRemoteTCPToUnixForward(listenAddr string, targetPath string) (*PortForwarder, error) { + return s.StartRemoteForward(ForwardRequest{ + ListenAddr: listenAddr, + TargetAddr: forwardEndpoint("unix", targetPath), + }) +} + +func (s *StarSSH) StartRemoteUnixForward(listenPath string, targetAddr string) (*PortForwarder, error) { + return s.StartRemoteForward(ForwardRequest{ + ListenAddr: forwardEndpoint("unix", listenPath), + TargetAddr: targetAddr, + }) +} + +func (s *StarSSH) StartRemoteUnixToUnixForward(listenPath string, targetPath string) (*PortForwarder, error) { + return s.StartRemoteForward(ForwardRequest{ + ListenAddr: forwardEndpoint("unix", listenPath), + TargetAddr: forwardEndpoint("unix", targetPath), + }) +} + func (s *StarSSH) dialTCPContext(ctx context.Context, network string, address string, onCancel func() error) (net.Conn, error) { if ctx == nil { ctx = context.Background() @@ -136,21 +244,22 @@ func (s *StarSSH) StartLocalForward(req ForwardRequest) (*PortForwarder, error) if _, err := s.requireSSHClient(); err != nil { return nil, err } - if strings.TrimSpace(req.ListenAddr) == "" { + normalizedReq, err := normalizeForwardRequest(req) + if err != nil { + return nil, err + } + if strings.TrimSpace(normalizedReq.ListenAddr) == "" { return nil, errors.New("local forward listen address is empty") } - if strings.TrimSpace(req.TargetAddr) == "" { - return nil, errors.New("local forward target address is empty") - } - - listener, err := net.Listen("tcp", req.ListenAddr) + listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr) if err != nil { return nil, err } forwarder := newPortForwarder(listener) + forwarder.addCleanup(cleanup) forwarder.serve(func(ctx context.Context) (net.Conn, error) { - return s.DialTCPContext(ctx, "tcp", req.TargetAddr) + return s.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr) }) return forwarder, nil } @@ -159,14 +268,12 @@ func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder, if _, err := s.requireSSHClient(); err != nil { return nil, err } - if strings.TrimSpace(req.ListenAddr) == "" { - return nil, errors.New("local forward listen address is empty") - } - if strings.TrimSpace(req.TargetAddr) == "" { - return nil, errors.New("local forward target address is empty") + normalizedReq, err := normalizeForwardRequest(req) + if err != nil { + return nil, err } - listener, err := net.Listen("tcp", req.ListenAddr) + listener, cleanup, err := prepareLocalForwardListener(normalizedReq.ListenNetwork, normalizedReq.ListenAddr) if err != nil { return nil, err } @@ -174,15 +281,19 @@ func (s *StarSSH) StartLocalForwardDetached(req ForwardRequest) (*PortForwarder, forwardClient, err := s.newForwardDialClient(context.Background()) if err != nil { _ = listener.Close() + if cleanup != nil { + _ = cleanup() + } return nil, err } forwarder := newPortForwarder(listener) + forwarder.addCleanup(cleanup) forwarder.addCleanup(func() error { return normalizeAlreadyClosedError(forwardClient.Close()) }) forwarder.serve(func(ctx context.Context) (net.Conn, error) { - return forwardClient.DialTCPContext(ctx, "tcp", req.TargetAddr) + return forwardClient.DialTCPContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr) }) return forwarder, nil } @@ -192,19 +303,17 @@ func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error) if err != nil { return nil, err } - if strings.TrimSpace(req.ListenAddr) == "" { - return nil, errors.New("remote forward listen address is empty") - } - if strings.TrimSpace(req.TargetAddr) == "" { - return nil, errors.New("remote forward target address is empty") - } - - listener, err := client.Listen("tcp", req.ListenAddr) + normalizedReq, err := normalizeForwardRequest(req) if err != nil { return nil, err } - dialContext := req.DialContext + listener, err := listenSSHClient(client, normalizedReq.ListenNetwork, normalizedReq.ListenAddr) + if err != nil { + return nil, err + } + + dialContext := normalizedReq.DialContext if dialContext == nil { dialer := &net.Dialer{ Timeout: defaultLoginTimeout, @@ -214,7 +323,7 @@ func (s *StarSSH) StartRemoteForward(req ForwardRequest) (*PortForwarder, error) forwarder := newPortForwarder(listener) forwarder.serve(func(ctx context.Context) (net.Conn, error) { - return dialContext(ctx, "tcp", req.TargetAddr) + return dialContext(ctx, normalizedReq.TargetNetwork, normalizedReq.TargetAddr) }) return forwarder, nil } @@ -239,6 +348,74 @@ func (s *StarSSH) StartDynamicForward(req DynamicForwardRequest) (*PortForwarder return forwarder, nil } +func normalizeForwardRequest(req ForwardRequest) (normalizedForwardRequest, error) { + normalized := normalizedForwardRequest{ + DialContext: req.DialContext, + } + + var err error + normalized.ListenNetwork, normalized.ListenAddr, err = parseForwardEndpoint(req.ListenAddr) + if err != nil { + return normalized, fmt.Errorf("normalize listen address: %w", err) + } + normalized.TargetNetwork, normalized.TargetAddr, err = parseForwardEndpoint(req.TargetAddr) + if err != nil { + return normalized, fmt.Errorf("normalize target address: %w", err) + } + if strings.TrimSpace(normalized.ListenAddr) == "" { + return normalized, errors.New("forward listen address is empty") + } + if strings.TrimSpace(normalized.TargetAddr) == "" { + return normalized, errors.New("forward target address is empty") + } + return normalized, nil +} + +func normalizeForwardNetwork(network string) string { + network = strings.ToLower(strings.TrimSpace(network)) + if network == "" { + return "tcp" + } + return network +} + +func isSupportedForwardNetwork(network string) bool { + switch network { + case "tcp", "tcp4", "tcp6", "unix": + return true + default: + return false + } +} + +func parseForwardEndpoint(value string) (network string, address string, err error) { + value = strings.TrimSpace(value) + if value == "" { + return "tcp", "", nil + } + + lowerValue := strings.ToLower(value) + for _, prefix := range []string{"tcp4://", "tcp6://", "tcp://", "unix://"} { + if strings.HasPrefix(lowerValue, prefix) { + network = normalizeForwardNetwork(strings.TrimSuffix(prefix, "://")) + address = value[len(prefix):] + if !isSupportedForwardNetwork(network) { + return "", "", fmt.Errorf("unsupported forward network %q", network) + } + return network, address, nil + } + } + return "tcp", value, nil +} + +func forwardEndpoint(network string, address string) string { + network = normalizeForwardNetwork(network) + if network == "tcp" { + return address + } + return network + "://" + address +} + func (s *StarSSH) StartDynamicForwardDetached(req DynamicForwardRequest) (*PortForwarder, error) { if _, err := s.requireSSHClient(); err != nil { return nil, err @@ -344,6 +521,87 @@ func (f *PortForwarder) addCleanup(fn func() error) { f.cleanupFns = append(f.cleanupFns, fn) } +func prepareLocalForwardListener(network string, address string) (net.Listener, func() error, error) { + network = normalizeForwardNetwork(network) + if network != "unix" { + listener, err := net.Listen(network, address) + return listener, nil, err + } + + if err := removeStaleUnixSocket(address); err != nil { + return nil, nil, err + } + + listener, err := net.Listen(network, address) + if err != nil { + return nil, nil, err + } + + cleanup, err := makeUnixSocketCleanup(address) + if err != nil { + _ = listener.Close() + _ = removeUnixSocketPath(address) + return nil, nil, err + } + return listener, cleanup, nil +} + +func removeStaleUnixSocket(path string) error { + info, err := os.Lstat(path) + if errors.Is(err, os.ErrNotExist) { + return nil + } + if err != nil { + return err + } + if info.Mode()&os.ModeSocket == 0 { + return fmt.Errorf("local unix forward path %q already exists and is not a socket", path) + } + + conn, err := net.DialTimeout("unix", path, unixForwardProbeTimeout) + if err == nil { + _ = conn.Close() + return fmt.Errorf("local unix forward path %q is already in use", path) + } + if !isStaleUnixSocketDialError(err) { + return fmt.Errorf("probe existing unix socket %q: %w", path, err) + } + return removeUnixSocketPath(path) +} + +func isStaleUnixSocketDialError(err error) bool { + return errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT) +} + +func makeUnixSocketCleanup(path string) (func() error, error) { + info, err := os.Lstat(path) + if err != nil { + return nil, err + } + + return func() error { + current, err := os.Lstat(path) + if errors.Is(err, os.ErrNotExist) { + return nil + } + if err != nil { + return err + } + if current.Mode()&os.ModeSocket == 0 || !os.SameFile(info, current) { + return nil + } + return removeUnixSocketPath(path) + }, nil +} + +func removeUnixSocketPath(path string) error { + err := os.Remove(path) + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err +} + func (f *PortForwarder) runCleanup() { if f == nil { return diff --git a/forward_test.go b/forward_test.go index ec38d5d..e942f28 100644 --- a/forward_test.go +++ b/forward_test.go @@ -2,8 +2,13 @@ package starssh import ( "context" + "errors" "io" "net" + "os" + "path/filepath" + "runtime" + "sync" "sync/atomic" "testing" "time" @@ -11,6 +16,59 @@ import ( "golang.org/x/crypto/ssh" ) +type stubListener struct { + addr net.Addr + acceptCh chan net.Conn + closeCh chan struct{} + closeOnce sync.Once +} + +type dialRecord struct { + network string + addr string +} + +func newStubListener(addr net.Addr) *stubListener { + return &stubListener{ + addr: addr, + acceptCh: make(chan net.Conn, 1), + closeCh: make(chan struct{}), + } +} + +func (l *stubListener) Accept() (net.Conn, error) { + select { + case conn, ok := <-l.acceptCh: + if !ok { + return nil, io.EOF + } + return conn, nil + case <-l.closeCh: + return nil, net.ErrClosed + } +} + +func (l *stubListener) Close() error { + l.closeOnce.Do(func() { + close(l.closeCh) + close(l.acceptCh) + }) + return nil +} + +func (l *stubListener) Addr() net.Addr { + return l.addr +} + +func (l *stubListener) Push(conn net.Conn) error { + select { + case <-l.closeCh: + return net.ErrClosed + case l.acceptCh <- conn: + return nil + } +} + func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) { oldDialSSHClient := dialSSHClient oldNewDetachedForwardClient := newDetachedForwardClient @@ -63,6 +121,64 @@ func TestStartLocalForwardUsesExistingConnectionByDefault(t *testing.T) { } } +func TestForwardRequestLegacyPositionalLiteralDefaultsToTCP(t *testing.T) { + dialer := func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, nil + } + + req, err := normalizeForwardRequest(ForwardRequest{ + "127.0.0.1:10022", + "example.internal:22", + dialer, + }) + if err != nil { + t.Fatalf("normalizeForwardRequest: %v", err) + } + if req.ListenNetwork != "tcp" { + t.Fatalf("ListenNetwork=%q want tcp", req.ListenNetwork) + } + if req.TargetNetwork != "tcp" { + t.Fatalf("TargetNetwork=%q want tcp", req.TargetNetwork) + } + if req.ListenAddr != "127.0.0.1:10022" || req.TargetAddr != "example.internal:22" { + t.Fatalf("unexpected normalized request: %+v", req) + } + if req.DialContext == nil { + t.Fatal("expected DialContext to be preserved") + } +} + +func TestParseForwardEndpointTreatsTCPPrefixLikePlainAddress(t *testing.T) { + network, address, err := parseForwardEndpoint("tcp:22") + if err != nil { + t.Fatalf("parseForwardEndpoint: %v", err) + } + if network != "tcp" { + t.Fatalf("network=%q want tcp", network) + } + if address != "tcp:22" { + t.Fatalf("address=%q want tcp:22", address) + } +} + +func TestParseForwardEndpointSupportsExplicitSchemes(t *testing.T) { + network, address, err := parseForwardEndpoint("unix:///tmp/test-forward.sock") + if err != nil { + t.Fatalf("parseForwardEndpoint unix: %v", err) + } + if network != "unix" || address != "/tmp/test-forward.sock" { + t.Fatalf("unexpected unix endpoint parse: network=%q address=%q", network, address) + } + + network, address, err = parseForwardEndpoint("tcp6://[::1]:2222") + if err != nil { + t.Fatalf("parseForwardEndpoint tcp6: %v", err) + } + if network != "tcp6" || address != "[::1]:2222" { + t.Fatalf("unexpected tcp6 endpoint parse: network=%q address=%q", network, address) + } +} + func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) { oldDialSSHClient := dialSSHClient oldNewDetachedForwardClient := newDetachedForwardClient @@ -132,6 +248,424 @@ func TestStartLocalForwardDetachedUsesSeparateConnection(t *testing.T) { } } +func TestStartRemoteForwardSupportsUnixListenAndTCPTarget(t *testing.T) { + oldListenSSHClient := listenSSHClient + t.Cleanup(func() { + listenSSHClient = oldListenSSHClient + }) + + baseClient := &ssh.Client{} + star := &StarSSH{} + star.setTransport(baseClient, nil) + + listener := newStubListener(&net.UnixAddr{ + Name: "/run/user/0/gnupg/S.gpg-agent", + Net: "unix", + }) + + var listenedNetwork string + var listenedAddr string + listenSSHClient = func(client *ssh.Client, network, address string) (net.Listener, error) { + if client != baseClient { + t.Fatalf("unexpected ssh client %p", client) + } + listenedNetwork = network + listenedAddr = address + return listener, nil + } + + var targetNetwork string + var targetAddr string + forwarder, err := star.StartRemoteForward(ForwardRequest{ + ListenAddr: forwardEndpoint("unix", "/run/user/0/gnupg/S.gpg-agent"), + TargetAddr: "127.0.0.1:4321", + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + targetNetwork = network + targetAddr = address + serverConn, clientConn := net.Pipe() + go echoForwardPipe(serverConn) + return clientConn, nil + }, + }) + if err != nil { + t.Fatalf("start remote unix forward: %v", err) + } + defer forwarder.Close() + + srcPeer, forwardedConn := net.Pipe() + defer srcPeer.Close() + if err := listener.Push(forwardedConn); err != nil { + t.Fatalf("push forwarded connection: %v", err) + } + + payload := []byte("unix-forward") + done := make(chan []byte, 1) + go func() { + reply := make([]byte, len(payload)) + _, _ = io.ReadFull(srcPeer, reply) + done <- reply + }() + + if _, err := srcPeer.Write(payload); err != nil { + t.Fatalf("write source payload: %v", err) + } + + select { + case reply := <-done: + if string(reply) != string(payload) { + t.Fatalf("unexpected remote unix forward reply: %q", string(reply)) + } + case <-time.After(2 * time.Second): + t.Fatal("remote unix forward did not relay payload") + } + + if listenedNetwork != "unix" || listenedAddr != "/run/user/0/gnupg/S.gpg-agent" { + t.Fatalf("unexpected remote listen request: network=%q addr=%q", listenedNetwork, listenedAddr) + } + if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" { + t.Fatalf("unexpected local dial target: network=%q addr=%q", targetNetwork, targetAddr) + } +} + +func TestStartLocalUnixForwardUsesUnixListenerAndTCPTarget(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("unix socket smoke test is exercised in WSL/Linux CI path") + } + + oldDialSSHClient := dialSSHClient + t.Cleanup(func() { + dialSSHClient = oldDialSSHClient + }) + + baseClient := &ssh.Client{} + star := &StarSSH{} + star.setTransport(baseClient, nil) + + var targetNetwork string + var targetAddr string + dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { + if client != baseClient { + t.Fatalf("unexpected ssh client %p", client) + } + targetNetwork = network + targetAddr = address + serverConn, clientConn := net.Pipe() + go echoForwardPipe(serverConn) + return clientConn, nil + } + + socketPath := filepath.Join(t.TempDir(), "forward.sock") + forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321") + if err != nil { + t.Fatalf("start local unix forward: %v", err) + } + defer func() { + closeErr := forwarder.Close() + if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) { + t.Fatalf("close local unix forward: %v", closeErr) + } + }() + + conn, err := net.DialTimeout("unix", socketPath, time.Second) + if err != nil { + t.Fatalf("dial unix forward listener: %v", err) + } + defer conn.Close() + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + + payload := []byte("unix-local-forward") + if _, err := conn.Write(payload); err != nil { + t.Fatalf("write unix forward payload: %v", err) + } + reply := make([]byte, len(payload)) + if _, err := io.ReadFull(conn, reply); err != nil { + t.Fatalf("read unix forward reply: %v", err) + } + if string(reply) != string(payload) { + t.Fatalf("unexpected unix forward reply: %q", string(reply)) + } + if targetNetwork != "tcp" || targetAddr != "127.0.0.1:4321" { + t.Fatalf("unexpected remote dial target: network=%q addr=%q", targetNetwork, targetAddr) + } +} + +func TestStartLocalUnixForwardRemovesSocketOnClose(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("unix socket smoke test is exercised in WSL/Linux CI path") + } + + oldDialSSHClient := dialSSHClient + t.Cleanup(func() { + dialSSHClient = oldDialSSHClient + }) + + baseClient := &ssh.Client{} + star := &StarSSH{} + star.setTransport(baseClient, nil) + + dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { + serverConn, clientConn := net.Pipe() + go echoForwardPipe(serverConn) + return clientConn, nil + } + + socketPath := filepath.Join(t.TempDir(), "cleanup.sock") + forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321") + if err != nil { + t.Fatalf("start local unix forward: %v", err) + } + + if _, err := os.Lstat(socketPath); err != nil { + t.Fatalf("socket should exist while forward is running: %v", err) + } + if err := forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + t.Fatalf("close local unix forward: %v", err) + } + if _, err := os.Lstat(socketPath); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("socket path should be removed on close, got err=%v", err) + } +} + +func TestStartLocalUnixForwardReusesStaleSocketPath(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("unix socket smoke test is exercised in WSL/Linux CI path") + } + + oldDialSSHClient := dialSSHClient + t.Cleanup(func() { + dialSSHClient = oldDialSSHClient + }) + + baseClient := &ssh.Client{} + star := &StarSSH{} + star.setTransport(baseClient, nil) + + dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { + serverConn, clientConn := net.Pipe() + go echoForwardPipe(serverConn) + return clientConn, nil + } + + socketPath := filepath.Join(t.TempDir(), "stale.sock") + staleListener, err := net.ListenUnix("unix", &net.UnixAddr{ + Name: socketPath, + Net: "unix", + }) + if err != nil { + t.Fatalf("create stale unix socket: %v", err) + } + staleListener.SetUnlinkOnClose(false) + if err := staleListener.Close(); err != nil { + t.Fatalf("close stale unix socket listener: %v", err) + } + if _, err := os.Lstat(socketPath); err != nil { + t.Fatalf("expected stale unix socket path to remain after close: %v", err) + } + + forwarder, err := star.StartLocalUnixForward(socketPath, "127.0.0.1:4321") + if err != nil { + t.Fatalf("start local unix forward on stale socket path: %v", err) + } + defer func() { + closeErr := forwarder.Close() + if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) { + t.Fatalf("close local unix forward: %v", closeErr) + } + }() + + reply := make([]byte, len("stale-reuse")) + conn, err := net.DialTimeout("unix", socketPath, time.Second) + if err != nil { + t.Fatalf("dial reused unix forward listener: %v", err) + } + defer conn.Close() + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + if _, err := conn.Write([]byte("stale-reuse")); err != nil { + t.Fatalf("write reused unix forward payload: %v", err) + } + if _, err := io.ReadFull(conn, reply); err != nil { + t.Fatalf("read reused unix forward reply: %v", err) + } + if string(reply) != "stale-reuse" { + t.Fatalf("unexpected reply on reused unix forward: %q", string(reply)) + } +} + +func TestStartLocalUnixToUnixForwardUsesUnixTarget(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("unix socket smoke test is exercised in WSL/Linux CI path") + } + + oldDialSSHClient := dialSSHClient + t.Cleanup(func() { + dialSSHClient = oldDialSSHClient + }) + + baseClient := &ssh.Client{} + star := &StarSSH{} + star.setTransport(baseClient, nil) + + targetSocketPath := filepath.Join(t.TempDir(), "target.sock") + targetListener, err := net.Listen("unix", targetSocketPath) + if err != nil { + t.Fatalf("listen target unix socket: %v", err) + } + defer targetListener.Close() + + done := make(chan []byte, 1) + go func() { + conn, acceptErr := targetListener.Accept() + if acceptErr != nil { + done <- nil + return + } + defer conn.Close() + buf := make([]byte, 64) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + done <- buf[:n] + }() + + dialRecordCh := make(chan dialRecord, 1) + dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { + if client != baseClient { + t.Fatalf("unexpected ssh client %p", client) + } + dialRecordCh <- dialRecord{network: network, addr: address} + var dialer net.Dialer + return dialer.DialContext(ctx, network, address) + } + + listenSocketPath := filepath.Join(t.TempDir(), "listen.sock") + forwarder, err := star.StartLocalUnixToUnixForward(listenSocketPath, targetSocketPath) + if err != nil { + t.Fatalf("start local unix-to-unix forward: %v", err) + } + defer func() { + closeErr := forwarder.Close() + if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) { + t.Fatalf("close local unix-to-unix forward: %v", closeErr) + } + }() + + conn, err := net.DialTimeout("unix", listenSocketPath, time.Second) + if err != nil { + t.Fatalf("dial unix-to-unix listener: %v", err) + } + defer conn.Close() + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + + payload := []byte("unix-to-unix") + if _, err := conn.Write(payload); err != nil { + t.Fatalf("write unix-to-unix payload: %v", err) + } + reply := make([]byte, len(payload)) + if _, err := io.ReadFull(conn, reply); err != nil { + t.Fatalf("read unix-to-unix reply: %v", err) + } + if string(reply) != string(payload) { + t.Fatalf("unexpected unix-to-unix reply: %q", string(reply)) + } + + select { + case got := <-done: + if string(got) != string(payload) { + t.Fatalf("unexpected payload seen by target unix socket: %q", string(got)) + } + case <-time.After(2 * time.Second): + t.Fatal("target unix socket did not receive forwarded payload") + } + + select { + case got := <-dialRecordCh: + if got.network != "unix" || got.addr != targetSocketPath { + t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr) + } + case <-time.After(2 * time.Second): + t.Fatal("did not observe unix target dial") + } +} + +func TestStartLocalTCPToUnixForwardUsesUnixTarget(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("unix socket smoke test is exercised in WSL/Linux CI path") + } + + oldDialSSHClient := dialSSHClient + t.Cleanup(func() { + dialSSHClient = oldDialSSHClient + }) + + baseClient := &ssh.Client{} + star := &StarSSH{} + star.setTransport(baseClient, nil) + + targetSocketPath := filepath.Join(t.TempDir(), "target-tcp-to-unix.sock") + targetListener, err := net.Listen("unix", targetSocketPath) + if err != nil { + t.Fatalf("listen target unix socket: %v", err) + } + defer targetListener.Close() + + done := make(chan []byte, 1) + go func() { + conn, acceptErr := targetListener.Accept() + if acceptErr != nil { + done <- nil + return + } + defer conn.Close() + buf := make([]byte, 64) + n, _ := conn.Read(buf) + _, _ = conn.Write(buf[:n]) + done <- buf[:n] + }() + + dialRecordCh := make(chan dialRecord, 1) + dialSSHClient = func(ctx context.Context, client *ssh.Client, network, address string) (net.Conn, error) { + if client != baseClient { + t.Fatalf("unexpected ssh client %p", client) + } + dialRecordCh <- dialRecord{network: network, addr: address} + var dialer net.Dialer + return dialer.DialContext(ctx, network, address) + } + + forwarder, err := star.StartLocalTCPToUnixForward("127.0.0.1:0", targetSocketPath) + if err != nil { + t.Fatalf("start local tcp-to-unix forward: %v", err) + } + defer func() { + closeErr := forwarder.Close() + if closeErr != nil && !errors.Is(closeErr, net.ErrClosed) { + t.Fatalf("close local tcp-to-unix forward: %v", closeErr) + } + }() + + reply := exerciseForwarder(t, forwarder.Addr().String(), []byte("tcp-to-unix")) + if string(reply) != "tcp-to-unix" { + t.Fatalf("unexpected tcp-to-unix reply: %q", string(reply)) + } + + select { + case got := <-done: + if string(got) != "tcp-to-unix" { + t.Fatalf("unexpected payload seen by unix target: %q", string(got)) + } + case <-time.After(2 * time.Second): + t.Fatal("unix target did not receive forwarded tcp payload") + } + + select { + case got := <-dialRecordCh: + if got.network != "unix" || got.addr != targetSocketPath { + t.Fatalf("unexpected unix target dial: network=%q addr=%q", got.network, got.addr) + } + case <-time.After(2 * time.Second): + t.Fatal("did not observe unix target dial") + } +} + func echoForwardPipe(conn net.Conn) { defer conn.Close() buf := make([]byte, 4096) diff --git a/session.go b/session.go index 769b6bf..0114c8c 100644 --- a/session.go +++ b/session.go @@ -9,6 +9,14 @@ import ( "golang.org/x/crypto/ssh" ) +var newSSHSession = func(client *ssh.Client) (*ssh.Session, error) { + return client.NewSession() +} + +var requestSessionPTY = func(session *ssh.Session, config TerminalConfig) error { + return session.RequestPty(config.Term, config.Rows, config.Columns, config.Modes) +} + func (s *StarSSH) Close() error { return s.closeTransport(true) } @@ -22,7 +30,15 @@ func (s *StarSSH) NewExecSession() (*ssh.Session, error) { if err != nil { return nil, err } - return NewExecSession(client) + session, err := NewExecSession(client) + if err != nil { + return nil, err + } + if err := s.maybeRequestAgentForwarding(session); err != nil { + _ = session.Close() + return nil, err + } + return session, nil } func (s *StarSSH) NewPTYSession(config *TerminalConfig) (*ssh.Session, error) { @@ -30,7 +46,15 @@ func (s *StarSSH) NewPTYSession(config *TerminalConfig) (*ssh.Session, error) { if err != nil { return nil, err } - return NewPTYSession(client, config) + session, err := NewPTYSession(client, config) + if err != nil { + return nil, err + } + if err := s.maybeRequestAgentForwarding(session); err != nil { + _ = session.Close() + return nil, err + } + return session, nil } func NewTransferSession(client *ssh.Client) (*ssh.Session, error) { @@ -41,7 +65,7 @@ func NewExecSession(client *ssh.Client) (*ssh.Session, error) { if client == nil { return nil, errors.New("ssh client is nil") } - return client.NewSession() + return newSSHSession(client) } func NewSession(client *ssh.Client) (*ssh.Session, error) { @@ -53,13 +77,13 @@ func NewPTYSession(client *ssh.Client, config *TerminalConfig) (*ssh.Session, er return nil, errors.New("ssh client is nil") } - session, err := client.NewSession() + session, err := newSSHSession(client) if err != nil { return nil, err } cfg := normalizeTerminalConfig(config) - if err := session.RequestPty(cfg.Term, cfg.Rows, cfg.Columns, cfg.Modes); err != nil { + if err := requestSessionPTY(session, cfg); err != nil { _ = session.Close() return nil, err } diff --git a/state.go b/state.go index c05afd4..9f5afaf 100644 --- a/state.go +++ b/state.go @@ -6,6 +6,8 @@ import ( "golang.org/x/crypto/ssh" ) +var errSSHClientClosing = errors.New("ssh client is closing") + type sshClientRequester interface { SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) Close() error @@ -29,10 +31,19 @@ func (s *StarSSH) snapshotSSHClient() *ssh.Client { } func (s *StarSSH) requireSSHClient() (*ssh.Client, error) { + if s == nil { + return nil, errors.New("ssh client is nil") + } + if s.closing.Load() { + return nil, errSSHClientClosing + } client := s.snapshotSSHClient() if client == nil { return nil, errors.New("ssh client is nil") } + if s.closing.Load() { + return nil, errSSHClientClosing + } return client, nil } @@ -46,6 +57,7 @@ func (s *StarSSH) setTransport(client *ssh.Client, upstream *StarSSH) { s.Client = client s.upstream = upstream s.online = client != nil + s.closing.Store(false) } func (s *StarSSH) detachTransport() (*ssh.Client, *StarSSH) { @@ -84,7 +96,9 @@ func (s *StarSSH) closeTransport(waitKeepalive bool) error { return nil } + s.closing.Store(true) _ = s.closeReusableSFTPClient() + agentForwarder := s.takeAgentForwarder() client, upstream := s.detachTransport() stop, done := s.takeKeepaliveHandles() @@ -93,8 +107,13 @@ func (s *StarSSH) closeTransport(waitKeepalive bool) error { } var closeErr error + if agentForwarder != nil { + closeErr = normalizeAlreadyClosedError(agentForwarder.Close()) + } if client != nil { - closeErr = normalizeAlreadyClosedError(closeSSHClient(client)) + if err := normalizeAlreadyClosedError(closeSSHClient(client)); closeErr == nil { + closeErr = err + } } if waitKeepalive && done != nil { <-done @@ -104,3 +123,13 @@ func (s *StarSSH) closeTransport(waitKeepalive bool) error { } return closeErr } + +func (s *StarSSH) canAttachAgentForwarder(client *ssh.Client) bool { + if s == nil || client == nil || s.closing.Load() { + return false + } + + s.stateMu.RLock() + defer s.stateMu.RUnlock() + return !s.closing.Load() && s.Client == client +} diff --git a/types.go b/types.go index 7e2fe35..27d9028 100644 --- a/types.go +++ b/types.go @@ -6,6 +6,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/pkg/sftp" @@ -58,21 +59,24 @@ const ( ) type StarSSH struct { - stateMu sync.RWMutex - Client *ssh.Client - PublicKey ssh.PublicKey - PubkeyBase64 string - Hostname string - RemoteAddr net.Addr - Banner string - LoginInfo LoginInput - online bool - upstream *StarSSH - sftpClient *sftp.Client - sftpMu sync.Mutex - keepaliveMu sync.Mutex - keepaliveStop chan struct{} - keepaliveDone chan struct{} + stateMu sync.RWMutex + Client *ssh.Client + PublicKey ssh.PublicKey + PubkeyBase64 string + Hostname string + RemoteAddr net.Addr + Banner string + LoginInfo LoginInput + online bool + upstream *StarSSH + sftpClient *sftp.Client + sftpMu sync.Mutex + agentForwardMu sync.Mutex + agentForwarder io.Closer + keepaliveMu sync.Mutex + keepaliveStop chan struct{} + keepaliveDone chan struct{} + closing atomic.Bool } type LoginInput struct { @@ -86,6 +90,7 @@ type LoginInput struct { Prikey string PrikeyPwd string DisableSSHAgent bool + ForwardSSHAgent bool AuthOrder []AuthMethodKind Addr string Port int