starssh/agent_forward.go

356 lines
7.2 KiB
Go
Raw Permalink Normal View History

package starssh
import (
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
sshagent "golang.org/x/crypto/ssh/agent"
)
var requestSSHAgentForwarding = func(session *ssh.Session) error {
return sshagent.RequestAgentForwarding(session)
}
const sshAgentChannelType = "auth-agent@openssh.com"
var routeSSHAgentForwarding = func(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
return startSSHAgentForwardProxy(client, timeout)
}
var probeSSHAgentForwarding = func(timeout time.Duration) error {
conn, err := dialSSHAgent(timeout)
if err != nil {
return wrapSSHAgentForwardingUnavailable(err)
}
if conn == nil {
return wrapSSHAgentForwardingUnavailable(errors.New("empty agent connection"))
}
return conn.Close()
}
var errSSHAgentForwardingDenied = errors.New("ssh-agent forwarding request denied")
var errSSHAgentForwardingUnavailable = errors.New("ssh-agent forwarding unavailable")
type sshAgentForwardProxy struct {
stopOnce sync.Once
stopCh chan struct{}
activeMu sync.Mutex
active map[*sshAgentForwardBridge]struct{}
}
func (p *sshAgentForwardProxy) Close() error {
if p == nil {
return nil
}
p.stopOnce.Do(func() {
close(p.stopCh)
})
p.closeActive()
return nil
}
type sshAgentForwardBridge struct {
proxy *sshAgentForwardProxy
channel ssh.Channel
conn net.Conn
closeOnce sync.Once
}
func (s *StarSSH) RequestAgentForwarding(session *ssh.Session) error {
if s == nil {
return errors.New("ssh client is nil")
}
if session == nil {
return errors.New("ssh session is nil")
}
if err := s.ensureAgentForwarding(); err != nil {
return err
}
if err := requestSSHAgentForwarding(session); err != nil {
if isSSHAgentForwardingDeniedError(err) {
return fmt.Errorf("%w: %v", errSSHAgentForwardingDenied, err)
}
return err
}
return nil
}
func (s *StarSSH) maybeRequestAgentForwarding(session *ssh.Session) error {
if s == nil || !s.LoginInfo.ForwardSSHAgent {
return nil
}
err := s.RequestAgentForwarding(session)
if isSSHAgentForwardingDeniedError(err) || isSSHAgentForwardingUnavailableError(err) {
return nil
}
return err
}
func (s *StarSSH) ensureAgentForwarding() error {
if s == nil {
return errors.New("ssh client is nil")
}
s.agentForwardMu.Lock()
defer s.agentForwardMu.Unlock()
if s.agentForwarder != nil {
return nil
}
client, err := s.requireSSHClient()
if err != nil {
return err
}
timeout := effectiveDialTimeout(s.LoginInfo)
if err := probeSSHAgentForwarding(timeout); err != nil {
return wrapSSHAgentForwardingUnavailable(err)
}
if s.closing.Load() {
return errSSHClientClosing
}
closer, err := routeSSHAgentForwarding(client, timeout)
if err != nil {
return err
}
if !s.canAttachAgentForwarder(client) {
if closer != nil {
_ = closer.Close()
}
return errSSHClientClosing
}
s.agentForwarder = closer
return nil
}
func (s *StarSSH) takeAgentForwarder() io.Closer {
if s == nil {
return nil
}
s.agentForwardMu.Lock()
defer s.agentForwardMu.Unlock()
closer := s.agentForwarder
s.agentForwarder = nil
return closer
}
func isSSHAgentForwardingDeniedError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, errSSHAgentForwardingDenied) {
return true
}
message := strings.ToLower(err.Error())
return strings.Contains(message, "forwarding request denied") ||
strings.Contains(message, "agent forwarding disabled")
}
func isSSHAgentForwardingUnavailableError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, errSSHAgentForwardingUnavailable) {
return true
}
message := strings.ToLower(err.Error())
return strings.Contains(message, "ssh-agent forwarding unavailable") ||
strings.Contains(message, "ssh-agent unavailable")
}
func wrapSSHAgentForwardingUnavailable(err error) error {
if err == nil {
return nil
}
if errors.Is(err, errSSHAgentForwardingUnavailable) {
return err
}
if errors.Is(err, errSSHAgentUnavailable) {
return fmt.Errorf("%w: %w", errSSHAgentForwardingUnavailable, err)
}
return fmt.Errorf("%w: %v", errSSHAgentForwardingUnavailable, err)
}
func startSSHAgentForwardProxy(client *ssh.Client, timeout time.Duration) (io.Closer, error) {
if client == nil {
return nil, errors.New("ssh client is nil")
}
channels := client.HandleChannelOpen(sshAgentChannelType)
if channels == nil {
return nil, errors.New("agent: already have handler for " + sshAgentChannelType)
}
proxy := &sshAgentForwardProxy{
stopCh: make(chan struct{}),
active: make(map[*sshAgentForwardBridge]struct{}),
}
go func() {
for {
select {
case <-proxy.stopCh:
return
case ch, ok := <-channels:
if !ok {
return
}
go handleSSHAgentForwardChannel(proxy, ch, timeout)
}
}
}()
return proxy, nil
}
func handleSSHAgentForwardChannel(proxy *sshAgentForwardProxy, ch ssh.NewChannel, timeout time.Duration) {
if ch == nil {
return
}
conn, err := dialSSHAgent(timeout)
if err != nil {
_ = ch.Reject(ssh.ConnectionFailed, err.Error())
return
}
if conn == nil {
_ = ch.Reject(ssh.ConnectionFailed, "ssh-agent connection unavailable")
return
}
channel, reqs, err := ch.Accept()
if err != nil {
_ = conn.Close()
return
}
go ssh.DiscardRequests(reqs)
bridge := &sshAgentForwardBridge{
proxy: proxy,
channel: channel,
conn: conn,
}
if !proxy.registerBridge(bridge) {
bridge.close()
return
}
go bridge.run()
}
func proxySSHAgentChannel(channel ssh.Channel, conn net.Conn) {
bridge := &sshAgentForwardBridge{
channel: channel,
conn: conn,
}
bridge.run()
}
func (b *sshAgentForwardBridge) run() {
if b == nil {
return
}
defer b.unregister()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(b.channel, b.conn)
b.close()
}()
go func() {
defer wg.Done()
_, _ = io.Copy(b.conn, b.channel)
b.close()
}()
wg.Wait()
}
func (b *sshAgentForwardBridge) close() {
if b == nil {
return
}
b.closeOnce.Do(func() {
closeWriter(b.channel)
closeWriter(b.conn)
if b.channel != nil {
_ = b.channel.Close()
}
if b.conn != nil {
_ = b.conn.Close()
}
})
}
func (b *sshAgentForwardBridge) unregister() {
if b == nil || b.proxy == nil {
return
}
b.proxy.unregisterBridge(b)
}
func (p *sshAgentForwardProxy) registerBridge(bridge *sshAgentForwardBridge) bool {
if p == nil || bridge == nil {
return false
}
p.activeMu.Lock()
defer p.activeMu.Unlock()
select {
case <-p.stopCh:
return false
default:
}
if p.active == nil {
p.active = make(map[*sshAgentForwardBridge]struct{})
}
p.active[bridge] = struct{}{}
return true
}
func (p *sshAgentForwardProxy) unregisterBridge(bridge *sshAgentForwardBridge) {
if p == nil || bridge == nil {
return
}
p.activeMu.Lock()
defer p.activeMu.Unlock()
delete(p.active, bridge)
}
func (p *sshAgentForwardProxy) closeActive() {
if p == nil {
return
}
p.activeMu.Lock()
active := make([]*sshAgentForwardBridge, 0, len(p.active))
for bridge := range p.active {
active = append(active, bridge)
}
p.active = make(map[*sshAgentForwardBridge]struct{})
p.activeMu.Unlock()
for _, bridge := range active {
bridge.close()
}
}
func closeWriter(value any) {
type closeWriter interface {
CloseWrite() error
}
if cw, ok := value.(closeWriter); ok {
_ = cw.CloseWrite()
}
}