152 lines
3.6 KiB
Go
152 lines
3.6 KiB
Go
|
|
package starssh
|
||
|
|
|
||
|
|
import (
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"golang.org/x/crypto/ssh"
|
||
|
|
sshagent "golang.org/x/crypto/ssh/agent"
|
||
|
|
)
|
||
|
|
|
||
|
|
var requestSSHAgentForwarding = func(session *ssh.Session) error {
|
||
|
|
return sshagent.RequestAgentForwarding(session)
|
||
|
|
}
|
||
|
|
|
||
|
|
var routeSSHAgentForwarding = func(client *ssh.Client, keyring sshagent.Agent) error {
|
||
|
|
return sshagent.ForwardToAgent(client, keyring)
|
||
|
|
}
|
||
|
|
|
||
|
|
var newSSHAgentForwarder = func(timeout time.Duration) (sshagent.Agent, io.Closer, error) {
|
||
|
|
conn, err := dialSSHAgent(timeout)
|
||
|
|
if err != nil {
|
||
|
|
return nil, nil, wrapSSHAgentForwardingUnavailable(err)
|
||
|
|
}
|
||
|
|
if conn == nil {
|
||
|
|
return nil, nil, wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
|
||
|
|
}
|
||
|
|
return sshagent.NewClient(conn), conn, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
|
||
|
|
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
|
||
|
|
|
||
|
|
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
|
||
|
|
if s == nil {
|
||
|
|
return errors.New("ssh client is nil")
|
||
|
|
}
|
||
|
|
if session == nil {
|
||
|
|
return errors.New("ssh session is nil")
|
||
|
|
}
|
||
|
|
if err := s.ensureAgentForwarding(); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if err := requestSSHAgentForwarding(session); err != nil {
|
||
|
|
if isSSHAgentForwardingDeniedError(err) {
|
||
|
|
return fmt.Errorf("%w: %v", errSSHAgentForwardingDenied, err)
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *StarSSH) maybeRequestAgentForwarding(session *ssh.Session) error {
|
||
|
|
if s == nil || !s.LoginInfo.ForwardSSHAgent {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
err := s.RequestAgentForwarding(session)
|
||
|
|
if isSSHAgentForwardingDeniedError(err) || isSSHAgentForwardingUnavailableError(err) {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *StarSSH) ensureAgentForwarding() error {
|
||
|
|
if s == nil {
|
||
|
|
return errors.New("ssh client is nil")
|
||
|
|
}
|
||
|
|
|
||
|
|
s.agentForwardMu.Lock()
|
||
|
|
defer s.agentForwardMu.Unlock()
|
||
|
|
|
||
|
|
if s.agentForwarder != nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
client, err := s.requireSSHClient()
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
keyring, closer, err := newSSHAgentForwarder(s.LoginInfo.Timeout)
|
||
|
|
if err != nil {
|
||
|
|
return wrapSSHAgentForwardingUnavailable(err)
|
||
|
|
}
|
||
|
|
if s.closing.Load() {
|
||
|
|
_ = closer.Close()
|
||
|
|
return errSSHClientClosing
|
||
|
|
}
|
||
|
|
if err := routeSSHAgentForwarding(client, keyring); err != nil {
|
||
|
|
_ = closer.Close()
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if !s.canAttachAgentForwarder(client) {
|
||
|
|
_ = closer.Close()
|
||
|
|
return errSSHClientClosing
|
||
|
|
}
|
||
|
|
s.agentForwarder = closer
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *StarSSH) takeAgentForwarder() io.Closer {
|
||
|
|
if s == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
s.agentForwardMu.Lock()
|
||
|
|
defer s.agentForwardMu.Unlock()
|
||
|
|
|
||
|
|
closer := s.agentForwarder
|
||
|
|
s.agentForwarder = nil
|
||
|
|
return closer
|
||
|
|
}
|
||
|
|
|
||
|
|
func isSSHAgentForwardingDeniedError(err error) bool {
|
||
|
|
if err == nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
if errors.Is(err, errSSHAgentForwardingDenied) {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
message := strings.ToLower(err.Error())
|
||
|
|
return strings.Contains(message, "forwarding request denied") ||
|
||
|
|
strings.Contains(message, "agent forwarding disabled")
|
||
|
|
}
|
||
|
|
|
||
|
|
func isSSHAgentForwardingUnavailableError(err error) bool {
|
||
|
|
if err == nil {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
if errors.Is(err, errSSHAgentForwardingUnavailable) {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
message := strings.ToLower(err.Error())
|
||
|
|
return strings.Contains(message, "ssh-agent forwarding unavailable") ||
|
||
|
|
strings.Contains(message, "ssh-agent unavailable")
|
||
|
|
}
|
||
|
|
|
||
|
|
func wrapSSHAgentForwardingUnavailable(err error) error {
|
||
|
|
if err == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if errors.Is(err, errSSHAgentForwardingUnavailable) {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if errors.Is(err, errSSHAgentUnavailable) {
|
||
|
|
return fmt.Errorf("%w: %w", errSSHAgentForwardingUnavailable, err)
|
||
|
|
}
|
||
|
|
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
|
||
|
|
}
|