package starssh import ( "errors" "fmt" "io" "strings" "time" "golang.org/x/crypto/ssh" sshagent "golang.org/x/crypto/ssh/agent" ) var requestSSHAgentForwarding = func(session *ssh.Session) error { return sshagent.RequestAgentForwarding(session) } var routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error { return sshagent.ForwardToAgent(client, keyring) } var newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) { conn, err := dialSSHAgent(timeout) if err != nil { return nil, nil, wrapSSHAgentForwardingUnavailable(err) } if conn == nil { return nil, nil, wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection")) } return sshagent.NewClient(conn), conn, nil } var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied") var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable") func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error { if s == nil { return errors.New("ssh client is nil") } if session == nil { return errors.New("ssh session is nil") } if err := s.ensureAgentForwarding(); err != nil { return err } if err := requestSSHAgentForwarding(session); err != nil { if isSSHAgentForwardingDeniedError(err) { return fmt.Errorf("%w: %v", errSSHAgentForwardingDenied, err) } return err } return nil } func (s *StarSSH) maybeRequestAgentForwarding(session *ssh.Session) error { if s == nil || !s.LoginInfo.ForwardSSHAgent { return nil } err := s.RequestAgentForwarding(session) if isSSHAgentForwardingDeniedError(err) || isSSHAgentForwardingUnavailableError(err) { return nil } return err } func (s *StarSSH) ensureAgentForwarding() error { if s == nil { return errors.New("ssh client is nil") } s.agentForwardMu.Lock() defer s.agentForwardMu.Unlock() if s.agentForwarder != nil { return nil } client, err := s.requireSSHClient() if err != nil { return err } keyring, closer, err := newSSHAgentForwarder(s.LoginInfo.Timeout) if err != nil { return wrapSSHAgentForwardingUnavailable(err) } if s.closing.Load() { _ = closer.Close() return errSSHClientClosing } if err := routeSSHAgentForwarding(client, keyring); err != nil { _ = closer.Close() return err } if !s.canAttachAgentForwarder(client) { _ = closer.Close() return errSSHClientClosing } s.agentForwarder = closer return nil } func (s *StarSSH) takeAgentForwarder() io.Closer { if s == nil { return nil } s.agentForwardMu.Lock() defer s.agentForwardMu.Unlock() closer := s.agentForwarder s.agentForwarder = nil return closer } func isSSHAgentForwardingDeniedError(err error) bool { if err == nil { return false } if errors.Is(err, errSSHAgentForwardingDenied) { return true } message := strings.ToLower(err.Error()) return strings.Contains(message, "forwarding request denied") || strings.Contains(message, "agent forwarding disabled") } func isSSHAgentForwardingUnavailableError(err error) bool { if err == nil { return false } if errors.Is(err, errSSHAgentForwardingUnavailable) { return true } message := strings.ToLower(err.Error()) return strings.Contains(message, "ssh-agent forwarding unavailable") || strings.Contains(message, "ssh-agent unavailable") } func wrapSSHAgentForwardingUnavailable(err error) error { if err == nil { return nil } if errors.Is(err, errSSHAgentForwardingUnavailable) { return err } if errors.Is(err, errSSHAgentUnavailable) { return fmt.Errorf("%w: %w", errSSHAgentForwardingUnavailable, err) } return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err) }