package starssh import ( "errors" "fmt" "io" "net" "strings" "sync" "time" "golang.org/x/crypto/ssh" sshagent "golang.org/x/crypto/ssh/agent" ) var requestSSHAgentForwarding = func(session *ssh.Session) error { return sshagent.RequestAgentForwarding(session) } const sshAgentChannelType = "auth-agent@openssh.com" var routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) { return startSSHAgentForwardProxy(client, timeout) } var probeSSHAgentForwarding = func(timeout time.Duration) error { conn, err := dialSSHAgent(timeout) if err != nil { return wrapSSHAgentForwardingUnavailable(err) } if conn == nil { return wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection")) } return conn.Close() } var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied") var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable") type sshAgentForwardProxy struct { stopOnce sync.Once stopCh chan struct{} activeMu sync.Mutex active map[*sshAgentForwardBridge]struct{} } func (p *sshAgentForwardProxy) Close() error { if p == nil { return nil } p.stopOnce.Do(func() { close(p.stopCh) }) p.closeActive() return nil } type sshAgentForwardBridge struct { proxy *sshAgentForwardProxy channel ssh.Channel conn net.Conn closeOnce sync.Once } func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error { if s == nil { return errors.New("ssh client is nil") } if session == nil { return errors.New("ssh session is nil") } if err := s.ensureAgentForwarding(); err != nil { return err } if err := requestSSHAgentForwarding(session); err != nil { if isSSHAgentForwardingDeniedError(err) { return fmt.Errorf("%w: %v", errSSHAgentForwardingDenied, err) } return err } return nil } func (s *StarSSH) maybeRequestAgentForwarding(session *ssh.Session) error { if s == nil || !s.LoginInfo.ForwardSSHAgent { return nil } err := s.RequestAgentForwarding(session) if isSSHAgentForwardingDeniedError(err) || isSSHAgentForwardingUnavailableError(err) { return nil } return err } func (s *StarSSH) ensureAgentForwarding() error { if s == nil { return errors.New("ssh client is nil") } s.agentForwardMu.Lock() defer s.agentForwardMu.Unlock() if s.agentForwarder != nil { return nil } client, err := s.requireSSHClient() if err != nil { return err } timeout := effectiveDialTimeout(s.LoginInfo) if err := probeSSHAgentForwarding(timeout); err != nil { return wrapSSHAgentForwardingUnavailable(err) } if s.closing.Load() { return errSSHClientClosing } closer, err := routeSSHAgentForwarding(client, timeout) if err != nil { return err } if !s.canAttachAgentForwarder(client) { if closer != nil { _ = closer.Close() } return errSSHClientClosing } s.agentForwarder = closer return nil } func (s *StarSSH) takeAgentForwarder() io.Closer { if s == nil { return nil } s.agentForwardMu.Lock() defer s.agentForwardMu.Unlock() closer := s.agentForwarder s.agentForwarder = nil return closer } func isSSHAgentForwardingDeniedError(err error) bool { if err == nil { return false } if errors.Is(err, errSSHAgentForwardingDenied) { return true } message := strings.ToLower(err.Error()) return strings.Contains(message, "forwarding request denied") || strings.Contains(message, "agent forwarding disabled") } func isSSHAgentForwardingUnavailableError(err error) bool { if err == nil { return false } if errors.Is(err, errSSHAgentForwardingUnavailable) { return true } message := strings.ToLower(err.Error()) return strings.Contains(message, "ssh-agent forwarding unavailable") || strings.Contains(message, "ssh-agent unavailable") } func wrapSSHAgentForwardingUnavailable(err error) error { if err == nil { return nil } if errors.Is(err, errSSHAgentForwardingUnavailable) { return err } if errors.Is(err, errSSHAgentUnavailable) { return fmt.Errorf("%w: %w", errSSHAgentForwardingUnavailable, err) } return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err) } func startSSHAgentForwardProxy(client *ssh.Client, timeout time.Duration) (io.Closer, error) { if client == nil { return nil, errors.New("ssh client is nil") } channels := client.HandleChannelOpen(sshAgentChannelType) if channels == nil { return nil, errors.New("agent: already have handler for " + sshAgentChannelType) } proxy := &sshAgentForwardProxy{ stopCh: make(chan struct{}), active: make(map[*sshAgentForwardBridge]struct{}), } go func() { for { select { case <-proxy.stopCh: return case ch, ok := <-channels: if !ok { return } go handleSSHAgentForwardChannel(proxy, ch, timeout) } } }() return proxy, nil } func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeout time.Duration) { if ch == nil { return } conn, err := dialSSHAgent(timeout) if err != nil { _ = ch.Reject(ssh.ConnectionFailed, err.Error()) return } if conn == nil { _ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable") return } channel, reqs, err := ch.Accept() if err != nil { _ = conn.Close() return } go ssh.DiscardRequests(reqs) bridge := &sshAgentForwardBridge{ proxy: proxy, channel: channel, conn: conn, } if !proxy.registerBridge(bridge) { bridge.close() return } go bridge.run() } func proxySSHAgentChannel(channel ssh.Channel, conn net.Conn) { bridge := &sshAgentForwardBridge{ channel: channel, conn: conn, } bridge.run() } func (b *sshAgentForwardBridge) run() { if b == nil { return } defer b.unregister() var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() _, _ = io.Copy(b.channel, b.conn) b.close() }() go func() { defer wg.Done() _, _ = io.Copy(b.conn, b.channel) b.close() }() wg.Wait() } func (b *sshAgentForwardBridge) close() { if b == nil { return } b.closeOnce.Do(func() { closeWriter(b.channel) closeWriter(b.conn) if b.channel != nil { _ = b.channel.Close() } if b.conn != nil { _ = b.conn.Close() } }) } func (b *sshAgentForwardBridge) unregister() { if b == nil || b.proxy == nil { return } b.proxy.unregisterBridge(b) } func (p *sshAgentForwardProxy) registerBridge(bridge *sshAgentForwardBridge) bool { if p == nil || bridge == nil { return false } p.activeMu.Lock() defer p.activeMu.Unlock() select { case <-p.stopCh: return false default: } if p.active == nil { p.active = make(map[*sshAgentForwardBridge]struct{}) } p.active[bridge] = struct{}{} return true } func (p *sshAgentForwardProxy) unregisterBridge(bridge *sshAgentForwardBridge) { if p == nil || bridge == nil { return } p.activeMu.Lock() defer p.activeMu.Unlock() delete(p.active, bridge) } func (p *sshAgentForwardProxy) closeActive() { if p == nil { return } p.activeMu.Lock() active := make([]*sshAgentForwardBridge, 0, len(p.active)) for bridge := range p.active { active = append(active, bridge) } p.active = make(map[*sshAgentForwardBridge]struct{}) p.activeMu.Unlock() for _, bridge := range active { bridge.close() } } func closeWriter(value any) { type closeWriter interface { CloseWrite() error } if cw, ok := value.(closeWriter); ok { _ = cw.CloseWrite() } }