package starssh import ( "errors" "fmt" "io" "net" "strings" "sync" "time" "golang.org/x/crypto/ssh" sshagent "golang.org/x/crypto/ssh/agent" ) var requestSSHAgentForwarding = func(session *ssh.Session) error { return sshagent.RequestAgentForwarding(session) } const sshAgentChannelType = "auth-agent@openssh.com" var routeSSHAgentForwarding = func(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) { return startSSHAgentForwardProxy(client, timeouts) } var probeSSHAgentForwarding = func(timeouts sshAgentTimeouts) error { conn, _, err := dialSSHAgentWithDebug("forward-probe", timeouts) if err != nil { return wrapSSHAgentForwardingUnavailable(err) } if conn == nil { return wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection")) } return conn.Close() } var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied") var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable") type sshAgentForwardProxy struct { stopOnce sync.Once stopCh chan struct{} activeMu sync.Mutex active map[*sshAgentForwardBridge]struct{} } func (p *sshAgentForwardProxy) Close() error { if p == nil { return nil } p.stopOnce.Do(func() { close(p.stopCh) }) p.closeActive() return nil } type sshAgentForwardBridge struct { proxy *sshAgentForwardProxy channel ssh.Channel conn net.Conn idleTimeout time.Duration closeOnce sync.Once signalOnce sync.Once done chan struct{} activity chan struct{} } func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error { if s == nil { return errors.New("ssh client is nil") } if session == nil { return errors.New("ssh session is nil") } if err := s.ensureAgentForwarding(); err != nil { return err } if err := requestSSHAgentForwarding(session); err != nil { if isSSHAgentForwardingDeniedError(err) { return fmt.Errorf("%w: %v", errSSHAgentForwardingDenied, err) } return err } return nil } func (s *StarSSH) maybeRequestAgentForwarding(session *ssh.Session) error { if s == nil || !s.LoginInfo.ForwardSSHAgent { return nil } err := s.RequestAgentForwarding(session) if isSSHAgentForwardingDeniedError(err) || isSSHAgentForwardingUnavailableError(err) { return nil } return err } func (s *StarSSH) ensureAgentForwarding() error { if s == nil { return errors.New("ssh client is nil") } s.agentForwardMu.Lock() defer s.agentForwardMu.Unlock() if s.agentForwarder != nil { return nil } client, err := s.requireSSHClient() if err != nil { return err } timeouts := effectiveSSHAgentTimeouts(s.LoginInfo) if err := probeSSHAgentForwarding(timeouts); err != nil { return wrapSSHAgentForwardingUnavailable(err) } if s.closing.Load() { return errSSHClientClosing } closer, err := routeSSHAgentForwarding(client, timeouts) if err != nil { return err } if !s.canAttachAgentForwarder(client) { if closer != nil { _ = closer.Close() } return errSSHClientClosing } s.agentForwarder = closer return nil } func (s *StarSSH) takeAgentForwarder() io.Closer { if s == nil { return nil } s.agentForwardMu.Lock() defer s.agentForwardMu.Unlock() closer := s.agentForwarder s.agentForwarder = nil return closer } func isSSHAgentForwardingDeniedError(err error) bool { if err == nil { return false } if errors.Is(err, errSSHAgentForwardingDenied) { return true } message := strings.ToLower(err.Error()) return strings.Contains(message, "forwarding request denied") || strings.Contains(message, "agent forwarding disabled") } func isSSHAgentForwardingUnavailableError(err error) bool { if err == nil { return false } if errors.Is(err, errSSHAgentForwardingUnavailable) { return true } message := strings.ToLower(err.Error()) return strings.Contains(message, "ssh-agent forwarding unavailable") || strings.Contains(message, "ssh-agent unavailable") } func wrapSSHAgentForwardingUnavailable(err error) error { if err == nil { return nil } if errors.Is(err, errSSHAgentForwardingUnavailable) { return err } if errors.Is(err, errSSHAgentUnavailable) { return fmt.Errorf("%w: %w", errSSHAgentForwardingUnavailable, err) } return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err) } func startSSHAgentForwardProxy(client *ssh.Client, timeouts sshAgentTimeouts) (io.Closer, error) { if client == nil { return nil, errors.New("ssh client is nil") } channels := client.HandleChannelOpen(sshAgentChannelType) if channels == nil { return nil, errors.New("agent: already have handler for " + sshAgentChannelType) } proxy := &sshAgentForwardProxy{ stopCh: make(chan struct{}), active: make(map[*sshAgentForwardBridge]struct{}), } go func() { for { select { case <-proxy.stopCh: return case ch, ok := <-channels: if !ok { return } go handleSSHAgentForwardChannel(proxy, ch, timeouts) } } }() return proxy, nil } func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeouts sshAgentTimeouts) { if ch == nil { return } conn, _, err := dialSSHAgentWithDebug("forward-channel", timeouts) if err != nil { _ = ch.Reject(ssh.ConnectionFailed, err.Error()) return } if conn == nil { _ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable") return } channel, reqs, err := ch.Accept() if err != nil { _ = conn.Close() return } go ssh.DiscardRequests(reqs) bridge := &sshAgentForwardBridge{ proxy: proxy, channel: channel, conn: conn, idleTimeout: timeouts.Forward, } if !proxy.registerBridge(bridge) { bridge.close() return } go bridge.run() } func proxySSHAgentChannel(channel ssh.Channel, conn net.Conn) { bridge := &sshAgentForwardBridge{ channel: channel, conn: conn, } bridge.run() } func (b *sshAgentForwardBridge) run() { if b == nil { return } b.ensureSignals() stopWatchdog := b.startIdleWatchdog() defer stopWatchdog() defer b.unregister() var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() _, _ = io.Copy( sshAgentForwardActivityWriter{Writer: b.channel, touch: b.touch}, sshAgentForwardActivityReader{Reader: b.conn, touch: b.touch}, ) b.close() }() go func() { defer wg.Done() _, _ = io.Copy( sshAgentForwardActivityWriter{Writer: b.conn, touch: b.touch}, sshAgentForwardActivityReader{Reader: b.channel, touch: b.touch}, ) b.close() }() wg.Wait() } func (b *sshAgentForwardBridge) close() { if b == nil { return } b.closeOnce.Do(func() { b.ensureSignals() close(b.done) closeWriter(b.channel) closeWriter(b.conn) if b.channel != nil { _ = b.channel.Close() } if b.conn != nil { _ = b.conn.Close() } }) } func (b *sshAgentForwardBridge) ensureSignals() { if b == nil { return } b.signalOnce.Do(func() { b.done = make(chan struct{}) b.activity = make(chan struct{}, 1) }) } func (b *sshAgentForwardBridge) startIdleWatchdog() func() { if b == nil || b.idleTimeout <= 0 { return func() {} } b.ensureSignals() timer := time.NewTimer(b.idleTimeout) stopped := make(chan struct{}) go func() { defer timer.Stop() for { select { case <-timer.C: b.close() return case <-b.activity: resetTimer(timer, b.idleTimeout) case <-b.done: return case <-stopped: return } } }() return func() { close(stopped) } } func (b *sshAgentForwardBridge) touch() { if b == nil || b.idleTimeout <= 0 || b.activity == nil { return } select { case b.activity <- struct{}{}: default: } } type sshAgentForwardActivityReader struct { io.Reader touch func() } func (r sshAgentForwardActivityReader) Read(p []byte) (int, error) { n, err := r.Reader.Read(p) if n > 0 && r.touch != nil { r.touch() } return n, err } type sshAgentForwardActivityWriter struct { io.Writer touch func() } func (w sshAgentForwardActivityWriter) Write(p []byte) (int, error) { n, err := w.Writer.Write(p) if n > 0 && w.touch != nil { w.touch() } return n, err } func resetTimer(timer *time.Timer, timeout time.Duration) { if !timer.Stop() { select { case <-timer.C: default: } } timer.Reset(timeout) } func (b *sshAgentForwardBridge) unregister() { if b == nil || b.proxy == nil { return } b.proxy.unregisterBridge(b) } func (p *sshAgentForwardProxy) registerBridge(bridge *sshAgentForwardBridge) bool { if p == nil || bridge == nil { return false } p.activeMu.Lock() defer p.activeMu.Unlock() select { case <-p.stopCh: return false default: } if p.active == nil { p.active = make(map[*sshAgentForwardBridge]struct{}) } p.active[bridge] = struct{}{} return true } func (p *sshAgentForwardProxy) unregisterBridge(bridge *sshAgentForwardBridge) { if p == nil || bridge == nil { return } p.activeMu.Lock() defer p.activeMu.Unlock() delete(p.active, bridge) } func (p *sshAgentForwardProxy) closeActive() { if p == nil { return } p.activeMu.Lock() active := make([]*sshAgentForwardBridge, 0, len(p.active)) for bridge := range p.active { active = append(active, bridge) } p.active = make(map[*sshAgentForwardBridge]struct{}) p.activeMu.Unlock() for _, bridge := range active { bridge.close() } } func closeWriter(value any) { type closeWriter interface { CloseWrite() error } if cw, ok := value.(closeWriter); ok { _ = cw.CloseWrite() } }